terminated_or_truncated¶
- torchrl.envs.utils.terminated_or_truncated(data: TensorDictBase, full_done_spec: TensorSpec | None = None, key: str = '_reset', write_full_false: bool = False) bool [源代码]¶
读取 tensordict 中的 done / terminated / truncated 键,并写入一个新张量,其中聚合了两个信号的值。
修改发生在提供的 TensorDict 实例中,是原地操作。此函数可用于计算批处理或多智能体设置中的 “_reset” 信号,因此输出键的默认名称如此。
- 参数:
data (TensorDictBase) – 输入数据,通常是调用
step()
的结果。full_done_spec (TensorSpec, optional) – 环境的 done_spec,指示应在何处找到 done 叶子。如果未提供,则将在数据中搜索默认的
"done"
、"terminated"
和"truncated"
条目。key (NestedKey, optional) –
应写入聚合结果的位置。如果为
None
,则函数不会写入任何键,而只会输出是否有任何 done 值为 true。 .. note:: 如果key
条目已存在值,将保留之前的值,不会进行更新。
write_full_false (bool, optional) – 如果为
True
,即使输出为False
(即,在提供的数据结构中没有 done 为True
),也会写入 reset 键。默认为False
。
- 返回: 一个布尔值,指示数据中找到的任何 done 状态是否
包含
True
。
示例
>>> from torchrl.data.tensor_specs import Categorical >>> from tensordict import TensorDict >>> spec = Composite( ... done=Categorical(2, dtype=torch.bool), ... truncated=Categorical(2, dtype=torch.bool), ... nested=Composite( ... done=Categorical(2, dtype=torch.bool), ... truncated=Categorical(2, dtype=torch.bool), ... ) ... ) >>> data = TensorDict({ ... "done": True, "truncated": False, ... "nested": {"done": False, "truncated": True}}, ... batch_size=[] ... ) >>> data = _terminated_or_truncated(data, spec) >>> print(data["_reset"]) tensor(True) >>> print(data["nested", "_reset"]) tensor(True)