Skip to content

Commit 26ca89a

Browse files
authored
Ban onnx imports (#87)
In all places except for tests, serde, and selected passes and helpers. Also cleaned up src/onnx_ir/_version_utils.py --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 1269e8b commit 26ca89a

File tree

10 files changed

+19
-57
lines changed

10 files changed

+19
-57
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,16 @@ ignore = [
107107

108108
[tool.ruff.lint.flake8-tidy-imports.banned-api]
109109
"pathlib".msg = "Using pathlib can impact performance. Use os.path instead"
110+
"onnx.helper".msg = "onnx helpers tend to be protobuf-y and slow. Consider using ir.tensor, ir.DataType and related methods instead"
111+
"onnx.numpy_helper".msg = "onnx numpy helpers tend to be slow. Consider using ir.tensor, ir.DataType and related methods instead"
112+
"onnx".msg = "Use onnx_ir methods and classes instead, or create an exception with `# noqa: TID251` if you need to use onnx directly"
110113

111114
[tool.ruff.lint.per-file-ignores]
112115
"__init__.py" = ["TID252"] # Allow relative imports in init files
113116
"setup.py" = ["TID251"] # pathlib is allowed in supporting code
114117
"**/*_test.py" = ["TID251"] # pathlib is allowed in tests
118+
"tools/**/*.py" = ["TID251"] # OK to use in tools
119+
"src/onnx_ir/testing.py" = ["TID251"] # OK to use in tools
115120

116121
[tool.ruff.lint.flake8-tidy-imports]
117122
# Disallow all relative imports.

src/onnx_ir/_convenience/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from collections.abc import Mapping, Sequence
2020
from typing import Union
2121

22-
import onnx
22+
import onnx # noqa: TID251
2323

2424
from onnx_ir import _core, _enums, _protocols, serde, traversal
2525

src/onnx_ir/_convenience/_constructors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from collections.abc import Mapping, Sequence
1414

1515
import numpy as np
16-
import onnx
16+
import onnx # noqa: TID251
1717

1818
from onnx_ir import _convenience, _core, _enums, _protocols, serde, tensor_adapters
1919

src/onnx_ir/_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import os
1010
from typing import Callable
1111

12-
import onnx
12+
import onnx # noqa: TID251
1313

1414
from onnx_ir import _core, _protocols, serde
1515
from onnx_ir import external_data as _external_data

src/onnx_ir/_version_utils.py

Lines changed: 5 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
# SPDX-License-Identifier: Apache-2.0
33
"""Version utils for testing."""
44

5+
# pylint: disable=import-outside-toplevel
56
from __future__ import annotations
67

78
import packaging.version
89

910

1011
def onnx_older_than(version: str) -> bool:
1112
"""Returns True if the ONNX version is older than the given version."""
12-
import onnx # pylint: disable=import-outside-toplevel
13+
import onnx # noqa: TID251
1314

1415
return (
1516
packaging.version.parse(onnx.__version__).release
@@ -19,50 +20,17 @@ def onnx_older_than(version: str) -> bool:
1920

2021
def torch_older_than(version: str) -> bool:
2122
"""Returns True if the torch version is older than the given version."""
22-
import torch # pylint: disable=import-outside-toplevel
23+
import torch
2324

2425
return (
2526
packaging.version.parse(torch.__version__).release
2627
< packaging.version.parse(version).release
2728
)
2829

2930

30-
def transformers_older_than(version: str) -> bool | None:
31-
"""Returns True if the transformers version is older than the given version."""
32-
try:
33-
import transformers # pylint: disable=import-outside-toplevel
34-
except ImportError:
35-
return None
36-
37-
return (
38-
packaging.version.parse(transformers.__version__).release
39-
< packaging.version.parse(version).release
40-
)
41-
42-
43-
def is_onnxruntime_training() -> bool:
44-
"""Returns True if the onnxruntime is onnxruntime-training."""
45-
try:
46-
from onnxruntime import training # pylint: disable=import-outside-toplevel
47-
48-
assert training
49-
except ImportError:
50-
# onnxruntime not training
51-
return False
52-
53-
try:
54-
from onnxruntime.capi.onnxruntime_pybind11_state import ( # pylint: disable=import-outside-toplevel
55-
OrtValueVector,
56-
)
57-
except ImportError:
58-
return False
59-
60-
return hasattr(OrtValueVector, "push_back_batch")
61-
62-
6331
def onnxruntime_older_than(version: str) -> bool:
6432
"""Returns True if the onnxruntime version is older than the given version."""
65-
import onnxruntime # pylint: disable=import-outside-toplevel
33+
import onnxruntime
6634

6735
return (
6836
packaging.version.parse(onnxruntime.__version__).release
@@ -72,20 +40,9 @@ def onnxruntime_older_than(version: str) -> bool:
7240

7341
def numpy_older_than(version: str) -> bool:
7442
"""Returns True if the numpy version is older than the given version."""
75-
import numpy # pylint: disable=import-outside-toplevel
43+
import numpy
7644

7745
return (
7846
packaging.version.parse(numpy.__version__).release
7947
< packaging.version.parse(version).release
8048
)
81-
82-
83-
def has_transformers():
84-
"""Tells if transformers is installed."""
85-
try:
86-
import transformers # pylint: disable=import-outside-toplevel
87-
88-
assert transformers
89-
return True # noqa
90-
except ImportError:
91-
return False

src/onnx_ir/passes/common/_c_api_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import onnx_ir as ir
1111

1212
if TYPE_CHECKING:
13-
import onnx
13+
import onnx # noqa: TID251
1414

1515

1616
logger = logging.getLogger(__name__)

src/onnx_ir/passes/common/onnx_checker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from typing import Literal
1212

13-
import onnx
13+
import onnx # noqa: TID251
1414

1515
import onnx_ir as ir
1616
from onnx_ir.passes.common import _c_api_utils

src/onnx_ir/passes/common/shape_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import logging
1313

14-
import onnx
14+
import onnx # noqa: TID251
1515

1616
import onnx_ir as ir
1717
from onnx_ir.passes.common import _c_api_utils

src/onnx_ir/passes/common/unused_removal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import logging
1212

13-
import onnx
13+
import onnx # noqa: TID251
1414

1515
import onnx_ir as ir
1616

src/onnx_ir/serde.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@
6767
from typing import Any, Callable
6868

6969
import numpy as np
70-
import onnx
71-
import onnx.external_data_helper
70+
import onnx # noqa: TID251
71+
import onnx.external_data_helper # noqa: TID251
7272

7373
from onnx_ir import _convenience, _core, _enums, _protocols, _type_casting
7474

0 commit comments

Comments
 (0)