快捷方式

Flat2TED

class torchrl.data.Flat2TED(done_key='done', shift_key='shift', is_full_key='is_full', done_keys=('done', 'truncated', 'terminated'), reward_keys=('reward',))[源代码]

一个存储加载挂钩,用于将扁平化的 TED 数据反序列化为 TED 格式。

参数:
  • done_key (NestedKey, 可选) – 应该读取完成状态的键。默认为 ("next", "done")

  • shift_key (NestedKey, 可选) – 将写入移位的键。默认为 “shift”。

  • is_full_key (NestedKey, 可选) – 将写入 is_full 属性的键。默认为 “is_full”。

  • done_keys (Tuple[NestedKey], 可选) – 表示完成条目的嵌套键元组。默认为 (“done”, “truncated”, “terminated”)

  • reward_keys (Tuple[NestedKey], 可选) – 表示奖励条目的嵌套键元组。默认为 (“reward”,)

示例

>>> import tempfile
>>>
>>> from tensordict import TensorDict
>>>
>>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.data import ReplayBuffer, TED2Flat, LazyMemmapStorage, Flat2TED
>>> from torchrl.envs import GymEnv
>>> import torch
>>>
>>> env = GymEnv("CartPole-v1")
>>> env.set_seed(0)
>>> torch.manual_seed(0)
>>> collector = SyncDataCollector(env, policy=env.rand_step, total_frames=200, frames_per_batch=200)
>>> rb = ReplayBuffer(storage=LazyMemmapStorage(200))
>>> rb.register_save_hook(TED2Flat())
>>> with tempfile.TemporaryDirectory() as tmpdir:
...     for i, data in enumerate(collector):
...         rb.extend(data)
...         rb.dumps(tmpdir)
...     # load the data to represent it
...     td = TensorDict.load(tmpdir + "/storage/")
...
...     rb_load = ReplayBuffer(storage=LazyMemmapStorage(200))
...     rb_load.register_load_hook(Flat2TED())
...     rb_load.load(tmpdir)
...     print("storage after loading", rb_load[:])
...     assert (rb[:] == rb_load[:]).all()
storage after loading TensorDict(
    fields={
        action: MemoryMappedTensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: MemoryMappedTensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([200]),
            device=cpu,
            is_shared=False),
        done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: MemoryMappedTensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([200]),
            device=cpu,
            is_shared=False),
        observation: MemoryMappedTensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([200]),
    device=cpu,
    is_shared=False)

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源