generate_next_token¶
- torchtune.generation.generate_next_token(model: TransformerDecoder, input_pos: Tensor, x: Tensor, q: Tensor, *, mask: Optional[Tensor] = None, temperature: float = 1.0, top_k: Optional[int] = None) Tuple[Tensor, Tensor] [源代码]¶
给定提示生成下一个标记,并返回相应的 logits。
- 参数:
model (TransformerDecoder) – 用于生成的模型
input_pos (torch.Tensor) – 包含与给定提示关联的位置编码的张量,形状为 [bsz x seq_length]。
x (torch.Tensor) – 包含与给定提示关联的标记 ID 的张量,形状为 [bsz x seq_length]。
q (torch.Tensor) – 用于 softmax 采样技巧的随机采样张量。请参阅 https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/generate.py#L40
mask (Optional[torch.Tensor]) – 形状为 [bsz x seq_length x seq_length] 的注意力掩码,默认为 None。
temperature (float) – 用于缩放预测 logits 的值,默认为 1.0。
top_k (Optional[int]) – 用于采样的 Top-k 值,默认为 None。
- 返回值:
- 两个张量的元组
- tokens (torch.Tensor):包含生成的标记的张量,
形状为 [bsz x 1]。
- logits (torch.Tensor):包含与生成的标记关联的 logits 的张量,
形状为 [bsz x seq_length x vocab_size]。
- 返回类型:
Tuple[torch.Tensor, torch.Tensor]