快捷方式

torch.func API 参考

函数变换

vmap

vmap 是向量化映射;vmap(func) 返回一个新函数,该函数将 func 映射到输入的某个维度上。

grad

grad 运算符有助于计算 func 相对于 argnums 指定的输入(或多个输入)的梯度。

grad_and_value

返回一个函数,用于计算梯度和原始计算(或前向计算)组成的元组。

vjp

vjp 代表向量-雅可比乘积 (vector-Jacobian product),返回一个元组,其中包含应用于 primalsfunc 的结果,以及一个给定 cotangents 后计算 func 相对于 primals 的反向模式雅可比(再乘以 cotangents)的函数。

jvp

jvp 代表雅可比-向量乘积 (Jacobian-vector product),返回一个元组,其中包含 func(*primals) 的输出以及“在 primals 处计算的 func 的雅可比”乘以 tangents 的结果。

linearize

返回 funcprimals 处的值以及在 primals 处的线性近似值。

jacrev

使用反向模式自动微分计算 func 相对于索引 argnum 处的参数(或多个参数)的雅可比。

jacfwd

使用前向模式自动微分计算 func 相对于索引 argnum 处的参数(或多个参数)的雅可比。

hessian

通过前向-后向策略计算 func 相对于索引 argnum 处的参数(或多个参数)的 Hessian 矩阵。

functionalize

functionalize 是一种变换,可用于从函数中去除(中间)修改和别名,同时保留函数的语义。

用于处理 torch.nn.Modules 的工具

通常,您可以对调用 torch.nn.Module 的函数进行变换。例如,下面是一个计算接受三个值并返回三个值的函数的雅可比的示例

model = torch.nn.Linear(3, 3)

def f(x):
    return model(x)

x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)

但是,如果您想对模型的参数计算雅可比之类的操作,则需要一种方法来构造一个将参数作为函数输入的函数。这就是 functional_call() 的作用:它接受一个 nn.Module、变换后的 parameters 以及 Module 前向传播的输入。它返回使用替换参数运行 Module 前向传播的值。

下面是如何计算参数的雅可比

model = torch.nn.Linear(3, 3)

def f(params, x):
    return torch.func.functional_call(model, params, x)

x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)

functional_call

通过用提供的参数和缓冲区替换模块的参数和缓冲区,对模块执行函数式调用。

stack_module_state

准备一个 torch.nn.Modules 列表,以便与 vmap() 一起进行集成。

replace_all_batch_norm_modules_

通过将 running_meanrunning_var 设置为 None 并将 track_running_stats 设置为 False,原地更新 root 中的任何 nn.BatchNorm 模块。

如果您正在寻找关于修复 Batch Norm 模块的信息,请遵循此处的指南

调试工具

debug_unwrap

展开一个 functorch tensor(例如

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源