快捷方式

torch.multinomial

torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) LongTensor

返回一个张量,其中每一行包含从多项式(更严格的定义是多元的,请参阅 torch.distributions.multinomial.Multinomial 获取更多详细信息)概率分布中采样的 num_samples 个索引,该概率分布位于张量 input 的对应行中。

注意

input 的行不需要加起来等于 1(在这种情况下,我们将这些值用作权重),但必须是非负、有限且具有非零总和。

索引按从左到右的顺序排列,对应于每个索引采样的时间(第一个样本放置在第一列中)。

如果 input 是一个向量,则 out 是大小为 num_samples 的向量。

如果 input 是一个具有 m 行的矩阵,则 out 是一个形状为 (m×num_samples)(m \times \text{num\_samples}) 的矩阵。

如果 replacement 为 True,则样本将有放回地抽取。

如果不是,则它们将无放回地抽取,这意味着当为一行抽取样本索引时,该索引就不能再次为该行抽取。

注意

当无放回地抽取时,num_samples 必须小于 input 中非零元素的数量(或者,如果它是一个矩阵,则小于每一行中非零元素的最小数量)。

参数
  • input (Tensor) – 包含概率的输入张量

  • num_samples (int) – 要抽取的样本数量

  • replacement (bool, optional) – 是否有放回地抽取

关键字参数
  • generator (torch.Generator, optional) – 用于采样的伪随机数生成器

  • out (Tensor, optional) – 输出张量。

示例

>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights
>>> torch.multinomial(weights, 2)
tensor([1, 2])
>>> torch.multinomial(weights, 5) # ERROR!
RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement
>>> torch.multinomial(weights, 4, replacement=True)
tensor([ 2,  1,  1,  1])

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的答案

查看资源