From a4f7f9683d01898851631eddacfd5a98f6dce1a2 Mon Sep 17 00:00:00 2001 From: kerthcet Date: Mon, 8 Jun 2026 22:01:05 +0800 Subject: [PATCH 1/2] mark experiment cancelled or aborted in deleteExperiments api Signed-off-by: kerthcet --- alphatrion/storage/sqlstore.py | 22 +- .../server/test_graphql_mutation.py | 129 +++++--- tests/integration/test_tracking.py | 12 +- tests/unit/storage/test_sql.py | 277 ++++++++++++++++++ 4 files changed, 393 insertions(+), 47 deletions(-) diff --git a/alphatrion/storage/sqlstore.py b/alphatrion/storage/sqlstore.py index 890b919..4c1e890 100644 --- a/alphatrion/storage/sqlstore.py +++ b/alphatrion/storage/sqlstore.py @@ -779,25 +779,41 @@ def delete_experiments(self, experiment_ids: list[uuid.UUID]) -> int: Also deletes all associated runs. Returns the number of experiments successfully deleted. Caller must ensure user has permission to delete these experiments. + + Status transitions on deletion: + - PENDING experiments -> ABORTED + - RUNNING experiments -> CANCELLED + - Other statuses remain unchanged """ + from sqlalchemy import case + session = self._session() # Delete the experiments - # if experiment is running, skip deletion for that experiment filtered_exps = ( session.query(Experiment.uuid) .filter( Experiment.uuid.in_(experiment_ids), Experiment.is_del == 0, - Experiment.status != Status.RUNNING, ) .all() ) filtered_exp_ids = [exp_id for (exp_id,) in filtered_exps] # unpack tuples + # Update status based on current state: PENDING -> ABORTED, RUNNING -> CANCELLED deleted_count = ( session.query(Experiment) .filter(Experiment.uuid.in_(filtered_exp_ids)) - .update({Experiment.is_del: 1}, synchronize_session=False) + .update( + { + Experiment.is_del: 1, + Experiment.status: case( + (Experiment.status == Status.PENDING, Status.ABORTED), + (Experiment.status == Status.RUNNING, Status.CANCELLED), + else_=Experiment.status, + ), + }, + synchronize_session=False, + ) ) # Delete all runs associated with these experiments session.query(Run).filter(Run.experiment_id.in_(filtered_exp_ids)).update( diff --git a/tests/integration/server/test_graphql_mutation.py b/tests/integration/server/test_graphql_mutation.py index 86f1188..f86abf7 100644 --- a/tests/integration/server/test_graphql_mutation.py +++ b/tests/integration/server/test_graphql_mutation.py @@ -946,10 +946,10 @@ def test_delete_running_experiment_fails( assert exp.status == Status.RUNNING -def test_delete_experiments_skips_running( +def test_delete_experiments_with_running_and_pending( execute_graphql, test_org_id, test_user_id, test_team_id ): - """Test that batch delete skips running experiments""" + """Test that batch delete handles running (CANCELLED) and pending (ABORTED) experiments""" runtime.init() metadb = runtime.storage_runtime().metadb @@ -961,8 +961,6 @@ def test_delete_experiments_skips_running( name="Completed Experiment", ) metadb.update_experiment( - org_id=test_org_id, - team_id=test_team_id, experiment_id=exp_id_1, status=Status.COMPLETED, ) @@ -974,8 +972,6 @@ def test_delete_experiments_skips_running( name="Running Experiment", ) metadb.update_experiment( - org_id=test_org_id, - team_id=test_team_id, experiment_id=exp_id_2, status=Status.RUNNING, ) @@ -984,12 +980,21 @@ def test_delete_experiments_skips_running( org_id=test_org_id, team_id=test_team_id, user_id=test_user_id, - name="Failed Experiment", + name="Pending Experiment", ) metadb.update_experiment( + experiment_id=exp_id_3, + status=Status.PENDING, + ) + + exp_id_4 = metadb.create_experiment( org_id=test_org_id, team_id=test_team_id, - experiment_id=exp_id_3, + user_id=test_user_id, + name="Failed Experiment", + ) + metadb.update_experiment( + experiment_id=exp_id_4, status=Status.FAILED, ) @@ -1012,16 +1017,23 @@ def test_delete_experiments_skips_running( user_id=test_user_id, experiment_id=exp_id_3, ) + run_id_4 = metadb.create_run( + org_id=test_org_id, + team_id=test_team_id, + user_id=test_user_id, + experiment_id=exp_id_4, + ) # Verify all experiments exist assert metadb.get_experiment(experiment_id=exp_id_1) is not None assert metadb.get_experiment(experiment_id=exp_id_2) is not None assert metadb.get_experiment(experiment_id=exp_id_3) is not None + assert metadb.get_experiment(experiment_id=exp_id_4) is not None - # Try to batch delete all experiments + # Batch delete all experiments mutation = f""" mutation {{ - deleteExperiments(experimentIds: ["{exp_id_1}", "{exp_id_2}", "{exp_id_3}"]) + deleteExperiments(experimentIds: ["{exp_id_1}", "{exp_id_2}", "{exp_id_3}", "{exp_id_4}"]) }} """ response = execute_graphql( @@ -1030,35 +1042,50 @@ def test_delete_experiments_skips_running( user_id=test_user_id, ) assert response.errors is None - # Should only delete 2 experiments (skipped the running one) - assert response.data["deleteExperiments"] == 2 + # Should delete all 4 experiments + assert response.data["deleteExperiments"] == 4 - # Verify running experiment still exists - exp_2 = metadb.get_experiment(experiment_id=exp_id_2) - assert exp_2 is not None - assert exp_2.status == Status.RUNNING - - # Verify non-running experiments are deleted + # Verify all experiments are deleted (get_experiment filters out deleted ones) assert metadb.get_experiment(experiment_id=exp_id_1) is None + assert metadb.get_experiment(experiment_id=exp_id_2) is None assert metadb.get_experiment(experiment_id=exp_id_3) is None + assert metadb.get_experiment(experiment_id=exp_id_4) is None + + # Verify status changes by querying directly from database + from alphatrion.storage.sql_models import Experiment + + session = metadb._session() + exp_1 = session.query(Experiment).filter(Experiment.uuid == exp_id_1).first() + exp_2 = session.query(Experiment).filter(Experiment.uuid == exp_id_2).first() + exp_3 = session.query(Experiment).filter(Experiment.uuid == exp_id_3).first() + exp_4 = session.query(Experiment).filter(Experiment.uuid == exp_id_4).first() - # Verify runs of deleted experiments are also deleted + # COMPLETED stays COMPLETED + assert exp_1.status == Status.COMPLETED + # RUNNING becomes CANCELLED + assert exp_2.status == Status.CANCELLED + # PENDING becomes ABORTED + assert exp_3.status == Status.ABORTED + # FAILED stays FAILED + assert exp_4.status == Status.FAILED + + session.close() + + # Verify all runs are deleted assert metadb.get_run(run_id=run_id_1) is None + assert metadb.get_run(run_id=run_id_2) is None assert metadb.get_run(run_id=run_id_3) is None - - # Verify run of running experiment still exists - run_2 = metadb.get_run(run_id=run_id_2) - assert run_2 is not None + assert metadb.get_run(run_id=run_id_4) is None -def test_delete_experiments_all_running( +def test_delete_experiments_all_running_and_pending( execute_graphql, test_org_id, test_user_id, test_team_id ): - """Test that batch delete returns 0 when all experiments are running""" + """Test that batch delete handles all running and pending experiments correctly""" runtime.init() metadb = runtime.storage_runtime().metadb - # Create multiple running experiments + # Create running and pending experiments exp_id_1 = metadb.create_experiment( org_id=test_org_id, team_id=test_team_id, @@ -1066,8 +1093,6 @@ def test_delete_experiments_all_running( name="Running Experiment 1", ) metadb.update_experiment( - org_id=test_org_id, - team_id=test_team_id, experiment_id=exp_id_1, status=Status.RUNNING, ) @@ -1079,16 +1104,25 @@ def test_delete_experiments_all_running( name="Running Experiment 2", ) metadb.update_experiment( - org_id=test_org_id, - team_id=test_team_id, experiment_id=exp_id_2, status=Status.RUNNING, ) - # Try to batch delete all running experiments + exp_id_3 = metadb.create_experiment( + org_id=test_org_id, + team_id=test_team_id, + user_id=test_user_id, + name="Pending Experiment 1", + ) + metadb.update_experiment( + experiment_id=exp_id_3, + status=Status.PENDING, + ) + + # Batch delete all experiments mutation = f""" mutation {{ - deleteExperiments(experimentIds: ["{exp_id_1}", "{exp_id_2}"]) + deleteExperiments(experimentIds: ["{exp_id_1}", "{exp_id_2}", "{exp_id_3}"]) }} """ response = execute_graphql( @@ -1097,16 +1131,29 @@ def test_delete_experiments_all_running( user_id=test_user_id, ) assert response.errors is None - # Should delete 0 experiments (all are running) - assert response.data["deleteExperiments"] == 0 + # Should delete all 3 experiments + assert response.data["deleteExperiments"] == 3 + + # Verify all experiments are deleted + assert metadb.get_experiment(experiment_id=exp_id_1) is None + assert metadb.get_experiment(experiment_id=exp_id_2) is None + assert metadb.get_experiment(experiment_id=exp_id_3) is None + + # Verify status changes by querying directly from database + from alphatrion.storage.sql_models import Experiment + + session = metadb._session() + exp_1 = session.query(Experiment).filter(Experiment.uuid == exp_id_1).first() + exp_2 = session.query(Experiment).filter(Experiment.uuid == exp_id_2).first() + exp_3 = session.query(Experiment).filter(Experiment.uuid == exp_id_3).first() + + # RUNNING experiments become CANCELLED + assert exp_1.status == Status.CANCELLED + assert exp_2.status == Status.CANCELLED + # PENDING experiment becomes ABORTED + assert exp_3.status == Status.ABORTED - # Verify all experiments still exist - exp_1 = metadb.get_experiment(experiment_id=exp_id_1) - exp_2 = metadb.get_experiment(experiment_id=exp_id_2) - assert exp_1 is not None - assert exp_2 is not None - assert exp_1.status == Status.RUNNING - assert exp_2.status == Status.RUNNING + session.close() def test_create_experiment_mutation( diff --git a/tests/integration/test_tracking.py b/tests/integration/test_tracking.py index 551496a..ff3a6c9 100644 --- a/tests/integration/test_tracking.py +++ b/tests/integration/test_tracking.py @@ -699,7 +699,9 @@ async def versioned_workflow(): agent_spans = [s for s in spans if s[1] == "agent"] tool_spans = [s for s in spans if s[1] == "tool"] - assert len(agent_spans) >= 1, f"Expected at least 1 agent span, got {len(agent_spans)}" + assert len(agent_spans) >= 1, ( + f"Expected at least 1 agent span, got {len(agent_spans)}" + ) assert len(tool_spans) >= 1, f"Expected at least 1 tool span, got {len(tool_spans)}" # Check that function names appear in span names @@ -717,8 +719,12 @@ async def versioned_workflow(): agent_version = agent_spans[0][2] tool_version = tool_spans[0][2] - assert agent_version == "3", f"Agent version not tracked, expected '3', got '{agent_version}'" - assert tool_version == "2", f"Tool version not tracked, expected '2', got '{tool_version}'" + assert agent_version == "3", ( + f"Agent version not tracked, expected '3', got '{agent_version}'" + ) + assert tool_version == "2", ( + f"Tool version not tracked, expected '2', got '{tool_version}'" + ) @pytest.mark.asyncio diff --git a/tests/unit/storage/test_sql.py b/tests/unit/storage/test_sql.py index 805029f..bce2894 100644 --- a/tests/unit/storage/test_sql.py +++ b/tests/unit/storage/test_sql.py @@ -240,3 +240,280 @@ def test_user_and_team_in_same_org_wrong_target_org(db): # Verify returns False when checking against wrong target org assert db.user_and_team_in_same_org(user_id, team_id, wrong_org_id) is False + + +def test_delete_experiments_basic(db): + """Test basic batch deletion of experiments""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + + # Create multiple experiments + exp_id1 = db.create_experiment( + org_id=org_id, team_id=team_id, user_id=user_id, name="exp1" + ) + exp_id2 = db.create_experiment( + org_id=org_id, team_id=team_id, user_id=user_id, name="exp2" + ) + exp_id3 = db.create_experiment( + org_id=org_id, team_id=team_id, user_id=user_id, name="exp3" + ) + + # Create some runs for these experiments + run_id1 = db.create_run( + org_id=org_id, team_id=team_id, user_id=user_id, experiment_id=exp_id1 + ) + run_id2 = db.create_run( + org_id=org_id, team_id=team_id, user_id=user_id, experiment_id=exp_id2 + ) + + # Delete two experiments + deleted_count = db.delete_experiments([exp_id1, exp_id2]) + assert deleted_count == 2 + + # Verify experiments are marked as deleted + exp1 = db.get_experiment(exp_id1) + exp2 = db.get_experiment(exp_id2) + exp3 = db.get_experiment(exp_id3) + assert exp1 is None + assert exp2 is None + assert exp3 is not None + + # Verify runs are also marked as deleted + run1 = db.get_run(run_id1) + run2 = db.get_run(run_id2) + assert run1 is None + assert run2 is None + + +def test_delete_experiments_pending_status(db): + """Test that pending experiments are marked as ABORTED when deleted""" + + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + + # Create pending experiments + exp_id1 = db.create_experiment( + org_id=org_id, + team_id=team_id, + user_id=user_id, + name="pending_exp1", + status=Status.PENDING, + ) + exp_id2 = db.create_experiment( + org_id=org_id, + team_id=team_id, + user_id=user_id, + name="pending_exp2", + status=Status.PENDING, + ) + + # Delete the pending experiments + deleted_count = db.delete_experiments([exp_id1, exp_id2]) + assert deleted_count == 2 + + # Verify experiments are deleted but check status in the database directly + # (since get_experiment filters out deleted experiments) + session = db._session() + from alphatrion.storage.sql_models import Experiment + + exp1 = session.query(Experiment).filter(Experiment.uuid == exp_id1).first() + exp2 = session.query(Experiment).filter(Experiment.uuid == exp_id2).first() + + assert exp1 is not None + assert exp1.is_del == 1 + assert exp1.status == Status.ABORTED + + assert exp2 is not None + assert exp2.is_del == 1 + assert exp2.status == Status.ABORTED + + session.close() + + +def test_delete_experiments_running_status(db): + """Test that running experiments are marked as CANCELLED when deleted""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + + # Create running experiments + exp_id1 = db.create_experiment( + org_id=org_id, + team_id=team_id, + user_id=user_id, + name="running_exp1", + status=Status.RUNNING, + ) + exp_id2 = db.create_experiment( + org_id=org_id, + team_id=team_id, + user_id=user_id, + name="running_exp2", + status=Status.RUNNING, + ) + + # Delete the running experiments + deleted_count = db.delete_experiments([exp_id1, exp_id2]) + assert deleted_count == 2 + + # Verify experiments are deleted and status is CANCELLED + session = db._session() + from alphatrion.storage.sql_models import Experiment + + exp1 = session.query(Experiment).filter(Experiment.uuid == exp_id1).first() + exp2 = session.query(Experiment).filter(Experiment.uuid == exp_id2).first() + + assert exp1 is not None + assert exp1.is_del == 1 + assert exp1.status == Status.CANCELLED + + assert exp2 is not None + assert exp2.is_del == 1 + assert exp2.status == Status.CANCELLED + + session.close() + + +def test_delete_experiments_mixed_statuses(db): + """Test deleting experiments with various statuses""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + + # Create experiments with different statuses + exp_id1 = db.create_experiment( + org_id=org_id, + team_id=team_id, + user_id=user_id, + name="pending_exp", + status=Status.PENDING, + ) + exp_id2 = db.create_experiment( + org_id=org_id, + team_id=team_id, + user_id=user_id, + name="running_exp", + status=Status.RUNNING, + ) + exp_id3 = db.create_experiment( + org_id=org_id, + team_id=team_id, + user_id=user_id, + name="completed_exp", + status=Status.COMPLETED, + ) + exp_id4 = db.create_experiment( + org_id=org_id, + team_id=team_id, + user_id=user_id, + name="failed_exp", + status=Status.FAILED, + ) + + # Create runs for these experiments + run_id1 = db.create_run( + org_id=org_id, team_id=team_id, user_id=user_id, experiment_id=exp_id1 + ) + run_id2 = db.create_run( + org_id=org_id, team_id=team_id, user_id=user_id, experiment_id=exp_id2 + ) + run_id3 = db.create_run( + org_id=org_id, team_id=team_id, user_id=user_id, experiment_id=exp_id3 + ) + + # Delete all experiments + deleted_count = db.delete_experiments([exp_id1, exp_id2, exp_id3, exp_id4]) + assert deleted_count == 4 + + # Verify experiments have correct status after deletion + session = db._session() + from alphatrion.storage.sql_models import Experiment + + exp1 = session.query(Experiment).filter(Experiment.uuid == exp_id1).first() + exp2 = session.query(Experiment).filter(Experiment.uuid == exp_id2).first() + exp3 = session.query(Experiment).filter(Experiment.uuid == exp_id3).first() + exp4 = session.query(Experiment).filter(Experiment.uuid == exp_id4).first() + + # PENDING -> ABORTED + assert exp1.is_del == 1 + assert exp1.status == Status.ABORTED + + # RUNNING -> CANCELLED + assert exp2.is_del == 1 + assert exp2.status == Status.CANCELLED + + # COMPLETED stays COMPLETED + assert exp3.is_del == 1 + assert exp3.status == Status.COMPLETED + + # FAILED stays FAILED + assert exp4.is_del == 1 + assert exp4.status == Status.FAILED + + session.close() + + # Verify all runs are marked as deleted + run1 = db.get_run(run_id1) + run2 = db.get_run(run_id2) + run3 = db.get_run(run_id3) + assert run1 is None + assert run2 is None + assert run3 is None + + +def test_delete_experiments_empty_list(db): + """Test deleting with empty experiment list""" + deleted_count = db.delete_experiments([]) + assert deleted_count == 0 + + +def test_delete_experiments_nonexistent_ids(db): + """Test deleting nonexistent experiments""" + nonexistent_id1 = uuid.uuid4() + nonexistent_id2 = uuid.uuid4() + + deleted_count = db.delete_experiments([nonexistent_id1, nonexistent_id2]) + assert deleted_count == 0 + + +def test_delete_experiments_mixed_existing_and_nonexistent(db): + """Test deleting mix of existing and nonexistent experiments""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + + # Create one experiment + exp_id1 = db.create_experiment( + org_id=org_id, team_id=team_id, user_id=user_id, name="exp1" + ) + nonexistent_id = uuid.uuid4() + + # Delete one existing and one nonexistent + deleted_count = db.delete_experiments([exp_id1, nonexistent_id]) + assert deleted_count == 1 + + # Verify the existing one is deleted + exp1 = db.get_experiment(exp_id1) + assert exp1 is None + + +def test_delete_experiments_already_deleted(db): + """Test deleting experiments that are already deleted""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + + # Create experiment + exp_id = db.create_experiment( + org_id=org_id, team_id=team_id, user_id=user_id, name="exp1" + ) + + # Delete once + deleted_count = db.delete_experiments([exp_id]) + assert deleted_count == 1 + + # Try to delete again + deleted_count = db.delete_experiments([exp_id]) + assert deleted_count == 0 From 15416ab454172a92305dca3ada6a4419c9f0a05a Mon Sep 17 00:00:00 2001 From: kerthcet Date: Tue, 9 Jun 2026 00:35:55 +0800 Subject: [PATCH 2/2] fix session leakage problem Signed-off-by: kerthcet --- .../server/test_graphql_mutation.py | 50 +++++----- tests/unit/storage/test_sql.py | 95 +++++++++---------- 2 files changed, 67 insertions(+), 78 deletions(-) diff --git a/tests/integration/server/test_graphql_mutation.py b/tests/integration/server/test_graphql_mutation.py index f86abf7..34cdfc2 100644 --- a/tests/integration/server/test_graphql_mutation.py +++ b/tests/integration/server/test_graphql_mutation.py @@ -1054,22 +1054,20 @@ def test_delete_experiments_with_running_and_pending( # Verify status changes by querying directly from database from alphatrion.storage.sql_models import Experiment - session = metadb._session() - exp_1 = session.query(Experiment).filter(Experiment.uuid == exp_id_1).first() - exp_2 = session.query(Experiment).filter(Experiment.uuid == exp_id_2).first() - exp_3 = session.query(Experiment).filter(Experiment.uuid == exp_id_3).first() - exp_4 = session.query(Experiment).filter(Experiment.uuid == exp_id_4).first() - - # COMPLETED stays COMPLETED - assert exp_1.status == Status.COMPLETED - # RUNNING becomes CANCELLED - assert exp_2.status == Status.CANCELLED - # PENDING becomes ABORTED - assert exp_3.status == Status.ABORTED - # FAILED stays FAILED - assert exp_4.status == Status.FAILED - - session.close() + with metadb._session() as session: + exp_1 = session.query(Experiment).filter(Experiment.uuid == exp_id_1).first() + exp_2 = session.query(Experiment).filter(Experiment.uuid == exp_id_2).first() + exp_3 = session.query(Experiment).filter(Experiment.uuid == exp_id_3).first() + exp_4 = session.query(Experiment).filter(Experiment.uuid == exp_id_4).first() + + # COMPLETED stays COMPLETED + assert exp_1.status == Status.COMPLETED + # RUNNING becomes CANCELLED + assert exp_2.status == Status.CANCELLED + # PENDING becomes ABORTED + assert exp_3.status == Status.ABORTED + # FAILED stays FAILED + assert exp_4.status == Status.FAILED # Verify all runs are deleted assert metadb.get_run(run_id=run_id_1) is None @@ -1142,18 +1140,16 @@ def test_delete_experiments_all_running_and_pending( # Verify status changes by querying directly from database from alphatrion.storage.sql_models import Experiment - session = metadb._session() - exp_1 = session.query(Experiment).filter(Experiment.uuid == exp_id_1).first() - exp_2 = session.query(Experiment).filter(Experiment.uuid == exp_id_2).first() - exp_3 = session.query(Experiment).filter(Experiment.uuid == exp_id_3).first() + with metadb._session() as session: + exp_1 = session.query(Experiment).filter(Experiment.uuid == exp_id_1).first() + exp_2 = session.query(Experiment).filter(Experiment.uuid == exp_id_2).first() + exp_3 = session.query(Experiment).filter(Experiment.uuid == exp_id_3).first() - # RUNNING experiments become CANCELLED - assert exp_1.status == Status.CANCELLED - assert exp_2.status == Status.CANCELLED - # PENDING experiment becomes ABORTED - assert exp_3.status == Status.ABORTED - - session.close() + # RUNNING experiments become CANCELLED + assert exp_1.status == Status.CANCELLED + assert exp_2.status == Status.CANCELLED + # PENDING experiment becomes ABORTED + assert exp_3.status == Status.ABORTED def test_create_experiment_mutation( diff --git a/tests/unit/storage/test_sql.py b/tests/unit/storage/test_sql.py index bce2894..a40fcc2 100644 --- a/tests/unit/storage/test_sql.py +++ b/tests/unit/storage/test_sql.py @@ -288,6 +288,7 @@ def test_delete_experiments_basic(db): def test_delete_experiments_pending_status(db): """Test that pending experiments are marked as ABORTED when deleted""" + from alphatrion.storage.sql_models import Experiment org_id = uuid.uuid4() team_id = uuid.uuid4() @@ -315,25 +316,23 @@ def test_delete_experiments_pending_status(db): # Verify experiments are deleted but check status in the database directly # (since get_experiment filters out deleted experiments) - session = db._session() - from alphatrion.storage.sql_models import Experiment - - exp1 = session.query(Experiment).filter(Experiment.uuid == exp_id1).first() - exp2 = session.query(Experiment).filter(Experiment.uuid == exp_id2).first() + with db._session() as session: + exp1 = session.query(Experiment).filter(Experiment.uuid == exp_id1).first() + exp2 = session.query(Experiment).filter(Experiment.uuid == exp_id2).first() - assert exp1 is not None - assert exp1.is_del == 1 - assert exp1.status == Status.ABORTED + assert exp1 is not None + assert exp1.is_del == 1 + assert exp1.status == Status.ABORTED - assert exp2 is not None - assert exp2.is_del == 1 - assert exp2.status == Status.ABORTED - - session.close() + assert exp2 is not None + assert exp2.is_del == 1 + assert exp2.status == Status.ABORTED def test_delete_experiments_running_status(db): """Test that running experiments are marked as CANCELLED when deleted""" + from alphatrion.storage.sql_models import Experiment + org_id = uuid.uuid4() team_id = uuid.uuid4() user_id = uuid.uuid4() @@ -359,25 +358,23 @@ def test_delete_experiments_running_status(db): assert deleted_count == 2 # Verify experiments are deleted and status is CANCELLED - session = db._session() - from alphatrion.storage.sql_models import Experiment - - exp1 = session.query(Experiment).filter(Experiment.uuid == exp_id1).first() - exp2 = session.query(Experiment).filter(Experiment.uuid == exp_id2).first() - - assert exp1 is not None - assert exp1.is_del == 1 - assert exp1.status == Status.CANCELLED + with db._session() as session: + exp1 = session.query(Experiment).filter(Experiment.uuid == exp_id1).first() + exp2 = session.query(Experiment).filter(Experiment.uuid == exp_id2).first() - assert exp2 is not None - assert exp2.is_del == 1 - assert exp2.status == Status.CANCELLED + assert exp1 is not None + assert exp1.is_del == 1 + assert exp1.status == Status.CANCELLED - session.close() + assert exp2 is not None + assert exp2.is_del == 1 + assert exp2.status == Status.CANCELLED def test_delete_experiments_mixed_statuses(db): """Test deleting experiments with various statuses""" + from alphatrion.storage.sql_models import Experiment + org_id = uuid.uuid4() team_id = uuid.uuid4() user_id = uuid.uuid4() @@ -428,31 +425,27 @@ def test_delete_experiments_mixed_statuses(db): assert deleted_count == 4 # Verify experiments have correct status after deletion - session = db._session() - from alphatrion.storage.sql_models import Experiment - - exp1 = session.query(Experiment).filter(Experiment.uuid == exp_id1).first() - exp2 = session.query(Experiment).filter(Experiment.uuid == exp_id2).first() - exp3 = session.query(Experiment).filter(Experiment.uuid == exp_id3).first() - exp4 = session.query(Experiment).filter(Experiment.uuid == exp_id4).first() - - # PENDING -> ABORTED - assert exp1.is_del == 1 - assert exp1.status == Status.ABORTED - - # RUNNING -> CANCELLED - assert exp2.is_del == 1 - assert exp2.status == Status.CANCELLED - - # COMPLETED stays COMPLETED - assert exp3.is_del == 1 - assert exp3.status == Status.COMPLETED - - # FAILED stays FAILED - assert exp4.is_del == 1 - assert exp4.status == Status.FAILED - - session.close() + with db._session() as session: + exp1 = session.query(Experiment).filter(Experiment.uuid == exp_id1).first() + exp2 = session.query(Experiment).filter(Experiment.uuid == exp_id2).first() + exp3 = session.query(Experiment).filter(Experiment.uuid == exp_id3).first() + exp4 = session.query(Experiment).filter(Experiment.uuid == exp_id4).first() + + # PENDING -> ABORTED + assert exp1.is_del == 1 + assert exp1.status == Status.ABORTED + + # RUNNING -> CANCELLED + assert exp2.is_del == 1 + assert exp2.status == Status.CANCELLED + + # COMPLETED stays COMPLETED + assert exp3.is_del == 1 + assert exp3.status == Status.COMPLETED + + # FAILED stays FAILED + assert exp4.is_del == 1 + assert exp4.status == Status.FAILED # Verify all runs are marked as deleted run1 = db.get_run(run_id1)