get_dataloader¶
- class torchrl.data.get_dataloader(batch_size: int, block_size: int, tensorclass_type: Type, device: device, dataset_name: Optional[str] = None, infinite: bool = True, prefetch: int = 0, split: str = 'train', root_dir: Optional[str] = None, from_disk: bool = False, num_workers: Optional[int] = None)[source]¶
创建一个数据集并从中返回一个数据加载器。
- 参数:
batch_size (int) – 数据加载器样本的批大小。
block_size (int) – 数据加载器中序列的最大长度。
tensorclass_type (tensorclass 类) – 一个 tensorclass,带有一个
from_dataset()
方法,该方法必须接受三个关键字参数:split
(见下文),max_length
,它是用于训练的块大小,以及dataset_name
,一个指示数据集的字符串。 还应支持root_dir
和from_disk
参数。device (torch.device 或 等效设备) – 样本应被转换到的设备。
dataset_name (str, 可选) – 数据集名称。 如果未提供且 tensorclass 支持,则将为正在使用的 tensorclass 收集默认数据集名称。
infinite (bool, 可选) – 如果
True
,则迭代将是无限的,以便next(iterator)
将始终返回一个值。 默认为True
。prefetch (int, 可选) – 如果使用多线程数据加载,则要预取的项目数。
split (str, 可选) – 数据分割。
"train"
或"valid"
之一。 默认为"train"
。root_dir (路径, 可选) – 数据集存储的路径。 默认为
"$HOME/.cache/torchrl/data"
from_disk (bool, 可选) – 如果
True
,将使用datasets.load_from_disk()
。 否则,将使用datasets.load_dataset()
。 默认为False
。num_workers (int, 可选) –
datasets.dataset.map()
的工作进程数,该方法在标记化期间被调用。 默认为max(os.cpu_count() // 2, 1)
。
示例
>>> from torchrl.data.rlhf.reward import PairwiseDataset >>> dataloader = get_dataloader( ... batch_size=256, block_size=550, tensorclass_type=PairwiseDataset, device="cpu") >>> for d in dataloader: ... print(d) ... break PairwiseDataset( chosen_data=RewardData( attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([256]), device=cpu, is_shared=False), rejected_data=RewardData( attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([256]), device=cpu, is_shared=False), batch_size=torch.Size([256]), device=cpu, is_shared=False)