示例¶
- torchtune.generation.sample(logits: Tensor, *, temperature: float = 1.0, top_k: Optional[int] = None, q: Optional[Tensor] = None) Tensor [source]¶
从概率分布中进行通用采样。包括对 Top-K 采样和温度的支持。
- 参数:
logits (torch.Tensor) – 要从中采样的 logits
temperature (float) – 用于缩放预测 logits 的值,默认值为 1.0。
top_k (Optional[int]) – 如果指定,我们将采样修剪为仅保留前 k 个概率内的 token id
q (Optional[torch.Tensor]) – 用于 softmax 采样技巧的随机采样张量。如果为 None,我们将使用默认的 softmax 采样技巧。默认值 None。
示例
>>> from torchtune.generation import sample >>> logits = torch.empty(3, 3).uniform_(0, 1) >>> sample(logits) tensor([[1], [2], [0]], dtype=torch.int32)
- 返回值:
采样的 token id
- 返回类型: