注意
轉到末尾 以下載完整的示例程式碼。
將 TensorDict 用於資料集¶
在本教程中,我們將演示如何使用 TensorDict 高效且透明地在訓練管道中載入和管理資料。本教程主要基於 PyTorch 快速入門教程,並進行了修改以展示 TensorDict 的用法。
import torch
import torch.nn as nn
from tensordict import MemoryMappedTensor, TensorDict
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu
torchvision.datasets 模組包含許多方便的預備資料集。在本教程中,我們將使用相對簡單的 FashionMNIST 資料集。每張影像都是一件衣服,目標是分類影像中的衣服型別(例如,“包”、“運動鞋”等)。
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
我們將建立兩個 tensordict,分別用於訓練資料和測試資料。我們建立記憶體對映張量來儲存資料。這將使我們能夠高效地從磁碟載入轉換後的資料批次,而不是重複載入和轉換單個影像。
首先,我們建立 MemoryMappedTensor 容器。
training_data_td = TensorDict(
{
"images": MemoryMappedTensor.empty(
(len(training_data), *training_data[0][0].squeeze().shape),
dtype=torch.float32,
),
"targets": MemoryMappedTensor.empty((len(training_data),), dtype=torch.int64),
},
batch_size=[len(training_data)],
device=device,
)
test_data_td = TensorDict(
{
"images": MemoryMappedTensor.empty(
(len(test_data), *test_data[0][0].squeeze().shape), dtype=torch.float32
),
"targets": MemoryMappedTensor.empty((len(test_data),), dtype=torch.int64),
},
batch_size=[len(test_data)],
device=device,
)
然後我們可以迭代資料來填充記憶體對映張量。這需要一些時間,但預先執行轉換可以節省後續訓練中的重複工作。
資料載入器¶
我們將從 torchvision 提供的 Dataset 以及我們的記憶體對映 TensorDict 建立 DataLoaders。
由於 TensorDict 實現了 __len__ 和 __getitem__(以及 __getitems__),我們可以像 map-style Dataset 一樣使用它,並直接從中建立 DataLoader。請注意,由於 TensorDict 已經可以處理批次索引,因此無需進行 collation(整理),因此我們將 identity 函式作為 collate_fn 傳遞。
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size) # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size) # noqa: TOR401
train_dataloader_td = DataLoader( # noqa: TOR401
training_data_td, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_td = DataLoader( # noqa: TOR401
test_data_td, batch_size=batch_size, collate_fn=lambda x: x
)
模型¶
我們使用與 快速入門教程 中相同的模型。
class Net(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = Net().to(device)
model_td = Net().to(device)
model, model_td
(Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
), Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
))
最佳化引數¶
我們將使用隨機梯度下降和交叉熵損失來最佳化模型的引數。
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_td = torch.optim.SGD(model_td.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
我們基於 TensorDict 的 DataLoader 的訓練迴圈非常相似,我們只需調整如何解包資料,使其採用 TensorDict 提供的更顯式的基於鍵的檢索方式。.contiguous() 方法載入儲存在 memmap 張量中的資料。
def train_td(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, data in enumerate(dataloader):
X, y = data["images"].contiguous(), data["targets"].contiguous()
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
def test_td(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for batch in dataloader:
X, y = batch["images"].contiguous(), batch["targets"].contiguous()
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
for d in train_dataloader_td:
print(d)
break
import time
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train_td(train_dataloader_td, model_td, loss_fn, optimizer_td)
test_td(test_dataloader_td, model_td, loss_fn)
print(f"TensorDict training done! time: {time.time() - t0: 4.4f} s")
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
TensorDict(
fields={
images: Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
targets: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([64]),
device=cpu,
is_shared=False)
Epoch 1
-------------------------
loss: 2.288576 [ 0/60000]
loss: 2.279963 [ 6400/60000]
loss: 2.268667 [12800/60000]
loss: 2.270308 [19200/60000]
loss: 2.233671 [25600/60000]
loss: 2.215914 [32000/60000]
loss: 2.214695 [38400/60000]
loss: 2.178674 [44800/60000]
loss: 2.172899 [51200/60000]
loss: 2.149208 [57600/60000]
Test Error:
Accuracy: 48.3%, Avg loss: 2.142256
Epoch 2
-------------------------
loss: 2.145386 [ 0/60000]
loss: 2.135478 [ 6400/60000]
loss: 2.081977 [12800/60000]
loss: 2.100613 [19200/60000]
loss: 2.030724 [25600/60000]
loss: 1.981814 [32000/60000]
loss: 1.995256 [38400/60000]
loss: 1.916826 [44800/60000]
loss: 1.928639 [51200/60000]
loss: 1.847330 [57600/60000]
Test Error:
Accuracy: 54.9%, Avg loss: 1.853513
Epoch 3
-------------------------
loss: 1.888883 [ 0/60000]
loss: 1.856458 [ 6400/60000]
loss: 1.744976 [12800/60000]
loss: 1.781419 [19200/60000]
loss: 1.656465 [25600/60000]
loss: 1.631539 [32000/60000]
loss: 1.631200 [38400/60000]
loss: 1.546029 [44800/60000]
loss: 1.584942 [51200/60000]
loss: 1.466830 [57600/60000]
Test Error:
Accuracy: 61.1%, Avg loss: 1.489801
Epoch 4
-------------------------
loss: 1.562606 [ 0/60000]
loss: 1.525372 [ 6400/60000]
loss: 1.381523 [12800/60000]
loss: 1.445276 [19200/60000]
loss: 1.316936 [25600/60000]
loss: 1.337760 [32000/60000]
loss: 1.331586 [38400/60000]
loss: 1.266051 [44800/60000]
loss: 1.314883 [51200/60000]
loss: 1.210139 [57600/60000]
Test Error:
Accuracy: 64.0%, Avg loss: 1.230520
Epoch 5
-------------------------
loss: 1.308549 [ 0/60000]
loss: 1.290072 [ 6400/60000]
loss: 1.127975 [12800/60000]
loss: 1.229955 [19200/60000]
loss: 1.095840 [25600/60000]
loss: 1.140236 [32000/60000]
loss: 1.147135 [38400/60000]
loss: 1.088268 [44800/60000]
loss: 1.143719 [51200/60000]
loss: 1.056760 [57600/60000]
Test Error:
Accuracy: 65.5%, Avg loss: 1.069039
TensorDict training done! time: 8.5605 s
Epoch 1
-------------------------
loss: 2.299359 [ 0/60000]
loss: 2.285545 [ 6400/60000]
loss: 2.273665 [12800/60000]
loss: 2.269405 [19200/60000]
loss: 2.254834 [25600/60000]
loss: 2.229681 [32000/60000]
loss: 2.230662 [38400/60000]
loss: 2.202860 [44800/60000]
loss: 2.191047 [51200/60000]
loss: 2.169984 [57600/60000]
Test Error:
Accuracy: 49.4%, Avg loss: 2.163469
Epoch 2
-------------------------
loss: 2.173521 [ 0/60000]
loss: 2.155031 [ 6400/60000]
loss: 2.110353 [12800/60000]
loss: 2.122467 [19200/60000]
loss: 2.073355 [25600/60000]
loss: 2.024080 [32000/60000]
loss: 2.040011 [38400/60000]
loss: 1.973018 [44800/60000]
loss: 1.969435 [51200/60000]
loss: 1.899309 [57600/60000]
Test Error:
Accuracy: 56.6%, Avg loss: 1.901211
Epoch 3
-------------------------
loss: 1.941732 [ 0/60000]
loss: 1.895479 [ 6400/60000]
loss: 1.793313 [12800/60000]
loss: 1.823587 [19200/60000]
loss: 1.725540 [25600/60000]
loss: 1.682630 [32000/60000]
loss: 1.695444 [38400/60000]
loss: 1.609786 [44800/60000]
loss: 1.630640 [51200/60000]
loss: 1.519337 [57600/60000]
Test Error:
Accuracy: 61.0%, Avg loss: 1.542196
Epoch 4
-------------------------
loss: 1.619711 [ 0/60000]
loss: 1.563892 [ 6400/60000]
loss: 1.428948 [12800/60000]
loss: 1.484869 [19200/60000]
loss: 1.386341 [25600/60000]
loss: 1.376361 [32000/60000]
loss: 1.382083 [38400/60000]
loss: 1.319620 [44800/60000]
loss: 1.352009 [51200/60000]
loss: 1.241564 [57600/60000]
Test Error:
Accuracy: 63.6%, Avg loss: 1.274075
Epoch 5
-------------------------
loss: 1.360628 [ 0/60000]
loss: 1.321853 [ 6400/60000]
loss: 1.170190 [12800/60000]
loss: 1.258210 [19200/60000]
loss: 1.151774 [25600/60000]
loss: 1.169670 [32000/60000]
loss: 1.181415 [38400/60000]
loss: 1.131277 [44800/60000]
loss: 1.170511 [51200/60000]
loss: 1.072559 [57600/60000]
Test Error:
Accuracy: 64.7%, Avg loss: 1.101794
Training done! time: 34.6420 s
指令碼總執行時間:(0 分鐘 56.008 秒)