torch.nn.utils.convert_conv2d_weight_memory_format¶
- torch.nn.utils.convert_conv2d_weight_memory_format(module, memory_format)[源代码]¶
将
nn.Conv2d.weight
的memory_format
转换为memory_format
。转换递归地应用于嵌套的
nn.Module
,包括module
。请注意,它只更改 memory_format,而不是每个维度的语义。此函数用于促进计算以采用 NHWC 内核,这为具有计算能力 >= 7.0 的 CUDA 设备上的 fp16 数据提供了相当大的加速注意
调用
model.to(memory_format=torch.channels_last)
比实用函数convert_conv2d_weight_memory_format
更激进。任何具有 4d 权重的层都将受到model.to
的影响,这并不一定从转换为指定的memory_format
中受益。我们确信的一个地方是 cuDNN 中卷积的 NHWC(channels_last)转换,因为它有利于以 NHWC 运行卷积,即使在必须对输入张量应用排列的情况下也是如此。因此,我们在这里的策略是仅将卷积的权重转换为 channels_last。这确保了:1. 将使用快速卷积内核,其好处可能超过排列的开销(如果输入格式不同)。2. 对不从 memory_format 转换中受益的层不应用不必要的排列。
最佳情况是,卷积层之间的层与 channels_last 兼容。当遇到第一个卷积层时,输入张量将被排列为 channels_last 并保持该内存格式。因此,后续卷积不需要排列其输入张量。
如果卷积层之间存在与 channels_last 不兼容的层,我们需要将输入张量排列回连续格式以供该层使用。输入张量将以连续格式遍历其余层,并在遇到另一个卷积层时被排列为 channels_last。将该排列传播到较早的层没有意义,因为大多数层对
memory_format
非常不敏感。当 PyTorch 支持排列融合时,此断言可能会发生变化,因为可能存在比在卷积之前立即融合排列更好的位置。
- 参数
module (nn.Module) –
nn.Conv2d
和nn.ConvTranspose2d
或容器nn.Module
memory_format – 用户指定的
memory_format
,例如torch.channels_last
或torch.contiguous_format
- 返回值
具有更新的
nn.Conv2d
的原始模块
示例
>>> input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda") >>> model = nn.Sequential( >>> nn.Conv2d(8, 4, 3)).cuda().half() >>> # This is identical to: >>> # nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last) >>> model = nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last) >>> out = model(input)