Skip to content

Commit e382f9d

Browse files
[DataType] Add is_integer and is_signed queries (#110)
These queries could be convenient for conditional checks in ONNX Script pattern rewrite rules. A simple example would be rewriting subtraction of a constant as addition of the negative (additive inverse) of the constant which should be restricted to signed numbers (as unsigned numbers do not have an additive inverse - at least within their data type bounds). See #109 --------- Signed-off-by: Christoph Berganski <christoph.berganski@gmail.com> Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 12047d2 commit e382f9d

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

src/onnx_ir/_enums.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,42 @@ def is_floating_point(self) -> bool:
169169
DataType.FLOAT4E2M1,
170170
}
171171

172+
def is_integer(self) -> bool:
173+
"""Returns True if the data type is an integer."""
174+
return self in {
175+
DataType.UINT8,
176+
DataType.INT8,
177+
DataType.UINT16,
178+
DataType.INT16,
179+
DataType.INT32,
180+
DataType.INT64,
181+
DataType.UINT32,
182+
DataType.UINT64,
183+
DataType.UINT4,
184+
DataType.INT4,
185+
}
186+
187+
def is_signed(self) -> bool:
188+
"""Returns True if the data type is a signed type."""
189+
return self in {
190+
DataType.FLOAT,
191+
DataType.INT8,
192+
DataType.INT16,
193+
DataType.INT32,
194+
DataType.INT64,
195+
DataType.FLOAT16,
196+
DataType.DOUBLE,
197+
DataType.COMPLEX64,
198+
DataType.COMPLEX128,
199+
DataType.BFLOAT16,
200+
DataType.FLOAT8E4M3FN,
201+
DataType.FLOAT8E4M3FNUZ,
202+
DataType.FLOAT8E5M2,
203+
DataType.FLOAT8E5M2FNUZ,
204+
DataType.INT4,
205+
DataType.FLOAT4E2M1,
206+
}
207+
172208
def __repr__(self) -> str:
173209
return self.name
174210

0 commit comments

Comments
 (0)