注意
转到末尾 下载完整的示例代码。
开始你的第一个训练循环¶
作者: Vincent Moens
注意
要在 notebook 中运行本教程,请在开头添加一个安装单元,其中包含
!pip install tensordict !pip install torchrl
是时候总结我们在本入门系列中学到的一切了!
在本教程中,我们将编写最基本的训练循环,仅使用我们在之前的课程中介绍的组件。
我们将使用带有 CartPole 环境的 DQN 作为原型示例。
我们将自愿地将冗长程度降至最低,仅将每个部分链接到相关的教程。
构建环境¶
我们将使用带有 StepCounter
转换的 gym 环境。 如果您需要复习,请查看 环境教程 中介绍的这些功能。
import torch
torch.manual_seed(0)
import time
from torchrl.envs import GymEnv, StepCounter, TransformedEnv
env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
env.set_seed(0)
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
设计策略¶
下一步是构建我们的策略。我们将制作一个常规的、确定性的 actor 版本,用于 损失模块 中和 评估 期间。接下来,我们将使用探索模块增强它,用于 推理。
from torchrl.modules import EGreedyModule, MLP, QValueModule
value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[64, 64])
value_net = Mod(value_mlp, in_keys=["observation"], out_keys=["action_value"])
policy = Seq(value_net, QValueModule(spec=env.action_spec))
exploration_module = EGreedyModule(
env.action_spec, annealing_num_steps=100_000, eps_init=0.5
)
policy_explore = Seq(policy, exploration_module)
数据收集器和回放缓冲区¶
数据部分来了:我们需要一个 数据收集器 来轻松获取数据批次,以及一个 回放缓冲区 来存储用于训练的数据。
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer
init_rand_steps = 5000
frames_per_batch = 100
optim_steps = 10
collector = SyncDataCollector(
env,
policy_explore,
frames_per_batch=frames_per_batch,
total_frames=-1,
init_random_frames=init_rand_steps,
)
rb = ReplayBuffer(storage=LazyTensorStorage(100_000))
from torch.optim import Adam
损失模块和优化器¶
我们按照 专用教程 中指示的方式构建损失函数,包括其优化器和目标参数更新器
from torchrl.objectives import DQNLoss, SoftUpdate
loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters(), lr=0.02)
updater = SoftUpdate(loss, eps=0.99)
日志记录器¶
我们将使用 CSV 日志记录器来记录我们的结果,并保存渲染的视频。
from torchrl._utils import logger as torchrl_logger
from torchrl.record import CSVLogger, VideoRecorder
path = "./training_loop"
logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4")
video_recorder = VideoRecorder(logger, tag="video")
record_env = TransformedEnv(
GymEnv("CartPole-v1", from_pixels=True, pixels_only=False), video_recorder
)
训练循环¶
与其固定要运行的特定迭代次数,不如继续训练网络,直到它达到一定的性能(任意定义为环境中的 200 步 - 对于 CartPole,成功被定义为拥有更长的轨迹)。
total_count = 0
total_episodes = 0
t0 = time.time()
for i, data in enumerate(collector):
# Write data in replay buffer
rb.extend(data)
max_length = rb[:]["next", "step_count"].max()
if len(rb) > init_rand_steps:
# Optim loop (we do several optim steps
# per batch collected for efficiency)
for _ in range(optim_steps):
sample = rb.sample(128)
loss_vals = loss(sample)
loss_vals["loss"].backward()
optim.step()
optim.zero_grad()
# Update exploration factor
exploration_module.step(data.numel())
# Update target params
updater.step()
if i % 10:
torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}")
total_count += data.numel()
total_episodes += data["next", "done"].sum()
if max_length > 200:
break
t1 = time.time()
torchrl_logger.info(
f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s."
)
渲染¶
最后,我们尽可能多地运行环境步骤,并将视频本地保存(请注意,我们没有探索)。
record_env.rollout(max_steps=1000, policy=policy)
video_recorder.dump()
这是完整的训练循环后,渲染的 CartPole 视频的样子

这结束了我们的“TorchRL 入门”系列教程! 欢迎在 GitHub 上分享有关它的反馈。
脚本的总运行时间: (0 分钟 22.853 秒)
预计内存使用量: 323 MB