TorchScript¶
TorchScript 是一種從 PyTorch 程式碼建立可序列化和可最佳化模型的方法。任何 TorchScript 程式都可以從 Python 程序中儲存,並在沒有 Python 依賴項的程序中載入。
我們提供了一些工具,可以將模型從純 Python 程式逐步轉換為可以獨立於 Python 執行的 TorchScript 程式,例如獨立的 C++ 程式。這使得可以使用 Python 中熟悉的工具在 PyTorch 中訓練模型,然後透過 TorchScript 將模型匯出到生產環境中,在這些環境中,Python 程式可能會因為效能和多執行緒的原因而處於劣勢。
如需 TorchScript 的入門介紹,請參閱TorchScript 簡介教學。
如需將 PyTorch 模型轉換為 TorchScript 並在 C++ 中執行的端到端範例,請參閱在 C++ 中載入 PyTorch 模型教學。
建立 TorchScript 程式碼¶
| 將函數腳本化。 | |
| 追蹤函數並返回一個可執行檔或 | |
| 在追蹤過程中第一次呼叫 | |
| 追蹤模組並返回一個可執行的 | |
| 建立一個執行func的非同步任務,以及對該執行結果值的參考。 | |
| 強制完成torch.jit.Future[T]非同步任務,返回任務的結果。 | |
| C++ torch::jit::Module 的包裝器,具有方法、屬性和參數。 | |
| 在功能上等同於 | |
| 凍結 ScriptModule,將子模組和屬性內嵌為常數。 | |
| 執行一組最佳化過程,以最佳化模型以用於推論。 | |
| 根據參數enabled啟用或停用 onednn JIT 融合。 | |
| 返回是否啟用了 onednn JIT 融合。 | |
| 設定融合過程中可能發生的特殊化類型和數量。 | |
| 如果並非所有節點都在推論中融合,或在訓練中進行符號微分,則發出錯誤。 | |
| 儲存此模組的離線版本,以供在另一個程序中使用。 | |
| 載入先前使用 | |
| 此裝飾器向編譯器指示應忽略某個函數或方法,並將其保留為 Python 函數。 | |
| 此裝飾器向編譯器指示應忽略某個函數或方法,並將其替換為引發異常。 | |
| 裝飾以註釋不同類型的類別或模組。 | |
| 在 TorchScript 中提供容器類型細化。 | |
| 此方法是一個傳遞函數,返回value,主要用於向 TorchScript 編譯器指示左側表達式是一個類別實例屬性,其類型為type。 | |
| 用於在 TorchScript 編譯器中指定the_value的類型。 | 
混合追蹤和腳本編寫¶
在許多情況下,追蹤或腳本編寫是將模型轉換為 TorchScript 的一種更簡單的方法。可以組合追蹤和腳本編寫以滿足模型一部分的特定需求。
腳本函式可以呼叫追蹤函式。當您需要在簡單的前饋模型周圍使用控制流程時,這一點特別有用。例如,序列到序列模型的集束搜尋通常會以腳本形式編寫,但可以呼叫使用追蹤產生的編碼器模組。
範例(在腳本中呼叫追蹤函式)
import torch
def foo(x, y):
    return 2 * x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
@torch.jit.script
def bar(x):
    return traced_foo(x, x)
追蹤函式可以呼叫腳本函式。當模型的一小部分需要一些控制流程,而模型的大部分只是一個前饋網路時,這一點非常有用。追蹤函式呼叫的腳本函式內部的控制流程會被正確保留。
範例(在追蹤函式中呼叫腳本函式)
import torch
@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r
def bar(x, y, z):
    return foo(x, y) + z
traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))
這種組合也適用於 nn.Module,其中它可以用於使用追蹤產生一個子模組,該子模組可以從腳本模組的方法中呼叫。
範例(使用追蹤模組)
import torch
import torchvision
class MyScriptModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
                                        .resize_(1, 3, 1, 1))
        self.resnet = torch.jit.trace(torchvision.models.resnet18(),
                                      torch.rand(1, 3, 224, 224))
    def forward(self, input):
        return self.resnet(input - self.means)
my_script_module = torch.jit.script(MyScriptModule())
TorchScript 語言¶
TorchScript 是 Python 的靜態類型子集,因此許多 Python 特性可以直接應用於 TorchScript。有關詳細資訊,請參閱完整的 TorchScript 語言參考。
內建函式和模組¶
TorchScript 支援使用大多數 PyTorch 函式和許多 Python 內建函式。有關支援函式的完整參考,請參閱 TorchScript 內建函式。
PyTorch 函式和模組¶
TorchScript 支援 PyTorch 提供的張量和神經網路函式的子集。張量上的大多數方法以及 torch 命名空間中的函式、torch.nn.functional 中的所有函式以及 torch.nn 中的大多數模組都在 TorchScript 中受支援。
有關不受支援的 PyTorch 函式和模組的清單,請參閱 TorchScript 不支援的 PyTorch 構造。
Python 函式和模組¶
TorchScript 中支援許多 Python 的 內建函式。也支援 math 模組(有關詳細資訊,請參閱 math 模組),但不支援其他 Python 模組(內建或協力廠商)。
Python 語言參考比較¶
有關支援的 Python 特性的完整清單,請參閱 Python 語言參考涵蓋範圍。
除錯¶
停用 JIT 以進行除錯¶
- PYTORCH_JIT¶
設定環境變數 PYTORCH_JIT=0 將停用所有腳本和追蹤註釋。如果您的其中一個 TorchScript 模型中存在難以除錯的錯誤,則可以使用此旗標強制所有內容都使用原生 Python 執行。由於使用此旗標停用了 TorchScript(腳本編寫和追蹤),因此您可以使用 pdb 之類的工具來除錯模型代碼。例如
@torch.jit.script
def scripted_fn(x : torch.Tensor):
    for i in range(12):
        x = x + x
    return x
def fn(x):
    x = torch.neg(x)
    import pdb; pdb.set_trace()
    return scripted_fn(x)
traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))
使用 pdb 除錯此腳本是可行的,除非我們呼叫 @torch.jit.script 函式。我們可以全域停用 JIT,以便我們可以將 @torch.jit.script 函式作為普通的 Python 函式呼叫,而不是編譯它。如果上述腳本被稱為 disable_jit_example.py,我們可以像這樣呼叫它
$ PYTORCH_JIT=0 python disable_jit_example.py
並且我們將能夠作為普通的 Python 函式逐步執行 @torch.jit.script 函式。要停用特定函式的 TorchScript 編譯器,請參閱 @torch.jit.ignore。
檢查代碼¶
TorchScript 為所有 ScriptModule 實例提供了一個代碼美化程式。這個美化程式將腳本方法的代碼解釋為有效的 Python 語法。例如
@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv
print(foo.code)
具有單個 forward 方法的 ScriptModule 將具有一個屬性 code,您可以使用它來檢查 ScriptModule 的代碼。如果 ScriptModule 具有多個方法,則需要存取方法本身上的 .code,而不是模組上的 .code。我們可以通過存取 .foo.code 來檢查 ScriptModule 上名為 foo 的方法的代碼。上面的示例產生以下輸出
def foo(len: int) -> Tensor:
    rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
    rv0 = rv
    for i in range(len):
        if torch.lt(i, 10):
            rv1 = torch.sub(rv0, 1., 1)
        else:
            rv1 = torch.add(rv0, 1., 1)
        rv0 = rv1
    return rv0
這是 TorchScript 對 forward 方法的代碼的編譯。您可以使用它來確保 TorchScript(追蹤或腳本編寫)正確捕獲了您的模型代碼。
解讀圖表¶
TorchScript 也具有比代碼美化程式更低級別的表示形式,即 IR 圖表的形式。
TorchScript 使用靜態單一賦值 (SSA) 中間表示 (IR) 來表示計算。這種格式的指令包括 ATen(PyTorch 的 C++ 後端)運算符和其他基本運算符,包括用於迴圈和條件式的控制流程運算符。舉例來說
@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv
print(foo.graph)
graph 遵循 檢查代碼 部分中描述的關於 forward 方法查找的相同規則。
上面的示例腳本產生以下圖表
graph(%len.1 : int):
  %24 : int = prim::Constant[value=1]()
  %17 : bool = prim::Constant[value=1]() # test.py:10:5
  %12 : bool? = prim::Constant()
  %10 : Device? = prim::Constant()
  %6 : int? = prim::Constant()
  %1 : int = prim::Constant[value=3]() # test.py:9:22
  %2 : int = prim::Constant[value=4]() # test.py:9:25
  %20 : int = prim::Constant[value=10]() # test.py:11:16
  %23 : float = prim::Constant[value=1]() # test.py:12:23
  %4 : int[] = prim::ListConstruct(%1, %2)
  %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
  %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
    block0(%i.1 : int, %rv.14 : Tensor):
      %21 : bool = aten::lt(%i.1, %20) # test.py:11:12
      %rv.13 : Tensor = prim::If(%21) # test.py:11:9
        block0():
          %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
          -> (%rv.3)
        block1():
          %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
          -> (%rv.6)
      -> (%17, %rv.13)
  return (%rv)
以指令 %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10 為例。
- %rv.1 : Tensor表示我們將輸出賦值給一個名為- rv.1的(唯一)值,該值屬於- Tensor類型,並且我們不知道它的具體形狀。
- aten::zeros是運算符(等效於- torch.zeros),輸入清單- (%4, %6, %6, %10, %12)指定應將範圍內的哪些值作為輸入傳遞。可以在 內建函式 中找到內建函式(如- aten::zeros)的架構。
- # test.py:9:10是生成此指令的原始原始碼檔案中的位置。在這種情況下,它是一個名為 test.py 的檔案,位於第 9 行,第 10 個字元。
請注意,運算符也可以有關聯的 blocks,即 prim::Loop 和 prim::If 運算符。在圖表列印輸出中,這些運算符的格式設定為反映其等效的原始碼形式,以便於除錯。
可以如圖所示檢查圖表,以確認 ScriptModule 描述的計算是正確的,如下所述,可以自動和手動方式進行。
追蹤器¶
追蹤邊緣案例¶
存在一些邊緣案例,其中給定 Python 函式/模組的追蹤不能代表底層代碼。這些案例可能包括
- 追蹤依賴於輸入(例如張量形狀)的控制流程 
- 追蹤張量檢視的原地操作(例如,賦值左側的索引) 
請注意,這些案例在將來實際上可能是可以追蹤的。
自動追蹤檢查¶
自動捕獲追蹤中的許多錯誤的一種方法是在 torch.jit.trace() API 上使用 check_inputs。check_inputs 採用輸入元組清單,這些輸入將用於重新追蹤計算並驗證結果。例如
def loop_in_traced_fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result
inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]
traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)
為我們提供以下診斷資訊
ERROR: Graphs differed across invocations!
Graph diff:
            graph(%x : Tensor) {
            %1 : int = prim::Constant[value=0]()
            %2 : int = prim::Constant[value=0]()
            %result.1 : Tensor = aten::select(%x, %1, %2)
            %4 : int = prim::Constant[value=0]()
            %5 : int = prim::Constant[value=0]()
            %6 : Tensor = aten::select(%x, %4, %5)
            %result.2 : Tensor = aten::mul(%result.1, %6)
            %8 : int = prim::Constant[value=0]()
            %9 : int = prim::Constant[value=1]()
            %10 : Tensor = aten::select(%x, %8, %9)
        -   %result : Tensor = aten::mul(%result.2, %10)
        +   %result.3 : Tensor = aten::mul(%result.2, %10)
        ?          ++
            %12 : int = prim::Constant[value=0]()
            %13 : int = prim::Constant[value=2]()
            %14 : Tensor = aten::select(%x, %12, %13)
        +   %result : Tensor = aten::mul(%result.3, %14)
        +   %16 : int = prim::Constant[value=0]()
        +   %17 : int = prim::Constant[value=3]()
        +   %18 : Tensor = aten::select(%x, %16, %17)
        -   %15 : Tensor = aten::mul(%result, %14)
        ?     ^                                 ^
        +   %19 : Tensor = aten::mul(%result, %18)
        ?     ^                                 ^
        -   return (%15);
        ?             ^
        +   return (%19);
        ?             ^
            }
此消息向我們表明,計算在我們第一次追蹤它時和我們使用 check_inputs 追蹤它時有所不同。實際上,loop_in_traced_fn 主體內的迴圈取決於輸入 x 的形狀,因此當我們嘗試另一個形狀不同的 x 時,追蹤就會有所不同。
在這種情況下,可以使用 torch.jit.script() 捕獲像這樣的資料依賴性控制流程
def fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result
inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]
scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())
for input_tuple in [inputs] + check_inputs:
    torch.testing.assert_close(fn(*input_tuple), scripted_fn(*input_tuple))
產生
graph(%x : Tensor) {
    %5 : bool = prim::Constant[value=1]()
    %1 : int = prim::Constant[value=0]()
    %result.1 : Tensor = aten::select(%x, %1, %1)
    %4 : int = aten::size(%x, %1)
    %result : Tensor = prim::Loop(%4, %5, %result.1)
    block0(%i : int, %7 : Tensor) {
        %10 : Tensor = aten::select(%x, %1, %i)
        %result.2 : Tensor = aten::mul(%7, %10)
        -> (%5, %result.2)
    }
    return (%result);
}
追蹤器警告¶
追蹤器會針對追蹤計算中的幾種有問題的模式產生警告。例如,讓我們來看一個函式的追蹤,該函式包含對張量的切片(檢視)的原地賦值
def fill_row_zero(x):
    x[0] = torch.rand(*x.shape[1:2])
    return x
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
產生數個警告以及一個僅回傳輸入值的圖表
fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
    x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
    traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
    return (%0);
}
我們可以透過修改程式碼來修正此問題,不要使用原地更新,而是在外部使用 torch.cat 建立結果張量
def fill_row_zero(x):
    x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
    return x
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
常見問題¶
問:我想在 GPU 上訓練模型並在 CPU 上進行推論。最佳實務是什麼?
首先將您的模型從 GPU 轉換為 CPU,然後儲存它,如下所示
cpu_model = gpu_model.cpu() sample_input_cpu = sample_input_gpu.cpu() traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu) torch.jit.save(traced_cpu, "cpu.pt") traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu) torch.jit.save(traced_gpu, "gpu.pt") # ... later, when using the model: if use_gpu: model = torch.jit.load("gpu.pt") else: model = torch.jit.load("cpu.pt") model(input)建議這樣做是因為追蹤器可能會看到在特定裝置上建立張量,因此轉換已載入的模型可能會產生意外的影響。在儲存模型*之前*轉換模型可確保追蹤器具有正確的裝置資訊。
問:如何在 ScriptModule 上儲存屬性?
假設我們有一個像這樣的模型
import torch class Model(torch.nn.Module): def __init__(self): super().__init__() self.x = 2 def forward(self): return self.x m = torch.jit.script(Model())如果
Model被實例化,它將導致編譯錯誤,因為編譯器不知道x。有 4 種方法可以通知編譯器ScriptModule上的屬性1.
nn.Parameter- 包裝在nn.Parameter中的值將像在nn.Modules 上一樣工作2.
register_buffer- 包裝在register_buffer中的值將像在nn.Modules 上一樣工作。這相當於Tensor類型的屬性(參見 4)。3. 常數 - 將類別成員標註為
Final(或在類別定義層級將其新增到名為__constants__的清單中)會將包含的名稱標記為常數。常數直接儲存在模型的程式碼中。詳情請參閱內建常數。4. 屬性 - 支援的類型 的值可以新增為可變屬性。大多數類型都可以推斷,但有些類型可能需要指定,詳情請參閱模組屬性。
問:我想追蹤模組的方法,但我一直收到這個錯誤
RuntimeError: 無法 插入 需要 grad 的 Tensor 作為 常數。 請考慮 將其 設為 參數 或 輸入, 或 分離 梯度
這個錯誤通常表示您正在追蹤的方法使用了模組的參數,並且您傳遞的是模組的方法而不是模組實例(例如
my_module_instance.forward與my_module_instance)。
使用模組的方法呼叫
trace會將模組參數(可能需要梯度)擷取為**常數**。
另一方面,使用模組的實例(例如
my_module)呼叫trace會建立一個新的模組,並將參數正確地複製到新的模組中,以便它們在需要時可以累積梯度。要追蹤模組上的特定方法,請參閱
torch.jit.trace_module
已知問題¶
如果您在 TorchScript 中使用 Sequential,則某些 Sequential 子模組的輸入可能會被錯誤地推斷為 Tensor,即使它們被標註為其他類型。標準的解決方案是繼承 nn.Sequential 並使用正確類型的輸入重新宣告 forward。
附錄¶
遷移到 PyTorch 1.2 遞迴腳本 API¶
本節詳細說明 PyTorch 1.2 中 TorchScript 的變化。如果您是 TorchScript 的新手,可以跳過本節。PyTorch 1.2 的 TorchScript API 有兩個主要變化。
1. torch.jit.script 現在將嘗試遞迴地編譯它遇到的函數、方法和類別。一旦您呼叫 torch.jit.script,編譯就變成「選擇退出」,而不是「選擇加入」。
2. torch.jit.script(nn_module_instance) 現在是建立 ScriptModules 的首選方法,而不是繼承自 torch.jit.ScriptModule。這些變化結合起來提供了一個更簡單、更容易使用的 API,用於將您的 nn.Modules 轉換為 ScriptModules,準備在非 Python 環境中進行最佳化和執行。
新的用法如下所示
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))
my_model = Model()
my_scripted_model = torch.jit.script(my_model)
- 預設情況下,模組的 - forward會被編譯。從- forward呼叫的方法會按照它們在- forward中的使用順序被延遲編譯。
- 要編譯不是從 - forward呼叫的- forward以外的方法,請新增- @torch.jit.export。
- 要阻止編譯器編譯方法,請新增 - @torch.jit.ignore或- @torch.jit.unused。- @ignore會保留
- 方法作為對 Python 的呼叫,而 - @unused會將其替換為例外。- @ignored無法匯出;- @unused可以。
- 大多數屬性類型都可以推斷,因此不需要 - torch.jit.Attribute。對於空的容器類型,請使用 PEP 526 風格 的類別註解來標註其類型。
- 可以使用 - Final類別註解來標記常數,而不是將成員的名稱新增到- __constants__中。
- 可以使用 Python 3 類型提示來代替 - torch.jit.annotate
- 由於這些變化,以下項目被視為已棄用,不應出現在新的程式碼中
- @torch.jit.script_method裝飾器
- 繼承自 - torch.jit.ScriptModule的類別
- torch.jit.Attribute包裝類別
- __constants__陣列
- torch.jit.annotate函數
 
模組¶
警告
@torch.jit.ignore 註解的行為在 PyTorch 1.2 中有所變化。在 PyTorch 1.2 之前,@ignore 裝飾器用於使函數或方法可以從匯出的程式碼中呼叫。要恢復此功能,請使用 @torch.jit.unused()。 @torch.jit.ignore 現在相當於 @torch.jit.ignore(drop=False)。詳情請參閱 @torch.jit.ignore 和 @torch.jit.unused。
當傳遞給 torch.jit.script 函數時,torch.nn.Module 的資料會被複製到 ScriptModule 中,並且 TorchScript 編譯器會編譯該模組。預設情況下,模組的 forward 會被編譯。從 forward 呼叫的方法會按照它們在 forward 中的使用順序被延遲編譯,任何 @torch.jit.export 方法也是如此。
- torch.jit.export(fn)[原始碼]¶
- 這個裝飾器表示 - nn.Module上的一個方法被用作- ScriptModule的進入點,應該被編譯。- forward隱含地被假定為一個進入點,因此它不需要這個裝飾器。從- forward呼叫的函數和方法會在編譯器看到它們時被編譯,因此它們也不需要這個裝飾器。- 範例(在方法上使用 - @torch.jit.export)- import torch import torch.nn as nn class MyModule(nn.Module): def implicitly_compiled_method(self, x): return x + 99 # `forward` is implicitly decorated with `@torch.jit.export`, # so adding it here would have no effect def forward(self, x): return x + 10 @torch.jit.export def another_forward(self, x): # When the compiler sees this call, it will compile # `implicitly_compiled_method` return self.implicitly_compiled_method(x) def unused_method(self, x): return x - 20 # `m` will contain compiled methods: # `forward` # `another_forward` # `implicitly_compiled_method` # `unused_method` will not be compiled since it was not called from # any compiled methods and wasn't decorated with `@torch.jit.export` m = torch.jit.script(MyModule()) 
函數¶
函數沒有太大變化,如果需要,它們可以用 @torch.jit.ignore 或 @torch.jit.unused 來裝飾。
# Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():
    return 2
# Marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
    return 2
# As with ignore, if nothing calls it then it has no effect.
# If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():
  import pdb; pdb.set_trace()
  return 4
# Doesn't do anything, this function is already
# the main entry point
@torch.jit.export
def some_fn4():
    return 2
TorchScript 類別¶
警告
TorchScript 類別支援尚在實驗階段。目前它最適合簡單的記錄類型(類似於附加了方法的 NamedTuple)。
使用者定義的 TorchScript 類別 中的所有內容都會預設匯出,如果需要,可以使用 @torch.jit.ignore 裝飾函數。
屬性¶
TorchScript 編譯器需要知道 模組屬性 的類型。大多數類型可以從成員的值推斷出來。空串列和字典無法推斷其類型,因此必須使用 PEP 526 風格 的類別註釋來註釋其類型。如果無法推斷類型且未明確註釋,則不會將其作為屬性添加到結果 ScriptModule 中。
舊 API
from typing import Dict
import torch
class MyModule(torch.jit.ScriptModule):
    def __init__(self):
        super().__init__()
        self.my_dict = torch.jit.Attribute({}, Dict[str, int])
        self.my_int = torch.jit.Attribute(20, int)
m = MyModule()
新 API
from typing import Dict
class MyModule(torch.nn.Module):
    my_dict: Dict[str, int]
    def __init__(self):
        super().__init__()
        # This type cannot be inferred and must be specified
        self.my_dict = {}
        # The attribute type here is inferred to be `int`
        self.my_int = 20
    def forward(self):
        pass
m = torch.jit.script(MyModule())
常數¶
Final 類型建構函數可用於將成員標記為 常數。如果成員未標記為常數,則會將其作為屬性複製到結果 ScriptModule。如果已知值是固定的,則使用 Final 可以開啟最佳化的機會,並提供額外的類型安全性。
舊 API
class MyModule(torch.jit.ScriptModule):
    __constants__ = ['my_constant']
    def __init__(self):
        super().__init__()
        self.my_constant = 2
    def forward(self):
        pass
m = MyModule()
新 API
from typing import Final
class MyModule(torch.nn.Module):
    my_constant: Final[int]
    def __init__(self):
        super().__init__()
        self.my_constant = 2
    def forward(self):
        pass
m = torch.jit.script(MyModule())
變數¶
假設容器的類型為 Tensor 且不可為空(有關詳細資訊,請參閱 預設類型)。以前,使用 torch.jit.annotate 來告知 TorchScript 編譯器類型應該是什麼。現在支援 Python 3 風格的類型提示。
import torch
from typing import Dict, Optional
@torch.jit.script
def make_dict(flag: bool):
    x: Dict[str, int] = {}
    x['hi'] = 2
    b: Optional[int] = None
    if flag:
        b = 2
    return x, b
融合後端¶
有幾個融合後端可用於最佳化 TorchScript 執行。CPU 上的預設融合器是 NNC,它可以執行 CPU 和 GPU 的融合。GPU 上的預設融合器是 NVFuser,它支援更廣泛的運算符,並且已展示了具有更高輸送量的已產生核心。如需使用和除錯的詳細資訊,請參閱 NVFuser 文件。