模型并行¶
DistributedModelParallel
是使用 TorchRec 优化进行分布式训练的主要 API。
- class torchrec.distributed.model_parallel.DistributedModelParallel(module: Module, env: Optional[ShardingEnv] = None, device: Optional[device] = None, plan: Optional[ShardingPlan] = None, sharders: Optional[List[ModuleSharder[Module]]] = None, init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[DataParallelWrapper] = None)¶
模型并行的入口点。
- 参数:
module (nn.Module) – 要包装的模块。
env (Optional[ShardingEnv]) – 包含进程组的切分环境。
device (Optional[torch.device]) – 计算设备,默认为 cpu。
plan (Optional[ShardingPlan]) – 切分时使用的计划,默认为 EmbeddingShardingPlanner.collective_plan()。
sharders (Optional[List[ModuleSharder[nn.Module]]]) – 可用于切分的 ModuleSharders,默认为 EmbeddingBagCollectionSharder()。
init_data_parallel (bool) – 数据并行模块可以是惰性的,即它们延迟参数初始化直到第一次前向传递。传递 True 以延迟数据并行模块的初始化。先进行第一次前向传递,然后调用 DistributedModelParallel.init_data_parallel()。
init_parameters (bool) – 初始化仍在元设备上的模块的参数。
data_parallel_wrapper (Optional[DataParallelWrapper]) – 数据并行模块的自定义包装器。
示例
@torch.no_grad() def init_weights(m): if isinstance(m, nn.Linear): m.weight.fill_(1.0) elif isinstance(m, EmbeddingBagCollection): for param in m.parameters(): init.kaiming_normal_(param) m = MyModel(device='meta') m = DistributedModelParallel(m) m.apply(init_weights)
- copy(device: device) DistributedModelParallel ¶
通过调用每个模块的自定义复制过程,递归地将子模块复制到新设备,因为某些模块需要使用原始引用(如用于推理的 ShardedModule)。
- forward(*args, **kwargs) Any ¶
定义每次调用时执行的计算。
应由所有子类覆盖。
注意
虽然前向传递的配方需要在此函数中定义,但之后应该调用
Module
实例而不是此函数,因为前者负责运行注册的钩子,而后者则静默地忽略它们。
- init_data_parallel() None ¶
请参阅 init_data_parallel c-tor 参数以了解用法。可以安全地多次调用此方法。
- load_state_dict(state_dict: OrderedDict[str, Tensor], prefix: str = '', strict: bool = True) _IncompatibleKeys ¶
将参数和缓冲区从
state_dict
复制到此模块及其后代。如果
strict
为True
,则state_dict
的键必须与该模块的state_dict()
函数返回的键完全匹配。警告
如果
assign
为True
,则必须在调用load_state_dict
后创建优化器,除非get_swap_module_params_on_conversion()
为True
。- 参数:
state_dict (dict) – 包含参数和持久缓冲区的字典。
strict (bool, optional) – 是否严格执行
state_dict
中的键与该模块的state_dict()
函数返回的键匹配。默认值:True
assign (bool, optional) – 当
False
时,当前模块中张量的属性将被保留,而当True
时,将保留状态字典中张量的属性。唯一的例外是requires_grad
字段的Default: ``False`
- 返回值:
- missing_keys 是一个包含任何预期键的字符串列表
但此模块在提供的
state_dict
中缺少。
- unexpected_keys 是一个包含此模块未预期的键的字符串列表,
但在提供的
state_dict
中存在。
- 返回类型:
NamedTuple
,包含missing_keys
和unexpected_keys
字段
注意
如果参数或缓冲区注册为
None
,并且其对应的键存在于state_dict
中,则load_state_dict()
将引发RuntimeError
。
- property module: Module¶
属性可直接访问分片模块,该模块不会包装在 DDP、FSDP、DMP 或任何其他并行包装器中。
- named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]] ¶
返回一个模块缓冲区的迭代器,同时生成缓冲区的名称和缓冲区本身。
- 参数:
prefix (str) – 要添加到所有缓冲区名称之前的 prefix。
recurse (bool, optional) – 如果为 True,则生成此模块和所有子模块的缓冲区。否则,仅生成作为此模块直接成员的缓冲区。默认为 True。
remove_duplicate (bool, optional) – 是否删除结果中重复的缓冲区。默认为 True。
- 产量:
(str, torch.Tensor) – 包含名称和缓冲区的元组
示例
>>> # xdoctest: +SKIP("undefined vars") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]] ¶
返回一个模块参数的迭代器,同时生成参数的名称和参数本身。
- 参数:
prefix (str) – 要添加到所有参数名称之前的 prefix。
recurse (bool) – 如果为 True,则生成此模块和所有子模块的参数。否则,仅生成作为此模块直接成员的参数。
remove_duplicate (bool, optional) – 是否删除结果中重复的参数。默认为 True。
- 产量:
(str, Parameter) – 包含名称和参数的元组
示例
>>> # xdoctest: +SKIP("undefined vars") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
返回一个包含模块完整状态的引用字典。
包含参数和持久缓冲区(例如,运行平均值)。键是对应的参数和缓冲区名称。设置为
None
的参数和缓冲区不包括在内。注意
返回的对象是浅拷贝。它包含对模块的参数和缓冲区的引用。
警告
目前
state_dict()
还接受destination
、prefix
和keep_vars
的位置参数。但是,这已被弃用,并且在将来的版本中将强制使用关键字参数。警告
请避免使用参数
destination
,因为它不是为最终用户设计的。- 参数:
destination (dict, optional) – 如果提供,则模块的状态将更新到字典中,并返回相同的对象。否则,将创建一个
OrderedDict
并返回。默认值:None
。prefix (str, optional) – 添加到参数和缓冲区名称之前的 prefix,用于组合 state_dict 中的键。默认值:
''
。keep_vars (bool, optional) – 默认情况下,state_dict 中返回的
Tensor
与自动梯度分离。如果将其设置为True
,则不会执行分离。默认值:False
。
- 返回值:
包含模块完整状态的字典
- 返回类型:
dict
示例
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']