• 文档 >
  • 使用 tensorclasses 构建数据集
快捷方式

使用 tensorclasses 构建数据集

在本教程中,我们将演示如何使用 tensorclasses 在训练管道中高效且透明地加载和管理数据。本教程主要基于 PyTorch 快速入门教程,但进行了修改以展示 tensorclass 的用法。请参阅使用 TensorDict 的相关教程。

import torch
import torch.nn as nn

from tensordict import MemoryMappedTensor, tensorclass
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu

torchvision.datasets 模块包含许多方便的预先准备好的数据集。在本教程中,我们将使用相对简单的 FashionMNIST 数据集。每个图像都是一件衣服,目标是将图像中的衣服类型分类(例如,“包”,“运动鞋”等)。

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/26421880 [00:00<?, ?it/s]
  0%|          | 65536/26421880 [00:00<01:12, 363272.69it/s]
  1%|          | 229376/26421880 [00:00<00:38, 681027.37it/s]
  3%|▎         | 917504/26421880 [00:00<00:09, 2631776.19it/s]
  7%|▋         | 1933312/26421880 [00:00<00:06, 4072985.79it/s]
 25%|██▌       | 6619136/26421880 [00:00<00:01, 15415616.73it/s]
 38%|███▊      | 10125312/26421880 [00:00<00:00, 17400784.07it/s]
 60%|█████▉    | 15761408/26421880 [00:01<00:00, 26963438.20it/s]
 74%|███████▎  | 19431424/26421880 [00:01<00:00, 24897905.06it/s]
 94%|█████████▎| 24707072/26421880 [00:01<00:00, 31435598.53it/s]
100%|██████████| 26421880/26421880 [00:01<00:00, 19333118.23it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/29515 [00:00<?, ?it/s]
100%|██████████| 29515/29515 [00:00<00:00, 324628.70it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/4422102 [00:00<?, ?it/s]
  1%|▏         | 65536/4422102 [00:00<00:12, 362684.57it/s]
  5%|▌         | 229376/4422102 [00:00<00:06, 681555.64it/s]
 19%|█▊        | 819200/4422102 [00:00<00:01, 1858506.03it/s]
 61%|██████    | 2686976/4422102 [00:00<00:00, 5201206.65it/s]
100%|██████████| 4422102/4422102 [00:00<00:00, 6087320.72it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/5148 [00:00<?, ?it/s]
100%|██████████| 5148/5148 [00:00<00:00, 44065871.41it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Tensorclasses 是 dataclasses,它类似于 TensorDict,为其内容提供专门的张量方法。当您要存储的数据结构是固定且可预测时,它们是不错的选择。

除了指定内容外,我们还可以将相关逻辑封装为自定义方法,以定义类。在本例中,我们将编写一个 from_dataset 类方法,它将数据集作为输入,并创建一个包含数据集数据的 tensorclass。我们创建内存映射张量来保存数据。这将使我们能够有效地从磁盘加载批量转换后的数据,而不是重复加载和转换单个图像。

@tensorclass
class FashionMNISTData:
    images: torch.Tensor
    targets: torch.Tensor

    @classmethod
    def from_dataset(cls, dataset, device=None):
        data = cls(
            images=MemoryMappedTensor.empty(
                (len(dataset), *dataset[0][0].squeeze().shape), dtype=torch.float32
            ),
            targets=MemoryMappedTensor.empty((len(dataset),), dtype=torch.int64),
            batch_size=[len(dataset)],
            device=device,
        )
        for i, (image, target) in enumerate(dataset):
            data[i] = cls(images=image, targets=torch.tensor(target), batch_size=[])
        return data

我们将创建两个 tensorclasses,分别用于训练数据和测试数据。请注意,我们在这里会产生一些开销,因为我们正在遍历整个数据集,进行转换并保存到磁盘。

training_data_tc = FashionMNISTData.from_dataset(training_data, device=device)
test_data_tc = FashionMNISTData.from_dataset(test_data, device=device)

数据加载器

我们将从 torchvision 提供的数据集以及内存映射的 TensorDicts 中创建数据加载器。

由于 TensorDict 实现了 __len____getitem__(以及 __getitems__),因此我们可以像使用映射式数据集一样使用它,并直接从它创建 DataLoader。请注意,由于 TensorDict 已经可以处理批量索引,因此无需合并,所以我们将身份函数作为 collate_fn 传递。

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)  # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size)  # noqa: TOR401

train_dataloader_tc = DataLoader(  # noqa: TOR401
    training_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_tc = DataLoader(  # noqa: TOR401
    test_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)

模型

我们使用 快速入门教程 中的相同模型。

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = Net().to(device)
model_tc = Net().to(device)
model, model_tc
(Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
), Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
))

优化参数

我们将使用随机梯度下降和交叉熵损失来优化模型的参数。

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_tc = torch.optim.SGD(model_tc.parameters(), lr=1e-3)


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

基于 tensorclass 的数据加载器的训练循环非常相似,我们只需要调整如何解包数据,以适应 tensorclass 提供的更明确的基于属性的检索方式。.contiguous() 方法加载存储在 memmap 张量中的数据。

def train_tc(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, data in enumerate(dataloader):
        X, y = data.images.contiguous(), data.targets.contiguous()

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


def test_tc(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            X, y = batch.images.contiguous(), batch.targets.contiguous()

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


for d in train_dataloader_tc:
    print(d)
    break

import time

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train_tc(train_dataloader_tc, model_tc, loss_fn, optimizer_tc)
    test_tc(test_dataloader_tc, model_tc, loss_fn)
print(f"Tensorclass training done! time: {time.time() - t0: 4.4f} s")

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
FashionMNISTData(
    images=Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
    targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([64]),
    device=cpu,
    is_shared=False)
Epoch 1
-------------------------
loss: 2.288561 [    0/60000]
loss: 2.280895 [ 6400/60000]
loss: 2.260488 [12800/60000]
loss: 2.261184 [19200/60000]
loss: 2.244328 [25600/60000]
loss: 2.213650 [32000/60000]
loss: 2.227577 [38400/60000]
loss: 2.185061 [44800/60000]
loss: 2.176095 [51200/60000]
loss: 2.154058 [57600/60000]
Test Error:
 Accuracy: 45.2%, Avg loss: 2.143079

Epoch 2
-------------------------
loss: 2.145867 [    0/60000]
loss: 2.134690 [ 6400/60000]
loss: 2.077907 [12800/60000]
loss: 2.099887 [19200/60000]
loss: 2.028121 [25600/60000]
loss: 1.979612 [32000/60000]
loss: 2.014188 [38400/60000]
loss: 1.926303 [44800/60000]
loss: 1.923347 [51200/60000]
loss: 1.852802 [57600/60000]
Test Error:
 Accuracy: 58.4%, Avg loss: 1.850140

Epoch 3
-------------------------
loss: 1.879460 [    0/60000]
loss: 1.846868 [ 6400/60000]
loss: 1.731937 [12800/60000]
loss: 1.779482 [19200/60000]
loss: 1.638177 [25600/60000]
loss: 1.615429 [32000/60000]
loss: 1.645688 [38400/60000]
loss: 1.544954 [44800/60000]
loss: 1.557245 [51200/60000]
loss: 1.453836 [57600/60000]
Test Error:
 Accuracy: 61.8%, Avg loss: 1.477857

Epoch 4
-------------------------
loss: 1.541883 [    0/60000]
loss: 1.510409 [ 6400/60000]
loss: 1.365352 [12800/60000]
loss: 1.439685 [19200/60000]
loss: 1.297079 [25600/60000]
loss: 1.318537 [32000/60000]
loss: 1.338518 [38400/60000]
loss: 1.265177 [44800/60000]
loss: 1.287562 [51200/60000]
loss: 1.188529 [57600/60000]
Test Error:
 Accuracy: 63.3%, Avg loss: 1.221189

Epoch 5
-------------------------
loss: 1.295978 [    0/60000]
loss: 1.282051 [ 6400/60000]
loss: 1.118315 [12800/60000]
loss: 1.224885 [19200/60000]
loss: 1.083846 [25600/60000]
loss: 1.126744 [32000/60000]
loss: 1.154716 [38400/60000]
loss: 1.093256 [44800/60000]
loss: 1.121146 [51200/60000]
loss: 1.036791 [57600/60000]
Test Error:
 Accuracy: 64.5%, Avg loss: 1.064607

Tensorclass training done! time:  8.3674 s
Epoch 1
-------------------------
loss: 2.304507 [    0/60000]
loss: 2.288714 [ 6400/60000]
loss: 2.274979 [12800/60000]
loss: 2.271725 [19200/60000]
loss: 2.259416 [25600/60000]
loss: 2.237974 [32000/60000]
loss: 2.235101 [38400/60000]
loss: 2.205691 [44800/60000]
loss: 2.196805 [51200/60000]
loss: 2.176510 [57600/60000]
Test Error:
 Accuracy: 52.8%, Avg loss: 2.164095

Epoch 2
-------------------------
loss: 2.169890 [    0/60000]
loss: 2.154640 [ 6400/60000]
loss: 2.102692 [12800/60000]
loss: 2.120185 [19200/60000]
loss: 2.078767 [25600/60000]
loss: 2.023822 [32000/60000]
loss: 2.041349 [38400/60000]
loss: 1.964761 [44800/60000]
loss: 1.962162 [51200/60000]
loss: 1.902772 [57600/60000]
Test Error:
 Accuracy: 53.7%, Avg loss: 1.894101

Epoch 3
-------------------------
loss: 1.924948 [    0/60000]
loss: 1.887992 [ 6400/60000]
loss: 1.775490 [12800/60000]
loss: 1.819449 [19200/60000]
loss: 1.725564 [25600/60000]
loss: 1.675467 [32000/60000]
loss: 1.698622 [38400/60000]
loss: 1.600037 [44800/60000]
loss: 1.624712 [51200/60000]
loss: 1.531494 [57600/60000]
Test Error:
 Accuracy: 59.1%, Avg loss: 1.541386

Epoch 4
-------------------------
loss: 1.607657 [    0/60000]
loss: 1.564384 [ 6400/60000]
loss: 1.421606 [12800/60000]
loss: 1.498601 [19200/60000]
loss: 1.396880 [25600/60000]
loss: 1.383857 [32000/60000]
loss: 1.402771 [38400/60000]
loss: 1.327729 [44800/60000]
loss: 1.359327 [51200/60000]
loss: 1.267020 [57600/60000]
Test Error:
 Accuracy: 62.7%, Avg loss: 1.287844

Epoch 5
-------------------------
loss: 1.365369 [    0/60000]
loss: 1.336154 [ 6400/60000]
loss: 1.178401 [12800/60000]
loss: 1.284765 [19200/60000]
loss: 1.172344 [25600/60000]
loss: 1.189937 [32000/60000]
loss: 1.211484 [38400/60000]
loss: 1.151805 [44800/60000]
loss: 1.183774 [51200/60000]
loss: 1.102122 [57600/60000]
Test Error:
 Accuracy: 64.5%, Avg loss: 1.121291

Training done! time:  34.7098 s

脚本的总运行时间:(0 分钟 58.406 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源