快捷方式

torch.optim.Optimizer.state_dict

Optimizer.state_dict()[源代码][源代码]

dict 的形式返回优化器的状态。

它包含两个条目:

  • state: 一个 Dict,包含当前的优化状态。其内容因优化器类

    而异,但有一些共同的特点。例如,状态是按参数保存的,参数本身则不保存。state 是一个将参数 ID 映射到包含对应每个参数状态的 Dict 的字典。

  • param_groups: 一个 List,包含所有参数组,其中每个

    参数组都是一个 Dict。每个参数组包含优化器特有的元数据,例如学习率和权重衰减,以及组中参数的参数 ID 列表。如果一个参数组是使用 named_parameters() 初始化的,则名称内容也会保存在状态字典中。

注意:参数 ID 可能看起来像索引,但它们仅仅是将状态与 param_group 相关联的 ID。从 state_dict 加载时,优化器将打包 param_group 的 params (整数 ID) 和优化器的 param_groups (实际的 nn.Parameter s),以便匹配状态,而无需进行额外验证。

返回的状态字典可能看起来像:

{
    'state': {
        0: {'momentum_buffer': tensor(...), ...},
        1: {'momentum_buffer': tensor(...), ...},
        2: {'momentum_buffer': tensor(...), ...},
        3: {'momentum_buffer': tensor(...), ...}
    },
    'param_groups': [
        {
            'lr': 0.01,
            'weight_decay': 0,
            ...
            'params': [0]
            'param_names' ['param0']  (optional)
        },
        {
            'lr': 0.001,
            'weight_decay': 0.5,
            ...
            'params': [1, 2, 3]
            'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional)
        }
    ]
}
返回类型

dict[str, Any]

文档

查看 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

寻找开发资源并获得问题解答

查看资源