torch.nn.utils.parametrize.cached¶
- torch.nn.utils.parametrize.cached()[源代码]¶
上下文管理器,可在使用
register_parametrization()
注册的参数化中启用缓存系统。当此上下文管理器处于活动状态时,首次需要参数化对象的值时,将计算并缓存该值。离开上下文管理器时,将丢弃缓存的值。
当在正向传递中多次使用参数化参数时,这很有用。一个例子是在参数化 RNN 的循环内核或共享权重时。
激活缓存的最简单方法是包装神经网络的前向传递
import torch.nn.utils.parametrize as P ... with P.cached(): output = model(inputs)
在训练和评估中。也可以包装模块中多次使用参数化张量的部分。例如,具有参数化循环内核的 RNN 的循环
with P.cached(): for x in xs: out_rnn = self.rnn_cell(x, out_rnn)