diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index cf572107a30..fea3702626b 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -21,6 +21,7 @@ from firebolt.common._types import ColType, ParameterType, SetParameter from firebolt.common.constants import ( JSON_OUTPUT_FORMAT, + REMOVE_PARAMETERS_HEADER, RESET_SESSION_HEADER, UPDATE_ENDPOINT_HEADER, UPDATE_PARAMETERS_HEADER, @@ -28,6 +29,7 @@ ) from firebolt.common.cursor.base_cursor import ( BaseCursor, + _parse_remove_parameters, _parse_update_endpoint, _parse_update_parameters, _raise_if_internal_set_parameter, @@ -194,6 +196,10 @@ async def _parse_response_headers(self, headers: Headers) -> None: param_dict = _parse_update_parameters(headers.get(UPDATE_PARAMETERS_HEADER)) self._update_set_parameters(param_dict) + if headers.get(REMOVE_PARAMETERS_HEADER): + param_list = _parse_remove_parameters(headers.get(REMOVE_PARAMETERS_HEADER)) + self._remove_set_parameters(param_list) + async def _close_rowset_and_reset(self) -> None: """Reset cursor state.""" if self._row_set is not None: diff --git a/src/firebolt/client/constants.py b/src/firebolt/client/constants.py index f5860e83d23..7fb4f1d849a 100644 --- a/src/firebolt/client/constants.py +++ b/src/firebolt/client/constants.py @@ -5,7 +5,7 @@ DEFAULT_API_URL: str = "api.app.firebolt.io" PROTOCOL_VERSION_HEADER_NAME = "Firebolt-Protocol-Version" -PROTOCOL_VERSION: str = "2.3" +PROTOCOL_VERSION: str = "2.4" _REQUEST_ERRORS: Tuple[Type, ...] = ( HTTPError, InvalidURL, diff --git a/src/firebolt/common/constants.py b/src/firebolt/common/constants.py index 9e1a94cf764..9622a0a7cf9 100644 --- a/src/firebolt/common/constants.py +++ b/src/firebolt/common/constants.py @@ -34,3 +34,4 @@ class ParameterStyle(Enum): UPDATE_ENDPOINT_HEADER = "Firebolt-Update-Endpoint" UPDATE_PARAMETERS_HEADER = "Firebolt-Update-Parameters" RESET_SESSION_HEADER = "Firebolt-Reset-Session" +REMOVE_PARAMETERS_HEADER = "Firebolt-Remove-Parameters" diff --git a/src/firebolt/common/cursor/base_cursor.py b/src/firebolt/common/cursor/base_cursor.py index 84937886428..932d10a7bfe 100644 --- a/src/firebolt/common/cursor/base_cursor.py +++ b/src/firebolt/common/cursor/base_cursor.py @@ -42,6 +42,13 @@ def _parse_update_parameters(parameter_header: str) -> Dict[str, str]: return param_dict +def _parse_remove_parameters(parameter_header: str) -> List[str]: + """Parse remove parameters header and return list of parameter names to remove.""" + # parse key1,key2,key3 comma separated string into list + param_list = [item.strip() for item in parameter_header.split(",")] + return param_list + + def _parse_update_endpoint( new_engine_endpoint_header: str, ) -> Tuple[str, Dict[str, str]]: @@ -223,6 +230,14 @@ def _update_set_parameters(self, parameters: Dict[str, Any]) -> None: self._set_parameters.update(user_parameters) + def _remove_set_parameters(self, parameter_names: List[str]) -> None: + """Remove parameters from both user and immutable parameter collections.""" + for param_name in parameter_names: + # Remove from user parameters + self._set_parameters.pop(param_name, None) + # Remove from immutable parameters + self.parameters.pop(param_name, None) + def _update_server_parameters(self, parameters: Dict[str, Any]) -> None: for key, value in parameters.items(): self.parameters[key] = value diff --git a/src/firebolt/db/cursor.py b/src/firebolt/db/cursor.py index b0c67dcedb8..0b1bb2ff24f 100644 --- a/src/firebolt/db/cursor.py +++ b/src/firebolt/db/cursor.py @@ -29,6 +29,7 @@ from firebolt.common._types import ColType, ParameterType, SetParameter from firebolt.common.constants import ( JSON_OUTPUT_FORMAT, + REMOVE_PARAMETERS_HEADER, RESET_SESSION_HEADER, UPDATE_ENDPOINT_HEADER, UPDATE_PARAMETERS_HEADER, @@ -36,6 +37,7 @@ ) from firebolt.common.cursor.base_cursor import ( BaseCursor, + _parse_remove_parameters, _parse_update_endpoint, _parse_update_parameters, _raise_if_internal_set_parameter, @@ -200,6 +202,10 @@ def _parse_response_headers(self, headers: Headers) -> None: param_dict = _parse_update_parameters(headers.get(UPDATE_PARAMETERS_HEADER)) self._update_set_parameters(param_dict) + if headers.get(REMOVE_PARAMETERS_HEADER): + param_list = _parse_remove_parameters(headers.get(REMOVE_PARAMETERS_HEADER)) + self._remove_set_parameters(param_list) + def _close_rowset_and_reset(self) -> None: """Reset the cursor state.""" if self._row_set is not None: diff --git a/tests/integration/dbapi/async/V2/test_queries_async.py b/tests/integration/dbapi/async/V2/test_queries_async.py index 27d41862506..94aac2d8679 100644 --- a/tests/integration/dbapi/async/V2/test_queries_async.py +++ b/tests/integration/dbapi/async/V2/test_queries_async.py @@ -289,12 +289,12 @@ async def test_parameterized_query_with_special_chars(connection: Connection) -> [ ( "fb_numeric", - 'INSERT INTO "test_tbl" VALUES ($1, $2)', + 'INSERT INTO "{table}" VALUES ($1, $2)', [(1, "alice"), (2, "bob"), (3, "charlie")], ), ( "qmark", - 'INSERT INTO "test_tbl" VALUES (?, ?)', + 'INSERT INTO "{table}" VALUES (?, ?)', [(4, "david"), (5, "eve"), (6, "frank")], ), ], @@ -312,7 +312,7 @@ async def test_executemany_bulk_insert_paramstyles( firebolt.async_db.paramstyle = paramstyle # Generate a unique label for this test execution unique_label = f"test_bulk_insert_async_{paramstyle}_{randint(100000, 999999)}" - table_name = "test_tbl" + table_name = create_drop_test_table_setup_teardown_async try: c = connection.cursor() @@ -323,7 +323,7 @@ async def test_executemany_bulk_insert_paramstyles( # Execute bulk insert await c.executemany( - query, + query.format(table=table_name), test_data, bulk_insert=True, ) @@ -767,3 +767,90 @@ async def test_select_quoted_bigint( assert result[0][0] == int( long_bigint_value ), "Invalid data returned by fetchall" + + +async def test_transaction_commit( + connection: Connection, create_drop_test_table_setup_teardown_async: Callable +) -> None: + """Test transaction SQL statements with COMMIT.""" + table_name = create_drop_test_table_setup_teardown_async + async with connection.cursor() as c: + # Test successful transaction with COMMIT + result = await c.execute("BEGIN TRANSACTION") + assert result == 0, "BEGIN TRANSACTION should return 0 rows" + + await c.execute(f"INSERT INTO \"{table_name}\" VALUES (1, 'committed')") + + result = await c.execute("COMMIT TRANSACTION") + assert result == 0, "COMMIT TRANSACTION should return 0 rows" + + # Verify the data was committed + await c.execute(f'SELECT * FROM "{table_name}" WHERE id = 1') + data = await c.fetchall() + assert len(data) == 1, "Committed data should be present" + assert data[0] == [ + 1, + "committed", + ], "Committed data should match inserted values" + + +async def test_transaction_rollback( + connection: Connection, create_drop_test_table_setup_teardown_async: Callable +) -> None: + """Test transaction SQL statements with ROLLBACK.""" + table_name = create_drop_test_table_setup_teardown_async + async with connection.cursor() as c: + # Test transaction with ROLLBACK + result = await c.execute("BEGIN") # Test short form + assert result == 0, "BEGIN should return 0 rows" + + await c.execute(f"INSERT INTO \"{table_name}\" VALUES (1, 'rolled_back')") + + # Verify data is visible within transaction + await c.execute(f'SELECT * FROM "{table_name}" WHERE id = 1') + data = await c.fetchall() + assert len(data) == 1, "Data should be visible within transaction" + + result = await c.execute("ROLLBACK") # Test short form + assert result == 0, "ROLLBACK should return 0 rows" + + # Verify the data was rolled back + await c.execute(f'SELECT * FROM "{table_name}" WHERE id = 1') + data = await c.fetchall() + assert len(data) == 0, "Rolled back data should not be present" + + +async def test_transaction_cursor_isolation( + connection: Connection, create_drop_test_table_setup_teardown_async: Callable +) -> None: + """Test that one cursor can't see another's data until it commits.""" + table_name = create_drop_test_table_setup_teardown_async + cursor1 = connection.cursor() + cursor2 = connection.cursor() + + # Start transaction in cursor1 and insert data + result = await cursor1.execute("BEGIN TRANSACTION") + assert result == 0, "BEGIN TRANSACTION should return 0 rows" + + await cursor1.execute(f"INSERT INTO \"{table_name}\" VALUES (1, 'isolated_data')") + + # Verify cursor1 can see its own uncommitted data + await cursor1.execute(f'SELECT * FROM "{table_name}" WHERE id = 1') + data1 = await cursor1.fetchall() + assert len(data1) == 1, "Cursor1 should see its own uncommitted data" + assert data1[0] == [1, "isolated_data"], "Cursor1 data should match inserted values" + + # Verify cursor2 cannot see cursor1's uncommitted data + await cursor2.execute(f'SELECT * FROM "{table_name}" WHERE id = 1') + data2 = await cursor2.fetchall() + assert len(data2) == 0, "Cursor2 should not see cursor1's uncommitted data" + + # Commit the transaction in cursor1 + result = await cursor1.execute("COMMIT TRANSACTION") + assert result == 0, "COMMIT TRANSACTION should return 0 rows" + + # Now cursor2 should be able to see the committed data + await cursor2.execute(f'SELECT * FROM "{table_name}" WHERE id = 1') + data2 = await cursor2.fetchall() + assert len(data2) == 1, "Cursor2 should see committed data after commit" + assert data2[0] == [1, "isolated_data"], "Cursor2 should see the committed data" diff --git a/tests/integration/dbapi/conftest.py b/tests/integration/dbapi/conftest.py index 7fc257d10cb..36929d4b7f7 100644 --- a/tests/integration/dbapi/conftest.py +++ b/tests/integration/dbapi/conftest.py @@ -1,4 +1,5 @@ import os +import uuid from datetime import date, datetime, timedelta, timezone from decimal import Decimal from logging import getLogger @@ -12,8 +13,8 @@ LOGGER = getLogger(__name__) -CREATE_TEST_TABLE = 'CREATE TABLE IF NOT EXISTS "test_tbl" (id int, name string)' -DROP_TEST_TABLE = 'DROP TABLE IF EXISTS "test_tbl" CASCADE' +CREATE_TEST_TABLE = 'CREATE TABLE IF NOT EXISTS "{table}" (id int, name string)' +DROP_TEST_TABLE = 'DROP TABLE IF EXISTS "{table}" CASCADE' LONG_SELECT_DEFAULT_V1 = 250000000000 LONG_SELECT_DEFAULT_V2 = 350000000000 @@ -29,38 +30,27 @@ def long_test_value_with_default(default: int = 0) -> int: return long_test_value_with_default -@fixture -def create_drop_test_table_setup_teardown(connection: Connection) -> None: - with connection.cursor() as c: - c.execute(CREATE_TEST_TABLE) - yield c - c.execute(DROP_TEST_TABLE) - - -@fixture -async def create_server_side_test_table_setup_teardown_async( - connection: Connection, -) -> None: - with connection.cursor() as c: - await c.execute(CREATE_TEST_TABLE) - yield c - await c.execute(DROP_TEST_TABLE) +def generate_unique_table_name() -> str: + """Generate a unique table name for testing purposes.""" + return f"test_table_{uuid.uuid4().hex}" @fixture def create_drop_test_table_setup_teardown(connection: Connection) -> None: + table = generate_unique_table_name() with connection.cursor() as c: - c.execute(CREATE_TEST_TABLE) - yield c - c.execute(DROP_TEST_TABLE) + c.execute(CREATE_TEST_TABLE.format(table=table)) + yield table + c.execute(DROP_TEST_TABLE.format(table=table)) @fixture async def create_drop_test_table_setup_teardown_async(connection: Connection) -> None: + table = generate_unique_table_name() async with connection.cursor() as c: - await c.execute(CREATE_TEST_TABLE) - yield c - await c.execute(DROP_TEST_TABLE) + await c.execute(CREATE_TEST_TABLE.format(table=table)) + yield table + await c.execute(DROP_TEST_TABLE.format(table=table)) @fixture diff --git a/tests/integration/dbapi/sync/V2/test_queries.py b/tests/integration/dbapi/sync/V2/test_queries.py index 55b99d5340c..c2ee62ceceb 100644 --- a/tests/integration/dbapi/sync/V2/test_queries.py +++ b/tests/integration/dbapi/sync/V2/test_queries.py @@ -290,12 +290,12 @@ def test_empty_query(c: Cursor, query: str, params: tuple) -> None: [ ( "fb_numeric", - 'INSERT INTO "test_tbl" VALUES ($1, $2)', + 'INSERT INTO "{table}" VALUES ($1, $2)', [(1, "alice"), (2, "bob"), (3, "charlie")], ), ( "qmark", - 'INSERT INTO "test_tbl" VALUES (?, ?)', + 'INSERT INTO "{table}" VALUES (?, ?)', [(4, "david"), (5, "eve"), (6, "frank")], ), ], @@ -313,7 +313,7 @@ def test_executemany_bulk_insert_paramstyles( firebolt.db.paramstyle = paramstyle # Generate a unique label for this test execution unique_label = f"test_bulk_insert_{paramstyle}_{randint(100000, 999999)}" - table_name = "test_tbl" + table_name = create_drop_test_table_setup_teardown try: c = connection.cursor() @@ -324,7 +324,7 @@ def test_executemany_bulk_insert_paramstyles( # Execute bulk insert c.executemany( - query, + query.format(table=table_name), test_data, bulk_insert=True, ) @@ -768,3 +768,90 @@ def test_select_quoted_bigint( assert result[0][0] == int( long_bigint_value ), "Invalid data returned by fetchall" + + +def test_transaction_commit( + connection: Connection, create_drop_test_table_setup_teardown: Callable +) -> None: + """Test transaction SQL statements with COMMIT.""" + table_name = create_drop_test_table_setup_teardown + with connection.cursor() as c: + # Test successful transaction with COMMIT + result = c.execute("BEGIN TRANSACTION") + assert result == 0, "BEGIN TRANSACTION should return 0 rows" + + c.execute(f"INSERT INTO \"{table_name}\" VALUES (1, 'committed')") + + result = c.execute("COMMIT TRANSACTION") + assert result == 0, "COMMIT TRANSACTION should return 0 rows" + + # Verify the data was committed + c.execute(f'SELECT * FROM "{table_name}" WHERE id = 1') + data = c.fetchall() + assert len(data) == 1, "Committed data should be present" + assert data[0] == [ + 1, + "committed", + ], "Committed data should match inserted values" + + +def test_transaction_rollback( + connection: Connection, create_drop_test_table_setup_teardown: Callable +) -> None: + """Test transaction SQL statements with ROLLBACK.""" + table_name = create_drop_test_table_setup_teardown + with connection.cursor() as c: + # Test transaction with ROLLBACK + result = c.execute("BEGIN") # Test short form + assert result == 0, "BEGIN should return 0 rows" + + c.execute(f"INSERT INTO \"{table_name}\" VALUES (1, 'rolled_back')") + + # Verify data is visible within transaction + c.execute(f'SELECT * FROM "{table_name}" WHERE id = 1') + data = c.fetchall() + assert len(data) == 1, "Data should be visible within transaction" + + result = c.execute("ROLLBACK") # Test short form + assert result == 0, "ROLLBACK should return 0 rows" + + # Verify the data was rolled back + c.execute(f'SELECT * FROM "{table_name}" WHERE id = 1') + data = c.fetchall() + assert len(data) == 0, "Rolled back data should not be present" + + +def test_transaction_cursor_isolation( + connection: Connection, create_drop_test_table_setup_teardown: Callable +) -> None: + """Test that one cursor can't see another's data until it commits.""" + table_name = create_drop_test_table_setup_teardown + cursor1 = connection.cursor() + cursor2 = connection.cursor() + + # Start transaction in cursor1 and insert data + result = cursor1.execute("BEGIN TRANSACTION") + assert result == 0, "BEGIN TRANSACTION should return 0 rows" + + cursor1.execute(f"INSERT INTO \"{table_name}\" VALUES (1, 'isolated_data')") + + # Verify cursor1 can see its own uncommitted data + cursor1.execute(f'SELECT * FROM "{table_name}" WHERE id = 1') + data1 = cursor1.fetchall() + assert len(data1) == 1, "Cursor1 should see its own uncommitted data" + assert data1[0] == [1, "isolated_data"], "Cursor1 data should match inserted values" + + # Verify cursor2 cannot see cursor1's uncommitted data + cursor2.execute(f'SELECT * FROM "{table_name}" WHERE id = 1') + data2 = cursor2.fetchall() + assert len(data2) == 0, "Cursor2 should not see cursor1's uncommitted data" + + # Commit the transaction in cursor1 + result = cursor1.execute("COMMIT TRANSACTION") + assert result == 0, "COMMIT TRANSACTION should return 0 rows" + + # Now cursor2 should be able to see the committed data + cursor2.execute(f'SELECT * FROM "{table_name}" WHERE id = 1') + data2 = cursor2.fetchall() + assert len(data2) == 1, "Cursor2 should see committed data after commit" + assert data2[0] == [1, "isolated_data"], "Cursor2 should see the committed data" diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index f7bb2966717..0ba8ede63da 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -796,6 +796,64 @@ def query_callback_with_headers(request: Request, **kwargs) -> Response: assert bool(cursor.database) is True, "database is not set" +async def test_cursor_remove_parameters_header( + httpx_mock: HTTPXMock, + select_one_query_callback: Callable, + query_callback_with_remove_header: Callable, + set_query_url: str, + cursor: Cursor, +): + """Test that cursor removes parameters when REMOVE_PARAMETERS_HEADER is received.""" + + # Set up initial parameters + httpx_mock.add_callback( + select_one_query_callback, + url=f"{set_query_url}¶m1=value1", + is_reusable=True, + ) + httpx_mock.add_callback( + select_one_query_callback, + url=f"{set_query_url}¶m1=value1¶m2=value2", + is_reusable=True, + ) + httpx_mock.add_callback( + select_one_query_callback, + url=f"{set_query_url}¶m1=value1¶m2=value2¶m3=value3", + is_reusable=True, + ) + + assert len(cursor._set_parameters) == 0 + + # Execute SET statements to add parameters + await cursor.execute("set param1 = value1") + await cursor.execute("set param2 = value2") + await cursor.execute("set param3 = value3") + + assert len(cursor._set_parameters) == 3 + assert "param1" in cursor._set_parameters + assert "param2" in cursor._set_parameters + assert "param3" in cursor._set_parameters + assert cursor._set_parameters["param1"] == "value1" + assert cursor._set_parameters["param2"] == "value2" + assert cursor._set_parameters["param3"] == "value3" + + # Execute query that returns remove parameters header + httpx_mock.reset() + httpx_mock.add_callback( + query_callback_with_remove_header, + url=f"{set_query_url}¶m1=value1¶m2=value2¶m3=value3&output_format=JSON_Compact", + is_reusable=True, + ) + await cursor.execute("SELECT 1") + + # Verify that param1 and param3 were removed, param2 remains + assert len(cursor._set_parameters) == 1 + assert "param1" not in cursor._set_parameters + assert "param2" in cursor._set_parameters + assert "param3" not in cursor._set_parameters + assert cursor._set_parameters["param2"] == "value2" + + async def test_cursor_timeout( httpx_mock: HTTPXMock, select_one_query_callback: Callable, diff --git a/tests/unit/common/test_base_cursor.py b/tests/unit/common/test_base_cursor.py index 830732b593b..5d5173187dc 100644 --- a/tests/unit/common/test_base_cursor.py +++ b/tests/unit/common/test_base_cursor.py @@ -3,7 +3,11 @@ from pytest import fixture, mark -from firebolt.common.cursor.base_cursor import BaseCursor +from firebolt.common.cursor.base_cursor import ( + BaseCursor, + _parse_remove_parameters, + _parse_update_parameters, +) from firebolt.common.statement_formatter import create_statement_formatter @@ -57,3 +61,75 @@ def test_update_server_parameters_known_params( updated_parameters = initial_parameters.copy() updated_parameters.update({"database": "new_database"}) assert cursor.parameters == updated_parameters + + +@mark.parametrize( + "header, expected", + [ + ( + "key1=value1,key2=value2,key3=value3", + {"key1": "value1", "key2": "value2", "key3": "value3"}, + ), + ( + "key1 = value1 , key2= value2, key3 =value3", + {"key1": "value1", "key2": "value2", "key3": "value3"}, + ), + ], +) +def test_parse_update_parameters(header: str, expected: Dict[str, str]): + """Test parsing update parameters header.""" + result = _parse_update_parameters(header) + assert result == expected + + +@mark.parametrize( + "header, expected", + [ + ("key1,key2,key3", ["key1", "key2", "key3"]), + (" key1 , key2, key3 ", ["key1", "key2", "key3"]), + ], +) +def test_parse_remove_parameters(header: str, expected: list): + """Test parsing remove parameters header.""" + result = _parse_remove_parameters(header) + assert result == expected + + +@mark.parametrize( + "initial_set_params, initial_params, params_to_remove, expected_set_params, expected_params", + [ + ( + {"key1": "value1", "key2": "value2", "key3": "value3"}, + {"param1": "value1", "param2": "value2"}, + ["key1", "key3", "param1"], + {"key2": "value2"}, + {"param2": "value2"}, + ), + ( + {"key1": "value1", "key2": "value2"}, + {"param1": "value1"}, + ["nonexistent", "also_nonexistent"], + {"key1": "value1", "key2": "value2"}, + {"param1": "value1"}, + ), + ], +) +def test_remove_set_parameters( + cursor: BaseCursor, + initial_set_params: Dict[str, str], + initial_params: Dict[str, str], + params_to_remove: list, + expected_set_params: Dict[str, str], + expected_params: Dict[str, str], +): + """Test removing parameters from cursor.""" + # Set up initial parameters + cursor._set_parameters = initial_set_params + cursor.parameters = initial_params + + # Remove parameters + cursor._remove_set_parameters(params_to_remove) + + # Assert parameters were removed correctly + assert cursor._set_parameters == expected_set_params + assert cursor.parameters == expected_params diff --git a/tests/unit/db/test_cursor.py b/tests/unit/db/test_cursor.py index ceba19a174e..7d8983bbc3f 100644 --- a/tests/unit/db/test_cursor.py +++ b/tests/unit/db/test_cursor.py @@ -782,6 +782,64 @@ def query_callback_with_headers(request: Request, **kwargs) -> Response: assert bool(cursor.database) is True, "database is not set" +def test_cursor_remove_parameters_header( + httpx_mock: HTTPXMock, + select_one_query_callback: Callable, + query_callback_with_remove_header: Callable, + set_query_url: str, + cursor: Cursor, +): + """Test that cursor removes parameters when REMOVE_PARAMETERS_HEADER is received.""" + + # Set up initial parameters + httpx_mock.add_callback( + select_one_query_callback, + url=f"{set_query_url}¶m1=value1", + is_reusable=True, + ) + httpx_mock.add_callback( + select_one_query_callback, + url=f"{set_query_url}¶m1=value1¶m2=value2", + is_reusable=True, + ) + httpx_mock.add_callback( + select_one_query_callback, + url=f"{set_query_url}¶m1=value1¶m2=value2¶m3=value3", + is_reusable=True, + ) + + assert len(cursor._set_parameters) == 0 + + # Execute SET statements to add parameters + cursor.execute("set param1 = value1") + cursor.execute("set param2 = value2") + cursor.execute("set param3 = value3") + + assert len(cursor._set_parameters) == 3 + assert "param1" in cursor._set_parameters + assert "param2" in cursor._set_parameters + assert "param3" in cursor._set_parameters + assert cursor._set_parameters["param1"] == "value1" + assert cursor._set_parameters["param2"] == "value2" + assert cursor._set_parameters["param3"] == "value3" + + # Execute query that returns remove parameters header + httpx_mock.reset() + httpx_mock.add_callback( + query_callback_with_remove_header, + url=f"{set_query_url}¶m1=value1¶m2=value2¶m3=value3&output_format=JSON_Compact", + is_reusable=True, + ) + cursor.execute("SELECT 1") + + # Verify that param1 and param3 were removed, param2 remains + assert len(cursor._set_parameters) == 1 + assert "param1" not in cursor._set_parameters + assert "param2" in cursor._set_parameters + assert "param3" not in cursor._set_parameters + assert cursor._set_parameters["param2"] == "value2" + + def test_cursor_timeout( httpx_mock: HTTPXMock, select_one_query_callback: Callable, diff --git a/tests/unit/db_conftest.py b/tests/unit/db_conftest.py index 7e38dc8c5f9..a1ff6d15da4 100644 --- a/tests/unit/db_conftest.py +++ b/tests/unit/db_conftest.py @@ -260,6 +260,37 @@ def do_query(request: Request, **kwargs) -> Response: return do_query +@fixture +def remove_parameters() -> List[str]: + return ["param1", "param3"] + + +@fixture +def query_callback_with_remove_header( + query_statistics: Dict[str, Any], remove_parameters: List[str] +) -> Callable: + """Fixture for query callback that returns REMOVE_PARAMETERS_HEADER. + + Returns a callback that simulates a server response with Firebolt-Remove-Parameters + header containing 'param1,param3' to test parameter removal functionality. + """ + + def do_query(request: Request, **kwargs) -> Response: + assert request.read() != b"" + assert request.method == "POST" + query_response = { + "meta": [{"name": "one", "type": "int"}], + "data": [1], + "rows": 1, + "statistics": query_statistics, + } + # Header with comma-separated parameter names to remove + headers = {"Firebolt-Remove-Parameters": ",".join(remove_parameters)} + return Response(status_code=codes.OK, json=query_response, headers=headers) + + return do_query + + def encode_param(p: Any) -> str: return jdumps(p).strip('"')