快捷方式

OptimizerInBackwardWrapper

class torchtune.training.OptimizerInBackwardWrapper(optim_map: Dict[str, Optimizer])[source]

一个简单的类,用于在反向传播过程中保存和加载优化器的检查点。用法仅限于以下情况

注意

此包装器仅适用于单设备用例。不支持分布式用例,例如 FSDP,这些用例需要专门的优化器状态检查点。

参数:

optim_map (Dict[str, torch.optim.Optimizer]) – 从参数名称到优化器的映射。

示例

>>> optim_dict = {
>>>     p: config.instantiate(cfg_optimizer, [p])
>>>     for p in self._model.parameters()
>>> }
>>>
>>> # Save checkpoint
>>> ckpt = OptimizerInBackwardWrapper(optim_dict).state_dict()
>>> torch.save("/tmp/optim_ckpt", ckpt)
>>>
>>> # Load checkpoint
>>> placeholder_optim_dict = {
>>>     p: config.instantiate(cfg_optimizer, [p])
>>>     for p in self._model.parameters()
>>> }
>>>
>>> wrapper = OptimInBackwardWrapper(placeholder_optim_dict)
>>>
>>> # load_state_dict expects a dict produced by this class's
>>> # state_dict method.
>>> wrapper.load_state_dict(torch.load("/tmp/optim_ckpt"))
>>> # placeholder_optim_dict now has updated optimizer states.
get_optim_key(key: str) Any[source]

返回在反向传播中运行的任意优化器的键值。请注意,这假设在反向传播中的所有优化器对于该键具有相同的值,即使用相同的超参数初始化。

load_state_dict(optim_ckpt_map: Dict[str, Any])[source]

从此类的 state_dict 方法生成的 state dict 加载优化器状态。

参数:

optim_ckpt_map (Dict[str, Any]) – 状态 dict,将参数名称映射到优化器状态。

引发:

RuntimeError – 如果优化器状态 dict 不包含所有预期的参数。

state_dict() Dict[str, Any][source]

返回一个状态 dict,将参数名称映射到优化器状态。此 state_dict 只能由同一类加载。

返回:

将参数名称映射到优化器状态的 state dict。

返回类型:

Dict[str, Any]

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源