注意
點選此處下載完整示例程式碼
(Beta) 使用 FX 構建卷積層/批次歸一化層融合器¶
建立時間:2021 年 3 月 4 日 | 最後更新:2024 年 1 月 16 日 | 最後驗證:2024 年 11 月 5 日
作者:Horace He
在本教程中,我們將使用 FX(一個用於對 PyTorch 進行可組合函式轉換的工具包)來完成以下任務
在資料依賴關係中查詢卷積層/批次歸一化層模式。
對於在 1) 中找到的模式,將批次歸一化統計資訊合併到卷積權重中。
請注意,此最佳化僅適用於處於推理模式(即 mode.eval())下的模型
我們將構建此處存在的融合器:https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/fx/experimental/fuser.py
首先,讓我們匯入一些模組(稍後將在程式碼中使用所有這些模組)。
from typing import Type, Dict, Any, Tuple, Iterable
import copy
import torch.fx as fx
import torch
import torch.nn as nn
在本教程中,我們將建立一個由卷積層和批次歸一化層組成的模型。請注意,此模型包含一些巧妙的元件 - 一些卷積層/批次歸一化層模式隱藏在 Sequential 中,並且其中一個 BatchNorms 被另一個 Module 包裝。
class WrappedBatchNorm(nn.Module):
def __init__(self):
super().__init__()
self.mod = nn.BatchNorm2d(1)
def forward(self, x):
return self.mod(x)
class M(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1)
self.bn1 = nn.BatchNorm2d(1)
self.conv2 = nn.Conv2d(1, 1, 1)
self.nested = nn.Sequential(
nn.BatchNorm2d(1),
nn.Conv2d(1, 1, 1),
)
self.wrapped = WrappedBatchNorm()
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.conv2(x)
x = self.nested(x)
x = self.wrapped(x)
return x
model = M()
model.eval()
融合卷積層與批次歸一化層¶
在 PyTorch 中嘗試自動融合卷積層和批次歸一化層的主要挑戰之一是 PyTorch 不提供輕鬆訪問計算圖的方法。FX 透過符號化跟蹤實際呼叫的操作來解決此問題,這樣我們就可以跟蹤透過 forward 呼叫、巢狀在 Sequential 模組中或包裝在使用者定義模組中的計算。
traced_model = torch.fx.symbolic_trace(model)
print(traced_model.graph)
這為我們提供了模型的圖表示。請注意,隱藏在 Sequential 中的模組以及包裝的 Module 都已內聯到圖中。這是預設的抽象級別,但可以由 Pass 編寫器配置。更多資訊可以在 FX 概覽中找到:https://pytorch.com.tw/docs/master/fx.html#module-torch.fx
融合卷積層與批次歸一化層¶
與其他一些融合不同,卷積層與批次歸一化層的融合不需要任何新運算元。相反,由於批次歸一化在推理過程中包含逐點加法和乘法,這些操作可以“烘焙”到前一個卷積的權重中。這使我們能夠完全從模型中移除批次歸一化層!請閱讀 https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ 瞭解更多詳細資訊。這裡的程式碼為了清晰起見,複製自 https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py。
def fuse_conv_bn_eval(conv, bn):
"""
Given a conv Module `A` and an batch_norm module `B`, returns a conv
module `C` such that C(x) == B(A(x)) in inference mode.
"""
assert(not (conv.training or bn.training)), "Fusion only for eval!"
fused_conv = copy.deepcopy(conv)
fused_conv.weight, fused_conv.bias = \
fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)
return fused_conv
def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
if conv_b is None:
conv_b = torch.zeros_like(bn_rm)
if bn_w is None:
bn_w = torch.ones_like(bn_rm)
if bn_b is None:
bn_b = torch.zeros_like(bn_rm)
bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)
FX 融合 Pass¶
現在我們已經有了計算圖以及融合卷積層和批次歸一化層的方法,剩下的就是遍歷 FX 圖並應用所需的融合。
def _parent_name(target : str) -> Tuple[str, str]:
"""
Splits a ``qualname`` into parent path and last atom.
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
"""
*parent, name = target.rsplit('.', 1)
return parent[0] if parent else '', name
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
assert(isinstance(node.target, str))
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name, new_module)
def fuse(model: torch.nn.Module) -> torch.nn.Module:
model = copy.deepcopy(model)
# The first step of most FX passes is to symbolically trace our model to
# obtain a `GraphModule`. This is a representation of our original model
# that is functionally identical to our original model, except that we now
# also have a graph representation of our forward pass.
fx_model: fx.GraphModule = fx.symbolic_trace(model)
modules = dict(fx_model.named_modules())
# The primary representation for working with FX are the `Graph` and the
# `Node`. Each `GraphModule` has a `Graph` associated with it - this
# `Graph` is also what generates `GraphModule.code`.
# The `Graph` itself is represented as a list of `Node` objects. Thus, to
# iterate through all of the operations in our graph, we iterate over each
# `Node` in our `Graph`.
for node in fx_model.graph.nodes:
# The FX IR contains several types of nodes, which generally represent
# call sites to modules, functions, or methods. The type of node is
# determined by `Node.op`.
if node.op != 'call_module': # If our current node isn't calling a Module then we can ignore it.
continue
# For call sites, `Node.target` represents the module/function/method
# that's being called. Here, we check `Node.target` to see if it's a
# batch norm module, and then check `Node.args[0].target` to see if the
# input `Node` is a convolution.
if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d:
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
continue
conv = modules[node.args[0].target]
bn = modules[node.target]
fused_conv = fuse_conv_bn_eval(conv, bn)
replace_node_module(node.args[0], modules, fused_conv)
# As we've folded the batch nor into the conv, we need to replace all uses
# of the batch norm with the conv.
node.replace_all_uses_with(node.args[0])
# Now that all uses of the batch norm have been replaced, we can
# safely remove the batch norm.
fx_model.graph.erase_node(node)
fx_model.graph.lint()
# After we've modified our graph, we need to recompile our graph in order
# to keep the generated code in sync.
fx_model.recompile()
return fx_model
注意
出於演示目的,我們在此處做了一些簡化,例如僅匹配 2D 卷積。檢視 https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py 獲取更可用的 Pass。
測試我們的融合 Pass¶
現在,我們可以在初始玩具模型上執行此融合 Pass,並驗證我們的結果是否一致。此外,我們可以打印出融合模型的程式碼,並驗證不再存在批次歸一化層。
fused_model = fuse(model)
print(fused_model.code)
inp = torch.randn(5, 1, 1, 1)
torch.testing.assert_allclose(fused_model(inp), model(inp))
在 ResNet18 上基準測試我們的融合¶
我們可以在像 ResNet18 這樣的大型模型上測試我們的融合 Pass,看看此 Pass 能在多大程度上提高推理效能。
import torchvision.models as models
import time
rn18 = models.resnet18()
rn18.eval()
inp = torch.randn(10, 3, 224, 224)
output = rn18(inp)
def benchmark(model, iters=20):
for _ in range(10):
model(inp)
begin = time.time()
for _ in range(iters):
model(inp)
return str(time.time()-begin)
fused_rn18 = fuse(rn18)
print("Unfused time: ", benchmark(rn18))
print("Fused time: ", benchmark(fused_rn18))
正如我們之前所見,FX 轉換的輸出是(“可 TorchScript 化”)的 PyTorch 程式碼,我們可以輕鬆地對輸出進行 jit.script,以便進一步提高效能。透過這種方式,FX 模型轉換可以與 TorchScript 無縫組合。
jit_rn18 = torch.jit.script(fused_rn18)
print("jit time: ", benchmark(jit_rn18))
############
# Conclusion
# ----------
# As we can see, using FX we can easily write static graph transformations on
# PyTorch code.
#
# Since FX is still in beta, we would be happy to hear any
# feedback you have about using it. Please feel free to use the
# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker
# (https://github.com/pytorch/pytorch/issues) to provide any feedback
# you might have.
指令碼總執行時間:( 0 分鐘 0.000 秒)