• 文档 >
  • 数据收集和存储入门
快捷方式

数据收集和存储入门

作者: Vincent Moens

注意

要在 notebook 中运行本教程,请在开头添加一个安装单元格,其中包含

!pip install tensordict
!pip install torchrl
import tempfile

没有数据就没有学习。在监督学习中,用户习惯于使用 DataLoader 等工具将数据整合到训练循环中。Dataloader 是可迭代对象,它们为你提供用于训练模型的数据。

TorchRL 以类似的方式处理数据加载问题,尽管这在强化学习库生态系统中是出乎意料的独特之处。TorchRL 的数据加载器被称为 DataCollectors。大多数时候,数据收集并不仅仅是原始数据的收集,因为在被损失模块消耗之前,数据需要暂时存储在缓冲区(或用于 on-policy 算法的等效结构)中。本教程将探讨这两个类。

数据收集器

此处讨论的主要数据收集器是 SyncDataCollector,这是本文档的重点。在基础层面,收集器是一个简单的类,负责在环境中执行你的策略,必要时重置环境,并提供预定义大小的批次数据。与环境教程中演示的 rollout() 方法不同,收集器在连续的数据批次之间不会重置。因此,连续的两个数据批次可能包含来自同一轨迹的元素。

你需要传递给收集器的基本参数是你想收集的批次大小(frames_per_batch),迭代器的长度(可能无限),策略和环境。为简单起见,在此示例中我们将使用一个虚拟的随机策略。

import torch

torch.manual_seed(0)

from torchrl.collectors import SyncDataCollector
from torchrl.envs import GymEnv
from torchrl.envs.utils import RandomPolicy

env = GymEnv("CartPole-v1")
env.set_seed(0)

policy = RandomPolicy(env.action_spec)
collector = SyncDataCollector(env, policy, frames_per_batch=200, total_frames=-1)

现在我们期望收集器无论在收集过程中发生什么,都将交付大小为 200 的批次数据。换句话说,这个批次中可能包含多个轨迹!total_frames 指示收集器应该运行多长时间。值为 -1 将生成一个永不停止的收集器。

让我们迭代收集器,了解这些数据看起来像什么

for data in collector:
    print(data)
    break

如你所见,我们的数据增加了一些收集器特有的元数据,这些元数据被分组到一个 "collector"tensordict 中,这是我们在环境 rollout 期间没有看到的。这对于跟踪轨迹 ID 很有用。在下面的列表中,每个条目标记了相应 transition 所属的轨迹编号

print(data["collector", "traj_ids"])

在编写最先进的算法时,数据收集器非常有用,因为性能通常是通过特定技术在给定数量的环境交互次数(收集器中的 total_frames 参数)内解决问题的能力来衡量的。因此,我们示例中的大多数训练循环都像这样

..code - block::Python

>>> for data in collector:
...     # your algorithm here

回放缓冲区

既然我们已经探索了如何收集数据,我们想知道如何存储它。在强化学习中,典型的设置是收集数据,临时存储,并在一段时间后根据某种启发式方法清除:先入先出或其他。一个典型的伪代码看起来像这样

..code - block::Python

>>> for data in collector:
...     storage.store(data)
...     for i in range(n_optim):
...         sample = storage.sample()
...         loss_val = loss_fn(sample)
...         loss_val.backward()
...         optim.step() # etc

TorchRL 中存储数据的父类被称为 ReplayBuffer。TorchRL 的回放缓冲区是可组合的:你可以编辑存储类型、它们的采样技术、写入启发式方法或应用于它们的 transforms。我们将把更高级的内容留给专门的深入教程。通用的回放缓冲区只需要知道它要使用什么存储。通常,我们推荐使用 TensorStorage 子类,这在大多数情况下都能很好地工作。在本教程中,我们将使用 LazyMemmapStorage,它具有两个很好的特性:首先,它很“lazy”,你无需提前明确告知它你的数据是什么样子。其次,它使用 MemoryMappedTensor 作为后端,以高效的方式将你的数据保存在磁盘上。你唯一需要知道的是你希望缓冲区有多大。

from torchrl.data.replay_buffers import LazyMemmapStorage, ReplayBuffer

buffer_scratch_dir = tempfile.TemporaryDirectory().name

buffer = ReplayBuffer(
    storage=LazyMemmapStorage(max_size=1000, scratch_dir=buffer_scratch_dir)
)

填充缓冲区可以通过 add()(单个元素)或 extend()(多个元素)方法完成。使用我们刚刚收集的数据,我们可以一步初始化并填充缓冲区

indices = buffer.extend(data)

我们可以检查缓冲区现在具有与我们从收集器中获得的数据相同的元素数量

assert len(buffer) == collector.frames_per_batch

唯一需要知道的是如何从缓冲区中获取数据。自然地,这依赖于 sample() 方法。由于我们没有指定采样必须是不重复的,因此不能保证从缓冲区获取的样本是唯一的

sample = buffer.sample(batch_size=30)
print(sample)

再次,我们的样本看起来与我们从收集器收集的数据完全相同!

下一步

  • 你可以查看其他多进程收集器,例如 MultiSyncDataCollectorMultiaSyncDataCollector

  • 如果你有多个节点用于推理,TorchRL 还提供了分布式收集器。请在API 参考中查看它们。

  • 查阅专门的回放缓冲区教程,了解构建缓冲区时的更多选项,或查阅API 参考,其中涵盖所有详细功能。回放缓冲区具有无数功能,例如多线程采样、优先经验回放等等……

  • 为简单起见,我们没有介绍回放缓冲区的迭代能力。你可以自己尝试一下:构建一个缓冲区并在构造函数中指定其 batch-size,然后尝试对其进行迭代。这相当于在循环中调用 rb.sample()

由 Sphinx-Gallery 生成的图库

文档

获取 PyTorch 全面的开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源