1- from collections import defaultdict
21from enum import StrEnum
32import itertools
43
4+ from grape import types
55from grape .automaton .tree_automaton import DFTA
66from grape .dsl import DSL
7+ from grape .program import Function , Primitive , Program , Variable
78
89
910class LoopStrategy (StrEnum ):
1011 NO_LOOP = "none"
1112 STATE = "state"
1213
1314
14- def __find_unbounded_types (
15- dfta : DFTA [str , str ], state_to_type : dict [str , str ]
16- ) -> set [str ]:
17- unbounded_types = set ()
18- added = True
19- while added :
20- added = False
21- for (P , args ), dst in dfta .rules .items ():
22- prod_type = state_to_type [dst ]
23- if prod_type not in unbounded_types and any (
24- state_to_type [arg_state ] in unbounded_types
25- or prod_type == state_to_type [arg_state ]
26- for arg_state in args
27- ):
28- unbounded_types .add (prod_type )
29- added = True
30- return unbounded_types
31-
32-
33- def __find_unconsumed_states (dfta : DFTA [str , str ]) -> set [str ]:
34- not_consumed = dfta .all_states
35- for P , args in dfta .rules :
36- for arg_state in args :
37- if arg_state in not_consumed :
38- not_consumed .remove (arg_state )
39- return not_consumed
40-
41-
42- def __prod_types_by_states (
43- dfta : DFTA [str , str ], state_to_type : dict [str , str ]
44- ) -> dict [str , set [str ]]:
45- # Compute transitive closure
46- reachable_from : dict [str , set [str ]] = defaultdict (set )
47- for (P , args ), dst in dfta .rules .items ():
48- reachable_from [dst ].update (args )
49- updated = True
50- while updated :
51- updated = False
52- for dst , reachables in reachable_from .copy ().items ():
53- before = len (reachables )
54- for S in reachables .copy ():
55- reachables .update (reachable_from [S ])
56- if len (reachables ) != before :
57- updated = True
58- return {
59- s : set (state_to_type [v ] for v in reachables )
60- for s , reachables in reachable_from .items ()
61- }
62-
63-
64- def __compute_outbound (dfta : DFTA [str , str ], unconsumed : set [str ]) -> dict [str , int ]:
65- outbound : dict [str , int ] = {}
66- for x in unconsumed :
67- outbound [x ] = 1
68- queue = list (dfta .all_states )
69- while queue :
70- x = queue .pop ()
71- if x in outbound :
72- continue
73- total = 0
74- has_missed = False
75- for (P , args ), dst in dfta .rules .items ():
76- if x in args :
77- if dst not in outbound :
78- has_missed = True
79- break
80- else :
81- total += outbound [dst ]
82- if has_missed :
83- queue .insert (0 , x )
84- else :
85- outbound [x ] = total
86- return outbound
15+ def __state2letter__ (state : str ) -> str :
16+ if "(" in state :
17+ return state [1 : state .find (" " )]
18+ else :
19+ return state
8720
8821
8922def __can_states_merge (
90- dfta : DFTA [str , str ], state_to_letter : dict [str , str ], original : str , candidate : str
23+ reversed_rules : dict [tuple [str , tuple [str , ...]], str ],
24+ original : str ,
25+ candidate : str ,
9126) -> bool :
92- if state_to_letter [ candidate ] != state_to_letter [ original ] and not str (
93- state_to_letter [ candidate ]
27+ if __state2letter__ ( candidate ) != __state2letter__ ( original ) and not str (
28+ __state2letter__ ( candidate )
9429 ).startswith ("var" ):
9530 return False
96- for P1 , args1 in dfta . reversed_rules [original ]:
31+ for P1 , args1 in reversed_rules [original ]:
9732 has_equivalent = False
98- for P2 , args2 in dfta . reversed_rules [candidate ]:
33+ for P2 , args2 in reversed_rules [candidate ]:
9934 if all (
100- __can_states_merge (dfta , state_to_letter , arg1 , arg2 )
35+ __can_states_merge (reversed_rules , arg1 , arg2 )
10136 for arg1 , arg2 in zip (args1 , args2 )
10237 ):
10338 has_equivalent = True
@@ -107,98 +42,115 @@ def __can_states_merge(
10742 return True
10843
10944
45+ def __find_merge__ (
46+ dfta : DFTA [str , str ], P : str , args : tuple [str , ...], candidates : set [str ]
47+ ) -> str | None :
48+ best_candidate = None
49+ for candidate in candidates :
50+ if __state2letter__ (candidate ) != P and not str (
51+ __state2letter__ (candidate )
52+ ).startswith ("var" ):
53+ continue
54+ has_equivalent = False
55+ for P2 , args2 in dfta .reversed_rules [candidate ]:
56+ if all (
57+ __can_states_merge (dfta .reversed_rules , arg1 , arg2 )
58+ for arg1 , arg2 in zip (args , args2 )
59+ ):
60+ has_equivalent = True
61+ break
62+ if has_equivalent and (
63+ best_candidate is None or best_candidate .count (" " ) < candidate .count (" " )
64+ ):
65+ best_candidate = candidate
66+ return best_candidate
67+
68+
69+ def __convert_automaton__ (dfta : DFTA [str , str ]) -> DFTA [str , Program ]:
70+ return dfta .map_alphabet (
71+ lambda x : Variable (int (str (x )[len ("var" ) :]))
72+ if str (x ).startswith ("var" )
73+ else Primitive (str (x ))
74+ )
75+
76+
11077def add_loops (
111- dfta : DFTA [str , str ],
78+ dfta : DFTA [str , Program | str ],
11279 dsl : DSL ,
11380 strategy : LoopStrategy ,
114- ) -> DFTA [str , str ]:
81+ ) -> DFTA [str , Program ]:
11582 """
116- Assumes one state is from one letter
83+ Assumes one state is from one letter and that variants are mapped.
11784 """
11885 if strategy == LoopStrategy .NO_LOOP :
119- return dfta
86+ return __convert_automaton__ ( dfta )
12087 elif dfta .is_unbounded ():
12188 raise ValueError ("automaton is already looping cannot add loops!" )
12289 else :
123- # In order to make the automaton loop
124- # 1) All unconsumed must be consumed
125- # 2) Programs of all produced types must have unbounded size
12690 state_to_type = dsl .get_state_types (dfta )
127- state_to_letter = {s : dfta .reversed_rules [s ][0 ][0 ] for s in state_to_type }
128- prod_types_by_state = __prod_types_by_states (dfta , state_to_type )
129- all_types = set (state_to_type .values ())
130- unbounded_types = __find_unbounded_types (dfta , state_to_type )
131- unconsumed = __find_unconsumed_states (dfta )
132- unconsumed_by_type = {
133- t : {s for s in unconsumed if state_to_type [s ] == t } for t in all_types
91+ state_to_size = {s : s .count (" " ) for s in dfta .all_states }
92+ max_size = max (state_to_size .values ())
93+ states_by_types = {
94+ t : set (s for s , st in state_to_type .items () if st == t )
95+ for t in set (state_to_type .values ())
13496 }
135- unbounded_unconsumed = {
136- t for t in unbounded_types if t not in unconsumed_by_type
137- }
138- # For each unbounded unconsumed
139- # find all states that are not consumed to produce more of that type
140- # mark them as unconsumed
141- for t in unbounded_unconsumed :
142- unconsumed_by_type [t ] = set ()
143- for state in dfta .all_states :
144- if state_to_type [state ] == t and t not in prod_types_by_state [state ]:
145- unconsumed .add (state )
146- unconsumed_by_type [t ].add (state )
147- # Computes consumed
148- consumed = dfta .all_states .difference (unconsumed )
149- consumed_by_type = {
150- t : {s for s in consumed if state_to_type [s ] == t } for t in all_types
151- }
152- outbound = __compute_outbound (dfta , unconsumed )
153- state_merged : dict [str , str ] = {}
154- new_rules = dfta .rules .copy ()
155- new_finals = dfta .finals .copy ()
156- # 1) Merge all unconsumed onto the largest subcontext that is being consumed
157- unmerged_by_type : dict [str , set [str ]] = defaultdict (set )
158- for t , states in unconsumed_by_type .items ():
159- for state in states :
160- has_merge = False
161- for candidate in consumed_by_type [t ]:
162- if not __can_states_merge (dfta , state_to_letter , state , candidate ):
163- continue
164- if (
165- has_merge
166- and outbound [candidate ] < outbound [state_merged [state ]]
167- ) or not has_merge :
168- state_merged [state ] = candidate
169- has_merge = True
170- if not has_merge :
171- unmerged_by_type [t ].add (state )
172- if strategy == LoopStrategy .STATE :
173- for (P , args ), dst in dfta .rules .items ():
174- if dst in state_merged :
175- new_rules [(P , args )] = state_merged [dst ]
176- else :
177- assert False , f"unsupported loop strategy:{ strategy } "
178- # 2) Some can still be unmerged
179- # this means multiple things:
180- # - there is no variable of that type
181- # - there is not smaller expression using the same letter
182- # print(
183- # "UNMERGED:\n",
184- # "\n".join([f"\t{k} ====> {v}" for k, v in unmerged_by_type.items()]),
185- # )
186- for (P , args ), dst in dfta .rules .items ():
187- possibles = [[arg ] for arg in args ]
97+ added = True
98+ new_dfta = DFTA (dfta .rules .copy (), dfta .finals .copy ())
99+ virtual_vars = set ()
100+ max_varno = (
101+ max (
102+ int (s [len ("var" ) :])
103+ for s in state_to_type .keys ()
104+ if s .startswith ("var" )
105+ )
106+ + 1
107+ )
108+ for t , states in states_by_types .items ():
109+ if all (not s .startswith ("var" ) for s in states ):
110+ virtual_vars .add (max_varno )
111+ dst = str (Variable (max_varno ))
112+ new_dfta .rules [(Variable (max_varno ), tuple ())] = dst
113+ # Create a variant so that every
114+ for (P , args ), new_dst in dfta .rules .items ():
115+ possibles = [
116+ [arg ] + ([dst ] if arg in states else []) for arg in args
117+ ]
118+ for new_args in itertools .product (* possibles ):
119+ if dst in new_args and (P , new_args ) not in new_dfta .rules :
120+ new_dfta .rules [(P , new_args )] = new_dst
121+ print ("adding:" , (P , new_args ), new_dst )
122+ max_varno += 1
123+ new_dfta .refresh_reversed_rules ()
124+ i = 0
125+ while added and i < 1 :
126+ i += 1
188127 added = False
189- for rtype , programs in unmerged_by_type .items ():
190- for li in possibles :
191- if state_to_type [li [0 ]] != rtype :
192- continue
193- else :
194- added = True
195- li .extend (programs )
196- if added :
197- for new_args in itertools .product (* possibles ):
198- new_rules [(P , new_args )] = dst
199- new_dfta = DFTA (new_rules , new_finals )
200- new_dfta .reduce ()
201- out = new_dfta .minimise (
202- can_be_merged = lambda x , y : state_to_type [x ] == state_to_type [y ]
203- ).classic_state_renaming ()
204- return out
128+ for P , (Ptype , _ ) in dsl .primitives .items ():
129+ possibles = [states_by_types [arg_t ] for arg_t in types .arguments (Ptype )]
130+ for combi in itertools .product (* possibles ):
131+ key = (P , combi )
132+ if key not in new_dfta .rules :
133+ args_size = list (map (lambda x : state_to_size [x ], combi ))
134+ dst_size = sum (args_size ) + 1
135+ if (
136+ dst_size >= max_size
137+ and max (args_size ) >= max_size - len (args_size ) + 1
138+ ):
139+ added = True
140+ rtype = types .return_type (dsl .get_type (P ))
141+ dst = Function (Primitive (P ), list (map (Primitive , combi )))
142+ new_state = __find_merge__ (
143+ new_dfta , P , combi , states_by_types [rtype ]
144+ ) or str (dst )
145+ new_dfta .rules [key ] = new_state
146+ states_by_types [rtype ].add (new_state )
147+ if new_state not in state_to_size :
148+ state_to_size [new_state ] = dst_size
149+ new_dfta .refresh_reversed_rules ()
150+
151+ for no in virtual_vars :
152+ dst = Variable (no )
153+ del new_dfta .rules [(dst , tuple ())]
154+ new_dfta .reduce ()
155+ new_dfta .refresh_reversed_rules ()
156+ return __convert_automaton__ (new_dfta ) # .minimise())#.classic_state_renaming())
0 commit comments