快捷方式

torch.split

torch.split(tensor, split_size_or_sections, dim=0)[源代码]

将张量分割成块。每个块都是原始张量的视图。

如果split_size_or_sections是整数类型,则tensor将被分割成大小相等的块(如果可能)。如果张量沿给定维度dim的大小不能被split_size整除,则最后一个块将更小。

如果split_size_or_sections是列表,则tensor将被分割成len(split_size_or_sections)个块,其在dim维度上的大小根据split_size_or_sections确定。

参数
  • tensor (张量) – 要分割的张量。

  • split_size_or_sections (int) 或 (list(int)) – 单个块的大小或每个块的大小列表

  • dim (int) – 要沿其分割张量的维度。

返回类型

元组[张量, …]

示例

>>> a = torch.arange(10).reshape(5, 2)
>>> a
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
         [2, 3]]),
 tensor([[4, 5],
         [6, 7]]),
 tensor([[8, 9]]))
>>> torch.split(a, [1, 4])
(tensor([[0, 1]]),
 tensor([[2, 3],
         [4, 5],
         [6, 7],
         [8, 9]]))

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源