快捷方式

padded_collate_tiled_images_and_mask

torchtune.data.padded_collate_tiled_images_and_mask(batch: List[Dict[str, Any]], padding_idx: int = 0, ignore_idx: int = - 100, pad_direction: str = 'right', pad_max_images: Optional[int] = None) Dict[str, Tensor][source]

对一批文本序列、平铺图像张量、纵横比和交叉注意力掩码进行填充。这可用于训练和推理。

batch 预期为样本字典列表,包含以下内容:
  • “tokens”: 长度为 text_seq_len 的整数列表,在不同样本中变化

  • “labels”: 长度为 text_seq_len 的整数列表,在不同样本中变化

  • “encoder_input”: Dict[str, List[torch.Tensor]]
    • “images”: torch.Tensor 列表,每个张量的形状为 (n_tiles, c, h, w)

    • “aspect_ratio”: torch.Tensor 列表,每个张量的形状为 (2,),用于指示 h_ratio 和 w_ratio

  • “encoder_mask”: Tensor 列表,每个张量的形状为 (text_seq_len, image_seq_len)

形状表示
  • c = 通道维度

  • h = 高度维度

  • w = 宽度维度

注意

对于批次中的每个元素,len(images) == len(encoder_mask) == len(aspect_ratio)

此 collater 执行以下操作
  1. 将文本序列和编码器掩码填充到批次中最长的序列长度

  2. 将图像张量在平铺维度上用零填充到批次中最大的平铺数量

  3. 向样本添加空图像(全为零),直到批次中的最大图像数量

  4. 为所有添加的填充图像的纵横比填充 (1,1)

参数:
  • batch (List[Dict[str, Any]]) – 包含 tokens、labels、images、encoder_mask 和 aspect_ratio 的样本字典列表。

  • padding_idx (int) – 输入 token id 的填充索引。默认为 0。

  • ignore_idx (int) – 标签的填充索引。默认为 -100。

  • pad_direction (str) – 是否从左侧或右侧填充条目。如果 pad_direction="right",我们使用 torch.nn.utils.rnn.pad_sequence(),否则如果 pad_direction="left",我们使用 torchtune.data.left_pad_sequence()。对于训练,我们通常希望从右侧填充。对于推理,我们通常希望从左侧填充。默认为“right”。

  • pad_max_images (Optional[int]) – 要填充到的最大图像数量。如果为 None,则将填充到批次中最大的图像数量。默认为 None。

返回值:

合并后的 tokens、labels、images、encoder_mask 和 aspect_ratio 张量。
  • tokens: 形状为 (bsz, max_seq_len) 的张量

  • labels: 形状为 (bsz, max_seq_len) 的张量

  • images: 形状为 (bsz, max_num_images, max_num_tiles, c, h, w) 的张量

  • encoder_mask: 形状为 (bsz, max_seq_len, tokens_per_tile * max_num_tiles * max_num_images) 的张量

  • aspect_ratio: 形状为 (bsz, max_num_images, 2) 的张量

返回类型:

Dict[str, Tensor]

引发:

ValueError – 如果 pad_direction 不是“left”或“right”之一。

示例

>>> image_id = 1
>>> tokens_per_tile = 5
>>> c, h, w = 1, 1, 1
>>> batch = [
...     {
...         "tokens": [1, 2, 1, 3], "labels": [4, 5, 6, 7],
...         "encoder_input": {
...             # One image with two tiles, one image with three tiles
...             "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)],
...             "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])],
...         },
...         # Mask is shape (text_seq_len, tokens_per_tile * n_tiles)
...         "encoder_mask": [torch.ones(4, 5 * 2), torch.ones(4, 5 * 3)],
...     },
...     {
...         "tokens": [1, 4], "labels": [8, 9],
...         "encoder_input": {
...             # One image with four tiles
...             "images": [torch.ones(4, c, h, w)],
...             "aspect_ratio": [torch.tensor([2, 2])],
...         },
...         # Mask is shape (text_seq_len, tokens_per_tile * n_tiles)
...         "encoder_mask": [torch.ones(2, 5 * 4)],
...     },
... ]
>>> model_inputs = padded_collate_tiled_images_and_mask(batch=batch)
>>> print(model_inputs["tokens"])
tensor([[1, 2, 1, 3],
        [1, 4, 0, 0]])
>>> print(model_inputs["labels"])
tensor([[4, 5, 6, 7],
        [8, 9, -100, -100]])
>>> print(model_inputs["encoder_input"]["images"].shape)  # (bsz, max_num_images, max_num_tiles, c, h, w)
torch.Size([2, 2, 4, 1, 1, 1])
>>> print(model_inputs["encoder_mask"].shape)  # (bsz, max_text_seq_len, tokens_per_tile * max_num_tiles * max_num_images)
torch.Size([2, 4, 40])
>>> print(model_inputs["encoder_input"]["aspect_ratio"].shape)  # (bsz, max_num_images, 2)
torch.Size([2, 2, 2])
>>> print(model_inputs["encoder_input"]["images"][0, 0, ...])  # Image with two tiles got padded to four
tensor([[[[1.]]], [[[1.]]], [[[0.]]], [[[0.]]]])
>>> print(model_inputs["encoder_input"]["images"][0, 1, ...])  # Image with three tiles got padded to four
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[0.]]]])
>>> print(model_inputs["encoder_input"]["images"][1, 0, ...])  # Image with four tiles did not get padded
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[1.]]]])
>>> print(model_inputs["encoder_input"]["images"][1, 1, ...])  # Extra padding image was added to second sample
tensor([[[[0.]]], [[[0.]]], [[[0.]]], [[[0.]]]])

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深度教程

查看教程

资源

查找开发资源并获取问题的解答

查看资源