Skip to content

Commit d4f33e7

Browse files
committed
add generic AddInterval and SubtractInterval traits, implements painless for all criteria, update painless for generic expression
1 parent 04c3508 commit d4f33e7

File tree

4 files changed

+169
-113
lines changed

4 files changed

+169
-113
lines changed

es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -984,38 +984,47 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
984984
println(query)
985985
query shouldBe
986986
"""{
987-
| "query": {
988-
| "bool": {
989-
| "filter": [
990-
| {
991-
| "script": {
992-
| "script": {
993-
| "lang": "painless",
994-
| "source": "doc['createdAt'].value.toLocalTime() < ZonedDateTime.now(ZoneId.of('Z')).toLocalTime()"
995-
| }
996-
| }
997-
| },
998-
| {
999-
| "script": {
1000-
| "script": {
1001-
| "lang": "painless",
1002-
| "source": "doc['createdAt'].value.toLocalTime() >= ZonedDateTime.now(ZoneId.of('Z')).toLocalTime().minus(10, ChronoUnit.MINUTES)"
1003-
| }
1004-
| }
1005-
| }
1006-
| ]
1007-
| }
1008-
| },
1009-
| "_source": {
1010-
| "includes": [
1011-
| "*"
1012-
| ]
1013-
| }
1014-
|}""".stripMargin
987+
| "query": {
988+
| "bool": {
989+
| "filter": [
990+
| {
991+
| "script": {
992+
| "script": {
993+
| "lang": "painless",
994+
| "source": "def left = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); left == null ? false : left < ZonedDateTime.now(ZoneId.of('Z')).toLocalTime()"
995+
| }
996+
| }
997+
| },
998+
| {
999+
| "script": {
1000+
| "script": {
1001+
| "lang": "painless",
1002+
| "source": "def left = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); left == null ? false : left >= ZonedDateTime.now(ZoneId.of('Z')).toLocalTime().minus(10, ChronoUnit.MINUTES)"
1003+
| }
1004+
| }
1005+
| }
1006+
| ]
1007+
| }
1008+
| },
1009+
| "_source": {
1010+
| "includes": [
1011+
| "*"
1012+
| ]
1013+
| }
1014+
|}""".stripMargin
10151015
.replaceAll("\\s", "")
10161016
.replaceAll("ChronoUnit", " ChronoUnit")
10171017
.replaceAll(">=", " >= ")
10181018
.replaceAll("<", " < ")
1019+
.replaceAll("\\|\\|", " || ")
1020+
.replaceAll("null:", "null : ")
1021+
.replaceAll("false:", "false : ")
1022+
.replaceAll(":null", " : null ")
1023+
.replaceAll("\\?", " ? ")
1024+
.replaceAll("==", " == ")
1025+
.replaceAll("\\);", "); ")
1026+
.replaceAll("=\\(", " = (")
1027+
.replaceAll("defl", "def l")
10191028
}
10201029

10211030
it should "handle having with date functions" in {

sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -980,40 +980,48 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
980980
SQLQuery(filterWithTimeAndInterval)
981981
val query = select.query
982982
println(query)
983-
query shouldBe
984-
"""{
985-
| "query": {
986-
| "bool": {
987-
| "filter": [
988-
| {
989-
| "script": {
990-
| "script": {
991-
| "lang": "painless",
992-
| "source": "doc['createdAt'].value.toLocalTime() < ZonedDateTime.now(ZoneId.of('Z')).toLocalTime()"
993-
| }
994-
| }
995-
| },
996-
| {
997-
| "script": {
998-
| "script": {
999-
| "lang": "painless",
1000-
| "source": "doc['createdAt'].value.toLocalTime() >= ZonedDateTime.now(ZoneId.of('Z')).toLocalTime().minus(10, ChronoUnit.MINUTES)"
1001-
| }
1002-
| }
1003-
| }
1004-
| ]
1005-
| }
1006-
| },
1007-
| "_source": {
1008-
| "includes": [
1009-
| "*"
1010-
| ]
1011-
| }
1012-
|}""".stripMargin
1013-
.replaceAll("\\s", "")
1014-
.replaceAll("ChronoUnit", " ChronoUnit")
1015-
.replaceAll(">=", " >= ")
1016-
.replaceAll("<", " < ")
983+
"""{
984+
| "query": {
985+
| "bool": {
986+
| "filter": [
987+
| {
988+
| "script": {
989+
| "script": {
990+
| "lang": "painless",
991+
| "source": "def left = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); left == null ? false : left < ZonedDateTime.now(ZoneId.of('Z')).toLocalTime()"
992+
| }
993+
| }
994+
| },
995+
| {
996+
| "script": {
997+
| "script": {
998+
| "lang": "painless",
999+
| "source": "def left = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); left == null ? false : left >= ZonedDateTime.now(ZoneId.of('Z')).toLocalTime().minus(10, ChronoUnit.MINUTES)"
1000+
| }
1001+
| }
1002+
| }
1003+
| ]
1004+
| }
1005+
| },
1006+
| "_source": {
1007+
| "includes": [
1008+
| "*"
1009+
| ]
1010+
| }
1011+
|}""".stripMargin
1012+
.replaceAll("\\s", "")
1013+
.replaceAll("ChronoUnit", " ChronoUnit")
1014+
.replaceAll(">=", " >= ")
1015+
.replaceAll("<", " < ")
1016+
.replaceAll("\\|\\|", " || ")
1017+
.replaceAll("null:", "null : ")
1018+
.replaceAll("false:", "false : ")
1019+
.replaceAll(":null", " : null ")
1020+
.replaceAll("\\?", " ? ")
1021+
.replaceAll("==", " == ")
1022+
.replaceAll("\\);", "); ")
1023+
.replaceAll("=\\(", " = (")
1024+
.replaceAll("defl", "def l")
10171025
}
10181026

10191027
it should "handle having with date functions" in {

sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ sealed trait TimeInterval extends PainlessScript with MathScript {
211211

212212
override def script: String = TimeInterval.script(this)
213213

214-
def applyType(in: SQLType): Either[String, SQLType] = {
214+
def checkType(in: SQLType): Either[String, SQLType] = {
215215
import TimeUnit._
216216
in match {
217217
case SQLTypes.Date =>
@@ -254,20 +254,18 @@ object TimeInterval {
254254
}
255255
}
256256

257-
sealed trait SQLIntervalFunction extends SQLArithmeticFunction[SQLTemporal, SQLTemporal] {
257+
sealed trait SQLIntervalFunction[IO <: SQLTemporal] extends SQLArithmeticFunction[IO, IO] {
258258
def interval: TimeInterval
259-
override def inputType: SQLTemporal = SQLTypes.Temporal
260-
override def outputType: SQLTemporal = SQLTypes.Temporal
261259
override def script: String = s"${operator.script}${interval.script}"
262260
private[this] var _out: SQLType = outputType
263261
override def out: SQLType = _out
264262

265263
override def applyType(in: SQLType): SQLType = {
266-
_out = interval.applyType(in).getOrElse(out)
264+
_out = interval.checkType(in).getOrElse(out)
267265
_out
268266
}
269267

270-
override def validate(): Either[String, Unit] = interval.applyType(out) match {
268+
override def validate(): Either[String, Unit] = interval.checkType(out) match {
271269
case Left(err) => Left(err)
272270
case Right(_) => Right(())
273271
}
@@ -279,20 +277,30 @@ sealed trait SQLIntervalFunction extends SQLArithmeticFunction[SQLTemporal, SQLT
279277
s"${SQLTypeUtils.coerce(base, expr.out, out, nullable = expr.nullable)}$painless"
280278
}
281279

282-
case class SQLAddInterval(interval: TimeInterval)
283-
extends SQLExpr(interval.sql)
284-
with SQLIntervalFunction {
280+
sealed trait AddInterval[IO <: SQLTemporal] extends SQLIntervalFunction[IO] {
285281
override def operator: ArithmeticOperator = Add
286282
override def painless: String = s".plus(${interval.painless})"
287283
}
288284

289-
case class SQLSubtractInterval(interval: TimeInterval)
290-
extends SQLExpr(interval.sql)
291-
with SQLIntervalFunction {
285+
sealed trait SubtractInterval[IO <: SQLTemporal] extends SQLIntervalFunction[IO] {
292286
override def operator: ArithmeticOperator = Subtract
293287
override def painless: String = s".minus(${interval.painless})"
294288
}
295289

290+
case class SQLAddInterval(interval: TimeInterval)
291+
extends SQLExpr(interval.sql)
292+
with AddInterval[SQLTemporal] {
293+
override def inputType: SQLTemporal = SQLTypes.Temporal
294+
override def outputType: SQLTemporal = SQLTypes.Temporal
295+
}
296+
297+
case class SQLSubtractInterval(interval: TimeInterval)
298+
extends SQLExpr(interval.sql)
299+
with SubtractInterval[SQLTemporal] {
300+
override def inputType: SQLTemporal = SQLTypes.Temporal
301+
override def outputType: SQLTemporal = SQLTypes.Temporal
302+
}
303+
296304
sealed trait DateTimeFunction extends SQLFunction {
297305
def now: String = "ZonedDateTime.now(ZoneId.of('Z'))"
298306
override def out: SQLType = SQLTypes.DateTime
@@ -418,27 +426,27 @@ case class DateDiff(end: PainlessScript, start: PainlessScript, unit: TimeUnit)
418426
case class DateAdd(identifier: SQLIdentifier, interval: TimeInterval)
419427
extends SQLExpr("date_add")
420428
with DateFunction
429+
with AddInterval[SQLDate]
421430
with SQLTransformFunction[SQLDate, SQLDate]
422431
with SQLFunctionWithIdentifier {
423432
override def inputType: SQLDate = SQLTypes.Date
424433
override def outputType: SQLDate = SQLTypes.Date
425434
override def toSQL(base: String): String = {
426435
s"$sql($base, ${interval.sql})"
427436
}
428-
override def painless: String = s".plus(${interval.painless})"
429437
}
430438

431439
case class DateSub(identifier: SQLIdentifier, interval: TimeInterval)
432440
extends SQLExpr("date_sub")
433441
with DateFunction
442+
with SubtractInterval[SQLDate]
434443
with SQLTransformFunction[SQLDate, SQLDate]
435444
with SQLFunctionWithIdentifier {
436445
override def inputType: SQLDate = SQLTypes.Date
437446
override def outputType: SQLDate = SQLTypes.Date
438447
override def toSQL(base: String): String = {
439448
s"$sql($base, ${interval.sql})"
440449
}
441-
override def painless: String = s".minus(${interval.painless})"
442450
}
443451

444452
case class ParseDate(identifier: SQLIdentifier, format: String)
@@ -480,27 +488,27 @@ case class FormatDate(identifier: SQLIdentifier, format: String)
480488
case class DateTimeAdd(identifier: SQLIdentifier, interval: TimeInterval)
481489
extends SQLExpr("datetime_add")
482490
with DateTimeFunction
491+
with AddInterval[SQLDateTime]
483492
with SQLTransformFunction[SQLDateTime, SQLDateTime]
484493
with SQLFunctionWithIdentifier {
485494
override def inputType: SQLDateTime = SQLTypes.DateTime
486495
override def outputType: SQLDateTime = SQLTypes.DateTime
487496
override def toSQL(base: String): String = {
488497
s"$sql($base, ${interval.sql})"
489498
}
490-
override def painless: String = s".plus(${interval.painless})"
491499
}
492500

493501
case class DateTimeSub(identifier: SQLIdentifier, interval: TimeInterval)
494502
extends SQLExpr("datetime_sub")
495503
with DateTimeFunction
504+
with SubtractInterval[SQLDateTime]
496505
with SQLTransformFunction[SQLDateTime, SQLDateTime]
497506
with SQLFunctionWithIdentifier {
498507
override def inputType: SQLDateTime = SQLTypes.DateTime
499508
override def outputType: SQLDateTime = SQLTypes.DateTime
500509
override def toSQL(base: String): String = {
501510
s"$sql($base, ${interval.sql})"
502511
}
503-
override def painless: String = s".minus(${interval.painless})"
504512
}
505513

506514
case class ParseDateTime(identifier: SQLIdentifier, format: String)

0 commit comments

Comments
 (0)