• 教程 >
  • 編寫自定義資料集(Dataset)、資料載入器(DataLoader)和轉換(Transform)
快捷方式

編寫自定義資料集(Dataset)、資料載入器(DataLoader)和轉換(Transform)

創建於: 2017年6月10日 | 最後更新於: 2025年3月11日 | 最後驗證於: 2024年11月5日

作者: Sasank Chilamkurthy

解決任何機器學習問題都需要花費大量精力來準備資料。PyTorch 提供了許多工具來簡化資料載入,並希望能使你的程式碼更具可讀性。在本教程中,我們將學習如何從一個非平凡的資料集載入和預處理/增強資料。

要執行本教程,請確保已安裝以下軟體包

  • scikit-image: 用於影像 I/O 和轉換

  • pandas: 用於更輕鬆地解析 CSV

import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode
<contextlib.ExitStack object at 0x7f5673c7ebf0>

我們要處理的資料集是關於人臉姿態的。這意味著人臉的標註是這樣的

../_images/landmarked_face2.png

總共為每張人臉標註了 68 個不同的特徵點。

注意

這裡下載資料集,將影像放在名為“data/faces/”的目錄中。這個資料集實際上是透過在 imagenet 中標記為“face”的幾張圖片上應用出色的dlib 姿態估計生成的。

資料集附帶一個包含標註的 .csv 檔案,其內容如下所示

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

我們以 CSV 中的一個影像名稱及其標註為例,例如行索引 65 的 person-7.jpg。讀取它,將影像名稱儲存在 img_name 中,並將其標註儲存在一個 (L, 2) 的陣列 landmarks 中,其中 L 是該行中的特徵點數量。

landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:]
landmarks = np.asarray(landmarks, dtype=float).reshape(-1, 2)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))
Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
 [33. 76.]
 [34. 86.]
 [34. 97.]]

我們來編寫一個簡單的輔助函式,用於顯示影像及其特徵點,並用它來展示一個樣本。

def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)  # pause a bit so that plots are updated

plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)),
               landmarks)
plt.show()
data loading tutorial

Dataset 類

torch.utils.data.Dataset 是表示資料集的抽象類。你的自定義資料集應該繼承 Dataset 並覆蓋以下方法

  • __len__ 方法,使得 len(dataset) 返回資料集的大小。

  • __getitem__ 方法,以支援索引,使得可以使用 dataset[i] 獲取第 \(i\) 個樣本。

讓我們為我們的人臉特徵點資料集建立一個 dataset 類。我們將在 __init__ 中讀取 csv 檔案,但將影像的讀取留在 __getitem__ 中。這樣做可以節省記憶體,因為所有影像不會同時儲存在記憶體中,而是按需讀取。

我們資料集的樣本將是一個字典 {'image': image, 'landmarks': landmarks}。我們的資料集將接受一個可選引數 transform,以便可以對樣本應用任何所需的處理。我們將在下一節中看到 transform 的有用性。

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks], dtype=float).reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

讓我們例項化這個類並迭代資料樣本。我們將列印前 4 個樣本的大小並顯示它們的特徵點。

face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                    root_dir='data/faces/')

fig = plt.figure()

for i, sample in enumerate(face_dataset):
    print(i, sample['image'].shape, sample['landmarks'].shape)

    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)

    if i == 3:
        plt.show()
        break
Sample #0, Sample #1, Sample #2, Sample #3
0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)

轉換(Transforms)

從上面我們可以看到一個問題,即樣本大小不一致。大多數神經網路期望固定大小的影像。因此,我們需要編寫一些預處理程式碼。讓我們建立三個轉換

  • Rescale: 縮放影像

  • RandomCrop: 從影像中隨機裁剪。這是資料增強。

  • ToTensor: 將 numpy 影像轉換為 torch 影像(我們需要交換軸)。

我們將它們寫成可呼叫類而不是簡單的函式,這樣就無需每次呼叫時都傳遞轉換的引數。為此,我們只需實現 __call__ 方法,如果需要,再實現 __init__ 方法。然後我們可以像這樣使用一個轉換

tsfm = Transform(params)
transformed_sample = tsfm(sample)

請注意下面這些轉換必須同時應用於影像和特徵點。

class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'landmarks': landmarks}


class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h + 1)
        left = np.random.randint(0, w - new_w + 1)

        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}

注意

在上面的例子中,RandomCrop 使用外部庫的隨機數生成器(在本例中是 Numpy 的 np.random.int)。這可能會導致 DataLoader 出現意外行為(參見此處)。實際上,更安全的方法是堅持使用 PyTorch 的隨機數生成器,例如使用 torch.randint

組合轉換(Compose transforms)

現在,我們將轉換應用於一個樣本。

假設我們想將影像的較短邊縮放到 256,然後從中隨機裁剪一個大小為 224 的正方形。換句話說,我們想組合 RescaleRandomCrop 轉換。torchvision.transforms.Compose 是一個簡單的可呼叫類,它允許我們這樣做。

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])

# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
    transformed_sample = tsfrm(sample)

    ax = plt.subplot(1, 3, i + 1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)

plt.show()
Rescale, RandomCrop, Compose

迭代資料集

讓我們將這一切放在一起,建立一個帶有組合轉換的資料集。總結一下,每次從這個資料集中取樣時

  • 影像會從檔案按需讀取

  • 轉換會應用於讀取的影像

  • 由於其中一個轉換是隨機的,資料在取樣時得到增強

我們可以像之前一樣使用 for i in range 迴圈遍歷建立的資料集。

transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                           root_dir='data/faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))

for i, sample in enumerate(transformed_dataset):
    print(i, sample['image'].size(), sample['landmarks'].size())

    if i == 3:
        break
0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])

然而,透過使用簡單的 for 迴圈來迭代資料,我們丟失了許多功能。特別是,我們錯過了以下幾點

  • 批次處理資料

  • 打亂資料

  • 使用 multiprocessing 工作程序並行載入資料。

torch.utils.data.DataLoader 是一個迭代器,它提供了所有這些功能。下面使用的引數應該很清楚。一個值得關注的引數是 collate_fn。你可以使用 collate_fn 指定樣本如何精確地進行批次處理。但是,預設的 collate 對於大多數用例來說應該可以正常工作。

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=0)


# Helper function to show a batch
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid_border_size = 2

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
                    landmarks_batch[i, :, 1].numpy() + grid_border_size,
                    s=10, marker='.', c='r')

        plt.title('Batch from dataloader')

# if you are using Windows, uncomment the next line and indent the for loop.
# you might need to go back and change ``num_workers`` to 0.

# if __name__ == '__main__':
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())

    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break
Batch from dataloader
0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

後話:torchvision

在本教程中,我們學習瞭如何編寫和使用資料集(datasets)、轉換(transforms)和資料載入器(dataloader)。torchvision 包提供了一些常見的資料集和轉換。你甚至可能不必編寫自定義類。torchvision 中提供的一個更通用的資料集是 ImageFolder。它假設影像按以下方式組織

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

其中“ants”、“bees”等是類別標籤。類似地,操作 PIL.Image 的通用轉換,如 RandomHorizontalFlipScale 等也可用。你可以使用這些來編寫一個數據載入器,如下所示

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

對於包含訓練程式碼的示例,請參閱計算機視覺的遷移學習教程

指令碼總執行時間: ( 0 分鐘 1.910 秒)

由 Sphinx-Gallery 生成的相簿


評價本教程

© 版權所有 2024, PyTorch。

使用 Sphinx 構建,主題由 Read the Docs 提供。

文件

訪問 PyTorch 全面的開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源