快捷方式

PromptTensorDictTokenizer

class torchrl.data.PromptTensorDictTokenizer(tokenizer, max_length, key='prompt', padding='max_length', truncation=True, return_tensordict=True, device=None)[source]

用于提示数据集的标记化配方。

返回一个标记器函数,该函数读取包含提示和标签的示例并对其进行标记化。

参数::
  • tokenizer (来自 transformers 库的标记器) – 要使用的标记器。

  • max_length (int) – 序列的最大长度。

  • key (str, optional) – 查找文本的键。默认为 "prompt"

  • padding (str, optional) – 填充类型。默认为 "max_length"

  • truncation (bool, optional) – 是否应将序列截断为 max_length。

  • return_tensordict (bool, optional) – 如果为 True,则返回 TensoDict。否则,将返回原始数据。

  • device (torch.device, optional) – 存储数据的设备。如果 return_tensordict=False,则忽略此选项。

此类的 __call__() 方法将执行以下操作

  • 读取与 label 字符串连接的 prompt 字符串,并对其进行标记化。结果将存储在 "input_ids" TensorDict 条目中。

  • 使用提示的最后一个有效标记的索引编写 "prompt_rindex" 条目。

  • 编写一个 "valid_sample",它标识 tensordict 中哪些条目具有足够的标记来满足 max_length 条件。

  • 返回一个 tensordict.TensorDict 实例,其中包含标记化的输入。

tensordict 批次大小将与输入的批次大小匹配。

示例

>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> tokenizer.pad_token = tokenizer.eos_token
>>> example = {
...     "prompt": ["This prompt is long enough to be tokenized.", "this one too!"],
...     "label": ["Indeed it is.", 'It might as well be.'],
... }
>>> fn = PromptTensorDictTokenizer(tokenizer, 50)
>>> print(fn(example))
TensorDict(
    fields={
        attention_mask: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False),
        prompt_rindex: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
        valid_sample: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获得面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源