快捷方式

TorchScript 語言參考

TorchScript 是 Python 的靜態型別子集,可以直接編寫(使用 @torch.jit.script 裝飾器)或透過跟蹤從 Python 程式碼自動生成。使用跟蹤時,程式碼會自動轉換為 Python 的這個子集,僅記錄 tensor 上的實際運算元,而簡單地執行並丟棄周圍的其他 Python 程式碼。

使用 @torch.jit.script 裝飾器直接編寫 TorchScript 時,程式設計師必須只使用 TorchScript 支援的 Python 子集。本節文件介紹了 TorchScript 中支援的內容,就像它是獨立語言的參考一樣。本參考中未提及的任何 Python 特性都不是 TorchScript 的一部分。有關可用 PyTorch tensor 方法、模組和函式的完整參考,請參見 內建函式

作為 Python 的子集,任何有效的 TorchScript 函式也是有效的 Python 函式。這使得停用 TorchScript 並使用標準的 Python 工具(例如 pdb)除錯函式成為可能。反之則不然:許多有效的 Python 程式並不是有效的 TorchScript 程式。相反,TorchScript 專門關注表示 PyTorch 中神經網路模型所需的 Python 特性。

型別

TorchScript 與完整 Python 語言最大的區別在於,TorchScript 只支援表達神經網路模型所需的一小部分型別。具體來說,TorchScript 支援以下型別:

型別

描述

Tensor

任何 dtype、維度或後端的 PyTorch tensor

Tuple[T0, T1, ..., TN]

包含子型別 T0T1 等的元組(例如 Tuple[Tensor, Tensor]

bool

布林值

int

標量整數

float

標量浮點數

str

字串

List[T]

其中所有成員都是型別 T 的列表

Optional[T]

一個值,可以是 None 或型別 T

Dict[K, V]

鍵型別為 K、值型別為 V 的字典。只允許將 strintfloat 作為鍵型別。

T

一個 TorchScript 類

E

一個 TorchScript 列舉

NamedTuple[T0, T1, ...]

一個 collections.namedtuple 元組型別

Union[T0, T1, ...]

子型別 T0T1 等之一

與 Python 不同,TorchScript 函式中的每個變數都必須具有單一的靜態型別。這使得最佳化 TorchScript 函式更加容易。

示例(型別不匹配)

import torch

@torch.jit.script
def an_error(x):
    if x:
        r = torch.rand(1)
    else:
        r = 4
    return r
Traceback (most recent call last):
  ...
RuntimeError: ...

Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
@torch.jit.script
def an_error(x):
    if x:
    ~~~~~
        r = torch.rand(1)
        ~~~~~~~~~~~~~~~~~
    else:
    ~~~~~
        r = 4
        ~~~~~ <--- HERE
    return r
and was used here:
    else:
        r = 4
    return r
           ~ <--- HERE...

不支援的型別構造

TorchScript 不支援 typing 模組的所有特性和型別。其中一些是更基礎的東西,未來不太可能新增,而另一些則可能會在使用者需求足夠多且成為優先事項時新增。

TorchScript 中不支援 typing 模組的以下型別和特性。

專案

描述

typing.Any

typing.Any 目前正在開發中,尚未釋出

typing.NoReturn

未實現

typing.Sequence

未實現

typing.Callable

未實現

typing.Literal

未實現

typing.ClassVar

未實現

typing.Final

對於模組屬性的類屬性註解支援此項,但對於函式不支援

typing.AnyStr

TorchScript 不支援 bytes,因此不使用此型別

typing.overload

typing.overload 目前正在開發中,尚未釋出

類型別名

未實現

名義子型別 vs 結構子型別

名義型別正在開發中,但結構型別尚未開發

NewType

不太可能實現

泛型

不太可能實現

本文件中未明確列出的 typing 模組中的任何其他功能均不受支援。

預設型別

預設情況下,TorchScript 函式的所有引數都假定為 Tensor。要指定 TorchScript 函式的引數是另一種型別,可以使用上面列出的 MyPy 風格的型別註解。

import torch

@torch.jit.script
def foo(x, tup):
    # type: (int, Tuple[Tensor, Tensor]) -> Tensor
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

注意

也可以使用 typing 模組中的 Python 3 型別提示來註解型別。

import torch
from typing import Tuple

@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

空列表假定為 List[Tensor],空字典假定為 Dict[str, Tensor]。要例項化其他型別的空列表或字典,請使用Python 3 型別提示

示例(Python 3 的型別註解)

import torch
import torch.nn as nn
from typing import Dict, List, Tuple

class EmptyDataStructures(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
        # This annotates the list to be a `List[Tuple[int, float]]`
        my_list: List[Tuple[int, float]] = []
        for i in range(10):
            my_list.append((i, x.item()))

        my_dict: Dict[str, int] = {}
        return my_list, my_dict

x = torch.jit.script(EmptyDataStructures())

Optional 型別細化

當在 if 語句的條件內進行與 None 的比較或在 assert 中檢查時,TorchScript 將細化型別為 Optional[T] 的變數的型別。編譯器可以處理與 andornot 組合的多個 None 檢查。對於未明確編寫的 if 語句的 else 塊,也會發生細化。

None 檢查必須在 if 語句的條件內;將 None 檢查賦值給變數並在 if 語句的條件中使用它不會細化檢查中變數的型別。只有區域性變數會被細化,像 self.x 這樣的屬性不會,必須賦值給區域性變數才能被細化。

示例(細化引數和區域性變數的型別)

import torch
import torch.nn as nn
from typing import Optional

class M(nn.Module):
    z: Optional[int]

    def __init__(self, z):
        super().__init__()
        # If `z` is None, its type cannot be inferred, so it must
        # be specified (above)
        self.z = z

    def forward(self, x, y, z):
        # type: (Optional[int], Optional[int], Optional[int]) -> int
        if x is None:
            x = 1
            x = x + 1

        # Refinement for an attribute by assigning it to a local
        z = self.z
        if y is not None and z is not None:
            x = y + z

        # Refinement via an `assert`
        assert z is not None
        x += z
        return x

module = torch.jit.script(M(2))
module = torch.jit.script(M(None))

TorchScript 類

警告

TorchScript 類支援是實驗性的。目前它最適合簡單的記錄類型別(可以將其視為帶有方法的 NamedTuple)。

如果使用 @torch.jit.script 進行註解,Python 類可以在 TorchScript 中使用,這類似於宣告 TorchScript 函式的方式

@torch.jit.script
class Foo:
  def __init__(self, x, y):
    self.x = x

  def aug_add_x(self, inc):
    self.x += inc

此子集受到限制

  • 所有函式都必須是有效的 TorchScript 函式(包括 __init__())。

  • 類必須是 new-style 類,因為我們使用 __new__() 和 pybind11 來構造它們。

  • TorchScript 類是靜態型別的。成員只能透過在 __init__() 方法中賦值給 self 來宣告。

    例如,在 __init__() 方法之外賦值給 self

    @torch.jit.script
    class Foo:
      def assign_x(self):
        self.x = torch.rand(2, 3)
    

    將導致

    RuntimeError:
    Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
    def assign_x(self):
      self.x = torch.rand(2, 3)
      ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    
  • 類主體中除了方法定義之外不允許有任何表示式。

  • 不支援繼承或任何其他多型策略,除了繼承 object 以指定 new-style 類。

定義類後,它可以在 TorchScript 和 Python 中像任何其他 TorchScript 型別一樣互換使用

# Declare a TorchScript class
@torch.jit.script
class Pair:
  def __init__(self, first, second):
    self.first = first
    self.second = second

@torch.jit.script
def sum_pair(p):
  # type: (Pair) -> Tensor
  return p.first + p.second

p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))

TorchScript 列舉

Python 列舉可以在 TorchScript 中使用,無需任何額外註解或程式碼

from enum import Enum


class Color(Enum):
    RED = 1
    GREEN = 2

@torch.jit.script
def enum_fn(x: Color, y: Color) -> bool:
    if x == Color.RED:
        return True

    return x == y

定義列舉後,它可以在 TorchScript 和 Python 中像任何其他 TorchScript 型別一樣互換使用。列舉值的型別必須是 intfloatstr。所有值必須具有相同的型別;不支援列舉值的異構型別。

具名元組

collections.namedtuple 產生的型別可以在 TorchScript 中使用。

import torch
import collections

Point = collections.namedtuple('Point', ['x', 'y'])

@torch.jit.script
def total(point):
    # type: (Point) -> Tensor
    return point.x + point.y

p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))

可迭代物件

某些函式(例如 zipenumerate)只能作用於可迭代型別。TorchScript 中的可迭代型別包括 Tensor、列表、元組、字典、字串、torch.nn.ModuleListtorch.nn.ModuleDict

表示式

支援以下 Python 表示式。

字面量

True
False
None
'string literals'
"string literals"
3  # interpreted as int
3.4  # interpreted as a float

列表構建

空列表假定型別為 List[Tensor]。其他列表字面量的型別派生自成員的型別。有關更多詳細資訊,請參見預設型別

[3, 4]
[]
[torch.rand(3), torch.rand(4)]

元組構建

(3, 4)
(3,)

字典構建

空字典假定型別為 Dict[str, Tensor]。其他字典字面量的型別派生自成員的型別。有關更多詳細資訊,請參見預設型別

{'hello': 3}
{}
{'a': torch.rand(3), 'b': torch.rand(4)}

變數

有關變數如何解析,請參見變數解析

my_variable_name

算術運算子

a + b
a - b
a * b
a / b
a ^ b
a @ b

比較運算子

a == b
a != b
a < b
a > b
a <= b
a >= b

邏輯運算子

a and b
a or b
not b

下標和切片

t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]

函式呼叫

呼叫內建函式

torch.rand(3, dtype=torch.int)

呼叫其他 script 函式

import torch

@torch.jit.script
def foo(x):
    return x + 1

@torch.jit.script
def bar(x):
    return foo(x)

方法呼叫

呼叫內建型別(如 tensor)的方法:x.mm(y)

對於模組,必須先編譯方法才能呼叫它們。TorchScript 編譯器在編譯其他方法時會遞迴編譯它看到的方法。預設情況下,編譯從 forward 方法開始。 forward 呼叫的任何方法都會被編譯,這些方法呼叫的任何方法也會被編譯,依此類推。要從 forward 以外的方法開始編譯,請使用 @torch.jit.export 裝飾器(forward 隱式標記為 @torch.jit.export)。

直接呼叫子模組(例如 self.resnet(input))等同於呼叫其 forward 方法(例如 self.resnet.forward(input))。

import torch
import torch.nn as nn
import torchvision

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        means = torch.tensor([103.939, 116.779, 123.68])
        self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
        resnet = torchvision.models.resnet18()
        self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))

    def helper(self, input):
        return self.resnet(input - self.means)

    def forward(self, input):
        return self.helper(input)

    # Since nothing in the model calls `top_level_method`, the compiler
    # must be explicitly told to compile this method
    @torch.jit.export
    def top_level_method(self, input):
        return self.other_helper(input)

    def other_helper(self, input):
        return input + 10

# `my_script_module` will have the compiled methods `forward`, `helper`,
# `top_level_method`, and `other_helper`
my_script_module = torch.jit.script(MyModule())

三元表示式

x if x > y else y

型別轉換

float(ten)
int(3.5)
bool(ten)
str(2)``

訪問模組引數

self.my_parameter
self.my_submodule.my_parameter

語句

TorchScript 支援以下型別的語句

簡單賦值

a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b

模式匹配賦值

a, b = tuple_or_list
a, b, *c = a_tuple

多重賦值

a = b, c = tup

If 語句

if a < 4:
    r = -a
elif a < 3:
    r = a + a
else:
    r = 3 * a

除了布林值,浮點數、整數和 Tensor 也可以在條件語句中使用,並會被隱式轉換為布林值。

While 迴圈

a = 0
while a < 4:
    print(a)
    a += 1

帶 range 的 For 迴圈

x = 0
for i in range(10):
    x *= i

遍歷元組的 For 迴圈

這些會展開迴圈,為元組的每個成員生成一個主體。主體必須對每個成員進行正確的型別檢查。

tup = (3, torch.rand(4))
for x in tup:
    print(x)

遍歷常量 nn.ModuleList 的 For 迴圈

要在已編譯的方法中使用 nn.ModuleList,必須透過將屬性名稱新增到型別的 __constants__ 列表中來將其標記為常量。遍歷 nn.ModuleList 的 For 迴圈會在編譯時展開迴圈主體,並處理常量模組列表的每個成員。

class SubModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(2))

    def forward(self, input):
        return self.weight + input

class MyModule(torch.nn.Module):
    __constants__ = ['mods']

    def __init__(self):
        super().__init__()
        self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])

    def forward(self, v):
        for module in self.mods:
            v = module(v)
        return v


m = torch.jit.script(MyModule())

Break 和 Continue

for i in range(5):
    if i == 1:
        continue
    if i == 3:
        break
    print(i)

Return

return a, b

變數解析

TorchScript 支援 Python 變數解析(即作用域)規則的一個子集。區域性變數的行為與 Python 中相同,但有一個限制:變數在透過函式的所有路徑上必須具有相同的型別。如果變數在 if 語句的不同分支上具有不同的型別,則在 if 語句結束之後使用它是錯誤的。

同樣,如果在函式中只沿某些路徑定義了變數,則不允許使用該變數。

示例

@torch.jit.script
def foo(x):
    if x < 0:
        y = 4
    print(y)
Traceback (most recent call last):
  ...
RuntimeError: ...

y is not defined in the false branch...
@torch.jit.script...
def foo(x):
    if x < 0:
    ~~~~~~~~~
        y = 4
        ~~~~~ <--- HERE
    print(y)
and was used here:
    if x < 0:
        y = 4
    print(y)
          ~ <--- HERE...

非區域性變數在定義函式時在編譯時解析為 Python 值。然後使用Python 值的使用中描述的規則將這些值轉換為 TorchScript 值。

Python 值的使用

為了讓編寫 TorchScript 更方便,我們允許指令碼程式碼引用周圍作用域中的 Python 值。例如,每當引用 torch 時,TorchScript 編譯器實際上是在函式宣告時將其解析為 torch Python 模組。這些 Python 值不是 TorchScript 的一等公民。相反,它們在編譯時被“去糖化”為 TorchScript 支援的基本型別。這取決於編譯時引用的 Python 值的動態型別。本節描述了在 TorchScript 中訪問 Python 值時使用的規則。

函式

TorchScript 可以呼叫 Python 函式。在將模型逐步轉換為 TorchScript 時,此功能非常有用。可以逐個函式地將模型移動到 TorchScript,同時保留對 Python 函式的呼叫。這樣您就可以在轉換過程中逐步檢查模型的正確性。

torch.jit.is_scripting()[source][source]

在編譯時返回 True,否則返回 False 的函式。這在使用 @unused 裝飾器時特別有用,可以將模型中尚未與 TorchScript 相容的程式碼保留下來。.. testcode

import torch

@torch.jit.unused
def unsupported_linear_op(x):
    return x

def linear(x):
    if torch.jit.is_scripting():
        return torch.linear(x)
    else:
        return unsupported_linear_op(x)
返回型別

bool

torch.jit.is_tracing()[source][source]

返回布林值。

在跟蹤時(如果在使用 torch.jit.trace 跟蹤程式碼期間呼叫了某個函式)返回 True,否則返回 False

在 Python 模組上查詢屬性

TorchScript 可以查詢模組上的屬性。透過這種方式訪問內建函式,例如 torch.add。這使得 TorchScript 可以呼叫其他模組中定義的函式。

Python 定義的常量

TorchScript 還提供了使用 Python 中定義的常量的方法。這些常量可以用於將超引數硬編碼到函式中,或定義通用常量。有兩種方法可以指定將 Python 值視為常量。

  1. 作為模組屬性查詢的值假定為常量

import math
import torch

@torch.jit.script
def fn():
    return math.pi
  1. ScriptModule 的屬性可以透過使用 Final[T] 註解來標記為常量

import torch
import torch.nn as nn

class Foo(nn.Module):
    # `Final` from the `typing_extensions` module can also be used
    a : torch.jit.Final[int]

    def __init__(self):
        super().__init__()
        self.a = 1 + 4

    def forward(self, input):
        return self.a + input

f = torch.jit.script(Foo())

支援的常量 Python 型別有

  • int

  • float

  • bool

  • torch.device

  • torch.layout

  • torch.dtype

  • 包含支援型別的元組

  • 可在 TorchScript for 迴圈中使用的 torch.nn.ModuleList

模組屬性

torch.nn.Parameter 包裝器和 register_buffer 可用於將 tensor 賦值給模組。如果其他賦值給已編譯模組的值的型別可以推斷,則這些值將被新增到已編譯的模組中。TorchScript 中所有可用的型別都可以用作模組屬性。Tensor 屬性在語義上與 buffer 相同。空列表和字典以及 None 值的型別無法推斷,必須透過 PEP 526 風格的類註解來指定。如果無法推斷型別且未明確註解,則不會將其作為屬性新增到生成的 ScriptModule 中。

示例

from typing import List, Dict

class Foo(nn.Module):
    # `words` is initialized as an empty list, so its type must be specified
    words: List[str]

    # The type could potentially be inferred if `a_dict` (below) was not
    # empty, but this annotation ensures `some_dict` will be made into the
    # proper type
    some_dict: Dict[str, int]

    def __init__(self, a_dict):
        super().__init__()
        self.words = []
        self.some_dict = a_dict

        # `int`s can be inferred
        self.my_int = 10

    def forward(self, input):
        # type: (str) -> int
        self.words.append(input)
        return self.some_dict[input] + self.my_int

f = torch.jit.script(Foo({'hi': 2}))

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源