VisionTransformer¶
- class torchtune.modules.VisionTransformer(patch_size: int, tile_size: int, num_layers: int, embed_dim: int, layer: Module, token_pos_embedding: Module, pre_tile_pos_embed: Optional[Module] = None, post_tile_pos_embed: Optional[Module] = None, cls_projection: Optional[Module] = None, out_indices: Optional[List[int]] = None, in_channels: int = 3, append_cls_token: bool = False)[source]¶
ViT 架构的实现 (https://arxiv.org/abs/2010.11929),支持 tile 裁剪的图像、隐藏层输出和可选的 CLS 投影。
ViT 是一种 Transformer 架构,它接收图像并输出代表该图像的 N 个嵌入 tokens。每个图像通过卷积被分割成 **patches**。这些 patches 被展平,随后被 Transformer 视为 **tokens**。
为了进一步增强 ViT 的性能并避免图像缩小,我们支持 tile 裁剪的图像,这些图像在预处理阶段被分割成 **tiles**。例如,如果 `tile_size=400`,我们可以将其裁剪成两个 400x400 的 tile,而不是将 800x400 的图像缩小以适应 400x400。有关预处理的详细信息,请参阅
torchtune.models.clip._transforms.CLIPImageTransform
。这些 tiles 中的每一个都通过卷积操作进一步分解成 patches。例如,如果您的
patch_size=40
,那么每个 (400, 400) 的 tile 将成为一个 10x10 的 patches 网格,您的整个图像将具有 num_tiles * n_tokens -> num_tiles * (10x10 patches + 1 CLS token) -> num_tiles * 101。在 Transformer 层之前,CLS token 作为第一个 token 被添加到每个 tile 中。在 Transformer 中,称为 CLS 的 token 是一种特殊的 token,它被添加到每个序列的开头。例如,可以使用此 token 代表整个输入,而不是使用池化操作。
为了帮助模型“看到”整个图像,我们使用位置嵌入。如果您的图像是 tile 裁剪的,那么您需要使用 tile 位置嵌入
token_pos_embedding (tile 裁剪的):
torchtune.models.clip._position_embeddings.TiledTokenPositionalEmbedding
pre_tile_pos_embed:
torchtune.models.clip._position_embeddings.TilePositionalEmbedding
post_tile_pos_embed:
torchtune.models.clip._position_embeddings.TilePositionalEmbedding
否则,pre 和 post tile_pos_embed 应为 None,您只需要一个简单的 token 位置嵌入
token_pos_embedding (未 tile 裁剪的):
torchtune.models.clip._position_embeddings.TokenPositionalEmbedding
所有图像将被视为 tile 的堆栈,即使您的图像未进行 tile 裁剪。在这种情况下,您的图像将由单个 tile 组成。
总结
图像在预处理过程中被分解成 tiles。
在 ViT 中,tiles 将被分解成 patches。
patches 将被展平并转换。我们将它们称为 tokens,因为 Transformer 就是这样看待它们的。
图像: 形状 (8x8)
| 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 |
Tiles: 形状 (4,4,4) # (num_tiles, tile_size, tile_size)
| 1 | 2 | 3 | 4 | | 5 | 6 | 7 | 8 | | 9 | 10 | 11 | 12 | | 13 | 14 | 15 | 16 | | 17 | 18 | 19 | 20 | | 21 | 22 | 23 | 24 | | 25 | 26 | 27 | 28 | | 29 | 30 | 31 | 32 | | 33 | 34 | 35 | 36 | | 37 | 38 | 39 | 40 | | 41 | 42 | 43 | 44 | | 45 | 46 | 47 | 48 | | 49 | 50 | 51 | 52 | | 53 | 54 | 55 | 56 | | 57 | 58 | 59 | 60 | | 61 | 62 | 63 | 64 |
Patches: 形状 (4,4,2,2) # (num_tiles, num_patches_per_tile, patch_size, patch_size)
| 1 | 2 | | 3 | 4 | | 5 | 6 | | 7 | 8 | | 9 | 10 | | 11 | 12 | | 13 | 14 | | 15 | 16 | | 17 | 18 | | 19 | 20 | | 21 | 22 | | 23 | 24 | | 25 | 26 | | 27 | 28 | | 29 | 30 | | 31 | 32 | | 33 | 34 | | 35 | 36 | | 37 | 38 | | 39 | 40 | | 41 | 42 | | 43 | 44 | | 45 | 46 | | 47 | 48 | | 49 | 50 | | 51 | 52 | | 53 | 54 | | 55 | 56 | | 57 | 58 | | 59 | 60 | | 61 | 62 | | 63 | 64 |
token: 形状 (4, 4, 4) # (num_tiles, num_patches_per_tile, emb_dim)
| 1 | 2 | 9 | 10 | | 3 | 4 | 11 | 12 | | 17 | 18 | 25 | 26 | | 19 | 20 | 27 | 28 | | ... continuation of data ... | ... continuation of data ... | 37 | 38 | 45 | 46 | | 39 | 40 | 47 | 48 | | 53 | 54 | 61 | 62 | | 55 | 56 | 63 | 64 |
关于位置嵌入
每个 tile 相同,每个 token 不同。
torchtune.models.clip._position_embeddings.TokenPositionalEmbedding
torchtune.models.clip._position_embeddings.TiledTokenPositionalEmbedding
| 1 | 2 | 3 | 4 | | 1 | 2 | 3 | 4 | | 9 | 10 | 11 | 12 | | 9 | 10 | 11 | 12 | | 17 | 18 | 19 | 20 | | 17 | 18 | 19 | 20 | | 25 | 26 | 27 | 28 | | 25 | 26 | 27 | 28 | | 1 | 2 | 3 | 4 | | 1 | 2 | 3 | 4 | | 9 | 10 | 11 | 12 | | 9 | 10 | 11 | 12 | | 17 | 18 | 19 | 20 | | 17 | 18 | 19 | 20 | | 25 | 26 | 27 | 28 | | 25 | 26 | 27 | 28 |
每个 tile 不同,每个 token 不同。
| 1 | 2 | | 3 | 4 | | 5 | 6 | | 7 | 8 | | 9 | 10 | | 11 | 12 | | 13 | 14 | | 15 | 16 | | 17 | 18 | | 19 | 20 | | 21 | 22 | | 23 | 24 | | 25 | 26 | | 27 | 28 | | 29 | 30 | | 31 | 32 | | 33 | 34 | | 35 | 36 | | 37 | 38 | | 39 | 40 | | 41 | 42 | | 43 | 44 | | 45 | 46 | | 47 | 48 | | 49 | 50 | | 51 | 52 | | 53 | 54 | | 55 | 56 | | 57 | 58 | | 59 | 60 | | 61 | 62 | | 63 | 64 |
每个 tile 不同,同一 tile 内的每个 token 相同。
| 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 |
- 参数:
patch_size (int) – 每个 patch 的大小。用于将 tiles 分割成 patches。例如,对于
patch_size=40
,形状为 (400, 400) 的 tile 将具有 10x10 的 patches 网格。tile_size (int) – 图像 tile 的大小,如果图像已预先进行 tile 裁剪。否则,为输入图像的大小。在这种情况下,函数将把图像视为单个 tile。每个 tile 的形状为 (40, 40)。
num_layers (int) – Transformer 层的数量。
embed_dim (int) – 每个 patch 嵌入(token)的维度。
layer (nn.Module) – Transformer 层模块。
token_pos_embedding (nn.Module) – token 位置嵌入模块。
pre_tile_pos_embed (Optional[nn.Module]) – 预 tile 位置嵌入模块。如果图像未预先进行 tile 裁剪,则应为 None。
post_tile_pos_embed (Optional[nn.Module]) – 后 tile 位置嵌入模块。如果图像未预先进行 tile 裁剪,则应为 None。
cls_projection (Optional[nn.Module]) – CLS 投影模块。它应接受形状为 (bsz * n_tiles, n_tokens, embed_dim) 的输入 tensor,并输出形状为 (bsz * n_tiles, cls_output_dim) 的 tensor。如果提供,将只输出 CLS token 投影,而不是所有 tokens。
out_indices (Optional[List[int]]) – 要返回的隐藏层的索引。如果提供,将返回 transformer 层在进入下一层之前的中间结果。例如,
out_indices=[0,3]
将返回进入第一层和第四层之前的 tokens。in_channels (int) – 图像输入通道的数量。
append_cls_token (bool) – 如果为 True,将 CLS token 添加到序列末尾。默认为 False,将 CLS token 添加到序列开头。
- 引发:
ValueError – 如果
tile_size
不大于 0,**或**如果patch_size
不大于 0,**或**如果len(out_indices)
大于num_layers
。
- forward(images: Tensor, aspect_ratio: Optional[Tensor] = None) Tuple[Tensor, List[Tensor]] [source]¶
处理图像并返回 tokens 和隐藏状态。
每个样本多张图像:我们在输入中添加一个维度 n_imgs。当单个样本包含多张图像时,这非常有用,例如
样本 1: “<image> 这是什么动物?”
样本 2: “我喜欢 <image> 多于 <image>”
在这种情况下,样本 1 有一张图像,样本 2 有两张图像。max_n_imgs = max(2,1) = 2。因此您的输入应具有形状 (bsz=2, n_imgs=2, num_tiles, n_channels, tile_size, tile_size)。
请注意,为了进行批量处理,您需要将 n_imgs 填充到 max_n_imgs 和 max_num_tiles。
- 参数:
images (torch.Tensor) – 形状为 (bsz, n_imgs, n_tiles, n_channels, tile_size, tile_size) 的 torch.Tensor。
aspect_ratio (Optional[torch.Tensor]) – 形状为 (bsz, n_imgs, 2) 的 torch.Tensor。如果所有图像都只有一个 tile,即它们未进行 tile 裁剪,则应为 None。用于计算 tile 的位置嵌入。
- 返回:
- 一个元组:(x, hidden_states),
其中 x 是形状为 (bsz, n_imgs, n_tiles, n_tokens, embed_dim) 的 torch.tensor,hidden_states 是一个包含 len(out_indices) 个形状为 (bsz, n_imgs, n_tiles, n_tokens, embed_dim) 的 torch.tensor 的列表。
- 返回类型:
Tuple[torch.Tensor, List[torch.Tensor]]
- 引发:
ValueError – 如果 aspect_ratio 为 None,但批次中 n_tiles > 1。
示例
>>> from torchtune.modules.transforms.vision_utils.tile_crop import tile_crop >>> from torchtune.modules import VisionTransformer >>> >>> num_channels = 3 >>> image_size = (800,400) >>> tile_size = 400 >>> patch_size=40 >>> patch_grid_size = tile_size // patch_size >>> >>> # for details about preprocessing, please check >>> # torchtune.models.clip._transforms.CLIPImageTransform >>> >>> # create a random image >>> image = torch.rand(num_channels, image_size[0], image_size[1]) >>> >>> # (num_tiles, nch, h, w) -> (2, 3, 400, 400) >>> tile_cropped_image = tile_crop(image, tile_size) >>> aspect_ratio = torch.tensor([2,1]) >>> >>> # make it a batch of 1 image >>> batch_image = tile_cropped_image.unsqueeze(0) >>> batch_aspect_ratio = aspect_ratio.unsqueeze(0) >>> >>> # make it have only 1 image per sample >>> batch_image = tile_cropped_image.unsqueeze(1) >>> batch_aspect_ratio = aspect_ratio.unsqueeze(1) >>> >>> # For a detailed example, please check >>> # torchtune.models.clip._position_embeddings.clip_vision_encoder >>> # model = VisionTransformer( ... # out_indices = [1,2,3,4,5], ... # patch_size=40, ... # patch_grid_size = patch_grid_size, ... # embed_dim = 32, ... # num_layers = 6, ... # in_channels = num_channels, ... # ...) >>> >>> x, hidden_states = model(images = batch_image, aspect_ratio = batch_aspect_ratio) >>> >>> # (bsz, n_imgs, num_tiles, num_patches_per_tile + CLS token, embed_dim) >>> print(x.shape) torch.Size([1, 1, 2, 101, 32]) >>> >>> # list with tensors of shape (bsz, n_imgs, num_tiles, num_patches_per_tile + CLS token, embed_dim) >>> print(len(hidden_states)) 5