快捷方式

在 PyTorch 中加载数据

PyTorch 提供了广泛的神经网络构建块,并具有简单、直观且稳定的 API。PyTorch 包含用于为您的模型准备和加载常用数据集的包。

简介

PyTorch 数据加载实用程序的核心是 torch.utils.data.DataLoader 类。它表示一个数据集上的 Python 可迭代对象。PyTorch 中的库提供内置的高质量数据集,供您在 torch.utils.data.Dataset 中使用。这些数据集目前可在以下库中使用

并将继续增加。使用 yesno 数据集(来自 torchaudio.datasets.YESNO),我们将演示如何有效且高效地将数据从 PyTorch Dataset 加载到 PyTorch DataLoader 中。

设置

在开始之前,我们需要安装 torchaudio 才能访问数据集。

# pip install torchaudio

要在 Google Colab 中运行,请取消以下行的注释

# !pip install torchaudio

步骤

  1. 导入加载数据所需的所有库

  2. 访问数据集中的数据

  3. 加载数据

  4. 遍历数据

  5. [可选] 可视化数据

1. 导入加载数据所需的所有库

对于此食谱,我们将使用 torchtorchaudio。根据您使用的内置数据集,您还可以安装并导入 torchvisiontorchtext

import torch
import torchaudio

2. 访问数据集中的数据

torchaudio 中的 yesno 数据集包含 60 个录音,由一个人用希伯来语说“是”或“否”,每个录音包含 8 个词语(点击此处了解更多)。

torchaudio.datasets.YESNO 创建了 yesno 的数据集。

torchaudio.datasets.YESNO(
     root='./',
     url='http://www.openslr.org/resources/1/waves_yesno.tar.gz',
     folder_in_archive='waves_yesno',
     download=True)

数据集中的每个项目都是一个元组,形式为:(波形,采样率,标签)。

您必须为 yesno 数据集设置一个 root,训练和测试数据集将存在于此位置。其他参数是可选的,它们的值将显示为默认值。以下是一些关于其他参数的有用信息

# * ``download``: If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
#
# Let’s access our ``yesno`` data:
#

# A data point in ``yesno`` is a tuple (waveform, sample_rate, labels) where labels
# is a list of integers with 1 for yes and 0 for no.
yesno_data = torchaudio.datasets.YESNO('./', download=True)

# Pick data point number 3 to see an example of the the ``yesno_data``:
n = 3
waveform, sample_rate, labels = yesno_data[n]
print("Waveform: {}\nSample rate: {}\nLabels: {}".format(waveform, sample_rate, labels))

在实际使用这些数据时,最佳实践是将数据分配到“训练”数据集和“测试”数据集。这将确保您拥有测试模型性能的样本外数据。

3. 加载数据

现在我们已经可以访问数据集了,我们需要通过 torch.utils.data.DataLoader 传递它。 DataLoader 将数据集和采样器组合起来,返回一个可迭代的数据集。

data_loader = torch.utils.data.DataLoader(yesno_data,
                                          batch_size=1,
                                          shuffle=True)

4. 迭代数据

现在,我们可以使用 data_loader 迭代我们的数据。这在我们开始训练模型时将是必要的!您会注意到,现在 data_loader 对象中的每个数据条目都被转换为一个张量,该张量包含表示波形、采样率和标签的张量。

for data in data_loader:
  print("Data: ", data)
  print("Waveform: {}\nSample rate: {}\nLabels: {}".format(data[0], data[1], data[2]))
  break

5. [可选] 可视化数据

您可以选择可视化您的数据,以更好地了解 DataLoader 的输出。

import matplotlib.pyplot as plt

print(data[0][0].numpy())

plt.figure()
plt.plot(waveform.t().numpy())

恭喜!您已成功在 PyTorch 中加载数据。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源