torch.multinomial¶
- torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) LongTensor ¶
返回一个张量,其中每一行包含从多项式(更严格的定义是多元的,请参阅
torch.distributions.multinomial.Multinomial
获取更多详细信息)概率分布中采样的num_samples
个索引,该概率分布位于张量input
的对应行中。注意
input
的行不需要加起来等于 1(在这种情况下,我们将这些值用作权重),但必须是非负、有限且具有非零总和。索引按从左到右的顺序排列,对应于每个索引采样的时间(第一个样本放置在第一列中)。
如果
input
是一个向量,则out
是大小为num_samples
的向量。如果
input
是一个具有 m 行的矩阵,则out
是一个形状为 的矩阵。如果 replacement 为
True
,则样本将有放回地抽取。如果不是,则它们将无放回地抽取,这意味着当为一行抽取样本索引时,该索引就不能再次为该行抽取。
注意
当无放回地抽取时,
num_samples
必须小于input
中非零元素的数量(或者,如果它是一个矩阵,则小于每一行中非零元素的最小数量)。- 参数
- 关键字参数
generator (
torch.Generator
, optional) – 用于采样的伪随机数生成器out (Tensor, optional) – 输出张量。
示例
>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights >>> torch.multinomial(weights, 2) tensor([1, 2]) >>> torch.multinomial(weights, 5) # ERROR! RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement >>> torch.multinomial(weights, 4, replacement=True) tensor([ 2, 1, 1, 1])