TiledTokenPositionalEmbedding¶
- class torchtune.models.clip.TiledTokenPositionalEmbedding(max_num_tiles: int, embed_dim: int, tile_size: int, patch_size: int)[source]¶
用于平铺图像的令牌位置嵌入,每个平铺不同,每个令牌也不同。
此模块中有两种位置嵌入:
local_token_positional_embedding:对每个平铺相同,对每个令牌不同。等同于
torchtune.models.clip._position_embeddings.TokenPositionalEmbedding
,但带门控。global_token_positional_embedding:对每个平铺不同,对每个令牌也不同。
注意,平铺(tile)与块(patch)(即令牌 token)不同。有关详细信息,请查阅
torchtune.modules.vision_transformer.VisionTransformer
的文档。- 参数:
- forward(x: Tensor, aspect_ratio: Tensor) Tensor [source]¶
- 参数:
x (torch.Tensor) – 形状为 (bsz * n_imgs, n_tiles, n_tokens_per_tile, embed_dim) 的 torch.Tensor。
aspect_ratio (torch.Tensor) – 形状为 (bsz * n_imgs, 2) 的 torch.Tensor,其中 aspect_ratio[k] 表示批次中第 k 个图像在平铺裁剪前的宽高比,例如 aspect_ratio[k] = (2,1)。
- 返回:
添加了位置嵌入的输入张量。
- 返回类型: