快捷方式

ConcatDataset

class torchtune.datasets.ConcatDataset(datasets: List[Dataset])[源代码]

一个用于将多个子数据集连接成单个数据集的数据集类。此类允许将不同数据集统一处理,就好像它们是一个数据集一样,从而简化了诸如在多个数据源上同时训练模型的任务。

该类在内部管理不同数据集的聚合,并允许跨数据集透明地进行索引。但是,它要求所有组成数据集都完全加载到内存中,这对于非常大的数据集来说可能不是最佳选择。

在初始化时,此类计算所有数据集的累积长度,并维护一个将索引映射到相应数据集的内部映射。这种方法允许 ConcatDataset 在访问特定索引时透明地将数据检索委托给适当的子数据集。

注意

将此类与非常大的数据集一起使用会导致高内存消耗,因为它要求所有数据集都加载到内存中。对于大规模场景,请考虑其他可能按需流式传输数据的策略。

参数::

datasets (List[Dataset]) – 要连接的数据集列表。每个数据集必须是派生自 Dataset 的类的实例。

示例

>>> dataset1 = MyCustomDataset(params1)
>>> dataset2 = MyCustomDataset(params2)
>>> concat_dataset = ConcatDataset([dataset1, dataset2])
>>> print(len(concat_dataset))  # Total length of both datasets
>>> data_point = concat_dataset[1500]  # Accesses an element from the appropriate dataset

这也可以通过将数据集列表传递到 YAML 配置中来实现

dataset:
  - _component_: torchtune.datasets.instruct_dataset
    source: vicgalle/alpaca-gpt4
    template: torchtune.data.AlpacaInstructTemplate
    split: train
    train_on_input: True
  - _component_: torchtune.datasets.instruct_dataset
    source: samsum
    template: torchtune.data.SummarizeTemplate
    column_map: {"output": "summary"}
    output: summary
    split: train
    train_on_input: False

此类主要侧重于提供一个统一的界面来访问多个数据集的元素,从而增强了在处理用于训练机器学习模型的不同数据源时的灵活性。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源