注意
单击 此处 下载完整示例代码
在 PyTorch 中加载数据¶
PyTorch 提供了广泛的神经网络构建块,并具有简单、直观且稳定的 API。PyTorch 包含用于为您的模型准备和加载常用数据集的包。
简介¶
PyTorch 数据加载实用程序的核心是 torch.utils.data.DataLoader 类。它表示一个数据集上的 Python 可迭代对象。PyTorch 中的库提供内置的高质量数据集,供您在 torch.utils.data.Dataset 中使用。这些数据集目前可在以下库中使用
并将继续增加。使用 yesno
数据集(来自 torchaudio.datasets.YESNO
),我们将演示如何有效且高效地将数据从 PyTorch Dataset
加载到 PyTorch DataLoader
中。
设置¶
在开始之前,我们需要安装 torchaudio
才能访问数据集。
# pip install torchaudio
要在 Google Colab 中运行,请取消以下行的注释
# !pip install torchaudio
步骤¶
导入加载数据所需的所有库
访问数据集中的数据
加载数据
遍历数据
[可选] 可视化数据
1. 导入加载数据所需的所有库¶
对于此食谱,我们将使用 torch
和 torchaudio
。根据您使用的内置数据集,您还可以安装并导入 torchvision
或 torchtext
。
import torch
import torchaudio
2. 访问数据集中的数据¶
在 torchaudio
中的 yesno
数据集包含 60 个录音,由一个人用希伯来语说“是”或“否”,每个录音包含 8 个词语(点击此处了解更多)。
torchaudio.datasets.YESNO
创建了 yesno
的数据集。
torchaudio.datasets.YESNO(
root='./',
url='http://www.openslr.org/resources/1/waves_yesno.tar.gz',
folder_in_archive='waves_yesno',
download=True)
数据集中的每个项目都是一个元组,形式为:(波形,采样率,标签)。
您必须为 yesno
数据集设置一个 root
,训练和测试数据集将存在于此位置。其他参数是可选的,它们的值将显示为默认值。以下是一些关于其他参数的有用信息
# * ``download``: If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
#
# Let’s access our ``yesno`` data:
#
# A data point in ``yesno`` is a tuple (waveform, sample_rate, labels) where labels
# is a list of integers with 1 for yes and 0 for no.
yesno_data = torchaudio.datasets.YESNO('./', download=True)
# Pick data point number 3 to see an example of the the ``yesno_data``:
n = 3
waveform, sample_rate, labels = yesno_data[n]
print("Waveform: {}\nSample rate: {}\nLabels: {}".format(waveform, sample_rate, labels))
在实际使用这些数据时,最佳实践是将数据分配到“训练”数据集和“测试”数据集。这将确保您拥有测试模型性能的样本外数据。
3. 加载数据¶
现在我们已经可以访问数据集了,我们需要通过 torch.utils.data.DataLoader
传递它。 DataLoader
将数据集和采样器组合起来,返回一个可迭代的数据集。
data_loader = torch.utils.data.DataLoader(yesno_data,
batch_size=1,
shuffle=True)
4. 迭代数据¶
现在,我们可以使用 data_loader
迭代我们的数据。这在我们开始训练模型时将是必要的!您会注意到,现在 data_loader
对象中的每个数据条目都被转换为一个张量,该张量包含表示波形、采样率和标签的张量。
for data in data_loader:
print("Data: ", data)
print("Waveform: {}\nSample rate: {}\nLabels: {}".format(data[0], data[1], data[2]))
break
5. [可选] 可视化数据¶
您可以选择可视化您的数据,以更好地了解 DataLoader
的输出。
import matplotlib.pyplot as plt
print(data[0][0].numpy())
plt.figure()
plt.plot(waveform.t().numpy())
恭喜!您已成功在 PyTorch 中加载数据。