diff --git a/sqlit/domains/connections/providers/teradata/adapter.py b/sqlit/domains/connections/providers/teradata/adapter.py index 29fe886f..aa8e0201 100644 --- a/sqlit/domains/connections/providers/teradata/adapter.py +++ b/sqlit/domains/connections/providers/teradata/adapter.py @@ -2,13 +2,13 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING, Any from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, IndexInfo, - SequenceInfo, TableInfo, TriggerInfo, ) @@ -49,9 +49,27 @@ def supports_cross_database_queries(self) -> bool: def supports_stored_procedures(self) -> bool: return True - @property - def supports_sequences(self) -> bool: - return True + _TERADATA_SELECT_KEYWORDS = frozenset( + {"SELECT", "SEL", "WITH", "SHOW", "DESCRIBE", "EXPLAIN", "HELP"} + ) + + _LOCKING_RE = re.compile( + r"\bFOR\s+(?:ACCESS|READ|WRITE|EXCLUSIVE)(?:\s+NOWAIT)?\s+(\w+)", + re.IGNORECASE, + ) + + def classify_query(self, query: str) -> bool: + """Classify Teradata queries, handling LOCKING/LOCK prefix and SEL abbreviation.""" + query_upper = query.strip().upper() + first_word = query_upper.split()[0] if query_upper else "" + + # Strip LOCKING/LOCK request modifier to find the actual statement keyword + if first_word in ("LOCKING", "LOCK"): + match = self._LOCKING_RE.search(query_upper) + if match: + first_word = match.group(1) + + return first_word in self._TERADATA_SELECT_KEYWORDS def apply_database_override(self, config: ConnectionConfig, database: str) -> ConnectionConfig: """Apply a default database for unqualified queries.""" @@ -91,8 +109,9 @@ def connect(self, config: ConnectionConfig) -> Any: def get_databases(self, conn: Any) -> list[str]: cursor = conn.cursor() cursor.execute( + "lock row for access " "SELECT DatabaseName FROM DBC.DatabasesV " - "WHERE DatabaseKind IN ('D', 'U') " + "WHERE dbkind IN ('D', 'U') " "ORDER BY DatabaseName" ) return [row[0] for row in cursor.fetchall()] @@ -101,6 +120,7 @@ def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]: cursor = conn.cursor() if database: cursor.execute( + "lock row for access " "SELECT DatabaseName, TableName FROM DBC.TablesV " "WHERE TableKind = 'T' AND DatabaseName = ? " "ORDER BY TableName", @@ -108,6 +128,7 @@ def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]: ) else: cursor.execute( + "lock row for access " "SELECT DatabaseName, TableName FROM DBC.TablesV " "WHERE TableKind = 'T' " "ORDER BY DatabaseName, TableName" @@ -118,6 +139,7 @@ def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: cursor = conn.cursor() if database: cursor.execute( + "lock row for access " "SELECT DatabaseName, TableName FROM DBC.TablesV " "WHERE TableKind = 'V' AND DatabaseName = ? " "ORDER BY TableName", @@ -125,6 +147,7 @@ def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: ) else: cursor.execute( + "lock row for access " "SELECT DatabaseName, TableName FROM DBC.TablesV " "WHERE TableKind = 'V' " "ORDER BY DatabaseName, TableName" @@ -142,14 +165,13 @@ def get_columns( pk_columns: set[str] = set() try: cursor.execute( - "SELECT ic.ColumnName " - "FROM DBC.IndexConstraintsV c " - "JOIN DBC.IndexColumnsV ic " - " ON c.DatabaseName = ic.DatabaseName " - " AND c.TableName = ic.TableName " - " AND c.IndexNumber = ic.IndexNumber " - "WHERE c.ConstraintType = 'P' " - "AND c.DatabaseName = ? AND c.TableName = ?", + "lock row for access " + "select " + "COLUMNNAME " + "from DBC.INDICESV " + "where DATABASENAME = ? " + "and TABLENAME = ? " + "and INDEXTYPE = 'P' ", (schema_name, table), ) pk_columns = {row[0] for row in cursor.fetchall()} @@ -157,6 +179,7 @@ def get_columns( pk_columns = set() cursor.execute( + "lock row for access " "SELECT ColumnName, ColumnType FROM DBC.ColumnsV " "WHERE DatabaseName = ? AND TableName = ? " "ORDER BY ColumnId", @@ -171,6 +194,7 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: cursor = conn.cursor() if database: cursor.execute( + "lock row for access " "SELECT TableName FROM DBC.TablesV " "WHERE TableKind = 'P' AND DatabaseName = ? " "ORDER BY TableName", @@ -178,6 +202,7 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: ) else: cursor.execute( + "lock row for access " "SELECT TableName FROM DBC.TablesV " "WHERE TableKind = 'P' " "ORDER BY TableName" @@ -188,6 +213,7 @@ def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo] cursor = conn.cursor() if database: cursor.execute( + "lock row for access " "SELECT IndexName, TableName, UniqueFlag FROM DBC.IndicesV " "WHERE DatabaseName = ? " "ORDER BY TableName, IndexName", @@ -195,6 +221,7 @@ def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo] ) else: cursor.execute( + "lock row for access " "SELECT IndexName, TableName, UniqueFlag FROM DBC.IndicesV " "ORDER BY DatabaseName, TableName, IndexName" ) @@ -207,6 +234,7 @@ def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerIn cursor = conn.cursor() if database: cursor.execute( + "lock row for access " "SELECT TriggerName, TableName FROM DBC.TriggersV " "WHERE DatabaseName = ? " "ORDER BY TableName, TriggerName", @@ -214,26 +242,18 @@ def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerIn ) else: cursor.execute( + "lock row for access " "SELECT TriggerName, TableName FROM DBC.TriggersV " "ORDER BY DatabaseName, TableName, TriggerName" ) return [TriggerInfo(name=row[0], table_name=row[1]) for row in cursor.fetchall()] - def get_sequences(self, conn: Any, database: str | None = None) -> list[SequenceInfo]: - cursor = conn.cursor() - if database: - cursor.execute( - "SELECT SequenceName FROM DBC.SequencesV " - "WHERE DatabaseName = ? " - "ORDER BY SequenceName", - (database,), - ) - else: - cursor.execute( - "SELECT SequenceName FROM DBC.SequencesV " - "ORDER BY DatabaseName, SequenceName" - ) - return [SequenceInfo(name=row[0]) for row in cursor.fetchall()] + def get_sequences(self, conn: Any, database: str | None = None) -> list[str]: + """Teradata does not support standalone sequences. + + Auto-increment behaviour is provided by IDENTITY columns instead. + """ + return [] def quote_identifier(self, name: str) -> str: escaped = name.replace('"', '""') @@ -242,5 +262,5 @@ def quote_identifier(self, name: str) -> str: def build_select_query(self, table: str, limit: int, database: str | None = None, schema: str | None = None) -> str: schema_name = schema or database if schema_name: - return f'SELECT TOP {limit} * FROM "{schema_name}"."{table}"' - return f'SELECT TOP {limit} * FROM "{table}"' + return f'lock row for access select top {limit} * from "{schema_name}"."{table}"' + return f'lock row for access select top {limit} * from "{table}"'