快捷方式

學習基礎知識 || 快速入門 || 張量 || 資料集與資料載入器 || 變換 || 構建模型 || Autograd || 最佳化 || 儲存與載入模型

變換

建立日期:2021 年 2 月 9 日 | 最後更新:2021 年 8 月 11 日 | 最後驗證:未驗證

資料並不總是以機器學習演算法訓練所需的最終處理形式出現。我們使用 變換(transforms) 對資料進行一些處理,使其適合訓練。

所有 TorchVision 資料集都有兩個引數 -transform 用於修改特徵,target_transform 用於修改標籤 - 它們接受包含變換邏輯的可呼叫物件。 torchvision.transforms 模組提供了幾個常用的現成變換。

FashionMNIST 特徵採用 PIL Image 格式,標籤是整數。為了訓練,我們需要將特徵轉換為歸一化張量,將標籤轉換為獨熱編碼張量。為了實現這些變換,我們使用 ToTensorLambda

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:12, 364kB/s]
  1%|          | 229k/26.4M [00:00<00:38, 685kB/s]
  4%|3         | 950k/26.4M [00:00<00:11, 2.20MB/s]
 15%|#4        | 3.83M/26.4M [00:00<00:02, 7.64MB/s]
 38%|###7      | 10.0M/26.4M [00:00<00:00, 17.3MB/s]
 61%|######1   | 16.1M/26.4M [00:01<00:00, 22.9MB/s]
 83%|########2 | 21.9M/26.4M [00:01<00:00, 30.3MB/s]
 96%|#########5| 25.4M/26.4M [00:01<00:00, 26.8MB/s]
100%|##########| 26.4M/26.4M [00:01<00:00, 19.5MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|##########| 29.5k/29.5k [00:00<00:00, 326kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|1         | 65.5k/4.42M [00:00<00:11, 363kB/s]
  5%|5         | 229k/4.42M [00:00<00:06, 684kB/s]
 19%|#9        | 852k/4.42M [00:00<00:01, 2.37MB/s]
 44%|####3     | 1.93M/4.42M [00:00<00:00, 4.17MB/s]
100%|##########| 4.42M/4.42M [00:00<00:00, 6.11MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|##########| 5.15k/5.15k [00:00<00:00, 65.0MB/s]

ToTensor()

ToTensor 將 PIL 影像或 NumPy ndarray 轉換為 FloatTensor,並將影像的畫素強度值縮放到 [0., 1.] 範圍內。

Lambda 變換

Lambda 變換應用任何使用者定義的 lambda 函式。在這裡,我們定義了一個函式將整數轉換為獨熱編碼張量。它首先建立一個大小為 10 的零張量(資料集中標籤的數量),然後呼叫 scatter_ 函式,根據標籤 y 給定的索引位置賦值 value=1

target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源