快捷方式

torch.optim.Optimizer.state_dict

Optimizer.state_dict()[源代码][源代码]

返回优化器的状态,类型为 dict

它包含两个条目

  • state: 一个 Dict,保存当前的优化状态。其内容

    因优化器类而异,但有一些共同的特征。例如,状态是按参数保存的,并且参数本身不被保存。state 是一个字典,将参数 ID 映射到与每个参数对应的状态的字典。

  • param_groups: 一个列表,包含所有参数组,其中每个

    参数组是一个 Dict。每个参数组包含特定于优化器的元数据,例如学习率和权重衰减,以及组中参数的参数 ID 列表。如果参数组使用 named_parameters() 初始化,则名称内容也将保存在状态字典中。

注意:参数 ID 可能看起来像索引,但它们只是将状态与 param_group 关联的 ID。当从 state_dict 加载时,优化器将压缩 param_group 的 params (int ID) 和优化器的 param_groups (实际的 nn.Parameter s) 以匹配状态,而无需额外的验证。

返回的状态字典可能如下所示

{
    'state': {
        0: {'momentum_buffer': tensor(...), ...},
        1: {'momentum_buffer': tensor(...), ...},
        2: {'momentum_buffer': tensor(...), ...},
        3: {'momentum_buffer': tensor(...), ...}
    },
    'param_groups': [
        {
            'lr': 0.01,
            'weight_decay': 0,
            ...
            'params': [0]
            'param_names' ['param0']  (optional)
        },
        {
            'lr': 0.001,
            'weight_decay': 0.5,
            ...
            'params': [1, 2, 3]
            'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional)
        }
    ]
}
返回类型

Dict[str, Any]

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源