isin¶
- class tensordict.utils.isin(input: TensorDictBase, reference: TensorDictBase, key: NestedKey, dim: int = 0)¶
测试 input 中
key
在dim
维度上的每个元素是否存在于 reference 中。此函数返回一个布尔张量,其长度为
input.batch_size[dim]
,对于key
入口中也存在于reference
中的元素,其值为True
。此函数假定input
和reference
具有相同的批处理大小并包含指定的入口,否则将引发错误。- 参数:
input (TensorDictBase) – 输入 TensorDict。
reference (TensorDictBase) – 用于测试的目标 TensorDict。
key (Nestedkey) – 要测试的键。
dim (int, 可选) – 要测试的维度。默认为
0
。
- 返回值:
- 一个布尔张量,其长度为
input.batch_size[dim]
,对于位于 的
input
key
张量中且也存在于reference
中的元素,其值为True
。
- 一个布尔张量,其长度为
- 返回类型:
out (Tensor)
示例
>>> td = TensorDict( ... { ... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9]]), ... "tensor2": torch.tensor([[10, 20], [30, 40], [40, 50], [50, 60]]), ... }, ... batch_size=[4], ... ) >>> td_ref = TensorDict( ... { ... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [10, 11, 12]]), ... "tensor2": torch.tensor([[10, 20], [30, 40], [50, 60]]), ... }, ... batch_size=[3], ... ) >>> in_reference = isin(td, td_ref, key="tensor1") >>> expected_in_reference = torch.tensor([True, True, True, False]) >>> torch.testing.assert_close(in_reference, expected_in_reference)