快捷方式

tensordict.nn.set_skip_existing

class tensordict.nn.set_skip_existing(mode: bool | None = True, in_key_attr='in_keys', out_key_attr='out_keys')

用于跳过 TensorDict 图中现有节点的上下文管理器。

当用作上下文管理器时,它会将 skip_existing() 值设置为指示的 mode,使用户能够编写检查全局值并相应地执行代码的方法。

当用作方法装饰器时,它将检查 tensordict 输入键,并且如果 skip_existing() 调用返回 True,则当所有输出键都已存在时,它将跳过该方法。 这不应被用作不遵循以下签名的方法的装饰器:def fun(self, tensordict, *args, **kwargs)

参数:
  • mode (bool, 可选) – 如果 True,则表示不会覆盖图中已存在的条目,除非它们仅部分存在。 skip_existing() 将返回 True。 如果 False,则不执行任何检查。 如果 None,则 skip_existing() 的值不会更改。 这旨在专门用于装饰方法,并允许其行为依赖于用作上下文管理器时的同一类(请参阅下面的示例)。 默认为 True

  • in_key_attr (str, 可选) – 模块的被装饰方法中输入键列表属性的名称。 默认为 in_keys

  • out_key_attr (str, 可选) – 模块的被装饰方法中输出键列表属性的名称。 默认为 out_keys

示例

>>> with set_skip_existing():
...     if skip_existing():
...         print("True")
...     else:
...         print("False")
...
True
>>> print("calling from outside:", skip_existing())
calling from outside: False

此类也可以用作装饰器

示例

>>> from tensordict import TensorDict
>>> from tensordict.nn import set_skip_existing, skip_existing, TensorDictModuleBase
>>> class MyModule(TensorDictModuleBase):
...     in_keys = []
...     out_keys = ["out"]
...     @set_skip_existing()
...     def forward(self, tensordict):
...         print("hello")
...         tensordict.set("out", torch.zeros(()))
...         return tensordict
>>> module = MyModule()
>>> module(TensorDict({"out": torch.zeros(())}, []))  # does not print anything
TensorDict(
    fields={
        out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> module(TensorDict())  # prints hello
hello
TensorDict(
    fields={
        out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

每当希望让上下文管理器从外部处理跳过操作时,使用设置为 None 模式装饰方法都很有用

示例

>>> from tensordict import TensorDict
>>> from tensordict.nn import set_skip_existing, skip_existing, TensorDictModuleBase
>>> class MyModule(TensorDictModuleBase):
...     in_keys = []
...     out_keys = ["out"]
...     @set_skip_existing(None)
...     def forward(self, tensordict):
...         print("hello")
...         tensordict.set("out", torch.zeros(()))
...         return tensordict
>>> module = MyModule()
>>> _ = module(TensorDict({"out": torch.zeros(())}, []))  # prints "hello"
hello
>>> with set_skip_existing(True):
...     _ = module(TensorDict({"out": torch.zeros(())}, []))  # no print

注意

为了允许模块具有相同的输入和输出键,并且不会错误地忽略子图,当输出键也是输入键时,@set_skip_existing(True) 将被停用

>>> class MyModule(TensorDictModuleBase):
...     in_keys = ["out"]
...     out_keys = ["out"]
...     @set_skip_existing()
...     def forward(self, tensordict):
...         print("calling the method!")
...         return tensordict
...
>>> module = MyModule()
>>> module(TensorDict({"out": torch.zeros(())}, []))  # does not print anything
calling the method!
TensorDict(
    fields={
        out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源