Skip to content

Commit 0aaaf71

Browse files
committed
fix painless script for coalesce, nullif, case functions
1 parent da9d6f1 commit 0aaaf71

File tree

9 files changed

+108
-72
lines changed

9 files changed

+108
-72
lines changed

es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2153,7 +2153,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
21532153
| "c": {
21542154
| "script": {
21552155
| "lang": "painless",
2156-
| "source": "def param1 = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(7, ChronoUnit.DAYS); def param2 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value.minus(3, ChronoUnit.DAYS)); def param3 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value.plus(2, ChronoUnit.DAYS)); def param4 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); param1 == param2 ? param2 : param1 == param3 ? param3 : param4"
2156+
| "source": "def param1 = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(7, ChronoUnit.DAYS); def param2 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value.toLocalDate().minus(3, ChronoUnit.DAYS)); def param3 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value.toLocalDate().plus(2, ChronoUnit.DAYS)); def param4 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value.toLocalDate()); param1 == param2 ? param2 : param1 == param3 ? param3 : param4"
21572157
| }
21582158
| }
21592159
| },
@@ -2748,7 +2748,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
27482748
| "hire_date": {
27492749
| "script": {
27502750
| "lang": "painless",
2751-
| "source": "def param1 = (!doc.containsKey('hire_date') || doc['hire_date'].empty ? null : doc['hire_date'].value); param1"
2751+
| "source": "def param1 = (!doc.containsKey('hire_date') || doc['hire_date'].empty ? null : doc['hire_date'].value.toLocalDate()); param1"
27522752
| }
27532753
| }
27542754
| },
@@ -2882,7 +2882,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
28822882
| "ld": {
28832883
| "script": {
28842884
| "lang": "painless",
2885-
| "source": "def param1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); (param1 == null) ? null : param1.withDayOfMonth(param1.lengthOfMonth())"
2885+
| "source": "def param1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value.toLocalDate()); (param1 == null) ? null : param1.withDayOfMonth(param1.lengthOfMonth())"
28862886
| }
28872887
| }
28882888
| },
@@ -3503,7 +3503,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
35033503
| "script": {
35043504
| "script": {
35053505
| "lang": "painless",
3506-
| "source": "def param1 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); def param2 = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate(); param1 == null ? false : (param1.isBefore(param2))"
3506+
| "source": "def param1 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value.toLocalDate()); def param2 = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate(); param1 == null ? false : (param1.isBefore(param2))"
35073507
| }
35083508
| }
35093509
| }

sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2157,7 +2157,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
21572157
| "c": {
21582158
| "script": {
21592159
| "lang": "painless",
2160-
| "source": "def param1 = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(7, ChronoUnit.DAYS); def param2 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value.minus(3, ChronoUnit.DAYS)); def param3 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value.plus(2, ChronoUnit.DAYS)); def param4 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); param1 == param2 ? param2 : param1 == param3 ? param3 : param4"
2160+
| "source": "def param1 = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(7, ChronoUnit.DAYS); def param2 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value.toLocalDate().minus(3, ChronoUnit.DAYS)); def param3 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value.toLocalDate().plus(2, ChronoUnit.DAYS)); def param4 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value.toLocalDate()); param1 == param2 ? param2 : param1 == param3 ? param3 : param4"
21612161
| }
21622162
| }
21632163
| },
@@ -2752,7 +2752,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
27522752
| "hire_date": {
27532753
| "script": {
27542754
| "lang": "painless",
2755-
| "source": "def param1 = (!doc.containsKey('hire_date') || doc['hire_date'].empty ? null : doc['hire_date'].value); param1"
2755+
| "source": "def param1 = (!doc.containsKey('hire_date') || doc['hire_date'].empty ? null : doc['hire_date'].value.toLocalDate()); param1"
27562756
| }
27572757
| }
27582758
| },
@@ -2886,7 +2886,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
28862886
| "ld": {
28872887
| "script": {
28882888
| "lang": "painless",
2889-
| "source": "def param1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); (param1 == null) ? null : param1.withDayOfMonth(param1.lengthOfMonth())"
2889+
| "source": "def param1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value.toLocalDate()); (param1 == null) ? null : param1.withDayOfMonth(param1.lengthOfMonth())"
28902890
| }
28912891
| }
28922892
| },
@@ -3507,7 +3507,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
35073507
| "script": {
35083508
| "script": {
35093509
| "lang": "painless",
3510-
| "source": "def param1 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); def param2 = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate(); param1 == null ? false : (param1.isBefore(param2))"
3510+
| "source": "def param1 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value.toLocalDate()); def param2 = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate(); param1 == null ? false : (param1.isBefore(param2))"
35113511
| }
35123512
| }
35133513
| }

sql/src/main/scala/app/softnetwork/elastic/sql/function/cond/package.scala

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ package object cond {
105105

106106
override def args: List[PainlessScript] = values
107107

108-
override def outputType: SQLType = SQLTypeUtils.leastCommonSuperType(args.map(_.baseType))
108+
override def outputType: SQLType =
109+
baseType //SQLTypeUtils.leastCommonSuperType(args.map(_.baseType))
109110

110111
override def identifier: Identifier = Identifier()
111112

@@ -114,10 +115,9 @@ package object cond {
114115
override def sql: String = s"$Coalesce(${values.map(_.sql).mkString(", ")})"
115116

116117
// Reprend l’idée de SQLValues mais pour n’importe quel token
117-
override def baseType: SQLType =
118-
SQLTypeUtils.leastCommonSuperType(values.map(_.baseType).distinct)
118+
override def baseType: SQLType = SQLTypeUtils.leastCommonSuperType(argTypes)
119119

120-
override def applyType(in: SQLType): SQLType = out
120+
override def applyType(in: SQLType): SQLType = baseType
121121

122122
override def validate(): Either[String, Unit] = {
123123
if (values.isEmpty) Left("COALESCE requires at least one argument")
@@ -151,14 +151,17 @@ package object cond {
151151

152152
override def inputType: SQLAny = SQLTypes.Any
153153

154-
override def baseType: SQLType = expr1.out
154+
override def baseType: SQLType = SQLTypeUtils.leastCommonSuperType(argTypes)
155155

156-
override def applyType(in: SQLType): SQLType = out
156+
override def applyType(in: SQLType): SQLType = baseType
157157

158-
override def checkIfNullable: Boolean = expr1.nullable && (expr1 match {
158+
private[this] def checkIfExpressionNullable(expr: PainlessScript): Boolean = expr match {
159159
case f: FunctionChain if f.functions.nonEmpty => true
160160
case _ => false
161-
})
161+
}
162+
163+
override def checkIfNullable: Boolean =
164+
false //checkIfExpressionNullable(expr1) || checkIfExpressionNullable(expr2)
162165

163166
override def toPainlessCall(
164167
callArgs: List[String],
@@ -195,7 +198,9 @@ package object cond {
195198
conditions: List[(PainlessScript, PainlessScript)],
196199
default: Option[PainlessScript]
197200
) extends TransformFunction[SQLAny, SQLAny] {
198-
override def args: List[PainlessScript] = List.empty
201+
override def args: List[PainlessScript] = expression.toList ++
202+
conditions.map { case (_, res) => res } ++
203+
default.toList
199204

200205
override def inputType: SQLAny = SQLTypes.Any
201206
override def outputType: SQLAny = SQLTypes.Any
@@ -210,9 +215,7 @@ package object cond {
210215
}
211216

212217
override def baseType: SQLType =
213-
SQLTypeUtils.leastCommonSuperType(
214-
conditions.map(_._2.baseType) ++ default.map(_.baseType).toList
215-
)
218+
SQLTypeUtils.leastCommonSuperType(argTypes)
216219

217220
override def applyType(in: SQLType): SQLType = baseType
218221

@@ -237,7 +240,7 @@ package object cond {
237240
var cases =
238241
expression match {
239242
case Some(expr) => // case with expression to evaluate
240-
val e = SQLTypeUtils.coerce(expr, expr.out, context)
243+
val e = SQLTypeUtils.coerce(expr, out, context)
241244
val expParam = ctx.addParam(
242245
LiteralParam(e)
243246
)
@@ -253,16 +256,14 @@ package object cond {
253256
i.name
254257
case _ => ""
255258
}
256-
val c = SQLTypeUtils.coerce(cond, expr.out, context)
259+
val c = SQLTypeUtils.coerce(cond, out, context)
257260
val r =
258261
res match {
259262
case i: Identifier if i.name == name && name.nonEmpty =>
260263
i.withNullable(false)
261264
SQLTypeUtils.coerce(
262-
i.painless(context),
263-
i.baseType,
265+
i,
264266
out,
265-
nullable = false,
266267
context
267268
)
268269
case _ =>
@@ -302,10 +303,8 @@ package object cond {
302303
case i: Identifier if i.name == name && name.nonEmpty =>
303304
i.withNullable(false)
304305
SQLTypeUtils.coerce(
305-
i.painless(context),
306-
i.baseType,
306+
i,
307307
out,
308-
nullable = false,
309308
context
310309
)
311310
case _ =>

sql/src/main/scala/app/softnetwork/elastic/sql/function/convert/package.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ import app.softnetwork.elastic.sql.{
2020
Alias,
2121
DateMathRounding,
2222
Expr,
23+
Identifier,
2324
PainlessContext,
2425
PainlessScript,
2526
TokenRegex
2627
}
27-
import app.softnetwork.elastic.sql.`type`.{SQLType, SQLTypeUtils}
28+
import app.softnetwork.elastic.sql.`type`.{SQLType, SQLTypeUtils, SQLTypes}
2829

2930
package object convert {
3031

@@ -46,6 +47,29 @@ package object convert {
4647
SQLTypeUtils.coerce(value, targetType, context)
4748

4849
override def toPainless(base: String, idx: Int, context: Option[PainlessContext]): String = {
50+
context match {
51+
case Some(ctx) =>
52+
value match {
53+
case _: Identifier =>
54+
inputType match {
55+
case SQLTypes.Any =>
56+
ctx.find(base) match {
57+
case Some(identifier) =>
58+
outputType match {
59+
case SQLTypes.Date =>
60+
identifier.addPainlessMethod(".toLocalDate()")
61+
case SQLTypes.Time =>
62+
identifier.addPainlessMethod(".toLocalTime()")
63+
case _ => // do nothing
64+
}
65+
case _ => // do nothing
66+
}
67+
case _ => // do nothing
68+
}
69+
case _ => // do nothing
70+
}
71+
case _ => // do nothing
72+
}
4973
val ret = SQLTypeUtils.coerce(base, value.baseType, targetType, value.nullable, context)
5074
val bloc = ret.startsWith("{") && ret.endsWith("}")
5175
val retWithBrackets = if (bloc) ret else s"{ $ret }"

sql/src/main/scala/app/softnetwork/elastic/sql/function/package.scala

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,6 @@ package object function {
147147

148148
override def applyType(in: SQLType): SQLType = outputType
149149

150-
lazy val targetedType: SQLType = SQLTypeUtils.leastCommonSuperType(argTypes)
151-
152150
override def sql: String =
153151
s"${fun.map(_.sql).getOrElse("")}(${args.map(_.sql).mkString(argsSeparator)})"
154152

@@ -198,27 +196,10 @@ package object function {
198196
case chain: FunctionChain if chain.functions.nonEmpty =>
199197
val ret = SQLTypeUtils
200198
.coerce(
201-
a.painless(context),
202-
a.baseType,
199+
a,
203200
argTypes(i),
204-
nullable = a.nullable,
205201
context
206202
)
207-
a match {
208-
case identifier: Identifier =>
209-
identifier.baseType match {
210-
case SQLTypes.Any => // in painless context, Any is ZonedDateTime
211-
targetedType match {
212-
case SQLTypes.Date =>
213-
identifier.addPainlessMethod(".toLocalDate()")
214-
case SQLTypes.Time =>
215-
identifier.addPainlessMethod(".toLocalTime()")
216-
case _ =>
217-
}
218-
case _ =>
219-
}
220-
case _ =>
221-
}
222203
if (ret.startsWith(".")) {
223204
// apply methods
224205
ctx.find(paramName) match {
@@ -235,7 +216,7 @@ package object function {
235216
case identifier: Identifier =>
236217
identifier.baseType match {
237218
case SQLTypes.Any => // in painless context, Any is ZonedDateTime
238-
targetedType match {
219+
out match {
239220
case SQLTypes.Date =>
240221
identifier.addPainlessMethod(".toLocalDate()")
241222
case SQLTypes.Time =>

sql/src/main/scala/app/softnetwork/elastic/sql/function/time/package.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,17 @@ package object time {
6262
case _ => None
6363
}
6464

65-
private[this] var _out: SQLType = outputType
65+
//private[this] var _out: SQLType = outputType
6666

67-
override def out: SQLType = _out
67+
//override def out: SQLType = _out
6868

6969
override def applyType(in: SQLType): SQLType = {
70-
_out = interval.checkType(in).getOrElse(out)
71-
_out
70+
interval.checkType(in) match {
71+
case Left(_) => baseType
72+
case Right(_) => cast(in)
73+
}
74+
//_out = interval.checkType(in).getOrElse(out)
75+
//_out
7276
}
7377

7478
override def validate(): Either[String, Unit] = interval.checkType(out) match {

sql/src/main/scala/app/softnetwork/elastic/sql/package.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ package object sql {
5252
def in: SQLType = baseType
5353
private[this] var _out: SQLType = SQLTypes.Null
5454
def out: SQLType = if (_out == SQLTypes.Null) baseType else _out
55-
def out_=(t: SQLType): Unit = {
55+
/*def out_=(t: SQLType): Unit = {
5656
_out = t
57-
}
57+
}*/
5858
def cast(targetType: SQLType): SQLType = {
59-
this.out = targetType
59+
this._out = targetType
6060
this.out
6161
}
6262
def system: Boolean = false

sql/src/main/scala/app/softnetwork/elastic/sql/query/Where.scala

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,25 @@ sealed trait Expression extends FunctionChain with ElasticFilter with Criteria {
389389
case Some(v) => SQLTypeUtils.leastCommonSuperType(List(identifier.out, v.out))
390390
case None => identifier.out
391391
}
392+
context match {
393+
case Some(ctx) =>
394+
ctx.addParam(identifier) match {
395+
case Some(_) =>
396+
identifier.baseType match {
397+
case SQLTypes.Any => // in painless context, Any is ZonedDateTime
398+
maybeValue.map(_.out).getOrElse(SQLTypes.Any) match {
399+
case SQLTypes.Date =>
400+
identifier.addPainlessMethod(".toLocalDate()")
401+
case SQLTypes.Time =>
402+
identifier.addPainlessMethod(".toLocalTime()")
403+
case _ =>
404+
}
405+
case _ =>
406+
}
407+
case _ => // do nothing
408+
}
409+
case _ => // do nothing
410+
}
392411
SQLTypeUtils.coerce(identifier, targetedType, context)
393412
}
394413

@@ -452,21 +471,11 @@ sealed trait Expression extends FunctionChain with ElasticFilter with Criteria {
452471
}
453472

454473
override def painless(context: Option[PainlessContext]): String = {
474+
val innerLeft = left(context)
455475
context match {
456476
case Some(ctx) =>
457-
ctx.addParam(identifier) match {
477+
ctx.get(identifier) match {
458478
case Some(p) =>
459-
identifier.baseType match {
460-
case SQLTypes.Any => // in painless context, Any is ZonedDateTime
461-
maybeValue.map(_.out).getOrElse(SQLTypes.Any) match {
462-
case SQLTypes.Date =>
463-
identifier.addPainlessMethod(".toLocalDate()")
464-
case SQLTypes.Time =>
465-
identifier.addPainlessMethod(".toLocalTime()")
466-
case _ =>
467-
}
468-
case _ =>
469-
}
470479
if (identifier.nullable)
471480
return s"$p == null ? false : $painlessNot(${check(context, p)})"
472481
else
@@ -476,9 +485,9 @@ sealed trait Expression extends FunctionChain with ElasticFilter with Criteria {
476485
case _ =>
477486
}
478487
if (identifier.nullable) {
479-
return s"def left = ${left(context)}; left == null ? false : $painlessNot(${check(context, "left")})"
488+
return s"def left = $innerLeft; left == null ? false : $painlessNot(${check(context, "left")})"
480489
}
481-
s"$painlessNot${check(context, left(context))}"
490+
s"$painlessNot${check(context, innerLeft)}"
482491
}
483492

484493
override def validate(): Either[String, Unit] = {

0 commit comments

Comments
 (0)