快捷方式

VisionCrossAttentionMask

class torchtune.modules.transforms.VisionCrossAttentionMask(tile_size: int, patch_size: int, image_token_id: int)[source]

计算文本 + 图像输入的交叉注意力掩码。与图像 token 参与交叉注意力的文本 token 将在掩码中显示 True,并遵循 Flamingo 论文图 7 中布局的交错结构 (https://arxiv.org/pdf/2204.14198)

  1. 紧随图像 token 之后的文本 token,直到下一个图像 token

  2. 连续的图像 token 注意后续的文本 token

     ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │   │ │   │ │   │ │   │ │   │
     └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
     ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img2 │   │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │   │ │   │ │   │ │   │ │   │
     └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
     ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img3 │   │ │   │ │   │ │   │ │   │ │   │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │
     └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
    <img1> <img2>These  are   two  dogs. <img3> This   is    a    cat.

结果掩码是针对每个图像构建的,形状为 (text_seq_len, image_seq_len),其中 True 表示来自图像编码器的 token 输出在交叉注意力中注意文本序列中的 token。返回这些掩码的列表,其长度等于样本中图像的数量。

参数:
  • tile_size (int) – 来自图像变换的图像瓦片的大小

  • patch_size (int) – 每个补丁的大小。用于将瓦片划分为补丁。例如,对于 patch_size = 40,形状为 (400, 400) 的瓦片将具有 10x10 的补丁网格,每个补丁的形状为 (40, 40)。

  • image_token_id (int) – 图像特殊 token 的 Token ID。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源