sample¶
- torchtune.generation.sample(logits: Tensor, *, temperature: float = 1.0, top_k: Optional[int] = None, q: Optional[Tensor] = None) Tensor [源]¶
从概率分布中进行泛型采样。支持 Top-K 采样和温度。
- 参数:
logits (torch.Tensor) – 用于采样的 logits
temperature (float) – 用于缩放预测 logits 的值,默认为 1.0。
top_k (Optional[int]) – 如果指定,我们将采样范围限制在 top_k 概率内的 token ID。
q (Optional[torch.Tensor]) – 用于 softmax 采样技巧的随机采样张量。如果为 None,则使用默认的 softmax 采样技巧。默认为 None。
示例
>>> from torchtune.generation import sample >>> logits = torch.empty(3, 3).uniform_(0, 1) >>> sample(logits) tensor([[1], [2], [0]], dtype=torch.int32)
- 返回值:
采样的 token ID
- 返回类型: