generate_next_token¶
- torchtune.generation.generate_next_token(model: TransformerDecoder, input_pos: Tensor, x: Tensor, q: Optional[Tensor] = None, *, mask: Optional[Tensor] = None, temperature: float = 1.0, top_k: Optional[int] = None) Tuple[Tensor, Tensor] [source]¶
根据提示生成下一个 token,并返回相应的 logits。
- 参数:
model (TransformerDecoder) – 用于生成的模型
input_pos (torch.Tensor) – 包含与给定提示相关的位置编码的 tensor,形状为 [bsz x seq_length]。
x (torch.Tensor) – 包含与给定提示相关的 token ID 的 tensor,形状为 [bsz x seq_length]。
q (Optional[torch.Tensor]) – 用于 softmax 采样技巧的随机采样 tensor。详见 https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/generate.py#L40
mask (Optional[torch.Tensor]) – 注意力 mask,形状为 [bsz x seq_length x seq_length],默认为 None。
temperature (float) – 用于缩放预测 logits 的值,默认为 1.0。
top_k (Optional[int]) – 用于采样的 Top-k 值,默认为 None。
- 返回值:
- 包含两个 tensor 的元组
- tokens (torch.Tensor): 包含生成的 token 的 tensor,
形状为 [bsz x 1]。
- logits (torch.Tensor): 包含与生成的 token 相关的 logits 的 tensor,
形状为 [bsz x 1 x vocab_size]。
- 返回类型:
Tuple[torch.Tensor, torch.Tensor]