快捷方式

torch.set_float32_matmul_precision

torch.set_float32_matmul_precision(precision)[原始碼][原始碼]

設定 float32 矩陣乘法的內部精度。

以較低精度執行 float32 矩陣乘法可以顯著提升效能,在某些程式中,精度的損失影響微乎其微。

支援三種設定:

  • “highest”,float32 矩陣乘法使用 float32 資料型別(尾數位 24 位,其中 23 位顯式儲存)進行內部計算。

  • “high”,如果可用的快速矩陣乘法演算法支援,float32 矩陣乘法要麼使用 TensorFloat32 資料型別(尾數位 10 位顯式儲存),要麼將每個 float32 數字視為兩個 bfloat16 數字之和(尾數位約 16 位,其中 14 位顯式儲存)。否則,float32 矩陣乘法將按照“highest”精度進行計算。有關 bfloat16 方法的更多資訊,請參見下文。

  • “medium”,如果內部使用 bfloat16 資料型別的快速矩陣乘法演算法可用,float32 矩陣乘法將使用 bfloat16 資料型別(尾數位 8 位,其中 7 位顯式儲存)進行內部計算。否則,float32 矩陣乘法將按照“high”精度進行計算。

使用“high”精度時,float32 乘法可能會使用基於 bfloat16 的演算法,該演算法比簡單地截斷到較少尾數位(例如 TensorFloat32 的 10 位,bfloat16 顯式儲存的 7 位)更復雜。有關此演算法的完整描述,請參閱 [Henry2019]。在此簡要解釋一下,第一步是意識到我們可以將單個 float32 數字完美地編碼為三個 bfloat16 數字之和(因為 float32 有 23 個尾數位,而 bfloat16 有 7 個顯式儲存位,並且兩者具有相同的指數位數)。這意味著兩個 float32 數字的乘積可以精確地表示為九個 bfloat16 數字乘積之和。然後,我們可以透過丟棄其中一些乘積來權衡精度和速度。“high”精度演算法特別只保留了三個最重要的乘積,這方便地排除了涉及任一輸入最後 8 個尾數位的所有乘積。這意味著我們可以將輸入表示為兩個 bfloat16 數字之和,而不是三個。由於 bfloat16 乘加融合 (FMA) 指令通常比 float32 指令快 10 倍以上,因此使用 bfloat16 精度進行三次乘法和兩次加法比使用 float32 精度進行一次乘法更快。

Henry2019

http://arxiv.org/abs/1904.06376

注意

這不會改變 float32 矩陣乘法的輸出資料型別(dtype),它控制的是矩陣乘法的內部計算方式。

注意

這不會改變卷積操作的精度。其他標誌,例如 torch.backends.cudnn.allow_tf32,可能會控制卷積操作的精度。

注意

當前,此標誌僅影響一種原生裝置型別:CUDA。如果設定為“high”或“medium”,則在計算 float32 矩陣乘法時將使用 TensorFloat32 資料型別,這等同於設定 torch.backends.cuda.matmul.allow_tf32 = True。當設定為“highest”(預設值)時,內部計算使用 float32 資料型別,這等同於設定 torch.backends.cuda.matmul.allow_tf32 = False

引數

precision (str) – 可以設定為“highest”(預設)、“high”或“medium”(參見上文)。

文件

獲取 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源