快捷方式

is_tensor_collection

class tensordict.is_tensor_collection(datatype: type | Any)

检查数据对象或类型是否来自 tensordict 库的张量容器。

返回值:

如果输入是 TensorDictBase 子类、tensorclass 或这些类的实例,则返回 True。否则返回 False

示例

>>> is_tensor_collection(TensorDictBase)  # True
>>> is_tensor_collection(TensorDict({}, []))  # True
>>> @tensorclass
... class MyClass:
...     pass
...
>>> is_tensor_collection(MyClass)  # True
>>> is_tensor_collection(MyClass(batch_size=[]))  # True

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源