• 文档 >
  • 使用回放缓冲区
快捷方式

使用回放缓冲区

作者: Vincent Moens

回放缓冲区是任何 RL 或控制算法的核心组成部分。监督学习方法的特点通常是训练循环,其中数据从静态数据集中随机抽取,并连续地馈送到模型和损失函数。在 RL 中,情况通常略有不同:数据是使用模型收集的,然后临时存储在动态结构(经验回放缓冲区)中,该结构充当损失模块的数据集。

与往常一样,使用缓冲区的上下文极大地影响了它的构建方式:有些人可能希望存储轨迹,而另一些人则希望存储单个转换。特定的采样策略在某些情况下可能是更可取的:某些项目可能比其他项目具有更高的优先级,或者在有或没有替换的情况下进行采样可能很重要。计算因素也可能发挥作用,例如缓冲区的大小可能超过可用的 RAM 存储。

由于这些原因,TorchRL 的回放缓冲区是完全可组合的:尽管它们“内置电池”,只需最少的努力即可构建,但它们也支持许多自定义,例如存储类型、采样策略或数据转换。

在本教程中,您将学习

基础知识:构建 Vanilla 回放缓冲区

TorchRL 的回放缓冲区旨在优先考虑模块化、可组合性、效率和简洁性。例如,创建一个基本的回放缓冲区是一个简单的过程,如下例所示

import tempfile

from torchrl.data import ReplayBuffer

buffer = ReplayBuffer()

默认情况下,此回放缓冲区的大小为 1000。让我们通过使用 extend() 方法填充缓冲区来检查这一点

print("length before adding elements:", len(buffer))

buffer.extend(range(2000))

print("length after adding elements:", len(buffer))
length before adding elements: 0
length after adding elements: 1000

我们使用了 extend() 方法,该方法旨在一次添加多个项目。如果传递给 extend 的对象的维度超过一个,则其第一个维度被认为是要在缓冲区中拆分为单独元素的维度。

这基本上意味着,当向缓冲区添加多维张量或 TensorDict 时,缓冲区在计算其内存中保存的元素时,只会查看第一个维度。如果传递的对象不可迭代,则会抛出异常。

要一次添加一个项目,应改用 add() 方法。

自定义存储

我们看到缓冲区已被限制为我们传递给它的前 1000 个元素。要更改大小,我们需要自定义存储。

TorchRL 提出了三种类型的存储

  • ListStorage 将元素独立存储在列表中。它支持任何数据类型,但这种灵活性是以效率为代价的;

  • LazyTensorStorage 连续存储张量数据结构。它自然适用于 TensorDict(或 tensorclass)对象。存储在每个张量的基础上是连续的,这意味着采样将比使用列表时更有效,但隐含的限制是,传递给它的任何数据都必须具有与用于实例化缓冲区的第一个数据批次相同的基本属性(例如形状和 dtype)。传递不符合此要求的数据要么会引发异常,要么会导致一些未定义的行为。

  • LazyMemmapStorage 的工作方式与 LazyTensorStorage 相同,因为它也是惰性的(即,它期望第一个数据批次被实例化),并且它要求每个存储的批次的数据在形状和 dtype 上都匹配。这种存储的独特之处在于它指向磁盘文件(或使用文件系统存储),这意味着它可以支持非常大的数据集,同时仍然以连续的方式访问数据。

让我们看看如何使用这些存储中的每一个

from torchrl.data import LazyMemmapStorage, LazyTensorStorage, ListStorage

# We define the maximum size of the buffer
size = 100

带有列表存储缓冲区的缓冲区可以存储任何类型的数据(但我们必须更改 collate_fn,因为默认值期望数值数据)

buffer_list = ReplayBuffer(storage=ListStorage(size), collate_fn=lambda x: x)
buffer_list.extend(["a", 0, "b"])
print(buffer_list.sample(3))
['a', 'a', 0]

由于它是假设最少的存储,因此 ListStorage 是 TorchRL 中的默认存储。

LazyTensorStorage 可以连续存储数据。当处理中等大小的复杂但不变的数据结构时,这应该是首选选项

buffer_lazytensor = ReplayBuffer(storage=LazyTensorStorage(size))

让我们创建一个大小为 ``torch.Size([3])` 的数据批次,其中存储了 2 个张量

import torch
from tensordict import TensorDict

data = TensorDict(
    {
        "a": torch.arange(12).view(3, 4),
        ("b", "c"): torch.arange(15).view(3, 5),
    },
    batch_size=[3],
)
print(data)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([3, 5]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

第一次调用 extend() 将实例化存储。数据的第一个维度被解绑为单独的数据点

buffer_lazytensor.extend(data)
print(f"The buffer has {len(buffer_lazytensor)} elements")
The buffer has 3 elements

让我们从缓冲区采样,并打印数据

sample = buffer_lazytensor.sample(5)
print("samples", sample["a"], sample["b", "c"])
samples tensor([[0, 1, 2, 3],
        [0, 1, 2, 3],
        [4, 5, 6, 7],
        [0, 1, 2, 3],
        [0, 1, 2, 3]]) tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4]])

LazyMemmapStorage 以相同的方式创建

buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size))
buffer_lazymemmap.extend(data)
print(f"The buffer has {len(buffer_lazymemmap)} elements")
sample = buffer_lazytensor.sample(5)
print("samples: a=", sample["a"], "\n('b', 'c'):", sample["b", "c"])
The buffer has 3 elements
samples: a= tensor([[ 0,  1,  2,  3],
        [ 0,  1,  2,  3],
        [ 0,  1,  2,  3],
        [ 8,  9, 10, 11],
        [ 8,  9, 10, 11]])
('b', 'c'): tensor([[ 0,  1,  2,  3,  4],
        [ 0,  1,  2,  3,  4],
        [ 0,  1,  2,  3,  4],
        [10, 11, 12, 13, 14],
        [10, 11, 12, 13, 14]])

我们还可以自定义磁盘上的存储位置

tempdir = tempfile.TemporaryDirectory()
buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size, scratch_dir=tempdir))
buffer_lazymemmap.extend(data)
print(f"The buffer has {len(buffer_lazymemmap)} elements")
print("the 'a' tensor is stored in", buffer_lazymemmap._storage._storage["a"].filename)
print(
    "the ('b', 'c') tensor is stored in",
    buffer_lazymemmap._storage._storage["b", "c"].filename,
)
The buffer has 3 elements
the 'a' tensor is stored in /pytorch/rl/docs/source/reference/generated/tutorials/<TemporaryDirectory '/tmp/tmpzxmvjdim'>/a.memmap
the ('b', 'c') tensor is stored in /pytorch/rl/docs/source/reference/generated/tutorials/<TemporaryDirectory '/tmp/tmpzxmvjdim'>/b/c.memmap

与 TensorDict 集成

张量位置遵循与包含它们的 TensorDict 相同的结构:这使得在训练期间轻松保存和加载缓冲区。

要充分利用 TensorDict 作为数据载体的潜力,可以使用 TensorDictReplayBuffer 类。它的主要优势之一是能够处理采样数据的组织,以及可能需要的任何其他信息(例如样本索引)。

它可以像标准的 ReplayBuffer 一样构建,并且通常可以互换使用。

from torchrl.data import TensorDictReplayBuffer

tempdir = tempfile.TemporaryDirectory()
buffer_lazymemmap = TensorDictReplayBuffer(
    storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12
)
buffer_lazymemmap.extend(data)
print(f"The buffer has {len(buffer_lazymemmap)} elements")
sample = buffer_lazymemmap.sample()
print("sample:", sample)
The buffer has 3 elements
sample: TensorDict(
    fields={
        a: Tensor(shape=torch.Size([12, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([12, 5]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([12]),
            device=cpu,
            is_shared=False),
        index: Tensor(shape=torch.Size([12]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([12]),
    device=cpu,
    is_shared=False)

我们的样本现在有一个额外的 "index" 键,指示采样了哪些索引。让我们看看这些索引

print(sample["index"])
tensor([2, 1, 2, 1, 0, 0, 1, 2, 0, 0, 0, 2])

与 tensorclass 集成

ReplayBuffer 类和相关的子类也原生支持 tensorclass 类,它可以方便地用于以更明确的方式编码数据集

from tensordict import tensorclass


@tensorclass
class MyData:
    images: torch.Tensor
    labels: torch.Tensor


data = MyData(
    images=torch.randint(
        255,
        (10, 64, 64, 3),
    ),
    labels=torch.randint(100, (10,)),
    batch_size=[10],
)

tempdir = tempfile.TemporaryDirectory()
buffer_lazymemmap = ReplayBuffer(
    storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12
)
buffer_lazymemmap.extend(data)
print(f"The buffer has {len(buffer_lazymemmap)} elements")
sample = buffer_lazymemmap.sample()
print("sample:", sample)
The buffer has 10 elements
sample: MyData(
    images=Tensor(shape=torch.Size([12, 64, 64, 3]), device=cpu, dtype=torch.int64, is_shared=False),
    labels=Tensor(shape=torch.Size([12]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([12]),
    device=cpu,
    is_shared=False)

正如预期的那样。数据具有正确的类和形状!

与其他张量结构(PyTrees)集成

TorchRL 的回放缓冲区也适用于任何 PyTree 数据结构。PyTree 是由字典、列表和/或元组组成的任意深度的嵌套结构,其中叶子是张量。这意味着可以将任何此类树结构存储在连续内存中!可以使用各种存储:TensorStorageLazyMemmapStorageLazyTensorStorage 都接受这种类型的数据。

以下是此功能的外观的简要演示

from torch.utils._pytree import tree_map

让我们在磁盘上构建我们的回放缓冲区

rb = ReplayBuffer(storage=LazyMemmapStorage(size))
data = {
    "a": torch.randn(3),
    "b": {"c": (torch.zeros(2), [torch.ones(1)])},
    30: -torch.ones(()),  # non-string keys also work
}
rb.add(data)

# The sample has a similar structure to the data (with a leading dimension of 10 for each tensor)
sample = rb.sample(10)

对于 pytree,任何可调用对象都可以用作转换

def transform(x):
    # Zeros all the data in the pytree
    return tree_map(lambda y: y * 0, x)


rb.append_transform(transform)
sample = rb.sample(batch_size=12)

让我们检查一下我们的转换是否完成了它的工作

def assert0(x):
    assert (x == 0).all()


tree_map(assert0, sample)
{'a': None, 'b': {'c': (None, [None])}, 30: None}

缓冲区的采样和迭代

回放缓冲区支持多种采样策略

  • 如果批大小是固定的并且可以在构建时定义,则可以将其作为关键字参数传递给缓冲区;

  • 对于固定的批大小,可以迭代回放缓冲区以收集样本;

  • 如果批大小是动态的,则可以将其传递给 sample 方法。

采样可以使用多线程完成,但这与最后一个选项不兼容(因为它要求缓冲区预先知道下一个批次的大小)。

让我们看几个例子

固定批大小

如果在构建期间传递了批大小,则在采样时应省略它

data = MyData(
    images=torch.randint(
        255,
        (200, 64, 64, 3),
    ),
    labels=torch.randint(100, (200,)),
    batch_size=[200],
)

buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size), batch_size=128)
buffer_lazymemmap.extend(data)
buffer_lazymemmap.sample()
MyData(
    images=Tensor(shape=torch.Size([128, 64, 64, 3]), device=cpu, dtype=torch.int64, is_shared=False),
    labels=Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([128]),
    device=cpu,
    is_shared=False)

此批数据的大小是我们希望它具有的大小 (128)。

为了启用多线程采样,只需在构造期间将一个正整数传递给 prefetch 关键字参数。这应该在采样耗时的情况下(例如,当使用优先级采样器时)显著加快采样速度。

buffer_lazymemmap = ReplayBuffer(
    storage=LazyMemmapStorage(size), batch_size=128, prefetch=10
)  # creates a queue of 10 elements to be prefetched in the background
buffer_lazymemmap.extend(data)
print(buffer_lazymemmap.sample())
MyData(
    images=Tensor(shape=torch.Size([128, 64, 64, 3]), device=cpu, dtype=torch.int64, is_shared=False),
    labels=Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([128]),
    device=cpu,
    is_shared=False)

使用固定批大小迭代缓冲区

我们也可以像使用常规数据加载器一样迭代缓冲区,只要预先定义了批大小。

for i, data in enumerate(buffer_lazymemmap):
    if i == 3:
        print(data)
        break
MyData(
    images=Tensor(shape=torch.Size([128, 64, 64, 3]), device=cpu, dtype=torch.int64, is_shared=False),
    labels=Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([128]),
    device=cpu,
    is_shared=False)

由于我们的采样技术是完全随机的,并且不阻止替换,因此所讨论的迭代器是无限的。但是,我们可以使用 SamplerWithoutReplacement 来代替,这将把我们的缓冲区转换为有限的迭代器。

from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

buffer_lazymemmap = ReplayBuffer(
    storage=LazyMemmapStorage(size), batch_size=32, sampler=SamplerWithoutReplacement()
)

我们创建一个足够大的数据,以便获得几个样本。

data = TensorDict(
    {
        "a": torch.arange(64).view(16, 4),
        ("b", "c"): torch.arange(128).view(16, 8),
    },
    batch_size=[16],
)

buffer_lazymemmap.extend(data)
for _i, _ in enumerate(buffer_lazymemmap):
    continue
print(f"A total of {_i+1} batches have been collected")
A total of 1 batches have been collected

动态批大小

与我们之前看到的相反,可以省略 batch_size 关键字参数,并直接将其传递给 sample 方法。

buffer_lazymemmap = ReplayBuffer(
    storage=LazyMemmapStorage(size), sampler=SamplerWithoutReplacement()
)
buffer_lazymemmap.extend(data)
print("sampling 3 elements:", buffer_lazymemmap.sample(3))
print("sampling 5 elements:", buffer_lazymemmap.sample(5))
sampling 3 elements: TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([3]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=cpu,
    is_shared=False)
sampling 5 elements: TensorDict(
    fields={
        a: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([5, 8]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([5]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

优先级回放缓冲区

TorchRL 还为 优先级回放缓冲区 提供了接口。这个缓冲区类根据通过数据传递的优先级信号对数据进行采样。

尽管此工具与非 TensorDict 数据兼容,但我们鼓励使用 TensorDict,因为它使元数据能够轻松地在缓冲区内外传递。

让我们首先看看如何在通用情况下构建优先级回放缓冲区。 \(\alpha\)\(\beta\) 超参数必须手动设置。

from torchrl.data.replay_buffers.samplers import PrioritizedSampler

size = 100

rb = ReplayBuffer(
    storage=ListStorage(size),
    sampler=PrioritizedSampler(max_capacity=size, alpha=0.8, beta=1.1),
    collate_fn=lambda x: x,
)

扩展回放缓冲区会返回项目索引,我们稍后需要这些索引来更新优先级。

indices = rb.extend([1, "foo", None])

采样器期望每个元素都有一个优先级。添加到缓冲区时,优先级设置为默认值 1。一旦计算出优先级(通常通过损失),就必须在缓冲区中更新它。

这是通过 update_priority() 方法完成的,该方法需要索引以及优先级。我们为数据集中的第二个样本分配人为的高优先级,以观察其对采样的影响。

rb.update_priority(index=indices, priority=torch.tensor([0, 1_000, 0.1]))

我们观察到,从缓冲区采样主要返回第二个样本 ("foo")。

sample, info = rb.sample(10, return_info=True)
print(sample)
['foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo']

信息包含项目的相对权重以及索引。

print(info)
{'_weight': tensor([2.0893e-10, 2.0893e-10, 2.0893e-10, 2.0893e-10, 2.0893e-10, 2.0893e-10,
        2.0893e-10, 2.0893e-10, 2.0893e-10, 2.0893e-10]), 'index': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}

我们看到,与常规缓冲区相比,使用优先级回放缓冲区需要在训练循环中执行一系列额外的步骤。

  • 在收集数据并扩展缓冲区之后,必须更新项目的优先级;

  • 在计算损失并从中获得“优先级信号”之后,我们必须再次更新缓冲区中项目的优先级。这要求我们跟踪索引。

这大大妨碍了缓冲区的可重用性:如果要编写一个可以同时创建优先级缓冲区和常规缓冲区的训练脚本,则必须添加大量的控制流,以确保在适当的位置调用适当的方法,仅当使用优先级缓冲区时才这样做。

让我们看看如何使用 TensorDict 来改进这一点。我们看到 TensorDictReplayBuffer 返回的数据会使用其相对存储索引进行扩充。我们没有提到的一个功能是,如果扩展期间存在优先级信号,此类还会确保将优先级信号自动解析到优先级采样器。

这些功能的组合在几个方面简化了事情:- 扩展缓冲区时,优先级信号将自动被

解析(如果存在),并且将准确地分配优先级;

  • 索引将存储在采样的 tensordict 中,从而可以轻松地在损失计算后更新优先级。

  • 在计算损失时,优先级信号将注册到传递给损失模块的 tensordict 中,从而可以轻松地更新权重。

    >>> data = replay_buffer.sample()
    >>> loss_val = loss_module(data)
    >>> replay_buffer.update_tensordict_priority(data)
    

以下代码说明了这些概念。我们使用优先级采样器构建了一个回放缓冲区,并在构造函数中指示应从中获取优先级信号的条目。

rb = TensorDictReplayBuffer(
    storage=ListStorage(size),
    sampler=PrioritizedSampler(size, alpha=0.8, beta=1.1),
    priority_key="td_error",
    batch_size=1024,
)

让我们选择一个与存储索引成比例的优先级信号。

data["td_error"] = torch.arange(data.numel())

rb.extend(data)

sample = rb.sample()

较高的索引应该更频繁地出现。

from matplotlib import pyplot as plt

plt.hist(sample["index"].numpy())
rb tutorial
(array([160.,  61., 118.,  64., 129.,  62., 111.,  61., 139., 119.]), array([ 0. ,  1.5,  3. ,  4.5,  6. ,  7.5,  9. , 10.5, 12. , 13.5, 15. ]), <BarContainer object of 10 artists>)

一旦我们处理完样本,我们就使用 torchrl.data.TensorDictReplayBuffer.update_tensordict_priority() 方法更新优先级键。为了展示其工作原理,让我们还原采样项目的优先级。

sample = rb.sample()
sample["td_error"] = data.numel() - sample["index"]
rb.update_tensordict_priority(sample)

现在,较高的索引应该不太频繁地出现。

sample = rb.sample()
from matplotlib import pyplot as plt

plt.hist(sample["index"].numpy())
rb tutorial
(array([211., 103., 178.,  66., 148.,  62., 119.,  40.,  63.,  34.]), array([ 0. ,  1.5,  3. ,  4.5,  6. ,  7.5,  9. , 10.5, 12. , 13.5, 15. ]), <BarContainer object of 10 artists>)

使用变换

存储在回放缓冲区中的数据可能尚未准备好呈现给损失模块。在某些情况下,收集器生成的数据可能过于庞大,无法按原样保存。这方面的示例包括将图像从 uint8 转换为浮点张量,或者在使用决策转换器时连接连续帧。

只需将适当的变换附加到缓冲区,就可以在缓冲区内外处理数据。以下是一些示例。

保存原始图像

uint8 类型的张量比我们通常提供给模型的浮点张量在内存开销上要小得多。因此,保存原始图像可能很有用。以下脚本展示了如何构建一个收集器,该收集器仅返回原始图像,但使用转换后的图像进行推理,以及如何在回放缓冲区中回收这些转换。

from torchrl.collectors import SyncDataCollector
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
    Compose,
    GrayScale,
    Resize,
    ToTensorImage,
    TransformedEnv,
)
from torchrl.envs.utils import RandomPolicy

env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True),
    Compose(
        ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
        Resize(in_keys=["pixels_trsf"], w=64, h=64),
        GrayScale(in_keys=["pixels_trsf"]),
    ),
)

让我们看一下 rollout。

print(env.rollout(3))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                pixels: Tensor(shape=torch.Size([3, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
                pixels_trsf: Tensor(shape=torch.Size([3, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False),
        pixels: Tensor(shape=torch.Size([3, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
        pixels_trsf: Tensor(shape=torch.Size([3, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

我们刚刚创建了一个生成像素的环境。这些图像经过处理以馈送到策略。我们希望存储原始图像,而不是其变换。为此,我们将向收集器附加一个变换,以选择我们要看到的键。

from torchrl.envs.transforms import ExcludeTransform

collector = SyncDataCollector(
    env,
    RandomPolicy(env.action_spec),
    frames_per_batch=10,
    total_frames=1000,
    postproc=ExcludeTransform("pixels_trsf", ("next", "pixels_trsf"), "collector"),
)

让我们看一下一批数据,并检查是否已丢弃 "pixels_trsf" 键。

for data in collector:
    print(data)
    break
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                pixels: Tensor(shape=torch.Size([10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
                reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False),
        pixels: Tensor(shape=torch.Size([10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
        terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

我们创建了一个回放缓冲区,其变换与环境相同。但是,有一个细节需要解决:在没有环境的情况下使用的变换会忽略数据结构。当将变换附加到环境时,"next" 嵌套 tensordict 中的数据会先被变换,然后在 rollout 执行期间复制到根目录。当处理静态数据时,情况并非如此。然而,我们的数据带有一个嵌套的 “next” tensordict,如果我们不明确指示变换处理它,它将被变换忽略。我们手动将这些键添加到变换中。

t = Compose(
    ToTensorImage(
        in_keys=["pixels", ("next", "pixels")],
        out_keys=["pixels_trsf", ("next", "pixels_trsf")],
    ),
    Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
    GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
)
rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16)
rb.extend(data)
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

我们可以检查 sample 方法是否重新显示了转换后的图像。

print(rb.sample())
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([16, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        index: Tensor(shape=torch.Size([16]), device=cpu, dtype=torch.int64, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                pixels: Tensor(shape=torch.Size([16, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
                pixels_trsf: Tensor(shape=torch.Size([16, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([16]),
            device=cpu,
            is_shared=False),
        pixels: Tensor(shape=torch.Size([16, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
        pixels_trsf: Tensor(shape=torch.Size([16, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([16]),
    device=cpu,
    is_shared=False)

更复杂的示例:使用 CatFrames

CatFrames 变换随时间展开观察结果,创建过去事件的 n-back 记忆,使模型能够将过去事件考虑在内(在 POMDP 的情况下或使用循环策略(如决策转换器)时)。存储这些连接的帧可能会消耗大量内存。当 n-back 窗口在训练和推理期间需要不同(通常更长)时,这也会成为问题。我们通过在两个阶段分别执行 CatFrames 变换来解决此问题。

from torchrl.envs import CatFrames, UnsqueezeTransform

我们为返回基于像素的观察结果的环境创建了一个标准的变换列表。

env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True),
    Compose(
        ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
        Resize(in_keys=["pixels_trsf"], w=64, h=64),
        GrayScale(in_keys=["pixels_trsf"]),
        UnsqueezeTransform(-4, in_keys=["pixels_trsf"]),
        CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]),
    ),
)
collector = SyncDataCollector(
    env,
    RandomPolicy(env.action_spec),
    frames_per_batch=10,
    total_frames=1000,
)
for data in collector:
    print(data)
    break
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False),
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                pixels: Tensor(shape=torch.Size([10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
                pixels_trsf: Tensor(shape=torch.Size([10, 4, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False),
        pixels: Tensor(shape=torch.Size([10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
        pixels_trsf: Tensor(shape=torch.Size([10, 4, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

缓冲区变换看起来与环境变换非常相似,但像以前一样带有额外的 ("next", ...) 键。

t = Compose(
    ToTensorImage(
        in_keys=["pixels", ("next", "pixels")],
        out_keys=["pixels_trsf", ("next", "pixels_trsf")],
    ),
    Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
    GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
    UnsqueezeTransform(-4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
    CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
)
rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(size), transform=t, batch_size=16)
data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf"))
rb.add(data_exclude)
0

让我们从缓冲区中采样一个批次。转换后的像素键的形状应在从末尾开始的第 4 维上具有长度 4。

s = rb.sample(1)  # the buffer has only one element
print(s)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1, 10, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([1, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([1, 10]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        index: Tensor(shape=torch.Size([1, 10]), device=cpu, dtype=torch.int64, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                pixels: Tensor(shape=torch.Size([1, 10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
                pixels_trsf: Tensor(shape=torch.Size([1, 10, 4, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([1, 10]),
            device=cpu,
            is_shared=False),
        pixels: Tensor(shape=torch.Size([1, 10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
        pixels_trsf: Tensor(shape=torch.Size([1, 10, 4, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([1, 10]),
    device=cpu,
    is_shared=False)

经过一些处理(排除未使用的键等)后,我们看到在线和离线生成的数据匹配!

assert (data.exclude("collector") == s.squeeze(0).exclude("index", "collector")).all()

存储轨迹

在许多情况下,希望从缓冲区访问轨迹而不是简单的转换。TorchRL 提供了多种实现此目的的方法。

目前首选的方法是将轨迹存储在缓冲区的第一个维度上,并使用 SliceSampler 对这些数据批次进行采样。此类只需要有关您的数据结构的几个信息即可完成其工作(注意,截至目前,它仅与 tensordict 结构化数据兼容):切片的数量或其长度,以及有关可以在何处找到 episode 之间分隔的信息(例如,回想一下,对于 DataCollector,轨迹 ID 存储在 ("collector", "traj_ids") 中)。在这个简单的例子中,我们构建了一个包含 4 个连续短轨迹的数据,并从中采样 4 个切片,每个切片的长度为 2(因为批大小为 8,并且 8 个项目 // 4 个切片 = 2 个时间步长)。我们也标记了步骤。

from torchrl.data import SliceSampler

rb = TensorDictReplayBuffer(
    storage=LazyMemmapStorage(size),
    sampler=SliceSampler(traj_key="episode", num_slices=4),
    batch_size=8,
)
episode = torch.zeros(10, dtype=torch.int)
episode[:3] = 1
episode[3:5] = 2
episode[5:7] = 3
episode[7:] = 4
steps = torch.cat([torch.arange(3), torch.arange(2), torch.arange(2), torch.arange(3)])
data = TensorDict(
    {
        "episode": episode,
        "obs": torch.randn((3, 4, 5)).expand(10, 3, 4, 5),
        "act": torch.randn((20,)).expand(10, 20),
        "other": torch.randn((20, 50)).expand(10, 20, 50),
        "steps": steps,
    },
    [10],
)
rb.extend(data)
sample = rb.sample()
print("episode are grouped", sample["episode"])
print("steps are successive", sample["steps"])
episode are grouped tensor([2, 2, 3, 3, 1, 1, 3, 3], dtype=torch.int32)
steps are successive tensor([0, 1, 0, 1, 1, 2, 0, 1])

结论

我们已经了解了如何在 TorchRL 中使用回放缓冲区,从其最简单的用法到更高级的用法,在这些用法中,数据需要以特定方式进行转换或存储。您现在应该能够

  • 创建回放缓冲区,自定义其存储、采样器和变换;

  • 为您的问题选择最佳存储类型(列表、内存或基于磁盘);

  • 最小化缓冲区的内存占用。

后续步骤

  • 查看数据 API 参考,了解 TorchRL 中的离线数据集,这些数据集基于我们的回放缓冲区 API;

  • 查看其他采样器,例如 SamplerWithoutReplacementPrioritizedSliceSamplerSliceSamplerWithoutReplacement,或其他写入器,例如 TensorDictMaxValueWriter

  • 查看如何在 文档 中检查点 ReplayBuffer。

脚本总运行时间: (2 分 55.642 秒)

估计内存使用量: 491 MB

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源