From 17175fb057f937a1bce5532c18bb51453b65c5d9 Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Thu, 30 Apr 2026 08:24:21 +0200 Subject: [PATCH 01/12] [FLINK-39268][table] Expand and reuse local refs in `CalcCodeGenerator` --- .../nodes/exec/common/CommonExecCalc.java | 1 + .../table/planner/utils/ShortcutUtils.java | 72 +++ .../planner/codegen/CalcCodeGenerator.scala | 62 +- .../codegen/CodeGeneratorContext.scala | 51 +- .../planner/codegen/ExprCodeGenerator.scala | 145 ++++- .../planner/codegen/ExpressionReducer.scala | 9 +- .../planner/codegen/JsonGenerateUtils.scala | 86 ++- .../codegen/LongHashJoinGenerator.scala | 2 +- .../calls/BridgingFunctionGenUtil.scala | 7 +- .../calls/BridgingSqlFunctionCallGen.scala | 8 +- .../codegen/calls/JsonArrayCallGen.scala | 11 +- .../codegen/calls/JsonObjectCallGen.scala | 10 +- .../codegen/calls/JsonStringCallGen.scala | 10 +- .../functions/JsonFunctionsITCase.java | 186 +++++- .../runtime/stream/sql/FunctionITCase.java | 561 +++++++++++++++--- 15 files changed, 1036 insertions(+), 185 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java index cf389655031cd..97e3c43e2ac19 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java @@ -100,6 +100,7 @@ protected Transformation translateToPlanInternal( CalcCodeGenerator.generateCalcOperator( ctx, inputTransform, + (RowType) inputEdge.getOutputType(), (RowType) getOutputType(), JavaScalaConversionUtil.toScala(projection), JavaScalaConversionUtil.toScala(Optional.ofNullable(this.condition)), diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java index 415b78efeac09..2016f8a0c10c7 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java @@ -24,6 +24,7 @@ import org.apache.flink.table.delegation.Planner; import org.apache.flink.table.expressions.CallExpression; import org.apache.flink.table.expressions.ResolvedExpression; +import org.apache.flink.table.functions.BuiltInFunctionDefinition; import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.FunctionKind; import org.apache.flink.table.planner.calcite.FlinkContext; @@ -40,13 +41,20 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexLocalRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexUtil; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.tools.RelBuilder; import javax.annotation.Nullable; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + /** * Utilities for quick access of commonly used instances (like {@link FlinkTypeFactory}) without * long chains of getters or casting like {@code (FlinkTypeFactory) @@ -169,6 +177,70 @@ public static boolean isFunctionKind(SqlOperator operator, FunctionKind kind) { return functionDefinition != null && functionDefinition.getKind() == kind; } + public static boolean isOneOfFunctionDefinitions( + RexNode rexNode, FunctionDefinition... expectedDefinitions) { + if (!(rexNode instanceof RexCall)) { + return false; + } + final RexCall call = (RexCall) rexNode; + final FunctionDefinition unwrapped = unwrapFunctionDefinition(call); + final String operatorName = call.getOperator().getName(); + for (FunctionDefinition expected : expectedDefinitions) { + if (unwrapped != null && unwrapped == expected) { + return true; + } + if (expected instanceof BuiltInFunctionDefinition + && ((BuiltInFunctionDefinition) expected) + .getName() + .equalsIgnoreCase(operatorName)) { + return true; + } + } + return false; + } + + public static RexNode expandLocalRef(RexNode operand, @Nullable List exprs) { + while (operand instanceof RexLocalRef && exprs != null) { + operand = exprs.get(((RexLocalRef) operand).getIndex()); + } + return operand; + } + + public static boolean isDeterministicThroughProgram( + RexNode node, @Nullable List exprs) { + if (exprs == null) { + return RexUtil.isDeterministic(node); + } + return isDeterministicThroughProgram(node, exprs, new HashSet<>()); + } + + private static boolean isDeterministicThroughProgram( + RexNode node, List exprs, Set visited) { + if (node instanceof RexCall) { + final RexCall call = (RexCall) node; + if (!call.getOperator().isDeterministic()) { + return false; + } + for (RexNode operand : call.getOperands()) { + if (!isDeterministicThroughProgram(operand, exprs, visited)) { + return false; + } + } + return true; + } + if (node instanceof RexLocalRef) { + final int idx = ((RexLocalRef) node).getIndex(); + // already on the stack: skip rather than recurse forever + return !visited.add(idx) + || isDeterministicThroughProgram(exprs.get(idx), exprs, visited); + } + if (node instanceof RexFieldAccess) { + return isDeterministicThroughProgram( + ((RexFieldAccess) node).getReferenceExpr(), exprs, visited); + } + return true; + } + public static @Nullable BridgingSqlFunction unwrapBridgingSqlFunction(RexCall call) { final SqlOperator operator = call.getOperator(); if (operator instanceof BridgingSqlFunction) { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala index 8072a9ca42711..464e449644dff 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala @@ -23,6 +23,7 @@ import org.apache.flink.configuration.ReadableConfig import org.apache.flink.table.api.{TableException, ValidationException} import org.apache.flink.table.data.{BoxedWrapperRowData, RowData} import org.apache.flink.table.functions.FunctionKind +import org.apache.flink.table.planner.calcite.{FlinkRexBuilder, FlinkTypeFactory, FlinkTypeSystem} import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction import org.apache.flink.table.runtime.generated.GeneratedFunction import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory @@ -31,19 +32,19 @@ import org.apache.flink.table.types.logical.RowType import org.apache.calcite.rex._ +import scala.collection.JavaConverters._ + object CalcCodeGenerator { def generateCalcOperator( ctx: CodeGeneratorContext, inputTransform: Transformation[RowData], + inputType: RowType, outputType: RowType, projection: Seq[RexNode], condition: Option[RexNode], retainHeader: Boolean = false, opName: String): CodeGenOperatorFactory[RowData] = { - val inputType = inputTransform.getOutputType - .asInstanceOf[InternalTypeInfo[RowData]] - .toRowType // filter out time attributes val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM val processCode = generateProcessCode( @@ -53,8 +54,12 @@ object CalcCodeGenerator { classOf[BoxedWrapperRowData], projection, condition, + inputTerm, + CodeGenUtils.DEFAULT_OPERATOR_COLLECTOR_TERM, eagerInputUnboxingCode = true, - retainHeader = retainHeader) + retainHeader = retainHeader, + outputDirectly = false + ) val genOperator = OperatorCodeGenerator.generateOneInputStreamOperator[RowData, RowData]( @@ -87,6 +92,7 @@ object CalcCodeGenerator { outRowClass, calcProjection, calcCondition, + inputTerm, collectorTerm = collectorTerm, eagerInputUnboxingCode = false, outputDirectly = true @@ -121,7 +127,9 @@ object CalcCodeGenerator { projection.foreach(_.accept(ScalarFunctionsValidator)) condition.foreach(_.accept(ScalarFunctionsValidator)) - val exprGenerator = new ExprCodeGenerator(ctx, false) + val rexProgram = buildRexProgram(ctx.classLoader, inputType, projection, condition) + + val exprGenerator = new ExprCodeGenerator(ctx, false, rexProgram) .bindInput(inputType, inputTerm = inputTerm) val onlyFilter = projection.lengthCompare(inputType.getFieldCount) == 0 && @@ -137,6 +145,8 @@ object CalcCodeGenerator { } def produceProjectionCode: String = { + val projection = rexProgram.getProjectList.asScala + val projectionExprs = projection.map(exprGenerator.generateExpression) val projectionExpression = exprGenerator.generateResultExpression(projectionExprs, outRowType, outRowClass) @@ -162,16 +172,20 @@ object CalcCodeGenerator { "It should be removed by CalcRemoveRule.") } else if (condition.isEmpty) { // only projection val projectionCode = produceProjectionCode + val localRefCode = ctx.reuseLocalRefCode() s""" |${if (eagerInputUnboxingCode) ctx.reuseInputUnboxingCode() else ""} + |$localRefCode |$projectionCode |""".stripMargin } else { - val filterCondition = exprGenerator.generateExpression(condition.get) + val filterCondition = exprGenerator.generateExpression(rexProgram.getCondition) // only filter if (onlyFilter) { + val localRefCode = ctx.reuseLocalRefCode() s""" |${if (eagerInputUnboxingCode) ctx.reuseInputUnboxingCode() else ""} + |$localRefCode |${filterCondition.code} |if (${filterCondition.resultTerm}) { | ${produceOutputCode(inputTerm)} @@ -181,19 +195,35 @@ object CalcCodeGenerator { val filterInputCode = ctx.reuseInputUnboxingCode() val filterInputSet = Set(ctx.reusableInputUnboxingExprs.keySet.toSeq: _*) + val filterLocalRefSet: Set[Int] = ctx.reusableLocalRefExprs.keySet.toSet + // if any filter conditions, projection code will enter an new scope val projectionCode = produceProjectionCode val projectionInputCode = ctx.reusableInputUnboxingExprs - .filter(entry => !filterInputSet.contains(entry._1)) + .filter { case (k, _) => !filterInputSet.contains(k) } + .values + .map(_.code) + .mkString("\n") + + val filterLocalRefCode = ctx.reusableLocalRefExprs + .filter { case (k, _) => filterLocalRefSet.contains(k) } .values .map(_.code) .mkString("\n") + val projectionLocalRefCode = ctx.reusableLocalRefExprs + .filter { case (k, _) => !filterLocalRefSet.contains(k) } + .values + .map(_.code) + .mkString("\n") + s""" |${if (eagerInputUnboxingCode) filterInputCode else ""} + |$filterLocalRefCode |${filterCondition.code} |if (${filterCondition.resultTerm}) { - | ${if (eagerInputUnboxingCode) projectionInputCode else ""} + | ${if (eagerInputUnboxingCode) projectionInputCode else ""} + | $projectionLocalRefCode | $projectionCode |} |""".stripMargin @@ -201,6 +231,22 @@ object CalcCodeGenerator { } } + private def buildRexProgram( + classLoader: ClassLoader, + inputType: RowType, + projection: Seq[RexNode], + condition: Option[RexNode]): RexProgram = { + val typeFactory = new FlinkTypeFactory(classLoader, FlinkTypeSystem.INSTANCE) + val rexBuilder = new FlinkRexBuilder(typeFactory) + val relInputType = typeFactory.createFieldTypeFromLogicalType(inputType) + val builder = new RexProgramBuilder(relInputType, rexBuilder) + projection.foreach(p => builder.addProject(p, null)) + if (condition.isDefined) { + builder.addCondition(condition.get) + } + builder.getProgram + } + private object ScalarFunctionsValidator extends RexVisitorImpl[Unit](true) { override def visitCall(call: RexCall): Unit = { super.visitCall(call) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala index 02706cc130909..6c2a5b20ba74e 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala @@ -17,7 +17,6 @@ */ package org.apache.flink.table.planner.codegen -import org.apache.flink.api.common.functions.Function import org.apache.flink.api.common.typeutils.TypeSerializer import org.apache.flink.configuration.ReadableConfig import org.apache.flink.table.data.GenericRowData @@ -116,6 +115,24 @@ class CodeGeneratorContext( val reusableInputUnboxingExprs: mutable.Map[(String, Int), GeneratedExpression] = mutable.Map[(String, Int), GeneratedExpression]() + // map of expressions for shared RexProgram exprList entries that will be added only once + // exprList index -> expr + val reusableLocalRefExprs: mutable.LinkedHashMap[Int, GeneratedExpression] = + mutable.LinkedHashMap[Int, GeneratedExpression]() + + // Stack of RexLocalRef cache scopes. The bottom scope IS reusableLocalRefExprs and is read + // by CalcCodeGenerator.reuseLocalRefCode() — its bodies are hoisted to the top of the + // generated method and so must be safe to evaluate unconditionally. + // + // ExprCodeGenerator pushes an inner scope before visiting a guarded operand (CASE WHEN's + // THEN/ELSE branch, AND/OR's right-hand side, ...) and pops it after. Any RexLocalRef body + // cached during that visit lives only in the inner scope; ExprCodeGenerator folds those + // bodies into the operand's generated code so they execute only when the guard fires. + // Without this scoping, an arithmetic expression like (a / b) inside CASE WHEN b > 0 would + // be hoisted above the if-block and divide by zero on rows where b == 0. + private val localRefScopes: mutable.ArrayBuffer[mutable.LinkedHashMap[Int, GeneratedExpression]] = + mutable.ArrayBuffer(reusableLocalRefExprs) + // set of constructor statements that will be added only once // we use a LinkedHashSet to keep the insertion order private val reusableConstructorStatements: mutable.LinkedHashSet[(String, String)] = @@ -173,6 +190,19 @@ class CodeGeneratorContext( def getReusableInputUnboxingExprs(inputTerm: String, index: Int): Option[GeneratedExpression] = reusableInputUnboxingExprs.get((inputTerm, index)) + def getReusableLocalRefExpr(index: Int): Option[GeneratedExpression] = { + // Search innermost-out: a body cached in an inner (guarded) scope wins over outer + // entries. In practice the cache is monotone — an entry never appears in two scopes + // simultaneously. + var i = localRefScopes.size - 1 + while (i >= 0) { + val maybe = localRefScopes(i).get(index) + if (maybe.isDefined) return maybe + i -= 1 + } + None + } + /** Prioritize using the nameCounter of the ancestor. */ def getNameCounter: AtomicLong = if (parentCtx == null) nameCounter else parentCtx.getNameCounter @@ -375,6 +405,10 @@ class CodeGeneratorContext( reusableInputUnboxingExprs.values.map(_.code).mkString("\n") } + def reuseLocalRefCode(): String = { + reusableLocalRefExprs.values.map(_.code).mkString("\n") + } + /** Returns code block of unboxing input variables which belongs to the given inputTerm. */ def reuseInputUnboxingCode(inputTerm: String): String = { val exprs = reusableInputUnboxingExprs.filter { @@ -458,6 +492,10 @@ class CodeGeneratorContext( index: Int, expr: GeneratedExpression): Unit = reusableInputUnboxingExprs((inputTerm, index)) = expr + /** Adds a reusable RexLocalRef expression keyed by its index in the program's exprList. */ + def addReusableLocalRefExpr(index: Int, expr: GeneratedExpression): Unit = + localRefScopes.last(index) = expr + /** Adds a reusable output record statement to member area. */ def addReusableOutputRecord( t: LogicalType, @@ -1075,4 +1113,15 @@ class CodeGeneratorContext( fieldTerm } + + def pushLocalRefScope(): Unit = { + localRefScopes.append(mutable.LinkedHashMap.empty) + } + + def popLocalRefScope(): scala.collection.Map[Int, GeneratedExpression] = { + require( + localRefScopes.size > 1, + "Cannot pop the bottom RexLocalRef cache scope (reusableLocalRefExprs).") + localRefScopes.remove(localRefScopes.size - 1) + } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala index 7154aa09f0a6e..97454165db992 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala @@ -38,6 +38,7 @@ import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable._ import org.apache.flink.table.planner.functions.sql.SqlThrowExceptionFunction import org.apache.flink.table.planner.functions.utils.{ScalarSqlFunction, TableSqlFunction} import org.apache.flink.table.planner.plan.utils.RexLiteralUtil +import org.apache.flink.table.planner.utils.ShortcutUtils import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromLogicalTypeToDataType import org.apache.flink.table.runtime.types.PlannerTypeUtils.isInteroperable import org.apache.flink.table.runtime.typeutils.TypeCheckUtils @@ -55,9 +56,15 @@ import scala.collection.JavaConversions._ * This code generator is mainly responsible for generating codes for a given calcite [[RexNode]]. * It can also generate type conversion codes for the result converter. */ -class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) +class ExprCodeGenerator( + ctx: CodeGeneratorContext, + nullableInput: Boolean, + val rexProgram: RexProgram) extends RexVisitor[GeneratedExpression] { + def this(ctx: CodeGeneratorContext, nullableInput: Boolean) = + this(ctx, nullableInput, null) + /** term of the [[ProcessFunction]]'s context, can be changed when needed */ var contextTerm = "ctx" @@ -344,7 +351,6 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) } override def visitInputRef(inputRef: RexInputRef): GeneratedExpression = { - // for specific custom code generation if (input1Type == null) { return GeneratedExpression( inputRef.getName, @@ -416,8 +422,53 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) GeneratedExpression(input1Term, NEVER_NULL, NO_CODE, input1Type) } - override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression = - throw new CodeGenException("RexLocalRef are not supported yet.") + override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression = { + // addReusableLocalVariable + // for specific custom code generation + if (input1Type == null) { + return GeneratedExpression( + localRef.getName, + localRef.getName + "IsNull", + NO_CODE, + FlinkTypeFactory.toLogicalType(localRef.getType)) + } + // for the general cases with a previous call to bindInput() + val input1Arity = input1Type match { + case r: RowType => r.getFieldCount + case _ => 1 + } + if (localRef.getIndex >= input1Arity) { + if (rexProgram == null) { + throw new CodeGenException(s"RexLocalRef(${localRef.getIndex}) requires a RexProgram.") + } + val idx = localRef.getIndex + val target = rexProgram.getExprList.get(idx) + if (!isDeterministicThroughProgram(target)) { + return target.accept(this) + } + val full = ctx.getReusableLocalRefExpr(idx) match { + case Some(cached) => cached + case None => + val expr = target.accept(this) + ctx.addReusableLocalRefExpr(idx, expr) + expr + } + return GeneratedExpression( + full.resultTerm, + full.nullTerm, + NO_CODE, + full.resultType, + full.literalValue) + } + + generateInputAccess( + ctx, + input1Type, + input1Term, + localRef.getIndex, + nullableInput, + deepCopy = true) + } def visitRexFieldVariable(variable: RexFieldVariable): GeneratedExpression = { val internalType = FlinkTypeFactory.toLogicalType(variable.dataType) @@ -462,7 +513,7 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) val resultType = FlinkTypeFactory.toLogicalType(call.getType) // throw exception if json function is called outside JSON_OBJECT or JSON_ARRAY function - if (isJsonFunctionOperand(call)) { + if (isJsonFunctionOperand(call, if (rexProgram == null) null else rexProgram.getExprList)) { throw new ValidationException( "The JSON() function is currently only supported inside JSON_ARRAY() or as the VALUE param" + " of JSON_OBJECT(). Example: JSON_OBJECT('a', JSON('{\"key\": \"value\"}')) or " + @@ -470,13 +521,19 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) } if (call.getKind == SqlKind.SEARCH) { - return generateSearch( - ctx, - generateExpression(call.getOperands.get(0)), - call.getOperands.get(1).asInstanceOf[RexLiteral]) + val sargLiteral = + if (rexProgram != null && call.getOperands.get(1).isInstanceOf[RexLocalRef]) { + rexProgram.getExprList + .get(call.getOperands.get(1).asInstanceOf[RexLocalRef].getIndex) + .asInstanceOf[RexLiteral] + } else { + call.getOperands.get(1).asInstanceOf[RexLiteral] + } + return generateSearch(ctx, generateExpression(call.getOperands.get(0)), sargLiteral) } // convert operands and help giving untyped NULL literals a type + val condIdxs = conditionalOperandIndices(call) val operands = call.getOperands.zipWithIndex.map { // this helps e.g. for AS(null) @@ -487,15 +544,59 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) generateNullLiteral(resultType) // We only support the JSON function inside of JSON_OBJECT or JSON_ARRAY - case (operand: RexNode, i) if isSupportedJsonOperand(operand, call, i) => + case (operand: RexNode, i) + if isSupportedJsonOperand( + operand, + call, + i, + if (rexProgram == null) null else rexProgram.getExprList) => generateJsonCall(operand) + case (o @ _, i) if condIdxs.contains(i) => visitOperandInScopedCache(o) + case (o @ _, _) => o.accept(this) } generateCallExpression(ctx, call, operands, resultType) } + /** + * Indices of `call`'s operands that are NOT unconditionally evaluated at runtime. Used to scope + * the RexLocalRef cache so that bodies cached while visiting these operands are not hoisted out + * of the surrounding short-circuit / if-block. + * + * - `CASE(when_1, then_1, when_2, then_2, ..., else)`: only `when_1` is unconditional. + * - `AND(a_0, a_1, ..., a_n)` / `OR(...)`: only `a_0` is unconditional; subsequent operands are + * short-circuited by the operator semantics and the codegen. + */ + private def conditionalOperandIndices(call: RexCall): Set[Int] = call.getKind match { + case SqlKind.CASE | SqlKind.AND | SqlKind.OR | SqlKind.COALESCE => + (1 until call.getOperands.size).toSet + case _ => Set.empty + } + + private def visitOperandInScopedCache(operand: RexNode): GeneratedExpression = { + ctx.pushLocalRefScope() + val (operandExpr, scopedBodies) = + try { + val expr = operand.accept(this) + val popped = ctx.popLocalRefScope() + (expr, popped.values.map(_.code).mkString("\n")) + } catch { + case t: Throwable => + ctx.popLocalRefScope() + throw t + } + if (scopedBodies.isEmpty) operandExpr + else + GeneratedExpression( + operandExpr.resultTerm, + operandExpr.nullTerm, + scopedBodies + "\n" + operandExpr.code, + operandExpr.resultType, + operandExpr.literalValue) + } + override def visitOver(over: RexOver): GeneratedExpression = throw new CodeGenException("Aggregate functions over windows are not supported yet.") @@ -786,9 +887,11 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) case JSON_QUERY => new JsonQueryCallGen().generate(ctx, operands, resultType) - case JSON_OBJECT => new JsonObjectCallGen(call).generate(ctx, operands, resultType) + case JSON_OBJECT => + new JsonObjectCallGen(call, rexProgram).generate(ctx, operands, resultType) - case JSON_ARRAY => new JsonArrayCallGen(call).generate(ctx, operands, resultType) + case JSON_ARRAY => + new JsonArrayCallGen(call, rexProgram).generate(ctx, operands, resultType) case _: SqlThrowExceptionFunction => val nullValue = generateNullLiteral(resultType) @@ -827,7 +930,7 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) generateGreatestLeast(ctx, resultType, operands, greatest = false) case BuiltInFunctionDefinitions.JSON_STRING => - new JsonStringCallGen(call).generate(ctx, operands, resultType) + new JsonStringCallGen(call, rexProgram).generate(ctx, operands, resultType) case BuiltInFunctionDefinitions.INTERNAL_HASHCODE => new HashCodeCallGen().generate(ctx, operands, resultType) @@ -847,7 +950,7 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) new JsonCallGen().generate(ctx, operands, FlinkTypeFactory.toLogicalType(call.getType)) case _ => - new BridgingSqlFunctionCallGen(call).generate(ctx, operands, resultType) + new BridgingSqlFunctionCallGen(call, rexProgram).generate(ctx, operands, resultType) } // advanced scalar functions @@ -875,7 +978,14 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) } private def generateJsonCall(operand: RexNode) = { - val jsonCall = operand.asInstanceOf[RexCall] + // After unification of projections + condition into a single RexProgram, structurally + // identical sub-expressions are collapsed into one exprList entry referenced via + // RexLocalRef. JSON_OBJECT/JSON_ARRAY operands recognised as JSON via + // isSupportedJsonOperand may therefore arrive here as a RexLocalRef; resolve it back to + // the underlying RexCall before casting. + val jsonCall = ShortcutUtils + .expandLocalRef(operand, if (rexProgram == null) null else rexProgram.getExprList) + .asInstanceOf[RexCall] val jsonOperands = jsonCall.getOperands.map(_.accept(this)) generateCallExpression( ctx, @@ -896,4 +1006,9 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) } }.toArray } + + private def isDeterministicThroughProgram(node: RexNode): Boolean = + ShortcutUtils.isDeterministicThroughProgram( + node, + if (rexProgram == null) null else rexProgram.getExprList) } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala index 567554e6e1c84..b1c70fb9c5260 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala @@ -139,8 +139,8 @@ class ExpressionReducer( } else unreduced match { case call: RexCall - if (nonReducibleJsonFunctions.contains(call.getOperator) || isJsonFunctionOperand( - call)) => + if (nonReducibleJsonFunctions.contains(call.getOperator) + || isJsonFunctionOperand(call, null)) => reducedValues.add(unreduced) case _ => unreduced.getType.getSqlTypeName match { @@ -297,7 +297,10 @@ class ExpressionReducer( } // Exclude some JSON functions which behave differently // when called as an argument of another call of one of these functions. - if (nonReducibleJsonFunctions.contains(call.getOperator) || isJsonFunctionOperand(call)) { + if ( + nonReducibleJsonFunctions.contains(call.getOperator) + || isJsonFunctionOperand(call, null) + ) { None } else { Some(call) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala index 3fd256e071d4b..73c2f30ad8819 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala @@ -21,11 +21,11 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.{ArrayNode, ObjectNode} import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.util.RawValue import org.apache.flink.table.api.{DataTypes, JsonOnNull} -import org.apache.flink.table.functions.BuiltInFunctionDefinitions.JSON +import org.apache.flink.table.functions.BuiltInFunctionDefinitions import org.apache.flink.table.planner.codegen.CodeGenUtils._ -import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable.{JSON_ARRAY, JSON_OBJECT} import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala -import org.apache.flink.table.planner.utils.ShortcutUtils.unwrapFunctionDefinition +import org.apache.flink.table.planner.utils.ShortcutUtils +import org.apache.flink.table.planner.utils.ShortcutUtils.expandLocalRef import org.apache.flink.table.runtime.functions.SqlJsonUtils import org.apache.flink.table.runtime.typeutils.TypeCheckUtils.isCharacterString import org.apache.flink.table.types.logical._ @@ -33,7 +33,7 @@ import org.apache.flink.table.types.logical.LogicalTypeRoot._ import org.apache.flink.table.types.logical.utils.LogicalTypeChecks import org.apache.flink.table.utils.EncodingUtils -import org.apache.calcite.rex.{RexCall, RexNode} +import org.apache.calcite.rex.RexNode import java.time.format.DateTimeFormatter @@ -51,8 +51,9 @@ object JsonGenerateUtils { def createNodeTerm( ctx: CodeGeneratorContext, expression: GeneratedExpression, - operand: RexNode): String = { - if (isJsonObjectOrArrayOperand(operand) || isJsonFunctionOperand(operand)) { + operand: RexNode, + exprs: java.util.List[RexNode]): String = { + if (isJsonObjectOrArrayOperand(operand, exprs) || isJsonFunctionOperand(operand, exprs)) { createRawNodeTerm(expression) } else { createNodeTerm(ctx, expression) @@ -177,59 +178,36 @@ object JsonGenerateUtils { } } - /** Determines whether the given operand is a call to a JSON_OBJECT */ - def isJsonObjectOperand(operand: RexNode): Boolean = { - operand match { - case rexCall: RexCall => - rexCall.getOperator match { - case JSON_OBJECT => true - case _ => false - } - case _ => false - } - } + /** Determines whether the given operand is a call to a JSON_OBJECT. */ + def isJsonObjectOperand(operand: RexNode, exprs: java.util.List[RexNode]): Boolean = + ShortcutUtils.isOneOfFunctionDefinitions( + expandLocalRef(operand, exprs), + BuiltInFunctionDefinitions.JSON_OBJECT) - /** Determines whether the given operand is a call to a JSON_ARRAY */ - def isJsonArrayOperand(operand: RexNode): Boolean = { - operand match { - case rexCall: RexCall => - rexCall.getOperator match { - case JSON_ARRAY => true - case _ => false - } - case _ => false - } - } + /** Determines whether the given operand is a call to a JSON_ARRAY. */ + def isJsonArrayOperand(operand: RexNode, exprs: java.util.List[RexNode]): Boolean = + ShortcutUtils.isOneOfFunctionDefinitions( + expandLocalRef(operand, exprs), + BuiltInFunctionDefinitions.JSON_ARRAY) /** * Determines whether the given operand is a call to a JSON_OBJECT or JSON_ARRAY whose result * should be inserted as a raw value instead of as a character string. */ - def isJsonObjectOrArrayOperand(operand: RexNode): Boolean = { - operand match { - case rexCall: RexCall => - rexCall.getOperator match { - case JSON_OBJECT | JSON_ARRAY => true - case _ => false - } - case _ => false - } - } + def isJsonObjectOrArrayOperand(operand: RexNode, exprs: java.util.List[RexNode]): Boolean = + ShortcutUtils.isOneOfFunctionDefinitions( + expandLocalRef(operand, exprs), + BuiltInFunctionDefinitions.JSON_OBJECT, + BuiltInFunctionDefinitions.JSON_ARRAY) /** * Determines whether the given operand is a call to JSON function whose call currently just - * passes through the input value as output value + * passes through the input value as output value. */ - def isJsonFunctionOperand(operand: RexNode): Boolean = { - operand match { - case rexCall: RexCall => - unwrapFunctionDefinition(rexCall) match { - case JSON => true - case _ => false - } - case _ => false - } - } + def isJsonFunctionOperand(operand: RexNode, exprs: java.util.List[RexNode]): Boolean = + ShortcutUtils.isOneOfFunctionDefinitions( + expandLocalRef(operand, exprs), + BuiltInFunctionDefinitions.JSON) /** * Determines whether a JSON function is allowed in the current context. JSON functions are @@ -237,9 +215,13 @@ object JsonGenerateUtils { * of a JSON_OBJECT call, we do (i % 2) == 0 to check if it's being used in second parameter, the * values' parameter. */ - def isSupportedJsonOperand(operand: RexNode, call: RexNode, i: Int): Boolean = { - isJsonFunctionOperand(operand) && - (isJsonArrayOperand(call) || isJsonObjectOperand(call) && (i % 2) == 0) + def isSupportedJsonOperand( + operand: RexNode, + call: RexNode, + i: Int, + exprs: java.util.List[RexNode]): Boolean = { + isJsonFunctionOperand(operand, exprs) && + (isJsonArrayOperand(call, exprs) || isJsonObjectOperand(call, exprs) && (i % 2) == 0) } /** Generates a method to convert arrays into [[ArrayNode]]. */ diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala index 4256c90b5960d..bae94282d1da7 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala @@ -18,7 +18,7 @@ package org.apache.flink.table.planner.codegen import org.apache.flink.api.common.functions.DefaultOpenContext -import org.apache.flink.configuration.{Configuration, ReadableConfig} +import org.apache.flink.configuration.ReadableConfig import org.apache.flink.metrics.Gauge import org.apache.flink.table.data.{RowData, TimestampData} import org.apache.flink.table.data.utils.JoinedRowData diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingFunctionGenUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingFunctionGenUtil.scala index cb241c8feb4dd..2115e76304d96 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingFunctionGenUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingFunctionGenUtil.scala @@ -18,7 +18,7 @@ package org.apache.flink.table.planner.codegen.calls import org.apache.flink.api.common.functions.{AbstractRichFunction, OpenContext, RichFunction} -import org.apache.flink.configuration.{Configuration, ReadableConfig} +import org.apache.flink.configuration.ReadableConfig import org.apache.flink.table.api.{DataTypes, TableException} import org.apache.flink.table.api.Expressions.callSql import org.apache.flink.table.data.{GenericRowData, RawValueData, StringData} @@ -27,9 +27,10 @@ import org.apache.flink.table.expressions.ApiExpressionUtils.{typeLiteral, unres import org.apache.flink.table.expressions.Expression import org.apache.flink.table.functions._ import org.apache.flink.table.functions.SpecializedFunction.{ExpressionEvaluator, ExpressionEvaluatorFactory} -import org.apache.flink.table.functions.UserDefinedFunctionHelper.{validateClassForRuntime, ASYNC_SCALAR_EVAL, ASYNC_TABLE_EVAL, SCALAR_EVAL, TABLE_EVAL} +import org.apache.flink.table.functions.UserDefinedFunctionHelper._ import org.apache.flink.table.planner.calcite.{FlinkTypeFactory, RexFactory} import org.apache.flink.table.planner.codegen._ +import org.apache.flink.table.planner.codegen.AsyncCodeGenerator.DEFAULT_DELEGATING_FUTURE_TERM import org.apache.flink.table.planner.codegen.CodeGenUtils._ import org.apache.flink.table.planner.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE} import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala @@ -48,8 +49,6 @@ import org.apache.flink.table.types.utils.DataTypeUtils import org.apache.flink.table.types.utils.DataTypeUtils.{isInternal, validateInputDataType, validateOutputDataType} import org.apache.flink.util.Preconditions -import AsyncCodeGenerator.{generateFunction, DEFAULT_DELEGATING_FUTURE_TERM} - import java.util.concurrent.CompletableFuture import scala.collection.JavaConverters._ diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala index cada84a1bca8b..77a3850f664f8 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala @@ -26,7 +26,7 @@ import org.apache.flink.table.planner.functions.inference.OperatorBindingCallCon import org.apache.flink.table.runtime.collector.WrappingCollector import org.apache.flink.table.types.logical.LogicalType -import org.apache.calcite.rex.{RexCall, RexCallBinding} +import org.apache.calcite.rex.{RexCall, RexCallBinding, RexProgram} import java.util.Collections @@ -37,7 +37,9 @@ import java.util.Collections * generator will be a reference to a [[WrappingCollector]]. Furthermore, atomic types are wrapped * into a row by the collector. */ -class BridgingSqlFunctionCallGen(call: RexCall) extends CallGenerator { +class BridgingSqlFunctionCallGen(call: RexCall, rexProgram: RexProgram) extends CallGenerator { + + def this(call: RexCall) = this(call, null) override def generate( ctx: CodeGeneratorContext, @@ -54,7 +56,7 @@ class BridgingSqlFunctionCallGen(call: RexCall) extends CallGenerator { val callContext = new OperatorBindingCallContext( dataTypeFactory, definition, - RexCallBinding.create(function.getTypeFactory, call, Collections.emptyList()), + RexCallBinding.create(function.getTypeFactory, call, rexProgram, Collections.emptyList()), call.getType) // create the final UDF for runtime diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala index ea324ec6082b5..eadc6413026dd 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala @@ -25,10 +25,13 @@ import org.apache.flink.table.planner.codegen.JsonGenerateUtils.{createNodeTerm, import org.apache.flink.table.runtime.functions.SqlJsonUtils import org.apache.flink.table.types.logical.LogicalType -import org.apache.calcite.rex.RexCall +import org.apache.calcite.rex.{RexCall, RexProgram} /** [[CallGenerator]] for `JSON_ARRAY`. */ -class JsonArrayCallGen(call: RexCall) extends CallGenerator { +class JsonArrayCallGen(call: RexCall, rexProgram: RexProgram) extends CallGenerator { + + def this(call: RexCall) = this(call, null) + private def jsonUtils = className[SqlJsonUtils] override def generate( @@ -47,7 +50,9 @@ class JsonArrayCallGen(call: RexCall) extends CallGenerator { .drop(1) .map { case (elementExpr, elementIdx) => - val elementTerm = createNodeTerm(ctx, elementExpr, call.operands.get(elementIdx)) + val exprs = if (rexProgram == null) null else rexProgram.getExprList + val elementTerm = + createNodeTerm(ctx, elementExpr, call.operands.get(elementIdx), exprs) onNull match { case JsonOnNull.NULL => diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala index 9a5c87fb06f79..b876ec85e2961 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala @@ -25,7 +25,7 @@ import org.apache.flink.table.planner.codegen.JsonGenerateUtils.{createNodeTerm, import org.apache.flink.table.runtime.functions.SqlJsonUtils import org.apache.flink.table.types.logical.LogicalType -import org.apache.calcite.rex.RexCall +import org.apache.calcite.rex.{RexCall, RexProgram} /** * [[CallGenerator]] for `JSON_OBJECT`. @@ -37,7 +37,10 @@ import org.apache.calcite.rex.RexCall * We remedy this by treating nested calls to this function differently and inserting the value as a * raw node instead of as a string node. */ -class JsonObjectCallGen(call: RexCall) extends CallGenerator { +class JsonObjectCallGen(call: RexCall, rexProgram: RexProgram) extends CallGenerator { + + def this(call: RexCall) = this(call, null) + private def jsonUtils = className[SqlJsonUtils] override def generate( @@ -57,7 +60,8 @@ class JsonObjectCallGen(call: RexCall) extends CallGenerator { .grouped(2) .map { case Seq((keyExpr, _), (valueExpr, valueIdx)) => - val valueTerm = createNodeTerm(ctx, valueExpr, call.operands.get(valueIdx)) + val exprs = if (rexProgram == null) null else rexProgram.getExprList + val valueTerm = createNodeTerm(ctx, valueExpr, call.operands.get(valueIdx), exprs) onNull match { case JsonOnNull.NULL => diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala index 1265d11a2787b..8dd4dc9fe7c7c 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala @@ -23,10 +23,13 @@ import org.apache.flink.table.planner.codegen.JsonGenerateUtils.createNodeTerm import org.apache.flink.table.runtime.functions.SqlJsonUtils import org.apache.flink.table.types.logical.LogicalType -import org.apache.calcite.rex.RexCall +import org.apache.calcite.rex.{RexCall, RexProgram} /** [[CallGenerator]] for `JSON_STRING`. */ -class JsonStringCallGen(call: RexCall) extends CallGenerator { +class JsonStringCallGen(call: RexCall, rexProgram: RexProgram) extends CallGenerator { + + def this(call: RexCall) = this(call, null) + private def jsonUtils = className[SqlJsonUtils] override def generate( @@ -34,7 +37,8 @@ class JsonStringCallGen(call: RexCall) extends CallGenerator { operands: Seq[GeneratedExpression], returnType: LogicalType): GeneratedExpression = { - val valueTerm = createNodeTerm(ctx, operands.head, call.operands.get(0)) + val exprs = if (rexProgram == null) null else rexProgram.getExprList + val valueTerm = createNodeTerm(ctx, operands.head, call.operands.get(0), exprs) val resultTerm = newName(ctx, "result") val resultTermType = primitiveTypeTermForType(returnType) diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonFunctionsITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonFunctionsITCase.java index 6c236ccd04d5d..bfcd299d937bb 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonFunctionsITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonFunctionsITCase.java @@ -41,7 +41,6 @@ import java.time.Instant; import java.time.LocalDateTime; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -94,6 +93,7 @@ Stream getTestSetSpecs() { testCases.addAll(jsonQuoteSpec()); testCases.addAll(jsonUnquoteSpecWithValidInput()); testCases.addAll(jsonUnquoteSpecWithInvalidInput()); + testCases.addAll(jsonLocalRefReuseSpec()); return testCases.stream(); } @@ -296,7 +296,7 @@ private static TestSetSpec jsonValueSpec() { } private static List isJsonSpec() { - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.IS_JSON) .onFieldsWithData(1) .andDataTypes(INT()) @@ -367,7 +367,7 @@ private static List isJsonSpec() { private static List jsonQuerySpec() { final String jsonValue = getJsonFromResource("/json/json-query.json"); - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_QUERY) .onFieldsWithData((String) null) .andDataTypes(STRING()) @@ -599,7 +599,7 @@ private static List jsonStringSpec() { multisetData.put("M1", 1); multisetData.put("M2", 2); - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_STRING) .onFieldsWithData(0) .testResult( @@ -616,7 +616,7 @@ private static List jsonStringSpec() { 1.23, LocalDateTime.parse("1990-06-02T13:37:42.001"), Instant.parse("1990-06-02T13:37:42.001Z"), - Arrays.asList("A1", "A2", "A3"), + List.of("A1", "A2", "A3"), Row.of("R1", Instant.parse("1990-06-02T13:37:42.001Z")), mapData, multisetData, @@ -717,7 +717,7 @@ private static List jsonStringSpec() { } private static List jsonSpec() { - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_OBJECT) .onFieldsWithData("{\"key\":\"value\"}", "{\"key\": {\"value\": 42}}") .andDataTypes(STRING(), STRING()) @@ -946,7 +946,7 @@ private static List jsonObjectSpec() { multisetData.put("M1", 1); multisetData.put("M2", 2); - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_OBJECT) .onFieldsWithData(0) .testResult( @@ -977,7 +977,7 @@ private static List jsonObjectSpec() { 1.23, LocalDateTime.parse("1990-06-02T13:37:42.001"), Instant.parse("1990-06-02T13:37:42.001Z"), - Arrays.asList("A1", "A2", "A3"), + List.of("A1", "A2", "A3"), Row.of("R1", Instant.parse("1990-06-02T13:37:42.001Z")), mapData, multisetData, @@ -1105,7 +1105,7 @@ private static List jsonObjectSpec() { private static List jsonQuoteSpec() { - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_QUOTE) .onFieldsWithData(0) .testResult( @@ -1177,7 +1177,7 @@ private static List jsonQuoteSpec() { private static List jsonUnquoteSpecWithValidInput() { - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_UNQUOTE) .onFieldsWithData(0) .testResult( @@ -1319,7 +1319,7 @@ private static List jsonUnquoteSpecWithValidInput() { private static List jsonUnquoteSpecWithInvalidInput() { - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_UNQUOTE) .onFieldsWithData(0) .testResult( @@ -1406,7 +1406,7 @@ private static List jsonArraySpec() { multisetData.put("M1", 1); multisetData.put("M2", 2); - return Arrays.asList( + return List.of( TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_ARRAY) .onFieldsWithData(0) .testResult( @@ -1436,7 +1436,7 @@ private static List jsonArraySpec() { 1.23, LocalDateTime.parse("1990-06-02T13:37:42.001"), Instant.parse("1990-06-02T13:37:42.001Z"), - Arrays.asList("A1", "A2", "A3"), + List.of("A1", "A2", "A3"), Row.of("R1", Instant.parse("1990-06-02T13:37:42.001Z")), mapData, multisetData, @@ -1540,6 +1540,166 @@ private static List jsonArraySpec() { STRING().notNull())); } + /** + * Pins the local-ref / common-sub-expression handling for JSON construction calls. + * + *

When two projections share a JSON-producing sub-expression, the planner deduplicates them + * into a {@link org.apache.calcite.rex.RexLocalRef} that points at the shared {@link + * org.apache.calcite.rex.RexCall}. The codegen helpers in {@code JsonGenerateUtils} must + * dereference that local ref through the surrounding {@code RexProgram} to recognize it as a + * JSON / JSON_OBJECT / JSON_ARRAY operand and embed the value as a raw JSON node. Without that + * dereference the helpers see a plain {@code RexLocalRef}, fall back to the string-quoting + * branch, and produce wrong output (e.g. {@code "{\"k\":\"[1,2,3]\"}"} instead of {@code + * "{\"k\":[1,2,3]}"}). + * + *

The scenarios below cover each callsite — {@code JsonObjectCallGen}, {@code + * JsonArrayCallGen}, {@code JsonStringCallGen} — and each branch of the inspection helpers + * ({@code JSON}, {@code JSON_OBJECT}, {@code JSON_ARRAY}). + */ + private static List jsonLocalRefReuseSpec() { + return List.of( + // Shared JSON(f) inside two JSON_OBJECT projections. + TestSetSpec.forFunction( + BuiltInFunctionDefinitions.JSON_OBJECT, + "Shared JSON(f) sub-expression across JSON_OBJECT projections") + .onFieldsWithData("[1,2,3]") + .andDataTypes(STRING()) + .testResult( + resultSpec( + jsonObject(JsonOnNull.NULL, "k1", json($("f0"))), + "JSON_OBJECT(KEY 'k1' VALUE JSON(f0))", + "{\"k1\":[1,2,3]}", + STRING().notNull(), + STRING().notNull()), + resultSpec( + jsonObject(JsonOnNull.NULL, "k2", json($("f0"))), + "JSON_OBJECT(KEY 'k2' VALUE JSON(f0))", + "{\"k2\":[1,2,3]}", + STRING().notNull(), + STRING().notNull())), + // Shared JSON_ARRAY(...) inside two JSON_OBJECT projections. + TestSetSpec.forFunction( + BuiltInFunctionDefinitions.JSON_OBJECT, + "Shared JSON_ARRAY sub-expression across JSON_OBJECT projections") + .onFieldsWithData(1, 2, 3) + .andDataTypes(INT(), INT(), INT()) + .testResult( + resultSpec( + jsonObject( + JsonOnNull.NULL, + "a", + jsonArray( + JsonOnNull.NULL, + $("f0"), + $("f1"), + $("f2"))), + "JSON_OBJECT(KEY 'a' VALUE JSON_ARRAY(f0, f1, f2))", + "{\"a\":[1,2,3]}", + STRING().notNull(), + STRING().notNull()), + resultSpec( + jsonObject( + JsonOnNull.NULL, + "b", + jsonArray( + JsonOnNull.NULL, + $("f0"), + $("f1"), + $("f2"))), + "JSON_OBJECT(KEY 'b' VALUE JSON_ARRAY(f0, f1, f2))", + "{\"b\":[1,2,3]}", + STRING().notNull(), + STRING().notNull())), + // Shared inner JSON_OBJECT inside two outer JSON_OBJECT projections. + TestSetSpec.forFunction( + BuiltInFunctionDefinitions.JSON_OBJECT, + "Shared inner JSON_OBJECT across outer JSON_OBJECT projections") + .onFieldsWithData("V") + .andDataTypes(STRING()) + .testResult( + resultSpec( + jsonObject( + JsonOnNull.NULL, + "outer1", + jsonObject(JsonOnNull.NULL, "inner", $("f0"))), + "JSON_OBJECT(KEY 'outer1' VALUE JSON_OBJECT(KEY 'inner' VALUE f0))", + "{\"outer1\":{\"inner\":\"V\"}}", + STRING().notNull(), + STRING().notNull()), + resultSpec( + jsonObject( + JsonOnNull.NULL, + "outer2", + jsonObject(JsonOnNull.NULL, "inner", $("f0"))), + "JSON_OBJECT(KEY 'outer2' VALUE JSON_OBJECT(KEY 'inner' VALUE f0))", + "{\"outer2\":{\"inner\":\"V\"}}", + STRING().notNull(), + STRING().notNull())), + // Shared JSON_OBJECT inside two JSON_ARRAY projections. + TestSetSpec.forFunction( + BuiltInFunctionDefinitions.JSON_ARRAY, + "Shared JSON_OBJECT inside JSON_ARRAY across projections") + .onFieldsWithData("V") + .andDataTypes(STRING()) + .testResult( + resultSpec( + jsonArray( + JsonOnNull.NULL, + jsonObject(JsonOnNull.NULL, "k", $("f0"))), + "JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE f0))", + "[{\"k\":\"V\"}]", + STRING().notNull(), + STRING().notNull()), + resultSpec( + jsonArray( + JsonOnNull.NULL, + jsonObject(JsonOnNull.NULL, "k", $("f0"))), + "JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE f0))", + "[{\"k\":\"V\"}]", + STRING().notNull(), + STRING().notNull())), + // Shared JSON(f) inside two JSON_ARRAY projections. + TestSetSpec.forFunction( + BuiltInFunctionDefinitions.JSON_ARRAY, + "Shared JSON(f) inside JSON_ARRAY across projections") + .onFieldsWithData("[1,2,3]") + .andDataTypes(STRING()) + .testResult( + resultSpec( + jsonArray(JsonOnNull.NULL, json($("f0"))), + "JSON_ARRAY(JSON(f0))", + "[[1,2,3]]", + STRING().notNull(), + STRING().notNull()), + resultSpec( + jsonArray(JsonOnNull.NULL, json($("f0"))), + "JSON_ARRAY(JSON(f0))", + "[[1,2,3]]", + STRING().notNull(), + STRING().notNull())), + // Shared JSON_OBJECT inside two JSON_STRING projections. JSON_STRING re-serializes + // the operand; without dereferencing the local ref it would wrap the already + // serialized JSON string a second time. + TestSetSpec.forFunction( + BuiltInFunctionDefinitions.JSON_STRING, + "Shared JSON_OBJECT inside JSON_STRING across projections") + .onFieldsWithData("V") + .andDataTypes(STRING()) + .testResult( + resultSpec( + jsonString(jsonObject(JsonOnNull.NULL, "k", $("f0"))), + "JSON_STRING(JSON_OBJECT(KEY 'k' VALUE f0))", + "{\"k\":\"V\"}", + STRING().notNull(), + STRING().notNull()), + resultSpec( + jsonString(jsonObject(JsonOnNull.NULL, "k", $("f0"))), + "JSON_STRING(JSON_OBJECT(KEY 'k' VALUE f0))", + "{\"k\":\"V\"}", + STRING().notNull(), + STRING().notNull()))); + } + // --------------------------------------------------------------------------------------------- /** diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java index c0246b6fe4103..c9d31af6e97e9 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java @@ -69,6 +69,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import java.lang.invoke.MethodHandle; @@ -76,7 +78,6 @@ import java.nio.ByteBuffer; import java.time.DayOfWeek; import java.time.LocalDateTime; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -85,7 +86,9 @@ import java.util.Random; import java.util.UUID; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.apache.flink.table.api.Expressions.$; import static org.apache.flink.table.utils.UserDefinedFunctions.GENERATED_LOWER_UDF_CLASS; @@ -131,10 +134,10 @@ public void before() throws Exception { void testCreateCatalogFunctionInDefaultCatalog() { String ddl1 = "create function f1 as 'org.apache.flink.function.TestFunction'"; tEnv().executeSql(ddl1); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f1"); + assertThat(List.of(tEnv().listFunctions())).contains("f1"); tEnv().executeSql("DROP FUNCTION IF EXISTS default_catalog.default_database.f1"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f1"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f1"); } @Test @@ -143,10 +146,10 @@ void testCreateFunctionWithFullPath() { "create function default_catalog.default_database.f2 as" + " 'org.apache.flink.function.TestFunction'"; tEnv().executeSql(ddl1); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f2"); + assertThat(List.of(tEnv().listFunctions())).contains("f2"); tEnv().executeSql("DROP FUNCTION IF EXISTS default_catalog.default_database.f2"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f2"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f2"); } @Test @@ -155,10 +158,10 @@ void testCreateFunctionWithoutCatalogIdentifier() { "create function default_database.f3 as" + " 'org.apache.flink.function.TestFunction'"; tEnv().executeSql(ddl1); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f3"); + assertThat(List.of(tEnv().listFunctions())).contains("f3"); tEnv().executeSql("DROP FUNCTION IF EXISTS default_catalog.default_database.f3"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f3"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f3"); } @Test @@ -186,7 +189,7 @@ void testDynamicDatetimeFunctionsAreEqual() { + " CURRENT_DATE = CURRENT_DATE()") .execute(); List actualRows = CollectionUtil.iteratorToList(tableResult.collect()); - assertThat(actualRows).isEqualTo(Arrays.asList(Row.of(true, true, true, true, true))); + assertThat(actualRows).isEqualTo(List.of(Row.of(true, true, true, true, true))); } @Test @@ -232,13 +235,13 @@ void testCreateTemporaryCatalogFunction() { String ddl4 = "drop temporary function if exists default_catalog.default_database.f4"; tEnv().executeSql(ddl1); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f4"); + assertThat(List.of(tEnv().listFunctions())).contains("f4"); tEnv().executeSql(ddl2); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f4"); + assertThat(List.of(tEnv().listFunctions())).contains("f4"); tEnv().executeSql(ddl3); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f4"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f4"); tEnv().executeSql(ddl1); assertThatThrownBy(() -> tEnv().executeSql(ddl1)) @@ -276,24 +279,24 @@ void testCreateTemporarySystemFunctionByUsingJar() throws Exception { "CREATE TEMPORARY SYSTEM FUNCTION f10 AS '%s' USING JAR '%s'", udfClassName, jarPath); tEnv().executeSql(ddl); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f10"); + assertThat(List.of(tEnv().listFunctions())).contains("f10"); try (CloseableIterator itor = tEnv().executeSql("SHOW JARS").collect()) { assertThat(itor.hasNext()).isFalse(); } tEnv().executeSql("DROP TEMPORARY SYSTEM FUNCTION f10"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f10"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f10"); } @Test void testCreateTemporarySystemFunctionWithTableAPI() { ResourceUri resourceUri = new ResourceUri(ResourceType.JAR, jarPath); - tEnv().createTemporarySystemFunction("f10", udfClassName, Arrays.asList(resourceUri)); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f10"); + tEnv().createTemporarySystemFunction("f10", udfClassName, List.of(resourceUri)); + assertThat(List.of(tEnv().listFunctions())).contains("f10"); tEnv().executeSql("DROP TEMPORARY SYSTEM FUNCTION f10"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f10"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f10"); } @Test @@ -303,7 +306,7 @@ void testUserDefinedTemporarySystemFunctionWithTableAPI() throws Exception { testUserDefinedFunctionByUsingJar( environment -> environment.createTemporarySystemFunction( - "lowerUdf", udfClassName, Arrays.asList(resourceUri)), + "lowerUdf", udfClassName, List.of(resourceUri)), dropFunctionSql); } @@ -314,20 +317,20 @@ void testCreateCatalogFunctionByUsingJar() { "CREATE FUNCTION default_database.f11 AS '%s' USING JAR '%s'", udfClassName, jarPath); tEnv().executeSql(ddl); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f11"); + assertThat(List.of(tEnv().listFunctions())).contains("f11"); tEnv().executeSql("DROP FUNCTION default_database.f11"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f11"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f11"); } @Test void testCreateCatalogFunctionWithTableAPI() { ResourceUri resourceUri = new ResourceUri(ResourceType.JAR, jarPath); - tEnv().createFunction("f11", udfClassName, Arrays.asList(resourceUri)); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f11"); + tEnv().createFunction("f11", udfClassName, List.of(resourceUri)); + assertThat(List.of(tEnv().listFunctions())).contains("f11"); tEnv().executeSql("DROP FUNCTION default_database.f11"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f11"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f11"); } @Test @@ -336,8 +339,7 @@ void testUserDefinedCatalogFunctionWithTableAPI() throws Exception { String dropFunctionSql = "DROP FUNCTION default_database.lowerUdf"; testUserDefinedFunctionByUsingJar( environment -> - environment.createFunction( - "lowerUdf", udfClassName, Arrays.asList(resourceUri)), + environment.createFunction("lowerUdf", udfClassName, List.of(resourceUri)), dropFunctionSql); } @@ -348,20 +350,20 @@ void testCreateTemporaryCatalogFunctionByUsingJar() { "CREATE TEMPORARY FUNCTION default_database.f12 AS '%s' USING JAR '%s'", udfClassName, jarPath); tEnv().executeSql(ddl); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f12"); + assertThat(List.of(tEnv().listFunctions())).contains("f12"); tEnv().executeSql("DROP TEMPORARY FUNCTION default_database.f12"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f12"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f12"); } @Test void testCreateTemporaryCatalogFunctionWithTableAPI() { ResourceUri resourceUri = new ResourceUri(ResourceType.JAR, jarPath); - tEnv().createTemporaryFunction("f12", udfClassName, Arrays.asList(resourceUri)); - assertThat(Arrays.asList(tEnv().listFunctions())).contains("f12"); + tEnv().createTemporaryFunction("f12", udfClassName, List.of(resourceUri)); + assertThat(List.of(tEnv().listFunctions())).contains("f12"); tEnv().executeSql("DROP TEMPORARY FUNCTION default_database.f12"); - assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f12"); + assertThat(List.of(tEnv().listFunctions())).doesNotContain("f12"); } @Test @@ -371,7 +373,7 @@ void testUserDefinedTemporaryCatalogFunctionWithTableAPI() throws Exception { testUserDefinedFunctionByUsingJar( environment -> environment.createTemporaryFunction( - "lowerUdf", udfClassName, Arrays.asList(resourceUri)), + "lowerUdf", udfClassName, List.of(resourceUri)), dropFunctionSql); } @@ -596,7 +598,7 @@ void testExpressionReducerByUsingJar() { TableResult tableResult = tEnv().executeSql("SELECT lowerUdf('HELLO')"); List actualRows = CollectionUtil.iteratorToList(tableResult.collect()); - assertThat(actualRows).isEqualTo(Arrays.asList(Row.of("hello"))); + assertThat(actualRows).isEqualTo(List.of(Row.of("hello"))); tEnv().executeSql("drop temporary function lowerUdf"); } @@ -611,7 +613,7 @@ public Integer eval(Integer a, Integer b) { private void testUserDefinedCatalogFunction(String createFunctionDDL) throws Exception { List sourceData = - Arrays.asList( + List.of( Row.of(1, "1000", 2), Row.of(2, "1", 3), Row.of(3, "2000", 4), @@ -644,7 +646,7 @@ private void testUserDefinedCatalogFunction(String createFunctionDDL) throws Exc private void testUserDefinedFunctionByUsingJar(FunctionCreator creator, String dropFunctionDDL) throws Exception { List sourceData = - Arrays.asList( + List.of( Row.of(1, "JARK"), Row.of(2, "RON"), Row.of(3, "LeoNard"), @@ -667,7 +669,7 @@ private void testUserDefinedFunctionByUsingJar(FunctionCreator creator, String d List result = TestCollectionTableFactory.RESULT(); List expected = - Arrays.asList( + List.of( Row.of(1, "jark"), Row.of(2, "ron"), Row.of(3, "leonard"), @@ -684,10 +686,10 @@ private void testUserDefinedFunctionByUsingJar(FunctionCreator creator, String d @Test void testPrimitiveScalarFunction() throws Exception { final List sourceData = - Arrays.asList(Row.of(1, 1L, "-"), Row.of(2, 2L, "--"), Row.of(3, 3L, "---")); + List.of(Row.of(1, 1L, "-"), Row.of(2, 2L, "--"), Row.of(3, 3L, "---")); final List sinkData = - Arrays.asList(Row.of(1, 3L, "-"), Row.of(2, 6L, "--"), Row.of(3, 9L, "---")); + List.of(Row.of(1, 3L, "-"), Row.of(2, 6L, "--"), Row.of(3, 9L, "---")); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -738,7 +740,7 @@ void testNullScalarFunction() throws Exception { @Test void testRowScalarFunction() throws Exception { final List sourceData = - Arrays.asList( + List.of( Row.of(1, Row.of(1, "1")), Row.of(2, Row.of(2, "2")), Row.of(3, Row.of(3, "3"))); @@ -761,14 +763,14 @@ void testRowScalarFunction() throws Exception { @Test void testComplexScalarFunction() throws Exception { final List sourceData = - Arrays.asList( + List.of( Row.of(1, new byte[] {1, 2, 3}), Row.of(2, new byte[] {2, 3, 4}), Row.of(3, new byte[] {3, 4, 5}), Row.of(null, null)); final List sinkData = - Arrays.asList( + List.of( Row.of( 1, "1+2012-12-12 12:12:12.123456789", @@ -834,11 +836,10 @@ void testComplexScalarFunction() throws Exception { @Test void testCustomScalarFunction() throws Exception { final List sourceData = - Arrays.asList(Row.of(1), Row.of(2), Row.of(3), Row.of((Integer) null)); + List.of(Row.of(1), Row.of(2), Row.of(3), Row.of((Integer) null)); final List sinkData = - Arrays.asList( - Row.of(1, 1, 5), Row.of(2, 2, 5), Row.of(3, 3, 5), Row.of(null, null, 5)); + List.of(Row.of(1, 1, 5), Row.of(2, 2, 5), Row.of(3, 3, 5), Row.of(null, null, 5)); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -862,7 +863,7 @@ void testCustomScalarFunction() throws Exception { @Test void testVarArgScalarFunction() { - final List sourceData = Arrays.asList(Row.of("Bob", 1), Row.of("Alice", 2)); + final List sourceData = List.of(Row.of("Bob", 1), Row.of("Alice", 2)); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -890,7 +891,7 @@ void testVarArgScalarFunction() { final List actual = CollectionUtil.iteratorToList(result.collect()); final List expected = - Arrays.asList( + List.of( Row.of( "(INT...)", "(INT...)", @@ -909,7 +910,7 @@ void testVarArgScalarFunction() { @Test void testRawLiteralScalarFunction() throws Exception { final List sourceData = - Arrays.asList( + List.of( Row.of(1, DayOfWeek.MONDAY), Row.of(2, DayOfWeek.FRIDAY), Row.of(null, null)); @@ -968,13 +969,390 @@ void testRawLiteralScalarFunction() throws Exception { assertThat(TestCollectionTableFactory.getResult()).containsExactlyInAnyOrder(sinkData); } + @ParameterizedTest(name = "{0}") + @MethodSource("inputForTestCalcLocalRefReuse") + void testCalcLocalRefReuse( + String sql, List expectedRows, int expectedDetCalls, int expectedNonDetCalls) { + final List sourceData = List.of(Row.of("Bob"), Row.of("Alice")); + + TestCollectionTableFactory.reset(); + TestCollectionTableFactory.initData(sourceData); + CountingUpperScalarFunction.COUNT.set(0); + NonDeterministicCountingScalarFunction.COUNT.set(0); + + tEnv().createTemporarySystemFunction("Det", CountingUpperScalarFunction.class); + tEnv().createTemporarySystemFunction( + "Nondet", NonDeterministicCountingScalarFunction.class); + tEnv().executeSql("CREATE TABLE SourceTable (s STRING) WITH ('connector' = 'COLLECTION')"); + + final List actual = CollectionUtil.iteratorToList(tEnv().executeSql(sql).collect()); + + assertThat(actual).containsExactlyElementsOf(expectedRows); + assertThat(CountingUpperScalarFunction.COUNT.get()) + .as("Deterministic invocations") + .isEqualTo(expectedDetCalls); + assertThat(NonDeterministicCountingScalarFunction.COUNT.get()) + .as("Non-deterministic invocations") + .isEqualTo(expectedNonDetCalls); + } + + static Stream inputForTestCalcLocalRefReuse() { + return Stream.of( + Arguments.of( + "SELECT Det(s), Det(s), Det(s) FROM SourceTable", + List.of(Row.of("BOB", "BOB", "BOB"), Row.of("ALICE", "ALICE", "ALICE")), + 2, // expected localref calls: rows × 1 (cached) + 0), + Arguments.of( + "SELECT Det(s), Det(s), UPPER(s) FROM SourceTable", + List.of(Row.of("BOB", "BOB", "BOB"), Row.of("ALICE", "ALICE", "ALICE")), + 2, // rows × 1 (cached); built-in UPPER not counted + 0), + Arguments.of( + "SELECT Det(Det(s)), Det(Det(s)), Det(Det(s)) FROM SourceTable", + List.of(Row.of("BOB", "BOB", "BOB"), Row.of("ALICE", "ALICE", "ALICE")), + 4, // rows × 2 layers + 0), + Arguments.of( + "SELECT Nondet(s), Nondet(s), Nondet(s) FROM SourceTable", + List.of( + Row.of("BOB_1", "BOB_2", "BOB_3"), + Row.of("ALICE_4", "ALICE_5", "ALICE_6")), + 0, + 6 // rows × 3 projections + ), + Arguments.of( + "SELECT Nondet(Det(s)), Nondet(Det(s)), Nondet(Det(s)) FROM SourceTable", + List.of( + Row.of("BOB_1", "BOB_2", "BOB_3"), + Row.of("ALICE_4", "ALICE_5", "ALICE_6")), + 2, // rows × 1 (inner cached) + 6 // rows × 3 projections + ), + Arguments.of( + "SELECT Det(Nondet(s)), Det(Nondet(s)), Det(Nondet(s)) FROM SourceTable", + List.of( + Row.of("BOB_1", "BOB_2", "BOB_3"), + Row.of("ALICE_4", "ALICE_5", "ALICE_6")), + 6, // rows × 3 (nondet input disables cache) + 6 // rows × 3 projections + ), + // shared Det in filter → cached once per row + Arguments.of( + "SELECT s FROM SourceTable" + + " WHERE Det(s) IS NOT NULL AND Det(s) <> '' AND Det(s) <> ' '", + List.of(Row.of("Bob"), Row.of("Alice")), + 2, + 0), + // mixed UDF + built-in + Arguments.of( + "SELECT s FROM SourceTable" + + " WHERE Det(s) IS NOT NULL AND Det(s) <> '' AND UPPER(s) <> ''", + List.of(Row.of("Bob"), Row.of("Alice")), + 2, + 0), + // nested Det in filter; both layers cached + Arguments.of( + "SELECT s FROM SourceTable" + + " WHERE Det(Det(s)) IS NOT NULL" + + " AND Det(Det(s)) <> '' AND Det(Det(s)) <> ' '", + List.of(Row.of("Bob"), Row.of("Alice")), + 4, + 0), + // non-deterministic in filter — never cached + Arguments.of( + "SELECT s FROM SourceTable" + + " WHERE Nondet(s) IS NOT NULL" + + " AND Nondet(s) <> '' AND Nondet(s) <> ' '", + List.of(Row.of("Bob"), Row.of("Alice")), + 0, + 6), + // outer nondet, inner Det cached + Arguments.of( + "SELECT s FROM SourceTable" + + " WHERE Nondet(Det(s)) IS NOT NULL" + + " AND Nondet(Det(s)) <> '' AND Nondet(Det(s)) <> ' '", + List.of(Row.of("Bob"), Row.of("Alice")), + 2, + 6), + // Det with nondet input → cache bypassed + Arguments.of( + "SELECT s FROM SourceTable" + + " WHERE Det(Nondet(s)) IS NOT NULL" + + " AND Det(Nondet(s)) <> '' AND Det(Nondet(s)) <> ' '", + List.of(Row.of("Bob"), Row.of("Alice")), + 6, + 6), + // filter ↔ projection share via unified program + Arguments.of( + "SELECT Det(s) FROM SourceTable WHERE Det(s) = 'BOB'", + List.of(Row.of("BOB")), + 2, + 0), + Arguments.of( + "SELECT Det(s), Det(s) FROM SourceTable WHERE Det(s) = 'BOB'", + List.of(Row.of("BOB", "BOB")), + 2, + 0), + + // --------------------------------------------------------------------------- + // JSON construction scenarios. These verify that the localref / RexProgram CSE + // cache also fires when the shared sub-expression is wrapped inside (or itself + // is) a JSON_OBJECT / JSON_ARRAY / JSON_STRING call. + // --------------------------------------------------------------------------- + + // JSON_OBJECT × 2 sharing inner Det → cached once per row. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'a' VALUE Det(s))," + + " JSON_OBJECT(KEY 'b' VALUE Det(s))" + + " FROM SourceTable", + List.of( + Row.of("{\"a\":\"BOB\"}", "{\"b\":\"BOB\"}"), + Row.of("{\"a\":\"ALICE\"}", "{\"b\":\"ALICE\"}")), + 2, // rows × 1 (cached) + 0), + // JSON_ARRAY × 2 sharing inner Det → cached. + Arguments.of( + "SELECT JSON_ARRAY(Det(s)), JSON_ARRAY(Det(s)) FROM SourceTable", + List.of( + Row.of("[\"BOB\"]", "[\"BOB\"]"), + Row.of("[\"ALICE\"]", "[\"ALICE\"]")), + 2, + 0), + // JSON_STRING × 2 sharing inner Det → cached. + Arguments.of( + "SELECT JSON_STRING(Det(s)), JSON_STRING(Det(s)) FROM SourceTable", + List.of(Row.of("\"BOB\"", "\"BOB\""), Row.of("\"ALICE\"", "\"ALICE\"")), + 2, + 0), + // Mixed JSON_OBJECT + JSON_ARRAY sharing same Det. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'k' VALUE Det(s)), JSON_ARRAY(Det(s))" + + " FROM SourceTable", + List.of( + Row.of("{\"k\":\"BOB\"}", "[\"BOB\"]"), + Row.of("{\"k\":\"ALICE\"}", "[\"ALICE\"]")), + 2, + 0), + // Mixed JSON_OBJECT + JSON_STRING sharing same Det. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'k' VALUE Det(s)), JSON_STRING(Det(s))" + + " FROM SourceTable", + List.of( + Row.of("{\"k\":\"BOB\"}", "\"BOB\""), + Row.of("{\"k\":\"ALICE\"}", "\"ALICE\"")), + 2, + 0), + // JSON_OBJECT × 3 sharing same Det → cached across all 3 sites. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'a' VALUE Det(s))," + + " JSON_OBJECT(KEY 'b' VALUE Det(s))," + + " JSON_OBJECT(KEY 'c' VALUE Det(s))" + + " FROM SourceTable", + List.of( + Row.of("{\"a\":\"BOB\"}", "{\"b\":\"BOB\"}", "{\"c\":\"BOB\"}"), + Row.of( + "{\"a\":\"ALICE\"}", + "{\"b\":\"ALICE\"}", + "{\"c\":\"ALICE\"}")), + 2, + 0), + // Nested Det(Det(s)) inside two JSON_OBJECT projections → both layers cached. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'a' VALUE Det(Det(s)))," + + " JSON_OBJECT(KEY 'b' VALUE Det(Det(s)))" + + " FROM SourceTable", + List.of( + Row.of("{\"a\":\"BOB\"}", "{\"b\":\"BOB\"}"), + Row.of("{\"a\":\"ALICE\"}", "{\"b\":\"ALICE\"}")), + 4, // rows × 2 layers + 0), + // Nondet inside two JSON_OBJECT projections → never cached. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'a' VALUE Nondet(s))," + + " JSON_OBJECT(KEY 'b' VALUE Nondet(s))" + + " FROM SourceTable", + List.of( + Row.of("{\"a\":\"BOB_1\"}", "{\"b\":\"BOB_2\"}"), + Row.of("{\"a\":\"ALICE_3\"}", "{\"b\":\"ALICE_4\"}")), + 0, + 4 // rows × 2 projections + ), + // Outer Nondet, inner Det inside two JSON_OBJECT projections — Det cached. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'a' VALUE Nondet(Det(s)))," + + " JSON_OBJECT(KEY 'b' VALUE Nondet(Det(s)))" + + " FROM SourceTable", + List.of( + Row.of("{\"a\":\"BOB_1\"}", "{\"b\":\"BOB_2\"}"), + Row.of("{\"a\":\"ALICE_3\"}", "{\"b\":\"ALICE_4\"}")), + 2, // inner Det cached + 4), + // Outer Det, inner Nondet → outer cache disabled by nondet operand. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'a' VALUE Det(Nondet(s)))," + + " JSON_OBJECT(KEY 'b' VALUE Det(Nondet(s)))" + + " FROM SourceTable", + List.of( + Row.of("{\"a\":\"BOB_1\"}", "{\"b\":\"BOB_2\"}"), + Row.of("{\"a\":\"ALICE_3\"}", "{\"b\":\"ALICE_4\"}")), + 4, // outer Det not cached (nondet operand) + 4), + // Filter ↔ JSON projection share Det via unified program. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'k' VALUE Det(s))" + + " FROM SourceTable WHERE Det(s) = 'BOB'", + List.of(Row.of("{\"k\":\"BOB\"}")), + 2, + 0), + // Shared inner JSON_OBJECT(KEY 'k' VALUE Det(s)) inside two outer JSON_OBJECT + // projections — verifies CSE works when the cached node is itself a JSON + // construction call (and validates the JSON helpers' RexLocalRef deref path + // along the way). + Arguments.of( + "SELECT JSON_OBJECT(KEY 'outer1' VALUE JSON_OBJECT(KEY 'k' VALUE Det(s)))," + + " JSON_OBJECT(KEY 'outer2' VALUE JSON_OBJECT(KEY 'k' VALUE Det(s)))" + + " FROM SourceTable", + List.of( + Row.of( + "{\"outer1\":{\"k\":\"BOB\"}}", + "{\"outer2\":{\"k\":\"BOB\"}}"), + Row.of( + "{\"outer1\":{\"k\":\"ALICE\"}}", + "{\"outer2\":{\"k\":\"ALICE\"}}")), + 2, + 0), + // Shared inner JSON_ARRAY(Det(s)) inside two outer JSON_OBJECT projections. + Arguments.of( + "SELECT JSON_OBJECT(KEY 'a' VALUE JSON_ARRAY(Det(s)))," + + " JSON_OBJECT(KEY 'b' VALUE JSON_ARRAY(Det(s)))" + + " FROM SourceTable", + List.of( + Row.of("{\"a\":[\"BOB\"]}", "{\"b\":[\"BOB\"]}"), + Row.of("{\"a\":[\"ALICE\"]}", "{\"b\":[\"ALICE\"]}")), + 2, + 0), + // Shared inner JSON_OBJECT(KEY 'k' VALUE Det(s)) inside two JSON_ARRAY + // projections. + Arguments.of( + "SELECT JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE Det(s)))," + + " JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE Det(s)))" + + " FROM SourceTable", + List.of( + Row.of("[{\"k\":\"BOB\"}]", "[{\"k\":\"BOB\"}]"), + Row.of("[{\"k\":\"ALICE\"}]", "[{\"k\":\"ALICE\"}]")), + 2, + 0)); + } + + @Test + void testLocalRefReuseForMixedArgs() { + final List sourceData = List.of(Row.of("Bob"), Row.of("Alice")); + final int callSites = 2; + + TestCollectionTableFactory.reset(); + TestCollectionTableFactory.initData(sourceData); + CountingUpperScalarFunction.COUNT.set(0); + NonDeterministicCountingScalarFunction.COUNT.set(0); + CountingConcat3ScalarFunction.COUNT.set(0); + + tEnv().createTemporarySystemFunction("Det", CountingUpperScalarFunction.class); + tEnv().createTemporarySystemFunction( + "Nondet", NonDeterministicCountingScalarFunction.class); + tEnv().createTemporarySystemFunction("Concat3", CountingConcat3ScalarFunction.class); + tEnv().executeSql("CREATE TABLE SourceTable (s STRING) WITH ('connector' = 'COLLECTION')"); + + final List actual = + CollectionUtil.iteratorToList( + tEnv().executeSql( + "SELECT Concat3(Det(s), Nondet(s), Det(s))," + + " Concat3(Det(s), Nondet(s), Det(s))" + + " FROM SourceTable") + .collect()); + + assertThat(actual) + .containsExactly( + Row.of("BOB/BOB_1/BOB", "BOB/BOB_2/BOB"), + Row.of("ALICE/ALICE_3/ALICE", "ALICE/ALICE_4/ALICE")); + + assertThat(CountingUpperScalarFunction.COUNT.get()).isEqualTo(sourceData.size()); + assertThat(NonDeterministicCountingScalarFunction.COUNT.get()) + .isEqualTo(sourceData.size() * callSites); + // Concat3 is deterministic however has non-deterministic input + assertThat(CountingConcat3ScalarFunction.COUNT.get()) + .isEqualTo(sourceData.size() * callSites); + } + + @Test + void testCalcSharesSubExpressionBetweenFilterAndProjection() { + final List sourceData = + List.of(Row.of("Bob"), Row.of("Bob"), Row.of("Alice"), Row.of("Alice")); + + TestCollectionTableFactory.reset(); + TestCollectionTableFactory.initData(sourceData); + CountingUpperScalarFunction.COUNT.set(0); + + tEnv().createTemporarySystemFunction("CountingUpper", CountingUpperScalarFunction.class); + tEnv().executeSql("CREATE TABLE SourceTable (s STRING) WITH ('connector' = 'COLLECTION')"); + + final List actual = + CollectionUtil.iteratorToList( + tEnv().executeSql( + "SELECT CountingUpper(s) FROM SourceTable" + + " WHERE CountingUpper(s) = 'BOB' AND CountingUpper(s) <> 'BOB2'") + .collect()); + + assertThat(actual).containsExactly(Row.of("BOB"), Row.of("BOB")); + + // Filter and projection share via the unified RexProgram, so the UDF runs once per + // source row regardless of how many call sites name it. + assertThat(CountingUpperScalarFunction.COUNT.get()).isEqualTo(sourceData.size()); + } + + /** + * Pins the CASE-WHEN guard interaction with the RexLocalRef cache. + * + *

Prior to scoped caching, RexProgramBuilder collapsed the division {@code a / b} into a + * single exprList entry; the codegen visitor cached the body and {@code + * CalcCodeGenerator.reuseLocalRefCode()} hoisted that body to the top of the generated method, + * evaluating {@code a / b} for every row regardless of the surrounding {@code CASE WHEN b > 0}. + * Rows with {@code b = 0} then threw {@code java.lang.ArithmeticException: Division undefined} + * — caught in the wild on TPC-DS query 34. With scoped caching the division body lives inside + * the THEN-branch's generated code and never executes when the guard is false. + */ + @Test + void testCalcCaseGuardShortCircuit() { + final List sourceData = + List.of(Row.of(10, 0), Row.of(10, 2), Row.of(20, 0), Row.of(30, 5), Row.of(40, 0)); + + TestCollectionTableFactory.reset(); + TestCollectionTableFactory.initData(sourceData); + tEnv().executeSql( + "CREATE TABLE SourceTable (a INT, b INT) WITH ('connector' = 'COLLECTION')"); + + final List actual = + CollectionUtil.iteratorToList( + tEnv().executeSql( + "SELECT a FROM SourceTable WHERE" + + " (CASE WHEN b > 0" + + " THEN CAST(a AS DECIMAL(7,2))" + + " / CAST(b AS DECIMAL(7,2))" + + " ELSE NULL END) > 1.2") + .collect()); + + // Row(10,2) → 10/2 = 5.0 (>1.2) + // Row(30,5) → 30/5 = 6.0 (>1.2) + // Rows with b=0 must NOT enter the THEN-branch (the division would fail). + assertThat(actual).containsExactly(Row.of(10), Row.of(30)); + } + @Test void testStructuredScalarFunction() throws Exception { final List sourceData = - Arrays.asList(Row.of("Bob", 42), Row.of("Alice", 12), Row.of(null, 0)); + List.of(Row.of("Bob", 42), Row.of("Alice", 12), Row.of(null, 0)); final List sinkData = - Arrays.asList( + List.of( Row.of("Bob 42", "Tyler"), Row.of("Alice 12", "Tyler"), Row.of("<>", "Tyler")); @@ -1020,11 +1398,10 @@ void testInvalidCustomScalarFunction() { @Test void testRowTableFunction() throws Exception { final List sourceData = - Arrays.asList( - Row.of("1,2,3"), Row.of("2,3,4"), Row.of("3,4,5"), Row.of((String) null)); + List.of(Row.of("1,2,3"), Row.of("2,3,4"), Row.of("3,4,5"), Row.of((String) null)); final List sinkData = - Arrays.asList( + List.of( Row.of("1,2,3", new String[] {"1", "2", "3"}), Row.of("2,3,4", new String[] {"2", "3", "4"}), Row.of("3,4,5", new String[] {"3", "4", "5"})); @@ -1048,10 +1425,9 @@ void testRowTableFunction() throws Exception { @Test void testStructuredTableFunction() throws Exception { final List sourceData = - Arrays.asList(Row.of("Bob", 42), Row.of("Alice", 12), Row.of(null, 0)); + List.of(Row.of("Bob", 42), Row.of("Alice", 12), Row.of(null, 0)); - final List sinkData = - Arrays.asList(Row.of("Bob", 42), Row.of("Alice", 12), Row.of(null, 0)); + final List sinkData = List.of(Row.of("Bob", 42), Row.of("Alice", 12), Row.of(null, 0)); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1157,10 +1533,10 @@ void testNamedArgumentsTableFunctionWithOptionalArguments() throws Exception { @Test void testNamedArgumentsScalarFunction() throws Exception { final List sourceData = - Arrays.asList(Row.of(1, 2, "str1"), Row.of(3, 4, "str2"), Row.of(5, 6, "str3")); + List.of(Row.of(1, 2, "str1"), Row.of(3, 4, "str2"), Row.of(5, 6, "str3")); final List sinkData = - Arrays.asList(Row.of(1, 2, "1: 2"), Row.of(3, 4, "3: 4"), Row.of(5, 6, "5: 6")); + List.of(Row.of(1, 2, "1: 2"), Row.of(3, 4, "3: 4"), Row.of(5, 6, "5: 6")); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1182,7 +1558,7 @@ void testNamedArgumentsScalarFunction() throws Exception { @Test void testNamedParametersScalarFunctionWithOverloadedMethod() throws Exception { final List sourceData = - Arrays.asList(Row.of(1, 2, "str1"), Row.of(3, 4, "str2"), Row.of(5, 6, "str3")); + List.of(Row.of(1, 2, "str1"), Row.of(3, 4, "str2"), Row.of(5, 6, "str3")); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1206,8 +1582,7 @@ void testNamedParametersScalarFunctionWithOverloadedMethod() throws Exception { @Test void testNamedArgumentsScalarFunctionWithOptionalArguments() throws Exception { - final List sinkData = - Arrays.asList(Row.of("s1: null", "null: s2", "s1: s2", "null: null")); + final List sinkData = List.of(Row.of("s1: null", "null: s2", "s1: s2", "null: null")); TestCollectionTableFactory.reset(); tEnv().executeSql( @@ -1230,14 +1605,13 @@ void testNamedArgumentsScalarFunctionWithOptionalArguments() throws Exception { @Test void testNamedArgumentAggregateFunction() throws Exception { final List sourceData = - Arrays.asList( + List.of( Row.of(LocalDateTime.parse("2007-12-03T10:15:30"), "a", "b", 1, 2), Row.of(LocalDateTime.parse("2007-12-03T10:15:30"), "c", "d", 33, 44), Row.of(LocalDateTime.parse("2007-12-03T10:15:32"), "e", "f", 5, 6), Row.of(LocalDateTime.parse("2007-12-03T10:15:32"), "gg", "hh", 7, 88)); - final List sinkData = - Arrays.asList(Row.of("a: b", "b: a"), Row.of("gg: hh", "hh: gg")); + final List sinkData = List.of(Row.of("a: b", "b: a"), Row.of("gg: hh", "hh: gg")); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1265,14 +1639,14 @@ void testNamedArgumentAggregateFunction() throws Exception { @Test void testNamedArgumentAggregateFunctionWithOptionalArguments() throws Exception { final List sourceData = - Arrays.asList( + List.of( Row.of(LocalDateTime.parse("2007-12-03T10:15:30"), "a", "b", 1, 2), Row.of(LocalDateTime.parse("2007-12-03T10:15:30"), "c", "d", 33, 44), Row.of(LocalDateTime.parse("2007-12-03T10:15:32"), "e", "f", 5, 6), Row.of(LocalDateTime.parse("2007-12-03T10:15:32"), "gg", "hh", 7, 88)); final List sinkData = - Arrays.asList(Row.of("a: null", "null: b"), Row.of("gg: null", "null: hh")); + List.of(Row.of("a: null", "null: b"), Row.of("gg: null", "null: hh")); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1346,7 +1720,7 @@ void testInvalidUseOfTableFunction() { @Test void testAggregateFunction() throws Exception { final List sourceData = - Arrays.asList( + List.of( Row.of(LocalDateTime.parse("2007-12-03T10:15:30"), "Bob"), Row.of(LocalDateTime.parse("2007-12-03T10:15:30"), "Alice"), Row.of(LocalDateTime.parse("2007-12-03T10:15:30"), null), @@ -1355,7 +1729,7 @@ void testAggregateFunction() throws Exception { Row.of(LocalDateTime.parse("2007-12-03T10:15:32"), "Alice")); final List sinkData = - Arrays.asList( + List.of( Row.of( "Jonathan", "Alice=(Alice, 5), Bob=(Bob, 3), Jonathan=(Jonathan, 8)"), @@ -1409,10 +1783,10 @@ void testLookupTableFunctionWithoutHintLevel1() private void testLookupTableFunctionBase(String lookupTableFunctionClassName) throws ExecutionException, InterruptedException { - final List sourceData = Arrays.asList(Row.of("Bob"), Row.of("Alice")); + final List sourceData = List.of(Row.of("Bob"), Row.of("Alice")); final List sinkData = - Arrays.asList( + List.of( Row.of("Bob", new byte[0]), Row.of("Bob", new byte[] {66, 111, 98}), Row.of("Alice", new byte[0]), @@ -1459,7 +1833,7 @@ private void testLookupTableFunctionBase(String lookupTableFunctionClassName) @Test void testSpecializedFunction() { final List sourceData = - Arrays.asList( + List.of( Row.of("Bob", 1, new BigDecimal("123.45")), Row.of("Alice", 2, new BigDecimal("123.456"))); @@ -1489,7 +1863,7 @@ void testSpecializedFunction() { final List actual = CollectionUtil.iteratorToList(result.collect()); final List expected = - Arrays.asList( + List.of( Row.of("CHAR(7) NOT NULL", "STRING", "INT", "DECIMAL(6, 3)"), Row.of("CHAR(7) NOT NULL", "STRING", "INT", "DECIMAL(6, 3)")); assertThat(actual).isEqualTo(expected); @@ -1498,7 +1872,7 @@ void testSpecializedFunction() { @Test void testSpecializedFunctionWithExpressionEvaluation() { final List sourceData = - Arrays.asList( + List.of( Row.of("Bob", new Integer[] {1, 2, 3}, new BigDecimal("123.000")), Row.of("Bob", new Integer[] {4, 5, 6}, new BigDecimal("123.456")), Row.of("Alice", new Integer[] {1, 2, 3}, null), @@ -1530,7 +1904,7 @@ void testSpecializedFunctionWithExpressionEvaluation() { final List actual = CollectionUtil.iteratorToList(result.collect()); final List expected = - Arrays.asList( + List.of( Row.of("Bob", null, null), Row.of( "Bob", @@ -1543,7 +1917,7 @@ void testSpecializedFunctionWithExpressionEvaluation() { @Test void testTimestampNotNull() { - List sourceData = Arrays.asList(Row.of(1), Row.of(2)); + List sourceData = List.of(Row.of(1), Row.of(2)); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1557,7 +1931,7 @@ void testTimestampNotNull() { @Test void testIsNullType() { - List sourceData = Arrays.asList(Row.of(1), Row.of((Object) null)); + List sourceData = List.of(Row.of(1), Row.of((Object) null)); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1571,7 +1945,7 @@ void testIsNullType() { @Test void testWithBoolNotNullTypeHint() { - List sourceData = Arrays.asList(Row.of(1, 2), Row.of(2, 3)); + List sourceData = List.of(Row.of(1, 2), Row.of(2, 3)); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1605,7 +1979,7 @@ void testUsingAddJar() throws Exception { @Test void testUdfWithMultiLocalVariables() { - List sourceData = Arrays.asList(Row.of(1L, 2L), Row.of(2L, 3L)); + List sourceData = List.of(Row.of(1L, 2L), Row.of(2L, 3L)); TestCollectionTableFactory.reset(); TestCollectionTableFactory.initData(sourceData); @@ -1620,7 +1994,7 @@ void testUdfWithMultiLocalVariables() { CollectionUtil.iteratorToList( tEnv().executeSql("SELECT MultiLocalVariables(x, y) FROM SourceTable") .collect()); - assertThat(actualRows).isEqualTo(Arrays.asList(Row.of(2L), Row.of(6L))); + assertThat(actualRows).isEqualTo(List.of(Row.of(2L), Row.of(6L))); } // -------------------------------------------------------------------------------------------- @@ -1757,6 +2131,41 @@ public TypeInference getTypeInference(DataTypeFactory typeFactory) { } } + /** Deterministic function with a counter. */ + public static class CountingUpperScalarFunction extends ScalarFunction { + public static final AtomicInteger COUNT = new AtomicInteger(); + + public String eval(String s) { + COUNT.incrementAndGet(); + return s == null ? null : s.toUpperCase(); + } + } + + /** Deterministic function with a counter and 3 args. */ + public static class CountingConcat3ScalarFunction extends ScalarFunction { + public static final AtomicInteger COUNT = new AtomicInteger(); + + public String eval(String a, String b, String c) { + COUNT.incrementAndGet(); + return a + "/" + b + "/" + c; + } + } + + /** Non-deterministic function with a counter. */ + public static class NonDeterministicCountingScalarFunction extends ScalarFunction { + public static final AtomicInteger COUNT = new AtomicInteger(); + + public String eval(String s) { + final int count = COUNT.incrementAndGet(); + return s == null ? null : s.toUpperCase() + "_" + count; + } + + @Override + public boolean isDeterministic() { + return false; + } + } + /** Function that has a custom type inference that is broader than the actual implementation. */ public static class CustomScalarFunction extends ScalarFunction { public Integer eval(Integer... args) { From 90a1f6bc63bebf719c29ec5c4b0a88b9024b6bdf Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Tue, 5 May 2026 13:06:11 +0200 Subject: [PATCH 02/12] Address feedback --- .../table/planner/utils/ShortcutUtils.java | 7 ------ .../planner/codegen/ExprCodeGenerator.scala | 9 +++---- .../planner/codegen/JsonGenerateUtils.scala | 24 +++++++++---------- .../calls/BridgingSqlFunctionCallGen.scala | 2 -- .../codegen/calls/JsonArrayCallGen.scala | 2 -- .../codegen/calls/JsonObjectCallGen.scala | 2 -- .../codegen/calls/JsonStringCallGen.scala | 2 -- .../planner/plan/utils/FlinkRexUtil.scala | 8 +++++++ 8 files changed, 25 insertions(+), 31 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java index 2016f8a0c10c7..af5d2ee0e3aeb 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java @@ -199,13 +199,6 @@ public static boolean isOneOfFunctionDefinitions( return false; } - public static RexNode expandLocalRef(RexNode operand, @Nullable List exprs) { - while (operand instanceof RexLocalRef && exprs != null) { - operand = exprs.get(((RexLocalRef) operand).getIndex()); - } - return operand; - } - public static boolean isDeterministicThroughProgram( RexNode node, @Nullable List exprs) { if (exprs == null) { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala index 97454165db992..e9179f76f676d 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala @@ -37,7 +37,7 @@ import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable._ import org.apache.flink.table.planner.functions.sql.SqlThrowExceptionFunction import org.apache.flink.table.planner.functions.utils.{ScalarSqlFunction, TableSqlFunction} -import org.apache.flink.table.planner.plan.utils.RexLiteralUtil +import org.apache.flink.table.planner.plan.utils.{FlinkRexUtil, RexLiteralUtil} import org.apache.flink.table.planner.utils.ShortcutUtils import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromLogicalTypeToDataType import org.apache.flink.table.runtime.types.PlannerTypeUtils.isInteroperable @@ -587,8 +587,9 @@ class ExprCodeGenerator( ctx.popLocalRefScope() throw t } - if (scopedBodies.isEmpty) operandExpr - else + if (scopedBodies.isEmpty) { + operandExpr + } else GeneratedExpression( operandExpr.resultTerm, operandExpr.nullTerm, @@ -983,7 +984,7 @@ class ExprCodeGenerator( // RexLocalRef. JSON_OBJECT/JSON_ARRAY operands recognised as JSON via // isSupportedJsonOperand may therefore arrive here as a RexLocalRef; resolve it back to // the underlying RexCall before casting. - val jsonCall = ShortcutUtils + val jsonCall = FlinkRexUtil .expandLocalRef(operand, if (rexProgram == null) null else rexProgram.getExprList) .asInstanceOf[RexCall] val jsonOperands = jsonCall.getOperands.map(_.accept(this)) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala index 73c2f30ad8819..bbc12d96c49bd 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala @@ -23,9 +23,9 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.util.RawV import org.apache.flink.table.api.{DataTypes, JsonOnNull} import org.apache.flink.table.functions.BuiltInFunctionDefinitions import org.apache.flink.table.planner.codegen.CodeGenUtils._ +import org.apache.flink.table.planner.plan.utils.FlinkRexUtil import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala import org.apache.flink.table.planner.utils.ShortcutUtils -import org.apache.flink.table.planner.utils.ShortcutUtils.expandLocalRef import org.apache.flink.table.runtime.functions.SqlJsonUtils import org.apache.flink.table.runtime.typeutils.TypeCheckUtils.isCharacterString import org.apache.flink.table.types.logical._ @@ -179,24 +179,24 @@ object JsonGenerateUtils { } /** Determines whether the given operand is a call to a JSON_OBJECT. */ - def isJsonObjectOperand(operand: RexNode, exprs: java.util.List[RexNode]): Boolean = + def isJsonObjectOperand(operand: RexNode, localRefs: java.util.List[RexNode]): Boolean = ShortcutUtils.isOneOfFunctionDefinitions( - expandLocalRef(operand, exprs), + FlinkRexUtil.expandLocalRef(operand, localRefs), BuiltInFunctionDefinitions.JSON_OBJECT) /** Determines whether the given operand is a call to a JSON_ARRAY. */ - def isJsonArrayOperand(operand: RexNode, exprs: java.util.List[RexNode]): Boolean = + def isJsonArrayOperand(operand: RexNode, localRefs: java.util.List[RexNode]): Boolean = ShortcutUtils.isOneOfFunctionDefinitions( - expandLocalRef(operand, exprs), + FlinkRexUtil.expandLocalRef(operand, localRefs), BuiltInFunctionDefinitions.JSON_ARRAY) /** * Determines whether the given operand is a call to a JSON_OBJECT or JSON_ARRAY whose result * should be inserted as a raw value instead of as a character string. */ - def isJsonObjectOrArrayOperand(operand: RexNode, exprs: java.util.List[RexNode]): Boolean = + def isJsonObjectOrArrayOperand(operand: RexNode, localRefs: java.util.List[RexNode]): Boolean = ShortcutUtils.isOneOfFunctionDefinitions( - expandLocalRef(operand, exprs), + FlinkRexUtil.expandLocalRef(operand, localRefs), BuiltInFunctionDefinitions.JSON_OBJECT, BuiltInFunctionDefinitions.JSON_ARRAY) @@ -204,9 +204,9 @@ object JsonGenerateUtils { * Determines whether the given operand is a call to JSON function whose call currently just * passes through the input value as output value. */ - def isJsonFunctionOperand(operand: RexNode, exprs: java.util.List[RexNode]): Boolean = + def isJsonFunctionOperand(operand: RexNode, localRefs: java.util.List[RexNode]): Boolean = ShortcutUtils.isOneOfFunctionDefinitions( - expandLocalRef(operand, exprs), + FlinkRexUtil.expandLocalRef(operand, localRefs), BuiltInFunctionDefinitions.JSON) /** @@ -219,9 +219,9 @@ object JsonGenerateUtils { operand: RexNode, call: RexNode, i: Int, - exprs: java.util.List[RexNode]): Boolean = { - isJsonFunctionOperand(operand, exprs) && - (isJsonArrayOperand(call, exprs) || isJsonObjectOperand(call, exprs) && (i % 2) == 0) + localRefs: java.util.List[RexNode]): Boolean = { + isJsonFunctionOperand(operand, localRefs) && + (isJsonArrayOperand(call, localRefs) || isJsonObjectOperand(call, localRefs) && (i % 2) == 0) } /** Generates a method to convert arrays into [[ArrayNode]]. */ diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala index 77a3850f664f8..83223e63200d9 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala @@ -39,8 +39,6 @@ import java.util.Collections */ class BridgingSqlFunctionCallGen(call: RexCall, rexProgram: RexProgram) extends CallGenerator { - def this(call: RexCall) = this(call, null) - override def generate( ctx: CodeGeneratorContext, operands: Seq[GeneratedExpression], diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala index eadc6413026dd..d7fcc73695a42 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala @@ -30,8 +30,6 @@ import org.apache.calcite.rex.{RexCall, RexProgram} /** [[CallGenerator]] for `JSON_ARRAY`. */ class JsonArrayCallGen(call: RexCall, rexProgram: RexProgram) extends CallGenerator { - def this(call: RexCall) = this(call, null) - private def jsonUtils = className[SqlJsonUtils] override def generate( diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala index b876ec85e2961..d012de58995e3 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala @@ -39,8 +39,6 @@ import org.apache.calcite.rex.{RexCall, RexProgram} */ class JsonObjectCallGen(call: RexCall, rexProgram: RexProgram) extends CallGenerator { - def this(call: RexCall) = this(call, null) - private def jsonUtils = className[SqlJsonUtils] override def generate( diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala index 8dd4dc9fe7c7c..8fa2bfe6fcbf9 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala @@ -28,8 +28,6 @@ import org.apache.calcite.rex.{RexCall, RexProgram} /** [[CallGenerator]] for `JSON_STRING`. */ class JsonStringCallGen(call: RexCall, rexProgram: RexProgram) extends CallGenerator { - def this(call: RexCall) = this(call, null) - private def jsonUtils = className[SqlJsonUtils] override def generate( diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala index 5cdbbff533a66..94d262463c6ff 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala @@ -525,6 +525,14 @@ object FlinkRexUtil { RexUtil.expandSearch(rexBuilder, program, program.expandLocalRef(program.getCondition))) RelOptUtil.conjunctions(condition) } + + def expandLocalRef(operand: RexNode, localRefs: util.List[RexNode]): RexNode = { + var expanded = operand + while (expanded.isInstanceOf[RexLocalRef] && localRefs != null) { + expanded = localRefs.get(expanded.asInstanceOf[RexLocalRef].getIndex) + } + expanded + } } /** From 941526a109ef891b28a702747b6c9cb256fd6ce9 Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Wed, 6 May 2026 12:51:07 +0200 Subject: [PATCH 03/12] Address feedback --- .../sql/FunctionDefinitionQueryable.java | 30 +++++++++++++++++++ .../sql/SqlJsonArrayFunctionWrapper.java | 11 ++++++- .../sql/SqlJsonObjectFunctionWrapper.java | 11 ++++++- .../sql/SqlJsonQueryFunctionWrapper.java | 10 ++++++- .../sql/SqlJsonValueFunctionWrapper.java | 11 ++++++- .../table/planner/utils/ShortcutUtils.java | 23 +++++++------- .../planner/codegen/ExprCodeGenerator.scala | 15 ++++------ .../planner/codegen/JsonGenerateUtils.scala | 1 + 8 files changed, 85 insertions(+), 27 deletions(-) create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FunctionDefinitionQueryable.java diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FunctionDefinitionQueryable.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FunctionDefinitionQueryable.java new file mode 100644 index 0000000000000..abfa8c9200e9c --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FunctionDefinitionQueryable.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions.sql; + +import org.apache.flink.table.functions.FunctionDefinition; +import org.apache.flink.table.functions.TableFunction; +import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction; + +/** + * Function which is not bridged by {@link BridgingSqlFunction} and which is not legacy {@link + * TableFunction} however whose {@link FunctionDefinition} might be queried. + */ +public interface FunctionDefinitionQueryable { + FunctionDefinition getFunctionDefinition(); +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonArrayFunctionWrapper.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonArrayFunctionWrapper.java index e1b60699d7fb2..ffb3036b94462 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonArrayFunctionWrapper.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonArrayFunctionWrapper.java @@ -18,6 +18,9 @@ package org.apache.flink.table.planner.functions.sql; +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.table.functions.FunctionDefinition; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.fun.SqlJsonArrayFunction; @@ -29,7 +32,8 @@ * This class is a wrapper class for the {@link SqlJsonArrayFunction} but using the {@code * VARCHAR_NOT_NULL} return type inference. */ -class SqlJsonArrayFunctionWrapper extends SqlJsonArrayFunction { +class SqlJsonArrayFunctionWrapper extends SqlJsonArrayFunction + implements FunctionDefinitionQueryable { @Override public RelDataType inferReturnType(SqlOperatorBinding opBinding) { @@ -49,4 +53,9 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) { public SqlReturnTypeInference getReturnTypeInference() { return VARCHAR_NOT_NULL; } + + @Override + public FunctionDefinition getFunctionDefinition() { + return BuiltInFunctionDefinitions.JSON_ARRAY; + } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonObjectFunctionWrapper.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonObjectFunctionWrapper.java index b09ab149a641f..e6d05d9d50e55 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonObjectFunctionWrapper.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonObjectFunctionWrapper.java @@ -18,6 +18,9 @@ package org.apache.flink.table.planner.functions.sql; +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.table.functions.FunctionDefinition; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.fun.SqlJsonObjectFunction; @@ -29,7 +32,8 @@ * This class is a wrapper class for the {@link SqlJsonObjectFunction} but using the {@code * VARCHAR_NOT_NULL} return type inference. */ -class SqlJsonObjectFunctionWrapper extends SqlJsonObjectFunction { +class SqlJsonObjectFunctionWrapper extends SqlJsonObjectFunction + implements FunctionDefinitionQueryable { @Override public RelDataType inferReturnType(SqlOperatorBinding opBinding) { @@ -49,4 +53,9 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) { public SqlReturnTypeInference getReturnTypeInference() { return VARCHAR_NOT_NULL; } + + @Override + public FunctionDefinition getFunctionDefinition() { + return BuiltInFunctionDefinitions.JSON_OBJECT; + } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonQueryFunctionWrapper.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonQueryFunctionWrapper.java index 7a145ba9cce29..772708d3fd818 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonQueryFunctionWrapper.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonQueryFunctionWrapper.java @@ -19,6 +19,8 @@ package org.apache.flink.table.planner.functions.sql; import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.table.functions.FunctionDefinition; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlCallBinding; @@ -42,7 +44,8 @@ * This class is a wrapper class for the {@link SqlJsonQueryFunction} but using the {@code * VARCHAR_FORCE_NULLABLE} return type inference. */ -class SqlJsonQueryFunctionWrapper extends SqlJsonQueryFunction { +class SqlJsonQueryFunctionWrapper extends SqlJsonQueryFunction + implements FunctionDefinitionQueryable { private final SqlReturnTypeInference returnTypeInference; SqlJsonQueryFunctionWrapper() { @@ -142,4 +145,9 @@ private static RelDataType explicitTypeSpec(SqlOperatorBinding opBinding) { } return null; } + + @Override + public FunctionDefinition getFunctionDefinition() { + return BuiltInFunctionDefinitions.JSON_QUERY; + } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonValueFunctionWrapper.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonValueFunctionWrapper.java index b28ef4786e47f..06c189e047cb8 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonValueFunctionWrapper.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonValueFunctionWrapper.java @@ -18,6 +18,9 @@ package org.apache.flink.table.planner.functions.sql; +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.table.functions.FunctionDefinition; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlJsonValueReturning; import org.apache.calcite.sql.SqlOperatorBinding; @@ -35,7 +38,8 @@ * VARCHAR_FORCE_NULLABLE} return type inference by default. It also supports specifying return type * with the RETURNING keyword just like the original {@link SqlJsonValueFunction}. */ -class SqlJsonValueFunctionWrapper extends SqlJsonValueFunction { +class SqlJsonValueFunctionWrapper extends SqlJsonValueFunction + implements FunctionDefinitionQueryable { private final SqlReturnTypeInference returnTypeInference; @@ -80,4 +84,9 @@ private static RelDataType explicitTypeSpec(SqlOperatorBinding opBinding) { } return null; } + + @Override + public FunctionDefinition getFunctionDefinition() { + return BuiltInFunctionDefinitions.JSON_VALUE; + } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java index af5d2ee0e3aeb..07150b43f6009 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java @@ -24,7 +24,6 @@ import org.apache.flink.table.delegation.Planner; import org.apache.flink.table.expressions.CallExpression; import org.apache.flink.table.expressions.ResolvedExpression; -import org.apache.flink.table.functions.BuiltInFunctionDefinition; import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.FunctionKind; import org.apache.flink.table.planner.calcite.FlinkContext; @@ -32,6 +31,7 @@ import org.apache.flink.table.planner.delegation.PlannerBase; import org.apache.flink.table.planner.expressions.RexNodeExpression; import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction; +import org.apache.flink.table.planner.functions.sql.FunctionDefinitionQueryable; import org.apache.flink.table.planner.functions.utils.TableSqlFunction; import org.apache.calcite.plan.Context; @@ -155,14 +155,18 @@ public static DataTypeFactory unwrapDataTypeFactory(RelBuilder relBuilder) { return null; } final RexCall call = (RexCall) rexNode; - if (!(call.getOperator() instanceof BridgingSqlFunction)) { + final SqlOperator operator = call.getOperator(); + if (!(operator instanceof BridgingSqlFunction)) { + if (operator instanceof FunctionDefinitionQueryable) { + return ((FunctionDefinitionQueryable) operator).getFunctionDefinition(); + } // legacy - if (call.getOperator() instanceof TableSqlFunction) { - return ((TableSqlFunction) call.getOperator()).udtf(); + if (operator instanceof TableSqlFunction) { + return ((TableSqlFunction) operator).udtf(); } return null; } - return ((BridgingSqlFunction) call.getOperator()).getDefinition(); + return ((BridgingSqlFunction) operator).getDefinition(); } public static @Nullable FunctionDefinition unwrapFunctionDefinition(SqlOperator operator) { @@ -184,15 +188,8 @@ public static boolean isOneOfFunctionDefinitions( } final RexCall call = (RexCall) rexNode; final FunctionDefinition unwrapped = unwrapFunctionDefinition(call); - final String operatorName = call.getOperator().getName(); for (FunctionDefinition expected : expectedDefinitions) { - if (unwrapped != null && unwrapped == expected) { - return true; - } - if (expected instanceof BuiltInFunctionDefinition - && ((BuiltInFunctionDefinition) expected) - .getName() - .equalsIgnoreCase(operatorName)) { + if (unwrapped == expected) { return true; } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala index e9179f76f676d..e18b6b5bf1d8e 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala @@ -577,16 +577,11 @@ class ExprCodeGenerator( private def visitOperandInScopedCache(operand: RexNode): GeneratedExpression = { ctx.pushLocalRefScope() - val (operandExpr, scopedBodies) = - try { - val expr = operand.accept(this) - val popped = ctx.popLocalRefScope() - (expr, popped.values.map(_.code).mkString("\n")) - } catch { - case t: Throwable => - ctx.popLocalRefScope() - throw t - } + val (operandExpr, scopedBodies) = { + val expr = operand.accept(this) + val popped = ctx.popLocalRefScope() + (expr, popped.values.map(_.code).mkString("\n")) + } if (scopedBodies.isEmpty) { operandExpr } else diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala index bbc12d96c49bd..3478fefa63df2 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala @@ -23,6 +23,7 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.util.RawV import org.apache.flink.table.api.{DataTypes, JsonOnNull} import org.apache.flink.table.functions.BuiltInFunctionDefinitions import org.apache.flink.table.planner.codegen.CodeGenUtils._ +import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable import org.apache.flink.table.planner.plan.utils.FlinkRexUtil import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala import org.apache.flink.table.planner.utils.ShortcutUtils From 647791c439025b7cd77dc7e0e33388377978960d Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Wed, 6 May 2026 13:42:48 +0200 Subject: [PATCH 04/12] Address feedback --- .../table/planner/codegen/ExprCodeGenerator.scala | 14 +++++--------- .../planner/codegen/calls/SearchOperatorGen.scala | 14 ++++++++++++-- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala index e18b6b5bf1d8e..38011a3894371 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala @@ -521,15 +521,11 @@ class ExprCodeGenerator( } if (call.getKind == SqlKind.SEARCH) { - val sargLiteral = - if (rexProgram != null && call.getOperands.get(1).isInstanceOf[RexLocalRef]) { - rexProgram.getExprList - .get(call.getOperands.get(1).asInstanceOf[RexLocalRef].getIndex) - .asInstanceOf[RexLiteral] - } else { - call.getOperands.get(1).asInstanceOf[RexLiteral] - } - return generateSearch(ctx, generateExpression(call.getOperands.get(0)), sargLiteral) + return generateSearch( + ctx, + generateExpression(call.getOperands.get(0)), + rexProgram, + call.getOperands) } // convert operands and help giving untyped NULL literals a type diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/SearchOperatorGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/SearchOperatorGen.scala index de55d1f8c308b..0f3e36e64175f 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/SearchOperatorGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/SearchOperatorGen.scala @@ -27,7 +27,7 @@ import org.apache.flink.table.planner.plan.utils.RexLiteralUtil.toFlinkInternalV import org.apache.flink.table.types.logical.{BooleanType, LogicalType} import org.apache.flink.table.types.logical.utils.LogicalTypeMerging.findCommonType -import org.apache.calcite.rex.{RexLiteral, RexUnknownAs} +import org.apache.calcite.rex.{RexLiteral, RexLocalRef, RexNode, RexProgram, RexUnknownAs} import org.apache.calcite.util.{RangeSets, Sarg} import java.util.Arrays.asList @@ -53,7 +53,17 @@ object SearchOperatorGen { def generateSearch( ctx: CodeGeneratorContext, target: GeneratedExpression, - sargLiteral: RexLiteral): GeneratedExpression = { + rexProgram: RexProgram, + operands: java.util.List[RexNode]): GeneratedExpression = { + val sargLiteral = + if (rexProgram != null && operands.get(1).isInstanceOf[RexLocalRef]) { + rexProgram.getExprList + .get(operands.get(1).asInstanceOf[RexLocalRef].getIndex) + .asInstanceOf[RexLiteral] + } else { + operands.get(1).asInstanceOf[RexLiteral] + } + val sarg: Sarg[Nothing] = sargLiteral.getValueAs(classOf[Sarg[Nothing]]) val targetType = target.resultType val sargType = FlinkTypeFactory.toLogicalType(sargLiteral.getType) From 909ebc944f9779f445705c540c10a1ed0d2cfd51 Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Wed, 6 May 2026 13:54:14 +0200 Subject: [PATCH 05/12] Address feedback --- .../flink/table/planner/codegen/CodeGenUtils.scala | 10 ++++++++++ .../table/planner/codegen/ExprCodeGenerator.scala | 8 ++++---- .../table/planner/codegen/calls/JsonArrayCallGen.scala | 4 ++-- .../planner/codegen/calls/JsonObjectCallGen.scala | 4 ++-- .../planner/codegen/calls/JsonStringCallGen.scala | 4 ++-- 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala index 811fd1a842087..ff924fe0f30f8 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala @@ -46,6 +46,8 @@ import org.apache.flink.types.{ColumnList, Row, RowKind} import org.apache.flink.types.bitmap.Bitmap import org.apache.flink.types.variant.Variant +import org.apache.calcite.rex.{RexNode, RexProgram} + import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong, Object => JObject, Short => JShort} import java.lang.reflect.Method import java.util.concurrent.atomic.AtomicLong @@ -1120,4 +1122,12 @@ object CodeGenUtils { GenerateUtils.generateFieldAccess(ctx, inputType, inputTerm, index) } } + + def getExprsFromProgramOrNull(rexProgram: RexProgram): java.util.List[RexNode] = { + if (rexProgram == null) { + null + } else { + rexProgram.getExprList + } + } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala index 38011a3894371..1de3ca9bee880 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala @@ -513,7 +513,7 @@ class ExprCodeGenerator( val resultType = FlinkTypeFactory.toLogicalType(call.getType) // throw exception if json function is called outside JSON_OBJECT or JSON_ARRAY function - if (isJsonFunctionOperand(call, if (rexProgram == null) null else rexProgram.getExprList)) { + if (isJsonFunctionOperand(call, CodeGenUtils.getExprsFromProgramOrNull(rexProgram))) { throw new ValidationException( "The JSON() function is currently only supported inside JSON_ARRAY() or as the VALUE param" + " of JSON_OBJECT(). Example: JSON_OBJECT('a', JSON('{\"key\": \"value\"}')) or " + @@ -545,7 +545,7 @@ class ExprCodeGenerator( operand, call, i, - if (rexProgram == null) null else rexProgram.getExprList) => + CodeGenUtils.getExprsFromProgramOrNull(rexProgram)) => generateJsonCall(operand) case (o @ _, i) if condIdxs.contains(i) => visitOperandInScopedCache(o) @@ -976,7 +976,7 @@ class ExprCodeGenerator( // isSupportedJsonOperand may therefore arrive here as a RexLocalRef; resolve it back to // the underlying RexCall before casting. val jsonCall = FlinkRexUtil - .expandLocalRef(operand, if (rexProgram == null) null else rexProgram.getExprList) + .expandLocalRef(operand, CodeGenUtils.getExprsFromProgramOrNull(rexProgram)) .asInstanceOf[RexCall] val jsonOperands = jsonCall.getOperands.map(_.accept(this)) generateCallExpression( @@ -1002,5 +1002,5 @@ class ExprCodeGenerator( private def isDeterministicThroughProgram(node: RexNode): Boolean = ShortcutUtils.isDeterministicThroughProgram( node, - if (rexProgram == null) null else rexProgram.getExprList) + CodeGenUtils.getExprsFromProgramOrNull(rexProgram)) } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala index d7fcc73695a42..524e1cb763587 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala @@ -19,7 +19,7 @@ package org.apache.flink.table.planner.codegen.calls import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.{ArrayNode, NullNode} import org.apache.flink.table.api.JsonOnNull -import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, GeneratedExpression} +import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenUtils, GeneratedExpression} import org.apache.flink.table.planner.codegen.CodeGenUtils.{className, newName, primitiveTypeTermForType, BINARY_STRING} import org.apache.flink.table.planner.codegen.JsonGenerateUtils.{createNodeTerm, getOnNullBehavior} import org.apache.flink.table.runtime.functions.SqlJsonUtils @@ -48,7 +48,7 @@ class JsonArrayCallGen(call: RexCall, rexProgram: RexProgram) extends CallGenera .drop(1) .map { case (elementExpr, elementIdx) => - val exprs = if (rexProgram == null) null else rexProgram.getExprList + val exprs = CodeGenUtils.getExprsFromProgramOrNull(rexProgram) val elementTerm = createNodeTerm(ctx, elementExpr, call.operands.get(elementIdx), exprs) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala index d012de58995e3..37d126776c451 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala @@ -19,7 +19,7 @@ package org.apache.flink.table.planner.codegen.calls import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.{NullNode, ObjectNode} import org.apache.flink.table.api.JsonOnNull -import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, GeneratedExpression} +import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenUtils, GeneratedExpression} import org.apache.flink.table.planner.codegen.CodeGenUtils._ import org.apache.flink.table.planner.codegen.JsonGenerateUtils.{createNodeTerm, getOnNullBehavior} import org.apache.flink.table.runtime.functions.SqlJsonUtils @@ -58,7 +58,7 @@ class JsonObjectCallGen(call: RexCall, rexProgram: RexProgram) extends CallGener .grouped(2) .map { case Seq((keyExpr, _), (valueExpr, valueIdx)) => - val exprs = if (rexProgram == null) null else rexProgram.getExprList + val exprs = CodeGenUtils.getExprsFromProgramOrNull(rexProgram) val valueTerm = createNodeTerm(ctx, valueExpr, call.operands.get(valueIdx), exprs) onNull match { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala index 8fa2bfe6fcbf9..3655943327bd9 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala @@ -17,7 +17,7 @@ */ package org.apache.flink.table.planner.codegen.calls -import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, GeneratedExpression} +import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenUtils, GeneratedExpression} import org.apache.flink.table.planner.codegen.CodeGenUtils.{className, newName, primitiveTypeTermForType, BINARY_STRING} import org.apache.flink.table.planner.codegen.JsonGenerateUtils.createNodeTerm import org.apache.flink.table.runtime.functions.SqlJsonUtils @@ -35,7 +35,7 @@ class JsonStringCallGen(call: RexCall, rexProgram: RexProgram) extends CallGener operands: Seq[GeneratedExpression], returnType: LogicalType): GeneratedExpression = { - val exprs = if (rexProgram == null) null else rexProgram.getExprList + val exprs = CodeGenUtils.getExprsFromProgramOrNull(rexProgram) val valueTerm = createNodeTerm(ctx, operands.head, call.operands.get(0), exprs) val resultTerm = newName(ctx, "result") From c4a9a038c3c555c596fbaceb02e8b010561f66d9 Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Wed, 6 May 2026 16:37:03 +0200 Subject: [PATCH 06/12] Address feedback --- .../nodes/exec/common/CommonExecCalc.java | 2 + .../exec/common/CommonExecLookupJoin.java | 6 ++- .../plan/nodes/exec/spec/DeltaJoinTree.java | 3 +- .../exec/stream/StreamExecDeltaJoin.java | 3 +- .../planner/codegen/CalcCodeGenerator.scala | 15 +++++--- .../codegen/CodeGeneratorContext.scala | 37 ++++++++++++++----- .../codegen/LookupJoinCodeGenerator.scala | 18 ++++++--- 7 files changed, 59 insertions(+), 25 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java index 97e3c43e2ac19..ec53d746556d7 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java @@ -32,6 +32,7 @@ import org.apache.flink.table.planner.plan.nodes.exec.SingleTransformationTranslator; import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil; import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; +import org.apache.flink.table.planner.utils.ShortcutUtils; import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory; import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; import org.apache.flink.table.types.logical.RowType; @@ -104,6 +105,7 @@ protected Transformation translateToPlanInternal( (RowType) getOutputType(), JavaScalaConversionUtil.toScala(projection), JavaScalaConversionUtil.toScala(Optional.ofNullable(this.condition)), + ShortcutUtils.unwrapTypeFactory(planner), retainHeader, getClass().getSimpleName()); return ExecNodeUtil.createOneInputTransformation( diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java index 05b07458a10ed..31cb09a545f6d 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java @@ -502,7 +502,8 @@ protected StreamOperatorFactory createAsyncLookupJoin( JavaScalaConversionUtil.toScala(projectionOnTemporalTable), filterOnTemporalTable, projectionOutputRelDataType, - tableSourceRowType); + tableSourceRowType, + ShortcutUtils.unwrapTypeFactory(relBuilder)); asyncFunc = new AsyncLookupJoinWithCalcRunner( generatedFuncWithType.tableFunc(), @@ -647,7 +648,8 @@ protected ProcessFunction createSyncLookupJoinFunction( JavaScalaConversionUtil.toScala(projectionOnTemporalTable), filterOnTemporalTable, projectionOutputRelDataType, - tableSourceRowType); + tableSourceRowType, + ShortcutUtils.unwrapTypeFactory(relBuilder)); processFunc = new LookupJoinWithCalcRunner( diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/DeltaJoinTree.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/DeltaJoinTree.java index f7555e06822e4..c826fcd1ff121 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/DeltaJoinTree.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/DeltaJoinTree.java @@ -228,7 +228,8 @@ private DeltaJoinRuntimeTree.Node convert2RuntimeTreeInternal( node.filter, rowTypePassThroughCalc, rowTypeBeforeCalc, - generatedCalcName)) + generatedCalcName, + typeFactory)) .orElse(null); if (node instanceof BinaryInputNode) { diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java index c540c543a6d0e..c1b6ffcd17947 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java @@ -574,7 +574,8 @@ private static LookupHandlerBase generateLookupHandler( JavaScalaConversionUtil.toScala(projectionOnTemporalTable), filterOnTemporalTable, lookupSidePassThroughCalcRowType, - lookupTableSourceRowType); + lookupTableSourceRowType, + typeFactory); } Preconditions.checkState(!generatedFetcherCollector.containsKey(lookupTableOrdinal)); diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala index 464e449644dff..32812a3ace38a 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala @@ -43,6 +43,7 @@ object CalcCodeGenerator { outputType: RowType, projection: Seq[RexNode], condition: Option[RexNode], + typeFactory: FlinkTypeFactory, retainHeader: Boolean = false, opName: String): CodeGenOperatorFactory[RowData] = { // filter out time attributes @@ -54,6 +55,7 @@ object CalcCodeGenerator { classOf[BoxedWrapperRowData], projection, condition, + typeFactory, inputTerm, CodeGenUtils.DEFAULT_OPERATOR_COLLECTOR_TERM, eagerInputUnboxingCode = true, @@ -81,7 +83,8 @@ object CalcCodeGenerator { calcProjection: Seq[RexNode], calcCondition: Option[RexNode], tableConfig: ReadableConfig, - classLoader: ClassLoader): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { + classLoader: ClassLoader, + typeFactory: FlinkTypeFactory): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { val ctx = new CodeGeneratorContext(tableConfig, classLoader) val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM val collectorTerm = CodeGenUtils.DEFAULT_COLLECTOR_TERM @@ -92,6 +95,7 @@ object CalcCodeGenerator { outRowClass, calcProjection, calcCondition, + typeFactory, inputTerm, collectorTerm = collectorTerm, eagerInputUnboxingCode = false, @@ -116,6 +120,7 @@ object CalcCodeGenerator { outRowClass: Class[_ <: RowData], projection: Seq[RexNode], condition: Option[RexNode], + typeFactory: FlinkTypeFactory, inputTerm: String = CodeGenUtils.DEFAULT_INPUT1_TERM, collectorTerm: String = CodeGenUtils.DEFAULT_OPERATOR_COLLECTOR_TERM, eagerInputUnboxingCode: Boolean, @@ -127,7 +132,7 @@ object CalcCodeGenerator { projection.foreach(_.accept(ScalarFunctionsValidator)) condition.foreach(_.accept(ScalarFunctionsValidator)) - val rexProgram = buildRexProgram(ctx.classLoader, inputType, projection, condition) + val rexProgram = buildRexProgram(typeFactory, inputType, projection, condition) val exprGenerator = new ExprCodeGenerator(ctx, false, rexProgram) .bindInput(inputType, inputTerm = inputTerm) @@ -232,11 +237,11 @@ object CalcCodeGenerator { } private def buildRexProgram( - classLoader: ClassLoader, + typeFactory: FlinkTypeFactory, inputType: RowType, projection: Seq[RexNode], - condition: Option[RexNode]): RexProgram = { - val typeFactory = new FlinkTypeFactory(classLoader, FlinkTypeSystem.INSTANCE) + condition: Option[RexNode] + ): RexProgram = { val rexBuilder = new FlinkRexBuilder(typeFactory) val relInputType = typeFactory.createFieldTypeFromLogicalType(inputType) val builder = new RexProgramBuilder(relInputType, rexBuilder) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala index 6c2a5b20ba74e..7d03a7c4a73b3 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala @@ -120,17 +120,34 @@ class CodeGeneratorContext( val reusableLocalRefExprs: mutable.LinkedHashMap[Int, GeneratedExpression] = mutable.LinkedHashMap[Int, GeneratedExpression]() - // Stack of RexLocalRef cache scopes. The bottom scope IS reusableLocalRefExprs and is read - // by CalcCodeGenerator.reuseLocalRefCode() — its bodies are hoisted to the top of the - // generated method and so must be safe to evaluate unconditionally. + // Stack of RexLocalRef cache scopes (`exprList-index -> generated body`). + // * Bottom scope == reusableLocalRefExprs: bodies are hoisted to the top of the method + // and run unconditionally for every row. + // * Inner scopes (push/popLocalRefScope): bodies are folded into a single guarded + // operand's code by ExprCodeGenerator.visitOperandInScopedCache and run only when + // the guard fires. Inserts always target the innermost scope; lookup walks innermost-out. // - // ExprCodeGenerator pushes an inner scope before visiting a guarded operand (CASE WHEN's - // THEN/ELSE branch, AND/OR's right-hand side, ...) and pops it after. Any RexLocalRef body - // cached during that visit lives only in the inner scope; ExprCodeGenerator folds those - // bodies into the operand's generated code so they execute only when the guard fires. - // Without this scoping, an arithmetic expression like (a / b) inside CASE WHEN b > 0 would - // be hoisted above the if-block and divide by zero on rows where b == 0. - private val localRefScopes: mutable.ArrayBuffer[mutable.LinkedHashMap[Int, GeneratedExpression]] = + // Example — `CASE WHEN b <> 0 THEN a / b ELSE NULL`: + // + // With scoping (correct): + // boolean cmp = b != 0; + // if (cmp) { + // int div = a / b; // emitted inside the guarded scope + // result = div; + // } else { + // result = null; + // } + // + // Without scoping (buggy): + // int div = a / b; // throws ArithmeticException when b == 0 + // boolean cmp = b != 0; + // if (cmp) { result = div; } + // else { result = null; } + // + // The set of operand positions that get scoped lives in + // ExprCodeGenerator.conditionalOperandIndices — extend it when adding new short-circuit + // operators. + private val localRefScopes = mutable.ArrayBuffer(reusableLocalRefExprs) // set of constructor statements that will be added only once diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala index db158572df9e9..e011dd2cfcec6 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala @@ -354,14 +354,16 @@ object LookupJoinCodeGenerator { projection: Seq[RexNode], condition: RexNode, outputType: RelDataType, - tableSourceRowType: RowType): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { + tableSourceRowType: RowType, + typeFactory: FlinkTypeFactory): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { generateCalcMapFunction( tableConfig, classLoader, projection, condition, FlinkTypeFactory.toLogicalRowType(outputType), - tableSourceRowType + tableSourceRowType, + typeFactory ) } @@ -375,7 +377,8 @@ object LookupJoinCodeGenerator { projection: Seq[RexNode], condition: RexNode, outputType: RowType, - tableSourceRowType: RowType): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { + tableSourceRowType: RowType, + typeFactory: FlinkTypeFactory): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { generateCalcMapFunction( tableConfig, classLoader, @@ -383,7 +386,8 @@ object LookupJoinCodeGenerator { condition, outputType, tableSourceRowType, - "TableCalcMapFunction") + "TableCalcMapFunction", + typeFactory) } /** @@ -397,7 +401,8 @@ object LookupJoinCodeGenerator { condition: RexNode, outputType: RowType, tableSourceRowType: RowType, - name: String): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { + name: String, + typeFactory: FlinkTypeFactory): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { CalcCodeGenerator.generateFunction( tableSourceRowType, name, @@ -406,7 +411,8 @@ object LookupJoinCodeGenerator { projection, Option(condition), tableConfig, - classLoader + classLoader, + typeFactory ) } } From 4a2715b64ee38a0a50c6826e2baa82748a93a801 Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Wed, 6 May 2026 17:30:53 +0200 Subject: [PATCH 07/12] Address feedback --- .../functions/JsonFunctionsITCase.java | 60 ++++++++++++------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonFunctionsITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonFunctionsITCase.java index bfcd299d937bb..1ec9561b25083 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonFunctionsITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonFunctionsITCase.java @@ -1540,22 +1540,7 @@ private static List jsonArraySpec() { STRING().notNull())); } - /** - * Pins the local-ref / common-sub-expression handling for JSON construction calls. - * - *

When two projections share a JSON-producing sub-expression, the planner deduplicates them - * into a {@link org.apache.calcite.rex.RexLocalRef} that points at the shared {@link - * org.apache.calcite.rex.RexCall}. The codegen helpers in {@code JsonGenerateUtils} must - * dereference that local ref through the surrounding {@code RexProgram} to recognize it as a - * JSON / JSON_OBJECT / JSON_ARRAY operand and embed the value as a raw JSON node. Without that - * dereference the helpers see a plain {@code RexLocalRef}, fall back to the string-quoting - * branch, and produce wrong output (e.g. {@code "{\"k\":\"[1,2,3]\"}"} instead of {@code - * "{\"k\":[1,2,3]}"}). - * - *

The scenarios below cover each callsite — {@code JsonObjectCallGen}, {@code - * JsonArrayCallGen}, {@code JsonStringCallGen} — and each branch of the inspection helpers - * ({@code JSON}, {@code JSON_OBJECT}, {@code JSON_ARRAY}). - */ + /** Pins the local-ref / common-sub-expression handling for JSON construction calls. */ private static List jsonLocalRefReuseSpec() { return List.of( // Shared JSON(f) inside two JSON_OBJECT projections. @@ -1576,7 +1561,12 @@ private static List jsonLocalRefReuseSpec() { "JSON_OBJECT(KEY 'k2' VALUE JSON(f0))", "{\"k2\":[1,2,3]}", STRING().notNull(), - STRING().notNull())), + STRING().notNull())) + .testSqlResult( + "JSON_OBJECT(KEY 'k1' VALUE JSON(f0))," + + " JSON_OBJECT(KEY 'k2' VALUE JSON(f0))", + List.of("{\"k1\":[1,2,3]}", "{\"k2\":[1,2,3]}"), + List.of(STRING().notNull(), STRING().notNull())), // Shared JSON_ARRAY(...) inside two JSON_OBJECT projections. TestSetSpec.forFunction( BuiltInFunctionDefinitions.JSON_OBJECT, @@ -1609,7 +1599,12 @@ private static List jsonLocalRefReuseSpec() { "JSON_OBJECT(KEY 'b' VALUE JSON_ARRAY(f0, f1, f2))", "{\"b\":[1,2,3]}", STRING().notNull(), - STRING().notNull())), + STRING().notNull())) + .testSqlResult( + "JSON_OBJECT(KEY 'a' VALUE JSON_ARRAY(f0, f1, f2))," + + " JSON_OBJECT(KEY 'b' VALUE JSON_ARRAY(f0, f1, f2))", + List.of("{\"a\":[1,2,3]}", "{\"b\":[1,2,3]}"), + List.of(STRING().notNull(), STRING().notNull())), // Shared inner JSON_OBJECT inside two outer JSON_OBJECT projections. TestSetSpec.forFunction( BuiltInFunctionDefinitions.JSON_OBJECT, @@ -1634,7 +1629,14 @@ private static List jsonLocalRefReuseSpec() { "JSON_OBJECT(KEY 'outer2' VALUE JSON_OBJECT(KEY 'inner' VALUE f0))", "{\"outer2\":{\"inner\":\"V\"}}", STRING().notNull(), - STRING().notNull())), + STRING().notNull())) + .testSqlResult( + "JSON_OBJECT(KEY 'outer1' VALUE JSON_OBJECT(KEY 'inner' VALUE f0))," + + " JSON_OBJECT(KEY 'outer2' VALUE JSON_OBJECT(KEY 'inner' VALUE f0))", + List.of( + "{\"outer1\":{\"inner\":\"V\"}}", + "{\"outer2\":{\"inner\":\"V\"}}"), + List.of(STRING().notNull(), STRING().notNull())), // Shared JSON_OBJECT inside two JSON_ARRAY projections. TestSetSpec.forFunction( BuiltInFunctionDefinitions.JSON_ARRAY, @@ -1657,7 +1659,12 @@ private static List jsonLocalRefReuseSpec() { "JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE f0))", "[{\"k\":\"V\"}]", STRING().notNull(), - STRING().notNull())), + STRING().notNull())) + .testSqlResult( + "JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE f0))," + + " JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE f0))", + List.of("[{\"k\":\"V\"}]", "[{\"k\":\"V\"}]"), + List.of(STRING().notNull(), STRING().notNull())), // Shared JSON(f) inside two JSON_ARRAY projections. TestSetSpec.forFunction( BuiltInFunctionDefinitions.JSON_ARRAY, @@ -1676,7 +1683,11 @@ private static List jsonLocalRefReuseSpec() { "JSON_ARRAY(JSON(f0))", "[[1,2,3]]", STRING().notNull(), - STRING().notNull())), + STRING().notNull())) + .testSqlResult( + "JSON_ARRAY(JSON(f0)), JSON_ARRAY(JSON(f0))", + List.of("[[1,2,3]]", "[[1,2,3]]"), + List.of(STRING().notNull(), STRING().notNull())), // Shared JSON_OBJECT inside two JSON_STRING projections. JSON_STRING re-serializes // the operand; without dereferencing the local ref it would wrap the already // serialized JSON string a second time. @@ -1697,7 +1708,12 @@ private static List jsonLocalRefReuseSpec() { "JSON_STRING(JSON_OBJECT(KEY 'k' VALUE f0))", "{\"k\":\"V\"}", STRING().notNull(), - STRING().notNull()))); + STRING().notNull())) + .testSqlResult( + "JSON_STRING(JSON_OBJECT(KEY 'k' VALUE f0))," + + " JSON_STRING(JSON_OBJECT(KEY 'k' VALUE f0))", + List.of("{\"k\":\"V\"}", "{\"k\":\"V\"}"), + List.of(STRING().notNull(), STRING().notNull()))); } // --------------------------------------------------------------------------------------------- From 94aa427bea4446f3a7ad347bf8ab92ee34fdc59a Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Wed, 6 May 2026 17:58:33 +0200 Subject: [PATCH 08/12] Address feedback --- .../codegen/CodeGeneratorContext.scala | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala index 7d03a7c4a73b3..00a457b2f047a 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala @@ -207,19 +207,6 @@ class CodeGeneratorContext( def getReusableInputUnboxingExprs(inputTerm: String, index: Int): Option[GeneratedExpression] = reusableInputUnboxingExprs.get((inputTerm, index)) - def getReusableLocalRefExpr(index: Int): Option[GeneratedExpression] = { - // Search innermost-out: a body cached in an inner (guarded) scope wins over outer - // entries. In practice the cache is monotone — an entry never appears in two scopes - // simultaneously. - var i = localRefScopes.size - 1 - while (i >= 0) { - val maybe = localRefScopes(i).get(index) - if (maybe.isDefined) return maybe - i -= 1 - } - None - } - /** Prioritize using the nameCounter of the ancestor. */ def getNameCounter: AtomicLong = if (parentCtx == null) nameCounter else parentCtx.getNameCounter @@ -422,10 +409,6 @@ class CodeGeneratorContext( reusableInputUnboxingExprs.values.map(_.code).mkString("\n") } - def reuseLocalRefCode(): String = { - reusableLocalRefExprs.values.map(_.code).mkString("\n") - } - /** Returns code block of unboxing input variables which belongs to the given inputTerm. */ def reuseInputUnboxingCode(inputTerm: String): String = { val exprs = reusableInputUnboxingExprs.filter { @@ -1131,6 +1114,27 @@ class CodeGeneratorContext( fieldTerm } + // --------------------------------------------------------------------------------- + // Reusable local ref code with scope + // --------------------------------------------------------------------------------- + + def getReusableLocalRefExpr(index: Int): Option[GeneratedExpression] = { + // Search innermost-out: a body cached in an inner (guarded) scope wins over outer + // entries. In practice the cache is monotone — an entry never appears in two scopes + // simultaneously. + var i = localRefScopes.size - 1 + while (i >= 0) { + val maybe = localRefScopes(i).get(index) + if (maybe.isDefined) return maybe + i -= 1 + } + None + } + + def reuseLocalRefCode(): String = { + reusableLocalRefExprs.values.map(_.code).mkString("\n") + } + def pushLocalRefScope(): Unit = { localRefScopes.append(mutable.LinkedHashMap.empty) } From c7f8040a0d694ac0557a772cd3ad3594e9b1e1cf Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Wed, 6 May 2026 20:07:49 +0200 Subject: [PATCH 09/12] optimize import --- .../table/planner/plan/nodes/exec/common/CommonExecCalc.java | 1 - .../flink/table/planner/codegen/CalcCodeGenerator.scala | 5 +---- .../flink/table/planner/codegen/FunctionCodeGenerator.scala | 3 +-- .../flink/table/planner/codegen/JsonGenerateUtils.scala | 1 - 4 files changed, 2 insertions(+), 8 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java index ec53d746556d7..e1ddbcbc46e6f 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java @@ -100,7 +100,6 @@ protected Transformation translateToPlanInternal( final CodeGenOperatorFactory substituteStreamOperator = CalcCodeGenerator.generateCalcOperator( ctx, - inputTransform, (RowType) inputEdge.getOutputType(), (RowType) getOutputType(), JavaScalaConversionUtil.toScala(projection), diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala index 32812a3ace38a..f5c4f83be3900 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala @@ -18,16 +18,14 @@ package org.apache.flink.table.planner.codegen import org.apache.flink.api.common.functions.{FlatMapFunction, Function} -import org.apache.flink.api.dag.Transformation import org.apache.flink.configuration.ReadableConfig import org.apache.flink.table.api.{TableException, ValidationException} import org.apache.flink.table.data.{BoxedWrapperRowData, RowData} import org.apache.flink.table.functions.FunctionKind -import org.apache.flink.table.planner.calcite.{FlinkRexBuilder, FlinkTypeFactory, FlinkTypeSystem} +import org.apache.flink.table.planner.calcite.{FlinkRexBuilder, FlinkTypeFactory} import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction import org.apache.flink.table.runtime.generated.GeneratedFunction import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory -import org.apache.flink.table.runtime.typeutils.InternalTypeInfo import org.apache.flink.table.types.logical.RowType import org.apache.calcite.rex._ @@ -38,7 +36,6 @@ object CalcCodeGenerator { def generateCalcOperator( ctx: CodeGeneratorContext, - inputTransform: Transformation[RowData], inputType: RowType, outputType: RowType, projection: Seq[RexNode], diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCodeGenerator.scala index 817d399bacc8e..3131de4f8f909 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCodeGenerator.scala @@ -18,12 +18,11 @@ package org.apache.flink.table.planner.codegen import org.apache.flink.api.common.functions._ -import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.functions.ProcessFunction import org.apache.flink.streaming.api.functions.async.{AsyncFunction, RichAsyncFunction} import org.apache.flink.table.planner.codegen.CodeGenUtils._ import org.apache.flink.table.planner.codegen.Indenter.toISC -import org.apache.flink.table.runtime.generated.{FilterCondition, GeneratedFilterCondition, GeneratedFunction, GeneratedJoinCondition, JoinCondition} +import org.apache.flink.table.runtime.generated._ import org.apache.flink.table.types.logical.LogicalType /** diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala index 3478fefa63df2..bbc12d96c49bd 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala @@ -23,7 +23,6 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.util.RawV import org.apache.flink.table.api.{DataTypes, JsonOnNull} import org.apache.flink.table.functions.BuiltInFunctionDefinitions import org.apache.flink.table.planner.codegen.CodeGenUtils._ -import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable import org.apache.flink.table.planner.plan.utils.FlinkRexUtil import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala import org.apache.flink.table.planner.utils.ShortcutUtils From 2de60e3eecfe9b9daca82d04156bf601c182da7e Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Thu, 7 May 2026 09:41:34 +0200 Subject: [PATCH 10/12] Addrss feedback --- .../codegen/CodeGeneratorContext.scala | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala index 00a457b2f047a..b1bbe0c4e6a75 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala @@ -492,10 +492,6 @@ class CodeGeneratorContext( index: Int, expr: GeneratedExpression): Unit = reusableInputUnboxingExprs((inputTerm, index)) = expr - /** Adds a reusable RexLocalRef expression keyed by its index in the program's exprList. */ - def addReusableLocalRefExpr(index: Int, expr: GeneratedExpression): Unit = - localRefScopes.last(index) = expr - /** Adds a reusable output record statement to member area. */ def addReusableOutputRecord( t: LogicalType, @@ -821,6 +817,7 @@ class CodeGeneratorContext( /** * Adds a reusable Object to the member area of the generated class + * * @param obj * the object to be added to the generated class * @param fieldNamePrefix @@ -1118,6 +1115,18 @@ class CodeGeneratorContext( // Reusable local ref code with scope // --------------------------------------------------------------------------------- + /** + * Adds a reusable [[org.apache.calcite.rex.RexLocalRef]] expression keyed by its index in the + * program's exprList. The expression is stored in the innermost active scope. + */ + def addReusableLocalRefExpr(index: Int, expr: GeneratedExpression): Unit = + localRefScopes.last(index) = expr + + /** + * Looks up a previously cached [[org.apache.calcite.rex.RexLocalRef]] expression by its exprList + * index. Scopes are searched innermost-out so that a body cached inside a guarded scope takes + * precedence over an outer entry. + */ def getReusableLocalRefExpr(index: Int): Option[GeneratedExpression] = { // Search innermost-out: a body cached in an inner (guarded) scope wins over outer // entries. In practice the cache is monotone — an entry never appears in two scopes @@ -1131,14 +1140,23 @@ class CodeGeneratorContext( None } + /** + * Returns the generated code for all unconditionally-evaluated local-ref expressions (bottom + * scope), concatenated in insertion order. + */ def reuseLocalRefCode(): String = { reusableLocalRefExprs.values.map(_.code).mkString("\n") } + /** Pushes a new, empty local-ref cache scope onto the scope stack. */ def pushLocalRefScope(): Unit = { localRefScopes.append(mutable.LinkedHashMap.empty) } + /** + * Pops the innermost local-ref cache scope and returns its entries. The bottom scope + * ([[reusableLocalRefExprs]]) cannot be popped. + */ def popLocalRefScope(): scala.collection.Map[Int, GeneratedExpression] = { require( localRefScopes.size > 1, From 3a1531d93d8099c85e64aeef64a1d6323a4a27ac Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Thu, 7 May 2026 13:48:57 +0200 Subject: [PATCH 11/12] Address feedback --- .../sql/FunctionDefinitionQueryable.java | 30 ----------------- .../sql/SqlJsonArrayFunctionWrapper.java | 11 +------ .../sql/SqlJsonObjectFunctionWrapper.java | 11 +------ .../sql/SqlJsonQueryFunctionWrapper.java | 10 +----- .../sql/SqlJsonValueFunctionWrapper.java | 11 +------ .../table/planner/utils/ShortcutUtils.java | 19 ----------- .../planner/codegen/JsonGenerateUtils.scala | 32 +++++++++++++++---- 7 files changed, 30 insertions(+), 94 deletions(-) delete mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FunctionDefinitionQueryable.java diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FunctionDefinitionQueryable.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FunctionDefinitionQueryable.java deleted file mode 100644 index abfa8c9200e9c..0000000000000 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FunctionDefinitionQueryable.java +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to you under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.planner.functions.sql; - -import org.apache.flink.table.functions.FunctionDefinition; -import org.apache.flink.table.functions.TableFunction; -import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction; - -/** - * Function which is not bridged by {@link BridgingSqlFunction} and which is not legacy {@link - * TableFunction} however whose {@link FunctionDefinition} might be queried. - */ -public interface FunctionDefinitionQueryable { - FunctionDefinition getFunctionDefinition(); -} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonArrayFunctionWrapper.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonArrayFunctionWrapper.java index ffb3036b94462..eb6aae3a8edef 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonArrayFunctionWrapper.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonArrayFunctionWrapper.java @@ -18,9 +18,6 @@ package org.apache.flink.table.planner.functions.sql; -import org.apache.flink.table.functions.BuiltInFunctionDefinitions; -import org.apache.flink.table.functions.FunctionDefinition; - import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.fun.SqlJsonArrayFunction; @@ -32,8 +29,7 @@ * This class is a wrapper class for the {@link SqlJsonArrayFunction} but using the {@code * VARCHAR_NOT_NULL} return type inference. */ -class SqlJsonArrayFunctionWrapper extends SqlJsonArrayFunction - implements FunctionDefinitionQueryable { +public class SqlJsonArrayFunctionWrapper extends SqlJsonArrayFunction { @Override public RelDataType inferReturnType(SqlOperatorBinding opBinding) { @@ -53,9 +49,4 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) { public SqlReturnTypeInference getReturnTypeInference() { return VARCHAR_NOT_NULL; } - - @Override - public FunctionDefinition getFunctionDefinition() { - return BuiltInFunctionDefinitions.JSON_ARRAY; - } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonObjectFunctionWrapper.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonObjectFunctionWrapper.java index e6d05d9d50e55..b4ef34b94c572 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonObjectFunctionWrapper.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonObjectFunctionWrapper.java @@ -18,9 +18,6 @@ package org.apache.flink.table.planner.functions.sql; -import org.apache.flink.table.functions.BuiltInFunctionDefinitions; -import org.apache.flink.table.functions.FunctionDefinition; - import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.fun.SqlJsonObjectFunction; @@ -32,8 +29,7 @@ * This class is a wrapper class for the {@link SqlJsonObjectFunction} but using the {@code * VARCHAR_NOT_NULL} return type inference. */ -class SqlJsonObjectFunctionWrapper extends SqlJsonObjectFunction - implements FunctionDefinitionQueryable { +public class SqlJsonObjectFunctionWrapper extends SqlJsonObjectFunction { @Override public RelDataType inferReturnType(SqlOperatorBinding opBinding) { @@ -53,9 +49,4 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) { public SqlReturnTypeInference getReturnTypeInference() { return VARCHAR_NOT_NULL; } - - @Override - public FunctionDefinition getFunctionDefinition() { - return BuiltInFunctionDefinitions.JSON_OBJECT; - } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonQueryFunctionWrapper.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonQueryFunctionWrapper.java index 772708d3fd818..ddae97fa23259 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonQueryFunctionWrapper.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonQueryFunctionWrapper.java @@ -19,8 +19,6 @@ package org.apache.flink.table.planner.functions.sql; import org.apache.flink.table.api.ValidationException; -import org.apache.flink.table.functions.BuiltInFunctionDefinitions; -import org.apache.flink.table.functions.FunctionDefinition; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlCallBinding; @@ -44,8 +42,7 @@ * This class is a wrapper class for the {@link SqlJsonQueryFunction} but using the {@code * VARCHAR_FORCE_NULLABLE} return type inference. */ -class SqlJsonQueryFunctionWrapper extends SqlJsonQueryFunction - implements FunctionDefinitionQueryable { +public class SqlJsonQueryFunctionWrapper extends SqlJsonQueryFunction { private final SqlReturnTypeInference returnTypeInference; SqlJsonQueryFunctionWrapper() { @@ -145,9 +142,4 @@ private static RelDataType explicitTypeSpec(SqlOperatorBinding opBinding) { } return null; } - - @Override - public FunctionDefinition getFunctionDefinition() { - return BuiltInFunctionDefinitions.JSON_QUERY; - } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonValueFunctionWrapper.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonValueFunctionWrapper.java index 06c189e047cb8..03e60ecdf3c9d 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonValueFunctionWrapper.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonValueFunctionWrapper.java @@ -18,9 +18,6 @@ package org.apache.flink.table.planner.functions.sql; -import org.apache.flink.table.functions.BuiltInFunctionDefinitions; -import org.apache.flink.table.functions.FunctionDefinition; - import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlJsonValueReturning; import org.apache.calcite.sql.SqlOperatorBinding; @@ -38,8 +35,7 @@ * VARCHAR_FORCE_NULLABLE} return type inference by default. It also supports specifying return type * with the RETURNING keyword just like the original {@link SqlJsonValueFunction}. */ -class SqlJsonValueFunctionWrapper extends SqlJsonValueFunction - implements FunctionDefinitionQueryable { +public class SqlJsonValueFunctionWrapper extends SqlJsonValueFunction { private final SqlReturnTypeInference returnTypeInference; @@ -84,9 +80,4 @@ private static RelDataType explicitTypeSpec(SqlOperatorBinding opBinding) { } return null; } - - @Override - public FunctionDefinition getFunctionDefinition() { - return BuiltInFunctionDefinitions.JSON_VALUE; - } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java index 07150b43f6009..b079eddf5570a 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java @@ -31,7 +31,6 @@ import org.apache.flink.table.planner.delegation.PlannerBase; import org.apache.flink.table.planner.expressions.RexNodeExpression; import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction; -import org.apache.flink.table.planner.functions.sql.FunctionDefinitionQueryable; import org.apache.flink.table.planner.functions.utils.TableSqlFunction; import org.apache.calcite.plan.Context; @@ -157,9 +156,6 @@ public static DataTypeFactory unwrapDataTypeFactory(RelBuilder relBuilder) { final RexCall call = (RexCall) rexNode; final SqlOperator operator = call.getOperator(); if (!(operator instanceof BridgingSqlFunction)) { - if (operator instanceof FunctionDefinitionQueryable) { - return ((FunctionDefinitionQueryable) operator).getFunctionDefinition(); - } // legacy if (operator instanceof TableSqlFunction) { return ((TableSqlFunction) operator).udtf(); @@ -181,21 +177,6 @@ public static boolean isFunctionKind(SqlOperator operator, FunctionKind kind) { return functionDefinition != null && functionDefinition.getKind() == kind; } - public static boolean isOneOfFunctionDefinitions( - RexNode rexNode, FunctionDefinition... expectedDefinitions) { - if (!(rexNode instanceof RexCall)) { - return false; - } - final RexCall call = (RexCall) rexNode; - final FunctionDefinition unwrapped = unwrapFunctionDefinition(call); - for (FunctionDefinition expected : expectedDefinitions) { - if (unwrapped == expected) { - return true; - } - } - return false; - } - public static boolean isDeterministicThroughProgram( RexNode node, @Nullable List exprs) { if (exprs == null) { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala index bbc12d96c49bd..a64a97727ccd8 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala @@ -21,8 +21,9 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.{ArrayNode, ObjectNode} import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.util.RawValue import org.apache.flink.table.api.{DataTypes, JsonOnNull} -import org.apache.flink.table.functions.BuiltInFunctionDefinitions +import org.apache.flink.table.functions.{BuiltInFunctionDefinitions, FunctionDefinition} import org.apache.flink.table.planner.codegen.CodeGenUtils._ +import org.apache.flink.table.planner.functions.sql.{SqlJsonArrayFunctionWrapper, SqlJsonObjectFunctionWrapper, SqlJsonQueryFunctionWrapper, SqlJsonValueFunctionWrapper} import org.apache.flink.table.planner.plan.utils.FlinkRexUtil import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala import org.apache.flink.table.planner.utils.ShortcutUtils @@ -33,7 +34,7 @@ import org.apache.flink.table.types.logical.LogicalTypeRoot._ import org.apache.flink.table.types.logical.utils.LogicalTypeChecks import org.apache.flink.table.utils.EncodingUtils -import org.apache.calcite.rex.RexNode +import org.apache.calcite.rex.{RexCall, RexNode} import java.time.format.DateTimeFormatter @@ -180,13 +181,13 @@ object JsonGenerateUtils { /** Determines whether the given operand is a call to a JSON_OBJECT. */ def isJsonObjectOperand(operand: RexNode, localRefs: java.util.List[RexNode]): Boolean = - ShortcutUtils.isOneOfFunctionDefinitions( + isOneOfFunctionDefinitions( FlinkRexUtil.expandLocalRef(operand, localRefs), BuiltInFunctionDefinitions.JSON_OBJECT) /** Determines whether the given operand is a call to a JSON_ARRAY. */ def isJsonArrayOperand(operand: RexNode, localRefs: java.util.List[RexNode]): Boolean = - ShortcutUtils.isOneOfFunctionDefinitions( + isOneOfFunctionDefinitions( FlinkRexUtil.expandLocalRef(operand, localRefs), BuiltInFunctionDefinitions.JSON_ARRAY) @@ -195,7 +196,7 @@ object JsonGenerateUtils { * should be inserted as a raw value instead of as a character string. */ def isJsonObjectOrArrayOperand(operand: RexNode, localRefs: java.util.List[RexNode]): Boolean = - ShortcutUtils.isOneOfFunctionDefinitions( + isOneOfFunctionDefinitions( FlinkRexUtil.expandLocalRef(operand, localRefs), BuiltInFunctionDefinitions.JSON_OBJECT, BuiltInFunctionDefinitions.JSON_ARRAY) @@ -205,7 +206,7 @@ object JsonGenerateUtils { * passes through the input value as output value. */ def isJsonFunctionOperand(operand: RexNode, localRefs: java.util.List[RexNode]): Boolean = - ShortcutUtils.isOneOfFunctionDefinitions( + isOneOfFunctionDefinitions( FlinkRexUtil.expandLocalRef(operand, localRefs), BuiltInFunctionDefinitions.JSON) @@ -313,4 +314,23 @@ object JsonGenerateUtils { ctx.addReusableMember(methodCode) methodName } + + def isOneOfFunctionDefinitions( + rexNode: RexNode, + expectedDefinitions: FunctionDefinition*): Boolean = { + if (!rexNode.isInstanceOf[RexCall]) return false + val call = rexNode.asInstanceOf[RexCall] + val unwrapped = ShortcutUtils.unwrapFunctionDefinition(call) match { + case d if d != null => d + case _ => + call.getOperator match { + case _: SqlJsonArrayFunctionWrapper => BuiltInFunctionDefinitions.JSON_ARRAY + case _: SqlJsonObjectFunctionWrapper => BuiltInFunctionDefinitions.JSON_OBJECT + case _: SqlJsonQueryFunctionWrapper => BuiltInFunctionDefinitions.JSON_QUERY + case _: SqlJsonValueFunctionWrapper => BuiltInFunctionDefinitions.JSON_VALUE + case _ => return false + } + } + expectedDefinitions.exists(_ eq unwrapped) + } } From d227528c70fec21738327d4e211da5caaa3104e4 Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Thu, 7 May 2026 15:15:00 +0200 Subject: [PATCH 12/12] Address feedback --- .../planner/codegen/CalcCodeGenerator.scala | 6 +++--- .../planner/codegen/CodeGeneratorContext.scala | 17 ++++++++--------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala index f5c4f83be3900..966aca0abec22 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala @@ -197,7 +197,7 @@ object CalcCodeGenerator { val filterInputCode = ctx.reuseInputUnboxingCode() val filterInputSet = Set(ctx.reusableInputUnboxingExprs.keySet.toSeq: _*) - val filterLocalRefSet: Set[Int] = ctx.reusableLocalRefExprs.keySet.toSet + val filterLocalRefSet: Set[Int] = ctx.getReusableLocalRefExprBottomScope.keySet.toSet // if any filter conditions, projection code will enter an new scope val projectionCode = produceProjectionCode @@ -208,12 +208,12 @@ object CalcCodeGenerator { .map(_.code) .mkString("\n") - val filterLocalRefCode = ctx.reusableLocalRefExprs + val filterLocalRefCode = ctx.getReusableLocalRefExprBottomScope .filter { case (k, _) => filterLocalRefSet.contains(k) } .values .map(_.code) .mkString("\n") - val projectionLocalRefCode = ctx.reusableLocalRefExprs + val projectionLocalRefCode = ctx.getReusableLocalRefExprBottomScope .filter { case (k, _) => !filterLocalRefSet.contains(k) } .values .map(_.code) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala index b1bbe0c4e6a75..fade7279606a1 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala @@ -115,13 +115,8 @@ class CodeGeneratorContext( val reusableInputUnboxingExprs: mutable.Map[(String, Int), GeneratedExpression] = mutable.Map[(String, Int), GeneratedExpression]() - // map of expressions for shared RexProgram exprList entries that will be added only once - // exprList index -> expr - val reusableLocalRefExprs: mutable.LinkedHashMap[Int, GeneratedExpression] = - mutable.LinkedHashMap[Int, GeneratedExpression]() - // Stack of RexLocalRef cache scopes (`exprList-index -> generated body`). - // * Bottom scope == reusableLocalRefExprs: bodies are hoisted to the top of the method + // * Bottom scope == getReusableLocalRefExprBottomScope: bodies are hoisted to the top of the method // and run unconditionally for every row. // * Inner scopes (push/popLocalRefScope): bodies are folded into a single guarded // operand's code by ExprCodeGenerator.visitOperandInScopedCache and run only when @@ -148,7 +143,7 @@ class CodeGeneratorContext( // ExprCodeGenerator.conditionalOperandIndices — extend it when adding new short-circuit // operators. private val localRefScopes = - mutable.ArrayBuffer(reusableLocalRefExprs) + mutable.ArrayBuffer(mutable.LinkedHashMap.empty[Int, GeneratedExpression]) // set of constructor statements that will be added only once // we use a LinkedHashSet to keep the insertion order @@ -1115,6 +1110,10 @@ class CodeGeneratorContext( // Reusable local ref code with scope // --------------------------------------------------------------------------------- + // Bottom scope of localRefScopes: holds unconditionally evaluated local refs. + def getReusableLocalRefExprBottomScope: mutable.LinkedHashMap[Int, GeneratedExpression] = + localRefScopes(0) + /** * Adds a reusable [[org.apache.calcite.rex.RexLocalRef]] expression keyed by its index in the * program's exprList. The expression is stored in the innermost active scope. @@ -1145,7 +1144,7 @@ class CodeGeneratorContext( * scope), concatenated in insertion order. */ def reuseLocalRefCode(): String = { - reusableLocalRefExprs.values.map(_.code).mkString("\n") + getReusableLocalRefExprBottomScope.values.map(_.code).mkString("\n") } /** Pushes a new, empty local-ref cache scope onto the scope stack. */ @@ -1155,7 +1154,7 @@ class CodeGeneratorContext( /** * Pops the innermost local-ref cache scope and returns its entries. The bottom scope - * ([[reusableLocalRefExprs]]) cannot be popped. + * ([[getReusableLocalRefExprBottomScope]]) cannot be popped. */ def popLocalRefScope(): scala.collection.Map[Int, GeneratedExpression] = { require(