編寫自己的量化張量¶
torchao 中的量化構建在張量子類的基礎上。它們是 torchao 的主要擴充套件點,用於使用低精度計算提供靈活的推理和訓練支援,同時與重要的 PyTorch 功能(如 torch.compile、autograd 和分散式原語)結合使用。
在本教程中,我們將強調利用張量子類相對於模組替換的優勢,並透過一個簡單的示例來演示如何使用這種方法表達量化。
什麼是張量子類?¶
張量子類只是繼承自 torch.Tensor 的類。它們允許使用者在模型中現有操作之間插入自定義計算邏輯,以便像頂層 torch 名稱空間中的 torch.add 等函式能夠繼續無縫工作。
張量子類方法的一個顯而易見的替代方案是模組替換:例如,簡單地將模型中的所有 nn.Linear 模組替換為您自定義的 Int8QuantizedLinear 模組。與這種方法相比,使用張量子類有幾個重要的優勢:
更細粒度的整合點。 模組替換在模組級別攔截計算,因此不適用於依賴 torch 函式或原生模組變體(例如,略微修改過的 nn.Linear 版本)的模型。相比之下,由於張量子類在函式/操作級別攔截計算,因此只要使用相同的函式/操作,我們就能夠對模型進行量化。
更好的可組合性。 使用模組替換組合多個功能很笨拙。例如,組合兩個現有的 Int8QuantizedLinear 和 DistributedLinear 模組需要使用者建立另一個線性類來複制這些功能。張量子類透過簡單地將一個子類包裝在另一個子類中來繞過這個問題。如果外部張量(例如 DTensor)知道內部張量已量化,這也可以提供效能優勢,從而可以使用更少的網路和記憶體頻寬執行昂貴的全收集操作。
重用 PyTorch 元件。 使用張量子類來表達量化是很自然的,因為量化張量只是具有不同 dtype 的 torch.Tensors。模型結構不會改變(nn.Linears 仍然是 nn.Linears),因此後續的最佳化過程也可以保持與之前完全相同。
在本教程的其餘部分,我們將透過一個示例來演示如何使用這兩種方法表達量化。有關張量子類的更多閱讀資料,請參閱:
使用模組替換進行量化¶
我們首先以一個簡單的示例開始,演示如何使用模組替換實現 int8 對稱權重唯一量化。所有程式碼都可以在此示例指令碼中找到。我們將使用以下函式將 float32 張量量化為 int8 張量
from typing import Tuple
import torch
def int8_symmetric_quantize(
fp32_tensor: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Symmetrically quantize the torch.float32 tensor into torch.int8.
Return a 2-tuple of (quantized value, scale).
input: dimensions=[M, N], dtype=torch.float32
output: dimensions=[M, N], dtype=torch.int8
scale: dimensions=[M, 1], dtype=torch.float32
"""
quant_min = -128
quant_max = 127
min_val = torch.amin(fp32_tensor, dim=[1], keepdim=False)
max_val = torch.amax(fp32_tensor, dim=[1], keepdim=False)
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
scale = scale.view(fp32_tensor.shape[0], -1)
out = torch.round(fp32_tensor * (1.0 / scale))
out = torch.clamp(out, quant_min, quant_max).to(torch.int8)
return out, scale
接下來,我們將建立一個新的 QuantizedLinear 模組,該模組呼叫此函式來動態量化權重
class QuantizedLinear(torch.nn.Linear):
"""
Linear module that performs dynamic and symmetric weight-only
int8 quantization.
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
w_int8, scale = int8_symmetric_quantize(self.weight)
return torch.matmul(x, w_int8.t().to(x.dtype)) * scale.t()
@classmethod
def from_float(cls, mod: torch.nn.Linear):
new_linear = cls(mod.in_features, mod.out_features, mod.bias)
new_linear.weight = mod.weight
return new_linear
然後,剩下要做的就是將模型中的所有 nn.Linear 模組替換為我們新的 QuantizedLinear 模組。讓我們使用以下玩具模型進行演示
import copy
class ToyModel(torch.nn.Module):
def __init__(self, m: int, n: int, k: int):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
float_model = ToyModel(64, 128, 32).cuda()
quantized_model = copy.deepcopy(float_model)
# Swap torch.nn.Linear with QuantizedLinear
for name, child in quantized_model.named_children():
if type(child) == torch.nn.Linear:
new_linear = QuantizedLinear.from_float(child)
setattr(quantized_model, name, new_linear)
驗證模型現在使用了我們的 QuantizedLinear 模組。這個模型現在可以使用了!
>>> print(float_model)
ToyModel(
(linear1): Linear(in_features=64, out_features=128, bias=False)
(linear2): Linear(in_features=128, out_features=32, bias=False)
)
>>> print(quantized_model)
ToyModel(
(linear1): QuantizedLinear(in_features=64, out_features=128, bias=False)
(linear2): QuantizedLinear(in_features=128, out_features=32, bias=False)
)
這種簡單方法的一個重要缺點是靈活性。目前,這僅適用於原生的 PyTorch 模組,但如果模型有略微修改的線性模組(例如,支援分散式訓練)怎麼辦?它也無法用於直接呼叫線性函式版本(torch.nn.functional.linear)的模型。
此外,假設我們想將此功能與分散式訓練結合,分散式訓練也是透過模組替換實現的。除了建立另一個結合這兩個功能的模組之外,沒有其他簡潔的方法可以做到這一點。這些限制可以透過張量子類解決,張量子類是一種更優雅的方式,可以在模型中插入自定義計算,例如量化。
使用張量子類進行量化¶
在這裡,我們將使用基於 __torch_dispatch__ 的張量子類重新實現上述量化技術。
張量子類(通常利用 __torch_dispatch__)是 PyTorch 中一個非常強大/靈活的擴充套件點。它們作為擴充套件點主要有兩個目的:
張量子類允許您覆蓋(幾乎)每個 PyTorch API 的實現,並且在實現其他 PyTorch 產品時使用得很頻繁
張量子類允許您將張量資料與附加元資料耦合。一些例子:
[量化] scale/zero_point 元資料(AffineQuantizedTensor)
[不規則性] 關於不規則結構的元資料(NestedTensor, 文件)
對張量子類感興趣的讀者可以參考其他一些資源:
__torch_dispatch__ 文件(連結)
什麼是 __torch_dispatch__ (及其作用)(連結)
使用 __torch_dispatch__ 實現 FlopCounter 和 MemoryTracker 的 Google Colab(連結)
言歸正傳,讓我們首先定義用於對稱量化的基本張量子類
class Int8SymmetricTensor(torch.Tensor):
"""
Our subclass represents a tensor that has been quantized to int8
It will hold two inner tensors:
int_data: int8[M, N]
scale: fp32[M, 1]
"""
@staticmethod
@torch._dynamo.disable
def __new__(cls, int_data: torch.Tensor, scale: torch.Tensor):
return torch.Tensor._make_wrapper_subclass(
cls,
int_data.shape,
strides=int_data.stride(),
storage_offset=int_data.storage_offset(),
dtype=scale.dtype,
device=int_data.device,
)
@torch._dynamo.disable
def __init__(self, int_data: torch.Tensor, scale: torch.Tensor):
# inner data expected to be quantized already
assert int_data.dtype is torch.int8
# we could do more work to support ndim > 2!
assert int_data.ndim == 2
assert scale.ndim == 2
self.int_data = int_data
self.scale = scale
def __tensor_flatten__(self) -> Tuple[List[str], Any]:
"""
Returns a tuple of:
names of all inner tensor attributes (two in our case)
any other additional, non-tensor metadata.
Needed for PT2 support.
"""
return ["int_data", "scale"], None
@classmethod
def __tensor_unflatten__(cls, tensor_data_dict, extra_metadata, outer_size=None, outer_stride=None):
"""
__tensor_unflatten__ should effectively undo __tensor_flatten__.
inputs:
a dict mapping names of inner tensor attributes back to the tensors
the constant metadata from __tensor_flatten__
output:
a new instance of your subclass
Needed for PT2 support.
"""
assert extra_metadata is None
int_data = tensor_data_dict["int_data"]
scale = tensor_data_dict["scale"]
return Int8SymmetricTensor(int_data, scale)
def __repr__(self):
return f'Int8SymmetricTensor(int_data={repr(self.int_data)}, scale={repr(self.scale)})'
@staticmethod
def from_float(float_tensor):
"""
Actually performs the symmetric quantization.
In our simple inference example we will quantize weights "ahead-of-time",
although later in a training example we can quantize/dequantize
during model execution, inside of our __torch_dispatch__
input:
float32 torch.Tensor
output:
Int8SymmetricTensor
"""
int8_tensor, scale = int8_symmetric_quantize(float_tensor)
return Int8SymmetricTensor(int8_tensor, scale)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
"""
Called for each ATen operator that our subclass is passed as an input to.
We need to define our own implementation for every operator here.
"""
if kwargs is None:
kwargs = {}
if func not in op_implementations_dict:
raise AssertionError(f'Int8SymmetricTensor does not yet support op: {str(func)}')
return op_implementations_dict[func](func, *args, **kwargs)
# Convenience function for registering our own implementation
# to every ATen operator in PyTorch
op_implementations_dict = {}
def register_op(ops: List[torch._ops.OpOverload]):
def impl_decorator(op_impl):
global op_implementations_dict
for op in ops:
op_implementations_dict[op] = op_impl
return op_impl
return impl_decorator
在上面的程式碼中,我們做了幾件事
定義了一個基本的“包裝器”張量子類 - 它實際上是一個容器物件,用於儲存一些內部資料(特別是對應於我們的 int8 資料和 scale 的兩個張量)
定義了 __torch_dispatch__ 的實現,當模型對我們的任何子類輸入呼叫任何 ATen 運算子時,都會呼叫此實現
(為了支援 PT2)定義了 __tensor_flatten__/__tensor_unflatten__ 方法。這是為了使我們的子類與 torch.compile 一起工作所需的幾個最大要求之一(稍後會詳細介紹)。它有效地告訴 torch.compile 如何將我們的子類“解糖”成其內部元件。
(為了支援 PT2)向兩個構造方法(__new__ 和 __init__)添加了 torch._dynamo.disable 裝飾器(稍後會詳細介紹)。
我們應該實現哪些運算子?¶
PyTorch 有一個相當大的運算子集合。我們不打算讓新的張量子類實現 100% 的覆蓋率,而是隻關注上面玩具模型所需的那些操作。
但是,我們的模型中呼叫了哪些運算子,以便我們知道應該先實現什麼?最笨的方法是反覆執行模型,檢視子類中出現哪些運算子錯誤。一個更優雅的方法是記錄模型在執行期間遇到的所有運算子。這可以透過另一個 LoggingTensor 子類實現,例如此示例。
讓我們在下面實現必要的運算子
from torch.utils._python_dispatch import return_and_correct_aliasing
@register_op([torch.ops.aten.mm.default])
def int8_mm(func, x, weight):
assert isinstance(weight, Int8SymmetricTensor), "Int8SymmetricTensor: matmul currently only supports the weight in low precision, not the input!"
return torch.mm(x, weight.int_data.to(x.dtype)) * weight.scale
@register_op([
torch.ops.aten.detach.default,
torch.ops.aten.t.default,
])
def int8_view_ops(func, *args, **kwargs):
assert isinstance(args[0], Int8SymmetricTensor)
out_data = func(args[0].int_data, *args[1:], **kwargs)
out_scale = func(args[0].scale, *args[1:], **kwargs)
out = Int8SymmetricTensor(out_data, out_scale)
return return_and_correct_aliasing(func, args, kwargs, out)
您很快會注意到一件事:我們的模型本身包含幾個線性層,但我們看到一些運算子(如 aten.t 和 aten.mm)命中了我們的子類。一些背景知識:
我們有許多存在於 C++ 中的運算子分解,它們執行在張量子類“之上”。linear 就是這樣一種運算子(分解程式碼在此)
分解的好處在於它們減少了作為子類作者需要實現的 API 數量。但如果您寧願覆蓋“更高層級”的運算子而不是其分解中的底層操作,那麼它們可能會很麻煩。
如果您希望在更高級別覆蓋某些操作(如 Linear),可以使用 __torch_function__ (示例)。值得注意的是,如果您需要 autograd 支援,那麼在 __torch_function__ 層進行的任何覆蓋都需要以可微分的方式編寫,而在 __torch_dispatch__ 中進行的任何覆蓋將自動可微分。
我們的實現中有一些值得指出的細微之處
您會注意到,在我們的 mm 實現中,我們不再需要在內部對權重/scale 進行轉置。這是因為在我們到達 aten.mm 操作之前,轉置“已經發生”了。
我們的 aten.mm 實現不返回張量子類輸出。從這個意義上說,我們的量化子類的“傳播”在矩陣乘法處結束。這反映了我們的權重是低精度的,但我們需要在高精度下執行矩陣乘法本身。一般來說,子類作者可以自由選擇他們的子類對哪些操作進行傳播或不傳播。如果您希望模型中的每個函式(包括所有逐點操作和歸約操作)都進行量化,您可以編寫子類實現,對每個操作的輸出進行量化,並始終返回一個子類。
我們能夠對 4 個檢視操作重用相同的實現。一般來說,許多操作可以透過相當通用的實現來處理:解包裝任何子類輸入,在內部張量上執行底層運算子,然後將輸出重新包裝回子類中。
然而,您是否總能重用實現取決於您嘗試做什麼。例如,我們在子類上透過對內部資料和內部 scale 張量呼叫相同的轉置來實現 transpose(dim0, dim1)。如果我們的 scale 和資料張量具有不同的維度數,這將不起作用,因此在這種情況下,轉置將需要自定義實現。
比較輸出¶
話不多說,讓我們用這兩種量化版本執行我們的模型,並確認它們給出相同的輸出!
float_model = ToyModel(64, 128, 32).cuda()
quantized_model_module_swap = copy.deepcopy(float_model)
quantized_model_subclass = copy.deepcopy(float_model)
# Swap torch.nn.Linear with QuantizedLinear
for name, child in quantized_model_module_swap.named_children():
if type(child) == torch.nn.Linear:
new_linear = QuantizedLinear.from_float(child)
setattr(quantized_model_module_swap, name, new_linear)
# Swap torch.nn.Linear weights with Int8SymmetricTensor subclasses
for name, child in quantized_model_subclass.named_children():
if type(child) == torch.nn.Linear:
subclass_param = Int8SymmetricTensor.from_float(child.weight)
child.weight = torch.nn.Parameter(subclass_param, requires_grad=True)
with torch.no_grad():
x = torch.randn(64, 64, 64, device='cuda')
out_module_swap = quantized_model_module_swap(x)
out = quantized_model_subclass(x)
print(torch.allclose(out, out_module_swap)) # prints True
# We can also use torch.compile to fuse some of our quantized logic
out_compiled = torch.compile(quantized_model_subclass)(x)
print(torch.allclose(out, out_compiled)) # prints True
後續步驟¶
在本教程中,我們演示瞭如何構建一個簡單的量化張量子類。這是本系列兩個教程中的第一部分。下一篇文章將討論如何向張量子類新增更高階的功能,例如使其可訓練、與 DTensor 組合以及新增張量並行性支援。有關 torchao 中如何使用張量子類構建 AffineQuantizedTensor 的更詳細示例,請參閱此示例。
如果您在實現子類時有任何疑問,請隨時在此提交問題。