torch.func API 参考¶
函数变换¶
vmap 是向量化映射; |
|
|
|
返回一个函数,用于计算梯度和原始计算(或前向计算)组成的元组。 |
|
vjp 代表向量-雅可比乘积 (vector-Jacobian product),返回一个元组,其中包含应用于 |
|
jvp 代表雅可比-向量乘积 (Jacobian-vector product),返回一个元组,其中包含 func(*primals) 的输出以及“在 |
|
返回 |
|
使用反向模式自动微分计算 |
|
使用前向模式自动微分计算 |
|
通过前向-后向策略计算 |
|
functionalize 是一种变换,可用于从函数中去除(中间)修改和别名,同时保留函数的语义。 |
用于处理 torch.nn.Modules 的工具¶
通常,您可以对调用 torch.nn.Module
的函数进行变换。例如,下面是一个计算接受三个值并返回三个值的函数的雅可比的示例
model = torch.nn.Linear(3, 3)
def f(x):
return model(x)
x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)
但是,如果您想对模型的参数计算雅可比之类的操作,则需要一种方法来构造一个将参数作为函数输入的函数。这就是 functional_call()
的作用:它接受一个 nn.Module、变换后的 parameters
以及 Module 前向传播的输入。它返回使用替换参数运行 Module 前向传播的值。
下面是如何计算参数的雅可比
model = torch.nn.Linear(3, 3)
def f(params, x):
return torch.func.functional_call(model, params, x)
x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)
通过用提供的参数和缓冲区替换模块的参数和缓冲区,对模块执行函数式调用。 |
|
准备一个 torch.nn.Modules 列表,以便与 |
|
通过将 |
如果您正在寻找关于修复 Batch Norm 模块的信息,请遵循此处的指南
调试工具¶
展开一个 functorch tensor(例如 |