diff --git a/sqlit/domains/connections/providers/adapters/base.py b/sqlit/domains/connections/providers/adapters/base.py index 036003a4..696c089e 100644 --- a/sqlit/domains/connections/providers/adapters/base.py +++ b/sqlit/domains/connections/providers/adapters/base.py @@ -450,6 +450,22 @@ def execute_non_query(self, conn: Any, query: str) -> int: pass +def _sanitize_cell(value: Any) -> Any: + """Convert non-picklable types to picklable equivalents. + + psycopg2 returns memoryview for bytea columns which cannot be pickled + through multiprocessing.Pipe, causing the process worker to hang. + """ + if isinstance(value, memoryview): + return bytes(value) + return value + + +def _sanitize_row(row: Any) -> tuple: + """Sanitize a database row so it can be pickled safely.""" + return tuple(_sanitize_cell(v) for v in row) + + class CursorBasedAdapter(DatabaseAdapter): """Base class for adapters using cursor-based execution (most SQL databases). @@ -471,7 +487,7 @@ def execute_query(self, conn: Any, query: str, max_rows: int | None = None) -> t else: rows = cursor.fetchall() truncated = False - return columns, [tuple(row) for row in rows], truncated + return columns, [_sanitize_row(row) for row in rows], truncated return [], [], False def execute_non_query(self, conn: Any, query: str) -> int: diff --git a/tests/unit/test_sanitize_row.py b/tests/unit/test_sanitize_row.py new file mode 100644 index 00000000..23acc4a5 --- /dev/null +++ b/tests/unit/test_sanitize_row.py @@ -0,0 +1,32 @@ +"""Unit tests for row sanitization in CursorBasedAdapter.""" + +from __future__ import annotations + +from sqlit.domains.connections.providers.adapters.base import _sanitize_cell, _sanitize_row + + +def test_sanitize_cell_converts_memoryview_to_bytes() -> None: + mv = memoryview(b"\xde\xad\xbe\xef") + result = _sanitize_cell(mv) + assert result == b"\xde\xad\xbe\xef" + assert isinstance(result, bytes) + + +def test_sanitize_cell_passes_through_other_types() -> None: + assert _sanitize_cell(42) == 42 + assert _sanitize_cell("hello") == "hello" + assert _sanitize_cell(None) is None + assert _sanitize_cell(3.14) == 3.14 + + +def test_sanitize_row_converts_memoryview_in_tuple() -> None: + row = (1, "row1", memoryview(b"\xca\xfe\xba\xbe")) + result = _sanitize_row(row) + assert result == (1, "row1", b"\xca\xfe\xba\xbe") + assert isinstance(result[2], bytes) + + +def test_sanitize_row_without_memoryview_unchanged() -> None: + row = (1, "text", None, 3.14) + result = _sanitize_row(row) + assert result == (1, "text", None, 3.14)