快捷方式

torch.hsplit

torch.hsplit(input, indices_or_sections) 张量列表

根据 indices_or_sections,将具有一个或多个维度的张量 input 水平拆分为多个张量。每个拆分都是 input 的视图。

如果 input 是一维的,则等效于调用 torch.tensor_split(input, indices_or_sections, dim=0)(拆分维度为零),如果 input 具有两个或多个维度,则等效于调用 torch.tensor_split(input, indices_or_sections, dim=1)(拆分维度为 1),但如果 indices_or_sections 是整数,则它必须均匀地划分拆分维度,否则会抛出运行时错误。

此函数基于 NumPy 的 numpy.hsplit()

参数
示例:
>>> t = torch.arange(16.0).reshape(4,4)
>>> t
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])
>>> torch.hsplit(t, 2)
(tensor([[ 0.,  1.],
         [ 4.,  5.],
         [ 8.,  9.],
         [12., 13.]]),
 tensor([[ 2.,  3.],
         [ 6.,  7.],
         [10., 11.],
         [14., 15.]]))
>>> torch.hsplit(t, [3, 6])
(tensor([[ 0.,  1.,  2.],
         [ 4.,  5.,  6.],
         [ 8.,  9., 10.],
         [12., 13., 14.]]),
 tensor([[ 3.],
         [ 7.],
         [11.],
         [15.]]),
 tensor([], size=(4, 0)))

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源