torch.nn.utils.parametrize.cached¶
- torch.nn.utils.parametrize.cached()[源代码][源代码]¶
上下文管理器,用于启用在通过
register_parametrization()
注册的参数化中使用的缓存系统。当此上下文管理器处于活动状态时,参数化对象的值在首次需要时计算并缓存。 离开上下文管理器时,缓存的值将被丢弃。
当在正向传播中多次使用参数化参数时,这非常有用。 一个例子是在参数化 RNN 的循环核时或共享权重时。
激活缓存的最简单方法是包装神经网络的正向传播
import torch.nn.utils.parametrize as P ... with P.cached(): output = model(inputs)
在训练和评估中。 也可以包装多次使用参数化张量的模块部分。 例如,具有参数化循环核的 RNN 的循环
with P.cached(): for x in xs: out_rnn = self.rnn_cell(x, out_rnn)