Skip to content

Commit a5797a1

Browse files
authored
feat(duckdb)!: handle named arguments and non-integer scale input for ROUND (#6495)
* support named arguments and non-integer scale values * remove * unpack Kwargs at Snowflake parse time + address other Kwarg feedback
1 parent c97a81d commit a5797a1

File tree

5 files changed

+112
-2
lines changed

5 files changed

+112
-2
lines changed

sqlglot/dialects/duckdb.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,6 +1605,15 @@ def round_sql(self, expression: exp.Round) -> str:
16051605
decimals = expression.args.get("decimals")
16061606
truncate = expression.args.get("truncate")
16071607

1608+
# DuckDB requires the scale (decimals) argument to be an INT
1609+
# Some dialects (e.g., Snowflake) allow non-integer scales and cast to an integer internally
1610+
if decimals is not None and expression.args.get("casts_non_integer_decimals"):
1611+
if isinstance(decimals, exp.Literal):
1612+
if not decimals.is_int:
1613+
decimals = exp.cast(decimals, exp.DataType.Type.INT)
1614+
elif not decimals.is_type(*exp.DataType.INTEGER_TYPES):
1615+
decimals = exp.cast(decimals, exp.DataType.Type.INT)
1616+
16081617
func = "ROUND"
16091618
if truncate:
16101619
# BigQuery uses ROUND_HALF_EVEN; Snowflake uses HALF_TO_EVEN

sqlglot/dialects/snowflake.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,36 @@ def _build_timestamp_from_parts(args: t.List) -> exp.Func:
550550
return exp.TimestampFromParts.from_arg_list(args)
551551

552552

553+
def _build_round(args: t.List) -> exp.Round:
554+
"""
555+
Build Round expression, unwrapping Snowflake's named parameters.
556+
557+
Maps EXPR => this, SCALE => decimals, ROUNDING_MODE => truncate.
558+
559+
Note: Snowflake does not support mixing named and positional arguments.
560+
Arguments are either all named or all positional.
561+
"""
562+
kwarg_map = {"EXPR": "this", "SCALE": "decimals", "ROUNDING_MODE": "truncate"}
563+
round_args = {}
564+
positional_keys = ["this", "decimals", "truncate"]
565+
positional_idx = 0
566+
567+
for arg in args:
568+
if isinstance(arg, exp.Kwarg):
569+
key = arg.this.name.upper()
570+
round_key = kwarg_map.get(key)
571+
if round_key:
572+
round_args[round_key] = arg.expression
573+
else:
574+
if positional_idx < len(positional_keys):
575+
round_args[positional_keys[positional_idx]] = arg
576+
positional_idx += 1
577+
578+
expression = exp.Round(**round_args)
579+
expression.set("casts_non_integer_decimals", True)
580+
return expression
581+
582+
553583
class Snowflake(Dialect):
554584
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
555585
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@@ -717,6 +747,7 @@ class Parser(parser.Parser):
717747
"REGEXP_SUBSTR_ALL": _build_regexp_extract(exp.RegexpExtractAll),
718748
"REPLACE": build_replace_with_optional_replacement,
719749
"RLIKE": exp.RegexpLike.from_arg_list,
750+
"ROUND": _build_round,
720751
"SHA1_BINARY": exp.SHA1Digest.from_arg_list,
721752
"SHA1_HEX": exp.SHA.from_arg_list,
722753
"SHA2_BINARY": exp.SHA2Digest.from_arg_list,

sqlglot/expressions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7727,7 +7727,12 @@ class Radians(Func):
77277727
# https://learn.microsoft.com/en-us/sql/t-sql/functions/round-transact-sql?view=sql-server-ver16
77287728
# tsql third argument function == trunctaion if not 0
77297729
class Round(Func):
7730-
arg_types = {"this": True, "decimals": False, "truncate": False}
7730+
arg_types = {
7731+
"this": True,
7732+
"decimals": False,
7733+
"truncate": False,
7734+
"casts_non_integer_decimals": False,
7735+
}
77317736

77327737

77337738
class RowNumber(Func):

tests/dialects/test_duckdb.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,7 @@ def test_duckdb(self):
12531253
)
12541254
self.validate_identity("SELECT GREATEST(1.0, 2.5, NULL, 3.7)")
12551255
self.validate_identity("FROM t1, t2 SELECT *", "SELECT * FROM t1, t2")
1256+
self.validate_identity("ROUND(2.256, 1)")
12561257

12571258
# TODO: This is incorrect AST, DATE_PART creates a STRUCT of values but it's stored in 'year' arg
12581259
self.validate_identity("SELECT MAKE_DATE(DATE_PART(['year', 'month', 'day'], TODAY()))")

tests/dialects/test_snowflake.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3712,6 +3712,22 @@ def test_round(self):
37123712
},
37133713
)
37143714

3715+
self.validate_all(
3716+
"SELECT ROUND(EXPR => 2.25, SCALE => 1) AS value",
3717+
write={
3718+
"snowflake": "SELECT ROUND(2.25, 1) AS value",
3719+
"duckdb": "SELECT ROUND(2.25, 1) AS value",
3720+
},
3721+
)
3722+
3723+
self.validate_all(
3724+
"SELECT ROUND(SCALE => 1, EXPR => 2.25) AS value",
3725+
write={
3726+
"snowflake": "SELECT ROUND(2.25, 1) AS value",
3727+
"duckdb": "SELECT ROUND(2.25, 1) AS value",
3728+
},
3729+
)
3730+
37153731
self.validate_all(
37163732
"SELECT ROUND(2.25, 1, 'HALF_AWAY_FROM_ZERO') AS value",
37173733
write={
@@ -3720,10 +3736,58 @@ def test_round(self):
37203736
},
37213737
)
37223738

3739+
self.validate_all(
3740+
"SELECT ROUND(EXPR => 2.25, SCALE => 1, ROUNDING_MODE => 'HALF_AWAY_FROM_ZERO') AS value",
3741+
write={
3742+
"snowflake": "SELECT ROUND(2.25, 1, 'HALF_AWAY_FROM_ZERO') AS value",
3743+
"duckdb": "SELECT ROUND(2.25, 1) AS value",
3744+
},
3745+
)
3746+
37233747
self.validate_all(
37243748
"SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value",
37253749
write={
3726-
"snowflake": """SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value""",
3750+
"snowflake": "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value",
3751+
"duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value",
3752+
},
3753+
)
3754+
3755+
self.validate_all(
3756+
"SELECT ROUND(ROUNDING_MODE => 'HALF_TO_EVEN', EXPR => 2.25, SCALE => 1) AS value",
3757+
write={
3758+
"snowflake": "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value",
3759+
"duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value",
3760+
},
3761+
)
3762+
3763+
self.validate_all(
3764+
"SELECT ROUND(SCALE => 1, EXPR => 2.25, , ROUNDING_MODE => 'HALF_TO_EVEN') AS value",
3765+
write={
3766+
"snowflake": "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value",
3767+
"duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value",
3768+
},
3769+
)
3770+
3771+
self.validate_all(
3772+
"SELECT ROUND(EXPR => 2.25, SCALE => 1, ROUNDING_MODE => 'HALF_TO_EVEN') AS value",
3773+
write={
3774+
"snowflake": "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value",
37273775
"duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value",
37283776
},
37293777
)
3778+
3779+
self.validate_all(
3780+
"SELECT ROUND(2.256, 1.8) AS value",
3781+
write={
3782+
"snowflake": "SELECT ROUND(2.256, 1.8) AS value",
3783+
"duckdb": "SELECT ROUND(2.256, CAST(1.8 AS INT)) AS value",
3784+
},
3785+
)
3786+
3787+
self.validate_all(
3788+
"SELECT ROUND(2.256, CAST(1.8 AS DECIMAL(38, 0))) AS value",
3789+
write={
3790+
"snowflake": "SELECT ROUND(2.256, CAST(1.8 AS DECIMAL(38, 0))) AS value",
3791+
"duckdb": "SELECT ROUND(2.256, CAST(CAST(1.8 AS DECIMAL(38, 0)) AS INT)) AS value",
3792+
},
3793+
)

0 commit comments

Comments
 (0)