快捷方式

學習基礎知識 || 快速入門 || 張量 || 資料集與 DataLoaders || Transforms || 構建模型 || Autograd || 最佳化 || 儲存與載入模型

儲存和載入模型

建立日期: 2021 年 2 月 9 日 | 最後更新: 2024 年 10 月 15 日 | 最後驗證: 2024 年 11 月 5 日

在本節中,我們將探討如何透過儲存、載入和執行模型預測來持久化模型狀態。

import torch
import torchvision.models as models

儲存和載入模型權重

PyTorch 模型將其學習到的引數儲存在一個內部狀態字典中,稱為 state_dict。這些引數可以透過 torch.save 方法進行持久化。

model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/vgg16-397923af.pth

  0%|          | 0.00/528M [00:00<?, ?B/s]
  6%|5         | 30.5M/528M [00:00<00:01, 318MB/s]
 13%|#2        | 66.5M/528M [00:00<00:01, 353MB/s]
 19%|#8        | 100M/528M [00:00<00:01, 317MB/s]
 25%|##4       | 131M/528M [00:00<00:01, 268MB/s]
 30%|##9       | 158M/528M [00:00<00:01, 265MB/s]
 35%|###5      | 187M/528M [00:00<00:01, 278MB/s]
 41%|####1     | 219M/528M [00:00<00:01, 294MB/s]
 49%|####8     | 257M/528M [00:00<00:00, 327MB/s]
 55%|#####4    | 289M/528M [00:01<00:00, 305MB/s]
 62%|######1   | 325M/528M [00:01<00:00, 324MB/s]
 67%|######7   | 356M/528M [00:01<00:00, 300MB/s]
 73%|#######3  | 386M/528M [00:01<00:00, 271MB/s]
 79%|#######8  | 415M/528M [00:01<00:00, 280MB/s]
 84%|########3 | 442M/528M [00:01<00:00, 254MB/s]
 89%|########8 | 468M/528M [00:01<00:00, 258MB/s]
 95%|#########4| 499M/528M [00:01<00:00, 274MB/s]
100%|##########| 528M/528M [00:01<00:00, 288MB/s]

要載入模型權重,首先需要建立相同模型的例項,然後使用 load_state_dict() 方法載入引數。

在下面的程式碼中,我們將 weights_only 設定為 True,以限制反序列化過程中執行的函式,僅保留載入權重所需的函式。載入權重時,使用 weights_only=True 被認為是最佳實踐。

model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth', weights_only=True))
model.eval()
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

注意

在進行推理之前,請務必呼叫 model.eval() 方法,將 dropout 和批次歸一化層設定為評估模式。若未能這樣做,推理結果將不穩定。

儲存和載入帶模型結構的模型

載入模型權重時,我們需要先例項化模型類,因為類定義了網路的結構。我們可能希望將這個類的結構與模型一起儲存,在這種情況下,我們可以將 model(而不是 model.state_dict())傳遞給儲存函式。

torch.save(model, 'model.pth')

然後,我們可以按如下所示載入模型。

Saving and loading torch.nn.Modules 中所述,儲存 state_dict 被認為是最佳實踐。然而,下面我們使用 weights_only=False 是因為這涉及載入整個模型,這是 torch.save 的一個遺留用例。

model = torch.load('model.pth', weights_only=False),

注意

這種方法在序列化模型時使用了 Python pickle 模組,因此在載入模型時需要實際的類定義可用。

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源