快捷方式

torch.nn.utils.parametrize.cached

torch.nn.utils.parametrize.cached()[源代码][源代码]

上下文管理器,用于在通过 register_parametrization() 注册的参数化中启用缓存系统。

当此上下文管理器处于活动状态时,参数化对象的值在首次需要时计算并缓存。离开上下文管理器时,缓存的值将被丢弃。

这在使用参数化参数在正向传播中多次出现时非常有用。例如,当对 RNN 的循环核进行参数化或共享权重时。

激活缓存的最简单方法是包裹神经网络的正向传播

import torch.nn.utils.parametrize as P
...
with P.cached():
    output = model(inputs)

在训练和评估中。也可以包裹多次使用参数化张量的模块部分。例如,带有参数化循环核的 RNN 的循环部分

with P.cached():
    for x in xs:
        out_rnn = self.rnn_cell(x, out_rnn)

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源