Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion sqlmesh/core/config/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from packaging import version
from sqlglot import exp
from sqlglot.helper import subclasses
from sqlglot.errors import ParseError

from sqlmesh.core import engine_adapter
from sqlmesh.core.config.base import BaseConfig
Expand Down Expand Up @@ -1890,6 +1891,7 @@ class TrinoConnectionConfig(ConnectionConfig):

# SQLMesh options
schema_location_mapping: t.Optional[dict[re.Pattern, str]] = None
timestamp_mapping: t.Optional[dict[exp.DataType, exp.DataType]] = None
concurrent_tasks: int = 4
register_comments: bool = True
pre_ping: t.Literal[False] = False
Expand All @@ -1914,6 +1916,34 @@ def _validate_regex_keys(
)
return compiled

@field_validator("timestamp_mapping", mode="before")
@classmethod
def _validate_timestamp_mapping(
cls, value: t.Optional[dict[str, str]]
) -> t.Optional[dict[exp.DataType, exp.DataType]]:
if value is None:
return value

result: dict[exp.DataType, exp.DataType] = {}
for source_type, target_type in value.items():
try:
source_datatype = exp.DataType.build(source_type)
except ParseError:
raise ConfigError(
f"Invalid SQL type string in timestamp_mapping: "
f"'{source_type}' is not a valid SQL data type."
)
try:
target_datatype = exp.DataType.build(target_type)
except ParseError:
raise ConfigError(
f"Invalid SQL type string in timestamp_mapping: "
f"'{target_type}' is not a valid SQL data type."
)
result[source_datatype] = target_datatype

return result

@model_validator(mode="after")
def _root_validator(self) -> Self:
port = self.port
Expand Down Expand Up @@ -2016,7 +2046,10 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:

@property
def _extra_engine_config(self) -> t.Dict[str, t.Any]:
return {"schema_location_mapping": self.schema_location_mapping}
return {
"schema_location_mapping": self.schema_location_mapping,
"timestamp_mapping": self.timestamp_mapping,
}


class ClickhouseConnectionConfig(ConnectionConfig):
Expand Down
49 changes: 43 additions & 6 deletions sqlmesh/core/engine_adapter/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,32 @@ class TrinoEngineAdapter(
def schema_location_mapping(self) -> t.Optional[t.Dict[re.Pattern, str]]:
return self._extra_config.get("schema_location_mapping")

@property
def timestamp_mapping(self) -> t.Optional[t.Dict[exp.DataType, exp.DataType]]:
return self._extra_config.get("timestamp_mapping")

def _apply_timestamp_mapping(
self, columns_to_types: t.Dict[str, exp.DataType]
) -> t.Tuple[t.Dict[str, exp.DataType], t.Set[str]]:
"""Apply custom timestamp mapping to column types.

Returns:
A tuple of (mapped_columns_to_types, mapped_column_names) where mapped_column_names
contains the names of columns that were found in the mapping.
"""
if not self.timestamp_mapping:
return columns_to_types, set()

result = {}
mapped_columns: t.Set[str] = set()
for column, column_type in columns_to_types.items():
if column_type in self.timestamp_mapping:
result[column] = self.timestamp_mapping[column_type]
mapped_columns.add(column)
else:
result[column] = column_type
return result, mapped_columns

@property
def catalog_support(self) -> CatalogSupport:
return CatalogSupport.FULL_SUPPORT
Expand Down Expand Up @@ -117,7 +143,7 @@ def session(self, properties: SessionProperties) -> t.Iterator[None]:
try:
yield
finally:
self.execute(f"RESET SESSION AUTHORIZATION")
self.execute("RESET SESSION AUTHORIZATION")

def replace_query(
self,
Expand Down Expand Up @@ -286,8 +312,11 @@ def _build_schema_exp(
is_view: bool = False,
materialized: bool = False,
) -> exp.Schema:
target_columns_to_types, mapped_columns = self._apply_timestamp_mapping(
target_columns_to_types
)
if "delta_lake" in self.get_catalog_type_from_table(table):
target_columns_to_types = self._to_delta_ts(target_columns_to_types)
target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns)

return super()._build_schema_exp(
table, target_columns_to_types, column_descriptions, expressions, is_view
Expand All @@ -313,10 +342,15 @@ def _scd_type_2(
source_columns: t.Optional[t.List[str]] = None,
**kwargs: t.Any,
) -> None:
mapped_columns: t.Set[str] = set()
if target_columns_to_types:
target_columns_to_types, mapped_columns = self._apply_timestamp_mapping(
target_columns_to_types
)
if target_columns_to_types and "delta_lake" in self.get_catalog_type_from_table(
target_table
):
target_columns_to_types = self._to_delta_ts(target_columns_to_types)
target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns)

return super()._scd_type_2(
target_table,
Expand Down Expand Up @@ -346,18 +380,21 @@ def _scd_type_2(
# - `timestamp(3) with time zone` for timezone-aware
# https://trino.io/docs/current/connector/delta-lake.html#delta-lake-to-trino-type-mapping
def _to_delta_ts(
self, columns_to_types: t.Dict[str, exp.DataType]
self,
columns_to_types: t.Dict[str, exp.DataType],
skip_columns: t.Optional[t.Set[str]] = None,
) -> t.Dict[str, exp.DataType]:
ts6 = exp.DataType.build("timestamp(6)")
ts3_tz = exp.DataType.build("timestamp(3) with time zone")
skip = skip_columns or set()

delta_columns_to_types = {
k: ts6 if v.is_type(exp.DataType.Type.TIMESTAMP) else v
k: ts6 if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMP) else v
for k, v in columns_to_types.items()
}

delta_columns_to_types = {
k: ts3_tz if v.is_type(exp.DataType.Type.TIMESTAMPTZ) else v
k: ts3_tz if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMPTZ) else v
for k, v in delta_columns_to_types.items()
}

Expand Down
187 changes: 187 additions & 0 deletions tests/core/engine_adapter/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,119 @@ def test_delta_timestamps(make_mocked_engine_adapter: t.Callable):
}


def test_timestamp_mapping():
"""Test that timestamp_mapping config property is properly defined and accessible."""
config = TrinoConnectionConfig(
user="user",
host="host",
catalog="catalog",
)

adapter = config.create_engine_adapter()
assert adapter.timestamp_mapping is None

config = TrinoConnectionConfig(
user="user",
host="host",
catalog="catalog",
timestamp_mapping={
"TIMESTAMP": "TIMESTAMP(6)",
"TIMESTAMP(3)": "TIMESTAMP WITH TIME ZONE",
},
)
adapter = config.create_engine_adapter()
assert adapter.timestamp_mapping is not None
assert adapter.timestamp_mapping[exp.DataType.build("TIMESTAMP")] == exp.DataType.build(
"TIMESTAMP(6)"
)


def test_delta_timestamps_with_custom_mapping(make_mocked_engine_adapter: t.Callable):
"""Test that _apply_timestamp_mapping + _to_delta_ts respects custom timestamp_mapping."""
# Create config with custom timestamp mapping
# Mapped columns are skipped by _to_delta_ts
config = TrinoConnectionConfig(
user="user",
host="host",
catalog="catalog",
timestamp_mapping={
"TIMESTAMP": "TIMESTAMP(3)",
"TIMESTAMP(1)": "TIMESTAMP(3)",
"TIMESTAMP WITH TIME ZONE": "TIMESTAMP(6) WITH TIME ZONE",
"TIMESTAMP(1) WITH TIME ZONE": "TIMESTAMP(6) WITH TIME ZONE",
},
)

adapter = make_mocked_engine_adapter(
TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping
)

ts3 = exp.DataType.build("timestamp(3)")
ts6_tz = exp.DataType.build("timestamp(6) with time zone")

columns_to_types = {
"ts": exp.DataType.build("TIMESTAMP"),
"ts_1": exp.DataType.build("TIMESTAMP(1)"),
"ts_tz": exp.DataType.build("TIMESTAMP WITH TIME ZONE"),
"ts_tz_1": exp.DataType.build("TIMESTAMP(1) WITH TIME ZONE"),
}

# Apply mapping first, then convert to delta types (skipping mapped columns)
mapped_columns_to_types, mapped_column_names = adapter._apply_timestamp_mapping(
columns_to_types
)
delta_columns_to_types = adapter._to_delta_ts(mapped_columns_to_types, mapped_column_names)

# All types were mapped, so _to_delta_ts skips them - they keep their mapped types
assert delta_columns_to_types == {
"ts": ts3,
"ts_1": ts3,
"ts_tz": ts6_tz,
"ts_tz_1": ts6_tz,
}


def test_delta_timestamps_with_partial_mapping(make_mocked_engine_adapter: t.Callable):
"""Test that _apply_timestamp_mapping + _to_delta_ts uses custom mapping for specified types."""
config = TrinoConnectionConfig(
user="user",
host="host",
catalog="catalog",
timestamp_mapping={
"TIMESTAMP": "TIMESTAMP(3)",
},
)

adapter = make_mocked_engine_adapter(
TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping
)

ts3 = exp.DataType.build("TIMESTAMP(3)")
ts6 = exp.DataType.build("timestamp(6)")
ts3_tz = exp.DataType.build("timestamp(3) with time zone")

columns_to_types = {
"ts": exp.DataType.build("TIMESTAMP"),
"ts_1": exp.DataType.build("TIMESTAMP(1)"),
"ts_tz": exp.DataType.build("TIMESTAMP WITH TIME ZONE"),
}

# Apply mapping first, then convert to delta types (skipping mapped columns)
mapped_columns_to_types, mapped_column_names = adapter._apply_timestamp_mapping(
columns_to_types
)
delta_columns_to_types = adapter._to_delta_ts(mapped_columns_to_types, mapped_column_names)

# TIMESTAMP is in mapping → TIMESTAMP(3), skipped by _to_delta_ts
# TIMESTAMP(1) is NOT in mapping, uses default TIMESTAMP → ts6
# TIMESTAMP WITH TIME ZONE is NOT in mapping, uses default TIMESTAMPTZ → ts3_tz
assert delta_columns_to_types == {
"ts": ts3, # Mapped to TIMESTAMP(3), skipped by _to_delta_ts
"ts_1": ts6, # Not in mapping, uses default
"ts_tz": ts3_tz, # Not in mapping, uses default
}


def test_table_format(trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture):
adapter = trino_mocked_engine_adapter
mocker.patch(
Expand Down Expand Up @@ -755,3 +868,77 @@ def test_insert_overwrite_time_partition_iceberg(
'DELETE FROM "my_catalog"."schema"."test_table" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'',
'INSERT INTO "my_catalog"."schema"."test_table" ("a", "b") SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'',
]


def test_delta_timestamps_with_non_timestamp_columns(make_mocked_engine_adapter: t.Callable):
"""Test that _apply_timestamp_mapping + _to_delta_ts handles non-timestamp columns."""
config = TrinoConnectionConfig(
user="user",
host="host",
catalog="catalog",
timestamp_mapping={
"TIMESTAMP": "TIMESTAMP(3)",
},
)

adapter = make_mocked_engine_adapter(
TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping
)

ts3 = exp.DataType.build("TIMESTAMP(3)")
ts6 = exp.DataType.build("timestamp(6)")

columns_to_types = {
"ts": exp.DataType.build("TIMESTAMP"),
"ts_1": exp.DataType.build("TIMESTAMP(1)"),
"int_col": exp.DataType.build("INT"),
"varchar_col": exp.DataType.build("VARCHAR(100)"),
"decimal_col": exp.DataType.build("DECIMAL(10,2)"),
}

# Apply mapping first, then convert to delta types (skipping mapped columns)
mapped_columns_to_types, mapped_column_names = adapter._apply_timestamp_mapping(
columns_to_types
)
delta_columns_to_types = adapter._to_delta_ts(mapped_columns_to_types, mapped_column_names)

# TIMESTAMP is in mapping → TIMESTAMP(3), skipped by _to_delta_ts
# TIMESTAMP(1) is NOT in mapping (exact match), uses default TIMESTAMP → ts6
# Non-timestamp columns should pass through unchanged
assert delta_columns_to_types == {
"ts": ts3, # Mapped to TIMESTAMP(3), skipped by _to_delta_ts
"ts_1": ts6, # Not in mapping, uses default
"int_col": exp.DataType.build("INT"),
"varchar_col": exp.DataType.build("VARCHAR(100)"),
"decimal_col": exp.DataType.build("DECIMAL(10,2)"),
}


def test_delta_timestamps_with_empty_mapping(make_mocked_engine_adapter: t.Callable):
"""Test that _to_delta_ts handles empty custom mapping dictionary."""
config = TrinoConnectionConfig(
user="user",
host="host",
catalog="catalog",
timestamp_mapping={},
)

adapter = make_mocked_engine_adapter(
TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping
)

ts6 = exp.DataType.build("timestamp(6)")
ts3_tz = exp.DataType.build("timestamp(3) with time zone")

columns_to_types = {
"ts": exp.DataType.build("TIMESTAMP"),
"ts_tz": exp.DataType.build("TIMESTAMP WITH TIME ZONE"),
}

delta_columns_to_types = adapter._to_delta_ts(columns_to_types)

# With empty custom mapping, should fall back to defaults
assert delta_columns_to_types == {
"ts": ts6,
"ts_tz": ts3_tz,
}
Loading
Loading