Skip to content

Commit 2035b3b

Browse files
committed
add support for arithmetic functions
1 parent 5e595d8 commit 2035b3b

File tree

8 files changed

+410
-80
lines changed

8 files changed

+410
-80
lines changed

es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2061,4 +2061,90 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
20612061
.replaceAll("if \\(\\s*def", "if (def")
20622062
}
20632063

2064+
it should "handle arithmetic function as script field and condition" in {
2065+
val select: ElasticSearchRequest =
2066+
SQLQuery(arithmetic.replace("as group1", ""))
2067+
val query = select.query
2068+
println(query)
2069+
query shouldBe
2070+
"""{
2071+
| "query": {
2072+
| "bool": {
2073+
| "filter": [
2074+
| {
2075+
| "script": {
2076+
| "script": {
2077+
| "lang": "painless",
2078+
| "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 * (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().get(ChronoUnit.YEARS) - 10)) > 10000"
2079+
| }
2080+
| }
2081+
| }
2082+
| ]
2083+
| }
2084+
| },
2085+
| "script_fields": {
2086+
| "add": {
2087+
| "script": {
2088+
| "lang": "painless",
2089+
| "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 + 1)"
2090+
| }
2091+
| },
2092+
| "sub": {
2093+
| "script": {
2094+
| "lang": "painless",
2095+
| "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 - 1)"
2096+
| }
2097+
| },
2098+
| "mul": {
2099+
| "script": {
2100+
| "lang": "painless",
2101+
| "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 * 2)"
2102+
| }
2103+
| },
2104+
| "div": {
2105+
| "script": {
2106+
| "lang": "painless",
2107+
| "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 / 2)"
2108+
| }
2109+
| },
2110+
| "mod": {
2111+
| "script": {
2112+
| "lang": "painless",
2113+
| "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 % 2)"
2114+
| }
2115+
| },
2116+
| "identifier_mul_identifier2_minus_10": {
2117+
| "script": {
2118+
| "lang": "painless",
2119+
| "source": "def lv0 = ((def lv1 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); def rv1 = ((!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value)); ( lv1 == null || rv1 == null ) ? null : (lv1 * rv1))); ( lv0 == null ) ? null : (lv0 - 10)"
2120+
| }
2121+
| }
2122+
| },
2123+
| "_source": {
2124+
| "includes": [
2125+
| "identifier"
2126+
| ]
2127+
| }
2128+
|}""".stripMargin
2129+
.replaceAll("\\s", "")
2130+
.replaceAll("defv", "def v")
2131+
.replaceAll("defe", "def e")
2132+
.replaceAll("defl", "def l")
2133+
.replaceAll("defr", "def r")
2134+
.replaceAll("if\\(", "if (")
2135+
.replaceAll("=\\(", " = (")
2136+
.replaceAll("\\?", " ? ")
2137+
.replaceAll(":null", " : null")
2138+
.replaceAll("null:", "null : ")
2139+
.replaceAll("return", " return ")
2140+
.replaceAll(";", "; ")
2141+
.replaceAll(">", " > ")
2142+
.replaceAll("\\*", " * ")
2143+
.replaceAll("/", " / ")
2144+
.replaceAll("%", " % ")
2145+
.replaceAll("\\+", " + ")
2146+
.replaceAll("-", " - ")
2147+
.replaceAll("==", " == ")
2148+
.replaceAll("\\|\\|", " || ")
2149+
}
20642150
}

sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2050,4 +2050,90 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
20502050
.replaceAll("if \\(\\s*def", "if (def")
20512051
}
20522052

2053+
it should "handle arithmetic function as script field and condition" in {
2054+
val select: ElasticSearchRequest =
2055+
SQLQuery(arithmetic.replace("as group1", ""))
2056+
val query = select.query
2057+
println(query)
2058+
query shouldBe
2059+
"""{
2060+
| "query": {
2061+
| "bool": {
2062+
| "filter": [
2063+
| {
2064+
| "script": {
2065+
| "script": {
2066+
| "lang": "painless",
2067+
| "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 * (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().get(ChronoUnit.YEARS) - 10)) > 10000"
2068+
| }
2069+
| }
2070+
| }
2071+
| ]
2072+
| }
2073+
| },
2074+
| "script_fields": {
2075+
| "add": {
2076+
| "script": {
2077+
| "lang": "painless",
2078+
| "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 + 1)"
2079+
| }
2080+
| },
2081+
| "sub": {
2082+
| "script": {
2083+
| "lang": "painless",
2084+
| "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 - 1)"
2085+
| }
2086+
| },
2087+
| "mul": {
2088+
| "script": {
2089+
| "lang": "painless",
2090+
| "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 * 2)"
2091+
| }
2092+
| },
2093+
| "div": {
2094+
| "script": {
2095+
| "lang": "painless",
2096+
| "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 / 2)"
2097+
| }
2098+
| },
2099+
| "mod": {
2100+
| "script": {
2101+
| "lang": "painless",
2102+
| "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 % 2)"
2103+
| }
2104+
| },
2105+
| "identifier_mul_identifier2_minus_10": {
2106+
| "script": {
2107+
| "lang": "painless",
2108+
| "source": "def lv0 = ((def lv1 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); def rv1 = ((!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value)); ( lv1 == null || rv1 == null ) ? null : (lv1 * rv1))); ( lv0 == null ) ? null : (lv0 - 10)"
2109+
| }
2110+
| }
2111+
| },
2112+
| "_source": {
2113+
| "includes": [
2114+
| "identifier"
2115+
| ]
2116+
| }
2117+
|}""".stripMargin
2118+
.replaceAll("\\s", "")
2119+
.replaceAll("defv", "def v")
2120+
.replaceAll("defe", "def e")
2121+
.replaceAll("defl", "def l")
2122+
.replaceAll("defr", "def r")
2123+
.replaceAll("if\\(", "if (")
2124+
.replaceAll("=\\(", " = (")
2125+
.replaceAll("\\?", " ? ")
2126+
.replaceAll(":null", " : null")
2127+
.replaceAll("null:", "null : ")
2128+
.replaceAll("return", " return ")
2129+
.replaceAll(";", "; ")
2130+
.replaceAll(">", " > ")
2131+
.replaceAll("\\*", " * ")
2132+
.replaceAll("/", " / ")
2133+
.replaceAll("%", " % ")
2134+
.replaceAll("\\+", " + ")
2135+
.replaceAll("-", " - ")
2136+
.replaceAll("==", " == ")
2137+
.replaceAll("\\|\\|", " || ")
2138+
}
20532139
}

sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala

Lines changed: 90 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ trait SQLFunctionChain extends SQLFunction {
9797
}
9898
}
9999

100+
def arithmetic: Boolean = functions.nonEmpty && functions.forall {
101+
case _: SQLArithmeticExpression => true
102+
case _ => false
103+
}
100104
}
101105

102106
sealed trait SQLUnaryFunction[In <: SQLType, Out <: SQLType]
@@ -110,9 +114,7 @@ sealed trait SQLUnaryFunction[In <: SQLType, Out <: SQLType]
110114
}
111115

112116
sealed trait SQLBinaryFunction[In1 <: SQLType, In2 <: SQLType, Out <: SQLType]
113-
extends SQLUnaryFunction[SQLAny, Out] { self: SQLFunction =>
114-
115-
override def inputType: SQLAny = SQLTypes.Any
117+
extends SQLUnaryFunction[In2, Out] { self: SQLFunction =>
116118

117119
def left: PainlessScript
118120
def right: PainlessScript
@@ -122,21 +124,13 @@ sealed trait SQLBinaryFunction[In1 <: SQLType, In2 <: SQLType, Out <: SQLType]
122124

123125
sealed trait SQLTransformFunction[In <: SQLType, Out <: SQLType] extends SQLUnaryFunction[In, Out] {
124126
def toPainless(base: String, idx: Int): String = {
125-
if (nullable)
127+
if (nullable && base.nonEmpty)
126128
s"(def e$idx = $base; e$idx != null ? e$idx$painless : null)"
127129
else
128130
s"$base$painless"
129131
}
130132
}
131133

132-
sealed trait SQLArithmeticFunction[In <: SQLType, Out <: SQLType]
133-
extends SQLTransformFunction[In, Out]
134-
with MathScript {
135-
def operator: ArithmeticOperator
136-
override def toSQL(base: String): String = s"$base$operator$sql"
137-
override def applyType(in: SQLType): SQLType = in
138-
}
139-
140134
sealed trait AggregateFunction extends SQLFunction
141135
case object Count extends SQLExpr("count") with AggregateFunction
142136
case object Min extends SQLExpr("min") with AggregateFunction
@@ -241,7 +235,11 @@ object TimeInterval {
241235
}
242236
}
243237

244-
sealed trait SQLIntervalFunction[IO <: SQLTemporal] extends SQLArithmeticFunction[IO, IO] {
238+
sealed trait SQLIntervalFunction[IO <: SQLTemporal]
239+
extends SQLTransformFunction[IO, IO]
240+
with MathScript {
241+
def operator: IntervalOperator
242+
override def toSQL(base: String): String = s"$base$operator$sql"
245243
def interval: TimeInterval
246244
override def script: String = s"${operator.script}${interval.script}"
247245
private[this] var _out: SQLType = outputType
@@ -265,12 +263,12 @@ sealed trait SQLIntervalFunction[IO <: SQLTemporal] extends SQLArithmeticFunctio
265263
}
266264

267265
sealed trait AddInterval[IO <: SQLTemporal] extends SQLIntervalFunction[IO] {
268-
override def operator: ArithmeticOperator = Add
266+
override def operator: IntervalOperator = Add
269267
override def painless: String = s".plus(${interval.painless})"
270268
}
271269

272270
sealed trait SubtractInterval[IO <: SQLTemporal] extends SQLIntervalFunction[IO] {
273-
override def operator: ArithmeticOperator = Subtract
271+
override def operator: IntervalOperator = Subtract
274272
override def painless: String = s".minus(${interval.painless})"
275273
}
276274

@@ -391,6 +389,7 @@ case class DateDiff(end: PainlessScript, start: PainlessScript, unit: TimeUnit)
391389
with DateTimeFunction
392390
with SQLBinaryFunction[SQLDateTime, SQLDateTime, SQLNumeric]
393391
with PainlessScript {
392+
override def inputType: SQLDateTime = SQLTypes.DateTime
394393
override def outputType: SQLNumeric = SQLTypes.Numeric
395394
override def left: PainlessScript = end
396395
override def right: PainlessScript = start
@@ -751,3 +750,79 @@ case class SQLCaseWhen(
751750
override def nullable: Boolean =
752751
conditions.exists { case (_, res) => res.nullable } || default.forall(_.nullable)
753752
}
753+
754+
case class SQLArithmeticExpression(
755+
left: PainlessScript,
756+
operator: ArithmeticOperator,
757+
right: PainlessScript,
758+
group: Boolean = false
759+
) extends SQLTransformFunction[SQLNumeric, SQLNumeric]
760+
with SQLBinaryFunction[SQLNumeric, SQLNumeric, SQLNumeric] {
761+
762+
override def inputType: SQLNumeric = SQLTypes.Numeric
763+
override def outputType: SQLNumeric = SQLTypes.Numeric
764+
765+
override def applyType(in: SQLType): SQLType = in
766+
767+
override def sql: String = {
768+
val expr = s"${left.sql}$operator${right.sql}"
769+
if (group)
770+
s"($expr)"
771+
else
772+
expr
773+
}
774+
775+
override def out: SQLType =
776+
SQLTypeUtils.leastCommonSuperType(List(left.out, right.out))
777+
778+
override def validate(): Either[String, Unit] = {
779+
for {
780+
_ <- left.validate()
781+
_ <- right.validate()
782+
_ <- SQLValidator.validateTypesMatching(left.out, right.out)
783+
} yield ()
784+
}
785+
786+
override def nullable: Boolean = left.nullable || right.nullable
787+
788+
override def toPainless(base: String, idx: Int): String = {
789+
if (nullable) {
790+
val l = left match {
791+
case t: SQLTransformFunction[_, _] =>
792+
SQLTypeUtils.coerce(t.toPainless("", idx + 1), left.out, out, nullable = false)
793+
case _ => SQLTypeUtils.coerce(left.painless, left.out, out, nullable = false)
794+
}
795+
val r = right match {
796+
case t: SQLTransformFunction[_, _] =>
797+
SQLTypeUtils.coerce(t.toPainless("", idx + 1), right.out, out, nullable = false)
798+
case _ => SQLTypeUtils.coerce(right.painless, right.out, out, nullable = false)
799+
}
800+
var expr = ""
801+
if (left.nullable)
802+
expr += s"def lv$idx = ($l); "
803+
if (right.nullable)
804+
expr += s"def rv$idx = ($r); "
805+
if (left.nullable && right.nullable)
806+
expr += s"(lv$idx == null || rv$idx == null) ? null : (lv$idx ${operator.painless} rv$idx)"
807+
else if (left.nullable)
808+
expr += s"(lv$idx == null) ? null : (lv$idx ${operator.painless} $r)"
809+
else
810+
expr += s"(rv$idx == null) ? null : ($l ${operator.painless} rv$idx)"
811+
if (group)
812+
expr = s"($expr)"
813+
return s"$base$expr"
814+
}
815+
s"$base$painless"
816+
}
817+
818+
override def painless: String = {
819+
val l = SQLTypeUtils.coerce(left, out)
820+
val r = SQLTypeUtils.coerce(right, out)
821+
val expr = s"$l ${operator.painless} $r"
822+
if (group)
823+
s"($expr)"
824+
else
825+
expr
826+
}
827+
828+
}

sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ sealed trait ArithmeticOperator extends SQLOperator with MathScript {
1717
override def toString: String = s" $sql "
1818
override def script: String = sql
1919
}
20-
case object Add extends SQLExpr("+") with ArithmeticOperator
21-
case object Subtract extends SQLExpr("-") with ArithmeticOperator
20+
21+
sealed trait IntervalOperator extends ArithmeticOperator
22+
23+
case object Add extends SQLExpr("+") with IntervalOperator
24+
case object Subtract extends SQLExpr("-") with IntervalOperator
2225
case object Multiply extends SQLExpr("*") with ArithmeticOperator
2326
case object Divide extends SQLExpr("/") with ArithmeticOperator
2427
case object Modulo extends SQLExpr("%") with ArithmeticOperator

0 commit comments

Comments
 (0)