注意
點選這裡下載完整的示例程式碼
編寫自定義資料集(Dataset)、資料載入器(DataLoader)和轉換(Transform)¶
創建於: 2017年6月10日 | 最後更新於: 2025年3月11日 | 最後驗證於: 2024年11月5日
解決任何機器學習問題都需要花費大量精力來準備資料。PyTorch 提供了許多工具來簡化資料載入,並希望能使你的程式碼更具可讀性。在本教程中,我們將學習如何從一個非平凡的資料集載入和預處理/增強資料。
要執行本教程,請確保已安裝以下軟體包
scikit-image: 用於影像 I/O 和轉換pandas: 用於更輕鬆地解析 CSV
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
plt.ion() # interactive mode
<contextlib.ExitStack object at 0x7f5673c7ebf0>
我們要處理的資料集是關於人臉姿態的。這意味著人臉的標註是這樣的
總共為每張人臉標註了 68 個不同的特徵點。
資料集附帶一個包含標註的 .csv 檔案,其內容如下所示
image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312
我們以 CSV 中的一個影像名稱及其標註為例,例如行索引 65 的 person-7.jpg。讀取它,將影像名稱儲存在 img_name 中,並將其標註儲存在一個 (L, 2) 的陣列 landmarks 中,其中 L 是該行中的特徵點數量。
landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')
n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:]
landmarks = np.asarray(landmarks, dtype=float).reshape(-1, 2)
print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))
Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
[33. 76.]
[34. 86.]
[34. 97.]]
我們來編寫一個簡單的輔助函式,用於顯示影像及其特徵點,並用它來展示一個樣本。
def show_landmarks(image, landmarks):
"""Show image with landmarks"""
plt.imshow(image)
plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
plt.pause(0.001) # pause a bit so that plots are updated
plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)),
landmarks)
plt.show()

Dataset 類¶
torch.utils.data.Dataset 是表示資料集的抽象類。你的自定義資料集應該繼承 Dataset 並覆蓋以下方法
__len__方法,使得len(dataset)返回資料集的大小。__getitem__方法,以支援索引,使得可以使用dataset[i]獲取第 \(i\) 個樣本。
讓我們為我們的人臉特徵點資料集建立一個 dataset 類。我們將在 __init__ 中讀取 csv 檔案,但將影像的讀取留在 __getitem__ 中。這樣做可以節省記憶體,因為所有影像不會同時儲存在記憶體中,而是按需讀取。
我們資料集的樣本將是一個字典 {'image': image, 'landmarks': landmarks}。我們的資料集將接受一個可選引數 transform,以便可以對樣本應用任何所需的處理。我們將在下一節中看到 transform 的有用性。
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform=None):
"""
Arguments:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = os.path.join(self.root_dir,
self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:]
landmarks = np.array([landmarks], dtype=float).reshape(-1, 2)
sample = {'image': image, 'landmarks': landmarks}
if self.transform:
sample = self.transform(sample)
return sample
讓我們例項化這個類並迭代資料樣本。我們將列印前 4 個樣本的大小並顯示它們的特徵點。
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
root_dir='data/faces/')
fig = plt.figure()
for i, sample in enumerate(face_dataset):
print(i, sample['image'].shape, sample['landmarks'].shape)
ax = plt.subplot(1, 4, i + 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
show_landmarks(**sample)
if i == 3:
plt.show()
break

0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)
轉換(Transforms)¶
從上面我們可以看到一個問題,即樣本大小不一致。大多數神經網路期望固定大小的影像。因此,我們需要編寫一些預處理程式碼。讓我們建立三個轉換
Rescale: 縮放影像RandomCrop: 從影像中隨機裁剪。這是資料增強。ToTensor: 將 numpy 影像轉換為 torch 影像(我們需要交換軸)。
我們將它們寫成可呼叫類而不是簡單的函式,這樣就無需每次呼叫時都傳遞轉換的引數。為此,我們只需實現 __call__ 方法,如果需要,再實現 __init__ 方法。然後我們可以像這樣使用一個轉換
tsfm = Transform(params)
transformed_sample = tsfm(sample)
請注意下面這些轉換必須同時應用於影像和特徵點。
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w))
# h and w are swapped for landmarks because for images,
# x and y axes are axis 1 and 0 respectively
landmarks = landmarks * [new_w / w, new_h / h]
return {'image': img, 'landmarks': landmarks}
class RandomCrop(object):
"""Crop randomly the image in a sample.
Args:
output_size (tuple or int): Desired output size. If int, square crop
is made.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h + 1)
left = np.random.randint(0, w - new_w + 1)
image = image[top: top + new_h,
left: left + new_w]
landmarks = landmarks - [left, top]
return {'image': image, 'landmarks': landmarks}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
# swap color axis because
# numpy image: H x W x C
# torch image: C x H x W
image = image.transpose((2, 0, 1))
return {'image': torch.from_numpy(image),
'landmarks': torch.from_numpy(landmarks)}
注意
在上面的例子中,RandomCrop 使用外部庫的隨機數生成器(在本例中是 Numpy 的 np.random.int)。這可能會導致 DataLoader 出現意外行為(參見此處)。實際上,更安全的方法是堅持使用 PyTorch 的隨機數生成器,例如使用 torch.randint。
組合轉換(Compose transforms)¶
現在,我們將轉換應用於一個樣本。
假設我們想將影像的較短邊縮放到 256,然後從中隨機裁剪一個大小為 224 的正方形。換句話說,我們想組合 Rescale 和 RandomCrop 轉換。torchvision.transforms.Compose 是一個簡單的可呼叫類,它允許我們這樣做。
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
RandomCrop(224)])
# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
transformed_sample = tsfrm(sample)
ax = plt.subplot(1, 3, i + 1)
plt.tight_layout()
ax.set_title(type(tsfrm).__name__)
show_landmarks(**transformed_sample)
plt.show()

迭代資料集¶
讓我們將這一切放在一起,建立一個帶有組合轉換的資料集。總結一下,每次從這個資料集中取樣時
影像會從檔案按需讀取
轉換會應用於讀取的影像
由於其中一個轉換是隨機的,資料在取樣時得到增強
我們可以像之前一樣使用 for i in range 迴圈遍歷建立的資料集。
transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
root_dir='data/faces/',
transform=transforms.Compose([
Rescale(256),
RandomCrop(224),
ToTensor()
]))
for i, sample in enumerate(transformed_dataset):
print(i, sample['image'].size(), sample['landmarks'].size())
if i == 3:
break
0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])
然而,透過使用簡單的 for 迴圈來迭代資料,我們丟失了許多功能。特別是,我們錯過了以下幾點
批次處理資料
打亂資料
使用
multiprocessing工作程序並行載入資料。
torch.utils.data.DataLoader 是一個迭代器,它提供了所有這些功能。下面使用的引數應該很清楚。一個值得關注的引數是 collate_fn。你可以使用 collate_fn 指定樣本如何精確地進行批次處理。但是,預設的 collate 對於大多數用例來說應該可以正常工作。
dataloader = DataLoader(transformed_dataset, batch_size=4,
shuffle=True, num_workers=0)
# Helper function to show a batch
def show_landmarks_batch(sample_batched):
"""Show image with landmarks for a batch of samples."""
images_batch, landmarks_batch = \
sample_batched['image'], sample_batched['landmarks']
batch_size = len(images_batch)
im_size = images_batch.size(2)
grid_border_size = 2
grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
for i in range(batch_size):
plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
landmarks_batch[i, :, 1].numpy() + grid_border_size,
s=10, marker='.', c='r')
plt.title('Batch from dataloader')
# if you are using Windows, uncomment the next line and indent the for loop.
# you might need to go back and change ``num_workers`` to 0.
# if __name__ == '__main__':
for i_batch, sample_batched in enumerate(dataloader):
print(i_batch, sample_batched['image'].size(),
sample_batched['landmarks'].size())
# observe 4th batch and stop.
if i_batch == 3:
plt.figure()
show_landmarks_batch(sample_batched)
plt.axis('off')
plt.ioff()
plt.show()
break

0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
後話:torchvision¶
在本教程中,我們學習瞭如何編寫和使用資料集(datasets)、轉換(transforms)和資料載入器(dataloader)。torchvision 包提供了一些常見的資料集和轉換。你甚至可能不必編寫自定義類。torchvision 中提供的一個更通用的資料集是 ImageFolder。它假設影像按以下方式組織
root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
其中“ants”、“bees”等是類別標籤。類似地,操作 PIL.Image 的通用轉換,如 RandomHorizontalFlip、Scale 等也可用。你可以使用這些來編寫一個數據載入器,如下所示
import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
batch_size=4, shuffle=True,
num_workers=4)
對於包含訓練程式碼的示例,請參閱計算機視覺的遷移學習教程。
指令碼總執行時間: ( 0 分鐘 1.910 秒)
