Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/firebolt/async_db/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
from firebolt.common._types import ColType, ParameterType, SetParameter
from firebolt.common.constants import (
JSON_OUTPUT_FORMAT,
REMOVE_PARAMETERS_HEADER,
RESET_SESSION_HEADER,
UPDATE_ENDPOINT_HEADER,
UPDATE_PARAMETERS_HEADER,
CursorState,
)
from firebolt.common.cursor.base_cursor import (
BaseCursor,
_parse_remove_parameters,
_parse_update_endpoint,
_parse_update_parameters,
_raise_if_internal_set_parameter,
Expand Down Expand Up @@ -194,6 +196,10 @@ async def _parse_response_headers(self, headers: Headers) -> None:
param_dict = _parse_update_parameters(headers.get(UPDATE_PARAMETERS_HEADER))
self._update_set_parameters(param_dict)

if headers.get(REMOVE_PARAMETERS_HEADER):
param_list = _parse_remove_parameters(headers.get(REMOVE_PARAMETERS_HEADER))
self._remove_set_parameters(param_list)

async def _close_rowset_and_reset(self) -> None:
"""Reset cursor state."""
if self._row_set is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/firebolt/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

DEFAULT_API_URL: str = "api.app.firebolt.io"
PROTOCOL_VERSION_HEADER_NAME = "Firebolt-Protocol-Version"
PROTOCOL_VERSION: str = "2.3"
PROTOCOL_VERSION: str = "2.4"
_REQUEST_ERRORS: Tuple[Type, ...] = (
HTTPError,
InvalidURL,
Expand Down
1 change: 1 addition & 0 deletions src/firebolt/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ class ParameterStyle(Enum):
UPDATE_ENDPOINT_HEADER = "Firebolt-Update-Endpoint"
UPDATE_PARAMETERS_HEADER = "Firebolt-Update-Parameters"
RESET_SESSION_HEADER = "Firebolt-Reset-Session"
REMOVE_PARAMETERS_HEADER = "Firebolt-Remove-Parameters"
15 changes: 15 additions & 0 deletions src/firebolt/common/cursor/base_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def _parse_update_parameters(parameter_header: str) -> Dict[str, str]:
return param_dict


def _parse_remove_parameters(parameter_header: str) -> List[str]:
"""Parse remove parameters header and return list of parameter names to remove."""
# parse key1,key2,key3 comma separated string into list
param_list = [item.strip() for item in parameter_header.split(",")]
return param_list


def _parse_update_endpoint(
new_engine_endpoint_header: str,
) -> Tuple[str, Dict[str, str]]:
Expand Down Expand Up @@ -223,6 +230,14 @@ def _update_set_parameters(self, parameters: Dict[str, Any]) -> None:

self._set_parameters.update(user_parameters)

def _remove_set_parameters(self, parameter_names: List[str]) -> None:
"""Remove parameters from both user and immutable parameter collections."""
for param_name in parameter_names:
# Remove from user parameters
self._set_parameters.pop(param_name, None)
# Remove from immutable parameters
self.parameters.pop(param_name, None)

def _update_server_parameters(self, parameters: Dict[str, Any]) -> None:
for key, value in parameters.items():
self.parameters[key] = value
Expand Down
6 changes: 6 additions & 0 deletions src/firebolt/db/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@
from firebolt.common._types import ColType, ParameterType, SetParameter
from firebolt.common.constants import (
JSON_OUTPUT_FORMAT,
REMOVE_PARAMETERS_HEADER,
RESET_SESSION_HEADER,
UPDATE_ENDPOINT_HEADER,
UPDATE_PARAMETERS_HEADER,
CursorState,
)
from firebolt.common.cursor.base_cursor import (
BaseCursor,
_parse_remove_parameters,
_parse_update_endpoint,
_parse_update_parameters,
_raise_if_internal_set_parameter,
Expand Down Expand Up @@ -200,6 +202,10 @@ def _parse_response_headers(self, headers: Headers) -> None:
param_dict = _parse_update_parameters(headers.get(UPDATE_PARAMETERS_HEADER))
self._update_set_parameters(param_dict)

if headers.get(REMOVE_PARAMETERS_HEADER):
param_list = _parse_remove_parameters(headers.get(REMOVE_PARAMETERS_HEADER))
self._remove_set_parameters(param_list)

def _close_rowset_and_reset(self) -> None:
"""Reset the cursor state."""
if self._row_set is not None:
Expand Down
95 changes: 91 additions & 4 deletions tests/integration/dbapi/async/V2/test_queries_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,12 @@ async def test_parameterized_query_with_special_chars(connection: Connection) ->
[
(
"fb_numeric",
'INSERT INTO "test_tbl" VALUES ($1, $2)',
'INSERT INTO "{table}" VALUES ($1, $2)',
[(1, "alice"), (2, "bob"), (3, "charlie")],
),
(
"qmark",
'INSERT INTO "test_tbl" VALUES (?, ?)',
'INSERT INTO "{table}" VALUES (?, ?)',
[(4, "david"), (5, "eve"), (6, "frank")],
),
],
Expand All @@ -312,7 +312,7 @@ async def test_executemany_bulk_insert_paramstyles(
firebolt.async_db.paramstyle = paramstyle
# Generate a unique label for this test execution
unique_label = f"test_bulk_insert_async_{paramstyle}_{randint(100000, 999999)}"
table_name = "test_tbl"
table_name = create_drop_test_table_setup_teardown_async

try:
c = connection.cursor()
Expand All @@ -323,7 +323,7 @@ async def test_executemany_bulk_insert_paramstyles(

# Execute bulk insert
await c.executemany(
query,
query.format(table=table_name),
test_data,
bulk_insert=True,
)
Expand Down Expand Up @@ -767,3 +767,90 @@ async def test_select_quoted_bigint(
assert result[0][0] == int(
long_bigint_value
), "Invalid data returned by fetchall"


async def test_transaction_commit(
connection: Connection, create_drop_test_table_setup_teardown_async: Callable
) -> None:
"""Test transaction SQL statements with COMMIT."""
table_name = create_drop_test_table_setup_teardown_async
async with connection.cursor() as c:
# Test successful transaction with COMMIT
result = await c.execute("BEGIN TRANSACTION")
assert result == 0, "BEGIN TRANSACTION should return 0 rows"

await c.execute(f"INSERT INTO \"{table_name}\" VALUES (1, 'committed')")

result = await c.execute("COMMIT TRANSACTION")
assert result == 0, "COMMIT TRANSACTION should return 0 rows"

# Verify the data was committed
await c.execute(f'SELECT * FROM "{table_name}" WHERE id = 1')
data = await c.fetchall()
assert len(data) == 1, "Committed data should be present"
assert data[0] == [
1,
"committed",
], "Committed data should match inserted values"


async def test_transaction_rollback(
connection: Connection, create_drop_test_table_setup_teardown_async: Callable
) -> None:
"""Test transaction SQL statements with ROLLBACK."""
table_name = create_drop_test_table_setup_teardown_async
async with connection.cursor() as c:
# Test transaction with ROLLBACK
result = await c.execute("BEGIN") # Test short form
assert result == 0, "BEGIN should return 0 rows"

await c.execute(f"INSERT INTO \"{table_name}\" VALUES (1, 'rolled_back')")

# Verify data is visible within transaction
await c.execute(f'SELECT * FROM "{table_name}" WHERE id = 1')
data = await c.fetchall()
assert len(data) == 1, "Data should be visible within transaction"

result = await c.execute("ROLLBACK") # Test short form
assert result == 0, "ROLLBACK should return 0 rows"

# Verify the data was rolled back
await c.execute(f'SELECT * FROM "{table_name}" WHERE id = 1')
data = await c.fetchall()
assert len(data) == 0, "Rolled back data should not be present"


async def test_transaction_cursor_isolation(
connection: Connection, create_drop_test_table_setup_teardown_async: Callable
) -> None:
"""Test that one cursor can't see another's data until it commits."""
table_name = create_drop_test_table_setup_teardown_async
cursor1 = connection.cursor()
cursor2 = connection.cursor()

# Start transaction in cursor1 and insert data
result = await cursor1.execute("BEGIN TRANSACTION")
assert result == 0, "BEGIN TRANSACTION should return 0 rows"

await cursor1.execute(f"INSERT INTO \"{table_name}\" VALUES (1, 'isolated_data')")

# Verify cursor1 can see its own uncommitted data
await cursor1.execute(f'SELECT * FROM "{table_name}" WHERE id = 1')
data1 = await cursor1.fetchall()
assert len(data1) == 1, "Cursor1 should see its own uncommitted data"
assert data1[0] == [1, "isolated_data"], "Cursor1 data should match inserted values"

# Verify cursor2 cannot see cursor1's uncommitted data
await cursor2.execute(f'SELECT * FROM "{table_name}" WHERE id = 1')
data2 = await cursor2.fetchall()
assert len(data2) == 0, "Cursor2 should not see cursor1's uncommitted data"

# Commit the transaction in cursor1
result = await cursor1.execute("COMMIT TRANSACTION")
assert result == 0, "COMMIT TRANSACTION should return 0 rows"

# Now cursor2 should be able to see the committed data
await cursor2.execute(f'SELECT * FROM "{table_name}" WHERE id = 1')
data2 = await cursor2.fetchall()
assert len(data2) == 1, "Cursor2 should see committed data after commit"
assert data2[0] == [1, "isolated_data"], "Cursor2 should see the committed data"
38 changes: 14 additions & 24 deletions tests/integration/dbapi/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import uuid
from datetime import date, datetime, timedelta, timezone
from decimal import Decimal
from logging import getLogger
Expand All @@ -12,8 +13,8 @@

LOGGER = getLogger(__name__)

CREATE_TEST_TABLE = 'CREATE TABLE IF NOT EXISTS "test_tbl" (id int, name string)'
DROP_TEST_TABLE = 'DROP TABLE IF EXISTS "test_tbl" CASCADE'
CREATE_TEST_TABLE = 'CREATE TABLE IF NOT EXISTS "{table}" (id int, name string)'
DROP_TEST_TABLE = 'DROP TABLE IF EXISTS "{table}" CASCADE'

LONG_SELECT_DEFAULT_V1 = 250000000000
LONG_SELECT_DEFAULT_V2 = 350000000000
Expand All @@ -29,38 +30,27 @@ def long_test_value_with_default(default: int = 0) -> int:
return long_test_value_with_default


@fixture
def create_drop_test_table_setup_teardown(connection: Connection) -> None:
with connection.cursor() as c:
c.execute(CREATE_TEST_TABLE)
yield c
c.execute(DROP_TEST_TABLE)


@fixture
async def create_server_side_test_table_setup_teardown_async(
connection: Connection,
) -> None:
with connection.cursor() as c:
await c.execute(CREATE_TEST_TABLE)
yield c
await c.execute(DROP_TEST_TABLE)
def generate_unique_table_name() -> str:
"""Generate a unique table name for testing purposes."""
return f"test_table_{uuid.uuid4().hex}"


@fixture
def create_drop_test_table_setup_teardown(connection: Connection) -> None:
table = generate_unique_table_name()
with connection.cursor() as c:
c.execute(CREATE_TEST_TABLE)
yield c
c.execute(DROP_TEST_TABLE)
c.execute(CREATE_TEST_TABLE.format(table=table))
yield table
c.execute(DROP_TEST_TABLE.format(table=table))


@fixture
async def create_drop_test_table_setup_teardown_async(connection: Connection) -> None:
table = generate_unique_table_name()
async with connection.cursor() as c:
await c.execute(CREATE_TEST_TABLE)
yield c
await c.execute(DROP_TEST_TABLE)
await c.execute(CREATE_TEST_TABLE.format(table=table))
yield table
await c.execute(DROP_TEST_TABLE.format(table=table))


@fixture
Expand Down
Loading
Loading