diff --git a/alphatrion/experiment/base.py b/alphatrion/experiment/base.py index 376bce9..77d3952 100644 --- a/alphatrion/experiment/base.py +++ b/alphatrion/experiment/base.py @@ -194,7 +194,11 @@ async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): - self.done() + # If we are here because of an exception, we want to mark the experiment as done with error. + if exc_type is not None: + self.done_with_err() + else: + self.done() self._end_status = None if self._signal_task: diff --git a/tests/unit/experiment/test_experiment.py b/tests/unit/experiment/test_experiment.py index 7544600..8327575 100644 --- a/tests/unit/experiment/test_experiment.py +++ b/tests/unit/experiment/test_experiment.py @@ -160,6 +160,36 @@ async def test_experiment_with_done_with_err(): assert global_runtime().metadb.get_run(run_id=run_id).status == Status.CANCELLED +@pytest.mark.asyncio +async def test_experiment_exception_handling(): + """Test that exceptions in experiment context automatically mark it as failed.""" + init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + org_id=uuid.uuid4(), + ) + + exp_id = None + run_id = None + + with pytest.raises(ValueError, match="Simulated error"): + async with CraftExperiment.start(name="failing-experiment") as exp: + exp_id = exp.id + run = exp.run(lambda: asyncio.sleep(2)) + run_id = run.id + raise ValueError("Simulated error") + + # Verify experiment was marked as FAILED due to exception + exp_obj = global_runtime().metadb.get_experiment(experiment_id=exp_id) + assert exp_obj is not None + assert exp_obj.status == Status.FAILED + assert exp_obj.duration is not None + + # Verify run was cancelled + run_obj = global_runtime().metadb.get_run(run_id=run_id) + assert run_obj.status == Status.CANCELLED + + @pytest.mark.asyncio async def test_experiment_with_resume(): init(