Skip to content

Commit c1442cf

Browse files
Added new RayPoolExecutor for limited concurrency with Ray.
1 parent 4bcd8bb commit c1442cf

File tree

2 files changed

+198
-23
lines changed

2 files changed

+198
-23
lines changed

src/synthesizrr/base/util/concurrency.py

Lines changed: 141 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ray.exceptions import GetTimeoutError
1717
from ray.util.dask import RayDaskCallback
1818
from pydantic import validate_arguments, conint, confloat
19-
from synthesizrr.base.util.language import ProgressBar, set_param_from_alias, type_str, get_default, first_item, if_else
19+
from synthesizrr.base.util.language import ProgressBar, set_param_from_alias, type_str, get_default, first_item, Parameters
2020
from synthesizrr.base.constants.DataProcessingConstants import Parallelize, FailureAction, Status, COMPLETED_STATUSES
2121

2222
from functools import partial
@@ -26,6 +26,12 @@
2626
import time, inspect
2727
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, wait as wait_future
2828

29+
_RAY_ACCUMULATE_ITEM_WAIT: float = 10e-3 ## 10ms
30+
_LOCAL_ACCUMULATE_ITEM_WAIT: float = 1e-3 ## 1ms
31+
32+
_RAY_ACCUMULATE_ITER_WAIT: float = 1000e-3 ## 1000ms
33+
_LOCAL_ACCUMULATE_ITER_WAIT: float = 100e-3 ## 100ms
34+
2935

3036
def _asyncio_start_event_loop(loop):
3137
asyncio.set_event_loop(loop)
@@ -173,6 +179,25 @@ def run_concurrent(
173179
raise e
174180

175181

182+
class RestrictedConcurrencyThreadPoolExecutor(ThreadPoolExecutor):
183+
"""
184+
Similar functionality to @concurrent.
185+
"""
186+
187+
def __init__(self, max_active_threads: Optional[int] = None, *args, **kwargs):
188+
super().__init__(*args, **kwargs)
189+
if max_active_threads is None:
190+
max_active_threads: int = self._max_workers
191+
assert isinstance(max_active_threads, int)
192+
self._semaphore = Semaphore(max_active_threads)
193+
194+
def submit(self, *args, **kwargs):
195+
self._semaphore.acquire()
196+
future = super().submit(*args, **kwargs)
197+
future.add_done_callback(lambda _: self._semaphore.release())
198+
return future
199+
200+
176201
_GLOBAL_PROCESS_POOL_EXECUTOR: ProcessPoolExecutor = ProcessPoolExecutor(
177202
max_workers=max(1, min(32, mp.cpu_count() - 1))
178203
)
@@ -255,10 +280,105 @@ def stop_executor(
255280

256281

257282
@ray.remote(num_cpus=1)
258-
def __run_parallel_ray_executor(fn, *args, **kwargs):
283+
def _run_parallel_ray_executor(fn, *args, **kwargs):
259284
return fn(*args, **kwargs)
260285

261286

287+
def _ray_asyncio_start_event_loop(loop):
288+
asyncio.set_event_loop(loop)
289+
loop.run_forever()
290+
291+
292+
class RayPoolExecutor(Parameters):
293+
max_workers: Union[int, Literal[inf]]
294+
iter_wait: float = _RAY_ACCUMULATE_ITER_WAIT
295+
item_wait: float = _RAY_ACCUMULATE_ITEM_WAIT
296+
_asyncio_event_loop: Optional = None
297+
_asyncio_event_loop_thread: Optional = None
298+
_submission_executor: Optional[ThreadPoolExecutor] = None
299+
_running_tasks: Dict = {}
300+
_latest_submit: Optional[int] = None
301+
302+
def _set_asyncio(self):
303+
# Create a new loop and a thread running this loop
304+
if self._asyncio_event_loop is None:
305+
self._asyncio_event_loop = asyncio.new_event_loop()
306+
# print(f'Started _asyncio_event_loop')
307+
if self._asyncio_event_loop_thread is None:
308+
self._asyncio_event_loop_thread = threading.Thread(
309+
target=_ray_asyncio_start_event_loop,
310+
args=(self._asyncio_event_loop,),
311+
)
312+
self._asyncio_event_loop_thread.start()
313+
# print(f'Started _asyncio_event_loop_thread')
314+
315+
def submit(
316+
self,
317+
fn,
318+
*args,
319+
scheduling_strategy: str = "SPREAD",
320+
num_cpus: int = 1,
321+
num_gpus: int = 0,
322+
max_retries: int = 0,
323+
retry_exceptions: Union[List, bool] = True,
324+
**kwargs,
325+
):
326+
# print(f'Running {fn_str(fn)} using {Parallelize.ray} with num_cpus={num_cpus}, num_gpus={num_gpus}')
327+
def _submit_task():
328+
return _run_parallel_ray_executor.options(
329+
scheduling_strategy=scheduling_strategy,
330+
num_cpus=num_cpus,
331+
num_gpus=num_gpus,
332+
max_retries=max_retries,
333+
retry_exceptions=retry_exceptions,
334+
).remote(fn, *args, **kwargs)
335+
336+
_task_uid = str(time.time_ns())
337+
338+
if self.max_workers == inf:
339+
return _submit_task() ## Submit to Ray directly
340+
self._set_asyncio()
341+
## Create a coroutine (i.e. Future), but do not actually start executing it.
342+
coroutine = self._ray_run_fn_async(
343+
submit_task=_submit_task,
344+
task_uid=_task_uid,
345+
)
346+
347+
## Schedule the coroutine to execute on the event loop (which is running on thread _asyncio_event_loop).
348+
fut = asyncio.run_coroutine_threadsafe(coroutine, self._asyncio_event_loop)
349+
# while _task_uid not in self._running_tasks: ## Ensure task has started scheduling
350+
# time.sleep(self.item_wait)
351+
return fut
352+
353+
async def _ray_run_fn_async(
354+
self,
355+
submit_task: Callable,
356+
task_uid: str,
357+
):
358+
# self._running_tasks[task_uid] = None
359+
while len(self._running_tasks) >= self.max_workers:
360+
for _task_uid in sorted(self._running_tasks.keys()):
361+
if is_done(self._running_tasks[_task_uid]):
362+
self._running_tasks.pop(_task_uid, None)
363+
# print(f'Popped {_task_uid}')
364+
if len(self._running_tasks) < self.max_workers:
365+
break
366+
time.sleep(self.item_wait)
367+
if len(self._running_tasks) < self.max_workers:
368+
break
369+
time.sleep(self.iter_wait)
370+
fut = submit_task()
371+
self._running_tasks[task_uid] = fut
372+
# print(f'Started {task_uid}. Num running: {len(self._running_tasks)}')
373+
374+
# ## Cleanup any completed tasks:
375+
# for k in list(self._running_tasks.keys()):
376+
# if is_done(self._running_tasks[k]):
377+
# self._running_tasks.pop(k, None)
378+
# time.sleep(self.item_wait)
379+
return fut
380+
381+
262382
def run_parallel_ray(
263383
fn,
264384
*args,
@@ -270,7 +390,7 @@ def run_parallel_ray(
270390
**kwargs,
271391
):
272392
# print(f'Running {fn_str(fn)} using {Parallelize.ray} with num_cpus={num_cpus}, num_gpus={num_gpus}')
273-
return __run_parallel_ray_executor.options(
393+
return _run_parallel_ray_executor.options(
274394
scheduling_strategy=scheduling_strategy,
275395
num_cpus=num_cpus,
276396
num_gpus=num_gpus,
@@ -285,7 +405,7 @@ def dispatch(
285405
parallelize: Parallelize,
286406
forward_parallelize: bool = False,
287407
delay: float = 0.0,
288-
executor: Optional[Union[ThreadPoolExecutor, ProcessPoolExecutor]] = None,
408+
executor: Optional[Union[ThreadPoolExecutor, ProcessPoolExecutor, RayPoolExecutor]] = None,
289409
**kwargs
290410
) -> Any:
291411
parallelize: Parallelize = Parallelize.from_str(parallelize)
@@ -301,24 +421,26 @@ def dispatch(
301421
elif parallelize is Parallelize.processes:
302422
return run_parallel(fn, *args, executor=executor, **kwargs)
303423
elif parallelize is Parallelize.ray:
304-
return run_parallel_ray(fn, *args, **kwargs)
424+
return run_parallel_ray(fn, *args, executor=executor, **kwargs)
305425
raise NotImplementedError(f'Unsupported parallelization: {parallelize}')
306426

307427

308428
def dispatch_executor(
309429
parallelize: Parallelize,
310430
**kwargs
311-
) -> Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]]:
431+
) -> Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor, RayPoolExecutor]]:
312432
parallelize: Parallelize = Parallelize.from_str(parallelize)
313433
set_param_from_alias(kwargs, param='max_workers', alias=['num_workers'], default=None)
314434
max_workers: Optional[int] = kwargs.pop('max_workers', None)
315435
if max_workers is None:
316-
## Uses the default executor for threads/processes.
436+
## Uses the default executor for threads/processes/ray.
317437
return None
318438
if parallelize is Parallelize.threads:
319439
return ThreadPoolExecutor(max_workers=max_workers)
320440
elif parallelize is Parallelize.processes:
321441
return ProcessPoolExecutor(max_workers=max_workers)
442+
elif parallelize is Parallelize.ray:
443+
return RayPoolExecutor(max_workers=max_workers)
322444
else:
323445
return None
324446

@@ -329,7 +451,7 @@ def get_result(
329451
wait: float = 1.0, ## 1000 ms
330452
) -> Optional[Any]:
331453
if isinstance(x, Future):
332-
return x.result()
454+
return get_result(x.result(), wait=wait)
333455
if isinstance(x, ray.ObjectRef):
334456
while True:
335457
try:
@@ -399,13 +521,6 @@ def is_failed(x, *, pending_returns_false: bool = False) -> Optional[bool]:
399521
return True
400522

401523

402-
_RAY_ACCUMULATE_ITEM_WAIT: float = 100e-3 ## 100ms
403-
_LOCAL_ACCUMULATE_ITEM_WAIT: float = 10e-3 ## 10ms
404-
405-
_RAY_ACCUMULATE_ITER_WAIT: float = 1000e-3 ## 1000ms
406-
_LOCAL_ACCUMULATE_ITER_WAIT: float = 100e-3 ## 100ms
407-
408-
409524
def accumulate(
410525
futures: Union[Tuple, List, Set, Dict, Any],
411526
*,
@@ -678,8 +793,9 @@ def retry(
678793
wait: confloat(ge=0.0) = 10.0,
679794
jitter: confloat(gt=0.0) = 0.5,
680795
silent: bool = True,
796+
return_num_failures: bool = False,
681797
**kwargs
682-
):
798+
) -> Union[Any, Tuple[Any, int]]:
683799
"""
684800
Retries a function call a certain number of times, waiting between calls (with a jitter in the wait period).
685801
:param fn: the function to call.
@@ -694,15 +810,21 @@ def retry(
694810
"""
695811
wait: float = float(wait)
696812
latest_exception = None
813+
num_failures: int = 0
697814
for retry_num in range(retries + 1):
698815
try:
699-
return fn(*args, **kwargs)
816+
out = fn(*args, **kwargs)
817+
if return_num_failures:
818+
return out, num_failures
819+
else:
820+
return out
700821
except Exception as e:
822+
num_failures += 1
701823
latest_exception = traceback.format_exc()
702824
if not silent:
703-
logging.debug(f'Function call failed with the following exception:\n{latest_exception}')
825+
print(f'Function call failed with the following exception:\n{latest_exception}')
704826
if retry_num < (retries - 1):
705-
logging.debug(f'Retrying {retries - (retry_num + 1)} more times...\n')
827+
print(f'Retrying {retries - (retry_num + 1)} more times...\n')
706828
time.sleep(np.random.uniform(wait - wait * jitter, wait + wait * jitter))
707829
raise RuntimeError(f'Function call failed {retries} times.\nLatest exception:\n{latest_exception}\n')
708830

src/synthesizrr/base/util/language.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,6 +1111,38 @@ def as_list(l) -> List:
11111111
return [l]
11121112

11131113

1114+
def list_pop_inplace(l: List, *, pop_condition: Callable) -> List:
1115+
assert isinstance(l, list) ## Needs to be a mutable
1116+
## Iterate backwards to preserve indexes while iterating
1117+
for i in range(len(l) - 1, -1, -1): # Iterate backwards
1118+
if pop_condition(l[i]):
1119+
l.pop(i) ## Remove the item inplace
1120+
return l
1121+
1122+
1123+
def set_union(*args) -> Set:
1124+
_union: Set = set()
1125+
for s in args:
1126+
if isinstance(s, (pd.Series, np.ndarray)):
1127+
s: List = s.tolist()
1128+
s: Set = set(s)
1129+
_union: Set = _union.union(s)
1130+
return _union
1131+
1132+
1133+
def set_intersection(*args) -> Set:
1134+
_intersection: Optional[Set] = None
1135+
for s in args:
1136+
if isinstance(s, (pd.Series, np.ndarray)):
1137+
s: List = s.tolist()
1138+
s: Set = set(s)
1139+
if _intersection is None:
1140+
_intersection: Set = s
1141+
else:
1142+
_intersection: Set = _intersection.intersection(s)
1143+
return _intersection
1144+
1145+
11141146
def filter_string_list(l: List[str], pattern: str, ignorecase: bool = False) -> List[str]:
11151147
"""
11161148
Filter a list of strings based on an exact match to a regex pattern. Leaves non-string items untouched.
@@ -1201,7 +1233,7 @@ def as_tuple(l) -> Tuple:
12011233

12021234

12031235
## ======================== Set utils ======================== ##
1204-
def is_set_like(l: Union[Set, frozenset]) -> bool:
1236+
def is_set_like(l: Any) -> bool:
12051237
return isinstance(l, (set, frozenset, KeysView))
12061238

12071239

@@ -1945,7 +1977,7 @@ def mean(vals):
19451977

19461978

19471979
def random_sample(
1948-
data: Union[List, Tuple, np.ndarray],
1980+
data: Union[List, Tuple, Set, np.ndarray],
19491981
n: SampleSizeType,
19501982
*,
19511983
replacement: bool = False,
@@ -1961,6 +1993,8 @@ def random_sample(
19611993
"""
19621994
np_random = np.random.RandomState(seed)
19631995
py_random = random.Random(seed)
1996+
if is_set_like(data):
1997+
data: List = list(data)
19641998
if not is_list_like(data):
19651999
raise ValueError(
19662000
f'Input `data` must be {list}, {tuple} or {np.ndarray}; '
@@ -2973,7 +3007,7 @@ def pd_partial_column_order(df: pd.DataFrame, columns: List) -> pd.DataFrame:
29733007

29743008
class ProgressBar(MutableParameters):
29753009
pbar: Optional[TqdmProgressBar] = None
2976-
style: Literal['auto', 'notebook', 'std'] = 'auto'
3010+
style: Literal['auto', 'notebook', 'std', 'ray'] = 'auto'
29773011
unit: str = 'row'
29783012
color: str = '#0288d1' ## Bluish
29793013
ncols: int = 100
@@ -2998,7 +3032,7 @@ def _set_params(cls, params: Dict) -> Dict:
29983032
@classmethod
29993033
def _create_pbar(
30003034
cls,
3001-
style: Literal['auto', 'notebook', 'std'],
3035+
style: Literal['auto', 'notebook', 'std', 'ray'],
30023036
**kwargs,
30033037
) -> TqdmProgressBar:
30043038
if style == 'auto':
@@ -3009,6 +3043,15 @@ def _create_pbar(
30093043
with optional_dependency('ipywidgets'):
30103044
kwargs['ncols']: Optional[int] = None
30113045
return NotebookTqdmProgressBar(**kwargs)
3046+
elif style == 'ray':
3047+
from ray.experimental import tqdm_ray
3048+
kwargs = filter_keys(
3049+
kwargs,
3050+
keys=set(get_fn_spec(tqdm_ray.tqdm).args + get_fn_spec(tqdm_ray.tqdm).kwargs),
3051+
how='include',
3052+
)
3053+
from ray.experimental import tqdm_ray
3054+
return tqdm_ray.tqdm(**kwargs)
30123055
else:
30133056
return StdTqdmProgressBar(**kwargs)
30143057

@@ -3363,3 +3406,13 @@ def plotsum(
33633406
else:
33643407
raise not_impl('how', how)
33653408
return plots
3409+
3410+
3411+
def to_pct(counts: pd.Series): ## Converts value counts to percentages
3412+
_sum = counts.sum()
3413+
return pd.DataFrame({
3414+
'value': counts.index.tolist(),
3415+
'count': counts.tolist(),
3416+
'pct': counts.apply(lambda x: 100 * x / _sum).tolist(),
3417+
'count_str': counts.apply(lambda x: f'{x} of {_sum}').tolist(),
3418+
})

0 commit comments

Comments
 (0)