注意
跳到末尾 下载完整示例代码。
使用 TensorClass 处理数据集¶
在本教程中,我们将演示如何使用 TensorClass 高效且透明地在训练管道中加载和管理数据。本教程大量参考了 PyTorch 快速入门教程,但进行了修改以演示 TensorClass 的使用。请参阅相关的 TensorDict
使用教程。
import torch
import torch.nn as nn
from tensordict import MemoryMappedTensor, tensorclass
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu
torchvision.datasets
模块包含许多方便的预处理数据集。在本教程中,我们将使用相对简单的 FashionMNIST 数据集。每张图像都是一件衣服,目标是分类图像中的服装类型(例如,“包袋”、“运动鞋”等)。
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:12, 363kB/s]
1%| | 229k/26.4M [00:00<00:37, 693kB/s]
4%|▎ | 950k/26.4M [00:00<00:11, 2.21MB/s]
9%|▉ | 2.49M/26.4M [00:00<00:04, 5.73MB/s]
13%|█▎ | 3.34M/26.4M [00:00<00:04, 5.36MB/s]
18%|█▊ | 4.82M/26.4M [00:00<00:03, 6.40MB/s]
24%|██▍ | 6.36M/26.4M [00:01<00:02, 7.13MB/s]
30%|██▉ | 7.90M/26.4M [00:01<00:02, 7.60MB/s]
36%|███▌ | 9.50M/26.4M [00:01<00:02, 8.01MB/s]
42%|████▏ | 11.1M/26.4M [00:01<00:01, 8.35MB/s]
47%|████▋ | 12.5M/26.4M [00:01<00:01, 9.49MB/s]
52%|█████▏ | 13.6M/26.4M [00:01<00:01, 8.45MB/s]
58%|█████▊ | 15.3M/26.4M [00:02<00:01, 8.78MB/s]
63%|██████▎ | 16.8M/26.4M [00:02<00:00, 10.0MB/s]
68%|██████▊ | 17.9M/26.4M [00:02<00:00, 8.85MB/s]
74%|███████▍ | 19.7M/26.4M [00:02<00:00, 10.7MB/s]
79%|███████▉ | 20.9M/26.4M [00:02<00:00, 9.39MB/s]
85%|████████▍ | 22.4M/26.4M [00:02<00:00, 9.07MB/s]
92%|█████████▏| 24.2M/26.4M [00:03<00:00, 9.43MB/s]
99%|█████████▊| 26.1M/26.4M [00:03<00:00, 9.67MB/s]
100%|██████████| 26.4M/26.4M [00:03<00:00, 8.11MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 328kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|▏ | 65.5k/4.42M [00:00<00:12, 360kB/s]
5%|▌ | 229k/4.42M [00:00<00:06, 679kB/s]
21%|██ | 918k/4.42M [00:00<00:01, 2.60MB/s]
44%|████▎ | 1.93M/4.42M [00:00<00:00, 4.05MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.04MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 66.8MB/s]
TensorClass 是一种数据类 (dataclass),它像 TensorDict
一样,提供了专门的张量方法来操作其内容。当您想存储的数据结构固定且可预测时,TensorClass 是一个不错的选择。
除了指定内容,我们还可以在定义类时,将相关逻辑封装为自定义方法。在本例中,我们将编写一个 from_dataset
类方法,该方法接受数据集作为输入,并创建一个包含数据集数据的 TensorClass。我们创建内存映射张量 (memory-mapped tensors) 来保存数据。这将使我们能够高效地从磁盘加载批次转换后的数据,而不是重复加载和转换单个图像。
@tensorclass
class FashionMNISTData:
images: torch.Tensor
targets: torch.Tensor
@classmethod
def from_dataset(cls, dataset, device=None):
data = cls(
images=MemoryMappedTensor.empty(
(len(dataset), *dataset[0][0].squeeze().shape), dtype=torch.float32
),
targets=MemoryMappedTensor.empty((len(dataset),), dtype=torch.int64),
batch_size=[len(dataset)],
device=device,
)
for i, (image, target) in enumerate(dataset):
data[i] = cls(images=image, targets=torch.tensor(target), batch_size=[])
return data
我们将创建两个 TensorClass,分别用于训练和测试数据。请注意,由于我们需要遍历整个数据集,对其进行转换并保存到磁盘,因此这里会产生一些开销。
数据加载器 (DataLoaders)¶
我们将从 torchvision
提供的数据集以及我们的内存映射 TensorDict 创建数据加载器 (DataLoaders)。
由于 TensorDict
实现了 __len__
和 __getitem__
(以及 __getitems__
),我们可以像 map-style 数据集一样使用它,并直接从中创建 DataLoader
。请注意,由于 TensorDict
已经能够处理批量索引,因此不需要 collate,所以我们将恒等函数 (identity function) 作为 collate_fn
传递。
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size) # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size) # noqa: TOR401
train_dataloader_tc = DataLoader( # noqa: TOR401
training_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_tc = DataLoader( # noqa: TOR401
test_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)
模型¶
我们使用与 快速入门教程 中相同的模型。
class Net(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = Net().to(device)
model_tc = Net().to(device)
model, model_tc
(Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
), Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
))
优化参数¶
我们将使用随机梯度下降 (stochastic gradient descent) 和交叉熵损失 (cross-entropy loss) 来优化模型的参数。
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_tc = torch.optim.SGD(model_tc.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
基于 TensorClass 的数据加载器 (DataLoader) 的训练循环非常相似,我们只需调整如何解包数据,以适应 TensorClass 提供的更显式的基于属性的检索方式。.contiguous()
方法加载存储在内存映射张量 (memmap tensor) 中的数据。
def train_tc(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, data in enumerate(dataloader):
X, y = data.images.contiguous(), data.targets.contiguous()
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
def test_tc(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for batch in dataloader:
X, y = batch.images.contiguous(), batch.targets.contiguous()
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
for d in train_dataloader_tc:
print(d)
break
import time
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train_tc(train_dataloader_tc, model_tc, loss_fn, optimizer_tc)
test_tc(test_dataloader_tc, model_tc, loss_fn)
print(f"Tensorclass training done! time: {time.time() - t0: 4.4f} s")
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
FashionMNISTData(
images=Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False),
batch_size=torch.Size([64]),
device=cpu,
is_shared=False)
Epoch 1
-------------------------
loss: 2.303174 [ 0/60000]
loss: 2.292315 [ 6400/60000]
loss: 2.276398 [12800/60000]
loss: 2.266935 [19200/60000]
loss: 2.243844 [25600/60000]
loss: 2.225336 [32000/60000]
loss: 2.217357 [38400/60000]
loss: 2.195517 [44800/60000]
loss: 2.190626 [51200/60000]
loss: 2.155467 [57600/60000]
Test Error:
Accuracy: 50.9%, Avg loss: 2.152797
Epoch 2
-------------------------
loss: 2.161741 [ 0/60000]
loss: 2.150906 [ 6400/60000]
loss: 2.099977 [12800/60000]
loss: 2.111458 [19200/60000]
loss: 2.054362 [25600/60000]
loss: 2.009193 [32000/60000]
loss: 2.014337 [38400/60000]
loss: 1.949405 [44800/60000]
loss: 1.948474 [51200/60000]
loss: 1.877130 [57600/60000]
Test Error:
Accuracy: 54.6%, Avg loss: 1.879634
Epoch 3
-------------------------
loss: 1.910757 [ 0/60000]
loss: 1.880123 [ 6400/60000]
loss: 1.772925 [12800/60000]
loss: 1.808776 [19200/60000]
loss: 1.692901 [25600/60000]
loss: 1.657541 [32000/60000]
loss: 1.662243 [38400/60000]
loss: 1.577471 [44800/60000]
loss: 1.601571 [51200/60000]
loss: 1.500768 [57600/60000]
Test Error:
Accuracy: 59.0%, Avg loss: 1.518311
Epoch 4
-------------------------
loss: 1.585062 [ 0/60000]
loss: 1.546511 [ 6400/60000]
loss: 1.407458 [12800/60000]
loss: 1.477034 [19200/60000]
loss: 1.352650 [25600/60000]
loss: 1.358121 [32000/60000]
loss: 1.363586 [38400/60000]
loss: 1.295179 [44800/60000]
loss: 1.331715 [51200/60000]
loss: 1.239219 [57600/60000]
Test Error:
Accuracy: 62.6%, Avg loss: 1.260449
Epoch 5
-------------------------
loss: 1.337718 [ 0/60000]
loss: 1.312946 [ 6400/60000]
loss: 1.158820 [12800/60000]
loss: 1.262185 [19200/60000]
loss: 1.131436 [25600/60000]
loss: 1.163135 [32000/60000]
loss: 1.178597 [38400/60000]
loss: 1.118884 [44800/60000]
loss: 1.159753 [51200/60000]
loss: 1.082628 [57600/60000]
Test Error:
Accuracy: 64.1%, Avg loss: 1.099033
Tensorclass training done! time: 8.5422 s
Epoch 1
-------------------------
loss: 2.308042 [ 0/60000]
loss: 2.298845 [ 6400/60000]
loss: 2.271829 [12800/60000]
loss: 2.261400 [19200/60000]
loss: 2.251647 [25600/60000]
loss: 2.210075 [32000/60000]
loss: 2.231766 [38400/60000]
loss: 2.188506 [44800/60000]
loss: 2.190940 [51200/60000]
loss: 2.149179 [57600/60000]
Test Error:
Accuracy: 33.4%, Avg loss: 2.146284
Epoch 2
-------------------------
loss: 2.159867 [ 0/60000]
loss: 2.151524 [ 6400/60000]
loss: 2.084647 [12800/60000]
loss: 2.098223 [19200/60000]
loss: 2.046152 [25600/60000]
loss: 1.983066 [32000/60000]
loss: 2.016088 [38400/60000]
loss: 1.934383 [44800/60000]
loss: 1.948664 [51200/60000]
loss: 1.854755 [57600/60000]
Test Error:
Accuracy: 59.6%, Avg loss: 1.860862
Epoch 3
-------------------------
loss: 1.902126 [ 0/60000]
loss: 1.867171 [ 6400/60000]
loss: 1.745260 [12800/60000]
loss: 1.780290 [19200/60000]
loss: 1.670075 [25600/60000]
loss: 1.632182 [32000/60000]
loss: 1.651149 [38400/60000]
loss: 1.564024 [44800/60000]
loss: 1.587803 [51200/60000]
loss: 1.465174 [57600/60000]
Test Error:
Accuracy: 62.5%, Avg loss: 1.494086
Epoch 4
-------------------------
loss: 1.568314 [ 0/60000]
loss: 1.531687 [ 6400/60000]
loss: 1.383616 [12800/60000]
loss: 1.444215 [19200/60000]
loss: 1.326217 [25600/60000]
loss: 1.338029 [32000/60000]
loss: 1.343920 [38400/60000]
loss: 1.286051 [44800/60000]
loss: 1.311489 [51200/60000]
loss: 1.203263 [57600/60000]
Test Error:
Accuracy: 64.0%, Avg loss: 1.235341
Epoch 5
-------------------------
loss: 1.313958 [ 0/60000]
loss: 1.298989 [ 6400/60000]
loss: 1.134868 [12800/60000]
loss: 1.230290 [19200/60000]
loss: 1.102433 [25600/60000]
loss: 1.143659 [32000/60000]
loss: 1.156451 [38400/60000]
loss: 1.113082 [44800/60000]
loss: 1.140398 [51200/60000]
loss: 1.050202 [57600/60000]
Test Error:
Accuracy: 65.1%, Avg loss: 1.076012
Training done! time: 34.3460 s
脚本总运行时间: (1 分 2.480 秒)