ProbabilisticTensorDictModule¶
- 类 tensordict.nn.ProbabilisticTensorDictModule(*args, **kwargs)¶
一个概率 TD 模块。
ProbabilisticTensorDictModule 是一个表示概率分布的非参数模块。它使用指定的 in_keys 从输入 TensorDict 中读取分布参数。输出根据一定的规则进行采样,该规则由输入的
default_interaction_type
参数和interaction_type()
全局函数指定。ProbabilisticTensorDictModule
可用于构建分布(通过get_dist()
方法)和/或从此分布中采样(通过对模块进行常规__call__()
)。一个
ProbabilisticTensorDictModule
实例有两个主要功能: - 它读取和写入 TensorDict 对象 - 它使用实数映射 R^n -> R^m 来创建 R^d 中的分布,可以从中采样或计算值。当调用
__call__
/forward
方法时,将创建一个分布,并计算一个值(使用 'mean'、'mode'、'median' 属性或 'rsample'、'sample' 方法)。如果提供的 TensorDict 已经具有所有期望的键值对,则会跳过采样步骤。默认情况下,ProbabilisticTensorDictModule 分布类是 Delta 分布,这使得 ProbabilisticTensorDictModule 成为确定性映射函数的简单包装器。
- 参数:
in_keys (NestedKey 或 NestedKey 列表 或 dict) – 将从输入 TensorDict 中读取并用于构建分布的键。重要的是,如果它是 NestedKey 列表或 NestedKey,则这些键的叶(最后一个元素)必须与感兴趣的分布类使用的关键字匹配,例如 Normal 分布的
"loc"
和"scale"
等。如果 in_keys 是字典,则键是分布的键,值是将与相应分布键匹配的 tensordict 中的键。out_keys (NestedKey 或 NestedKey 列表) – 将写入采样值的键。重要的是,如果在输入 TensorDict 中找到这些键,则会跳过采样步骤。
default_interaction_mode (str, 可选) – 已弃用 仅关键字参数。请改用 default_interaction_type。
default_interaction_type (InteractionType, 可选) –
仅关键字参数。用于检索输出值的默认方法。应为 InteractionType 之一:MODE、MEDIAN、MEAN 或 RANDOM(在这种情况下,值是从分布中随机采样的)。默认为 MODE。
注意
当绘制样本时,
ProbabilisticTensorDictModule
实例将首先查找由interaction_type()
全局函数指示的交互模式。如果这返回 None(其默认值),则将使用 ProbabilisticTDModule 实例的 default_interaction_type。请注意,DataCollectorBase
实例默认将使用 set_interaction_type 设置为tensordict.nn.InteractionType.RANDOM
。注意
在某些情况下,模式、中值或平均值可能无法通过相应的属性轻松获得。为了缓解这种情况,
ProbabilisticTensorDictModule
将首先尝试通过调用get_mode()
、get_median()
或get_mean()
(如果该方法存在)来获取值。distribution_class (Type, 可选) –
仅关键字参数。用于采样的
torch.distributions.Distribution
类。默认为Delta
。注意
如果分布类是
CompositeDistribution
类型,则可以直接从通过此类的distribution_kwargs
关键字参数提供的"distribution_map"
或"name_map"
关键字参数推断out_keys
,从而使out_keys
在这种情况下成为可选的。distribution_kwargs (dict, 可选) – 仅关键字参数。要传递给分布的关键字参数对。
return_log_prob (bool, 可选) – 仅关键字参数。如果为
True
,则分布样本的对数概率将使用键 log_prob_key 写入 tensordict 中。默认为False
。log_prob_key (NestedKey, 可选) – 如果 return_log_prob = True,则将对数概率写入的键。默认为 ‘sample_log_prob’。
cache_dist (bool, 可选) – 仅关键字参数。实验性:如果为
True
,则分布的参数(即模块的输出)将与样本一起写入 tensordict 中。这些参数可用于稍后重新计算原始分布(例如,计算用于采样动作的分布与 PPO 中更新的分布之间的差异)。默认为False
。n_empirical_estimate (int, 可选) – 仅关键字参数。计算经验平均值的样本数(当它不可用时)。默认为 1000。
示例
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import ( ... ProbabilisticTensorDictModule, ... ProbabilisticTensorDictSequential, ... TensorDictModule, ... ) >>> from tensordict.nn.distributions import NormalParamExtractor >>> from tensordict.nn.functional_modules import make_functional >>> from torch.distributions import Normal, Independent >>> td = TensorDict( ... {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3] ... ) >>> net = torch.nn.GRUCell(4, 8) >>> module = TensorDictModule( ... net, in_keys=["input", "hidden"], out_keys=["params"] ... ) >>> normal_params = TensorDictModule( ... NormalParamExtractor(), in_keys=["params"], out_keys=["loc", "scale"] ... ) >>> def IndepNormal(**kwargs): ... return Independent(Normal(**kwargs), 1) >>> prob_module = ProbabilisticTensorDictModule( ... in_keys=["loc", "scale"], ... out_keys=["action"], ... distribution_class=IndepNormal, ... return_log_prob=True, ... ) >>> td_module = ProbabilisticTensorDictSequential( ... module, normal_params, prob_module ... ) >>> params = TensorDict.from_module(td_module) >>> with params.to_module(td_module): ... _ = td_module(td) >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), params: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> with params.to_module(td_module): ... dist = td_module.get_dist(td) >>> print(dist) Independent(Normal(loc: torch.Size([3, 4]), scale: torch.Size([3, 4])), 1) >>> # we can also apply the module to the TensorDict with vmap >>> from torch import vmap >>> params = params.expand(4) >>> def func(td, params): ... with params.to_module(td_module): ... return td_module(td) >>> td_vmap = vmap(func, (None, 0))(td, params) >>> print(td_vmap) TensorDict( fields={ action: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), hidden: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), params: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 3]), device=None, is_shared=False)
- forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, _requires_sample: bool = True) TensorDictBase ¶
定义每次调用时执行的计算。
应由所有子类重写。
注意
尽管需要在该函数中定义前向传递的配方,但是之后应该调用
Module
实例而不是此函数,因为前者负责运行注册的钩子,而后者则会静默地忽略它们。
- get_dist(tensordict: TensorDictBase) Distribution ¶
使用输入 tensordict 中提供的参数创建
torch.distribution.Distribution
实例。
- log_prob(tensordict)¶
写入分布样本的对数概率。