修补 Batch Norm¶
发生了什么?¶
Batch Norm 需要对与输入具有相同尺寸的 running_mean
和 running_var
进行原地更新。Functorch 不支持对接收批处理张量(即 regular.add_(batched)
不允许)的常规张量进行原地更新。因此,当对单个模块的输入批次进行 vmapping 时,我们会遇到此错误
如何修复¶
一种最佳支持的方式是将 BatchNorm 替换为 GroupNorm。选项 1 和 2 支持此方法
所有这些选项都假设您不需要运行统计量(running stats)。如果您正在使用模块,这意味着假设您不会在评估模式下使用 BatchNorm。如果您需要在评估模式下使用 vmap 运行 BatchNorm,请提交一个 issue
选项 1:更改 BatchNorm¶
如果您想将其更改为 GroupNorm,请将所有 BatchNorm 替换为
BatchNorm2d(C, G, track_running_stats=False)
这里 C
与原始 BatchNorm 中的 C
相同。G
是将 C
划分成的组数。因此,C % G == 0
成立,作为备用方案,您可以设置 C == G
,这意味着每个通道都将被单独处理。
如果您必须使用 BatchNorm 并且是自己构建的模块,则可以更改模块以不使用运行统计量。换句话说,在任何 BatchNorm 模块处,将 track_running_stats
标志设置为 False
BatchNorm2d(64, track_running_stats=False)
选项 2:torchvision 参数¶
一些 torchvision 模型,如 resnet 和 regnet,可以接受 norm_layer
参数。如果使用了默认值,这些参数通常会默认为 BatchNorm2d。
您可以将其设置为 GroupNorm。
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))
这里,同样,c % g == 0
成立,因此作为备用方案,设置 g = c
。
如果您坚持使用 BatchNorm,请务必使用不使用运行统计量的版本
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
选项 3:functorch 的修补¶
functorch 添加了一些功能,可以快速、原地修补模块以不使用运行统计量。更改规范层更脆弱,因此我们没有提供该选项。如果您有一个网络,希望其中的 BatchNorm 不使用运行统计量,可以运行 replace_all_batch_norm_modules_
来原地更新模块以不使用运行统计量。
from torch.func import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)
选项 4:评估模式¶
在评估模式下运行时,running_mean
和 running_var
将不会更新。因此,vmap 可以支持此模式。
model.eval()
vmap(model)(x)
model.train()