torch.func API 参考¶
函数转换¶
| vmap 是向量化映射; | |
| 
 | |
| 返回一个函数,用于计算梯度和原始(或正向)计算的元组。 | |
| 代表向量-雅可比乘积,返回一个元组,其中包含应用于  | |
| 代表雅可比-向量乘积,返回一个元组,其中包含 func(*primals) 的输出,以及“在  | |
| 返回  | |
| 使用反向模式自动微分计算  | |
| 使用前向模式自动微分计算  | |
| 使用前向-反向策略计算  | |
| functionalize 是一种转换,可以用来移除函数中的(中间)变异和别名,同时保留函数的语义。 | 
用于处理 torch.nn.Modules 的实用程序¶
一般来说,您可以对调用 torch.nn.Module 的函数进行转换。例如,以下是一个计算接受三个值并返回三个值的函数的雅可比矩阵的示例
model = torch.nn.Linear(3, 3)
def f(x):
    return model(x)
x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)
但是,如果您想做一些类似于计算模型参数的雅可比矩阵的事情,那么需要有一种方法来构建一个函数,其中参数是函数的输入。这就是 functional_call() 的作用:它接受一个 nn.Module、转换后的 parameters 以及 Module 前向传递的输入。它返回使用替换后的参数运行 Module 前向传递的值。
以下是我们如何计算参数的雅可比矩阵
model = torch.nn.Linear(3, 3)
def f(params, x):
    return torch.func.functional_call(model, params, x)
x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)
| 通过用提供的参数替换模块参数和缓冲区,对模块执行函数调用。 | |
| 准备一个 torch.nn.Modules 列表,用于与  | |
| 通过将  | 
如果您正在寻找有关修复 Batch Norm 模块的信息,请遵循此处的指南