generate_next_token¶
- torchtune.generation.generate_next_token(model: TransformerDecoder, input_pos: Tensor, x: Tensor, q: Optional[Tensor] = None, *, mask: Optional[Tensor] = None, temperature: float = 1.0, top_k: Optional[int] = None) Tuple[Tensor, Tensor] [source]¶
根据提示生成下一个 token,并返回相应的 logits。
- 参数:
model (TransformerDecoder) – 用于生成的模型
input_pos (torch.Tensor) – 张量,包含与给定提示相关联的位置编码,形状为 [bsz x seq_length]。
x (torch.Tensor) – 张量,包含与给定提示相关联的 token ID,形状为 [bsz x seq_length]。
q (Optional[torch.Tensor]) – 随机采样的张量,用于 softmax 采样技巧。参见 https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/generate.py#L40
mask (Optional[torch.Tensor]) – 注意力掩码,形状为 [bsz x seq_length x seq_length],默认为 None。
temperature (float) – 用于缩放预测 logits 的值,默认为 1.0。
top_k (Optional[int]) – 用于采样的 Top-k 值,默认为 None。
- 返回:
- 两个张量的元组
- tokens (torch.Tensor): 包含生成的 token 的张量,
形状为 [bsz x 1]。
- logits (torch.Tensor): 包含与生成的 token 相关联的 logits 的张量,
形状为 [bsz x seq_length x vocab_size]。
- 返回类型:
Tuple[torch.Tensor, torch.Tensor]