CausalVariant¶
- class torch.nn.attention.bias.CausalVariant(value)[source][source]¶
用于注意力机制中的因果变体枚举。
定义了两种类型的因果偏置
UPPER_LEFT:表示标准因果注意力机制中的左上三角偏置。构建此偏置的等效 PyTorch 代码是
torch.tril(torch.ones(size, dtype=torch.bool))
例如,当 shape=(3,4) 时,具体化的偏置张量将是
[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0]]
LOWER_RIGHT:表示右下三角偏置,包含的值与矩阵的右下角对齐。
构建此偏置的等效 PyTorch 代码是
diagonal_offset = size[1] - size[0] torch.tril( torch.ones(size, dtype=torch.bool), diagonal=diagonal_offset, )
例如,当 shape=(3,4) 时,具体化的偏置张量将是
[[1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1]]
请注意,当查询和键/值张量的序列长度相等时,这些变体是等效的,因为此时三角矩阵是方阵。
警告
此枚举为原型,未来可能会发生变化。