Skip to content

Commit f03c4ad

Browse files
authored
Create adapters for torch (#94)
Fix #93 --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent a833ab1 commit f03c4ad

File tree

2 files changed

+83
-7
lines changed

2 files changed

+83
-7
lines changed

docs/api/ir_tensor_adapters.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# onnx_ir.tensor_adapters
2+
3+
```{eval-rst}
4+
.. automodule:: onnx_ir.tensor_adapters
5+
```
6+
7+
## Adapters for PyTorch
8+
9+
```{eval-rst}
10+
.. autosummary::
11+
:toctree: generated
12+
:template: classtemplate.rst
13+
:nosignatures:
14+
15+
onnx_ir.tensor_adapters.TorchTensor
16+
```
17+
18+
```{eval-rst}
19+
.. autofunction:: onnx_ir.tensor_adapters.from_torch_dtype
20+
.. autofunction:: onnx_ir.tensor_adapters.to_torch_dtype
21+
```

src/onnx_ir/tensor_adapters.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from __future__ import annotations
3030

3131
__all__ = [
32+
"from_torch_dtype",
33+
"to_torch_dtype",
3234
"TorchTensor",
3335
]
3436

@@ -44,14 +46,17 @@
4446
import torch
4547

4648

47-
class TorchTensor(_core.Tensor):
48-
def __init__(
49-
self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
50-
):
51-
# Pass the tensor as the raw data to ir.Tensor's constructor
49+
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] | None = None
50+
_ONNX_DTYPE_TO_TORCH: dict[ir.DataType, torch.dtype] | None = None
51+
52+
53+
def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
54+
"""Convert a PyTorch dtype to an ONNX IR DataType."""
55+
global _TORCH_DTYPE_TO_ONNX
56+
if _TORCH_DTYPE_TO_ONNX is None:
5257
import torch
5358

54-
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
59+
_TORCH_DTYPE_TO_ONNX = {
5560
torch.bfloat16: ir.DataType.BFLOAT16,
5661
torch.bool: ir.DataType.BOOL,
5762
torch.complex128: ir.DataType.COMPLEX128,
@@ -72,8 +77,58 @@ def __init__(
7277
torch.uint32: ir.DataType.UINT32,
7378
torch.uint64: ir.DataType.UINT64,
7479
}
80+
if dtype not in _TORCH_DTYPE_TO_ONNX:
81+
raise TypeError(
82+
f"Unsupported PyTorch dtype '{dtype}'. "
83+
"Please use a supported dtype from the list: "
84+
f"{list(_TORCH_DTYPE_TO_ONNX.keys())}"
85+
)
86+
return _TORCH_DTYPE_TO_ONNX[dtype]
87+
88+
89+
def to_torch_dtype(dtype: ir.DataType) -> torch.dtype:
90+
"""Convert an ONNX IR DataType to a PyTorch dtype."""
91+
global _ONNX_DTYPE_TO_TORCH
92+
if _ONNX_DTYPE_TO_TORCH is None:
93+
import torch
94+
95+
_ONNX_DTYPE_TO_TORCH = {
96+
ir.DataType.BFLOAT16: torch.bfloat16,
97+
ir.DataType.BOOL: torch.bool,
98+
ir.DataType.COMPLEX128: torch.complex128,
99+
ir.DataType.COMPLEX64: torch.complex64,
100+
ir.DataType.FLOAT16: torch.float16,
101+
ir.DataType.FLOAT: torch.float32,
102+
ir.DataType.DOUBLE: torch.float64,
103+
ir.DataType.FLOAT8E4M3FN: torch.float8_e4m3fn,
104+
ir.DataType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz,
105+
ir.DataType.FLOAT8E5M2: torch.float8_e5m2,
106+
ir.DataType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz,
107+
ir.DataType.INT16: torch.int16,
108+
ir.DataType.INT32: torch.int32,
109+
ir.DataType.INT64: torch.int64,
110+
ir.DataType.INT8: torch.int8,
111+
ir.DataType.UINT8: torch.uint8,
112+
ir.DataType.UINT16: torch.uint16,
113+
ir.DataType.UINT32: torch.uint32,
114+
ir.DataType.UINT64: torch.uint64,
115+
}
116+
if dtype not in _ONNX_DTYPE_TO_TORCH:
117+
raise TypeError(
118+
f"Unsupported conversion from ONNX dtype '{dtype}' to torch. "
119+
"Please use a supported dtype from the list: "
120+
f"{list(_ONNX_DTYPE_TO_TORCH.keys())}"
121+
)
122+
return _ONNX_DTYPE_TO_TORCH[dtype]
123+
124+
125+
class TorchTensor(_core.Tensor):
126+
def __init__(
127+
self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
128+
):
129+
# Pass the tensor as the raw data to ir.Tensor's constructor
75130
super().__init__(
76-
tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string
131+
tensor, dtype=from_torch_dtype(tensor.dtype), name=name, doc_string=doc_string
77132
)
78133

79134
def numpy(self) -> npt.NDArray:

0 commit comments

Comments
 (0)