快捷方式

torch.jit.trace_module

torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_inputs_is_kwarg=False, _store_inputs=True)[源][源]

追蹤一個模組並返回一個可執行的 ScriptModule,該模組將使用即時編譯進行最佳化。

當一個模組被傳遞給 torch.jit.trace 時,只會執行和追蹤 forward 方法。使用 trace_module,你可以指定一個方法名到示例輸入的字典來追蹤(見下面的 inputs 引數)。

有關追蹤的更多資訊,請參閱 torch.jit.trace

引數
  • mod (torch.nn.Module) – 一個包含 inputs 中指定方法名稱的 torch.nn.Module。給定的方法將作為單個 ScriptModule 的一部分進行編譯。

  • inputs (dict) – 一個字典,包含按 mod 中方法名稱索引的示例輸入。這些輸入在追蹤時將傳遞給與輸入鍵對應的方法。例如:{ 'forward' : example_forward_input, 'method2': example_method2_input}

關鍵字引數
  • check_trace (bool, 可選) – 檢查透過追蹤程式碼執行的相同輸入是否產生相同的輸出。預設值:True。如果您的網路包含非確定性操作,或者您確定網路正確儘管檢查器失敗,您可能希望停用此選項。

  • check_inputs (list of dicts, optional) – 一個包含輸入引數字典的列表,用於對照預期結果檢查追蹤。每個元組等同於在 inputs 中指定的一組輸入引數。為了獲得最佳結果,請傳入一組具有代表性的形狀和型別輸入的檢查輸入,這些輸入代表您期望網路看到的輸入空間。如果未指定,則使用原始 inputs 進行檢查。

  • check_tolerance (float, optional) – 在檢查過程中使用的浮點比較容差。這可用於在結果由於已知原因(例如運算子融合)出現數值差異時放寬檢查器嚴格性。

  • example_inputs_is_kwarg (bool, 可選) – 此引數指示示例輸入是否為關鍵字引數包。預設值:False

返回

一個 ScriptModule 物件,其中包含一個具有追蹤程式碼的 forward 方法。當 func 是一個 torch.nn.Module 時,返回的 ScriptModule 將具有與 func 相同的子模組和引數集。

示例 (追蹤具有多個方法的模組)

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

    def weighted_kernel_sum(self, weight):
        return weight * self.conv.weight


n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)

# Trace specific methods on a module (specified in `inputs`), constructs
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {"forward": example_forward_input, "weighted_kernel_sum": example_weight}
module = torch.jit.trace_module(n, inputs)

文件

查閱 PyTorch 全面的開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源