快捷方式

SafeProbabilisticModule

class torchrl.modules.tensordict_module.SafeProbabilisticModule(*args, **kwargs)[source]

tensordict.nn.ProbabilisticTensorDictModule 子类,接受一个 TensorSpec 作为参数来控制输出域。

SafeProbabilisticModule 是一个非参数化模块,表示一个概率分布。它使用指定的 in_keys 从输入 TensorDict 读取分布参数。输出是在给定规则的情况下进行采样的,该规则由输入 default_interaction_type 参数和 interaction_type() 全局函数指定。

SafeProbabilisticModule 可用于构建分布(通过 get_dist() 方法)和/或从该分布中采样(通过对模块进行常规的 __call__())。

一个 SafeProbabilisticModule 实例具有两个主要特征: - 它读取和写入 TensorDict 对象 - 它使用一个实际映射 R^n -> R^m 从 R^d 中创建一个分布,从中可以采样或计算值。

当调用 __call__ / forward 方法时,会创建一个分布,并计算一个值(使用 'mean'、'mode'、'median' 属性或 'rsample'、'sample' 方法)。如果提供的 TensorDict 已经包含所有所需的键值对,则会跳过采样步骤。

默认情况下,SafeProbabilisticModule 分布类是 Delta 分布,这使得 SafeProbabilisticModule 成为一个简单包装器,围绕着确定性映射函数。

参数:
  • in_keys (NestedKeylist of NestedKeydict) – 将从输入 TensorDict 读取并用于构建分布的键。重要的是,如果它是 NestedKey 的列表或 NestedKey,则这些键的叶子(最后一个元素)必须与目标分布类使用的关键字匹配,例如 "loc""scale" 用于正态分布,以及类似的情况。如果 in_keys 是字典,则键是分布的键,而值是 tensordict 中的键,这些键将与相应的分布键匹配。

  • out_keys (NestedKeylist of NestedKey) – 采样值将写入的键。重要的是,如果在输入 TensorDict 中找到这些键,则会跳过采样步骤。

  • spec (TensorSpec) – 第一个输出张量的规范。在调用 td_module.random() 生成目标空间中的随机值时使用。

  • safe (bool, optional) – 如果为 True,则会将样本的值与输入规范进行检查。由于探索策略或数值下溢/上溢问题,可能会出现域外采样。至于 spec 参数,此检查仅会针对分布样本进行,而不会针对输入模块返回的其他张量进行。如果样本超出范围,则会使用 TensorSpec.project 方法将其投影回所需的空間。默认值为 False

  • default_interaction_type (str, optional) – 用于检索输出值的默认方法。应该是以下之一:'mode'、'median'、'mean' 或 'random'(在这种情况下,值将从分布中随机采样)。默认值为 'mode'。注意:当绘制样本时,ProbabilisticTDModule 实例将首先查找由 interaction_typ() 全局函数指示的交互模式。如果返回的是 None(其默认值),则将使用 default_interaction_typeProbabilisticTDModule 实例。请注意,DataCollector 实例默认情况下会使用 tensordict.nn.set_interaction_type()tensordict.nn.InteractionType.RANDOM

  • distribution_class (Type, optional) – 用于采样的 torch.distributions.Distribution 类。默认值为 Delta。

  • distribution_kwargs (dict, optional) – 要传递给分布的关键字参数。

  • return_log_prob (bool, optional) – 如果为 True,则会将分布样本的对数概率写入 tensordict,键为 ‘sample_log_prob’。默认值为 False

  • log_prob_key (NestedKey, optional) – 如果 return_log_prob = True,则写入 log_prob 的键。默认值为 ‘sample_log_prob’

  • cache_dist (bool, optional) – 实验性:如果为 True,则会将分布的参数(即模块的输出)与样本一起写入 tensordict。这些参数可用于稍后重新计算原始分布(例如,计算用于采样动作的分布与 PPO 中更新的分布之间的差异)。默认值为 False

  • n_empirical_estimate (int, optional) – 计算经验均值时的样本数量,如果经验均值不可用。默认值为 1000

random(tensordict: TensorDictBase) TensorDictBase[source]

从目标空间中随机采样一个元素,与任何输入无关。

如果存在多个输出键,则仅将第一个输出键写入输入 tensordict 中。

参数:

tensordict (TensorDictBase) – 用于写入输出值的 tensordict。

返回值:

包含更新后的输出键值的原始 tensordict。

random_sample(tensordict: TensorDictBase) TensorDictBase[source]

参见 SafeModule.random(...)

文档

访问 PyTorch 的全面的开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源