Skip to content

Commit 7a65f9c

Browse files
committed
add SQLLogicalFunction and corresponding criteria, implements PainlessScript for every expression, mixed SQLToken with SQLValidation, validate the sql search request after parsing
1 parent 10b7165 commit 7a65f9c

File tree

17 files changed

+498
-107
lines changed

17 files changed

+498
-107
lines changed

es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ import app.softnetwork.elastic.sql.{
1313
SQLExpression,
1414
SQLIn,
1515
SQLIsNotNull,
16-
SQLIsNull
16+
SQLIsNotNullCriteria,
17+
SQLIsNull,
18+
SQLIsNullCriteria
1719
}
1820
import com.sksamuel.elastic4s.ElasticApi._
1921
import com.sksamuel.elastic4s.searches.queries.Query
@@ -71,6 +73,8 @@ case class ElasticQuery(filter: ElasticFilter) {
7173
case geoDistance: ElasticGeoDistance => geoDistance
7274
case matchExpression: ElasticMatch => matchExpression
7375
case dateMath: SQLComparisonDateMath => dateMath
76+
case isNull: SQLIsNullCriteria => isNull
77+
case isNotNull: SQLIsNotNullCriteria => isNotNull
7478
case other =>
7579
throw new IllegalArgumentException(s"Unsupported filter type: ${other.getClass.getName}")
7680
}

es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,20 @@ package object bridge {
335335
existsQuery(identifier.name)
336336
}
337337

338+
implicit def isNullCriteriaToQuery(
339+
isNull: SQLIsNullCriteria
340+
): Query = {
341+
import isNull._
342+
not(existsQuery(identifier.name))
343+
}
344+
345+
implicit def isNotNullCriteriaToQuery(
346+
isNotNull: SQLIsNotNullCriteria
347+
): Query = {
348+
import isNotNull._
349+
existsQuery(identifier.name)
350+
}
351+
338352
implicit def inToQuery[R, T <: SQLValue[R]](in: SQLIn[R, T]): Query = {
339353
import in._
340354
val _values: Seq[Any] = values.innerValues

es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,4 +1475,109 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
14751475
|}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit")
14761476
}
14771477

1478+
it should "handle is_null function as script field" in {
1479+
val select: ElasticSearchRequest =
1480+
SQLQuery(isnull)
1481+
val query = select.query
1482+
println(query)
1483+
query shouldBe
1484+
"""{
1485+
| "query": {
1486+
| "match_all": {}
1487+
| },
1488+
| "script_fields": {
1489+
| "flag": {
1490+
| "script": {
1491+
| "lang": "painless",
1492+
| "source": "(doc['identifier'].value == null)"
1493+
| }
1494+
| }
1495+
| },
1496+
| "_source": true
1497+
|}""".stripMargin.replaceAll("\\s+", "").replaceAll("==", " == ")
1498+
}
1499+
1500+
it should "handle is_notnull function as script field" in {
1501+
val select: ElasticSearchRequest =
1502+
SQLQuery(isnotnull)
1503+
val query = select.query
1504+
println(query)
1505+
query shouldBe
1506+
"""{
1507+
| "query": {
1508+
| "match_all": {}
1509+
| },
1510+
| "script_fields": {
1511+
| "flag": {
1512+
| "script": {
1513+
| "lang": "painless",
1514+
| "source": "(doc['identifier2'].value != null)"
1515+
| }
1516+
| }
1517+
| },
1518+
| "_source": {
1519+
| "includes": [
1520+
| "identifier"
1521+
| ]
1522+
| }
1523+
|}""".stripMargin.replaceAll("\\s+", "").replaceAll("!=", " != ")
1524+
}
1525+
1526+
it should "handle is_null criteria as must_not exists" in {
1527+
val select: ElasticSearchRequest =
1528+
SQLQuery(isNullCriteria)
1529+
val query = select.query
1530+
println(query)
1531+
query shouldBe
1532+
"""{
1533+
| "query": {
1534+
| "bool": {
1535+
| "filter": [
1536+
| {
1537+
| "bool": {
1538+
| "must_not": [
1539+
| {
1540+
| "exists": {
1541+
| "field": "identifier"
1542+
| }
1543+
| }
1544+
| ]
1545+
| }
1546+
| }
1547+
| ]
1548+
| }
1549+
| },
1550+
| "_source": {
1551+
| "includes": [
1552+
| "*"
1553+
| ]
1554+
| }
1555+
|}""".stripMargin.replaceAll("\\s+", "")
1556+
}
1557+
1558+
it should "handle is_notnull criteria as exists" in {
1559+
val select: ElasticSearchRequest =
1560+
SQLQuery(isNotNullCriteria)
1561+
val query = select.query
1562+
println(query)
1563+
query shouldBe
1564+
"""{
1565+
| "query": {
1566+
| "bool": {
1567+
| "filter": [
1568+
| {
1569+
| "exists": {
1570+
| "field": "identifier"
1571+
| }
1572+
| }
1573+
| ]
1574+
| }
1575+
| },
1576+
| "_source": {
1577+
| "includes": [
1578+
| "*"
1579+
| ]
1580+
| }
1581+
|}""".stripMargin.replaceAll("\\s+", "")
1582+
}
14781583
}

sql/src/main/scala/app/softnetwork/elastic/sql/SQLFrom.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,12 @@ case class SQLFrom(tables: Seq[SQLTable]) extends Updateable {
2727
}
2828
def update(request: SQLSearchRequest): SQLFrom =
2929
this.copy(tables = tables.map(_.update(request)))
30+
31+
override def validate(): Either[String, Unit] = {
32+
if (tables.isEmpty) {
33+
Left("At least one table is required in FROM clause")
34+
} else {
35+
Right(())
36+
}
37+
}
3038
}

sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import scala.util.matching.Regex
44

55
sealed trait SQLFunction extends SQLRegex {
66
def toSQL(base: String): String = if (base.nonEmpty) s"$sql($base)" else sql
7+
def system: Boolean = false
78
}
89

910
sealed trait SQLFunctionWithIdentifier extends SQLFunction {
@@ -26,7 +27,7 @@ object SQLFunctionUtils {
2627

2728
}
2829

29-
trait SQLFunctionChain extends SQLFunction with SQLValidation {
30+
trait SQLFunctionChain extends SQLFunction {
3031
def functions: List[SQLFunction]
3132

3233
override def validate(): Either[String, Unit] =
@@ -43,13 +44,19 @@ trait SQLFunctionChain extends SQLFunction with SQLValidation {
4344
}
4445

4546
lazy val aggregation: Boolean = aggregateFunction.isDefined
47+
48+
override def in: SQLType = functions.lastOption.map(_.in).getOrElse(super.in)
49+
50+
override def out: SQLType = functions.headOption.map(_.out).getOrElse(super.out)
4651
}
4752

4853
sealed trait SQLUnaryFunction[In <: SQLType, Out <: SQLType]
4954
extends SQLFunction
5055
with PainlessScript {
5156
def inputType: In
5257
def outputType: Out
58+
override def in: SQLType = inputType
59+
override def out: SQLType = outputType
5360
}
5461

5562
sealed trait SQLBinaryFunction[In1 <: SQLType, In2 <: SQLType, Out <: SQLType]
@@ -174,7 +181,7 @@ case class SQLAddInterval(interval: TimeInterval)
174181
override def script: String = s"${operator.script}${interval.script}"
175182
}
176183

177-
case class SQLSubstractInterval(interval: TimeInterval)
184+
case class SQLSubtractInterval(interval: TimeInterval)
178185
extends SQLExpr(interval.sql)
179186
with SQLArithmeticFunction[SQLDateTime, SQLDateTime]
180187
with MathScript {
@@ -191,18 +198,29 @@ sealed trait DateFunction extends DateTimeFunction
191198

192199
sealed trait TimeFunction extends DateTimeFunction
193200

194-
sealed trait CurrentDateTimeFunction extends DateTimeFunction with PainlessScript with MathScript {
201+
sealed trait SystemFunction extends SQLFunction {
202+
override def system: Boolean = true
203+
}
204+
205+
sealed trait CurrentDateTimeFunction
206+
extends DateTimeFunction
207+
with PainlessScript
208+
with MathScript
209+
with SystemFunction {
195210
override def painless: String =
196211
"ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalDateTime()"
197212
override def script: String = "now"
213+
override def out: SQLType = SQLTypes.DateTime
198214
}
199215

200216
sealed trait CurrentDateFunction extends CurrentDateTimeFunction with DateFunction {
201217
override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalDate()"
218+
override def out: SQLType = SQLTypes.Date
202219
}
203220

204221
sealed trait CurrentTimeFunction extends CurrentDateTimeFunction with TimeFunction {
205222
override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalTime()"
223+
override def out: SQLType = SQLTypes.Time
206224
}
207225

208226
case object CurrentDate extends SQLExpr("current_date") with CurrentDateFunction
@@ -396,3 +414,27 @@ case class FormatDateTime(identifier: SQLIdentifier, format: String)
396414
override def toPainless(base: String): String =
397415
s"DateTimeFormatter.ofPattern('$format').format($base)"
398416
}
417+
418+
sealed trait SQLLogicalFunction[In <: SQLType]
419+
extends SQLTransformFunction[In, SQLBool]
420+
with SQLFunctionWithIdentifier {
421+
def operator: SQLLogicalOperator
422+
override def outputType: SQLBool = SQLTypes.Boolean
423+
override def toPainless(base: String): String = s"($base$painless)"
424+
}
425+
426+
case class SQLIsNullFunction(identifier: SQLIdentifier)
427+
extends SQLExpr("isnull")
428+
with SQLLogicalFunction[SQLAny] {
429+
override def operator: SQLLogicalOperator = IsNull
430+
override def inputType: SQLAny = SQLTypes.Any
431+
override def painless: String = s" == null"
432+
}
433+
434+
case class SQLIsNotNullFunction(identifier: SQLIdentifier)
435+
extends SQLExpr("isnotnull")
436+
with SQLLogicalFunction[SQLAny] {
437+
override def operator: SQLLogicalOperator = IsNotNull
438+
override def inputType: SQLAny = SQLTypes.Any
439+
override def painless: String = s" != null"
440+
}

sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@ case class SQLGroupBy(buckets: Seq[SQLBucket]) extends Updateable {
99
lazy val bucketNames: Map[String, SQLBucket] = buckets.map { b =>
1010
b.identifier.identifierName -> b
1111
}.toMap
12+
13+
override def validate(): Either[String, Unit] = {
14+
if (buckets.isEmpty) {
15+
Left("At least one bucket is required in GROUP BY clause")
16+
} else {
17+
Right(())
18+
}
19+
}
1220
}
1321

1422
case class SQLBucket(
@@ -95,19 +103,13 @@ object BucketSelectorScript {
95103
extractBucketsPath(left) ++ extractBucketsPath(right)
96104
case relation: ElasticRelation => extractBucketsPath(relation.criteria)
97105
case _: SQLMatch => Map.empty //MATCH is not supported in bucket_selector
98-
case b: BinaryExpression =>
99-
import b._
100-
if (left.aggregation && right.aggregation)
101-
Map(left.aliasOrName -> left.aliasOrName, right.aliasOrName -> right.aliasOrName)
102-
else if (left.aggregation)
103-
Map(left.aliasOrName -> left.aliasOrName)
104-
else if (right.aggregation)
105-
Map(right.aliasOrName -> right.aliasOrName)
106-
else
107-
Map.empty
108106
case e: Expression if e.aggregation =>
109107
import e._
110-
Map(identifier.aliasOrName -> identifier.aliasOrName)
108+
maybeValue match {
109+
case Some(v: SQLIdentifier) if v.aggregation =>
110+
Map(identifier.aliasOrName -> identifier.aliasOrName, v.aliasOrName -> v.aliasOrName)
111+
case _ => Map(identifier.aliasOrName -> identifier.aliasOrName)
112+
}
111113
case _ => Map.empty
112114
}
113115

@@ -155,18 +157,7 @@ object BucketSelectorScript {
155157
case _: SQLMatch => "1 == 1" //MATCH is not supported in bucket_selector
156158

157159
case e: Expression if e.aggregation =>
158-
val param =
159-
s"params.${e.identifier.aliasOrName}"
160-
e.maybeValue match {
161-
case Some(v) => toPainless(param, e.operator, v, e.maybeNot.nonEmpty)
162-
case None =>
163-
e.operator match {
164-
case IsNull => s"$param == null"
165-
case IsNotNull => s"$param != null"
166-
case _ =>
167-
throw new IllegalArgumentException(s"Operator ${e.operator} requires a value")
168-
}
169-
}
160+
e.painless
170161

171162
case _ => "1 == 1" //throw new IllegalArgumentException(s"Unsupported SQLCriteria type: $expr")
172163
}

sql/src/main/scala/app/softnetwork/elastic/sql/SQLHaving.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,6 @@ case class SQLHaving(criteria: Option[SQLCriteria]) extends Updateable {
99
}
1010
def update(request: SQLSearchRequest): SQLHaving =
1111
this.copy(criteria = criteria.map(_.update(request)))
12+
13+
override def validate(): Either[String, Unit] = criteria.map(_.validate()).getOrElse(Right(()))
1214
}

sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
package app.softnetwork.elastic.sql
22

3-
trait SQLOperator extends SQLToken
3+
trait SQLOperator extends SQLToken with PainlessScript {
4+
override def painless: String = this match {
5+
case And => "&&"
6+
case Or => "||"
7+
case Not => "!"
8+
case _ => sql
9+
}
10+
}
411

512
sealed trait ArithmeticOperator extends SQLOperator with MathScript {
613
override def toString: String = s" $sql "

0 commit comments

Comments
 (0)