修补批次规范化¶
发生了什么?¶
批次规范化需要对与输入大小相同的 running_mean 和 running_var 进行就地更新。Functorch 不支持对采用批次张量的常规张量进行就地更新(即不允许 regular.add_(batched)
)。因此,当对单个模块的批次输入进行 vmap 时,我们最终会遇到此错误
如何修复¶
最受支持的最佳方式之一是将 BatchNorm 切换为 GroupNorm。选项 1 和 2 支持此操作
所有这些选项都假设您不需要运行统计信息。如果您正在使用模块,则这意味着假设您不会在评估模式下使用批次规范化。如果您有在评估模式下使用 vmap 运行批次规范化的用例,请提交问题
选项 1:更改 BatchNorm¶
如果您想更改为 GroupNorm,在您拥有 BatchNorm 的任何地方,用以下内容替换它
BatchNorm2d(C, G, track_running_stats=False)
此处 C
与原始 BatchNorm 中的 C
相同。 G
是将 C
分解为的组数。因此,C % G == 0
并且作为后备,您可以设置 C == G
,这意味着每个通道将被单独处理。
如果您必须使用 BatchNorm 并且您自己构建了模块,则可以更改模块以不使用运行统计信息。换句话说,在任何有 BatchNorm 模块的地方,将 track_running_stats
标志设置为 False
BatchNorm2d(64, track_running_stats=False)
选项 2:torchvision 参数¶
一些 torchvision 模型(如 resnet 和 regnet)可以采用 norm_layer
参数。如果它们已设为默认值,则通常默认为 BatchNorm2d。
相反,您可以将其设置为 GroupNorm。
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))
在此处,再次,c % g == 0
因此作为后备,设置 g = c
。
如果您依附于 BatchNorm,请务必使用不使用运行统计信息的版本
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
选项 3:functorch 的修补¶
functorch 添加了一些功能,允许快速、就地修补模块以不使用运行统计信息。更改范数层更脆弱,因此我们没有提供。如果您有一个网络,您希望 BatchNorm 不使用运行统计信息,则可以运行 replace_all_batch_norm_modules_
以就地更新模块以不使用运行统计信息
from torch.func import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)
选项 4:评估模式¶
在评估模式下运行时,running_mean 和 running_var 将不会更新。因此,vmap 可以支持此模式
model.eval()
vmap(model)(x)
model.train()