• 文件 >
  • 建立 TorchScript 模組
快捷方式

建立 TorchScript 模組

TorchScript 是一種從 PyTorch 程式碼建立可序列化和可最佳化模型的方法。PyTorch 有詳細的文件說明如何實現這一點:https://pytorch.com.tw/tutorials/beginner/Intro_to_TorchScript_tutorial.html。這裡簡要介紹一下關鍵背景資訊和過程。

PyTorch 程式基於 ModuleModule 可用於組合更高級別的模組。Module 包含一個用於設定模組、引數和子模組的建構函式,以及一個描述模組呼叫時如何使用引數和子模組的前向函式 (forward function)。

例如,我們可以這樣定義一個 LeNet 模組

 1import torch.nn as nn
 2import torch.nn.functional as F
 3
 4
 5class LeNetFeatExtractor(nn.Module):
 6    def __init__(self):
 7        super(LeNetFeatExtractor, self).__init__()
 8        self.conv1 = nn.Conv2d(1, 6, 3)
 9        self.conv2 = nn.Conv2d(6, 16, 3)
10
11    def forward(self, x):
12        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
13        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
14        return x
15
16
17class LeNetClassifier(nn.Module):
18    def __init__(self):
19        super(LeNetClassifier, self).__init__()
20        self.fc1 = nn.Linear(16 * 6 * 6, 120)
21        self.fc2 = nn.Linear(120, 84)
22        self.fc3 = nn.Linear(84, 10)
23
24    def forward(self, x):
25        x = torch.flatten(x, 1)
26        x = F.relu(self.fc1(x))
27        x = F.relu(self.fc2(x))
28        x = self.fc3(x)
29        return x
30
31
32class LeNet(nn.Module):
33    def __init__(self):
34        super(LeNet, self).__init__()
35        self.feat = LeNetFeatExtractor()
36        self.classifier = LeNetClassifier()
37
38    def forward(self, x):
39        x = self.feat(x)
40        x = self.classifier(x)
41        return x

.

顯然,您可能希望將如此簡單的模型整合到一個模組中,但這裡我們可以看到 PyTorch 的可組合性。

從 PyTorch Python 程式碼到 TorchScript 程式碼有兩種途徑:追蹤 (Tracing) 和指令碼化 (Scripting)。

追蹤 (Tracing) 在模組被呼叫時跟蹤執行路徑並記錄發生的情況。要追蹤我們的 LeNet 模組例項,我們可以使用一個示例輸入呼叫 torch.jit.trace

import torch

model = LeNet()
input_data = torch.empty([1, 1, 32, 32])
traced_model = torch.jit.trace(model, input_data)

指令碼化 (Scripting) 實際上是使用編譯器檢查您的程式碼並生成等效的 TorchScript 程式。區別在於,由於追蹤 (tracing) 跟蹤的是模組的執行過程,它無法捕捉控制流等資訊。而指令碼化 (scripting) 透過分析 Python 程式碼,編譯器可以包含這些元件。我們可以透過呼叫 torch.jit.script 在我們的 LeNet 模組上執行指令碼編譯器。

import torch

model = LeNet()
script_model = torch.jit.script(model)

選擇哪種途徑都有其原因,PyTorch 文件提供瞭如何選擇的資訊。從 Torch-TensorRT 的角度來看,對追蹤 (traced) 模組的支援更好(即您的模組更有可能被編譯),因為它不包含完整程式語言的所有複雜性,儘管兩種途徑都受支援。

在對模組進行指令碼化 (scripting) 或追蹤 (tracing) 後,您將獲得一個 TorchScript 模組。它包含用於執行模組的程式碼和引數,儲存在一個 Torch-TensorRT 可以使用的中間表示 (intermediate representation) 中。

以下是 LeNet 追蹤 (traced) 模組的 IR 外觀

graph(%self.1 : __torch__.___torch_mangle_10.LeNet,
    %input.1 : Float(1, 1, 32, 32)):
    %129 : __torch__.___torch_mangle_9.LeNetClassifier = prim::GetAttr[name="classifier"](%self.1)
    %119 : __torch__.___torch_mangle_5.LeNetFeatExtractor = prim::GetAttr[name="feat"](%self.1)
    %137 : Tensor = prim::CallMethod[name="forward"](%119, %input.1)
    %138 : Tensor = prim::CallMethod[name="forward"](%129, %137)
    return (%138)

以及 LeNet 指令碼化 (scripted) 模組的 IR

graph(%self : __torch__.LeNet,
    %x.1 : Tensor):
    %2 : __torch__.LeNetFeatExtractor = prim::GetAttr[name="feat"](%self)
    %x.3 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) # x.py:38:12
    %5 : __torch__.LeNetClassifier = prim::GetAttr[name="classifier"](%self)
    %x.5 : Tensor = prim::CallMethod[name="forward"](%5, %x.3) # x.py:39:12
    return (%x.5)

您可以看到 IR 保留了我們在 python 程式碼中的模組結構。

在 Python 中使用 TorchScript

TorchScript 模組的執行方式與普通 PyTorch 模組相同。您可以使用 forward 方法或直接呼叫模組(如 torch_script_module(in_tensor))來執行前向傳播 (forward pass)。JIT 編譯器將即時編譯和最佳化模組,然後返回結果。

將 TorchScript 模組儲存到磁碟

對於追蹤 (traced) 或指令碼化 (scripted) 模組,您可以使用以下命令將模組儲存到磁碟

import torch

model = LeNet()
script_model = torch.jit.script(model)
script_model.save("lenet_scripted.ts")

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源