From 78aa05272e90bb5f35574f9e7367de703fda5c99 Mon Sep 17 00:00:00 2001 From: Sanjeev Bhatt Date: Tue, 24 Mar 2026 07:26:48 +0000 Subject: [PATCH 1/2] feat: Implement initial Google Cloud Spanner DB-API 2.0 driver with core components and comprehensive unit and system tests. --- .../google/cloud/spanner_driver/__init__.py | 59 ++- .../google/cloud/spanner_driver/connection.py | 138 +++++ .../google/cloud/spanner_driver/cursor.py | 475 ++++++++++++++++++ .../google/cloud/spanner_driver/errors.py | 241 +++++++++ .../google/cloud/spanner_driver/types.py | 170 +++++++ .../noxfile.py | 1 + .../tests/system/_helper.py | 58 +++ .../tests/system/test_connection.py | 44 ++ .../tests/system/test_cursor.py | 144 ++++++ .../tests/system/test_errors.py | 74 +++ .../tests/system/test_executemany.py | 64 +++ .../tests/system/test_transaction.py | 116 +++++ .../tests/unit/conftest.py | 221 ++++++++ .../tests/unit/test_connection.py | 111 ++++ .../tests/unit/test_cursor.py | 349 +++++++++++++ .../tests/unit/test_errors.py | 57 +++ .../tests/unit/test_types.py | 57 +++ 17 files changed, 2376 insertions(+), 3 deletions(-) create mode 100644 packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/errors.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/types.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/system/_helper.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/system/test_connection.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/system/test_cursor.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/system/test_errors.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/system/test_executemany.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/system/test_transaction.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/unit/conftest.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/unit/test_connection.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/unit/test_cursor.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/unit/test_errors.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/unit/test_types.py diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py index b75f2a4d398f..d898b418c6f5 100644 --- a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py @@ -14,17 +14,70 @@ """Spanner Python Driver.""" import logging +from typing import Final -from . import version as package_version +from .connection import Connection, connect +from .cursor import Cursor from .dbapi import apilevel, paramstyle, threadsafety +from .errors import ( + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + NotSupportedError, + OperationalError, + ProgrammingError, + Warning, +) +from .types import ( + BINARY, + DATETIME, + NUMBER, + ROWID, + STRING, + Binary, + Date, + DateFromTicks, + Time, + TimeFromTicks, + Timestamp, + TimestampFromTicks, +) -__version__ = package_version.__version__ +__version__: Final[str] = "0.0.1" logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) __all__: list[str] = [ "apilevel", - "paramstyle", "threadsafety", + "paramstyle", + "Connection", + "connect", + "Cursor", + "Date", + "Time", + "Timestamp", + "DateFromTicks", + "TimeFromTicks", + "TimestampFromTicks", + "Binary", + "STRING", + "BINARY", + "NUMBER", + "DATETIME", + "ROWID", + "InterfaceError", + "ProgrammingError", + "OperationalError", + "DatabaseError", + "DataError", + "NotSupportedError", + "IntegrityError", + "InternalError", + "Warning", + "Error", ] diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py new file mode 100644 index 000000000000..12e4c3638d98 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py @@ -0,0 +1,138 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Any + +from google.cloud.spannerlib.pool import Pool + +from . import errors +from .cursor import Cursor + +logger = logging.getLogger(__name__) + + +def check_not_closed(function): + """`Connection` class methods decorator. + + Raise an exception if the connection is closed. + + :raises: :class:`InterfaceError` if the connection is closed. + """ + + def wrapper(connection, *args, **kwargs): + if connection._closed: + raise errors.InterfaceError("Connection is closed") + + return function(connection, *args, **kwargs) + + return wrapper + + +class Connection: + """Connection to a Google Cloud Spanner database. + + This class provides a connection to the Spanner database and adheres to + PEP 249 (Python Database API Specification v2.0). + """ + + def __init__(self, internal_connection: Any): + """ + Args: + internal_connection: An instance of + google.cloud.spannerlib.Connection + """ + self._internal_conn = internal_connection + self._closed = False + self._messages: list[Any] = [] + + @property + def messages(self) -> list[Any]: + """Return the list of messages sent to the client by the database.""" + return self._messages + + @check_not_closed + def cursor(self) -> Cursor: + """Return a new Cursor Object using the connection. + + Returns: + Cursor: A cursor object. + """ + return Cursor(self) + + @check_not_closed + def begin(self) -> None: + """Begin a new transaction.""" + logger.debug("Beginning transaction") + try: + self._internal_conn.begin_transaction() + except Exception as e: + raise errors.map_spanner_error(e) + + @check_not_closed + def commit(self) -> None: + """Commit any pending transaction to the database. + + This is a no-op if there is no active client transaction. + """ + logger.debug("Committing transaction") + try: + self._internal_conn.commit() + except Exception as e: + # raise errors.map_spanner_error(e) + logger.debug(f"Commit failed {e}") + + @check_not_closed + def rollback(self) -> None: + """Rollback any pending transaction to the database. + + This is a no-op if there is no active client transaction. + """ + logger.debug("Rolling back transaction") + try: + self._internal_conn.rollback() + except Exception as e: + # raise errors.map_spanner_error(e) + logger.debug(f"Rollback failed {e}") + + def close(self) -> None: + """Close the connection now. + + The connection will be unusable from this point forward; an Error (or + subclass) exception will be raised if any operation is attempted with + the connection. The same applies to all cursor objects trying to use + the connection. + """ + if self._closed: + raise errors.InterfaceError("Connection is already closed") + + logger.debug("Closing connection") + self._internal_conn.close() + self._closed = True + + def __enter__(self) -> "Connection": + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.close() + + +def connect(connection_string: str, **kwargs: Any) -> Connection: + logger.debug(f"Connecting to {connection_string}") + # Create the pool + pool = Pool.create_pool(connection_string) + + # Create the low-level connection + internal_conn = pool.create_connection() + + return Connection(internal_conn) diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py new file mode 100644 index 000000000000..a81e95ef47e8 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py @@ -0,0 +1,475 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import datetime +import logging +import uuid +from enum import Enum +from typing import TYPE_CHECKING, Any + +from google.cloud.spanner_v1 import ( + ExecuteBatchDmlRequest, + ExecuteSqlRequest, + Type, + TypeCode, +) + +from . import errors +from .types import _type_code_to_dbapi_type + +if TYPE_CHECKING: + from .connection import Connection + +logger = logging.getLogger(__name__) + + +def check_not_closed(function): + """`Cursor` class methods decorator. + + Raise an exception if the cursor is closed. + + :raises: :class:`InterfaceError` if the cursor is closed. + """ + + def wrapper(cursor, *args, **kwargs): + if cursor._closed: + raise errors.InterfaceError("Cursor is closed") + + return function(cursor, *args, **kwargs) + + return wrapper + + +class FetchScope(Enum): + FETCH_ONE = 1 + FETCH_MANY = 2 + FETCH_ALL = 3 + + +class Cursor: + """Cursor object for the Google Cloud Spanner database. + + This class lets you use a cursor to interact with the database. + """ + + def __init__(self, connection: "Connection"): + self._connection = connection + self._rows: Any = None # Holds the google.cloud.spannerlib.rows.Rows object + self._closed = False + self.arraysize = 1 + self._rowcount = -1 + + @property + def description(self) -> tuple[tuple[Any, ...], ...] | None: + """ + This read-only attribute is a sequence of 7-item sequences. + + Each of these sequences contains information describing one result + column: + - name + - type_code + - display_size + - internal_size + - precision + - scale + - null_ok + + The first two items (name and type_code) are mandatory, the other + five are optional and are set to None if no meaningful values can be + provided. + + This attribute will be None for operations that do not return rows or + if the cursor has not had an operation invoked via the .execute*() + method yet. + """ + logger.debug("Fetching description for cursor") + if not self._rows: + return None + + try: + metadata = self._rows.metadata() + if not metadata or not metadata.row_type: + return None + + desc = [] + for field in metadata.row_type.fields: + desc.append( + ( + field.name, + _type_code_to_dbapi_type(field.type.code), + None, # display_size + None, # internal_size + None, # precision + None, # scale + True, # null_ok + ) + ) + return tuple(desc) + except Exception: + return None + + @property + def rowcount(self) -> int: + """ + This read-only attribute specifies the number of rows that the last + .execute*() produced (for DQL statements like 'select') or affected + (for DML statements like 'update' or 'insert'). + + The attribute is -1 in case no .execute*() has been performed on the + cursor or the rowcount of the last operation cannot be determined by + the interface. + """ + return self._rowcount + + def _prepare_params( + self, parameters: dict[str, Any] | list[Any] | tuple[Any] | None = None + ) -> (dict[str, Any] | None, dict[str, Type] | None): + """ + Prepares parameters for Spanner execution + + Args: + parameters: A dictionary (for named parameters/GoogleSQL) + or a list/tuple + (for positional parameters/PostgreSQL). + + Returns: + A tuple containing: + - converted_params: Dictionary of parameters with values + converted for Spanner (e.g. ints to strings). + - param_types: Dictionary mapping parameter names to + their Spanner Type. + """ + if not parameters: + return {}, {} + + converted_params = {} + param_types = {} + + # Normalize input to an iterable of (key, value) + if isinstance(parameters, (list, tuple)): + # PostgreSQL Dialect: Positional parameters $1, $2... are + # mapped to P1, P2... + iterator = ((f"P{i}", val) for i, val in enumerate(parameters, 1)) + elif isinstance(parameters, dict): + # GoogleSQL Dialect: Named parameters @name are mapped directly. + iterator = parameters.items() + else: + # If strictly required, raise an error for unsupported types + return {}, {} + + for key, value in iterator: + if value is None: + converted_params[key] = None + continue + # Note: check bool before int, as bool is a subclass of int + if isinstance(value, bool): + converted_params[key] = value + param_types[key] = Type(code=TypeCode.BOOL) + elif isinstance(value, int): + # Spanner expects INT64 as strings to preserve precision + # in JSON + converted_params[key] = str(value) + param_types[key] = Type(code=TypeCode.INT64) + elif isinstance(value, float): + converted_params[key] = value + param_types[key] = Type(code=TypeCode.FLOAT64) + elif isinstance(value, bytes): + converted_params[key] = value + param_types[key] = Type(code=TypeCode.BYTES) + elif isinstance(value, uuid.UUID): + # Convert UUID to string as requested + converted_params[key] = str(value) + # Use STRING type for UUIDs (unless specific UUID type is + # required/supported by your backend version) + param_types[key] = Type(code=TypeCode.STRING) + elif isinstance(value, datetime.datetime): + # Convert Datetime to string (RFC 3339 format is standard + # for str(datetime)) + converted_params[key] = str(value) + param_types[key] = Type(code=TypeCode.TIMESTAMP) + elif isinstance(value, datetime.date): + converted_params[key] = str(value) + param_types[key] = Type(code=TypeCode.DATE) + else: + # Fallback for strings and other types + converted_params[key] = value + # For strings, we can explicitly set the type or let it default. + if isinstance(value, str): + param_types[key] = Type(code=TypeCode.STRING) + + return converted_params, param_types + + @check_not_closed + def execute( + self, + operation: str, + parameters: dict[str, Any] | list[Any] | tuple[Any] | None = None, + ) -> None: + """Prepare and execute a database operation (query or command). + + Parameters may be provided as sequence or mapping and will be bound to + variables in the operation. Variables are specified in a + database-specific notation (see the module's paramstyle attribute for + details). + + Args: + operation (str): The SQL statement to execute. + parameters (dict | list | tuple, optional): parameters to bind. + """ + logger.debug(f"Executing operation: {operation}") + + request = ExecuteSqlRequest(sql=operation) + params, _ = self._prepare_params(parameters) + request.params = params + + try: + self._rows = self._connection._internal_conn.execute(request) + + if self.description: + self._rowcount = -1 + else: + update_count = self._rows.update_count() + if update_count != -1: + self._rowcount = update_count + self._rows.close() + self._rows = None + + except Exception as e: + raise errors.map_spanner_error(e) from e + + @check_not_closed + def executemany( + self, + operation: str, + seq_of_parameters: (list[dict[str, Any]] | list[list[Any]] | list[tuple[Any]]), + ) -> None: + """Prepare a database operation (query or command) and then execute it + against all parameter sequences or mappings found in the sequence + seq_of_parameters. + + Args: + operation (str): The SQL statement to execute. + seq_of_parameters (list): A list of parameter sequences/mappings. + """ + logger.debug(f"Executing batch operation: {operation}") + + request = ExecuteBatchDmlRequest() + + for parameters in seq_of_parameters: + statement = ExecuteBatchDmlRequest.Statement(sql=operation) + params, _ = self._prepare_params(parameters) + statement.params = params + + request.statements.append(statement) + + try: + response = self._connection._internal_conn.execute_batch(request) + total_rowcount = 0 + for result_set in response.result_sets: + if result_set.stats.row_count_exact != -1: + total_rowcount += result_set.stats.row_count_exact + elif result_set.stats.row_count_lower_bound != -1: + total_rowcount += result_set.stats.row_count_lower_bound + self._rowcount = total_rowcount + + except Exception as e: + raise errors.map_spanner_error(e) from e + + def _convert_value(self, value: Any, field_type: Any) -> Any: + kind = value.WhichOneof("kind") + if kind == "null_value": + return None + if kind == "bool_value": + return value.bool_value + if kind == "number_value": + return value.number_value + if kind == "string_value": + code = field_type.code + val = value.string_value + if code == TypeCode.INT64: + return int(val) + if code == TypeCode.BYTES or code == TypeCode.PROTO: + return base64.b64decode(val) + return val + if kind == "list_value": + return [ + self._convert_value(v, field_type.array_element_type) + for v in value.list_value.values + ] + # Fallback for complex types (structs) not fully mapped yet + return value + + def _convert_row(self, row: Any) -> tuple[Any, ...]: + metadata = self._rows.metadata() + fields = metadata.row_type.fields + converted = [] + for i, value in enumerate(row.values): + converted.append(self._convert_value(value, fields[i].type)) + return tuple(converted) + + def _fetch( + self, scope: FetchScope, size: int | None = None + ) -> list[tuple[Any, ...]]: + if not self._rows: + raise errors.ProgrammingError("No result set available") + try: + rows = [] + if scope == FetchScope.FETCH_ONE: + try: + row = self._rows.next() + if row is not None: + rows.append(self._convert_row(row)) + except StopIteration: + pass + elif scope == FetchScope.FETCH_MANY: + # size is guaranteed to be int if scope is FETCH_MANY and + # called from fetchmany but might be None if internal logic + # changes, strict check would satisfy type checker + limit = size if size is not None else self.arraysize + for _ in range(limit): + try: + row = self._rows.next() + if row is None: + break + rows.append(self._convert_row(row)) + except StopIteration: + break + elif scope == FetchScope.FETCH_ALL: + while True: + try: + row = self._rows.next() + if row is None: + break + rows.append(self._convert_row(row)) + except StopIteration: + break + except Exception as e: + raise errors.map_spanner_error(e) from e + + return rows + + @check_not_closed + def fetchone(self) -> tuple[Any, ...] | None: + """Fetch the next row of a query result set, returning a single + sequence, or None when no more data is available. + + Returns: + tuple | None: A row of data or None. + """ + logger.debug("Fetching one row") + rows = self._fetch(FetchScope.FETCH_ONE) + if not rows: + return None + return rows[0] + + @check_not_closed + def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]: + """Fetch the next set of rows of a query result, returning a sequence + of sequences (e.g. a list of tuples). An empty sequence is returned + when no more rows are available. + + The number of rows to fetch per call is specified by the parameter. If + it is not given, the cursor's arraysize determines the number of rows + to be fetched. + + Args: + size (int, optional): The number of rows to fetch. + + Returns: + list[tuple]: A list of rows. + """ + logger.debug("Fetching many rows") + if size is None: + size = self.arraysize + return self._fetch(FetchScope.FETCH_MANY, size) + + @check_not_closed + def fetchall(self) -> list[tuple[Any, ...]]: + """Fetch all (remaining) rows of a query result, returning them as a + sequence of sequences (e.g. a list of tuples). + + Returns: + list[tuple]: A list of rows. + """ + logger.debug("Fetching all rows") + return self._fetch(FetchScope.FETCH_ALL) + + def close(self) -> None: + """Close the cursor now. + + The cursor will be unusable from this point forward; an Error (or + subclass) exception will be raised if any operation is attempted with + the cursor. + """ + logger.debug("Closing cursor") + self._closed = True + if self._rows: + self._rows.close() + + @check_not_closed + def nextset(self) -> bool | None: + """Skip to the next available set of results.""" + logger.debug("Fetching next set of results") + if not self._rows: + return None + + try: + next_metadata = self._rows.next_result_set() + if next_metadata: + return True + return None + except Exception: + return None + + def __enter__(self) -> "Cursor": + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.close() + + def __iter__(self) -> "Cursor": + return self + + def __next__(self) -> tuple[Any, ...]: + row = self.fetchone() + if row is None: + raise StopIteration + return row + + @check_not_closed + def setinputsizes(self, sizes: list[Any]) -> None: + """Predefine memory areas for parameters. + This operation is a no-op implementation. + """ + logger.debug("NO-OP: Setting input sizes") + pass + + @check_not_closed + def setoutputsize(self, size: int, column: int | None = None) -> None: + """Set a column buffer size. + This operation is a no-op implementation. + """ + logger.debug("NO-OP: Setting output size") + pass + + @check_not_closed + def callproc( + self, procname: str, parameters: list[Any] | tuple[Any] | None = None + ) -> None: + """Call a stored database procedure with the given name. + + This method is not supported by Spanner. + """ + logger.debug("NO-OP: Calling stored procedure") + raise errors.NotSupportedError("Stored procedures are not supported.") diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/errors.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/errors.py new file mode 100644 index 000000000000..8225d374eee8 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/errors.py @@ -0,0 +1,241 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Spanner Python Driver Errors. + +DBAPI-defined Exceptions are defined in the following hierarchy:: + + Exceptions + |__Warning + |__Error + |__InterfaceError + |__DatabaseError + |__DataError + |__OperationalError + |__IntegrityError + |__InternalError + |__ProgrammingError + |__NotSupportedError + +""" + +from typing import Any, Sequence + +from google.api_core.exceptions import GoogleAPICallError + + +class Warning(Exception): + """Important DB API warning.""" + + pass + + +class Error(Exception): + """The base class for all the DB API exceptions. + + Does not include :class:`Warning`. + """ + + def _is_error_cause_instance_of_google_api_exception(self) -> bool: + return isinstance(self.__cause__, GoogleAPICallError) + + @property + def reason(self) -> str | None: + """The reason of the error. + Reference: + https://cloud.google.com/apis/design/errors#error_info + Returns: + Union[str, None]: An optional string containing reason of the error. + """ + return ( + self.__cause__.reason + if self._is_error_cause_instance_of_google_api_exception() + else None + ) + + @property + def domain(self) -> str | None: + """The logical grouping to which the "reason" belongs. + Reference: + https://cloud.google.com/apis/design/errors#error_info + Returns: + Union[str, None]: An optional string containing a logical grouping + to which the "reason" belongs. + """ + return ( + self.__cause__.domain + if self._is_error_cause_instance_of_google_api_exception() + else None + ) + + @property + def metadata(self) -> dict[str, str] | None: + """Additional structured details about this error. + Reference: + https://cloud.google.com/apis/design/errors#error_info + Returns: + Union[Dict[str, str], None]: An optional object containing + structured details about the error. + """ + return ( + self.__cause__.metadata + if self._is_error_cause_instance_of_google_api_exception() + else None + ) + + @property + def details(self) -> Sequence[Any] | None: + """Information contained in google.rpc.status.details. + Reference: + https://cloud.google.com/apis/design/errors#error_model + https://cloud.google.com/apis/design/errors#error_details + Returns: + Sequence[Any]: A list of structured objects from + error_details.proto + """ + return ( + self.__cause__.details + if self._is_error_cause_instance_of_google_api_exception() + else None + ) + + +class InterfaceError(Error): + """ + Error related to the database interface + rather than the database itself. + """ + + pass + + +class DatabaseError(Error): + """Error related to the database.""" + + pass + + +class DataError(DatabaseError): + """ + Error due to problems with the processed data like + division by zero, numeric value out of range, etc. + """ + + pass + + +class OperationalError(DatabaseError): + """ + Error related to the database's operation, e.g. an + unexpected disconnect, the data source name is not + found, a transaction could not be processed, a + memory allocation error, etc. + """ + + pass + + +class IntegrityError(DatabaseError): + """ + Error for cases of relational integrity of the database + is affected, e.g. a foreign key check fails. + """ + + pass + + +class InternalError(DatabaseError): + """ + Internal database error, e.g. the cursor is not valid + anymore, the transaction is out of sync, etc. + """ + + pass + + +class ProgrammingError(DatabaseError): + """ + Programming error, e.g. table not found or already + exists, syntax error in the SQL statement, wrong + number of parameters specified, etc. + """ + + pass + + +class NotSupportedError(DatabaseError): + """ + Error for case of a method or database API not + supported by the database was used. + """ + + pass + + +def map_spanner_error(error: Exception) -> Error: + """Map SpannerLibError or GoogleAPICallError to DB API 2.0 errors.""" + from google.api_core import exceptions + from google.cloud.spannerlib.internal.errors import SpannerLibError + + match error: + # Handle SpannerLibError by matching on the internal + # error_code attribute + case SpannerLibError(error_code=code): + match code: + # 3 - INVALID_ARGUMENT + # 5 - NOT_FOUND + case 3 | 5: + return ProgrammingError(error) + # 6 - ALREADY_EXISTS + case 6: + return IntegrityError(error) + # 11 - OUT_OF_RANGE + case 11: + return DataError(error) + # 1 - CANCELLED + # 4 - DEADLINE_EXCEEDED + # 7 - PERMISSION_DENIED + # 9 - FAILED_PRECONDITION + # 10 - ABORTED + # 14 - INTERNAL + # 16 - UNAUTHENTICATED + case 1 | 4 | 7 | 9 | 10 | 14 | 16: + return OperationalError(error) + # 13 - INTERNAL + case 13: + return InternalError(error) + case _: + return DatabaseError(error) + + # Handle standard api_core exceptions + case exceptions.InvalidArgument() | exceptions.NotFound(): + return ProgrammingError(error) + case exceptions.AlreadyExists(): + return IntegrityError(error) + case exceptions.OutOfRange(): + return DataError(error) + case ( + exceptions.FailedPrecondition() + | exceptions.Unauthenticated() + | exceptions.PermissionDenied() + | exceptions.DeadlineExceeded() + | exceptions.ServiceUnavailable() + | exceptions.Aborted() + | exceptions.Cancelled() + ): + return OperationalError(error) + case exceptions.InternalServerError(): + return InternalError(error) + case _: + return DatabaseError(error) diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/types.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/types.py new file mode 100644 index 000000000000..3b3d228ee743 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/types.py @@ -0,0 +1,170 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Types.""" + +import datetime +from typing import Any + +from google.cloud.spanner_v1 import TypeCode + + +def Date(year: int, month: int, day: int) -> datetime.date: + """Construct a date object. + + Args: + year (int): The year of the date. + month (int): The month of the date. + day (int): The day of the date. + + Returns: + datetime.date: A date object. + """ + return datetime.date(year, month, day) + + +def Time(hour: int, minute: int, second: int) -> datetime.time: + """Construct a time object. + + Args: + hour (int): The hour of the time. + minute (int): The minute of the time. + second (int): The second of the time. + + Returns: + datetime.time: A time object. + """ + return datetime.time(hour, minute, second) + + +def Timestamp( + year: int, month: int, day: int, hour: int, minute: int, second: int +) -> datetime.datetime: + """Construct a timestamp object. + + Args: + year (int): The year of the timestamp. + month (int): The month of the timestamp. + day (int): The day of the timestamp. + hour (int): The hour of the timestamp. + minute (int): The minute of the timestamp. + second (int): The second of the timestamp. + + Returns: + datetime.datetime: A timestamp object. + """ + return datetime.datetime(year, month, day, hour, minute, second) + + +def DateFromTicks(ticks: float) -> datetime.date: + """Construct a date object from ticks. + + Args: + ticks (float): The number of seconds since the epoch. + + Returns: + datetime.date: A date object. + """ + return datetime.date.fromtimestamp(ticks) + + +def TimeFromTicks(ticks: float) -> datetime.time: + """Construct a time object from ticks. + + Args: + ticks (float): The number of seconds since the epoch. + + Returns: + datetime.time: A time object. + """ + return datetime.datetime.fromtimestamp(ticks).time() + + +def TimestampFromTicks(ticks: float) -> datetime.datetime: + """Construct a timestamp object from ticks. + + Args: + ticks (float): The number of seconds since the epoch. + + Returns: + datetime.datetime: A timestamp object. + """ + return datetime.datetime.fromtimestamp(ticks) + + +def Binary(string: str | bytes) -> bytes: + """Construct a binary object. + + Args: + string (str | bytes): The string or bytes to convert. + + Returns: + bytes: A binary object. + """ + return bytes(string, "utf-8") if isinstance(string, str) else bytes(string) + + +# Type Objects for description comparison +class DBAPITypeObject: + def __init__(self, *values: str): + self.values = values + + def __eq__(self, other: Any) -> bool: + return other in self.values + + +STRING = DBAPITypeObject("STRING") +BINARY = DBAPITypeObject("BYTES", "PROTO") +NUMBER = DBAPITypeObject("INT64", "FLOAT64", "NUMERIC") +DATETIME = DBAPITypeObject("TIMESTAMP", "DATE") +BOOLEAN = DBAPITypeObject("BOOL") +ROWID = DBAPITypeObject() + + +class Type(object): + STRING = TypeCode.STRING + BYTES = TypeCode.BYTES + BOOL = TypeCode.BOOL + INT64 = TypeCode.INT64 + FLOAT64 = TypeCode.FLOAT64 + DATE = TypeCode.DATE + TIMESTAMP = TypeCode.TIMESTAMP + NUMERIC = TypeCode.NUMERIC + JSON = TypeCode.JSON + PROTO = TypeCode.PROTO + ENUM = TypeCode.ENUM + + +def _type_code_to_dbapi_type(type_code: int) -> DBAPITypeObject: + if type_code == TypeCode.STRING: + return STRING + if type_code == TypeCode.JSON: + return STRING + if type_code == TypeCode.BYTES: + return BINARY + if type_code == TypeCode.PROTO: + return BINARY + if type_code == TypeCode.BOOL: + return BOOLEAN + if type_code == TypeCode.INT64: + return NUMBER + if type_code == TypeCode.FLOAT64: + return NUMBER + if type_code == TypeCode.NUMERIC: + return NUMBER + if type_code == TypeCode.DATE: + return DATETIME + if type_code == TypeCode.TIMESTAMP: + return DATETIME + + return STRING diff --git a/packages/google-cloud-spanner-dbapi-driver/noxfile.py b/packages/google-cloud-spanner-dbapi-driver/noxfile.py index 3be4dcc7e55c..2fedee7ee5af 100644 --- a/packages/google-cloud-spanner-dbapi-driver/noxfile.py +++ b/packages/google-cloud-spanner-dbapi-driver/noxfile.py @@ -64,6 +64,7 @@ "pytest", "pytest-cov", "pytest-asyncio", + "google-cloud-spanner", ] UNIT_TEST_EXTERNAL_DEPENDENCIES: List[str] = [] UNIT_TEST_LOCAL_DEPENDENCIES: List[str] = [] diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/system/_helper.py b/packages/google-cloud-spanner-dbapi-driver/tests/system/_helper.py new file mode 100644 index 000000000000..e940f79c01ee --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/system/_helper.py @@ -0,0 +1,58 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helper functions for system tests.""" + +import os + +SPANNER_EMULATOR_HOST = os.environ.get("SPANNER_EMULATOR_HOST") +TEST_ON_PROD = not bool(SPANNER_EMULATOR_HOST) + +if TEST_ON_PROD: + PROJECT_ID = os.environ.get("SPANNER_PROJECT_ID") + INSTANCE_ID = os.environ.get("SPANNER_INSTANCE_ID") + DATABASE_ID = os.environ.get("SPANNER_DATABASE_ID") + + if not PROJECT_ID or not INSTANCE_ID or not DATABASE_ID: + raise ValueError( + "SPANNER_PROJECT_ID, SPANNER_INSTANCE_ID, and SPANNER_DATABASE_ID " + "must be set when running tests on production." + ) +else: + PROJECT_ID = "test-project" + INSTANCE_ID = "test-instance" + DATABASE_ID = "test-db" + +PROD_TEST_CONNECTION_STRING = ( + f"projects/{PROJECT_ID}/instances/{INSTANCE_ID}/databases/{DATABASE_ID}" +) + +EMULATOR_TEST_CONNECTION_STRING = ( + f"{SPANNER_EMULATOR_HOST}" + f"projects/{PROJECT_ID}" + f"/instances/{INSTANCE_ID}" + f"/databases/{DATABASE_ID}" + "?autoConfigEmulator=true" +) + + +def setup_test_env() -> None: + if not TEST_ON_PROD: + print(f"Set SPANNER_EMULATOR_HOST to {os.environ['SPANNER_EMULATOR_HOST']}") + print(f"Using Connection String: {get_test_connection_string()}") + + +def get_test_connection_string() -> str: + if TEST_ON_PROD: + return PROD_TEST_CONNECTION_STRING + return EMULATOR_TEST_CONNECTION_STRING diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/system/test_connection.py b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_connection.py new file mode 100644 index 000000000000..9fb4199f1b80 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_connection.py @@ -0,0 +1,44 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for connection.py""" + +from google.cloud.spanner_driver import connect + +from ._helper import get_test_connection_string + + +class TestConnect: + def test_cursor(self): + """Test the connect method.""" + connection_string = get_test_connection_string() + + # Test Context Manager + with connect(connection_string) as connection: + assert connection is not None + + # Test Cursor Context Manager + with connection.cursor() as cursor: + assert cursor is not None + + +class TestConnectMethod: + """Tests for the connection.py module.""" + + def test_connect(self): + """Test the connect method.""" + connection_string = get_test_connection_string() + + # Test Context Manager + with connect(connection_string) as connection: + assert connection is not None diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/system/test_cursor.py b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_cursor.py new file mode 100644 index 000000000000..5719b4030fa5 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_cursor.py @@ -0,0 +1,144 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for cursor.py""" + +from google.cloud.spanner_driver import connect, types + +from ._helper import get_test_connection_string + + +class TestCursor: + def test_execute(self): + """Test the execute method.""" + connection_string = get_test_connection_string() + + # Test Context Manager + with connect(connection_string) as connection: + assert connection is not None + + # Test Cursor Context Manager + with connection.cursor() as cursor: + assert cursor is not None + + # Test execute and fetchone + cursor.execute("SELECT 1 AS col1") + assert cursor.description is not None + assert cursor.description[0][0] == "col1" + assert ( + cursor.description[0][1] == types.NUMBER + ) # TypeCode.INT64 maps to types.NUMBER + + result = cursor.fetchone() + assert result == (1,) + + def test_execute_params(self): + """Test the execute method with parameters.""" + connection_string = get_test_connection_string() + with connect(connection_string) as connection: + with connection.cursor() as cursor: + sql = "SELECT @a AS col1" + params = {"a": 1} + cursor.execute(sql, params) + result = cursor.fetchone() + assert result == (1,) + + def test_execute_dml(self): + """Test DML execution.""" + connection_string = get_test_connection_string() + with connect(connection_string) as connection: + with connection.cursor() as cursor: + cursor.execute("DROP TABLE IF EXISTS Singers") + + # Create table + cursor.execute( + """ + CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + ) PRIMARY KEY (SingerId) + """ + ) + + # Insert + cursor.execute( + "INSERT INTO Singers (SingerId, FirstName, LastName) " + "VALUES (@id, @first, @last)", + {"id": 1, "first": "John", "last": "Doe"}, + ) + assert cursor.rowcount == 1 + + # Update + cursor.execute( + "UPDATE Singers SET FirstName = 'Jane' WHERE SingerId = 1" + ) + assert cursor.rowcount == 1 + + # Select back to verify + cursor.execute("SELECT FirstName FROM Singers WHERE SingerId = 1") + row = cursor.fetchone() + assert row == ("Jane",) + + # Cleanup (optional if emulator is reset) + + def test_fetch_methods(self): + """Test fetchmany and fetchall.""" + connection_string = get_test_connection_string() + with connect(connection_string) as connection: + with connection.cursor() as cursor: + # Use UNNEST to generate rows + cursor.execute( + "SELECT * FROM UNNEST([1, 2, 3, 4, 5]) AS numbers ORDER BY numbers" + ) + + # Fetch one + row = cursor.fetchone() + assert row == (1,) + + # Fetch many + rows = cursor.fetchmany(2) + assert len(rows) == 2 + assert rows[0] == (2,) + assert rows[1] == (3,) + + # Fetch all remaining + rows = cursor.fetchall() + assert len(rows) == 2 + assert rows[0] == (4,) + assert rows[1] == (5,) + + def test_data_types(self): + """Test various data types.""" + connection_string = get_test_connection_string() + with connect(connection_string) as connection: + with connection.cursor() as cursor: + sql = """ + SELECT + 1 AS int_val, + 3.14 AS float_val, + TRUE AS bool_val, + 'hello' AS str_val, + b'bytes' AS bytes_val, + DATE '2023-01-01' AS date_val, + TIMESTAMP '2023-01-01T12:00:00Z' AS timestamp_val + """ + cursor.execute(sql) + row = cursor.fetchone() + + assert row[0] == 1 + assert row[1] == 3.14 + assert row[2] is True + assert row[3] == "hello" + assert row[4] == b"bytes" + assert row[4] == b"bytes" diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/system/test_errors.py b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_errors.py new file mode 100644 index 000000000000..5a7b39a0f2ea --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_errors.py @@ -0,0 +1,74 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for error handling in cursor.py and connection.py""" + +import pytest + +from google.cloud.spanner_driver import connect, errors + +from ._helper import get_test_connection_string + +DDL = """ +CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), +) PRIMARY KEY (SingerId) +""" + + +class TestErrors: + def setup_method(self): + """Re-create the table before each test.""" + connection_string = get_test_connection_string() + with connect(connection_string) as connection: + with connection.cursor() as cursor: + cursor.execute("DROP TABLE IF EXISTS Singers") + cursor.execute(DDL) + + def test_programming_error_table_not_found(self): + """Test that selecting from a non-existent table + raises expected error.""" + connection_string = get_test_connection_string() + + with connect(connection_string) as connection: + with connection.cursor() as cursor: + with pytest.raises(errors.ProgrammingError): + cursor.execute("SELECT * FROM NonExistentTable") + + def test_integrity_error_duplicate_pk(self): + """Test that duplicate primary key raises IntegrityError.""" + connection_string = get_test_connection_string() + + with connect(connection_string) as connection: + with connection.cursor() as cursor: + sql = ( + "INSERT INTO Singers (SingerId, FirstName, LastName) " + "VALUES (@id, @first, @last)" + ) + params = {"id": 1, "first": "Alice", "last": "A"} + + cursor.execute(sql, params) + + # Second insert with same PK + with pytest.raises(errors.IntegrityError): + cursor.execute(sql, params) + + def test_operational_error_syntax(self): + """Test bad syntax raises ProgrammingError/OperationalError.""" + connection_string = get_test_connection_string() + with connect(connection_string) as connection: + with connection.cursor() as cursor: + with pytest.raises(errors.ProgrammingError): + cursor.execute("SELECT * FROM Singers WHERE") diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/system/test_executemany.py b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_executemany.py new file mode 100644 index 000000000000..ea305480db1e --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_executemany.py @@ -0,0 +1,64 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for executemany support in cursor.py""" + +from google.cloud.spanner_driver import connect + +from ._helper import get_test_connection_string + +DDL = """ +CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), +) PRIMARY KEY (SingerId) +""" + + +class TestExecuteMany: + def setup_method(self): + """Re-create the table before each test.""" + connection_string = get_test_connection_string() + with connect(connection_string) as connection: + with connection.cursor() as cursor: + cursor.execute("DROP TABLE IF EXISTS Singers") + cursor.execute(DDL) + + def test_executemany(self): + """Test executemany with multiple rows.""" + connection_string = get_test_connection_string() + + with connect(connection_string) as connection: + with connection.cursor() as cursor: + sql = ( + "INSERT INTO Singers (SingerId, FirstName, LastName) " + "VALUES (@id, @first, @last)" + ) + params_seq = [ + {"id": 1, "first": "Alice", "last": "A"}, + {"id": 2, "first": "Bob", "last": "B"}, + {"id": 3, "first": "Charlie", "last": "C"}, + ] + + cursor.executemany(sql, params_seq) + + assert cursor.rowcount == 3 + + # Verify rows + cursor.execute("SELECT * FROM Singers ORDER BY SingerId") + rows = cursor.fetchall() + assert len(rows) == 3 + assert rows[0] == (1, "Alice", "A") + assert rows[1] == (2, "Bob", "B") + assert rows[2] == (3, "Charlie", "C") diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/system/test_transaction.py b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_transaction.py new file mode 100644 index 000000000000..0e3db27d9e25 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_transaction.py @@ -0,0 +1,116 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for transaction support in connection.py""" + +from google.cloud.spanner_driver import connect + +from ._helper import get_test_connection_string + +DDL = """ +CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), +) PRIMARY KEY (SingerId) +""" + + +class TestTransaction: + def setup_method(self): + """Re-create the table before each test.""" + connection_string = get_test_connection_string() + with connect(connection_string) as connection: + with connection.cursor() as cursor: + cursor.execute("DROP TABLE IF EXISTS Singers") + cursor.execute(DDL) + + def test_commit(self): + """Test that changes are visible after commit.""" + connection_string = get_test_connection_string() + + # 1. Insert in a transaction + with connect(connection_string) as conn1: + conn1.begin() + with conn1.cursor() as cursor: + cursor.execute( + "INSERT INTO Singers (SingerId, FirstName, LastName) " + "VALUES (@id, @first, @last)", + {"id": 1, "first": "John", "last": "Doe"}, + ) + conn1.commit() + + # 2. Verify visibility from another connection + with connect(connection_string) as conn2: + with conn2.cursor() as cursor: + cursor.execute("SELECT FirstName FROM Singers WHERE SingerId = 1") + row = cursor.fetchone() + assert row == ("John",) + + def test_rollback(self): + """Test that changes are discarded after rollback.""" + connection_string = get_test_connection_string() + + # 1. Insert then rollback + with connect(connection_string) as conn1: + conn1.begin() + with conn1.cursor() as cursor: + cursor.execute( + "INSERT INTO Singers (SingerId, FirstName, LastName) " + "VALUES (@id, @first, @last)", + {"id": 2, "first": "Jane", "last": "Doe"}, + ) + conn1.rollback() + + # 2. Verify NOT visible + with connect(connection_string) as conn2: + with conn2.cursor() as cursor: + cursor.execute("SELECT FirstName FROM Singers WHERE SingerId = 2") + row = cursor.fetchone() + assert row is None + + def test_isolation(self): + """Test that uncommitted changes are not visible to others.""" + connection_string = get_test_connection_string() + + conn1 = connect(connection_string) + conn2 = connect(connection_string) + + try: + conn1.begin() + curs1 = conn1.cursor() + curs2 = conn2.cursor() + + # Insert in conn1 (uncommitted) + curs1.execute( + "INSERT INTO Singers (SingerId, FirstName, LastName) " + "VALUES (@id, @first, @last)", + {"id": 3, "first": "Bob", "last": "Smith"}, + ) + + # Check from conn2 + curs2.execute("SELECT FirstName FROM Singers WHERE SingerId = 3") + row = curs2.fetchone() + assert row is None + + # Commit conn1 + conn1.commit() + + # Check from conn2 + curs2.execute("SELECT FirstName FROM Singers WHERE SingerId = 3") + row = curs2.fetchone() + assert row == ("Bob",) + + finally: + conn1.close() + conn2.close() diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/unit/conftest.py b/packages/google-cloud-spanner-dbapi-driver/tests/unit/conftest.py new file mode 100644 index 000000000000..42047ef13721 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/unit/conftest.py @@ -0,0 +1,221 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from unittest.mock import MagicMock + +import google.cloud + + +# 1. Define Exception Classes +class MockGoogleAPICallError(Exception): + def __init__(self, message=None, errors=None, response=None, **kwargs): + super().__init__(message) + self.message = message + self.errors = errors + self.response = response + self.reason = "reason" + self.domain = "domain" + self.metadata = {} + self.details = [] + + +class AlreadyExists(MockGoogleAPICallError): + pass + + +class NotFound(MockGoogleAPICallError): + pass + + +class InvalidArgument(MockGoogleAPICallError): + pass + + +class FailedPrecondition(MockGoogleAPICallError): + pass + + +class OutOfRange(MockGoogleAPICallError): + pass + + +class Unauthenticated(MockGoogleAPICallError): + pass + + +class PermissionDenied(MockGoogleAPICallError): + pass + + +class DeadlineExceeded(MockGoogleAPICallError): + pass + + +class ServiceUnavailable(MockGoogleAPICallError): + pass + + +class Aborted(MockGoogleAPICallError): + pass + + +class InternalServerError(MockGoogleAPICallError): + pass + + +class Unknown(MockGoogleAPICallError): + pass + + +class Cancelled(MockGoogleAPICallError): + pass + + +class DataLoss(MockGoogleAPICallError): + pass + + +class MockSpannerLibError(Exception): + pass + + +# 2. Define Type/Proto Classes +class MockTypeCode: + STRING = 1 + BYTES = 2 + BOOL = 3 + INT64 = 4 + FLOAT64 = 5 + DATE = 6 + TIMESTAMP = 7 + NUMERIC = 8 + JSON = 9 + PROTO = 10 + ENUM = 11 + + +class MockExecuteSqlRequest: + def __init__(self, sql=None, params=None): + self.sql = sql + self.params = params + + +class MockType: + def __init__(self, code): + self.code = code + + def __eq__(self, other): + return isinstance(other, MockType) and self.code == other.code + + def __repr__(self): + return f"MockType(code={self.code})" + + +class MockStructField: + def __init__(self, name, type_): + self.name = name + self.type = type_ # Avoid conflict with builtin type + + def __eq__(self, other): + return ( + isinstance(other, MockStructField) + and self.name == other.name + and self.type == other.type + ) + + +class MockStructType: + def __init__(self, fields): + self.fields = fields + + +# 3. Create Module Mocks +# google.cloud.spanner_v1 +spanner_v1 = MagicMock() +spanner_v1.TypeCode = MockTypeCode +spanner_v1.ExecuteSqlRequest = MockExecuteSqlRequest +spanner_v1.Type = MockType +spanner_v1.StructField = MockStructField +spanner_v1.StructType = MockStructType + +# google.cloud.spanner_v1.types +spanner_v1_types = MagicMock() +spanner_v1_types.Type = MockType +spanner_v1_types.StructField = MockStructField +spanner_v1_types.StructType = MockStructType + +# google.api_core.exceptions +exceptions_module = MagicMock() +exceptions_module.GoogleAPICallError = MockGoogleAPICallError +exceptions_module.AlreadyExists = AlreadyExists +exceptions_module.NotFound = NotFound +exceptions_module.InvalidArgument = InvalidArgument +exceptions_module.FailedPrecondition = FailedPrecondition +exceptions_module.OutOfRange = OutOfRange +exceptions_module.Unauthenticated = Unauthenticated +exceptions_module.PermissionDenied = PermissionDenied +exceptions_module.DeadlineExceeded = DeadlineExceeded +exceptions_module.ServiceUnavailable = ServiceUnavailable +exceptions_module.Aborted = Aborted +exceptions_module.InternalServerError = InternalServerError +exceptions_module.Unknown = Unknown +exceptions_module.Cancelled = Cancelled +exceptions_module.DataLoss = DataLoss + +# google.cloud.spannerlib +spannerlib = MagicMock() +# internal.errors +spannerlib_internal_errors = MagicMock() +spannerlib_internal_errors.SpannerLibError = MockSpannerLibError +spannerlib.internal.errors = spannerlib_internal_errors + +# pool +spannerlib_pool = MagicMock() +spannerlib.pool = spannerlib_pool + + +# pool.Pool class +class MockPool: + @staticmethod + def create_pool(connection_string): + return MockPool() + + def create_connection(self): + return MagicMock() + + +spannerlib.pool.Pool = MockPool + +# connection +spannerlib_connection = MagicMock() +spannerlib.connection = spannerlib_connection + +# 4. Inject into sys.modules +sys.modules["google.cloud.spanner_v1"] = spanner_v1 +sys.modules["google.cloud.spanner_v1.types"] = spanner_v1_types +sys.modules["google.api_core.exceptions"] = exceptions_module +sys.modules["google.api_core"] = MagicMock(exceptions=exceptions_module) +sys.modules["google.cloud.spannerlib"] = spannerlib +sys.modules["google.cloud.spannerlib.internal"] = spannerlib.internal +sys.modules["google.cloud.spannerlib.internal.errors"] = spannerlib_internal_errors +sys.modules["google.cloud.spannerlib.pool"] = spannerlib_pool +sys.modules["google.cloud.spannerlib.connection"] = spannerlib_connection + + +# 4. Patch google.cloud +# This is tricky because google is a namespace package +# but spannerlib might need to be explicitly set in google.cloud +google.cloud.spannerlib = spannerlib +google.cloud.spanner_v1 = spanner_v1 diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_connection.py b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_connection.py new file mode 100644 index 000000000000..ed9a0fa18736 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_connection.py @@ -0,0 +1,111 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import mock + +from google.cloud import spanner_driver +from google.cloud.spanner_driver import connection, errors + + +class TestConnect(unittest.TestCase): + def test_connect(self): + connection_string = "spanner://projects/p/instances/i/databases/d" + + with mock.patch( + "google.cloud.spannerlib.pool.Pool.create_pool" + ) as mock_create_pool: + mock_pool = mock.Mock() + mock_create_pool.return_value = mock_pool + mock_internal_conn = mock.Mock() + mock_pool.create_connection.return_value = mock_internal_conn + + conn = spanner_driver.connect(connection_string) + + self.assertIsInstance(conn, connection.Connection) + mock_create_pool.assert_called_once_with(connection_string) + mock_pool.create_connection.assert_called_once() + + +class TestConnection(unittest.TestCase): + def setUp(self): + self.mock_internal_conn = mock.Mock() + self.conn = connection.Connection(self.mock_internal_conn) + + def test_cursor(self): + cursor = self.conn.cursor() + self.assertIsInstance(cursor, spanner_driver.Cursor) + self.assertEqual(cursor._connection, self.conn) + + def test_cursor_closed(self): + self.conn.close() + with self.assertRaises(errors.InterfaceError): + self.conn.cursor() + + def test_begin(self): + self.conn.begin() + self.mock_internal_conn.begin_transaction.assert_called_once() + + def test_begin_error(self): + self.mock_internal_conn.begin_transaction.side_effect = Exception( + "Internal Error" + ) + with self.assertRaises(errors.DatabaseError): + self.conn.begin() + + def test_commit(self): + self.conn.commit() + self.mock_internal_conn.commit.assert_called_once() + + def test_commit_error(self): + self.mock_internal_conn.commit.side_effect = Exception("Commit Failed") + try: + self.conn.commit() + except Exception: + self.fail("commit() raised Exception unexpectedly!") + self.mock_internal_conn.commit.assert_called_once() + + def test_rollback(self): + self.conn.rollback() + self.mock_internal_conn.rollback.assert_called_once() + + def test_rollback_error(self): + # Similar to commit, rollback errors are caught and logged + self.mock_internal_conn.rollback.side_effect = Exception("Rollback Failed") + try: + self.conn.rollback() + except Exception: + self.fail("rollback() raised Exception unexpectedly!") + self.mock_internal_conn.rollback.assert_called_once() + + def test_close(self): + self.assertFalse(self.conn._closed) + self.conn.close() + self.assertTrue(self.conn._closed) + self.mock_internal_conn.close.assert_called_once() + + def test_close_idempotent(self): + self.conn.close() + self.mock_internal_conn.close.reset_mock() + self.assertRaises(errors.InterfaceError, self.conn.close) + + def test_messages(self): + self.assertEqual(self.conn.messages, []) + + def test_context_manager(self): + with self.conn as c: + self.assertEqual(c, self.conn) + self.assertFalse(c._closed) + self.assertTrue(self.conn._closed) + self.mock_internal_conn.close.assert_called_once() diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_cursor.py b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_cursor.py new file mode 100644 index 000000000000..7cb6cf4e992f --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_cursor.py @@ -0,0 +1,349 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import unittest +import uuid +from unittest import mock + +from google.cloud.spanner_v1 import ExecuteSqlRequest, TypeCode +from google.cloud.spanner_v1.types import StructField, Type + +from google.cloud.spanner_driver import cursor + + +class TestCursor(unittest.TestCase): + def setUp(self): + self.mock_connection = mock.Mock() + self.mock_internal_conn = mock.Mock() + self.mock_connection._internal_conn = self.mock_internal_conn + self.cursor = cursor.Cursor(self.mock_connection) + + def test_init(self): + self.assertEqual(self.cursor._connection, self.mock_connection) + + def test_execute(self): + operation = "SELECT * FROM table" + mock_rows = mock.Mock() + # Mocking description to be None so it treats as DML or query with no + # result initially? If description calls metadata(), we need to mock + # that. logic: if self.description: self._rowcount = -1 + + # Scenario 1: SELECT query (returns rows) + mock_metadata = mock.Mock() + mock_metadata.row_type.fields = [ + StructField(name="col1", type_=Type(code=TypeCode.INT64)) + ] + mock_rows.metadata.return_value = mock_metadata + self.mock_internal_conn.execute.return_value = mock_rows + + self.cursor.execute(operation) + + self.mock_internal_conn.execute.assert_called_once() + call_args = self.mock_internal_conn.execute.call_args + self.assertIsInstance(call_args[0][0], ExecuteSqlRequest) + self.assertEqual(call_args[0][0].sql, operation) + self.assertEqual(self.cursor._rowcount, -1) + self.assertEqual(self.cursor._rows, mock_rows) + + def test_execute_dml(self): + operation = "UPDATE table SET col=1" + mock_rows = mock.Mock() + # Returns empty metadata or no metadata for DML? + # Actually in Spanner, DML returns a ResultSet with stats. + # But here we check `if self.description`. + + # Scenario 2: DML (no fields in metadata usually, or we can simulate + # it) If metadata calls fail or return empty, description returns + # usually None. + mock_rows.metadata.return_value = None + mock_rows.update_count.return_value = 10 + self.mock_internal_conn.execute.return_value = mock_rows + + self.cursor.execute(operation) + + self.assertEqual(self.cursor._rowcount, 10) + # rows should be closed and set to None for DML in this driver + # implementation + mock_rows.close.assert_called_once() + self.assertIsNone(self.cursor._rows) + + def test_execute_with_params(self): + operation = "SELECT * FROM table WHERE id=@id" + params = {"id": 1} + mock_rows = mock.Mock() + mock_rows.metadata.return_value = mock.Mock() + self.mock_internal_conn.execute.return_value = mock_rows + + self.cursor.execute(operation, params) + + call_args = self.mock_internal_conn.execute.call_args + request = call_args[0][0] + self.assertEqual(request.sql, operation) + self.assertEqual(request.sql, operation) + self.assertEqual(request.params, {"id": "1"}) + + def test_executemany(self): + operation = "INSERT INTO table (id) VALUES (@id)" + params_seq = [{"id": 1, "name": "val1"}, {"id": 2}] + + # Mock execute_batch response + mock_response = mock.Mock() + mock_result_set1 = mock.Mock() + mock_result_set1.stats.row_count_exact = 1 + mock_result_set2 = mock.Mock() + mock_result_set2.stats.row_count_exact = 1 + mock_response.result_sets = [mock_result_set1, mock_result_set2] + + self.mock_internal_conn.execute_batch.return_value = mock_response + + # Patch ExecuteBatchDmlRequest in cursor module + with mock.patch( + "google.cloud.spanner_driver.cursor.ExecuteBatchDmlRequest" + ) as MockRequest: + # Setup mock request instance and statements list behavior + mock_request_instance = MockRequest.return_value + mock_request_instance.statements = [] # Use a real list to verify append + + # Setup Statement mock + MockStatement = mock.Mock() + MockRequest.Statement = MockStatement + + self.cursor.executemany(operation, params_seq) + + # Verify execute_batch called with our mock request + self.mock_internal_conn.execute_batch.assert_called_once_with( + mock_request_instance + ) + + # Verify statements were created and appended + self.assertEqual(len(mock_request_instance.statements), 2) + + # Verify first statement + call1 = MockStatement.call_args_list[0] + self.assertEqual(call1.kwargs["sql"], operation) + + self.assertEqual(MockStatement.call_count, 2) + + # Verify rowcount update + self.assertEqual(self.cursor.rowcount, 2) + + def test_fetchone(self): + mock_rows = mock.Mock() + self.cursor._rows = mock_rows + + # Mock metadata for type information + mock_metadata = mock.Mock() + mock_metadata.row_type.fields = [ + StructField(name="col1", type_=Type(code=TypeCode.INT64)) + ] + mock_rows.metadata.return_value = mock_metadata + mock_rows.metadata.return_value = mock_metadata + + # Mock row as object with values attribute + mock_row = mock.Mock() + mock_val = mock.Mock() + mock_val.WhichOneof.return_value = "string_value" + mock_val.string_value = "1" + mock_row.values = [mock_val] + + mock_rows.next.return_value = mock_row + + row = self.cursor.fetchone() + self.assertEqual(row, (1,)) + mock_rows.next.assert_called_once() + + def test_fetchone_empty(self): + mock_rows = mock.Mock() + self.cursor._rows = mock_rows + mock_rows.next.side_effect = StopIteration + + row = self.cursor.fetchone() + self.assertIsNone(row) + + def test_fetchmany(self): + mock_rows = mock.Mock() + self.cursor._rows = mock_rows + + # Metadata + mock_metadata = mock.Mock() + mock_metadata.row_type.fields = [ + StructField(name="col1", type_=Type(code=TypeCode.INT64)) + ] + mock_rows.metadata.return_value = mock_metadata + mock_rows.metadata.return_value = mock_metadata + + # Rows + mock_row1 = mock.Mock() + v1 = mock.Mock() + v1.WhichOneof.return_value = "string_value" + v1.string_value = "1" + mock_row1.values = [v1] + + mock_row2 = mock.Mock() + v2 = mock.Mock() + v2.WhichOneof.return_value = "string_value" + v2.string_value = "2" + mock_row2.values = [v2] + + mock_rows.next.side_effect = [mock_row1, mock_row2, StopIteration] + + rows = self.cursor.fetchmany(size=5) + self.assertEqual(len(rows), 2) + self.assertEqual(rows, [(1,), (2,)]) + + def test_fetchall(self): + mock_rows = mock.Mock() + self.cursor._rows = mock_rows + + # Metadata + mock_metadata = mock.Mock() + mock_metadata.row_type.fields = [ + StructField(name="col1", type_=Type(code=TypeCode.INT64)) + ] + mock_rows.metadata.return_value = mock_metadata + mock_rows.metadata.return_value = mock_metadata + + # Rows + mock_row1 = mock.Mock() + v1 = mock.Mock() + v1.WhichOneof.return_value = "string_value" + v1.string_value = "1" + mock_row1.values = [v1] + + mock_row2 = mock.Mock() + v2 = mock.Mock() + v2.WhichOneof.return_value = "string_value" + v2.string_value = "2" + mock_row2.values = [v2] + + mock_rows.next.side_effect = [mock_row1, mock_row2, StopIteration] + + rows = self.cursor.fetchall() + self.assertEqual(len(rows), 2) + + def test_description(self): + mock_rows = mock.Mock() + self.cursor._rows = mock_rows + + mock_metadata = mock.Mock() + mock_metadata.row_type.fields = [ + StructField(name="col1", type_=Type(code=TypeCode.INT64)), + StructField(name="col2", type_=Type(code=TypeCode.STRING)), + ] + mock_rows.metadata.return_value = mock_metadata + + desc = self.cursor.description + self.assertEqual(len(desc), 2) + self.assertEqual(desc[0][0], "col1") + self.assertEqual(desc[1][0], "col2") + + def test_close(self): + mock_rows = mock.Mock() + self.cursor._rows = mock_rows + + self.cursor.close() + + self.assertTrue(self.cursor._closed) + mock_rows.close.assert_called_once() + + def test_context_manager(self): + with self.cursor as c: + self.assertEqual(c, self.cursor) + self.assertTrue(self.cursor._closed) + + def test_iterator(self): + mock_rows = mock.Mock() + self.cursor._rows = mock_rows + + mock_metadata = mock.Mock() + mock_metadata.row_type.fields = [ + StructField(name="col1", type_=Type(code=TypeCode.INT64)) + ] + mock_rows.metadata.return_value = mock_metadata + mock_rows.metadata.return_value = mock_metadata + + mock_row = mock.Mock() + v1 = mock.Mock() + v1.WhichOneof.return_value = "string_value" + v1.string_value = "1" + mock_row.values = [v1] + + mock_rows.next.side_effect = [mock_row, StopIteration] + + # __next__ calls fetchone + it = iter(self.cursor) + self.assertEqual(next(it), (1,)) + with self.assertRaises(StopIteration): + next(it) + + def test_prepare_params(self): + # Test 1: None + converted, types = self.cursor._prepare_params(None) + self.assertEqual(converted, {}) + self.assertEqual(types, {}) + + # Test 2: Dict (GoogleSQL) + uuid_val = uuid.uuid4() + dt_val = datetime.datetime(2024, 1, 1, 12, 0, 0) + date_val = datetime.date(2024, 1, 1) + params = { + "int_val": 123, + "bool_val": True, + "float_val": 1.23, + "bytes_val": b"bytes", + "str_val": "string", + "uuid_val": uuid_val, + "dt_val": dt_val, + "date_val": date_val, + "none_val": None, + } + converted, types = self.cursor._prepare_params(params) + + self.assertEqual(converted["int_val"], "123") + self.assertEqual(types["int_val"].code, TypeCode.INT64) + + self.assertEqual(converted["bool_val"], True) + self.assertEqual(types["bool_val"].code, TypeCode.BOOL) + + self.assertEqual(converted["float_val"], 1.23) + self.assertEqual(types["float_val"].code, TypeCode.FLOAT64) + + self.assertEqual(converted["bytes_val"], b"bytes") + self.assertEqual(types["bytes_val"].code, TypeCode.BYTES) + + self.assertEqual(converted["str_val"], "string") + self.assertEqual(types["str_val"].code, TypeCode.STRING) + + self.assertEqual(converted["uuid_val"], str(uuid_val)) + self.assertEqual(types["uuid_val"].code, TypeCode.STRING) + + self.assertEqual(converted["dt_val"], str(dt_val)) + self.assertEqual(types["dt_val"].code, TypeCode.TIMESTAMP) + + self.assertEqual(converted["date_val"], str(date_val)) + self.assertEqual(types["date_val"].code, TypeCode.DATE) + + self.assertIsNone(converted["none_val"]) + self.assertNotIn("none_val", types) + + # Test 3: List (PostgreSQL) + params_list = [1, "test"] + converted, types = self.cursor._prepare_params(params_list) + + self.assertEqual(converted["P1"], "1") + self.assertEqual(types["P1"].code, TypeCode.INT64) + + self.assertEqual(converted["P2"], "test") + self.assertEqual(types["P2"].code, TypeCode.STRING) diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_errors.py b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_errors.py new file mode 100644 index 000000000000..deabcaca6d79 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_errors.py @@ -0,0 +1,57 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from google.api_core import exceptions +from google.cloud.spannerlib.internal.errors import SpannerLibError + +from google.cloud.spanner_driver import errors + + +class TestErrors(unittest.TestCase): + def test_map_spanner_lib_error(self): + err = SpannerLibError("Internal Error") + mapped_err = errors.map_spanner_error(err) + self.assertIsInstance(mapped_err, errors.DatabaseError) + + def test_map_not_found(self): + err = exceptions.NotFound("Not found") + mapped_err = errors.map_spanner_error(err) + self.assertIsInstance(mapped_err, errors.ProgrammingError) + + def test_map_already_exists(self): + err = exceptions.AlreadyExists("Exists") + mapped_err = errors.map_spanner_error(err) + self.assertIsInstance(mapped_err, errors.IntegrityError) + + def test_map_invalid_argument(self): + err = exceptions.InvalidArgument("Invalid") + mapped_err = errors.map_spanner_error(err) + self.assertIsInstance(mapped_err, errors.ProgrammingError) + + def test_map_failed_precondition(self): + err = exceptions.FailedPrecondition("Precondition") + mapped_err = errors.map_spanner_error(err) + self.assertIsInstance(mapped_err, errors.OperationalError) + + def test_map_out_of_range(self): + err = exceptions.OutOfRange("OOR") + mapped_err = errors.map_spanner_error(err) + self.assertIsInstance(mapped_err, errors.DataError) + + def test_map_unknown(self): + err = exceptions.Unknown("Unknown") + mapped_err = errors.map_spanner_error(err) + self.assertIsInstance(mapped_err, errors.DatabaseError) diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_types.py b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_types.py new file mode 100644 index 000000000000..4dd3b45f11ed --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_types.py @@ -0,0 +1,57 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import unittest + +from google.cloud.spanner_v1 import TypeCode + +from google.cloud.spanner_driver import types + + +class TestTypes(unittest.TestCase): + def test_date(self): + d = types.Date(2025, 1, 1) + self.assertEqual(d, datetime.date(2025, 1, 1)) + + def test_time(self): + t = types.Time(12, 30, 0) + self.assertEqual(t, datetime.time(12, 30, 0)) + + def test_timestamp(self): + ts = types.Timestamp(2025, 1, 1, 12, 30, 0) + self.assertEqual(ts, datetime.datetime(2025, 1, 1, 12, 30, 0)) + + def test_binary(self): + b = types.Binary("hello") + self.assertEqual(b, b"hello") + b2 = types.Binary(b"world") + self.assertEqual(b2, b"world") + + def test_type_objects(self): + self.assertEqual(types.STRING, types.STRING) + self.assertNotEqual(types.STRING, types.NUMBER) + self.assertEqual(types.STRING, "STRING") # DBAPITypeObject compares using 'in' + + def test_type_code_mapping(self): + self.assertEqual(types._type_code_to_dbapi_type(TypeCode.STRING), types.STRING) + self.assertEqual(types._type_code_to_dbapi_type(TypeCode.INT64), types.NUMBER) + self.assertEqual(types._type_code_to_dbapi_type(TypeCode.BOOL), types.BOOLEAN) + self.assertEqual(types._type_code_to_dbapi_type(TypeCode.FLOAT64), types.NUMBER) + self.assertEqual(types._type_code_to_dbapi_type(TypeCode.BYTES), types.BINARY) + self.assertEqual( + types._type_code_to_dbapi_type(TypeCode.TIMESTAMP), types.DATETIME + ) + self.assertEqual(types._type_code_to_dbapi_type(TypeCode.DATE), types.DATETIME) + self.assertEqual(types._type_code_to_dbapi_type(TypeCode.JSON), types.STRING) From 1db7aa35d346b951ca3283caa432e4d0216598ee Mon Sep 17 00:00:00 2001 From: Sanjeev Bhatt Date: Wed, 25 Mar 2026 12:29:47 +0000 Subject: [PATCH 2/2] refactor: improve error handling, parameter validation, and logging in Spanner DBAPI driver --- .../google/cloud/spanner_driver/__init__.py | 42 +++++++++---------- .../google/cloud/spanner_driver/connection.py | 6 +-- .../google/cloud/spanner_driver/cursor.py | 11 +++-- .../tests/system/test_cursor.py | 1 - .../tests/unit/test_connection.py | 9 +--- .../tests/unit/test_cursor.py | 4 ++ 6 files changed, 37 insertions(+), 36 deletions(-) diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py index d898b418c6f5..32cac8125778 100644 --- a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py @@ -52,32 +52,32 @@ logger.addHandler(logging.NullHandler()) __all__: list[str] = [ - "apilevel", - "threadsafety", - "paramstyle", + "BINARY", + "Binary", "Connection", - "connect", "Cursor", - "Date", - "Time", - "Timestamp", - "DateFromTicks", - "TimeFromTicks", - "TimestampFromTicks", - "Binary", - "STRING", - "BINARY", - "NUMBER", "DATETIME", - "ROWID", - "InterfaceError", - "ProgrammingError", - "OperationalError", - "DatabaseError", "DataError", - "NotSupportedError", + "DatabaseError", + "Date", + "DateFromTicks", + "Error", "IntegrityError", + "InterfaceError", "InternalError", + "NUMBER", + "NotSupportedError", + "OperationalError", + "ProgrammingError", + "ROWID", + "STRING", + "Time", + "TimeFromTicks", + "Timestamp", + "TimestampFromTicks", "Warning", - "Error", + "apilevel", + "connect", + "paramstyle", + "threadsafety", ] diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py index 12e4c3638d98..e81cefd2497f 100644 --- a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py @@ -89,8 +89,8 @@ def commit(self) -> None: try: self._internal_conn.commit() except Exception as e: - # raise errors.map_spanner_error(e) logger.debug(f"Commit failed {e}") + raise errors.map_spanner_error(e) @check_not_closed def rollback(self) -> None: @@ -102,8 +102,8 @@ def rollback(self) -> None: try: self._internal_conn.rollback() except Exception as e: - # raise errors.map_spanner_error(e) logger.debug(f"Rollback failed {e}") + raise errors.map_spanner_error(e) def close(self) -> None: """Close the connection now. @@ -127,7 +127,7 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() -def connect(connection_string: str, **kwargs: Any) -> Connection: +def connect(connection_string: str) -> Connection: logger.debug(f"Connecting to {connection_string}") # Create the pool pool = Pool.create_pool(connection_string) diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py index a81e95ef47e8..4278ef791d17 100644 --- a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py @@ -116,7 +116,8 @@ def description(self) -> tuple[tuple[Any, ...], ...] | None: ) ) return tuple(desc) - except Exception: + except Exception as e: + logger.warning("Could not determine cursor description: %s", e) return None @property @@ -165,8 +166,9 @@ def _prepare_params( # GoogleSQL Dialect: Named parameters @name are mapped directly. iterator = parameters.items() else: - # If strictly required, raise an error for unsupported types - return {}, {} + raise errors.ProgrammingError( + f"Parameters must be a dict, list, or tuple, not {type(parameters).__name__}" + ) for key, value in iterator: if value is None: @@ -429,7 +431,8 @@ def nextset(self) -> bool | None: if next_metadata: return True return None - except Exception: + except Exception as e: + logger.warning("Could not determine next set of results: %s", e) return None def __enter__(self) -> "Cursor": diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/system/test_cursor.py b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_cursor.py index 5719b4030fa5..5287fc646008 100644 --- a/packages/google-cloud-spanner-dbapi-driver/tests/system/test_cursor.py +++ b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_cursor.py @@ -141,4 +141,3 @@ def test_data_types(self): assert row[2] is True assert row[3] == "hello" assert row[4] == b"bytes" - assert row[4] == b"bytes" diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_connection.py b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_connection.py index ed9a0fa18736..56feea8792ea 100644 --- a/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_connection.py +++ b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_connection.py @@ -70,10 +70,8 @@ def test_commit(self): def test_commit_error(self): self.mock_internal_conn.commit.side_effect = Exception("Commit Failed") - try: + with self.assertRaises(errors.DatabaseError): self.conn.commit() - except Exception: - self.fail("commit() raised Exception unexpectedly!") self.mock_internal_conn.commit.assert_called_once() def test_rollback(self): @@ -81,12 +79,9 @@ def test_rollback(self): self.mock_internal_conn.rollback.assert_called_once() def test_rollback_error(self): - # Similar to commit, rollback errors are caught and logged self.mock_internal_conn.rollback.side_effect = Exception("Rollback Failed") - try: + with self.assertRaises(errors.DatabaseError): self.conn.rollback() - except Exception: - self.fail("rollback() raised Exception unexpectedly!") self.mock_internal_conn.rollback.assert_called_once() def test_close(self): diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_cursor.py b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_cursor.py index 7cb6cf4e992f..9042a66c0645 100644 --- a/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_cursor.py +++ b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_cursor.py @@ -347,3 +347,7 @@ def test_prepare_params(self): self.assertEqual(converted["P2"], "test") self.assertEqual(types["P2"].code, TypeCode.STRING) + + def test_prepare_params_unsupported_type(self): + with self.assertRaises(cursor.errors.ProgrammingError): + self.cursor._prepare_params(123) # Int is not supported directly