Skip to content

Commit 68a5e61

Browse files
feat(snowflake)!: annotate type for REGR_* functions (#6452)
* Type annotation for REGR_* functions * removed unrequired change * added tests for other databases and made all REGR classes inherit from AggFunc * removed unsupported databases
1 parent 01e5a05 commit 68a5e61

File tree

6 files changed

+367
-10
lines changed

6 files changed

+367
-10
lines changed

sqlglot/expressions.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7643,23 +7643,47 @@ class RegexpCount(Func):
76437643
}
76447644

76457645

7646-
class RegrValx(Func):
7646+
class RegrValx(AggFunc):
76477647
arg_types = {"this": True, "expression": True}
76487648

76497649

7650-
class RegrValy(Func):
7650+
class RegrValy(AggFunc):
76517651
arg_types = {"this": True, "expression": True}
76527652

76537653

7654-
class RegrAvgy(Func):
7654+
class RegrAvgy(AggFunc):
76557655
arg_types = {"this": True, "expression": True}
76567656

76577657

7658-
class RegrAvgx(Func):
7658+
class RegrAvgx(AggFunc):
76597659
arg_types = {"this": True, "expression": True}
76607660

76617661

7662-
class RegrSlope(Func):
7662+
class RegrCount(AggFunc):
7663+
arg_types = {"this": True, "expression": True}
7664+
7665+
7666+
class RegrIntercept(AggFunc):
7667+
arg_types = {"this": True, "expression": True}
7668+
7669+
7670+
class RegrR2(AggFunc):
7671+
arg_types = {"this": True, "expression": True}
7672+
7673+
7674+
class RegrSxx(AggFunc):
7675+
arg_types = {"this": True, "expression": True}
7676+
7677+
7678+
class RegrSxy(AggFunc):
7679+
arg_types = {"this": True, "expression": True}
7680+
7681+
7682+
class RegrSyy(AggFunc):
7683+
arg_types = {"this": True, "expression": True}
7684+
7685+
7686+
class RegrSlope(AggFunc):
76637687
arg_types = {"this": True, "expression": True}
76647688

76657689

sqlglot/typing/snowflake.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,6 @@ def _annotate_math_with_float_decfloat(
276276
exp.ManhattanDistance,
277277
exp.MonthsBetween,
278278
exp.Normal,
279-
exp.RegrAvgx,
280-
exp.RegrAvgy,
281-
exp.RegrSlope,
282-
exp.RegrValx,
283-
exp.RegrValy,
284279
exp.Sinh,
285280
}
286281
},
@@ -306,6 +301,17 @@ def _annotate_math_with_float_decfloat(
306301
exp.Log,
307302
exp.Pow,
308303
exp.Radians,
304+
exp.RegrAvgx,
305+
exp.RegrAvgy,
306+
exp.RegrCount,
307+
exp.RegrIntercept,
308+
exp.RegrR2,
309+
exp.RegrSlope,
310+
exp.RegrSxx,
311+
exp.RegrSxy,
312+
exp.RegrSyy,
313+
exp.RegrValx,
314+
exp.RegrValy,
309315
exp.Sin,
310316
exp.Sqrt,
311317
exp.Tan,

tests/dialects/test_dialect.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3943,6 +3943,101 @@ def test_reverse(self):
39433943
},
39443944
)
39453945

3946+
def test_regr_count(self):
3947+
self.validate_all(
3948+
"REGR_COUNT(x, y)",
3949+
read={
3950+
"": "REGR_COUNT(x, y)",
3951+
"databricks": "REGR_COUNT(x, y)",
3952+
"duckdb": "REGR_COUNT(x, y)",
3953+
"exasol": "REGR_COUNT(x, y)",
3954+
"hive": "REGR_COUNT(x, y)",
3955+
"oracle": "REGR_COUNT(x, y)",
3956+
"postgres": "REGR_COUNT(x, y)",
3957+
"presto": "REGR_COUNT(x, y)",
3958+
"snowflake": "REGR_COUNT(x, y)",
3959+
"spark": "REGR_COUNT(x, y)",
3960+
"teradata": "REGR_COUNT(x, y)",
3961+
"trino": "REGR_COUNT(x, y)",
3962+
},
3963+
write={
3964+
"": "REGR_COUNT(x, y)",
3965+
"databricks": "REGR_COUNT(x, y)",
3966+
"duckdb": "REGR_COUNT(x, y)",
3967+
"exasol": "REGR_COUNT(x, y)",
3968+
"hive": "REGR_COUNT(x, y)",
3969+
"oracle": "REGR_COUNT(x, y)",
3970+
"postgres": "REGR_COUNT(x, y)",
3971+
"presto": "REGR_COUNT(x, y)",
3972+
"snowflake": "REGR_COUNT(x, y)",
3973+
"spark": "REGR_COUNT(x, y)",
3974+
"teradata": "REGR_COUNT(x, y)",
3975+
"trino": "REGR_COUNT(x, y)",
3976+
},
3977+
)
3978+
3979+
def test_regr_intercept(self):
3980+
self.validate_all(
3981+
"REGR_INTERCEPT(x, y)",
3982+
read={
3983+
"": "REGR_INTERCEPT(x, y)",
3984+
"databricks": "REGR_INTERCEPT(x, y)",
3985+
"duckdb": "REGR_INTERCEPT(x, y)",
3986+
"exasol": "REGR_INTERCEPT(x, y)",
3987+
"hive": "REGR_INTERCEPT(x, y)",
3988+
"oracle": "REGR_INTERCEPT(x, y)",
3989+
"postgres": "REGR_INTERCEPT(x, y)",
3990+
"presto": "REGR_INTERCEPT(x, y)",
3991+
"snowflake": "REGR_INTERCEPT(x, y)",
3992+
"spark": "REGR_INTERCEPT(x, y)",
3993+
"teradata": "REGR_INTERCEPT(x, y)",
3994+
},
3995+
write={
3996+
"": "REGR_INTERCEPT(x, y)",
3997+
"databricks": "REGR_INTERCEPT(x, y)",
3998+
"duckdb": "REGR_INTERCEPT(x, y)",
3999+
"exasol": "REGR_INTERCEPT(x, y)",
4000+
"hive": "REGR_INTERCEPT(x, y)",
4001+
"oracle": "REGR_INTERCEPT(x, y)",
4002+
"postgres": "REGR_INTERCEPT(x, y)",
4003+
"presto": "REGR_INTERCEPT(x, y)",
4004+
"snowflake": "REGR_INTERCEPT(x, y)",
4005+
"spark": "REGR_INTERCEPT(x, y)",
4006+
"teradata": "REGR_INTERCEPT(x, y)",
4007+
},
4008+
)
4009+
4010+
def test_regr_r2(self):
4011+
self.validate_all(
4012+
"REGR_R2(x, y)",
4013+
read={
4014+
"": "REGR_R2(x, y)",
4015+
"databricks": "REGR_R2(x, y)",
4016+
"duckdb": "REGR_R2(x, y)",
4017+
"exasol": "REGR_R2(x, y)",
4018+
"hive": "REGR_R2(x, y)",
4019+
"oracle": "REGR_R2(x, y)",
4020+
"postgres": "REGR_R2(x, y)",
4021+
"presto": "REGR_R2(x, y)",
4022+
"snowflake": "REGR_R2(x, y)",
4023+
"spark": "REGR_R2(x, y)",
4024+
"teradata": "REGR_R2(x, y)",
4025+
},
4026+
write={
4027+
"": "REGR_R2(x, y)",
4028+
"databricks": "REGR_R2(x, y)",
4029+
"duckdb": "REGR_R2(x, y)",
4030+
"exasol": "REGR_R2(x, y)",
4031+
"hive": "REGR_R2(x, y)",
4032+
"oracle": "REGR_R2(x, y)",
4033+
"postgres": "REGR_R2(x, y)",
4034+
"presto": "REGR_R2(x, y)",
4035+
"snowflake": "REGR_R2(x, y)",
4036+
"spark": "REGR_R2(x, y)",
4037+
"teradata": "REGR_R2(x, y)",
4038+
},
4039+
)
4040+
39464041
def test_regr_slope(self):
39474042
self.validate_all(
39484043
"REGR_SLOPE(x, y)",
@@ -3974,6 +4069,99 @@ def test_regr_slope(self):
39744069
},
39754070
)
39764071

4072+
def test_regr_sxx(self):
4073+
self.validate_all(
4074+
"REGR_SXX(x, y)",
4075+
read={
4076+
"": "REGR_SXX(x, y)",
4077+
"databricks": "REGR_SXX(x, y)",
4078+
"duckdb": "REGR_SXX(x, y)",
4079+
"exasol": "REGR_SXX(x, y)",
4080+
"hive": "REGR_SXX(x, y)",
4081+
"oracle": "REGR_SXX(x, y)",
4082+
"postgres": "REGR_SXX(x, y)",
4083+
"presto": "REGR_SXX(x, y)",
4084+
"snowflake": "REGR_SXX(x, y)",
4085+
"spark": "REGR_SXX(x, y)",
4086+
"teradata": "REGR_SXX(x, y)",
4087+
},
4088+
write={
4089+
"": "REGR_SXX(x, y)",
4090+
"databricks": "REGR_SXX(x, y)",
4091+
"duckdb": "REGR_SXX(x, y)",
4092+
"exasol": "REGR_SXX(x, y)",
4093+
"hive": "REGR_SXX(x, y)",
4094+
"oracle": "REGR_SXX(x, y)",
4095+
"postgres": "REGR_SXX(x, y)",
4096+
"presto": "REGR_SXX(x, y)",
4097+
"snowflake": "REGR_SXX(x, y)",
4098+
"spark": "REGR_SXX(x, y)",
4099+
"teradata": "REGR_SXX(x, y)",
4100+
},
4101+
)
4102+
4103+
def test_regr_sxy(self):
4104+
self.validate_all(
4105+
"REGR_SXY(x, y)",
4106+
read={
4107+
"": "REGR_SXY(x, y)",
4108+
"databricks": "REGR_SXY(x, y)",
4109+
"duckdb": "REGR_SXY(x, y)",
4110+
"exasol": "REGR_SXY(x, y)",
4111+
"hive": "REGR_SXY(x, y)",
4112+
"oracle": "REGR_SXY(x, y)",
4113+
"postgres": "REGR_SXY(x, y)",
4114+
"presto": "REGR_SXY(x, y)",
4115+
"snowflake": "REGR_SXY(x, y)",
4116+
"spark": "REGR_SXY(x, y)",
4117+
"teradata": "REGR_SXY(x, y)",
4118+
},
4119+
write={
4120+
"": "REGR_SXY(x, y)",
4121+
"databricks": "REGR_SXY(x, y)",
4122+
"duckdb": "REGR_SXY(x, y)",
4123+
"exasol": "REGR_SXY(x, y)",
4124+
"hive": "REGR_SXY(x, y)",
4125+
"oracle": "REGR_SXY(x, y)",
4126+
"postgres": "REGR_SXY(x, y)",
4127+
"presto": "REGR_SXY(x, y)",
4128+
"snowflake": "REGR_SXY(x, y)",
4129+
"spark": "REGR_SXY(x, y)",
4130+
"teradata": "REGR_SXY(x, y)",
4131+
},
4132+
)
4133+
4134+
def test_regr_syy(self):
4135+
self.validate_all(
4136+
"REGR_SYY(x, y)",
4137+
read={
4138+
"": "REGR_SYY(x, y)",
4139+
"databricks": "REGR_SYY(x, y)",
4140+
"duckdb": "REGR_SYY(x, y)",
4141+
"exasol": "REGR_SYY(x, y)",
4142+
"hive": "REGR_SYY(x, y)",
4143+
"oracle": "REGR_SYY(x, y)",
4144+
"postgres": "REGR_SYY(x, y)",
4145+
"presto": "REGR_SYY(x, y)",
4146+
"snowflake": "REGR_SYY(x, y)",
4147+
"spark": "REGR_SYY(x, y)",
4148+
"teradata": "REGR_SYY(x, y)",
4149+
},
4150+
write={
4151+
"": "REGR_SYY(x, y)",
4152+
"databricks": "REGR_SYY(x, y)",
4153+
"duckdb": "REGR_SYY(x, y)",
4154+
"exasol": "REGR_SYY(x, y)",
4155+
"hive": "REGR_SYY(x, y)",
4156+
"oracle": "REGR_SYY(x, y)",
4157+
"postgres": "REGR_SYY(x, y)",
4158+
"presto": "REGR_SYY(x, y)",
4159+
"snowflake": "REGR_SYY(x, y)",
4160+
"spark": "REGR_SYY(x, y)",
4161+
"teradata": "REGR_SYY(x, y)",
4162+
},
4163+
)
4164+
39774165
def test_translate(self):
39784166
self.validate_all(
39794167
"TRANSLATE(x, y, z)",

tests/dialects/test_snowflake.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,12 @@ def test_snowflake(self):
138138
self.validate_identity("SELECT REGR_VALY(y, x)")
139139
self.validate_identity("SELECT REGR_AVGX(y, x)")
140140
self.validate_identity("SELECT REGR_AVGY(y, x)")
141+
self.validate_identity("SELECT REGR_COUNT(y, x)")
142+
self.validate_identity("SELECT REGR_INTERCEPT(y, x)")
143+
self.validate_identity("SELECT REGR_R2(y, x)")
144+
self.validate_identity("SELECT REGR_SXX(y, x)")
145+
self.validate_identity("SELECT REGR_SXY(y, x)")
146+
self.validate_identity("SELECT REGR_SYY(y, x)")
141147
self.validate_identity("SELECT REGR_SLOPE(y, x)")
142148
self.validate_all(
143149
"SELECT SKEW(a)",

0 commit comments

Comments
 (0)