ParameterDict¶
- class torch.nn.ParameterDict(parameters=None)[source][source]¶
在字典中保存参数。
ParameterDict 可以像常规 Python 字典一样进行索引,但它包含的 Parameters 会被正确注册,并且对所有 Module 方法可见。其他对象将被视为常规 Python 字典中的对象。
ParameterDict
是一个有序字典。使用其他无序映射类型(例如,Python 的普通dict
)的update()
不会保留合并映射的顺序。另一方面,OrderedDict
或另一个ParameterDict
将保留其顺序。请注意,构造函数、分配字典的元素以及
update()
方法会将任何Tensor
转换为Parameter
。- 参数
values (iterable, optional) – (字符串 : 任意类型) 的映射(字典)或 (字符串, 任意类型) 类型的键值对的可迭代对象
示例
class MyModule(nn.Module): def __init__(self) -> None: super().__init__() self.params = nn.ParameterDict({ 'left': nn.Parameter(torch.randn(5, 10)), 'right': nn.Parameter(torch.randn(5, 10)) }) def forward(self, x, choice): x = self.params[choice].mm(x) return x
- copy()[source][source]¶
返回此
ParameterDict
实例的副本。- 返回类型
- fromkeys(keys, default=None)[source][source]¶
返回使用提供的键的新 ParameterDict。
- 参数
keys (iterable, string) – 用于创建新 ParameterDict 的键
default (Parameter, optional) – 为所有键设置的值
- 返回类型
- get(key, default=None)[source][source]¶
如果存在与键关联的参数,则返回该参数。否则,如果提供了 default,则返回 default,如果未提供,则返回 None。
- setdefault(key, default=None)[source][source]¶
在 Parameterdict 中为键设置默认值。
如果键在 ParameterDict 中,则返回其值。如果不在,则插入带有参数 default 的 key 并返回 default。default 默认为 None。
- update(parameters)[source][source]¶
使用来自
parameters
的键值对更新ParameterDict
,覆盖现有键。注意
如果
parameters
是OrderedDict
、ParameterDict
或键值对的可迭代对象,则保留其中新元素的顺序。- 参数
parameters (iterable) – 从字符串到
Parameter
的映射(字典),或 (字符串,Parameter
) 类型的键值对的可迭代对象