SafeProbabilisticModule¶
- class torchrl.modules.tensordict_module.SafeProbabilisticModule(*args, **kwargs)[源]¶
是
tensordict.nn.ProbabilisticTensorDictModule
的子类,接受一个TensorSpec
参数来控制输出域。SafeProbabilisticModule 是一个非参数模块,内嵌一个概率分布构造器。它使用指定的 in_keys 从输入 TensorDict 读取分布参数,并输出该分布的一个样本(广义上)。
输出的“样本”是根据某个规则生成的,该规则由输入的
default_interaction_type
参数和interaction_type()
全局函数指定。SafeProbabilisticModule 可用于构造分布(通过
get_dist()
方法)和/或从该分布中采样(通过对模块的常规__call__()
调用)。一个 SafeProbabilisticModule 实例有两个主要特性
它读写 TensorDict 对象;
它使用一个实数映射 R^n -> R^m 来创建 R^d 中的分布,可以从中采样或计算值。
当调用
__call__()
和forward()
方法时,会创建一个分布并计算一个值(取决于interaction_type
值,可以使用 ‘dist.mean’、‘dist.mode’、‘dist.median’ 属性,以及 ‘dist.rsample’、‘dist.sample’ 方法)。如果提供的 TensorDict 中已包含所有期望的键值对,则跳过采样步骤。默认情况下,SafeProbabilisticModule 的分布类是一个
Delta
分布,这使得 SafeProbabilisticModule 成为一个确定性映射函数的简单包装器。此类与
tensordict.nn.ProbabilisticTensorDictModule
不同之处在于它接受一个spec
关键字参数,可用于控制样本是否属于该分布。`safe` 关键字参数控制是否应根据 spec 检查样本值。- 参数:
in_keys (NestedKey | List[NestedKey] | Dict[str, NestedKey]) – 将从输入 TensorDict 中读取并用于构建分布的键。重要提示:如果它是 NestedKey 列表或 NestedKey,则这些键的叶子(最后一个元素)必须与感兴趣的分布类使用的关键字匹配,例如
"loc"
和"scale"
用于Normal
分布等。如果 in_keys 是字典,则键是分布的键,值是 tensordict 中将与相应分布键匹配的键。out_keys (NestedKey | List[NestedKey] | None) – 写入采样值的键。重要提示:如果在输入 TensorDict 中找到了这些键,则会跳过采样步骤。
spec (TensorSpec) – 第一个输出张量的规范。在调用 td_module.random() 时用于在目标空间生成随机值。
- 关键字参数:
safe (bool, optional) – 如果为
True
,则会根据输入 spec 检查样本的值。由于探索策略或数值下溢/上溢问题,可能会发生超出域的采样。与spec
参数一样,此检查仅针对分布样本进行,而不针对输入模块返回的其他张量。如果样本超出界限,则使用 TensorSpec.project 方法将其投影回期望的空间。默认值为False
。default_interaction_type (InteractionType, optional) –
仅关键字参数。用于获取输出值的默认方法。应为 InteractionType 之一:MODE、MEDIAN、MEAN 或 RANDOM(在这种情况下,值是从分布中随机采样的)。默认值为 MODE。
注
抽取样本时,
ProbabilisticTensorDictModule
实例将首先查找由interaction_type()
全局函数指定的交互模式。如果返回 None(其默认值),则将使用 ProbabilisticTDModule 实例的 default_interaction_type。请注意,DataCollectorBase
实例默认将 set_interaction_type 设置为tensordict.nn.InteractionType.RANDOM
。注
在某些情况下,mode、median 或 mean 值可能无法通过相应的属性直接获得。为了弥补这一点,
ProbabilisticTensorDictModule
将首先尝试通过调用get_mode()
、get_median()
或get_mean()
来获取值,如果方法存在。distribution_class (Type or Callable[[Any], Distribution], optional) –
仅关键字参数。一个
torch.distributions.Distribution
类,用于采样。默认值为Delta
。注
如果分布类是
CompositeDistribution
类型,则out_keys
可以直接从通过此类distribution_kwargs
关键字参数提供的"distribution_map"
或"name_map"
关键字参数推断出来,在这种情况下out_keys
是可选的。distribution_kwargs (dict, optional) –
仅关键字参数。要传递给分布的关键字参数对。
注
如果您的 kwargs 包含您想随模块一起转移到设备上的张量,或者在调用 module.to(dtype) 时应修改其 dtype 的张量,您可以将 kwargs 包装在
TensorDictParams
中以自动完成此操作。return_log_prob (bool, optional) – 仅关键字参数。如果为
True
,则分布样本的对数概率将写入 tensordict 中,键为 log_prob_key。默认值为False
。log_prob_keys (List[NestedKey], optional) –
如果
return_log_prob=True
,写入 log_prob 的键。默认值为 ‘<sample_key_name>_log_prob’,其中 <sample_key_name> 是每个out_keys
。注
仅当
composite_lp_aggregate()
设置为False
时可用。log_prob_key (NestedKey, optional) –
如果
return_log_prob=True
,写入 log_prob 的键。当composite_lp_aggregate()
设置为 True 时,默认值为 ‘sample_log_prob’,否则为 ‘<sample_key_name>_log_prob’。注
当存在多个样本时,仅当
composite_lp_aggregate()
设置为True
时可用。cache_dist (bool, optional) – 仅关键字参数。实验性功能:如果为
True
,则分布的参数(即模块的输出)将与样本一起写入 tensordict 中。这些参数可以在之后用于重新计算原始分布(例如,在 PPO 中计算用于采样动作的分布与更新后的分布之间的散度)。默认值为False
。n_empirical_estimate (int, optional) – 仅关键字参数。计算经验平均值时使用的样本数量,当经验平均值不可用时。默认值为 1000。
警告
运行检查会花费时间!使用 safe=True 将保证样本在
project()
中编码的启发式方法给定的 spec 界限内,但这需要检查值是否在 spec 空间内,这将引入一些开销。另请参阅
:class:`tensordict 中的组合分布 <~tensordict.nn.CompositeDistribution>` 可用于创建多头策略。
示例
>>> from torchrl.modules import SafeProbabilisticModule >>> from torchrl.data import Bounded >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import InteractionType >>> mod = SafeProbabilisticModule( ... in_keys=["loc", "scale"], ... out_keys=["action"], ... distribution_class=torch.distributions.Normal, ... safe=True, ... spec=Bounded(low=-1, high=1, shape=()), ... default_interaction_type=InteractionType.RANDOM ... ) >>> _ = torch.manual_seed(0) >>> data = TensorDict( ... loc=torch.zeros(10, requires_grad=True), ... scale=torch.full((10,), 10.0), ... batch_size=(10,)) >>> data = mod(data) >>> print(data["action"]) # All actions are within bound tensor([ 1., -1., -1., 1., -1., -1., 1., 1., -1., -1.], grad_fn=<ClampBackward0>) >>> data["action"].mean().backward() >>> print(data["loc"].grad) # clamp anihilates gradients tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])