Skip to content

Commit 04c3508

Browse files
committed
fix sql types applied for interval functions, finalize implementation for CAST
1 parent ee0606b commit 04c3508

File tree

6 files changed

+205
-23
lines changed

6 files changed

+205
-23
lines changed

es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,4 +1834,49 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
18341834
.replaceAll("ZonedDateTime", " ZonedDateTime")
18351835
}
18361836

1837+
it should "handle cast function as script field" in {
1838+
val select: ElasticSearchRequest =
1839+
SQLQuery(cast)
1840+
val query = select.query
1841+
println(query)
1842+
query shouldBe
1843+
"""{
1844+
| "query": {
1845+
| "match_all": {}
1846+
| },
1847+
| "script_fields": {
1848+
| "c": {
1849+
| "script": {
1850+
| "lang": "painless",
1851+
| "source": "{ def v0 = ({ def e1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); def e2 = DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from); return e1 == e2 ? null : e1; });if (v0 != null) return v0; return (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate()).atStartOfDay(ZoneId.of('Z')).minus(2, ChronoUnit.HOURS); }.toInstant().toEpochMilli()"
1852+
| }
1853+
| }
1854+
| },
1855+
| "_source": {
1856+
| "includes": [
1857+
| "identifier"
1858+
| ]
1859+
| }
1860+
|}""".stripMargin
1861+
.replaceAll("\\s+", "")
1862+
.replaceAll("defv", " def v")
1863+
.replaceAll("defe", " def e")
1864+
.replaceAll("if\\(", "if (")
1865+
.replaceAll("=\\(", " = (")
1866+
.replaceAll("\\?", " ? ")
1867+
.replaceAll(":null", " : null")
1868+
.replaceAll("null:", "null : ")
1869+
.replaceAll("return", " return ")
1870+
.replaceAll("between\\(s,", "between(s, ")
1871+
.replaceAll(";", "; ")
1872+
.replaceAll("; if", ";if")
1873+
.replaceAll("==", " == ")
1874+
.replaceAll("!=", " != ")
1875+
.replaceAll("&&", " && ")
1876+
.replaceAll("\\|\\|", " || ")
1877+
.replaceAll(";\\s\\s", "; ")
1878+
.replaceAll("ChronoUnit", " ChronoUnit")
1879+
.replaceAll(",LocalDate", ", LocalDate")
1880+
.replaceAll("=DateTimeFormatter", " = DateTimeFormatter")
1881+
}
18371882
}

sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1823,4 +1823,50 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
18231823
.replaceAll("ZonedDateTime", " ZonedDateTime")
18241824
}
18251825

1826+
1827+
it should "handle cast function as script field" in {
1828+
val select: ElasticSearchRequest =
1829+
SQLQuery(cast)
1830+
val query = select.query
1831+
println(query)
1832+
query shouldBe
1833+
"""{
1834+
| "query": {
1835+
| "match_all": {}
1836+
| },
1837+
| "script_fields": {
1838+
| "c": {
1839+
| "script": {
1840+
| "lang": "painless",
1841+
| "source": "{ def v0 = ({ def e1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); def e2 = DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from); return e1 == e2 ? null : e1; });if (v0 != null) return v0; return (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate()).atStartOfDay(ZoneId.of('Z')).minus(2, ChronoUnit.HOURS); }.toInstant().toEpochMilli()"
1842+
| }
1843+
| }
1844+
| },
1845+
| "_source": {
1846+
| "includes": [
1847+
| "identifier"
1848+
| ]
1849+
| }
1850+
|}""".stripMargin
1851+
.replaceAll("\\s+", "")
1852+
.replaceAll("defv", " def v")
1853+
.replaceAll("defe", " def e")
1854+
.replaceAll("if\\(", "if (")
1855+
.replaceAll("=\\(", " = (")
1856+
.replaceAll("\\?", " ? ")
1857+
.replaceAll(":null", " : null")
1858+
.replaceAll("null:", "null : ")
1859+
.replaceAll("return", " return ")
1860+
.replaceAll("between\\(s,", "between(s, ")
1861+
.replaceAll(";", "; ")
1862+
.replaceAll("; if", ";if")
1863+
.replaceAll("==", " == ")
1864+
.replaceAll("!=", " != ")
1865+
.replaceAll("&&", " && ")
1866+
.replaceAll("\\|\\|", " || ")
1867+
.replaceAll(";\\s\\s", "; ")
1868+
.replaceAll("ChronoUnit", " ChronoUnit")
1869+
.replaceAll(",LocalDate", ", LocalDate")
1870+
.replaceAll("=DateTimeFormatter", " = DateTimeFormatter")
1871+
}
18261872
}

sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,18 +254,29 @@ object TimeInterval {
254254
}
255255
}
256256

257-
sealed trait SQLIntervalFunction extends SQLArithmeticFunction[SQLDateTime, SQLDateTime] {
257+
sealed trait SQLIntervalFunction extends SQLArithmeticFunction[SQLTemporal, SQLTemporal] {
258258
def interval: TimeInterval
259-
override def inputType: SQLDateTime = SQLTypes.DateTime
260-
override def outputType: SQLDateTime = SQLTypes.DateTime
259+
override def inputType: SQLTemporal = SQLTypes.Temporal
260+
override def outputType: SQLTemporal = SQLTypes.Temporal
261261
override def script: String = s"${operator.script}${interval.script}"
262+
private[this] var _out: SQLType = outputType
263+
override def out: SQLType = _out
262264

263-
override def applyType(in: SQLType): SQLType = interval.applyType(in).getOrElse(out)
265+
override def applyType(in: SQLType): SQLType = {
266+
_out = interval.applyType(in).getOrElse(out)
267+
_out
268+
}
264269

265270
override def validate(): Either[String, Unit] = interval.applyType(out) match {
266271
case Left(err) => Left(err)
267272
case Right(_) => Right(())
268273
}
274+
275+
override def toPainless(base: String, idx: Int): String =
276+
if (nullable)
277+
s"(def e$idx = $base; e$idx != null ? ${SQLTypeUtils.coerce(s"e$idx", expr.out, out, nullable = false)}$painless : null)"
278+
else
279+
s"${SQLTypeUtils.coerce(base, expr.out, out, nullable = expr.nullable)}$painless"
269280
}
270281

271282
case class SQLAddInterval(interval: TimeInterval)
@@ -636,5 +647,15 @@ case class SQLCast(value: PainlessScript, targetType: SQLType, as: Boolean = tru
636647
override def sql: String =
637648
s"$Cast(${value.sql} ${if (as) s"$Alias " else ""}${targetType.typeId})"
638649

639-
override def painless: String = SQLTypeUtils.coerce(value, targetType)
650+
override def toSQL(base: String): String = sql
651+
652+
override def painless: String =
653+
SQLTypeUtils.coerce(value, targetType)
654+
655+
override def toPainless(base: String, idx: Int): String =
656+
SQLTypeUtils.coerce(base, value.out, targetType, value.nullable)
657+
/*if (nullable)
658+
s"(def e$idx = $base; e$idx != null ? ${SQLTypeUtils.coerce(s"e$idx", value.out, out, nullable = false)}$painless : null)"
659+
else
660+
s"${SQLTypeUtils.coerce(base, value.out, targetType, nullable = value.nullable)}$painless"*/
640661
}

sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ trait SQLParser extends RegexParsers with PackratParsers {
153153
SQLSubtractInterval(it)
154154
}
155155

156-
def intervalFunction: PackratParser[SQLArithmeticFunction[SQLDateTime, SQLDateTime]] =
156+
def intervalFunction: PackratParser[SQLArithmeticFunction[SQLTemporal, SQLTemporal]] =
157157
addInterval | substractInterval
158158

159159
def identifierWithSystemFunction: PackratParser[SQLIdentifier] =
@@ -373,7 +373,31 @@ trait SQLParser extends RegexParsers with PackratParsers {
373373
"day",
374374
"hour",
375375
"minute",
376-
"second"
376+
"second",
377+
"quarter",
378+
"string",
379+
"int",
380+
"integer",
381+
"long",
382+
"double",
383+
"boolean",
384+
"time",
385+
"date",
386+
"datetime",
387+
"timestamp",
388+
"and",
389+
"or",
390+
"not",
391+
"like",
392+
"in",
393+
"between",
394+
"distinct",
395+
"cast",
396+
"count",
397+
"min",
398+
"max",
399+
"avg",
400+
"sum"
377401
)
378402

379403
private val identifierRegexStr =
@@ -392,6 +416,36 @@ trait SQLParser extends RegexParsers with PackratParsers {
392416
)
393417
}
394418

419+
def string_type: PackratParser[SQLTypes.String.type] = "(?i)string".r ^^ (_ => SQLTypes.String)
420+
421+
def date_type: PackratParser[SQLTypes.Date.type] = "(?i)date".r ^^ (_ => SQLTypes.Date)
422+
423+
def time_type: PackratParser[SQLTypes.Time.type] = "(?i)time".r ^^ (_ => SQLTypes.Time)
424+
425+
def datetime_type: PackratParser[SQLTypes.DateTime.type] =
426+
"(?i)(datetime)".r ^^ (_ => SQLTypes.DateTime)
427+
428+
def timestamp_type: PackratParser[SQLTypes.Timestamp.type] =
429+
"(?i)(timestamp)".r ^^ (_ => SQLTypes.Timestamp)
430+
431+
def boolean_type: PackratParser[SQLTypes.Boolean.type] =
432+
"(?i)boolean".r ^^ (_ => SQLTypes.Boolean)
433+
434+
def long_type: PackratParser[SQLTypes.Long.type] = "(?i)long".r ^^ (_ => SQLTypes.Long)
435+
436+
def double_type: PackratParser[SQLTypes.Double.type] = "(?i)double".r ^^ (_ => SQLTypes.Double)
437+
438+
def int_type: PackratParser[SQLTypes.Int.type] = "(?i)(int|integer)".r ^^ (_ => SQLTypes.Int)
439+
440+
def sql_type: PackratParser[SQLType] =
441+
string_type | datetime_type | timestamp_type | date_type | time_type | boolean_type | long_type | double_type | int_type
442+
443+
private[this] def castFunctionWithIdentifier: PackratParser[SQLIdentifier] =
444+
"(?i)cast".r ~ start ~ (identifierWithTransformation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | identifier) ~ Alias.regex.? ~ sql_type ~ end ^^ {
445+
case _ ~ _ ~ i ~ as ~ t ~ _ =>
446+
i.copy(functions = SQLCast(i, targetType = t, as = as.isDefined) +: i.functions)
447+
}
448+
395449
private[this] def dateFunctionWithIdentifier: PackratParser[SQLIdentifier] =
396450
(parse_date | format_date | date_add | date_sub) ~ arithmeticFunction.? ^^ { case t ~ af =>
397451
af match {
@@ -409,13 +463,13 @@ trait SQLParser extends RegexParsers with PackratParsers {
409463
}
410464
}
411465

412-
private[this] def logicalFunctionWithIdentifier: PackratParser[SQLIdentifier] =
466+
private[this] def conditionalFunctionWithIdentifier: PackratParser[SQLIdentifier] =
413467
(is_null | is_notnull | coalesce | nullif) ^^ { t =>
414468
t.identifier.copy(functions = t +: t.identifier.functions)
415469
}
416470

417471
def identifierWithTransformation: PackratParser[SQLIdentifier] =
418-
logicalFunctionWithIdentifier | dateFunctionWithIdentifier | dateTimeFunctionWithIdentifier
472+
castFunctionWithIdentifier | conditionalFunctionWithIdentifier | dateFunctionWithIdentifier | dateTimeFunctionWithIdentifier
419473

420474
def arithmeticFunction: PackratParser[SQLArithmeticFunction[_, _]] = intervalFunction
421475

sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ object SQLTypeUtils {
2929
.contains(
3030
out.typeId
3131
)) ||
32+
(out.typeId == String.typeId && in.typeId == String.typeId) ||
33+
(out.typeId == Boolean.typeId && in.typeId == Boolean.typeId) ||
3234
out.typeId == Any.typeId || in.typeId == Any.typeId ||
3335
out.typeId == Null.typeId || in.typeId == Null.typeId
3436

@@ -37,29 +39,29 @@ object SQLTypeUtils {
3739
if (distinct.size == 1) return distinct.head
3840

3941
// 1. String
40-
if (distinct.exists(matches(SQLTypes.String, _))) return SQLTypes.String
42+
if (distinct.contains(SQLTypes.String)) return SQLTypes.String
4143

4244
// 2. Number
43-
if (distinct.exists(matches(SQLTypes.Double, _))) return SQLTypes.Double
44-
if (distinct.exists(matches(SQLTypes.Long, _))) return SQLTypes.Long
45-
if (distinct.exists(matches(SQLTypes.Int, _))) return SQLTypes.Int
46-
if (distinct.exists(matches(SQLTypes.Number, _))) return SQLTypes.Number
45+
if (distinct.contains(SQLTypes.Double)) return SQLTypes.Double
46+
if (distinct.contains(SQLTypes.Long)) return SQLTypes.Long
47+
if (distinct.contains(SQLTypes.Int)) return SQLTypes.Int
48+
if (distinct.contains(SQLTypes.Number)) return SQLTypes.Number
4749

4850
// 3. Temporal
49-
if (distinct.exists(matches(SQLTypes.Timestamp, _))) return SQLTypes.Timestamp
50-
if (distinct.exists(matches(SQLTypes.DateTime, _))) return SQLTypes.DateTime
51+
if (distinct.contains(SQLTypes.Timestamp)) return SQLTypes.Timestamp
52+
if (distinct.contains(SQLTypes.DateTime)) return SQLTypes.DateTime
5153

5254
// mixed case DATE + TIME → DATETIME
53-
if (distinct.exists(matches(SQLTypes.Date, _)) && distinct.exists(matches(SQLTypes.Time, _)))
55+
if (distinct.contains(SQLTypes.Date) && distinct.contains(SQLTypes.Time))
5456
return SQLTypes.DateTime
5557

56-
if (distinct.exists(matches(SQLTypes.Date, _))) return SQLTypes.Date
57-
if (distinct.exists(matches(SQLTypes.Time, _))) return SQLTypes.Time
58-
if (distinct.exists(matches(SQLTypes.Temporal, _))) return SQLTypes.Timestamp
58+
if (distinct.contains(SQLTypes.Date)) return SQLTypes.Date
59+
if (distinct.contains(SQLTypes.Time)) return SQLTypes.Time
60+
if (distinct.contains(SQLTypes.Temporal)) return SQLTypes.Timestamp
5961

6062
// 4. Null or Any
61-
if (distinct.exists(matches(SQLTypes.Null, _))) return SQLTypes.Any
62-
if (distinct.exists(matches(SQLTypes.Any, _))) return SQLTypes.Any
63+
if (distinct.contains(SQLTypes.Null)) return SQLTypes.Any
64+
if (distinct.contains(SQLTypes.Any)) return SQLTypes.Any
6365

6466
// 5. Fallback
6567
SQLTypes.Any
@@ -68,6 +70,11 @@ object SQLTypeUtils {
6870
def coerce(in: PainlessScript, to: SQLType): String = {
6971
val expr = in.painless
7072
val from = in.out
73+
val nullable = in.nullable
74+
coerce(expr, from, to, nullable)
75+
}
76+
77+
def coerce(expr: String, from: SQLType, to: SQLType, nullable: Boolean): String = {
7178
val ret = {
7279
(from, to) match {
7380
// ---- DATE & TIME ----
@@ -105,7 +112,7 @@ object SQLTypeUtils {
105112
return expr // fallback
106113
}
107114
}
108-
if (!in.nullable)
115+
if (!nullable)
109116
return ret
110117
s"($expr != null ? $ret : null)"
111118
}

sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ object Queries {
141141
"select coalesce(createdAt - interval 35 minute, current_date) as c, identifier from Table"
142142
val nullif: String =
143143
"select coalesce(nullif(createdAt, parse_date('2025-09-11', 'yyyy-MM-dd') - interval 2 day), current_date) as c, identifier from Table"
144+
val cast: String =
145+
"select cast(coalesce(nullif(createdAt, parse_date('2025-09-11', 'yyyy-MM-dd')), current_date - interval 2 hour) long) as c, identifier from Table"
144146
}
145147

146148
/** Created by smanciot on 15/02/17.
@@ -556,4 +558,11 @@ class SQLParserSpec extends AnyFlatSpec with Matchers {
556558
nullif
557559
)
558560
}
561+
562+
it should "parse cast function" in {
563+
val result = SQLParser(cast)
564+
result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===(
565+
cast
566+
)
567+
}
559568
}

0 commit comments

Comments
 (0)