torch.autograd.Function.vmap¶
- static Function.vmap(info, in_dims, *args)[源代码]¶
定义此 autograd.Function 在
torch.vmap()下的执行行为。要使
torch.autograd.Function()支持torch.vmap(),您必须覆盖此静态方法,或将generate_vmap_rule设置为True(您不能同时执行这两项操作)。如果您选择覆盖此静态方法:它必须接受
一个
info对象作为第一个参数。info.batch_size指定正在进行 vmap 的维度的尺寸,而info.randomness是传递给torch.vmap()的随机性选项。一个
in_dims元组作为第二个参数。对于args中的每个参数,in_dims都有一个对应的Optional[int]。如果参数不是张量或参数未被 vmap,则为None,否则,它是一个整数,指定正在进行 vmap 的张量的哪个维度。*args,它与forward()的参数相同。
vmap 静态方法的返回值是
(output, out_dims)的元组。与in_dims类似,out_dims应该与output的结构相同,并且每个输出包含一个out_dim,用于指定输出是否具有 vmap 维度以及该维度在输出中的索引。有关更多详细信息,请参阅 使用 autograd.Function 扩展 torch.func。