diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonArrayFunctionWrapper.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonArrayFunctionWrapper.java index e1b60699d7fb2..eb6aae3a8edef 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonArrayFunctionWrapper.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonArrayFunctionWrapper.java @@ -29,7 +29,7 @@ * This class is a wrapper class for the {@link SqlJsonArrayFunction} but using the {@code * VARCHAR_NOT_NULL} return type inference. */ -class SqlJsonArrayFunctionWrapper extends SqlJsonArrayFunction { +public class SqlJsonArrayFunctionWrapper extends SqlJsonArrayFunction { @Override public RelDataType inferReturnType(SqlOperatorBinding opBinding) { diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonObjectFunctionWrapper.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonObjectFunctionWrapper.java index b09ab149a641f..b4ef34b94c572 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonObjectFunctionWrapper.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonObjectFunctionWrapper.java @@ -29,7 +29,7 @@ * This class is a wrapper class for the {@link SqlJsonObjectFunction} but using the {@code * VARCHAR_NOT_NULL} return type inference. */ -class SqlJsonObjectFunctionWrapper extends SqlJsonObjectFunction { +public class SqlJsonObjectFunctionWrapper extends SqlJsonObjectFunction { @Override public RelDataType inferReturnType(SqlOperatorBinding opBinding) { diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonQueryFunctionWrapper.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonQueryFunctionWrapper.java index 7a145ba9cce29..ddae97fa23259 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonQueryFunctionWrapper.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonQueryFunctionWrapper.java @@ -42,7 +42,7 @@ * This class is a wrapper class for the {@link SqlJsonQueryFunction} but using the {@code * VARCHAR_FORCE_NULLABLE} return type inference. */ -class SqlJsonQueryFunctionWrapper extends SqlJsonQueryFunction { +public class SqlJsonQueryFunctionWrapper extends SqlJsonQueryFunction { private final SqlReturnTypeInference returnTypeInference; SqlJsonQueryFunctionWrapper() { diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonValueFunctionWrapper.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonValueFunctionWrapper.java index b28ef4786e47f..03e60ecdf3c9d 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonValueFunctionWrapper.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonValueFunctionWrapper.java @@ -35,7 +35,7 @@ * VARCHAR_FORCE_NULLABLE} return type inference by default. It also supports specifying return type * with the RETURNING keyword just like the original {@link SqlJsonValueFunction}. */ -class SqlJsonValueFunctionWrapper extends SqlJsonValueFunction { +public class SqlJsonValueFunctionWrapper extends SqlJsonValueFunction { private final SqlReturnTypeInference returnTypeInference; diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java index cf389655031cd..e1ddbcbc46e6f 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java @@ -32,6 +32,7 @@ import org.apache.flink.table.planner.plan.nodes.exec.SingleTransformationTranslator; import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil; import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; +import org.apache.flink.table.planner.utils.ShortcutUtils; import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory; import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; import org.apache.flink.table.types.logical.RowType; @@ -99,10 +100,11 @@ protected Transformation translateToPlanInternal( final CodeGenOperatorFactory substituteStreamOperator = CalcCodeGenerator.generateCalcOperator( ctx, - inputTransform, + (RowType) inputEdge.getOutputType(), (RowType) getOutputType(), JavaScalaConversionUtil.toScala(projection), JavaScalaConversionUtil.toScala(Optional.ofNullable(this.condition)), + ShortcutUtils.unwrapTypeFactory(planner), retainHeader, getClass().getSimpleName()); return ExecNodeUtil.createOneInputTransformation( diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java index 05b07458a10ed..31cb09a545f6d 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java @@ -502,7 +502,8 @@ protected StreamOperatorFactory createAsyncLookupJoin( JavaScalaConversionUtil.toScala(projectionOnTemporalTable), filterOnTemporalTable, projectionOutputRelDataType, - tableSourceRowType); + tableSourceRowType, + ShortcutUtils.unwrapTypeFactory(relBuilder)); asyncFunc = new AsyncLookupJoinWithCalcRunner( generatedFuncWithType.tableFunc(), @@ -647,7 +648,8 @@ protected ProcessFunction createSyncLookupJoinFunction( JavaScalaConversionUtil.toScala(projectionOnTemporalTable), filterOnTemporalTable, projectionOutputRelDataType, - tableSourceRowType); + tableSourceRowType, + ShortcutUtils.unwrapTypeFactory(relBuilder)); processFunc = new LookupJoinWithCalcRunner( diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/DeltaJoinTree.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/DeltaJoinTree.java index f7555e06822e4..c826fcd1ff121 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/DeltaJoinTree.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/DeltaJoinTree.java @@ -228,7 +228,8 @@ private DeltaJoinRuntimeTree.Node convert2RuntimeTreeInternal( node.filter, rowTypePassThroughCalc, rowTypeBeforeCalc, - generatedCalcName)) + generatedCalcName, + typeFactory)) .orElse(null); if (node instanceof BinaryInputNode) { diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java index c540c543a6d0e..c1b6ffcd17947 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java @@ -574,7 +574,8 @@ private static LookupHandlerBase generateLookupHandler( JavaScalaConversionUtil.toScala(projectionOnTemporalTable), filterOnTemporalTable, lookupSidePassThroughCalcRowType, - lookupTableSourceRowType); + lookupTableSourceRowType, + typeFactory); } Preconditions.checkState(!generatedFetcherCollector.containsKey(lookupTableOrdinal)); diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java index 415b78efeac09..b079eddf5570a 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java @@ -40,13 +40,20 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexLocalRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexUtil; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.tools.RelBuilder; import javax.annotation.Nullable; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + /** * Utilities for quick access of commonly used instances (like {@link FlinkTypeFactory}) without * long chains of getters or casting like {@code (FlinkTypeFactory) @@ -147,14 +154,15 @@ public static DataTypeFactory unwrapDataTypeFactory(RelBuilder relBuilder) { return null; } final RexCall call = (RexCall) rexNode; - if (!(call.getOperator() instanceof BridgingSqlFunction)) { + final SqlOperator operator = call.getOperator(); + if (!(operator instanceof BridgingSqlFunction)) { // legacy - if (call.getOperator() instanceof TableSqlFunction) { - return ((TableSqlFunction) call.getOperator()).udtf(); + if (operator instanceof TableSqlFunction) { + return ((TableSqlFunction) operator).udtf(); } return null; } - return ((BridgingSqlFunction) call.getOperator()).getDefinition(); + return ((BridgingSqlFunction) operator).getDefinition(); } public static @Nullable FunctionDefinition unwrapFunctionDefinition(SqlOperator operator) { @@ -169,6 +177,41 @@ public static boolean isFunctionKind(SqlOperator operator, FunctionKind kind) { return functionDefinition != null && functionDefinition.getKind() == kind; } + public static boolean isDeterministicThroughProgram( + RexNode node, @Nullable List exprs) { + if (exprs == null) { + return RexUtil.isDeterministic(node); + } + return isDeterministicThroughProgram(node, exprs, new HashSet<>()); + } + + private static boolean isDeterministicThroughProgram( + RexNode node, List exprs, Set visited) { + if (node instanceof RexCall) { + final RexCall call = (RexCall) node; + if (!call.getOperator().isDeterministic()) { + return false; + } + for (RexNode operand : call.getOperands()) { + if (!isDeterministicThroughProgram(operand, exprs, visited)) { + return false; + } + } + return true; + } + if (node instanceof RexLocalRef) { + final int idx = ((RexLocalRef) node).getIndex(); + // already on the stack: skip rather than recurse forever + return !visited.add(idx) + || isDeterministicThroughProgram(exprs.get(idx), exprs, visited); + } + if (node instanceof RexFieldAccess) { + return isDeterministicThroughProgram( + ((RexFieldAccess) node).getReferenceExpr(), exprs, visited); + } + return true; + } + public static @Nullable BridgingSqlFunction unwrapBridgingSqlFunction(RexCall call) { final SqlOperator operator = call.getOperator(); if (operator instanceof BridgingSqlFunction) { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala index 8072a9ca42711..966aca0abec22 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala @@ -18,32 +18,31 @@ package org.apache.flink.table.planner.codegen import org.apache.flink.api.common.functions.{FlatMapFunction, Function} -import org.apache.flink.api.dag.Transformation import org.apache.flink.configuration.ReadableConfig import org.apache.flink.table.api.{TableException, ValidationException} import org.apache.flink.table.data.{BoxedWrapperRowData, RowData} import org.apache.flink.table.functions.FunctionKind +import org.apache.flink.table.planner.calcite.{FlinkRexBuilder, FlinkTypeFactory} import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction import org.apache.flink.table.runtime.generated.GeneratedFunction import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory -import org.apache.flink.table.runtime.typeutils.InternalTypeInfo import org.apache.flink.table.types.logical.RowType import org.apache.calcite.rex._ +import scala.collection.JavaConverters._ + object CalcCodeGenerator { def generateCalcOperator( ctx: CodeGeneratorContext, - inputTransform: Transformation[RowData], + inputType: RowType, outputType: RowType, projection: Seq[RexNode], condition: Option[RexNode], + typeFactory: FlinkTypeFactory, retainHeader: Boolean = false, opName: String): CodeGenOperatorFactory[RowData] = { - val inputType = inputTransform.getOutputType - .asInstanceOf[InternalTypeInfo[RowData]] - .toRowType // filter out time attributes val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM val processCode = generateProcessCode( @@ -53,8 +52,13 @@ object CalcCodeGenerator { classOf[BoxedWrapperRowData], projection, condition, + typeFactory, + inputTerm, + CodeGenUtils.DEFAULT_OPERATOR_COLLECTOR_TERM, eagerInputUnboxingCode = true, - retainHeader = retainHeader) + retainHeader = retainHeader, + outputDirectly = false + ) val genOperator = OperatorCodeGenerator.generateOneInputStreamOperator[RowData, RowData]( @@ -76,7 +80,8 @@ object CalcCodeGenerator { calcProjection: Seq[RexNode], calcCondition: Option[RexNode], tableConfig: ReadableConfig, - classLoader: ClassLoader): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { + classLoader: ClassLoader, + typeFactory: FlinkTypeFactory): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { val ctx = new CodeGeneratorContext(tableConfig, classLoader) val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM val collectorTerm = CodeGenUtils.DEFAULT_COLLECTOR_TERM @@ -87,6 +92,8 @@ object CalcCodeGenerator { outRowClass, calcProjection, calcCondition, + typeFactory, + inputTerm, collectorTerm = collectorTerm, eagerInputUnboxingCode = false, outputDirectly = true @@ -110,6 +117,7 @@ object CalcCodeGenerator { outRowClass: Class[_ <: RowData], projection: Seq[RexNode], condition: Option[RexNode], + typeFactory: FlinkTypeFactory, inputTerm: String = CodeGenUtils.DEFAULT_INPUT1_TERM, collectorTerm: String = CodeGenUtils.DEFAULT_OPERATOR_COLLECTOR_TERM, eagerInputUnboxingCode: Boolean, @@ -121,7 +129,9 @@ object CalcCodeGenerator { projection.foreach(_.accept(ScalarFunctionsValidator)) condition.foreach(_.accept(ScalarFunctionsValidator)) - val exprGenerator = new ExprCodeGenerator(ctx, false) + val rexProgram = buildRexProgram(typeFactory, inputType, projection, condition) + + val exprGenerator = new ExprCodeGenerator(ctx, false, rexProgram) .bindInput(inputType, inputTerm = inputTerm) val onlyFilter = projection.lengthCompare(inputType.getFieldCount) == 0 && @@ -137,6 +147,8 @@ object CalcCodeGenerator { } def produceProjectionCode: String = { + val projection = rexProgram.getProjectList.asScala + val projectionExprs = projection.map(exprGenerator.generateExpression) val projectionExpression = exprGenerator.generateResultExpression(projectionExprs, outRowType, outRowClass) @@ -162,16 +174,20 @@ object CalcCodeGenerator { "It should be removed by CalcRemoveRule.") } else if (condition.isEmpty) { // only projection val projectionCode = produceProjectionCode + val localRefCode = ctx.reuseLocalRefCode() s""" |${if (eagerInputUnboxingCode) ctx.reuseInputUnboxingCode() else ""} + |$localRefCode |$projectionCode |""".stripMargin } else { - val filterCondition = exprGenerator.generateExpression(condition.get) + val filterCondition = exprGenerator.generateExpression(rexProgram.getCondition) // only filter if (onlyFilter) { + val localRefCode = ctx.reuseLocalRefCode() s""" |${if (eagerInputUnboxingCode) ctx.reuseInputUnboxingCode() else ""} + |$localRefCode |${filterCondition.code} |if (${filterCondition.resultTerm}) { | ${produceOutputCode(inputTerm)} @@ -181,19 +197,35 @@ object CalcCodeGenerator { val filterInputCode = ctx.reuseInputUnboxingCode() val filterInputSet = Set(ctx.reusableInputUnboxingExprs.keySet.toSeq: _*) + val filterLocalRefSet: Set[Int] = ctx.getReusableLocalRefExprBottomScope.keySet.toSet + // if any filter conditions, projection code will enter an new scope val projectionCode = produceProjectionCode val projectionInputCode = ctx.reusableInputUnboxingExprs - .filter(entry => !filterInputSet.contains(entry._1)) + .filter { case (k, _) => !filterInputSet.contains(k) } + .values + .map(_.code) + .mkString("\n") + + val filterLocalRefCode = ctx.getReusableLocalRefExprBottomScope + .filter { case (k, _) => filterLocalRefSet.contains(k) } .values .map(_.code) .mkString("\n") + val projectionLocalRefCode = ctx.getReusableLocalRefExprBottomScope + .filter { case (k, _) => !filterLocalRefSet.contains(k) } + .values + .map(_.code) + .mkString("\n") + s""" |${if (eagerInputUnboxingCode) filterInputCode else ""} + |$filterLocalRefCode |${filterCondition.code} |if (${filterCondition.resultTerm}) { - | ${if (eagerInputUnboxingCode) projectionInputCode else ""} + | ${if (eagerInputUnboxingCode) projectionInputCode else ""} + | $projectionLocalRefCode | $projectionCode |} |""".stripMargin @@ -201,6 +233,22 @@ object CalcCodeGenerator { } } + private def buildRexProgram( + typeFactory: FlinkTypeFactory, + inputType: RowType, + projection: Seq[RexNode], + condition: Option[RexNode] + ): RexProgram = { + val rexBuilder = new FlinkRexBuilder(typeFactory) + val relInputType = typeFactory.createFieldTypeFromLogicalType(inputType) + val builder = new RexProgramBuilder(relInputType, rexBuilder) + projection.foreach(p => builder.addProject(p, null)) + if (condition.isDefined) { + builder.addCondition(condition.get) + } + builder.getProgram + } + private object ScalarFunctionsValidator extends RexVisitorImpl[Unit](true) { override def visitCall(call: RexCall): Unit = { super.visitCall(call) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala index 811fd1a842087..ff924fe0f30f8 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala @@ -46,6 +46,8 @@ import org.apache.flink.types.{ColumnList, Row, RowKind} import org.apache.flink.types.bitmap.Bitmap import org.apache.flink.types.variant.Variant +import org.apache.calcite.rex.{RexNode, RexProgram} + import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong, Object => JObject, Short => JShort} import java.lang.reflect.Method import java.util.concurrent.atomic.AtomicLong @@ -1120,4 +1122,12 @@ object CodeGenUtils { GenerateUtils.generateFieldAccess(ctx, inputType, inputTerm, index) } } + + def getExprsFromProgramOrNull(rexProgram: RexProgram): java.util.List[RexNode] = { + if (rexProgram == null) { + null + } else { + rexProgram.getExprList + } + } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala index 02706cc130909..fade7279606a1 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala @@ -17,7 +17,6 @@ */ package org.apache.flink.table.planner.codegen -import org.apache.flink.api.common.functions.Function import org.apache.flink.api.common.typeutils.TypeSerializer import org.apache.flink.configuration.ReadableConfig import org.apache.flink.table.data.GenericRowData @@ -116,6 +115,36 @@ class CodeGeneratorContext( val reusableInputUnboxingExprs: mutable.Map[(String, Int), GeneratedExpression] = mutable.Map[(String, Int), GeneratedExpression]() + // Stack of RexLocalRef cache scopes (`exprList-index -> generated body`). + // * Bottom scope == getReusableLocalRefExprBottomScope: bodies are hoisted to the top of the method + // and run unconditionally for every row. + // * Inner scopes (push/popLocalRefScope): bodies are folded into a single guarded + // operand's code by ExprCodeGenerator.visitOperandInScopedCache and run only when + // the guard fires. Inserts always target the innermost scope; lookup walks innermost-out. + // + // Example — `CASE WHEN b <> 0 THEN a / b ELSE NULL`: + // + // With scoping (correct): + // boolean cmp = b != 0; + // if (cmp) { + // int div = a / b; // emitted inside the guarded scope + // result = div; + // } else { + // result = null; + // } + // + // Without scoping (buggy): + // int div = a / b; // throws ArithmeticException when b == 0 + // boolean cmp = b != 0; + // if (cmp) { result = div; } + // else { result = null; } + // + // The set of operand positions that get scoped lives in + // ExprCodeGenerator.conditionalOperandIndices — extend it when adding new short-circuit + // operators. + private val localRefScopes = + mutable.ArrayBuffer(mutable.LinkedHashMap.empty[Int, GeneratedExpression]) + // set of constructor statements that will be added only once // we use a LinkedHashSet to keep the insertion order private val reusableConstructorStatements: mutable.LinkedHashSet[(String, String)] = @@ -783,6 +812,7 @@ class CodeGeneratorContext( /** * Adds a reusable Object to the member area of the generated class + * * @param obj * the object to be added to the generated class * @param fieldNamePrefix @@ -1075,4 +1105,61 @@ class CodeGeneratorContext( fieldTerm } + + // --------------------------------------------------------------------------------- + // Reusable local ref code with scope + // --------------------------------------------------------------------------------- + + // Bottom scope of localRefScopes: holds unconditionally evaluated local refs. + def getReusableLocalRefExprBottomScope: mutable.LinkedHashMap[Int, GeneratedExpression] = + localRefScopes(0) + + /** + * Adds a reusable [[org.apache.calcite.rex.RexLocalRef]] expression keyed by its index in the + * program's exprList. The expression is stored in the innermost active scope. + */ + def addReusableLocalRefExpr(index: Int, expr: GeneratedExpression): Unit = + localRefScopes.last(index) = expr + + /** + * Looks up a previously cached [[org.apache.calcite.rex.RexLocalRef]] expression by its exprList + * index. Scopes are searched innermost-out so that a body cached inside a guarded scope takes + * precedence over an outer entry. + */ + def getReusableLocalRefExpr(index: Int): Option[GeneratedExpression] = { + // Search innermost-out: a body cached in an inner (guarded) scope wins over outer + // entries. In practice the cache is monotone — an entry never appears in two scopes + // simultaneously. + var i = localRefScopes.size - 1 + while (i >= 0) { + val maybe = localRefScopes(i).get(index) + if (maybe.isDefined) return maybe + i -= 1 + } + None + } + + /** + * Returns the generated code for all unconditionally-evaluated local-ref expressions (bottom + * scope), concatenated in insertion order. + */ + def reuseLocalRefCode(): String = { + getReusableLocalRefExprBottomScope.values.map(_.code).mkString("\n") + } + + /** Pushes a new, empty local-ref cache scope onto the scope stack. */ + def pushLocalRefScope(): Unit = { + localRefScopes.append(mutable.LinkedHashMap.empty) + } + + /** + * Pops the innermost local-ref cache scope and returns its entries. The bottom scope + * ([[getReusableLocalRefExprBottomScope]]) cannot be popped. + */ + def popLocalRefScope(): scala.collection.Map[Int, GeneratedExpression] = { + require( + localRefScopes.size > 1, + "Cannot pop the bottom RexLocalRef cache scope (reusableLocalRefExprs).") + localRefScopes.remove(localRefScopes.size - 1) + } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala index 7154aa09f0a6e..1de3ca9bee880 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala @@ -37,7 +37,8 @@ import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable._ import org.apache.flink.table.planner.functions.sql.SqlThrowExceptionFunction import org.apache.flink.table.planner.functions.utils.{ScalarSqlFunction, TableSqlFunction} -import org.apache.flink.table.planner.plan.utils.RexLiteralUtil +import org.apache.flink.table.planner.plan.utils.{FlinkRexUtil, RexLiteralUtil} +import org.apache.flink.table.planner.utils.ShortcutUtils import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromLogicalTypeToDataType import org.apache.flink.table.runtime.types.PlannerTypeUtils.isInteroperable import org.apache.flink.table.runtime.typeutils.TypeCheckUtils @@ -55,9 +56,15 @@ import scala.collection.JavaConversions._ * This code generator is mainly responsible for generating codes for a given calcite [[RexNode]]. * It can also generate type conversion codes for the result converter. */ -class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) +class ExprCodeGenerator( + ctx: CodeGeneratorContext, + nullableInput: Boolean, + val rexProgram: RexProgram) extends RexVisitor[GeneratedExpression] { + def this(ctx: CodeGeneratorContext, nullableInput: Boolean) = + this(ctx, nullableInput, null) + /** term of the [[ProcessFunction]]'s context, can be changed when needed */ var contextTerm = "ctx" @@ -344,7 +351,6 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) } override def visitInputRef(inputRef: RexInputRef): GeneratedExpression = { - // for specific custom code generation if (input1Type == null) { return GeneratedExpression( inputRef.getName, @@ -416,8 +422,53 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) GeneratedExpression(input1Term, NEVER_NULL, NO_CODE, input1Type) } - override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression = - throw new CodeGenException("RexLocalRef are not supported yet.") + override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression = { + // addReusableLocalVariable + // for specific custom code generation + if (input1Type == null) { + return GeneratedExpression( + localRef.getName, + localRef.getName + "IsNull", + NO_CODE, + FlinkTypeFactory.toLogicalType(localRef.getType)) + } + // for the general cases with a previous call to bindInput() + val input1Arity = input1Type match { + case r: RowType => r.getFieldCount + case _ => 1 + } + if (localRef.getIndex >= input1Arity) { + if (rexProgram == null) { + throw new CodeGenException(s"RexLocalRef(${localRef.getIndex}) requires a RexProgram.") + } + val idx = localRef.getIndex + val target = rexProgram.getExprList.get(idx) + if (!isDeterministicThroughProgram(target)) { + return target.accept(this) + } + val full = ctx.getReusableLocalRefExpr(idx) match { + case Some(cached) => cached + case None => + val expr = target.accept(this) + ctx.addReusableLocalRefExpr(idx, expr) + expr + } + return GeneratedExpression( + full.resultTerm, + full.nullTerm, + NO_CODE, + full.resultType, + full.literalValue) + } + + generateInputAccess( + ctx, + input1Type, + input1Term, + localRef.getIndex, + nullableInput, + deepCopy = true) + } def visitRexFieldVariable(variable: RexFieldVariable): GeneratedExpression = { val internalType = FlinkTypeFactory.toLogicalType(variable.dataType) @@ -462,7 +513,7 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) val resultType = FlinkTypeFactory.toLogicalType(call.getType) // throw exception if json function is called outside JSON_OBJECT or JSON_ARRAY function - if (isJsonFunctionOperand(call)) { + if (isJsonFunctionOperand(call, CodeGenUtils.getExprsFromProgramOrNull(rexProgram))) { throw new ValidationException( "The JSON() function is currently only supported inside JSON_ARRAY() or as the VALUE param" + " of JSON_OBJECT(). Example: JSON_OBJECT('a', JSON('{\"key\": \"value\"}')) or " + @@ -473,10 +524,12 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) return generateSearch( ctx, generateExpression(call.getOperands.get(0)), - call.getOperands.get(1).asInstanceOf[RexLiteral]) + rexProgram, + call.getOperands) } // convert operands and help giving untyped NULL literals a type + val condIdxs = conditionalOperandIndices(call) val operands = call.getOperands.zipWithIndex.map { // this helps e.g. for AS(null) @@ -487,15 +540,55 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) generateNullLiteral(resultType) // We only support the JSON function inside of JSON_OBJECT or JSON_ARRAY - case (operand: RexNode, i) if isSupportedJsonOperand(operand, call, i) => + case (operand: RexNode, i) + if isSupportedJsonOperand( + operand, + call, + i, + CodeGenUtils.getExprsFromProgramOrNull(rexProgram)) => generateJsonCall(operand) + case (o @ _, i) if condIdxs.contains(i) => visitOperandInScopedCache(o) + case (o @ _, _) => o.accept(this) } generateCallExpression(ctx, call, operands, resultType) } + /** + * Indices of `call`'s operands that are NOT unconditionally evaluated at runtime. Used to scope + * the RexLocalRef cache so that bodies cached while visiting these operands are not hoisted out + * of the surrounding short-circuit / if-block. + * + * - `CASE(when_1, then_1, when_2, then_2, ..., else)`: only `when_1` is unconditional. + * - `AND(a_0, a_1, ..., a_n)` / `OR(...)`: only `a_0` is unconditional; subsequent operands are + * short-circuited by the operator semantics and the codegen. + */ + private def conditionalOperandIndices(call: RexCall): Set[Int] = call.getKind match { + case SqlKind.CASE | SqlKind.AND | SqlKind.OR | SqlKind.COALESCE => + (1 until call.getOperands.size).toSet + case _ => Set.empty + } + + private def visitOperandInScopedCache(operand: RexNode): GeneratedExpression = { + ctx.pushLocalRefScope() + val (operandExpr, scopedBodies) = { + val expr = operand.accept(this) + val popped = ctx.popLocalRefScope() + (expr, popped.values.map(_.code).mkString("\n")) + } + if (scopedBodies.isEmpty) { + operandExpr + } else + GeneratedExpression( + operandExpr.resultTerm, + operandExpr.nullTerm, + scopedBodies + "\n" + operandExpr.code, + operandExpr.resultType, + operandExpr.literalValue) + } + override def visitOver(over: RexOver): GeneratedExpression = throw new CodeGenException("Aggregate functions over windows are not supported yet.") @@ -786,9 +879,11 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) case JSON_QUERY => new JsonQueryCallGen().generate(ctx, operands, resultType) - case JSON_OBJECT => new JsonObjectCallGen(call).generate(ctx, operands, resultType) + case JSON_OBJECT => + new JsonObjectCallGen(call, rexProgram).generate(ctx, operands, resultType) - case JSON_ARRAY => new JsonArrayCallGen(call).generate(ctx, operands, resultType) + case JSON_ARRAY => + new JsonArrayCallGen(call, rexProgram).generate(ctx, operands, resultType) case _: SqlThrowExceptionFunction => val nullValue = generateNullLiteral(resultType) @@ -827,7 +922,7 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) generateGreatestLeast(ctx, resultType, operands, greatest = false) case BuiltInFunctionDefinitions.JSON_STRING => - new JsonStringCallGen(call).generate(ctx, operands, resultType) + new JsonStringCallGen(call, rexProgram).generate(ctx, operands, resultType) case BuiltInFunctionDefinitions.INTERNAL_HASHCODE => new HashCodeCallGen().generate(ctx, operands, resultType) @@ -847,7 +942,7 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) new JsonCallGen().generate(ctx, operands, FlinkTypeFactory.toLogicalType(call.getType)) case _ => - new BridgingSqlFunctionCallGen(call).generate(ctx, operands, resultType) + new BridgingSqlFunctionCallGen(call, rexProgram).generate(ctx, operands, resultType) } // advanced scalar functions @@ -875,7 +970,14 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) } private def generateJsonCall(operand: RexNode) = { - val jsonCall = operand.asInstanceOf[RexCall] + // After unification of projections + condition into a single RexProgram, structurally + // identical sub-expressions are collapsed into one exprList entry referenced via + // RexLocalRef. JSON_OBJECT/JSON_ARRAY operands recognised as JSON via + // isSupportedJsonOperand may therefore arrive here as a RexLocalRef; resolve it back to + // the underlying RexCall before casting. + val jsonCall = FlinkRexUtil + .expandLocalRef(operand, CodeGenUtils.getExprsFromProgramOrNull(rexProgram)) + .asInstanceOf[RexCall] val jsonOperands = jsonCall.getOperands.map(_.accept(this)) generateCallExpression( ctx, @@ -896,4 +998,9 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) } }.toArray } + + private def isDeterministicThroughProgram(node: RexNode): Boolean = + ShortcutUtils.isDeterministicThroughProgram( + node, + CodeGenUtils.getExprsFromProgramOrNull(rexProgram)) } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala index 567554e6e1c84..b1c70fb9c5260 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala @@ -139,8 +139,8 @@ class ExpressionReducer( } else unreduced match { case call: RexCall - if (nonReducibleJsonFunctions.contains(call.getOperator) || isJsonFunctionOperand( - call)) => + if (nonReducibleJsonFunctions.contains(call.getOperator) + || isJsonFunctionOperand(call, null)) => reducedValues.add(unreduced) case _ => unreduced.getType.getSqlTypeName match { @@ -297,7 +297,10 @@ class ExpressionReducer( } // Exclude some JSON functions which behave differently // when called as an argument of another call of one of these functions. - if (nonReducibleJsonFunctions.contains(call.getOperator) || isJsonFunctionOperand(call)) { + if ( + nonReducibleJsonFunctions.contains(call.getOperator) + || isJsonFunctionOperand(call, null) + ) { None } else { Some(call) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCodeGenerator.scala index 817d399bacc8e..3131de4f8f909 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCodeGenerator.scala @@ -18,12 +18,11 @@ package org.apache.flink.table.planner.codegen import org.apache.flink.api.common.functions._ -import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.functions.ProcessFunction import org.apache.flink.streaming.api.functions.async.{AsyncFunction, RichAsyncFunction} import org.apache.flink.table.planner.codegen.CodeGenUtils._ import org.apache.flink.table.planner.codegen.Indenter.toISC -import org.apache.flink.table.runtime.generated.{FilterCondition, GeneratedFilterCondition, GeneratedFunction, GeneratedJoinCondition, JoinCondition} +import org.apache.flink.table.runtime.generated._ import org.apache.flink.table.types.logical.LogicalType /** diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala index 3fd256e071d4b..a64a97727ccd8 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala @@ -21,11 +21,12 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.{ArrayNode, ObjectNode} import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.util.RawValue import org.apache.flink.table.api.{DataTypes, JsonOnNull} -import org.apache.flink.table.functions.BuiltInFunctionDefinitions.JSON +import org.apache.flink.table.functions.{BuiltInFunctionDefinitions, FunctionDefinition} import org.apache.flink.table.planner.codegen.CodeGenUtils._ -import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable.{JSON_ARRAY, JSON_OBJECT} +import org.apache.flink.table.planner.functions.sql.{SqlJsonArrayFunctionWrapper, SqlJsonObjectFunctionWrapper, SqlJsonQueryFunctionWrapper, SqlJsonValueFunctionWrapper} +import org.apache.flink.table.planner.plan.utils.FlinkRexUtil import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala -import org.apache.flink.table.planner.utils.ShortcutUtils.unwrapFunctionDefinition +import org.apache.flink.table.planner.utils.ShortcutUtils import org.apache.flink.table.runtime.functions.SqlJsonUtils import org.apache.flink.table.runtime.typeutils.TypeCheckUtils.isCharacterString import org.apache.flink.table.types.logical._ @@ -51,8 +52,9 @@ object JsonGenerateUtils { def createNodeTerm( ctx: CodeGeneratorContext, expression: GeneratedExpression, - operand: RexNode): String = { - if (isJsonObjectOrArrayOperand(operand) || isJsonFunctionOperand(operand)) { + operand: RexNode, + exprs: java.util.List[RexNode]): String = { + if (isJsonObjectOrArrayOperand(operand, exprs) || isJsonFunctionOperand(operand, exprs)) { createRawNodeTerm(expression) } else { createNodeTerm(ctx, expression) @@ -177,59 +179,36 @@ object JsonGenerateUtils { } } - /** Determines whether the given operand is a call to a JSON_OBJECT */ - def isJsonObjectOperand(operand: RexNode): Boolean = { - operand match { - case rexCall: RexCall => - rexCall.getOperator match { - case JSON_OBJECT => true - case _ => false - } - case _ => false - } - } + /** Determines whether the given operand is a call to a JSON_OBJECT. */ + def isJsonObjectOperand(operand: RexNode, localRefs: java.util.List[RexNode]): Boolean = + isOneOfFunctionDefinitions( + FlinkRexUtil.expandLocalRef(operand, localRefs), + BuiltInFunctionDefinitions.JSON_OBJECT) - /** Determines whether the given operand is a call to a JSON_ARRAY */ - def isJsonArrayOperand(operand: RexNode): Boolean = { - operand match { - case rexCall: RexCall => - rexCall.getOperator match { - case JSON_ARRAY => true - case _ => false - } - case _ => false - } - } + /** Determines whether the given operand is a call to a JSON_ARRAY. */ + def isJsonArrayOperand(operand: RexNode, localRefs: java.util.List[RexNode]): Boolean = + isOneOfFunctionDefinitions( + FlinkRexUtil.expandLocalRef(operand, localRefs), + BuiltInFunctionDefinitions.JSON_ARRAY) /** * Determines whether the given operand is a call to a JSON_OBJECT or JSON_ARRAY whose result * should be inserted as a raw value instead of as a character string. */ - def isJsonObjectOrArrayOperand(operand: RexNode): Boolean = { - operand match { - case rexCall: RexCall => - rexCall.getOperator match { - case JSON_OBJECT | JSON_ARRAY => true - case _ => false - } - case _ => false - } - } + def isJsonObjectOrArrayOperand(operand: RexNode, localRefs: java.util.List[RexNode]): Boolean = + isOneOfFunctionDefinitions( + FlinkRexUtil.expandLocalRef(operand, localRefs), + BuiltInFunctionDefinitions.JSON_OBJECT, + BuiltInFunctionDefinitions.JSON_ARRAY) /** * Determines whether the given operand is a call to JSON function whose call currently just - * passes through the input value as output value + * passes through the input value as output value. */ - def isJsonFunctionOperand(operand: RexNode): Boolean = { - operand match { - case rexCall: RexCall => - unwrapFunctionDefinition(rexCall) match { - case JSON => true - case _ => false - } - case _ => false - } - } + def isJsonFunctionOperand(operand: RexNode, localRefs: java.util.List[RexNode]): Boolean = + isOneOfFunctionDefinitions( + FlinkRexUtil.expandLocalRef(operand, localRefs), + BuiltInFunctionDefinitions.JSON) /** * Determines whether a JSON function is allowed in the current context. JSON functions are @@ -237,9 +216,13 @@ object JsonGenerateUtils { * of a JSON_OBJECT call, we do (i % 2) == 0 to check if it's being used in second parameter, the * values' parameter. */ - def isSupportedJsonOperand(operand: RexNode, call: RexNode, i: Int): Boolean = { - isJsonFunctionOperand(operand) && - (isJsonArrayOperand(call) || isJsonObjectOperand(call) && (i % 2) == 0) + def isSupportedJsonOperand( + operand: RexNode, + call: RexNode, + i: Int, + localRefs: java.util.List[RexNode]): Boolean = { + isJsonFunctionOperand(operand, localRefs) && + (isJsonArrayOperand(call, localRefs) || isJsonObjectOperand(call, localRefs) && (i % 2) == 0) } /** Generates a method to convert arrays into [[ArrayNode]]. */ @@ -331,4 +314,23 @@ object JsonGenerateUtils { ctx.addReusableMember(methodCode) methodName } + + def isOneOfFunctionDefinitions( + rexNode: RexNode, + expectedDefinitions: FunctionDefinition*): Boolean = { + if (!rexNode.isInstanceOf[RexCall]) return false + val call = rexNode.asInstanceOf[RexCall] + val unwrapped = ShortcutUtils.unwrapFunctionDefinition(call) match { + case d if d != null => d + case _ => + call.getOperator match { + case _: SqlJsonArrayFunctionWrapper => BuiltInFunctionDefinitions.JSON_ARRAY + case _: SqlJsonObjectFunctionWrapper => BuiltInFunctionDefinitions.JSON_OBJECT + case _: SqlJsonQueryFunctionWrapper => BuiltInFunctionDefinitions.JSON_QUERY + case _: SqlJsonValueFunctionWrapper => BuiltInFunctionDefinitions.JSON_VALUE + case _ => return false + } + } + expectedDefinitions.exists(_ eq unwrapped) + } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala index 4256c90b5960d..bae94282d1da7 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala @@ -18,7 +18,7 @@ package org.apache.flink.table.planner.codegen import org.apache.flink.api.common.functions.DefaultOpenContext -import org.apache.flink.configuration.{Configuration, ReadableConfig} +import org.apache.flink.configuration.ReadableConfig import org.apache.flink.metrics.Gauge import org.apache.flink.table.data.{RowData, TimestampData} import org.apache.flink.table.data.utils.JoinedRowData diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala index db158572df9e9..e011dd2cfcec6 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala @@ -354,14 +354,16 @@ object LookupJoinCodeGenerator { projection: Seq[RexNode], condition: RexNode, outputType: RelDataType, - tableSourceRowType: RowType): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { + tableSourceRowType: RowType, + typeFactory: FlinkTypeFactory): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { generateCalcMapFunction( tableConfig, classLoader, projection, condition, FlinkTypeFactory.toLogicalRowType(outputType), - tableSourceRowType + tableSourceRowType, + typeFactory ) } @@ -375,7 +377,8 @@ object LookupJoinCodeGenerator { projection: Seq[RexNode], condition: RexNode, outputType: RowType, - tableSourceRowType: RowType): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { + tableSourceRowType: RowType, + typeFactory: FlinkTypeFactory): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { generateCalcMapFunction( tableConfig, classLoader, @@ -383,7 +386,8 @@ object LookupJoinCodeGenerator { condition, outputType, tableSourceRowType, - "TableCalcMapFunction") + "TableCalcMapFunction", + typeFactory) } /** @@ -397,7 +401,8 @@ object LookupJoinCodeGenerator { condition: RexNode, outputType: RowType, tableSourceRowType: RowType, - name: String): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { + name: String, + typeFactory: FlinkTypeFactory): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { CalcCodeGenerator.generateFunction( tableSourceRowType, name, @@ -406,7 +411,8 @@ object LookupJoinCodeGenerator { projection, Option(condition), tableConfig, - classLoader + classLoader, + typeFactory ) } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingFunctionGenUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingFunctionGenUtil.scala index cb241c8feb4dd..2115e76304d96 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingFunctionGenUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingFunctionGenUtil.scala @@ -18,7 +18,7 @@ package org.apache.flink.table.planner.codegen.calls import org.apache.flink.api.common.functions.{AbstractRichFunction, OpenContext, RichFunction} -import org.apache.flink.configuration.{Configuration, ReadableConfig} +import org.apache.flink.configuration.ReadableConfig import org.apache.flink.table.api.{DataTypes, TableException} import org.apache.flink.table.api.Expressions.callSql import org.apache.flink.table.data.{GenericRowData, RawValueData, StringData} @@ -27,9 +27,10 @@ import org.apache.flink.table.expressions.ApiExpressionUtils.{typeLiteral, unres import org.apache.flink.table.expressions.Expression import org.apache.flink.table.functions._ import org.apache.flink.table.functions.SpecializedFunction.{ExpressionEvaluator, ExpressionEvaluatorFactory} -import org.apache.flink.table.functions.UserDefinedFunctionHelper.{validateClassForRuntime, ASYNC_SCALAR_EVAL, ASYNC_TABLE_EVAL, SCALAR_EVAL, TABLE_EVAL} +import org.apache.flink.table.functions.UserDefinedFunctionHelper._ import org.apache.flink.table.planner.calcite.{FlinkTypeFactory, RexFactory} import org.apache.flink.table.planner.codegen._ +import org.apache.flink.table.planner.codegen.AsyncCodeGenerator.DEFAULT_DELEGATING_FUTURE_TERM import org.apache.flink.table.planner.codegen.CodeGenUtils._ import org.apache.flink.table.planner.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE} import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala @@ -48,8 +49,6 @@ import org.apache.flink.table.types.utils.DataTypeUtils import org.apache.flink.table.types.utils.DataTypeUtils.{isInternal, validateInputDataType, validateOutputDataType} import org.apache.flink.util.Preconditions -import AsyncCodeGenerator.{generateFunction, DEFAULT_DELEGATING_FUTURE_TERM} - import java.util.concurrent.CompletableFuture import scala.collection.JavaConverters._ diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala index cada84a1bca8b..83223e63200d9 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala @@ -26,7 +26,7 @@ import org.apache.flink.table.planner.functions.inference.OperatorBindingCallCon import org.apache.flink.table.runtime.collector.WrappingCollector import org.apache.flink.table.types.logical.LogicalType -import org.apache.calcite.rex.{RexCall, RexCallBinding} +import org.apache.calcite.rex.{RexCall, RexCallBinding, RexProgram} import java.util.Collections @@ -37,7 +37,7 @@ import java.util.Collections * generator will be a reference to a [[WrappingCollector]]. Furthermore, atomic types are wrapped * into a row by the collector. */ -class BridgingSqlFunctionCallGen(call: RexCall) extends CallGenerator { +class BridgingSqlFunctionCallGen(call: RexCall, rexProgram: RexProgram) extends CallGenerator { override def generate( ctx: CodeGeneratorContext, @@ -54,7 +54,7 @@ class BridgingSqlFunctionCallGen(call: RexCall) extends CallGenerator { val callContext = new OperatorBindingCallContext( dataTypeFactory, definition, - RexCallBinding.create(function.getTypeFactory, call, Collections.emptyList()), + RexCallBinding.create(function.getTypeFactory, call, rexProgram, Collections.emptyList()), call.getType) // create the final UDF for runtime diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala index ea324ec6082b5..524e1cb763587 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala @@ -19,16 +19,17 @@ package org.apache.flink.table.planner.codegen.calls import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.{ArrayNode, NullNode} import org.apache.flink.table.api.JsonOnNull -import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, GeneratedExpression} +import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenUtils, GeneratedExpression} import org.apache.flink.table.planner.codegen.CodeGenUtils.{className, newName, primitiveTypeTermForType, BINARY_STRING} import org.apache.flink.table.planner.codegen.JsonGenerateUtils.{createNodeTerm, getOnNullBehavior} import org.apache.flink.table.runtime.functions.SqlJsonUtils import org.apache.flink.table.types.logical.LogicalType -import org.apache.calcite.rex.RexCall +import org.apache.calcite.rex.{RexCall, RexProgram} /** [[CallGenerator]] for `JSON_ARRAY`. */ -class JsonArrayCallGen(call: RexCall) extends CallGenerator { +class JsonArrayCallGen(call: RexCall, rexProgram: RexProgram) extends CallGenerator { + private def jsonUtils = className[SqlJsonUtils] override def generate( @@ -47,7 +48,9 @@ class JsonArrayCallGen(call: RexCall) extends CallGenerator { .drop(1) .map { case (elementExpr, elementIdx) => - val elementTerm = createNodeTerm(ctx, elementExpr, call.operands.get(elementIdx)) + val exprs = CodeGenUtils.getExprsFromProgramOrNull(rexProgram) + val elementTerm = + createNodeTerm(ctx, elementExpr, call.operands.get(elementIdx), exprs) onNull match { case JsonOnNull.NULL => diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala index 9a5c87fb06f79..37d126776c451 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala @@ -19,13 +19,13 @@ package org.apache.flink.table.planner.codegen.calls import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.{NullNode, ObjectNode} import org.apache.flink.table.api.JsonOnNull -import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, GeneratedExpression} +import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenUtils, GeneratedExpression} import org.apache.flink.table.planner.codegen.CodeGenUtils._ import org.apache.flink.table.planner.codegen.JsonGenerateUtils.{createNodeTerm, getOnNullBehavior} import org.apache.flink.table.runtime.functions.SqlJsonUtils import org.apache.flink.table.types.logical.LogicalType -import org.apache.calcite.rex.RexCall +import org.apache.calcite.rex.{RexCall, RexProgram} /** * [[CallGenerator]] for `JSON_OBJECT`. @@ -37,7 +37,8 @@ import org.apache.calcite.rex.RexCall * We remedy this by treating nested calls to this function differently and inserting the value as a * raw node instead of as a string node. */ -class JsonObjectCallGen(call: RexCall) extends CallGenerator { +class JsonObjectCallGen(call: RexCall, rexProgram: RexProgram) extends CallGenerator { + private def jsonUtils = className[SqlJsonUtils] override def generate( @@ -57,7 +58,8 @@ class JsonObjectCallGen(call: RexCall) extends CallGenerator { .grouped(2) .map { case Seq((keyExpr, _), (valueExpr, valueIdx)) => - val valueTerm = createNodeTerm(ctx, valueExpr, call.operands.get(valueIdx)) + val exprs = CodeGenUtils.getExprsFromProgramOrNull(rexProgram) + val valueTerm = createNodeTerm(ctx, valueExpr, call.operands.get(valueIdx), exprs) onNull match { case JsonOnNull.NULL => diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala index 1265d11a2787b..3655943327bd9 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala @@ -17,16 +17,17 @@ */ package org.apache.flink.table.planner.codegen.calls -import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, GeneratedExpression} +import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenUtils, GeneratedExpression} import org.apache.flink.table.planner.codegen.CodeGenUtils.{className, newName, primitiveTypeTermForType, BINARY_STRING} import org.apache.flink.table.planner.codegen.JsonGenerateUtils.createNodeTerm import org.apache.flink.table.runtime.functions.SqlJsonUtils import org.apache.flink.table.types.logical.LogicalType -import org.apache.calcite.rex.RexCall +import org.apache.calcite.rex.{RexCall, RexProgram} /** [[CallGenerator]] for `JSON_STRING`. */ -class JsonStringCallGen(call: RexCall) extends CallGenerator { +class JsonStringCallGen(call: RexCall, rexProgram: RexProgram) extends CallGenerator { + private def jsonUtils = className[SqlJsonUtils] override def generate( @@ -34,7 +35,8 @@ class JsonStringCallGen(call: RexCall) extends CallGenerator { operands: Seq[GeneratedExpression], returnType: LogicalType): GeneratedExpression = { - val valueTerm = createNodeTerm(ctx, operands.head, call.operands.get(0)) + val exprs = CodeGenUtils.getExprsFromProgramOrNull(rexProgram) + val valueTerm = createNodeTerm(ctx, operands.head, call.operands.get(0), exprs) val resultTerm = newName(ctx, "result") val resultTermType = primitiveTypeTermForType(returnType) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/SearchOperatorGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/SearchOperatorGen.scala index de55d1f8c308b..0f3e36e64175f 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/SearchOperatorGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/SearchOperatorGen.scala @@ -27,7 +27,7 @@ import org.apache.flink.table.planner.plan.utils.RexLiteralUtil.toFlinkInternalV import org.apache.flink.table.types.logical.{BooleanType, LogicalType} import org.apache.flink.table.types.logical.utils.LogicalTypeMerging.findCommonType -import org.apache.calcite.rex.{RexLiteral, RexUnknownAs} +import org.apache.calcite.rex.{RexLiteral, RexLocalRef, RexNode, RexProgram, RexUnknownAs} import org.apache.calcite.util.{RangeSets, Sarg} import java.util.Arrays.asList @@ -53,7 +53,17 @@ object SearchOperatorGen { def generateSearch( ctx: CodeGeneratorContext, target: GeneratedExpression, - sargLiteral: RexLiteral): GeneratedExpression = { + rexProgram: RexProgram, + operands: java.util.List[RexNode]): GeneratedExpression = { + val sargLiteral = + if (rexProgram != null && operands.get(1).isInstanceOf[RexLocalRef]) { + rexProgram.getExprList + .get(operands.get(1).asInstanceOf[RexLocalRef].getIndex) + .asInstanceOf[RexLiteral] + } else { + operands.get(1).asInstanceOf[RexLiteral] + } + val sarg: Sarg[Nothing] = sargLiteral.getValueAs(classOf[Sarg[Nothing]]) val targetType = target.resultType val sargType = FlinkTypeFactory.toLogicalType(sargLiteral.getType) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala index 5cdbbff533a66..94d262463c6ff 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala @@ -525,6 +525,14 @@ object FlinkRexUtil { RexUtil.expandSearch(rexBuilder, program, program.expandLocalRef(program.getCondition))) RelOptUtil.conjunctions(condition) } + + def expandLocalRef(operand: RexNode, localRefs: util.List[RexNode]): RexNode = { + var expanded = operand + while (expanded.isInstanceOf[RexLocalRef] && localRefs != null) { + expanded = localRefs.get(expanded.asInstanceOf[RexLocalRef].getIndex) + } + expanded + } } /** diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonFunctionsITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonFunctionsITCase.java index 6c236ccd04d5d..1ec9561b25083 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonFunctionsITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonFunctionsITCase.java @@ -41,7 +41,6 @@ import java.time.Instant; import java.time.LocalDateTime; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -94,6 +93,7 @@ Stream getTestSetSpecs() { testCases.addAll(jsonQuoteSpec()); testCases.addAll(jsonUnquoteSpecWithValidInput()); testCases.addAll(jsonUnquoteSpecWithInvalidInput()); + testCases.addAll(jsonLocalRefReuseSpec()); return testCases.stream(); } @@ -296,7 +296,7 @@ private static TestSetSpec jsonValueSpec() { } private static List isJsonSpec() { - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.IS_JSON) .onFieldsWithData(1) .andDataTypes(INT()) @@ -367,7 +367,7 @@ private static List isJsonSpec() { private static List jsonQuerySpec() { final String jsonValue = getJsonFromResource("/json/json-query.json"); - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_QUERY) .onFieldsWithData((String) null) .andDataTypes(STRING()) @@ -599,7 +599,7 @@ private static List jsonStringSpec() { multisetData.put("M1", 1); multisetData.put("M2", 2); - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_STRING) .onFieldsWithData(0) .testResult( @@ -616,7 +616,7 @@ private static List jsonStringSpec() { 1.23, LocalDateTime.parse("1990-06-02T13:37:42.001"), Instant.parse("1990-06-02T13:37:42.001Z"), - Arrays.asList("A1", "A2", "A3"), + List.of("A1", "A2", "A3"), Row.of("R1", Instant.parse("1990-06-02T13:37:42.001Z")), mapData, multisetData, @@ -717,7 +717,7 @@ private static List jsonStringSpec() { } private static List jsonSpec() { - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_OBJECT) .onFieldsWithData("{\"key\":\"value\"}", "{\"key\": {\"value\": 42}}") .andDataTypes(STRING(), STRING()) @@ -946,7 +946,7 @@ private static List jsonObjectSpec() { multisetData.put("M1", 1); multisetData.put("M2", 2); - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_OBJECT) .onFieldsWithData(0) .testResult( @@ -977,7 +977,7 @@ private static List jsonObjectSpec() { 1.23, LocalDateTime.parse("1990-06-02T13:37:42.001"), Instant.parse("1990-06-02T13:37:42.001Z"), - Arrays.asList("A1", "A2", "A3"), + List.of("A1", "A2", "A3"), Row.of("R1", Instant.parse("1990-06-02T13:37:42.001Z")), mapData, multisetData, @@ -1105,7 +1105,7 @@ private static List jsonObjectSpec() { private static List jsonQuoteSpec() { - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_QUOTE) .onFieldsWithData(0) .testResult( @@ -1177,7 +1177,7 @@ private static List jsonQuoteSpec() { private static List jsonUnquoteSpecWithValidInput() { - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_UNQUOTE) .onFieldsWithData(0) .testResult( @@ -1319,7 +1319,7 @@ private static List jsonUnquoteSpecWithValidInput() { private static List jsonUnquoteSpecWithInvalidInput() { - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_UNQUOTE) .onFieldsWithData(0) .testResult( @@ -1406,7 +1406,7 @@ private static List jsonArraySpec() { multisetData.put("M1", 1); multisetData.put("M2", 2); - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_ARRAY) .onFieldsWithData(0) .testResult( @@ -1436,7 +1436,7 @@ private static List jsonArraySpec() { 1.23, LocalDateTime.parse("1990-06-02T13:37:42.001"), Instant.parse("1990-06-02T13:37:42.001Z"), - Arrays.asList("A1", "A2", "A3"), + List.of("A1", "A2", "A3"), Row.of("R1", Instant.parse("1990-06-02T13:37:42.001Z")), mapData, multisetData, @@ -1540,6 +1540,182 @@ private static List jsonArraySpec() { STRING().notNull())); } + /** Pins the local-ref / common-sub-expression handling for JSON construction calls. */ + private static List jsonLocalRefReuseSpec() { + return List.of( + // Shared JSON(f) inside two JSON_OBJECT projections. + TestSetSpec.forFunction( + BuiltInFunctionDefinitions.JSON_OBJECT, + "Shared JSON(f) sub-expression across JSON_OBJECT projections") + .onFieldsWithData("[1,2,3]") + .andDataTypes(STRING()) + .testResult( + resultSpec( + jsonObject(JsonOnNull.NULL, "k1", json($("f0"))), + "JSON_OBJECT(KEY 'k1' VALUE JSON(f0))", + "{\"k1\":[1,2,3]}", + STRING().notNull(), + STRING().notNull()), + resultSpec( + jsonObject(JsonOnNull.NULL, "k2", json($("f0"))), + "JSON_OBJECT(KEY 'k2' VALUE JSON(f0))", + "{\"k2\":[1,2,3]}", + STRING().notNull(), + STRING().notNull())) + .testSqlResult( + "JSON_OBJECT(KEY 'k1' VALUE JSON(f0))," + + " JSON_OBJECT(KEY 'k2' VALUE JSON(f0))", + List.of("{\"k1\":[1,2,3]}", "{\"k2\":[1,2,3]}"), + List.of(STRING().notNull(), STRING().notNull())), + // Shared JSON_ARRAY(...) inside two JSON_OBJECT projections. + TestSetSpec.forFunction( + BuiltInFunctionDefinitions.JSON_OBJECT, + "Shared JSON_ARRAY sub-expression across JSON_OBJECT projections") + .onFieldsWithData(1, 2, 3) + .andDataTypes(INT(), INT(), INT()) + .testResult( + resultSpec( + jsonObject( + JsonOnNull.NULL, + "a", + jsonArray( + JsonOnNull.NULL, + $("f0"), + $("f1"), + $("f2"))), + "JSON_OBJECT(KEY 'a' VALUE JSON_ARRAY(f0, f1, f2))", + "{\"a\":[1,2,3]}", + STRING().notNull(), + STRING().notNull()), + resultSpec( + jsonObject( + JsonOnNull.NULL, + "b", + jsonArray( + JsonOnNull.NULL, + $("f0"), + $("f1"), + $("f2"))), + "JSON_OBJECT(KEY 'b' VALUE JSON_ARRAY(f0, f1, f2))", + "{\"b\":[1,2,3]}", + STRING().notNull(), + STRING().notNull())) + .testSqlResult( + "JSON_OBJECT(KEY 'a' VALUE JSON_ARRAY(f0, f1, f2))," + + " JSON_OBJECT(KEY 'b' VALUE JSON_ARRAY(f0, f1, f2))", + List.of("{\"a\":[1,2,3]}", "{\"b\":[1,2,3]}"), + List.of(STRING().notNull(), STRING().notNull())), + // Shared inner JSON_OBJECT inside two outer JSON_OBJECT projections. + TestSetSpec.forFunction( + BuiltInFunctionDefinitions.JSON_OBJECT, + "Shared inner JSON_OBJECT across outer JSON_OBJECT projections") + .onFieldsWithData("V") + .andDataTypes(STRING()) + .testResult( + resultSpec( + jsonObject( + JsonOnNull.NULL, + "outer1", + jsonObject(JsonOnNull.NULL, "inner", $("f0"))), + "JSON_OBJECT(KEY 'outer1' VALUE JSON_OBJECT(KEY 'inner' VALUE f0))", + "{\"outer1\":{\"inner\":\"V\"}}", + STRING().notNull(), + STRING().notNull()), + resultSpec( + jsonObject( + JsonOnNull.NULL, + "outer2", + jsonObject(JsonOnNull.NULL, "inner", $("f0"))), + "JSON_OBJECT(KEY 'outer2' VALUE JSON_OBJECT(KEY 'inner' VALUE f0))", + "{\"outer2\":{\"inner\":\"V\"}}", + STRING().notNull(), + STRING().notNull())) + .testSqlResult( + "JSON_OBJECT(KEY 'outer1' VALUE JSON_OBJECT(KEY 'inner' VALUE f0))," + + " JSON_OBJECT(KEY 'outer2' VALUE JSON_OBJECT(KEY 'inner' VALUE f0))", + List.of( + "{\"outer1\":{\"inner\":\"V\"}}", + "{\"outer2\":{\"inner\":\"V\"}}"), + List.of(STRING().notNull(), STRING().notNull())), + // Shared JSON_OBJECT inside two JSON_ARRAY projections. + TestSetSpec.forFunction( + BuiltInFunctionDefinitions.JSON_ARRAY, + "Shared JSON_OBJECT inside JSON_ARRAY across projections") + .onFieldsWithData("V") + .andDataTypes(STRING()) + .testResult( + resultSpec( + jsonArray( + JsonOnNull.NULL, + jsonObject(JsonOnNull.NULL, "k", $("f0"))), + "JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE f0))", + "[{\"k\":\"V\"}]", + STRING().notNull(), + STRING().notNull()), + resultSpec( + jsonArray( + JsonOnNull.NULL, + jsonObject(JsonOnNull.NULL, "k", $("f0"))), + "JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE f0))", + "[{\"k\":\"V\"}]", + STRING().notNull(), + STRING().notNull())) + .testSqlResult( + "JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE f0))," + + " JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE f0))", + List.of("[{\"k\":\"V\"}]", "[{\"k\":\"V\"}]"), + List.of(STRING().notNull(), STRING().notNull())), + // Shared JSON(f) inside two JSON_ARRAY projections. + TestSetSpec.forFunction( + BuiltInFunctionDefinitions.JSON_ARRAY, + "Shared JSON(f) inside JSON_ARRAY across projections") + .onFieldsWithData("[1,2,3]") + .andDataTypes(STRING()) + .testResult( + resultSpec( + jsonArray(JsonOnNull.NULL, json($("f0"))), + "JSON_ARRAY(JSON(f0))", + "[[1,2,3]]", + STRING().notNull(), + STRING().notNull()), + resultSpec( + jsonArray(JsonOnNull.NULL, json($("f0"))), + "JSON_ARRAY(JSON(f0))", + "[[1,2,3]]", + STRING().notNull(), + STRING().notNull())) + .testSqlResult( + "JSON_ARRAY(JSON(f0)), JSON_ARRAY(JSON(f0))", + List.of("[[1,2,3]]", "[[1,2,3]]"), + List.of(STRING().notNull(), STRING().notNull())), + // Shared JSON_OBJECT inside two JSON_STRING projections. JSON_STRING re-serializes + // the operand; without dereferencing the local ref it would wrap the already + // serialized JSON string a second time. + TestSetSpec.forFunction( + BuiltInFunctionDefinitions.JSON_STRING, + "Shared JSON_OBJECT inside JSON_STRING across projections") + .onFieldsWithData("V") + .andDataTypes(STRING()) + .testResult( + resultSpec( + jsonString(jsonObject(JsonOnNull.NULL, "k", $("f0"))), + "JSON_STRING(JSON_OBJECT(KEY 'k' VALUE f0))", + "{\"k\":\"V\"}", + STRING().notNull(), + STRING().notNull()), + resultSpec( + jsonString(jsonObject(JsonOnNull.NULL, "k", $("f0"))), + "JSON_STRING(JSON_OBJECT(KEY 'k' VALUE f0))", + "{\"k\":\"V\"}", + STRING().notNull(), + STRING().notNull())) + .testSqlResult( + "JSON_STRING(JSON_OBJECT(KEY 'k' VALUE f0))," + + " JSON_STRING(JSON_OBJECT(KEY 'k' VALUE f0))", + List.of("{\"k\":\"V\"}", "{\"k\":\"V\"}"), + List.of(STRING().notNull(), STRING().notNull()))); + } + // --------------------------------------------------------------------------------------------- /** diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java index c0246b6fe4103..c9d31af6e97e9 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java @@ -69,6 +69,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import java.lang.invoke.MethodHandle; @@ -76,7 +78,6 @@ import java.nio.ByteBuffer; import java.time.DayOfWeek; import java.time.LocalDateTime; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -85,7 +86,9 @@ import java.util.Random; import java.util.UUID; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.apache.flink.table.api.Expressions.$; import static org.apache.flink.table.utils.UserDefinedFunctions.GENERATED_LOWER_UDF_CLASS; @@ -131,10 +134,10 @@ public void before() throws Exception { void testCreateCatalogFunctionInDefaultCatalog() { String ddl1 = "create function f1 as 'org.apache.flink.function.TestFunction'"; tEnv().executeSql(ddl1); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f1"); + assertThat(List.of(tEnv().listFunctions())).contains("f1"); tEnv().executeSql("DROP FUNCTION IF EXISTS default_catalog.default_database.f1"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f1"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f1"); } @Test @@ -143,10 +146,10 @@ void testCreateFunctionWithFullPath() { "create function default_catalog.default_database.f2 as" + " 'org.apache.flink.function.TestFunction'"; tEnv().executeSql(ddl1); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f2"); + assertThat(List.of(tEnv().listFunctions())).contains("f2"); tEnv().executeSql("DROP FUNCTION IF EXISTS default_catalog.default_database.f2"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f2"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f2"); } @Test @@ -155,10 +158,10 @@ void testCreateFunctionWithoutCatalogIdentifier() { "create function default_database.f3 as" + " 'org.apache.flink.function.TestFunction'"; tEnv().executeSql(ddl1); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f3"); + assertThat(List.of(tEnv().listFunctions())).contains("f3"); tEnv().executeSql("DROP FUNCTION IF EXISTS default_catalog.default_database.f3"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f3"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f3"); } @Test @@ -186,7 +189,7 @@ void testDynamicDatetimeFunctionsAreEqual() { + " CURRENT_DATE = CURRENT_DATE()") .execute(); List actualRows = CollectionUtil.iteratorToList(tableResult.collect()); - assertThat(actualRows).isEqualTo(Arrays.asList(Row.of(true, true, true, true, true))); + assertThat(actualRows).isEqualTo(List.of(Row.of(true, true, true, true, true))); } @Test @@ -232,13 +235,13 @@ void testCreateTemporaryCatalogFunction() { String ddl4 = "drop temporary function if exists default_catalog.default_database.f4"; tEnv().executeSql(ddl1); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f4"); + assertThat(List.of(tEnv().listFunctions())).contains("f4"); tEnv().executeSql(ddl2); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f4"); + assertThat(List.of(tEnv().listFunctions())).contains("f4"); tEnv().executeSql(ddl3); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f4"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f4"); tEnv().executeSql(ddl1); assertThatThrownBy(() -> tEnv().executeSql(ddl1)) @@ -276,24 +279,24 @@ void testCreateTemporarySystemFunctionByUsingJar() throws Exception { "CREATE TEMPORARY SYSTEM FUNCTION f10 AS '%s' USING JAR '%s'", udfClassName, jarPath); tEnv().executeSql(ddl); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f10"); + assertThat(List.of(tEnv().listFunctions())).contains("f10"); try (CloseableIterator itor = tEnv().executeSql("SHOW JARS").collect()) { assertThat(itor.hasNext()).isFalse(); } tEnv().executeSql("DROP TEMPORARY SYSTEM FUNCTION f10"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f10"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f10"); } @Test void testCreateTemporarySystemFunctionWithTableAPI() { ResourceUri resourceUri = new ResourceUri(ResourceType.JAR, jarPath); - tEnv().createTemporarySystemFunction("f10", udfClassName, Arrays.asList(resourceUri)); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f10"); + tEnv().createTemporarySystemFunction("f10", udfClassName, List.of(resourceUri)); + assertThat(List.of(tEnv().listFunctions())).contains("f10"); tEnv().executeSql("DROP TEMPORARY SYSTEM FUNCTION f10"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f10"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f10"); } @Test @@ -303,7 +306,7 @@ void testUserDefinedTemporarySystemFunctionWithTableAPI() throws Exception { testUserDefinedFunctionByUsingJar( environment -> environment.createTemporarySystemFunction( - "lowerUdf", udfClassName, Arrays.asList(resourceUri)), + "lowerUdf", udfClassName, List.of(resourceUri)), dropFunctionSql); } @@ -314,20 +317,20 @@ void testCreateCatalogFunctionByUsingJar() { "CREATE FUNCTION default_database.f11 AS '%s' USING JAR '%s'", udfClassName, jarPath); tEnv().executeSql(ddl); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f11"); + assertThat(List.of(tEnv().listFunctions())).contains("f11"); tEnv().executeSql("DROP FUNCTION default_database.f11"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f11"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f11"); } @Test void testCreateCatalogFunctionWithTableAPI() { ResourceUri resourceUri = new ResourceUri(ResourceType.JAR, jarPath); - tEnv().createFunction("f11", udfClassName, Arrays.asList(resourceUri)); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f11"); + tEnv().createFunction("f11", udfClassName, List.of(resourceUri)); + assertThat(List.of(tEnv().listFunctions())).contains("f11"); tEnv().executeSql("DROP FUNCTION default_database.f11"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f11"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f11"); } @Test @@ -336,8 +339,7 @@ void testUserDefinedCatalogFunctionWithTableAPI() throws Exception { String dropFunctionSql = "DROP FUNCTION default_database.lowerUdf"; testUserDefinedFunctionByUsingJar( environment -> - environment.createFunction( - "lowerUdf", udfClassName, Arrays.asList(resourceUri)), + environment.createFunction("lowerUdf", udfClassName, List.of(resourceUri)), dropFunctionSql); } @@ -348,20 +350,20 @@ void testCreateTemporaryCatalogFunctionByUsingJar() { "CREATE TEMPORARY FUNCTION default_database.f12 AS '%s' USING JAR '%s'", udfClassName, jarPath); tEnv().executeSql(ddl); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f12"); + assertThat(List.of(tEnv().listFunctions())).contains("f12"); tEnv().executeSql("DROP TEMPORARY FUNCTION default_database.f12"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f12"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f12"); } @Test void testCreateTemporaryCatalogFunctionWithTableAPI() { ResourceUri resourceUri = new ResourceUri(ResourceType.JAR, jarPath); - tEnv().createTemporaryFunction("f12", udfClassName, Arrays.asList(resourceUri)); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f12"); + tEnv().createTemporaryFunction("f12", udfClassName, List.of(resourceUri)); + assertThat(List.of(tEnv().listFunctions())).contains("f12"); tEnv().executeSql("DROP TEMPORARY FUNCTION default_database.f12"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f12"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f12"); } @Test @@ -371,7 +373,7 @@ void testUserDefinedTemporaryCatalogFunctionWithTableAPI() throws Exception { testUserDefinedFunctionByUsingJar( environment -> environment.createTemporaryFunction( - "lowerUdf", udfClassName, Arrays.asList(resourceUri)), + "lowerUdf", udfClassName, List.of(resourceUri)), dropFunctionSql); } @@ -596,7 +598,7 @@ void testExpressionReducerByUsingJar() { TableResult tableResult = tEnv().executeSql("SELECT lowerUdf('HELLO')"); List actualRows = CollectionUtil.iteratorToList(tableResult.collect()); - assertThat(actualRows).isEqualTo(Arrays.asList(Row.of("hello"))); + assertThat(actualRows).isEqualTo(List.of(Row.of("hello"))); tEnv().executeSql("drop temporary function lowerUdf"); } @@ -611,7 +613,7 @@ public Integer eval(Integer a, Integer b) { private void testUserDefinedCatalogFunction(String createFunctionDDL) throws Exception { List sourceData = - Arrays.asList( + List.of( Row.of(1, "1000", 2), Row.of(2, "1", 3), Row.of(3, "2000", 4), @@ -644,7 +646,7 @@ private void testUserDefinedCatalogFunction(String createFunctionDDL) throws Exc private void testUserDefinedFunctionByUsingJar(FunctionCreator creator, String dropFunctionDDL) throws Exception { List sourceData = - Arrays.asList( + List.of( Row.of(1, "JARK"), Row.of(2, "RON"), Row.of(3, "LeoNard"), @@ -667,7 +669,7 @@ private void testUserDefinedFunctionByUsingJar(FunctionCreator creator, String d List result = TestCollectionTableFactory.RESULT(); List expected = - Arrays.asList( + List.of( Row.of(1, "jark"), Row.of(2, "ron"), Row.of(3, "leonard"), @@ -684,10 +686,10 @@ private void testUserDefinedFunctionByUsingJar(FunctionCreator creator, String d @Test void testPrimitiveScalarFunction() throws Exception { final List sourceData = - Arrays.asList(Row.of(1, 1L, "-"), Row.of(2, 2L, "--"), Row.of(3, 3L, "---")); + List.of(Row.of(1, 1L, "-"), Row.of(2, 2L, "--"), Row.of(3, 3L, "---")); final List sinkData = - Arrays.asList(Row.of(1, 3L, "-"), Row.of(2, 6L, "--"), Row.of(3, 9L, "---")); + List.of(Row.of(1, 3L, "-"), Row.of(2, 6L, "--"), Row.of(3, 9L, "---")); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -738,7 +740,7 @@ void testNullScalarFunction() throws Exception { @Test void testRowScalarFunction() throws Exception { final List sourceData = - Arrays.asList( + List.of( Row.of(1, Row.of(1, "1")), Row.of(2, Row.of(2, "2")), Row.of(3, Row.of(3, "3"))); @@ -761,14 +763,14 @@ void testRowScalarFunction() throws Exception { @Test void testComplexScalarFunction() throws Exception { final List sourceData = - Arrays.asList( + List.of( Row.of(1, new byte[] {1, 2, 3}), Row.of(2, new byte[] {2, 3, 4}), Row.of(3, new byte[] {3, 4, 5}), Row.of(null, null)); final List sinkData = - Arrays.asList( + List.of( Row.of( 1, "1+2012-12-12 12:12:12.123456789", @@ -834,11 +836,10 @@ void testComplexScalarFunction() throws Exception { @Test void testCustomScalarFunction() throws Exception { final List sourceData = - Arrays.asList(Row.of(1), Row.of(2), Row.of(3), Row.of((Integer) null)); + List.of(Row.of(1), Row.of(2), Row.of(3), Row.of((Integer) null)); final List sinkData = - Arrays.asList( - Row.of(1, 1, 5), Row.of(2, 2, 5), Row.of(3, 3, 5), Row.of(null, null, 5)); + List.of(Row.of(1, 1, 5), Row.of(2, 2, 5), Row.of(3, 3, 5), Row.of(null, null, 5)); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -862,7 +863,7 @@ void testCustomScalarFunction() throws Exception { @Test void testVarArgScalarFunction() { - final List sourceData = Arrays.asList(Row.of("Bob", 1), Row.of("Alice", 2)); + final List sourceData = List.of(Row.of("Bob", 1), Row.of("Alice", 2)); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -890,7 +891,7 @@ void testVarArgScalarFunction() { final List actual = CollectionUtil.iteratorToList(result.collect()); final List expected = - Arrays.asList( + List.of( Row.of( "(INT...)", "(INT...)", @@ -909,7 +910,7 @@ void testVarArgScalarFunction() { @Test void testRawLiteralScalarFunction() throws Exception { final List sourceData = - Arrays.asList( + List.of( Row.of(1, DayOfWeek.MONDAY), Row.of(2, DayOfWeek.FRIDAY), Row.of(null, null)); @@ -968,13 +969,390 @@ void testRawLiteralScalarFunction() throws Exception { assertThat(TestCollectionTableFactory.getResult()).containsExactlyInAnyOrder(sinkData); } + @ParameterizedTest(name = "{0}") + @MethodSource("inputForTestCalcLocalRefReuse") + void testCalcLocalRefReuse( + String sql, List expectedRows, int expectedDetCalls, int expectedNonDetCalls) { + final List sourceData = List.of(Row.of("Bob"), Row.of("Alice")); + + TestCollectionTableFactory.reset(); + TestCollectionTableFactory.initData(sourceData); + CountingUpperScalarFunction.COUNT.set(0); + NonDeterministicCountingScalarFunction.COUNT.set(0); + + tEnv().createTemporarySystemFunction("Det", CountingUpperScalarFunction.class); + tEnv().createTemporarySystemFunction( + "Nondet", NonDeterministicCountingScalarFunction.class); + tEnv().executeSql("CREATE TABLE SourceTable (s STRING) WITH ('connector' = 'COLLECTION')"); + + final List actual = CollectionUtil.iteratorToList(tEnv().executeSql(sql).collect()); + + assertThat(actual).containsExactlyElementsOf(expectedRows); + assertThat(CountingUpperScalarFunction.COUNT.get()) + .as("Deterministic invocations") + .isEqualTo(expectedDetCalls); + assertThat(NonDeterministicCountingScalarFunction.COUNT.get()) + .as("Non-deterministic invocations") + .isEqualTo(expectedNonDetCalls); + } + + static Stream inputForTestCalcLocalRefReuse() { + return Stream.of( + Arguments.of( + "SELECT Det(s), Det(s), Det(s) FROM SourceTable", + List.of(Row.of("BOB", "BOB", "BOB"), Row.of("ALICE", "ALICE", "ALICE")), + 2, // expected localref calls: rows × 1 (cached) + 0), + Arguments.of( + "SELECT Det(s), Det(s), UPPER(s) FROM SourceTable", + List.of(Row.of("BOB", "BOB", "BOB"), Row.of("ALICE", "ALICE", "ALICE")), + 2, // rows × 1 (cached); built-in UPPER not counted + 0), + Arguments.of( + "SELECT Det(Det(s)), Det(Det(s)), Det(Det(s)) FROM SourceTable", + List.of(Row.of("BOB", "BOB", "BOB"), Row.of("ALICE", "ALICE", "ALICE")), + 4, // rows × 2 layers + 0), + Arguments.of( + "SELECT Nondet(s), Nondet(s), Nondet(s) FROM SourceTable", + List.of( + Row.of("BOB_1", "BOB_2", "BOB_3"), + Row.of("ALICE_4", "ALICE_5", "ALICE_6")), + 0, + 6 // rows × 3 projections + ), + Arguments.of( + "SELECT Nondet(Det(s)), Nondet(Det(s)), Nondet(Det(s)) FROM SourceTable", + List.of( + Row.of("BOB_1", "BOB_2", "BOB_3"), + Row.of("ALICE_4", "ALICE_5", "ALICE_6")), + 2, // rows × 1 (inner cached) + 6 // rows × 3 projections + ), + Arguments.of( + "SELECT Det(Nondet(s)), Det(Nondet(s)), Det(Nondet(s)) FROM SourceTable", + List.of( + Row.of("BOB_1", "BOB_2", "BOB_3"), + Row.of("ALICE_4", "ALICE_5", "ALICE_6")), + 6, // rows × 3 (nondet input disables cache) + 6 // rows × 3 projections + ), + // shared Det in filter → cached once per row + Arguments.of( + "SELECT s FROM SourceTable" + + " WHERE Det(s) IS NOT NULL AND Det(s) <> '' AND Det(s) <> ' '", + List.of(Row.of("Bob"), Row.of("Alice")), + 2, + 0), + // mixed UDF + built-in + Arguments.of( + "SELECT s FROM SourceTable" + + " WHERE Det(s) IS NOT NULL AND Det(s) <> '' AND UPPER(s) <> ''", + List.of(Row.of("Bob"), Row.of("Alice")), + 2, + 0), + // nested Det in filter; both layers cached + Arguments.of( + "SELECT s FROM SourceTable" + + " WHERE Det(Det(s)) IS NOT NULL" + + " AND Det(Det(s)) <> '' AND Det(Det(s)) <> ' '", + List.of(Row.of("Bob"), Row.of("Alice")), + 4, + 0), + // non-deterministic in filter — never cached + Arguments.of( + "SELECT s FROM SourceTable" + + " WHERE Nondet(s) IS NOT NULL" + + " AND Nondet(s) <> '' AND Nondet(s) <> ' '", + List.of(Row.of("Bob"), Row.of("Alice")), + 0, + 6), + // outer nondet, inner Det cached + Arguments.of( + "SELECT s FROM SourceTable" + + " WHERE Nondet(Det(s)) IS NOT NULL" + + " AND Nondet(Det(s)) <> '' AND Nondet(Det(s)) <> ' '", + List.of(Row.of("Bob"), Row.of("Alice")), + 2, + 6), + // Det with nondet input → cache bypassed + Arguments.of( + "SELECT s FROM SourceTable" + + " WHERE Det(Nondet(s)) IS NOT NULL" + + " AND Det(Nondet(s)) <> '' AND Det(Nondet(s)) <> ' '", + List.of(Row.of("Bob"), Row.of("Alice")), + 6, + 6), + // filter ↔ projection share via unified program + Arguments.of( + "SELECT Det(s) FROM SourceTable WHERE Det(s) = 'BOB'", + List.of(Row.of("BOB")), + 2, + 0), + Arguments.of( + "SELECT Det(s), Det(s) FROM SourceTable WHERE Det(s) = 'BOB'", + List.of(Row.of("BOB", "BOB")), + 2, + 0), + + // --------------------------------------------------------------------------- + // JSON construction scenarios. These verify that the localref / RexProgram CSE + // cache also fires when the shared sub-expression is wrapped inside (or itself + // is) a JSON_OBJECT / JSON_ARRAY / JSON_STRING call. + // --------------------------------------------------------------------------- + + // JSON_OBJECT × 2 sharing inner Det → cached once per row. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'a' VALUE Det(s))," + + " JSON_OBJECT(KEY 'b' VALUE Det(s))" + + " FROM SourceTable", + List.of( + Row.of("{\"a\":\"BOB\"}", "{\"b\":\"BOB\"}"), + Row.of("{\"a\":\"ALICE\"}", "{\"b\":\"ALICE\"}")), + 2, // rows × 1 (cached) + 0), + // JSON_ARRAY × 2 sharing inner Det → cached. + Arguments.of( + "SELECT JSON_ARRAY(Det(s)), JSON_ARRAY(Det(s)) FROM SourceTable", + List.of( + Row.of("[\"BOB\"]", "[\"BOB\"]"), + Row.of("[\"ALICE\"]", "[\"ALICE\"]")), + 2, + 0), + // JSON_STRING × 2 sharing inner Det → cached. + Arguments.of( + "SELECT JSON_STRING(Det(s)), JSON_STRING(Det(s)) FROM SourceTable", + List.of(Row.of("\"BOB\"", "\"BOB\""), Row.of("\"ALICE\"", "\"ALICE\"")), + 2, + 0), + // Mixed JSON_OBJECT + JSON_ARRAY sharing same Det. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'k' VALUE Det(s)), JSON_ARRAY(Det(s))" + + " FROM SourceTable", + List.of( + Row.of("{\"k\":\"BOB\"}", "[\"BOB\"]"), + Row.of("{\"k\":\"ALICE\"}", "[\"ALICE\"]")), + 2, + 0), + // Mixed JSON_OBJECT + JSON_STRING sharing same Det. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'k' VALUE Det(s)), JSON_STRING(Det(s))" + + " FROM SourceTable", + List.of( + Row.of("{\"k\":\"BOB\"}", "\"BOB\""), + Row.of("{\"k\":\"ALICE\"}", "\"ALICE\"")), + 2, + 0), + // JSON_OBJECT × 3 sharing same Det → cached across all 3 sites. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'a' VALUE Det(s))," + + " JSON_OBJECT(KEY 'b' VALUE Det(s))," + + " JSON_OBJECT(KEY 'c' VALUE Det(s))" + + " FROM SourceTable", + List.of( + Row.of("{\"a\":\"BOB\"}", "{\"b\":\"BOB\"}", "{\"c\":\"BOB\"}"), + Row.of( + "{\"a\":\"ALICE\"}", + "{\"b\":\"ALICE\"}", + "{\"c\":\"ALICE\"}")), + 2, + 0), + // Nested Det(Det(s)) inside two JSON_OBJECT projections → both layers cached. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'a' VALUE Det(Det(s)))," + + " JSON_OBJECT(KEY 'b' VALUE Det(Det(s)))" + + " FROM SourceTable", + List.of( + Row.of("{\"a\":\"BOB\"}", "{\"b\":\"BOB\"}"), + Row.of("{\"a\":\"ALICE\"}", "{\"b\":\"ALICE\"}")), + 4, // rows × 2 layers + 0), + // Nondet inside two JSON_OBJECT projections → never cached. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'a' VALUE Nondet(s))," + + " JSON_OBJECT(KEY 'b' VALUE Nondet(s))" + + " FROM SourceTable", + List.of( + Row.of("{\"a\":\"BOB_1\"}", "{\"b\":\"BOB_2\"}"), + Row.of("{\"a\":\"ALICE_3\"}", "{\"b\":\"ALICE_4\"}")), + 0, + 4 // rows × 2 projections + ), + // Outer Nondet, inner Det inside two JSON_OBJECT projections — Det cached. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'a' VALUE Nondet(Det(s)))," + + " JSON_OBJECT(KEY 'b' VALUE Nondet(Det(s)))" + + " FROM SourceTable", + List.of( + Row.of("{\"a\":\"BOB_1\"}", "{\"b\":\"BOB_2\"}"), + Row.of("{\"a\":\"ALICE_3\"}", "{\"b\":\"ALICE_4\"}")), + 2, // inner Det cached + 4), + // Outer Det, inner Nondet → outer cache disabled by nondet operand. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'a' VALUE Det(Nondet(s)))," + + " JSON_OBJECT(KEY 'b' VALUE Det(Nondet(s)))" + + " FROM SourceTable", + List.of( + Row.of("{\"a\":\"BOB_1\"}", "{\"b\":\"BOB_2\"}"), + Row.of("{\"a\":\"ALICE_3\"}", "{\"b\":\"ALICE_4\"}")), + 4, // outer Det not cached (nondet operand) + 4), + // Filter ↔ JSON projection share Det via unified program. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'k' VALUE Det(s))" + + " FROM SourceTable WHERE Det(s) = 'BOB'", + List.of(Row.of("{\"k\":\"BOB\"}")), + 2, + 0), + // Shared inner JSON_OBJECT(KEY 'k' VALUE Det(s)) inside two outer JSON_OBJECT + // projections — verifies CSE works when the cached node is itself a JSON + // construction call (and validates the JSON helpers' RexLocalRef deref path + // along the way). + Arguments.of( + "SELECT JSON_OBJECT(KEY 'outer1' VALUE JSON_OBJECT(KEY 'k' VALUE Det(s)))," + + " JSON_OBJECT(KEY 'outer2' VALUE JSON_OBJECT(KEY 'k' VALUE Det(s)))" + + " FROM SourceTable", + List.of( + Row.of( + "{\"outer1\":{\"k\":\"BOB\"}}", + "{\"outer2\":{\"k\":\"BOB\"}}"), + Row.of( + "{\"outer1\":{\"k\":\"ALICE\"}}", + "{\"outer2\":{\"k\":\"ALICE\"}}")), + 2, + 0), + // Shared inner JSON_ARRAY(Det(s)) inside two outer JSON_OBJECT projections. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'a' VALUE JSON_ARRAY(Det(s)))," + + " JSON_OBJECT(KEY 'b' VALUE JSON_ARRAY(Det(s)))" + + " FROM SourceTable", + List.of( + Row.of("{\"a\":[\"BOB\"]}", "{\"b\":[\"BOB\"]}"), + Row.of("{\"a\":[\"ALICE\"]}", "{\"b\":[\"ALICE\"]}")), + 2, + 0), + // Shared inner JSON_OBJECT(KEY 'k' VALUE Det(s)) inside two JSON_ARRAY + // projections. + Arguments.of( + "SELECT JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE Det(s)))," + + " JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE Det(s)))" + + " FROM SourceTable", + List.of( + Row.of("[{\"k\":\"BOB\"}]", "[{\"k\":\"BOB\"}]"), + Row.of("[{\"k\":\"ALICE\"}]", "[{\"k\":\"ALICE\"}]")), + 2, + 0)); + } + + @Test + void testLocalRefReuseForMixedArgs() { + final List sourceData = List.of(Row.of("Bob"), Row.of("Alice")); + final int callSites = 2; + + TestCollectionTableFactory.reset(); + TestCollectionTableFactory.initData(sourceData); + CountingUpperScalarFunction.COUNT.set(0); + NonDeterministicCountingScalarFunction.COUNT.set(0); + CountingConcat3ScalarFunction.COUNT.set(0); + + tEnv().createTemporarySystemFunction("Det", CountingUpperScalarFunction.class); + tEnv().createTemporarySystemFunction( + "Nondet", NonDeterministicCountingScalarFunction.class); + tEnv().createTemporarySystemFunction("Concat3", CountingConcat3ScalarFunction.class); + tEnv().executeSql("CREATE TABLE SourceTable (s STRING) WITH ('connector' = 'COLLECTION')"); + + final List actual = + CollectionUtil.iteratorToList( + tEnv().executeSql( + "SELECT Concat3(Det(s), Nondet(s), Det(s))," + + " Concat3(Det(s), Nondet(s), Det(s))" + + " FROM SourceTable") + .collect()); + + assertThat(actual) + .containsExactly( + Row.of("BOB/BOB_1/BOB", "BOB/BOB_2/BOB"), + Row.of("ALICE/ALICE_3/ALICE", "ALICE/ALICE_4/ALICE")); + + assertThat(CountingUpperScalarFunction.COUNT.get()).isEqualTo(sourceData.size()); + assertThat(NonDeterministicCountingScalarFunction.COUNT.get()) + .isEqualTo(sourceData.size() * callSites); + // Concat3 is deterministic however has non-deterministic input + assertThat(CountingConcat3ScalarFunction.COUNT.get()) + .isEqualTo(sourceData.size() * callSites); + } + + @Test + void testCalcSharesSubExpressionBetweenFilterAndProjection() { + final List sourceData = + List.of(Row.of("Bob"), Row.of("Bob"), Row.of("Alice"), Row.of("Alice")); + + TestCollectionTableFactory.reset(); + TestCollectionTableFactory.initData(sourceData); + CountingUpperScalarFunction.COUNT.set(0); + + tEnv().createTemporarySystemFunction("CountingUpper", CountingUpperScalarFunction.class); + tEnv().executeSql("CREATE TABLE SourceTable (s STRING) WITH ('connector' = 'COLLECTION')"); + + final List actual = + CollectionUtil.iteratorToList( + tEnv().executeSql( + "SELECT CountingUpper(s) FROM SourceTable" + + " WHERE CountingUpper(s) = 'BOB' AND CountingUpper(s) <> 'BOB2'") + .collect()); + + assertThat(actual).containsExactly(Row.of("BOB"), Row.of("BOB")); + + // Filter and projection share via the unified RexProgram, so the UDF runs once per + // source row regardless of how many call sites name it. + assertThat(CountingUpperScalarFunction.COUNT.get()).isEqualTo(sourceData.size()); + } + + /** + * Pins the CASE-WHEN guard interaction with the RexLocalRef cache. + * + *

Prior to scoped caching, RexProgramBuilder collapsed the division {@code a / b} into a + * single exprList entry; the codegen visitor cached the body and {@code + * CalcCodeGenerator.reuseLocalRefCode()} hoisted that body to the top of the generated method, + * evaluating {@code a / b} for every row regardless of the surrounding {@code CASE WHEN b > 0}. + * Rows with {@code b = 0} then threw {@code java.lang.ArithmeticException: Division undefined} + * — caught in the wild on TPC-DS query 34. With scoped caching the division body lives inside + * the THEN-branch's generated code and never executes when the guard is false. + */ + @Test + void testCalcCaseGuardShortCircuit() { + final List sourceData = + List.of(Row.of(10, 0), Row.of(10, 2), Row.of(20, 0), Row.of(30, 5), Row.of(40, 0)); + + TestCollectionTableFactory.reset(); + TestCollectionTableFactory.initData(sourceData); + tEnv().executeSql( + "CREATE TABLE SourceTable (a INT, b INT) WITH ('connector' = 'COLLECTION')"); + + final List actual = + CollectionUtil.iteratorToList( + tEnv().executeSql( + "SELECT a FROM SourceTable WHERE" + + " (CASE WHEN b > 0" + + " THEN CAST(a AS DECIMAL(7,2))" + + " / CAST(b AS DECIMAL(7,2))" + + " ELSE NULL END) > 1.2") + .collect()); + + // Row(10,2) → 10/2 = 5.0 (>1.2) + // Row(30,5) → 30/5 = 6.0 (>1.2) + // Rows with b=0 must NOT enter the THEN-branch (the division would fail). + assertThat(actual).containsExactly(Row.of(10), Row.of(30)); + } + @Test void testStructuredScalarFunction() throws Exception { final List sourceData = - Arrays.asList(Row.of("Bob", 42), Row.of("Alice", 12), Row.of(null, 0)); + List.of(Row.of("Bob", 42), Row.of("Alice", 12), Row.of(null, 0)); final List sinkData = - Arrays.asList( + List.of( Row.of("Bob 42", "Tyler"), Row.of("Alice 12", "Tyler"), Row.of("<>", "Tyler")); @@ -1020,11 +1398,10 @@ void testInvalidCustomScalarFunction() { @Test void testRowTableFunction() throws Exception { final List sourceData = - Arrays.asList( - Row.of("1,2,3"), Row.of("2,3,4"), Row.of("3,4,5"), Row.of((String) null)); + List.of(Row.of("1,2,3"), Row.of("2,3,4"), Row.of("3,4,5"), Row.of((String) null)); final List sinkData = - Arrays.asList( + List.of( Row.of("1,2,3", new String[] {"1", "2", "3"}), Row.of("2,3,4", new String[] {"2", "3", "4"}), Row.of("3,4,5", new String[] {"3", "4", "5"})); @@ -1048,10 +1425,9 @@ void testRowTableFunction() throws Exception { @Test void testStructuredTableFunction() throws Exception { final List sourceData = - Arrays.asList(Row.of("Bob", 42), Row.of("Alice", 12), Row.of(null, 0)); + List.of(Row.of("Bob", 42), Row.of("Alice", 12), Row.of(null, 0)); - final List sinkData = - Arrays.asList(Row.of("Bob", 42), Row.of("Alice", 12), Row.of(null, 0)); + final List sinkData = List.of(Row.of("Bob", 42), Row.of("Alice", 12), Row.of(null, 0)); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1157,10 +1533,10 @@ void testNamedArgumentsTableFunctionWithOptionalArguments() throws Exception { @Test void testNamedArgumentsScalarFunction() throws Exception { final List sourceData = - Arrays.asList(Row.of(1, 2, "str1"), Row.of(3, 4, "str2"), Row.of(5, 6, "str3")); + List.of(Row.of(1, 2, "str1"), Row.of(3, 4, "str2"), Row.of(5, 6, "str3")); final List sinkData = - Arrays.asList(Row.of(1, 2, "1: 2"), Row.of(3, 4, "3: 4"), Row.of(5, 6, "5: 6")); + List.of(Row.of(1, 2, "1: 2"), Row.of(3, 4, "3: 4"), Row.of(5, 6, "5: 6")); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1182,7 +1558,7 @@ void testNamedArgumentsScalarFunction() throws Exception { @Test void testNamedParametersScalarFunctionWithOverloadedMethod() throws Exception { final List sourceData = - Arrays.asList(Row.of(1, 2, "str1"), Row.of(3, 4, "str2"), Row.of(5, 6, "str3")); + List.of(Row.of(1, 2, "str1"), Row.of(3, 4, "str2"), Row.of(5, 6, "str3")); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1206,8 +1582,7 @@ void testNamedParametersScalarFunctionWithOverloadedMethod() throws Exception { @Test void testNamedArgumentsScalarFunctionWithOptionalArguments() throws Exception { - final List sinkData = - Arrays.asList(Row.of("s1: null", "null: s2", "s1: s2", "null: null")); + final List sinkData = List.of(Row.of("s1: null", "null: s2", "s1: s2", "null: null")); TestCollectionTableFactory.reset(); tEnv().executeSql( @@ -1230,14 +1605,13 @@ void testNamedArgumentsScalarFunctionWithOptionalArguments() throws Exception { @Test void testNamedArgumentAggregateFunction() throws Exception { final List sourceData = - Arrays.asList( + List.of( Row.of(LocalDateTime.parse("2007-12-03T10:15:30"), "a", "b", 1, 2), Row.of(LocalDateTime.parse("2007-12-03T10:15:30"), "c", "d", 33, 44), Row.of(LocalDateTime.parse("2007-12-03T10:15:32"), "e", "f", 5, 6), Row.of(LocalDateTime.parse("2007-12-03T10:15:32"), "gg", "hh", 7, 88)); - final List sinkData = - Arrays.asList(Row.of("a: b", "b: a"), Row.of("gg: hh", "hh: gg")); + final List sinkData = List.of(Row.of("a: b", "b: a"), Row.of("gg: hh", "hh: gg")); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1265,14 +1639,14 @@ void testNamedArgumentAggregateFunction() throws Exception { @Test void testNamedArgumentAggregateFunctionWithOptionalArguments() throws Exception { final List sourceData = - Arrays.asList( + List.of( Row.of(LocalDateTime.parse("2007-12-03T10:15:30"), "a", "b", 1, 2), Row.of(LocalDateTime.parse("2007-12-03T10:15:30"), "c", "d", 33, 44), Row.of(LocalDateTime.parse("2007-12-03T10:15:32"), "e", "f", 5, 6), Row.of(LocalDateTime.parse("2007-12-03T10:15:32"), "gg", "hh", 7, 88)); final List sinkData = - Arrays.asList(Row.of("a: null", "null: b"), Row.of("gg: null", "null: hh")); + List.of(Row.of("a: null", "null: b"), Row.of("gg: null", "null: hh")); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1346,7 +1720,7 @@ void testInvalidUseOfTableFunction() { @Test void testAggregateFunction() throws Exception { final List sourceData = - Arrays.asList( + List.of( Row.of(LocalDateTime.parse("2007-12-03T10:15:30"), "Bob"), Row.of(LocalDateTime.parse("2007-12-03T10:15:30"), "Alice"), Row.of(LocalDateTime.parse("2007-12-03T10:15:30"), null), @@ -1355,7 +1729,7 @@ void testAggregateFunction() throws Exception { Row.of(LocalDateTime.parse("2007-12-03T10:15:32"), "Alice")); final List sinkData = - Arrays.asList( + List.of( Row.of( "Jonathan", "Alice=(Alice, 5), Bob=(Bob, 3), Jonathan=(Jonathan, 8)"), @@ -1409,10 +1783,10 @@ void testLookupTableFunctionWithoutHintLevel1() private void testLookupTableFunctionBase(String lookupTableFunctionClassName) throws ExecutionException, InterruptedException { - final List sourceData = Arrays.asList(Row.of("Bob"), Row.of("Alice")); + final List sourceData = List.of(Row.of("Bob"), Row.of("Alice")); final List sinkData = - Arrays.asList( + List.of( Row.of("Bob", new byte[0]), Row.of("Bob", new byte[] {66, 111, 98}), Row.of("Alice", new byte[0]), @@ -1459,7 +1833,7 @@ private void testLookupTableFunctionBase(String lookupTableFunctionClassName) @Test void testSpecializedFunction() { final List sourceData = - Arrays.asList( + List.of( Row.of("Bob", 1, new BigDecimal("123.45")), Row.of("Alice", 2, new BigDecimal("123.456"))); @@ -1489,7 +1863,7 @@ void testSpecializedFunction() { final List actual = CollectionUtil.iteratorToList(result.collect()); final List expected = - Arrays.asList( + List.of( Row.of("CHAR(7) NOT NULL", "STRING", "INT", "DECIMAL(6, 3)"), Row.of("CHAR(7) NOT NULL", "STRING", "INT", "DECIMAL(6, 3)")); assertThat(actual).isEqualTo(expected); @@ -1498,7 +1872,7 @@ void testSpecializedFunction() { @Test void testSpecializedFunctionWithExpressionEvaluation() { final List sourceData = - Arrays.asList( + List.of( Row.of("Bob", new Integer[] {1, 2, 3}, new BigDecimal("123.000")), Row.of("Bob", new Integer[] {4, 5, 6}, new BigDecimal("123.456")), Row.of("Alice", new Integer[] {1, 2, 3}, null), @@ -1530,7 +1904,7 @@ void testSpecializedFunctionWithExpressionEvaluation() { final List actual = CollectionUtil.iteratorToList(result.collect()); final List expected = - Arrays.asList( + List.of( Row.of("Bob", null, null), Row.of( "Bob", @@ -1543,7 +1917,7 @@ void testSpecializedFunctionWithExpressionEvaluation() { @Test void testTimestampNotNull() { - List sourceData = Arrays.asList(Row.of(1), Row.of(2)); + List sourceData = List.of(Row.of(1), Row.of(2)); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1557,7 +1931,7 @@ void testTimestampNotNull() { @Test void testIsNullType() { - List sourceData = Arrays.asList(Row.of(1), Row.of((Object) null)); + List sourceData = List.of(Row.of(1), Row.of((Object) null)); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1571,7 +1945,7 @@ void testIsNullType() { @Test void testWithBoolNotNullTypeHint() { - List sourceData = Arrays.asList(Row.of(1, 2), Row.of(2, 3)); + List sourceData = List.of(Row.of(1, 2), Row.of(2, 3)); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1605,7 +1979,7 @@ void testUsingAddJar() throws Exception { @Test void testUdfWithMultiLocalVariables() { - List sourceData = Arrays.asList(Row.of(1L, 2L), Row.of(2L, 3L)); + List sourceData = List.of(Row.of(1L, 2L), Row.of(2L, 3L)); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1620,7 +1994,7 @@ void testUdfWithMultiLocalVariables() { CollectionUtil.iteratorToList( tEnv().executeSql("SELECT MultiLocalVariables(x, y) FROM SourceTable") .collect()); - assertThat(actualRows).isEqualTo(Arrays.asList(Row.of(2L), Row.of(6L))); + assertThat(actualRows).isEqualTo(List.of(Row.of(2L), Row.of(6L))); } // -------------------------------------------------------------------------------------------- @@ -1757,6 +2131,41 @@ public TypeInference getTypeInference(DataTypeFactory typeFactory) { } } + /** Deterministic function with a counter. */ + public static class CountingUpperScalarFunction extends ScalarFunction { + public static final AtomicInteger COUNT = new AtomicInteger(); + + public String eval(String s) { + COUNT.incrementAndGet(); + return s == null ? null : s.toUpperCase(); + } + } + + /** Deterministic function with a counter and 3 args. */ + public static class CountingConcat3ScalarFunction extends ScalarFunction { + public static final AtomicInteger COUNT = new AtomicInteger(); + + public String eval(String a, String b, String c) { + COUNT.incrementAndGet(); + return a + "/" + b + "/" + c; + } + } + + /** Non-deterministic function with a counter. */ + public static class NonDeterministicCountingScalarFunction extends ScalarFunction { + public static final AtomicInteger COUNT = new AtomicInteger(); + + public String eval(String s) { + final int count = COUNT.incrementAndGet(); + return s == null ? null : s.toUpperCase() + "_" + count; + } + + @Override + public boolean isDeterministic() { + return false; + } + } + /** Function that has a custom type inference that is broader than the actual implementation. */ public static class CustomScalarFunction extends ScalarFunction { public Integer eval(Integer... args) {