• 文档 >
  • 导出 TorchRL 模块
快捷方式

导出 TorchRL 模块

作者: Vincent Moens

注意

要在 notebook 中运行本教程,请在开头添加一个安装单元格,包含

!pip install tensordict
!pip install torchrl
!pip install "gymnasium[atari,accept-rom-license]"<1.0.0

简介

如果在实际场景中无法部署所学到的策略,那么学习策略的价值就很小。正如其他教程所示,TorchRL 非常注重模块化和可组合性:得益于 tensordict,库的组件可以以最通用的方式编写,只需将其签名抽象为对输入 TensorDict 的一组操作即可。这可能会给人一种印象,即该库仅用于训练,因为典型的低级执行硬件(边缘设备、机器人、arduino、Raspberry Pi)不执行 python 代码,更不用说安装 pytorch、tensordict 或 torchrl 了。

幸运的是,PyTorch 提供了一整套生态系统解决方案,用于将代码和训练好的模型导出到设备和硬件上,并且 TorchRL 完全能够与其交互。可以选择多种后端,包括本教程中示例的 ONNX 或 AOTInductor。本教程简要介绍了如何将训练好的模型隔离并作为独立可执行文件分发,以便导出到硬件上。

要点回顾

  • 训练后导出任何 TorchRL 模块;

  • 使用各种后端;

  • 测试导出的模型。

快速回顾:一个简单的 TorchRL 训练循环

在本节中,我们重现了上一个入门教程中的训练循环,并稍作修改以用于由 gymnasium 库渲染的 Atari 游戏。我们将继续使用 DQN 示例,并稍后展示如何使用输出价值分布的策略来代替它。

import time
from pathlib import Path

import numpy as np

import torch

from tensordict.nn import (
    TensorDictModule as Mod,
    TensorDictSequential,
    TensorDictSequential as Seq,
)

from torch.optim import Adam

from torchrl._utils import timeit
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer

from torchrl.envs import (
    Compose,
    GrayScale,
    GymEnv,
    Resize,
    set_exploration_type,
    StepCounter,
    ToTensorImage,
    TransformedEnv,
)

from torchrl.modules import ConvNet, EGreedyModule, QValueModule

from torchrl.objectives import DQNLoss, SoftUpdate

torch.manual_seed(0)

env = TransformedEnv(
    GymEnv("ALE/Pong-v5", categorical_action_encoding=True),
    Compose(
        ToTensorImage(), Resize(84, interpolation="nearest"), GrayScale(), StepCounter()
    ),
)
env.set_seed(0)

value_mlp = ConvNet.default_atari_dqn(num_actions=env.action_spec.space.n)
value_net = Mod(value_mlp, in_keys=["pixels"], out_keys=["action_value"])
policy = Seq(value_net, QValueModule(spec=env.action_spec))
exploration_module = EGreedyModule(
    env.action_spec, annealing_num_steps=100_000, eps_init=0.5
)
policy_explore = Seq(policy, exploration_module)

init_rand_steps = 5000
frames_per_batch = 100
optim_steps = 10
collector = SyncDataCollector(
    env,
    policy_explore,
    frames_per_batch=frames_per_batch,
    total_frames=-1,
    init_random_frames=init_rand_steps,
)
rb = ReplayBuffer(storage=LazyTensorStorage(100_000))

loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters())
updater = SoftUpdate(loss, eps=0.99)

total_count = 0
total_episodes = 0
t0 = time.time()
for data in collector:
    # Write data in replay buffer
    rb.extend(data)
    max_length = rb[:]["next", "step_count"].max()
    if len(rb) > init_rand_steps:
        # Optim loop (we do several optim steps
        # per batch collected for efficiency)
        for _ in range(optim_steps):
            sample = rb.sample(128)
            loss_vals = loss(sample)
            loss_vals["loss"].backward()
            optim.step()
            optim.zero_grad()
            # Update exploration factor
            exploration_module.step(data.numel())
            # Update target params
            updater.step()
            total_count += data.numel()
            total_episodes += data["next", "done"].sum()
    if max_length > 200:
        break

导出基于 TensorDictModule 的策略

TensorDict 允许我们以极大的灵活性构建策略:从一个输出观察对应的动作价值的常规 Module,我们添加了一个 QValueModule 模块,该模块读取这些价值并使用某种启发式方法(例如 argmax 调用)计算动作。

然而,在我们的案例中有一个小小的技术细节:环境(实际的 Atari 游戏)不返回灰度、84x84 的图像,而是原始屏幕尺寸的彩色图像。我们附加到环境的 transforms 确保图像可以被模型读取。我们可以看到,从训练的角度来看,环境和模型之间的界限是模糊的,但在执行时事情就清晰多了:模型应该负责将输入数据(图像)转换成可以被我们的 CNN 处理的格式。

在这里,tensordict 的魔力再次为我们扫清障碍:碰巧大多数局部的(非递归的)TorchRL transforms 既可以用作环境 transforms,也可以用作 Module 实例内的预处理块。让我们看看如何将它们前置到我们的策略中

policy_transform = TensorDictSequential(
    env.transform[
        :-1
    ],  # the last transform is a step counter which we don't need for preproc
    policy_explore.requires_grad_(
        False
    ),  # Using the explorative version of the policy for didactic purposes, see below.
)

我们创建一个伪输入,并将其与策略一起传递给 export()。这将生成一个“原始”的 python 函数,该函数将读取我们的输入张量并输出一个动作,不包含任何对 TorchRL 或 tensordict 模块的引用。

一个好的做法是调用 select_out_keys() 来告知模型我们只需要特定的输出集合(以防策略返回多个张量)。

fake_td = env.base_env.fake_tensordict()
pixels = fake_td["pixels"]
with set_exploration_type("DETERMINISTIC"):
    exported_policy = torch.export.export(
        # Select only the "action" output key
        policy_transform.select_out_keys("action"),
        args=(),
        kwargs={"pixels": pixels},
        strict=False,
    )

可视化策略可能会非常有启发性:我们可以看到第一个操作是 permute、div、unsqueeze、resize,接着是卷积层和 MLP 层。

print("Deterministic policy")
exported_policy.graph_module.print_readable()

作为最后检查,我们可以使用一个虚拟输入执行策略。输出(对于单个图像)应该是一个从 0 到 6 的整数,表示在游戏中要执行的动作。

output = exported_policy.module()(pixels=pixels)
print("Exported module output", output)

关于导出 TensorDictModule 实例的更多详细信息,请参阅 tensordict 文档

注意

导出接受和输出嵌套键的模块是完全可以的。对应的 kwargs 将是键的 “_”.join(key) 版本,即 (“group0”, “agent0”, “obs”) 键将对应于 “group0_agent0_obs” 关键字参数。键冲突(例如 (“group0_agent0”, “obs”)(“group0”, “agent0_obs”))可能导致未定义的行为,应不惜一切代价避免。显然,键名也应始终产生有效的关键字参数,即它们不应包含空格或逗号等特殊字符。

torch.export 还有许多其他特性,我们将在下文进一步探讨。在此之前,让我们先简要探讨一下在测试时推理环境下的探索和随机策略,以及循环策略。

使用随机策略

您可能已经注意到,上面我们使用了 set_exploration_type 上下文管理器来控制策略的行为。如果策略是随机的(例如,策略输出动作空间的分布,就像 PPO 或其他类似的 on-policy 算法中那样)或具有探索性(附加了探索模块,如 E-Greedy、加性高斯或 Ornstein-Uhlenbeck),我们可能希望或不希望在导出的版本中使用该探索策略。幸运的是,导出工具可以理解该上下文管理器,并且只要导出发生在正确的上下文管理器中,策略的行为就应该与所示一致。为了演示这一点,让我们尝试另一种探索类型

with set_exploration_type("RANDOM"):
    exported_stochastic_policy = torch.export.export(
        policy_transform.select_out_keys("action"),
        args=(),
        kwargs={"pixels": pixels},
        strict=False,
    )

与之前的版本不同,我们导出的策略现在应该在调用栈的末尾包含一个随机模块。实际上,最后的三个操作是:生成一个介于 0 到 6 之间的随机整数,使用一个随机掩码,并根据掩码中的值选择网络输出或随机动作。

print("Stochastic policy")
exported_stochastic_policy.graph_module.print_readable()

使用循环策略

另一种典型的用例是循环策略,它将输出一个动作以及一个或多个循环状态。LSTM 和 GRU 是基于 CuDNN 的模块,这意味着它们的行为会与常规 Module 实例不同(导出工具可能无法很好地跟踪它们)。幸运的是,TorchRL 提供了这些模块的 Python 实现,可以在需要时与 CuDNN 版本互换使用。

为了展示这一点,让我们编写一个依赖于 RNN 的原型策略

from tensordict.nn import TensorDictModule
from torchrl.envs import BatchSizeTransform
from torchrl.modules import LSTMModule, MLP

lstm = LSTMModule(
    input_size=32,
    num_layers=2,
    hidden_size=256,
    in_keys=["observation", "hidden0", "hidden1"],
    out_keys=["intermediate", "hidden0", "hidden1"],
)

如果 LSTM 模块不是基于 Python 而是 CuDNN 的(LSTM),则可以使用 make_python_based() 方法来使用 Python 版本。

lstm = lstm.make_python_based()

现在我们来创建策略。我们将两个修改输入形状的层(unsqueeze/squeeze 操作)与 LSTM 和一个 MLP 结合起来。

recurrent_policy = TensorDictSequential(
    # Unsqueeze the first dim of all tensors to make LSTMCell happy
    BatchSizeTransform(reshape_fn=lambda x: x.unsqueeze(0)),
    lstm,
    TensorDictModule(
        MLP(in_features=256, out_features=5, num_cells=[64, 64]),
        in_keys=["intermediate"],
        out_keys=["action"],
    ),
    # Squeeze the first dim of all tensors to get the original shape back
    BatchSizeTransform(reshape_fn=lambda x: x.squeeze(0)),
)

和之前一样,我们选择相关的键

recurrent_policy.select_out_keys("action", "hidden0", "hidden1")
print("recurrent policy input keys:", recurrent_policy.in_keys)
print("recurrent policy output keys:", recurrent_policy.out_keys)

现在我们准备好导出了。为此,我们构建伪输入并将其传递给 export()

fake_obs = torch.randn(32)
fake_hidden0 = torch.randn(2, 256)
fake_hidden1 = torch.randn(2, 256)

# Tensor indicating whether the state is the first of a sequence
fake_is_init = torch.zeros((), dtype=torch.bool)

exported_recurrent_policy = torch.export.export(
    recurrent_policy,
    args=(),
    kwargs={
        "observation": fake_obs,
        "hidden0": fake_hidden0,
        "hidden1": fake_hidden1,
        "is_init": fake_is_init,
    },
    strict=False,
)
print("Recurrent policy graph:")
exported_recurrent_policy.graph_module.print_readable()

AOTInductor:将策略导出到不依赖 PyTorch 的 C++ 二进制文件

AOTInductor 是一个 PyTorch 模块,允许您将模型(策略或其他)导出到不依赖 PyTorch 的 C++ 二进制文件。当您需要在没有安装 PyTorch 的设备或平台上部署模型时,这特别有用。

这是一个如何使用 AOTInductor 导出策略的示例,灵感来自 AOTI 文档

from tempfile import TemporaryDirectory

from torch._inductor import aoti_compile_and_package, aoti_load_package

with TemporaryDirectory() as tmpdir:
    path = str(Path(tmpdir) / "model.pt2")
    with torch.no_grad():
        pkg_path = aoti_compile_and_package(
            exported_policy,
            # Specify the generated shared library path
            package_path=path,
        )
    print("pkg_path", pkg_path)

    compiled_module = aoti_load_package(pkg_path)

print(compiled_module(pixels=pixels))

使用 ONNX 导出 TorchRL 模型

注意

要执行脚本的这一部分,请确保已安装 pytorch onnx

!pip install onnx-pytorch
!pip install onnxruntime

您还可以在 PyTorch 生态系统中找到更多关于使用 ONNX 的信息 在此。以下示例基于此文档。

在本节中,我们将展示如何以不依赖 PyTorch 的方式导出我们的模型,使其可以在没有安装 PyTorch 的环境中执行。

网络上有很多资源解释了如何使用 ONNX 将 PyTorch 模型部署到各种硬件和设备上,包括 Raspberry PiNVIDIA TensorRTiOSAndroid

我们训练所用的 Atari 游戏可以使用 ALE 库 在没有 TorchRL 或 gymnasium 的情况下独立运行,因此为我们提供了一个关于使用 ONNX 可以实现什么的良好示例。

让我们看看这个 API 的样子

from ale_py import ALEInterface, roms

# Create the interface
ale = ALEInterface()
# Load the pong environment
ale.loadROM(roms.Pong)
ale.reset_game()

# Make a step in the simulator
action = 0
reward = ale.act(action)
screen_obs = ale.getScreenRGB()
print("Observation from ALE simulator:", type(screen_obs), screen_obs.shape)

from matplotlib import pyplot as plt

plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
plt.imshow(screen_obs)
plt.title("Screen rendering of Pong game.")

导出到 ONNX 与上面的 Export/AOTI 非常相似

import onnxruntime

with set_exploration_type("DETERMINISTIC"):
    # We use torch.onnx.dynamo_export to capture the computation graph from our policy_explore model
    pixels = torch.as_tensor(screen_obs)
    onnx_policy_export = torch.onnx.dynamo_export(policy_transform, pixels=pixels)

现在我们可以将程序保存到磁盘并加载它

with TemporaryDirectory() as tmpdir:
    onnx_file_path = str(Path(tmpdir) / "policy.onnx")
    onnx_policy_export.save(onnx_file_path)

    ort_session = onnxruntime.InferenceSession(
        onnx_file_path, providers=["CPUExecutionProvider"]
    )

onnxruntime_input = {ort_session.get_inputs()[0].name: screen_obs}
onnx_policy = ort_session.run(None, onnxruntime_input)

使用 ONNX 运行 rollout

我们现在有了一个可以运行我们策略的 ONNX 模型。让我们将其与原始的 TorchRL 实例进行比较:由于 ONNX 版本更轻量,它应该比 TorchRL 版本更快。

def onnx_policy(screen_obs: np.ndarray) -> int:
    onnxruntime_input = {ort_session.get_inputs()[0].name: screen_obs}
    onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
    action = int(onnxruntime_outputs[0])
    return action


with timeit("ONNX rollout"):
    num_steps = 1000
    ale.reset_game()
    for _ in range(num_steps):
        screen_obs = ale.getScreenRGB()
        action = onnx_policy(screen_obs)
        reward = ale.act(action)

with timeit("TorchRL version"), torch.no_grad(), set_exploration_type("DETERMINISTIC"):
    env.rollout(num_steps, policy_explore)

print(timeit.print())

请注意,ONNX 也提供了直接优化模型的可能性,但这超出了本教程的范围。

结论

在本教程中,我们学习了如何使用各种后端导出 TorchRL 模块,例如 PyTorch 内置的导出功能、AOTInductorONNX。我们演示了如何导出在 Atari 游戏上训练的策略,并使用 ALE 库在不依赖 PyTorch 的环境中运行它。我们还比较了原始 TorchRL 实例与导出的 ONNX 模型的性能。

主要收获

  • 导出 TorchRL 模块可以在未安装 PyTorch 的设备上进行部署。

  • AOTInductor 和 ONNX 提供了用于导出模型的替代后端。

  • 优化 ONNX 模型可以提高性能。

进一步阅读和学习步骤

  • 查阅 PyTorch 关于 导出功能AOTInductorONNX 的官方文档,了解更多信息。

  • 尝试在不同设备上部署导出的模型。

  • 探索 ONNX 模型的优化技术以提高性能。

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源