diff --git a/pyproject.toml b/pyproject.toml index 3f21e3da..e2ece284 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,9 @@ all = [ "redshift-connector", "pyathena>=3.22.0", "adbc-driver-flightsql>=1.0.0", + "impyla>=0.18.0", + "osquery>=3.0.0", + "surrealdb>=0.3.0", ] postgres = ["psycopg2-binary>=2.9.0"] cockroachdb = ["psycopg2-binary>=2.9.0"] @@ -81,6 +84,9 @@ firebird = ["firebirdsql>=1.3.5"] snowflake = ["snowflake-connector-python>=3.7.0"] athena = ["pyathena>=3.22.0"] flight = ["adbc-driver-flightsql>=1.0.0"] +impala = ["impyla>=0.18.0"] +osquery = ["osquery>=3.0.0"] +surrealdb = ["surrealdb>=0.3.0"] ssh = [ "sshtunnel>=0.4.0", "paramiko>=2.0.0,<4.0.0", @@ -241,7 +247,11 @@ module = [ "adbc_driver_flightsql", "adbc_driver_flightsql.dbapi", "adbc_driver_manager", - "textual_fastdatatable" + "textual_fastdatatable", + "impala", + "impala.dbapi", + "osquery", + "surrealdb" ] ignore_missing_imports = true diff --git a/sqlit/domains/connections/domain/config.py b/sqlit/domains/connections/domain/config.py index c2eba8f3..5038287b 100644 --- a/sqlit/domains/connections/domain/config.py +++ b/sqlit/domains/connections/domain/config.py @@ -19,18 +19,21 @@ class DatabaseType(str, Enum): FIREBIRD = "firebird" FLIGHT = "flight" HANA = "hana" + IMPALA = "impala" MARIADB = "mariadb" MOTHERDUCK = "motherduck" MSSQL = "mssql" MYSQL = "mysql" ORACLE = "oracle" ORACLE_LEGACY = "oracle_legacy" + OSQUERY = "osquery" POSTGRESQL = "postgresql" PRESTO = "presto" REDSHIFT = "redshift" SNOWFLAKE = "snowflake" SQLITE = "sqlite" SUPABASE = "supabase" + SURREALDB = "surrealdb" TERADATA = "teradata" TRINO = "trino" TURSO = "turso" @@ -52,6 +55,7 @@ class DatabaseType(str, Enum): DatabaseType.BIGQUERY, DatabaseType.TRINO, DatabaseType.PRESTO, + DatabaseType.IMPALA, DatabaseType.DUCKDB, DatabaseType.MOTHERDUCK, DatabaseType.REDSHIFT, @@ -63,6 +67,8 @@ class DatabaseType(str, Enum): DatabaseType.ATHENA, DatabaseType.FIREBIRD, DatabaseType.FLIGHT, + DatabaseType.OSQUERY, + DatabaseType.SURREALDB, ] diff --git a/sqlit/domains/connections/providers/impala/__init__.py b/sqlit/domains/connections/providers/impala/__init__.py new file mode 100644 index 00000000..24a78646 --- /dev/null +++ b/sqlit/domains/connections/providers/impala/__init__.py @@ -0,0 +1 @@ +"""Impala provider package.""" diff --git a/sqlit/domains/connections/providers/impala/adapter.py b/sqlit/domains/connections/providers/impala/adapter.py new file mode 100644 index 00000000..f9a9f7a7 --- /dev/null +++ b/sqlit/domains/connections/providers/impala/adapter.py @@ -0,0 +1,176 @@ +"""Impala adapter using impyla.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from sqlit.domains.connections.providers.adapters.base import ( + ColumnInfo, + CursorBasedAdapter, + IndexInfo, + SequenceInfo, + TableInfo, + TriggerInfo, +) +from sqlit.domains.connections.providers.registry import get_default_port + +if TYPE_CHECKING: + from sqlit.domains.connections.domain.config import ConnectionConfig + + +class ImpalaAdapter(CursorBasedAdapter): + """Adapter for Apache Impala using impyla.""" + + @property + def name(self) -> str: + return "Impala" + + @property + def install_extra(self) -> str: + return "impala" + + @property + def install_package(self) -> str: + return "impyla" + + @property + def driver_import_names(self) -> tuple[str, ...]: + return ("impala.dbapi",) + + @property + def supports_multiple_databases(self) -> bool: + return True + + @property + def supports_cross_database_queries(self) -> bool: + return True + + @property + def supports_stored_procedures(self) -> bool: + return False + + @property + def supports_indexes(self) -> bool: + return False # Impala uses partitions, not traditional indexes + + @property + def supports_triggers(self) -> bool: + return False + + @property + def supports_sequences(self) -> bool: + return False + + @property + def system_databases(self) -> frozenset[str]: + return frozenset({"_impala_builtins"}) + + @property + def default_schema(self) -> str: + return "" + + def apply_database_override(self, config: ConnectionConfig, database: str) -> ConnectionConfig: + """Apply a default database for unqualified queries.""" + if not database: + return config + return config.with_endpoint(database=database) + + def connect(self, config: ConnectionConfig) -> Any: + impala_dbapi = self._import_driver_module( + "impala.dbapi", + driver_name=self.name, + extra_name=self.install_extra, + package_name=self.install_package, + ) + + endpoint = config.tcp_endpoint + if endpoint is None: + raise ValueError("Impala connections require a TCP-style endpoint.") + port = int(endpoint.port or get_default_port("impala")) + + auth_mechanism = str(config.get_option("auth_mechanism", "NOSASL")) + use_ssl = str(config.get_option("use_ssl", "false")).lower() == "true" + + connect_args: dict[str, Any] = { + "host": endpoint.host, + "port": port, + "auth_mechanism": auth_mechanism, + "use_ssl": use_ssl, + } + + if endpoint.database: + connect_args["database"] = endpoint.database + + if endpoint.username: + connect_args["user"] = endpoint.username + if endpoint.password: + connect_args["password"] = endpoint.password + + connect_args.update(config.extra_options) + return impala_dbapi.connect(**connect_args) + + def get_databases(self, conn: Any) -> list[str]: + cursor = conn.cursor() + cursor.execute("SHOW DATABASES") + return [row[0] for row in cursor.fetchall()] + + def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]: + cursor = conn.cursor() + if database: + cursor.execute(f"SHOW TABLES IN {self.quote_identifier(database)}") + else: + cursor.execute("SHOW TABLES") + return [("", row[0]) for row in cursor.fetchall()] + + def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: + # Impala doesn't distinguish views in SHOW TABLES by default + # We can query from information_schema if available + cursor = conn.cursor() + try: + if database: + cursor.execute( + f"SELECT table_name FROM {self.quote_identifier(database)}.information_schema.tables " + "WHERE table_type = 'VIEW' ORDER BY table_name" + ) + else: + cursor.execute( + "SELECT table_name FROM information_schema.tables " + "WHERE table_type = 'VIEW' ORDER BY table_name" + ) + return [("", row[0]) for row in cursor.fetchall()] + except Exception: + # information_schema might not be available + return [] + + def get_columns( + self, conn: Any, table: str, database: str | None = None, schema: str | None = None + ) -> list[ColumnInfo]: + cursor = conn.cursor() + if database: + cursor.execute(f"DESCRIBE {self.quote_identifier(database)}.{self.quote_identifier(table)}") + else: + cursor.execute(f"DESCRIBE {self.quote_identifier(table)}") + return [ColumnInfo(name=row[0], data_type=row[1]) for row in cursor.fetchall()] + + def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: + return [] + + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: + return [] + + def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerInfo]: + return [] + + def get_sequences(self, conn: Any, database: str | None = None) -> list[SequenceInfo]: + return [] + + def quote_identifier(self, name: str) -> str: + escaped = name.replace("`", "``") + return f"`{escaped}`" + + def build_select_query( + self, table: str, limit: int, database: str | None = None, schema: str | None = None + ) -> str: + if database: + return f"SELECT * FROM `{database}`.`{table}` LIMIT {limit}" + return f"SELECT * FROM `{table}` LIMIT {limit}" diff --git a/sqlit/domains/connections/providers/impala/provider.py b/sqlit/domains/connections/providers/impala/provider.py new file mode 100644 index 00000000..444fddca --- /dev/null +++ b/sqlit/domains/connections/providers/impala/provider.py @@ -0,0 +1,29 @@ +"""Provider registration for Impala.""" + +from sqlit.domains.connections.providers.adapter_provider import build_adapter_provider +from sqlit.domains.connections.providers.catalog import register_provider +from sqlit.domains.connections.providers.impala.schema import SCHEMA +from sqlit.domains.connections.providers.model import DatabaseProvider, ProviderSpec + + +def _provider_factory(spec: ProviderSpec) -> DatabaseProvider: + from sqlit.domains.connections.providers.impala.adapter import ImpalaAdapter + + return build_adapter_provider(spec, SCHEMA, ImpalaAdapter()) + + +SPEC = ProviderSpec( + db_type="impala", + display_name="Impala", + schema_path=("sqlit.domains.connections.providers.impala.schema", "SCHEMA"), + supports_ssh=True, + is_file_based=False, + has_advanced_auth=True, # Kerberos support + default_port="21050", + requires_auth=False, + badge_label="Impala", + url_schemes=("impala",), + provider_factory=_provider_factory, +) + +register_provider(SPEC) diff --git a/sqlit/domains/connections/providers/impala/schema.py b/sqlit/domains/connections/providers/impala/schema.py new file mode 100644 index 00000000..e0b65bf8 --- /dev/null +++ b/sqlit/domains/connections/providers/impala/schema.py @@ -0,0 +1,61 @@ +"""Connection schema for Impala.""" + +from sqlit.domains.connections.providers.schema_helpers import ( + SSH_FIELDS, + ConnectionSchema, + FieldType, + SchemaField, + SelectOption, + _password_field, + _port_field, + _server_field, + _username_field, +) + + +def _get_auth_mechanism_options() -> tuple[SelectOption, ...]: + return ( + SelectOption("NOSASL", "No Auth"), + SelectOption("PLAIN", "PLAIN (LDAP)"), + SelectOption("GSSAPI", "Kerberos (GSSAPI)"), + ) + + +SCHEMA = ConnectionSchema( + db_type="impala", + display_name="Impala", + fields=( + _server_field(), + _port_field("21050"), + SchemaField( + name="database", + label="Database", + placeholder="default", + required=False, + ), + _username_field(required=False), + _password_field(), + SchemaField( + name="auth_mechanism", + label="Auth Mechanism", + field_type=FieldType.SELECT, + options=_get_auth_mechanism_options(), + default="NOSASL", + advanced=True, + ), + SchemaField( + name="use_ssl", + label="Use SSL", + field_type=FieldType.SELECT, + options=( + SelectOption("false", "No"), + SelectOption("true", "Yes"), + ), + default="false", + advanced=True, + ), + ) + + SSH_FIELDS, + default_port="21050", + requires_auth=False, +) diff --git a/sqlit/domains/connections/providers/osquery/__init__.py b/sqlit/domains/connections/providers/osquery/__init__.py new file mode 100644 index 00000000..1ea80390 --- /dev/null +++ b/sqlit/domains/connections/providers/osquery/__init__.py @@ -0,0 +1 @@ +"""osquery provider package.""" diff --git a/sqlit/domains/connections/providers/osquery/adapter.py b/sqlit/domains/connections/providers/osquery/adapter.py new file mode 100644 index 00000000..0638530d --- /dev/null +++ b/sqlit/domains/connections/providers/osquery/adapter.py @@ -0,0 +1,231 @@ +"""osquery adapter using osquery-python.""" + +from __future__ import annotations + +import platform +from typing import TYPE_CHECKING, Any + +from sqlit.domains.connections.providers.adapters.base import ( + ColumnInfo, + DatabaseAdapter, + IndexInfo, + SequenceInfo, + TableInfo, + TriggerInfo, +) + +if TYPE_CHECKING: + from sqlit.domains.connections.domain.config import ConnectionConfig + + +class OsqueryConnection: + """Wrapper for osquery connection that provides a consistent interface.""" + + def __init__(self, instance: Any, is_spawned: bool = False) -> None: + self.instance = instance + self.is_spawned = is_spawned + self._client: Any = None + + @property + def client(self) -> Any: + if self._client is None: + self._client = self.instance.client + return self._client + + def close(self) -> None: + if self.is_spawned: + # SpawnInstance doesn't need explicit close + pass + else: + self.instance.close() + + +class OsqueryAdapter(DatabaseAdapter): + """Adapter for osquery using osquery-python. + + osquery is not a traditional database - it queries system information + through virtual tables. Supports both spawning an embedded instance + and connecting to a running osqueryd daemon via socket. + """ + + @property + def name(self) -> str: + return "osquery" + + @property + def install_extra(self) -> str: + return "osquery" + + @property + def install_package(self) -> str: + return "osquery" + + @property + def driver_import_names(self) -> tuple[str, ...]: + return ("osquery",) + + @property + def supports_multiple_databases(self) -> bool: + return False + + @property + def supports_cross_database_queries(self) -> bool: + return False + + @property + def supports_stored_procedures(self) -> bool: + return False + + @property + def supports_indexes(self) -> bool: + return False + + @property + def supports_triggers(self) -> bool: + return False + + @property + def supports_sequences(self) -> bool: + return False + + @property + def supports_process_worker(self) -> bool: + # osquery spawned instances may not work well across process boundaries + return False + + @property + def default_schema(self) -> str: + return "" + + @property + def test_query(self) -> str: + return "SELECT 1 AS test" + + def execute_test_query(self, conn: Any) -> None: + """Execute a simple query to verify the connection works.""" + result = conn.client.query("SELECT 1 AS test") + if result.status.code != 0: + raise Exception(f"osquery test failed: {result.status.message}") + + def _get_default_socket_path(self) -> str: + """Get the default osquery socket path for the current platform.""" + if platform.system() == "Windows": + return r"\\.\pipe\osquery.em" + return "/var/osquery/osquery.em" + + def connect(self, config: ConnectionConfig) -> OsqueryConnection: + osquery_module = self._import_driver_module( + "osquery", + driver_name=self.name, + extra_name=self.install_extra, + package_name=self.install_package, + ) + + connection_mode = str(config.get_option("connection_mode", "spawn")) + + if connection_mode == "socket": + socket_path = config.get_option("socket_path") + if not socket_path: + socket_path = self._get_default_socket_path() + instance = osquery_module.ExtensionClient(socket_path) + instance.open() + return OsqueryConnection(instance, is_spawned=False) + else: + # Spawn embedded instance + instance = osquery_module.SpawnInstance() + instance.open() + return OsqueryConnection(instance, is_spawned=True) + + def disconnect(self, conn: Any) -> None: + if isinstance(conn, OsqueryConnection): + conn.close() + + def get_databases(self, conn: Any) -> list[str]: + # osquery has a single virtual database + return ["main"] + + def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]: + """Get list of osquery virtual tables.""" + result = conn.client.query( + "SELECT name FROM osquery_registry WHERE registry = 'table' ORDER BY name" + ) + if result.status.code != 0: + return [] + return [("", row.get("name", "")) for row in result.response if row.get("name")] + + def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: + # osquery doesn't have views + return [] + + def get_columns( + self, conn: Any, table: str, database: str | None = None, schema: str | None = None + ) -> list[ColumnInfo]: + """Get column information for an osquery table using PRAGMA.""" + result = conn.client.query(f"PRAGMA table_info({table})") + if result.status.code != 0: + return [] + columns = [] + for row in result.response: + name = row.get("name", "") + data_type = row.get("type", "TEXT") + if name: + columns.append(ColumnInfo(name=name, data_type=data_type)) + return columns + + def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: + return [] + + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: + return [] + + def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerInfo]: + return [] + + def get_sequences(self, conn: Any, database: str | None = None) -> list[SequenceInfo]: + return [] + + def quote_identifier(self, name: str) -> str: + escaped = name.replace('"', '""') + return f'"{escaped}"' + + def build_select_query( + self, table: str, limit: int, database: str | None = None, schema: str | None = None + ) -> str: + return f'SELECT * FROM "{table}" LIMIT {limit}' + + def execute_query( + self, conn: Any, query: str, max_rows: int | None = None + ) -> tuple[list[str], list[tuple], bool]: + """Execute a query and return (columns, rows, truncated).""" + result = conn.client.query(query) + if result.status.code != 0: + raise Exception(f"osquery error: {result.status.message}") + + if not result.response: + return [], [], False + + # Get columns from first row + first_row = result.response[0] + columns = list(first_row.keys()) + + # Convert rows to tuples + all_rows = [tuple(row.get(col, None) for col in columns) for row in result.response] + + if max_rows is not None and len(all_rows) > max_rows: + return columns, all_rows[:max_rows], True + + return columns, all_rows, False + + def execute_non_query(self, conn: Any, query: str) -> int: + """Execute a non-query statement. + + osquery is read-only, so this just executes the query for compatibility. + """ + result = conn.client.query(query) + if result.status.code != 0: + raise Exception(f"osquery error: {result.status.message}") + return 0 + + def classify_query(self, query: str) -> bool: + """osquery queries are always SELECT-like (read-only).""" + return True diff --git a/sqlit/domains/connections/providers/osquery/provider.py b/sqlit/domains/connections/providers/osquery/provider.py new file mode 100644 index 00000000..03fa58d7 --- /dev/null +++ b/sqlit/domains/connections/providers/osquery/provider.py @@ -0,0 +1,29 @@ +"""Provider registration for osquery.""" + +from sqlit.domains.connections.providers.adapter_provider import build_adapter_provider +from sqlit.domains.connections.providers.catalog import register_provider +from sqlit.domains.connections.providers.model import DatabaseProvider, ProviderSpec +from sqlit.domains.connections.providers.osquery.schema import SCHEMA + + +def _provider_factory(spec: ProviderSpec) -> DatabaseProvider: + from sqlit.domains.connections.providers.osquery.adapter import OsqueryAdapter + + return build_adapter_provider(spec, SCHEMA, OsqueryAdapter()) + + +SPEC = ProviderSpec( + db_type="osquery", + display_name="osquery", + schema_path=("sqlit.domains.connections.providers.osquery.schema", "SCHEMA"), + supports_ssh=False, + is_file_based=False, + has_advanced_auth=False, + default_port="", + requires_auth=False, + badge_label="osq", + url_schemes=("osquery",), + provider_factory=_provider_factory, +) + +register_provider(SPEC) diff --git a/sqlit/domains/connections/providers/osquery/schema.py b/sqlit/domains/connections/providers/osquery/schema.py new file mode 100644 index 00000000..28a32293 --- /dev/null +++ b/sqlit/domains/connections/providers/osquery/schema.py @@ -0,0 +1,42 @@ +"""Connection schema for osquery.""" + +from sqlit.domains.connections.providers.schema_helpers import ( + ConnectionSchema, + FieldType, + SchemaField, + SelectOption, +) + + +def _connection_mode_is_socket(v: dict) -> bool: + return v.get("connection_mode") == "socket" + + +SCHEMA = ConnectionSchema( + db_type="osquery", + display_name="osquery", + fields=( + SchemaField( + name="connection_mode", + label="Connection Mode", + field_type=FieldType.SELECT, + options=( + SelectOption("spawn", "Spawn Instance (embedded)"), + SelectOption("socket", "Connect to Socket"), + ), + default="spawn", + ), + SchemaField( + name="socket_path", + label="Socket Path", + placeholder="/var/osquery/osquery.em", + required=False, + visible_when=_connection_mode_is_socket, + description="Path to osqueryd extension socket", + ), + ), + supports_ssh=False, + is_file_based=False, + default_port="", + requires_auth=False, +) diff --git a/sqlit/domains/connections/providers/surrealdb/__init__.py b/sqlit/domains/connections/providers/surrealdb/__init__.py new file mode 100644 index 00000000..5a55958b --- /dev/null +++ b/sqlit/domains/connections/providers/surrealdb/__init__.py @@ -0,0 +1 @@ +"""SurrealDB provider package.""" diff --git a/sqlit/domains/connections/providers/surrealdb/adapter.py b/sqlit/domains/connections/providers/surrealdb/adapter.py new file mode 100644 index 00000000..107ed97a --- /dev/null +++ b/sqlit/domains/connections/providers/surrealdb/adapter.py @@ -0,0 +1,306 @@ +"""SurrealDB adapter using surrealdb.py SDK.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from sqlit.domains.connections.providers.adapters.base import ( + ColumnInfo, + DatabaseAdapter, + IndexInfo, + SequenceInfo, + TableInfo, + TriggerInfo, +) +from sqlit.domains.connections.providers.registry import get_default_port + +if TYPE_CHECKING: + from sqlit.domains.connections.domain.config import ConnectionConfig + + +class SurrealDBAdapter(DatabaseAdapter): + """Adapter for SurrealDB using the official Python SDK. + + SurrealDB is a multi-model database that uses SurrealQL, + a query language similar to SQL but with some differences. + """ + + @property + def name(self) -> str: + return "SurrealDB" + + @property + def install_extra(self) -> str: + return "surrealdb" + + @property + def install_package(self) -> str: + return "surrealdb" + + @property + def driver_import_names(self) -> tuple[str, ...]: + return ("surrealdb",) + + @property + def supports_multiple_databases(self) -> bool: + return True # Namespace/database hierarchy + + @property + def supports_cross_database_queries(self) -> bool: + return False # Must use() a specific database + + @property + def supports_stored_procedures(self) -> bool: + return False + + @property + def supports_indexes(self) -> bool: + return True + + @property + def supports_triggers(self) -> bool: + return False + + @property + def supports_sequences(self) -> bool: + return False + + @property + def supports_process_worker(self) -> bool: + # WebSocket connections may not work well across process boundaries + return False + + @property + def default_schema(self) -> str: + return "" + + @property + def test_query(self) -> str: + return "RETURN 1" + + def connect(self, config: ConnectionConfig) -> Any: + surrealdb_module = self._import_driver_module( + "surrealdb", + driver_name=self.name, + extra_name=self.install_extra, + package_name=self.install_package, + ) + + endpoint = config.tcp_endpoint + if endpoint is None: + raise ValueError("SurrealDB connections require a TCP-style endpoint.") + port = int(endpoint.port or get_default_port("surrealdb")) + + # Build WebSocket URL + use_ssl = str(config.get_option("use_ssl", "false")).lower() == "true" + scheme = "wss" if use_ssl else "ws" + url = f"{scheme}://{endpoint.host}:{port}/rpc" + + # Create sync connection + db = surrealdb_module.Surreal(url) + db.connect() + + # Sign in if credentials provided + if endpoint.username and endpoint.password: + db.signin({"user": endpoint.username, "pass": endpoint.password}) + + # Select namespace and database + namespace = config.get_option("namespace", "test") + database = endpoint.database or config.get_option("database", "test") + db.use(namespace, database) + + return db + + def disconnect(self, conn: Any) -> None: + if hasattr(conn, "close"): + conn.close() + + def execute_test_query(self, conn: Any) -> None: + """Execute a simple query to verify the connection works.""" + result = conn.query("RETURN 1") + if not result: + raise Exception("SurrealDB test query returned no result") + + def get_databases(self, conn: Any) -> list[str]: + """Get list of databases in the current namespace.""" + try: + result = conn.query("INFO FOR NS") + if result and isinstance(result, list) and result[0]: + info = result[0] + if isinstance(info, dict) and "databases" in info: + return list(info["databases"].keys()) + except Exception: + pass + return [] + + def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]: + """Get list of tables in the current database.""" + try: + result = conn.query("INFO FOR DB") + if result and isinstance(result, list) and result[0]: + info = result[0] + if isinstance(info, dict) and "tables" in info: + tables = list(info["tables"].keys()) + return [("", t) for t in sorted(tables)] + except Exception: + pass + return [] + + def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: + # SurrealDB doesn't have traditional views + return [] + + def get_columns( + self, conn: Any, table: str, database: str | None = None, schema: str | None = None + ) -> list[ColumnInfo]: + """Get column information for a table. + + SurrealDB is schemaless by default, so we sample records to infer columns. + If a schema is defined, we use INFO FOR TABLE. + """ + columns: list[ColumnInfo] = [] + + try: + # First try to get schema info + result = conn.query(f"INFO FOR TABLE {table}") + if result and isinstance(result, list) and result[0]: + info = result[0] + if isinstance(info, dict) and "fields" in info: + for field_name, field_def in info["fields"].items(): + # field_def might contain type info + data_type = "any" + if isinstance(field_def, dict) and "type" in field_def: + data_type = str(field_def["type"]) + elif isinstance(field_def, str): + # Try to extract type from definition string + data_type = field_def.split()[0] if field_def else "any" + columns.append(ColumnInfo(name=field_name, data_type=data_type)) + + # If no schema fields, sample data to infer columns + if not columns: + sample = conn.query(f"SELECT * FROM {table} LIMIT 1") + if sample and isinstance(sample, list) and sample[0]: + first_row = sample[0] + if isinstance(first_row, dict): + for key in first_row.keys(): + if key != "id": # id is always present + value = first_row[key] + data_type = type(value).__name__ if value is not None else "any" + columns.append(ColumnInfo(name=key, data_type=data_type)) + # Add id column first + columns.insert(0, ColumnInfo(name="id", data_type="record")) + except Exception: + pass + + return columns + + def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: + return [] + + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: + """Get list of indexes across all tables.""" + indexes: list[IndexInfo] = [] + try: + result = conn.query("INFO FOR DB") + if result and isinstance(result, list) and result[0]: + info = result[0] + if isinstance(info, dict) and "tables" in info: + for table_name in info["tables"].keys(): + table_info = conn.query(f"INFO FOR TABLE {table_name}") + if table_info and isinstance(table_info, list) and table_info[0]: + t_info = table_info[0] + if isinstance(t_info, dict) and "indexes" in t_info: + for idx_name, idx_def in t_info["indexes"].items(): + is_unique = "UNIQUE" in str(idx_def).upper() if idx_def else False + indexes.append(IndexInfo( + name=idx_name, + table_name=table_name, + is_unique=is_unique + )) + except Exception: + pass + return indexes + + def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerInfo]: + return [] + + def get_sequences(self, conn: Any, database: str | None = None) -> list[SequenceInfo]: + return [] + + def quote_identifier(self, name: str) -> str: + # SurrealDB uses backticks for identifiers with special characters + if any(c in name for c in " -./"): + escaped = name.replace("`", "``") + return f"`{escaped}`" + return name + + def build_select_query( + self, table: str, limit: int, database: str | None = None, schema: str | None = None + ) -> str: + return f"SELECT * FROM {self.quote_identifier(table)} LIMIT {limit}" + + def execute_query( + self, conn: Any, query: str, max_rows: int | None = None + ) -> tuple[list[str], list[tuple], bool]: + """Execute a query and return (columns, rows, truncated).""" + result = conn.query(query) + + if not result: + return [], [], False + + # SurrealDB returns a list of results (one per statement) + # For a single query, we take the first result + data = result[0] if isinstance(result, list) else result + + # Handle single value returns (like RETURN 1) + if not isinstance(data, (list, dict)): + return ["result"], [(data,)], False + + # Handle empty results + if isinstance(data, list) and not data: + return [], [], False + + # Handle list of records + if isinstance(data, list): + if not data: + return [], [], False + first = data[0] + if isinstance(first, dict): + columns = list(first.keys()) + all_rows = [tuple(row.get(col) for col in columns) for row in data] + if max_rows is not None and len(all_rows) > max_rows: + return columns, all_rows[:max_rows], True + return columns, all_rows, False + # List of non-dict values + rows = [(v,) for v in (data[:max_rows] if max_rows else data)] + truncated = max_rows is not None and len(data) > max_rows + return ["value"], rows, truncated + + # Handle single dict + if isinstance(data, dict): + columns = list(data.keys()) + return columns, [tuple(data.values())], False + + return [], [], False + + def execute_non_query(self, conn: Any, query: str) -> int: + """Execute a non-query statement.""" + result = conn.query(query) + # SurrealDB doesn't return row counts in the traditional sense + # Return 1 if operation succeeded + if result is not None: + if isinstance(result, list) and result: + data = result[0] + if isinstance(data, list): + return len(data) + return 1 + return 0 + + def classify_query(self, query: str) -> bool: + """Return True if the query is expected to return rows.""" + query_type = query.strip().upper().split()[0] if query.strip() else "" + # SurrealQL query types that return data + return query_type in { + "SELECT", "RETURN", "INFO", "SHOW", "LIVE", + "CREATE", "INSERT", "UPDATE", "UPSERT", "DELETE" # These also return the affected records + } diff --git a/sqlit/domains/connections/providers/surrealdb/provider.py b/sqlit/domains/connections/providers/surrealdb/provider.py new file mode 100644 index 00000000..6519186f --- /dev/null +++ b/sqlit/domains/connections/providers/surrealdb/provider.py @@ -0,0 +1,29 @@ +"""Provider registration for SurrealDB.""" + +from sqlit.domains.connections.providers.adapter_provider import build_adapter_provider +from sqlit.domains.connections.providers.catalog import register_provider +from sqlit.domains.connections.providers.model import DatabaseProvider, ProviderSpec +from sqlit.domains.connections.providers.surrealdb.schema import SCHEMA + + +def _provider_factory(spec: ProviderSpec) -> DatabaseProvider: + from sqlit.domains.connections.providers.surrealdb.adapter import SurrealDBAdapter + + return build_adapter_provider(spec, SCHEMA, SurrealDBAdapter()) + + +SPEC = ProviderSpec( + db_type="surrealdb", + display_name="SurrealDB", + schema_path=("sqlit.domains.connections.providers.surrealdb.schema", "SCHEMA"), + supports_ssh=True, + is_file_based=False, + has_advanced_auth=False, + default_port="8000", + requires_auth=True, + badge_label="Surreal", + url_schemes=("surrealdb", "surreal"), + provider_factory=_provider_factory, +) + +register_provider(SPEC) diff --git a/sqlit/domains/connections/providers/surrealdb/schema.py b/sqlit/domains/connections/providers/surrealdb/schema.py new file mode 100644 index 00000000..4f4c84cb --- /dev/null +++ b/sqlit/domains/connections/providers/surrealdb/schema.py @@ -0,0 +1,50 @@ +"""Connection schema for SurrealDB.""" + +from sqlit.domains.connections.providers.schema_helpers import ( + SSH_FIELDS, + ConnectionSchema, + FieldType, + SchemaField, + SelectOption, + _password_field, + _port_field, + _server_field, + _username_field, +) + +SCHEMA = ConnectionSchema( + db_type="surrealdb", + display_name="SurrealDB", + fields=( + _server_field(), + _port_field("8000"), + SchemaField( + name="namespace", + label="Namespace", + placeholder="test", + required=True, + ), + SchemaField( + name="database", + label="Database", + placeholder="test", + required=True, + ), + _username_field(), + _password_field(), + SchemaField( + name="use_ssl", + label="Use SSL", + field_type=FieldType.SELECT, + options=( + SelectOption("false", "No"), + SelectOption("true", "Yes"), + ), + default="false", + advanced=True, + ), + ) + + SSH_FIELDS, + default_port="8000", + requires_auth=True, +) diff --git a/tests/fixtures/data.duckdb b/tests/fixtures/data.duckdb index 2230283a..7a2f024a 100644 Binary files a/tests/fixtures/data.duckdb and b/tests/fixtures/data.duckdb differ diff --git a/tests/fixtures/data.duckdb.wal b/tests/fixtures/data.duckdb.wal deleted file mode 100644 index 2caf51b9..00000000 Binary files a/tests/fixtures/data.duckdb.wal and /dev/null differ