快捷方式

generate

torchtune.generation.generate(model: TransformerDecoder, prompt: Tensor, *, max_generated_tokens: int, pad_id: int = 0, temperature: float = 1.0, top_k: Optional[int] = None, stop_tokens: Optional[List[int]] = None, rng: Optional[Generator] = None, custom_generate_next_token: Optional[Callable] = None) Tuple[Tensor, Tensor][源码]

根据提示从模型生成 token,并返回生成的 logits。

参数:
  • model (TransformerDecoder) – 用于生成的模型

  • prompt (torch.Tensor) – 包含给定提示对应的 token ID 的张量,形状为 [seq_length] 或 [bsz x seq_length]。

  • max_generated_tokens (int) – 要生成的 token 数量

  • pad_id (int) – 用于填充的 token ID,默认为 0。

  • temperature (float) – 用于缩放预测 logits 的值,默认为 1.0。

  • top_k (Optional[int]) – 如果指定,我们将采样修剪为仅包含 top_k 概率范围内的 token ID,默认为 None。

  • stop_tokens (Optional[List[int]]) – 如果指定,当生成任何这些 token 时停止生成,默认为 None。

  • rng (Optional[torch.Generator]) – 随机数生成器,默认为 None。

  • custom_generate_next_token (Optional[Callable]) – 如果指定,我们将使用 custom_generate_next_token 函数。这通常只在出于性能原因想要指定 torch.compile 版本的 generate next token 时有用。如果为 None,我们使用默认的 generate_next_token()。默认为 None。

注意

此函数仅在 decoder-only 模型上进行过测试。

示例

>>> model = torchtune.models.llama3.llama3_8b()
>>> tokenizer = torchtune.models.llama3.llama3_tokenizer()
>>> prompt = tokenizer.encode("Hi my name is")
>>> rng.manual_seed(42)
>>> output, logits = generate(model, torch.tensor(prompt), max_generated_tokens=100, pad_id=0)
>>> print(tokenizer.decode(output[0].tolist()))
Hi my name is Jeremy and I'm a friendly language model assistant!
返回值:

两个张量的元组
  • tokens (torch.Tensor): 包含生成的 token 的张量,

    形状为 [bsz x seq_len + num_generated_tokens],其中如果提供了 stop_tokensnum_generated_tokens 可能小于 max_generated_tokens

  • logits (torch.Tensor): 包含与生成的 token 相关的 logits 的张量,

    形状为 [bsz x num_generated_tokens x vocab_size]

返回类型:

Tuple[torch.Tensor, torch.Tensor]

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源