TorchScript 語言參考¶
TorchScript 是 Python 的靜態型別子集,可以直接編寫(使用 @torch.jit.script 裝飾器)或透過跟蹤從 Python 程式碼自動生成。使用跟蹤時,程式碼會自動轉換為 Python 的這個子集,僅記錄 tensor 上的實際運算元,而簡單地執行並丟棄周圍的其他 Python 程式碼。
使用 @torch.jit.script 裝飾器直接編寫 TorchScript 時,程式設計師必須只使用 TorchScript 支援的 Python 子集。本節文件介紹了 TorchScript 中支援的內容,就像它是獨立語言的參考一樣。本參考中未提及的任何 Python 特性都不是 TorchScript 的一部分。有關可用 PyTorch tensor 方法、模組和函式的完整參考,請參見 內建函式。
作為 Python 的子集,任何有效的 TorchScript 函式也是有效的 Python 函式。這使得停用 TorchScript 並使用標準的 Python 工具(例如 pdb)除錯函式成為可能。反之則不然:許多有效的 Python 程式並不是有效的 TorchScript 程式。相反,TorchScript 專門關注表示 PyTorch 中神經網路模型所需的 Python 特性。
型別¶
TorchScript 與完整 Python 語言最大的區別在於,TorchScript 只支援表達神經網路模型所需的一小部分型別。具體來說,TorchScript 支援以下型別:
型別 |
描述 |
|---|---|
|
任何 dtype、維度或後端的 PyTorch tensor |
|
包含子型別 |
|
布林值 |
|
標量整數 |
|
標量浮點數 |
|
字串 |
|
其中所有成員都是型別 |
|
一個值,可以是 None 或型別 |
|
鍵型別為 |
|
|
|
|
|
一個 |
|
子型別 |
與 Python 不同,TorchScript 函式中的每個變數都必須具有單一的靜態型別。這使得最佳化 TorchScript 函式更加容易。
示例(型別不匹配)
import torch
@torch.jit.script
def an_error(x):
if x:
r = torch.rand(1)
else:
r = 4
return r
Traceback (most recent call last):
...
RuntimeError: ...
Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
@torch.jit.script
def an_error(x):
if x:
~~~~~
r = torch.rand(1)
~~~~~~~~~~~~~~~~~
else:
~~~~~
r = 4
~~~~~ <--- HERE
return r
and was used here:
else:
r = 4
return r
~ <--- HERE...
不支援的型別構造¶
TorchScript 不支援 typing 模組的所有特性和型別。其中一些是更基礎的東西,未來不太可能新增,而另一些則可能會在使用者需求足夠多且成為優先事項時新增。
TorchScript 中不支援 typing 模組的以下型別和特性。
專案 |
描述 |
|---|---|
|
|
未實現 |
|
未實現 |
|
未實現 |
|
未實現 |
|
未實現 |
|
對於模組屬性的類屬性註解支援此項,但對於函式不支援 |
|
TorchScript 不支援 |
|
|
|
類型別名 |
未實現 |
名義子型別 vs 結構子型別 |
名義型別正在開發中,但結構型別尚未開發 |
NewType |
不太可能實現 |
泛型 |
不太可能實現 |
本文件中未明確列出的 typing 模組中的任何其他功能均不受支援。
預設型別¶
預設情況下,TorchScript 函式的所有引數都假定為 Tensor。要指定 TorchScript 函式的引數是另一種型別,可以使用上面列出的 MyPy 風格的型別註解。
import torch
@torch.jit.script
def foo(x, tup):
# type: (int, Tuple[Tensor, Tensor]) -> Tensor
t0, t1 = tup
return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
注意
也可以使用 typing 模組中的 Python 3 型別提示來註解型別。
import torch
from typing import Tuple
@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
t0, t1 = tup
return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
空列表假定為 List[Tensor],空字典假定為 Dict[str, Tensor]。要例項化其他型別的空列表或字典,請使用Python 3 型別提示。
示例(Python 3 的型別註解)
import torch
import torch.nn as nn
from typing import Dict, List, Tuple
class EmptyDataStructures(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
# This annotates the list to be a `List[Tuple[int, float]]`
my_list: List[Tuple[int, float]] = []
for i in range(10):
my_list.append((i, x.item()))
my_dict: Dict[str, int] = {}
return my_list, my_dict
x = torch.jit.script(EmptyDataStructures())
Optional 型別細化¶
當在 if 語句的條件內進行與 None 的比較或在 assert 中檢查時,TorchScript 將細化型別為 Optional[T] 的變數的型別。編譯器可以處理與 and、or 和 not 組合的多個 None 檢查。對於未明確編寫的 if 語句的 else 塊,也會發生細化。
None 檢查必須在 if 語句的條件內;將 None 檢查賦值給變數並在 if 語句的條件中使用它不會細化檢查中變數的型別。只有區域性變數會被細化,像 self.x 這樣的屬性不會,必須賦值給區域性變數才能被細化。
示例(細化引數和區域性變數的型別)
import torch
import torch.nn as nn
from typing import Optional
class M(nn.Module):
z: Optional[int]
def __init__(self, z):
super().__init__()
# If `z` is None, its type cannot be inferred, so it must
# be specified (above)
self.z = z
def forward(self, x, y, z):
# type: (Optional[int], Optional[int], Optional[int]) -> int
if x is None:
x = 1
x = x + 1
# Refinement for an attribute by assigning it to a local
z = self.z
if y is not None and z is not None:
x = y + z
# Refinement via an `assert`
assert z is not None
x += z
return x
module = torch.jit.script(M(2))
module = torch.jit.script(M(None))
TorchScript 類¶
警告
TorchScript 類支援是實驗性的。目前它最適合簡單的記錄類型別(可以將其視為帶有方法的 NamedTuple)。
如果使用 @torch.jit.script 進行註解,Python 類可以在 TorchScript 中使用,這類似於宣告 TorchScript 函式的方式
@torch.jit.script
class Foo:
def __init__(self, x, y):
self.x = x
def aug_add_x(self, inc):
self.x += inc
此子集受到限制
所有函式都必須是有效的 TorchScript 函式(包括
__init__())。類必須是 new-style 類,因為我們使用
__new__()和 pybind11 來構造它們。TorchScript 類是靜態型別的。成員只能透過在
__init__()方法中賦值給 self 來宣告。例如,在
__init__()方法之外賦值給self@torch.jit.script class Foo: def assign_x(self): self.x = torch.rand(2, 3)
將導致
RuntimeError: Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: def assign_x(self): self.x = torch.rand(2, 3) ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
類主體中除了方法定義之外不允許有任何表示式。
不支援繼承或任何其他多型策略,除了繼承
object以指定 new-style 類。
定義類後,它可以在 TorchScript 和 Python 中像任何其他 TorchScript 型別一樣互換使用
# Declare a TorchScript class
@torch.jit.script
class Pair:
def __init__(self, first, second):
self.first = first
self.second = second
@torch.jit.script
def sum_pair(p):
# type: (Pair) -> Tensor
return p.first + p.second
p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))
TorchScript 列舉¶
Python 列舉可以在 TorchScript 中使用,無需任何額外註解或程式碼
from enum import Enum
class Color(Enum):
RED = 1
GREEN = 2
@torch.jit.script
def enum_fn(x: Color, y: Color) -> bool:
if x == Color.RED:
return True
return x == y
定義列舉後,它可以在 TorchScript 和 Python 中像任何其他 TorchScript 型別一樣互換使用。列舉值的型別必須是 int、float 或 str。所有值必須具有相同的型別;不支援列舉值的異構型別。
具名元組¶
由 collections.namedtuple 產生的型別可以在 TorchScript 中使用。
import torch
import collections
Point = collections.namedtuple('Point', ['x', 'y'])
@torch.jit.script
def total(point):
# type: (Point) -> Tensor
return point.x + point.y
p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))
可迭代物件¶
某些函式(例如 zip 和 enumerate)只能作用於可迭代型別。TorchScript 中的可迭代型別包括 Tensor、列表、元組、字典、字串、torch.nn.ModuleList 和 torch.nn.ModuleDict。
表示式¶
支援以下 Python 表示式。
字面量¶
True
False
None
'string literals'
"string literals"
3 # interpreted as int
3.4 # interpreted as a float
列表構建¶
空列表假定型別為 List[Tensor]。其他列表字面量的型別派生自成員的型別。有關更多詳細資訊,請參見預設型別。
[3, 4]
[]
[torch.rand(3), torch.rand(4)]
元組構建¶
(3, 4)
(3,)
算術運算子¶
a + b
a - b
a * b
a / b
a ^ b
a @ b
比較運算子¶
a == b
a != b
a < b
a > b
a <= b
a >= b
邏輯運算子¶
a and b
a or b
not b
下標和切片¶
t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]
函式呼叫¶
呼叫內建函式
torch.rand(3, dtype=torch.int)
呼叫其他 script 函式
import torch
@torch.jit.script
def foo(x):
return x + 1
@torch.jit.script
def bar(x):
return foo(x)
方法呼叫¶
呼叫內建型別(如 tensor)的方法:x.mm(y)
對於模組,必須先編譯方法才能呼叫它們。TorchScript 編譯器在編譯其他方法時會遞迴編譯它看到的方法。預設情況下,編譯從 forward 方法開始。 forward 呼叫的任何方法都會被編譯,這些方法呼叫的任何方法也會被編譯,依此類推。要從 forward 以外的方法開始編譯,請使用 @torch.jit.export 裝飾器(forward 隱式標記為 @torch.jit.export)。
直接呼叫子模組(例如 self.resnet(input))等同於呼叫其 forward 方法(例如 self.resnet.forward(input))。
import torch
import torch.nn as nn
import torchvision
class MyModule(nn.Module):
def __init__(self):
super().__init__()
means = torch.tensor([103.939, 116.779, 123.68])
self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
resnet = torchvision.models.resnet18()
self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))
def helper(self, input):
return self.resnet(input - self.means)
def forward(self, input):
return self.helper(input)
# Since nothing in the model calls `top_level_method`, the compiler
# must be explicitly told to compile this method
@torch.jit.export
def top_level_method(self, input):
return self.other_helper(input)
def other_helper(self, input):
return input + 10
# `my_script_module` will have the compiled methods `forward`, `helper`,
# `top_level_method`, and `other_helper`
my_script_module = torch.jit.script(MyModule())
三元表示式¶
x if x > y else y
型別轉換¶
float(ten)
int(3.5)
bool(ten)
str(2)``
訪問模組引數¶
self.my_parameter
self.my_submodule.my_parameter
語句¶
TorchScript 支援以下型別的語句
簡單賦值¶
a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b
列印語句¶
print("the result of an add:", a + b)
If 語句¶
if a < 4:
r = -a
elif a < 3:
r = a + a
else:
r = 3 * a
除了布林值,浮點數、整數和 Tensor 也可以在條件語句中使用,並會被隱式轉換為布林值。
While 迴圈¶
a = 0
while a < 4:
print(a)
a += 1
帶 range 的 For 迴圈¶
x = 0
for i in range(10):
x *= i
遍歷元組的 For 迴圈¶
這些會展開迴圈,為元組的每個成員生成一個主體。主體必須對每個成員進行正確的型別檢查。
tup = (3, torch.rand(4))
for x in tup:
print(x)
遍歷常量 nn.ModuleList 的 For 迴圈¶
要在已編譯的方法中使用 nn.ModuleList,必須透過將屬性名稱新增到型別的 __constants__ 列表中來將其標記為常量。遍歷 nn.ModuleList 的 For 迴圈會在編譯時展開迴圈主體,並處理常量模組列表的每個成員。
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(2))
def forward(self, input):
return self.weight + input
class MyModule(torch.nn.Module):
__constants__ = ['mods']
def __init__(self):
super().__init__()
self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])
def forward(self, v):
for module in self.mods:
v = module(v)
return v
m = torch.jit.script(MyModule())
Break 和 Continue¶
for i in range(5):
if i == 1:
continue
if i == 3:
break
print(i)
Return¶
return a, b
變數解析¶
TorchScript 支援 Python 變數解析(即作用域)規則的一個子集。區域性變數的行為與 Python 中相同,但有一個限制:變數在透過函式的所有路徑上必須具有相同的型別。如果變數在 if 語句的不同分支上具有不同的型別,則在 if 語句結束之後使用它是錯誤的。
同樣,如果在函式中只沿某些路徑定義了變數,則不允許使用該變數。
示例
@torch.jit.script
def foo(x):
if x < 0:
y = 4
print(y)
Traceback (most recent call last):
...
RuntimeError: ...
y is not defined in the false branch...
@torch.jit.script...
def foo(x):
if x < 0:
~~~~~~~~~
y = 4
~~~~~ <--- HERE
print(y)
and was used here:
if x < 0:
y = 4
print(y)
~ <--- HERE...
非區域性變數在定義函式時在編譯時解析為 Python 值。然後使用Python 值的使用中描述的規則將這些值轉換為 TorchScript 值。
Python 值的使用¶
為了讓編寫 TorchScript 更方便,我們允許指令碼程式碼引用周圍作用域中的 Python 值。例如,每當引用 torch 時,TorchScript 編譯器實際上是在函式宣告時將其解析為 torch Python 模組。這些 Python 值不是 TorchScript 的一等公民。相反,它們在編譯時被“去糖化”為 TorchScript 支援的基本型別。這取決於編譯時引用的 Python 值的動態型別。本節描述了在 TorchScript 中訪問 Python 值時使用的規則。
函式¶
TorchScript 可以呼叫 Python 函式。在將模型逐步轉換為 TorchScript 時,此功能非常有用。可以逐個函式地將模型移動到 TorchScript,同時保留對 Python 函式的呼叫。這樣您就可以在轉換過程中逐步檢查模型的正確性。
- torch.jit.is_scripting()[source][source]¶
在編譯時返回 True,否則返回 False 的函式。這在使用 @unused 裝飾器時特別有用,可以將模型中尚未與 TorchScript 相容的程式碼保留下來。.. testcode
import torch @torch.jit.unused def unsupported_linear_op(x): return x def linear(x): if torch.jit.is_scripting(): return torch.linear(x) else: return unsupported_linear_op(x)
- 返回型別
在 Python 模組上查詢屬性¶
TorchScript 可以查詢模組上的屬性。透過這種方式訪問內建函式,例如 torch.add。這使得 TorchScript 可以呼叫其他模組中定義的函式。
Python 定義的常量¶
TorchScript 還提供了使用 Python 中定義的常量的方法。這些常量可以用於將超引數硬編碼到函式中,或定義通用常量。有兩種方法可以指定將 Python 值視為常量。
作為模組屬性查詢的值假定為常量
import math
import torch
@torch.jit.script
def fn():
return math.pi
ScriptModule 的屬性可以透過使用
Final[T]註解來標記為常量
import torch
import torch.nn as nn
class Foo(nn.Module):
# `Final` from the `typing_extensions` module can also be used
a : torch.jit.Final[int]
def __init__(self):
super().__init__()
self.a = 1 + 4
def forward(self, input):
return self.a + input
f = torch.jit.script(Foo())
支援的常量 Python 型別有
intfloatbooltorch.devicetorch.layouttorch.dtype包含支援型別的元組
可在 TorchScript for 迴圈中使用的
torch.nn.ModuleList
模組屬性¶
torch.nn.Parameter 包裝器和 register_buffer 可用於將 tensor 賦值給模組。如果其他賦值給已編譯模組的值的型別可以推斷,則這些值將被新增到已編譯的模組中。TorchScript 中所有可用的型別都可以用作模組屬性。Tensor 屬性在語義上與 buffer 相同。空列表和字典以及 None 值的型別無法推斷,必須透過 PEP 526 風格的類註解來指定。如果無法推斷型別且未明確註解,則不會將其作為屬性新增到生成的 ScriptModule 中。
示例
from typing import List, Dict
class Foo(nn.Module):
# `words` is initialized as an empty list, so its type must be specified
words: List[str]
# The type could potentially be inferred if `a_dict` (below) was not
# empty, but this annotation ensures `some_dict` will be made into the
# proper type
some_dict: Dict[str, int]
def __init__(self, a_dict):
super().__init__()
self.words = []
self.some_dict = a_dict
# `int`s can be inferred
self.my_int = 10
def forward(self, input):
# type: (str) -> int
self.words.append(input)
return self.some_dict[input] + self.my_int
f = torch.jit.script(Foo({'hi': 2}))