注意
请跳转至末尾 下载完整的示例代码。
TorchRL 简介¶
本演示文稿在 ICML 2022 工业演示日上进行展示。
它很好地概述了 TorchRL 的功能。如果您对此有疑问或意见,请随时联系 vmoens@fb.com 或提交 issue。
TorchRL 是一个用于 PyTorch 的开源强化学习 (RL) 库。
PyTorch 生态系统团队 (Meta) 已决定投资该库,以提供一个领先的平台,用于在研究环境中开发 RL 解决方案。
它提供了 pytorch 和Python 优先的低级和高级抽象,旨在实现高效、文档齐全和经过充分测试。代码旨在支持 RL 研究。大部分代码以高度模块化的方式用 Python 编写,以便研究人员可以轻松地替换组件、转换组件或编写新组件,而无需太多精力。
该仓库力求与现有的 pytorch 生态系统库对齐,它具有数据集支柱 (torchrl/envs)、转换、模型、数据工具(例如收集器和容器)等。TorchRL 的目标是尽可能少地依赖项(Python 标准库、numpy 和 pytorch)。常见的环境库(例如 OpenAI gym)仅为可选依赖。
与许多其他领域不同,RL 更多关注算法而非媒体。因此,很难构建真正独立的组件。
TorchRL 不是什么
算法集合:我们不打算提供 RL 算法的 SOTA(最新最好)实现,我们提供这些算法仅作为如何使用该库的示例。
研究框架:TorchRL 的模块化有两种形式。首先,我们尝试构建可重用组件,以便它们可以轻松地相互替换。其次,我们尽最大努力使组件可以独立于库的其余部分使用。
TorchRL 的核心依赖项非常少,主要是 PyTorch 和 numpy。所有其他依赖项(gym、torchvision、wandb / tensorboard)都是可选的。
数据¶
TensorDict¶
import torch
from tensordict import TensorDict
让我们创建一个 TensorDict。构造函数接受许多不同的格式,例如传入一个 dict 或使用关键字参数。
batch_size = 5
data = TensorDict(
key1=torch.zeros(batch_size, 3),
key2=torch.zeros(batch_size, 5, 6, dtype=torch.bool),
batch_size=[batch_size],
)
print(data)
您可以沿着 batch_size
索引 TensorDict,也可以查询键 (keys)。
print(data[2])
print(data["key1"] is data.get("key1"))
以下展示了如何堆叠多个 TensorDict。这在编写 rollout 循环时特别有用!
data1 = TensorDict(
{
"key1": torch.zeros(batch_size, 1),
"key2": torch.zeros(batch_size, 5, 6, dtype=torch.bool),
},
batch_size=[batch_size],
)
data2 = TensorDict(
{
"key1": torch.ones(batch_size, 1),
"key2": torch.ones(batch_size, 5, 6, dtype=torch.bool),
},
batch_size=[batch_size],
)
data = torch.stack([data1, data2], 0)
data.batch_size, data["key1"]
以下是 TensorDict 的其他一些功能:查看、置换、共享内存或扩展。
print(
"view(-1): ",
data.view(-1).batch_size,
data.view(-1).get("key1").shape,
)
print("to device: ", data.to("cpu"))
# print("pin_memory: ", data.pin_memory())
print("share memory: ", data.share_memory_())
print(
"permute(1, 0): ",
data.permute(1, 0).batch_size,
data.permute(1, 0).get("key1").shape,
)
print(
"expand: ",
data.expand(3, *data.batch_size).batch_size,
data.expand(3, *data.batch_size).get("key1").shape,
)
您也可以创建嵌套数据。
data = TensorDict(
source={
"key1": torch.zeros(batch_size, 3),
"key2": TensorDict(
source={"sub_key1": torch.zeros(batch_size, 2, 1)},
batch_size=[batch_size, 2],
),
},
batch_size=[batch_size],
)
data
经验回放缓冲区¶
经验回放缓冲区是许多 RL 算法中的关键组成部分。TorchRL 提供了一系列经验回放缓冲区实现。大多数基本功能适用于任何数据结构(list、tuples、dict),但为了充分发挥经验回放缓冲区的作用并实现快速读写访问,应优先使用 TensorDict API。
from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer
rb = ReplayBuffer(collate_fn=lambda x: x)
可以使用 add()
(n=1) 或 extend()
(n>1) 添加。
rb.add(1)
rb.sample(1)
rb.extend([2, 3])
rb.sample(3)
也可以使用优先经验回放缓冲区
rb = PrioritizedReplayBuffer(alpha=0.7, beta=1.1, collate_fn=lambda x: x)
rb.add(1)
rb.sample(1)
rb.update_priority(1, 0.5)
以下是使用带 data_stack 的经验回放缓冲区的示例。使用它们可以轻松地抽象经验回放缓冲区在多种用例中的行为。
collate_fn = torch.stack
rb = ReplayBuffer(collate_fn=collate_fn)
rb.add(TensorDict({"a": torch.randn(3)}, batch_size=[]))
len(rb)
rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
print(len(rb))
print(rb.sample(10))
print(rb.sample(2).contiguous())
torch.manual_seed(0)
from torchrl.data import TensorDictPrioritizedReplayBuffer
rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, priority_key="td_error")
rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
data_sample = rb.sample(2).contiguous()
print(data_sample)
print(data_sample["index"])
data_sample["td_error"] = torch.rand(2)
rb.update_tensordict_priority(data_sample)
for i, val in enumerate(rb._sampler._sum_tree):
print(i, val)
if i == len(rb):
break
环境¶
TorchRL 提供了一系列环境封装器和工具。
Gym 环境¶
try:
import gymnasium as gym
except ModuleNotFoundError:
import gym
from torchrl.envs.libs.gym import GymEnv, GymWrapper, set_gym_backend
gym_env = gym.make("Pendulum-v1")
env = GymWrapper(gym_env)
env = GymEnv("Pendulum-v1")
data = env.reset()
env.rand_step(data)
更改环境配置¶
env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env.reset()
env.close()
del env
from torchrl.envs import (
Compose,
NoopResetEnv,
ObservationNorm,
ToTensorImage,
TransformedEnv,
)
base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
环境转换¶
转换的作用类似于 Gym 封装器,但其 API 更接近 torchvision 的 torch.distributions
的转换。有多种转换可供选择。
from torchrl.envs import (
Compose,
NoopResetEnv,
ObservationNorm,
StepCounter,
ToTensorImage,
TransformedEnv,
)
base_env = GymEnv("HalfCheetah-v4", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env = env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
env.reset()
print("env: ", env)
print("last transform parent: ", env.transform[2].parent)
向量化环境¶
向量化/并行环境可以提供显著的加速。
from torchrl.envs import ParallelEnv
def make_env():
# You can control whether to use gym or gymnasium for your env
with set_gym_backend("gym"):
return GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
base_env = ParallelEnv(
4,
make_env,
mp_start_method="fork", # This will break on Windows machines! Remove and decorate with if __name__ == "__main__"
)
env = TransformedEnv(
base_env, Compose(StepCounter(), ToTensorImage())
) # applies transforms on batch of envs
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
env.reset()
print(env.action_spec)
env.close()
del env
模块¶
库中提供了多种模块(工具、模型和封装器)。
模型¶
MLP 模型示例
from torch import nn
from torchrl.modules import ConvNet, MLP
from torchrl.modules.models.utils import SquashDims
net = MLP(num_cells=[32, 64], out_features=4, activation_class=nn.ELU)
print(net)
print(net(torch.randn(10, 3)).shape)
CNN 模型示例
cnn = ConvNet(
num_cells=[32, 64],
kernel_sizes=[8, 4],
strides=[2, 1],
aggregator_class=SquashDims,
)
print(cnn)
print(cnn(torch.randn(10, 3, 32, 32)).shape) # last tensor is squashed
TensorDict 模块¶
一些模块专门设计用于处理 TensorDict 输入。
from tensordict.nn import TensorDictModule
data = TensorDict({"key1": torch.randn(10, 3)}, batch_size=[10])
module = nn.Linear(3, 4)
td_module = TensorDictModule(module, in_keys=["key1"], out_keys=["key2"])
td_module(data)
print(data)
模块序列¶
TensorDictSequential
使构建模块序列变得容易。
from tensordict.nn import TensorDictSequential
backbone_module = nn.Linear(5, 3)
backbone = TensorDictModule(
backbone_module, in_keys=["observation"], out_keys=["hidden"]
)
actor_module = nn.Linear(3, 4)
actor = TensorDictModule(actor_module, in_keys=["hidden"], out_keys=["action"])
value_module = MLP(out_features=1, num_cells=[4, 5])
value = TensorDictModule(value_module, in_keys=["hidden", "action"], out_keys=["value"])
sequence = TensorDictSequential(backbone, actor, value)
print(sequence)
print(sequence.in_keys, sequence.out_keys)
data = TensorDict(
{"observation": torch.randn(3, 5)},
[3],
)
backbone(data)
actor(data)
value(data)
data = TensorDict(
{"observation": torch.randn(3, 5)},
[3],
)
sequence(data)
print(data)
函数式编程(集成 / 元强化学习)¶
函数式调用从未如此简单。使用 from_module()
提取参数,然后使用 to_module()
替换它们。
from tensordict import from_module
params = from_module(sequence)
print("extracted params", params)
使用 TensorDict 的函数式调用
with params.to_module(sequence):
data = sequence(data)
VMAP¶
快速执行多个相似架构的副本是快速训练模型的关键。vmap()
正是为此量身定制的。
专门类¶
TorchRL 还提供了一些专门模块,它们对输出值运行检查。
torch.manual_seed(0)
from torchrl.data import Bounded
from torchrl.modules import SafeModule
spec = Bounded(-torch.ones(3), torch.ones(3))
base_module = nn.Linear(5, 3)
module = SafeModule(
module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True
)
data = TensorDict({"obs": torch.randn(5)}, batch_size=[])
module(data)["action"]
data = TensorDict({"obs": torch.randn(5) * 100}, batch_size=[])
module(data)["action"] # safe=True projects the result within the set
Actor
类有一个预定义的输出键 ("action"
)。
from torchrl.modules import Actor
base_module = nn.Linear(5, 3)
actor = Actor(base_module, in_keys=["obs"])
data = TensorDict({"obs": torch.randn(5)}, batch_size=[])
actor(data) # action is the default value
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
)
借助 tensordict.nn
API,处理概率模型也变得容易。
from torchrl.modules import NormalParamExtractor, TanhNormal
td = TensorDict({"input": torch.randn(3, 5)}, [3])
net = nn.Sequential(
nn.Linear(5, 4), NormalParamExtractor()
) # splits the output in loc and scale
module = TensorDictModule(net, in_keys=["input"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
module,
ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
return_log_prob=False,
),
)
td_module(td)
print(td)
# returning the log-probability
td = TensorDict({"input": torch.randn(3, 5)}, [3])
td_module = ProbabilisticTensorDictSequential(
module,
ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
return_log_prob=True,
),
)
td_module(td)
print(td)
通过上下文管理器 set_exploration_type
可以控制随机性和采样策略。
from torchrl.envs.utils import ExplorationType, set_exploration_type
td = TensorDict({"input": torch.randn(3, 5)}, [3])
torch.manual_seed(0)
with set_exploration_type(ExplorationType.RANDOM):
td_module(td)
print("random:", td["action"])
with set_exploration_type(ExplorationType.DETERMINISTIC):
td_module(td)
print("mode:", td["action"])
使用环境和模块¶
让我们看看环境和模块如何结合使用。
from torchrl.envs.utils import step_mdp
env = GymEnv("Pendulum-v1")
action_spec = env.action_spec
actor_module = nn.Linear(3, 1)
actor = SafeModule(
actor_module, spec=action_spec, in_keys=["observation"], out_keys=["action"]
)
torch.manual_seed(0)
env.set_seed(0)
max_steps = 100
data = env.reset()
data_stack = TensorDict(batch_size=[max_steps])
for i in range(max_steps):
actor(data)
data_stack[i] = env.step(data)
if data["done"].any():
break
data = step_mdp(data) # roughly equivalent to obs = next_obs
tensordicts_prealloc = data_stack.clone()
print("total steps:", i)
print(data_stack)
# equivalent
torch.manual_seed(0)
env.set_seed(0)
max_steps = 100
data = env.reset()
data_stack = []
for _ in range(max_steps):
actor(data)
data_stack.append(env.step(data))
if data["done"].any():
break
data = step_mdp(data) # roughly equivalent to obs = next_obs
tensordicts_stack = torch.stack(data_stack, 0)
print("total steps:", i)
print(tensordicts_stack)
(tensordicts_stack == tensordicts_prealloc).all()
torch.manual_seed(0)
env.set_seed(0)
tensordict_rollout = env.rollout(policy=actor, max_steps=max_steps)
tensordict_rollout
(tensordict_rollout == tensordicts_prealloc).all()
from tensordict.nn import TensorDictModule
收集器¶
我们还提供了一组数据收集器,它们会自动按照需求收集每批次的帧数。它们支持从单节点、单 worker 到多节点、多 worker 的设置。
from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector
from torchrl.envs import EnvCreator, SerialEnv
from torchrl.envs.libs.gym import GymEnv
EnvCreator 确保我们可以在进程之间发送 lambda 函数。为了简单起见(单个 worker),我们使用 SerialEnv
,但对于大型任务,ParallelEnv
(多个 worker)将更适合。
注意
多进程环境和多进程收集器可以结合使用!
parallel_env = SerialEnv(
3,
EnvCreator(lambda: GymEnv("Pendulum-v1")),
)
create_env_fn = [parallel_env, parallel_env]
actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"])
同步多进程数据收集器¶
devices = ["cpu", "cpu"]
collector = MultiSyncDataCollector(
create_env_fn=create_env_fn, # either a list of functions or a ParallelEnv
policy=actor,
total_frames=240,
max_frames_per_traj=-1, # envs are terminating, we don't need to stop them early
frames_per_batch=60, # we want 60 frames at a time (we have 3 envs per sub-collector)
device=devices,
)
for i, d in enumerate(collector):
if i == 0:
print(d) # trajectories are split automatically in [6 workers x 10 steps]
collector.update_policy_weights_() # make sure that our policies have the latest weights if working on multiple devices
print(i)
collector.shutdown()
del collector
异步多进程数据收集器¶
此类允许您在模型训练时收集数据。这在离策略(off-policy)设置中特别有用,因为它解耦了推理和模型训练。数据按先就绪先服务(first-ready-first-served)原则交付(workers 将排队其结果)。
collector = MultiaSyncDataCollector(
create_env_fn=create_env_fn, # either a list of functions or a ParallelEnv
policy=actor,
total_frames=240,
max_frames_per_traj=-1, # envs are terminating, we don't need to stop them early
frames_per_batch=60, # we want 60 frames at a time (we have 3 envs per sub-collector)
device=devices,
)
for i, d in enumerate(collector):
if i == 0:
print(d) # trajectories are split automatically in [6 workers x 10 steps]
collector.update_policy_weights_() # make sure that our policies have the latest weights if working on multiple devices
print(i)
collector.shutdown()
del collector
del create_env_fn
del parallel_env
目标¶
目标是编写新算法时的主要入口点。
from torchrl.objectives import DDPGLoss
actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"])
class ConcatModule(nn.Linear):
def forward(self, obs, action):
return super().forward(torch.cat([obs, action], -1))
value_module = ConcatModule(4, 1)
value = TensorDictModule(
value_module, in_keys=["observation", "action"], out_keys=["state_action_value"]
)
loss_fn = DDPGLoss(actor, value)
loss_fn.make_value_estimator(loss_fn.default_value_estimator, gamma=0.99)
data = TensorDict(
{
"observation": torch.randn(10, 3),
"next": {
"observation": torch.randn(10, 3),
"reward": torch.randn(10, 1),
"done": torch.zeros(10, 1, dtype=torch.bool),
},
"action": torch.randn(10, 1),
},
batch_size=[10],
device="cpu",
)
loss_td = loss_fn(data)
print(loss_td)
print(data)
安装库¶
该库已发布到 PyPI:pip install torchrl 更多信息请参阅 README 文件。
贡献¶
我们正在积极寻找贡献者和早期用户。如果您正在从事 RL 工作(或只是好奇),请尝试一下!请给我们反馈:TorchRL 的成功取决于它如何很好地满足研究人员的需求。为此,我们需要他们的意见!由于该库尚处于萌芽阶段,现在是塑造您想要它的方式的绝佳时机!
更多信息请参阅 贡献指南。