Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ all = [
"redshift-connector",
"pyathena>=3.22.0",
"adbc-driver-flightsql>=1.0.0",
"impyla>=0.18.0",
"osquery>=3.0.0",
"surrealdb>=0.3.0",
]
postgres = ["psycopg2-binary>=2.9.0"]
cockroachdb = ["psycopg2-binary>=2.9.0"]
Expand All @@ -81,6 +84,9 @@ firebird = ["firebirdsql>=1.3.5"]
snowflake = ["snowflake-connector-python>=3.7.0"]
athena = ["pyathena>=3.22.0"]
flight = ["adbc-driver-flightsql>=1.0.0"]
impala = ["impyla>=0.18.0"]
osquery = ["osquery>=3.0.0"]
surrealdb = ["surrealdb>=0.3.0"]
ssh = [
"sshtunnel>=0.4.0",
"paramiko>=2.0.0,<4.0.0",
Expand Down Expand Up @@ -241,7 +247,11 @@ module = [
"adbc_driver_flightsql",
"adbc_driver_flightsql.dbapi",
"adbc_driver_manager",
"textual_fastdatatable"
"textual_fastdatatable",
"impala",
"impala.dbapi",
"osquery",
"surrealdb"
]
ignore_missing_imports = true

Expand Down
6 changes: 6 additions & 0 deletions sqlit/domains/connections/domain/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,21 @@ class DatabaseType(str, Enum):
FIREBIRD = "firebird"
FLIGHT = "flight"
HANA = "hana"
IMPALA = "impala"
MARIADB = "mariadb"
MOTHERDUCK = "motherduck"
MSSQL = "mssql"
MYSQL = "mysql"
ORACLE = "oracle"
ORACLE_LEGACY = "oracle_legacy"
OSQUERY = "osquery"
POSTGRESQL = "postgresql"
PRESTO = "presto"
REDSHIFT = "redshift"
SNOWFLAKE = "snowflake"
SQLITE = "sqlite"
SUPABASE = "supabase"
SURREALDB = "surrealdb"
TERADATA = "teradata"
TRINO = "trino"
TURSO = "turso"
Expand All @@ -52,6 +55,7 @@ class DatabaseType(str, Enum):
DatabaseType.BIGQUERY,
DatabaseType.TRINO,
DatabaseType.PRESTO,
DatabaseType.IMPALA,
DatabaseType.DUCKDB,
DatabaseType.MOTHERDUCK,
DatabaseType.REDSHIFT,
Expand All @@ -63,6 +67,8 @@ class DatabaseType(str, Enum):
DatabaseType.ATHENA,
DatabaseType.FIREBIRD,
DatabaseType.FLIGHT,
DatabaseType.OSQUERY,
DatabaseType.SURREALDB,
]


Expand Down
1 change: 1 addition & 0 deletions sqlit/domains/connections/providers/impala/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Impala provider package."""
176 changes: 176 additions & 0 deletions sqlit/domains/connections/providers/impala/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""Impala adapter using impyla."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from sqlit.domains.connections.providers.adapters.base import (
ColumnInfo,
CursorBasedAdapter,
IndexInfo,
SequenceInfo,
TableInfo,
TriggerInfo,
)
from sqlit.domains.connections.providers.registry import get_default_port

if TYPE_CHECKING:
from sqlit.domains.connections.domain.config import ConnectionConfig


class ImpalaAdapter(CursorBasedAdapter):
"""Adapter for Apache Impala using impyla."""

@property
def name(self) -> str:
return "Impala"

@property
def install_extra(self) -> str:
return "impala"

@property
def install_package(self) -> str:
return "impyla"

@property
def driver_import_names(self) -> tuple[str, ...]:
return ("impala.dbapi",)

@property
def supports_multiple_databases(self) -> bool:
return True

@property
def supports_cross_database_queries(self) -> bool:
return True

@property
def supports_stored_procedures(self) -> bool:
return False

@property
def supports_indexes(self) -> bool:
return False # Impala uses partitions, not traditional indexes

@property
def supports_triggers(self) -> bool:
return False

@property
def supports_sequences(self) -> bool:
return False

@property
def system_databases(self) -> frozenset[str]:
return frozenset({"_impala_builtins"})

@property
def default_schema(self) -> str:
return ""

def apply_database_override(self, config: ConnectionConfig, database: str) -> ConnectionConfig:
"""Apply a default database for unqualified queries."""
if not database:
return config
return config.with_endpoint(database=database)

def connect(self, config: ConnectionConfig) -> Any:
impala_dbapi = self._import_driver_module(
"impala.dbapi",
driver_name=self.name,
extra_name=self.install_extra,
package_name=self.install_package,
)

endpoint = config.tcp_endpoint
if endpoint is None:
raise ValueError("Impala connections require a TCP-style endpoint.")
port = int(endpoint.port or get_default_port("impala"))

auth_mechanism = str(config.get_option("auth_mechanism", "NOSASL"))
use_ssl = str(config.get_option("use_ssl", "false")).lower() == "true"

connect_args: dict[str, Any] = {
"host": endpoint.host,
"port": port,
"auth_mechanism": auth_mechanism,
"use_ssl": use_ssl,
}

if endpoint.database:
connect_args["database"] = endpoint.database

if endpoint.username:
connect_args["user"] = endpoint.username
if endpoint.password:
connect_args["password"] = endpoint.password

connect_args.update(config.extra_options)
return impala_dbapi.connect(**connect_args)

def get_databases(self, conn: Any) -> list[str]:
cursor = conn.cursor()
cursor.execute("SHOW DATABASES")
return [row[0] for row in cursor.fetchall()]

def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]:
cursor = conn.cursor()
if database:
cursor.execute(f"SHOW TABLES IN {self.quote_identifier(database)}")
else:
cursor.execute("SHOW TABLES")
return [("", row[0]) for row in cursor.fetchall()]

def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]:
# Impala doesn't distinguish views in SHOW TABLES by default
# We can query from information_schema if available
cursor = conn.cursor()
try:
if database:
cursor.execute(
f"SELECT table_name FROM {self.quote_identifier(database)}.information_schema.tables "
"WHERE table_type = 'VIEW' ORDER BY table_name"
)
else:
cursor.execute(
"SELECT table_name FROM information_schema.tables "
"WHERE table_type = 'VIEW' ORDER BY table_name"
)
return [("", row[0]) for row in cursor.fetchall()]
except Exception:
# information_schema might not be available
return []

def get_columns(
self, conn: Any, table: str, database: str | None = None, schema: str | None = None
) -> list[ColumnInfo]:
cursor = conn.cursor()
if database:
cursor.execute(f"DESCRIBE {self.quote_identifier(database)}.{self.quote_identifier(table)}")
else:
cursor.execute(f"DESCRIBE {self.quote_identifier(table)}")
return [ColumnInfo(name=row[0], data_type=row[1]) for row in cursor.fetchall()]

def get_procedures(self, conn: Any, database: str | None = None) -> list[str]:
return []

def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]:
return []

def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerInfo]:
return []

def get_sequences(self, conn: Any, database: str | None = None) -> list[SequenceInfo]:
return []

def quote_identifier(self, name: str) -> str:
escaped = name.replace("`", "``")
return f"`{escaped}`"

def build_select_query(
self, table: str, limit: int, database: str | None = None, schema: str | None = None
) -> str:
if database:
return f"SELECT * FROM `{database}`.`{table}` LIMIT {limit}"
return f"SELECT * FROM `{table}` LIMIT {limit}"
29 changes: 29 additions & 0 deletions sqlit/domains/connections/providers/impala/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Provider registration for Impala."""

from sqlit.domains.connections.providers.adapter_provider import build_adapter_provider
from sqlit.domains.connections.providers.catalog import register_provider
from sqlit.domains.connections.providers.impala.schema import SCHEMA
from sqlit.domains.connections.providers.model import DatabaseProvider, ProviderSpec


def _provider_factory(spec: ProviderSpec) -> DatabaseProvider:
from sqlit.domains.connections.providers.impala.adapter import ImpalaAdapter

return build_adapter_provider(spec, SCHEMA, ImpalaAdapter())


SPEC = ProviderSpec(
db_type="impala",
display_name="Impala",
schema_path=("sqlit.domains.connections.providers.impala.schema", "SCHEMA"),
supports_ssh=True,
is_file_based=False,
has_advanced_auth=True, # Kerberos support
default_port="21050",
requires_auth=False,
badge_label="Impala",
url_schemes=("impala",),
provider_factory=_provider_factory,
)

register_provider(SPEC)
61 changes: 61 additions & 0 deletions sqlit/domains/connections/providers/impala/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Connection schema for Impala."""

from sqlit.domains.connections.providers.schema_helpers import (
SSH_FIELDS,
ConnectionSchema,
FieldType,
SchemaField,
SelectOption,
_password_field,
_port_field,
_server_field,
_username_field,
)


def _get_auth_mechanism_options() -> tuple[SelectOption, ...]:
return (
SelectOption("NOSASL", "No Auth"),
SelectOption("PLAIN", "PLAIN (LDAP)"),
SelectOption("GSSAPI", "Kerberos (GSSAPI)"),
)


SCHEMA = ConnectionSchema(
db_type="impala",
display_name="Impala",
fields=(
_server_field(),
_port_field("21050"),
SchemaField(
name="database",
label="Database",
placeholder="default",
required=False,
),
_username_field(required=False),
_password_field(),
SchemaField(
name="auth_mechanism",
label="Auth Mechanism",
field_type=FieldType.SELECT,
options=_get_auth_mechanism_options(),
default="NOSASL",
advanced=True,
),
SchemaField(
name="use_ssl",
label="Use SSL",
field_type=FieldType.SELECT,
options=(
SelectOption("false", "No"),
SelectOption("true", "Yes"),
),
default="false",
advanced=True,
),
)
+ SSH_FIELDS,
default_port="21050",
requires_auth=False,
)
1 change: 1 addition & 0 deletions sqlit/domains/connections/providers/osquery/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""osquery provider package."""
Loading