快捷方式

TiledTokenPositionalEmbedding

class torchtune.models.clip.TiledTokenPositionalEmbedding(max_num_tiles: int, embed_dim: int, tile_size: int, patch_size: int)[source]

用于平铺图像的令牌位置嵌入,每个平铺不同,每个令牌也不同。

此模块中有两种位置嵌入:

注意,平铺(tile)与块(patch)(即令牌 token)不同。有关详细信息,请查阅 torchtune.modules.vision_transformer.VisionTransformer 的文档。

参数:
  • max_num_tiles (int) – 图像可被划分成的最大平铺数。

  • embed_dim (int) – 每个令牌嵌入的维度。

  • tile_size (int) – 图像平铺的大小(如果图像已提前进行平铺裁剪)。否则,为输入图像的大小。在这种情况下,函数会将图像视为单个平铺。

  • patch_size (int) – 每个块的大小。用于将平铺划分为块。例如,对于 patch_size=40,一个形状为 (400, 400) 的平铺将拥有 10x10 的块网格,每个块的形状为 (40, 40)。

forward(x: Tensor, aspect_ratio: Tensor) Tensor[source]
参数:
  • x (torch.Tensor) – 形状为 (bsz * n_imgs, n_tiles, n_tokens_per_tile, embed_dim) 的 torch.Tensor。

  • aspect_ratio (torch.Tensor) – 形状为 (bsz * n_imgs, 2) 的 torch.Tensor,其中 aspect_ratio[k] 表示批次中第 k 个图像在平铺裁剪前的宽高比,例如 aspect_ratio[k] = (2,1)。

返回:

添加了位置嵌入的输入张量。

返回类型:

torch.Tensor

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源