SafeProbabilisticModule¶
- class torchrl.modules.tensordict_module.SafeProbabilisticModule(*args, **kwargs)[source]¶
tensordict.nn.ProbabilisticTensorDictModule
的子类,接受TensorSpec
作为参数来控制输出域。SafeProbabilisticModule 是一个非参数模块,表示概率分布。它使用指定的 in_keys 从输入 TensorDict 中读取分布参数。输出根据输入
default_interaction_type
参数和interaction_type()
全局函数指定的规则进行采样。SafeProbabilisticModule
可用于构建分布(通过get_dist()
方法)和/或从此分布中采样(通过对模块的常规__call__()
调用)。SafeProbabilisticModule
实例有两个主要特点: - 它读取和写入 TensorDict 对象 - 它使用实数映射 R^n -> R^m 来创建 R^d 中的分布,从中可以采样或计算值。当调用
__call__
/forward
方法时,会创建一个分布,并计算一个值(使用 ‘mean’、‘mode’、‘median’ 属性或 ‘rsample’、‘sample’ 方法)。如果提供的 TensorDict 已经拥有所有期望的键值对,则会跳过采样步骤。默认情况下,SafeProbabilisticModule 的分布类是 Delta 分布,这使得 SafeProbabilisticModule 成为确定性映射函数的简单包装器。
- 参数:
in_keys (NestedKey 或 list of NestedKey 或 dict) – 将从输入 TensorDict 中读取并用于构建分布的键。重要的是,如果是 NestedKey 列表或 NestedKey,这些键的叶子(最后一个元素)必须与感兴趣的分布类使用的关键字匹配,例如 Normal 分布的
"loc"
和"scale"
以及类似的关键字。如果 in_keys 是字典,则键是分布的键,值是 tensordict 中将与相应分布键匹配的键。out_keys (NestedKey 或 list of NestedKey) – 将写入采样值的键。重要的是,如果在输入 TensorDict 中找到这些键,则会跳过采样步骤。
spec (TensorSpec) – 第一个输出张量的规格。在调用 td_module.random() 以在目标空间中生成随机值时使用。
safe (bool, optional) – 如果为
True
,则会根据输入 spec 检查样本的值。由于探索策略或数值下溢/溢出问题,可能会发生域外采样。与 spec 参数一样,此检查仅对分布样本执行,而不对输入模块返回的其他张量执行。如果样本超出范围,则使用 TensorSpec.project 方法将其投影回期望的空间。默认为False
。default_interaction_type (str, optional) – 用于检索输出值的默认方法。应为以下之一:‘mode’、‘median’、‘mean’ 或 ‘random’(在这种情况下,值将从分布中随机采样)。默认为 ‘mode’。注意:当绘制样本时,
ProbabilisticTDModule
实例将首先查找由 interaction_typ() 全局函数指定的交互模式。如果这返回 None(其默认值),则将使用ProbabilisticTDModule
实例的 default_interaction_type。请注意,DataCollector 实例默认情况下将使用tensordict.nn.set_interaction_type()
设置为tensordict.nn.InteractionType.RANDOM
。distribution_class (Type, optional) – 用于采样的 torch.distributions.Distribution 类。默认为 Delta。
distribution_kwargs (dict, optional) – 要传递给分布的 kwargs。
return_log_prob (bool, optional) – 如果为
True
,则分布样本的对数概率将写入 tensordict,键为 ‘sample_log_prob’。默认为False
。log_prob_key (NestedKey, optional) – 如果 return_log_prob = True,则将对数概率写入的键。默认为 ‘sample_log_prob’。
cache_dist (bool, optional) – 实验性功能:如果为
True
,则分布的参数(即模块的输出)将与样本一起写入 tensordict。这些参数可用于稍后重新计算原始分布(例如,计算用于采样动作的分布与 PPO 中更新的分布之间的散度)。默认为False
。n_empirical_estimate (int, optional) – 当经验均值不可用时,用于计算经验均值的样本数。默认为 1000