快捷方式

VisionCrossAttentionMask

class torchtune.modules.transforms.VisionCrossAttentionMask(tile_size: int, patch_size: int, image_token_id: int, max_num_tiles: Optional[int] = None)[源代码]

计算文本 + 图像输入的交叉注意力掩码。参与与图像标记进行交叉注意力的文本标记将在掩码中显示 True,并遵循 Flamingo 论文(https://arxiv.org/pdf/2204.14198)图 7 中规定的交错结构。

  1. 紧跟在图像标记之后的文本标记,直到下一个图像标记。

  2. 连续的图像标记会关注后续的文本标记。

     ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │   │ │   │ │   │ │   │ │   │
     └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
     ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img2 │   │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │   │ │   │ │   │ │   │ │   │
     └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
     ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img3 │   │ │   │ │   │ │   │ │   │ │   │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │
     └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
    <img1> <img2>These  are   two  dogs. <img3> This   is    a    cat.

生成的掩码是针对每个图像构建的,形状为 (text_seq_len, image_seq_len),其中 True 表示图像编码器输出的标记在交叉注意力中关注文本序列中的标记。返回这些掩码的列表,长度等于样本中图像的数量。

参数:
  • tile_size (int) – 图像转换中图像块的大小

  • patch_size (int) – 每个补丁的大小。用于将块划分为补丁。例如,对于 patch_size = 40,形状为 (400, 400) 的块将具有 10x10 网格的补丁,每个补丁的形状为 (40, 40)。

  • image_token_id (int) – 图像特殊标记的标记 ID。

  • max_num_tiles (Optional[int]) – 图像中的最大块数,用于在推理期间填充掩码。默认为 None

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源