1616from ray .exceptions import GetTimeoutError
1717from ray .util .dask import RayDaskCallback
1818from pydantic import validate_arguments , conint , confloat
19- from synthesizrr .base .util .language import ProgressBar , set_param_from_alias , type_str , get_default , first_item , if_else
19+ from synthesizrr .base .util .language import ProgressBar , set_param_from_alias , type_str , get_default , first_item , Parameters
2020from synthesizrr .base .constants .DataProcessingConstants import Parallelize , FailureAction , Status , COMPLETED_STATUSES
2121
2222from functools import partial
2626import time , inspect
2727from concurrent .futures import ThreadPoolExecutor , ProcessPoolExecutor , wait as wait_future
2828
29+ _RAY_ACCUMULATE_ITEM_WAIT : float = 10e-3 ## 10ms
30+ _LOCAL_ACCUMULATE_ITEM_WAIT : float = 1e-3 ## 1ms
31+
32+ _RAY_ACCUMULATE_ITER_WAIT : float = 1000e-3 ## 1000ms
33+ _LOCAL_ACCUMULATE_ITER_WAIT : float = 100e-3 ## 100ms
34+
2935
3036def _asyncio_start_event_loop (loop ):
3137 asyncio .set_event_loop (loop )
@@ -173,6 +179,25 @@ def run_concurrent(
173179 raise e
174180
175181
182+ class RestrictedConcurrencyThreadPoolExecutor (ThreadPoolExecutor ):
183+ """
184+ Similar functionality to @concurrent.
185+ """
186+
187+ def __init__ (self , max_active_threads : Optional [int ] = None , * args , ** kwargs ):
188+ super ().__init__ (* args , ** kwargs )
189+ if max_active_threads is None :
190+ max_active_threads : int = self ._max_workers
191+ assert isinstance (max_active_threads , int )
192+ self ._semaphore = Semaphore (max_active_threads )
193+
194+ def submit (self , * args , ** kwargs ):
195+ self ._semaphore .acquire ()
196+ future = super ().submit (* args , ** kwargs )
197+ future .add_done_callback (lambda _ : self ._semaphore .release ())
198+ return future
199+
200+
176201_GLOBAL_PROCESS_POOL_EXECUTOR : ProcessPoolExecutor = ProcessPoolExecutor (
177202 max_workers = max (1 , min (32 , mp .cpu_count () - 1 ))
178203)
@@ -255,10 +280,105 @@ def stop_executor(
255280
256281
257282@ray .remote (num_cpus = 1 )
258- def __run_parallel_ray_executor (fn , * args , ** kwargs ):
283+ def _run_parallel_ray_executor (fn , * args , ** kwargs ):
259284 return fn (* args , ** kwargs )
260285
261286
287+ def _ray_asyncio_start_event_loop (loop ):
288+ asyncio .set_event_loop (loop )
289+ loop .run_forever ()
290+
291+
292+ class RayPoolExecutor (Parameters ):
293+ max_workers : Union [int , Literal [inf ]]
294+ iter_wait : float = _RAY_ACCUMULATE_ITER_WAIT
295+ item_wait : float = _RAY_ACCUMULATE_ITEM_WAIT
296+ _asyncio_event_loop : Optional = None
297+ _asyncio_event_loop_thread : Optional = None
298+ _submission_executor : Optional [ThreadPoolExecutor ] = None
299+ _running_tasks : Dict = {}
300+ _latest_submit : Optional [int ] = None
301+
302+ def _set_asyncio (self ):
303+ # Create a new loop and a thread running this loop
304+ if self ._asyncio_event_loop is None :
305+ self ._asyncio_event_loop = asyncio .new_event_loop ()
306+ # print(f'Started _asyncio_event_loop')
307+ if self ._asyncio_event_loop_thread is None :
308+ self ._asyncio_event_loop_thread = threading .Thread (
309+ target = _ray_asyncio_start_event_loop ,
310+ args = (self ._asyncio_event_loop ,),
311+ )
312+ self ._asyncio_event_loop_thread .start ()
313+ # print(f'Started _asyncio_event_loop_thread')
314+
315+ def submit (
316+ self ,
317+ fn ,
318+ * args ,
319+ scheduling_strategy : str = "SPREAD" ,
320+ num_cpus : int = 1 ,
321+ num_gpus : int = 0 ,
322+ max_retries : int = 0 ,
323+ retry_exceptions : Union [List , bool ] = True ,
324+ ** kwargs ,
325+ ):
326+ # print(f'Running {fn_str(fn)} using {Parallelize.ray} with num_cpus={num_cpus}, num_gpus={num_gpus}')
327+ def _submit_task ():
328+ return _run_parallel_ray_executor .options (
329+ scheduling_strategy = scheduling_strategy ,
330+ num_cpus = num_cpus ,
331+ num_gpus = num_gpus ,
332+ max_retries = max_retries ,
333+ retry_exceptions = retry_exceptions ,
334+ ).remote (fn , * args , ** kwargs )
335+
336+ _task_uid = str (time .time_ns ())
337+
338+ if self .max_workers == inf :
339+ return _submit_task () ## Submit to Ray directly
340+ self ._set_asyncio ()
341+ ## Create a coroutine (i.e. Future), but do not actually start executing it.
342+ coroutine = self ._ray_run_fn_async (
343+ submit_task = _submit_task ,
344+ task_uid = _task_uid ,
345+ )
346+
347+ ## Schedule the coroutine to execute on the event loop (which is running on thread _asyncio_event_loop).
348+ fut = asyncio .run_coroutine_threadsafe (coroutine , self ._asyncio_event_loop )
349+ # while _task_uid not in self._running_tasks: ## Ensure task has started scheduling
350+ # time.sleep(self.item_wait)
351+ return fut
352+
353+ async def _ray_run_fn_async (
354+ self ,
355+ submit_task : Callable ,
356+ task_uid : str ,
357+ ):
358+ # self._running_tasks[task_uid] = None
359+ while len (self ._running_tasks ) >= self .max_workers :
360+ for _task_uid in sorted (self ._running_tasks .keys ()):
361+ if is_done (self ._running_tasks [_task_uid ]):
362+ self ._running_tasks .pop (_task_uid , None )
363+ # print(f'Popped {_task_uid}')
364+ if len (self ._running_tasks ) < self .max_workers :
365+ break
366+ time .sleep (self .item_wait )
367+ if len (self ._running_tasks ) < self .max_workers :
368+ break
369+ time .sleep (self .iter_wait )
370+ fut = submit_task ()
371+ self ._running_tasks [task_uid ] = fut
372+ # print(f'Started {task_uid}. Num running: {len(self._running_tasks)}')
373+
374+ # ## Cleanup any completed tasks:
375+ # for k in list(self._running_tasks.keys()):
376+ # if is_done(self._running_tasks[k]):
377+ # self._running_tasks.pop(k, None)
378+ # time.sleep(self.item_wait)
379+ return fut
380+
381+
262382def run_parallel_ray (
263383 fn ,
264384 * args ,
@@ -270,7 +390,7 @@ def run_parallel_ray(
270390 ** kwargs ,
271391):
272392 # print(f'Running {fn_str(fn)} using {Parallelize.ray} with num_cpus={num_cpus}, num_gpus={num_gpus}')
273- return __run_parallel_ray_executor .options (
393+ return _run_parallel_ray_executor .options (
274394 scheduling_strategy = scheduling_strategy ,
275395 num_cpus = num_cpus ,
276396 num_gpus = num_gpus ,
@@ -285,7 +405,7 @@ def dispatch(
285405 parallelize : Parallelize ,
286406 forward_parallelize : bool = False ,
287407 delay : float = 0.0 ,
288- executor : Optional [Union [ThreadPoolExecutor , ProcessPoolExecutor ]] = None ,
408+ executor : Optional [Union [ThreadPoolExecutor , ProcessPoolExecutor , RayPoolExecutor ]] = None ,
289409 ** kwargs
290410) -> Any :
291411 parallelize : Parallelize = Parallelize .from_str (parallelize )
@@ -301,24 +421,26 @@ def dispatch(
301421 elif parallelize is Parallelize .processes :
302422 return run_parallel (fn , * args , executor = executor , ** kwargs )
303423 elif parallelize is Parallelize .ray :
304- return run_parallel_ray (fn , * args , ** kwargs )
424+ return run_parallel_ray (fn , * args , executor = executor , ** kwargs )
305425 raise NotImplementedError (f'Unsupported parallelization: { parallelize } ' )
306426
307427
308428def dispatch_executor (
309429 parallelize : Parallelize ,
310430 ** kwargs
311- ) -> Optional [Union [ProcessPoolExecutor , ThreadPoolExecutor ]]:
431+ ) -> Optional [Union [ProcessPoolExecutor , ThreadPoolExecutor , RayPoolExecutor ]]:
312432 parallelize : Parallelize = Parallelize .from_str (parallelize )
313433 set_param_from_alias (kwargs , param = 'max_workers' , alias = ['num_workers' ], default = None )
314434 max_workers : Optional [int ] = kwargs .pop ('max_workers' , None )
315435 if max_workers is None :
316- ## Uses the default executor for threads/processes.
436+ ## Uses the default executor for threads/processes/ray .
317437 return None
318438 if parallelize is Parallelize .threads :
319439 return ThreadPoolExecutor (max_workers = max_workers )
320440 elif parallelize is Parallelize .processes :
321441 return ProcessPoolExecutor (max_workers = max_workers )
442+ elif parallelize is Parallelize .ray :
443+ return RayPoolExecutor (max_workers = max_workers )
322444 else :
323445 return None
324446
@@ -329,7 +451,7 @@ def get_result(
329451 wait : float = 1.0 , ## 1000 ms
330452) -> Optional [Any ]:
331453 if isinstance (x , Future ):
332- return x .result ()
454+ return get_result ( x .result (), wait = wait )
333455 if isinstance (x , ray .ObjectRef ):
334456 while True :
335457 try :
@@ -399,13 +521,6 @@ def is_failed(x, *, pending_returns_false: bool = False) -> Optional[bool]:
399521 return True
400522
401523
402- _RAY_ACCUMULATE_ITEM_WAIT : float = 100e-3 ## 100ms
403- _LOCAL_ACCUMULATE_ITEM_WAIT : float = 10e-3 ## 10ms
404-
405- _RAY_ACCUMULATE_ITER_WAIT : float = 1000e-3 ## 1000ms
406- _LOCAL_ACCUMULATE_ITER_WAIT : float = 100e-3 ## 100ms
407-
408-
409524def accumulate (
410525 futures : Union [Tuple , List , Set , Dict , Any ],
411526 * ,
@@ -678,8 +793,9 @@ def retry(
678793 wait : confloat (ge = 0.0 ) = 10.0 ,
679794 jitter : confloat (gt = 0.0 ) = 0.5 ,
680795 silent : bool = True ,
796+ return_num_failures : bool = False ,
681797 ** kwargs
682- ):
798+ ) -> Union [ Any , Tuple [ Any , int ]] :
683799 """
684800 Retries a function call a certain number of times, waiting between calls (with a jitter in the wait period).
685801 :param fn: the function to call.
@@ -694,15 +810,21 @@ def retry(
694810 """
695811 wait : float = float (wait )
696812 latest_exception = None
813+ num_failures : int = 0
697814 for retry_num in range (retries + 1 ):
698815 try :
699- return fn (* args , ** kwargs )
816+ out = fn (* args , ** kwargs )
817+ if return_num_failures :
818+ return out , num_failures
819+ else :
820+ return out
700821 except Exception as e :
822+ num_failures += 1
701823 latest_exception = traceback .format_exc ()
702824 if not silent :
703- logging . debug (f'Function call failed with the following exception:\n { latest_exception } ' )
825+ print (f'Function call failed with the following exception:\n { latest_exception } ' )
704826 if retry_num < (retries - 1 ):
705- logging . debug (f'Retrying { retries - (retry_num + 1 )} more times...\n ' )
827+ print (f'Retrying { retries - (retry_num + 1 )} more times...\n ' )
706828 time .sleep (np .random .uniform (wait - wait * jitter , wait + wait * jitter ))
707829 raise RuntimeError (f'Function call failed { retries } times.\n Latest exception:\n { latest_exception } \n ' )
708830
0 commit comments