VisionCrossAttentionMask¶
- class torchtune.modules.transforms.VisionCrossAttentionMask(tile_size: int, patch_size: int, image_token_id: int, max_num_tiles: Optional[int] = None)[源代码]¶
计算文本 + 图像输入的交叉注意力掩码。参与与图像标记进行交叉注意力的文本标记将在掩码中显示 True,并遵循 Flamingo 论文(https://arxiv.org/pdf/2204.14198)图 7 中规定的交错结构。
紧跟在图像标记之后的文本标记,直到下一个图像标记。
连续的图像标记会关注后续的文本标记。
┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ img2 │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ img3 │ │ │ │ │ │ │ │ │ │ │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ <img1> <img2>These are two dogs. <img3> This is a cat.
生成的掩码是针对每个图像构建的,形状为 (text_seq_len, image_seq_len),其中 True 表示图像编码器输出的标记在交叉注意力中关注文本序列中的标记。返回这些掩码的列表,长度等于样本中图像的数量。