快捷方式

torchrl.objectives 包

TorchRL 提供了一系列用于训练脚本的损失函数。目的是拥有易于重用/交换且具有简单签名的损失函数。

TorchRL 损失函数的主要特征是

  • 它们是有状态的对象:它们包含可训练参数的副本,因此 loss_module.parameters() 会返回训练算法所需的一切。

  • 它们遵循 tensordict 约定:torch.nn.Module.forward() 方法会接收一个 tensordict 作为输入,其中包含返回损失值所需的所有信息。

  • 它们输出一个 tensordict.TensorDict 实例,其中损失值写入 "loss_<smth>" 下,其中 smth 是描述损失的字符串。tensordict 中的其他键可能是训练期间有用的指标。

注意

我们返回独立损失的原因是为了让用户例如可以对不同的参数集使用不同的优化器。损失的求和可以通过以下方式简单地完成

>>> loss_val = sum(loss for key, loss in loss_vals.items() if key.startswith("loss_"))

注意

损失中的参数初始化可以通过查询 get_stateful_net() 来完成,它将返回网络的有状态版本,该版本可以像任何其他模块一样进行初始化。如果修改是就地完成的,它将下游传播到使用相同参数集的任何其他模块(在损失内部和外部):例如,修改损失中的 actor_network 参数也将修改收集器中的 actor。如果参数是异地修改的,则可以使用 from_stateful_net() 将损失中的参数重置为新值。

torch.vmap 和随机性

TorchRL 损失模块有很多对 vmap() 的调用,以摊销在循环中调用多个相似模型的成本,并将其操作向量化。vmap 需要被明确告知在调用中需要生成随机数时该怎么做。为此,需要设置随机性模式,并且必须是 “error”(默认值,在处理伪随机函数时出错)、“same”(在批次中复制结果)或 “different”(批次的每个元素都被单独处理)之一。依赖于默认值通常会导致类似于此的错误

>>> RuntimeError: vmap: called random operation while in randomness error mode.

由于对 vmap 的调用隐藏在损失模块中,因此 TorchRL 提供了一个接口,可以通过 loss.vmap_randomness = str_value 从外部设置该 vmap 模式,有关更多信息,请参见 vmap_randomness()

LossModule.vmap_randomness 如果未检测到随机模块,则默认为 “error”,否则默认为 “different”。默认情况下,只有有限数量的模块被列为随机模块,但可以使用 add_random_module() 函数扩展该列表。

训练价值函数

TorchRL 提供了一系列**价值估计器**,例如 TD(0)、TD(1)、TD(\(\lambda\)) 和 GAE。简而言之,价值估计器是数据的函数(主要是奖励和完成状态)和状态值(即估计状态值的函数返回的值)。要详细了解价值估计器,请查看来自 Sutton 和 Barto 的强化学习介绍,特别是关于值迭代和 TD 学习的章节。它基于数据和代理映射,对某个状态或状态-动作对之后获得的折扣回报进行某种有偏差的估计。这些估计器在两种情况下使用

  • 为了训练价值网络以学习“真实”状态值(或状态-动作值)映射,需要一个目标值来拟合它。估计器越好(偏差越小,方差越小),价值网络就越好,这反过来可以显着加快策略训练的速度。通常,价值网络损失将如下所示

    >>> value = value_network(states)
    >>> target_value = value_estimator(rewards, done, value_network(next_state))
    >>> value_net_loss = (value - target_value).pow(2).mean()
    
  • 为策略优化计算“优势”信号。优势是价值估计(来自估计器,即来自“真实”数据)和价值网络输出(即该价值的代理)之间的差值。正优势可以被视为一个信号,表明策略的实际表现优于预期,从而表明如果将该轨迹作为示例,则仍有改进的空间。相反,负优势表明策略的表现低于预期。

情况并不总是像上面示例中那样简单,计算价值估计器或优势的公式可能比这稍微复杂一些。为了帮助用户灵活地使用一个或另一个价值估计器,我们提供了一个简单的 API 来动态更改它。这是一个 DQN 示例,但所有模块都将遵循类似的结构

>>> from torchrl.objectives import DQNLoss, ValueEstimators
>>> loss_module = DQNLoss(actor)
>>> kwargs = {"gamma": 0.9, "lmbda": 0.9}
>>> loss_module.make_value_estimator(ValueEstimators.TDLambda, **kwargs)

ValueEstimators 类枚举了可供选择的价值估计器。这使得用户可以轻松依靠自动完成来做出选择。

LossModule(*args, **kwargs)

强化学习损失的父类。

DQN

DQNLoss(*args, **kwargs)

DQN 损失类。

DistributionalDQNLoss(*args, **kwargs)

分布式 DQN 损失类。

DDPG

DDPGLoss(*args, **kwargs)

DDPG 损失类。

SAC

SACLoss(*args, **kwargs)

TorchRL 实现的 SAC 损失函数。

DiscreteSACLoss(*args, **kwargs)

离散 SAC 损失模块。

REDQ

REDQLoss(*args, **kwargs)

REDQ 损失模块。

CrossQ

IQL

IQLLoss(*args, **kwargs)

TorchRL 实现的 IQL 损失函数。

DiscreteIQLLoss(*args, **kwargs)

TorchRL 实现的离散 IQL 损失函数。

CQL

CQLLoss(*args, **kwargs)

TorchRL 实现的连续 CQL 损失函数。

DiscreteCQLLoss(*args, **kwargs)

TorchRL 实现的离散 CQL 损失函数。

DT

DTLoss(*args, **kwargs)

TorchRL 实现的在线决策 Transformer 损失函数。

OnlineDTLoss(*args, **kwargs)

TorchRL 实现的在线决策 Transformer 损失函数。

TD3

TD3Loss(*args, **kwargs)

TD3 损失模块。

TD3+BC

TD3BCLoss(*args, **kwargs)

TD3+BC 损失模块。

PPO

PPOLoss(*args, **kwargs)

PPO 损失函数的父类。

ClipPPOLoss(*args, **kwargs)

裁剪后的 PPO 损失函数。

KLPENPPOLoss(*args, **kwargs)

KL 惩罚 PPO 损失函数。

A2C

A2CLoss(*args, **kwargs)

TorchRL 实现的 A2C 损失函数。

Reinforce

ReinforceLoss(*args, **kwargs)

Reinforce 损失模块。

Dreamer

DreamerActorLoss(*args, **kwargs)

Dreamer Actor 损失函数。

DreamerModelLoss(*args, **kwargs)

Dreamer 模型损失函数。

DreamerValueLoss(*args, **kwargs)

Dreamer 值函数损失函数。

多智能体目标函数

这些目标函数特定于多智能体算法。

QMixer

QMixerLoss(*args, **kwargs)

QMixer 损失函数类。

返回值

ValueEstimatorBase(*args, **kwargs)

值函数模块的抽象父类。

TD0Estimator(*args, **kwargs)

优势函数的时序差分 (TD(0)) 估计。

TD1Estimator(*args, **kwargs)

优势函数的\(\infty\)-时序差分 (TD(1)) 估计。

TDLambdaEstimator(*args, **kwargs)

优势函数的 TD(\(\lambda\)) 估计。

GAE(*args, **kwargs)

广义优势估计函数的类包装器。

functional.td0_return_estimate(gamma, ...[, ...])

轨迹的 TD(0) 折扣回报估计。

functional.td0_advantage_estimate(gamma, ...)

轨迹的 TD(0) 优势估计。

functional.td1_return_estimate(gamma, ...[, ...])

TD(1) 返回估计。

functional.vec_td1_return_estimate(gamma, ...)

矢量化 TD(1) 返回估计。

functional.td1_advantage_estimate(gamma, ...)

TD(1) 优势估计。

functional.vec_td1_advantage_estimate(gamma, ...)

矢量化 TD(1) 优势估计。

functional.td_lambda_return_estimate(gamma, ...)

TD(\(\lambda\)) 返回估计。

functional.vec_td_lambda_return_estimate(...)

矢量化 TD(\(\lambda\)) 返回估计。

functional.td_lambda_advantage_estimate(...)

TD(\(\lambda\)) 优势估计。

functional.vec_td_lambda_advantage_estimate(...)

矢量化 TD(\(\lambda\)) 优势估计。

functional.generalized_advantage_estimate(...)

轨迹的广义优势估计。

functional.vec_generalized_advantage_estimate(...)

轨迹的矢量化广义优势估计。

functional.reward2go(reward, done, gamma, *)

计算给定多个轨迹和情节结束时的奖励的折扣累积和。

实用程序

distance_loss(v1, v2, loss_function[, ...])

计算两个张量之间的距离损失。

hold_out_net(network)

用于将网络从计算图中剔除的上下文管理器。

hold_out_params(params)

用于将参数列表从计算图中剔除的上下文管理器。

next_state_value(tensordict[, operator, ...])

计算下一个状态值(无梯度)以计算目标值。

SoftUpdate(loss_module, *[, eps, tau])

用于双重 DQN/DDPG 中目标网络更新的软更新类。

HardUpdate(loss_module, *[, ...])

用于双重 DQN/DDPG 中目标网络更新的硬更新类(与软更新形成对比)。

ValueEstimators(value)

用于自定义构建的估计器的值函数枚举器。

default_value_kwargs(value_type)

默认值函数关键字参数生成器。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源