• 教程 >
  • (beta) 使用 FX 構建一個簡單的 CPU 效能分析器
快捷方式

(beta) 使用 FX 構建一個簡單的 CPU 效能分析器

創建於: Mar 04, 2021 | 最後更新: Jan 16, 2024 | 最後驗證: 未驗證

作者: James Reed

在本教程中,我們將使用 FX 來完成以下任務:

  1. 捕獲 PyTorch Python 程式碼,以便我們可以檢查並收集程式碼結構和執行情況的統計資訊

  2. 構建一個小型類,用作簡單的效能“分析器”,從實際執行中收集模型每個部分的執行時統計資訊。

在本教程中,我們將使用 torchvision 的 ResNet18 模型進行演示。

import torch
import torch.fx
import torchvision.models as models

rn18 = models.resnet18()
rn18.eval()
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

現在我們有了模型,我們想深入研究其效能。也就是說,對於以下呼叫,模型的哪些部分耗時最長?

input = torch.randn(5, 3, 224, 224)
output = rn18(input)

回答這個問題的常見方法是遍歷程式原始碼,新增程式碼在程式的各個點收集時間戳,並比較這些時間戳之間的差異以檢視這些區域花費了多長時間。

這種技術當然適用於 PyTorch 程式碼,但是如果我們不必複製模型程式碼並對其進行編輯,特別是我們沒有編寫的程式碼(例如這個 torchvision 模型),那會更好。相反,我們將使用 FX 來自動化這種“檢測”過程,而無需修改任何原始碼。

首先,讓我們匯入一些庫(我們將在後面的程式碼中全部使用它們)。

import statistics, tabulate, time
from typing import Any, Dict, List
from torch.fx import Interpreter

注意

tabulate 是一個外部庫,不是 PyTorch 的依賴項。我們將使用它更方便地視覺化效能資料。請確保你已經從你喜歡的 Python 包源安裝了它。

使用符號跟蹤捕獲模型

接下來,我們將使用 FX 的符號跟蹤機制來捕獲模型的定義,並將其儲存在一個我們可以操作和檢查的資料結構中。

traced_rn18 = torch.fx.symbolic_trace(rn18)
print(traced_rn18.graph)
graph():
    %x : torch.Tensor [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
    %relu : [num_users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
    %maxpool : [num_users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
    %layer1_0_conv1 : [num_users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
    %layer1_0_bn1 : [num_users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
    %layer1_0_relu : [num_users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
    %layer1_0_conv2 : [num_users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
    %layer1_0_bn2 : [num_users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {})
    %layer1_0_relu_1 : [num_users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
    %layer1_1_conv1 : [num_users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {})
    %layer1_1_bn1 : [num_users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {})
    %layer1_1_relu : [num_users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {})
    %layer1_1_conv2 : [num_users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {})
    %layer1_1_bn2 : [num_users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {})
    %layer1_1_relu_1 : [num_users=2] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {})
    %layer2_0_conv1 : [num_users=1] = call_module[target=layer2.0.conv1](args = (%layer1_1_relu_1,), kwargs = {})
    %layer2_0_bn1 : [num_users=1] = call_module[target=layer2.0.bn1](args = (%layer2_0_conv1,), kwargs = {})
    %layer2_0_relu : [num_users=1] = call_module[target=layer2.0.relu](args = (%layer2_0_bn1,), kwargs = {})
    %layer2_0_conv2 : [num_users=1] = call_module[target=layer2.0.conv2](args = (%layer2_0_relu,), kwargs = {})
    %layer2_0_bn2 : [num_users=1] = call_module[target=layer2.0.bn2](args = (%layer2_0_conv2,), kwargs = {})
    %layer2_0_downsample_0 : [num_users=1] = call_module[target=layer2.0.downsample.0](args = (%layer1_1_relu_1,), kwargs = {})
    %layer2_0_downsample_1 : [num_users=1] = call_module[target=layer2.0.downsample.1](args = (%layer2_0_downsample_0,), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=operator.add](args = (%layer2_0_bn2, %layer2_0_downsample_1), kwargs = {})
    %layer2_0_relu_1 : [num_users=2] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {})
    %layer2_1_conv1 : [num_users=1] = call_module[target=layer2.1.conv1](args = (%layer2_0_relu_1,), kwargs = {})
    %layer2_1_bn1 : [num_users=1] = call_module[target=layer2.1.bn1](args = (%layer2_1_conv1,), kwargs = {})
    %layer2_1_relu : [num_users=1] = call_module[target=layer2.1.relu](args = (%layer2_1_bn1,), kwargs = {})
    %layer2_1_conv2 : [num_users=1] = call_module[target=layer2.1.conv2](args = (%layer2_1_relu,), kwargs = {})
    %layer2_1_bn2 : [num_users=1] = call_module[target=layer2.1.bn2](args = (%layer2_1_conv2,), kwargs = {})
    %add_3 : [num_users=1] = call_function[target=operator.add](args = (%layer2_1_bn2, %layer2_0_relu_1), kwargs = {})
    %layer2_1_relu_1 : [num_users=2] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {})
    %layer3_0_conv1 : [num_users=1] = call_module[target=layer3.0.conv1](args = (%layer2_1_relu_1,), kwargs = {})
    %layer3_0_bn1 : [num_users=1] = call_module[target=layer3.0.bn1](args = (%layer3_0_conv1,), kwargs = {})
    %layer3_0_relu : [num_users=1] = call_module[target=layer3.0.relu](args = (%layer3_0_bn1,), kwargs = {})
    %layer3_0_conv2 : [num_users=1] = call_module[target=layer3.0.conv2](args = (%layer3_0_relu,), kwargs = {})
    %layer3_0_bn2 : [num_users=1] = call_module[target=layer3.0.bn2](args = (%layer3_0_conv2,), kwargs = {})
    %layer3_0_downsample_0 : [num_users=1] = call_module[target=layer3.0.downsample.0](args = (%layer2_1_relu_1,), kwargs = {})
    %layer3_0_downsample_1 : [num_users=1] = call_module[target=layer3.0.downsample.1](args = (%layer3_0_downsample_0,), kwargs = {})
    %add_4 : [num_users=1] = call_function[target=operator.add](args = (%layer3_0_bn2, %layer3_0_downsample_1), kwargs = {})
    %layer3_0_relu_1 : [num_users=2] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {})
    %layer3_1_conv1 : [num_users=1] = call_module[target=layer3.1.conv1](args = (%layer3_0_relu_1,), kwargs = {})
    %layer3_1_bn1 : [num_users=1] = call_module[target=layer3.1.bn1](args = (%layer3_1_conv1,), kwargs = {})
    %layer3_1_relu : [num_users=1] = call_module[target=layer3.1.relu](args = (%layer3_1_bn1,), kwargs = {})
    %layer3_1_conv2 : [num_users=1] = call_module[target=layer3.1.conv2](args = (%layer3_1_relu,), kwargs = {})
    %layer3_1_bn2 : [num_users=1] = call_module[target=layer3.1.bn2](args = (%layer3_1_conv2,), kwargs = {})
    %add_5 : [num_users=1] = call_function[target=operator.add](args = (%layer3_1_bn2, %layer3_0_relu_1), kwargs = {})
    %layer3_1_relu_1 : [num_users=2] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {})
    %layer4_0_conv1 : [num_users=1] = call_module[target=layer4.0.conv1](args = (%layer3_1_relu_1,), kwargs = {})
    %layer4_0_bn1 : [num_users=1] = call_module[target=layer4.0.bn1](args = (%layer4_0_conv1,), kwargs = {})
    %layer4_0_relu : [num_users=1] = call_module[target=layer4.0.relu](args = (%layer4_0_bn1,), kwargs = {})
    %layer4_0_conv2 : [num_users=1] = call_module[target=layer4.0.conv2](args = (%layer4_0_relu,), kwargs = {})
    %layer4_0_bn2 : [num_users=1] = call_module[target=layer4.0.bn2](args = (%layer4_0_conv2,), kwargs = {})
    %layer4_0_downsample_0 : [num_users=1] = call_module[target=layer4.0.downsample.0](args = (%layer3_1_relu_1,), kwargs = {})
    %layer4_0_downsample_1 : [num_users=1] = call_module[target=layer4.0.downsample.1](args = (%layer4_0_downsample_0,), kwargs = {})
    %add_6 : [num_users=1] = call_function[target=operator.add](args = (%layer4_0_bn2, %layer4_0_downsample_1), kwargs = {})
    %layer4_0_relu_1 : [num_users=2] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {})
    %layer4_1_conv1 : [num_users=1] = call_module[target=layer4.1.conv1](args = (%layer4_0_relu_1,), kwargs = {})
    %layer4_1_bn1 : [num_users=1] = call_module[target=layer4.1.bn1](args = (%layer4_1_conv1,), kwargs = {})
    %layer4_1_relu : [num_users=1] = call_module[target=layer4.1.relu](args = (%layer4_1_bn1,), kwargs = {})
    %layer4_1_conv2 : [num_users=1] = call_module[target=layer4.1.conv2](args = (%layer4_1_relu,), kwargs = {})
    %layer4_1_bn2 : [num_users=1] = call_module[target=layer4.1.bn2](args = (%layer4_1_conv2,), kwargs = {})
    %add_7 : [num_users=1] = call_function[target=operator.add](args = (%layer4_1_bn2, %layer4_0_relu_1), kwargs = {})
    %layer4_1_relu_1 : [num_users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {})
    %avgpool : [num_users=1] = call_module[target=avgpool](args = (%layer4_1_relu_1,), kwargs = {})
    %flatten : [num_users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
    %fc : [num_users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
    return fc

這為我們提供了 ResNet18 模型的 Graph 表示。一個 Graph 由一系列相互連線的 Node 組成。每個 Node 表示 Python 程式碼中的一個呼叫點(無論是函式、模組還是方法),而邊(表示為每個節點上的 argskwargs)表示這些呼叫點之間傳遞的值。有關 Graph 表示和 FX 其餘 API 的更多資訊,請參閱 FX 文件 https://pytorch.com.tw/docs/master/fx.html

建立一個分析直譯器

接下來,我們將建立一個繼承自 torch.fx.Interpreter 的類。雖然 symbolic_trace 產生的 GraphModule 會編譯 Python 程式碼,並在呼叫 GraphModule 時執行,但執行 GraphModule 的另一種方法是逐個執行 Graph 中的每個 Node。這就是 Interpreter 提供的功能:它逐節點地解釋圖。

透過繼承自 Interpreter,我們可以重寫各種功能並安裝我們想要的分析行為。目標是擁有一個物件,我們可以向其傳遞模型,呼叫模型 1 次或多次,然後獲取有關模型及其每個部分在這些執行期間花費了多長時間的統計資訊。

讓我們定義 ProfilingInterpreter

class ProfilingInterpreter(Interpreter):
    def __init__(self, mod : torch.nn.Module):
        # Rather than have the user symbolically trace their model,
        # we're going to do it in the constructor. As a result, the
        # user can pass in any ``Module`` without having to worry about
        # symbolic tracing APIs
        gm = torch.fx.symbolic_trace(mod)
        super().__init__(gm)

        # We are going to store away two things here:
        #
        # 1. A list of total runtimes for ``mod``. In other words, we are
        #    storing away the time ``mod(...)`` took each time this
        #    interpreter is called.
        self.total_runtime_sec : List[float] = []
        # 2. A map from ``Node`` to a list of times (in seconds) that
        #    node took to run. This can be seen as similar to (1) but
        #    for specific sub-parts of the model.
        self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}

    ######################################################################
    # Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``
    # method is the top-level entry point for execution of the model. We will
    # want to intercept this so that we can record the total runtime of the
    # model.

    def run(self, *args) -> Any:
        # Record the time we started running the model
        t_start = time.time()
        # Run the model by delegating back into Interpreter.run()
        return_val = super().run(*args)
        # Record the time we finished running the model
        t_end = time.time()
        # Store the total elapsed time this model execution took in the
        # ``ProfilingInterpreter``
        self.total_runtime_sec.append(t_end - t_start)
        return return_val

    ######################################################################
    # Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each
    # time it executes a single node. We will intercept this so that we
    # can measure and record the time taken for each individual call in
    # the model.

    def run_node(self, n : torch.fx.Node) -> Any:
        # Record the time we started running the op
        t_start = time.time()
        # Run the op by delegating back into Interpreter.run_node()
        return_val = super().run_node(n)
        # Record the time we finished running the op
        t_end = time.time()
        # If we don't have an entry for this node in our runtimes_sec
        # data structure, add one with an empty list value.
        self.runtimes_sec.setdefault(n, [])
        # Record the total elapsed time for this single invocation
        # in the runtimes_sec data structure
        self.runtimes_sec[n].append(t_end - t_start)
        return return_val

    ######################################################################
    # Finally, we are going to define a method (one which doesn't override
    # any ``Interpreter`` method) that provides us a nice, organized view of
    # the data we have collected.

    def summary(self, should_sort : bool = False) -> str:
        # Build up a list of summary information for each node
        node_summaries : List[List[Any]] = []
        # Calculate the mean runtime for the whole network. Because the
        # network may have been called multiple times during profiling,
        # we need to summarize the runtimes. We choose to use the
        # arithmetic mean for this.
        mean_total_runtime = statistics.mean(self.total_runtime_sec)

        # For each node, record summary statistics
        for node, runtimes in self.runtimes_sec.items():
            # Similarly, compute the mean runtime for ``node``
            mean_runtime = statistics.mean(runtimes)
            # For easier understanding, we also compute the percentage
            # time each node took with respect to the whole network.
            pct_total = mean_runtime / mean_total_runtime * 100
            # Record the node's type, name of the node, mean runtime, and
            # percent runtime.
            node_summaries.append(
                [node.op, str(node), mean_runtime, pct_total])

        # One of the most important questions to answer when doing performance
        # profiling is "Which op(s) took the longest?". We can make this easy
        # to see by providing sorting functionality in our summary view
        if should_sort:
            node_summaries.sort(key=lambda s: s[2], reverse=True)

        # Use the ``tabulate`` library to create a well-formatted table
        # presenting our summary information
        headers : List[str] = [
            'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
        ]
        return tabulate.tabulate(node_summaries, headers=headers)

注意

我們使用 Python 的 time.time 函式獲取即時時鐘時間戳並進行比較。這不是衡量效能最準確的方法,並且只能給我們提供一級近似。在本教程中,我們僅出於演示目的使用這種簡單技術。

調查 ResNet18 的效能

我們現在可以使用 ProfilingInterpreter 來檢查 ResNet18 模型的效能特徵;

interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True))
Op type        Op                       Average runtime (s)    Pct total runtime
-------------  ---------------------  ---------------------  -------------------
call_module    maxpool                          0.00606203            10.1653
call_module    conv1                            0.00542283             9.09342
call_module    layer4_0_conv2                   0.00363326             6.09253
call_module    layer1_0_conv1                   0.00321054             5.38369
call_module    layer1_0_conv2                   0.00315976             5.29853
call_module    layer4_1_conv1                   0.00309634             5.19218
call_module    layer4_1_conv2                   0.00293541             4.92232
call_module    layer1_1_conv2                   0.00277257             4.64926
call_module    layer1_1_conv1                   0.00243998             4.09154
call_module    layer3_1_conv2                   0.00232077             3.89164
call_module    layer2_1_conv1                   0.0022316              3.74211
call_module    layer2_1_conv2                   0.00211406             3.54501
call_module    layer3_0_conv2                   0.00210214             3.52502
call_module    layer3_1_conv1                   0.00208902             3.50303
call_module    layer2_0_conv2                   0.0020225              3.39149
call_module    layer4_0_conv1                   0.00196433             3.29394
call_module    bn1                              0.00139451             2.33842
call_module    layer2_0_conv1                   0.00133061             2.23128
call_module    layer3_0_conv1                   0.00124478             2.08735
call_module    layer2_0_downsample_0            0.00108624             1.82148
call_module    layer4_0_downsample_0            0.000471592            0.790801
call_module    layer3_0_downsample_0            0.000450134            0.75482
call_function  add                              0.000432968            0.726034
call_function  add_1                            0.000419855            0.704045
call_module    relu                             0.000311613            0.522537
call_module    layer1_0_bn1                     0.000288248            0.483356
call_module    layer1_0_bn2                     0.000271082            0.454571
call_module    layer1_1_bn2                     0.000260115            0.43618
call_module    fc                               0.000248432            0.41659
call_function  add_3                            0.000231981            0.389004
call_module    layer2_1_bn2                     0.000171661            0.287855
call_module    layer2_1_bn1                     0.000165224            0.27706
call_module    layer2_0_downsample_1            0.000153065            0.256671
call_module    layer1_1_bn1                     0.000150442            0.252273
call_module    avgpool                          0.000132084            0.221488
call_module    layer3_1_bn2                     0.000120401            0.201898
call_module    layer3_1_bn1                     0.000115395            0.193502
call_module    layer4_1_bn2                     0.000115395            0.193502
call_module    layer4_0_bn2                     0.000113249            0.189904
call_module    layer1_0_relu                    9.799e-05              0.164317
call_module    layer3_0_bn2                     9.60827e-05            0.161119
call_module    layer1_0_relu_1                  9.58443e-05            0.160719
call_module    layer2_0_bn1                     9.10759e-05            0.152723
call_module    layer2_0_bn2                     9.08375e-05            0.152323
call_module    layer4_1_bn1                     8.58307e-05            0.143927
call_function  add_2                            8.27312e-05            0.13873
call_module    layer1_1_relu_1                  8.2016e-05             0.137531
call_function  add_5                            7.93934e-05            0.133133
call_function  add_7                            7.67708e-05            0.128735
call_module    layer4_0_downsample_1            7.53403e-05            0.126336
call_module    layer4_0_bn1                     7.43866e-05            0.124737
call_module    layer1_1_relu                    7.29561e-05            0.122338
call_module    layer3_0_downsample_1            7.03335e-05            0.117941
call_module    layer3_0_bn1                     6.86646e-05            0.115142
call_function  add_6                            6.67572e-05            0.111944
call_module    layer4_0_relu                    6.36578e-05            0.106746
call_function  add_4                            5.74589e-05            0.0963514
call_module    layer4_1_relu                    5.34058e-05            0.0895549
call_module    layer2_0_relu_1                  5.22137e-05            0.0875559
call_module    layer4_0_relu_1                  5.05447e-05            0.0847573
call_module    layer2_1_relu                    4.69685e-05            0.0787603
call_module    layer2_1_relu_1                  4.69685e-05            0.0787603
call_module    layer2_0_relu                    4.673e-05              0.0783605
call_module    layer4_1_relu_1                  4.33922e-05            0.0727633
call_module    layer3_1_relu                    3.83854e-05            0.0643676
call_module    layer3_1_relu_1                  3.74317e-05            0.0627684
call_module    layer3_0_relu                    3.69549e-05            0.0619688
call_module    layer3_0_relu_1                  3.69549e-05            0.0619688
call_function  flatten                          2.88486e-05            0.0483756
placeholder    x                                1.57356e-05            0.0263867
output         output                           9.77516e-06            0.0163917

這裡有兩點需要指出:

  • MaxPool2d 佔用了大部分時間。這是一個已知問題:https://github.com/pytorch/pytorch/issues/51393

  • BatchNorm2d 也佔用了大量時間。我們可以繼續沿著這個思路思考,並在使用 FX 的卷積-BN 融合教程中對其進行最佳化。

結論

正如我們所見,使用 FX,我們可以輕鬆地捕獲 PyTorch 程式(即使是我們沒有原始碼的程式!)並將其轉換為機器可解釋的格式,並用於分析,例如我們在這裡進行的效能分析。FX 為處理 PyTorch 程式打開了一個激動人心的世界。

最後,由於 FX 仍處於測試階段,我們非常樂意聽取您在使用它方面的任何反饋意見。請隨時使用 PyTorch 論壇 (https://discuss.pytorch.org/) 和問題追蹤器 (https://github.com/pytorch/pytorch/issues) 提供您可能擁有的任何反饋。

指令碼總執行時間: ( 0 分 0.292 秒)

Gallery 生成自 Sphinx-Gallery

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源