torch.jit.trace_module¶
- torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_inputs_is_kwarg=False, _store_inputs=True)[源][源]¶
追蹤一個模組並返回一個可執行的
ScriptModule,該模組將使用即時編譯進行最佳化。當一個模組被傳遞給
torch.jit.trace時,只會執行和追蹤forward方法。使用trace_module,你可以指定一個方法名到示例輸入的字典來追蹤(見下面的inputs引數)。有關追蹤的更多資訊,請參閱
torch.jit.trace。- 引數
mod (torch.nn.Module) – 一個包含
inputs中指定方法名稱的torch.nn.Module。給定的方法將作為單個 ScriptModule 的一部分進行編譯。inputs (dict) – 一個字典,包含按
mod中方法名稱索引的示例輸入。這些輸入在追蹤時將傳遞給與輸入鍵對應的方法。例如:{ 'forward' : example_forward_input, 'method2': example_method2_input}
- 關鍵字引數
check_trace (
bool, 可選) – 檢查透過追蹤程式碼執行的相同輸入是否產生相同的輸出。預設值:True。如果您的網路包含非確定性操作,或者您確定網路正確儘管檢查器失敗,您可能希望停用此選項。check_inputs (list of dicts, optional) – 一個包含輸入引數字典的列表,用於對照預期結果檢查追蹤。每個元組等同於在
inputs中指定的一組輸入引數。為了獲得最佳結果,請傳入一組具有代表性的形狀和型別輸入的檢查輸入,這些輸入代表您期望網路看到的輸入空間。如果未指定,則使用原始inputs進行檢查。check_tolerance (float, optional) – 在檢查過程中使用的浮點比較容差。這可用於在結果由於已知原因(例如運算子融合)出現數值差異時放寬檢查器嚴格性。
example_inputs_is_kwarg (
bool, 可選) – 此引數指示示例輸入是否為關鍵字引數包。預設值:False。
- 返回
一個
ScriptModule物件,其中包含一個具有追蹤程式碼的forward方法。當func是一個torch.nn.Module時,返回的ScriptModule將具有與func相同的子模組和引數集。
示例 (追蹤具有多個方法的模組)
import torch import torch.nn as nn class Net(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) def weighted_kernel_sum(self, weight): return weight * self.conv.weight n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input) # Trace specific methods on a module (specified in `inputs`), constructs # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods inputs = {"forward": example_forward_input, "weighted_kernel_sum": example_weight} module = torch.jit.trace_module(n, inputs)