Skip to content

Commit 6137001

Browse files
committed
add SQLFunctionChain, SQLBinaryFunction, SQLFunctionField, implements date_diff
1 parent 2bbc83e commit 6137001

File tree

13 files changed

+204
-88
lines changed

13 files changed

+204
-88
lines changed

es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import app.softnetwork.elastic.sql.{
1212
Min,
1313
SQLBucket,
1414
SQLCriteria,
15-
SQLTransformFunction,
15+
SQLFunctionUtils,
1616
SortOrder,
1717
Sum
1818
}
@@ -101,12 +101,7 @@ object ElasticAggregation {
101101
buildScript: (String, Script) => Aggregation
102102
): Aggregation = {
103103
if (transformFuncs.nonEmpty) {
104-
val base = s"doc['$sourceField'].value"
105-
val orderedTransforms = transformFuncs.reverse
106-
val scriptSrc = orderedTransforms.foldLeft(base) {
107-
case (expr, f: SQLTransformFunction[_, _]) => f.toPainless(expr)
108-
case (expr, f) => f.toSQL(expr) // fallback
109-
}
104+
val scriptSrc = SQLFunctionUtils.buildPainless(Option(identifier), transformFuncs)
110105
val script = Script(scriptSrc).lang("painless")
111106
buildScript(aggName, script)
112107
} else {

sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import app.softnetwork.elastic.sql.{
1212
Min,
1313
SQLBucket,
1414
SQLCriteria,
15-
SQLTransformFunction,
15+
SQLFunctionUtils,
1616
SortOrder,
1717
Sum
1818
}
@@ -100,12 +100,7 @@ object ElasticAggregation {
100100
buildScript: (String, Script) => Aggregation
101101
): Aggregation = {
102102
if (transformFuncs.nonEmpty) {
103-
val base = s"doc['$sourceField'].value"
104-
val orderedTransforms = transformFuncs.reverse
105-
val scriptSrc = orderedTransforms.foldLeft(base) {
106-
case (expr, f: SQLTransformFunction[_, _]) => f.toPainless(expr)
107-
case (expr, f) => f.toSQL(expr) // fallback
108-
}
103+
val scriptSrc = SQLFunctionUtils.buildPainless(Option(identifier), transformFuncs)
109104
val script = Script(scriptSrc).lang("painless")
110105
buildScript(aggName, script)
111106
} else {

sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala

Lines changed: 76 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,62 @@ sealed trait SQLFunction extends SQLRegex {
66
def toSQL(base: String): String = s"$sql($base)"
77
}
88

9-
trait SQLTypedFunction[In <: SQLType, Out <: SQLType] extends SQLFunction {
9+
object SQLFunctionUtils {
10+
def buildPainless(functions: List[SQLFunction]): String =
11+
buildPainless(None, functions)
12+
13+
def buildPainless(
14+
painless: Option[PainlessScript] = None,
15+
functions: List[SQLFunction]
16+
): String = {
17+
val base = painless.map(_.painless).getOrElse("")
18+
val orderedFunctions = functions.reverse
19+
orderedFunctions.foldLeft(base) {
20+
case (expr, f: SQLTransformFunction[_, _]) => f.toPainless(expr)
21+
case (_, f: PainlessScript) => f.painless
22+
case (expr, f) => f.toSQL(expr) // fallback
23+
}
24+
}
25+
}
26+
27+
trait SQLFunctionChain extends SQLFunction with SQLValidation {
28+
def functions: List[SQLFunction]
29+
30+
override def validate(): Either[String, Unit] =
31+
SQLValidator.validateChain(functions)
32+
33+
override def toSQL(base: String): String =
34+
functions.reverse.foldLeft(base)((expr, fun) => {
35+
fun.toSQL(expr)
36+
})
37+
38+
lazy val aggregateFunction: Option[AggregateFunction] = functions.headOption match {
39+
case Some(af: AggregateFunction) => Some(af)
40+
case _ => None
41+
}
42+
43+
lazy val aggregation: Boolean = aggregateFunction.isDefined
44+
}
45+
46+
sealed trait SQLUnaryFunction[In <: SQLType, Out <: SQLType]
47+
extends SQLFunction
48+
with PainlessScript {
1049
def inputType: In
1150
def outputType: Out
12-
def from(other: SQLTypedFunction[_, _]): Boolean =
13-
inputType.typeId == other.outputType.asInstanceOf[SQLType].typeId ||
14-
(inputType.typeId == "temporal" && Set("date", "datetime").contains(
15-
other.outputType.asInstanceOf[SQLType].typeId
16-
))
17-
def to(other: SQLTypedFunction[_, _]): Boolean =
18-
outputType.typeId == other.inputType.asInstanceOf[SQLType].typeId ||
19-
(outputType.typeId == "temporal" && Set("date", "datetime").contains(
20-
other.inputType.asInstanceOf[SQLType].typeId
21-
))
2251
}
2352

24-
sealed trait SQLTransformFunction[In <: SQLType, Out <: SQLType] extends SQLTypedFunction[In, Out] {
25-
def toPainless(base: String): String
53+
trait SQLBinaryFunction[In1 <: SQLType, In2 <: SQLType, Out <: SQLType]
54+
extends SQLUnaryFunction[SQLAny, Out] { self: SQLFunction =>
55+
56+
override def inputType: SQLAny = SQLTypes.Any
57+
58+
def left: PainlessScript
59+
def right: PainlessScript
60+
61+
}
62+
63+
sealed trait SQLTransformFunction[In <: SQLType, Out <: SQLType] extends SQLUnaryFunction[In, Out] {
64+
def toPainless(base: String): String = s"$base$painless"
2665
}
2766

2867
sealed trait ParametrizedFunction extends SQLFunction {
@@ -162,7 +201,7 @@ case class DateTrunc(unit: TimeUnit)
162201
override def inputType: SQLTemporal = SQLTypes.Temporal // par défaut
163202
override def outputType: SQLTemporal = SQLTypes.Temporal // idem
164203
override def params: Seq[String] = Seq(unit.sql)
165-
override def toPainless(base: String): String = s"$base.truncatedTo(${unit.painless})"
204+
override def painless: String = s".truncatedTo(${unit.painless})"
166205
}
167206

168207
case class Extract(unit: TimeUnit, override val sql: String = "extract")
@@ -173,7 +212,7 @@ case class Extract(unit: TimeUnit, override val sql: String = "extract")
173212
override def inputType: SQLTemporal = SQLTypes.Temporal
174213
override def outputType: SQLNumber = SQLTypes.Number
175214
override def params: Seq[String] = Seq(unit.sql)
176-
override def toPainless(base: String): String = s"$base.get(${unit.painless})"
215+
override def painless: String = s".get(${unit.painless})"
177216
}
178217

179218
object YEAR extends Extract(Year, Year.sql) {
@@ -200,6 +239,20 @@ object SECOND extends Extract(Second, Second.sql) {
200239
override def params: Seq[String] = Seq.empty
201240
}
202241

242+
case class DateDiff(end: PainlessScript, start: PainlessScript, unit: TimeUnit)
243+
extends SQLExpr("date_diff")
244+
with DateTimeFunction
245+
with SQLBinaryFunction[SQLDateTime, SQLDateTime, SQLNumber]
246+
with PainlessScript {
247+
override def outputType: SQLNumber = SQLTypes.Number
248+
override def left: PainlessScript = end
249+
override def right: PainlessScript = start
250+
override def toSQL(base: String): String = {
251+
s"$sql(${end.sql}, ${start.sql}, ${unit.sql})"
252+
}
253+
override def painless: String = s"${unit.painless}.between(${start.painless}, ${end.painless})"
254+
}
255+
203256
case class DateAdd(interval: TimeInterval)
204257
extends SQLExpr("date_add")
205258
with DateFunction
@@ -208,7 +261,7 @@ case class DateAdd(interval: TimeInterval)
208261
override def inputType: SQLDate = SQLTypes.Date
209262
override def outputType: SQLDate = SQLTypes.Date
210263
override def params: Seq[String] = Seq(interval.sql)
211-
override def toPainless(base: String): String = s"$base.plus(${interval.painless})"
264+
override def painless: String = s".plus(${interval.painless})"
212265
}
213266

214267
case class DateSub(interval: TimeInterval)
@@ -219,7 +272,7 @@ case class DateSub(interval: TimeInterval)
219272
override def inputType: SQLDate = SQLTypes.Date
220273
override def outputType: SQLDate = SQLTypes.Date
221274
override def params: Seq[String] = Seq(interval.sql)
222-
override def toPainless(base: String): String = s"$base.minus(${interval.painless})"
275+
override def painless: String = s".minus(${interval.painless})"
223276
}
224277

225278
case class ParseDate(format: String)
@@ -230,6 +283,7 @@ case class ParseDate(format: String)
230283
override def inputType: SQLString = SQLTypes.String
231284
override def outputType: SQLDate = SQLTypes.Date
232285
override def params: Seq[String] = Seq(s"'$format'")
286+
override def painless: String = throw new NotImplementedError("Use toPainless instead")
233287
override def toPainless(base: String): String =
234288
s"DateTimeFormatter.ofPattern('$format').parse($base, LocalDate::from)"
235289
}
@@ -242,6 +296,7 @@ case class FormatDate(format: String)
242296
override def inputType: SQLDate = SQLTypes.Date
243297
override def outputType: SQLString = SQLTypes.String
244298
override def params: Seq[String] = Seq(s"'$format'")
299+
override def painless: String = throw new NotImplementedError("Use toPainless instead")
245300
override def toPainless(base: String): String =
246301
s"DateTimeFormatter.ofPattern('$format').format($base)"
247302
}
@@ -254,7 +309,7 @@ case class DateTimeAdd(interval: TimeInterval)
254309
override def inputType: SQLDateTime = SQLTypes.DateTime
255310
override def outputType: SQLDateTime = SQLTypes.DateTime
256311
override def params: Seq[String] = Seq(interval.sql)
257-
override def toPainless(base: String): String = s"$base.plus(${interval.painless})"
312+
override def painless: String = s".plus(${interval.painless})"
258313
}
259314

260315
case class DateTimeSub(interval: TimeInterval)
@@ -265,7 +320,7 @@ case class DateTimeSub(interval: TimeInterval)
265320
override def inputType: SQLDateTime = SQLTypes.DateTime
266321
override def outputType: SQLDateTime = SQLTypes.DateTime
267322
override def params: Seq[String] = Seq(interval.sql)
268-
override def toPainless(base: String): String = s"$base.minus(${interval.painless})"
323+
override def painless: String = s".minus(${interval.painless})"
269324
}
270325

271326
case class ParseDateTime(format: String)
@@ -276,6 +331,7 @@ case class ParseDateTime(format: String)
276331
override def inputType: SQLString = SQLTypes.String
277332
override def outputType: SQLDateTime = SQLTypes.DateTime
278333
override def params: Seq[String] = Seq(s"'$format'")
334+
override def painless: String = throw new NotImplementedError("Use toPainless instead")
279335
override def toPainless(base: String): String =
280336
s"DateTimeFormatter.ofPattern('$format').parse($base, LocalDateTime::from)"
281337
}
@@ -288,6 +344,7 @@ case class FormatDateTime(format: String)
288344
override def inputType: SQLDateTime = SQLTypes.DateTime
289345
override def outputType: SQLString = SQLTypes.String
290346
override def params: Seq[String] = Seq(s"'$format'")
347+
override def painless: String = throw new NotImplementedError("Use toPainless instead")
291348
override def toPainless(base: String): String =
292349
s"DateTimeFormatter.ofPattern('$format').format($base)"
293350
}

sql/src/main/scala/app/softnetwork/elastic/sql/SQLOrderBy.scala

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,9 @@ case class SQLFieldSort(
1212
field: String,
1313
order: Option[SortOrder],
1414
functions: List[SQLFunction] = List.empty
15-
) extends SQLTokenWithFunction {
16-
private[this] lazy val fieldWithFunction: String =
17-
functions.foldLeft(field)((expr, fun) => {
18-
fun.toSQL(expr)
19-
})
20-
15+
) extends SQLFunctionChain {
2116
lazy val direction: SortOrder = order.getOrElse(Asc)
22-
lazy val name: String = fieldWithFunction
17+
lazy val name: String = toSQL(field)
2318
override def sql: String = s"$name $direction"
2419
}
2520

sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -146,34 +146,34 @@ trait SQLParser extends RegexParsers with PackratParsers {
146146
)
147147
}
148148

149-
def date_trunc: PackratParser[SQLTypedFunction[SQLTemporal, SQLTemporal]] =
149+
def date_trunc: PackratParser[SQLUnaryFunction[SQLTemporal, SQLTemporal]] =
150150
"(?i)date_trunc".r ~ start ~ time_unit ~ end ^^ { case _ ~ _ ~ u ~ _ =>
151151
DateTrunc(u)
152152
}
153153

154-
def extract: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] =
154+
def extract: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] =
155155
"(?i)extract".r ~ start ~ time_unit ~ end ^^ { case _ ~ _ ~ u ~ _ =>
156156
Extract(u)
157157
}
158158

159-
def extract_year: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] =
159+
def extract_year: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] =
160160
Year.regex ^^ (_ => YEAR)
161161

162-
def extract_month: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] =
162+
def extract_month: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] =
163163
Month.regex ^^ (_ => MONTH)
164164

165-
def extract_day: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] = Day.regex ^^ (_ => DAY)
165+
def extract_day: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = Day.regex ^^ (_ => DAY)
166166

167-
def extract_hour: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] =
167+
def extract_hour: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] =
168168
Hour.regex ^^ (_ => HOUR)
169169

170-
def extract_minute: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] =
170+
def extract_minute: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] =
171171
Minute.regex ^^ (_ => MINUTE)
172172

173-
def extract_second: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] =
173+
def extract_second: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] =
174174
Second.regex ^^ (_ => SECOND)
175175

176-
def extractors: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] =
176+
def extractors: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] =
177177
extract | extract_year | extract_month | extract_day | extract_hour | extract_minute | extract_second
178178

179179
def date_add: PackratParser[DateFunction] =
@@ -225,15 +225,35 @@ trait SQLParser extends RegexParsers with PackratParsers {
225225

226226
def distance: PackratParser[SQLFunction] = Distance.regex ^^ (_ => Distance)
227227

228+
def date_painless: PackratParser[PainlessScript] =
229+
repsep(
230+
date_trunc | extractors | date_functions | datetime_functions,
231+
start
232+
) ~ start.? ~ identifier.? ~ rep(end) ^^ { case f ~ _ ~ i ~ _ =>
233+
SQLValidator.validateChain(f) match {
234+
case Left(error) => throw SQLValidationError(error)
235+
case _ =>
236+
}
237+
i match {
238+
case Some(id) => id.copy(functions = f)
239+
case None => SQLIdentifier("", functions = f)
240+
}
241+
}
242+
243+
def date_diff: PackratParser[DateDiff] =
244+
"(?i)date_diff".r ~ start ~ (date_painless | identifier) ~ separator ~ (date_painless | identifier) ~ separator ~ time_unit ~ end ^^ {
245+
case _ ~ _ ~ d1 ~ _ ~ d2 ~ _ ~ u ~ _ => DateDiff(d1, d2, u)
246+
}
247+
228248
def sql_functions: PackratParser[SQLFunction] =
229-
aggregates | distance | date_trunc | extractors | date_functions | datetime_functions
249+
aggregates | distance | date_diff | date_trunc | extractors | date_functions | datetime_functions
230250

231251
private val regexIdentifier = """[\*a-zA-Z_\-][a-zA-Z0-9_\-\.\[\]\*]*"""
232252

233253
def identifierWithFunction: PackratParser[SQLIdentifier] =
234254
rep1sep(sql_functions, start) ~ start.? ~ identifier ~ rep1(end) ^^ { case f ~ _ ~ i ~ _ =>
235255
SQLValidator.validateChain(f) match {
236-
case Left(error) => throw new IllegalArgumentException(error)
256+
case Left(error) => throw SQLValidationError(error)
237257
case _ =>
238258
}
239259
i.copy(functions = f)
@@ -271,6 +291,20 @@ trait SQLParser extends RegexParsers with PackratParsers {
271291
(dateTimeWithInterval | identifierWithInterval) ~ alias.? ^^ { case d ~ a =>
272292
d.copy(fieldAlias = a)
273293
}
294+
295+
def date_diff_field: PackratParser[SQLFunctionField] = date_diff ~ alias.? ^^ { case d ~ a =>
296+
SQLFunctionField(d :: Nil, a)
297+
}
298+
299+
def functionField: PackratParser[SQLFunctionField] =
300+
rep1sep(sql_functions, start) ~ start.? ~ rep1(end) ~ alias.? ^^ { case f ~ _ ~ _ ~ a =>
301+
SQLValidator.validateChain(f) match {
302+
case Left(error) => throw SQLValidationError(error)
303+
case _ =>
304+
}
305+
SQLFunctionField(f, a)
306+
}
307+
274308
}
275309

276310
trait SQLSelectParser {
@@ -282,7 +316,10 @@ trait SQLSelectParser {
282316
}
283317

284318
def select: PackratParser[SQLSelect] =
285-
Select.regex ~ rep1sep(scriptField | field, separator) ~ except.? ^^ { case _ ~ fields ~ e =>
319+
Select.regex ~ rep1sep(
320+
date_diff_field | functionField | scriptField | field,
321+
separator
322+
) ~ except.? ^^ { case _ ~ fields ~ e =>
286323
SQLSelect(fields, e)
287324
}
288325

@@ -550,7 +587,7 @@ trait SQLWhereParser {
550587
case _ :: Nil =>
551588
processTokensHelper(rest, op :: stack)
552589
case _ =>
553-
throw new IllegalStateException("Invalid stack state for predicate creation")
590+
throw SQLValidationError("Invalid stack state for predicate creation")
554591
}
555592
case (_: EndDelimiter) :: rest =>
556593
processTokensHelper(rest, stack) // Ignore and move on
@@ -581,7 +618,7 @@ trait SQLWhereParser {
581618
*/
582619
private def processSubTokens(tokens: List[SQLToken]): SQLCriteria = {
583620
processTokensHelper(tokens, Nil).getOrElse(
584-
throw new IllegalStateException("Empty sub-expression")
621+
throw SQLValidationError("Empty sub-expression")
585622
)
586623
}
587624

@@ -604,7 +641,7 @@ trait SQLWhereParser {
604641
subTokens: List[SQLToken] = Nil
605642
): (List[SQLToken], List[SQLToken]) = {
606643
tokens match {
607-
case Nil => throw new IllegalStateException("Unbalanced parentheses")
644+
case Nil => throw SQLValidationError("Unbalanced parentheses")
608645
case (start: StartDelimiter) :: rest =>
609646
extractSubTokens(rest, openCount + 1, start :: subTokens)
610647
case (end: EndDelimiter) :: rest =>
@@ -654,7 +691,7 @@ trait SQLOrderByParser {
654691
def fieldWithFunction: PackratParser[(String, List[SQLFunction])] =
655692
rep1sep(sql_functions, start) ~ start.? ~ fieldName ~ rep1(end) ^^ { case f ~ _ ~ n ~ _ =>
656693
SQLValidator.validateChain(f) match {
657-
case Left(error) => throw new IllegalArgumentException(error)
694+
case Left(error) => throw SQLValidationError(error)
658695
case _ =>
659696
}
660697
(n, f)

0 commit comments

Comments
 (0)