• 教程 >
  • 使用 torchaudio 进行语音命令分类
快捷方式

使用 torchaudio 进行语音命令分类

本教程将向您展示如何正确格式化音频数据集,然后在该数据集上训练/测试音频分类器网络。

Colab 提供了 GPU 选项。在菜单选项卡中,选择“运行时”,然后选择“更改运行时类型”。在随后弹出的窗口中,您可以选择 GPU。更改后,您的运行时应自动重启(这意味着已执行单元格中的信息将消失)。

首先,让我们导入常用的 torch 包,例如 torchaudio,它可以通过按照网站上的说明进行安装。

# Uncomment the line corresponding to your "runtime type" to run in Google Colab

# CPU:
# !pip install pydub torch==1.7.0+cpu torchvision==0.8.1+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html

# GPU:
# !pip install pydub torch==1.7.0+cu101 torchvision==0.8.1+cu101 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
import sys

import matplotlib.pyplot as plt
import IPython.display as ipd

from tqdm import tqdm

让我们检查 CUDA GPU 是否可用并选择我们的设备。在 GPU 上运行网络将大大减少训练/测试运行时间。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda

导入数据集

我们使用 torchaudio 下载和表示数据集。这里我们使用 SpeechCommands,它是一个由不同人说的 35 个命令的数据集。数据集 SPEECHCOMMANDS 是该数据集的 torch.utils.data.Dataset 版本。在这个数据集中,所有音频文件大约 1 秒长(因此大约 16000 个时间帧长)。

实际的加载和格式化步骤发生在访问数据点时,torchaudio 会负责将音频文件转换为张量。如果想直接加载音频文件,可以使用 torchaudio.load()。它返回一个元组,包含新创建的张量以及音频文件的采样频率(SpeechCommands 为 16kHz)。

回到数据集,这里我们创建一个子类,将其拆分为标准的训练、验证、测试子集。

from torchaudio.datasets import SPEECHCOMMANDS
import os


class SubsetSC(SPEECHCOMMANDS):
    def __init__(self, subset: str = None):
        super().__init__("./", download=True)

        def load_list(filename):
            filepath = os.path.join(self._path, filename)
            with open(filepath) as fileobj:
                return [os.path.normpath(os.path.join(self._path, line.strip())) for line in fileobj]

        if subset == "validation":
            self._walker = load_list("validation_list.txt")
        elif subset == "testing":
            self._walker = load_list("testing_list.txt")
        elif subset == "training":
            excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
            excludes = set(excludes)
            self._walker = [w for w in self._walker if w not in excludes]


# Create training and testing split of the data. We do not use validation in this tutorial.
train_set = SubsetSC("training")
test_set = SubsetSC("testing")

waveform, sample_rate, label, speaker_id, utterance_number = train_set[0]
  0%|          | 0.00/2.26G [00:00<?, ?B/s]
  1%|          | 13.1M/2.26G [00:00<00:17, 137MB/s]
  1%|1         | 30.4M/2.26G [00:00<00:14, 163MB/s]
  2%|2         | 47.5M/2.26G [00:00<00:13, 171MB/s]
  3%|2         | 64.8M/2.26G [00:00<00:13, 175MB/s]
  4%|3         | 82.0M/2.26G [00:00<00:13, 177MB/s]
  4%|4         | 98.9M/2.26G [00:00<00:13, 177MB/s]
  5%|5         | 116M/2.26G [00:00<00:12, 178MB/s]
  6%|5         | 133M/2.26G [00:00<00:12, 177MB/s]
  6%|6         | 150M/2.26G [00:00<00:12, 175MB/s]
  7%|7         | 167M/2.26G [00:01<00:12, 175MB/s]
  8%|7         | 184M/2.26G [00:01<00:12, 176MB/s]
  9%|8         | 201M/2.26G [00:01<00:12, 177MB/s]
  9%|9         | 218M/2.26G [00:01<00:12, 176MB/s]
 10%|#         | 236M/2.26G [00:01<00:12, 179MB/s]
 11%|#         | 254M/2.26G [00:01<00:11, 184MB/s]
 12%|#1        | 273M/2.26G [00:01<00:11, 188MB/s]
 13%|#2        | 291M/2.26G [00:01<00:11, 189MB/s]
 13%|#3        | 310M/2.26G [00:01<00:11, 191MB/s]
 14%|#4        | 329M/2.26G [00:01<00:10, 194MB/s]
 15%|#5        | 348M/2.26G [00:02<00:10, 195MB/s]
 16%|#5        | 367M/2.26G [00:02<00:10, 198MB/s]
 17%|#6        | 386M/2.26G [00:02<00:10, 199MB/s]
 18%|#7        | 406M/2.26G [00:02<00:10, 199MB/s]
 18%|#8        | 425M/2.26G [00:02<00:09, 199MB/s]
 19%|#9        | 444M/2.26G [00:02<00:09, 199MB/s]
 20%|#9        | 463M/2.26G [00:02<00:09, 199MB/s]
 21%|##        | 482M/2.26G [00:02<00:09, 197MB/s]
 22%|##1       | 501M/2.26G [00:02<00:09, 196MB/s]
 22%|##2       | 519M/2.26G [00:02<00:09, 195MB/s]
 23%|##3       | 538M/2.26G [00:03<00:09, 196MB/s]
 24%|##4       | 557M/2.26G [00:03<00:09, 196MB/s]
 25%|##4       | 576M/2.26G [00:03<00:09, 192MB/s]
 26%|##5       | 594M/2.26G [00:03<00:09, 188MB/s]
 26%|##6       | 612M/2.26G [00:03<00:09, 188MB/s]
 27%|##7       | 630M/2.26G [00:03<00:09, 187MB/s]
 28%|##7       | 648M/2.26G [00:03<00:09, 188MB/s]
 29%|##8       | 668M/2.26G [00:03<00:08, 194MB/s]
 30%|##9       | 687M/2.26G [00:03<00:08, 196MB/s]
 30%|###       | 706M/2.26G [00:03<00:08, 197MB/s]
 31%|###1      | 725M/2.26G [00:04<00:08, 198MB/s]
 32%|###2      | 744M/2.26G [00:04<00:08, 198MB/s]
 33%|###2      | 763M/2.26G [00:04<00:08, 197MB/s]
 34%|###3      | 781M/2.26G [00:04<00:08, 197MB/s]
 35%|###4      | 801M/2.26G [00:04<00:08, 198MB/s]
 35%|###5      | 820M/2.26G [00:04<00:07, 198MB/s]
 36%|###6      | 838M/2.26G [00:04<00:07, 197MB/s]
 37%|###7      | 857M/2.26G [00:04<00:07, 197MB/s]
 38%|###7      | 876M/2.26G [00:04<00:07, 197MB/s]
 39%|###8      | 895M/2.26G [00:04<00:07, 197MB/s]
 39%|###9      | 914M/2.26G [00:05<00:07, 195MB/s]
 40%|####      | 933M/2.26G [00:05<00:07, 196MB/s]
 41%|####1     | 952M/2.26G [00:05<00:07, 198MB/s]
 42%|####1     | 971M/2.26G [00:05<00:07, 199MB/s]
 43%|####2     | 990M/2.26G [00:05<00:06, 200MB/s]
 44%|####3     | 0.99G/2.26G [00:05<00:07, 186MB/s]
 44%|####4     | 1.00G/2.26G [00:05<00:07, 190MB/s]
 45%|####5     | 1.02G/2.26G [00:05<00:06, 194MB/s]
 46%|####6     | 1.04G/2.26G [00:05<00:06, 195MB/s]
 47%|####6     | 1.06G/2.26G [00:05<00:06, 195MB/s]
 48%|####7     | 1.08G/2.26G [00:06<00:06, 197MB/s]
 48%|####8     | 1.10G/2.26G [00:06<00:06, 196MB/s]
 49%|####9     | 1.12G/2.26G [00:06<00:06, 198MB/s]
 50%|#####     | 1.13G/2.26G [00:06<00:06, 195MB/s]
 51%|#####     | 1.15G/2.26G [00:06<00:06, 190MB/s]
 52%|#####1    | 1.17G/2.26G [00:06<00:06, 183MB/s]
 52%|#####2    | 1.19G/2.26G [00:06<00:06, 182MB/s]
 53%|#####3    | 1.21G/2.26G [00:06<00:06, 186MB/s]
 54%|#####4    | 1.22G/2.26G [00:06<00:05, 190MB/s]
 55%|#####4    | 1.24G/2.26G [00:06<00:05, 188MB/s]
 56%|#####5    | 1.26G/2.26G [00:07<00:05, 191MB/s]
 56%|#####6    | 1.28G/2.26G [00:07<00:05, 192MB/s]
 57%|#####7    | 1.30G/2.26G [00:07<00:05, 192MB/s]
 58%|#####8    | 1.31G/2.26G [00:07<00:05, 187MB/s]
 59%|#####8    | 1.33G/2.26G [00:07<00:05, 175MB/s]
 60%|#####9    | 1.35G/2.26G [00:07<00:05, 175MB/s]
 60%|######    | 1.36G/2.26G [00:07<00:05, 172MB/s]
 61%|######1   | 1.38G/2.26G [00:07<00:05, 169MB/s]
 62%|######1   | 1.40G/2.26G [00:07<00:05, 167MB/s]
 62%|######2   | 1.41G/2.26G [00:08<00:05, 166MB/s]
 63%|######3   | 1.43G/2.26G [00:08<00:05, 168MB/s]
 64%|######3   | 1.44G/2.26G [00:08<00:05, 166MB/s]
 64%|######4   | 1.46G/2.26G [00:08<00:05, 166MB/s]
 65%|######5   | 1.47G/2.26G [00:08<00:06, 141MB/s]
 66%|######5   | 1.49G/2.26G [00:08<00:05, 149MB/s]
 67%|######6   | 1.51G/2.26G [00:08<00:05, 159MB/s]
 67%|######7   | 1.52G/2.26G [00:08<00:04, 165MB/s]
 68%|######8   | 1.54G/2.26G [00:08<00:04, 164MB/s]
 69%|######8   | 1.56G/2.26G [00:09<00:04, 164MB/s]
 69%|######9   | 1.57G/2.26G [00:09<00:04, 164MB/s]
 70%|#######   | 1.59G/2.26G [00:09<00:04, 166MB/s]
 71%|#######   | 1.60G/2.26G [00:09<00:04, 167MB/s]
 72%|#######1  | 1.62G/2.26G [00:09<00:04, 167MB/s]
 72%|#######2  | 1.63G/2.26G [00:09<00:04, 166MB/s]
 73%|#######2  | 1.65G/2.26G [00:09<00:03, 165MB/s]
 74%|#######3  | 1.67G/2.26G [00:09<00:03, 166MB/s]
 74%|#######4  | 1.68G/2.26G [00:09<00:03, 169MB/s]
 75%|#######5  | 1.70G/2.26G [00:09<00:03, 169MB/s]
 76%|#######5  | 1.71G/2.26G [00:10<00:03, 169MB/s]
 76%|#######6  | 1.73G/2.26G [00:10<00:03, 171MB/s]
 77%|#######7  | 1.75G/2.26G [00:10<00:03, 170MB/s]
 78%|#######7  | 1.76G/2.26G [00:10<00:03, 168MB/s]
 79%|#######8  | 1.78G/2.26G [00:10<00:03, 167MB/s]
 79%|#######9  | 1.79G/2.26G [00:10<00:03, 166MB/s]
 80%|#######9  | 1.81G/2.26G [00:10<00:02, 167MB/s]
 81%|########  | 1.82G/2.26G [00:10<00:02, 168MB/s]
 81%|########1 | 1.84G/2.26G [00:10<00:02, 169MB/s]
 82%|########2 | 1.86G/2.26G [00:10<00:02, 168MB/s]
 83%|########2 | 1.87G/2.26G [00:11<00:02, 169MB/s]
 83%|########3 | 1.89G/2.26G [00:11<00:02, 170MB/s]
 84%|########4 | 1.90G/2.26G [00:11<00:02, 171MB/s]
 85%|########4 | 1.92G/2.26G [00:11<00:02, 172MB/s]
 86%|########5 | 1.94G/2.26G [00:11<00:02, 171MB/s]
 86%|########6 | 1.95G/2.26G [00:11<00:01, 171MB/s]
 87%|########6 | 1.97G/2.26G [00:11<00:01, 169MB/s]
 88%|########7 | 1.98G/2.26G [00:11<00:01, 168MB/s]
 88%|########8 | 2.00G/2.26G [00:11<00:01, 168MB/s]
 89%|########9 | 2.01G/2.26G [00:11<00:01, 168MB/s]
 90%|########9 | 2.03G/2.26G [00:12<00:01, 170MB/s]
 90%|######### | 2.05G/2.26G [00:12<00:01, 170MB/s]
 91%|#########1| 2.06G/2.26G [00:12<00:01, 169MB/s]
 92%|#########1| 2.08G/2.26G [00:12<00:01, 168MB/s]
 93%|#########2| 2.09G/2.26G [00:12<00:01, 169MB/s]
 93%|#########3| 2.11G/2.26G [00:12<00:00, 171MB/s]
 94%|#########4| 2.13G/2.26G [00:12<00:00, 171MB/s]
 95%|#########4| 2.14G/2.26G [00:12<00:00, 169MB/s]
 95%|#########5| 2.16G/2.26G [00:12<00:00, 167MB/s]
 96%|#########6| 2.17G/2.26G [00:12<00:00, 167MB/s]
 97%|#########6| 2.19G/2.26G [00:13<00:00, 169MB/s]
 98%|#########7| 2.21G/2.26G [00:13<00:00, 169MB/s]
 98%|#########8| 2.22G/2.26G [00:13<00:00, 169MB/s]
 99%|#########8| 2.24G/2.26G [00:13<00:00, 170MB/s]
100%|#########9| 2.25G/2.26G [00:13<00:00, 146MB/s]
100%|##########| 2.26G/2.26G [00:13<00:00, 179MB/s]

SPEECHCOMMANDS 数据集中的一个数据点是一个元组,由波形(音频信号)、采样率、话语(标签)、说话者 ID 和话语编号组成。

print("Shape of waveform: {}".format(waveform.size()))
print("Sample rate of waveform: {}".format(sample_rate))

plt.plot(waveform.t().numpy());
speech command classification with torchaudio tutorial
Shape of waveform: torch.Size([1, 16000])
Sample rate of waveform: 16000

[<matplotlib.lines.Line2D object at 0x7fcd4fcfd150>]

让我们找到数据集中可用的标签列表。

labels = sorted(list(set(datapoint[2] for datapoint in train_set)))
labels
['backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow', 'forward', 'four', 'go', 'happy', 'house', 'learn', 'left', 'marvin', 'nine', 'no', 'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three', 'tree', 'two', 'up', 'visual', 'wow', 'yes', 'zero']

这 35 个音频标签是用户所说的命令。前几个文件是人们说“marvin”的声音。

waveform_first, *_ = train_set[0]
ipd.Audio(waveform_first.numpy(), rate=sample_rate)

waveform_second, *_ = train_set[1]
ipd.Audio(waveform_second.numpy(), rate=sample_rate)


最后一个文件是有人说“visual”的声音。

waveform_last, *_ = train_set[-1]
ipd.Audio(waveform_last.numpy(), rate=sample_rate)


格式化数据

这是一个对数据应用转换的好地方。对于波形,我们对音频进行降采样以加快处理速度,而不会损失太多分类能力。

我们不需要在这里应用其他转换。不过,对于某些数据集来说,通常需要通过沿通道维度取平均值或仅保留其中一个通道来减少通道数(例如,从立体声到单声道)。由于 SpeechCommands 使用单个通道进行音频,因此这里不需要这样做。

new_sample_rate = 8000
transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=new_sample_rate)
transformed = transform(waveform)

ipd.Audio(transformed.numpy(), rate=new_sample_rate)


我们使用标签列表中的索引对每个单词进行编码。

def label_to_index(word):
    # Return the position of the word in labels
    return torch.tensor(labels.index(word))


def index_to_label(index):
    # Return the word corresponding to the index in labels
    # This is the inverse of label_to_index
    return labels[index]


word_start = "yes"
index = label_to_index(word_start)
word_recovered = index_to_label(index)

print(word_start, "-->", index, "-->", word_recovered)
yes --> tensor(33) --> yes

为了将由音频录制和话语组成的数据点列表转换为模型的两个批处理张量,我们实现了一个 collate 函数,该函数由 PyTorch DataLoader 使用,使我们能够按批次迭代数据集。有关使用 collate 函数的更多信息,请参阅 文档

在 collate 函数中,我们还应用重采样和文本编码。

def pad_sequence(batch):
    # Make all tensor in a batch the same length by padding with zeros
    batch = [item.t() for item in batch]
    batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.)
    return batch.permute(0, 2, 1)


def collate_fn(batch):

    # A data tuple has the form:
    # waveform, sample_rate, label, speaker_id, utterance_number

    tensors, targets = [], []

    # Gather in lists, and encode labels as indices
    for waveform, _, label, *_ in batch:
        tensors += [waveform]
        targets += [label_to_index(label)]

    # Group the list of tensors into a batched tensor
    tensors = pad_sequence(tensors)
    targets = torch.stack(targets)

    return tensors, targets


batch_size = 256

if device == "cuda":
    num_workers = 1
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)
test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

定义网络

在本教程中,我们将使用卷积神经网络(CNN)来处理原始音频数据。通常,会对音频数据应用更高级的变换,但 CNN 也可以用于精确地处理原始数据。具体的架构是根据这篇论文中描述的 M5 网络架构进行建模的。处理原始音频数据的模型的一个重要方面是其第一层滤波器的感受野。我们模型的第一层滤波器长度为 80,因此在处理以 8kHz 采样的音频时,感受野约为 10ms(在 4kHz 时,约为 20ms)。此大小类似于语音处理应用程序,这些应用程序通常使用 20ms 到 40ms 的感受野。

class M5(nn.Module):
    def __init__(self, n_input=1, n_output=35, stride=16, n_channel=32):
        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.pool1 = nn.MaxPool1d(4)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.pool2 = nn.MaxPool1d(4)
        self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(2 * n_channel)
        self.pool3 = nn.MaxPool1d(4)
        self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.pool4 = nn.MaxPool1d(4)
        self.fc1 = nn.Linear(2 * n_channel, n_output)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        x = self.pool2(x)
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        x = self.pool3(x)
        x = self.conv4(x)
        x = F.relu(self.bn4(x))
        x = self.pool4(x)
        x = F.avg_pool1d(x, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        return F.log_softmax(x, dim=2)


model = M5(n_input=transformed.shape[0], n_output=len(labels))
model.to(device)
print(model)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


n = count_parameters(model)
print("Number of parameters: %s" % n)
M5(
  (conv1): Conv1d(1, 32, kernel_size=(80,), stride=(16,))
  (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(32, 32, kernel_size=(3,), stride=(1,))
  (bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv1d(32, 64, kernel_size=(3,), stride=(1,))
  (bn3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool3): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv1d(64, 64, kernel_size=(3,), stride=(1,))
  (bn4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool4): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=64, out_features=35, bias=True)
)
Number of parameters: 26915

我们将使用论文中使用的相同优化技术,即带有权重衰减(设置为 0.0001)的 Adam 优化器。首先,我们将使用 0.01 的学习率进行训练,但我们将在训练 20 个 epoch 后使用一个scheduler将其降低到 0.001。

optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)  # reduce the learning after 20 epochs by a factor of 10

训练和测试网络

现在让我们定义一个训练函数,该函数将我们的训练数据馈送到模型中,并执行反向传播和优化步骤。对于训练,我们将使用的损失函数是负对数似然。然后,网络将在每个 epoch 后进行测试,以查看准确性在训练期间是如何变化的。

def train(model, epoch, log_interval):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):

        data = data.to(device)
        target = target.to(device)

        # apply transform and model on whole batch directly on device
        data = transform(data)
        output = model(data)

        # negative log-likelihood for a tensor of size (batch x 1 x n_output)
        loss = F.nll_loss(output.squeeze(), target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print training stats
        if batch_idx % log_interval == 0:
            print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

        # update progress bar
        pbar.update(pbar_update)
        # record loss
        losses.append(loss.item())

现在我们有了训练函数,我们需要为测试网络的准确性创建一个函数。我们将模型设置为eval()模式,然后在测试数据集上运行推理。调用eval()会将网络中所有模块的训练变量设置为 false。某些层(如批归一化和 dropout 层)在训练期间的行为有所不同,因此此步骤对于获得正确的结果至关重要。

def number_of_correct(pred, target):
    # count number of correct predictions
    return pred.squeeze().eq(target).sum().item()


def get_likely_index(tensor):
    # find most likely label index for each element in the batch
    return tensor.argmax(dim=-1)


def test(model, epoch):
    model.eval()
    correct = 0
    for data, target in test_loader:

        data = data.to(device)
        target = target.to(device)

        # apply transform and model on whole batch directly on device
        data = transform(data)
        output = model(data)

        pred = get_likely_index(output)
        correct += number_of_correct(pred, target)

        # update progress bar
        pbar.update(pbar_update)

    print(f"\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n")

最后,我们可以训练和测试网络。我们将训练网络 10 个 epoch,然后降低学习率并再训练 10 个 epoch。网络将在每个 epoch 后进行测试,以查看准确性在训练期间是如何变化的。

log_interval = 20
n_epoch = 2

pbar_update = 1 / (len(train_loader) + len(test_loader))
losses = []

# The transform needs to live on the same device as the model and the data.
transform = transform.to(device)
with tqdm(total=n_epoch) as pbar:
    for epoch in range(1, n_epoch + 1):
        train(model, epoch, log_interval)
        test(model, epoch)
        scheduler.step()

# Let's plot the training loss versus the number of iteration.
# plt.plot(losses);
# plt.title("training loss");
  0%|          | 0/2 [00:00<?, ?it/s]Train Epoch: 1 [0/84843 (0%)]      Loss: 3.813234

  0%|          | 0.0026666666666666666/2 [00:02<35:06, 1054.66s/it]
  0%|          | 0.005333333333333333/2 [00:02<15:14, 458.61s/it]
  1%|          | 0.010666666666666666/2 [00:03<06:24, 193.47s/it]
  1%|          | 0.016/2 [00:03<03:56, 119.08s/it]
  1%|1         | 0.021333333333333336/2 [00:03<02:47, 84.82s/it]
  1%|1         | 0.026666666666666672/2 [00:03<02:10, 66.28s/it]
  2%|1         | 0.03200000000000001/2 [00:03<01:49, 55.42s/it]
  2%|1         | 0.037333333333333336/2 [00:04<01:34, 48.32s/it]
  2%|2         | 0.042666666666666665/2 [00:04<01:26, 44.01s/it]
  2%|2         | 0.047999999999999994/2 [00:04<01:19, 40.72s/it]
  3%|2         | 0.05333333333333332/2 [00:04<01:14, 38.51s/it] Train Epoch: 1 [5120/84843 (6%)]        Loss: 3.041089

  3%|2         | 0.05866666666666665/2 [00:04<01:12, 37.28s/it]
  3%|3         | 0.06399999999999999/2 [00:04<01:09, 35.69s/it]
  3%|3         | 0.06933333333333332/2 [00:05<01:05, 34.18s/it]
  4%|3         | 0.07466666666666665/2 [00:05<01:04, 33.40s/it]
  4%|3         | 0.07999999999999997/2 [00:05<01:02, 32.65s/it]
  4%|4         | 0.0853333333333333/2 [00:05<01:01, 32.33s/it]
  5%|4         | 0.09066666666666663/2 [00:05<01:01, 32.00s/it]
  5%|4         | 0.09599999999999996/2 [00:05<01:00, 31.76s/it]
  5%|5         | 0.10133333333333329/2 [00:06<01:00, 31.79s/it]
  5%|5         | 0.10666666666666662/2 [00:06<00:59, 31.63s/it]Train Epoch: 1 [10240/84843 (12%)]       Loss: 2.553542

  6%|5         | 0.11199999999999995/2 [00:06<00:59, 31.61s/it]
  6%|5         | 0.11733333333333328/2 [00:06<00:59, 31.68s/it]
  6%|6         | 0.1226666666666666/2 [00:06<00:59, 31.80s/it]
  6%|6         | 0.12799999999999995/2 [00:06<00:59, 32.00s/it]
  7%|6         | 0.1333333333333333/2 [00:07<01:00, 32.18s/it]
  7%|6         | 0.13866666666666666/2 [00:07<01:00, 32.38s/it]
  7%|7         | 0.14400000000000002/2 [00:07<00:59, 31.99s/it]
  7%|7         | 0.14933333333333337/2 [00:07<00:59, 31.94s/it]
  8%|7         | 0.15466666666666673/2 [00:07<00:58, 31.82s/it]
  8%|8         | 0.1600000000000001/2 [00:07<00:58, 31.83s/it] Train Epoch: 1 [15360/84843 (18%)]       Loss: 2.254806

  8%|8         | 0.16533333333333344/2 [00:08<00:58, 31.74s/it]
  9%|8         | 0.1706666666666668/2 [00:08<00:57, 31.45s/it]
  9%|8         | 0.17600000000000016/2 [00:08<00:56, 31.23s/it]
  9%|9         | 0.1813333333333335/2 [00:08<00:56, 31.16s/it]
  9%|9         | 0.18666666666666687/2 [00:08<00:56, 31.21s/it]
 10%|9         | 0.19200000000000023/2 [00:08<00:56, 31.32s/it]
 10%|9         | 0.19733333333333358/2 [00:09<00:57, 31.75s/it]
 10%|#         | 0.20266666666666694/2 [00:09<00:56, 31.59s/it]
 10%|#         | 0.2080000000000003/2 [00:09<00:56, 31.46s/it]
 11%|#         | 0.21333333333333365/2 [00:09<00:56, 31.56s/it]Train Epoch: 1 [20480/84843 (24%)]       Loss: 2.029420

 11%|#         | 0.218666666666667/2 [00:09<00:56, 31.62s/it]
 11%|#1        | 0.22400000000000037/2 [00:10<00:58, 32.74s/it]
 11%|#1        | 0.22933333333333372/2 [00:10<00:59, 33.57s/it]
 12%|#1        | 0.23466666666666708/2 [00:10<00:59, 33.96s/it]
 12%|#2        | 0.24000000000000044/2 [00:10<01:00, 34.59s/it]
 12%|#2        | 0.2453333333333338/2 [00:10<01:00, 34.48s/it]
 13%|#2        | 0.25066666666666715/2 [00:10<00:58, 33.66s/it]
 13%|#2        | 0.25600000000000045/2 [00:11<00:58, 33.29s/it]
 13%|#3        | 0.26133333333333375/2 [00:11<00:57, 33.15s/it]
 13%|#3        | 0.26666666666666705/2 [00:11<00:56, 32.75s/it]Train Epoch: 1 [25600/84843 (30%)]       Loss: 1.780852

 14%|#3        | 0.27200000000000035/2 [00:11<00:55, 32.39s/it]
 14%|#3        | 0.27733333333333365/2 [00:11<00:55, 32.33s/it]
 14%|#4        | 0.28266666666666695/2 [00:11<00:55, 32.18s/it]
 14%|#4        | 0.28800000000000026/2 [00:12<00:54, 31.90s/it]
 15%|#4        | 0.29333333333333356/2 [00:12<00:54, 31.80s/it]
 15%|#4        | 0.29866666666666686/2 [00:12<00:53, 31.72s/it]
 15%|#5        | 0.30400000000000016/2 [00:12<00:53, 31.61s/it]
 15%|#5        | 0.30933333333333346/2 [00:12<00:53, 31.57s/it]
 16%|#5        | 0.31466666666666676/2 [00:12<00:53, 31.86s/it]
 16%|#6        | 0.32000000000000006/2 [00:13<00:54, 32.20s/it]Train Epoch: 1 [30720/84843 (36%)]       Loss: 1.574860

 16%|#6        | 0.32533333333333336/2 [00:13<00:53, 31.77s/it]
 17%|#6        | 0.33066666666666666/2 [00:13<00:52, 31.46s/it]
 17%|#6        | 0.33599999999999997/2 [00:13<00:52, 31.58s/it]
 17%|#7        | 0.34133333333333327/2 [00:13<00:52, 31.46s/it]
 17%|#7        | 0.34666666666666657/2 [00:13<00:52, 31.48s/it]
 18%|#7        | 0.35199999999999987/2 [00:14<00:52, 31.57s/it]
 18%|#7        | 0.35733333333333317/2 [00:14<00:51, 31.63s/it]
 18%|#8        | 0.36266666666666647/2 [00:14<00:51, 31.71s/it]
 18%|#8        | 0.36799999999999977/2 [00:14<00:51, 31.56s/it]
 19%|#8        | 0.3733333333333331/2 [00:14<00:51, 31.65s/it] Train Epoch: 1 [35840/84843 (42%)]       Loss: 1.676274

 19%|#8        | 0.3786666666666664/2 [00:14<00:51, 31.64s/it]
 19%|#9        | 0.3839999999999997/2 [00:15<00:50, 31.42s/it]
 19%|#9        | 0.389333333333333/2 [00:15<00:50, 31.18s/it]
 20%|#9        | 0.3946666666666663/2 [00:15<00:50, 31.24s/it]
 20%|#9        | 0.3999999999999996/2 [00:15<00:50, 31.39s/it]
 20%|##        | 0.4053333333333329/2 [00:15<00:50, 31.55s/it]
 21%|##        | 0.4106666666666662/2 [00:15<00:50, 31.57s/it]
 21%|##        | 0.4159999999999995/2 [00:16<00:49, 31.45s/it]
 21%|##1       | 0.4213333333333328/2 [00:16<00:49, 31.46s/it]
 21%|##1       | 0.4266666666666661/2 [00:16<00:49, 31.66s/it]Train Epoch: 1 [40960/84843 (48%)]        Loss: 1.567081

 22%|##1       | 0.4319999999999994/2 [00:16<00:50, 31.97s/it]
 22%|##1       | 0.4373333333333327/2 [00:16<00:49, 31.72s/it]
 22%|##2       | 0.442666666666666/2 [00:17<00:49, 31.71s/it]
 22%|##2       | 0.4479999999999993/2 [00:17<00:49, 31.63s/it]
 23%|##2       | 0.4533333333333326/2 [00:17<00:48, 31.67s/it]
 23%|##2       | 0.4586666666666659/2 [00:17<00:49, 31.90s/it]
 23%|##3       | 0.4639999999999992/2 [00:17<00:48, 31.70s/it]
 23%|##3       | 0.4693333333333325/2 [00:17<00:48, 31.73s/it]
 24%|##3       | 0.4746666666666658/2 [00:18<00:48, 31.60s/it]
 24%|##3       | 0.4799999999999991/2 [00:18<00:47, 31.46s/it]Train Epoch: 1 [46080/84843 (54%)]        Loss: 1.373785

 24%|##4       | 0.4853333333333324/2 [00:18<00:47, 31.47s/it]
 25%|##4       | 0.4906666666666657/2 [00:18<00:47, 31.38s/it]
 25%|##4       | 0.495999999999999/2 [00:18<00:47, 31.66s/it]
 25%|##5       | 0.5013333333333323/2 [00:18<00:48, 32.09s/it]
 25%|##5       | 0.5066666666666657/2 [00:19<00:47, 31.92s/it]
 26%|##5       | 0.5119999999999991/2 [00:19<00:47, 31.78s/it]
 26%|##5       | 0.5173333333333325/2 [00:19<00:46, 31.69s/it]
 26%|##6       | 0.522666666666666/2 [00:19<00:46, 31.56s/it]
 26%|##6       | 0.5279999999999994/2 [00:19<00:46, 31.48s/it]
 27%|##6       | 0.5333333333333328/2 [00:19<00:46, 31.37s/it]Train Epoch: 1 [51200/84843 (60%)]        Loss: 1.345279

 27%|##6       | 0.5386666666666662/2 [00:20<00:46, 31.53s/it]
 27%|##7       | 0.5439999999999996/2 [00:20<00:46, 31.67s/it]
 27%|##7       | 0.549333333333333/2 [00:20<00:45, 31.69s/it]
 28%|##7       | 0.5546666666666664/2 [00:20<00:45, 31.75s/it]
 28%|##7       | 0.5599999999999998/2 [00:20<00:45, 31.73s/it]
 28%|##8       | 0.5653333333333332/2 [00:20<00:45, 31.94s/it]
 29%|##8       | 0.5706666666666667/2 [00:21<00:46, 32.51s/it]
 29%|##8       | 0.5760000000000001/2 [00:21<00:45, 32.22s/it]
 29%|##9       | 0.5813333333333335/2 [00:21<00:45, 32.06s/it]
 29%|##9       | 0.5866666666666669/2 [00:21<00:45, 32.21s/it]Train Epoch: 1 [56320/84843 (66%)]        Loss: 1.342250

 30%|##9       | 0.5920000000000003/2 [00:21<00:45, 32.14s/it]
 30%|##9       | 0.5973333333333337/2 [00:21<00:44, 31.89s/it]
 30%|###       | 0.6026666666666671/2 [00:22<00:44, 31.83s/it]
 30%|###       | 0.6080000000000005/2 [00:22<00:44, 31.90s/it]
 31%|###       | 0.613333333333334/2 [00:22<00:44, 31.96s/it]
 31%|###       | 0.6186666666666674/2 [00:22<00:44, 31.91s/it]
 31%|###1      | 0.6240000000000008/2 [00:22<00:44, 32.04s/it]
 31%|###1      | 0.6293333333333342/2 [00:22<00:43, 31.95s/it]
 32%|###1      | 0.6346666666666676/2 [00:23<00:43, 31.67s/it]
 32%|###2      | 0.640000000000001/2 [00:23<00:42, 31.41s/it] Train Epoch: 1 [61440/84843 (72%)]        Loss: 1.286719

 32%|###2      | 0.6453333333333344/2 [00:23<00:42, 31.22s/it]
 33%|###2      | 0.6506666666666678/2 [00:23<00:42, 31.32s/it]
 33%|###2      | 0.6560000000000012/2 [00:23<00:42, 31.41s/it]
 33%|###3      | 0.6613333333333347/2 [00:23<00:42, 31.45s/it]
 33%|###3      | 0.6666666666666681/2 [00:24<00:41, 31.46s/it]
 34%|###3      | 0.6720000000000015/2 [00:24<00:42, 31.79s/it]
 34%|###3      | 0.6773333333333349/2 [00:24<00:42, 31.96s/it]
 34%|###4      | 0.6826666666666683/2 [00:24<00:42, 32.07s/it]
 34%|###4      | 0.6880000000000017/2 [00:24<00:41, 31.82s/it]
 35%|###4      | 0.6933333333333351/2 [00:24<00:41, 31.90s/it]Train Epoch: 1 [66560/84843 (78%)]        Loss: 1.136720

 35%|###4      | 0.6986666666666685/2 [00:25<00:41, 31.76s/it]
 35%|###5      | 0.704000000000002/2 [00:25<00:41, 31.71s/it]
 35%|###5      | 0.7093333333333354/2 [00:25<00:41, 32.04s/it]
 36%|###5      | 0.7146666666666688/2 [00:25<00:41, 32.20s/it]
 36%|###6      | 0.7200000000000022/2 [00:25<00:41, 32.09s/it]
 36%|###6      | 0.7253333333333356/2 [00:25<00:40, 31.89s/it]
 37%|###6      | 0.730666666666669/2 [00:26<00:40, 31.78s/it]
 37%|###6      | 0.7360000000000024/2 [00:26<00:40, 31.77s/it]
 37%|###7      | 0.7413333333333358/2 [00:26<00:40, 31.89s/it]
 37%|###7      | 0.7466666666666693/2 [00:26<00:39, 31.67s/it]Train Epoch: 1 [71680/84843 (84%)]        Loss: 1.120914

 38%|###7      | 0.7520000000000027/2 [00:26<00:39, 31.68s/it]
 38%|###7      | 0.7573333333333361/2 [00:27<00:39, 31.85s/it]
 38%|###8      | 0.7626666666666695/2 [00:27<00:39, 31.57s/it]
 38%|###8      | 0.7680000000000029/2 [00:27<00:38, 31.49s/it]
 39%|###8      | 0.7733333333333363/2 [00:27<00:39, 32.02s/it]
 39%|###8      | 0.7786666666666697/2 [00:27<00:38, 31.83s/it]
 39%|###9      | 0.7840000000000031/2 [00:27<00:38, 31.82s/it]
 39%|###9      | 0.7893333333333366/2 [00:28<00:38, 31.93s/it]
 40%|###9      | 0.79466666666667/2 [00:28<00:38, 31.83s/it]
 40%|####      | 0.8000000000000034/2 [00:28<00:38, 31.89s/it]Train Epoch: 1 [76800/84843 (90%)]        Loss: 1.195492

 40%|####      | 0.8053333333333368/2 [00:28<00:38, 31.86s/it]
 41%|####      | 0.8106666666666702/2 [00:28<00:37, 31.56s/it]
 41%|####      | 0.8160000000000036/2 [00:28<00:37, 31.83s/it]
 41%|####1     | 0.821333333333337/2 [00:29<00:37, 31.72s/it]
 41%|####1     | 0.8266666666666704/2 [00:29<00:37, 31.57s/it]
 42%|####1     | 0.8320000000000038/2 [00:29<00:36, 31.64s/it]
 42%|####1     | 0.8373333333333373/2 [00:29<00:36, 31.43s/it]
 42%|####2     | 0.8426666666666707/2 [00:29<00:36, 31.45s/it]
 42%|####2     | 0.8480000000000041/2 [00:29<00:36, 31.54s/it]
 43%|####2     | 0.8533333333333375/2 [00:30<00:36, 31.72s/it]Train Epoch: 1 [81920/84843 (96%)]        Loss: 1.064560

 43%|####2     | 0.8586666666666709/2 [00:30<00:36, 32.06s/it]
 43%|####3     | 0.8640000000000043/2 [00:30<00:36, 32.07s/it]
 43%|####3     | 0.8693333333333377/2 [00:30<00:36, 32.00s/it]
 44%|####3     | 0.8746666666666711/2 [00:30<00:35, 31.88s/it]
 44%|####4     | 0.8800000000000046/2 [00:30<00:35, 31.82s/it]
 44%|####4     | 0.885333333333338/2 [00:31<00:33, 29.78s/it]
 45%|####4     | 0.8906666666666714/2 [00:31<00:33, 29.85s/it]
 45%|####4     | 0.8960000000000048/2 [00:31<00:32, 29.65s/it]
 45%|####5     | 0.9013333333333382/2 [00:31<00:32, 29.68s/it]
 45%|####5     | 0.9066666666666716/2 [00:31<00:32, 29.84s/it]
 46%|####5     | 0.912000000000005/2 [00:31<00:32, 29.89s/it]
 46%|####5     | 0.9173333333333384/2 [00:31<00:32, 29.99s/it]
 46%|####6     | 0.9226666666666719/2 [00:32<00:32, 30.10s/it]
 46%|####6     | 0.9280000000000053/2 [00:32<00:32, 29.88s/it]
 47%|####6     | 0.9333333333333387/2 [00:32<00:31, 29.91s/it]
 47%|####6     | 0.9386666666666721/2 [00:32<00:32, 30.24s/it]
 47%|####7     | 0.9440000000000055/2 [00:32<00:31, 30.20s/it]
 47%|####7     | 0.9493333333333389/2 [00:32<00:31, 30.17s/it]
 48%|####7     | 0.9546666666666723/2 [00:33<00:31, 30.11s/it]
 48%|####8     | 0.9600000000000057/2 [00:33<00:31, 30.18s/it]
 48%|####8     | 0.9653333333333391/2 [00:33<00:30, 29.91s/it]
 49%|####8     | 0.9706666666666726/2 [00:33<00:30, 30.06s/it]
 49%|####8     | 0.976000000000006/2 [00:33<00:31, 31.04s/it]
 49%|####9     | 0.9813333333333394/2 [00:33<00:31, 30.67s/it]
 49%|####9     | 0.9866666666666728/2 [00:34<00:30, 30.50s/it]
 50%|####9     | 0.9920000000000062/2 [00:34<00:30, 30.40s/it]
 50%|####9     | 0.9973333333333396/2 [00:34<00:30, 30.27s/it]
Test Epoch: 1   Accuracy: 6052/11005 (55%)

Train Epoch: 2 [0/84843 (0%)]   Loss: 1.176855

 50%|#####     | 1.0026666666666728/2 [00:34<00:30, 30.71s/it]
 50%|#####     | 1.008000000000006/2 [00:34<00:30, 30.75s/it]
 51%|#####     | 1.0133333333333392/2 [00:34<00:30, 30.96s/it]
 51%|#####     | 1.0186666666666724/2 [00:35<00:30, 31.20s/it]
 51%|#####1    | 1.0240000000000056/2 [00:35<00:30, 31.49s/it]
 51%|#####1    | 1.0293333333333388/2 [00:35<00:30, 31.82s/it]
 52%|#####1    | 1.034666666666672/2 [00:35<00:30, 31.77s/it]
 52%|#####2    | 1.0400000000000051/2 [00:35<00:30, 31.86s/it]
 52%|#####2    | 1.0453333333333383/2 [00:35<00:30, 31.54s/it]
 53%|#####2    | 1.0506666666666715/2 [00:36<00:29, 31.41s/it]Train Epoch: 2 [5120/84843 (6%)]  Loss: 1.009675

 53%|#####2    | 1.0560000000000047/2 [00:36<00:30, 31.94s/it]
 53%|#####3    | 1.061333333333338/2 [00:36<00:29, 31.71s/it]
 53%|#####3    | 1.066666666666671/2 [00:36<00:29, 31.66s/it]
 54%|#####3    | 1.0720000000000043/2 [00:36<00:29, 31.47s/it]
 54%|#####3    | 1.0773333333333375/2 [00:36<00:28, 31.37s/it]
 54%|#####4    | 1.0826666666666707/2 [00:37<00:28, 31.34s/it]
 54%|#####4    | 1.0880000000000039/2 [00:37<00:28, 31.36s/it]
 55%|#####4    | 1.093333333333337/2 [00:37<00:28, 31.58s/it]
 55%|#####4    | 1.0986666666666702/2 [00:37<00:28, 31.46s/it]
 55%|#####5    | 1.1040000000000034/2 [00:37<00:28, 31.33s/it]Train Epoch: 2 [10240/84843 (12%)]        Loss: 0.983963

 55%|#####5    | 1.1093333333333366/2 [00:37<00:28, 31.52s/it]
 56%|#####5    | 1.1146666666666698/2 [00:38<00:27, 31.40s/it]
 56%|#####6    | 1.120000000000003/2 [00:38<00:27, 31.33s/it]
 56%|#####6    | 1.1253333333333362/2 [00:38<00:27, 31.24s/it]
 57%|#####6    | 1.1306666666666694/2 [00:38<00:27, 31.35s/it]
 57%|#####6    | 1.1360000000000026/2 [00:38<00:27, 31.60s/it]
 57%|#####7    | 1.1413333333333358/2 [00:38<00:27, 31.82s/it]
 57%|#####7    | 1.146666666666669/2 [00:39<00:27, 31.71s/it]
 58%|#####7    | 1.1520000000000021/2 [00:39<00:26, 31.78s/it]
 58%|#####7    | 1.1573333333333353/2 [00:39<00:26, 31.55s/it]Train Epoch: 2 [15360/84843 (18%)]        Loss: 0.854694

 58%|#####8    | 1.1626666666666685/2 [00:39<00:26, 32.00s/it]
 58%|#####8    | 1.1680000000000017/2 [00:39<00:26, 32.06s/it]
 59%|#####8    | 1.173333333333335/2 [00:39<00:26, 32.09s/it]
 59%|#####8    | 1.178666666666668/2 [00:40<00:26, 32.00s/it]
 59%|#####9    | 1.1840000000000013/2 [00:40<00:26, 32.00s/it]
 59%|#####9    | 1.1893333333333345/2 [00:40<00:25, 31.85s/it]
 60%|#####9    | 1.1946666666666677/2 [00:40<00:25, 31.60s/it]
 60%|######    | 1.2000000000000008/2 [00:40<00:25, 31.99s/it]
 60%|######    | 1.205333333333334/2 [00:41<00:25, 31.91s/it]
 61%|######    | 1.2106666666666672/2 [00:41<00:25, 31.86s/it]Train Epoch: 2 [20480/84843 (24%)]        Loss: 0.923911

 61%|######    | 1.2160000000000004/2 [00:41<00:25, 31.89s/it]
 61%|######1   | 1.2213333333333336/2 [00:41<00:24, 31.84s/it]
 61%|######1   | 1.2266666666666668/2 [00:41<00:24, 31.66s/it]
 62%|######1   | 1.232/2 [00:41<00:24, 31.95s/it]
 62%|######1   | 1.2373333333333332/2 [00:42<00:24, 31.97s/it]
 62%|######2   | 1.2426666666666664/2 [00:42<00:24, 31.91s/it]
 62%|######2   | 1.2479999999999996/2 [00:42<00:23, 31.84s/it]
 63%|######2   | 1.2533333333333327/2 [00:42<00:23, 31.87s/it]
 63%|######2   | 1.258666666666666/2 [00:42<00:23, 31.79s/it]
 63%|######3   | 1.2639999999999991/2 [00:42<00:23, 31.64s/it]Train Epoch: 2 [25600/84843 (30%)]        Loss: 1.085688

 63%|######3   | 1.2693333333333323/2 [00:43<00:23, 32.04s/it]
 64%|######3   | 1.2746666666666655/2 [00:43<00:22, 31.66s/it]
 64%|######3   | 1.2799999999999987/2 [00:43<00:22, 31.72s/it]
 64%|######4   | 1.2853333333333319/2 [00:43<00:22, 31.45s/it]
 65%|######4   | 1.290666666666665/2 [00:43<00:22, 31.38s/it]
 65%|######4   | 1.2959999999999983/2 [00:43<00:22, 31.40s/it]
 65%|######5   | 1.3013333333333315/2 [00:44<00:21, 31.40s/it]
 65%|######5   | 1.3066666666666646/2 [00:44<00:21, 31.34s/it]
 66%|######5   | 1.3119999999999978/2 [00:44<00:21, 31.12s/it]
 66%|######5   | 1.317333333333331/2 [00:44<00:21, 31.31s/it] Train Epoch: 2 [30720/84843 (36%)]        Loss: 0.894427

 66%|######6   | 1.3226666666666642/2 [00:44<00:21, 31.47s/it]
 66%|######6   | 1.3279999999999974/2 [00:44<00:20, 30.94s/it]
 67%|######6   | 1.3333333333333306/2 [00:45<00:20, 31.15s/it]
 67%|######6   | 1.3386666666666638/2 [00:45<00:20, 31.16s/it]
 67%|######7   | 1.343999999999997/2 [00:45<00:20, 31.32s/it]
 67%|######7   | 1.3493333333333302/2 [00:45<00:20, 31.36s/it]
 68%|######7   | 1.3546666666666634/2 [00:45<00:20, 31.26s/it]
 68%|######7   | 1.3599999999999965/2 [00:45<00:19, 31.17s/it]
 68%|######8   | 1.3653333333333297/2 [00:46<00:19, 31.11s/it]
 69%|######8   | 1.370666666666663/2 [00:46<00:19, 31.05s/it] Train Epoch: 2 [35840/84843 (42%)]        Loss: 0.971906

 69%|######8   | 1.3759999999999961/2 [00:46<00:19, 31.32s/it]
 69%|######9   | 1.3813333333333293/2 [00:46<00:19, 31.34s/it]
 69%|######9   | 1.3866666666666625/2 [00:46<00:19, 31.50s/it]
 70%|######9   | 1.3919999999999957/2 [00:46<00:19, 31.31s/it]
 70%|######9   | 1.3973333333333289/2 [00:47<00:18, 31.31s/it]
 70%|#######   | 1.402666666666662/2 [00:47<00:18, 31.34s/it]
 70%|#######   | 1.4079999999999953/2 [00:47<00:18, 31.55s/it]
 71%|#######   | 1.4133333333333284/2 [00:47<00:18, 31.50s/it]
 71%|#######   | 1.4186666666666616/2 [00:47<00:18, 31.43s/it]
 71%|#######1  | 1.4239999999999948/2 [00:47<00:18, 31.36s/it]Train Epoch: 2 [40960/84843 (48%)]        Loss: 1.002216

 71%|#######1  | 1.429333333333328/2 [00:48<00:17, 31.47s/it]
 72%|#######1  | 1.4346666666666612/2 [00:48<00:17, 31.02s/it]
 72%|#######1  | 1.4399999999999944/2 [00:48<00:17, 30.94s/it]
 72%|#######2  | 1.4453333333333276/2 [00:48<00:17, 31.24s/it]
 73%|#######2  | 1.4506666666666608/2 [00:48<00:17, 31.25s/it]
 73%|#######2  | 1.455999999999994/2 [00:48<00:17, 31.42s/it]
 73%|#######3  | 1.4613333333333272/2 [00:49<00:16, 31.33s/it]
 73%|#######3  | 1.4666666666666603/2 [00:49<00:16, 31.19s/it]
 74%|#######3  | 1.4719999999999935/2 [00:49<00:16, 31.24s/it]
 74%|#######3  | 1.4773333333333267/2 [00:49<00:16, 31.23s/it]Train Epoch: 2 [46080/84843 (54%)]        Loss: 0.820399

 74%|#######4  | 1.48266666666666/2 [00:49<00:16, 31.30s/it]
 74%|#######4  | 1.487999999999993/2 [00:49<00:16, 31.30s/it]
 75%|#######4  | 1.4933333333333263/2 [00:50<00:15, 31.31s/it]
 75%|#######4  | 1.4986666666666595/2 [00:50<00:15, 31.14s/it]
 75%|#######5  | 1.5039999999999927/2 [00:50<00:15, 31.20s/it]
 75%|#######5  | 1.5093333333333259/2 [00:50<00:15, 31.04s/it]
 76%|#######5  | 1.514666666666659/2 [00:50<00:15, 31.13s/it]
 76%|#######5  | 1.5199999999999922/2 [00:50<00:14, 31.16s/it]
 76%|#######6  | 1.5253333333333254/2 [00:51<00:14, 31.59s/it]
 77%|#######6  | 1.5306666666666586/2 [00:51<00:14, 31.87s/it]Train Epoch: 2 [51200/84843 (60%)]        Loss: 0.811157

 77%|#######6  | 1.5359999999999918/2 [00:51<00:14, 32.27s/it]
 77%|#######7  | 1.541333333333325/2 [00:51<00:14, 31.63s/it]
 77%|#######7  | 1.5466666666666582/2 [00:51<00:14, 31.63s/it]
 78%|#######7  | 1.5519999999999914/2 [00:51<00:14, 31.85s/it]
 78%|#######7  | 1.5573333333333246/2 [00:52<00:13, 31.59s/it]
 78%|#######8  | 1.5626666666666578/2 [00:52<00:13, 31.77s/it]
 78%|#######8  | 1.567999999999991/2 [00:52<00:13, 31.89s/it]
 79%|#######8  | 1.5733333333333241/2 [00:52<00:13, 32.17s/it]
 79%|#######8  | 1.5786666666666573/2 [00:52<00:13, 32.04s/it]
 79%|#######9  | 1.5839999999999905/2 [00:52<00:13, 31.83s/it]Train Epoch: 2 [56320/84843 (66%)]        Loss: 0.875054

 79%|#######9  | 1.5893333333333237/2 [00:53<00:13, 32.36s/it]
 80%|#######9  | 1.594666666666657/2 [00:53<00:12, 32.01s/it]
 80%|#######9  | 1.59999999999999/2 [00:53<00:12, 31.93s/it]
 80%|########  | 1.6053333333333233/2 [00:53<00:12, 32.12s/it]
 81%|########  | 1.6106666666666565/2 [00:53<00:12, 31.99s/it]
 81%|########  | 1.6159999999999897/2 [00:53<00:12, 31.66s/it]
 81%|########1 | 1.6213333333333229/2 [00:54<00:11, 31.45s/it]
 81%|########1 | 1.626666666666656/2 [00:54<00:11, 31.64s/it]
 82%|########1 | 1.6319999999999892/2 [00:54<00:11, 31.58s/it]
 82%|########1 | 1.6373333333333224/2 [00:54<00:11, 31.76s/it]Train Epoch: 2 [61440/84843 (72%)]        Loss: 0.923683

 82%|########2 | 1.6426666666666556/2 [00:54<00:11, 31.94s/it]
 82%|########2 | 1.6479999999999888/2 [00:54<00:11, 31.99s/it]
 83%|########2 | 1.653333333333322/2 [00:55<00:11, 31.86s/it]
 83%|########2 | 1.6586666666666552/2 [00:55<00:10, 31.85s/it]
 83%|########3 | 1.6639999999999884/2 [00:55<00:10, 31.66s/it]
 83%|########3 | 1.6693333333333216/2 [00:55<00:10, 31.57s/it]
 84%|########3 | 1.6746666666666548/2 [00:55<00:10, 31.46s/it]
 84%|########3 | 1.679999999999988/2 [00:55<00:10, 31.60s/it]
 84%|########4 | 1.6853333333333211/2 [00:56<00:09, 31.72s/it]
 85%|########4 | 1.6906666666666543/2 [00:56<00:10, 33.01s/it]Train Epoch: 2 [66560/84843 (78%)]        Loss: 0.889358

 85%|########4 | 1.6959999999999875/2 [00:56<00:10, 33.06s/it]
 85%|########5 | 1.7013333333333207/2 [00:56<00:09, 32.27s/it]
 85%|########5 | 1.706666666666654/2 [00:56<00:09, 32.09s/it]
 86%|########5 | 1.711999999999987/2 [00:57<00:09, 31.90s/it]
 86%|########5 | 1.7173333333333203/2 [00:57<00:08, 31.69s/it]
 86%|########6 | 1.7226666666666535/2 [00:57<00:08, 31.81s/it]
 86%|########6 | 1.7279999999999867/2 [00:57<00:08, 31.87s/it]
 87%|########6 | 1.7333333333333198/2 [00:57<00:08, 31.85s/it]
 87%|########6 | 1.738666666666653/2 [00:57<00:08, 31.93s/it]
 87%|########7 | 1.7439999999999862/2 [00:58<00:08, 31.82s/it]Train Epoch: 2 [71680/84843 (84%)]        Loss: 0.826419

 87%|########7 | 1.7493333333333194/2 [00:58<00:08, 32.33s/it]
 88%|########7 | 1.7546666666666526/2 [00:58<00:07, 31.73s/it]
 88%|########7 | 1.7599999999999858/2 [00:58<00:07, 31.62s/it]
 88%|########8 | 1.765333333333319/2 [00:58<00:07, 31.65s/it]
 89%|########8 | 1.7706666666666522/2 [00:58<00:07, 31.88s/it]
 89%|########8 | 1.7759999999999854/2 [00:59<00:07, 31.85s/it]
 89%|########9 | 1.7813333333333186/2 [00:59<00:06, 31.80s/it]
 89%|########9 | 1.7866666666666517/2 [00:59<00:06, 31.78s/it]
 90%|########9 | 1.791999999999985/2 [00:59<00:06, 31.60s/it]
 90%|########9 | 1.7973333333333181/2 [00:59<00:06, 31.79s/it]Train Epoch: 2 [76800/84843 (90%)]        Loss: 1.038377

 90%|######### | 1.8026666666666513/2 [00:59<00:06, 31.90s/it]
 90%|######### | 1.8079999999999845/2 [01:00<00:06, 31.64s/it]
 91%|######### | 1.8133333333333177/2 [01:00<00:05, 31.96s/it]
 91%|######### | 1.8186666666666509/2 [01:00<00:05, 31.57s/it]
 91%|#########1| 1.823999999999984/2 [01:00<00:05, 31.51s/it]
 91%|#########1| 1.8293333333333173/2 [01:00<00:05, 31.72s/it]
 92%|#########1| 1.8346666666666505/2 [01:00<00:05, 31.34s/it]
 92%|#########1| 1.8399999999999836/2 [01:01<00:04, 31.20s/it]
 92%|#########2| 1.8453333333333168/2 [01:01<00:04, 31.33s/it]
 93%|#########2| 1.85066666666665/2 [01:01<00:04, 31.29s/it]  Train Epoch: 2 [81920/84843 (96%)]        Loss: 0.769750

 93%|#########2| 1.8559999999999832/2 [01:01<00:04, 31.47s/it]
 93%|#########3| 1.8613333333333164/2 [01:01<00:04, 31.12s/it]
 93%|#########3| 1.8666666666666496/2 [01:01<00:04, 31.24s/it]
 94%|#########3| 1.8719999999999828/2 [01:02<00:03, 31.21s/it]
 94%|#########3| 1.877333333333316/2 [01:02<00:03, 31.37s/it]
 94%|#########4| 1.8826666666666492/2 [01:02<00:03, 31.47s/it]
 94%|#########4| 1.8879999999999824/2 [01:02<00:03, 29.55s/it]
 95%|#########4| 1.8933333333333155/2 [01:02<00:03, 29.54s/it]
 95%|#########4| 1.8986666666666487/2 [01:02<00:02, 29.52s/it]
 95%|#########5| 1.903999999999982/2 [01:03<00:02, 29.42s/it]
 95%|#########5| 1.9093333333333151/2 [01:03<00:02, 29.38s/it]
 96%|#########5| 1.9146666666666483/2 [01:03<00:02, 29.33s/it]
 96%|#########5| 1.9199999999999815/2 [01:03<00:02, 29.38s/it]
 96%|#########6| 1.9253333333333147/2 [01:03<00:02, 29.27s/it]
 97%|#########6| 1.9306666666666479/2 [01:03<00:02, 29.16s/it]
 97%|#########6| 1.935999999999981/2 [01:03<00:01, 29.07s/it]
 97%|#########7| 1.9413333333333143/2 [01:04<00:01, 29.09s/it]
 97%|#########7| 1.9466666666666474/2 [01:04<00:01, 29.45s/it]
 98%|#########7| 1.9519999999999806/2 [01:04<00:01, 29.40s/it]
 98%|#########7| 1.9573333333333138/2 [01:04<00:01, 29.40s/it]
 98%|#########8| 1.962666666666647/2 [01:04<00:01, 29.27s/it]
 98%|#########8| 1.9679999999999802/2 [01:04<00:00, 29.33s/it]
 99%|#########8| 1.9733333333333134/2 [01:05<00:00, 29.50s/it]
 99%|#########8| 1.9786666666666466/2 [01:05<00:00, 29.39s/it]
 99%|#########9| 1.9839999999999798/2 [01:05<00:00, 29.41s/it]
 99%|#########9| 1.989333333333313/2 [01:05<00:00, 29.40s/it]
100%|#########9| 1.9946666666666462/2 [01:05<00:00, 29.43s/it]
100%|#########9| 1.9999999999999793/2 [01:05<00:00, 29.19s/it]
Test Epoch: 2   Accuracy: 7852/11005 (71%)


100%|#########9| 1.9999999999999793/2 [01:05<00:00, 32.91s/it]

网络在 2 个 epoch 后在测试集上的准确率应该超过 65%,在 21 个 epoch 后超过 85%。让我们看看训练集中最后的几个单词,看看模型在这些单词上的表现如何。

def predict(tensor):
    # Use the model to predict the label of the waveform
    tensor = tensor.to(device)
    tensor = transform(tensor)
    tensor = model(tensor.unsqueeze(0))
    tensor = get_likely_index(tensor)
    tensor = index_to_label(tensor.squeeze())
    return tensor


waveform, sample_rate, utterance, *_ = train_set[-1]
ipd.Audio(waveform.numpy(), rate=sample_rate)

print(f"Expected: {utterance}. Predicted: {predict(waveform)}.")
Expected: zero. Predicted: zero.

让我们找一个没有被正确分类的示例(如果存在的话)。

for i, (waveform, sample_rate, utterance, *_) in enumerate(test_set):
    output = predict(waveform)
    if output != utterance:
        ipd.Audio(waveform.numpy(), rate=sample_rate)
        print(f"Data point #{i}. Expected: {utterance}. Predicted: {output}.")
        break
else:
    print("All examples in this dataset were correctly classified!")
    print("In this case, let's just look at the last data point")
    ipd.Audio(waveform.numpy(), rate=sample_rate)
    print(f"Data point #{i}. Expected: {utterance}. Predicted: {output}.")
Data point #1. Expected: right. Predicted: seven.

随意尝试使用您自己录制的一个标签的音频!例如,在使用 Colab 时,在执行下面的单元格时说“Go”。这将录制一秒钟的音频并尝试对其进行分类。

def record(seconds=1):

    from google.colab import output as colab_output
    from base64 import b64decode
    from io import BytesIO
    from pydub import AudioSegment

    RECORD = (
        b"const sleep  = time => new Promise(resolve => setTimeout(resolve, time))\n"
        b"const b2text = blob => new Promise(resolve => {\n"
        b"  const reader = new FileReader()\n"
        b"  reader.onloadend = e => resolve(e.srcElement.result)\n"
        b"  reader.readAsDataURL(blob)\n"
        b"})\n"
        b"var record = time => new Promise(async resolve => {\n"
        b"  stream = await navigator.mediaDevices.getUserMedia({ audio: true })\n"
        b"  recorder = new MediaRecorder(stream)\n"
        b"  chunks = []\n"
        b"  recorder.ondataavailable = e => chunks.push(e.data)\n"
        b"  recorder.start()\n"
        b"  await sleep(time)\n"
        b"  recorder.onstop = async ()=>{\n"
        b"    blob = new Blob(chunks)\n"
        b"    text = await b2text(blob)\n"
        b"    resolve(text)\n"
        b"  }\n"
        b"  recorder.stop()\n"
        b"})"
    )
    RECORD = RECORD.decode("ascii")

    print(f"Recording started for {seconds} seconds.")
    display(ipd.Javascript(RECORD))
    s = colab_output.eval_js("record(%d)" % (seconds * 1000))
    print("Recording ended.")
    b = b64decode(s.split(",")[1])

    fileformat = "wav"
    filename = f"_audio.{fileformat}"
    AudioSegment.from_file(BytesIO(b)).export(filename, format=fileformat)
    return torchaudio.load(filename)


# Detect whether notebook runs in google colab
if "google.colab" in sys.modules:
    waveform, sample_rate = record()
    print(f"Predicted: {predict(waveform)}.")
    ipd.Audio(waveform.numpy(), rate=sample_rate)

结论

在本教程中,我们使用 torchaudio 加载数据集并对信号进行重采样。然后,我们定义了一个神经网络,并对其进行了训练以识别给定的命令。还有一些其他的数据预处理方法,例如查找梅尔频率倒谱系数 (MFCC),可以减少数据集的大小。此变换在 torchaudio 中也可用,作为torchaudio.transforms.MFCC

脚本的总运行时间:(2 分 29.974 秒)

Gallery generated by Sphinx-Gallery

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深度教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源