快捷方式

VisionTransformer

class torchtune.modules.VisionTransformer(patch_size: int, tile_size: int, num_layers: int, embed_dim: int, layer: Module, token_pos_embedding: Module, pre_tile_pos_embed: Optional[Module] = None, post_tile_pos_embed: Optional[Module] = None, cls_projection: Optional[Module] = None, out_indices: Optional[List[int]] = None, in_channels: int = 3)[source]

ViT 架构 (https://arxiv.org/abs/2010.11929) 的实现,支持分块裁剪图像,输出隐藏层以及可选的 CLS 投影。

ViT 是一种 Transformer 架构,它接收图像并输出 N 个嵌入标记,这些标记表示该图像。每个图像通过卷积被划分为**块**。这些块被展平,随后由 Transformer 视为**标记**。

为了进一步提高 ViT 的性能并避免缩小图像,我们支持分块裁剪图像,即在预处理阶段将图像划分为**块**。例如,如果 tile_size=400,则可以将 800x400 的图像裁剪成两个 400x400 的块,而不是将其缩小到 400x400。有关预处理的详细信息,请参阅 torchtune.models.clip._transforms.CLIPImageTransform

每个块进一步通过卷积运算分解成块。例如,如果你的 patch_size=40,那么每个 (400, 400) 块将变成 10x10 块的网格,你的整个图像将有 num_tiles * n_tokens -> num_tiles * (10x10 块 + 1 个 CLS 标记) -> num_tiles * 101。

在 Transformer 层之前,将 CLS 标记作为第一个标记添加到每个块中。在 Transformer 中,称为 CLS 的标记是一个特殊标记,它被添加到每个序列的开头。例如,此标记可用于表示整个输入,而不是使用池化操作。

为了帮助模型“看到”整个图像,我们使用位置嵌入。如果你的图像被分块裁剪,则需要使用分块位置嵌入。

否则,pre 和 post tile_pos_embed 应该为 None,你只需要一个简单的标记位置嵌入即可。

所有图像都将被视为图块的堆叠,即使您的图像没有进行图块裁剪。在这种情况下,您的图像将由单个图块组成。

总结

  1. 图像在预处理过程中被分解成图块。

  2. 在 ViT 中,图块将被分解成 patch。

  3. patch 将被扁平化并进行变换。我们称之为 token,因为 Transformer 就是这样看待它们的。

图像:形状 (8x8)

|  1 |  2 |  3 |  4 |  5 |  6 |  7 |  8 |
|  9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 |
| 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 |
| 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 |
| 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 |
| 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 |
| 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 |
| 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 |

图块:形状 (4,4,4) # (图块数量, 图块大小, 图块大小)

|  1 |  2 |  3 |  4 |    |  5 |  6 |  7 |  8 |
|  9 | 10 | 11 | 12 |    | 13 | 14 | 15 | 16 |
| 17 | 18 | 19 | 20 |    | 21 | 22 | 23 | 24 |
| 25 | 26 | 27 | 28 |    | 29 | 30 | 31 | 32 |

| 33 | 34 | 35 | 36 |    | 37 | 38 | 39 | 40 |
| 41 | 42 | 43 | 44 |    | 45 | 46 | 47 | 48 |
| 49 | 50 | 51 | 52 |    | 53 | 54 | 55 | 56 |
| 57 | 58 | 59 | 60 |    | 61 | 62 | 63 | 64 |

patch:形状 (4,4,2,2) # (图块数量, 每个图块的 patch 数量, patch 大小, patch 大小)

|  1 |  2 |    |  3 |  4 |    |  5 |  6 |    |  7 |  8 |
|  9 | 10 |    | 11 | 12 |    | 13 | 14 |    | 15 | 16 |

| 17 | 18 |    | 19 | 20 |    | 21 | 22 |    | 23 | 24 |
| 25 | 26 |    | 27 | 28 |    | 29 | 30 |    | 31 | 32 |

| 33 | 34 |    | 35 | 36 |    | 37 | 38 |    | 39 | 40 |
| 41 | 42 |    | 43 | 44 |    | 45 | 46 |    | 47 | 48 |

| 49 | 50 |    | 51 | 52 |    | 53 | 54 |    | 55 | 56 |
| 57 | 58 |    | 59 | 60 |    | 61 | 62 |    | 63 | 64 |

token:形状 (4, 4, 4) # (图块数量, 每个图块的 patch 数量, 嵌入维度)

|  1 |  2 |  9 |  10 |    |  3 |  4 |  11 |  12 |    |  17 |  18 |  25 |  26 |    | 19 | 20 |  27 |  28 |
| ... continuation of data ...
| ... continuation of data ...
| 37 | 38 | 45 |  46 |    | 39 |  40 | 47 |  48 |    | 53 | 54 |  61 |  62 |    | 55 | 56 |  63 |  64 |

对于位置嵌入

每个图块都相同,每个 token 都不同。

|  1 |  2 |  3 |  4 |    |  1 |  2 |  3 |  4 |
|  9 | 10 | 11 | 12 |    |  9 | 10 | 11 | 12 |
| 17 | 18 | 19 | 20 |    | 17 | 18 | 19 | 20 |
| 25 | 26 | 27 | 28 |    | 25 | 26 | 27 | 28 |

|  1 |  2 |  3 |  4 |    |  1 |  2 |  3 |  4 |
|  9 | 10 | 11 | 12 |    |  9 | 10 | 11 | 12 |
| 17 | 18 | 19 | 20 |    | 17 | 18 | 19 | 20 |
| 25 | 26 | 27 | 28 |    | 25 | 26 | 27 | 28 |

每个图块都不同,每个 token 都不同。

|  1 |  2 |    |  3 |  4 |    |  5 |  6 |    |  7 |  8 |
|  9 | 10 |    | 11 | 12 |    | 13 | 14 |    | 15 | 16 |

| 17 | 18 |    | 19 | 20 |    | 21 | 22 |    | 23 | 24 |
| 25 | 26 |    | 27 | 28 |    | 29 | 30 |    | 31 | 32 |

| 33 | 34 |    | 35 | 36 |    | 37 | 38 |    | 39 | 40 |
| 41 | 42 |    | 43 | 44 |    | 45 | 46 |    | 47 | 48 |

| 49 | 50 |    | 51 | 52 |    | 53 | 54 |    | 55 | 56 |
| 57 | 58 |    | 59 | 60 |    | 61 | 62 |    | 63 | 64 |

每个图块都不同,每个图块内的 token 都相同。

|  1 |  1 |  1 |  1 |    |  2 |  2 |  2 |  3 |
|  1 |  1 |  1 |  1 |    |  2 |  2 |  2 |  3 |
|  1 |  1 |  1 |  1 |    |  2 |  2 |  2 |  3 |
|  1 |  1 |  1 |  1 |    |  2 |  2 |  2 |  3 |

|  3 |  3 |  3 |  3 |    |  4 |  4 |  4 |  4 |
|  3 |  3 |  3 |  3 |    |  4 |  4 |  4 |  4 |
|  3 |  3 |  3 |  3 |    |  4 |  4 |  4 |  4 |
|  3 |  3 |  3 |  3 |    |  4 |  4 |  4 |  4 |
参数:
  • patch_size (int) – 每个 patch 的大小。用于将图块划分为 patch。例如,对于 patch_size=40,形状为 (400, 400) 的图块将具有 10x10 的 patch 网格。

  • tile_size (int) – 如果图像事先进行了图块裁剪,则为图像图块的大小。否则,为输入图像的大小。在这种情况下,该函数将您的图像视为单个图块,每个图块的形状为 (40, 40)。

  • num_layers (int) – Transformer 层的数量。

  • embed_dim (int) – 每个 patch 嵌入(token)的维度。

  • layer (nn.Module) – Transformer 层模块。

  • token_pos_embedding (nn.Module) – token 位置嵌入模块。

  • pre_tile_pos_embed (Optional[nn.Module]) – 图块前位置嵌入模块。如果您的图像事先没有进行图块裁剪,则应为 None。

  • post_tile_pos_embed (Optional[nn.Module]) – 图块后位置嵌入模块。如果您的图像事先没有进行图块裁剪,则应为 None。

  • cls_projection (Optional[nn.Module]) – CLS 投影模块。它应该接收形状为 (bsz * n_tiles, n_tokens, embed_dim) 的输入张量,并输出形状为 (bsz * n_tiles, cls_output_dim) 的张量。如果提供,则只会输出 CLS token 投影,而不是所有 token。

  • out_indices (Optional[List[int]]) – 要返回的隐藏层索引。如果提供,它将返回 Transformer 层在进入下一层之前的中间结果。例如,out_indices=[0,3] 将返回在经过第一层和第四层之前的 token。

  • in_channels (int) – 图像输入通道的数量。

引发:
  • ValueError – 如果 tile_size 不大于 0。

  • ValueError – 如果 patch_size 不大于 0。

  • ValueError – 如果 len(out_indices) 大于 num_layers

forward(images: Tensor, aspect_ratio: Optional[Tensor] = None) Tuple[Tensor, List[Tensor]][source]

处理图像并返回 token 和隐藏状态。

每个样本中的多个图像:我们在输入中添加了一个维度 n_imgs。当单个样本包含多个图像时,这很有用,例如

  • 样本 1:“<image> 这是什么动物?”

  • 样本 2:“我比 <image> 更喜欢 <image>。”

在这种情况下,样本 1 有一个图像,样本 2 有两个图像。max_n_imgs = max(2,1) = 2。因此,您的输入应该具有形状 (bsz=2, n_imgs=2, 图块数量, 通道数量, 图块大小, 图块大小)。

请注意,要进行批处理,您需要将 n_imgs 填充到 max_n_imgs 和 max_num_tiles。

参数:
  • images (torch.Tensor) – 形状为 (bsz, n_imgs, 图块数量, 通道数量, 图块大小, 图块大小) 的 torch.Tensor。

  • aspect_ratio (Optional[torch.Tensor]) – 形状为 (bsz, n_imgs, 2) 的 torch.Tensor。如果所有图像都只有一个图块,即它们没有进行图块裁剪,则应为 None。用于计算图块的位置嵌入。

返回值:

一个元组:(x, hidden_states),

其中 x 是形状为 (bsz, n_imgs, 图块数量, token 数量, 嵌入维度) 的 torch.tensor,而 hidden_states 的形状是长度为 len(out_indices) 的列表,其中每个 torch.tensor 的形状为 (bsz, n_imgs, 图块数量, token 数量, 嵌入维度)。

返回类型:

Tuple[torch.Tensor, List[torch.Tensor]]

引发:

ValueError – 如果 aspect_ratio 为 None,但在批处理中 n_tiles > 1。

示例

>>> from torchtune.modules.transforms.vision_utils.tile_crop import tile_crop
>>> from torchtune.modules import VisionTransformer
>>>
>>> num_channels = 3
>>> image_size = (800,400)
>>> tile_size = 400
>>> patch_size=40
>>> patch_grid_size = tile_size // patch_size
>>>
>>> # for details about preprocessing, please check
>>> # torchtune.models.clip._transforms.CLIPImageTransform
>>>
>>> # create a random image
>>> image = torch.rand(num_channels, image_size[0], image_size[1])
>>>
>>> # (num_tiles, nch, h, w) -> (2, 3, 400, 400)
>>> tile_cropped_image = tile_crop(image, tile_size)
>>> aspect_ratio = torch.tensor([2,1])
>>>
>>> # make it a batch of 1 image
>>> batch_image = tile_cropped_image.unsqueeze(0)
>>> batch_aspect_ratio = aspect_ratio.unsqueeze(0)
>>>
>>> # make it have only 1 image per sample
>>> batch_image = tile_cropped_image.unsqueeze(1)
>>> batch_aspect_ratio = aspect_ratio.unsqueeze(1)
>>>
>>> # For a detailed example, please check
>>> # torchtune.models.clip._position_embeddings.clip_vision_encoder
>>> # model = VisionTransformer(
... #           out_indices = [1,2,3,4,5],
... #           patch_size=40,
... #           patch_grid_size = patch_grid_size,
... #           embed_dim = 32,
... #           num_layers = 6,
... #           in_channels = num_channels,
... #           ...)
>>>
>>> x, hidden_states = model(images = batch_image, aspect_ratio = batch_aspect_ratio)
>>>
>>> # (bsz, n_imgs, num_tiles, num_patches_per_tile + CLS token, embed_dim)
>>> print(x.shape)
torch.Size([1, 1, 2, 101, 32])
>>>
>>> # list with tensors of shape (bsz, n_imgs, num_tiles, num_patches_per_tile + CLS token, embed_dim)
>>> print(len(hidden_states))
5

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源