快捷方式

CommDebugMode 入门

创建日期:2024 年 8 月 19 日 | 最后更新:2024 年 10 月 8 日 | 最后验证:2024 年 11 月 5 日

作者: Anshul Sinha

在本教程中,我们将探讨如何在分布式训练环境中使用 CommDebugMode 配合 PyTorch 的 DistributedTensor (DTensor) 来跟踪集合操作进行调试。

先决条件

  • Python 3.8 - 3.11

  • PyTorch 2.2 或更高版本

CommDebugMode 是什么以及为何有用

随着模型规模不断增大,用户正寻求利用各种并行策略组合来扩展分布式训练。然而,现有解决方案之间的互操作性不足带来了巨大挑战,这主要是由于缺乏能够桥接这些不同并行策略的统一抽象。为了解决这个问题,PyTorch 提出了 DistributedTensor(DTensor),它抽象了分布式训练中张量通信的复杂性,提供了无缝的用户体验。然而,在使用现有并行解决方案以及利用 DTensor 等统一抽象开发并行解决方案时,底层集合通信的内容和发生时间缺乏透明度,这可能导致高级用户难以识别和解决问题。为了应对这一挑战,CommDebugMode(一个 Python 上下文管理器)将作为 DTensor 的主要调试工具之一,使用户能够查看在使用 DTensor 时集合操作的发生时间和原因,从而有效解决此问题。

使用 CommDebugMode

您可以按如下方式使用 CommDebugMode

# The model used in this example is a MLPModule applying Tensor Parallel
comm_mode = CommDebugMode()
    with comm_mode:
        output = model(inp)

# print the operation level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))

# log the operation level collective tracing information to a file
comm_mode.log_comm_debug_tracing_table_to_file(
    noise_level=1, file_name="transformer_operation_log.txt"
)

# dump the operation level collective tracing information to json file,
# used in the visual browser below
comm_mode.generate_json_dump(noise_level=2)

这是 MLPModule 在噪声级别 0 时的输出示例

Expected Output:
    Global
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        MLPModule
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            MLPModule.net1
            MLPModule.relu
            MLPModule.net2
              FORWARD PASS
                *c10d_functional.all_reduce: 1

要使用 CommDebugMode,您必须将运行模型的代码包装在 CommDebugMode 中,并调用您想要用来显示数据的 API。您还可以使用 noise_level 参数来控制显示信息的详细程度。以下是每个噪声级别显示的内容:

0. 打印模块级别的集合计数。
1. 打印 DTensor 操作(不包括微不足道的操作)、模块分片信息。
2. 打印张量操作(不包括微不足道的操作)。
3. 打印所有操作。

在上面的示例中,您可以看到集合操作 all_reduce 在 MLPModule 的前向传播中发生了一次。此外,您可以使用 CommDebugMode 精确定位到 all-reduce 操作发生在 MLPModule 的第二个线性层中。

下面是交互式模块树可视化工具,您可以使用它上传自己的 JSON dump 文件

CommDebugMode 模块树 - PyTorch 深度学习库
将文件拖到此处

结论

在本 Recipe 中,我们学习了如何使用 CommDebugMode 来调试 Distributed Tensors 以及使用 PyTorch 中通信集合的并行解决方案。您可以在嵌入的可视化浏览器中使用您自己的 JSON 输出。

有关 CommDebugMode 的更多详细信息,请参阅 comm_mode_features_example.py

文档

查看 PyTorch 的完整开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源