• 文档 >
  • 开始使用您自己的第一个训练循环
快捷方式

开始使用您自己的第一个训练循环

**作者**:Vincent Moens

注意

要在笔记本中运行本教程,请在开头添加一个包含以下内容的安装单元格

!pip install tensordict
!pip install torchrl

是时候总结我们在本入门系列中所学到的所有内容了!

在本教程中,我们将使用仅包含我们在先前课程中介绍过的组件来编写最基本的训练循环。

我们将使用 DQN 和 CartPole 环境作为原型示例。

我们将自愿将详细程度降至最低,仅将每个部分链接到相关教程。

构建环境

我们将使用带有 StepCounter 转换的 gym 环境。如果您需要复习,请查看 环境教程 中介绍的这些功能。

import torch

torch.manual_seed(0)

import time

from torchrl.envs import GymEnv, StepCounter, TransformedEnv

env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
env.set_seed(0)

from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq

设计策略

下一步是构建我们的策略。我们将制作一个常规的、确定性的演员版本,将在 损失模块 中和 评估 期间使用。接下来,我们将使用探索模块对其进行增强以进行 推理

from torchrl.modules import EGreedyModule, MLP, QValueModule

value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[64, 64])
value_net = Mod(value_mlp, in_keys=["observation"], out_keys=["action_value"])
policy = Seq(value_net, QValueModule(spec=env.action_spec))
exploration_module = EGreedyModule(
    env.action_spec, annealing_num_steps=100_000, eps_init=0.5
)
policy_explore = Seq(policy, exploration_module)

数据收集器和回放缓冲区

数据部分来了:我们需要一个 数据收集器 来轻松获取数据批次,以及一个 回放缓冲区 来存储这些数据以进行训练。

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer

init_rand_steps = 5000
frames_per_batch = 100
optim_steps = 10
collector = SyncDataCollector(
    env,
    policy,
    frames_per_batch=frames_per_batch,
    total_frames=-1,
    init_random_frames=init_rand_steps,
)
rb = ReplayBuffer(storage=LazyTensorStorage(100_000))

from torch.optim import Adam

损失模块和优化器

我们按照 专用教程 中的指示构建我们的损失,并使用其优化器和目标参数更新器

from torchrl.objectives import DQNLoss, SoftUpdate

loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters(), lr=0.02)
updater = SoftUpdate(loss, eps=0.99)

日志记录器

我们将使用 CSV 日志记录器来记录我们的结果并保存渲染的视频。

from torchrl._utils import logger as torchrl_logger
from torchrl.record import CSVLogger, VideoRecorder

path = "./training_loop"
logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4")
video_recorder = VideoRecorder(logger, tag="video")
record_env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True, pixels_only=False), video_recorder
)

训练循环

我们将继续训练网络,直到它达到一定性能(任意定义为环境中的 200 步,在 CartPole 中,成功定义为具有更长的轨迹),而不是固定运行的特定迭代次数。

total_count = 0
total_episodes = 0
t0 = time.time()
for i, data in enumerate(collector):
    # Write data in replay buffer
    rb.extend(data)
    max_length = rb[:]["next", "step_count"].max()
    if len(rb) > init_rand_steps:
        # Optim loop (we do several optim steps
        # per batch collected for efficiency)
        for _ in range(optim_steps):
            sample = rb.sample(128)
            loss_vals = loss(sample)
            loss_vals["loss"].backward()
            optim.step()
            optim.zero_grad()
            # Update exploration factor
            exploration_module.step(data.numel())
            # Update target params
            updater.step()
            if i % 10:
                torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}")
            total_count += data.numel()
            total_episodes += data["next", "done"].sum()
    if max_length > 200:
        break

t1 = time.time()

torchrl_logger.info(
    f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s."
)

渲染

最后,我们运行环境尽可能多的步数,并将视频保存在本地(请注意,我们没有探索)。

record_env.rollout(max_steps=1000, policy=policy)
video_recorder.dump()

完成完整的训练循环后,您的渲染的 CartPole 视频将如下所示

../_images/cartpole.gif

这结束了我们“TorchRL 入门”教程系列!随时在 GitHub 上分享您的反馈。

**脚本的总运行时间:**(0 分钟 20.760 秒)

**估计内存使用量:** 321 MB

Sphinx-Gallery 生成的画廊

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的答案

查看资源