• 文档 >
  • TorchRL 训练器:DQN 示例
快捷方式

TorchRL 训练器:DQN 示例

作者: Vincent Moens

TorchRL 提供了一个通用的 Trainer 类来处理您的训练循环。训练器执行一个嵌套循环,其中外循环是数据收集,内循环消耗此数据或从重放缓冲区检索的一些数据来训练模型。在此训练循环的各个点,可以附加 hook 并在给定的时间间隔执行。

在本教程中,我们将使用 trainer 类来训练 DQN 算法,以从头开始解决 CartPole 任务。

主要内容

  • 构建具有基本组件的训练器:数据收集器、损失模块、重放缓冲区和优化器。

  • 向训练器添加 hook,例如记录器、目标网络更新器等。

训练器是完全可定制的,并提供大量功能。本教程围绕其构建展开。我们将首先详细介绍如何构建库的每个组件,然后使用 Trainer 类将各个部分组合在一起。

在此过程中,我们还将关注库的其他一些方面

  • 如何在 TorchRL 中构建环境,包括转换(例如数据归一化、帧连接、调整大小和转换为灰度)和并行执行。与我们在 DDPG 教程 中所做的不同,我们将归一化像素而不是状态向量。

  • 如何设计 QValueActor 对象,即估计动作值并选择具有最高估计回报的动作的 actor;

  • 如何从您的环境中高效收集数据并将其存储在重放缓冲区中;

  • 如何使用多步,这是离策略算法的简单预处理步骤;

  • 最后是如何评估您的模型。

先决条件:我们鼓励您首先通过 PPO 教程 熟悉 torchrl。

DQN

DQN (深度 Q 学习) 是深度强化学习的奠基之作。

从高层次来看,该算法非常简单:Q 学习包括学习状态-动作值表,以便在遇到任何特定状态时,我们只需搜索值最高的动作就知道选择哪个动作。这种简单的设置要求动作和状态是离散的,否则无法构建查找表。

DQN 使用神经网络,该网络编码从状态-动作空间到值(标量)空间的映射,这摊销了存储和探索所有可能的状态-动作组合的成本:如果过去没有见过某个状态,我们仍然可以将其与各种可用动作一起传递到我们的神经网络中,并获得每个可用动作的插值。

我们将解决经典控制问题:倒立摆。来自 Gymnasium 文档,该环境从中检索

一根杆通过一个非驱动关节连接到一辆小车上,小车沿着一条
无摩擦轨道移动。钟摆垂直放置在小车上,目标是
通过在左右方向施加力来平衡杆
在小车上。
Cart Pole

我们的目标不是给出该算法的 SOTA 实现,而是提供 TorchRL 功能在该算法背景下的高级说明。

import os
import uuid

import torch
from torch import nn
from torchrl.collectors import MultiaSyncDataCollector, SyncDataCollector
from torchrl.data import LazyMemmapStorage, MultiStep, TensorDictReplayBuffer
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    ParallelEnv,
    RewardScaling,
    StepCounter,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
    CatFrames,
    Compose,
    GrayScale,
    ObservationNorm,
    Resize,
    ToTensorImage,
    TransformedEnv,
)
from torchrl.modules import DuelingCnnDQNet, EGreedyModule, QValueActor

from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl.record.loggers.csv import CSVLogger
from torchrl.trainers import (
    LogReward,
    Recorder,
    ReplayBufferTrainer,
    Trainer,
    UpdateWeights,
)


def is_notebook() -> bool:
    try:
        shell = get_ipython().__class__.__name__
        if shell == "ZMQInteractiveShell":
            return True  # Jupyter notebook or qtconsole
        elif shell == "TerminalInteractiveShell":
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False  # Probably standard Python interpreter

让我们从算法所需的各种部分开始

  • 一个环境;

  • 一个策略(以及我们在“模型”保护伞下分组的相关模块);

  • 一个数据收集器,它使策略在环境中发挥作用并交付训练数据;

  • 一个重放缓冲区,用于存储训练数据;

  • 一个损失模块,它计算目标函数以训练我们的策略以最大化回报;

  • 一个优化器,它根据我们的损失执行参数更新。

其他模块包括日志记录器、记录器(在“eval”模式下执行策略)和目标网络更新器。有了所有这些组件,很容易看出一个人如何在训练脚本中错放或误用一个组件。训练器就在那里为您协调一切!

构建环境

首先,让我们编写一个辅助函数,该函数将输出一个环境。与往常一样,“原始”环境可能太简单而无法在实践中使用,我们需要进行一些数据转换才能将其输出公开给策略。

我们将使用五个转换

  • StepCounter 来计算每个轨迹中的步数;

  • ToTensorImage[W, H, C] uint8 张量转换为 [0, 1] 空间中的浮点张量,形状为 [C, W, H]

  • RewardScaling 以减小回报的规模;

  • GrayScale 将把我们的图像变成灰度图像;

  • Resize 将图像大小调整为 64x64 格式;

  • CatFrames 将沿通道维度将任意数量的连续帧 (N=4) 连接到单个张量中。这很有用,因为单个图像不携带有关倒立摆运动的信息。需要一些关于过去观察和动作的记忆,可以通过循环神经网络或使用帧堆栈来实现。

  • ObservationNorm 它将根据一些自定义汇总统计信息来归一化我们的观察结果。

在实践中,我们的环境构建器有两个参数

  • parallel:确定是否必须并行运行多个环境。我们在 ParallelEnv 之后堆叠转换,以利用设备上操作的向量化,尽管这在技术上适用于附加到其自身转换集的每个环境。

  • obs_norm_sd 将包含 ObservationNorm 转换的归一化常数。

def make_env(
    parallel=False,
    obs_norm_sd=None,
    num_workers=1,
):
    if obs_norm_sd is None:
        obs_norm_sd = {"standard_normal": True}
    if parallel:

        def maker():
            return GymEnv(
                "CartPole-v1",
                from_pixels=True,
                pixels_only=True,
                device=device,
            )

        base_env = ParallelEnv(
            num_workers,
            EnvCreator(maker),
            # Don't create a sub-process if we have only one worker
            serial_for_single=True,
            mp_start_method=mp_context,
        )
    else:
        base_env = GymEnv(
            "CartPole-v1",
            from_pixels=True,
            pixels_only=True,
            device=device,
        )

    env = TransformedEnv(
        base_env,
        Compose(
            StepCounter(),  # to count the steps of each trajectory
            ToTensorImage(),
            RewardScaling(loc=0.0, scale=0.1),
            GrayScale(),
            Resize(64, 64),
            CatFrames(4, in_keys=["pixels"], dim=-3),
            ObservationNorm(in_keys=["pixels"], **obs_norm_sd),
        ),
    )
    return env

计算归一化常数

为了归一化图像,我们不想使用完整的 [C, W, H] 归一化掩码独立归一化每个像素,而是使用更简单的 [C, 1, 1] 形的归一化常数集(位置和比例参数)。我们将使用 init_stats()reduce_dim 参数来指示必须减少哪些维度,并使用 keep_dims 参数来确保并非所有维度都在此过程中消失

def get_norm_stats():
    test_env = make_env()
    test_env.transform[-1].init_stats(
        num_iter=1000, cat_dim=0, reduce_dim=[-1, -2, -4], keep_dims=(-1, -2)
    )
    obs_norm_sd = test_env.transform[-1].state_dict()
    # let's check that normalizing constants have a size of ``[C, 1, 1]`` where
    # ``C=4`` (because of :class:`~torchrl.envs.CatFrames`).
    print("state dict of the observation norm:", obs_norm_sd)
    test_env.close()
    del test_env
    return obs_norm_sd

构建模型(深度 Q 网络)

以下函数构建了一个 DuelingCnnDQNet 对象,这是一个简单的 CNN,后跟一个两层 MLP。此处使用的唯一技巧是使用以下公式计算动作值(即左动作值和右动作值)

\[\mathbb{v} = b(obs) + v(obs) - \mathbb{E}[v(obs)]\]

其中 \(\mathbb{v}\) 是我们的动作值向量,\(b\)\(\mathbb{R}^n \rightarrow 1\) 函数,\(v\)\(\mathbb{R}^n \rightarrow \mathbb{R}^m\) 函数,其中 \(n = \# obs\)\(m = \# actions\)

我们的网络被包装在 QValueActor 中,它将读取状态-动作值,选择具有最大值的动作值,并将所有这些结果写入输入 tensordict.TensorDict

def make_model(dummy_env):
    cnn_kwargs = {
        "num_cells": [32, 64, 64],
        "kernel_sizes": [6, 4, 3],
        "strides": [2, 2, 1],
        "activation_class": nn.ELU,
        # This can be used to reduce the size of the last layer of the CNN
        # "squeeze_output": True,
        # "aggregator_class": nn.AdaptiveAvgPool2d,
        # "aggregator_kwargs": {"output_size": (1, 1)},
    }
    mlp_kwargs = {
        "depth": 2,
        "num_cells": [
            64,
            64,
        ],
        "activation_class": nn.ELU,
    }
    net = DuelingCnnDQNet(
        dummy_env.action_spec.shape[-1], 1, cnn_kwargs, mlp_kwargs
    ).to(device)
    net.value[-1].bias.data.fill_(init_bias)

    actor = QValueActor(net, in_keys=["pixels"], spec=dummy_env.action_spec).to(device)
    # init actor: because the model is composed of lazy conv/linear layers,
    # we must pass a fake batch of data through it to instantiate them.
    tensordict = dummy_env.fake_tensordict()
    actor(tensordict)

    # we join our actor with an EGreedyModule for data collection
    exploration_module = EGreedyModule(
        spec=dummy_env.action_spec,
        annealing_num_steps=total_frames,
        eps_init=eps_greedy_val,
        eps_end=eps_greedy_val_env,
    )
    actor_explore = TensorDictSequential(actor, exploration_module)

    return actor, actor_explore

收集和存储数据

重放缓冲区

重放缓冲区在离策略 RL 算法(如 DQN)中起着核心作用。它们构成了我们将在训练期间从中采样的数据集。

在这里,我们将使用常规采样策略,尽管优先 RB 可以显着提高性能。

我们使用 LazyMemmapStorage 类将存储放置在磁盘上。此存储以延迟方式创建:它仅在第一批数据传递给它时才会被实例化。

此存储的唯一要求是,在写入时传递给它的数据必须始终具有相同的形状。

def get_replay_buffer(buffer_size, n_optim, batch_size):
    replay_buffer = TensorDictReplayBuffer(
        batch_size=batch_size,
        storage=LazyMemmapStorage(buffer_size),
        prefetch=n_optim,
    )
    return replay_buffer

数据收集器

PPODDPG 中一样,我们将使用数据收集器作为外循环中的数据加载器。

我们选择以下配置:我们将在一系列并行环境中同步并行运行,这些并行环境位于不同的收集器中,这些收集器本身并行运行,但异步运行。

注意

此功能仅在 Python 多处理库的“spawn”启动方法中运行代码时可用。如果本教程直接作为脚本运行(因此使用“fork”方法),我们将使用常规的 SyncDataCollector

此配置的优点是我们可以在批量执行的计算量与我们希望异步执行的计算量之间取得平衡。我们鼓励读者尝试修改收集器的数量(即传递给收集器的环境构造函数的数量)和每个收集器中并行执行的环境数量(由 num_workers 超参数控制)来实验收集速度如何受到影响。

收集器的设备可以通过 device(通用)、policy_deviceenv_devicestoring_device 参数完全参数化。storing_device 参数将修改正在收集的数据的位置:如果我们正在收集的批次具有相当大的大小,我们可能希望将它们存储在与计算发生的设备不同的位置。对于异步数据收集器(如我们的收集器),不同的存储设备意味着我们收集的数据每次都不会位于同一设备上,这是我们的训练循环必须考虑的。为简单起见,我们将所有子收集器的设备设置为相同的值。

def get_collector(
    stats,
    num_collectors,
    actor_explore,
    frames_per_batch,
    total_frames,
    device,
):
    # We can't use nested child processes with mp_start_method="fork"
    if is_fork:
        cls = SyncDataCollector
        env_arg = make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers)
    else:
        cls = MultiaSyncDataCollector
        env_arg = [
            make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers)
        ] * num_collectors
    data_collector = cls(
        env_arg,
        policy=actor_explore,
        frames_per_batch=frames_per_batch,
        total_frames=total_frames,
        # this is the default behavior: the collector runs in ``"random"`` (or explorative) mode
        exploration_type=ExplorationType.RANDOM,
        # We set the all the devices to be identical. Below is an example of
        # heterogeneous devices
        device=device,
        storing_device=device,
        split_trajs=False,
        postproc=MultiStep(gamma=gamma, n_steps=5),
    )
    return data_collector

损失函数

构建我们的损失函数非常简单:我们只需要向 DQNLoss 类提供模型和一堆超参数即可。

目标参数

许多离策略 RL 算法在估计下一个状态或状态-动作对的值时使用“目标参数”的概念。目标参数是模型参数的滞后副本。由于它们的预测与当前模型配置的预测不匹配,因此它们通过对估计的值设置悲观界限来帮助学习。这是一个强大的技巧(称为“双 Q 学习”),在类似的算法中很常见。

def get_loss_module(actor, gamma):
    loss_module = DQNLoss(actor, delay_value=True)
    loss_module.make_value_estimator(gamma=gamma)
    target_updater = SoftUpdate(loss_module, eps=0.995)
    return loss_module, target_updater

超参数

让我们从超参数开始。以下设置应该在实践中运行良好,并且该算法的性能应该不会对这些参数的细微变化过于敏感。

is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)

优化器

# the learning rate of the optimizer
lr = 2e-3
# weight decay
wd = 1e-5
# the beta parameters of Adam
betas = (0.9, 0.999)
# Optimization steps per batch collected (aka UPD or updates per data)
n_optim = 8

DQN 参数

伽马衰减因子

gamma = 0.99

平滑目标网络更新衰减参数。这大致对应于具有硬目标网络更新的 1/tau 间隔

tau = 0.02

数据收集和重放缓冲区

注意

用于正确训练的值已注释。

在环境中收集的总帧数。在其他实现中,用户定义了最大 episode 数。这对于我们的数据收集器来说更难做到,因为它们返回 N 个收集帧的批次,其中 N 是一个常数。但是,当收集到一定数量的 episode 时,可以通过中断训练循环来轻松获得对 episode 数量的相同限制。

total_frames = 5_000  # 500000

用于初始化重放缓冲区的随机帧。

init_random_frames = 100  # 1000

每个收集批次中的帧数。

frames_per_batch = 32  # 128

在每个优化步骤中从重放缓冲区采样的帧数

batch_size = 32  # 256

重放缓冲区的大小(以帧为单位)

buffer_size = min(total_frames, 100000)

每个数据收集器中并行运行的环境数

num_workers = 2  # 8
num_collectors = 2  # 4

环境和探索

我们在 Epsilon-greedy 探索中设置了 epsilon 因子的初始值和最终值。由于我们的策略是确定性的,因此探索至关重要:如果没有探索,随机性的唯一来源将是环境重置。

eps_greedy_val = 0.1
eps_greedy_val_env = 0.005

为了加快学习速度,我们将值网络最后一层的偏差设置为预定义值(这不是强制性的)

init_bias = 2.0

注意

为了快速渲染教程,total_frames 超参数设置为非常低的值。要获得合理的性能,请使用更大的值,例如 500000

构建训练器

TorchRL 的 Trainer 类构造函数采用以下仅关键字参数

  • 收集器

  • loss_module

  • 优化器

  • logger:记录器可以是

  • total_frames:此参数定义训练器的生命周期。

  • frame_skip:当使用帧跳过时,必须让收集器知道它,以便准确计算收集的帧数等。让训练器知道此参数不是强制性的,但有助于在总帧数(预算)固定但帧跳过可变的情况下进行更公平的比较。

stats = get_norm_stats()
test_env = make_env(parallel=False, obs_norm_sd=stats)
# Get model
actor, actor_explore = make_model(test_env)
loss_module, target_net_updater = get_loss_module(actor, gamma)

collector = get_collector(
    stats=stats,
    num_collectors=num_collectors,
    actor_explore=actor_explore,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    device=device,
)
optimizer = torch.optim.Adam(
    loss_module.parameters(), lr=lr, weight_decay=wd, betas=betas
)
exp_name = f"dqn_exp_{uuid.uuid1()}"
tmpdir = tempfile.TemporaryDirectory()
logger = CSVLogger(exp_name=exp_name, log_dir=tmpdir.name)
warnings.warn(f"log dir: {logger.experiment.log_dir}")
state dict of the observation norm: OrderedDict([('standard_normal', tensor(True)), ('loc', tensor([[[0.9895]],

        [[0.9895]],

        [[0.9895]],

        [[0.9895]]])), ('scale', tensor([[[0.0737]],

        [[0.0737]],

        [[0.0737]],

        [[0.0737]]]))])

我们可以控制应记录标量的频率。在这里,我们将此值设置为较低的值,因为我们的训练循环很短

log_interval = 500

trainer = Trainer(
    collector=collector,
    total_frames=total_frames,
    frame_skip=1,
    loss_module=loss_module,
    optimizer=optimizer,
    logger=logger,
    optim_steps_per_batch=n_optim,
    log_interval=log_interval,
)

注册 hook

可以通过两种不同的方式实现注册 hook

  • 如果 hook 具有它,则 register() 方法是首选。只需提供训练器作为输入,hook 将以默认名称在默认位置注册。对于某些 hook,注册可能非常复杂:ReplayBufferTrainer 需要 3 个 hook(extendsampleupdate_priority),这可能很难实现。

buffer_hook = ReplayBufferTrainer(
    get_replay_buffer(buffer_size, n_optim, batch_size=batch_size),
    flatten_tensordicts=True,
)
buffer_hook.register(trainer)
weight_updater = UpdateWeights(collector, update_weights_interval=1)
weight_updater.register(trainer)
recorder = Recorder(
    record_interval=100,  # log every 100 optimization steps
    record_frames=1000,  # maximum number of frames in the record
    frame_skip=1,
    policy_exploration=actor_explore,
    environment=test_env,
    exploration_type=ExplorationType.DETERMINISTIC,
    log_keys=[("next", "reward")],
    out_keys={("next", "reward"): "rewards"},
    log_pbar=True,
)
recorder.register(trainer)

探索模块 epsilon 因子也被退火

trainer.register_op("post_steps", actor_explore[1].step, frames=frames_per_batch)
  • 可以使用 register_op() 注册任何可调用对象(包括 TrainerHookBase 子类)。在这种情况下,必须显式传递位置 ()。此方法可以更好地控制 hook 的位置,但也需要更多地了解 Trainer 机制。查看 trainer 文档,了解有关 trainer hook 的详细说明。

trainer.register_op("post_optim", target_net_updater.step)

我们也可以记录训练奖励。请注意,这对于 CartPole 的兴趣有限,因为奖励始终为 1。奖励的折扣总和最大化不是通过获得更高的奖励来实现的,而是通过让倒立摆存活更长时间来实现的。这将反映在进度条中显示的 total_rewards 值中。

log_reward = LogReward(log_pbar=True)
log_reward.register(trainer)

注意

如果需要,可以将多个优化器链接到训练器。在这种情况下,每个优化器都将绑定到损失字典中的一个字段。查看 OptimizerHook 以了解更多信息。

我们准备好训练我们的算法了!简单调用 trainer.train(),我们将获得记录在案的结果。

trainer.train()
  0%|          | 0/5000 [00:00<?, ?it/s]
  1%|          | 32/5000 [00:07<20:39,  4.01it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:   1%|          | 32/5000 [00:07<20:39,  4.01it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:   1%|▏         | 64/5000 [00:08<09:02,  9.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:   1%|▏         | 64/5000 [00:08<09:02,  9.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:   2%|▏         | 96/5000 [00:08<05:18, 15.40it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434:   2%|▏         | 96/5000 [00:08<05:18, 15.40it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434:   3%|▎         | 128/5000 [00:09<03:33, 22.84it/s]
r_training: 0.3323, rewards: 0.1000, total_rewards: 0.9434:   3%|▎         | 128/5000 [00:09<03:33, 22.84it/s]
r_training: 0.3323, rewards: 0.1000, total_rewards: 0.9434:   3%|▎         | 160/5000 [00:09<02:34, 31.38it/s]
r_training: 0.3445, rewards: 0.1000, total_rewards: 0.9434:   3%|▎         | 160/5000 [00:09<02:34, 31.38it/s]
r_training: 0.3445, rewards: 0.1000, total_rewards: 0.9434:   4%|▍         | 192/5000 [00:09<01:57, 40.89it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:   4%|▍         | 192/5000 [00:09<01:57, 40.89it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:   4%|▍         | 224/5000 [00:10<01:36, 49.41it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434:   4%|▍         | 224/5000 [00:10<01:36, 49.41it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434:   5%|▌         | 256/5000 [00:10<01:21, 58.17it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434:   5%|▌         | 256/5000 [00:10<01:21, 58.17it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434:   6%|▌         | 288/5000 [00:10<01:12, 65.07it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:   6%|▌         | 288/5000 [00:10<01:12, 65.07it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:   6%|▋         | 320/5000 [00:11<01:05, 71.61it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:   6%|▋         | 320/5000 [00:11<01:05, 71.61it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:   7%|▋         | 352/5000 [00:11<01:01, 76.03it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434:   7%|▋         | 352/5000 [00:11<01:01, 76.03it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434:   8%|▊         | 384/5000 [00:11<00:57, 80.71it/s]
r_training: 0.3505, rewards: 0.1000, total_rewards: 0.9434:   8%|▊         | 384/5000 [00:11<00:57, 80.71it/s]
r_training: 0.3505, rewards: 0.1000, total_rewards: 0.9434:   8%|▊         | 416/5000 [00:12<00:55, 82.38it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:   8%|▊         | 416/5000 [00:12<00:55, 82.38it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:   9%|▉         | 448/5000 [00:12<00:52, 86.92it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:   9%|▉         | 448/5000 [00:12<00:52, 86.92it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  10%|▉         | 480/5000 [00:12<00:51, 87.57it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  10%|▉         | 480/5000 [00:12<00:51, 87.57it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  10%|█         | 512/5000 [00:13<00:52, 86.28it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434:  10%|█         | 512/5000 [00:13<00:52, 86.28it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434:  11%|█         | 544/5000 [00:13<00:51, 86.63it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434:  11%|█         | 544/5000 [00:13<00:51, 86.63it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434:  12%|█▏        | 576/5000 [00:14<00:50, 88.04it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  12%|█▏        | 576/5000 [00:14<00:50, 88.04it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  12%|█▏        | 608/5000 [00:14<00:49, 88.23it/s]
r_training: 0.3596, rewards: 0.1000, total_rewards: 0.9434:  12%|█▏        | 608/5000 [00:14<00:49, 88.23it/s]
r_training: 0.3596, rewards: 0.1000, total_rewards: 0.9434:  13%|█▎        | 640/5000 [00:14<00:49, 88.38it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434:  13%|█▎        | 640/5000 [00:14<00:49, 88.38it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434:  13%|█▎        | 672/5000 [00:15<00:47, 90.20it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434:  13%|█▎        | 672/5000 [00:15<00:47, 90.20it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434:  14%|█▍        | 704/5000 [00:15<00:48, 88.22it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  14%|█▍        | 704/5000 [00:15<00:48, 88.22it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  15%|█▍        | 736/5000 [00:15<00:48, 88.05it/s]
r_training: 0.3960, rewards: 0.1000, total_rewards: 0.9434:  15%|█▍        | 736/5000 [00:15<00:48, 88.05it/s]
r_training: 0.3960, rewards: 0.1000, total_rewards: 0.9434:  15%|█▌        | 768/5000 [00:16<00:47, 88.59it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  15%|█▌        | 768/5000 [00:16<00:47, 88.59it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  16%|█▌        | 800/5000 [00:16<00:46, 89.47it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434:  16%|█▌        | 800/5000 [00:16<00:46, 89.47it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434:  17%|█▋        | 832/5000 [00:16<00:47, 88.51it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434:  17%|█▋        | 832/5000 [00:16<00:47, 88.51it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434:  17%|█▋        | 864/5000 [00:17<00:45, 91.01it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  17%|█▋        | 864/5000 [00:17<00:45, 91.01it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  18%|█▊        | 896/5000 [00:17<00:46, 89.03it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  18%|█▊        | 896/5000 [00:17<00:46, 89.03it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  19%|█▊        | 928/5000 [00:17<00:46, 88.35it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434:  19%|█▊        | 928/5000 [00:17<00:46, 88.35it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434:  19%|█▉        | 960/5000 [00:18<00:45, 88.13it/s]
r_training: 0.3292, rewards: 0.1000, total_rewards: 0.9434:  19%|█▉        | 960/5000 [00:18<00:45, 88.13it/s]
r_training: 0.3292, rewards: 0.1000, total_rewards: 0.9434:  20%|█▉        | 992/5000 [00:18<00:44, 89.57it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  20%|█▉        | 992/5000 [00:18<00:44, 89.57it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  20%|██        | 1024/5000 [00:19<00:45, 88.35it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  20%|██        | 1024/5000 [00:19<00:45, 88.35it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  21%|██        | 1056/5000 [00:19<00:45, 87.62it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  21%|██        | 1056/5000 [00:19<00:45, 87.62it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  22%|██▏       | 1088/5000 [00:19<00:42, 91.60it/s]
r_training: 0.3899, rewards: 0.1000, total_rewards: 0.9434:  22%|██▏       | 1088/5000 [00:19<00:42, 91.60it/s]
r_training: 0.3899, rewards: 0.1000, total_rewards: 0.9434:  22%|██▏       | 1120/5000 [00:20<00:41, 93.01it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  22%|██▏       | 1120/5000 [00:20<00:41, 93.01it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  23%|██▎       | 1152/5000 [00:20<00:40, 95.24it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  23%|██▎       | 1152/5000 [00:20<00:40, 95.24it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  24%|██▎       | 1184/5000 [00:20<00:41, 92.75it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  24%|██▎       | 1184/5000 [00:20<00:41, 92.75it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  24%|██▍       | 1216/5000 [00:21<00:41, 90.23it/s]
r_training: 0.3445, rewards: 0.1000, total_rewards: 0.9434:  24%|██▍       | 1216/5000 [00:21<00:41, 90.23it/s]
r_training: 0.3445, rewards: 0.1000, total_rewards: 0.9434:  25%|██▍       | 1248/5000 [00:21<00:42, 87.78it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  25%|██▍       | 1248/5000 [00:21<00:42, 87.78it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  26%|██▌       | 1280/5000 [00:21<00:42, 87.94it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434:  26%|██▌       | 1280/5000 [00:21<00:42, 87.94it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434:  26%|██▌       | 1312/5000 [00:22<00:41, 88.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  26%|██▌       | 1312/5000 [00:22<00:41, 88.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  27%|██▋       | 1344/5000 [00:22<00:40, 90.44it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  27%|██▋       | 1344/5000 [00:22<00:40, 90.44it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  28%|██▊       | 1376/5000 [00:22<00:40, 88.57it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  28%|██▊       | 1376/5000 [00:22<00:40, 88.57it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  28%|██▊       | 1408/5000 [00:23<00:40, 88.32it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  28%|██▊       | 1408/5000 [00:23<00:40, 88.32it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  29%|██▉       | 1440/5000 [00:23<00:40, 88.41it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  29%|██▉       | 1440/5000 [00:23<00:40, 88.41it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  29%|██▉       | 1472/5000 [00:24<00:40, 88.02it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  29%|██▉       | 1472/5000 [00:24<00:40, 88.02it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  30%|███       | 1504/5000 [00:24<00:40, 87.23it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  30%|███       | 1504/5000 [00:24<00:40, 87.23it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  31%|███       | 1536/5000 [00:24<00:39, 86.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  31%|███       | 1536/5000 [00:24<00:39, 86.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  31%|███▏      | 1568/5000 [00:25<00:39, 87.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  31%|███▏      | 1568/5000 [00:25<00:39, 87.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  32%|███▏      | 1600/5000 [00:25<00:39, 86.43it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  32%|███▏      | 1600/5000 [00:25<00:39, 86.43it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  33%|███▎      | 1632/5000 [00:25<00:38, 86.66it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  33%|███▎      | 1632/5000 [00:25<00:38, 86.66it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  33%|███▎      | 1664/5000 [00:26<00:38, 87.17it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  33%|███▎      | 1664/5000 [00:26<00:38, 87.17it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  34%|███▍      | 1696/5000 [00:26<00:37, 89.14it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  34%|███▍      | 1696/5000 [00:26<00:37, 89.14it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  35%|███▍      | 1728/5000 [00:26<00:36, 89.07it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434:  35%|███▍      | 1728/5000 [00:26<00:36, 89.07it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434:  35%|███▌      | 1760/5000 [00:27<00:35, 91.73it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  35%|███▌      | 1760/5000 [00:27<00:35, 91.73it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  36%|███▌      | 1792/5000 [00:27<00:35, 91.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  36%|███▌      | 1792/5000 [00:27<00:35, 91.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  36%|███▋      | 1824/5000 [00:28<00:35, 88.62it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  36%|███▋      | 1824/5000 [00:28<00:35, 88.62it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  37%|███▋      | 1856/5000 [00:28<00:35, 88.46it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  37%|███▋      | 1856/5000 [00:28<00:35, 88.46it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  38%|███▊      | 1888/5000 [00:28<00:33, 92.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  38%|███▊      | 1888/5000 [00:28<00:33, 92.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  38%|███▊      | 1920/5000 [00:29<00:32, 94.75it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  38%|███▊      | 1920/5000 [00:29<00:32, 94.75it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  39%|███▉      | 1952/5000 [00:29<00:32, 94.41it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  39%|███▉      | 1952/5000 [00:29<00:32, 94.41it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  40%|███▉      | 1984/5000 [00:29<00:31, 94.62it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  40%|███▉      | 1984/5000 [00:29<00:31, 94.62it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  40%|████      | 2016/5000 [00:30<00:31, 94.09it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  40%|████      | 2016/5000 [00:30<00:31, 94.09it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  41%|████      | 2048/5000 [00:30<00:31, 93.61it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  41%|████      | 2048/5000 [00:30<00:31, 93.61it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  42%|████▏     | 2080/5000 [00:30<00:31, 94.07it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  42%|████▏     | 2080/5000 [00:30<00:31, 94.07it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  42%|████▏     | 2112/5000 [00:31<00:30, 95.79it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  42%|████▏     | 2112/5000 [00:31<00:30, 95.79it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  43%|████▎     | 2144/5000 [00:31<00:31, 91.76it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  43%|████▎     | 2144/5000 [00:31<00:31, 91.76it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  44%|████▎     | 2176/5000 [00:31<00:30, 93.25it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  44%|████▎     | 2176/5000 [00:31<00:30, 93.25it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  44%|████▍     | 2208/5000 [00:32<00:29, 95.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  44%|████▍     | 2208/5000 [00:32<00:29, 95.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  45%|████▍     | 2240/5000 [00:32<00:29, 93.00it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434:  45%|████▍     | 2240/5000 [00:32<00:29, 93.00it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434:  45%|████▌     | 2272/5000 [00:32<00:28, 94.99it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  45%|████▌     | 2272/5000 [00:32<00:28, 94.99it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  46%|████▌     | 2304/5000 [00:33<00:28, 95.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  46%|████▌     | 2304/5000 [00:33<00:28, 95.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  47%|████▋     | 2336/5000 [00:33<00:28, 94.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  47%|████▋     | 2336/5000 [00:33<00:28, 94.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  47%|████▋     | 2368/5000 [00:33<00:28, 92.67it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  47%|████▋     | 2368/5000 [00:33<00:28, 92.67it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  48%|████▊     | 2400/5000 [00:34<00:28, 91.16it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  48%|████▊     | 2400/5000 [00:34<00:28, 91.16it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  49%|████▊     | 2432/5000 [00:34<00:28, 91.36it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434:  49%|████▊     | 2432/5000 [00:34<00:28, 91.36it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434:  49%|████▉     | 2464/5000 [00:34<00:27, 92.58it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  49%|████▉     | 2464/5000 [00:34<00:27, 92.58it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  50%|████▉     | 2496/5000 [00:35<00:27, 90.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  50%|████▉     | 2496/5000 [00:35<00:27, 90.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  51%|█████     | 2528/5000 [00:35<00:27, 89.95it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  51%|█████     | 2528/5000 [00:35<00:27, 89.95it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  51%|█████     | 2560/5000 [00:35<00:27, 87.80it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  51%|█████     | 2560/5000 [00:35<00:27, 87.80it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  52%|█████▏    | 2592/5000 [00:36<00:26, 89.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  52%|█████▏    | 2592/5000 [00:36<00:26, 89.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  52%|█████▏    | 2624/5000 [00:36<00:26, 90.43it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  52%|█████▏    | 2624/5000 [00:36<00:26, 90.43it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  53%|█████▎    | 2656/5000 [00:37<00:26, 88.07it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  53%|█████▎    | 2656/5000 [00:37<00:26, 88.07it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  54%|█████▍    | 2688/5000 [00:37<00:26, 86.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  54%|█████▍    | 2688/5000 [00:37<00:26, 86.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  54%|█████▍    | 2720/5000 [00:37<00:26, 86.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  54%|█████▍    | 2720/5000 [00:37<00:26, 86.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  55%|█████▌    | 2752/5000 [00:38<00:25, 88.22it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  55%|█████▌    | 2752/5000 [00:38<00:25, 88.22it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  56%|█████▌    | 2784/5000 [00:38<00:24, 90.21it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  56%|█████▌    | 2784/5000 [00:38<00:24, 90.21it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  56%|█████▋    | 2816/5000 [00:38<00:23, 92.05it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  56%|█████▋    | 2816/5000 [00:38<00:23, 92.05it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  57%|█████▋    | 2848/5000 [00:39<00:24, 89.07it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  57%|█████▋    | 2848/5000 [00:39<00:24, 89.07it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  58%|█████▊    | 2880/5000 [00:39<00:23, 90.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  58%|█████▊    | 2880/5000 [00:39<00:23, 90.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  58%|█████▊    | 2912/5000 [00:39<00:22, 92.89it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  58%|█████▊    | 2912/5000 [00:39<00:22, 92.89it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  59%|█████▉    | 2944/5000 [00:40<00:22, 90.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  59%|█████▉    | 2944/5000 [00:40<00:22, 90.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  60%|█████▉    | 2976/5000 [00:40<00:22, 89.29it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  60%|█████▉    | 2976/5000 [00:40<00:22, 89.29it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  60%|██████    | 3008/5000 [00:40<00:21, 91.42it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434:  60%|██████    | 3008/5000 [00:40<00:21, 91.42it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434:  61%|██████    | 3040/5000 [00:41<00:21, 90.81it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  61%|██████    | 3040/5000 [00:41<00:21, 90.81it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  61%|██████▏   | 3072/5000 [00:41<00:21, 91.65it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  61%|██████▏   | 3072/5000 [00:41<00:21, 91.65it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  62%|██████▏   | 3104/5000 [00:41<00:20, 93.16it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  62%|██████▏   | 3104/5000 [00:41<00:20, 93.16it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  63%|██████▎   | 3136/5000 [00:42<00:19, 94.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  63%|██████▎   | 3136/5000 [00:42<00:19, 94.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  63%|██████▎   | 3168/5000 [00:42<00:19, 93.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  63%|██████▎   | 3168/5000 [00:42<00:19, 93.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  64%|██████▍   | 3200/5000 [00:42<00:18, 94.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  64%|██████▍   | 3200/5000 [00:42<00:18, 94.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  65%|██████▍   | 3232/5000 [00:50<02:12, 13.38it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  65%|██████▍   | 3232/5000 [00:50<02:12, 13.38it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  65%|██████▌   | 3264/5000 [00:50<01:37, 17.83it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  65%|██████▌   | 3264/5000 [00:50<01:37, 17.83it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  66%|██████▌   | 3296/5000 [00:50<01:13, 23.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  66%|██████▌   | 3296/5000 [00:50<01:13, 23.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  67%|██████▋   | 3328/5000 [00:51<00:55, 29.87it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  67%|██████▋   | 3328/5000 [00:51<00:55, 29.87it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  67%|██████▋   | 3360/5000 [00:51<00:44, 36.95it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  67%|██████▋   | 3360/5000 [00:51<00:44, 36.95it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  68%|██████▊   | 3392/5000 [00:52<00:36, 44.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  68%|██████▊   | 3392/5000 [00:52<00:36, 44.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  68%|██████▊   | 3424/5000 [00:52<00:30, 51.78it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  68%|██████▊   | 3424/5000 [00:52<00:30, 51.78it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  69%|██████▉   | 3456/5000 [00:52<00:26, 58.47it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  69%|██████▉   | 3456/5000 [00:52<00:26, 58.47it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  70%|██████▉   | 3488/5000 [00:53<00:23, 64.95it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  70%|██████▉   | 3488/5000 [00:53<00:23, 64.95it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  70%|███████   | 3520/5000 [00:53<00:21, 69.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  70%|███████   | 3520/5000 [00:53<00:21, 69.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  71%|███████   | 3552/5000 [00:54<00:19, 72.87it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  71%|███████   | 3552/5000 [00:54<00:19, 72.87it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  72%|███████▏  | 3584/5000 [00:54<00:18, 75.22it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  72%|███████▏  | 3584/5000 [00:54<00:18, 75.22it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  72%|███████▏  | 3616/5000 [00:54<00:17, 78.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  72%|███████▏  | 3616/5000 [00:54<00:17, 78.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  73%|███████▎  | 3648/5000 [00:55<00:16, 80.90it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  73%|███████▎  | 3648/5000 [00:55<00:16, 80.90it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  74%|███████▎  | 3680/5000 [00:55<00:15, 82.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  74%|███████▎  | 3680/5000 [00:55<00:15, 82.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  74%|███████▍  | 3712/5000 [00:55<00:15, 84.10it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  74%|███████▍  | 3712/5000 [00:55<00:15, 84.10it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  75%|███████▍  | 3744/5000 [00:56<00:14, 83.85it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  75%|███████▍  | 3744/5000 [00:56<00:14, 83.85it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  76%|███████▌  | 3776/5000 [00:56<00:14, 84.22it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  76%|███████▌  | 3776/5000 [00:56<00:14, 84.22it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  76%|███████▌  | 3808/5000 [00:56<00:14, 84.97it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  76%|███████▌  | 3808/5000 [00:56<00:14, 84.97it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  77%|███████▋  | 3840/5000 [00:57<00:13, 85.76it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  77%|███████▋  | 3840/5000 [00:57<00:13, 85.76it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  77%|███████▋  | 3872/5000 [00:57<00:12, 88.48it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  77%|███████▋  | 3872/5000 [00:57<00:12, 88.48it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  78%|███████▊  | 3904/5000 [00:58<00:12, 89.26it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  78%|███████▊  | 3904/5000 [00:58<00:12, 89.26it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  79%|███████▊  | 3936/5000 [00:58<00:11, 90.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  79%|███████▊  | 3936/5000 [00:58<00:11, 90.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  79%|███████▉  | 3968/5000 [00:58<00:11, 89.77it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  79%|███████▉  | 3968/5000 [00:58<00:11, 89.77it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  80%|████████  | 4000/5000 [00:59<00:11, 89.92it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 5.5556:  80%|████████  | 4000/5000 [00:59<00:11, 89.92it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 5.5556:  81%|████████  | 4032/5000 [00:59<00:10, 88.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  81%|████████  | 4032/5000 [00:59<00:10, 88.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  81%|████████▏ | 4064/5000 [00:59<00:10, 86.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  81%|████████▏ | 4064/5000 [00:59<00:10, 86.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  82%|████████▏ | 4096/5000 [01:00<00:10, 84.34it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  82%|████████▏ | 4096/5000 [01:00<00:10, 84.34it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  83%|████████▎ | 4128/5000 [01:00<00:10, 85.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 5.5556:  83%|████████▎ | 4128/5000 [01:00<00:10, 85.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 5.5556:  83%|████████▎ | 4160/5000 [01:00<00:09, 87.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  83%|████████▎ | 4160/5000 [01:00<00:09, 87.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  84%|████████▍ | 4192/5000 [01:01<00:08, 90.26it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  84%|████████▍ | 4192/5000 [01:01<00:08, 90.26it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  84%|████████▍ | 4224/5000 [01:01<00:08, 92.19it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  84%|████████▍ | 4224/5000 [01:01<00:08, 92.19it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  85%|████████▌ | 4256/5000 [01:01<00:08, 92.86it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 5.5556:  85%|████████▌ | 4256/5000 [01:01<00:08, 92.86it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 5.5556:  86%|████████▌ | 4288/5000 [01:02<00:07, 90.60it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  86%|████████▌ | 4288/5000 [01:02<00:07, 90.60it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  86%|████████▋ | 4320/5000 [01:02<00:07, 92.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  86%|████████▋ | 4320/5000 [01:02<00:07, 92.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  87%|████████▋ | 4352/5000 [01:03<00:06, 93.60it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  87%|████████▋ | 4352/5000 [01:03<00:06, 93.60it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  88%|████████▊ | 4384/5000 [01:03<00:06, 92.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  88%|████████▊ | 4384/5000 [01:03<00:06, 92.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  88%|████████▊ | 4416/5000 [01:03<00:06, 92.21it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  88%|████████▊ | 4416/5000 [01:03<00:06, 92.21it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  89%|████████▉ | 4448/5000 [01:04<00:05, 92.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  89%|████████▉ | 4448/5000 [01:04<00:05, 92.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  90%|████████▉ | 4480/5000 [01:04<00:05, 89.09it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  90%|████████▉ | 4480/5000 [01:04<00:05, 89.09it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  90%|█████████ | 4512/5000 [01:04<00:05, 87.57it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556:  90%|█████████ | 4512/5000 [01:04<00:05, 87.57it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556:  91%|█████████ | 4544/5000 [01:05<00:05, 87.44it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556:  91%|█████████ | 4544/5000 [01:05<00:05, 87.44it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556:  92%|█████████▏| 4576/5000 [01:05<00:04, 86.47it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  92%|█████████▏| 4576/5000 [01:05<00:04, 86.47it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  92%|█████████▏| 4608/5000 [01:05<00:04, 85.84it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  92%|█████████▏| 4608/5000 [01:05<00:04, 85.84it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  93%|█████████▎| 4640/5000 [01:06<00:04, 85.98it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  93%|█████████▎| 4640/5000 [01:06<00:04, 85.98it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  93%|█████████▎| 4672/5000 [01:06<00:03, 87.97it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556:  93%|█████████▎| 4672/5000 [01:06<00:03, 87.97it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556:  94%|█████████▍| 4704/5000 [01:07<00:03, 87.37it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 5.5556:  94%|█████████▍| 4704/5000 [01:07<00:03, 87.37it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 5.5556:  95%|█████████▍| 4736/5000 [01:07<00:02, 90.46it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  95%|█████████▍| 4736/5000 [01:07<00:02, 90.46it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  95%|█████████▌| 4768/5000 [01:07<00:02, 93.67it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  95%|█████████▌| 4768/5000 [01:07<00:02, 93.67it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  96%|█████████▌| 4800/5000 [01:08<00:02, 93.89it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  96%|█████████▌| 4800/5000 [01:08<00:02, 93.89it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  97%|█████████▋| 4832/5000 [01:08<00:01, 91.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  97%|█████████▋| 4832/5000 [01:08<00:01, 91.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  97%|█████████▋| 4864/5000 [01:08<00:01, 89.91it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  97%|█████████▋| 4864/5000 [01:08<00:01, 89.91it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  98%|█████████▊| 4896/5000 [01:09<00:01, 88.75it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  98%|█████████▊| 4896/5000 [01:09<00:01, 88.75it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  99%|█████████▊| 4928/5000 [01:09<00:00, 90.36it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 5.5556:  99%|█████████▊| 4928/5000 [01:09<00:00, 90.36it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 5.5556:  99%|█████████▉| 4960/5000 [01:09<00:00, 93.68it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  99%|█████████▉| 4960/5000 [01:09<00:00, 93.68it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 100%|█████████▉| 4992/5000 [01:10<00:00, 91.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 100%|█████████▉| 4992/5000 [01:10<00:00, 91.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: : 5024it [01:10, 90.51it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556: : 5024it [01:10, 90.51it/s]

我们现在可以快速检查包含结果的 CSV。

def print_csv_files_in_folder(folder_path):
    """
    Find all CSV files in a folder and prints the first 10 lines of each file.

    Args:
        folder_path (str): The relative path to the folder.

    """
    csv_files = []
    output_str = ""
    for dirpath, _, filenames in os.walk(folder_path):
        for file in filenames:
            if file.endswith(".csv"):
                csv_files.append(os.path.join(dirpath, file))
    for csv_file in csv_files:
        output_str += f"File: {csv_file}\n"
        with open(csv_file, "r") as f:
            for i, line in enumerate(f):
                if i == 10:
                    break
                output_str += line.strip() + "\n"
        output_str += "\n"
    print(output_str)


print_csv_files_in_folder(logger.experiment.log_dir)
File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/r_training.csv
512,0.3566153347492218
1024,0.39912936091423035
1536,0.39912936091423035
2048,0.39912936091423035
2560,0.42945271730422974
3072,0.40213119983673096
3584,0.39912933111190796
4096,0.42945271730422974
4608,0.42945271730422974

File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/optim_steps.csv
512,128.0
1024,256.0
1536,384.0
2048,512.0
2560,640.0
3072,768.0
3584,896.0
4096,1024.0
4608,1152.0

File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/loss.csv
512,0.47876793146133423
1024,0.18667784333229065
1536,0.1948033571243286
2048,0.22345909476280212
2560,0.2145865112543106
3072,0.47586697340011597
3584,0.28343674540519714
4096,0.3203103542327881
4608,0.3053428530693054

File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/grad_norm_0.csv
512,5.5816755294799805
1024,2.9089717864990234
1536,3.4687838554382324
2048,2.8756051063537598
2560,2.7815587520599365
3072,6.685841083526611
3584,3.793360948562622
4096,3.469670295715332
4608,3.317387104034424

File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/rewards.csv
3232,0.10000000894069672

File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/total_rewards.csv
3232,5.555555820465088

结论和可能的改进

在本教程中,我们学习了

  • 如何编写训练器,包括构建其组件并在训练器中注册它们;

  • 如何编写 DQN 算法,包括如何创建使用 QValueNetwork 选择具有最高值的动作的策略;

  • 如何构建多进程数据收集器;

对此教程可能的改进包括

  • 也可以使用优先重放缓冲区。这将为价值准确性最差的样本提供更高的优先级。在文档的 重放缓冲区部分 了解更多信息。

  • 分布式损失(有关更多信息,请参阅 DistributionalDQNLoss)。

  • 更高级的探索技术,例如 NoisyLinear 层等。

脚本的总运行时间: (2 分 40.957 秒)

估计内存使用量: 1267 MB

Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源