Skip to content

Commit 5bdc32f

Browse files
zrr1999SigureMo
andauthored
✨ feat: add apply_method_to_operator (#5)
Co-authored-by: Nyakku Shigure <sigure.qaq@gmail.com>
1 parent 3326852 commit 5bdc32f

File tree

4 files changed

+189
-1
lines changed

4 files changed

+189
-1
lines changed

src/expr_simplifier/transforms/auto_simplify.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
from expr_simplifier.transforms.constant_folding import apply_constant_folding
66
from expr_simplifier.transforms.cse import apply_cse
77
from expr_simplifier.transforms.logical_simplification import apply_logical_simplification
8+
from expr_simplifier.transforms.method_to_operator import apply_method_to_operator
89
from expr_simplifier.utils import loop_until_stable
910

1011

1112
def auto_simplify(expr: ast.AST) -> ast.AST:
12-
return loop_until_stable(expr, [apply_constant_folding, apply_logical_simplification, apply_cse], max_iter=100)
13+
return loop_until_stable(
14+
expr, [apply_constant_folding, apply_logical_simplification, apply_cse, apply_method_to_operator], max_iter=100
15+
)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
from ast import Add, BitAnd, BitOr, BitXor, Div, FloorDiv, Invert, LShift, Mod, Mult, Pow, RShift, Sub, UAdd, USub
5+
6+
MAGIC_NAMES_TO_BINARY_OPS: dict[str, ast.operator] = {
7+
"__add__": Add(),
8+
"__sub__": Sub(),
9+
"__mul__": Mult(),
10+
"__truediv__": Div(),
11+
"__floordiv__": FloorDiv(),
12+
"__mod__": Mod(),
13+
"__pow__": Pow(),
14+
"__and__": BitAnd(),
15+
"__or__": BitOr(),
16+
"__xor__": BitXor(),
17+
"__lshift__": LShift(),
18+
"__rshift__": RShift(),
19+
}
20+
21+
MAGIC_NAMES_TO_REVERSE_BINARY_OPS: dict[str, ast.operator] = {
22+
"__radd__": Add(),
23+
"__rsub__": Sub(),
24+
"__rmul__": Mult(),
25+
"__rtruediv__": Div(),
26+
"__rfloordiv__": FloorDiv(),
27+
"__rmod__": Mod(),
28+
"__rpow__": Pow(),
29+
"__rand__": BitAnd(),
30+
"__ror__": BitOr(),
31+
"__rxor__": BitXor(),
32+
"__rlshift__": LShift(),
33+
"__rrshift__": RShift(),
34+
}
35+
36+
MAGIC_NAMES_TO_UNARY_OPS: dict[str, ast.unaryop] = {
37+
"__neg__": USub(),
38+
"__pos__": UAdd(),
39+
"__invert__": Invert(),
40+
}
41+
42+
43+
class MethodToOperator(ast.NodeTransformer):
44+
def visit_Call(self, node: ast.Call) -> ast.AST:
45+
if not isinstance(node.func, ast.Attribute):
46+
return self.generic_visit(node)
47+
if node.func.attr in MAGIC_NAMES_TO_BINARY_OPS and len(node.args) == 1:
48+
return ast.BinOp(
49+
left=MethodToOperator().visit(node.func.value),
50+
op=MAGIC_NAMES_TO_BINARY_OPS[node.func.attr],
51+
right=MethodToOperator().visit(node.args[0]),
52+
)
53+
if node.func.attr in MAGIC_NAMES_TO_REVERSE_BINARY_OPS and len(node.args) == 1:
54+
return ast.BinOp(
55+
left=MethodToOperator().visit(node.args[0]),
56+
op=MAGIC_NAMES_TO_REVERSE_BINARY_OPS[node.func.attr],
57+
right=MethodToOperator().visit(node.func.value),
58+
)
59+
if node.func.attr in MAGIC_NAMES_TO_UNARY_OPS and len(node.args) == 0:
60+
return ast.UnaryOp(
61+
op=MAGIC_NAMES_TO_UNARY_OPS[node.func.attr],
62+
operand=MethodToOperator().visit(node.func.value),
63+
)
64+
return self.generic_visit(node)
65+
66+
67+
def apply_method_to_operator(expr: ast.AST) -> ast.AST:
68+
return MethodToOperator().visit(expr)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
5+
import pytest
6+
7+
from expr_simplifier.transforms.method_to_operator import apply_method_to_operator
8+
9+
from .utils import check_expr_at_runtime
10+
11+
12+
@pytest.mark.parametrize(
13+
["expr", "expected"],
14+
[
15+
("a.__add__(b)", "a + b"),
16+
("a.__sub__(b)", "a - b"),
17+
("a.__mul__(b)", "a * b"),
18+
("a.__truediv__(b)", "a / b"),
19+
("a.__floordiv__(b)", "a // b"),
20+
("a.__mod__(b)", "a % b"),
21+
("a.__pow__(b)", "a ** b"),
22+
("a.__and__(b)", "a & b"),
23+
("a.__or__(b)", "a | b"),
24+
("a.__xor__(b)", "a ^ b"),
25+
("a.__lshift__(b)", "a << b"),
26+
("a.__rshift__(b)", "a >> b"),
27+
("a.b.__add__(c.d)", "a.b + c.d"),
28+
("a.__add__(b).__mul__(c)", "(a + b) * c"),
29+
("a.__add__(b.__mul__(c))", "a + b * c"),
30+
("a.method(b).__add__(c)", "a.method(b) + c"),
31+
("a.__add__(b.method(c))", "a + b.method(c)"),
32+
("a.b.c.__add__(d.e.f)", "a.b.c + d.e.f"),
33+
("a.__add__(b).__sub__(c.__mul__(d))", "a + b - c * d"),
34+
],
35+
)
36+
def test_method_to_operator(expr: str, expected: str):
37+
tree = ast.parse(expr, mode="eval")
38+
transformed_tree = apply_method_to_operator(tree)
39+
transformed_expr = ast.unparse(transformed_tree)
40+
assert transformed_expr == expected
41+
check_expr_at_runtime(tree, transformed_tree)
42+
43+
44+
@pytest.mark.parametrize(
45+
["expr", "expected"],
46+
[
47+
("a.__neg__()", "-a"),
48+
("a.__pos__()", "+a"),
49+
("a.__invert__()", "~a"),
50+
],
51+
)
52+
def test_method_to_unary_operator(expr: str, expected: str):
53+
tree = ast.parse(expr, mode="eval")
54+
transformed_tree = apply_method_to_operator(tree)
55+
transformed_expr = ast.unparse(transformed_tree)
56+
assert transformed_expr == expected
57+
check_expr_at_runtime(tree, transformed_tree)
58+
59+
60+
@pytest.mark.parametrize(
61+
["expr", "expected"],
62+
[
63+
("a.method(b)", "a.method(b)"),
64+
("a.__unknown__(b)", "a.__unknown__(b)"),
65+
("a.__add__", "a.__add__"),
66+
],
67+
)
68+
def test_no_change(expr: str, expected: str):
69+
tree = ast.parse(expr, mode="eval")
70+
transformed_tree = apply_method_to_operator(tree)
71+
transformed_expr = ast.unparse(transformed_tree)
72+
assert transformed_expr == expected

tests/test_transforms/utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ def __radd__(self, other: Any):
2020
other = AnyObject.to_any_object(other)
2121
return AnyObject(f"{other.name} + {self.name}")
2222

23+
def __sub__(self, other: Any):
24+
other = AnyObject.to_any_object(other)
25+
return AnyObject(f"{self.name} - {other.name}")
26+
27+
def __rsub__(self, other: Any):
28+
other = AnyObject.to_any_object(other)
29+
return AnyObject(f"{other.name} - {self.name}")
30+
2331
def __mul__(self, other: Any):
2432
other = AnyObject.to_any_object(other)
2533
return AnyObject(f"{self.name} * {other.name}")
@@ -44,6 +52,34 @@ def __rfloordiv__(self, other: Any):
4452
other = AnyObject.to_any_object(other)
4553
return AnyObject(f"{other.name} // {self.name}")
4654

55+
def __lshift__(self, other: Any):
56+
other = AnyObject.to_any_object(other)
57+
return AnyObject(f"{self.name} << {other.name}")
58+
59+
def __rshift__(self, other: Any):
60+
other = AnyObject.to_any_object(other)
61+
return AnyObject(f"{self.name} >> {other.name}")
62+
63+
def __or__(self, other: Any):
64+
other = AnyObject.to_any_object(other)
65+
return AnyObject(f"{self.name} | {other.name}")
66+
67+
def __xor__(self, other: Any):
68+
other = AnyObject.to_any_object(other)
69+
return AnyObject(f"{other.name} ^ {self.name}")
70+
71+
def __pow__(self, other: Any):
72+
other = AnyObject.to_any_object(other)
73+
return AnyObject(f"{self.name} ** {other.name}")
74+
75+
def __mod__(self, other: Any):
76+
other = AnyObject.to_any_object(other)
77+
return AnyObject(f"{self.name} % {other.name}")
78+
79+
def __and__(self, other: Any):
80+
other = AnyObject.to_any_object(other)
81+
return AnyObject(f"{self.name} & {other.name}")
82+
4783
def __getattr__(self, name: str):
4884
return AnyObject(f"{self.name}.{name}")
4985

@@ -81,6 +117,15 @@ def __le__(self, other: Any):
81117
def __repr__(self):
82118
return f"AnyObject({self.name})"
83119

120+
def __neg__(self):
121+
return AnyObject(f"-{self.name}")
122+
123+
def __pos__(self):
124+
return AnyObject(f"+{self.name}")
125+
126+
def __invert__(self):
127+
return AnyObject(f"~{self.name}")
128+
84129
@staticmethod
85130
def to_any_object(value: Any):
86131
if isinstance(value, AnyObject):

0 commit comments

Comments
 (0)