VisionCrossAttentionMask¶
- class torchtune.modules.transforms.VisionCrossAttentionMask(tile_size: int, patch_size: int, image_token_id: int)[source]¶
计算文本 + 图像输入的交叉注意力掩码。与图像 token 参与交叉注意力的文本 token 将在掩码中显示 True,并遵循 Flamingo 论文图 7 中布局的交错结构 (https://arxiv.org/pdf/2204.14198)
紧随图像 token 之后的文本 token,直到下一个图像 token
连续的图像 token 注意后续的文本 token
┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ img2 │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ img3 │ │ │ │ │ │ │ │ │ │ │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ <img1> <img2>These are two dogs. <img3> This is a cat.
结果掩码是针对每个图像构建的,形状为 (text_seq_len, image_seq_len),其中 True 表示来自图像编码器的 token 输出在交叉注意力中注意文本序列中的 token。返回这些掩码的列表,其长度等于样本中图像的数量。