tensordict.nn.dispatch¶
- class tensordict.nn.dispatch(separator='_', source='in_keys', dest='out_keys', auto_batch_size: bool = True)¶
允许使用 kwargs 调用期望 TensorDict 的函数。
dispatch()
必须在具有in_keys
(或由source
关键字参数指示的另一个键源)和out_keys
(或另一个dest
键列表)属性的模块中使用,这些属性指示要从 tensordict 中读取和写入哪些键。包装后的函数也应该有一个tensordict
作为第一个参数。结果函数将返回单个张量(如果
out_keys
中只有一个元素),否则它将返回一个元组,其排序方式与模块的out_keys
相同。dispatch()
可以用作方法或类,当需要传递额外参数时。- 参数:
separator (str, 可选) – 用于将
in_keys
中为字符串元组的子键组合在一起的分隔符。默认为"_"
。source (str 或 键列表, 可选) – 如果提供字符串,则它指向包含要使用的输入键列表的模块属性。如果改为提供列表,则它将包含用作模块输入的键。默认为
"in_keys"
,它是TensorDictModule
输入键列表的属性名称。dest (str 或 键列表, 可选) – 如果提供字符串,则它指向包含要使用的输出键列表的模块属性。如果改为提供列表,则它将包含用作模块输出的键。默认为
"out_keys"
,它是TensorDictModule
输出键列表的属性名称。auto_batch_size (bool, 可选) – 如果为
True
,则输入 tensordict 的批大小将自动确定为所有输入张量中公共维度的最大数量。默认为True
。
示例
>>> class MyModule(nn.Module): ... in_keys = ["a"] ... out_keys = ["b"] ... ... @dispatch ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a'] + 1 ... return tensordict ... >>> module = MyModule() >>> b = module(a=torch.zeros(1, 2)) >>> assert (b == 1).all() >>> # equivalently >>> class MyModule(nn.Module): ... keys_in = ["a"] ... keys_out = ["b"] ... ... @dispatch(source="keys_in", dest="keys_out") ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a'] + 1 ... return tensordict ... >>> module = MyModule() >>> b = module(a=torch.zeros(1, 2)) >>> assert (b == 1).all() >>> # or this >>> class MyModule(nn.Module): ... @dispatch(source=["a"], dest=["b"]) ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a'] + 1 ... return tensordict ... >>> module = MyModule() >>> b = module(a=torch.zeros(1, 2)) >>> assert (b == 1).all()
dispatch_kwargs()
也适用于使用默认"_"
分隔符的嵌套键。示例
>>> class MyModuleNest(nn.Module): ... in_keys = [("a", "c")] ... out_keys = ["b"] ... ... @dispatch ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a', 'c'] + 1 ... return tensordict ... >>> module = MyModuleNest() >>> b, = module(a_c=torch.zeros(1, 2)) >>> assert (b == 1).all()
如果需要其他分隔符,可以在构造函数中使用
separator
参数指定。示例
>>> class MyModuleNest(nn.Module): ... in_keys = [("a", "c")] ... out_keys = ["b"] ... ... @dispatch(separator="sep") ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a', 'c'] + 1 ... return tensordict ... >>> module = MyModuleNest() >>> b, = module(asepc=torch.zeros(1, 2)) >>> assert (b == 1).all()
由于输入键是字符串的有序序列,因此
dispatch()
也可以与无名参数一起使用,其中顺序必须与输入键的顺序匹配。注意
如果第一个参数是
TensorDictBase
实例,则假定未使用 dispatch,并且此 tensordict 包含运行模块所需的所有信息。换句话说,无法使用模块输入的第一个键指向 tensordict 实例来分解 tensordict。通常,最好将dispatch()
与仅 tensordict 叶子节点一起使用。示例
>>> class MyModuleNest(nn.Module): ... in_keys = [("a", "c"), "d"] ... out_keys = ["b"] ... ... @dispatch ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a', 'c'] + tensordict["d"] ... return tensordict ... >>> module = MyModuleNest() >>> b, = module(torch.zeros(1, 2), d=torch.ones(1, 2)) # works >>> assert (b == 1).all() >>> b, = module(torch.zeros(1, 2), torch.ones(1, 2)) # works >>> assert (b == 1).all() >>> try: ... b, = module(torch.zeros(1, 2), a_c=torch.ones(1, 2)) # fails ... except: ... print("oopsy!") ...