Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion alphatrion/experiment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,11 @@ async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
self.done()
# If we are here because of an exception, we want to mark the experiment as done with error.
if exc_type is not None:
self.done_with_err()
else:
self.done()
self._end_status = None

if self._signal_task:
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/experiment/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,36 @@ async def test_experiment_with_done_with_err():
assert global_runtime().metadb.get_run(run_id=run_id).status == Status.CANCELLED


@pytest.mark.asyncio
async def test_experiment_exception_handling():
"""Test that exceptions in experiment context automatically mark it as failed."""
init(
team_id=uuid.uuid4(),
user_id=uuid.uuid4(),
org_id=uuid.uuid4(),
)

exp_id = None
run_id = None

with pytest.raises(ValueError, match="Simulated error"):
async with CraftExperiment.start(name="failing-experiment") as exp:
exp_id = exp.id
run = exp.run(lambda: asyncio.sleep(2))
run_id = run.id
raise ValueError("Simulated error")

# Verify experiment was marked as FAILED due to exception
exp_obj = global_runtime().metadb.get_experiment(experiment_id=exp_id)
assert exp_obj is not None
assert exp_obj.status == Status.FAILED
assert exp_obj.duration is not None

# Verify run was cancelled
run_obj = global_runtime().metadb.get_run(run_id=run_id)
assert run_obj.status == Status.CANCELLED


@pytest.mark.asyncio
async def test_experiment_with_resume():
init(
Expand Down
Loading