快捷方式

torch.set_default_dtype

torch.set_default_dtype(d, /)[源][源]

將預設浮點 dtype 設定為 d。支援浮點 dtype 作為輸入。其他 dtype 將導致 torch 引發異常。

當 PyTorch 初始化時,其預設浮點 dtype 為 torch.float32,set_default_dtype(torch.float64) 的目的是為了促進 NumPy 風格的型別推斷。預設浮點 dtype 用於:

  1. 隱式確定預設複數 dtype。當預設浮點型別為 float16 時,預設複數 dtype 為 complex32。對於 float32,預設複數 dtype 為 complex64。對於 float64,它為 complex128。對於 bfloat16,將引發異常,因為 bfloat16 沒有對應的複數型別。

  2. 推斷使用 Python 浮點數或複數 Python 數字構造的張量的 dtype。參見下面的示例。

  3. 確定布林和整數張量與 Python 浮點數和複數 Python 數字之間型別提升的結果。

引數

d (torch.dtype) – 作為預設的浮點 dtype。

示例

>>> # initial default for floating point is torch.float32
>>> # Python floats are interpreted as float32
>>> torch.tensor([1.2, 3]).dtype
torch.float32
>>> # initial default for floating point is torch.complex64
>>> # Complex Python numbers are interpreted as complex64
>>> torch.tensor([1.2, 3j]).dtype
torch.complex64
>>> torch.set_default_dtype(torch.float64)
>>> # Python floats are now interpreted as float64
>>> torch.tensor([1.2, 3]).dtype  # a new floating point tensor
torch.float64
>>> # Complex Python numbers are now interpreted as complex128
>>> torch.tensor([1.2, 3j]).dtype  # a new complex tensor
torch.complex128
>>> torch.set_default_dtype(torch.float16)
>>> # Python floats are now interpreted as float16
>>> torch.tensor([1.2, 3]).dtype  # a new floating point tensor
torch.float16
>>> # Complex Python numbers are now interpreted as complex128
>>> torch.tensor([1.2, 3j]).dtype  # a new complex tensor
torch.complex32

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源