快捷方式

标记化数据集加载器

class torchrl.data.TokenizedDatasetLoader(split, max_length, dataset_name, tokenizer_fn: Type[TensorDictTokenizer], pre_tokenization_hook=None, root_dir=None, from_disk=False, valid_size: int = 2000, num_workers: Optional[int] = None, tokenizer_class=None, tokenizer_model_name=None)[source]

加载标记化数据集,并缓存其内存映射副本。

参数:
  • split (str) – "train""valid" 之一。

  • max_length (int) – 最大序列长度。

  • dataset_name (str) – 数据集的名称。

  • tokenizer_fn (callable) – 标记方法构造函数,例如 torchrl.data.rlhf.TensorDictTokenizer。当被调用时,它应该返回一个 tensordict.TensorDict 实例或包含标记化数据的类似字典的结构。

  • pre_tokenization_hook (callable, optional) – 在标记化之前在数据集上调用。它应该返回一个修改后的数据集对象。其预期用途是在执行需要修改整个数据集而不是修改单个数据点的任务时,例如根据特定条件丢弃某些数据点。标记化和其他“逐元素”数据操作由映射到数据集的处理函数执行。

  • root_dir (path, optional) – 存储数据集的路径。默认为 "$HOME/.cache/torchrl/data"

  • from_disk (bool, optional) – 如果为 True,则将使用 datasets.load_from_disk()。否则,将使用 datasets.load_dataset()。默认为 False

  • valid_size (int, optional) – 如果 "valid" 的开头是 split,则验证数据集的大小将被截断到此值。默认为 2000 项。

  • num_workers (int, optional) – datasets.dataset.map() 的工作进程数量,在标记化期间被调用。默认为 max(os.cpu_count() // 2, 1)

  • tokenizer_class (type, optional) – 标记器类,例如 AutoTokenizer(默认)。

  • tokenizer_model_name (str, optional) – 应该从中收集词汇表的模型。默认为 "gpt2"

该数据集将存储在 <root_dir>/<split>/<max_length>/ 中。

示例

>>> from torchrl.data.rlhf import TensorDictTokenizer
>>> from torchrl.data.rlhf.reward import  pre_tokenization_hook
>>> split = "train"
>>> max_length = 550
>>> dataset_name = "CarperAI/openai_summarize_comparisons"
>>> loader = TokenizedDatasetLoader(
...     split,
...     max_length,
...     dataset_name,
...     TensorDictTokenizer,
...     pre_tokenization_hook=pre_tokenization_hook,
... )
>>> dataset = loader.load()
>>> print(dataset)
TensorDict(
    fields={
        attention_mask: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([185068]),
    device=None,
    is_shared=False)
static dataset_to_tensordict(dataset: 'datasets.Dataset' | TensorDict, data_dir: Path, prefix: NestedKey = None, features: Sequence[str] = None, batch_dims=1, valid_mask_key=None)[source]

将数据集转换为内存映射的 TensorDict。

如果数据集已经是 TensorDict 实例,则它只是简单地转换为内存映射的 TensorDict。否则,数据集应具有 features 属性,它是一个字符串序列,指示可以在数据集中找到的特征。如果没有,则必须将 features 明确地传递给此函数。

参数:
  • dataset (datasets.Dataset, TensorDict等效) – 要转换为内存映射的 TensorDict 的数据集。如果 featuresNone,则它必须具有一个 features 属性,其中包含要在 tensordict 中写入的键的列表。

  • data_dir (路径等效项) – 数据写入的目录。

  • prefix (嵌套键, 可选) – 数据集位置的前缀。这可用于区分经过不同预处理的相同数据集的多个副本。

  • features (字符串序列, 可选) – 指示数据集中可以找到的特征的字符串序列。

  • batch_dims (整数, 可选) – 数据的批次维度数量(即张量字典可以沿其索引的维度数量)。默认为 1。

  • valid_mask_key (嵌套键, 可选) – 如果提供,此条目将被尝试性地收集并用于过滤数据。默认为 None(即,没有过滤器键)。

返回:包含具有数据集的内存映射张量的 TensorDict。

示例

>>> from datasets import Dataset
>>> import tempfile
>>> data = Dataset.from_dict({"tokens": torch.randint(20, (10, 11)), "labels": torch.zeros(10, 11)})
>>> with tempfile.TemporaryDirectory() as tmpdir:
...     data_memmap = TokenizedDatasetLoader.dataset_to_tensordict(
...         data, data_dir=tmpdir, prefix=("some", "prefix"), features=["tokens", "labels"]
...     )
...     print(data_memmap)
TensorDict(
    fields={
        some: TensorDict(
            fields={
                prefix: TensorDict(
                    fields={
                        labels: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False),
                        tokens: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)},
                    batch_size=torch.Size([10]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
load()[source]

如果存在,则加载预处理的内存映射数据集,否则创建它。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源