注意
跳到末尾 下載完整示例程式碼。
使用 TensorClass 處理資料集¶
在本教程中,我們將演示如何使用 TensorClass 高效且透明地在訓練管道中載入和管理資料。本教程大量參考了 PyTorch 快速入門教程,但進行了修改以演示 TensorClass 的使用。請參閱相關的 TensorDict 使用教程。
import torch
import torch.nn as nn
from tensordict import MemoryMappedTensor, tensorclass
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu
torchvision.datasets 模組包含許多方便的預處理資料集。在本教程中,我們將使用相對簡單的 FashionMNIST 資料集。每張影像都是一件衣服,目標是分類影像中的服裝型別(例如,“包袋”、“運動鞋”等)。
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:12, 363kB/s]
1%| | 229k/26.4M [00:00<00:37, 693kB/s]
4%|▎ | 950k/26.4M [00:00<00:11, 2.21MB/s]
9%|▉ | 2.49M/26.4M [00:00<00:04, 5.73MB/s]
13%|█▎ | 3.34M/26.4M [00:00<00:04, 5.36MB/s]
18%|█▊ | 4.82M/26.4M [00:00<00:03, 6.40MB/s]
24%|██▍ | 6.36M/26.4M [00:01<00:02, 7.13MB/s]
30%|██▉ | 7.90M/26.4M [00:01<00:02, 7.60MB/s]
36%|███▌ | 9.50M/26.4M [00:01<00:02, 8.01MB/s]
42%|████▏ | 11.1M/26.4M [00:01<00:01, 8.35MB/s]
47%|████▋ | 12.5M/26.4M [00:01<00:01, 9.49MB/s]
52%|█████▏ | 13.6M/26.4M [00:01<00:01, 8.45MB/s]
58%|█████▊ | 15.3M/26.4M [00:02<00:01, 8.78MB/s]
63%|██████▎ | 16.8M/26.4M [00:02<00:00, 10.0MB/s]
68%|██████▊ | 17.9M/26.4M [00:02<00:00, 8.85MB/s]
74%|███████▍ | 19.7M/26.4M [00:02<00:00, 10.7MB/s]
79%|███████▉ | 20.9M/26.4M [00:02<00:00, 9.39MB/s]
85%|████████▍ | 22.4M/26.4M [00:02<00:00, 9.07MB/s]
92%|█████████▏| 24.2M/26.4M [00:03<00:00, 9.43MB/s]
99%|█████████▊| 26.1M/26.4M [00:03<00:00, 9.67MB/s]
100%|██████████| 26.4M/26.4M [00:03<00:00, 8.11MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 328kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|▏ | 65.5k/4.42M [00:00<00:12, 360kB/s]
5%|▌ | 229k/4.42M [00:00<00:06, 679kB/s]
21%|██ | 918k/4.42M [00:00<00:01, 2.60MB/s]
44%|████▎ | 1.93M/4.42M [00:00<00:00, 4.05MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.04MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 66.8MB/s]
TensorClass 是一種資料類 (dataclass),它像 TensorDict 一樣,提供了專門的張量方法來操作其內容。當您想儲存的資料結構固定且可預測時,TensorClass 是一個不錯的選擇。
除了指定內容,我們還可以在定義類時,將相關邏輯封裝為自定義方法。在本例中,我們將編寫一個 from_dataset 類方法,該方法接受資料集作為輸入,並建立一個包含資料集資料的 TensorClass。我們建立記憶體對映張量 (memory-mapped tensors) 來儲存資料。這將使我們能夠高效地從磁碟載入批次轉換後的資料,而不是重複載入和轉換單個影像。
@tensorclass
class FashionMNISTData:
images: torch.Tensor
targets: torch.Tensor
@classmethod
def from_dataset(cls, dataset, device=None):
data = cls(
images=MemoryMappedTensor.empty(
(len(dataset), *dataset[0][0].squeeze().shape), dtype=torch.float32
),
targets=MemoryMappedTensor.empty((len(dataset),), dtype=torch.int64),
batch_size=[len(dataset)],
device=device,
)
for i, (image, target) in enumerate(dataset):
data[i] = cls(images=image, targets=torch.tensor(target), batch_size=[])
return data
我們將建立兩個 TensorClass,分別用於訓練和測試資料。請注意,由於我們需要遍歷整個資料集,對其進行轉換並儲存到磁碟,因此這裡會產生一些開銷。
資料載入器 (DataLoaders)¶
我們將從 torchvision 提供的資料集以及我們的記憶體對映 TensorDict 建立資料載入器 (DataLoaders)。
由於 TensorDict 實現了 __len__ 和 __getitem__ (以及 __getitems__),我們可以像 map-style 資料集一樣使用它,並直接從中建立 DataLoader。請注意,由於 TensorDict 已經能夠處理批次索引,因此不需要 collate,所以我們將恆等函式 (identity function) 作為 collate_fn 傳遞。
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size) # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size) # noqa: TOR401
train_dataloader_tc = DataLoader( # noqa: TOR401
training_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_tc = DataLoader( # noqa: TOR401
test_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)
模型¶
我們使用與 快速入門教程 中相同的模型。
class Net(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = Net().to(device)
model_tc = Net().to(device)
model, model_tc
(Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
), Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
))
最佳化引數¶
我們將使用隨機梯度下降 (stochastic gradient descent) 和交叉熵損失 (cross-entropy loss) 來最佳化模型的引數。
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_tc = torch.optim.SGD(model_tc.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
基於 TensorClass 的資料載入器 (DataLoader) 的訓練迴圈非常相似,我們只需調整如何解包資料,以適應 TensorClass 提供的更顯式的基於屬性的檢索方式。.contiguous() 方法載入儲存在記憶體對映張量 (memmap tensor) 中的資料。
def train_tc(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, data in enumerate(dataloader):
X, y = data.images.contiguous(), data.targets.contiguous()
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
def test_tc(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for batch in dataloader:
X, y = batch.images.contiguous(), batch.targets.contiguous()
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
for d in train_dataloader_tc:
print(d)
break
import time
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train_tc(train_dataloader_tc, model_tc, loss_fn, optimizer_tc)
test_tc(test_dataloader_tc, model_tc, loss_fn)
print(f"Tensorclass training done! time: {time.time() - t0: 4.4f} s")
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
FashionMNISTData(
images=Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False),
batch_size=torch.Size([64]),
device=cpu,
is_shared=False)
Epoch 1
-------------------------
loss: 2.303174 [ 0/60000]
loss: 2.292315 [ 6400/60000]
loss: 2.276398 [12800/60000]
loss: 2.266935 [19200/60000]
loss: 2.243844 [25600/60000]
loss: 2.225336 [32000/60000]
loss: 2.217357 [38400/60000]
loss: 2.195517 [44800/60000]
loss: 2.190626 [51200/60000]
loss: 2.155467 [57600/60000]
Test Error:
Accuracy: 50.9%, Avg loss: 2.152797
Epoch 2
-------------------------
loss: 2.161741 [ 0/60000]
loss: 2.150906 [ 6400/60000]
loss: 2.099977 [12800/60000]
loss: 2.111458 [19200/60000]
loss: 2.054362 [25600/60000]
loss: 2.009193 [32000/60000]
loss: 2.014337 [38400/60000]
loss: 1.949405 [44800/60000]
loss: 1.948474 [51200/60000]
loss: 1.877130 [57600/60000]
Test Error:
Accuracy: 54.6%, Avg loss: 1.879634
Epoch 3
-------------------------
loss: 1.910757 [ 0/60000]
loss: 1.880123 [ 6400/60000]
loss: 1.772925 [12800/60000]
loss: 1.808776 [19200/60000]
loss: 1.692901 [25600/60000]
loss: 1.657541 [32000/60000]
loss: 1.662243 [38400/60000]
loss: 1.577471 [44800/60000]
loss: 1.601571 [51200/60000]
loss: 1.500768 [57600/60000]
Test Error:
Accuracy: 59.0%, Avg loss: 1.518311
Epoch 4
-------------------------
loss: 1.585062 [ 0/60000]
loss: 1.546511 [ 6400/60000]
loss: 1.407458 [12800/60000]
loss: 1.477034 [19200/60000]
loss: 1.352650 [25600/60000]
loss: 1.358121 [32000/60000]
loss: 1.363586 [38400/60000]
loss: 1.295179 [44800/60000]
loss: 1.331715 [51200/60000]
loss: 1.239219 [57600/60000]
Test Error:
Accuracy: 62.6%, Avg loss: 1.260449
Epoch 5
-------------------------
loss: 1.337718 [ 0/60000]
loss: 1.312946 [ 6400/60000]
loss: 1.158820 [12800/60000]
loss: 1.262185 [19200/60000]
loss: 1.131436 [25600/60000]
loss: 1.163135 [32000/60000]
loss: 1.178597 [38400/60000]
loss: 1.118884 [44800/60000]
loss: 1.159753 [51200/60000]
loss: 1.082628 [57600/60000]
Test Error:
Accuracy: 64.1%, Avg loss: 1.099033
Tensorclass training done! time: 8.5422 s
Epoch 1
-------------------------
loss: 2.308042 [ 0/60000]
loss: 2.298845 [ 6400/60000]
loss: 2.271829 [12800/60000]
loss: 2.261400 [19200/60000]
loss: 2.251647 [25600/60000]
loss: 2.210075 [32000/60000]
loss: 2.231766 [38400/60000]
loss: 2.188506 [44800/60000]
loss: 2.190940 [51200/60000]
loss: 2.149179 [57600/60000]
Test Error:
Accuracy: 33.4%, Avg loss: 2.146284
Epoch 2
-------------------------
loss: 2.159867 [ 0/60000]
loss: 2.151524 [ 6400/60000]
loss: 2.084647 [12800/60000]
loss: 2.098223 [19200/60000]
loss: 2.046152 [25600/60000]
loss: 1.983066 [32000/60000]
loss: 2.016088 [38400/60000]
loss: 1.934383 [44800/60000]
loss: 1.948664 [51200/60000]
loss: 1.854755 [57600/60000]
Test Error:
Accuracy: 59.6%, Avg loss: 1.860862
Epoch 3
-------------------------
loss: 1.902126 [ 0/60000]
loss: 1.867171 [ 6400/60000]
loss: 1.745260 [12800/60000]
loss: 1.780290 [19200/60000]
loss: 1.670075 [25600/60000]
loss: 1.632182 [32000/60000]
loss: 1.651149 [38400/60000]
loss: 1.564024 [44800/60000]
loss: 1.587803 [51200/60000]
loss: 1.465174 [57600/60000]
Test Error:
Accuracy: 62.5%, Avg loss: 1.494086
Epoch 4
-------------------------
loss: 1.568314 [ 0/60000]
loss: 1.531687 [ 6400/60000]
loss: 1.383616 [12800/60000]
loss: 1.444215 [19200/60000]
loss: 1.326217 [25600/60000]
loss: 1.338029 [32000/60000]
loss: 1.343920 [38400/60000]
loss: 1.286051 [44800/60000]
loss: 1.311489 [51200/60000]
loss: 1.203263 [57600/60000]
Test Error:
Accuracy: 64.0%, Avg loss: 1.235341
Epoch 5
-------------------------
loss: 1.313958 [ 0/60000]
loss: 1.298989 [ 6400/60000]
loss: 1.134868 [12800/60000]
loss: 1.230290 [19200/60000]
loss: 1.102433 [25600/60000]
loss: 1.143659 [32000/60000]
loss: 1.156451 [38400/60000]
loss: 1.113082 [44800/60000]
loss: 1.140398 [51200/60000]
loss: 1.050202 [57600/60000]
Test Error:
Accuracy: 65.1%, Avg loss: 1.076012
Training done! time: 34.3460 s
指令碼總執行時間: (1 分 2.480 秒)