• 教程 >
  • 探索 TorchRec 分片
快捷方式

探索 TorchRec 分片

创建时间:2022 年 5 月 10 日 | 最后更新:2022 年 5 月 13 日 | 最后验证:2024 年 11 月 5 日

本教程主要介绍如何通过 EmbeddingPlannerDistributedModelParallel API 实现嵌入表的分片,并通过显式配置不同的分片方案来探讨它们对于嵌入表的益处。

安装

要求: - python >= 3.7

强烈推荐使用 CUDA 来使用 torchRec。如果使用 CUDA: - cuda >= 11.0

# install conda to make installying pytorch with cudatoolkit 11.3 easier.
!sudo rm Miniconda3-py37_4.9.2-Linux-x86_64.sh Miniconda3-py37_4.9.2-Linux-x86_64.sh.*
!sudo wget https://repo.anaconda.com/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh
!sudo chmod +x Miniconda3-py37_4.9.2-Linux-x86_64.sh
!sudo bash ./Miniconda3-py37_4.9.2-Linux-x86_64.sh -b -f -p /usr/local
# install pytorch with cudatoolkit 11.3
!sudo conda install pytorch cudatoolkit=11.3 -c pytorch-nightly -y

安装 torchRec 也会安装 FBGEMM,这是一个包含 CUDA 内核和 GPU 启用操作的集合

# install torchrec
!pip3 install torchrec-nightly

安装 multiprocess 以与 ipython 配合在 colab 中进行多进程编程

!pip3 install multiprocess

对于 Colab 运行时检测添加的共享库,需要执行以下步骤。运行时会在 /usr/lib 中搜索共享库,因此我们将安装在 /usr/local/lib/ 中的库复制过去。这是非常必要的步骤,仅在 colab 运行时中需要

!sudo cp /usr/local/lib/lib* /usr/lib/

此时重启运行时,以便识别新安装的包。重启后立即运行下面的步骤,以便 python 知道去哪里查找包。重启运行时后务必运行此步骤。

import sys
sys.path = ['', '/env/python', '/usr/local/lib/python37.zip', '/usr/local/lib/python3.7', '/usr/local/lib/python3.7/lib-dynload', '/usr/local/lib/python3.7/site-packages', './.local/lib/python3.7/site-packages']

分布式设置

由于是 Notebook 环境,我们无法在这里运行 SPMD 程序,但我们可以在 Notebook 内部进行多进程处理来模拟设置。在使用 Torchrec 时,用户应负责设置自己的 SPMD 启动器。我们设置环境,以便基于 torch distributed 的通信后端可以工作。

import os
import torch
import torchrec

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

构建我们的嵌入模型

这里我们使用 TorchRec 提供的 EmbeddingBagCollection 来构建包含嵌入表的嵌入包模型。

在这里,我们创建一个包含四个嵌入包的 EmbeddingBagCollection (EBC)。我们有两种类型的表:大表和小表,它们通过行大小差异来区分:4096 vs 1024。每个表仍然由 64 维嵌入表示。

我们为表配置了 ParameterConstraints 数据结构,它为模型并行 API 提供了提示,以帮助决定表的切分和放置策略。在 TorchRec 中,我们支持: * table-wise:将整个表放置在一个设备上; * row-wise:按行维度均匀地切分表,并将一个分片放置在通信世界的每个设备上; * column-wise:按嵌入维度均匀地切分表,并将一个分片放置在通信世界的每个设备上; * table-row-wise:为可用的快速机器内部设备互连(例如 NVLink)优化的特殊分片; * data_parallel:为每个设备复制表;

注意我们最初是如何在设备“meta”上分配 EBC 的。这会告诉 EBC 暂不分配内存。

from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.types import ShardingType
from typing import Dict

large_table_cnt = 2
small_table_cnt = 2
large_tables=[
  torchrec.EmbeddingBagConfig(
    name="large_table_" + str(i),
    embedding_dim=64,
    num_embeddings=4096,
    feature_names=["large_table_feature_" + str(i)],
    pooling=torchrec.PoolingType.SUM,
  ) for i in range(large_table_cnt)
]
small_tables=[
  torchrec.EmbeddingBagConfig(
    name="small_table_" + str(i),
    embedding_dim=64,
    num_embeddings=1024,
    feature_names=["small_table_feature_" + str(i)],
    pooling=torchrec.PoolingType.SUM,
  ) for i in range(small_table_cnt)
]

def gen_constraints(sharding_type: ShardingType = ShardingType.TABLE_WISE) -> Dict[str, ParameterConstraints]:
  large_table_constraints = {
    "large_table_" + str(i): ParameterConstraints(
      sharding_types=[sharding_type.value],
    ) for i in range(large_table_cnt)
  }
  small_table_constraints = {
    "small_table_" + str(i): ParameterConstraints(
      sharding_types=[sharding_type.value],
    ) for i in range(small_table_cnt)
  }
  constraints = {**large_table_constraints, **small_table_constraints}
  return constraints
ebc = torchrec.EmbeddingBagCollection(
    device="cuda",
    tables=large_tables + small_tables
)

多进程中的 DistributedModelParallel

现在,我们有一个单进程执行函数,用于模拟 SPMD 执行期间一个 Rank 的工作。

此代码将与其他进程协同切分模型并相应地分配内存。它首先设置进程组,然后使用规划器进行嵌入表放置,并使用 DistributedModelParallel 生成切分模型。

def single_rank_execution(
    rank: int,
    world_size: int,
    constraints: Dict[str, ParameterConstraints],
    module: torch.nn.Module,
    backend: str,
) -> None:
    import os
    import torch
    import torch.distributed as dist
    from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
    from torchrec.distributed.model_parallel import DistributedModelParallel
    from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
    from torchrec.distributed.types import ModuleSharder, ShardingEnv
    from typing import cast

    def init_distributed_single_host(
        rank: int,
        world_size: int,
        backend: str,
        # pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type.
    ) -> dist.ProcessGroup:
        os.environ["RANK"] = f"{rank}"
        os.environ["WORLD_SIZE"] = f"{world_size}"
        dist.init_process_group(rank=rank, world_size=world_size, backend=backend)
        return dist.group.WORLD

    if backend == "nccl":
        device = torch.device(f"cuda:{rank}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cpu")
    topology = Topology(world_size=world_size, compute_device="cuda")
    pg = init_distributed_single_host(rank, world_size, backend)
    planner = EmbeddingShardingPlanner(
        topology=topology,
        constraints=constraints,
    )
    sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())]
    plan: ShardingPlan = planner.collective_plan(module, sharders, pg)

    sharded_model = DistributedModelParallel(
        module,
        env=ShardingEnv.from_process_group(pg),
        plan=plan,
        sharders=sharders,
        device=device,
    )
    print(f"rank:{rank},sharding plan: {plan}")
    return sharded_model

多进程执行

现在让我们在代表多个 GPU Rank 的多进程中执行代码。

import multiprocess

def spmd_sharing_simulation(
    sharding_type: ShardingType = ShardingType.TABLE_WISE,
    world_size = 2,
):
  ctx = multiprocess.get_context("spawn")
  processes = []
  for rank in range(world_size):
      p = ctx.Process(
          target=single_rank_execution,
          args=(
              rank,
              world_size,
              gen_constraints(sharding_type),
              ebc,
              "nccl"
          ),
      )
      p.start()
      processes.append(p)

  for p in processes:
      p.join()
      assert 0 == p.exitcode

表格维度分片

现在让我们在两个进程中执行代码,对应于 2 个 GPU。我们可以在规划输出中看到我们的表是如何跨 GPU 进行分片的。每个节点将有一个大表和一个小表,这表明我们的规划器试图在嵌入表之间进行负载均衡。表格维度分片对于许多中小型表来说是实现设备间负载均衡的事实上的首选分片方案。

spmd_sharing_simulation(ShardingType.TABLE_WISE)
rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)])), 'large_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:0/cuda:0)])), 'small_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:1/cuda:1)]))}}
rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)])), 'large_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:0/cuda:0)])), 'small_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:1/cuda:1)]))}}

探索其他分片模式

我们初步探讨了表格维度分片的样子以及它如何平衡表的放置。现在我们探讨更侧重于负载均衡的分片模式:行维度分片 (row-wise)。行维度分片专门针对由于嵌入行数较大而单个设备无法容纳的大表。它可以解决模型中超大表的放置问题。用户可以在输出计划日志的 shard_sizes 部分看到,这些表按行维度减半后分布到两个 GPU 上。

spmd_sharing_simulation(ShardingType.ROW_WISE)
rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)]))}}
rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)]))}}

另一方面,列维度分片 (column-wise) 解决了嵌入维度较大的表的负载不均衡问题。我们将垂直切分表。用户可以在输出计划日志的 shard_sizes 部分看到,这些表按嵌入维度减半后分布到两个 GPU 上。

spmd_sharing_simulation(ShardingType.COLUMN_WISE)
rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)]))}}
rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)]))}}

对于 table-row-wise 分片,不幸的是我们无法模拟它,因为它需要在多主机设置下运行。未来我们将提供一个 Python SPMD 示例来演示如何使用 table-row-wise 训练模型。

使用数据并行,我们将为所有设备复制表。

spmd_sharing_simulation(ShardingType.DATA_PARALLEL)
rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'large_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None)}}
rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'large_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None)}}

文档

查阅 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源