快捷方式

VisionCrossAttentionMask

class torchtune.modules.transforms.VisionCrossAttentionMask(tile_size: int, patch_size: int, image_token_id: int)[source]

计算文本 + 图像输入的交叉注意力掩码。参与与图像 token 交叉注意力的文本 token 在掩码中将显示 True,并遵循 Flamingo 论文图 7 (https://arxiv.org/pdf/2204.14198) 中描述的交错结构。

  1. 紧跟图像 token 之后直到下一个图像 token 的文本 token

  2. 连续的图像 token 关注随后的文本 token

     ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │   │ │   │ │   │ │   │ │   │
     └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
     ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img2 │   │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │   │ │   │ │   │ │   │ │   │
     └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
     ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img3 │   │ │   │ │   │ │   │ │   │ │   │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │
     └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
    <img1> <img2>These  are   two  dogs. <img3> This   is    a    cat.

结果掩码按每张图像构建,形状为 (text_seq_len, image_seq_len),其中 True 表示图像编码器输出的 token 在交叉注意力中关注文本序列中的 token。返回的掩码列表长度等于样本中的图像数量。

参数:
  • tile_size (int) – 图像变换产生的图像块大小

  • patch_size (int) – 每个 patch 的大小。用于将块分割成 patch。例如,对于 patch_size = 40,形状为 (400, 400) 的块将包含 10x10 的 patch 网格,每个 patch 的形状为 (40, 40)。

  • image_token_id (int) – 图像特殊 token 的 Token ID。

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源