快捷方式

torch.autograd.Function.vmap

静态 Function.vmap(info, in_dims, *args)[源代码][源代码]

定义此 autograd.Function 在 torch.vmap() 下的行为。

对于一个 torch.autograd.Function() 要支持 torch.vmap(),您必须覆盖此静态方法,或将 generate_vmap_rule 设置为 True(两者不可同时进行)。

如果您选择覆盖此静态方法:它必须接受

  • 一个 info 对象作为第一个参数。info.batch_size 指定正在 vmap 化的维度的尺寸,而 info.randomness 是传递给 torch.vmap() 的随机性选项。

  • 一个 in_dims 元组作为第二个参数。对于 args 中的每个参数,in_dims 都有一个对应的 Optional[int]。如果参数不是 Tensor 或参数未被 vmap 化,则为 None,否则,它是一个整数,指定 Tensor 的哪个维度正在被 vmap 化。

  • *args,它与 forward() 的参数相同。

vmap 静态方法的返回值是一个 (output, out_dims) 元组。与 in_dims 类似,out_dims 应与 output 具有相同的结构,并且每个输出包含一个 out_dim,用于指定输出是否具有 vmap 化的维度以及它在哪个索引中。

请参阅 使用 autograd.Function 扩展 torch.func 以获取更多详细信息。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源