快捷方式

批处理归一化修补

发生了什么?

批处理归一化需要对与输入大小相同的 running_mean 和 running_var 进行就地更新。Functorch 不支持对接收批处理张量的普通张量进行就地更新(即 regular.add_(batched) 不允许)。因此,当对单个模块的一批输入进行 vmap 时,最终会出现此错误

如何修复

所有这些选项都假设您不需要运行统计信息。如果您正在使用模块,则表示假设您不会在评估模式下使用批处理归一化。如果您有在评估模式下使用 vmap 进行批处理归一化的用例,请提交问题

选项 1:更改批处理归一化

如果您自己构建了模块,则可以更改模块以不使用运行统计信息。换句话说,在任何存在 BatchNorm 模块的地方,都将 track_running_stats 标志设置为 False

BatchNorm2d(64, track_running_stats=False)

选项 2:torchvision 参数

一些 torchvision 模型,如 resnet 和 regnet,可以接收 norm_layer 参数。如果它们已设置默认值,则通常默认为 BatchNorm2d。相反,您可以将其设置为不使用运行统计信息的 BatchNorm

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))

选项 3:functorch 的修补

functorch 添加了一些功能,以便快速、就地修补模块。如果您想要更改网络,可以运行 replace_all_batch_norm_modules_ 以就地更新模块以不使用运行统计信息

from functorch.experimental import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源