生成¶
- torchtune.generation.generate(model: TransformerDecoder, prompt: Tensor, *, max_generated_tokens: int, pad_id: int = 0, temperature: float = 1.0, top_k: Optional[int] = None, stop_tokens: Optional[List[int]] = None, rng: Optional[Generator] = None, custom_generate_next_token: Optional[Callable] = None) Tuple[Tensor, Tensor] [source]¶
根据提示条件从模型中生成令牌,并返回生成的令牌的 logits。
- 参数:
model (TransformerDecoder) – 用于生成的模型
prompt (torch.Tensor) – 包含与给定提示关联的令牌 ID 的张量,形状为 [seq_length] 或 [bsz x seq_length]。
max_generated_tokens (int) – 要生成的令牌数量
pad_id (int) – 用于填充的令牌 ID,默认值为 0。
temperature (float) – 用于缩放预测 logits 的值,默认值为 1.0。
top_k (Optional[int]) – 如果指定,我们将采样修剪为仅包含 top_k 概率内的令牌 ID,默认值为 None。
stop_tokens (Optional[List[int]]) – 如果指定,当生成任何这些令牌时,生成将停止,默认值为 None。
rng (Optional[torch.Generator]) – 随机数生成器,默认值为 None。
custom_generate_next_token (Optional[Callable]) – 如果指定,我们将使用
custom_generate_next_token function
。这通常仅在您想要为性能原因指定torch.compile
版本的生成下一个令牌时有用。如果为 None,我们将使用默认的generate_next_token()
。默认值为 None。
注意
此函数仅在解码器专用模型上经过测试。
示例
>>> model = torchtune.models.llama3.llama3_8b() >>> tokenizer = torchtune.models.llama3.llama3_tokenizer() >>> prompt = tokenizer.encode("Hi my name is") >>> rng.manual_seed(42) >>> output, logits = generate(model, torch.tensor(prompt), max_generated_tokens=100, pad_id=0) >>> print(tokenizer.decode(output[0].tolist())) Hi my name is Jeremy and I'm a friendly language model assistant!
- 返回:
- 两个张量的元组
- tokens (torch.Tensor):包含生成的令牌的张量,
形状为
[bsz x seq_len + num_generated_tokens]
,其中num_generated_tokens
可能小于max_generated_tokens
(如果提供了stop_tokens
)。
- logits (torch.Tensor):包含与生成的令牌关联的 logits 的张量,
形状为
[bsz x seq_len + num_generated_tokens x vocab_size]
。
- 返回类型:
Tuple[torch.Tensor, torch.Tensor]