快捷方式

TensorDictTokenizer

class torchrl.data.TensorDictTokenizer(tokenizer, max_length, key='text', padding='max_length', truncation=True, return_tensordict=True, device=None)[源代码]

用于应用分词器处理文本示例的过程函数的工厂。

参数:
  • tokenizer (来自 transformers 库的分词器) – 要使用的分词器。

  • max_length (int) – 序列的最大长度。

  • key (str, 可选) – 查找文本的键。默认为 "text"

  • padding (str, 可选) – 填充类型。默认为 "max_length"

  • truncation (bool, 可选) – 序列是否应截断为 max_length。

  • return_tensordict (bool, 可选) – 如果为 True,则返回 TensoDict。否则,将返回原始数据。

  • device (torch.device, 可选) – 存储数据的设备。如果 return_tensordict=False,则忽略此选项。

有关分词器的更多信息,请参阅 transformers 库

填充和截断: https://hugging-face.cn/docs/transformers/pad_truncation

返回:一个 tensordict.TensorDict 实例,其批次大小与输入数据相同。

示例

>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> tokenizer.pad_token = 100
>>> process = TensorDictTokenizer(tokenizer, max_length=10)
>>> # example with a single input
>>> example = {"text": "I am a little worried"}
>>> process(example)
TensorDict(
    fields={
        attention_mask: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> # example with a multiple inputs
>>> example = {"text": ["Let me reassure you", "It will be ok"]}
>>> process(example)
TensorDict(
    fields={
        attention_mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源