快捷方式

terminated_or_truncated

torchrl.envs.utils.terminated_or_truncated(data: TensorDictBase, full_done_spec: TensorSpec | None = None, key: str = '_reset', write_full_false: bool = False) bool[源代码]

读取 tensordict 中的 done / terminated / truncated 键,并写入一个新张量,其中聚合了两个信号的值。

修改发生在提供的 TensorDict 实例中,是原地操作。此函数可用于计算批处理或多智能体设置中的 “_reset” 信号,因此输出键的默认名称如此。

参数:
  • data (TensorDictBase) – 输入数据,通常是调用 step() 的结果。

  • full_done_spec (TensorSpec, optional) – 环境的 done_spec,指示应在何处找到 done 叶子。如果未提供,则将在数据中搜索默认的 "done""terminated""truncated" 条目。

  • key (NestedKey, optional) –

    应写入聚合结果的位置。如果为 None,则函数不会写入任何键,而只会输出是否有任何 done 值为 true。 .. note:: 如果 key 条目已存在值,

    将保留之前的值,不会进行更新。

  • write_full_false (bool, optional) – 如果为 True,即使输出为 False(即,在提供的数据结构中没有 done 为 True),也会写入 reset 键。默认为 False

返回: 一个布尔值,指示数据中找到的任何 done 状态是否

包含 True

示例

>>> from torchrl.data.tensor_specs import Categorical
>>> from tensordict import TensorDict
>>> spec = Composite(
...     done=Categorical(2, dtype=torch.bool),
...     truncated=Categorical(2, dtype=torch.bool),
...     nested=Composite(
...         done=Categorical(2, dtype=torch.bool),
...         truncated=Categorical(2, dtype=torch.bool),
...     )
... )
>>> data = TensorDict({
...     "done": True, "truncated": False,
...     "nested": {"done": False, "truncated": True}},
...     batch_size=[]
... )
>>> data = _terminated_or_truncated(data, spec)
>>> print(data["_reset"])
tensor(True)
>>> print(data["nested", "_reset"])
tensor(True)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源