快捷方式

VisionTransformer

class torchtune.modules.VisionTransformer(patch_size: int, tile_size: int, num_layers: int, embed_dim: int, layer: Module, token_pos_embedding: Module, pre_tile_pos_embed: Optional[Module] = None, post_tile_pos_embed: Optional[Module] = None, cls_projection: Optional[Module] = None, out_indices: Optional[List[int]] = None, in_channels: int = 3, append_cls_token: bool = False)[source]

ViT 架构的实现 (https://arxiv.org/abs/2010.11929),支持平铺裁剪图像、隐藏层输出和可选的 CLS 投影。

ViT 是一种 Transformer 架构,它接收图像并输出 N 个嵌入 tokens,这些 tokens 代表该图像。每个图像通过卷积被分成patches。这些 patches 被展平,随后被 Transformer 视为 tokens

为了进一步提高 ViT 的性能并避免缩小图像,我们支持平铺裁剪图像,这些图像在预处理阶段被分成 tiles。例如,与其将 800x400 的图像缩小到 400x400,不如将其裁剪成两个 400x400 的 tiles,如果 tile_size=400。有关预处理的详细信息,请参阅 torchtune.models.clip._transforms.CLIPImageTransform

每个 tiles 都会通过卷积操作进一步分解为 patches。例如,如果您的 patch_size=40,那么每个 (400, 400) 的 tiles 将变成 10x10 patches 的网格,并且您的整个图像将具有 num_tiles * n_tokens -> num_tiles * (10x10 patches + 1 CLS token) -> num_tiles * 101。

在 Transformer 层之前,CLS token 作为第一个 token 添加到每个 tiles。在 Transformer 中,名为 CLS 的 token 是添加到每个序列开头的特殊 token。此 token 可用于表示整个输入,而不是使用例如池化操作。

为了帮助模型“看到”整个图像,我们使用位置嵌入。如果您的图像被平铺裁剪,那么您需要使用 tiles 位置嵌入

否则,pre 和 post tile_pos_embed 应该为 None,您只需要一个简单的 token 位置嵌入

所有图像都将被视为 tiles 的堆叠,即使您的图像未被平铺裁剪。在这种情况下,您的图像将由单个 tiles 组成。

总结

  1. 图像在预处理期间被分解为 tiles。

  2. 在 ViT 中,tiles 将被分解为 patches。

  3. patches 将被展平并转换。我们称它们为 tokens,因为 Transformer 就是这样看待它们的。

图像:形状 (8x8)

|  1 |  2 |  3 |  4 |  5 |  6 |  7 |  8 |
|  9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 |
| 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 |
| 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 |
| 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 |
| 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 |
| 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 |
| 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 |

Tiles:形状 (4,4,4) # (num_tiles, tile_size, tile_size)

|  1 |  2 |  3 |  4 |    |  5 |  6 |  7 |  8 |
|  9 | 10 | 11 | 12 |    | 13 | 14 | 15 | 16 |
| 17 | 18 | 19 | 20 |    | 21 | 22 | 23 | 24 |
| 25 | 26 | 27 | 28 |    | 29 | 30 | 31 | 32 |

| 33 | 34 | 35 | 36 |    | 37 | 38 | 39 | 40 |
| 41 | 42 | 43 | 44 |    | 45 | 46 | 47 | 48 |
| 49 | 50 | 51 | 52 |    | 53 | 54 | 55 | 56 |
| 57 | 58 | 59 | 60 |    | 61 | 62 | 63 | 64 |

Patches:形状 (4,4,2,2) # (num_tiles, num_patches_per_tile, patch_size, patch_size)

|  1 |  2 |    |  3 |  4 |    |  5 |  6 |    |  7 |  8 |
|  9 | 10 |    | 11 | 12 |    | 13 | 14 |    | 15 | 16 |

| 17 | 18 |    | 19 | 20 |    | 21 | 22 |    | 23 | 24 |
| 25 | 26 |    | 27 | 28 |    | 29 | 30 |    | 31 | 32 |

| 33 | 34 |    | 35 | 36 |    | 37 | 38 |    | 39 | 40 |
| 41 | 42 |    | 43 | 44 |    | 45 | 46 |    | 47 | 48 |

| 49 | 50 |    | 51 | 52 |    | 53 | 54 |    | 55 | 56 |
| 57 | 58 |    | 59 | 60 |    | 61 | 62 |    | 63 | 64 |

token:形状 (4, 4, 4) # (num_tiles, num_patches_per_tile, emb_dim)

|  1 |  2 |  9 |  10 |    |  3 |  4 |  11 |  12 |    |  17 |  18 |  25 |  26 |    | 19 | 20 |  27 |  28 |
| ... continuation of data ...
| ... continuation of data ...
| 37 | 38 | 45 |  46 |    | 39 |  40 | 47 |  48 |    | 53 | 54 |  61 |  62 |    | 55 | 56 |  63 |  64 |

对于位置嵌入

每个 tiles 相同,每个 token 不同。

|  1 |  2 |  3 |  4 |    |  1 |  2 |  3 |  4 |
|  9 | 10 | 11 | 12 |    |  9 | 10 | 11 | 12 |
| 17 | 18 | 19 | 20 |    | 17 | 18 | 19 | 20 |
| 25 | 26 | 27 | 28 |    | 25 | 26 | 27 | 28 |

|  1 |  2 |  3 |  4 |    |  1 |  2 |  3 |  4 |
|  9 | 10 | 11 | 12 |    |  9 | 10 | 11 | 12 |
| 17 | 18 | 19 | 20 |    | 17 | 18 | 19 | 20 |
| 25 | 26 | 27 | 28 |    | 25 | 26 | 27 | 28 |

每个 tiles 不同,每个 token 也不同。

|  1 |  2 |    |  3 |  4 |    |  5 |  6 |    |  7 |  8 |
|  9 | 10 |    | 11 | 12 |    | 13 | 14 |    | 15 | 16 |

| 17 | 18 |    | 19 | 20 |    | 21 | 22 |    | 23 | 24 |
| 25 | 26 |    | 27 | 28 |    | 29 | 30 |    | 31 | 32 |

| 33 | 34 |    | 35 | 36 |    | 37 | 38 |    | 39 | 40 |
| 41 | 42 |    | 43 | 44 |    | 45 | 46 |    | 47 | 48 |

| 49 | 50 |    | 51 | 52 |    | 53 | 54 |    | 55 | 56 |
| 57 | 58 |    | 59 | 60 |    | 61 | 62 |    | 63 | 64 |

每个 tiles 不同,tiles 内的每个 token 相同。

|  1 |  1 |  1 |  1 |    |  2 |  2 |  2 |  3 |
|  1 |  1 |  1 |  1 |    |  2 |  2 |  2 |  3 |
|  1 |  1 |  1 |  1 |    |  2 |  2 |  2 |  3 |
|  1 |  1 |  1 |  1 |    |  2 |  2 |  2 |  3 |

|  3 |  3 |  3 |  3 |    |  4 |  4 |  4 |  4 |
|  3 |  3 |  3 |  3 |    |  4 |  4 |  4 |  4 |
|  3 |  3 |  3 |  3 |    |  4 |  4 |  4 |  4 |
|  3 |  3 |  3 |  3 |    |  4 |  4 |  4 |  4 |
参数:
  • patch_size (int) – 每个 patch 的大小。用于将 tiles 分成 patches。例如,对于 patch_size=40,形状为 (400, 400) 的 tiles 将具有 10x10 的 patches 网格。

  • tile_size (int) – 您的图像 tiles 的大小,如果图像已提前平铺裁剪。否则,为输入图像的大小。在这种情况下,该函数会将您的图像视为单个 tiles。每个形状为 (40, 40)。

  • num_layers (int) – Transformer 层的数量。

  • embed_dim (int) – 每个 patch 嵌入(token)的维度。

  • layer (nn.Module) – Transformer 层模块。

  • token_pos_embedding (nn.Module) – token 位置嵌入模块。

  • pre_tile_pos_embed (Optional[nn.Module]) – pre-tile 位置嵌入模块。如果您的图像未提前平铺裁剪,则应为 None。

  • post_tile_pos_embed (Optional[nn.Module]) – post-tile 位置嵌入模块。如果您的图像未提前平铺裁剪,则应为 None。

  • cls_projection (Optional[nn.Module]) – CLS 投影模块。它应接收形状为 (bsz * n_tiles, n_tokens, embed_dim) 的输入张量,并输出形状为 (bsz * n_tiles, cls_output_dim) 的张量。如果提供,则仅输出 CLS token 投影,而不是所有 tokens。

  • out_indices (Optional[List[int]]) – 要返回的隐藏层索引。如果提供,它将返回 Transformer 层的中间结果,然后再通过下一层。例如,out_indices=[0,3] 将返回 tokens 在通过第一层和第四层之前的状态。

  • in_channels (int) – 图像输入通道数。

  • append_cls_token (bool) – 如果为 True,则将 CLS token 添加到序列的末尾。默认值为 False,即在序列的开头添加 CLS token。

Raises:
  • ValueError – 如果 tile_size 不大于 0。

  • ValueError – 如果 patch_size 不大于 0。

  • ValueError – 如果 len(out_indices) 大于 num_layers

forward(images: Tensor, aspect_ratio: Optional[Tensor] = None) Tuple[Tensor, List[Tensor]][source]

处理图像并返回 tokens 和隐藏状态。

每个样本多张图像:我们向输入添加维度 n_imgs。当单个样本包含多张图像时,这非常有用,例如

  • 样本 1:“<image> 这是什么动物?”

  • 样本 2:“我喜欢 <image> 胜过 <image>”

在这种情况下,样本 1 有一张图像,样本 2 有两张图像。max_n_imgs = max(2,1) = 2。因此,您的输入应具有形状 (bsz=2, n_imgs=2, num_tiles, n_channels, tile_size, tile_size)。

请注意,为了进行批处理,您必须将 n_imgs 填充到 max_n_imgs 和 max_num_tiles。

参数:
  • images (torch.Tensor) – 形状为 (bsz, n_imgs, n_tiles, n_channels, tile_size, tile_size) 的 torch.Tensor。

  • aspect_ratio (Optional[torch.Tensor]) – 形状为 (bsz, n_imgs, 2) 的 torch.Tensor。如果所有图像都只有一个 tiles,即它们未被平铺裁剪,则应为 None。用于计算 tiles 的位置嵌入。

Returns:

一个元组:(x, hidden_states),

其中 x 是形状为 (bsz, n_imgs, n_tiles, n_tokens, embed_dim) 的 torch.tensor,hidden_states 是长度为 len(out_indices) 的 torch.tensor 列表,形状为 (bsz, n_imgs, n_tiles, n_tokens, embed_dim)。

Return type:

Tuple[torch.Tensor, List[torch.Tensor]]

Raises:

ValueError – 如果 aspect_ratio 为 None,但批处理中 n_tiles > 1。

示例

>>> from torchtune.modules.transforms.vision_utils.tile_crop import tile_crop
>>> from torchtune.modules import VisionTransformer
>>>
>>> num_channels = 3
>>> image_size = (800,400)
>>> tile_size = 400
>>> patch_size=40
>>> patch_grid_size = tile_size // patch_size
>>>
>>> # for details about preprocessing, please check
>>> # torchtune.models.clip._transforms.CLIPImageTransform
>>>
>>> # create a random image
>>> image = torch.rand(num_channels, image_size[0], image_size[1])
>>>
>>> # (num_tiles, nch, h, w) -> (2, 3, 400, 400)
>>> tile_cropped_image = tile_crop(image, tile_size)
>>> aspect_ratio = torch.tensor([2,1])
>>>
>>> # make it a batch of 1 image
>>> batch_image = tile_cropped_image.unsqueeze(0)
>>> batch_aspect_ratio = aspect_ratio.unsqueeze(0)
>>>
>>> # make it have only 1 image per sample
>>> batch_image = tile_cropped_image.unsqueeze(1)
>>> batch_aspect_ratio = aspect_ratio.unsqueeze(1)
>>>
>>> # For a detailed example, please check
>>> # torchtune.models.clip._position_embeddings.clip_vision_encoder
>>> # model = VisionTransformer(
... #           out_indices = [1,2,3,4,5],
... #           patch_size=40,
... #           patch_grid_size = patch_grid_size,
... #           embed_dim = 32,
... #           num_layers = 6,
... #           in_channels = num_channels,
... #           ...)
>>>
>>> x, hidden_states = model(images = batch_image, aspect_ratio = batch_aspect_ratio)
>>>
>>> # (bsz, n_imgs, num_tiles, num_patches_per_tile + CLS token, embed_dim)
>>> print(x.shape)
torch.Size([1, 1, 2, 101, 32])
>>>
>>> # list with tensors of shape (bsz, n_imgs, num_tiles, num_patches_per_tile + CLS token, embed_dim)
>>> print(len(hidden_states))
5

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源