• 文档 >
  • 使用回放缓冲区
快捷方式

使用回放缓冲区

作者: Vincent Moens

回放缓冲区是任何 RL 或控制算法的核心组成部分。监督学习方法通常通过一个训练循环来表征,其中数据从静态数据集中随机提取并依次馈送到模型和损失函数。在 RL 中,情况通常略有不同:数据使用模型收集,然后临时存储在动态结构(经验回放缓冲区)中,该结构用作损失模块的数据集。

一如既往,缓冲区的用途极大地影响了它的构建方式:有些人可能希望存储轨迹,而另一些人则希望存储单个转换。特定的采样策略可能在某些上下文中更受欢迎:某些项目可能比其他项目具有更高的优先级,或者有放回或无放回抽样可能很重要。计算因素也可能发挥作用,例如缓冲区的大小可能超出可用的 RAM 存储空间。

由于这些原因,TorchRL 的回放缓冲区是完全可组合的:虽然它们自带“电池”,只需最少的努力即可构建,但它们也支持许多自定义,例如存储类型、采样策略或数据 transforms。

在本教程中,你将学习

基础知识:构建一个普通回放缓冲区

TorchRL 的回放缓冲区旨在优先考虑模块化、可组合性、效率和简单性。例如,创建一个基本的回放缓冲区是一个简单的过程,如下例所示

import tempfile

from torchrl.data import ReplayBuffer

buffer = ReplayBuffer()

默认情况下,此回放缓冲区的大小为 1000。我们通过使用 extend() 方法填充缓冲区来检查这一点

print("length before adding elements:", len(buffer))

buffer.extend(range(2000))

print("length after adding elements:", len(buffer))

我们使用了旨在一次添加多个项目的 extend() 方法。如果传递给 extend 的对象具有多个维度,则其第一个维度将被视为要在缓冲区中拆分成单独元素的部分。

这实质上意味着,当将多维 tensors 或 tensordicts 添加到缓冲区时,缓冲区在计算其内存中保存的元素时,只会查看第一个维度。如果传递的对象不可迭代,则会抛出异常。

要逐个添加项目,应改用 add() 方法。

自定义存储

我们看到缓冲区已被限制为我们传递给它的前 1000 个元素。要更改大小,我们需要自定义我们的存储。

TorchRL 提供三种类型的存储

  • ListStorage 将元素独立存储在列表中。它支持任何数据类型,但这种灵活性是以效率为代价的;

  • LazyTensorStorage 连续存储 tensors 数据结构。它自然地与 TensorDict(或 tensorclass)对象一起工作。存储是按 tensor 连续的,这意味着采样将比使用列表时更高效,但隐含的限制是传递给它的任何数据必须与用于实例化缓冲区的第一个数据批次具有相同基本属性(如 shape 和 dtype)。传递不符合此要求的数据将引发异常或导致某些未定义的行为。

  • LazyMemmapStorage 的工作方式与 LazyTensorStorage 类似,它也是 lazy 的(即,它期望实例化第一个数据批次),并且它要求存储的每个批次数据在 shape 和 dtype 上匹配。使这种存储独特之处在于它指向磁盘文件(或使用文件系统存储),这意味着它可以支持非常大的数据集,同时仍以连续的方式访问数据。

让我们看看如何使用这些存储中的每一种

from torchrl.data import LazyMemmapStorage, LazyTensorStorage, ListStorage

# We define the maximum size of the buffer
size = 100

带有列表存储的缓冲区可以存储任何类型的数据(但我们必须更改 collate_fn,因为默认期望数值数据)

buffer_list = ReplayBuffer(storage=ListStorage(size), collate_fn=lambda x: x)
buffer_list.extend(["a", 0, "b"])
print(buffer_list.sample(3))

由于它具有最少量的假设,ListStorage 是 TorchRL 中的默认存储。

LazyTensorStorage 可以连续存储数据。在处理复杂但不变的中等大小数据结构时,这应该是首选选项

buffer_lazytensor = ReplayBuffer(storage=LazyTensorStorage(size))

让我们创建一个大小为 torch.Size([3]) 的数据批次,其中存储了 2 个 tensors

import torch
from tensordict import TensorDict

data = TensorDict(
    {
        "a": torch.arange(12).view(3, 4),
        ("b", "c"): torch.arange(15).view(3, 5),
    },
    batch_size=[3],
)
print(data)

第一次调用 extend() 将实例化存储。数据的第一维度被解开为单独的数据点

buffer_lazytensor.extend(data)
print(f"The buffer has {len(buffer_lazytensor)} elements")

让我们从缓冲区中采样,并打印数据

sample = buffer_lazytensor.sample(5)
print("samples", sample["a"], sample["b", "c"])

LazyMemmapStorage 也以同样的方式创建。我们还可以自定义磁盘上的存储位置

with tempfile.TemporaryDirectory() as tempdir:
    buffer_lazymemmap = ReplayBuffer(
        storage=LazyMemmapStorage(size, scratch_dir=tempdir)
    )
    buffer_lazymemmap.extend(data)
    print(f"The buffer has {len(buffer_lazymemmap)} elements")
    print(
        "the 'a' tensor is stored in", buffer_lazymemmap._storage._storage["a"].filename
    )
    print(
        "the ('b', 'c') tensor is stored in",
        buffer_lazymemmap._storage._storage["b", "c"].filename,
    )
    sample = buffer_lazytensor.sample(5)
    print("samples: a=", sample["a"], "\n('b', 'c'):", sample["b", "c"])
    del buffer_lazymemmap

与 TensorDict 集成

tensor 位置遵循包含它们的 TensorDict 的相同结构:这使得在训练期间轻松保存和加载缓冲区成为可能。

为了充分发挥 TensorDict 作为数据载体的潜力,可以使用 TensorDictReplayBuffer 类。它的一个主要优点是能够处理采样数据的组织,以及可能需要的任何附加信息(例如样本索引)。

它可以像标准的 ReplayBuffer 一样构建,并且通常可以互换使用。

from torchrl.data import TensorDictReplayBuffer

with tempfile.TemporaryDirectory() as tempdir:
    buffer_lazymemmap = TensorDictReplayBuffer(
        storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12
    )
    buffer_lazymemmap.extend(data)
    print(f"The buffer has {len(buffer_lazymemmap)} elements")
    sample = buffer_lazymemmap.sample()
    print("sample:", sample)
    del buffer_lazymemmap

我们的样本现在有一个额外的 "index" 键,指示采样了哪些索引。让我们看看这些索引

print(sample["index"])

与 tensorclass 集成

ReplayBuffer 类及其相关的子类也原生支持 tensorclass 类,这些类可以方便地用于以更明确的方式编码数据集

from tensordict import tensorclass


@tensorclass
class MyData:
    images: torch.Tensor
    labels: torch.Tensor


data = MyData(
    images=torch.randint(
        255,
        (10, 64, 64, 3),
    ),
    labels=torch.randint(100, (10,)),
    batch_size=[10],
)

buffer_lazy = ReplayBuffer(storage=LazyTensorStorage(size), batch_size=12)
buffer_lazy.extend(data)
print(f"The buffer has {len(buffer_lazy)} elements")
sample = buffer_lazy.sample()
print("sample:", sample)

正如所料,数据具有正确的类和 shape!

与其他 tensor 结构(PyTrees)集成

TorchRL 的回放缓冲区也支持任何 pytree 数据结构。一个 PyTree 是由 dicts、lists 和/或 tuples 组成的任意深度的嵌套结构,其中叶子是 tensors。这意味着可以在连续内存中存储任何此类树结构!可以使用各种存储:TensorStorageLazyMemmapStorageLazyTensorStorage 都接受这种数据。

这里是对此功能的简要演示

from torch.utils._pytree import tree_map

让我们在 RAM 上构建回放缓冲区

rb = ReplayBuffer(storage=LazyTensorStorage(size))
data = {
    "a": torch.randn(3),
    "b": {"c": (torch.zeros(2), [torch.ones(1)])},
    30: -torch.ones(()),  # non-string keys also work
}
rb.add(data)

# The sample has a similar structure to the data (with a leading dimension of 10 for each tensor)
sample = rb.sample(10)

对于 pytrees,任何可调用对象都可以用作 transform

def transform(x):
    # Zeros all the data in the pytree
    return tree_map(lambda y: y * 0, x)


rb.append_transform(transform)
sample = rb.sample(batch_size=12)

让我们检查一下我们的 transform 是否正常工作

def assert0(x):
    assert (x == 0).all()


tree_map(assert0, sample)

采样和遍历缓冲区

回放缓冲区支持多种采样策略

  • 如果 batch-size 是固定的并且可以在构建时定义,则可以将其作为关键字参数传递给缓冲区;

  • 使用固定的 batch-size,可以遍历回放缓冲区以收集样本;

  • 如果 batch-size 是动态的,则可以在运行时将其传递给 sample 方法。

可以使用多线程进行采样,但这与最后一种选择不兼容(因为它要求缓冲区预先知道下一个批次的大小)。

让我们看几个例子

固定 batch-size

如果在构建期间传递了 batch-size,则在采样时应省略它

data = MyData(
    images=torch.randint(
        255,
        (200, 64, 64, 3),
    ),
    labels=torch.randint(100, (200,)),
    batch_size=[200],
)

buffer_lazy = ReplayBuffer(storage=LazyTensorStorage(size), batch_size=128)
buffer_lazy.extend(data)
buffer_lazy.sample()

此数据批次的大小是我们想要的大小 (128)。

要启用多线程采样,只需在构建期间将正整数传递给 prefetch 关键字参数。这应显着加快采样速度,尤其是在采样耗时的情况下(例如,使用优先采样器时)

buffer_lazy = ReplayBuffer(
    storage=LazyTensorStorage(size), batch_size=128, prefetch=10
)  # creates a queue of 10 elements to be prefetched in the background
buffer_lazy.extend(data)
print(buffer_lazy.sample())

以固定 batch-size 遍历缓冲区

只要 batch-size 是预定义的,我们也可以像使用常规 dataloader 一样遍历缓冲区

for i, data in enumerate(buffer_lazy):
    if i == 3:
        print(data)
        break

del buffer_lazy

由于我们的采样技术是完全随机的并且不阻止有放回采样,因此该迭代器是无限的。但是,我们可以改用 SamplerWithoutReplacement(无放回采样器),它将把我们的缓冲区转换为一个有限迭代器

from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

buffer_lazy = ReplayBuffer(
    storage=LazyTensorStorage(size), batch_size=32, sampler=SamplerWithoutReplacement()
)

我们创建一个足够大的数据来获取几个样本

data = TensorDict(
    {
        "a": torch.arange(64).view(16, 4),
        ("b", "c"): torch.arange(128).view(16, 8),
    },
    batch_size=[16],
)

buffer_lazy.extend(data)
for _i, _ in enumerate(buffer_lazy):
    continue
print(f"A total of {_i+1} batches have been collected")

del buffer_lazy

动态 batch-size

与我们之前看到的不同,batch_size 关键字参数可以省略,并直接传递给 sample 方法

buffer_lazy = ReplayBuffer(
    storage=LazyTensorStorage(size), sampler=SamplerWithoutReplacement()
)
buffer_lazy.extend(data)
print("sampling 3 elements:", buffer_lazy.sample(3))
print("sampling 5 elements:", buffer_lazy.sample(5))

del buffer_lazy

优先回放缓冲区

TorchRL 还提供了 优先回放缓冲区 的接口。此类缓冲区根据通过数据传递的优先级信号进行采样。

虽然此工具兼容非 TensorDict 数据,但我们鼓励改用 TensorDict,因为它使得在缓冲区内外携带元数据变得容易。

让我们首先看看如何在一般情况下构建一个优先回放缓冲区。\(\alpha\) 和 \(\beta\) 超参数必须手动设置

from torchrl.data.replay_buffers.samplers import PrioritizedSampler

size = 100

rb = ReplayBuffer(
    storage=ListStorage(size),
    sampler=PrioritizedSampler(max_capacity=size, alpha=0.8, beta=1.1),
    collate_fn=lambda x: x,
)

扩展回放缓冲区会返回项目索引,我们稍后将需要这些索引来更新优先级

indices = rb.extend([1, "foo", None])

采样器期望每个元素都有一个优先级。当添加到缓冲区时,优先级被设置为默认值 1。优先级计算后(通常通过损失函数),必须在缓冲区中更新它。

这是通过 update_priority() 方法完成的,该方法需要索引和优先级。我们将数据集中的第二个样本分配一个人为的高优先级,以观察其对采样的影响

rb.update_priority(index=indices, priority=torch.tensor([0, 1_000, 0.1]))

我们观察到从缓冲区采样的结果主要是第二个样本("foo"

sample, info = rb.sample(10, return_info=True)
print(sample)

info 包含项目的相对权重以及索引。

print(info)

我们看到,使用优先回放缓冲区与使用常规缓冲区相比,在训练循环中需要一系列额外的步骤

  • 收集数据并扩展缓冲区后,必须更新项目的优先级;

  • 计算损失并从中获取“优先级信号”后,我们必须再次更新缓冲区中项目的优先级。这需要我们跟踪索引。

这极大地阻碍了缓冲区的可重用性:如果要编写一个训练脚本,其中既可以创建优先缓冲区也可以创建常规缓冲区,则她必须添加大量的控制流,以确保仅当使用优先缓冲区时,才在适当的位置调用适当的方法。

让我们看看如何使用 TensorDict 来改进这一点。我们看到 TensorDictReplayBuffer 返回的数据通过其相对存储索引进行了增强。我们没有提到的一个特性是,如果优先级信号在扩展期间存在,此类还会确保将其自动解析到优先采样器。

这些功能的结合在几个方面简化了事情:- 扩展缓冲区时,优先级信号将自动

如果存在则被解析,并且优先级将被准确分配;

  • 索引将存储在采样的 tensordicts 中,使得在损失计算后易于更新优先级。

  • 计算损失时,优先级信号将注册到传递给损失模块的 tensordict 中,使得无需努力即可更新权重

    ..code - block::Python

    >>> data = replay_buffer.sample()
    >>> loss_val = loss_module(data)
    >>> replay_buffer.update_tensordict_priority(data)
    

以下代码阐述了这些概念。我们构建了一个带有优先采样器的回放缓冲区,并在构造函数中指明了应该获取优先级信号的入口

rb = TensorDictReplayBuffer(
    storage=ListStorage(size),
    sampler=PrioritizedSampler(size, alpha=0.8, beta=1.1),
    priority_key="td_error",
    batch_size=1024,
)

让我们选择一个与存储索引成比例的优先级信号

data["td_error"] = torch.arange(data.numel())

rb.extend(data)

sample = rb.sample()

较高的索引应该更频繁地出现

from matplotlib import pyplot as plt

fig = plt.hist(sample["index"].numpy())
plt.show()

处理完样本后,我们使用 torchrl.data.TensorDictReplayBuffer.update_tensordict_priority() 方法更新优先级键。为了演示其工作原理,我们反转采样项目的优先级

sample = rb.sample()
sample["td_error"] = data.numel() - sample["index"]
rb.update_tensordict_priority(sample)

现在,较高的索引应该更少地出现

sample = rb.sample()

fig = plt.hist(sample["index"].numpy())
plt.show()

使用 transforms

存储在回放缓冲区中的数据可能尚未准备好呈现给损失模块。在某些情况下,collector 生成的数据可能太重而无法按原样保存。例如,将图像从 uint8 转换为浮点 tensors,或在使用 decision transformers 时连接连续的帧。

只需向缓冲区附加适当的 transform,即可处理进出缓冲区的数据。这里有一些例子

保存原始图像

uint8 类型的 tensors 比我们通常馈送到模型的浮点 tensors 在内存方面便宜得多。因此,保存原始图像会很有用。以下脚本展示了如何构建一个仅返回原始图像但使用 transformed 图像进行推理的 collector,以及如何将这些 transformations 在回放缓冲区中重复使用

from torchrl.collectors import SyncDataCollector
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
    Compose,
    GrayScale,
    Resize,
    ToTensorImage,
    TransformedEnv,
)
from torchrl.envs.utils import RandomPolicy

env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True),
    Compose(
        ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
        Resize(in_keys=["pixels_trsf"], w=64, h=64),
        GrayScale(in_keys=["pixels_trsf"]),
    ),
)

让我们看看一个 rollout

print(env.rollout(3))

我们刚刚创建了一个生成像素的环境。这些图像经过处理以馈送给策略。我们希望存储原始图像,而不是它们的 transforms。为此,我们将一个 transform 附加到 collector,以选择我们希望出现的键

from torchrl.envs.transforms import ExcludeTransform

collector = SyncDataCollector(
    env,
    RandomPolicy(env.action_spec),
    frames_per_batch=10,
    total_frames=1000,
    postproc=ExcludeTransform("pixels_trsf", ("next", "pixels_trsf"), "collector"),
)

让我们看看一批数据,并确认 "pixels_trsf" 键已被丢弃

for data in collector:
    print(data)
    break

collector.shutdown()

我们创建一个回放缓冲区,其 transform 与环境相同。然而,有一个细节需要注意:在没有环境的情况下使用的 transforms 对数据结构一无所知。将 transform 附加到环境时,嵌套 tensordict 中 "next" 的数据首先被转换,然后在 rollout 执行期间复制到根目录。使用静态数据时,情况并非如此。尽管如此,我们的数据带有一个嵌套的 “next” tensordict,如果我们不明确指示 transform 处理它,它将被忽略。我们手动将这些键添加到 transform 中

t = Compose(
    ToTensorImage(
        in_keys=["pixels", ("next", "pixels")],
        out_keys=["pixels_trsf", ("next", "pixels_trsf")],
    ),
    Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
    GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
)
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(1000), transform=t, batch_size=16)
rb.extend(data)

我们可以检查一个 sample 方法是否能看到 transformed 图像重新出现

print(rb.sample())

一个更复杂的示例:使用 CatFrames

CatFrames transform 随时间展开 observations,创建一个过去事件的 n-back 内存,使模型能够考虑过去事件(在 POMDPs 或使用像 Decision Transformers 这样的循环策略的情况下)。存储这些连接的帧会消耗大量的内存。当训练和推理期间 n-back 窗口需要不同(通常更长)时,这也可能成为问题。我们通过在两个阶段中分别执行 CatFrames transform 来解决这个问题。

from torchrl.envs import CatFrames, UnsqueezeTransform

我们为返回像素observations 的环境创建了一个标准的 transforms 列表

env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True),
    Compose(
        ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
        Resize(in_keys=["pixels_trsf"], w=64, h=64),
        GrayScale(in_keys=["pixels_trsf"]),
        UnsqueezeTransform(-4, in_keys=["pixels_trsf"]),
        CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]),
    ),
)
collector = SyncDataCollector(
    env,
    RandomPolicy(env.action_spec),
    frames_per_batch=10,
    total_frames=1000,
)
for data in collector:
    print(data)
    break

collector.shutdown()

缓冲区 transform 看起来与环境的 transform 非常相似,但像之前一样带有额外的 ("next", ...)

t = Compose(
    ToTensorImage(
        in_keys=["pixels", ("next", "pixels")],
        out_keys=["pixels_trsf", ("next", "pixels_trsf")],
    ),
    Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
    GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
    UnsqueezeTransform(-4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
    CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
)
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(size), transform=t, batch_size=16)
data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf"))
rb.add(data_exclude)

让我们从缓冲区中采样一个批次。transformed 像素键的 shape 在倒数第四个维度上应该长度为 4

s = rb.sample(1)  # the buffer has only one element
print(s)

经过一些处理(排除未使用键等)后,我们看到在线和离线生成的数据匹配!

assert (data.exclude("collector") == s.squeeze(0).exclude("index", "collector")).all()

存储轨迹

在许多情况下,最好从缓冲区访问轨迹而不是简单的转换。TorchRL 提供了多种实现方式。

目前,首选方法是将轨迹沿着缓冲区的第一维存储,并使用 SliceSampler 对这些数据批次进行采样。此类只需要一些关于你的数据结构的信息即可完成其工作(注意,截至目前,它仅与 tensordict 结构的数据兼容):slices 的数量或其长度,以及关于 episode 之间在哪里找到分离的信息(例如,回想一下,使用 DataCollector 时,轨迹 ID 存储在 ("collector", "traj_ids") 中)。在这个简单的例子中,我们构建了一个包含 4 个连续短轨迹的数据,并从中采样了 4 个 slices,每个 slices 的长度为 2(因为 batch size 是 8,并且 8 items // 4 slices = 2 time steps)。我们还标记了 steps。

from torchrl.data import SliceSampler

rb = TensorDictReplayBuffer(
    storage=LazyTensorStorage(size),
    sampler=SliceSampler(traj_key="episode", num_slices=4),
    batch_size=8,
)
episode = torch.zeros(10, dtype=torch.int)
episode[:3] = 1
episode[3:5] = 2
episode[5:7] = 3
episode[7:] = 4
steps = torch.cat([torch.arange(3), torch.arange(2), torch.arange(2), torch.arange(3)])
data = TensorDict(
    {
        "episode": episode,
        "obs": torch.randn((3, 4, 5)).expand(10, 3, 4, 5),
        "act": torch.randn((20,)).expand(10, 20),
        "other": torch.randn((20, 50)).expand(10, 20, 50),
        "steps": steps,
    },
    [10],
)
rb.extend(data)
sample = rb.sample()
print("episode are grouped", sample["episode"])
print("steps are successive", sample["steps"])

gc.collect()

结论

我们已经了解了如何在 TorchRL 中使用回放缓冲区,从最简单的用法到需要转换或以特定方式存储数据的更高级用法。现在你应该能够

  • 创建回放缓冲区,自定义其存储、采样器和 transforms;

  • 为你的问题选择最佳存储类型(list、内存或基于磁盘的);

  • 最小化缓冲区的内存占用。

下一步

  • 查阅数据 API 参考,了解 TorchRL 中的离线数据集,这些数据集基于我们的回放缓冲区 API;

  • 查阅其他采样器,例如 SamplerWithoutReplacementPrioritizedSliceSamplerSliceSamplerWithoutReplacement,或查看其他 writer,例如 TensorDictMaxValueWriter

  • 查阅 文档,了解如何检查点 ReplayBuffers。

由 Sphinx-Gallery 生成的图库

文档

查阅 PyTorch 全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源