快捷方式

torch.frombuffer

torch.frombuffer(buffer, *, dtype, count=-1, offset=0, requires_grad=False) Tensor

从实现 Python 缓冲区协议的对象创建一维 Tensor

跳过缓冲区中的前 offset 字节,并将剩余的原始字节解释为类型为 dtype、具有 count 个元素的一维张量。

请注意,以下两种情况之一必须为真

1. count 是一个正的非零数,并且缓冲区中的总字节数大于 offset 加上 count 乘以 dtype 的大小(以字节为单位)。

2. count 为负数,并且缓冲区的长度(字节数)减去 offsetdtype 的大小(以字节为单位)的倍数。

返回的张量和缓冲区共享相同的内存。对张量的修改将反映在缓冲区中,反之亦然。返回的张量不可调整大小。

注意

此函数会增加拥有共享内存的对象的引用计数。因此,此类内存在返回的张量超出范围之前不会被释放。

警告

当传递一个实现缓冲区协议且数据不在 CPU 上的对象时,此函数的行为是未定义的。这样做很可能会导致段错误。

警告

此函数不会尝试推断 dtype(因此,它不是可选的)。传递与源不同的 dtype 可能会导致意外行为。

参数

buffer (object) – 公开缓冲区接口的 Python 对象。

关键字参数
  • dtype (torch.dtype) – 返回张量的所需数据类型。

  • count (int, 可选) – 要读取的所需元素数量。如果为负数,将读取所有元素(直到缓冲区的末尾)。默认值:-1。

  • offset (int, 可选) – 在缓冲区开头跳过的字节数。默认值:0。

  • requires_grad (bool, 可选) – autograd 是否应该记录对返回张量的操作。默认值:False

示例

>>> import array
>>> a = array.array('i', [1, 2, 3])
>>> t = torch.frombuffer(a, dtype=torch.int32)
>>> t
tensor([ 1,  2,  3])
>>> t[0] = -1
>>> a
array([-1,  2,  3])

>>> # Interprets the signed char bytes as 32-bit integers.
>>> # Each 4 signed char elements will be interpreted as
>>> # 1 signed 32-bit integer.
>>> import array
>>> a = array.array('b', [-1, 0, 0, 0])
>>> torch.frombuffer(a, dtype=torch.int32)
tensor([255], dtype=torch.int32)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适用于初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源