is_tensor_collection¶
- class tensordict.is_tensor_collection(datatype: Union[type, Any])¶
检查数据对象或类型是否为来自 tensordict 库的张量容器。
- 返回:
如果输入是 TensorDictBase 子类、tensorclass 或它们的实例,则返回
True
。否则返回False
。
示例
>>> is_tensor_collection(TensorDictBase) # True >>> is_tensor_collection(TensorDict()) # True >>> @tensorclass ... class MyClass: ... pass ... >>> is_tensor_collection(MyClass) # True >>> is_tensor_collection(MyClass(batch_size=[])) # True