快捷方式

isin

class tensordict.utils.isin(input: TensorDictBase, reference: TensorDictBase, key: NestedKey, dim: int = 0)

测试输入 key 中每个元素在 input dim 中是否也存在于参考中。

此函数返回一个长度为 input.batch_size[dim] 的布尔张量,对于在 reference 中也存在的条目 key 中的元素,该张量为 True。此函数假设 inputreference 具有相同的批次大小并包含指定的条目,否则将引发错误。

参数:
  • input (TensorDictBase) – 输入 TensorDict。

  • reference (TensorDictBase) – 要测试的目标 TensorDict。

  • key (Nestedkey) – 要测试的键。

  • dim (int, 可选) – 要测试的维度。默认为 0

返回值:

一个长度为 input.batch_size[dim] 的布尔张量,对于

reference 中也存在的 input key 张量中的元素,该张量为 True

返回类型:

out (Tensor)

示例

>>> td = TensorDict(
...     {
...         "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9]]),
...         "tensor2": torch.tensor([[10, 20], [30, 40], [40, 50], [50, 60]]),
...     },
...     batch_size=[4],
... )
>>> td_ref = TensorDict(
...     {
...         "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [10, 11, 12]]),
...         "tensor2": torch.tensor([[10, 20], [30, 40], [50, 60]]),
...     },
...     batch_size=[3],
... )
>>> in_reference = isin(td, td_ref, key="tensor1")
>>> expected_in_reference = torch.tensor([True, True, True, False])
>>> torch.testing.assert_close(in_reference, expected_in_reference)

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源