TorchScript 語言參考¶
TorchScript 是 Python 的靜態類型子集,可以直接撰寫(使用 @torch.jit.script 裝飾器)或透過追蹤從 Python 程式碼自動產生。使用追蹤時,程式碼會自動轉換為此 Python 子集,方法是僅記錄張量上的實際運算子,並簡單地執行和捨棄其他周圍的 Python 程式碼。
使用 @torch.jit.script 裝飾器直接撰寫 TorchScript 時,程式設計師只能使用 TorchScript 中支援的 Python 子集。本節說明 TorchScript 中支援的內容,就好像它是獨立語言的語言參考一樣。本參考中未提及的任何 Python 功能都不是 TorchScript 的一部分。如需可用 PyTorch 張量方法、模組和函數的完整參考,請參閱內建函數。
作為 Python 的子集,任何有效的 TorchScript 函數也是有效的 Python 函數。這使得停用 TorchScript並使用標準 Python 工具(如 pdb)偵錯函數成為可能。反之則不然:有許多有效的 Python 程式不是有效的 TorchScript 程式。相反地,TorchScript 特別專注於在 PyTorch 中表示神經網路模型所需的 Python 功能。
類型¶
TorchScript 與完整 Python 語言之間最大的區別在於,TorchScript 僅支援表達神經網路模型所需的一小部分類型。特別是,TorchScript 支援
| 類型 | 說明 | 
|---|---|
| 
 | 任何 dtype、維度或後端的 PyTorch 張量 | 
| 
 | 包含子類型  | 
| 
 | 布林值 | 
| 
 | 純量整數 | 
| 
 | 純量浮點數 | 
| 
 | 字串 | 
| 
 | 所有成員都是類型  | 
| 
 | 值為 None 或類型  | 
| 
 | 具有鍵類型  | 
| 
 | |
| 
 | |
| 
 | |
| 
 | 子類型  | 
與 Python 不同,TorchScript 函數中的每個變數都必須具有單一靜態類型。這使得最佳化 TorchScript 函數變得更加容易。
範例(類型不符)
import torch
@torch.jit.script
def an_error(x):
    if x:
        r = torch.rand(1)
    else:
        r = 4
    return r
Traceback (most recent call last):
  ...
RuntimeError: ...
Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
@torch.jit.script
def an_error(x):
    if x:
    ~~~~~
        r = torch.rand(1)
        ~~~~~~~~~~~~~~~~~
    else:
    ~~~~~
        r = 4
        ~~~~~ <--- HERE
    return r
and was used here:
    else:
        r = 4
    return r
           ~ <--- HERE...
不支援的 Typing 建構
TorchScript 不支援 typing 模組的所有功能和類型。其中一些是更基礎的東西,未來不太可能加入,而另一些則可能會在有足夠的使用者需求使其成為優先事項時加入。
typing 模組中的這些類型和功能在 TorchScript 中不可用。
| 項目 | 說明 | 
|---|---|
| 
 | |
| 未實作 | |
| 未實作 | |
| 未實作 | |
| 未實作 | |
| 未實作 | |
| 這適用於 模組屬性 類別屬性註釋,但不適用於函數 | |
| TorchScript 不支援  | |
| 
 | |
| 類型別名 | 未實作 | 
| 標名與結構子類型 | 標名類型正在開發中,但結構類型則尚未 | 
| NewType | 不太可能實作 | 
| 泛型 | 不太可能實作 | 
任何其他未在本文檔中明確列出的 typing 模組功能皆不受支援。
預設類型¶
根據預設,TorchScript 函式的所有參數都假設為 Tensor。若要指定 TorchScript 函式的引數為其他類型,可以使用上面列出的類型,以 MyPy 樣式類型註解來指定。
import torch
@torch.jit.script
def foo(x, tup):
    # type: (int, Tuple[Tensor, Tensor]) -> Tensor
    t0, t1 = tup
    return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
注意
也可以使用 typing 模組中的 Python 3 類型提示來註解類型。
import torch
from typing import Tuple
@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
    t0, t1 = tup
    return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
空清單假設為 List[Tensor],空字典則假設為 Dict[str, Tensor]。若要將空清單或字典實例化為其他類型,請使用 Python 3 類型提示.
範例(Python 3 的類型註解)
import torch
import torch.nn as nn
from typing import Dict, List, Tuple
class EmptyDataStructures(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
        # This annotates the list to be a `List[Tuple[int, float]]`
        my_list: List[Tuple[int, float]] = []
        for i in range(10):
            my_list.append((i, x.item()))
        my_dict: Dict[str, int] = {}
        return my_list, my_dict
x = torch.jit.script(EmptyDataStructures())
可選類型細化¶
當在 if 陳述式的條件式內將類型為 Optional[T] 的變數與 None 進行比較,或在 assert 中檢查時,TorchScript 會細化該變數的類型。編譯器可以推斷多個以 and、or 和 not 組合的 None 檢查。未明確寫入的 if 陳述式 else 區塊也會進行細化。
None 檢查必須在 if 陳述式的條件式內;將 None 檢查指派給變數並在 if 陳述式的條件式中使用,將不會細化檢查中變數的類型。只有區域變數會被細化,像 self.x 這樣的屬性則不會,必須指派給區域變數才能被細化。
範例(細化參數和區域變數的類型)
import torch
import torch.nn as nn
from typing import Optional
class M(nn.Module):
    z: Optional[int]
    def __init__(self, z):
        super().__init__()
        # If `z` is None, its type cannot be inferred, so it must
        # be specified (above)
        self.z = z
    def forward(self, x, y, z):
        # type: (Optional[int], Optional[int], Optional[int]) -> int
        if x is None:
            x = 1
            x = x + 1
        # Refinement for an attribute by assigning it to a local
        z = self.z
        if y is not None and z is not None:
            x = y + z
        # Refinement via an `assert`
        assert z is not None
        x += z
        return x
module = torch.jit.script(M(2))
module = torch.jit.script(M(None))
TorchScript 類別¶
警告
TorchScript 類別支援尚在實驗階段。目前最適合用於簡單的記錄類型(例如附加了方法的 NamedTuple)。
如果 Python 類別使用 @torch.jit.script 進行註解,則可以在 TorchScript 中使用,類似於宣告 TorchScript 函式的方式
@torch.jit.script
class Foo:
  def __init__(self, x, y):
    self.x = x
  def aug_add_x(self, inc):
    self.x += inc
此子集受到限制
- 所有函式都必須是有效的 TorchScript 函式(包括 - __init__())。
- 類別必須是新式類別,因為我們使用 - __new__()搭配 pybind11 來建構它們。
- TorchScript 類別是靜態類型的。成員只能透過在 - __init__()方法中指派給 self 來宣告。- 例如,在 - __init__()方法之外指派給- self- @torch.jit.script class Foo: def assign_x(self): self.x = torch.rand(2, 3) - 將會導致 - RuntimeError: Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: def assign_x(self): self.x = torch.rand(2, 3) ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE 
- 除了方法定義之外,類別主體中不允許出現任何運算式。 
- 除了繼承自 - object以指定新式類別之外,不支援繼承或任何其他多型策略。
定義類別後,就可以像任何其他 TorchScript 類型一樣,在 TorchScript 和 Python 中互換使用
# Declare a TorchScript class
@torch.jit.script
class Pair:
  def __init__(self, first, second):
    self.first = first
    self.second = second
@torch.jit.script
def sum_pair(p):
  # type: (Pair) -> Tensor
  return p.first + p.second
p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))
TorchScript 列舉¶
Python 列舉可以在 TorchScript 中使用,無需任何額外註解或程式碼
from enum import Enum
class Color(Enum):
    RED = 1
    GREEN = 2
@torch.jit.script
def enum_fn(x: Color, y: Color) -> bool:
    if x == Color.RED:
        return True
    return x == y
定義列舉後,就可以像任何其他 TorchScript 類型一樣,在 TorchScript 和 Python 中互換使用。列舉值的類型必須是 int、float 或 str。所有值都必須是相同的類型;不支援列舉值的異質類型。
具名元組¶
collections.namedtuple 生成的類型可以在 TorchScript 中使用。
import torch
import collections
Point = collections.namedtuple('Point', ['x', 'y'])
@torch.jit.script
def total(point):
    # type: (Point) -> Tensor
    return point.x + point.y
p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))
可迭代物件¶
某些函式(例如,zip 和 enumerate)只能對可迭代類型進行操作。TorchScript 中的可迭代類型包括 Tensor、清單、元組、字典、字串、torch.nn.ModuleList 和 torch.nn.ModuleDict。
運算式¶
支援以下 Python 運算式。
字面值¶
True
False
None
'string literals'
"string literals"
3  # interpreted as int
3.4  # interpreted as a float
清單建構¶
空清單的類型假設為 List[Tensor]。其他清單字面值的類型是從成員的類型衍生而來。如需更多詳細資訊,請參閱 預設類型。
[3, 4]
[]
[torch.rand(3), torch.rand(4)]
元組建構¶
(3, 4)
(3,)
算術運算子¶
a + b
a - b
a * b
a / b
a ^ b
a @ b
比較運算子¶
a == b
a != b
a < b
a > b
a <= b
a >= b
邏輯運算子¶
a and b
a or b
not b
下標和切片¶
t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]
函式呼叫¶
對 內建函式 的呼叫
torch.rand(3, dtype=torch.int)
對其他腳本函式的呼叫
import torch
@torch.jit.script
def foo(x):
    return x + 1
@torch.jit.script
def bar(x):
    return foo(x)
方法呼叫¶
對內建類型(例如 tensor)的方法呼叫:x.mm(y)
在模組上,方法必須先經過編譯才能呼叫。TorchScript 編譯器會在編譯其他方法時,遞迴地編譯它看到的方法。根據預設,編譯從 forward 方法開始。任何由 forward 呼叫的方法都會被編譯,而任何由這些方法呼叫的方法也會被編譯,依此類推。若要從 forward 以外的方法開始編譯,請使用 @torch.jit.export 裝飾器(forward 隱含標記為 @torch.jit.export)。
直接呼叫子模組(例如 self.resnet(input))等同於呼叫其 forward 方法(例如 self.resnet.forward(input))。
import torch
import torch.nn as nn
import torchvision
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        means = torch.tensor([103.939, 116.779, 123.68])
        self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
        resnet = torchvision.models.resnet18()
        self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))
    def helper(self, input):
        return self.resnet(input - self.means)
    def forward(self, input):
        return self.helper(input)
    # Since nothing in the model calls `top_level_method`, the compiler
    # must be explicitly told to compile this method
    @torch.jit.export
    def top_level_method(self, input):
        return self.other_helper(input)
    def other_helper(self, input):
        return input + 10
# `my_script_module` will have the compiled methods `forward`, `helper`,
# `top_level_method`, and `other_helper`
my_script_module = torch.jit.script(MyModule())
三元運算式¶
x if x > y else y
轉型¶
float(ten)
int(3.5)
bool(ten)
str(2)``
存取模組參數¶
self.my_parameter
self.my_submodule.my_parameter
陳述式¶
TorchScript 支援以下類型的陳述式
簡單指派¶
a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b
列印陳述式¶
print("the result of an add:", a + b)
If 陳述式¶
if a < 4:
    r = -a
elif a < 3:
    r = a + a
else:
    r = 3 * a
除了布林值之外,浮點數、整數和 Tensor 都可以在條件式中使用,並且會被隱式轉換為布林值。
While 迴圈¶
a = 0
while a < 4:
    print(a)
    a += 1
使用 range 的 For 迴圈¶
x = 0
for i in range(10):
    x *= i
迭代元組的 For 迴圈¶
這些會展開迴圈,為元組的每個成員產生一個主體。主體必須針對每個成員進行正確的類型檢查。
tup = (3, torch.rand(4))
for x in tup:
    print(x)
迭代常數 nn.ModuleList 的 For 迴圈¶
若要在已編譯的方法中使用 nn.ModuleList,必須透過將屬性名稱新增至類型的 __constants__ 清單中,將其標記為常數。迭代 nn.ModuleList 的 For 迴圈會在編譯時展開迴圈主體,其中包含常數模組清單的每個成員。
class SubModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(2))
    def forward(self, input):
        return self.weight + input
class MyModule(torch.nn.Module):
    __constants__ = ['mods']
    def __init__(self):
        super().__init__()
        self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])
    def forward(self, v):
        for module in self.mods:
            v = module(v)
        return v
m = torch.jit.script(MyModule())
Break 和 Continue¶
for i in range(5):
    if i == 1:
        continue
    if i == 3:
        break
    print(i)
Return¶
return a, b
變數解析¶
TorchScript 支援 Python 變數解析(即作用域)規則的子集。局部變數的行為與 Python 中的行為相同,但有一個限制,即變數在函數的所有路徑中必須具有相同的類型。如果一個變數在 if 語句的不同分支上具有不同的類型,則在 if 語句結束後使用它是錯誤的。
同樣地,如果一個變數僅在函數的某些路徑上被*定義*,則不允許使用該變數。
範例
@torch.jit.script
def foo(x):
    if x < 0:
        y = 4
    print(y)
Traceback (most recent call last):
  ...
RuntimeError: ...
y is not defined in the false branch...
@torch.jit.script...
def foo(x):
    if x < 0:
    ~~~~~~~~~
        y = 4
        ~~~~~ <--- HERE
    print(y)
and was used here:
    if x < 0:
        y = 4
    print(y)
          ~ <--- HERE...
在定義函數時,非局部變數會在編譯時解析為 Python 值。然後,使用使用 Python 值中描述的規則將這些值轉換為 TorchScript 值。
使用 Python 值¶
為了使編寫 TorchScript 更方便,我們允許腳本程式碼引用周圍作用域中的 Python 值。例如,每當引用torch時,TorchScript 編譯器實際上是在聲明函數時將其解析為torch Python 模組。這些 Python 值不是 TorchScript 的第一類組成部分。相反,它們在編譯時被解糖為 TorchScript 支持的原始類型。這取決於編譯時引用的 Python 值的動態類型。本節描述在 TorchScript 中訪問 Python 值時使用的規則。
函數¶
TorchScript 可以呼叫 Python 函數。當逐步將模型轉換為 TorchScript 時,此功能非常有用。可以將模型逐個函數地移動到 TorchScript,並保留對 Python 函數的呼叫。這樣,您就可以在進行過程中逐步檢查模型的正確性。
Python 模組上的屬性查找¶
TorchScript 可以查找模組上的屬性。內建函數(如torch.add)就是通過這種方式訪問的。這允許 TorchScript 呼叫在其他模組中定義的函數。
Python 定義的常數¶
TorchScript 還提供了一種使用 Python 中定義的常數的方法。這些常數可用於將超參數硬編碼到函數中,或定義通用常數。有兩種方法可以指定應將 Python 值視為常數。
- 作為模組的屬性查找的值假定為常數 
import math
import torch
@torch.jit.script
def fn():
    return math.pi
- 可以使用 - Final[T]註釋 ScriptModule 的屬性,將其標記為常數
import torch
import torch.nn as nn
class Foo(nn.Module):
    # `Final` from the `typing_extensions` module can also be used
    a : torch.jit.Final[int]
    def __init__(self):
        super().__init__()
        self.a = 1 + 4
    def forward(self, input):
        return self.a + input
f = torch.jit.script(Foo())
支持的常數 Python 類型有
- 整數
- 浮點數
- 布林值
- torch.device
- torch.layout
- torch.dtype
- 包含支持類型的元組 
- torch.nn.ModuleList,可以在 TorchScript for 迴圈中使用
模組屬性¶
torch.nn.Parameter包裝器和register_buffer可用於將張量分配給模組。分配給已編譯模組的其他值,如果可以推斷出其類型,則將添加到已編譯模組中。TorchScript 中可用的所有類型都可以用作模組屬性。張量屬性在語義上與緩衝區相同。無法推斷空串列和字典以及None值的類型,必須通過PEP 526 風格的類註釋來指定。如果無法推斷類型且未明確註釋,則不會將其作為屬性添加到生成的ScriptModule中。
範例
from typing import List, Dict
class Foo(nn.Module):
    # `words` is initialized as an empty list, so its type must be specified
    words: List[str]
    # The type could potentially be inferred if `a_dict` (below) was not
    # empty, but this annotation ensures `some_dict` will be made into the
    # proper type
    some_dict: Dict[str, int]
    def __init__(self, a_dict):
        super().__init__()
        self.words = []
        self.some_dict = a_dict
        # `int`s can be inferred
        self.my_int = 10
    def forward(self, input):
        # type: (str) -> int
        self.words.append(input)
        return self.some_dict[input] + self.my_int
f = torch.jit.script(Foo({'hi': 2}))