tensordict.nn.set_skip_existing¶
- class tensordict.nn.set_skip_existing(mode: bool | None = True, in_key_attr='in_keys', out_key_attr='out_keys')¶
用于跳过 TensorDict 图中现有节点的上下文管理器。
当用作上下文管理器时,它会将 skip_existing() 值设置为指示的
mode
,使用户能够编写检查全局值并相应地执行代码的方法。当用作方法装饰器时,它将检查 tensordict 输入键,并且如果
skip_existing()
调用返回True
,则当所有输出键都已存在时,它将跳过该方法。 这不应被用作不遵循以下签名的方法的装饰器:def fun(self, tensordict, *args, **kwargs)
。- 参数:
mode (bool, 可选) – 如果
True
,则表示不会覆盖图中已存在的条目,除非它们仅部分存在。skip_existing()
将返回True
。 如果False
,则不执行任何检查。 如果None
,则skip_existing()
的值不会更改。 这旨在专门用于装饰方法,并允许其行为依赖于用作上下文管理器时的同一类(请参阅下面的示例)。 默认为True
。in_key_attr (str, 可选) – 模块的被装饰方法中输入键列表属性的名称。 默认为
in_keys
。out_key_attr (str, 可选) – 模块的被装饰方法中输出键列表属性的名称。 默认为
out_keys
。
示例
>>> with set_skip_existing(): ... if skip_existing(): ... print("True") ... else: ... print("False") ... True >>> print("calling from outside:", skip_existing()) calling from outside: False
此类也可以用作装饰器
示例
>>> from tensordict import TensorDict >>> from tensordict.nn import set_skip_existing, skip_existing, TensorDictModuleBase >>> class MyModule(TensorDictModuleBase): ... in_keys = [] ... out_keys = ["out"] ... @set_skip_existing() ... def forward(self, tensordict): ... print("hello") ... tensordict.set("out", torch.zeros(())) ... return tensordict >>> module = MyModule() >>> module(TensorDict({"out": torch.zeros(())}, [])) # does not print anything TensorDict( fields={ out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> module(TensorDict()) # prints hello hello TensorDict( fields={ out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
每当希望让上下文管理器从外部处理跳过操作时,使用设置为
None
模式装饰方法都很有用示例
>>> from tensordict import TensorDict >>> from tensordict.nn import set_skip_existing, skip_existing, TensorDictModuleBase >>> class MyModule(TensorDictModuleBase): ... in_keys = [] ... out_keys = ["out"] ... @set_skip_existing(None) ... def forward(self, tensordict): ... print("hello") ... tensordict.set("out", torch.zeros(())) ... return tensordict >>> module = MyModule() >>> _ = module(TensorDict({"out": torch.zeros(())}, [])) # prints "hello" hello >>> with set_skip_existing(True): ... _ = module(TensorDict({"out": torch.zeros(())}, [])) # no print
注意
为了允许模块具有相同的输入和输出键,并且不会错误地忽略子图,当输出键也是输入键时,
@set_skip_existing(True)
将被停用>>> class MyModule(TensorDictModuleBase): ... in_keys = ["out"] ... out_keys = ["out"] ... @set_skip_existing() ... def forward(self, tensordict): ... print("calling the method!") ... return tensordict ... >>> module = MyModule() >>> module(TensorDict({"out": torch.zeros(())}, [])) # does not print anything calling the method! TensorDict( fields={ out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)