TensorDictTokenizer¶
- class torchrl.data.TensorDictTokenizer(tokenizer, max_length, key='text', padding='max_length', truncation=True, return_tensordict=True, device=None)[源代码]¶
用于应用分词器处理文本示例的过程函数的工厂。
- 参数:
tokenizer (来自 transformers 库的分词器) – 要使用的分词器。
max_length (int) – 序列的最大长度。
key (str, 可选) – 查找文本的键。默认为
"text"
。padding (str, 可选) – 填充类型。默认为
"max_length"
。truncation (bool, 可选) – 序列是否应截断为 max_length。
return_tensordict (bool, 可选) – 如果为
True
,则返回 TensoDict。否则,将返回原始数据。device (torch.device, 可选) – 存储数据的设备。如果
return_tensordict=False
,则忽略此选项。
- 有关分词器的更多信息,请参阅 transformers 库
填充和截断: https://hugging-face.cn/docs/transformers/pad_truncation
返回:一个
tensordict.TensorDict
实例,其批次大小与输入数据相同。示例
>>> from transformers import AutoTokenizer >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> tokenizer.pad_token = 100 >>> process = TensorDictTokenizer(tokenizer, max_length=10) >>> # example with a single input >>> example = {"text": "I am a little worried"} >>> process(example) TensorDict( fields={ attention_mask: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False), input_ids: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> # example with a multiple inputs >>> example = {"text": ["Let me reassure you", "It will be ok"]} >>> process(example) TensorDict( fields={ attention_mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), input_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False)