• 文档 >
  • 使用回放缓冲区
快捷方式

使用回放缓冲区

作者: Vincent Moens

回放缓冲区是任何 RL 或控制算法的核心部分。监督学习方法通常以一个训练循环为特征,在这个循环中,数据从静态数据集随机提取并连续馈送到模型和损失函数。在 RL 中,事情通常略有不同:数据是使用模型收集的,然后暂时存储在动态结构(经验回放缓冲区)中,该结构充当损失模块的数据集。

与往常一样,缓冲区使用的上下文极大地影响了它的构建方式:有些人可能希望存储轨迹,而另一些人则希望存储单个转换。特定采样策略在上下文中可能更可取:某些项目可能比其他项目具有更高的优先级,或者对有/无替换进行采样可能很重要。计算因素也可能发挥作用,例如缓冲区的大小,它可能超过可用的 RAM 存储。

出于这些原因,TorchRL 的回放缓冲区是完全可组合的:尽管它们附带“电池”,需要最少的努力才能构建,但它们也支持许多自定义,例如存储类型、采样策略或数据转换。

在本教程中,您将学习

基础知识:构建一个普通的回放缓冲区

TorchRL 的回放缓冲区旨在优先考虑模块化、可组合性、效率和简单性。例如,创建一个基本的回放缓冲区是一个简单的过程,如以下示例所示

import tempfile

from torchrl.data import ReplayBuffer

buffer = ReplayBuffer()

默认情况下,此回放缓冲区的大小为 1000。让我们使用 extend() 方法填充我们的缓冲区来检查这一点

print("length before adding elements:", len(buffer))

buffer.extend(range(2000))

print("length after adding elements:", len(buffer))
length before adding elements: 0
length after adding elements: 1000

我们使用了 extend() 方法,该方法旨在一次添加多个项目。如果传递给 extend 的对象具有多个维度,则其第一个维度将被视为在缓冲区中拆分为单独元素的维度。

这实际上意味着,当将多维张量或 tensordict 添加到缓冲区时,缓冲区在计算其内存中保存的元素数量时只会查看第一个维度。如果传递的对象不可迭代,则会抛出异常。

要逐个添加项目,应使用 add() 方法。

自定义存储

我们看到缓冲区已被限制在我们传递给它的前 1000 个元素。要更改大小,我们需要自定义存储。

TorchRL 提供了三种类型的存储

  • ListStorage 在列表中独立存储元素。它支持任何数据类型,但这种灵活性是以效率为代价的;

  • LazyTensorStorage 连续存储张量数据结构。它自然地与 TensorDict(或 tensorclass)对象一起使用。存储在每个张量基础上是连续的,这意味着采样比使用列表时更有效,但隐式限制是传递给它的任何数据必须具有与用于实例化缓冲区的第一个数据批次相同的基本属性(例如形状和数据类型)。传递不符合此要求的数据将引发异常或导致一些未定义的行为。

  • LazyMemmapStorage 的工作方式与 LazyTensorStorage 相同,因为它很懒(即它期望第一个数据批次被实例化),并且它需要每个存储批次中形状和数据类型匹配的数据。使这种存储独特的是,它指向磁盘文件(或使用文件系统存储),这意味着它可以支持非常大的数据集,同时仍然以连续的方式访问数据。

让我们看看如何使用这些存储中的每一个

from torchrl.data import LazyMemmapStorage, LazyTensorStorage, ListStorage

# We define the maximum size of the buffer
size = 100

具有列表存储缓冲区的缓冲区可以存储任何类型的数据(但我们必须更改 collate_fn,因为默认值期望数值数据)

buffer_list = ReplayBuffer(storage=ListStorage(size), collate_fn=lambda x: x)
buffer_list.extend(["a", 0, "b"])
print(buffer_list.sample(3))
[0, 'b', 'a']

由于它对假设的程度最低,因此 ListStorage 是 TorchRL 中的默认存储。

LazyTensorStorage 可以连续存储数据。这应该是处理中等大小的复杂但不变的数据结构的首选方法

buffer_lazytensor = ReplayBuffer(storage=LazyTensorStorage(size))

让我们创建一个大小为 ``torch.Size([3])` 并存储了 2 个张量的批次数据

import torch
from tensordict import TensorDict

data = TensorDict(
    {
        "a": torch.arange(12).view(3, 4),
        ("b", "c"): torch.arange(15).view(3, 5),
    },
    batch_size=[3],
)
print(data)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([3, 5]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

第一次调用 extend() 将实例化存储。数据的第一个维度将解绑为单独的数据点

buffer_lazytensor.extend(data)
print(f"The buffer has {len(buffer_lazytensor)} elements")
The buffer has 3 elements

让我们从缓冲区中采样,并打印数据

sample = buffer_lazytensor.sample(5)
print("samples", sample["a"], sample["b", "c"])
samples tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 0,  1,  2,  3],
        [ 8,  9, 10, 11],
        [ 0,  1,  2,  3]]) tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [ 0,  1,  2,  3,  4],
        [10, 11, 12, 13, 14],
        [ 0,  1,  2,  3,  4]])

LazyMemmapStorage 的创建方式相同

buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size))
buffer_lazymemmap.extend(data)
print(f"The buffer has {len(buffer_lazymemmap)} elements")
sample = buffer_lazytensor.sample(5)
print("samples: a=", sample["a"], "\n('b', 'c'):", sample["b", "c"])
The buffer has 3 elements
samples: a= tensor([[ 4,  5,  6,  7],
        [ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [ 8,  9, 10, 11]])
('b', 'c'): tensor([[ 5,  6,  7,  8,  9],
        [ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [10, 11, 12, 13, 14]])

我们还可以自定义磁盘上的存储位置

tempdir = tempfile.TemporaryDirectory()
buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size, scratch_dir=tempdir))
buffer_lazymemmap.extend(data)
print(f"The buffer has {len(buffer_lazymemmap)} elements")
print("the 'a' tensor is stored in", buffer_lazymemmap._storage._storage["a"].filename)
print(
    "the ('b', 'c') tensor is stored in",
    buffer_lazymemmap._storage._storage["b", "c"].filename,
)
The buffer has 3 elements
the 'a' tensor is stored in /pytorch/rl/docs/source/reference/generated/tutorials/<TemporaryDirectory '/tmp/tmpcavfrtwk'>/a.memmap
the ('b', 'c') tensor is stored in /pytorch/rl/docs/source/reference/generated/tutorials/<TemporaryDirectory '/tmp/tmpcavfrtwk'>/b/c.memmap

与 TensorDict 集成

张量位置遵循与包含它们的 TensorDict 相同的结构:这使得在训练期间保存和加载缓冲区变得容易。

要充分利用 TensorDict 作为数据载体,可以使用 TensorDictReplayBuffer 类。其主要优势之一是能够处理采样数据的组织,以及可能需要的任何其他信息(例如样本索引)。

它可以像标准的 ReplayBuffer 一样构建,并且通常可以互换使用。

from torchrl.data import TensorDictReplayBuffer

tempdir = tempfile.TemporaryDirectory()
buffer_lazymemmap = TensorDictReplayBuffer(
    storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12
)
buffer_lazymemmap.extend(data)
print(f"The buffer has {len(buffer_lazymemmap)} elements")
sample = buffer_lazymemmap.sample()
print("sample:", sample)
The buffer has 3 elements
sample: TensorDict(
    fields={
        a: Tensor(shape=torch.Size([12, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([12, 5]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([12]),
            device=cpu,
            is_shared=False),
        index: Tensor(shape=torch.Size([12]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([12]),
    device=cpu,
    is_shared=False)

现在我们的样本有一个额外的 "index" 键,指示采样的索引。让我们看一下这些索引

print(sample["index"])
tensor([1, 2, 0, 1, 1, 1, 2, 1, 1, 2, 1, 1])

与 tensorclass 集成

ReplayBuffer 类和相关的子类也与 tensorclass 类原生兼容,这些类可以方便地用于以更明确的方式编码数据集

from tensordict import tensorclass


@tensorclass
class MyData:
    images: torch.Tensor
    labels: torch.Tensor


data = MyData(
    images=torch.randint(
        255,
        (10, 64, 64, 3),
    ),
    labels=torch.randint(100, (10,)),
    batch_size=[10],
)

tempdir = tempfile.TemporaryDirectory()
buffer_lazymemmap = ReplayBuffer(
    storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12
)
buffer_lazymemmap.extend(data)
print(f"The buffer has {len(buffer_lazymemmap)} elements")
sample = buffer_lazymemmap.sample()
print("sample:", sample)
The buffer has 10 elements
sample: MyData(
    images=Tensor(shape=torch.Size([12, 64, 64, 3]), device=cpu, dtype=torch.int64, is_shared=False),
    labels=Tensor(shape=torch.Size([12]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([12]),
    device=cpu,
    is_shared=False)

正如预期,数据具有正确的类和形状!

与其他张量结构(PyTrees)集成

TorchRL 的重放缓冲区也适用于任何 pytree 数据结构。PyTree 是由字典、列表和/或元组组成的任意深度的嵌套结构,其中叶子是张量。这意味着可以将任何此类树结构存储在连续内存中!可以使用各种存储:TensorStorageLazyMemmapStorageLazyTensorStorage 都接受这种数据。

以下是此功能的简要演示

from torch.utils._pytree import tree_map

让我们在磁盘上构建我们的重放缓冲区

rb = ReplayBuffer(storage=LazyMemmapStorage(size))
data = {
    "a": torch.randn(3),
    "b": {"c": (torch.zeros(2), [torch.ones(1)])},
    30: -torch.ones(()),  # non-string keys also work
}
rb.add(data)

# The sample has a similar structure to the data (with a leading dimension of 10 for each tensor)
sample = rb.sample(10)

使用 pytrees,任何可调用对象都可以用作转换

def transform(x):
    # Zeros all the data in the pytree
    return tree_map(lambda y: y * 0, x)


rb.append_transform(transform)
sample = rb.sample(batch_size=12)

让我们检查一下我们的转换是否完成了它的工作

def assert0(x):
    assert (x == 0).all()


tree_map(assert0, sample)
{'a': None, 'b': {'c': (None, [None])}, 30: None}

从缓冲区中采样和迭代

重放缓冲区支持多种采样策略

  • 如果批次大小是固定的并且可以在构建时定义,则可以将其作为关键字参数传递给缓冲区;

  • 使用固定批次大小,可以遍历重放缓冲区以收集样本;

  • 如果批次大小是动态的,则可以将其传递给 sample 方法以动态方式。

可以使用多线程进行采样,但这与最后一个选项不兼容(因为它要求缓冲区提前知道下一个批次的大小)。

让我们看几个例子

固定批次大小

如果在构建期间传递了批次大小,则在采样时应该省略它

data = MyData(
    images=torch.randint(
        255,
        (200, 64, 64, 3),
    ),
    labels=torch.randint(100, (200,)),
    batch_size=[200],
)

buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size), batch_size=128)
buffer_lazymemmap.extend(data)
buffer_lazymemmap.sample()
MyData(
    images=Tensor(shape=torch.Size([128, 64, 64, 3]), device=cpu, dtype=torch.int64, is_shared=False),
    labels=Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([128]),
    device=cpu,
    is_shared=False)

此批次数据具有我们希望它具有的大小(128)。

要启用多线程采样,只需在构建期间将正整数传递给 prefetch 关键字参数。只要采样很耗时(例如,使用优先采样器时),这应该会显著加快采样速度

buffer_lazymemmap = ReplayBuffer(
    storage=LazyMemmapStorage(size), batch_size=128, prefetch=10
)  # creates a queue of 10 elements to be prefetched in the background
buffer_lazymemmap.extend(data)
print(buffer_lazymemmap.sample())
MyData(
    images=Tensor(shape=torch.Size([128, 64, 64, 3]), device=cpu, dtype=torch.int64, is_shared=False),
    labels=Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([128]),
    device=cpu,
    is_shared=False)

使用固定批次大小遍历缓冲区

我们也可以像使用常规数据加载器一样遍历缓冲区,只要批次大小是预定义的

for i, data in enumerate(buffer_lazymemmap):
    if i == 3:
        print(data)
        break
MyData(
    images=Tensor(shape=torch.Size([128, 64, 64, 3]), device=cpu, dtype=torch.int64, is_shared=False),
    labels=Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([128]),
    device=cpu,
    is_shared=False)

由于我们的采样技术完全随机,并且不会阻止替换,因此所讨论的迭代器是无限的。但是,我们可以使用 SamplerWithoutReplacement,它将把我们的缓冲区转换为有限迭代器

from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

buffer_lazymemmap = ReplayBuffer(
    storage=LazyMemmapStorage(size), batch_size=32, sampler=SamplerWithoutReplacement()
)

我们创建一个足够大的数据以获得一些样本

data = TensorDict(
    {
        "a": torch.arange(64).view(16, 4),
        ("b", "c"): torch.arange(128).view(16, 8),
    },
    batch_size=[16],
)

buffer_lazymemmap.extend(data)
for _i, _ in enumerate(buffer_lazymemmap):
    continue
print(f"A total of {_i+1} batches have been collected")
A total of 1 batches have been collected

动态批次大小

与我们之前看到的相反,batch_size 关键字参数可以省略,并直接传递给 sample 方法

buffer_lazymemmap = ReplayBuffer(
    storage=LazyMemmapStorage(size), sampler=SamplerWithoutReplacement()
)
buffer_lazymemmap.extend(data)
print("sampling 3 elements:", buffer_lazymemmap.sample(3))
print("sampling 5 elements:", buffer_lazymemmap.sample(5))
sampling 3 elements: TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([3]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=cpu,
    is_shared=False)
sampling 5 elements: TensorDict(
    fields={
        a: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([5, 8]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([5]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

优先重放缓冲区

TorchRL 还提供了一个用于 优先重放缓冲区 的接口。此缓冲区类根据通过数据传递的优先级信号采样数据。

虽然此工具与非 tensordict 数据兼容,但我们鼓励使用 TensorDict,因为它可以使在缓冲区中进出元数据变得容易。

让我们首先看看如何在一般情况下构建优先重放缓冲区。必须手动设置 \(\alpha\)\(\beta\) 超参数

from torchrl.data.replay_buffers.samplers import PrioritizedSampler

size = 100

rb = ReplayBuffer(
    storage=ListStorage(size),
    sampler=PrioritizedSampler(max_capacity=size, alpha=0.8, beta=1.1),
    collate_fn=lambda x: x,
)

扩展重放缓冲区会返回项目索引,我们将在稍后需要这些索引来更新优先级

indices = rb.extend([1, "foo", None])

采样器期望每个元素都有一个优先级。添加到缓冲区时,优先级被设置为默认值 1。一旦计算出优先级(通常通过损失),就必须在缓冲区中更新它。

这是通过 update_priority() 方法完成的,该方法需要索引和优先级。我们将数据集中的第二个样本分配一个人为的高优先级以观察它对采样的影响

rb.update_priority(index=indices, priority=torch.tensor([0, 1_000, 0.1]))

我们观察到从缓冲区中采样主要返回第二个样本 ("foo")

sample, info = rb.sample(10, return_info=True)
print(sample)
['foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo']

信息包含项目的相对权重以及索引。

print(info)
{'_weight': tensor([2.0893e-10, 2.0893e-10, 2.0893e-10, 2.0893e-10, 2.0893e-10, 2.0893e-10,
        2.0893e-10, 2.0893e-10, 2.0893e-10, 2.0893e-10]), 'index': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}

我们看到,与常规缓冲区相比,使用优先重放缓冲区需要在训练循环中执行一系列额外的步骤

  • 收集数据并扩展缓冲区后,必须更新项目的优先级;

  • 计算损失并从中获得“优先级信号”后,我们必须再次更新缓冲区中项目的优先级。这要求我们跟踪索引。

这极大地阻碍了缓冲区的可重用性:如果要编写一个训练脚本,其中可以创建优先缓冲区和常规缓冲区,则必须添加大量的控制流以确保在适当的位置调用适当的方法,如果且仅当使用优先缓冲区时。

让我们看看如何用 TensorDict 改进这一点。我们看到 TensorDictReplayBuffer 返回带有其相对存储索引的增强数据。我们没有提到的一个功能是,此类还确保在扩展期间存在时,优先级信号会自动解析为优先级采样器。

这些功能的组合以多种方式简化了事情:- 在扩展缓冲区时,优先级信号将自动

解析(如果存在),并且优先级将被准确分配;

  • 索引将存储在采样的 tensordict 中,这使得在损失计算后更新优先级变得容易。

  • 在计算损失时,优先级信号将被注册在传递给损失模块的 tensordict 中,这使得无需任何努力即可更新权重

    >>> data = replay_buffer.sample()
    >>> loss_val = loss_module(data)
    >>> replay_buffer.update_tensordict_priority(data)
    

以下代码说明了这些概念。我们使用优先级采样器构建一个重放缓冲区,并在构造函数中指示应该从哪个条目中获取优先级信号

rb = TensorDictReplayBuffer(
    storage=ListStorage(size),
    sampler=PrioritizedSampler(size, alpha=0.8, beta=1.1),
    priority_key="td_error",
    batch_size=1024,
)

让我们选择一个与存储索引成比例的优先级信号

data["td_error"] = torch.arange(data.numel())

rb.extend(data)

sample = rb.sample()

更高的索引应该更频繁地出现

from matplotlib import pyplot as plt

plt.hist(sample["index"].numpy())
rb tutorial
(array([132.,  55., 122.,  55., 113.,  66., 124.,  66., 142., 149.]), array([ 0. ,  1.5,  3. ,  4.5,  6. ,  7.5,  9. , 10.5, 12. , 13.5, 15. ]), <BarContainer object of 10 artists>)

一旦我们处理了样本,我们就会使用 torchrl.data.TensorDictReplayBuffer.update_tensordict_priority() 方法更新优先级键。为了展示其工作原理,让我们还原采样项目的优先级

sample = rb.sample()
sample["td_error"] = data.numel() - sample["index"]
rb.update_tensordict_priority(sample)

现在,更高的索引应该出现的频率更低

sample = rb.sample()
from matplotlib import pyplot as plt

plt.hist(sample["index"].numpy())
rb tutorial
(array([200., 100., 160.,  82., 167.,  70., 103.,  41.,  71.,  30.]), array([ 0. ,  1.5,  3. ,  4.5,  6. ,  7.5,  9. , 10.5, 12. , 13.5, 15. ]), <BarContainer object of 10 artists>)

使用转换

存储在重放缓冲区中的数据可能还没有准备好呈现给损失模块。在某些情况下,收集器产生的数据可能过于庞大,无法按原样保存。例如,将图像从 uint8 转换为浮点数张量,或在使用决策转换器时连接连续帧。

只需将适当的转换附加到缓冲区,即可在缓冲区内处理数据。以下是一些示例

保存原始图像

uint8 类型的张量比我们通常馈送到模型中的浮点数张量在内存上便宜得多。出于这个原因,保存原始图像可能很有用。以下脚本展示了如何构建一个仅返回原始图像但使用转换后的图像进行推理的收集器,以及如何在重放缓冲区中回收这些转换

from torchrl.collectors import SyncDataCollector
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
    Compose,
    GrayScale,
    Resize,
    ToTensorImage,
    TransformedEnv,
)
from torchrl.envs.utils import RandomPolicy

env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True),
    Compose(
        ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
        Resize(in_keys=["pixels_trsf"], w=64, h=64),
        GrayScale(in_keys=["pixels_trsf"]),
    ),
)

让我们看一下回滚

print(env.rollout(3))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                pixels: Tensor(shape=torch.Size([3, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
                pixels_trsf: Tensor(shape=torch.Size([3, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False),
        pixels: Tensor(shape=torch.Size([3, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
        pixels_trsf: Tensor(shape=torch.Size([3, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

我们刚刚创建了一个生成像素的环境。处理这些图像以馈送到策略。我们希望存储原始图像,而不是它们的转换。为此,我们将一个转换附加到收集器以选择我们要查看的键

from torchrl.envs.transforms import ExcludeTransform

collector = SyncDataCollector(
    env,
    RandomPolicy(env.action_spec),
    frames_per_batch=10,
    total_frames=1000,
    postproc=ExcludeTransform("pixels_trsf", ("next", "pixels_trsf"), "collector"),
)

让我们看一下一批数据,并控制 "pixels_trsf" 键是否已被丢弃

for data in collector:
    print(data)
    break
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                pixels: Tensor(shape=torch.Size([10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
                reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False),
        pixels: Tensor(shape=torch.Size([10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
        terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

我们使用与环境相同的转换创建了一个重放缓冲区。但是,有一个细节需要解决:在没有环境的情况下使用的转换不了解数据结构。将转换附加到环境时,首先会转换 "next" 嵌套 tensordict 中的数据,然后在回滚执行期间将其复制到根目录。在处理静态数据时,情况并非如此。尽管如此,我们的数据带有嵌套的“next”tensordict,如果我们没有明确指示它来处理它,我们的转换将忽略它。我们手动将这些键添加到转换中

t = Compose(
    ToTensorImage(
        in_keys=["pixels", ("next", "pixels")],
        out_keys=["pixels_trsf", ("next", "pixels_trsf")],
    ),
    Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
    GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
)
rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16)
rb.extend(data)
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

我们可以检查 sample 方法是否看到转换后的图像重新出现

print(rb.sample())
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([16, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        index: Tensor(shape=torch.Size([16]), device=cpu, dtype=torch.int64, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                pixels: Tensor(shape=torch.Size([16, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
                pixels_trsf: Tensor(shape=torch.Size([16, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([16]),
            device=cpu,
            is_shared=False),
        pixels: Tensor(shape=torch.Size([16, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
        pixels_trsf: Tensor(shape=torch.Size([16, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([16]),
    device=cpu,
    is_shared=False)

更复杂的示例:使用 CatFrames

The CatFrames transform unfolds the observations through time, creating a n-back memory of past events that allow the model to take the past events into account (in the case of POMDPs or with recurrent policies such as Decision Transformers). Storing these concatenated frames can consume a considerable amount of memory. It can also be problematic when the n-back window needs to be different (usually longer) during training and inference. We solve this problem by executing the CatFrames transform separately in the two phases.

from torchrl.envs import CatFrames, UnsqueezeTransform

我们为返回基于像素的观测的环境创建了一个标准的转换列表

env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True),
    Compose(
        ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
        Resize(in_keys=["pixels_trsf"], w=64, h=64),
        GrayScale(in_keys=["pixels_trsf"]),
        UnsqueezeTransform(-4, in_keys=["pixels_trsf"]),
        CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]),
    ),
)
collector = SyncDataCollector(
    env,
    RandomPolicy(env.action_spec),
    frames_per_batch=10,
    total_frames=1000,
)
for data in collector:
    print(data)
    break
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False),
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                pixels: Tensor(shape=torch.Size([10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
                pixels_trsf: Tensor(shape=torch.Size([10, 4, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False),
        pixels: Tensor(shape=torch.Size([10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
        pixels_trsf: Tensor(shape=torch.Size([10, 4, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

缓冲区转换看起来与环境转换非常相似,但具有额外的 ("next", ...) 键,如前所述

t = Compose(
    ToTensorImage(
        in_keys=["pixels", ("next", "pixels")],
        out_keys=["pixels_trsf", ("next", "pixels_trsf")],
    ),
    Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
    GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
    UnsqueezeTransform(-4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
    CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
)
rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(size), transform=t, batch_size=16)
data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf"))
rb.add(data_exclude)
0

让我们从缓冲区中采样一个批次。转换后的像素键的形状应该从末尾开始在第 4 维度上具有长度为 4 的长度

s = rb.sample(1)  # the buffer has only one element
print(s)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1, 10, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([1, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([1, 10]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        index: Tensor(shape=torch.Size([1, 10]), device=cpu, dtype=torch.int64, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                pixels: Tensor(shape=torch.Size([1, 10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
                pixels_trsf: Tensor(shape=torch.Size([1, 10, 4, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([1, 10]),
            device=cpu,
            is_shared=False),
        pixels: Tensor(shape=torch.Size([1, 10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
        pixels_trsf: Tensor(shape=torch.Size([1, 10, 4, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([1, 10]),
    device=cpu,
    is_shared=False)

经过一些处理(排除未使用的键等)后,我们看到在线和离线生成的数据匹配!

assert (data.exclude("collector") == s.squeeze(0).exclude("index", "collector")).all()

存储轨迹

在许多情况下,希望从缓冲区访问轨迹,而不是简单的转换。TorchRL 提供了多种实现此目的的方法。

目前首选的方法是将轨迹存储在缓冲区的第一个维度上,并使用 SliceSampler 来采样这些数据批次。此类只需要有关数据结构的少量信息来执行其工作(目前它只与 tensordict 结构化数据兼容):切片的数量或它们的长度,以及有关可以在哪里找到片段之间分离的信息(例如,回想一下 使用 DataCollector,轨迹 ID 存储在 ("collector", "traj_ids") 中)。在这个简单的示例中,我们构建了一个包含 4 个连续的短轨迹的数据,并从中采样 4 个切片,每个切片的长度为 2(因为批次大小为 8,并且 8 个项目 // 4 个切片 = 2 个时间步长)。我们也标记了这些步骤。

from torchrl.data import SliceSampler

rb = TensorDictReplayBuffer(
    storage=LazyMemmapStorage(size),
    sampler=SliceSampler(traj_key="episode", num_slices=4),
    batch_size=8,
)
episode = torch.zeros(10, dtype=torch.int)
episode[:3] = 1
episode[3:5] = 2
episode[5:7] = 3
episode[7:] = 4
steps = torch.cat([torch.arange(3), torch.arange(2), torch.arange(2), torch.arange(3)])
data = TensorDict(
    {
        "episode": episode,
        "obs": torch.randn((3, 4, 5)).expand(10, 3, 4, 5),
        "act": torch.randn((20,)).expand(10, 20),
        "other": torch.randn((20, 50)).expand(10, 20, 50),
        "steps": steps,
    },
    [10],
)
rb.extend(data)
sample = rb.sample()
print("episode are grouped", sample["episode"])
print("steps are successive", sample["steps"])
episode are grouped tensor([2, 2, 1, 1, 2, 2, 2, 2], dtype=torch.int32)
steps are successive tensor([0, 1, 1, 2, 0, 1, 0, 1])

结论

我们已经了解了如何在 TorchRL 中使用回放缓冲区,从最简单的用法到更高级的用法,其中数据需要以特定方式进行转换或存储。现在您应该能够

  • 创建回放缓冲区,自定义其存储、采样器和转换;

  • 为您的问题选择最佳的存储类型(列表、内存或基于磁盘);

  • 最大限度地减少缓冲区的大小。

后续步骤

  • 查看数据 API 参考以了解 TorchRL 中的离线数据集,这些数据集基于我们的回放缓冲区 API;

  • 检查其他采样器,例如 SamplerWithoutReplacementPrioritizedSliceSamplerSliceSamplerWithoutReplacement,或其他编写器,例如 TensorDictMaxValueWriter

  • 查看如何在 文档中 检查点回放缓冲区。

脚本的总运行时间:(2 分钟 53.110 秒)

估计内存使用量:522 MB

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源