JitScalarType¶
- class torch.onnx.JitScalarType(value)¶
在 torch 中定義的標量型別。
使用
JitScalarType將 torch 和 JIT 標量型別轉換為 ONNX 標量型別。示例
>>> JitScalarType.from_value(torch.ones(1, 2)).onnx_type() TensorProtoDataType.FLOAT
>>> JitScalarType.from_value(torch_c_value_with_type_float).onnx_type() TensorProtoDataType.FLOAT
>>> JitScalarType.from_dtype(torch.get_default_dtype).onnx_type() TensorProtoDataType.FLOAT
- classmethod from_dtype(dtype)[source][source]¶
將 torch dtype 轉換為 JitScalarType。
- 注意:當 dtype 來自 torch._C.Value.type() 呼叫時,請勿使用此 API。
在 shape 資訊不存在的多種情況下,可能會引發“RuntimeError: INTERNAL ASSERT FAILED at “../aten/src/ATen/core/jit_type_base.h” 錯誤。請改用更安全的 from_value API。
- 引數
dtype (torch.dtype | None) – 用於建立 JitScalarType 的 torch.dtype
- 返回值
JitScalarType
- 丟擲異常
OnnxExporterError – 如果 dtype 不是有效的 torch.dtype 或為 None。
- 返回型別
- classmethod from_onnx_type(onnx_type)[source][source]¶
將 ONNX 資料型別轉換為 JitScalarType。
- 引數
onnx_type (int | _C_onnx.TensorProtoDataType | None) – 用於建立 JitScalarType 的 torch._C._onnx.TensorProtoDataType
- 返回值
JitScalarType
- 丟擲異常
OnnxExporterError – 如果 dtype 不是有效的 torch.dtype 或為 None。
- 返回型別
- classmethod from_value(value, default=None)[source][source]¶
從值的標量型別建立 JitScalarType。
- 引數
value (None | torch._C.Value | torch.Tensor) – 從中獲取標量型別的物件。
default – 如果無法從值中獲取有效標量,則返回的 JitScalarType。
- 返回值
JitScalarType.
- 丟擲異常
OnnxExporterError – 如果值沒有有效的標量型別且 default 為 None。
SymbolicValueError – 當 value.type() 的資訊為空且 default 為 None 時。
- 返回型別
- scalar_name()[source][source]¶
將 JitScalarType 轉換為 JIT 標量型別名稱。
- 返回型別
Literal[‘Byte’, ‘Char’, ‘Double’, ‘Float’, ‘Half’, ‘Int’, ‘Long’, ‘Short’, ‘Bool’, ‘ComplexHalf’, ‘ComplexFloat’, ‘ComplexDouble’, ‘QInt8’, ‘QUInt8’, ‘QInt32’, ‘BFloat16’, ‘Float8E5M2’, ‘Float8E4M3FN’, ‘Float8E5M2FNUZ’, ‘Float8E4M3FNUZ’, ‘Undefined’]
- torch_name()[source][source]¶
將 JitScalarType 轉換為 torch 型別名稱。
- 返回型別
Literal[‘bool’, ‘uint8_t’, ‘int8_t’, ‘double’, ‘float’, ‘half’, ‘int’, ‘int64_t’, ‘int16_t’, ‘complex32’, ‘complex64’, ‘complex128’, ‘qint8’, ‘quint8’, ‘qint32’, ‘bfloat16’, ‘float8_e5m2’, ‘float8_e4m3fn’, ‘float8_e5m2fnuz’, ‘float8_e4m3fnuz’]