• 文档 >
  • 使用 TensorDict 处理数据集
快捷方式

使用 TensorDict 处理数据集

在本教程中,我们将演示如何使用 TensorDict 在训练管道中高效且透明地加载和管理数据。本教程主要基于 PyTorch 快速入门教程,但进行了修改以演示 TensorDict 的用法。

import torch
import torch.nn as nn

from tensordict import MemoryMappedTensor, TensorDict
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu

torchvision.datasets 模块包含许多方便的预准备数据集。在本教程中,我们将使用相对简单的 FashionMNIST 数据集。每个图像都是一件服装,目标是分类图像中的服装类型(例如,“包”、“运动鞋”等)。

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

我们将创建两个 tensordict,分别用于训练数据和测试数据。我们创建内存映射张量来保存数据。这将使我们能够高效地从磁盘加载批量转换后的数据,而不是重复加载和转换单个图像。

首先,我们创建 MemoryMappedTensor 容器。

training_data_td = TensorDict(
    {
        "images": MemoryMappedTensor.empty(
            (len(training_data), *training_data[0][0].squeeze().shape),
            dtype=torch.float32,
        ),
        "targets": MemoryMappedTensor.empty((len(training_data),), dtype=torch.int64),
    },
    batch_size=[len(training_data)],
    device=device,
)
test_data_td = TensorDict(
    {
        "images": MemoryMappedTensor.empty(
            (len(test_data), *test_data[0][0].squeeze().shape), dtype=torch.float32
        ),
        "targets": MemoryMappedTensor.empty((len(test_data),), dtype=torch.int64),
    },
    batch_size=[len(test_data)],
    device=device,
)

然后,我们可以遍历数据以填充内存映射张量。这需要一些时间,但预先执行转换将在以后的训练过程中节省重复的努力。

for i, (img, label) in enumerate(training_data):
    training_data_td[i] = TensorDict({"images": img, "targets": label}, [])

for i, (img, label) in enumerate(test_data):
    test_data_td[i] = TensorDict({"images": img, "targets": label}, [])

数据加载器

我们将从 torchvision 提供的数据集以及我们的内存映射 TensorDict 创建数据加载器。

由于 TensorDict 实现了 __len____getitem__(以及 __getitems__),因此我们可以像使用映射风格的数据集一样使用它,并直接从中创建 DataLoader。请注意,因为 TensorDict 已经可以处理批量索引,所以不需要整理,因此我们将身份函数作为 collate_fn 传递。

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)  # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size)  # noqa: TOR401

train_dataloader_td = DataLoader(  # noqa: TOR401
    training_data_td, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_td = DataLoader(  # noqa: TOR401
    test_data_td, batch_size=batch_size, collate_fn=lambda x: x
)

模型

我们使用 快速入门教程 中的相同模型。

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = Net().to(device)
model_td = Net().to(device)
model, model_td
(Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
), Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
))

优化参数

我们将使用随机梯度下降和交叉熵损失来优化模型的参数。

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_td = torch.optim.SGD(model_td.parameters(), lr=1e-3)


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

我们基于 TensorDict 的 DataLoader 的训练循环非常相似,我们只是调整了如何将数据解包到 TensorDict 提供的更明确的基于键的检索方式。 .contiguous() 方法加载存储在内存映射张量中的数据。

def train_td(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, data in enumerate(dataloader):
        X, y = data["images"].contiguous(), data["targets"].contiguous()

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


def test_td(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            X, y = batch["images"].contiguous(), batch["targets"].contiguous()

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


for d in train_dataloader_td:
    print(d)
    break

import time

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train_td(train_dataloader_td, model_td, loss_fn, optimizer_td)
    test_td(test_dataloader_td, model_td, loss_fn)
print(f"TensorDict training done! time: {time.time() - t0: 4.4f} s")

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
TensorDict(
    fields={
        images: Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
        targets: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([64]),
    device=cpu,
    is_shared=False)
Epoch 1
-------------------------
loss: 2.299852 [    0/60000]
loss: 2.292439 [ 6400/60000]
loss: 2.268839 [12800/60000]
loss: 2.262053 [19200/60000]
loss: 2.247478 [25600/60000]
loss: 2.213343 [32000/60000]
loss: 2.225055 [38400/60000]
loss: 2.192145 [44800/60000]
loss: 2.194392 [51200/60000]
loss: 2.160515 [57600/60000]
Test Error:
 Accuracy: 49.2%, Avg loss: 2.154577

Epoch 2
-------------------------
loss: 2.171067 [    0/60000]
loss: 2.158847 [ 6400/60000]
loss: 2.099328 [12800/60000]
loss: 2.108190 [19200/60000]
loss: 2.054464 [25600/60000]
loss: 2.002670 [32000/60000]
loss: 2.027874 [38400/60000]
loss: 1.954475 [44800/60000]
loss: 1.964767 [51200/60000]
loss: 1.876978 [57600/60000]
Test Error:
 Accuracy: 54.4%, Avg loss: 1.879585

Epoch 3
-------------------------
loss: 1.924535 [    0/60000]
loss: 1.888327 [ 6400/60000]
loss: 1.768981 [12800/60000]
loss: 1.797174 [19200/60000]
loss: 1.680208 [25600/60000]
loss: 1.649336 [32000/60000]
loss: 1.665631 [38400/60000]
loss: 1.577420 [44800/60000]
loss: 1.605635 [51200/60000]
loss: 1.486382 [57600/60000]
Test Error:
 Accuracy: 60.4%, Avg loss: 1.508037

Epoch 4
-------------------------
loss: 1.586136 [    0/60000]
loss: 1.546187 [ 6400/60000]
loss: 1.395918 [12800/60000]
loss: 1.455341 [19200/60000]
loss: 1.334505 [25600/60000]
loss: 1.347199 [32000/60000]
loss: 1.356066 [38400/60000]
loss: 1.290384 [44800/60000]
loss: 1.322476 [51200/60000]
loss: 1.219003 [57600/60000]
Test Error:
 Accuracy: 63.2%, Avg loss: 1.243664

Epoch 5
-------------------------
loss: 1.327081 [    0/60000]
loss: 1.306844 [ 6400/60000]
loss: 1.139891 [12800/60000]
loss: 1.237875 [19200/60000]
loss: 1.110981 [25600/60000]
loss: 1.149890 [32000/60000]
loss: 1.169009 [38400/60000]
loss: 1.112532 [44800/60000]
loss: 1.147224 [51200/60000]
loss: 1.064276 [57600/60000]
Test Error:
 Accuracy: 64.3%, Avg loss: 1.081395

TensorDict training done! time:  8.4377 s
Epoch 1
-------------------------
loss: 2.316761 [    0/60000]
loss: 2.298437 [ 6400/60000]
loss: 2.284247 [12800/60000]
loss: 2.269306 [19200/60000]
loss: 2.255049 [25600/60000]
loss: 2.231190 [32000/60000]
loss: 2.229480 [38400/60000]
loss: 2.200073 [44800/60000]
loss: 2.197099 [51200/60000]
loss: 2.164955 [57600/60000]
Test Error:
 Accuracy: 47.7%, Avg loss: 2.159026

Epoch 2
-------------------------
loss: 2.171586 [    0/60000]
loss: 2.163805 [ 6400/60000]
loss: 2.108787 [12800/60000]
loss: 2.121653 [19200/60000]
loss: 2.070410 [25600/60000]
loss: 2.013198 [32000/60000]
loss: 2.043009 [38400/60000]
loss: 1.962271 [44800/60000]
loss: 1.960111 [51200/60000]
loss: 1.905308 [57600/60000]
Test Error:
 Accuracy: 59.3%, Avg loss: 1.894151

Epoch 3
-------------------------
loss: 1.918365 [    0/60000]
loss: 1.898802 [ 6400/60000]
loss: 1.781930 [12800/60000]
loss: 1.828506 [19200/60000]
loss: 1.715667 [25600/60000]
loss: 1.665848 [32000/60000]
loss: 1.698011 [38400/60000]
loss: 1.590174 [44800/60000]
loss: 1.609489 [51200/60000]
loss: 1.523170 [57600/60000]
Test Error:
 Accuracy: 61.4%, Avg loss: 1.527936

Epoch 4
-------------------------
loss: 1.583255 [    0/60000]
loss: 1.557144 [ 6400/60000]
loss: 1.408821 [12800/60000]
loss: 1.489391 [19200/60000]
loss: 1.365822 [25600/60000]
loss: 1.360586 [32000/60000]
loss: 1.382201 [38400/60000]
loss: 1.299545 [44800/60000]
loss: 1.336339 [51200/60000]
loss: 1.245674 [57600/60000]
Test Error:
 Accuracy: 63.0%, Avg loss: 1.262910

Epoch 5
-------------------------
loss: 1.334715 [    0/60000]
loss: 1.318874 [ 6400/60000]
loss: 1.160190 [12800/60000]
loss: 1.264431 [19200/60000]
loss: 1.142190 [25600/60000]
loss: 1.166963 [32000/60000]
loss: 1.186403 [38400/60000]
loss: 1.120692 [44800/60000]
loss: 1.164268 [51200/60000]
loss: 1.083194 [57600/60000]
Test Error:
 Accuracy: 64.5%, Avg loss: 1.097534

Training done! time:  34.8765 s

脚本的总运行时间:(0 分 53.316 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深度教程

查看教程

资源

查找开发资源并解答您的问题

查看资源