Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 50 additions & 30 deletions sqlit/domains/connections/providers/teradata/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any

from sqlit.domains.connections.providers.adapters.base import (
ColumnInfo,
CursorBasedAdapter,
IndexInfo,
SequenceInfo,
TableInfo,
TriggerInfo,
)
Expand Down Expand Up @@ -49,9 +49,27 @@ def supports_cross_database_queries(self) -> bool:
def supports_stored_procedures(self) -> bool:
return True

@property
def supports_sequences(self) -> bool:
return True
_TERADATA_SELECT_KEYWORDS = frozenset(
{"SELECT", "SEL", "WITH", "SHOW", "DESCRIBE", "EXPLAIN", "HELP"}
)

_LOCKING_RE = re.compile(
r"\bFOR\s+(?:ACCESS|READ|WRITE|EXCLUSIVE)(?:\s+NOWAIT)?\s+(\w+)",
re.IGNORECASE,
)

def classify_query(self, query: str) -> bool:
"""Classify Teradata queries, handling LOCKING/LOCK prefix and SEL abbreviation."""
query_upper = query.strip().upper()
first_word = query_upper.split()[0] if query_upper else ""

# Strip LOCKING/LOCK request modifier to find the actual statement keyword
if first_word in ("LOCKING", "LOCK"):
match = self._LOCKING_RE.search(query_upper)
if match:
first_word = match.group(1)

return first_word in self._TERADATA_SELECT_KEYWORDS

def apply_database_override(self, config: ConnectionConfig, database: str) -> ConnectionConfig:
"""Apply a default database for unqualified queries."""
Expand Down Expand Up @@ -91,8 +109,9 @@ def connect(self, config: ConnectionConfig) -> Any:
def get_databases(self, conn: Any) -> list[str]:
cursor = conn.cursor()
cursor.execute(
"lock row for access "
"SELECT DatabaseName FROM DBC.DatabasesV "
"WHERE DatabaseKind IN ('D', 'U') "
"WHERE dbkind IN ('D', 'U') "
"ORDER BY DatabaseName"
)
return [row[0] for row in cursor.fetchall()]
Expand All @@ -101,13 +120,15 @@ def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]:
cursor = conn.cursor()
if database:
cursor.execute(
"lock row for access "
"SELECT DatabaseName, TableName FROM DBC.TablesV "
"WHERE TableKind = 'T' AND DatabaseName = ? "
"ORDER BY TableName",
(database,),
)
else:
cursor.execute(
"lock row for access "
"SELECT DatabaseName, TableName FROM DBC.TablesV "
"WHERE TableKind = 'T' "
"ORDER BY DatabaseName, TableName"
Expand All @@ -118,13 +139,15 @@ def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]:
cursor = conn.cursor()
if database:
cursor.execute(
"lock row for access "
"SELECT DatabaseName, TableName FROM DBC.TablesV "
"WHERE TableKind = 'V' AND DatabaseName = ? "
"ORDER BY TableName",
(database,),
)
else:
cursor.execute(
"lock row for access "
"SELECT DatabaseName, TableName FROM DBC.TablesV "
"WHERE TableKind = 'V' "
"ORDER BY DatabaseName, TableName"
Expand All @@ -142,21 +165,21 @@ def get_columns(
pk_columns: set[str] = set()
try:
cursor.execute(
"SELECT ic.ColumnName "
"FROM DBC.IndexConstraintsV c "
"JOIN DBC.IndexColumnsV ic "
" ON c.DatabaseName = ic.DatabaseName "
" AND c.TableName = ic.TableName "
" AND c.IndexNumber = ic.IndexNumber "
"WHERE c.ConstraintType = 'P' "
"AND c.DatabaseName = ? AND c.TableName = ?",
"lock row for access "
"select "
"COLUMNNAME "
"from DBC.INDICESV "
"where DATABASENAME = ? "
"and TABLENAME = ? "
"and INDEXTYPE = 'P' ",
(schema_name, table),
)
pk_columns = {row[0] for row in cursor.fetchall()}
except Exception:
pk_columns = set()

cursor.execute(
"lock row for access "
"SELECT ColumnName, ColumnType FROM DBC.ColumnsV "
"WHERE DatabaseName = ? AND TableName = ? "
"ORDER BY ColumnId",
Expand All @@ -171,13 +194,15 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]:
cursor = conn.cursor()
if database:
cursor.execute(
"lock row for access "
"SELECT TableName FROM DBC.TablesV "
"WHERE TableKind = 'P' AND DatabaseName = ? "
"ORDER BY TableName",
(database,),
)
else:
cursor.execute(
"lock row for access "
"SELECT TableName FROM DBC.TablesV "
"WHERE TableKind = 'P' "
"ORDER BY TableName"
Expand All @@ -188,13 +213,15 @@ def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]
cursor = conn.cursor()
if database:
cursor.execute(
"lock row for access "
"SELECT IndexName, TableName, UniqueFlag FROM DBC.IndicesV "
"WHERE DatabaseName = ? "
"ORDER BY TableName, IndexName",
(database,),
)
else:
cursor.execute(
"lock row for access "
"SELECT IndexName, TableName, UniqueFlag FROM DBC.IndicesV "
"ORDER BY DatabaseName, TableName, IndexName"
)
Expand All @@ -207,33 +234,26 @@ def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerIn
cursor = conn.cursor()
if database:
cursor.execute(
"lock row for access "
"SELECT TriggerName, TableName FROM DBC.TriggersV "
"WHERE DatabaseName = ? "
"ORDER BY TableName, TriggerName",
(database,),
)
else:
cursor.execute(
"lock row for access "
"SELECT TriggerName, TableName FROM DBC.TriggersV "
"ORDER BY DatabaseName, TableName, TriggerName"
)
return [TriggerInfo(name=row[0], table_name=row[1]) for row in cursor.fetchall()]

def get_sequences(self, conn: Any, database: str | None = None) -> list[SequenceInfo]:
cursor = conn.cursor()
if database:
cursor.execute(
"SELECT SequenceName FROM DBC.SequencesV "
"WHERE DatabaseName = ? "
"ORDER BY SequenceName",
(database,),
)
else:
cursor.execute(
"SELECT SequenceName FROM DBC.SequencesV "
"ORDER BY DatabaseName, SequenceName"
)
return [SequenceInfo(name=row[0]) for row in cursor.fetchall()]
def get_sequences(self, conn: Any, database: str | None = None) -> list[str]:
"""Teradata does not support standalone sequences.

Auto-increment behaviour is provided by IDENTITY columns instead.
"""
return []

def quote_identifier(self, name: str) -> str:
escaped = name.replace('"', '""')
Expand All @@ -242,5 +262,5 @@ def quote_identifier(self, name: str) -> str:
def build_select_query(self, table: str, limit: int, database: str | None = None, schema: str | None = None) -> str:
schema_name = schema or database
if schema_name:
return f'SELECT TOP {limit} * FROM "{schema_name}"."{table}"'
return f'SELECT TOP {limit} * FROM "{table}"'
return f'lock row for access select top {limit} * from "{schema_name}"."{table}"'
return f'lock row for access select top {limit} * from "{table}"'