快捷方式

TokenizedDatasetLoader

class torchrl.data.TokenizedDatasetLoader(split, max_length, dataset_name, tokenizer_fn: Type[TensorDictTokenizer], pre_tokenization_hook=None, root_dir=None, from_disk=False, valid_size: int = 2000, num_workers: Optional[int] = None, tokenizer_class=None, tokenizer_model_name=None)[source]

加载一个分词后的数据集,并缓存一个内存映射的副本。

参数:
  • split (str) – 取值可以是 "train""valid"

  • max_length (int) – 最大序列长度。

  • dataset_name (str) – 数据集的名称。

  • tokenizer_fn (callable) – 分词方法的构造函数,例如 torchrl.data.rlhf.TensorDictTokenizer。调用时,应返回一个 tensordict.TensorDict 实例或一个包含分词数据的字典状结构。

  • pre_tokenization_hook (callable, optional) – 在分词之前在数据集上调用。它应返回一个修改后的 Dataset 对象。其预期用途是执行需要修改整个数据集而不是修改单个数据点的任务,例如根据特定条件丢弃某些数据点。数据的分词和其他“按元素”操作由映射到数据集上的处理函数执行。

  • root_dir (path, optional) – 数据集存储的路径。默认为 "$HOME/.cache/torchrl/data"

  • from_disk (bool, optional) – 如果为 True,则使用 datasets.load_from_disk()。否则,使用 datasets.load_dataset()。默认为 False

  • valid_size (int, optional) – 验证数据集的大小(如果 split 以 "valid" 开头)将被截断为此值。默认为 2000 项。

  • num_workers (int, optional) – datasets.dataset.map() 的工作进程数,该方法在分词期间调用。默认为 max(os.cpu_count() // 2, 1)

  • tokenizer_class (Type, optional) – 一个分词器类,例如 AutoTokenizer(默认)。

  • tokenizer_model_name (str, optional) – 从中收集词汇表的模型。默认为 "gpt2"

数据集将存储在 <root_dir>/<split>/<max_length>/ 中。

示例

>>> from torchrl.data.rlhf import TensorDictTokenizer
>>> from torchrl.data.rlhf.reward import  pre_tokenization_hook
>>> split = "train"
>>> max_length = 550
>>> dataset_name = "CarperAI/openai_summarize_comparisons"
>>> loader = TokenizedDatasetLoader(
...     split,
...     max_length,
...     dataset_name,
...     TensorDictTokenizer,
...     pre_tokenization_hook=pre_tokenization_hook,
... )
>>> dataset = loader.load()
>>> print(dataset)
TensorDict(
    fields={
        attention_mask: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([185068]),
    device=None,
    is_shared=False)
static dataset_to_tensordict(dataset: 'datasets.Dataset' | TensorDict, data_dir: Path, prefix: NestedKey = None, features: Sequence[str] =None, batch_dims=1, valid_mask_key=None)[source]

将数据集转换为内存映射的 TensorDict。

如果数据集已经是 TensorDict 实例,则直接将其转换为内存映射的 TensorDict。否则,数据集应具有一个 features 属性,该属性是一个字符串序列,指示数据集中可以找到的特征。如果没有该属性,则必须显式地将 features 传递给此函数。

参数:
  • dataset (datasets.Dataset, TensorDict or equivalent) – 要转换为内存映射 TensorDict 的数据集。如果 featuresNone,则必须具有一个 features 属性,其中包含要写入 tensordict 的键列表。

  • data_dir (Path or equivalent) – 应写入数据的目录。

  • prefix (NestedKey, optional) – 数据集位置的前缀。这可用于区分经过不同预处理的同一数据集的多个副本。

  • features (sequence of str, optional) – 一个字符串序列,指示可以在数据集中找到的特征。

  • batch_dims (int, optional) – 数据的批处理维度数(即 tensordict 可以沿其索引的维度数)。默认为 1。

  • valid_mask_key (NestedKey, optional) – 如果提供,将尝试收集此条目并用于过滤数据。默认为 None(即没有过滤键)。

返回值: 一个包含数据集内存映射张量的 TensorDict。

示例

>>> from datasets import Dataset
>>> import tempfile
>>> data = Dataset.from_dict({"tokens": torch.randint(20, (10, 11)), "labels": torch.zeros(10, 11)})
>>> with tempfile.TemporaryDirectory() as tmpdir:
...     data_memmap = TokenizedDatasetLoader.dataset_to_tensordict(
...         data, data_dir=tmpdir, prefix=("some", "prefix"), features=["tokens", "labels"]
...     )
...     print(data_memmap)
TensorDict(
    fields={
        some: TensorDict(
            fields={
                prefix: TensorDict(
                    fields={
                        labels: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False),
                        tokens: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)},
                    batch_size=torch.Size([10]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
load()[source]

如果存在预处理的内存映射数据集,则加载它,否则创建它。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源