Skip to content

Commit 92839f2

Browse files
Fix flaky log streaming test (#8601)
1 parent d30151b commit 92839f2

File tree

1 file changed

+59
-52
lines changed

1 file changed

+59
-52
lines changed

services/api-server/tests/unit/test_services_rabbitmq.py

Lines changed: 59 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import asyncio
99
import logging
1010
import random
11-
from collections.abc import AsyncIterable, Callable, Iterable
11+
from collections.abc import AsyncIterable, Callable
1212
from contextlib import asynccontextmanager
1313
from datetime import datetime, timedelta
1414
from typing import Final, Literal, cast
@@ -311,94 +311,90 @@ async def _log_publisher():
311311

312312

313313
@pytest.fixture
314-
def computation_done() -> Iterable[Callable[[], bool]]:
315-
stop_time: Final[datetime] = datetime.now() + timedelta(seconds=2)
316-
317-
def _job_done() -> bool:
318-
return datetime.now() >= stop_time
319-
320-
return _job_done
321-
322-
323-
@pytest.fixture
324-
async def log_streamer_with_distributor(
314+
async def create_log_streamer_with_distributor(
325315
client: httpx.AsyncClient,
326316
app: FastAPI,
327317
project_id: ProjectID,
328318
user_id: UserID,
329319
mocked_directorv2_rest_api_base: respx.MockRouter,
330-
computation_done: Callable[[], bool],
331320
log_distributor: LogDistributor,
332-
) -> AsyncIterable[LogStreamer]:
333-
def _get_computation(request: httpx.Request, **kwargs) -> httpx.Response:
334-
task = ComputationTaskGet.model_validate(
335-
ComputationTaskGet.model_json_schema()["examples"][0]
336-
)
337-
if computation_done():
338-
task.state = RunningState.SUCCESS
339-
task.stopped = datetime.now()
340-
return httpx.Response(
341-
status_code=status.HTTP_200_OK, json=jsonable_encoder(task)
342-
)
321+
) -> Callable[[Callable[..., httpx.Response]], LogStreamer]:
343322

344-
mocked_directorv2_rest_api_base.get(f"/v2/computations/{project_id}").mock(
345-
side_effect=_get_computation
346-
)
323+
def _create_log_streamer_with_distributor(
324+
get_computation: Callable[..., httpx.Response],
325+
) -> LogStreamer:
347326

348-
assert isinstance(d2_client := DirectorV2Api.get_instance(app), DirectorV2Api)
349-
yield LogStreamer(
350-
user_id=user_id,
351-
director2_api=d2_client,
352-
job_id=project_id,
353-
log_distributor=log_distributor,
354-
log_check_timeout=1,
355-
)
327+
mocked_directorv2_rest_api_base.get(f"/v2/computations/{project_id}").mock(
328+
side_effect=get_computation
329+
)
356330

357-
assert len(log_distributor._log_streamers.keys()) == 0
331+
assert isinstance(d2_client := DirectorV2Api.get_instance(app), DirectorV2Api)
332+
return LogStreamer(
333+
user_id=user_id,
334+
director2_api=d2_client,
335+
job_id=project_id,
336+
log_distributor=log_distributor,
337+
log_check_timeout=2,
338+
)
339+
340+
return _create_log_streamer_with_distributor
358341

359342

360343
async def test_log_streamer_with_distributor(
361344
project_id: ProjectID,
362345
node_id: NodeID,
363346
produce_logs: Callable,
364347
log_distributor: LogDistributor,
365-
log_streamer_with_distributor: LogStreamer,
348+
create_log_streamer_with_distributor: Callable[
349+
[Callable[..., httpx.Response]], LogStreamer
350+
],
366351
faker: Faker,
367-
computation_done: Callable[[], bool],
368352
):
369353
published_logs: list[str] = []
370354

371355
async def _log_publisher():
372-
while not computation_done():
356+
start = datetime.now()
357+
while (datetime.now() - start) < timedelta(seconds=1):
373358
msg: str = faker.text()
374359
await produce_logs("expected", project_id, node_id, [msg], logging.DEBUG)
375360
published_logs.append(msg)
361+
await asyncio.sleep(0.1)
376362

377363
publish_task = asyncio.create_task(_log_publisher())
378364

365+
def _get_computation(request: httpx.Request, **kwargs) -> httpx.Response:
366+
task = ComputationTaskGet.model_validate(
367+
ComputationTaskGet.model_json_schema()["examples"][0]
368+
)
369+
if publish_task.done():
370+
task.state = RunningState.SUCCESS
371+
task.stopped = datetime.now()
372+
else:
373+
task.state = RunningState.STARTED
374+
task.stopped = None
375+
return httpx.Response(
376+
status_code=status.HTTP_200_OK, json=jsonable_encoder(task)
377+
)
378+
379+
log_streamer = create_log_streamer_with_distributor(_get_computation)
380+
379381
@asynccontextmanager
380382
async def registered_log_streamer():
381-
await log_distributor.register(project_id, log_streamer_with_distributor.queue)
383+
await log_distributor.register(project_id, log_streamer.queue)
382384
try:
383385
yield
384386
finally:
385387
await log_distributor.deregister(project_id)
386388

387389
collected_messages: list[str] = []
388390
async with registered_log_streamer():
389-
async for log in log_streamer_with_distributor.log_generator():
391+
async for log in log_streamer.log_generator():
390392
job_log: JobLog = JobLog.model_validate_json(log)
391393
assert len(job_log.messages) == 1
392394
assert job_log.job_id == project_id
393395
collected_messages.append(job_log.messages[0])
394396

395-
if not publish_task.done():
396-
publish_task.cancel()
397-
try:
398-
await publish_task
399-
except asyncio.CancelledError:
400-
pass
401-
397+
assert publish_task.done()
402398
assert len(published_logs) > 0
403399
assert published_logs == collected_messages
404400

@@ -408,9 +404,9 @@ async def test_log_streamer_not_raise_with_distributor(
408404
project_id: ProjectID,
409405
node_id: NodeID,
410406
produce_logs: Callable,
411-
log_streamer_with_distributor: LogStreamer,
412-
faker: Faker,
413-
computation_done: Callable[[], bool],
407+
create_log_streamer_with_distributor: Callable[
408+
[Callable[..., httpx.Response]], LogStreamer
409+
],
414410
):
415411
class InvalidLoggerRabbitMessage(LoggerRabbitMessage):
416412
channel_name: Literal["simcore.services.logs.v2"] = "simcore.services.logs.v2"
@@ -433,8 +429,19 @@ def routing_key(self) -> str:
433429

434430
await produce_logs("expected", log_message=log_rabbit_message)
435431

432+
def _get_computation(request: httpx.Request, **kwargs) -> httpx.Response:
433+
task = ComputationTaskGet.model_validate(
434+
ComputationTaskGet.model_json_schema()["examples"][0]
435+
)
436+
task.state = RunningState.SUCCESS
437+
task.stopped = datetime.now()
438+
return httpx.Response(
439+
status_code=status.HTTP_200_OK, json=jsonable_encoder(task)
440+
)
441+
442+
log_streamer = create_log_streamer_with_distributor(_get_computation)
436443
ii: int = 0
437-
async for log in log_streamer_with_distributor.log_generator():
444+
async for log in log_streamer.log_generator():
438445
_ = JobLog.model_validate_json(log)
439446
ii += 1
440447
assert ii == 0

0 commit comments

Comments
 (0)