TED2Flat¶
- class torchrl.data.TED2Flat(done_key=('next', 'done'), shift_key='shift', is_full_key='is_full', done_keys=('done', 'truncated', 'terminated'), reward_keys=('reward',))[source]¶
一个存储节省挂钩,用于以紧凑格式序列化 TED 数据。
- 参数:
done_key (NestedKey, 可选) – 应读取完成状态的键。默认为
("next", "done")
。shift_key (NestedKey, 可选) – 将写入移位的键。默认为“shift”。
is_full_key (NestedKey, 可选) – 将写入 is_full 属性的键。默认为“is_full”。
done_keys (Tuple[NestedKey], 可选) – 一个嵌套键元组,指示完成条目。默认为 (“done”, “truncated”, “terminated”)
reward_keys (Tuple[NestedKey], 可选) – 一个嵌套键元组,指示奖励条目。默认为 (“reward”,)
示例
>>> import tempfile >>> >>> from tensordict import TensorDict >>> >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.data import ReplayBuffer, TED2Flat, LazyMemmapStorage >>> from torchrl.envs import GymEnv >>> import torch >>> >>> env = GymEnv("CartPole-v1") >>> env.set_seed(0) >>> torch.manual_seed(0) >>> collector = SyncDataCollector(env, policy=env.rand_step, total_frames=200, frames_per_batch=200) >>> rb = ReplayBuffer(storage=LazyMemmapStorage(200)) >>> rb.register_save_hook(TED2Flat()) >>> with tempfile.TemporaryDirectory() as tmpdir: ... for i, data in enumerate(collector): ... rb.extend(data) ... rb.dumps(tmpdir) ... # load the data to represent it ... td = TensorDict.load(tmpdir + "/storage/") ... print(td) TensorDict( fields={ action: MemoryMappedTensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=True), collector: TensorDict( fields={ traj_ids: MemoryMappedTensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cpu, is_shared=False), done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True), observation: MemoryMappedTensor(shape=torch.Size([220, 4]), device=cpu, dtype=torch.float32, is_shared=True), reward: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=True), terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True), truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True)}, batch_size=torch.Size([]), device=cpu, is_shared=False)