Skip to content

Commit c2f818e

Browse files
committed
Feat(trino): Introduce custom timestamp type mapping
1 parent d719391 commit c2f818e

File tree

4 files changed

+322
-7
lines changed

4 files changed

+322
-7
lines changed

sqlmesh/core/config/connection.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from packaging import version
1818
from sqlglot import exp
1919
from sqlglot.helper import subclasses
20+
from sqlglot.errors import ParseError
2021

2122
from sqlmesh.core import engine_adapter
2223
from sqlmesh.core.config.base import BaseConfig
@@ -1890,6 +1891,7 @@ class TrinoConnectionConfig(ConnectionConfig):
18901891

18911892
# SQLMesh options
18921893
schema_location_mapping: t.Optional[dict[re.Pattern, str]] = None
1894+
timestamp_mapping: t.Optional[dict[exp.DataType, exp.DataType]] = None
18931895
concurrent_tasks: int = 4
18941896
register_comments: bool = True
18951897
pre_ping: t.Literal[False] = False
@@ -1914,6 +1916,34 @@ def _validate_regex_keys(
19141916
)
19151917
return compiled
19161918

1919+
@field_validator("timestamp_mapping", mode="before")
1920+
@classmethod
1921+
def _validate_timestamp_mapping(
1922+
cls, value: t.Optional[dict[str, str]]
1923+
) -> t.Optional[dict[exp.DataType, exp.DataType]]:
1924+
if value is None:
1925+
return value
1926+
1927+
result: dict[exp.DataType, exp.DataType] = {}
1928+
for source_type, target_type in value.items():
1929+
try:
1930+
source_datatype = exp.DataType.build(source_type)
1931+
except ParseError:
1932+
raise ConfigError(
1933+
f"Invalid SQL type string in timestamp_mapping: "
1934+
f"'{source_type}' is not a valid SQL data type."
1935+
)
1936+
try:
1937+
target_datatype = exp.DataType.build(target_type)
1938+
except ParseError:
1939+
raise ConfigError(
1940+
f"Invalid SQL type string in timestamp_mapping: "
1941+
f"'{target_type}' is not a valid SQL data type."
1942+
)
1943+
result[source_datatype] = target_datatype
1944+
1945+
return result
1946+
19171947
@model_validator(mode="after")
19181948
def _root_validator(self) -> Self:
19191949
port = self.port
@@ -2016,7 +2046,10 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
20162046

20172047
@property
20182048
def _extra_engine_config(self) -> t.Dict[str, t.Any]:
2019-
return {"schema_location_mapping": self.schema_location_mapping}
2049+
return {
2050+
"schema_location_mapping": self.schema_location_mapping,
2051+
"timestamp_mapping": self.timestamp_mapping,
2052+
}
20202053

20212054

20222055
class ClickhouseConnectionConfig(ConnectionConfig):

sqlmesh/core/engine_adapter/trino.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,32 @@ class TrinoEngineAdapter(
7474
def schema_location_mapping(self) -> t.Optional[t.Dict[re.Pattern, str]]:
7575
return self._extra_config.get("schema_location_mapping")
7676

77+
@property
78+
def timestamp_mapping(self) -> t.Optional[t.Dict[exp.DataType, exp.DataType]]:
79+
return self._extra_config.get("timestamp_mapping")
80+
81+
def _apply_timestamp_mapping(
82+
self, columns_to_types: t.Dict[str, exp.DataType]
83+
) -> t.Tuple[t.Dict[str, exp.DataType], t.Set[str]]:
84+
"""Apply custom timestamp mapping to column types.
85+
86+
Returns:
87+
A tuple of (mapped_columns_to_types, mapped_column_names) where mapped_column_names
88+
contains the names of columns that were found in the mapping.
89+
"""
90+
if not self.timestamp_mapping:
91+
return columns_to_types, set()
92+
93+
result = {}
94+
mapped_columns: t.Set[str] = set()
95+
for column, column_type in columns_to_types.items():
96+
if column_type in self.timestamp_mapping:
97+
result[column] = self.timestamp_mapping[column_type]
98+
mapped_columns.add(column)
99+
else:
100+
result[column] = column_type
101+
return result, mapped_columns
102+
77103
@property
78104
def catalog_support(self) -> CatalogSupport:
79105
return CatalogSupport.FULL_SUPPORT
@@ -117,7 +143,7 @@ def session(self, properties: SessionProperties) -> t.Iterator[None]:
117143
try:
118144
yield
119145
finally:
120-
self.execute(f"RESET SESSION AUTHORIZATION")
146+
self.execute("RESET SESSION AUTHORIZATION")
121147

122148
def replace_query(
123149
self,
@@ -286,8 +312,11 @@ def _build_schema_exp(
286312
is_view: bool = False,
287313
materialized: bool = False,
288314
) -> exp.Schema:
315+
target_columns_to_types, mapped_columns = self._apply_timestamp_mapping(
316+
target_columns_to_types
317+
)
289318
if "delta_lake" in self.get_catalog_type_from_table(table):
290-
target_columns_to_types = self._to_delta_ts(target_columns_to_types)
319+
target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns)
291320

292321
return super()._build_schema_exp(
293322
table, target_columns_to_types, column_descriptions, expressions, is_view
@@ -313,10 +342,15 @@ def _scd_type_2(
313342
source_columns: t.Optional[t.List[str]] = None,
314343
**kwargs: t.Any,
315344
) -> None:
345+
mapped_columns: t.Set[str] = set()
346+
if target_columns_to_types:
347+
target_columns_to_types, mapped_columns = self._apply_timestamp_mapping(
348+
target_columns_to_types
349+
)
316350
if target_columns_to_types and "delta_lake" in self.get_catalog_type_from_table(
317351
target_table
318352
):
319-
target_columns_to_types = self._to_delta_ts(target_columns_to_types)
353+
target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns)
320354

321355
return super()._scd_type_2(
322356
target_table,
@@ -346,18 +380,21 @@ def _scd_type_2(
346380
# - `timestamp(3) with time zone` for timezone-aware
347381
# https://trino.io/docs/current/connector/delta-lake.html#delta-lake-to-trino-type-mapping
348382
def _to_delta_ts(
349-
self, columns_to_types: t.Dict[str, exp.DataType]
383+
self,
384+
columns_to_types: t.Dict[str, exp.DataType],
385+
skip_columns: t.Optional[t.Set[str]] = None,
350386
) -> t.Dict[str, exp.DataType]:
351387
ts6 = exp.DataType.build("timestamp(6)")
352388
ts3_tz = exp.DataType.build("timestamp(3) with time zone")
389+
skip = skip_columns or set()
353390

354391
delta_columns_to_types = {
355-
k: ts6 if v.is_type(exp.DataType.Type.TIMESTAMP) else v
392+
k: ts6 if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMP) else v
356393
for k, v in columns_to_types.items()
357394
}
358395

359396
delta_columns_to_types = {
360-
k: ts3_tz if v.is_type(exp.DataType.Type.TIMESTAMPTZ) else v
397+
k: ts3_tz if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMPTZ) else v
361398
for k, v in delta_columns_to_types.items()
362399
}
363400

tests/core/engine_adapter/test_trino.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,119 @@ def test_delta_timestamps(make_mocked_engine_adapter: t.Callable):
404404
}
405405

406406

407+
def test_timestamp_mapping():
408+
"""Test that timestamp_mapping config property is properly defined and accessible."""
409+
config = TrinoConnectionConfig(
410+
user="user",
411+
host="host",
412+
catalog="catalog",
413+
)
414+
415+
adapter = config.create_engine_adapter()
416+
assert adapter.timestamp_mapping is None
417+
418+
config = TrinoConnectionConfig(
419+
user="user",
420+
host="host",
421+
catalog="catalog",
422+
timestamp_mapping={
423+
"TIMESTAMP": "TIMESTAMP(6)",
424+
"TIMESTAMP(3)": "TIMESTAMP WITH TIME ZONE",
425+
},
426+
)
427+
adapter = config.create_engine_adapter()
428+
assert adapter.timestamp_mapping is not None
429+
assert adapter.timestamp_mapping[exp.DataType.build("TIMESTAMP")] == exp.DataType.build(
430+
"TIMESTAMP(6)"
431+
)
432+
433+
434+
def test_delta_timestamps_with_custom_mapping(make_mocked_engine_adapter: t.Callable):
435+
"""Test that _apply_timestamp_mapping + _to_delta_ts respects custom timestamp_mapping."""
436+
# Create config with custom timestamp mapping
437+
# Mapped columns are skipped by _to_delta_ts
438+
config = TrinoConnectionConfig(
439+
user="user",
440+
host="host",
441+
catalog="catalog",
442+
timestamp_mapping={
443+
"TIMESTAMP": "TIMESTAMP(3)",
444+
"TIMESTAMP(1)": "TIMESTAMP(3)",
445+
"TIMESTAMP WITH TIME ZONE": "TIMESTAMP(6) WITH TIME ZONE",
446+
"TIMESTAMP(1) WITH TIME ZONE": "TIMESTAMP(6) WITH TIME ZONE",
447+
},
448+
)
449+
450+
adapter = make_mocked_engine_adapter(
451+
TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping
452+
)
453+
454+
ts3 = exp.DataType.build("timestamp(3)")
455+
ts6_tz = exp.DataType.build("timestamp(6) with time zone")
456+
457+
columns_to_types = {
458+
"ts": exp.DataType.build("TIMESTAMP"),
459+
"ts_1": exp.DataType.build("TIMESTAMP(1)"),
460+
"ts_tz": exp.DataType.build("TIMESTAMP WITH TIME ZONE"),
461+
"ts_tz_1": exp.DataType.build("TIMESTAMP(1) WITH TIME ZONE"),
462+
}
463+
464+
# Apply mapping first, then convert to delta types (skipping mapped columns)
465+
mapped_columns_to_types, mapped_column_names = adapter._apply_timestamp_mapping(
466+
columns_to_types
467+
)
468+
delta_columns_to_types = adapter._to_delta_ts(mapped_columns_to_types, mapped_column_names)
469+
470+
# All types were mapped, so _to_delta_ts skips them - they keep their mapped types
471+
assert delta_columns_to_types == {
472+
"ts": ts3,
473+
"ts_1": ts3,
474+
"ts_tz": ts6_tz,
475+
"ts_tz_1": ts6_tz,
476+
}
477+
478+
479+
def test_delta_timestamps_with_partial_mapping(make_mocked_engine_adapter: t.Callable):
480+
"""Test that _apply_timestamp_mapping + _to_delta_ts uses custom mapping for specified types."""
481+
config = TrinoConnectionConfig(
482+
user="user",
483+
host="host",
484+
catalog="catalog",
485+
timestamp_mapping={
486+
"TIMESTAMP": "TIMESTAMP(3)",
487+
},
488+
)
489+
490+
adapter = make_mocked_engine_adapter(
491+
TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping
492+
)
493+
494+
ts3 = exp.DataType.build("TIMESTAMP(3)")
495+
ts6 = exp.DataType.build("timestamp(6)")
496+
ts3_tz = exp.DataType.build("timestamp(3) with time zone")
497+
498+
columns_to_types = {
499+
"ts": exp.DataType.build("TIMESTAMP"),
500+
"ts_1": exp.DataType.build("TIMESTAMP(1)"),
501+
"ts_tz": exp.DataType.build("TIMESTAMP WITH TIME ZONE"),
502+
}
503+
504+
# Apply mapping first, then convert to delta types (skipping mapped columns)
505+
mapped_columns_to_types, mapped_column_names = adapter._apply_timestamp_mapping(
506+
columns_to_types
507+
)
508+
delta_columns_to_types = adapter._to_delta_ts(mapped_columns_to_types, mapped_column_names)
509+
510+
# TIMESTAMP is in mapping → TIMESTAMP(3), skipped by _to_delta_ts
511+
# TIMESTAMP(1) is NOT in mapping, uses default TIMESTAMP → ts6
512+
# TIMESTAMP WITH TIME ZONE is NOT in mapping, uses default TIMESTAMPTZ → ts3_tz
513+
assert delta_columns_to_types == {
514+
"ts": ts3, # Mapped to TIMESTAMP(3), skipped by _to_delta_ts
515+
"ts_1": ts6, # Not in mapping, uses default
516+
"ts_tz": ts3_tz, # Not in mapping, uses default
517+
}
518+
519+
407520
def test_table_format(trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture):
408521
adapter = trino_mocked_engine_adapter
409522
mocker.patch(
@@ -755,3 +868,77 @@ def test_insert_overwrite_time_partition_iceberg(
755868
'DELETE FROM "my_catalog"."schema"."test_table" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'',
756869
'INSERT INTO "my_catalog"."schema"."test_table" ("a", "b") SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'',
757870
]
871+
872+
873+
def test_delta_timestamps_with_non_timestamp_columns(make_mocked_engine_adapter: t.Callable):
874+
"""Test that _apply_timestamp_mapping + _to_delta_ts handles non-timestamp columns."""
875+
config = TrinoConnectionConfig(
876+
user="user",
877+
host="host",
878+
catalog="catalog",
879+
timestamp_mapping={
880+
"TIMESTAMP": "TIMESTAMP(3)",
881+
},
882+
)
883+
884+
adapter = make_mocked_engine_adapter(
885+
TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping
886+
)
887+
888+
ts3 = exp.DataType.build("TIMESTAMP(3)")
889+
ts6 = exp.DataType.build("timestamp(6)")
890+
891+
columns_to_types = {
892+
"ts": exp.DataType.build("TIMESTAMP"),
893+
"ts_1": exp.DataType.build("TIMESTAMP(1)"),
894+
"int_col": exp.DataType.build("INT"),
895+
"varchar_col": exp.DataType.build("VARCHAR(100)"),
896+
"decimal_col": exp.DataType.build("DECIMAL(10,2)"),
897+
}
898+
899+
# Apply mapping first, then convert to delta types (skipping mapped columns)
900+
mapped_columns_to_types, mapped_column_names = adapter._apply_timestamp_mapping(
901+
columns_to_types
902+
)
903+
delta_columns_to_types = adapter._to_delta_ts(mapped_columns_to_types, mapped_column_names)
904+
905+
# TIMESTAMP is in mapping → TIMESTAMP(3), skipped by _to_delta_ts
906+
# TIMESTAMP(1) is NOT in mapping (exact match), uses default TIMESTAMP → ts6
907+
# Non-timestamp columns should pass through unchanged
908+
assert delta_columns_to_types == {
909+
"ts": ts3, # Mapped to TIMESTAMP(3), skipped by _to_delta_ts
910+
"ts_1": ts6, # Not in mapping, uses default
911+
"int_col": exp.DataType.build("INT"),
912+
"varchar_col": exp.DataType.build("VARCHAR(100)"),
913+
"decimal_col": exp.DataType.build("DECIMAL(10,2)"),
914+
}
915+
916+
917+
def test_delta_timestamps_with_empty_mapping(make_mocked_engine_adapter: t.Callable):
918+
"""Test that _to_delta_ts handles empty custom mapping dictionary."""
919+
config = TrinoConnectionConfig(
920+
user="user",
921+
host="host",
922+
catalog="catalog",
923+
timestamp_mapping={},
924+
)
925+
926+
adapter = make_mocked_engine_adapter(
927+
TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping
928+
)
929+
930+
ts6 = exp.DataType.build("timestamp(6)")
931+
ts3_tz = exp.DataType.build("timestamp(3) with time zone")
932+
933+
columns_to_types = {
934+
"ts": exp.DataType.build("TIMESTAMP"),
935+
"ts_tz": exp.DataType.build("TIMESTAMP WITH TIME ZONE"),
936+
}
937+
938+
delta_columns_to_types = adapter._to_delta_ts(columns_to_types)
939+
940+
# With empty custom mapping, should fall back to defaults
941+
assert delta_columns_to_types == {
942+
"ts": ts6,
943+
"ts_tz": ts3_tz,
944+
}

0 commit comments

Comments
 (0)