OptimizerInBackwardWrapper¶
- class torchtune.training.OptimizerInBackwardWrapper(optim_map: Dict[str, Optimizer])[source]¶
一个简单的类,用于在反向传播过程中保存和加载优化器的检查点。用法仅限于以下情况
注意
此包装器仅适用于单设备用例。不支持分布式用例,例如 FSDP,这些用例需要专门的优化器状态检查点。
- 参数:
optim_map (Dict[str, torch.optim.Optimizer]) – 从参数名称到优化器的映射。
示例
>>> optim_dict = { >>> p: config.instantiate(cfg_optimizer, [p]) >>> for p in self._model.parameters() >>> } >>> >>> # Save checkpoint >>> ckpt = OptimizerInBackwardWrapper(optim_dict).state_dict() >>> torch.save("/tmp/optim_ckpt", ckpt) >>> >>> # Load checkpoint >>> placeholder_optim_dict = { >>> p: config.instantiate(cfg_optimizer, [p]) >>> for p in self._model.parameters() >>> } >>> >>> wrapper = OptimInBackwardWrapper(placeholder_optim_dict) >>> >>> # load_state_dict expects a dict produced by this class's >>> # state_dict method. >>> wrapper.load_state_dict(torch.load("/tmp/optim_ckpt")) >>> # placeholder_optim_dict now has updated optimizer states.
- get_optim_key(key: str) Any [source]¶
返回在反向传播中运行的任意优化器的键值。请注意,这假设在反向传播中的所有优化器对于该键具有相同的值,即使用相同的超参数初始化。