快捷方式

torch.jit.freeze

torch.jit.freeze(mod, preserved_attrs=None, optimize_numerics=True)[][]

凍結 ScriptModule,將子模組和屬性內聯為常量。

凍結 ScriptModule 將克隆它,並嘗試將克隆模組的子模組、引數和屬性內聯為 TorchScript IR Graph 中的常量。預設情況下,將保留 forward 方法,以及 preserved_attrs 中指定的屬性和方法。此外,在保留方法中修改的任何屬性也將被保留。

當前凍結只接受處於評估模式(eval mode)的 ScriptModule。

凍結應用通用最佳化,無論機器如何,都能加速您的模型。要使用特定於伺服器的設定進一步最佳化,請在凍結後執行 optimize_for_inference

引數
  • mod (ScriptModule) – 要凍結的模組

  • preserved_attrs (Optional[List[str]]) – 除 forward 方法外要保留的屬性列表。在保留方法中修改的屬性也將被保留。

  • optimize_numerics (bool) – 如果為 True,將執行一組不嚴格保留數值的最佳化過程。最佳化的完整詳細資訊可在 torch.jit.run_frozen_optimizations 中找到。

返回

凍結的 ScriptModule

示例(凍結帶有 Parameter 的簡單模組)

    def forward(self, input):
        output = self.weight.mm(input)
        output = self.linear(output)
        return output

scripted_module = torch.jit.script(MyModule(2, 3).eval())
frozen_module = torch.jit.freeze(scripted_module)
# parameters have been removed and inlined into the Graph as constants
assert len(list(frozen_module.named_parameters())) == 0
# See the compiled graph as Python code
print(frozen_module.code)

示例(凍結帶有保留屬性的模組)

    def forward(self, input):
        self.modified_tensor += 1
        return input + self.modified_tensor

scripted_module = torch.jit.script(MyModule2().eval())
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
# we've manually preserved `version`, so it still exists on the frozen module and can be modified
assert frozen_module.version == 1
frozen_module.version = 2
# `modified_tensor` is detected as being mutated in the forward, so freezing preserves
# it to retain model semantics
assert frozen_module(torch.tensor(1)) == torch.tensor(12)
# now that we've run it once, the next result will be incremented by one
assert frozen_module(torch.tensor(1)) == torch.tensor(13)

注意

也支援凍結子模組屬性:frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["submodule.version"])

注意

如果您不確定為什麼某個屬性沒有被內聯為常量,您可以在 frozen_module.forward.graph 上執行 dump_alias_db,檢視凍結是否檢測到該屬性正在被修改。

注意

由於凍結將權重變為常量並移除模組層級結構,因此 to 和其他用於操作裝置或 dtype 的 nn.Module 方法不再起作用。作為一種變通方法,您可以透過在 torch.jit.load 中指定 map_location 來重新對映裝置,但特定於裝置的邏輯可能已融入模型中。

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源