• 文件 >
  • 使用 TensorClass 處理資料集
快捷方式

使用 TensorClass 處理資料集

在本教程中,我們將演示如何使用 TensorClass 高效且透明地在訓練管道中載入和管理資料。本教程大量參考了 PyTorch 快速入門教程,但進行了修改以演示 TensorClass 的使用。請參閱相關的 TensorDict 使用教程。

import torch
import torch.nn as nn

from tensordict import MemoryMappedTensor, tensorclass
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu

torchvision.datasets 模組包含許多方便的預處理資料集。在本教程中,我們將使用相對簡單的 FashionMNIST 資料集。每張影像都是一件衣服,目標是分類影像中的服裝型別(例如,“包袋”、“運動鞋”等)。

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:12, 363kB/s]
  1%|          | 229k/26.4M [00:00<00:37, 693kB/s]
  4%|▎         | 950k/26.4M [00:00<00:11, 2.21MB/s]
  9%|▉         | 2.49M/26.4M [00:00<00:04, 5.73MB/s]
 13%|█▎        | 3.34M/26.4M [00:00<00:04, 5.36MB/s]
 18%|█▊        | 4.82M/26.4M [00:00<00:03, 6.40MB/s]
 24%|██▍       | 6.36M/26.4M [00:01<00:02, 7.13MB/s]
 30%|██▉       | 7.90M/26.4M [00:01<00:02, 7.60MB/s]
 36%|███▌      | 9.50M/26.4M [00:01<00:02, 8.01MB/s]
 42%|████▏     | 11.1M/26.4M [00:01<00:01, 8.35MB/s]
 47%|████▋     | 12.5M/26.4M [00:01<00:01, 9.49MB/s]
 52%|█████▏    | 13.6M/26.4M [00:01<00:01, 8.45MB/s]
 58%|█████▊    | 15.3M/26.4M [00:02<00:01, 8.78MB/s]
 63%|██████▎   | 16.8M/26.4M [00:02<00:00, 10.0MB/s]
 68%|██████▊   | 17.9M/26.4M [00:02<00:00, 8.85MB/s]
 74%|███████▍  | 19.7M/26.4M [00:02<00:00, 10.7MB/s]
 79%|███████▉  | 20.9M/26.4M [00:02<00:00, 9.39MB/s]
 85%|████████▍ | 22.4M/26.4M [00:02<00:00, 9.07MB/s]
 92%|█████████▏| 24.2M/26.4M [00:03<00:00, 9.43MB/s]
 99%|█████████▊| 26.1M/26.4M [00:03<00:00, 9.67MB/s]
100%|██████████| 26.4M/26.4M [00:03<00:00, 8.11MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 328kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|▏         | 65.5k/4.42M [00:00<00:12, 360kB/s]
  5%|▌         | 229k/4.42M [00:00<00:06, 679kB/s]
 21%|██        | 918k/4.42M [00:00<00:01, 2.60MB/s]
 44%|████▎     | 1.93M/4.42M [00:00<00:00, 4.05MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.04MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 66.8MB/s]

TensorClass 是一種資料類 (dataclass),它像 TensorDict 一樣,提供了專門的張量方法來操作其內容。當您想儲存的資料結構固定且可預測時,TensorClass 是一個不錯的選擇。

除了指定內容,我們還可以在定義類時,將相關邏輯封裝為自定義方法。在本例中,我們將編寫一個 from_dataset 類方法,該方法接受資料集作為輸入,並建立一個包含資料集資料的 TensorClass。我們建立記憶體對映張量 (memory-mapped tensors) 來儲存資料。這將使我們能夠高效地從磁碟載入批次轉換後的資料,而不是重複載入和轉換單個影像。

@tensorclass
class FashionMNISTData:
    images: torch.Tensor
    targets: torch.Tensor

    @classmethod
    def from_dataset(cls, dataset, device=None):
        data = cls(
            images=MemoryMappedTensor.empty(
                (len(dataset), *dataset[0][0].squeeze().shape), dtype=torch.float32
            ),
            targets=MemoryMappedTensor.empty((len(dataset),), dtype=torch.int64),
            batch_size=[len(dataset)],
            device=device,
        )
        for i, (image, target) in enumerate(dataset):
            data[i] = cls(images=image, targets=torch.tensor(target), batch_size=[])
        return data

我們將建立兩個 TensorClass,分別用於訓練和測試資料。請注意,由於我們需要遍歷整個資料集,對其進行轉換並儲存到磁碟,因此這裡會產生一些開銷。

training_data_tc = FashionMNISTData.from_dataset(training_data, device=device)
test_data_tc = FashionMNISTData.from_dataset(test_data, device=device)

資料載入器 (DataLoaders)

我們將從 torchvision 提供的資料集以及我們的記憶體對映 TensorDict 建立資料載入器 (DataLoaders)。

由於 TensorDict 實現了 __len____getitem__ (以及 __getitems__),我們可以像 map-style 資料集一樣使用它,並直接從中建立 DataLoader。請注意,由於 TensorDict 已經能夠處理批次索引,因此不需要 collate,所以我們將恆等函式 (identity function) 作為 collate_fn 傳遞。

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)  # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size)  # noqa: TOR401

train_dataloader_tc = DataLoader(  # noqa: TOR401
    training_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_tc = DataLoader(  # noqa: TOR401
    test_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)

模型

我們使用與 快速入門教程 中相同的模型。

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = Net().to(device)
model_tc = Net().to(device)
model, model_tc
(Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
), Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
))

最佳化引數

我們將使用隨機梯度下降 (stochastic gradient descent) 和交叉熵損失 (cross-entropy loss) 來最佳化模型的引數。

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_tc = torch.optim.SGD(model_tc.parameters(), lr=1e-3)


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

基於 TensorClass 的資料載入器 (DataLoader) 的訓練迴圈非常相似,我們只需調整如何解包資料,以適應 TensorClass 提供的更顯式的基於屬性的檢索方式。.contiguous() 方法載入儲存在記憶體對映張量 (memmap tensor) 中的資料。

def train_tc(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, data in enumerate(dataloader):
        X, y = data.images.contiguous(), data.targets.contiguous()

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


def test_tc(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            X, y = batch.images.contiguous(), batch.targets.contiguous()

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


for d in train_dataloader_tc:
    print(d)
    break

import time

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train_tc(train_dataloader_tc, model_tc, loss_fn, optimizer_tc)
    test_tc(test_dataloader_tc, model_tc, loss_fn)
print(f"Tensorclass training done! time: {time.time() - t0: 4.4f} s")

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
FashionMNISTData(
    images=Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
    targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([64]),
    device=cpu,
    is_shared=False)
Epoch 1
-------------------------
loss: 2.303174 [    0/60000]
loss: 2.292315 [ 6400/60000]
loss: 2.276398 [12800/60000]
loss: 2.266935 [19200/60000]
loss: 2.243844 [25600/60000]
loss: 2.225336 [32000/60000]
loss: 2.217357 [38400/60000]
loss: 2.195517 [44800/60000]
loss: 2.190626 [51200/60000]
loss: 2.155467 [57600/60000]
Test Error:
 Accuracy: 50.9%, Avg loss: 2.152797

Epoch 2
-------------------------
loss: 2.161741 [    0/60000]
loss: 2.150906 [ 6400/60000]
loss: 2.099977 [12800/60000]
loss: 2.111458 [19200/60000]
loss: 2.054362 [25600/60000]
loss: 2.009193 [32000/60000]
loss: 2.014337 [38400/60000]
loss: 1.949405 [44800/60000]
loss: 1.948474 [51200/60000]
loss: 1.877130 [57600/60000]
Test Error:
 Accuracy: 54.6%, Avg loss: 1.879634

Epoch 3
-------------------------
loss: 1.910757 [    0/60000]
loss: 1.880123 [ 6400/60000]
loss: 1.772925 [12800/60000]
loss: 1.808776 [19200/60000]
loss: 1.692901 [25600/60000]
loss: 1.657541 [32000/60000]
loss: 1.662243 [38400/60000]
loss: 1.577471 [44800/60000]
loss: 1.601571 [51200/60000]
loss: 1.500768 [57600/60000]
Test Error:
 Accuracy: 59.0%, Avg loss: 1.518311

Epoch 4
-------------------------
loss: 1.585062 [    0/60000]
loss: 1.546511 [ 6400/60000]
loss: 1.407458 [12800/60000]
loss: 1.477034 [19200/60000]
loss: 1.352650 [25600/60000]
loss: 1.358121 [32000/60000]
loss: 1.363586 [38400/60000]
loss: 1.295179 [44800/60000]
loss: 1.331715 [51200/60000]
loss: 1.239219 [57600/60000]
Test Error:
 Accuracy: 62.6%, Avg loss: 1.260449

Epoch 5
-------------------------
loss: 1.337718 [    0/60000]
loss: 1.312946 [ 6400/60000]
loss: 1.158820 [12800/60000]
loss: 1.262185 [19200/60000]
loss: 1.131436 [25600/60000]
loss: 1.163135 [32000/60000]
loss: 1.178597 [38400/60000]
loss: 1.118884 [44800/60000]
loss: 1.159753 [51200/60000]
loss: 1.082628 [57600/60000]
Test Error:
 Accuracy: 64.1%, Avg loss: 1.099033

Tensorclass training done! time:  8.5422 s
Epoch 1
-------------------------
loss: 2.308042 [    0/60000]
loss: 2.298845 [ 6400/60000]
loss: 2.271829 [12800/60000]
loss: 2.261400 [19200/60000]
loss: 2.251647 [25600/60000]
loss: 2.210075 [32000/60000]
loss: 2.231766 [38400/60000]
loss: 2.188506 [44800/60000]
loss: 2.190940 [51200/60000]
loss: 2.149179 [57600/60000]
Test Error:
 Accuracy: 33.4%, Avg loss: 2.146284

Epoch 2
-------------------------
loss: 2.159867 [    0/60000]
loss: 2.151524 [ 6400/60000]
loss: 2.084647 [12800/60000]
loss: 2.098223 [19200/60000]
loss: 2.046152 [25600/60000]
loss: 1.983066 [32000/60000]
loss: 2.016088 [38400/60000]
loss: 1.934383 [44800/60000]
loss: 1.948664 [51200/60000]
loss: 1.854755 [57600/60000]
Test Error:
 Accuracy: 59.6%, Avg loss: 1.860862

Epoch 3
-------------------------
loss: 1.902126 [    0/60000]
loss: 1.867171 [ 6400/60000]
loss: 1.745260 [12800/60000]
loss: 1.780290 [19200/60000]
loss: 1.670075 [25600/60000]
loss: 1.632182 [32000/60000]
loss: 1.651149 [38400/60000]
loss: 1.564024 [44800/60000]
loss: 1.587803 [51200/60000]
loss: 1.465174 [57600/60000]
Test Error:
 Accuracy: 62.5%, Avg loss: 1.494086

Epoch 4
-------------------------
loss: 1.568314 [    0/60000]
loss: 1.531687 [ 6400/60000]
loss: 1.383616 [12800/60000]
loss: 1.444215 [19200/60000]
loss: 1.326217 [25600/60000]
loss: 1.338029 [32000/60000]
loss: 1.343920 [38400/60000]
loss: 1.286051 [44800/60000]
loss: 1.311489 [51200/60000]
loss: 1.203263 [57600/60000]
Test Error:
 Accuracy: 64.0%, Avg loss: 1.235341

Epoch 5
-------------------------
loss: 1.313958 [    0/60000]
loss: 1.298989 [ 6400/60000]
loss: 1.134868 [12800/60000]
loss: 1.230290 [19200/60000]
loss: 1.102433 [25600/60000]
loss: 1.143659 [32000/60000]
loss: 1.156451 [38400/60000]
loss: 1.113082 [44800/60000]
loss: 1.140398 [51200/60000]
loss: 1.050202 [57600/60000]
Test Error:
 Accuracy: 65.1%, Avg loss: 1.076012

Training done! time:  34.3460 s

指令碼總執行時間: (1 分 2.480 秒)

由 Sphinx-Gallery 生成的相簿

文件

查閱 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發者資源並獲得解答

檢視資源