快捷方式

LazyTensorStorage

class torchrl.data.replay_buffers.LazyTensorStorage(max_size: int, *, device: device = 'cpu', ndim: int = 1, compilable: bool = False, consolidated: bool = False)[source]

一个用于张量(tensor)和张量字典(tensordict)的预分配张量存储。

参数:

max_size (int) – 存储的大小,即缓冲区中存储的最大元素数量。

关键字参数:
  • device (torch.device, optional) – 采样张量存储和发送到的设备。默认为 torch.device("cpu")。如果传入“auto”,设备将自动从传入的第一批数据中获取。默认不启用此功能,以避免错误地将数据放置在 GPU 上,从而导致 OOM(内存不足)问题。

  • ndim (int, optional) – 测量存储大小时需要考虑的维度数量。例如,形状为 [3, 4] 的存储,如果 ndim=1,其容量为 3;如果 ndim=2,其容量为 12。默认为 1

  • compilable (bool, optional) – 存储是否可编译。如果为 True,写入器不能在多个进程之间共享。默认为 False

  • consolidated (bool, optional) – 如果为 True,存储将在首次扩展后被整合。默认为 False

示例

>>> data = TensorDict({
...     "some data": torch.randn(10, 11),
...     ("some", "nested", "data"): torch.randn(10, 11, 12),
... }, batch_size=[10, 11])
>>> storage = LazyTensorStorage(100)
>>> storage.set(range(10), data)
>>> len(storage)  # only the first dimension is considered as indexable
10
>>> storage.get(0)
TensorDict(
    fields={
        some data: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
        some: TensorDict(
            fields={
                nested: TensorDict(
                    fields={
                        data: Tensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([11]),
                    device=cpu,
                    is_shared=False)},
            batch_size=torch.Size([11]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([11]),
    device=cpu,
    is_shared=False)
>>> storage.set(0, storage.get(0).zero_()) # zeros the data along index ``0``

此类也支持 tensorclass 数据。

示例

>>> from tensordict import tensorclass
>>> @tensorclass
... class MyClass:
...     foo: torch.Tensor
...     bar: torch.Tensor
>>> data = MyClass(foo=torch.randn(10, 11), bar=torch.randn(10, 11, 12), batch_size=[10, 11])
>>> storage = LazyTensorStorage(10)
>>> storage.set(range(10), data)
>>> storage.get(0)
MyClass(
    bar=Tensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False),
    foo=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
    batch_size=torch.Size([11]),
    device=cpu,
    is_shared=False)
attach(buffer: Any) None

此函数将采样器附加到此存储。

从该存储读取的缓冲区必须通过调用此方法作为附加实体包含进来。这确保了当存储中的数据发生变化时,即使存储与其他缓冲区(例如,优先级采样器)共享,组件也能感知到变化。

参数:

buffer – 从此存储读取的对象。

dump(*args, **kwargs)

dumps() 的别名。

load(*args, **kwargs)

loads() 的别名。

save(*args, **kwargs)

dumps() 的别名。

文档

查阅 PyTorch 全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源