Reward2GoTransform¶
- class torchrl.envs.transforms.Reward2GoTransform(gamma: Optional[Union[float, torch.Tensor]] = 1.0, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, done_key: Optional[NestedKey] = 'done')[源代码]¶
根据回合奖励和折扣因子计算剩余奖励(reward-to-go)。
由于
Reward2GoTransform
仅是一个逆向变换(inverse transform),in_keys
将直接用于in_keys_inv
。剩余奖励(reward-to-go)只能在回合结束时计算。因此,该变换应该应用于经验回放缓冲区(replay buffer),而不是收集器(collector)或环境(environment)内部。- 参数:
gamma (
float
或 torch.Tensor) – 折扣因子。默认为 1.0。in_keys (NestedKey 序列) – 需要重命名的条目。如果未提供,默认为
("next", "reward")
。out_keys (NestedKey 序列) – 需要重命名的条目。如果未提供,默认为
in_keys
的值。done_key (NestedKey) – done 条目。默认为
"done"
。truncated_key (NestedKey) – truncated 条目。默认为
"truncated"
。如果未找到 truncated 条目,则仅使用"done"
。
示例
>>> # Using this transform as part of a replay buffer >>> from torchrl.data import ReplayBuffer, LazyTensorStorage >>> torch.manual_seed(0) >>> r2g = Reward2GoTransform(gamma=0.99, out_keys=["reward_to_go"]) >>> rb = ReplayBuffer(storage=LazyTensorStorage(100), transform=r2g) >>> batch, timesteps = 4, 5 >>> done = torch.zeros(batch, timesteps, 1, dtype=torch.bool) >>> for i in range(batch): ... while not done[i].any(): ... done[i] = done[i].bernoulli_(0.1) >>> reward = torch.ones(batch, timesteps, 1) >>> td = TensorDict( ... {"next": {"done": done, "reward": reward}}, ... [batch, timesteps], ... ) >>> rb.extend(td) >>> sample = rb.sample(1) >>> print(sample["next", "reward"]) tensor([[[1.], [1.], [1.], [1.], [1.]]]) >>> print(sample["reward_to_go"]) tensor([[[4.9010], [3.9404], [2.9701], [1.9900], [1.0000]]])
也可以直接将此变换与收集器(collector)一起使用:确保附加变换的 inv 方法。
示例
>>> from torchrl.envs.utils import RandomPolicy >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs.libs.gym import GymEnv >>> t = Reward2GoTransform(gamma=0.99, out_keys=["reward_to_go"]) >>> env = GymEnv("Pendulum-v1") >>> collector = SyncDataCollector( ... env, ... RandomPolicy(env.action_spec), ... frames_per_batch=200, ... total_frames=-1, ... postproc=t.inv ... ) >>> for data in collector: ... break >>> print(data) TensorDict( fields={ action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), collector: TensorDict( fields={ traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), reward_to_go: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False)
在环境(env)中使用此变换将引发异常
示例
>>> t = Reward2GoTransform(gamma=0.99) >>> TransformedEnv(GymEnv("Pendulum-v1"), t) # crashes
注意
在存在多个 done 条目的情况下,应为每个 done-reward 对构建一个单独的
Reward2GoTransform
。