注意
转到结尾 下载完整的示例代码。
TorchRL 训练器:DQN 示例¶
作者: Vincent Moens
TorchRL 提供一个通用的 Trainer
类来处理您的训练循环。训练器执行一个嵌套循环,其中外循环是数据收集,内循环使用此数据或从回放缓冲区检索到的某些数据来训练模型。在此训练循环的各个点,可以在给定的间隔处附加和执行钩子。
在本教程中,我们将使用训练器类从头开始训练 DQN 算法以解决 CartPole 任务。
主要要点
构建训练器及其基本组件:数据收集器、损失模块、回放缓冲区和优化器。
向训练器添加钩子,例如记录器、目标网络更新器等。
训练器是完全可定制的,并提供大量功能。本教程围绕其构造组织。我们将首先详细说明如何构建库的每个组件,然后使用 Trainer
类将这些部分组合在一起。
在过程中,我们还将关注库的其他一些方面
如何在 TorchRL 中构建环境,包括转换(例如数据归一化、帧连接、调整大小和转换为灰度)和并行执行。与我们在 DDPG 教程 中所做不同,我们将对像素进行归一化,而不是对状态向量进行归一化。
如何设计
QValueActor
对象,即估计动作值并选择具有最高估计回报的动作的演员;如何有效地从您的环境中收集数据并将其存储在回放缓冲区中;
如何使用多步,一个用于离策略算法的简单预处理步骤;
最后,如何评估您的模型。
先决条件:我们建议您首先通过 PPO 教程 熟悉 torchrl。
DQN¶
DQN (深度 Q 学习) 是深度强化学习的奠基之作。
从高层次上讲,该算法非常简单:Q 学习包括学习状态-动作值的表格,这样,当遇到任何特定状态时,我们只需要通过搜索具有最高值的动作来了解要选择哪个动作。这种简单的设置要求动作和状态是离散的,否则无法构建查找表。
DQN 使用一个神经网络,该网络对从状态-动作空间到值(标量)空间的映射进行编码,从而摊销了存储和探索所有可能状态-动作组合的成本:如果过去未曾看到某个状态,我们仍然可以将其与各种可用的动作结合起来传递到我们的神经网络中,并为每个可用的动作获得一个插值值。
我们将解决经典的控制问题,即推车和杆子。从我们从中检索到此环境的 Gymnasium 文档
我们不打算提供该算法的 SOTA 实现,而是提供该算法背景下 TorchRL 功能的高级说明。
import os
import uuid
import torch
from torch import nn
from torchrl.collectors import MultiaSyncDataCollector, SyncDataCollector
from torchrl.data import LazyMemmapStorage, MultiStep, TensorDictReplayBuffer
from torchrl.envs import (
EnvCreator,
ExplorationType,
ParallelEnv,
RewardScaling,
StepCounter,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
CatFrames,
Compose,
GrayScale,
ObservationNorm,
Resize,
ToTensorImage,
TransformedEnv,
)
from torchrl.modules import DuelingCnnDQNet, EGreedyModule, QValueActor
from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl.record.loggers.csv import CSVLogger
from torchrl.trainers import (
LogReward,
Recorder,
ReplayBufferTrainer,
Trainer,
UpdateWeights,
)
def is_notebook() -> bool:
try:
shell = get_ipython().__class__.__name__
if shell == "ZMQInteractiveShell":
return True # Jupyter notebook or qtconsole
elif shell == "TerminalInteractiveShell":
return False # Terminal running IPython
else:
return False # Other type (?)
except NameError:
return False # Probably standard Python interpreter
让我们开始创建算法所需的各种部分
一个环境;
一个策略(以及我们归类为“模型”的关联模块);
一个数据收集器,它使策略在环境中运行并提供训练数据;
一个回放缓冲区,用于存储训练数据;
一个损失模块,它计算用于训练策略以最大化回报的目标函数;
一个优化器,它根据我们的损失执行参数更新。
其他模块包括记录器、记录器(以“评估”模式执行策略)和目标网络更新器。有了所有这些组件,很容易看出在训练脚本中如何误放或误用一个组件。训练器就在那里为你协调一切!
构建环境¶
首先,让我们编写一个辅助函数来输出一个环境。像往常一样,“原始”环境可能过于简单而无法在实践中使用,我们将需要一些数据转换来将它的输出暴露给策略。
我们将使用五个转换
StepCounter
用于计算每个轨迹中的步数;ToTensorImage
将[W, H, C]
uint8 张量转换为形状为[C, W, H]
的浮点数张量,位于[0, 1]
空间中;RewardScaling
用于缩减回报的规模;GrayScale
将我们的图像转换为灰度;Resize
将图像调整为 64x64 格式;CatFrames
将会沿着通道维度将任意数量的连续帧 (N=4
) 连接到一个张量中。这对于单个图像无法携带关于小车杆运动的信息很有用。需要关于过去观测和动作的一些记忆,无论是通过循环神经网络还是使用一组帧。ObservationNorm
用于根据自定义的统计摘要对我们的观测结果进行归一化。
实际上,我们的环境构建器有两个参数
parallel
: 确定是否需要并行运行多个环境。我们在ParallelEnv
之后堆叠变换,以利用设备上操作的矢量化优势,尽管这在技术上适用于每个连接到自身一组变换的单独环境。obs_norm_sd
将包含用于ObservationNorm
变换的归一化常量。
def make_env(
parallel=False,
obs_norm_sd=None,
num_workers=1,
):
if obs_norm_sd is None:
obs_norm_sd = {"standard_normal": True}
if parallel:
def maker():
return GymEnv(
"CartPole-v1",
from_pixels=True,
pixels_only=True,
device=device,
)
base_env = ParallelEnv(
num_workers,
EnvCreator(maker),
# Don't create a sub-process if we have only one worker
serial_for_single=True,
mp_start_method=mp_context,
)
else:
base_env = GymEnv(
"CartPole-v1",
from_pixels=True,
pixels_only=True,
device=device,
)
env = TransformedEnv(
base_env,
Compose(
StepCounter(), # to count the steps of each trajectory
ToTensorImage(),
RewardScaling(loc=0.0, scale=0.1),
GrayScale(),
Resize(64, 64),
CatFrames(4, in_keys=["pixels"], dim=-3),
ObservationNorm(in_keys=["pixels"], **obs_norm_sd),
),
)
return env
计算归一化常数¶
为了归一化图像,我们不想使用完整的 [C, W, H]
归一化掩码独立地归一化每个像素,而是使用更简单的 [C, 1, 1]
形状的归一化常量集(位置和比例参数)。我们将使用 init_stats()
的 reduce_dim
参数来指示必须减少哪些维度,以及 keep_dims
参数来确保在该过程中不会消失所有维度
def get_norm_stats():
test_env = make_env()
test_env.transform[-1].init_stats(
num_iter=1000, cat_dim=0, reduce_dim=[-1, -2, -4], keep_dims=(-1, -2)
)
obs_norm_sd = test_env.transform[-1].state_dict()
# let's check that normalizing constants have a size of ``[C, 1, 1]`` where
# ``C=4`` (because of :class:`~torchrl.envs.CatFrames`).
print("state dict of the observation norm:", obs_norm_sd)
test_env.close()
del test_env
return obs_norm_sd
构建模型(深度 Q 网络)¶
以下函数构建了一个 DuelingCnnDQNet
对象,它是一个简单的 CNN,后面跟着一个两层 MLP。这里使用的唯一技巧是动作值(即左右动作值)使用以下方法计算
其中 \(\mathbb{v}\) 是我们的动作值向量,\(b\) 是一个 \(\mathbb{R}^n \rightarrow 1\) 函数,\(v\) 是一个 \(\mathbb{R}^n \rightarrow \mathbb{R}^m\) 函数,对于 \(n = \# obs\) 和 \(m = \# actions\)。
我们的网络被包装在一个 QValueActor
中,它将读取状态-动作值,选取具有最大值的那个,并将所有这些结果写入输入 tensordict.TensorDict
中。
def make_model(dummy_env):
cnn_kwargs = {
"num_cells": [32, 64, 64],
"kernel_sizes": [6, 4, 3],
"strides": [2, 2, 1],
"activation_class": nn.ELU,
# This can be used to reduce the size of the last layer of the CNN
# "squeeze_output": True,
# "aggregator_class": nn.AdaptiveAvgPool2d,
# "aggregator_kwargs": {"output_size": (1, 1)},
}
mlp_kwargs = {
"depth": 2,
"num_cells": [
64,
64,
],
"activation_class": nn.ELU,
}
net = DuelingCnnDQNet(
dummy_env.action_spec.shape[-1], 1, cnn_kwargs, mlp_kwargs
).to(device)
net.value[-1].bias.data.fill_(init_bias)
actor = QValueActor(net, in_keys=["pixels"], spec=dummy_env.action_spec).to(device)
# init actor: because the model is composed of lazy conv/linear layers,
# we must pass a fake batch of data through it to instantiate them.
tensordict = dummy_env.fake_tensordict()
actor(tensordict)
# we join our actor with an EGreedyModule for data collection
exploration_module = EGreedyModule(
spec=dummy_env.action_spec,
annealing_num_steps=total_frames,
eps_init=eps_greedy_val,
eps_end=eps_greedy_val_env,
)
actor_explore = TensorDictSequential(actor, exploration_module)
return actor, actor_explore
收集和存储数据¶
经验回放缓冲区¶
经验回放缓冲区在像 DQN 这样的离策略 RL 算法中起着核心作用。它们构成了我们在训练期间从中采样的数据集。
在这里,我们将使用常规的采样策略,尽管优先级回放缓冲区可以显著提高性能。
我们使用 LazyMemmapStorage
类将存储放置在磁盘上。此存储以延迟方式创建:只有在第一个数据批次传递给它时才会实例化它。
此存储的唯一要求是写入时传递给它的数据始终具有相同的形状。
def get_replay_buffer(buffer_size, n_optim, batch_size):
replay_buffer = TensorDictReplayBuffer(
batch_size=batch_size,
storage=LazyMemmapStorage(buffer_size),
prefetch=n_optim,
)
return replay_buffer
数据收集器¶
与 PPO 和 DDPG 一样,我们将使用数据收集器作为外部循环中的数据加载器。
我们选择以下配置:我们将并行地以同步方式在不同的收集器中运行一系列并行环境,这些收集器本身并行运行,但异步运行。
注意
此功能仅在 Python 多进程库的“spawn”启动方法中运行代码时可用。如果本教程直接作为脚本运行(因此使用“fork”方法),我们将使用常规的 SyncDataCollector
。
此配置的优势在于,我们可以平衡批量执行的计算量以及我们希望异步执行的计算量。我们鼓励读者尝试通过修改收集器数量(即传递给收集器的环境构造器数量)和每个收集器中并行执行的环境数量(由 num_workers
超参数控制)来了解收集速度如何受到影响。
收集器的设备完全可以通过 device
(通用)、policy_device
、env_device
和 storing_device
参数进行参数化。 storing_device
参数将修改正在收集的数据的位置:如果我们正在收集的批次大小可观,我们可能希望将它们存储在与计算发生的位置不同的位置。对于像我们这样的异步数据收集器,不同的存储设备意味着我们收集的数据每次不会位于同一个设备上,这是我们的训练循环必须考虑的事情。为简单起见,我们将所有子收集器的设备设置为相同的值。
def get_collector(
stats,
num_collectors,
actor_explore,
frames_per_batch,
total_frames,
device,
):
# We can't use nested child processes with mp_start_method="fork"
if is_fork:
cls = SyncDataCollector
env_arg = make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers)
else:
cls = MultiaSyncDataCollector
env_arg = [
make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers)
] * num_collectors
data_collector = cls(
env_arg,
policy=actor_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
# this is the default behaviour: the collector runs in ``"random"`` (or explorative) mode
exploration_type=ExplorationType.RANDOM,
# We set the all the devices to be identical. Below is an example of
# heterogeneous devices
device=device,
storing_device=device,
split_trajs=False,
postproc=MultiStep(gamma=gamma, n_steps=5),
)
return data_collector
损失函数¶
构建损失函数很简单:我们只需要向 DQNLoss 类提供模型和一堆超参数。
目标参数¶
许多离策略 RL 算法在估计下一个状态或状态-动作对的值时使用“目标参数”的概念。目标参数是模型参数的滞后副本。由于它们的预测与当前模型配置的预测不匹配,它们通过对正在估计的值施加悲观边界来帮助学习。这是一个强大的技巧(称为“双 Q 学习”),在类似算法中无处不在。
def get_loss_module(actor, gamma):
loss_module = DQNLoss(actor, delay_value=True)
loss_module.make_value_estimator(gamma=gamma)
target_updater = SoftUpdate(loss_module, eps=0.995)
return loss_module, target_updater
超参数¶
让我们从超参数开始。以下设置在实践中应该很好地工作,并且算法的性能应该不会对这些设置的轻微变化过于敏感。
is_fork = multiprocessing.get_start_method() == "fork"
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)
优化器¶
# the learning rate of the optimizer
lr = 2e-3
# weight decay
wd = 1e-5
# the beta parameters of Adam
betas = (0.9, 0.999)
# Optimization steps per batch collected (aka UPD or updates per data)
n_optim = 8
DQN 参数¶
gamma 衰减因子
gamma = 0.99
平滑目标网络更新衰减参数。这大致对应于具有硬目标网络更新的 1/tau 间隔
tau = 0.02
数据收集和经验回放缓冲区¶
注意
用于正确训练的值已注释。
在环境中收集的总帧数。在其他实现中,用户定义了最大情节数。使用我们的数据收集器很难做到这一点,因为它们返回 N 个收集帧的批次,其中 N 是一个常量。但是,可以通过在收集到一定数量的情节后中断训练循环来轻松获得对情节数量的相同限制。
total_frames = 5_000 # 500000
用于初始化经验回放缓冲区的随机帧。
init_random_frames = 100 # 1000
每个收集批次中的帧数。
frames_per_batch = 32 # 128
在每个优化步骤中从经验回放缓冲区采样的帧数
batch_size = 32 # 256
经验回放缓冲区的大小(以帧为单位)
buffer_size = min(total_frames, 100000)
每个数据收集器中并行运行的环境数量
num_workers = 2 # 8
num_collectors = 2 # 4
环境和探索¶
我们设置了 Epsilon-greedy 探索中 Epsilon 因子的初始值和最终值。由于我们的策略是确定性的,探索至关重要:如果没有它,唯一的随机性来源将是环境重置。
eps_greedy_val = 0.1
eps_greedy_val_env = 0.005
为了加快学习速度,我们将值网络最后一层的偏差设置为预定义的值(这不是强制性的)
init_bias = 2.0
注意
为了快速渲染教程,total_frames
超参数被设置为一个非常小的数字。要获得合理的性能,请使用更大的值,例如 500000
构建训练器¶
TorchRL 的 Trainer
类构造函数采用以下关键字参数
收集器
损失模块
优化器
logger
: 日志记录器可以是total_frames
: 此参数定义了训练器的寿命。frame_skip
: 当使用帧跳过时,必须让收集器知道它,以便准确地计算收集到的帧数等。让训练器知道此参数不是强制性的,但有助于在帧总数(预算)固定而帧跳过可变的情况下进行更公平的比较。
stats = get_norm_stats()
test_env = make_env(parallel=False, obs_norm_sd=stats)
# Get model
actor, actor_explore = make_model(test_env)
loss_module, target_net_updater = get_loss_module(actor, gamma)
collector = get_collector(
stats=stats,
num_collectors=num_collectors,
actor_explore=actor_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
)
optimizer = torch.optim.Adam(
loss_module.parameters(), lr=lr, weight_decay=wd, betas=betas
)
exp_name = f"dqn_exp_{uuid.uuid1()}"
tmpdir = tempfile.TemporaryDirectory()
logger = CSVLogger(exp_name=exp_name, log_dir=tmpdir.name)
warnings.warn(f"log dir: {logger.experiment.log_dir}")
state dict of the observation norm: OrderedDict([('standard_normal', tensor(True)), ('loc', tensor([[[0.9895]],
[[0.9895]],
[[0.9895]],
[[0.9895]]])), ('scale', tensor([[[0.0737]],
[[0.0737]],
[[0.0737]],
[[0.0737]]]))])
我们可以控制标量记录的频率。这里我们将它设置为一个较小的值,因为我们的训练循环很短
注册钩子¶
注册钩子可以通过两种不同的方式实现
如果钩子有它,
register()
方法是首选。只需要提供训练器作为输入,钩子将使用默认名称在默认位置注册。对于某些钩子,注册可能相当复杂:ReplayBufferTrainer
需要 3 个钩子 (extend
、sample
和update_priority
),这些钩子可能难以实现。
buffer_hook = ReplayBufferTrainer(
get_replay_buffer(buffer_size, n_optim, batch_size=batch_size),
flatten_tensordicts=True,
)
buffer_hook.register(trainer)
weight_updater = UpdateWeights(collector, update_weights_interval=1)
weight_updater.register(trainer)
recorder = Recorder(
record_interval=100, # log every 100 optimization steps
record_frames=1000, # maximum number of frames in the record
frame_skip=1,
policy_exploration=actor_explore,
environment=test_env,
exploration_type=ExplorationType.MODE,
log_keys=[("next", "reward")],
out_keys={("next", "reward"): "rewards"},
log_pbar=True,
)
recorder.register(trainer)
探索模块的 Epsilon 因子也被退火
trainer.register_op("post_steps", actor_explore[1].step, frames=frames_per_batch)
任何可调用对象(包括
TrainerHookBase
子类)都可以使用register_op()
注册。在这种情况下,必须明确传递位置 ()。这种方法可以更好地控制钩子的位置,但也需要更多地了解 Trainer 机制。查看 trainer 文档 以详细了解 trainer 钩子。
trainer.register_op("post_optim", target_net_updater.step)
我们也可以记录训练奖励。请注意,这对于 CartPole 来说意义不大,因为奖励始终为 1。奖励的折现总和并非通过获得更高的奖励来最大化,而是通过使车杆保持更长时间来最大化。这将反映在进度条中显示的 total_rewards 值中。
log_reward = LogReward(log_pbar=True)
log_reward.register(trainer)
注意
如果需要,可以将多个优化器链接到训练器。在这种情况下,每个优化器将绑定到损失字典中的一个字段。查看 OptimizerHook
了解更多信息。
我们已经准备好训练我们的算法!只需简单调用 trainer.train()
,我们就可以将结果记录下来。
trainer.train()
0%| | 0/5000 [00:00<?, ?it/s]
1%| | 32/5000 [00:07<19:22, 4.27it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 1%| | 32/5000 [00:07<19:22, 4.27it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 1%|▏ | 64/5000 [00:07<08:27, 9.73it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 1%|▏ | 64/5000 [00:07<08:27, 9.73it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 2%|▏ | 96/5000 [00:08<04:57, 16.46it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 2%|▏ | 96/5000 [00:08<04:57, 16.46it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 3%|▎ | 128/5000 [00:08<03:19, 24.43it/s]
r_training: 0.3172, rewards: 0.1000, total_rewards: 0.9434: 3%|▎ | 128/5000 [00:08<03:19, 24.43it/s]
r_training: 0.3172, rewards: 0.1000, total_rewards: 0.9434: 3%|▎ | 160/5000 [00:08<02:24, 33.46it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 3%|▎ | 160/5000 [00:08<02:24, 33.46it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 4%|▍ | 192/5000 [00:09<01:51, 42.98it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 4%|▍ | 192/5000 [00:09<01:51, 42.98it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 4%|▍ | 224/5000 [00:09<01:30, 52.66it/s]
r_training: 0.3415, rewards: 0.1000, total_rewards: 0.9434: 4%|▍ | 224/5000 [00:09<01:30, 52.66it/s]
r_training: 0.3415, rewards: 0.1000, total_rewards: 0.9434: 5%|▌ | 256/5000 [00:09<01:17, 61.24it/s]
r_training: 0.3415, rewards: 0.1000, total_rewards: 0.9434: 5%|▌ | 256/5000 [00:09<01:17, 61.24it/s]
r_training: 0.3415, rewards: 0.1000, total_rewards: 0.9434: 6%|▌ | 288/5000 [00:10<01:09, 68.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 6%|▌ | 288/5000 [00:10<01:09, 68.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 6%|▋ | 320/5000 [00:10<01:03, 73.36it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 6%|▋ | 320/5000 [00:10<01:03, 73.36it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 7%|▋ | 352/5000 [00:10<01:00, 77.01it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 7%|▋ | 352/5000 [00:10<01:00, 77.01it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 8%|▊ | 384/5000 [00:11<00:56, 81.86it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 8%|▊ | 384/5000 [00:11<00:56, 81.86it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 8%|▊ | 416/5000 [00:11<00:53, 85.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 8%|▊ | 416/5000 [00:11<00:53, 85.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 9%|▉ | 448/5000 [00:11<00:52, 87.25it/s]
r_training: 0.3899, rewards: 0.1000, total_rewards: 0.9434: 9%|▉ | 448/5000 [00:11<00:52, 87.25it/s]
r_training: 0.3899, rewards: 0.1000, total_rewards: 0.9434: 10%|▉ | 480/5000 [00:12<00:50, 89.90it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 10%|▉ | 480/5000 [00:12<00:50, 89.90it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 10%|█ | 512/5000 [00:12<00:50, 89.34it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 10%|█ | 512/5000 [00:12<00:50, 89.34it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 11%|█ | 544/5000 [00:12<00:48, 91.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 11%|█ | 544/5000 [00:12<00:48, 91.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 12%|█▏ | 576/5000 [00:13<00:48, 92.05it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 12%|█▏ | 576/5000 [00:13<00:48, 92.05it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 12%|█▏ | 608/5000 [00:13<00:48, 90.24it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 12%|█▏ | 608/5000 [00:13<00:48, 90.24it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 13%|█▎ | 640/5000 [00:14<00:46, 93.16it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 13%|█▎ | 640/5000 [00:14<00:46, 93.16it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 13%|█▎ | 672/5000 [00:14<00:47, 91.51it/s]
r_training: 0.3385, rewards: 0.1000, total_rewards: 0.9434: 13%|█▎ | 672/5000 [00:14<00:47, 91.51it/s]
r_training: 0.3385, rewards: 0.1000, total_rewards: 0.9434: 14%|█▍ | 704/5000 [00:14<00:47, 90.33it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 14%|█▍ | 704/5000 [00:14<00:47, 90.33it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 15%|█▍ | 736/5000 [00:15<00:46, 91.78it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 15%|█▍ | 736/5000 [00:15<00:46, 91.78it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 15%|█▌ | 768/5000 [00:15<00:46, 91.99it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 15%|█▌ | 768/5000 [00:15<00:46, 91.99it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 16%|█▌ | 800/5000 [00:15<00:46, 90.97it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434: 16%|█▌ | 800/5000 [00:15<00:46, 90.97it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434: 17%|█▋ | 832/5000 [00:16<00:46, 88.81it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 17%|█▋ | 832/5000 [00:16<00:46, 88.81it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 17%|█▋ | 864/5000 [00:16<00:46, 88.58it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 17%|█▋ | 864/5000 [00:16<00:46, 88.58it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 18%|█▊ | 896/5000 [00:16<00:44, 91.40it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 18%|█▊ | 896/5000 [00:16<00:44, 91.40it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 19%|█▊ | 928/5000 [00:17<00:44, 92.46it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 19%|█▊ | 928/5000 [00:17<00:44, 92.46it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 19%|█▉ | 960/5000 [00:17<00:43, 93.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 19%|█▉ | 960/5000 [00:17<00:43, 93.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 20%|█▉ | 992/5000 [00:17<00:43, 92.90it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 20%|█▉ | 992/5000 [00:17<00:43, 92.90it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 20%|██ | 1024/5000 [00:18<00:42, 94.08it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 20%|██ | 1024/5000 [00:18<00:42, 94.08it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 21%|██ | 1056/5000 [00:18<00:42, 92.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 21%|██ | 1056/5000 [00:18<00:42, 92.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 22%|██▏ | 1088/5000 [00:18<00:41, 95.07it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 22%|██▏ | 1088/5000 [00:18<00:41, 95.07it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 22%|██▏ | 1120/5000 [00:19<00:41, 94.24it/s]
r_training: 0.3596, rewards: 0.1000, total_rewards: 0.9434: 22%|██▏ | 1120/5000 [00:19<00:41, 94.24it/s]
r_training: 0.3596, rewards: 0.1000, total_rewards: 0.9434: 23%|██▎ | 1152/5000 [00:19<00:41, 93.41it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 23%|██▎ | 1152/5000 [00:19<00:41, 93.41it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 24%|██▎ | 1184/5000 [00:19<00:40, 93.20it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 24%|██▎ | 1184/5000 [00:19<00:40, 93.20it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 24%|██▍ | 1216/5000 [00:20<00:40, 93.42it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 24%|██▍ | 1216/5000 [00:20<00:40, 93.42it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 25%|██▍ | 1248/5000 [00:20<00:39, 95.28it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 25%|██▍ | 1248/5000 [00:20<00:39, 95.28it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 26%|██▌ | 1280/5000 [00:20<00:40, 92.85it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 26%|██▌ | 1280/5000 [00:20<00:40, 92.85it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 26%|██▌ | 1312/5000 [00:21<00:40, 91.84it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 26%|██▌ | 1312/5000 [00:21<00:40, 91.84it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 27%|██▋ | 1344/5000 [00:21<00:39, 92.25it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 27%|██▋ | 1344/5000 [00:21<00:39, 92.25it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 28%|██▊ | 1376/5000 [00:21<00:39, 92.48it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434: 28%|██▊ | 1376/5000 [00:21<00:39, 92.48it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434: 28%|██▊ | 1408/5000 [00:22<00:39, 89.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 28%|██▊ | 1408/5000 [00:22<00:39, 89.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 29%|██▉ | 1440/5000 [00:22<00:39, 89.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 29%|██▉ | 1440/5000 [00:22<00:39, 89.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 29%|██▉ | 1472/5000 [00:23<00:38, 90.56it/s]
r_training: 0.3596, rewards: 0.1000, total_rewards: 0.9434: 29%|██▉ | 1472/5000 [00:23<00:38, 90.56it/s]
r_training: 0.3596, rewards: 0.1000, total_rewards: 0.9434: 30%|███ | 1504/5000 [00:23<00:39, 89.50it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 30%|███ | 1504/5000 [00:23<00:39, 89.50it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 31%|███ | 1536/5000 [00:23<00:38, 90.03it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 31%|███ | 1536/5000 [00:23<00:38, 90.03it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 31%|███▏ | 1568/5000 [00:24<00:37, 90.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 31%|███▏ | 1568/5000 [00:24<00:37, 90.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 32%|███▏ | 1600/5000 [00:24<00:36, 92.23it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 32%|███▏ | 1600/5000 [00:24<00:36, 92.23it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 33%|███▎ | 1632/5000 [00:24<00:37, 90.55it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 33%|███▎ | 1632/5000 [00:24<00:37, 90.55it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 33%|███▎ | 1664/5000 [00:25<00:37, 90.12it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 33%|███▎ | 1664/5000 [00:25<00:37, 90.12it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 34%|███▍ | 1696/5000 [00:25<00:36, 91.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 34%|███▍ | 1696/5000 [00:25<00:36, 91.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 35%|███▍ | 1728/5000 [00:25<00:36, 90.58it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 35%|███▍ | 1728/5000 [00:25<00:36, 90.58it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 35%|███▌ | 1760/5000 [00:26<00:36, 89.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 35%|███▌ | 1760/5000 [00:26<00:36, 89.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 36%|███▌ | 1792/5000 [00:26<00:34, 93.09it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 36%|███▌ | 1792/5000 [00:26<00:34, 93.09it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 36%|███▋ | 1824/5000 [00:26<00:34, 92.96it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 36%|███▋ | 1824/5000 [00:26<00:34, 92.96it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 37%|███▋ | 1856/5000 [00:27<00:33, 93.14it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 37%|███▋ | 1856/5000 [00:27<00:33, 93.14it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 38%|███▊ | 1888/5000 [00:27<00:33, 92.74it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 38%|███▊ | 1888/5000 [00:27<00:33, 92.74it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 38%|███▊ | 1920/5000 [00:27<00:33, 92.69it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 38%|███▊ | 1920/5000 [00:27<00:33, 92.69it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 39%|███▉ | 1952/5000 [00:28<00:33, 91.41it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 39%|███▉ | 1952/5000 [00:28<00:33, 91.41it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 40%|███▉ | 1984/5000 [00:28<00:32, 92.94it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 40%|███▉ | 1984/5000 [00:28<00:32, 92.94it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 40%|████ | 2016/5000 [00:29<00:32, 92.04it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 40%|████ | 2016/5000 [00:29<00:32, 92.04it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 41%|████ | 2048/5000 [00:29<00:31, 92.53it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 41%|████ | 2048/5000 [00:29<00:31, 92.53it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 42%|████▏ | 2080/5000 [00:29<00:32, 89.41it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 42%|████▏ | 2080/5000 [00:29<00:32, 89.41it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 42%|████▏ | 2112/5000 [00:30<00:32, 89.80it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 42%|████▏ | 2112/5000 [00:30<00:32, 89.80it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 43%|████▎ | 2144/5000 [00:30<00:32, 88.50it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 43%|████▎ | 2144/5000 [00:30<00:32, 88.50it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 44%|████▎ | 2176/5000 [00:30<00:31, 90.17it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 44%|████▎ | 2176/5000 [00:30<00:31, 90.17it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 44%|████▍ | 2208/5000 [00:31<00:31, 87.70it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 44%|████▍ | 2208/5000 [00:31<00:31, 87.70it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 45%|████▍ | 2240/5000 [00:31<00:31, 88.43it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 45%|████▍ | 2240/5000 [00:31<00:31, 88.43it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 45%|████▌ | 2272/5000 [00:31<00:31, 87.93it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 45%|████▌ | 2272/5000 [00:31<00:31, 87.93it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 46%|████▌ | 2304/5000 [00:32<00:30, 89.17it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 46%|████▌ | 2304/5000 [00:32<00:30, 89.17it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 47%|████▋ | 2336/5000 [00:32<00:30, 88.09it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 47%|████▋ | 2336/5000 [00:32<00:30, 88.09it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 47%|████▋ | 2368/5000 [00:32<00:29, 88.86it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 47%|████▋ | 2368/5000 [00:32<00:29, 88.86it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 48%|████▊ | 2400/5000 [00:33<00:28, 91.64it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 48%|████▊ | 2400/5000 [00:33<00:28, 91.64it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 49%|████▊ | 2432/5000 [00:33<00:28, 90.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 49%|████▊ | 2432/5000 [00:33<00:28, 90.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 49%|████▉ | 2464/5000 [00:34<00:27, 91.63it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 49%|████▉ | 2464/5000 [00:34<00:27, 91.63it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 50%|████▉ | 2496/5000 [00:34<00:27, 91.98it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 50%|████▉ | 2496/5000 [00:34<00:27, 91.98it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 51%|█████ | 2528/5000 [00:34<00:27, 89.48it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 51%|█████ | 2528/5000 [00:34<00:27, 89.48it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 51%|█████ | 2560/5000 [00:35<00:27, 89.53it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 51%|█████ | 2560/5000 [00:35<00:27, 89.53it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 52%|█████▏ | 2592/5000 [00:35<00:27, 89.09it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 52%|█████▏ | 2592/5000 [00:35<00:27, 89.09it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 52%|█████▏ | 2624/5000 [00:35<00:26, 89.68it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 52%|█████▏ | 2624/5000 [00:35<00:26, 89.68it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 53%|█████▎ | 2656/5000 [00:36<00:25, 92.80it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 53%|█████▎ | 2656/5000 [00:36<00:25, 92.80it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 54%|█████▍ | 2688/5000 [00:36<00:25, 90.67it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 54%|█████▍ | 2688/5000 [00:36<00:25, 90.67it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 54%|█████▍ | 2720/5000 [00:36<00:25, 90.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 54%|█████▍ | 2720/5000 [00:36<00:25, 90.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 55%|█████▌ | 2752/5000 [00:37<00:24, 92.32it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 55%|█████▌ | 2752/5000 [00:37<00:24, 92.32it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 56%|█████▌ | 2784/5000 [00:37<00:24, 90.36it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 56%|█████▌ | 2784/5000 [00:37<00:24, 90.36it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 56%|█████▋ | 2816/5000 [00:37<00:23, 91.64it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 56%|█████▋ | 2816/5000 [00:37<00:23, 91.64it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 57%|█████▋ | 2848/5000 [00:38<00:23, 93.53it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 57%|█████▋ | 2848/5000 [00:38<00:23, 93.53it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 58%|█████▊ | 2880/5000 [00:38<00:23, 91.80it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 58%|█████▊ | 2880/5000 [00:38<00:23, 91.80it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 58%|█████▊ | 2912/5000 [00:38<00:22, 93.54it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 58%|█████▊ | 2912/5000 [00:38<00:22, 93.54it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 59%|█████▉ | 2944/5000 [00:39<00:21, 94.19it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 59%|█████▉ | 2944/5000 [00:39<00:21, 94.19it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 60%|█████▉ | 2976/5000 [00:39<00:21, 95.39it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 60%|█████▉ | 2976/5000 [00:39<00:21, 95.39it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 60%|██████ | 3008/5000 [00:39<00:21, 93.11it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 60%|██████ | 3008/5000 [00:39<00:21, 93.11it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 61%|██████ | 3040/5000 [00:40<00:21, 92.74it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 61%|██████ | 3040/5000 [00:40<00:21, 92.74it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 61%|██████▏ | 3072/5000 [00:40<00:20, 94.36it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 61%|██████▏ | 3072/5000 [00:40<00:20, 94.36it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 62%|██████▏ | 3104/5000 [00:40<00:20, 94.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 62%|██████▏ | 3104/5000 [00:40<00:20, 94.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 63%|██████▎ | 3136/5000 [00:41<00:20, 90.60it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 63%|██████▎ | 3136/5000 [00:41<00:20, 90.60it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 63%|██████▎ | 3168/5000 [00:41<00:19, 92.78it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 63%|██████▎ | 3168/5000 [00:41<00:19, 92.78it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 64%|██████▍ | 3200/5000 [00:42<00:19, 90.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 64%|██████▍ | 3200/5000 [00:42<00:19, 90.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 65%|██████▍ | 3232/5000 [00:48<02:05, 14.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 65%|██████▍ | 3232/5000 [00:48<02:05, 14.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 65%|██████▌ | 3264/5000 [00:49<01:32, 18.79it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 65%|██████▌ | 3264/5000 [00:49<01:32, 18.79it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 66%|██████▌ | 3296/5000 [00:49<01:08, 24.70it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 66%|██████▌ | 3296/5000 [00:49<01:08, 24.70it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 67%|██████▋ | 3328/5000 [00:49<00:53, 31.47it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 67%|██████▋ | 3328/5000 [00:49<00:53, 31.47it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 67%|██████▋ | 3360/5000 [00:50<00:41, 39.15it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 67%|██████▋ | 3360/5000 [00:50<00:41, 39.15it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 68%|██████▊ | 3392/5000 [00:50<00:33, 47.31it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 68%|██████▊ | 3392/5000 [00:50<00:33, 47.31it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 68%|██████▊ | 3424/5000 [00:50<00:28, 55.38it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 68%|██████▊ | 3424/5000 [00:50<00:28, 55.38it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 69%|██████▉ | 3456/5000 [00:51<00:24, 62.68it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 69%|██████▉ | 3456/5000 [00:51<00:24, 62.68it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 70%|██████▉ | 3488/5000 [00:51<00:21, 68.93it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 70%|██████▉ | 3488/5000 [00:51<00:21, 68.93it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 70%|███████ | 3520/5000 [00:51<00:20, 73.99it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 6.6667: 70%|███████ | 3520/5000 [00:51<00:20, 73.99it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 6.6667: 71%|███████ | 3552/5000 [00:52<00:18, 79.67it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 71%|███████ | 3552/5000 [00:52<00:18, 79.67it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 72%|███████▏ | 3584/5000 [00:52<00:17, 81.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 72%|███████▏ | 3584/5000 [00:52<00:17, 81.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 72%|███████▏ | 3616/5000 [00:53<00:16, 83.02it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 72%|███████▏ | 3616/5000 [00:53<00:16, 83.02it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 73%|███████▎ | 3648/5000 [00:53<00:16, 83.59it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 73%|███████▎ | 3648/5000 [00:53<00:16, 83.59it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 74%|███████▎ | 3680/5000 [00:53<00:15, 85.55it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 74%|███████▎ | 3680/5000 [00:53<00:15, 85.55it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 74%|███████▍ | 3712/5000 [00:54<00:15, 85.86it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 74%|███████▍ | 3712/5000 [00:54<00:15, 85.86it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 75%|███████▍ | 3744/5000 [00:54<00:14, 86.99it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 6.6667: 75%|███████▍ | 3744/5000 [00:54<00:14, 86.99it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 6.6667: 76%|███████▌ | 3776/5000 [00:54<00:13, 90.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 76%|███████▌ | 3776/5000 [00:54<00:13, 90.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 76%|███████▌ | 3808/5000 [00:55<00:12, 92.52it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 76%|███████▌ | 3808/5000 [00:55<00:12, 92.52it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 77%|███████▋ | 3840/5000 [00:55<00:12, 93.65it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 77%|███████▋ | 3840/5000 [00:55<00:12, 93.65it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 77%|███████▋ | 3872/5000 [00:55<00:11, 94.13it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 77%|███████▋ | 3872/5000 [00:55<00:11, 94.13it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 78%|███████▊ | 3904/5000 [00:56<00:11, 94.43it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 78%|███████▊ | 3904/5000 [00:56<00:11, 94.43it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 79%|███████▊ | 3936/5000 [00:56<00:11, 95.61it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 79%|███████▊ | 3936/5000 [00:56<00:11, 95.61it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 79%|███████▉ | 3968/5000 [00:56<00:10, 96.93it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 79%|███████▉ | 3968/5000 [00:56<00:10, 96.93it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 80%|████████ | 4000/5000 [00:57<00:10, 93.21it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 80%|████████ | 4000/5000 [00:57<00:10, 93.21it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 81%|████████ | 4032/5000 [00:57<00:10, 91.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 81%|████████ | 4032/5000 [00:57<00:10, 91.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 81%|████████▏ | 4064/5000 [00:57<00:10, 88.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 81%|████████▏ | 4064/5000 [00:57<00:10, 88.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 82%|████████▏ | 4096/5000 [00:58<00:10, 88.45it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 82%|████████▏ | 4096/5000 [00:58<00:10, 88.45it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 83%|████████▎ | 4128/5000 [00:58<00:09, 87.25it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 83%|████████▎ | 4128/5000 [00:58<00:09, 87.25it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 83%|████████▎ | 4160/5000 [00:59<00:09, 88.03it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 83%|████████▎ | 4160/5000 [00:59<00:09, 88.03it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 84%|████████▍ | 4192/5000 [00:59<00:09, 89.32it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 84%|████████▍ | 4192/5000 [00:59<00:09, 89.32it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 84%|████████▍ | 4224/5000 [00:59<00:08, 88.51it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 84%|████████▍ | 4224/5000 [00:59<00:08, 88.51it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 85%|████████▌ | 4256/5000 [01:00<00:08, 90.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 85%|████████▌ | 4256/5000 [01:00<00:08, 90.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 86%|████████▌ | 4288/5000 [01:00<00:07, 92.98it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 86%|████████▌ | 4288/5000 [01:00<00:07, 92.98it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 86%|████████▋ | 4320/5000 [01:00<00:07, 91.38it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 86%|████████▋ | 4320/5000 [01:00<00:07, 91.38it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 87%|████████▋ | 4352/5000 [01:01<00:07, 91.61it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 87%|████████▋ | 4352/5000 [01:01<00:07, 91.61it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 88%|████████▊ | 4384/5000 [01:01<00:06, 94.17it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 88%|████████▊ | 4384/5000 [01:01<00:06, 94.17it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 88%|████████▊ | 4416/5000 [01:01<00:06, 94.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 88%|████████▊ | 4416/5000 [01:01<00:06, 94.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 89%|████████▉ | 4448/5000 [01:02<00:05, 92.92it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 89%|████████▉ | 4448/5000 [01:02<00:05, 92.92it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 90%|████████▉ | 4480/5000 [01:02<00:05, 91.40it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 90%|████████▉ | 4480/5000 [01:02<00:05, 91.40it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 90%|█████████ | 4512/5000 [01:02<00:05, 91.26it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 90%|█████████ | 4512/5000 [01:02<00:05, 91.26it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 91%|█████████ | 4544/5000 [01:03<00:04, 92.65it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 91%|█████████ | 4544/5000 [01:03<00:04, 92.65it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 92%|█████████▏| 4576/5000 [01:03<00:04, 93.42it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 92%|█████████▏| 4576/5000 [01:03<00:04, 93.42it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 92%|█████████▏| 4608/5000 [01:03<00:04, 94.25it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 92%|█████████▏| 4608/5000 [01:03<00:04, 94.25it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 93%|█████████▎| 4640/5000 [01:04<00:03, 92.66it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 93%|█████████▎| 4640/5000 [01:04<00:03, 92.66it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 93%|█████████▎| 4672/5000 [01:04<00:03, 95.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 93%|█████████▎| 4672/5000 [01:04<00:03, 95.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 94%|█████████▍| 4704/5000 [01:04<00:03, 95.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 94%|█████████▍| 4704/5000 [01:04<00:03, 95.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 95%|█████████▍| 4736/5000 [01:05<00:02, 91.72it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 95%|█████████▍| 4736/5000 [01:05<00:02, 91.72it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 95%|█████████▌| 4768/5000 [01:05<00:02, 93.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 95%|█████████▌| 4768/5000 [01:05<00:02, 93.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 96%|█████████▌| 4800/5000 [01:05<00:02, 90.68it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 6.6667: 96%|█████████▌| 4800/5000 [01:05<00:02, 90.68it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 6.6667: 97%|█████████▋| 4832/5000 [01:06<00:01, 89.91it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 97%|█████████▋| 4832/5000 [01:06<00:01, 89.91it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 97%|█████████▋| 4864/5000 [01:06<00:01, 88.96it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 97%|█████████▋| 4864/5000 [01:06<00:01, 88.96it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 98%|█████████▊| 4896/5000 [01:07<00:01, 90.07it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 98%|█████████▊| 4896/5000 [01:07<00:01, 90.07it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 99%|█████████▊| 4928/5000 [01:07<00:00, 90.39it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 99%|█████████▊| 4928/5000 [01:07<00:00, 90.39it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 99%|█████████▉| 4960/5000 [01:07<00:00, 92.28it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 99%|█████████▉| 4960/5000 [01:07<00:00, 92.28it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 100%|█████████▉| 4992/5000 [01:08<00:00, 90.27it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 100%|█████████▉| 4992/5000 [01:08<00:00, 90.27it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: : 5024it [01:08, 90.73it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: : 5024it [01:08, 90.73it/s]
我们现在可以快速查看包含结果的 CSV 文件。
def print_csv_files_in_folder(folder_path):
"""
Find all CSV files in a folder and prints the first 10 lines of each file.
Args:
folder_path (str): The relative path to the folder.
"""
csv_files = []
output_str = ""
for dirpath, _, filenames in os.walk(folder_path):
for file in filenames:
if file.endswith(".csv"):
csv_files.append(os.path.join(dirpath, file))
for csv_file in csv_files:
output_str += f"File: {csv_file}\n"
with open(csv_file, "r") as f:
for i, line in enumerate(f):
if i == 10:
break
output_str += line.strip() + "\n"
output_str += "\n"
print(output_str)
print_csv_files_in_folder(logger.experiment.log_dir)
File: /tmp/tmprz_tomjh/dqn_exp_1f3ef74e-4ec7-11ef-991e-0242ac110002/scalars/r_training.csv
512,0.38084372878074646
1024,0.37784188985824585
1536,0.41726210713386536
2048,0.36880600452423096
2560,0.39912933111190796
3072,0.39912936091423035
3584,0.42945271730422974
4096,0.42945271730422974
4608,0.39912933111190796
File: /tmp/tmprz_tomjh/dqn_exp_1f3ef74e-4ec7-11ef-991e-0242ac110002/scalars/optim_steps.csv
512,128.0
1024,256.0
1536,384.0
2048,512.0
2560,640.0
3072,768.0
3584,896.0
4096,1024.0
4608,1152.0
File: /tmp/tmprz_tomjh/dqn_exp_1f3ef74e-4ec7-11ef-991e-0242ac110002/scalars/loss.csv
512,0.13857470452785492
1024,0.15926682949066162
1536,0.1994696408510208
2048,0.21946831047534943
2560,0.25826987624168396
3072,0.30737149715423584
3584,0.24386540055274963
4096,0.34079253673553467
4608,0.2449716478586197
File: /tmp/tmprz_tomjh/dqn_exp_1f3ef74e-4ec7-11ef-991e-0242ac110002/scalars/grad_norm_0.csv
512,1.89058256149292
1024,2.0029563903808594
1536,3.236938714981079
2048,2.1101794242858887
2560,2.259946823120117
3072,2.8765692710876465
3584,3.375800609588623
4096,3.7260398864746094
4608,2.8490850925445557
File: /tmp/tmprz_tomjh/dqn_exp_1f3ef74e-4ec7-11ef-991e-0242ac110002/scalars/rewards.csv
3232,0.10000000894069672
File: /tmp/tmprz_tomjh/dqn_exp_1f3ef74e-4ec7-11ef-991e-0242ac110002/scalars/total_rewards.csv
3232,6.666667461395264
结论和可能的改进¶
在本教程中,我们学习了
如何编写训练器,包括构建其组件并在训练器中注册它们;
如何编写 DQN 算法,包括如何使用
QValueNetwork
创建一个选择具有最高值的动作的策略;如何构建一个多进程数据收集器;
本教程的可能改进包括
还可以使用优先级重放缓冲区。这将对价值准确性最差的样本赋予更高的优先级。在文档的 重放缓冲区部分 中了解更多信息。
分布式损失(有关更多信息,请参阅
DistributionalDQNLoss
)。更多花哨的探索技术,例如
NoisyLinear
层等。
脚本的总运行时间:(2 分钟 37.294 秒)
估计的内存使用量:1018 MB