快捷方式

tensordict.nn.dispatch

class tensordict.nn.dispatch(separator='_', source='in_keys', dest='out_keys', auto_batch_size: bool = True)

允许使用 kwargs 调用期望 TensorDict 的函数。

dispatch() 必须在具有 in_keys(或由 source 关键字参数指示的另一个键源)和 out_keys(或另一个 dest 键列表)属性的模块中使用,这些属性指示要从 tensordict 中读取和写入哪些键。包装后的函数也应该有一个 tensordict 作为第一个参数。

结果函数将返回单个张量(如果 out_keys 中只有一个元素),否则它将返回一个元组,其排序方式与模块的 out_keys 相同。

dispatch() 可以用作方法或类,当需要传递额外参数时。

参数:
  • separator (str, 可选) – 用于将 in_keys 中为字符串元组的子键组合在一起的分隔符。默认为 "_"

  • source (str键列表, 可选) – 如果提供字符串,则它指向包含要使用的输入键列表的模块属性。如果改为提供列表,则它将包含用作模块输入的键。默认为 "in_keys",它是 TensorDictModule 输入键列表的属性名称。

  • dest (str键列表, 可选) – 如果提供字符串,则它指向包含要使用的输出键列表的模块属性。如果改为提供列表,则它将包含用作模块输出的键。默认为 "out_keys",它是 TensorDictModule 输出键列表的属性名称。

  • auto_batch_size (bool, 可选) – 如果为 True,则输入 tensordict 的批大小将自动确定为所有输入张量中公共维度的最大数量。默认为 True

示例

>>> class MyModule(nn.Module):
...     in_keys = ["a"]
...     out_keys = ["b"]
...
...     @dispatch
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a'] + 1
...         return tensordict
...
>>> module = MyModule()
>>> b = module(a=torch.zeros(1, 2))
>>> assert (b == 1).all()
>>> # equivalently
>>> class MyModule(nn.Module):
...     keys_in = ["a"]
...     keys_out = ["b"]
...
...     @dispatch(source="keys_in", dest="keys_out")
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a'] + 1
...         return tensordict
...
>>> module = MyModule()
>>> b = module(a=torch.zeros(1, 2))
>>> assert (b == 1).all()
>>> # or this
>>> class MyModule(nn.Module):
...     @dispatch(source=["a"], dest=["b"])
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a'] + 1
...         return tensordict
...
>>> module = MyModule()
>>> b = module(a=torch.zeros(1, 2))
>>> assert (b == 1).all()

dispatch_kwargs() 也适用于使用默认 "_" 分隔符的嵌套键。

示例

>>> class MyModuleNest(nn.Module):
...     in_keys = [("a", "c")]
...     out_keys = ["b"]
...
...     @dispatch
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a', 'c'] + 1
...         return tensordict
...
>>> module = MyModuleNest()
>>> b, = module(a_c=torch.zeros(1, 2))
>>> assert (b == 1).all()

如果需要其他分隔符,可以在构造函数中使用 separator 参数指定。

示例

>>> class MyModuleNest(nn.Module):
...     in_keys = [("a", "c")]
...     out_keys = ["b"]
...
...     @dispatch(separator="sep")
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a', 'c'] + 1
...         return tensordict
...
>>> module = MyModuleNest()
>>> b, = module(asepc=torch.zeros(1, 2))
>>> assert (b == 1).all()

由于输入键是字符串的有序序列,因此 dispatch() 也可以与无名参数一起使用,其中顺序必须与输入键的顺序匹配。

注意

如果第一个参数是 TensorDictBase 实例,则假定未使用 dispatch,并且此 tensordict 包含运行模块所需的所有信息。换句话说,无法使用模块输入的第一个键指向 tensordict 实例来分解 tensordict。通常,最好将 dispatch() 与仅 tensordict 叶子节点一起使用。

示例

>>> class MyModuleNest(nn.Module):
...     in_keys = [("a", "c"), "d"]
...     out_keys = ["b"]
...
...     @dispatch
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a', 'c'] + tensordict["d"]
...         return tensordict
...
>>> module = MyModuleNest()
>>> b, = module(torch.zeros(1, 2), d=torch.ones(1, 2))  # works
>>> assert (b == 1).all()
>>> b, = module(torch.zeros(1, 2), torch.ones(1, 2))  # works
>>> assert (b == 1).all()
>>> try:
...     b, = module(torch.zeros(1, 2), a_c=torch.ones(1, 2))  # fails
... except:
...     print("oopsy!")
...

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源