快捷方式

RewardData

torchrl.data.RewardData(input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor', rewards: 'Optional[torch.Tensor]' = None, end_scores: 'Optional[torch.Tensor]' = None, *, batch_size, device=None, names=None)[源代码]
属性 device: device

检索 tensorclass 的设备类型。

dumps(prefix: str | None = None, copy_existing: bool =False, *, num_threads: int =0, return_early: bool =False, share_non_tensor: bool =False) T

将 tensordict 保存到磁盘。

此函数是 memmap() 的代理。

类方法 fields()

返回描述此数据类的字段的元组。

接受数据类或其一个实例。元组元素类型为 Field。

类方法 from_tensordict(tensordict, non_tensordict=None, safe=True)

用于实例化新的 tensor class 对象的 tensor class 包装器。

参数:
  • tensordict (TensorDict) – 张量类型的字典

  • non_tensordict (dict) – 包含非张量和嵌套 tensor class 对象的字典

get(key: NestedKey, *args, **kwargs)

获取与输入键关联的值。

参数:
  • key (str, tuple of str) – 要查询的键。如果是字符串元组,则相当于链式调用 getattr。

  • default – 如果在 tensorclass 中找不到键,则返回默认值。

返回值:

与输入键关联的值

类方法 load(prefix: str | Path, *args, **kwargs) T

从磁盘加载 tensordict。

此类方法是 load_memmap() 的代理。

load_(prefix: str | Path, *args, **kwargs)

在当前 tensordict 中从磁盘加载 tensordict。

此方法是 load_memmap_() 的代理。

类方法 load_memmap(prefix: str | Path, device: torch.device | None = None, non_blocking: bool =False, *, out: TensorDictBase | None = None) T

从磁盘加载内存映射 tensordict。

参数:
  • prefix (str文件夹路径) – 应从中获取已保存 tensordict 的文件夹路径。

  • device (torch.device等效类型, 可选) – 如果提供,数据将异步转换到该设备。支持 “meta” 设备,在此情况下,数据不会被加载,而是创建一组空的“meta”张量。这有助于了解模型整体大小和结构,而无需实际打开任何文件。

  • non_blocking (bool, 可选) – 如果为 True,则在将张量加载到设备后不会调用 synchronize。默认为 False

  • out (TensorDictBase, 可选) – 可选的 tensordict,数据应写入其中。

示例

>>> from tensordict import TensorDict
>>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0)
>>> td.memmap("./saved_td")
>>> td_load = TensorDict.load_memmap("./saved_td")
>>> assert (td == td_load).all()

此方法还允许加载嵌套的 tensordict。

示例

>>> nested = TensorDict.load_memmap("./saved_td/nested")
>>> assert nested["e"] == 0

tensordict 也可以加载到“meta”设备上,或者作为伪张量加载。

示例

>>> import tempfile
>>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}})
>>> with tempfile.TemporaryDirectory() as path:
...     td.save(path)
...     td_load = TensorDict.load_memmap(path, device="meta")
...     print("meta:", td_load)
...     from torch._subclasses import FakeTensorMode
...     with FakeTensorMode():
...         td_load = TensorDict.load_memmap(path)
...         print("fake:", td_load)
meta: TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=meta,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=meta,
    is_shared=False)
fake: TensorDict(
    fields={
        a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False)

尝试就地加载 state_dict 到目标 tensorclass。

memmap(prefix: str | None = None, copy_existing: bool =False, *, num_threads: int =0, return_early: bool =False, share_non_tensor: bool =False, existsok: bool =True) T

将所有张量写入新 tensordict 中的对应内存映射张量。

参数:
  • prefix (str) – 内存映射张量将被存储的目录前缀。目录树结构将模仿 tensordict 的结构。

  • copy_existing (bool) – 如果 False(默认),则如果 tensordict 中的某个条目已是存储在磁盘上且有关联文件的张量,但未按照 prefix 保存到正确位置,则将引发异常。如果为 True,则任何现有张量将被复制到新位置。

关键字参数:
  • num_threads (int, 可选) – 用于写入 memmap 张量的线程数。默认为 0

  • return_early (bool, 可选) – 如果为 Truenum_threads>0,此方法将返回 tensordict 的 Future。

  • share_non_tensor (bool, 可选) – 如果为 True,则非张量数据将在进程之间共享,并且在单节点内任何 worker 上的写入操作(例如就地更新或 set)将更新所有其他 worker 上的值。如果非张量叶子数量很高(例如,共享大量非张量数据堆栈),这可能导致 OOM 或类似错误。默认为 False

  • existsok (bool, 可选) – 如果为 False,则如果同一路径下已存在张量,将引发异常。默认为 True

然后 TensorDict 将被锁定,这意味着任何非就地写入操作(例如,重命名、设置或移除条目)都将抛出异常。一旦 tensordict 解锁,内存映射属性将变为 False,因为不再保证跨进程的身份一致性。

返回值:

如果 return_early=False,则返回一个张量存储在磁盘上的新 tensordict,否则返回一个 TensorDictFuture 实例。

注意

以这种方式序列化可能对深度嵌套的 tensordict 较慢,因此不建议在训练循环内部调用此方法。

memmap_(prefix: str | None = None, copy_existing: bool =False, *, num_threads: int =0, return_early: bool =False, share_non_tensor: bool =False, existsok: bool =True) T

将所有张量就地写入对应内存映射张量。

参数:
  • prefix (str) – 内存映射张量将被存储的目录前缀。目录树结构将模仿 tensordict 的结构。

  • copy_existing (bool) – 如果 False(默认),则如果 tensordict 中的某个条目已是存储在磁盘上且有关联文件的张量,但未按照 prefix 保存到正确位置,则将引发异常。如果为 True,则任何现有张量将被复制到新位置。

关键字参数:
  • num_threads (int, 可选) – 用于写入 memmap 张量的线程数。默认为 0

  • return_early (bool, 可选) – 如果为 Truenum_threads>0,此方法将返回 tensordict 的 Future。可以使用 future.result() 查询结果 tensordict。

  • share_non_tensor (bool, 可选) – 如果为 True,则非张量数据将在进程之间共享,并且在单节点内任何 worker 上的写入操作(例如就地更新或 set)将更新所有其他 worker 上的值。如果非张量叶子数量很高(例如,共享大量非张量数据堆栈),这可能导致 OOM 或类似错误。默认为 False

  • existsok (bool, 可选) – 如果为 False,则如果同一路径下已存在张量,将引发异常。默认为 True

然后 TensorDict 将被锁定,这意味着任何非就地写入操作(例如,重命名、设置或移除条目)都将抛出异常。一旦 tensordict 解锁,内存映射属性将变为 False,因为不再保证跨进程的身份一致性。

返回值:

如果 return_early=False,则返回自身,否则返回一个 TensorDictFuture 实例。

注意

以这种方式序列化可能对深度嵌套的 tensordict 较慢,因此不建议在训练循环内部调用此方法。

memmap_like(prefix: str | None = None, copy_existing: bool =False, *, existsok: bool =True, num_threads: int =0, return_early: bool =False, share_non_tensor: bool =False) T

创建一个与原始 tensordict 形状相同但不含内容的内存映射 tensordict。

参数:
  • prefix (str) – 内存映射张量将被存储的目录前缀。目录树结构将模仿 tensordict 的结构。

  • copy_existing (bool) – 如果 False(默认),则如果 tensordict 中的某个条目已是存储在磁盘上且有关联文件的张量,但未按照 prefix 保存到正确位置,则将引发异常。如果为 True,则任何现有张量将被复制到新位置。

关键字参数:
  • num_threads (int, 可选) – 用于写入 memmap 张量的线程数。默认为 0

  • return_early (bool, 可选) – 如果为 Truenum_threads>0,此方法将返回 tensordict 的 Future。

  • share_non_tensor (bool, 可选) – 如果为 True,则非张量数据将在进程之间共享,并且在单节点内任何 worker 上的写入操作(例如就地更新或 set)将更新所有其他 worker 上的值。如果非张量叶子数量很高(例如,共享大量非张量数据堆栈),这可能导致 OOM 或类似错误。默认为 False

  • existsok (bool, 可选) – 如果为 False,则如果同一路径下已存在张量,将引发异常。默认为 True

然后 TensorDict 将被锁定,这意味着任何非就地写入操作(例如,重命名、设置或移除条目)都将抛出异常。一旦 tensordict 解锁,内存映射属性将变为 False,因为不再保证跨进程的身份一致性。

返回值:

如果 return_early=False,则返回一个数据存储为内存映射张量的新 TensorDict 实例,否则返回一个 TensorDictFuture 实例。

注意

这是将一组大型缓冲区写入磁盘的推荐方法,因为 memmap_() 会复制信息,这对于大内容来说可能很慢。

示例

>>> td = TensorDict({
...     "a": torch.zeros((3, 64, 64), dtype=torch.uint8),
...     "b": torch.zeros(1, dtype=torch.int64),
... }, batch_size=[]).expand(1_000_000)  # expand does not allocate new memory
>>> buffer = td.memmap_like("/path/to/dataset")
memmap_refresh_()

如果内存映射 tensordict 具有 saved_path,则刷新其内容。

如果没有与其关联的路径,此方法将引发异常。

save(prefix: str | None = None, copy_existing: bool =False, *, num_threads: int =0, return_early: bool =False, share_non_tensor: bool =False) T

将 tensordict 保存到磁盘。

此函数是 memmap() 的代理。

set(key: NestedKey, value: Any, inplace: bool =False, non_blocking: bool =False)

设置新的键值对。

参数:
  • key (str, 字符串元组) – 要设置的键的名称。如果是字符串元组,则相当于链式调用 getattr,后跟最终的 setattr。

  • value (Any) – 要存储在 tensorclass 中的值

  • inplace (bool, 可选) – 如果为 True,set 将尝试就地更新值。如果为 False 或者键不存在,值将简单地写入其目标位置。

返回值:

自身

state_dict(destination=None, prefix='', keep_vars=False, flatten=False) dict[str, Any]

返回一个 state_dict 字典,可用于保存和从 tensorclass 加载数据。

to_tensordict(*, retain_none: bool | None =None) TensorDict

将 tensorclass 转换为常规 TensorDict。

复制所有条目。内存映射张量和共享内存张量将转换为常规张量。

参数:

retain_none (bool) –

如果为 True,则 None 值将被写入 tensordict。否则将被丢弃。默认值:True

注意

从 v0.8 开始,默认值将切换为 False

返回值:

包含与 tensorclass 相同值的新 TensorDict 对象。

unbind(dim: int)

返回一个沿指定维度解绑的索引 tensorclass 实例元组。

结果 tensorclass 实例将共享初始 tensorclass 实例的存储。

文档

查阅 PyTorch 完整的开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源