PromptTensorDictTokenizer¶
- class torchrl.data.PromptTensorDictTokenizer(tokenizer, max_length, key='prompt', padding='max_length', truncation=True, return_tensordict=True, device=None)[source]¶
用于提示数据集的标记化配方。
返回一个标记器函数,该函数读取包含提示和标签的示例并对其进行标记化。
- 参数::
tokenizer (来自 transformers 库的标记器) – 要使用的标记器。
max_length (int) – 序列的最大长度。
key (str, optional) – 查找文本的键。默认为
"prompt"
。padding (str, optional) – 填充类型。默认为
"max_length"
。truncation (bool, optional) – 是否应将序列截断为 max_length。
return_tensordict (bool, optional) – 如果为
True
,则返回 TensoDict。否则,将返回原始数据。device (torch.device, optional) – 存储数据的设备。如果
return_tensordict=False
,则忽略此选项。
此类的
__call__()
方法将执行以下操作读取与
label
字符串连接的prompt
字符串,并对其进行标记化。结果将存储在"input_ids"
TensorDict 条目中。使用提示的最后一个有效标记的索引编写
"prompt_rindex"
条目。编写一个
"valid_sample"
,它标识 tensordict 中哪些条目具有足够的标记来满足max_length
条件。返回一个
tensordict.TensorDict
实例,其中包含标记化的输入。
tensordict 批次大小将与输入的批次大小匹配。
示例
>>> from transformers import AutoTokenizer >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> tokenizer.pad_token = tokenizer.eos_token >>> example = { ... "prompt": ["This prompt is long enough to be tokenized.", "this one too!"], ... "label": ["Indeed it is.", 'It might as well be.'], ... } >>> fn = PromptTensorDictTokenizer(tokenizer, 50) >>> print(fn(example)) TensorDict( fields={ attention_mask: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False), input_ids: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False), prompt_rindex: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False), valid_sample: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False)