快捷方式

get_dataloader

class torchrl.data.get_dataloader(batch_size: int, block_size: int, tensorclass_type: Type, device: device, dataset_name: Optional[str] = None, infinite: bool = True, prefetch: int = 0, split: str = 'train', root_dir: Optional[str] = None, from_disk: bool = False, num_workers: Optional[int] = None)[source]

创建一个数据集并从中返回一个数据加载器。

参数:
  • batch_size (int) – 数据加载器样本的批大小。

  • block_size (int) – 数据加载器中序列的最大长度。

  • tensorclass_type (tensorclass 类) – 一个 tensorclass,带有一个 from_dataset() 方法,该方法必须接受三个关键字参数: split (见下文), max_length,它是用于训练的块大小,以及 dataset_name,一个指示数据集的字符串。 还应支持 root_dirfrom_disk 参数。

  • device (torch.device等效设备) – 样本应被转换到的设备。

  • dataset_name (str, 可选) – 数据集名称。 如果未提供且 tensorclass 支持,则将为正在使用的 tensorclass 收集默认数据集名称。

  • infinite (bool, 可选) – 如果 True,则迭代将是无限的,以便 next(iterator) 将始终返回一个值。 默认为 True

  • prefetch (int, 可选) – 如果使用多线程数据加载,则要预取的项目数。

  • split (str, 可选) – 数据分割。 "train""valid" 之一。 默认为 "train"

  • root_dir (路径, 可选) – 数据集存储的路径。 默认为 "$HOME/.cache/torchrl/data"

  • from_disk (bool, 可选) – 如果 True,将使用 datasets.load_from_disk()。 否则,将使用 datasets.load_dataset()。 默认为 False

  • num_workers (int, 可选) – datasets.dataset.map() 的工作进程数,该方法在标记化期间被调用。 默认为 max(os.cpu_count() // 2, 1)

示例

>>> from torchrl.data.rlhf.reward import PairwiseDataset
>>> dataloader = get_dataloader(
...     batch_size=256, block_size=550, tensorclass_type=PairwiseDataset, device="cpu")
>>> for d in dataloader:
...     print(d)
...     break
PairwiseDataset(
    chosen_data=RewardData(
        attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        rewards=None,
        end_scores=None,
        batch_size=torch.Size([256]),
        device=cpu,
        is_shared=False),
    rejected_data=RewardData(
        attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        rewards=None,
        end_scores=None,
        batch_size=torch.Size([256]),
        device=cpu,
        is_shared=False),
    batch_size=torch.Size([256]),
    device=cpu,
    is_shared=False)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源