From c2f818eec25c5df1635bf168455a4673447bb260 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Wed, 17 Dec 2025 12:37:07 -0800 Subject: [PATCH] Feat(trino): Introduce custom timestamp type mapping --- sqlmesh/core/config/connection.py | 35 ++++- sqlmesh/core/engine_adapter/trino.py | 49 ++++++- tests/core/engine_adapter/test_trino.py | 187 ++++++++++++++++++++++++ tests/core/test_connection_config.py | 58 ++++++++ 4 files changed, 322 insertions(+), 7 deletions(-) diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index d89d896897..638f0c28c8 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -17,6 +17,7 @@ from packaging import version from sqlglot import exp from sqlglot.helper import subclasses +from sqlglot.errors import ParseError from sqlmesh.core import engine_adapter from sqlmesh.core.config.base import BaseConfig @@ -1890,6 +1891,7 @@ class TrinoConnectionConfig(ConnectionConfig): # SQLMesh options schema_location_mapping: t.Optional[dict[re.Pattern, str]] = None + timestamp_mapping: t.Optional[dict[exp.DataType, exp.DataType]] = None concurrent_tasks: int = 4 register_comments: bool = True pre_ping: t.Literal[False] = False @@ -1914,6 +1916,34 @@ def _validate_regex_keys( ) return compiled + @field_validator("timestamp_mapping", mode="before") + @classmethod + def _validate_timestamp_mapping( + cls, value: t.Optional[dict[str, str]] + ) -> t.Optional[dict[exp.DataType, exp.DataType]]: + if value is None: + return value + + result: dict[exp.DataType, exp.DataType] = {} + for source_type, target_type in value.items(): + try: + source_datatype = exp.DataType.build(source_type) + except ParseError: + raise ConfigError( + f"Invalid SQL type string in timestamp_mapping: " + f"'{source_type}' is not a valid SQL data type." + ) + try: + target_datatype = exp.DataType.build(target_type) + except ParseError: + raise ConfigError( + f"Invalid SQL type string in timestamp_mapping: " + f"'{target_type}' is not a valid SQL data type." + ) + result[source_datatype] = target_datatype + + return result + @model_validator(mode="after") def _root_validator(self) -> Self: port = self.port @@ -2016,7 +2046,10 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: @property def _extra_engine_config(self) -> t.Dict[str, t.Any]: - return {"schema_location_mapping": self.schema_location_mapping} + return { + "schema_location_mapping": self.schema_location_mapping, + "timestamp_mapping": self.timestamp_mapping, + } class ClickhouseConnectionConfig(ConnectionConfig): diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index 74df3667ff..89470728f2 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -74,6 +74,32 @@ class TrinoEngineAdapter( def schema_location_mapping(self) -> t.Optional[t.Dict[re.Pattern, str]]: return self._extra_config.get("schema_location_mapping") + @property + def timestamp_mapping(self) -> t.Optional[t.Dict[exp.DataType, exp.DataType]]: + return self._extra_config.get("timestamp_mapping") + + def _apply_timestamp_mapping( + self, columns_to_types: t.Dict[str, exp.DataType] + ) -> t.Tuple[t.Dict[str, exp.DataType], t.Set[str]]: + """Apply custom timestamp mapping to column types. + + Returns: + A tuple of (mapped_columns_to_types, mapped_column_names) where mapped_column_names + contains the names of columns that were found in the mapping. + """ + if not self.timestamp_mapping: + return columns_to_types, set() + + result = {} + mapped_columns: t.Set[str] = set() + for column, column_type in columns_to_types.items(): + if column_type in self.timestamp_mapping: + result[column] = self.timestamp_mapping[column_type] + mapped_columns.add(column) + else: + result[column] = column_type + return result, mapped_columns + @property def catalog_support(self) -> CatalogSupport: return CatalogSupport.FULL_SUPPORT @@ -117,7 +143,7 @@ def session(self, properties: SessionProperties) -> t.Iterator[None]: try: yield finally: - self.execute(f"RESET SESSION AUTHORIZATION") + self.execute("RESET SESSION AUTHORIZATION") def replace_query( self, @@ -286,8 +312,11 @@ def _build_schema_exp( is_view: bool = False, materialized: bool = False, ) -> exp.Schema: + target_columns_to_types, mapped_columns = self._apply_timestamp_mapping( + target_columns_to_types + ) if "delta_lake" in self.get_catalog_type_from_table(table): - target_columns_to_types = self._to_delta_ts(target_columns_to_types) + target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns) return super()._build_schema_exp( table, target_columns_to_types, column_descriptions, expressions, is_view @@ -313,10 +342,15 @@ def _scd_type_2( source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: + mapped_columns: t.Set[str] = set() + if target_columns_to_types: + target_columns_to_types, mapped_columns = self._apply_timestamp_mapping( + target_columns_to_types + ) if target_columns_to_types and "delta_lake" in self.get_catalog_type_from_table( target_table ): - target_columns_to_types = self._to_delta_ts(target_columns_to_types) + target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns) return super()._scd_type_2( target_table, @@ -346,18 +380,21 @@ def _scd_type_2( # - `timestamp(3) with time zone` for timezone-aware # https://trino.io/docs/current/connector/delta-lake.html#delta-lake-to-trino-type-mapping def _to_delta_ts( - self, columns_to_types: t.Dict[str, exp.DataType] + self, + columns_to_types: t.Dict[str, exp.DataType], + skip_columns: t.Optional[t.Set[str]] = None, ) -> t.Dict[str, exp.DataType]: ts6 = exp.DataType.build("timestamp(6)") ts3_tz = exp.DataType.build("timestamp(3) with time zone") + skip = skip_columns or set() delta_columns_to_types = { - k: ts6 if v.is_type(exp.DataType.Type.TIMESTAMP) else v + k: ts6 if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMP) else v for k, v in columns_to_types.items() } delta_columns_to_types = { - k: ts3_tz if v.is_type(exp.DataType.Type.TIMESTAMPTZ) else v + k: ts3_tz if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMPTZ) else v for k, v in delta_columns_to_types.items() } diff --git a/tests/core/engine_adapter/test_trino.py b/tests/core/engine_adapter/test_trino.py index bf925c875a..a3c67eb023 100644 --- a/tests/core/engine_adapter/test_trino.py +++ b/tests/core/engine_adapter/test_trino.py @@ -404,6 +404,119 @@ def test_delta_timestamps(make_mocked_engine_adapter: t.Callable): } +def test_timestamp_mapping(): + """Test that timestamp_mapping config property is properly defined and accessible.""" + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + ) + + adapter = config.create_engine_adapter() + assert adapter.timestamp_mapping is None + + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + timestamp_mapping={ + "TIMESTAMP": "TIMESTAMP(6)", + "TIMESTAMP(3)": "TIMESTAMP WITH TIME ZONE", + }, + ) + adapter = config.create_engine_adapter() + assert adapter.timestamp_mapping is not None + assert adapter.timestamp_mapping[exp.DataType.build("TIMESTAMP")] == exp.DataType.build( + "TIMESTAMP(6)" + ) + + +def test_delta_timestamps_with_custom_mapping(make_mocked_engine_adapter: t.Callable): + """Test that _apply_timestamp_mapping + _to_delta_ts respects custom timestamp_mapping.""" + # Create config with custom timestamp mapping + # Mapped columns are skipped by _to_delta_ts + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + timestamp_mapping={ + "TIMESTAMP": "TIMESTAMP(3)", + "TIMESTAMP(1)": "TIMESTAMP(3)", + "TIMESTAMP WITH TIME ZONE": "TIMESTAMP(6) WITH TIME ZONE", + "TIMESTAMP(1) WITH TIME ZONE": "TIMESTAMP(6) WITH TIME ZONE", + }, + ) + + adapter = make_mocked_engine_adapter( + TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping + ) + + ts3 = exp.DataType.build("timestamp(3)") + ts6_tz = exp.DataType.build("timestamp(6) with time zone") + + columns_to_types = { + "ts": exp.DataType.build("TIMESTAMP"), + "ts_1": exp.DataType.build("TIMESTAMP(1)"), + "ts_tz": exp.DataType.build("TIMESTAMP WITH TIME ZONE"), + "ts_tz_1": exp.DataType.build("TIMESTAMP(1) WITH TIME ZONE"), + } + + # Apply mapping first, then convert to delta types (skipping mapped columns) + mapped_columns_to_types, mapped_column_names = adapter._apply_timestamp_mapping( + columns_to_types + ) + delta_columns_to_types = adapter._to_delta_ts(mapped_columns_to_types, mapped_column_names) + + # All types were mapped, so _to_delta_ts skips them - they keep their mapped types + assert delta_columns_to_types == { + "ts": ts3, + "ts_1": ts3, + "ts_tz": ts6_tz, + "ts_tz_1": ts6_tz, + } + + +def test_delta_timestamps_with_partial_mapping(make_mocked_engine_adapter: t.Callable): + """Test that _apply_timestamp_mapping + _to_delta_ts uses custom mapping for specified types.""" + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + timestamp_mapping={ + "TIMESTAMP": "TIMESTAMP(3)", + }, + ) + + adapter = make_mocked_engine_adapter( + TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping + ) + + ts3 = exp.DataType.build("TIMESTAMP(3)") + ts6 = exp.DataType.build("timestamp(6)") + ts3_tz = exp.DataType.build("timestamp(3) with time zone") + + columns_to_types = { + "ts": exp.DataType.build("TIMESTAMP"), + "ts_1": exp.DataType.build("TIMESTAMP(1)"), + "ts_tz": exp.DataType.build("TIMESTAMP WITH TIME ZONE"), + } + + # Apply mapping first, then convert to delta types (skipping mapped columns) + mapped_columns_to_types, mapped_column_names = adapter._apply_timestamp_mapping( + columns_to_types + ) + delta_columns_to_types = adapter._to_delta_ts(mapped_columns_to_types, mapped_column_names) + + # TIMESTAMP is in mapping → TIMESTAMP(3), skipped by _to_delta_ts + # TIMESTAMP(1) is NOT in mapping, uses default TIMESTAMP → ts6 + # TIMESTAMP WITH TIME ZONE is NOT in mapping, uses default TIMESTAMPTZ → ts3_tz + assert delta_columns_to_types == { + "ts": ts3, # Mapped to TIMESTAMP(3), skipped by _to_delta_ts + "ts_1": ts6, # Not in mapping, uses default + "ts_tz": ts3_tz, # Not in mapping, uses default + } + + def test_table_format(trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture): adapter = trino_mocked_engine_adapter mocker.patch( @@ -755,3 +868,77 @@ def test_insert_overwrite_time_partition_iceberg( 'DELETE FROM "my_catalog"."schema"."test_table" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'', 'INSERT INTO "my_catalog"."schema"."test_table" ("a", "b") SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'', ] + + +def test_delta_timestamps_with_non_timestamp_columns(make_mocked_engine_adapter: t.Callable): + """Test that _apply_timestamp_mapping + _to_delta_ts handles non-timestamp columns.""" + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + timestamp_mapping={ + "TIMESTAMP": "TIMESTAMP(3)", + }, + ) + + adapter = make_mocked_engine_adapter( + TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping + ) + + ts3 = exp.DataType.build("TIMESTAMP(3)") + ts6 = exp.DataType.build("timestamp(6)") + + columns_to_types = { + "ts": exp.DataType.build("TIMESTAMP"), + "ts_1": exp.DataType.build("TIMESTAMP(1)"), + "int_col": exp.DataType.build("INT"), + "varchar_col": exp.DataType.build("VARCHAR(100)"), + "decimal_col": exp.DataType.build("DECIMAL(10,2)"), + } + + # Apply mapping first, then convert to delta types (skipping mapped columns) + mapped_columns_to_types, mapped_column_names = adapter._apply_timestamp_mapping( + columns_to_types + ) + delta_columns_to_types = adapter._to_delta_ts(mapped_columns_to_types, mapped_column_names) + + # TIMESTAMP is in mapping → TIMESTAMP(3), skipped by _to_delta_ts + # TIMESTAMP(1) is NOT in mapping (exact match), uses default TIMESTAMP → ts6 + # Non-timestamp columns should pass through unchanged + assert delta_columns_to_types == { + "ts": ts3, # Mapped to TIMESTAMP(3), skipped by _to_delta_ts + "ts_1": ts6, # Not in mapping, uses default + "int_col": exp.DataType.build("INT"), + "varchar_col": exp.DataType.build("VARCHAR(100)"), + "decimal_col": exp.DataType.build("DECIMAL(10,2)"), + } + + +def test_delta_timestamps_with_empty_mapping(make_mocked_engine_adapter: t.Callable): + """Test that _to_delta_ts handles empty custom mapping dictionary.""" + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + timestamp_mapping={}, + ) + + adapter = make_mocked_engine_adapter( + TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping + ) + + ts6 = exp.DataType.build("timestamp(6)") + ts3_tz = exp.DataType.build("timestamp(3) with time zone") + + columns_to_types = { + "ts": exp.DataType.build("TIMESTAMP"), + "ts_tz": exp.DataType.build("TIMESTAMP WITH TIME ZONE"), + } + + delta_columns_to_types = adapter._to_delta_ts(columns_to_types) + + # With empty custom mapping, should fall back to defaults + assert delta_columns_to_types == { + "ts": ts6, + "ts_tz": ts3_tz, + } diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index a0d54e03dd..dd979a2551 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -4,6 +4,7 @@ import pytest from _pytest.fixtures import FixtureRequest +from sqlglot import exp from unittest.mock import patch, MagicMock from sqlmesh.core.config.connection import ( @@ -444,6 +445,63 @@ def test_trino_catalog_type_override(make_config): assert config.catalog_type_overrides == {"my_catalog": "iceberg"} +def test_trino_timestamp_mapping(make_config): + required_kwargs = dict( + type="trino", + user="user", + host="host", + catalog="catalog", + ) + + # Test config without timestamp_mapping + config = make_config(**required_kwargs) + assert config.timestamp_mapping is None + + # Test config with timestamp_mapping + config = make_config( + **required_kwargs, + timestamp_mapping={ + "TIMESTAMP": "TIMESTAMP(6)", + "TIMESTAMP(3)": "TIMESTAMP WITH TIME ZONE", + }, + ) + + assert config.timestamp_mapping is not None + assert config.timestamp_mapping[exp.DataType.build("TIMESTAMP")] == exp.DataType.build( + "TIMESTAMP(6)" + ) + + # Test with invalid source type + with pytest.raises(ConfigError) as exc_info: + make_config( + **required_kwargs, + timestamp_mapping={ + "INVALID_TYPE": "TIMESTAMP", + }, + ) + assert "Invalid SQL type string" in str(exc_info.value) + assert "INVALID_TYPE" in str(exc_info.value) + + # Test with invalid target type (not a valid SQL type) + with pytest.raises(ConfigError) as exc_info: + make_config( + **required_kwargs, + timestamp_mapping={ + "TIMESTAMP": "INVALID_TARGET_TYPE", + }, + ) + assert "Invalid SQL type string" in str(exc_info.value) + assert "INVALID_TARGET_TYPE" in str(exc_info.value) + + # Test with empty mapping + config = make_config( + **required_kwargs, + timestamp_mapping={}, + ) + assert config.timestamp_mapping is not None + assert config.timestamp_mapping == {} + + def test_duckdb(make_config): config = make_config( type="duckdb",