torchrl.collectors 包¶
数据收集器在某种程度上等同于 pytorch 数据加载器,但不同的是 (1) 它们在非静态数据源上收集数据,以及 (2) 数据是使用模型 (可能是正在训练的模型的版本) 收集的。
TorchRL 的数据收集器接受两个主要参数:一个环境(或环境构造函数列表)和一个策略。它们将迭代执行环境步骤和策略查询,在将收集的数据堆栈传递给用户之前,执行定义的步骤数。环境将在它们达到 done 状态或在预定义的步骤数之后重置。
由于数据收集是一个潜在的计算密集型过程,因此适当地配置执行超参数至关重要。首先要考虑的参数是数据收集是否应与优化步骤串行发生或并行发生。这SyncDataCollector
类将在训练工作器上执行数据收集。这MultiSyncDataCollector
类将跨多个工作器拆分工作负载,并汇总将传递给训练工作器的结果。最后,这MultiaSyncDataCollector
类将在多个工作器上执行数据收集,并传递它可以收集的第一批结果。此执行将连续发生,并与网络训练同时进行:这意味着用于数据收集的策略的权重可能略微滞后于训练工作器上策略的配置。因此,虽然此类可能是收集数据的最快方式,但它以仅适用于异步收集数据的设置 (例如,离策略 RL 或课程 RL) 为代价。对于远程执行的回滚 (MultiSyncDataCollector
或 MultiaSyncDataCollector
),必须使用 collector.update_policy_weights_() 或在构造函数中设置 update_at_each_batch=True 来将远程策略的权重与训练工作器的权重同步。
第二个要考虑的参数(在远程设置中)是将收集数据的设备以及执行环境和策略操作的设备。例如,在 CPU 上执行的策略可能比在 CUDA 上执行的策略慢。当多个推理工作器同时运行时,跨可用设备调度计算工作负载可以加速收集或避免 OOM 错误。最后,批次大小的选择以及传递设备(即存储数据的设备,在等待传递给收集工作器时)也可能影响内存管理。要控制的关键参数是 devices
,它控制执行设备(即策略的设备)和 storing_device
,它将控制在回滚期间存储环境和数据的设备。一个很好的启发式方法通常是使用相同的设备进行存储和计算,这是仅传递 devices 参数时的默认行为。
除了这些计算参数之外,用户可以选择配置以下参数
max_frames_per_traj:调用
env.reset()
后的帧数frames_per_batch:每次迭代在收集器上传递的帧数
init_random_frames:随机步骤数(调用
env.rand_step()
的步骤)reset_at_each_iter:如果为
True
,则环境将在每次批次收集后重置split_trajs:如果为
True
,则轨迹将被拆分并以填充的 tensordict 形式传递,以及一个"mask"
键,该键将指向一个布尔掩码,表示有效值。exploration_type:要与策略一起使用的探索策略。
reset_when_done:是否应在达到 done 状态时重置环境。
收集器和批次大小¶
因为每个收集器都有自己的组织运行在其中的环境的方式,所以数据将具有不同的批次大小,具体取决于收集器的具体特性。下表总结了收集数据时应预期的内容
SyncDataCollector |
MultiSyncDataCollector (n=B) |
MultiaSyncDataCollector (n=B) |
|||
---|---|---|---|---|---|
cat_results |
NA |
“stack” |
0 |
-1 |
NA |
单个环境 |
[T] |
[B, T] |
[B*(T//B) |
[B*(T//B)] |
[T] |
批处理环境 (n=P) |
[P, T] |
[B, P, T] |
[B * P, T] |
[P, T * B] |
[P, T] |
在所有这些情况下,最后一个维度 (T
表示 time
) 会进行调整,以使批次大小等于传递给收集器的 frames_per_batch
参数。
警告
不应将 MultiSyncDataCollector
与 cat_results=0
一起使用,因为数据将沿着批次维度与批处理环境堆叠,或者对于单个环境而言是时间维度,这在交换两者时可能会造成一些混淆。 cat_results="stack"
是一种更好的、更一致的与环境交互的方式,因为它将保持每个维度分离,并提供配置、收集器类和其他组件之间更好的可互换性。
与 MultiSyncDataCollector
具有与正在运行的子收集器数量相对应的维度 (B
) 不同,MultiaSyncDataCollector
没有。当考虑到 MultiaSyncDataCollector
按先到先得的方式提供数据批次,而 MultiSyncDataCollector
在交付数据之前从每个子收集器中收集数据时,这一点很容易理解。
收集器和回放缓冲区互操作性¶
在最简单的场景中,需要从回放缓冲区中采样单个转换,需要很少关注收集器的构建方式。在填充存储之前,对收集后的数据进行扁平化将是一个足够的前处理步骤。
>>> memory = ReplayBuffer(
... storage=LazyTensorStorage(N),
... transform=lambda data: data.reshape(-1))
>>> for data in collector:
... memory.extend(data)
如果需要收集轨迹片段,建议使用多维缓冲区并使用 SliceSampler
采样器类进行采样。必须确保传递给缓冲区的数据形状正确,time
和 batch
维度清晰地分开。在实践中,以下配置将有效
>>> # Single environment: no need for a multi-dimensional buffer
>>> memory = ReplayBuffer(
... storage=LazyTensorStorage(N),
... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1)
>>> for data in collector:
... memory.extend(data)
>>> # Batched environments: a multi-dim buffer is required
>>> memory = ReplayBuffer(
... storage=LazyTensorStorage(N, ndim=2),
... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> env = ParallelEnv(4, make_env)
>>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1)
>>> for data in collector:
... memory.extend(data)
>>> # MultiSyncDataCollector + regular env: behaves like a ParallelEnv iif cat_results="stack"
>>> memory = ReplayBuffer(
... storage=LazyTensorStorage(N, ndim=2),
... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> collector = MultiSyncDataCollector([make_env] * 4,
... policy,
... frames_per_batch=N,
... total_frames=-1,
... cat_results="stack")
>>> for data in collector:
... memory.extend(data)
>>> # MultiSyncDataCollector + parallel env: the ndim must be adapted accordingly
>>> memory = ReplayBuffer(
... storage=LazyTensorStorage(N, ndim=3),
... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> collector = MultiSyncDataCollector([ParallelEnv(2, make_env)] * 4,
... policy,
... frames_per_batch=N,
... total_frames=-1,
... cat_results="stack")
>>> for data in collector:
... memory.extend(data)
目前尚不支持使用 MultiSyncDataCollector
采样轨迹的回放缓冲区,因为数据批次可以来自任何工作器,并且在大多数情况下,写入缓冲区的连续批次不会来自同一个源(从而中断轨迹)。
单节点数据收集器¶
|
数据收集器的基类。 |
|
用于 RL 问题的通用数据收集器。 |
|
在单独的进程上同步运行给定数量的 DataCollectors。 |
|
在单独的进程上异步运行给定数量的 DataCollectors。 |
|
在单独的进程上运行单个 DataCollector。 |
分布式数据收集器¶
TorchRL 提供了一组分布式数据收集器。这些工具支持多种后端 ('gloo'
、'nccl'
、'mpi'
,使用 DistributedDataCollector
或 PyTorch RPC 使用 RPCDataCollector
) 和启动器 ('ray'
、submitit
或 torch.multiprocessing
)。它们可以在单个节点或跨多个节点的同步或异步模式下有效使用。
资源:在 专用文件夹 中查找这些收集器的示例。
注意
选择子收集器:所有分布式收集器都支持各种单机收集器。人们可能想知道为什么使用 MultiSyncDataCollector
或 ParallelEnv
。通常,多进程收集器比并行环境具有更低的 IO 占用,并行环境需要在每一步都进行通信。然而,模型规格在相反的方向上发挥作用,因为使用并行环境将导致策略(和/或转换)的更快执行,因为这些操作将被向量化。
注意
选择收集器(或并行环境)的设备:使用并行环境和在 CPU 上执行的多进程环境,通过共享内存缓冲区在进程之间共享数据。根据正在使用的机器的功能,与使用 cuda 驱动程序原生支持的 GPU 上共享数据相比,这可能非常慢。在实践中,这意味着在构建并行环境或收集器时使用 device="cpu"
关键字参数可能会导致比使用 device="cuda"
(如果可用)更慢的收集。
注意
鉴于库的许多可选依赖项(例如,Gym、Gymnasium 和许多其他依赖项),警告在多进程/分布式设置中可能很快变得非常烦人。默认情况下,TorchRL 在子进程中过滤掉这些警告。如果仍然希望看到这些警告,可以通过设置 torchrl.filter_warnings_subprocess=False
来显示它们。
|
具有 torch.distributed 后端的分布式数据收集器。 |
|
基于 RPC 的分布式数据收集器。 |
|
具有 torch.distributed 后端的分布式同步数据收集器。 |
|
用于 submitit 的延迟启动器。 |
|
具有 Ray 后端的分布式数据收集器。 |
辅助函数¶
|
用于轨迹分离的实用程序函数。 |