快捷方式

is_tensor_collection

class tensordict.is_tensor_collection(datatype: Union[type, Any])

检查数据对象或类型是否为来自 tensordict 库的张量容器。

返回:

如果输入是 TensorDictBase 子类、tensorclass 或它们的实例,则返回 True。否则返回 False

示例

>>> is_tensor_collection(TensorDictBase)  # True
>>> is_tensor_collection(TensorDict())  # True
>>> @tensorclass
... class MyClass:
...     pass
...
>>> is_tensor_collection(MyClass)  # True
>>> is_tensor_collection(MyClass(batch_size=[]))  # True

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源