torch.ao.ns._numeric_suite¶
警告
此模組為早期原型,可能會有所變動。
- torch.ao.ns._numeric_suite.compare_weights(float_dict, quantized_dict)[原始碼]¶
比較浮點數模組与其對應的量化模組的權重。返回一個字典,鍵對應於模組名稱,每個條目都是一個包含兩個鍵「float」和「quantized」的字典,分別包含浮點數和量化權重。此字典可用於比較和計算浮點數和量化模型權重的量化誤差。
範例用法
wt_compare_dict = compare_weights( float_model.state_dict(), qmodel.state_dict()) for key in wt_compare_dict: print( key, compute_error( wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize() ) )
- torch.ao.ns._numeric_suite.get_logger_dict(mod, prefix='')[原始碼]¶
遍歷模組並將所有記錄器統計資訊儲存到目標字典中。這主要用於量化精度偵錯。
- 支援的記錄器類型
ShadowLogger:用於記錄量化模組及其匹配的浮點數陰影模組的輸出,OutputLogger:用於記錄模組的輸出
- class torch.ao.ns._numeric_suite.Shadow(q_module, float_module, logger_cls)[原始碼]¶
Shadow 模組會將浮點數模組附加到其匹配的量化模組作為陰影。然後它使用 Logger 模組來處理兩個模組的輸出。
- 參數
q_module – 從我們要建立陰影的 float_module 量化的模組
float_module – 用於建立 q_module 陰影的浮點數模組
logger_cls – 用於處理 q_module 和 float_module 輸出的記錄器類型。可以使用 ShadowLogger 或自訂記錄器。
- torch.ao.ns._numeric_suite.prepare_model_with_stubs(float_module, q_module, module_swap_list, logger_cls)[原始碼]¶
如果浮點數模組類型在 module_swap_list 中,則透過將浮點數模組附加到其匹配的量化模組作為陰影來準備模型。
範例用法
prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger) q_model(data) ob_dict = get_logger_dict(q_model)
- torch.ao.ns._numeric_suite.compare_model_stub(float_model, q_model, module_swap_list, *data, logger_cls=<class 'torch.ao.ns._numeric_suite.ShadowLogger'>)[原始碼]¶
比較模型中的量化模組及其浮點數對應模組,將相同的輸入饋送到兩者。傳回一個字典,其中鍵對應於模組名稱,每個條目都是一個包含兩個鍵「float」和「quantized」的字典,分別包含量化模組及其匹配的浮點數陰影模組的輸出張量。這個字典可用於比較和計算模組級別的量化誤差。
這個函式首先呼叫 prepare_model_with_stubs() 來交換我們要與 Shadow 模組比較的量化模組,該模組將量化模組、對應的浮點數模組和記錄器作為輸入,並在內部建立一條前向路徑,使浮點數模組能夠建立與量化模組共享相同輸入的陰影。記錄器可以自訂,預設記錄器是 ShadowLogger,它將儲存量化模組和浮點數模組的輸出,可用於計算模組級別的量化誤差。
範例用法
module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock] ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data) for key in ob_dict: print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
- torch.ao.ns._numeric_suite.get_matching_activations(float_module, q_module)[原始碼]¶
在浮點數和量化模組之間尋找匹配的激活函式。
- torch.ao.ns._numeric_suite.prepare_model_outputs(float_module, q_module, logger_cls=<class 'torch.ao.ns._numeric_suite.OutputLogger'>, allow_list=None)[原始碼]¶
如果浮點數模組和量化模組在 allow_list 中,則透過將記錄器附加到它們來準備模型。
- torch.ao.ns._numeric_suite.compare_model_outputs(float_model, q_model, *data, logger_cls=<class 'torch.ao.ns._numeric_suite.OutputLogger'>, allow_list=None)[原始碼]¶
針對相同的輸入,比較浮點數和量化模型在對應位置的輸出激活函式。傳回一個字典,其中鍵對應於量化模組名稱,每個條目都是一個包含兩個鍵「float」和「quantized」的字典,分別包含量化模型和浮點數模型在匹配位置的激活函式。這個字典可用於比較和計算傳播量化誤差。
範例用法
act_compare_dict = compare_model_outputs(float_model, qmodel, data) for key in act_compare_dict: print( key, compute_error( act_compare_dict[key]['float'], act_compare_dict[key]['quantized'].dequantize() ) )