• 教程 >
  • (Beta) PyTorch 中的 Channels Last 内存格式
快捷方式

(Beta) PyTorch 中的 Channels Last 内存格式

创建日期:2020 年 4 月 20 日 | 最后更新:2023 年 10 月 04 日 | 最后验证:2024 年 11 月 05 日

作者: Vitaly Fedyunin

什么是 Channels Last

Channels Last 内存格式是 NCHW 张量在内存中排序的另一种方式,同时保留维度顺序。Channels Last 张量的排序方式使得通道成为最密集的维度(也就是逐像素存储图像)。

例如,经典的(连续的)NCHW 张量存储方式(在本例中是两个 4x4 图像,每个图像有 3 个颜色通道)如下所示:

classic_memory_format

Channels Last 内存格式以不同的方式排序数据

channels_last_memory_format

PyTorch 通过利用现有的步长结构来支持内存格式(并为现有模型提供向后兼容性,包括 eager、JIT 和 TorchScript)。例如,在 Channels Last 格式中,一个 10x3x16x16 的批处理将具有 (768, 1, 48, 3) 的步长。

Channels Last 内存格式仅针对 4D NCHW 张量实现。

内存格式 API

以下是如何在连续内存格式和 Channels Last 内存格式之间转换张量。

经典的 PyTorch 连续张量

import torch

N, C, H, W = 10, 3, 32, 32
x = torch.empty(N, C, H, W)
print(x.stride())  # Outputs: (3072, 1024, 32, 1)
(3072, 1024, 32, 1)

转换算子

x = x.to(memory_format=torch.channels_last)
print(x.shape)  # Outputs: (10, 3, 32, 32) as dimensions order preserved
print(x.stride())  # Outputs: (3072, 1, 96, 3)
torch.Size([10, 3, 32, 32])
(3072, 1, 96, 3)

回到连续格式

x = x.to(memory_format=torch.contiguous_format)
print(x.stride())  # Outputs: (3072, 1024, 32, 1)
(3072, 1024, 32, 1)

替代选项

x = x.contiguous(memory_format=torch.channels_last)
print(x.stride())  # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)

格式检查

print(x.is_contiguous(memory_format=torch.channels_last))  # Outputs: True
True

两个 API tocontiguous 之间存在细微差别。我们建议在明确转换张量内存格式时使用 to

在一般情况下,这两个 API 的行为相同。然而,对于大小为 NCHW 的 4D 张量,在以下特殊情况之一:C==1H==1 && W==1 时,只有 to 才能生成表示 Channels Last 内存格式的正确步长。

这是因为在上述两种情况下的任一种,张量的内存格式都是模糊的,即大小为 N1HW 的连续张量在内存存储上既是 contiguous 也是 channels last。因此,对于给定的内存格式,它们已经被认为是 is_contiguous,因此调用 contiguous 会变成空操作,并且不会更新步长。相反,to 会在大小为 1 的维度上重新设置张量的步长,以便正确表示预期的内存格式。

special_x = torch.empty(4, 1, 4, 4)
print(special_x.is_contiguous(memory_format=torch.channels_last))  # Outputs: True
print(special_x.is_contiguous(memory_format=torch.contiguous_format))  # Outputs: True
True
True

同样的情况也适用于显式置换 API permute。在可能出现模糊性的特殊情况下,permute 不保证生成能够正确承载预期内存格式的步长。我们建议使用带有显式内存格式的 to 来避免意外行为。

另外需要注意的是,在极端情况下,如果三个非批处理维度都等于 1C==1 && H==1 && W==1),当前实现无法将张量标记为 Channels Last 内存格式。

创建为 channels last

x = torch.empty(N, C, H, W, memory_format=torch.channels_last)
print(x.stride())  # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)

clone 保留内存格式

y = x.clone()
print(y.stride())  # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)

to, cuda, float 等 … 保留内存格式

if torch.cuda.is_available():
    y = x.cuda()
    print(y.stride())  # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)

empty_like, *_like 算子保留内存格式

y = torch.empty_like(x)
print(y.stride())  # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)

逐点算子保留内存格式

z = x + y
print(z.stride())  # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)

使用 cudnn 后端的 ConvBatchnorm 模块支持 Channels Last(仅适用于 cuDNN >= 7.6)。与二元逐点算子不同,卷积模块以 Channels Last 作为主导内存格式。如果所有输入都是连续内存格式,则算子生成连续内存格式的输出。否则,输出将是 Channels Last 内存格式。

if torch.backends.cudnn.is_available() and torch.backends.cudnn.version() >= 7603:
    model = torch.nn.Conv2d(8, 4, 3).cuda().half()
    model = model.to(memory_format=torch.channels_last)  # Module parameters need to be channels last

    input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, requires_grad=True)
    input = input.to(device="cuda", memory_format=torch.channels_last, dtype=torch.float16)

    out = model(input)
    print(out.is_contiguous(memory_format=torch.channels_last))  # Outputs: True
True

当输入张量到达一个不支持 Channels Last 的算子时,内核中会自动应用一个置换来恢复输入张量的连续性。这会引入开销并停止 Channels Last 内存格式的传播。尽管如此,它保证了正确的输出。

性能提升

Channels Last 内存格式优化在 GPU 和 CPU 上均可用。在 GPU 上,使用 Channels Last 在支持 Tensor Core 的 NVIDIA 硬件上以降低精度(torch.float16)运行时,观察到了最显著的性能提升。与连续格式相比,我们在 Channels Last 上实现了超过 22% 的性能提升,这两种情况下都使用了 'AMP (自动混合精度)' 训练脚本。我们的脚本使用了 NVIDIA 提供的 AMP,地址为 https://github.com/NVIDIA/apex

python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2  ./data

# opt_level = O2
# keep_batchnorm_fp32 = None <class 'NoneType'>
# loss_scale = None <class 'NoneType'>
# CUDNN VERSION: 7603
# => creating model 'resnet50'
# Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.
# Defaults for this optimization level are:
# enabled                : True
# opt_level              : O2
# cast_model_type        : torch.float16
# patch_torch_functions  : False
# keep_batchnorm_fp32    : True
# master_weights         : True
# loss_scale             : dynamic
# Processing user overrides (additional kwargs that are not None)...
# After processing overrides, optimization options are:
# enabled                : True
# opt_level              : O2
# cast_model_type        : torch.float16
# patch_torch_functions  : False
# keep_batchnorm_fp32    : True
# master_weights         : True
# loss_scale             : dynamic
# Epoch: [0][10/125] Time 0.866 (0.866) Speed 230.949 (230.949) Loss 0.6735125184 (0.6735) Prec@1 61.000 (61.000) Prec@5 100.000 (100.000)
# Epoch: [0][20/125] Time 0.259 (0.562) Speed 773.481 (355.693) Loss 0.6968704462 (0.6852) Prec@1 55.000 (58.000) Prec@5 100.000 (100.000)
# Epoch: [0][30/125] Time 0.258 (0.461) Speed 775.089 (433.965) Loss 0.7877287269 (0.7194) Prec@1 51.500 (55.833) Prec@5 100.000 (100.000)
# Epoch: [0][40/125] Time 0.259 (0.410) Speed 771.710 (487.281) Loss 0.8285319805 (0.7467) Prec@1 48.500 (54.000) Prec@5 100.000 (100.000)
# Epoch: [0][50/125] Time 0.260 (0.380) Speed 770.090 (525.908) Loss 0.7370464802 (0.7447) Prec@1 56.500 (54.500) Prec@5 100.000 (100.000)
# Epoch: [0][60/125] Time 0.258 (0.360) Speed 775.623 (555.728) Loss 0.7592862844 (0.7472) Prec@1 51.000 (53.917) Prec@5 100.000 (100.000)
# Epoch: [0][70/125] Time 0.258 (0.345) Speed 774.746 (579.115) Loss 1.9698858261 (0.9218) Prec@1 49.500 (53.286) Prec@5 100.000 (100.000)
# Epoch: [0][80/125] Time 0.260 (0.335) Speed 770.324 (597.659) Loss 2.2505953312 (1.0879) Prec@1 50.500 (52.938) Prec@5 100.000 (100.000)

传递 --channels-last true 可以在 Channels Last 格式下运行模型,并观察到 22% 的性能提升。

python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 --channels-last true ./data

# opt_level = O2
# keep_batchnorm_fp32 = None <class 'NoneType'>
# loss_scale = None <class 'NoneType'>
#
# CUDNN VERSION: 7603
#
# => creating model 'resnet50'
# Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.
#
# Defaults for this optimization level are:
# enabled                : True
# opt_level              : O2
# cast_model_type        : torch.float16
# patch_torch_functions  : False
# keep_batchnorm_fp32    : True
# master_weights         : True
# loss_scale             : dynamic
# Processing user overrides (additional kwargs that are not None)...
# After processing overrides, optimization options are:
# enabled                : True
# opt_level              : O2
# cast_model_type        : torch.float16
# patch_torch_functions  : False
# keep_batchnorm_fp32    : True
# master_weights         : True
# loss_scale             : dynamic
#
# Epoch: [0][10/125] Time 0.767 (0.767) Speed 260.785 (260.785) Loss 0.7579724789 (0.7580) Prec@1 53.500 (53.500) Prec@5 100.000 (100.000)
# Epoch: [0][20/125] Time 0.198 (0.482) Speed 1012.135 (414.716) Loss 0.7007197738 (0.7293) Prec@1 49.000 (51.250) Prec@5 100.000 (100.000)
# Epoch: [0][30/125] Time 0.198 (0.387) Speed 1010.977 (516.198) Loss 0.7113101482 (0.7233) Prec@1 55.500 (52.667) Prec@5 100.000 (100.000)
# Epoch: [0][40/125] Time 0.197 (0.340) Speed 1013.023 (588.333) Loss 0.8943189979 (0.7661) Prec@1 54.000 (53.000) Prec@5 100.000 (100.000)
# Epoch: [0][50/125] Time 0.198 (0.312) Speed 1010.541 (641.977) Loss 1.7113249302 (0.9551) Prec@1 51.000 (52.600) Prec@5 100.000 (100.000)
# Epoch: [0][60/125] Time 0.198 (0.293) Speed 1011.163 (683.574) Loss 5.8537774086 (1.7716) Prec@1 50.500 (52.250) Prec@5 100.000 (100.000)
# Epoch: [0][70/125] Time 0.198 (0.279) Speed 1011.453 (716.767) Loss 5.7595844269 (2.3413) Prec@1 46.500 (51.429) Prec@5 100.000 (100.000)
# Epoch: [0][80/125] Time 0.198 (0.269) Speed 1011.827 (743.883) Loss 2.8196096420 (2.4011) Prec@1 47.500 (50.938) Prec@5 100.000 (100.000)

以下模型列表完全支持 Channels Last 格式,并在 Volta 设备上显示 8%-35% 的性能提升:alexnet, mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3, mobilenet_v2, resnet101, resnet152, resnet18, resnet34, resnet50, resnext50_32x4d, shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5, shufflenet_v2_x2_0, squeezenet1_0, squeezenet1_1, vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn, wide_resnet101_2, wide_resnet50_2

以下模型列表完全支持 Channels Last 格式,并在英特尔(R) 至强(R) Ice Lake(或更新)CPU 上显示 26%-76% 的性能提升:alexnet, densenet121, densenet161, densenet169, googlenet, inception_v3, mnasnet0_5, mnasnet1_0, resnet101, resnet152, resnet18, resnet34, resnet50, resnext101_32x8d, resnext50_32x4d, shufflenet_v2_x0_5, shufflenet_v2_x1_0, squeezenet1_0, squeezenet1_1, vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn, wide_resnet101_2, wide_resnet50_2

转换现有模型

Channels Last 支持不限于现有模型,因为任何模型都可以转换为 Channels Last 格式,并且只要输入(或某些权重)的格式正确,就可以在计算图中传播该格式。

# Need to be done once, after model initialization (or load)
model = model.to(memory_format=torch.channels_last)  # Replace with your model

# Need to be done for every input
input = input.to(memory_format=torch.channels_last)  # Replace with your input
output = model(input)

然而,并非所有算子都完全转换为支持 Channels Last(通常返回连续输出)。在上面发布的示例中,不支持 Channels Last 的层将停止内存格式的传播。尽管如此,由于我们已将模型转换为 Channels Last 格式,这意味着每个卷积层(其 4 维权重采用 Channels Last 内存格式)将恢复 Channels Last 内存格式并受益于更快的内核。

但不支持 Channels Last 的算子会通过置换引入开销。如果您想提高转换后模型的性能,可以选择调查并识别模型中不支持 Channels Last 的算子。

这意味着您需要对照支持 Channels Last 的算子列表 https://github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support 验证您使用的算子列表,或者在 eager 执行模式中引入内存格式检查并运行您的模型。

运行以下代码后,如果算子的输出与输入的内存格式不匹配,算子将引发异常。

def contains_cl(args):
    for t in args:
        if isinstance(t, torch.Tensor):
            if t.is_contiguous(memory_format=torch.channels_last) and not t.is_contiguous():
                return True
        elif isinstance(t, list) or isinstance(t, tuple):
            if contains_cl(list(t)):
                return True
    return False


def print_inputs(args, indent=""):
    for t in args:
        if isinstance(t, torch.Tensor):
            print(indent, t.stride(), t.shape, t.device, t.dtype)
        elif isinstance(t, list) or isinstance(t, tuple):
            print(indent, type(t))
            print_inputs(list(t), indent=indent + "    ")
        else:
            print(indent, t)


def check_wrapper(fn):
    name = fn.__name__

    def check_cl(*args, **kwargs):
        was_cl = contains_cl(args)
        try:
            result = fn(*args, **kwargs)
        except Exception as e:
            print("`{}` inputs are:".format(name))
            print_inputs(args)
            print("-------------------")
            raise e
        failed = False
        if was_cl:
            if isinstance(result, torch.Tensor):
                if result.dim() == 4 and not result.is_contiguous(memory_format=torch.channels_last):
                    print(
                        "`{}` got channels_last input, but output is not channels_last:".format(name),
                        result.shape,
                        result.stride(),
                        result.device,
                        result.dtype,
                    )
                    failed = True
        if failed and True:
            print("`{}` inputs are:".format(name))
            print_inputs(args)
            raise Exception("Operator `{}` lost channels_last property".format(name))
        return result

    return check_cl


old_attrs = dict()


def attribute(m):
    old_attrs[m] = dict()
    for i in dir(m):
        e = getattr(m, i)
        exclude_functions = ["is_cuda", "has_names", "numel", "stride", "Tensor", "is_contiguous", "__class__"]
        if i not in exclude_functions and not i.startswith("_") and "__call__" in dir(e):
            try:
                old_attrs[m][i] = e
                setattr(m, i, check_wrapper(e))
            except Exception as e:
                print(i)
                print(e)


attribute(torch.Tensor)
attribute(torch.nn.functional)
attribute(torch)

如果您发现一个不支持 Channels Last 张量的算子并想贡献力量,请随时使用以下开发者指南 https://github.com/pytorch/pytorch/wiki/Writing-memory-format-aware-operators

以下代码用于恢复 torch 的属性。

for (m, attrs) in old_attrs.items():
    for (k, v) in attrs.items():
        setattr(m, k, v)

待完成的工作

还有许多工作要做,例如:

  • 解决 N1HWNC11 张量的模糊性;

  • 测试分布式训练支持;

  • 提高算子覆盖率。

如果您有反馈或改进建议,请通过创建问题告知我们。

脚本总运行时间: ( 0 分钟 0.026 秒)

Gallery 由 Sphinx-Gallery 生成

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源