VisionTransformer¶
- class torchtune.modules.VisionTransformer(patch_size: int, tile_size: int, num_layers: int, embed_dim: int, layer: Module, token_pos_embedding: Module, pre_tile_pos_embed: Optional[Module] = None, post_tile_pos_embed: Optional[Module] = None, cls_projection: Optional[Module] = None, out_indices: Optional[List[int]] = None, in_channels: int = 3)[source]¶
ViT 架构 (https://arxiv.org/abs/2010.11929) 的实现,支持分块裁剪图像,输出隐藏层以及可选的 CLS 投影。
ViT 是一种 Transformer 架构,它接收图像并输出 N 个嵌入标记,这些标记表示该图像。每个图像通过卷积被划分为**块**。这些块被展平,随后由 Transformer 视为**标记**。
为了进一步提高 ViT 的性能并避免缩小图像,我们支持分块裁剪图像,即在预处理阶段将图像划分为**块**。例如,如果
tile_size=400
,则可以将 800x400 的图像裁剪成两个 400x400 的块,而不是将其缩小到 400x400。有关预处理的详细信息,请参阅torchtune.models.clip._transforms.CLIPImageTransform
。每个块进一步通过卷积运算分解成块。例如,如果你的
patch_size=40
,那么每个 (400, 400) 块将变成 10x10 块的网格,你的整个图像将有 num_tiles * n_tokens -> num_tiles * (10x10 块 + 1 个 CLS 标记) -> num_tiles * 101。在 Transformer 层之前,将 CLS 标记作为第一个标记添加到每个块中。在 Transformer 中,称为 CLS 的标记是一个特殊标记,它被添加到每个序列的开头。例如,此标记可用于表示整个输入,而不是使用池化操作。
为了帮助模型“看到”整个图像,我们使用位置嵌入。如果你的图像被分块裁剪,则需要使用分块位置嵌入。
token_pos_embedding(分块):
torchtune.models.clip._position_embeddings.TiledTokenPositionalEmbedding
pre_tile_pos_embed:
torchtune.models.clip._position_embeddings.TilePositionalEmbedding
post_tile_pos_embed:
torchtune.models.clip._position_embeddings.TilePositionalEmbedding
否则,pre 和 post tile_pos_embed 应该为 None,你只需要一个简单的标记位置嵌入即可。
所有图像都将被视为图块的堆叠,即使您的图像没有进行图块裁剪。在这种情况下,您的图像将由单个图块组成。
总结
图像在预处理过程中被分解成图块。
在 ViT 中,图块将被分解成 patch。
patch 将被扁平化并进行变换。我们称之为 token,因为 Transformer 就是这样看待它们的。
图像:形状 (8x8)
| 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 |
图块:形状 (4,4,4) # (图块数量, 图块大小, 图块大小)
| 1 | 2 | 3 | 4 | | 5 | 6 | 7 | 8 | | 9 | 10 | 11 | 12 | | 13 | 14 | 15 | 16 | | 17 | 18 | 19 | 20 | | 21 | 22 | 23 | 24 | | 25 | 26 | 27 | 28 | | 29 | 30 | 31 | 32 | | 33 | 34 | 35 | 36 | | 37 | 38 | 39 | 40 | | 41 | 42 | 43 | 44 | | 45 | 46 | 47 | 48 | | 49 | 50 | 51 | 52 | | 53 | 54 | 55 | 56 | | 57 | 58 | 59 | 60 | | 61 | 62 | 63 | 64 |
patch:形状 (4,4,2,2) # (图块数量, 每个图块的 patch 数量, patch 大小, patch 大小)
| 1 | 2 | | 3 | 4 | | 5 | 6 | | 7 | 8 | | 9 | 10 | | 11 | 12 | | 13 | 14 | | 15 | 16 | | 17 | 18 | | 19 | 20 | | 21 | 22 | | 23 | 24 | | 25 | 26 | | 27 | 28 | | 29 | 30 | | 31 | 32 | | 33 | 34 | | 35 | 36 | | 37 | 38 | | 39 | 40 | | 41 | 42 | | 43 | 44 | | 45 | 46 | | 47 | 48 | | 49 | 50 | | 51 | 52 | | 53 | 54 | | 55 | 56 | | 57 | 58 | | 59 | 60 | | 61 | 62 | | 63 | 64 |
token:形状 (4, 4, 4) # (图块数量, 每个图块的 patch 数量, 嵌入维度)
| 1 | 2 | 9 | 10 | | 3 | 4 | 11 | 12 | | 17 | 18 | 25 | 26 | | 19 | 20 | 27 | 28 | | ... continuation of data ... | ... continuation of data ... | 37 | 38 | 45 | 46 | | 39 | 40 | 47 | 48 | | 53 | 54 | 61 | 62 | | 55 | 56 | 63 | 64 |
对于位置嵌入
每个图块都相同,每个 token 都不同。
torchtune.models.clip._position_embeddings.TokenPositionalEmbedding
torchtune.models.clip._position_embeddings.TiledTokenPositionalEmbedding
| 1 | 2 | 3 | 4 | | 1 | 2 | 3 | 4 | | 9 | 10 | 11 | 12 | | 9 | 10 | 11 | 12 | | 17 | 18 | 19 | 20 | | 17 | 18 | 19 | 20 | | 25 | 26 | 27 | 28 | | 25 | 26 | 27 | 28 | | 1 | 2 | 3 | 4 | | 1 | 2 | 3 | 4 | | 9 | 10 | 11 | 12 | | 9 | 10 | 11 | 12 | | 17 | 18 | 19 | 20 | | 17 | 18 | 19 | 20 | | 25 | 26 | 27 | 28 | | 25 | 26 | 27 | 28 |
每个图块都不同,每个 token 都不同。
| 1 | 2 | | 3 | 4 | | 5 | 6 | | 7 | 8 | | 9 | 10 | | 11 | 12 | | 13 | 14 | | 15 | 16 | | 17 | 18 | | 19 | 20 | | 21 | 22 | | 23 | 24 | | 25 | 26 | | 27 | 28 | | 29 | 30 | | 31 | 32 | | 33 | 34 | | 35 | 36 | | 37 | 38 | | 39 | 40 | | 41 | 42 | | 43 | 44 | | 45 | 46 | | 47 | 48 | | 49 | 50 | | 51 | 52 | | 53 | 54 | | 55 | 56 | | 57 | 58 | | 59 | 60 | | 61 | 62 | | 63 | 64 |
每个图块都不同,每个图块内的 token 都相同。
| 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 |
- 参数:
patch_size (int) – 每个 patch 的大小。用于将图块划分为 patch。例如,对于
patch_size=40
,形状为 (400, 400) 的图块将具有 10x10 的 patch 网格。tile_size (int) – 如果图像事先进行了图块裁剪,则为图像图块的大小。否则,为输入图像的大小。在这种情况下,该函数将您的图像视为单个图块,每个图块的形状为 (40, 40)。
num_layers (int) – Transformer 层的数量。
embed_dim (int) – 每个 patch 嵌入(token)的维度。
layer (nn.Module) – Transformer 层模块。
token_pos_embedding (nn.Module) – token 位置嵌入模块。
pre_tile_pos_embed (Optional[nn.Module]) – 图块前位置嵌入模块。如果您的图像事先没有进行图块裁剪,则应为 None。
post_tile_pos_embed (Optional[nn.Module]) – 图块后位置嵌入模块。如果您的图像事先没有进行图块裁剪,则应为 None。
cls_projection (Optional[nn.Module]) – CLS 投影模块。它应该接收形状为 (bsz * n_tiles, n_tokens, embed_dim) 的输入张量,并输出形状为 (bsz * n_tiles, cls_output_dim) 的张量。如果提供,则只会输出 CLS token 投影,而不是所有 token。
out_indices (Optional[List[int]]) – 要返回的隐藏层索引。如果提供,它将返回 Transformer 层在进入下一层之前的中间结果。例如,
out_indices=[0,3]
将返回在经过第一层和第四层之前的 token。in_channels (int) – 图像输入通道的数量。
- 引发:
ValueError – 如果 tile_size 不大于 0。
ValueError – 如果 patch_size 不大于 0。
ValueError – 如果 len(out_indices) 大于 num_layers。
- forward(images: Tensor, aspect_ratio: Optional[Tensor] = None) Tuple[Tensor, List[Tensor]] [source]¶
处理图像并返回 token 和隐藏状态。
每个样本中的多个图像:我们在输入中添加了一个维度 n_imgs。当单个样本包含多个图像时,这很有用,例如
样本 1:“<image> 这是什么动物?”
样本 2:“我比 <image> 更喜欢 <image>。”
在这种情况下,样本 1 有一个图像,样本 2 有两个图像。max_n_imgs = max(2,1) = 2。因此,您的输入应该具有形状 (bsz=2, n_imgs=2, 图块数量, 通道数量, 图块大小, 图块大小)。
请注意,要进行批处理,您需要将 n_imgs 填充到 max_n_imgs 和 max_num_tiles。
- 参数:
images (torch.Tensor) – 形状为 (bsz, n_imgs, 图块数量, 通道数量, 图块大小, 图块大小) 的 torch.Tensor。
aspect_ratio (Optional[torch.Tensor]) – 形状为 (bsz, n_imgs, 2) 的 torch.Tensor。如果所有图像都只有一个图块,即它们没有进行图块裁剪,则应为 None。用于计算图块的位置嵌入。
- 返回值:
- 一个元组:(x, hidden_states),
其中 x 是形状为 (bsz, n_imgs, 图块数量, token 数量, 嵌入维度) 的 torch.tensor,而 hidden_states 的形状是长度为 len(out_indices) 的列表,其中每个 torch.tensor 的形状为 (bsz, n_imgs, 图块数量, token 数量, 嵌入维度)。
- 返回类型:
Tuple[torch.Tensor, List[torch.Tensor]]
- 引发:
ValueError – 如果 aspect_ratio 为 None,但在批处理中 n_tiles > 1。
示例
>>> from torchtune.modules.transforms.vision_utils.tile_crop import tile_crop >>> from torchtune.modules import VisionTransformer >>> >>> num_channels = 3 >>> image_size = (800,400) >>> tile_size = 400 >>> patch_size=40 >>> patch_grid_size = tile_size // patch_size >>> >>> # for details about preprocessing, please check >>> # torchtune.models.clip._transforms.CLIPImageTransform >>> >>> # create a random image >>> image = torch.rand(num_channels, image_size[0], image_size[1]) >>> >>> # (num_tiles, nch, h, w) -> (2, 3, 400, 400) >>> tile_cropped_image = tile_crop(image, tile_size) >>> aspect_ratio = torch.tensor([2,1]) >>> >>> # make it a batch of 1 image >>> batch_image = tile_cropped_image.unsqueeze(0) >>> batch_aspect_ratio = aspect_ratio.unsqueeze(0) >>> >>> # make it have only 1 image per sample >>> batch_image = tile_cropped_image.unsqueeze(1) >>> batch_aspect_ratio = aspect_ratio.unsqueeze(1) >>> >>> # For a detailed example, please check >>> # torchtune.models.clip._position_embeddings.clip_vision_encoder >>> # model = VisionTransformer( ... # out_indices = [1,2,3,4,5], ... # patch_size=40, ... # patch_grid_size = patch_grid_size, ... # embed_dim = 32, ... # num_layers = 6, ... # in_channels = num_channels, ... # ...) >>> >>> x, hidden_states = model(images = batch_image, aspect_ratio = batch_aspect_ratio) >>> >>> # (bsz, n_imgs, num_tiles, num_patches_per_tile + CLS token, embed_dim) >>> print(x.shape) torch.Size([1, 1, 2, 101, 32]) >>> >>> # list with tensors of shape (bsz, n_imgs, num_tiles, num_patches_per_tile + CLS token, embed_dim) >>> print(len(hidden_states)) 5