load_memmap¶
- class tensordict.load_memmap(prefix: str | pathlib.Path, device: Optional[device] = None, non_blocking: bool = False, *, out: Optional[TensorDictBase] = None)¶
从磁盘加载内存映射的 tensordict。
- 参数:
prefix (字符串 或 文件夹路径) – 保存的 tensordict 应该被获取的文件夹路径。
device (torch.device 或 等效设备, 可选) – 如果提供,数据将被异步转换为该设备。支持 “meta” 设备,在这种情况下,数据不会被加载,但会创建一组空的 “meta” 张量。这有助于了解模型的总体大小和结构,而无需实际打开任何文件。
non_blocking (bool, 可选) – 如果为
True
,则在设备上加载张量后不会调用 synchronize。默认为False
。out (TensorDictBase, 可选) – 可选的 tensordict,数据应写入其中。
示例
>>> from tensordict import TensorDict, load_memmap >>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0) >>> td.memmap("./saved_td") >>> td_load = TensorDict.load_memmap("./saved_td") >>> assert (td == td_load).all()
此方法还允许加载嵌套的 tensordict。
示例
>>> nested = TensorDict.load_memmap("./saved_td/nested") >>> assert nested["e"] == 0
tensordict 也可以在 “meta” 设备上加载,或者作为假张量加载。
示例
>>> import tempfile >>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}}) >>> with tempfile.TemporaryDirectory() as path: ... td.save(path) ... td_load = load_memmap(path, device="meta") ... print("meta:", td_load) ... from torch._subclasses import FakeTensorMode ... with FakeTensorMode(): ... td_load = load_memmap(path) ... print("fake:", td_load) meta: TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False) fake: TensorDict( fields={ a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)