88import asyncio
99import logging
1010import random
11- from collections .abc import AsyncIterable , Callable , Iterable
11+ from collections .abc import AsyncIterable , Callable
1212from contextlib import asynccontextmanager
1313from datetime import datetime , timedelta
1414from typing import Final , Literal , cast
@@ -311,94 +311,90 @@ async def _log_publisher():
311311
312312
313313@pytest .fixture
314- def computation_done () -> Iterable [Callable [[], bool ]]:
315- stop_time : Final [datetime ] = datetime .now () + timedelta (seconds = 2 )
316-
317- def _job_done () -> bool :
318- return datetime .now () >= stop_time
319-
320- return _job_done
321-
322-
323- @pytest .fixture
324- async def log_streamer_with_distributor (
314+ async def create_log_streamer_with_distributor (
325315 client : httpx .AsyncClient ,
326316 app : FastAPI ,
327317 project_id : ProjectID ,
328318 user_id : UserID ,
329319 mocked_directorv2_rest_api_base : respx .MockRouter ,
330- computation_done : Callable [[], bool ],
331320 log_distributor : LogDistributor ,
332- ) -> AsyncIterable [LogStreamer ]:
333- def _get_computation (request : httpx .Request , ** kwargs ) -> httpx .Response :
334- task = ComputationTaskGet .model_validate (
335- ComputationTaskGet .model_json_schema ()["examples" ][0 ]
336- )
337- if computation_done ():
338- task .state = RunningState .SUCCESS
339- task .stopped = datetime .now ()
340- return httpx .Response (
341- status_code = status .HTTP_200_OK , json = jsonable_encoder (task )
342- )
321+ ) -> Callable [[Callable [..., httpx .Response ]], LogStreamer ]:
343322
344- mocked_directorv2_rest_api_base . get ( f"/v2/computations/ { project_id } " ). mock (
345- side_effect = _get_computation
346- )
323+ def _create_log_streamer_with_distributor (
324+ get_computation : Callable [..., httpx . Response ],
325+ ) -> LogStreamer :
347326
348- assert isinstance (d2_client := DirectorV2Api .get_instance (app ), DirectorV2Api )
349- yield LogStreamer (
350- user_id = user_id ,
351- director2_api = d2_client ,
352- job_id = project_id ,
353- log_distributor = log_distributor ,
354- log_check_timeout = 1 ,
355- )
327+ mocked_directorv2_rest_api_base .get (f"/v2/computations/{ project_id } " ).mock (
328+ side_effect = get_computation
329+ )
356330
357- assert len (log_distributor ._log_streamers .keys ()) == 0
331+ assert isinstance (d2_client := DirectorV2Api .get_instance (app ), DirectorV2Api )
332+ return LogStreamer (
333+ user_id = user_id ,
334+ director2_api = d2_client ,
335+ job_id = project_id ,
336+ log_distributor = log_distributor ,
337+ log_check_timeout = 2 ,
338+ )
339+
340+ return _create_log_streamer_with_distributor
358341
359342
360343async def test_log_streamer_with_distributor (
361344 project_id : ProjectID ,
362345 node_id : NodeID ,
363346 produce_logs : Callable ,
364347 log_distributor : LogDistributor ,
365- log_streamer_with_distributor : LogStreamer ,
348+ create_log_streamer_with_distributor : Callable [
349+ [Callable [..., httpx .Response ]], LogStreamer
350+ ],
366351 faker : Faker ,
367- computation_done : Callable [[], bool ],
368352):
369353 published_logs : list [str ] = []
370354
371355 async def _log_publisher ():
372- while not computation_done ():
356+ start = datetime .now ()
357+ while (datetime .now () - start ) < timedelta (seconds = 1 ):
373358 msg : str = faker .text ()
374359 await produce_logs ("expected" , project_id , node_id , [msg ], logging .DEBUG )
375360 published_logs .append (msg )
361+ await asyncio .sleep (0.1 )
376362
377363 publish_task = asyncio .create_task (_log_publisher ())
378364
365+ def _get_computation (request : httpx .Request , ** kwargs ) -> httpx .Response :
366+ task = ComputationTaskGet .model_validate (
367+ ComputationTaskGet .model_json_schema ()["examples" ][0 ]
368+ )
369+ if publish_task .done ():
370+ task .state = RunningState .SUCCESS
371+ task .stopped = datetime .now ()
372+ else :
373+ task .state = RunningState .STARTED
374+ task .stopped = None
375+ return httpx .Response (
376+ status_code = status .HTTP_200_OK , json = jsonable_encoder (task )
377+ )
378+
379+ log_streamer = create_log_streamer_with_distributor (_get_computation )
380+
379381 @asynccontextmanager
380382 async def registered_log_streamer ():
381- await log_distributor .register (project_id , log_streamer_with_distributor .queue )
383+ await log_distributor .register (project_id , log_streamer .queue )
382384 try :
383385 yield
384386 finally :
385387 await log_distributor .deregister (project_id )
386388
387389 collected_messages : list [str ] = []
388390 async with registered_log_streamer ():
389- async for log in log_streamer_with_distributor .log_generator ():
391+ async for log in log_streamer .log_generator ():
390392 job_log : JobLog = JobLog .model_validate_json (log )
391393 assert len (job_log .messages ) == 1
392394 assert job_log .job_id == project_id
393395 collected_messages .append (job_log .messages [0 ])
394396
395- if not publish_task .done ():
396- publish_task .cancel ()
397- try :
398- await publish_task
399- except asyncio .CancelledError :
400- pass
401-
397+ assert publish_task .done ()
402398 assert len (published_logs ) > 0
403399 assert published_logs == collected_messages
404400
@@ -408,9 +404,9 @@ async def test_log_streamer_not_raise_with_distributor(
408404 project_id : ProjectID ,
409405 node_id : NodeID ,
410406 produce_logs : Callable ,
411- log_streamer_with_distributor : LogStreamer ,
412- faker : Faker ,
413- computation_done : Callable [[], bool ],
407+ create_log_streamer_with_distributor : Callable [
408+ [ Callable [..., httpx . Response ]], LogStreamer
409+ ],
414410):
415411 class InvalidLoggerRabbitMessage (LoggerRabbitMessage ):
416412 channel_name : Literal ["simcore.services.logs.v2" ] = "simcore.services.logs.v2"
@@ -433,8 +429,19 @@ def routing_key(self) -> str:
433429
434430 await produce_logs ("expected" , log_message = log_rabbit_message )
435431
432+ def _get_computation (request : httpx .Request , ** kwargs ) -> httpx .Response :
433+ task = ComputationTaskGet .model_validate (
434+ ComputationTaskGet .model_json_schema ()["examples" ][0 ]
435+ )
436+ task .state = RunningState .SUCCESS
437+ task .stopped = datetime .now ()
438+ return httpx .Response (
439+ status_code = status .HTTP_200_OK , json = jsonable_encoder (task )
440+ )
441+
442+ log_streamer = create_log_streamer_with_distributor (_get_computation )
436443 ii : int = 0
437- async for log in log_streamer_with_distributor .log_generator ():
444+ async for log in log_streamer .log_generator ():
438445 _ = JobLog .model_validate_json (log )
439446 ii += 1
440447 assert ii == 0
0 commit comments