• 文档 >
  • 开始你的第一个训练循环
快捷方式

开始你的第一个训练循环

作者: Vincent Moens

注意

要在 notebook 中运行本教程,请在开头添加一个安装单元,其中包含

!pip install tensordict
!pip install torchrl

是时候总结我们在本入门系列中学到的一切了!

在本教程中,我们将编写最基本的训练循环,仅使用我们在之前的课程中介绍的组件。

我们将使用带有 CartPole 环境的 DQN 作为原型示例。

我们将自愿地将冗长程度降至最低,仅将每个部分链接到相关的教程。

构建环境

我们将使用带有 StepCounter 转换的 gym 环境。 如果您需要复习,请查看 环境教程 中介绍的这些功能。

import torch

torch.manual_seed(0)

import time

from torchrl.envs import GymEnv, StepCounter, TransformedEnv

env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
env.set_seed(0)

from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq

设计策略

下一步是构建我们的策略。我们将制作一个常规的、确定性的 actor 版本,用于 损失模块 中和 评估 期间。接下来,我们将使用探索模块增强它,用于 推理

from torchrl.modules import EGreedyModule, MLP, QValueModule

value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[64, 64])
value_net = Mod(value_mlp, in_keys=["observation"], out_keys=["action_value"])
policy = Seq(value_net, QValueModule(spec=env.action_spec))
exploration_module = EGreedyModule(
    env.action_spec, annealing_num_steps=100_000, eps_init=0.5
)
policy_explore = Seq(policy, exploration_module)

数据收集器和回放缓冲区

数据部分来了:我们需要一个 数据收集器 来轻松获取数据批次,以及一个 回放缓冲区 来存储用于训练的数据。

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer

init_rand_steps = 5000
frames_per_batch = 100
optim_steps = 10
collector = SyncDataCollector(
    env,
    policy_explore,
    frames_per_batch=frames_per_batch,
    total_frames=-1,
    init_random_frames=init_rand_steps,
)
rb = ReplayBuffer(storage=LazyTensorStorage(100_000))

from torch.optim import Adam

损失模块和优化器

我们按照 专用教程 中指示的方式构建损失函数,包括其优化器和目标参数更新器

from torchrl.objectives import DQNLoss, SoftUpdate

loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters(), lr=0.02)
updater = SoftUpdate(loss, eps=0.99)

日志记录器

我们将使用 CSV 日志记录器来记录我们的结果,并保存渲染的视频。

from torchrl._utils import logger as torchrl_logger
from torchrl.record import CSVLogger, VideoRecorder

path = "./training_loop"
logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4")
video_recorder = VideoRecorder(logger, tag="video")
record_env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True, pixels_only=False), video_recorder
)

训练循环

与其固定要运行的特定迭代次数,不如继续训练网络,直到它达到一定的性能(任意定义为环境中的 200 步 - 对于 CartPole,成功被定义为拥有更长的轨迹)。

total_count = 0
total_episodes = 0
t0 = time.time()
for i, data in enumerate(collector):
    # Write data in replay buffer
    rb.extend(data)
    max_length = rb[:]["next", "step_count"].max()
    if len(rb) > init_rand_steps:
        # Optim loop (we do several optim steps
        # per batch collected for efficiency)
        for _ in range(optim_steps):
            sample = rb.sample(128)
            loss_vals = loss(sample)
            loss_vals["loss"].backward()
            optim.step()
            optim.zero_grad()
            # Update exploration factor
            exploration_module.step(data.numel())
            # Update target params
            updater.step()
            if i % 10:
                torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}")
            total_count += data.numel()
            total_episodes += data["next", "done"].sum()
    if max_length > 200:
        break

t1 = time.time()

torchrl_logger.info(
    f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s."
)

渲染

最后,我们尽可能多地运行环境步骤,并将视频本地保存(请注意,我们没有探索)。

record_env.rollout(max_steps=1000, policy=policy)
video_recorder.dump()

这是完整的训练循环后,渲染的 CartPole 视频的样子

../_images/cartpole.gif

这结束了我们的“TorchRL 入门”系列教程! 欢迎在 GitHub 上分享有关它的反馈。

脚本的总运行时间: (0 分钟 22.853 秒)

预计内存使用量: 323 MB

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源