快捷方式

torch.nn.utils.parametrize.cached

torch.nn.utils.parametrize.cached()[源代码]

上下文管理器,可在使用 register_parametrization() 注册的参数化中启用缓存系统。

当此上下文管理器处于活动状态时,首次需要参数化对象的值时,将计算并缓存该值。离开上下文管理器时,将丢弃缓存的值。

当在正向传递中多次使用参数化参数时,这很有用。一个例子是在参数化 RNN 的循环内核或共享权重时。

激活缓存的最简单方法是包装神经网络的前向传递

import torch.nn.utils.parametrize as P
...
with P.cached():
    output = model(inputs)

在训练和评估中。也可以包装模块中多次使用参数化张量的部分。例如,具有参数化循环内核的 RNN 的循环

with P.cached():
    for x in xs:
        out_rnn = self.rnn_cell(x, out_rnn)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源