Skip to content

Commit 71357d8

Browse files
committed
working loop manager for single type
1 parent 134fdc0 commit 71357d8

File tree

2 files changed

+121
-167
lines changed

2 files changed

+121
-167
lines changed

grape/automaton/loop_manager.py

Lines changed: 115 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,38 @@
1-
from collections import defaultdict
21
from enum import StrEnum
32
import itertools
43

4+
from grape import types
55
from grape.automaton.tree_automaton import DFTA
66
from grape.dsl import DSL
7+
from grape.program import Function, Primitive, Program, Variable
78

89

910
class LoopStrategy(StrEnum):
1011
NO_LOOP = "none"
1112
STATE = "state"
1213

1314

14-
def __find_unbounded_types(
15-
dfta: DFTA[str, str], state_to_type: dict[str, str]
16-
) -> set[str]:
17-
unbounded_types = set()
18-
added = True
19-
while added:
20-
added = False
21-
for (P, args), dst in dfta.rules.items():
22-
prod_type = state_to_type[dst]
23-
if prod_type not in unbounded_types and any(
24-
state_to_type[arg_state] in unbounded_types
25-
or prod_type == state_to_type[arg_state]
26-
for arg_state in args
27-
):
28-
unbounded_types.add(prod_type)
29-
added = True
30-
return unbounded_types
31-
32-
33-
def __find_unconsumed_states(dfta: DFTA[str, str]) -> set[str]:
34-
not_consumed = dfta.all_states
35-
for P, args in dfta.rules:
36-
for arg_state in args:
37-
if arg_state in not_consumed:
38-
not_consumed.remove(arg_state)
39-
return not_consumed
40-
41-
42-
def __prod_types_by_states(
43-
dfta: DFTA[str, str], state_to_type: dict[str, str]
44-
) -> dict[str, set[str]]:
45-
# Compute transitive closure
46-
reachable_from: dict[str, set[str]] = defaultdict(set)
47-
for (P, args), dst in dfta.rules.items():
48-
reachable_from[dst].update(args)
49-
updated = True
50-
while updated:
51-
updated = False
52-
for dst, reachables in reachable_from.copy().items():
53-
before = len(reachables)
54-
for S in reachables.copy():
55-
reachables.update(reachable_from[S])
56-
if len(reachables) != before:
57-
updated = True
58-
return {
59-
s: set(state_to_type[v] for v in reachables)
60-
for s, reachables in reachable_from.items()
61-
}
62-
63-
64-
def __compute_outbound(dfta: DFTA[str, str], unconsumed: set[str]) -> dict[str, int]:
65-
outbound: dict[str, int] = {}
66-
for x in unconsumed:
67-
outbound[x] = 1
68-
queue = list(dfta.all_states)
69-
while queue:
70-
x = queue.pop()
71-
if x in outbound:
72-
continue
73-
total = 0
74-
has_missed = False
75-
for (P, args), dst in dfta.rules.items():
76-
if x in args:
77-
if dst not in outbound:
78-
has_missed = True
79-
break
80-
else:
81-
total += outbound[dst]
82-
if has_missed:
83-
queue.insert(0, x)
84-
else:
85-
outbound[x] = total
86-
return outbound
15+
def __state2letter__(state: str) -> str:
16+
if "(" in state:
17+
return state[1 : state.find(" ")]
18+
else:
19+
return state
8720

8821

8922
def __can_states_merge(
90-
dfta: DFTA[str, str], state_to_letter: dict[str, str], original: str, candidate: str
23+
reversed_rules: dict[tuple[str, tuple[str, ...]], str],
24+
original: str,
25+
candidate: str,
9126
) -> bool:
92-
if state_to_letter[candidate] != state_to_letter[original] and not str(
93-
state_to_letter[candidate]
27+
if __state2letter__(candidate) != __state2letter__(original) and not str(
28+
__state2letter__(candidate)
9429
).startswith("var"):
9530
return False
96-
for P1, args1 in dfta.reversed_rules[original]:
31+
for P1, args1 in reversed_rules[original]:
9732
has_equivalent = False
98-
for P2, args2 in dfta.reversed_rules[candidate]:
33+
for P2, args2 in reversed_rules[candidate]:
9934
if all(
100-
__can_states_merge(dfta, state_to_letter, arg1, arg2)
35+
__can_states_merge(reversed_rules, arg1, arg2)
10136
for arg1, arg2 in zip(args1, args2)
10237
):
10338
has_equivalent = True
@@ -107,98 +42,115 @@ def __can_states_merge(
10742
return True
10843

10944

45+
def __find_merge__(
46+
dfta: DFTA[str, str], P: str, args: tuple[str, ...], candidates: set[str]
47+
) -> str | None:
48+
best_candidate = None
49+
for candidate in candidates:
50+
if __state2letter__(candidate) != P and not str(
51+
__state2letter__(candidate)
52+
).startswith("var"):
53+
continue
54+
has_equivalent = False
55+
for P2, args2 in dfta.reversed_rules[candidate]:
56+
if all(
57+
__can_states_merge(dfta.reversed_rules, arg1, arg2)
58+
for arg1, arg2 in zip(args, args2)
59+
):
60+
has_equivalent = True
61+
break
62+
if has_equivalent and (
63+
best_candidate is None or best_candidate.count(" ") < candidate.count(" ")
64+
):
65+
best_candidate = candidate
66+
return best_candidate
67+
68+
69+
def __convert_automaton__(dfta: DFTA[str, str]) -> DFTA[str, Program]:
70+
return dfta.map_alphabet(
71+
lambda x: Variable(int(str(x)[len("var") :]))
72+
if str(x).startswith("var")
73+
else Primitive(str(x))
74+
)
75+
76+
11077
def add_loops(
111-
dfta: DFTA[str, str],
78+
dfta: DFTA[str, Program | str],
11279
dsl: DSL,
11380
strategy: LoopStrategy,
114-
) -> DFTA[str, str]:
81+
) -> DFTA[str, Program]:
11582
"""
116-
Assumes one state is from one letter
83+
Assumes one state is from one letter and that variants are mapped.
11784
"""
11885
if strategy == LoopStrategy.NO_LOOP:
119-
return dfta
86+
return __convert_automaton__(dfta)
12087
elif dfta.is_unbounded():
12188
raise ValueError("automaton is already looping cannot add loops!")
12289
else:
123-
# In order to make the automaton loop
124-
# 1) All unconsumed must be consumed
125-
# 2) Programs of all produced types must have unbounded size
12690
state_to_type = dsl.get_state_types(dfta)
127-
state_to_letter = {s: dfta.reversed_rules[s][0][0] for s in state_to_type}
128-
prod_types_by_state = __prod_types_by_states(dfta, state_to_type)
129-
all_types = set(state_to_type.values())
130-
unbounded_types = __find_unbounded_types(dfta, state_to_type)
131-
unconsumed = __find_unconsumed_states(dfta)
132-
unconsumed_by_type = {
133-
t: {s for s in unconsumed if state_to_type[s] == t} for t in all_types
91+
state_to_size = {s: s.count(" ") for s in dfta.all_states}
92+
max_size = max(state_to_size.values())
93+
states_by_types = {
94+
t: set(s for s, st in state_to_type.items() if st == t)
95+
for t in set(state_to_type.values())
13496
}
135-
unbounded_unconsumed = {
136-
t for t in unbounded_types if t not in unconsumed_by_type
137-
}
138-
# For each unbounded unconsumed
139-
# find all states that are not consumed to produce more of that type
140-
# mark them as unconsumed
141-
for t in unbounded_unconsumed:
142-
unconsumed_by_type[t] = set()
143-
for state in dfta.all_states:
144-
if state_to_type[state] == t and t not in prod_types_by_state[state]:
145-
unconsumed.add(state)
146-
unconsumed_by_type[t].add(state)
147-
# Computes consumed
148-
consumed = dfta.all_states.difference(unconsumed)
149-
consumed_by_type = {
150-
t: {s for s in consumed if state_to_type[s] == t} for t in all_types
151-
}
152-
outbound = __compute_outbound(dfta, unconsumed)
153-
state_merged: dict[str, str] = {}
154-
new_rules = dfta.rules.copy()
155-
new_finals = dfta.finals.copy()
156-
# 1) Merge all unconsumed onto the largest subcontext that is being consumed
157-
unmerged_by_type: dict[str, set[str]] = defaultdict(set)
158-
for t, states in unconsumed_by_type.items():
159-
for state in states:
160-
has_merge = False
161-
for candidate in consumed_by_type[t]:
162-
if not __can_states_merge(dfta, state_to_letter, state, candidate):
163-
continue
164-
if (
165-
has_merge
166-
and outbound[candidate] < outbound[state_merged[state]]
167-
) or not has_merge:
168-
state_merged[state] = candidate
169-
has_merge = True
170-
if not has_merge:
171-
unmerged_by_type[t].add(state)
172-
if strategy == LoopStrategy.STATE:
173-
for (P, args), dst in dfta.rules.items():
174-
if dst in state_merged:
175-
new_rules[(P, args)] = state_merged[dst]
176-
else:
177-
assert False, f"unsupported loop strategy:{strategy}"
178-
# 2) Some can still be unmerged
179-
# this means multiple things:
180-
# - there is no variable of that type
181-
# - there is not smaller expression using the same letter
182-
# print(
183-
# "UNMERGED:\n",
184-
# "\n".join([f"\t{k} ====> {v}" for k, v in unmerged_by_type.items()]),
185-
# )
186-
for (P, args), dst in dfta.rules.items():
187-
possibles = [[arg] for arg in args]
97+
added = True
98+
new_dfta = DFTA(dfta.rules.copy(), dfta.finals.copy())
99+
virtual_vars = set()
100+
max_varno = (
101+
max(
102+
int(s[len("var") :])
103+
for s in state_to_type.keys()
104+
if s.startswith("var")
105+
)
106+
+ 1
107+
)
108+
for t, states in states_by_types.items():
109+
if all(not s.startswith("var") for s in states):
110+
virtual_vars.add(max_varno)
111+
dst = str(Variable(max_varno))
112+
new_dfta.rules[(Variable(max_varno), tuple())] = dst
113+
# Create a variant so that every
114+
for (P, args), new_dst in dfta.rules.items():
115+
possibles = [
116+
[arg] + ([dst] if arg in states else []) for arg in args
117+
]
118+
for new_args in itertools.product(*possibles):
119+
if dst in new_args and (P, new_args) not in new_dfta.rules:
120+
new_dfta.rules[(P, new_args)] = new_dst
121+
print("adding:", (P, new_args), new_dst)
122+
max_varno += 1
123+
new_dfta.refresh_reversed_rules()
124+
i = 0
125+
while added and i < 1:
126+
i += 1
188127
added = False
189-
for rtype, programs in unmerged_by_type.items():
190-
for li in possibles:
191-
if state_to_type[li[0]] != rtype:
192-
continue
193-
else:
194-
added = True
195-
li.extend(programs)
196-
if added:
197-
for new_args in itertools.product(*possibles):
198-
new_rules[(P, new_args)] = dst
199-
new_dfta = DFTA(new_rules, new_finals)
200-
new_dfta.reduce()
201-
out = new_dfta.minimise(
202-
can_be_merged=lambda x, y: state_to_type[x] == state_to_type[y]
203-
).classic_state_renaming()
204-
return out
128+
for P, (Ptype, _) in dsl.primitives.items():
129+
possibles = [states_by_types[arg_t] for arg_t in types.arguments(Ptype)]
130+
for combi in itertools.product(*possibles):
131+
key = (P, combi)
132+
if key not in new_dfta.rules:
133+
args_size = list(map(lambda x: state_to_size[x], combi))
134+
dst_size = sum(args_size) + 1
135+
if (
136+
dst_size >= max_size
137+
and max(args_size) >= max_size - len(args_size) + 1
138+
):
139+
added = True
140+
rtype = types.return_type(dsl.get_type(P))
141+
dst = Function(Primitive(P), list(map(Primitive, combi)))
142+
new_state = __find_merge__(
143+
new_dfta, P, combi, states_by_types[rtype]
144+
) or str(dst)
145+
new_dfta.rules[key] = new_state
146+
states_by_types[rtype].add(new_state)
147+
if new_state not in state_to_size:
148+
state_to_size[new_state] = dst_size
149+
new_dfta.refresh_reversed_rules()
150+
151+
for no in virtual_vars:
152+
dst = Variable(no)
153+
del new_dfta.rules[(dst, tuple())]
154+
new_dfta.reduce()
155+
new_dfta.refresh_reversed_rules()
156+
return __convert_automaton__(new_dfta) # .minimise())#.classic_state_renaming())

grape/dsl.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,8 @@ def get_state_types(self, automaton: DFTA[T, str | Program]) -> dict[T, str]:
6565
# Assumes types variants are not present.
6666
specialized = spec_manager.is_specialized(automaton)
6767
if specialized:
68-
arg_types = types.arguments(
69-
spec_manager.type_request_from_specialized(automaton, self)
70-
)
68+
guessed_tr = spec_manager.type_request_from_specialized(automaton, self)
69+
arg_types = types.arguments(guessed_tr)
7170

7271
state_to_type: dict[Any, str] = {}
7372
elements = list(automaton.rules.items())
@@ -99,7 +98,10 @@ def get_state_types(self, automaton: DFTA[T, str | Program]) -> dict[T, str]:
9998
)
10099
Ptype = all_possibles.pop()
101100
if dst in state_to_type:
102-
assert state_to_type[dst] == types.return_type(Ptype)
101+
all_types = set()
102+
for variant in types.all_variants(Ptype):
103+
all_types.add(types.return_type(variant))
104+
assert state_to_type[dst] in all_types
103105
else:
104106
state_to_type[dst] = types.return_type(Ptype)
105107
return state_to_type

0 commit comments

Comments
 (0)