快捷方式

示例

torchtune.generation.sample(logits: Tensor, *, temperature: float = 1.0, top_k: Optional[int] = None, q: Optional[Tensor] = None) Tensor[source]

从概率分布中进行通用采样。包括对 Top-K 采样和温度的支持。

参数:
  • logits (torch.Tensor) – 要从中采样的 logits

  • temperature (float) – 用于缩放预测 logits 的值,默认值为 1.0。

  • top_k (Optional[int]) – 如果指定,我们将采样修剪为仅保留前 k 个概率内的 token id

  • q (Optional[torch.Tensor]) – 用于 softmax 采样技巧的随机采样张量。如果为 None,我们将使用默认的 softmax 采样技巧。默认值 None。

示例

>>> from torchtune.generation import sample
>>> logits = torch.empty(3, 3).uniform_(0, 1)
>>> sample(logits)
tensor([[1],
        [2],
        [0]], dtype=torch.int32)
返回值:

采样的 token id

返回类型:

torch.Tensor

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源