torch.nn.utils.convert_conv3d_weight_memory_format¶
- torch.nn.utils.convert_conv3d_weight_memory_format(module, memory_format)[源代码][源代码]¶
将
nn.Conv3d.weight
的memory_format
转换为指定的memory_format
。此转换递归应用于嵌套的nn.Module
,包括module
本身。请注意,它只改变 memory_format,而不改变每个维度的语义。此函数用于促使计算采用 NHWC 核,这在计算能力 >= 7.0 的 CUDA 设备上为 fp16 数据提供了显著的速度提升。注意
调用
model.to(memory_format=torch.channels_last_3d)
比 utility 函数convert_conv3d_weight_memory_format
更具侵略性。任何带有 4d 权重的层都会受到model.to
的影响,而这些层不一定能从转换为指定的memory_format
中受益。我们有信心的一点是 cuDNN 中卷积的 NDHWC(channels_last_3d) 转换,因为即使在必须对输入张量应用排列的情况下,在 NDHWC 中运行卷积也是有益的。因此,我们的策略是仅将卷积的权重转换为 channels_last_3d。这确保了:1. 将使用快速卷积核,其好处可能超过排列的开销(如果输入不是相同格式)。2. 不会对不从 memory_format 转换中受益的层应用不必要的排列。
最优情况是,卷积层之间的层与 channels last 格式兼容。输入张量在遇到第一个卷积层时会被排列为 channels last 格式并保持该内存格式。因此,后续的卷积不需要对其输入张量进行排列。
在卷积层之间存在 channels last 不兼容层的情况下,我们需要将该层的输入张量重新排列回 contiguous 格式。输入张量将以 contiguous 格式通过剩余的层,并在遇到另一个卷积层时被排列为 channels last 格式。将该排列传播到较早的层没有意义,因为大多数层对
memory_format
相当不敏感。当 PyTorch 支持排列融合时,这种说法可能会改变,因为可能存在比紧接在卷积之前更好的排列融合位置。
- 参数
module (nn.Module) –
nn.Conv3d
和nn.ConvTranspose3d
或容器nn.Module
memory_format – 用户指定的
memory_format
,例如torch.channels_last
或torch.contiguous_format
- 返回值
更新了
nn.Conv3d
的原始模块
示例
>>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda") >>> model = nn.Sequential( >>> nn.Conv3d(8, 4, 3)).cuda().half() >>> # This is identical to: >>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) >>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) >>> out = model(input)