DecisionTransformerInferenceWrapper¶
- class torchrl.modules.tensordict_module.DecisionTransformerInferenceWrapper(*args, **kwargs)[source]¶
决策转换器的推理动作包装器。
专为决策转换器设计的包装器,它将输入张量字典序列屏蔽到推理上下文中。输出将是一个与输入具有相同键的 TensorDict,但仅包含预测动作序列的最后一个动作和最后一个回报。
此模块创建并返回张量字典的修改副本,即它 **不会** 就地修改张量字典。
注意
如果动作、观察或回报键不是标准的,则应使用
set_tensor_keys()
方法,例如:>>> dt_inference_wrapper.set_tensor_keys(action="foo", observation="bar", return_to_go="baz")
in_keys 是观察、动作和回报键。out-keys 与 in-keys 匹配,并添加策略的任何其他 out-key(例如,分布的参数或隐藏值)。
- 参数:
policy (TensorDictModule) – 接收观察并生成动作值的策略模块
- 关键字参数:
inference_context (int) – 上下文中不会被屏蔽的前几个动作的数量。例如,对于形状为 [batch_size, context, obs_dim] 的观察输入,其中 context=20 且 inference_context=5,上下文的最初 15 个条目将被屏蔽。默认为 5。
spec (Optional[TensorSpec]) – 输入 TensorDict 的规范。如果为 None,则将从策略模块推断。
示例
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> from torchrl.modules import ( ... ProbabilisticActor, ... TanhDelta, ... DTActor, ... DecisionTransformerInferenceWrapper, ... ) >>> dtactor = DTActor(state_dim=4, action_dim=2, ... transformer_config=DTActor.default_config() ... ) >>> actor_module = TensorDictModule( ... dtactor, ... in_keys=["observation", "action", "return_to_go"], ... out_keys=["param"]) >>> dist_class = TanhDelta >>> dist_kwargs = { ... "low": -1.0, ... "high": 1.0, ... } >>> actor = ProbabilisticActor( ... in_keys=["param"], ... out_keys=["action"], ... module=actor_module, ... distribution_class=dist_class, ... distribution_kwargs=dist_kwargs) >>> inference_actor = DecisionTransformerInferenceWrapper(actor) >>> sequence_length = 20 >>> td = TensorDict({"observation": torch.randn(1, sequence_length, 4), ... "action": torch.randn(1, sequence_length, 2), ... "return_to_go": torch.randn(1, sequence_length, 1)}, [1,]) >>> result = inference_actor(td) >>> print(result) TensorDict( fields={ action: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([1, 20, 4]), device=cpu, dtype=torch.float32, is_shared=False), param: Tensor(shape=torch.Size([1, 20, 2]), device=cpu, dtype=torch.float32, is_shared=False), return_to_go: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([1]), device=None, is_shared=False)