@@ -6,23 +6,62 @@ sealed trait SQLFunction extends SQLRegex {
66 def toSQL (base : String ): String = s " $sql( $base) "
77}
88
9- trait SQLTypedFunction [In <: SQLType , Out <: SQLType ] extends SQLFunction {
9+ object SQLFunctionUtils {
10+ def buildPainless (functions : List [SQLFunction ]): String =
11+ buildPainless(None , functions)
12+
13+ def buildPainless (
14+ painless : Option [PainlessScript ] = None ,
15+ functions : List [SQLFunction ]
16+ ): String = {
17+ val base = painless.map(_.painless).getOrElse(" " )
18+ val orderedFunctions = functions.reverse
19+ orderedFunctions.foldLeft(base) {
20+ case (expr, f : SQLTransformFunction [_, _]) => f.toPainless(expr)
21+ case (_, f : PainlessScript ) => f.painless
22+ case (expr, f) => f.toSQL(expr) // fallback
23+ }
24+ }
25+ }
26+
27+ trait SQLFunctionChain extends SQLFunction with SQLValidation {
28+ def functions : List [SQLFunction ]
29+
30+ override def validate (): Either [String , Unit ] =
31+ SQLValidator .validateChain(functions)
32+
33+ override def toSQL (base : String ): String =
34+ functions.reverse.foldLeft(base)((expr, fun) => {
35+ fun.toSQL(expr)
36+ })
37+
38+ lazy val aggregateFunction : Option [AggregateFunction ] = functions.headOption match {
39+ case Some (af : AggregateFunction ) => Some (af)
40+ case _ => None
41+ }
42+
43+ lazy val aggregation : Boolean = aggregateFunction.isDefined
44+ }
45+
46+ sealed trait SQLUnaryFunction [In <: SQLType , Out <: SQLType ]
47+ extends SQLFunction
48+ with PainlessScript {
1049 def inputType : In
1150 def outputType : Out
12- def from (other : SQLTypedFunction [_, _]): Boolean =
13- inputType.typeId == other.outputType.asInstanceOf [SQLType ].typeId ||
14- (inputType.typeId == " temporal" && Set (" date" , " datetime" ).contains(
15- other.outputType.asInstanceOf [SQLType ].typeId
16- ))
17- def to (other : SQLTypedFunction [_, _]): Boolean =
18- outputType.typeId == other.inputType.asInstanceOf [SQLType ].typeId ||
19- (outputType.typeId == " temporal" && Set (" date" , " datetime" ).contains(
20- other.inputType.asInstanceOf [SQLType ].typeId
21- ))
2251}
2352
24- sealed trait SQLTransformFunction [In <: SQLType , Out <: SQLType ] extends SQLTypedFunction [In , Out ] {
25- def toPainless (base : String ): String
53+ trait SQLBinaryFunction [In1 <: SQLType , In2 <: SQLType , Out <: SQLType ]
54+ extends SQLUnaryFunction [SQLAny , Out ] { self : SQLFunction =>
55+
56+ override def inputType : SQLAny = SQLTypes .Any
57+
58+ def left : PainlessScript
59+ def right : PainlessScript
60+
61+ }
62+
63+ sealed trait SQLTransformFunction [In <: SQLType , Out <: SQLType ] extends SQLUnaryFunction [In , Out ] {
64+ def toPainless (base : String ): String = s " $base$painless"
2665}
2766
2867sealed trait ParametrizedFunction extends SQLFunction {
@@ -162,7 +201,7 @@ case class DateTrunc(unit: TimeUnit)
162201 override def inputType : SQLTemporal = SQLTypes .Temporal // par défaut
163202 override def outputType : SQLTemporal = SQLTypes .Temporal // idem
164203 override def params : Seq [String ] = Seq (unit.sql)
165- override def toPainless ( base : String ) : String = s " $base .truncatedTo( ${unit.painless}) "
204+ override def painless : String = s " .truncatedTo( ${unit.painless}) "
166205}
167206
168207case class Extract (unit : TimeUnit , override val sql : String = " extract" )
@@ -173,7 +212,7 @@ case class Extract(unit: TimeUnit, override val sql: String = "extract")
173212 override def inputType : SQLTemporal = SQLTypes .Temporal
174213 override def outputType : SQLNumber = SQLTypes .Number
175214 override def params : Seq [String ] = Seq (unit.sql)
176- override def toPainless ( base : String ) : String = s " $base .get( ${unit.painless}) "
215+ override def painless : String = s " .get( ${unit.painless}) "
177216}
178217
179218object YEAR extends Extract (Year , Year .sql) {
@@ -200,6 +239,20 @@ object SECOND extends Extract(Second, Second.sql) {
200239 override def params : Seq [String ] = Seq .empty
201240}
202241
242+ case class DateDiff (end : PainlessScript , start : PainlessScript , unit : TimeUnit )
243+ extends SQLExpr (" date_diff" )
244+ with DateTimeFunction
245+ with SQLBinaryFunction [SQLDateTime , SQLDateTime , SQLNumber ]
246+ with PainlessScript {
247+ override def outputType : SQLNumber = SQLTypes .Number
248+ override def left : PainlessScript = end
249+ override def right : PainlessScript = start
250+ override def toSQL (base : String ): String = {
251+ s " $sql( ${end.sql}, ${start.sql}, ${unit.sql}) "
252+ }
253+ override def painless : String = s " ${unit.painless}.between( ${start.painless}, ${end.painless}) "
254+ }
255+
203256case class DateAdd (interval : TimeInterval )
204257 extends SQLExpr (" date_add" )
205258 with DateFunction
@@ -208,7 +261,7 @@ case class DateAdd(interval: TimeInterval)
208261 override def inputType : SQLDate = SQLTypes .Date
209262 override def outputType : SQLDate = SQLTypes .Date
210263 override def params : Seq [String ] = Seq (interval.sql)
211- override def toPainless ( base : String ) : String = s " $base .plus( ${interval.painless}) "
264+ override def painless : String = s " .plus( ${interval.painless}) "
212265}
213266
214267case class DateSub (interval : TimeInterval )
@@ -219,7 +272,7 @@ case class DateSub(interval: TimeInterval)
219272 override def inputType : SQLDate = SQLTypes .Date
220273 override def outputType : SQLDate = SQLTypes .Date
221274 override def params : Seq [String ] = Seq (interval.sql)
222- override def toPainless ( base : String ) : String = s " $base .minus( ${interval.painless}) "
275+ override def painless : String = s " .minus( ${interval.painless}) "
223276}
224277
225278case class ParseDate (format : String )
@@ -230,6 +283,7 @@ case class ParseDate(format: String)
230283 override def inputType : SQLString = SQLTypes .String
231284 override def outputType : SQLDate = SQLTypes .Date
232285 override def params : Seq [String ] = Seq (s " ' $format' " )
286+ override def painless : String = throw new NotImplementedError (" Use toPainless instead" )
233287 override def toPainless (base : String ): String =
234288 s " DateTimeFormatter.ofPattern(' $format').parse( $base, LocalDate::from) "
235289}
@@ -242,6 +296,7 @@ case class FormatDate(format: String)
242296 override def inputType : SQLDate = SQLTypes .Date
243297 override def outputType : SQLString = SQLTypes .String
244298 override def params : Seq [String ] = Seq (s " ' $format' " )
299+ override def painless : String = throw new NotImplementedError (" Use toPainless instead" )
245300 override def toPainless (base : String ): String =
246301 s " DateTimeFormatter.ofPattern(' $format').format( $base) "
247302}
@@ -254,7 +309,7 @@ case class DateTimeAdd(interval: TimeInterval)
254309 override def inputType : SQLDateTime = SQLTypes .DateTime
255310 override def outputType : SQLDateTime = SQLTypes .DateTime
256311 override def params : Seq [String ] = Seq (interval.sql)
257- override def toPainless ( base : String ) : String = s " $base .plus( ${interval.painless}) "
312+ override def painless : String = s " .plus( ${interval.painless}) "
258313}
259314
260315case class DateTimeSub (interval : TimeInterval )
@@ -265,7 +320,7 @@ case class DateTimeSub(interval: TimeInterval)
265320 override def inputType : SQLDateTime = SQLTypes .DateTime
266321 override def outputType : SQLDateTime = SQLTypes .DateTime
267322 override def params : Seq [String ] = Seq (interval.sql)
268- override def toPainless ( base : String ) : String = s " $base .minus( ${interval.painless}) "
323+ override def painless : String = s " .minus( ${interval.painless}) "
269324}
270325
271326case class ParseDateTime (format : String )
@@ -276,6 +331,7 @@ case class ParseDateTime(format: String)
276331 override def inputType : SQLString = SQLTypes .String
277332 override def outputType : SQLDateTime = SQLTypes .DateTime
278333 override def params : Seq [String ] = Seq (s " ' $format' " )
334+ override def painless : String = throw new NotImplementedError (" Use toPainless instead" )
279335 override def toPainless (base : String ): String =
280336 s " DateTimeFormatter.ofPattern(' $format').parse( $base, LocalDateTime::from) "
281337}
@@ -288,6 +344,7 @@ case class FormatDateTime(format: String)
288344 override def inputType : SQLDateTime = SQLTypes .DateTime
289345 override def outputType : SQLString = SQLTypes .String
290346 override def params : Seq [String ] = Seq (s " ' $format' " )
347+ override def painless : String = throw new NotImplementedError (" Use toPainless instead" )
291348 override def toPainless (base : String ): String =
292349 s " DateTimeFormatter.ofPattern(' $format').format( $base) "
293350}
0 commit comments