快捷方式

批归一化补丁

发生了什么?

批归一化需要对与输入大小相同的 running_mean 和 running_var 进行就地更新。Functorch 不支持对接收批处理张量的普通张量进行就地更新(即 regular.add_(batched) 不允许)。因此,当对单个模块的一批输入进行 vmap 时,我们最终会遇到此错误

如何修复

最受支持的方法之一是将批归一化切换为组归一化。选项 1 和 2 支持此功能

所有这些选项都假设您不需要运行统计信息。如果您使用的是模块,则意味着假设您不会在评估模式下使用批归一化。如果您有一个使用案例涉及在评估模式下使用 vmap 运行批归一化,请提交一个问题

选项 1:更改批归一化

如果您想更改为组归一化,在任何包含批归一化的地方,用以下代码替换它

BatchNorm2d(C, G, track_running_stats=False)

这里 C 与原始批归一化中的 C 相同。 G 是将 C 分割成的组数。因此,C % G == 0,作为备用,您可以设置 C == G,这意味着每个通道将被单独处理。

如果您必须使用批归一化并且自己构建了模块,则可以更改模块以不使用运行统计信息。换句话说,在任何包含批归一化模块的地方,将 track_running_stats 标志设置为 False

BatchNorm2d(64, track_running_stats=False)

选项 2:torchvision 参数

一些 torchvision 模型(如 resnet 和 regnet)可以接收 norm_layer 参数。如果它们已设置默认值,则通常默认为 BatchNorm2d。

您可以改为将其设置为组归一化。

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))

这里,再次,c % g == 0,因此作为备用,请设置 g = c

如果您坚持使用批归一化,请务必使用不使用运行统计信息的版本

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))

选项 3:functorch 的补丁

functorch 添加了一些功能,允许对模块进行快速的就地补丁,以不使用运行统计信息。更改规范化层更脆弱,因此我们没有提供这种功能。如果您有一个您希望批归一化不使用运行统计信息的网络,您可以运行 replace_all_batch_norm_modules_ 对模块进行就地更新,以不使用运行统计信息

from torch.func import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)

选项 4:评估模式

在评估模式下运行时,running_mean 和 running_var 不会更新。因此,vmap 可以支持此模式

model.eval()
vmap(model)(x)
model.train()

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源