2929from __future__ import annotations
3030
3131__all__ = [
32+ "from_torch_dtype" ,
33+ "to_torch_dtype" ,
3234 "TorchTensor" ,
3335]
3436
4446 import torch
4547
4648
47- class TorchTensor (_core .Tensor ):
48- def __init__ (
49- self , tensor : torch .Tensor , name : str | None = None , doc_string : str | None = None
50- ):
51- # Pass the tensor as the raw data to ir.Tensor's constructor
49+ _TORCH_DTYPE_TO_ONNX : dict [torch .dtype , ir .DataType ] | None = None
50+ _ONNX_DTYPE_TO_TORCH : dict [ir .DataType , torch .dtype ] | None = None
51+
52+
53+ def from_torch_dtype (dtype : torch .dtype ) -> ir .DataType :
54+ """Convert a PyTorch dtype to an ONNX IR DataType."""
55+ global _TORCH_DTYPE_TO_ONNX
56+ if _TORCH_DTYPE_TO_ONNX is None :
5257 import torch
5358
54- _TORCH_DTYPE_TO_ONNX : dict [ torch . dtype , ir . DataType ] = {
59+ _TORCH_DTYPE_TO_ONNX = {
5560 torch .bfloat16 : ir .DataType .BFLOAT16 ,
5661 torch .bool : ir .DataType .BOOL ,
5762 torch .complex128 : ir .DataType .COMPLEX128 ,
@@ -72,8 +77,58 @@ def __init__(
7277 torch .uint32 : ir .DataType .UINT32 ,
7378 torch .uint64 : ir .DataType .UINT64 ,
7479 }
80+ if dtype not in _TORCH_DTYPE_TO_ONNX :
81+ raise TypeError (
82+ f"Unsupported PyTorch dtype '{ dtype } '. "
83+ "Please use a supported dtype from the list: "
84+ f"{ list (_TORCH_DTYPE_TO_ONNX .keys ())} "
85+ )
86+ return _TORCH_DTYPE_TO_ONNX [dtype ]
87+
88+
89+ def to_torch_dtype (dtype : ir .DataType ) -> torch .dtype :
90+ """Convert an ONNX IR DataType to a PyTorch dtype."""
91+ global _ONNX_DTYPE_TO_TORCH
92+ if _ONNX_DTYPE_TO_TORCH is None :
93+ import torch
94+
95+ _ONNX_DTYPE_TO_TORCH = {
96+ ir .DataType .BFLOAT16 : torch .bfloat16 ,
97+ ir .DataType .BOOL : torch .bool ,
98+ ir .DataType .COMPLEX128 : torch .complex128 ,
99+ ir .DataType .COMPLEX64 : torch .complex64 ,
100+ ir .DataType .FLOAT16 : torch .float16 ,
101+ ir .DataType .FLOAT : torch .float32 ,
102+ ir .DataType .DOUBLE : torch .float64 ,
103+ ir .DataType .FLOAT8E4M3FN : torch .float8_e4m3fn ,
104+ ir .DataType .FLOAT8E4M3FNUZ : torch .float8_e4m3fnuz ,
105+ ir .DataType .FLOAT8E5M2 : torch .float8_e5m2 ,
106+ ir .DataType .FLOAT8E5M2FNUZ : torch .float8_e5m2fnuz ,
107+ ir .DataType .INT16 : torch .int16 ,
108+ ir .DataType .INT32 : torch .int32 ,
109+ ir .DataType .INT64 : torch .int64 ,
110+ ir .DataType .INT8 : torch .int8 ,
111+ ir .DataType .UINT8 : torch .uint8 ,
112+ ir .DataType .UINT16 : torch .uint16 ,
113+ ir .DataType .UINT32 : torch .uint32 ,
114+ ir .DataType .UINT64 : torch .uint64 ,
115+ }
116+ if dtype not in _ONNX_DTYPE_TO_TORCH :
117+ raise TypeError (
118+ f"Unsupported conversion from ONNX dtype '{ dtype } ' to torch. "
119+ "Please use a supported dtype from the list: "
120+ f"{ list (_ONNX_DTYPE_TO_TORCH .keys ())} "
121+ )
122+ return _ONNX_DTYPE_TO_TORCH [dtype ]
123+
124+
125+ class TorchTensor (_core .Tensor ):
126+ def __init__ (
127+ self , tensor : torch .Tensor , name : str | None = None , doc_string : str | None = None
128+ ):
129+ # Pass the tensor as the raw data to ir.Tensor's constructor
75130 super ().__init__ (
76- tensor , dtype = _TORCH_DTYPE_TO_ONNX [ tensor .dtype ] , name = name , doc_string = doc_string
131+ tensor , dtype = from_torch_dtype ( tensor .dtype ) , name = name , doc_string = doc_string
77132 )
78133
79134 def numpy (self ) -> npt .NDArray :
0 commit comments