VisionTransformer¶
- class torchtune.modules.VisionTransformer(patch_size: int, tile_size: int, num_layers: int, embed_dim: int, layer: Module, token_pos_embedding: Module, pre_tile_pos_embed: Optional[Module] = None, post_tile_pos_embed: Optional[Module] = None, cls_projection: Optional[Module] = None, out_indices: Optional[List[int]] = None, in_channels: int = 3, append_cls_token: bool = False)[source]¶
ViT 架构的实现 (https://arxiv.org/abs/2010.11929),支持平铺裁剪图像、隐藏层输出和可选的 CLS 投影。
ViT 是一种 Transformer 架构,它接收图像并输出 N 个嵌入 tokens,这些 tokens 代表该图像。每个图像通过卷积被分成patches。这些 patches 被展平,随后被 Transformer 视为 tokens。
为了进一步提高 ViT 的性能并避免缩小图像,我们支持平铺裁剪图像,这些图像在预处理阶段被分成 tiles。例如,与其将 800x400 的图像缩小到 400x400,不如将其裁剪成两个 400x400 的 tiles,如果
tile_size=400
。有关预处理的详细信息,请参阅torchtune.models.clip._transforms.CLIPImageTransform
。每个 tiles 都会通过卷积操作进一步分解为 patches。例如,如果您的
patch_size=40
,那么每个 (400, 400) 的 tiles 将变成 10x10 patches 的网格,并且您的整个图像将具有 num_tiles * n_tokens -> num_tiles * (10x10 patches + 1 CLS token) -> num_tiles * 101。在 Transformer 层之前,CLS token 作为第一个 token 添加到每个 tiles。在 Transformer 中,名为 CLS 的 token 是添加到每个序列开头的特殊 token。此 token 可用于表示整个输入,而不是使用例如池化操作。
为了帮助模型“看到”整个图像,我们使用位置嵌入。如果您的图像被平铺裁剪,那么您需要使用 tiles 位置嵌入
token_pos_embedding (平铺):
torchtune.models.clip._position_embeddings.TiledTokenPositionalEmbedding
pre_tile_pos_embed:
torchtune.models.clip._position_embeddings.TilePositionalEmbedding
post_tile_pos_embed:
torchtune.models.clip._position_embeddings.TilePositionalEmbedding
否则,pre 和 post tile_pos_embed 应该为 None,您只需要一个简单的 token 位置嵌入
token_pos_embedding (未平铺):
torchtune.models.clip._position_embeddings.TokenPositionalEmbedding
所有图像都将被视为 tiles 的堆叠,即使您的图像未被平铺裁剪。在这种情况下,您的图像将由单个 tiles 组成。
总结
图像在预处理期间被分解为 tiles。
在 ViT 中,tiles 将被分解为 patches。
patches 将被展平并转换。我们称它们为 tokens,因为 Transformer 就是这样看待它们的。
图像:形状 (8x8)
| 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 |
Tiles:形状 (4,4,4) # (num_tiles, tile_size, tile_size)
| 1 | 2 | 3 | 4 | | 5 | 6 | 7 | 8 | | 9 | 10 | 11 | 12 | | 13 | 14 | 15 | 16 | | 17 | 18 | 19 | 20 | | 21 | 22 | 23 | 24 | | 25 | 26 | 27 | 28 | | 29 | 30 | 31 | 32 | | 33 | 34 | 35 | 36 | | 37 | 38 | 39 | 40 | | 41 | 42 | 43 | 44 | | 45 | 46 | 47 | 48 | | 49 | 50 | 51 | 52 | | 53 | 54 | 55 | 56 | | 57 | 58 | 59 | 60 | | 61 | 62 | 63 | 64 |
Patches:形状 (4,4,2,2) # (num_tiles, num_patches_per_tile, patch_size, patch_size)
| 1 | 2 | | 3 | 4 | | 5 | 6 | | 7 | 8 | | 9 | 10 | | 11 | 12 | | 13 | 14 | | 15 | 16 | | 17 | 18 | | 19 | 20 | | 21 | 22 | | 23 | 24 | | 25 | 26 | | 27 | 28 | | 29 | 30 | | 31 | 32 | | 33 | 34 | | 35 | 36 | | 37 | 38 | | 39 | 40 | | 41 | 42 | | 43 | 44 | | 45 | 46 | | 47 | 48 | | 49 | 50 | | 51 | 52 | | 53 | 54 | | 55 | 56 | | 57 | 58 | | 59 | 60 | | 61 | 62 | | 63 | 64 |
token:形状 (4, 4, 4) # (num_tiles, num_patches_per_tile, emb_dim)
| 1 | 2 | 9 | 10 | | 3 | 4 | 11 | 12 | | 17 | 18 | 25 | 26 | | 19 | 20 | 27 | 28 | | ... continuation of data ... | ... continuation of data ... | 37 | 38 | 45 | 46 | | 39 | 40 | 47 | 48 | | 53 | 54 | 61 | 62 | | 55 | 56 | 63 | 64 |
对于位置嵌入
每个 tiles 相同,每个 token 不同。
torchtune.models.clip._position_embeddings.TokenPositionalEmbedding
torchtune.models.clip._position_embeddings.TiledTokenPositionalEmbedding
| 1 | 2 | 3 | 4 | | 1 | 2 | 3 | 4 | | 9 | 10 | 11 | 12 | | 9 | 10 | 11 | 12 | | 17 | 18 | 19 | 20 | | 17 | 18 | 19 | 20 | | 25 | 26 | 27 | 28 | | 25 | 26 | 27 | 28 | | 1 | 2 | 3 | 4 | | 1 | 2 | 3 | 4 | | 9 | 10 | 11 | 12 | | 9 | 10 | 11 | 12 | | 17 | 18 | 19 | 20 | | 17 | 18 | 19 | 20 | | 25 | 26 | 27 | 28 | | 25 | 26 | 27 | 28 |
每个 tiles 不同,每个 token 也不同。
| 1 | 2 | | 3 | 4 | | 5 | 6 | | 7 | 8 | | 9 | 10 | | 11 | 12 | | 13 | 14 | | 15 | 16 | | 17 | 18 | | 19 | 20 | | 21 | 22 | | 23 | 24 | | 25 | 26 | | 27 | 28 | | 29 | 30 | | 31 | 32 | | 33 | 34 | | 35 | 36 | | 37 | 38 | | 39 | 40 | | 41 | 42 | | 43 | 44 | | 45 | 46 | | 47 | 48 | | 49 | 50 | | 51 | 52 | | 53 | 54 | | 55 | 56 | | 57 | 58 | | 59 | 60 | | 61 | 62 | | 63 | 64 |
每个 tiles 不同,tiles 内的每个 token 相同。
| 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 |
- 参数:
patch_size (int) – 每个 patch 的大小。用于将 tiles 分成 patches。例如,对于
patch_size=40
,形状为 (400, 400) 的 tiles 将具有 10x10 的 patches 网格。tile_size (int) – 您的图像 tiles 的大小,如果图像已提前平铺裁剪。否则,为输入图像的大小。在这种情况下,该函数会将您的图像视为单个 tiles。每个形状为 (40, 40)。
num_layers (int) – Transformer 层的数量。
embed_dim (int) – 每个 patch 嵌入(token)的维度。
layer (nn.Module) – Transformer 层模块。
token_pos_embedding (nn.Module) – token 位置嵌入模块。
pre_tile_pos_embed (Optional[nn.Module]) – pre-tile 位置嵌入模块。如果您的图像未提前平铺裁剪,则应为 None。
post_tile_pos_embed (Optional[nn.Module]) – post-tile 位置嵌入模块。如果您的图像未提前平铺裁剪,则应为 None。
cls_projection (Optional[nn.Module]) – CLS 投影模块。它应接收形状为 (bsz * n_tiles, n_tokens, embed_dim) 的输入张量,并输出形状为 (bsz * n_tiles, cls_output_dim) 的张量。如果提供,则仅输出 CLS token 投影,而不是所有 tokens。
out_indices (Optional[List[int]]) – 要返回的隐藏层索引。如果提供,它将返回 Transformer 层的中间结果,然后再通过下一层。例如,
out_indices=[0,3]
将返回 tokens 在通过第一层和第四层之前的状态。in_channels (int) – 图像输入通道数。
append_cls_token (bool) – 如果为 True,则将 CLS token 添加到序列的末尾。默认值为 False,即在序列的开头添加 CLS token。
- Raises:
ValueError – 如果 tile_size 不大于 0。
ValueError – 如果 patch_size 不大于 0。
ValueError – 如果 len(out_indices) 大于 num_layers。
- forward(images: Tensor, aspect_ratio: Optional[Tensor] = None) Tuple[Tensor, List[Tensor]] [source]¶
处理图像并返回 tokens 和隐藏状态。
每个样本多张图像:我们向输入添加维度 n_imgs。当单个样本包含多张图像时,这非常有用,例如
样本 1:“<image> 这是什么动物?”
样本 2:“我喜欢 <image> 胜过 <image>”
在这种情况下,样本 1 有一张图像,样本 2 有两张图像。max_n_imgs = max(2,1) = 2。因此,您的输入应具有形状 (bsz=2, n_imgs=2, num_tiles, n_channels, tile_size, tile_size)。
请注意,为了进行批处理,您必须将 n_imgs 填充到 max_n_imgs 和 max_num_tiles。
- 参数:
images (torch.Tensor) – 形状为 (bsz, n_imgs, n_tiles, n_channels, tile_size, tile_size) 的 torch.Tensor。
aspect_ratio (Optional[torch.Tensor]) – 形状为 (bsz, n_imgs, 2) 的 torch.Tensor。如果所有图像都只有一个 tiles,即它们未被平铺裁剪,则应为 None。用于计算 tiles 的位置嵌入。
- Returns:
- 一个元组:(x, hidden_states),
其中 x 是形状为 (bsz, n_imgs, n_tiles, n_tokens, embed_dim) 的 torch.tensor,hidden_states 是长度为 len(out_indices) 的 torch.tensor 列表,形状为 (bsz, n_imgs, n_tiles, n_tokens, embed_dim)。
- Return type:
Tuple[torch.Tensor, List[torch.Tensor]]
- Raises:
ValueError – 如果 aspect_ratio 为 None,但批处理中 n_tiles > 1。
示例
>>> from torchtune.modules.transforms.vision_utils.tile_crop import tile_crop >>> from torchtune.modules import VisionTransformer >>> >>> num_channels = 3 >>> image_size = (800,400) >>> tile_size = 400 >>> patch_size=40 >>> patch_grid_size = tile_size // patch_size >>> >>> # for details about preprocessing, please check >>> # torchtune.models.clip._transforms.CLIPImageTransform >>> >>> # create a random image >>> image = torch.rand(num_channels, image_size[0], image_size[1]) >>> >>> # (num_tiles, nch, h, w) -> (2, 3, 400, 400) >>> tile_cropped_image = tile_crop(image, tile_size) >>> aspect_ratio = torch.tensor([2,1]) >>> >>> # make it a batch of 1 image >>> batch_image = tile_cropped_image.unsqueeze(0) >>> batch_aspect_ratio = aspect_ratio.unsqueeze(0) >>> >>> # make it have only 1 image per sample >>> batch_image = tile_cropped_image.unsqueeze(1) >>> batch_aspect_ratio = aspect_ratio.unsqueeze(1) >>> >>> # For a detailed example, please check >>> # torchtune.models.clip._position_embeddings.clip_vision_encoder >>> # model = VisionTransformer( ... # out_indices = [1,2,3,4,5], ... # patch_size=40, ... # patch_grid_size = patch_grid_size, ... # embed_dim = 32, ... # num_layers = 6, ... # in_channels = num_channels, ... # ...) >>> >>> x, hidden_states = model(images = batch_image, aspect_ratio = batch_aspect_ratio) >>> >>> # (bsz, n_imgs, num_tiles, num_patches_per_tile + CLS token, embed_dim) >>> print(x.shape) torch.Size([1, 1, 2, 101, 32]) >>> >>> # list with tensors of shape (bsz, n_imgs, num_tiles, num_patches_per_tile + CLS token, embed_dim) >>> print(len(hidden_states)) 5