is_tensor_collection¶
- class tensordict.is_tensor_collection(datatype: type | Any)¶
检查数据对象或类型是否来自 tensordict 库的张量容器。
- 返回值:
如果输入是 TensorDictBase 子类、tensorclass 或这些类的实例,则返回
True
。否则返回False
。
示例
>>> is_tensor_collection(TensorDictBase) # True >>> is_tensor_collection(TensorDict({}, [])) # True >>> @tensorclass ... class MyClass: ... pass ... >>> is_tensor_collection(MyClass) # True >>> is_tensor_collection(MyClass(batch_size=[])) # True