torch.autograd.Function.vmap¶
- 静态 Function.vmap(info, in_dims, *args)[源代码][源代码]¶
定义此 autograd.Function 在
torch.vmap()
下的行为。对于一个
torch.autograd.Function()
要支持torch.vmap()
,您必须覆盖此静态方法,或将generate_vmap_rule
设置为True
(两者不可同时进行)。如果您选择覆盖此静态方法:它必须接受
一个
info
对象作为第一个参数。info.batch_size
指定正在 vmap 化的维度的尺寸,而info.randomness
是传递给torch.vmap()
的随机性选项。一个
in_dims
元组作为第二个参数。对于args
中的每个参数,in_dims
都有一个对应的Optional[int]
。如果参数不是 Tensor 或参数未被 vmap 化,则为None
,否则,它是一个整数,指定 Tensor 的哪个维度正在被 vmap 化。*args
,它与forward()
的参数相同。
vmap 静态方法的返回值是一个
(output, out_dims)
元组。与in_dims
类似,out_dims
应与output
具有相同的结构,并且每个输出包含一个out_dim
,用于指定输出是否具有 vmap 化的维度以及它在哪个索引中。请参阅 使用 autograd.Function 扩展 torch.func 以获取更多详细信息。