Skip to content

Commit e457052

Browse files
authored
Merge branch 'main' into fix-javadoc-c-007
2 parents bfe7e98 + 190c071 commit e457052

62 files changed

Lines changed: 1804 additions & 34 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ out/**
1414
.metals
1515
.bloop
1616
.project
17+
.classpath
18+
.settings
19+
bin/

core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,19 @@ public O visit(Expression.IfThen expr, C context) throws E {
461461
return visitFallback(expr, context);
462462
}
463463

464+
/**
465+
* Visits a Lambda expression.
466+
*
467+
* @param expr the Lambda expression
468+
* @param context the visitation context
469+
* @return the visit result
470+
* @throws E if visitation fails
471+
*/
472+
@Override
473+
public O visit(Expression.Lambda expr, C context) throws E {
474+
return visitFallback(expr, context);
475+
}
476+
464477
/**
465478
* Visits a scalar function invocation.
466479
*

core/src/main/java/io/substrait/expression/Expression.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,30 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
758758
}
759759
}
760760

761+
@Value.Immutable
762+
abstract class Lambda implements Expression {
763+
public abstract Type.Struct parameters();
764+
765+
public abstract Expression body();
766+
767+
@Override
768+
public Type getType() {
769+
List<Type> paramTypes = parameters().fields();
770+
Type returnType = body().getType();
771+
772+
// TODO: Type.Func nullability is hardcoded to false here because the spec does not allow for
773+
// declaring otherwise.
774+
// See: https://github.com/substrait-io/substrait/issues/976
775+
return Type.withNullability(false).func(paramTypes, returnType);
776+
}
777+
778+
@Override
779+
public <R, C extends VisitationContext, E extends Throwable> R accept(
780+
ExpressionVisitor<R, C, E> visitor, C context) throws E {
781+
return visitor.visit(this, context);
782+
}
783+
}
784+
761785
/**
762786
* Base interface for user-defined literals.
763787
*

core/src/main/java/io/substrait/expression/ExpressionVisitor.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,16 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr
321321
*/
322322
R visit(Expression.NestedStruct expr, C context) throws E;
323323

324+
/**
325+
* Visit a Lambda expression.
326+
*
327+
* @param expr the Lambda expression
328+
* @param context visitation context
329+
* @return visit result
330+
* @throws E on visit failure
331+
*/
332+
R visit(Expression.Lambda expr, C context) throws E;
333+
324334
/**
325335
* Visit a user-defined any literal.
326336
*

core/src/main/java/io/substrait/expression/FieldReference.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ public abstract class FieldReference implements Expression {
2020

2121
public abstract Optional<Integer> outerReferenceStepsOut();
2222

23+
public abstract Optional<Integer> lambdaParameterReferenceStepsOut();
24+
2325
@Override
2426
public Type getType() {
2527
return type();
@@ -35,16 +37,30 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
3537
return visitor.visit(this, context);
3638
}
3739

40+
@Value.Check
41+
protected void check() {
42+
if (outerReferenceStepsOut().isPresent() && lambdaParameterReferenceStepsOut().isPresent()) {
43+
throw new IllegalArgumentException(
44+
"FieldReference cannot have both outerReferenceStepsOut and lambdaParameterReferenceStepsOut set");
45+
}
46+
}
47+
3848
public boolean isSimpleRootReference() {
3949
return segments().size() == 1
4050
&& !inputExpression().isPresent()
41-
&& !outerReferenceStepsOut().isPresent();
51+
&& !outerReferenceStepsOut().isPresent()
52+
&& !lambdaParameterReferenceStepsOut().isPresent();
4253
}
4354

4455
public boolean isOuterReference() {
4556
return outerReferenceStepsOut().orElse(0) > 0;
4657
}
4758

59+
/** Returns true if this field reference refers to a lambda parameter. */
60+
public boolean isLambdaParameterReference() {
61+
return lambdaParameterReferenceStepsOut().isPresent();
62+
}
63+
4864
public FieldReference dereferenceStruct(int index) {
4965
Type newType = StructFieldFinder.getReferencedType(type(), index);
5066
return dereference(newType, StructField.of(index));
@@ -134,6 +150,14 @@ public static FieldReference newInputRelReference(int index, List<Rel> rels) {
134150
index, currentOffset));
135151
}
136152

153+
static FieldReference newLambdaParameterReference(int stepsOut, int paramIndex, Type knownType) {
154+
return ImmutableFieldReference.builder()
155+
.addSegments(StructField.of(paramIndex))
156+
.type(knownType)
157+
.lambdaParameterReferenceStepsOut(stepsOut)
158+
.build();
159+
}
160+
137161
public interface ReferenceSegment {
138162
FieldReference apply(FieldReference reference);
139163

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
package io.substrait.expression;
2+
3+
import io.substrait.type.Type;
4+
import java.util.ArrayList;
5+
import java.util.List;
6+
import java.util.function.Function;
7+
8+
/**
9+
* Builds lambda expressions with build-time validation of parameter references.
10+
*
11+
* <p>Maintains a stack of lambda parameter scopes. Each call to {@link #lambda} pushes parameters
12+
* onto the stack, builds the body via a callback, and pops. Nested lambdas simply call {@code
13+
* lambda()} again on the same builder.
14+
*
15+
* <p>The callback receives a {@link Scope} handle for creating validated parameter references. The
16+
* correct {@code stepsOut} value is computed automatically from the stack.
17+
*
18+
* <pre>{@code
19+
* LambdaBuilder lb = new LambdaBuilder();
20+
*
21+
* // Simple: (x: i32) -> x
22+
* Expression.Lambda simple = lb.lambda(List.of(R.I32), params -> params.ref(0));
23+
*
24+
* // Nested: (x: i32) -> (y: i64) -> add(x, y)
25+
* Expression.Lambda nested = lb.lambda(List.of(R.I32), outer ->
26+
* lb.lambda(List.of(R.I64), inner ->
27+
* add(outer.ref(0), inner.ref(0))
28+
* )
29+
* );
30+
* }</pre>
31+
*/
32+
public class LambdaBuilder {
33+
private final List<Type.Struct> lambdaContext = new ArrayList<>();
34+
35+
/**
36+
* Builds a lambda expression. The body function receives a {@link Scope} for creating validated
37+
* parameter references. Nested lambdas are built by calling this method again inside the
38+
* callback.
39+
*
40+
* @param paramTypes the lambda's parameter types
41+
* @param bodyFn function that builds the lambda body given a scope handle
42+
* @return the constructed lambda expression
43+
*/
44+
public Expression.Lambda lambda(List<Type> paramTypes, Function<Scope, Expression> bodyFn) {
45+
Type.Struct params = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build();
46+
pushLambdaContext(params);
47+
try {
48+
Scope scope = new Scope(params);
49+
Expression body = bodyFn.apply(scope);
50+
return ImmutableExpression.Lambda.builder().parameters(params).body(body).build();
51+
} finally {
52+
popLambdaContext();
53+
}
54+
}
55+
56+
/**
57+
* Builds a lambda expression from a pre-built parameter struct. Used by internal converters that
58+
* already have a Type.Struct (e.g., during protobuf deserialization).
59+
*
60+
* @param params the lambda's parameter struct
61+
* @param bodyFn function that builds the lambda body
62+
* @return the constructed lambda expression
63+
*/
64+
public Expression.Lambda lambdaFromStruct(
65+
Type.Struct params, java.util.function.Supplier<Expression> bodyFn) {
66+
pushLambdaContext(params);
67+
try {
68+
Expression body = bodyFn.get();
69+
return ImmutableExpression.Lambda.builder().parameters(params).body(body).build();
70+
} finally {
71+
popLambdaContext();
72+
}
73+
}
74+
75+
/**
76+
* Resolves the parameter struct for a lambda at the given stepsOut from the current innermost
77+
* scope. Used by internal converters to validate lambda parameter references during
78+
* deserialization.
79+
*
80+
* @param stepsOut number of lambda scopes to traverse outward (0 = current/innermost)
81+
* @return the parameter struct at the target scope level
82+
* @throws IllegalArgumentException if stepsOut exceeds the current nesting depth
83+
*/
84+
public Type.Struct resolveParams(int stepsOut) {
85+
int targetDepth = lambdaContext.size() - stepsOut;
86+
if (targetDepth <= 0 || targetDepth > lambdaContext.size()) {
87+
throw new IllegalArgumentException(
88+
String.format(
89+
"Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)",
90+
stepsOut, lambdaContext.size()));
91+
}
92+
return lambdaContext.get(targetDepth - 1);
93+
}
94+
95+
/**
96+
* Creates a validated field reference to a lambda parameter. Validates that stepsOut is valid for
97+
* the current lambda nesting context.
98+
*
99+
* @param stepsOut number of lambda scopes to traverse outward (0 = current/innermost)
100+
* @param paramIndex index of the parameter within the target lambda's parameter struct
101+
* @return a field reference to the specified lambda parameter
102+
* @throws IllegalArgumentException if stepsOut exceeds the current nesting depth
103+
* @throws IndexOutOfBoundsException if paramIndex is out of bounds for the target lambda
104+
*/
105+
public FieldReference newParameterReference(int stepsOut, int paramIndex) {
106+
Type.Struct params = resolveParams(stepsOut);
107+
Type type = params.fields().get(paramIndex);
108+
return FieldReference.newLambdaParameterReference(stepsOut, paramIndex, type);
109+
}
110+
111+
/**
112+
* Pushes a lambda's parameters onto the context stack. This makes the parameters available for
113+
* validation when building the lambda's body, and allows nested lambda parameter references to
114+
* correctly compute their stepsOut values.
115+
*/
116+
private void pushLambdaContext(Type.Struct params) {
117+
lambdaContext.add(params);
118+
}
119+
120+
/**
121+
* Pops the most recently pushed lambda parameters from the context stack. Called after a lambda's
122+
* body has been built, restoring the context to the enclosing lambda's scope.
123+
*/
124+
private void popLambdaContext() {
125+
lambdaContext.remove(lambdaContext.size() - 1);
126+
}
127+
128+
/**
129+
* A handle to a particular lambda's parameter scope. Use {@link #ref} to create validated
130+
* parameter references.
131+
*
132+
* <p>Each Scope captures the depth of the lambdaContext stack at the time it was created. When
133+
* {@link #ref} is called, the Substrait {@code stepsOut} value is computed as the difference
134+
* between the current stack depth and the captured depth. This means the same Scope produces
135+
* different stepsOut values depending on the nesting level at the time of the call, which is what
136+
* allows outer.ref(0) to produce stepsOut=1 when called inside a nested lambda.
137+
*/
138+
public class Scope {
139+
private final Type.Struct params;
140+
private final int depth;
141+
142+
private Scope(Type.Struct params) {
143+
this.params = params;
144+
this.depth = lambdaContext.size();
145+
}
146+
147+
/**
148+
* Computes the number of lambda boundaries between this scope and the current innermost scope.
149+
* This value changes dynamically as nested lambdas are built.
150+
*/
151+
private int stepsOut() {
152+
return lambdaContext.size() - depth;
153+
}
154+
155+
/**
156+
* Creates a validated reference to a parameter of this lambda.
157+
*
158+
* @param paramIndex index of the parameter within this lambda's parameter struct
159+
* @return a {@link FieldReference} pointing to the specified parameter
160+
* @throws IndexOutOfBoundsException if paramIndex is out of bounds
161+
*/
162+
public FieldReference ref(int paramIndex) throws IndexOutOfBoundsException {
163+
Type type = params.fields().get(paramIndex);
164+
return FieldReference.newLambdaParameterReference(stepsOut(), paramIndex, type);
165+
}
166+
}
167+
}

core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,18 @@ public Expression visit(
387387
});
388388
}
389389

390+
@Override
391+
public Expression visit(
392+
io.substrait.expression.Expression.Lambda expr, EmptyVisitationContext context)
393+
throws RuntimeException {
394+
return io.substrait.proto.Expression.newBuilder()
395+
.setLambda(
396+
io.substrait.proto.Expression.Lambda.newBuilder()
397+
.setParameters(typeProtoConverter.toProto(expr.parameters()).getStruct())
398+
.setBody(expr.body().accept(this, context)))
399+
.build();
400+
}
401+
390402
@Override
391403
public Expression visit(
392404
io.substrait.expression.Expression.UserDefinedAnyLiteral expr,
@@ -617,6 +629,10 @@ public Expression visit(FieldReference expr, EmptyVisitationContext context) {
617629
out.setOuterReference(
618630
io.substrait.proto.Expression.FieldReference.OuterReference.newBuilder()
619631
.setStepsOut(expr.outerReferenceStepsOut().get()));
632+
} else if (expr.lambdaParameterReferenceStepsOut().isPresent()) {
633+
out.setLambdaParameterReference(
634+
io.substrait.proto.Expression.FieldReference.LambdaParameterReference.newBuilder()
635+
.setStepsOut(expr.lambdaParameterReferenceStepsOut().get()));
620636
} else {
621637
out.setRootReference(Expression.FieldReference.RootReference.getDefaultInstance());
622638
}

core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import io.substrait.expression.FieldReference.ReferenceSegment;
77
import io.substrait.expression.FunctionArg;
88
import io.substrait.expression.FunctionOption;
9+
import io.substrait.expression.LambdaBuilder;
910
import io.substrait.expression.WindowBound;
1011
import io.substrait.extension.ExtensionLookup;
1112
import io.substrait.extension.SimpleExtension;
@@ -37,6 +38,7 @@ public class ProtoExpressionConverter {
3738
private final Type.Struct rootType;
3839
private final ProtoTypeConverter protoTypeConverter;
3940
private final ProtoRelConverter protoRelConverter;
41+
private final LambdaBuilder lambdaBuilder = new LambdaBuilder();
4042

4143
public ProtoExpressionConverter(
4244
ExtensionLookup lookup,
@@ -75,6 +77,21 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc
7577
reference.getDirectReference().getStructField().getField(),
7678
rootType,
7779
reference.getOuterReference().getStepsOut());
80+
case LAMBDA_PARAMETER_REFERENCE:
81+
{
82+
io.substrait.proto.Expression.FieldReference.LambdaParameterReference lambdaParamRef =
83+
reference.getLambdaParameterReference();
84+
85+
// Check for unsupported nested field access
86+
if (reference.getDirectReference().getStructField().hasChild()) {
87+
throw new UnsupportedOperationException(
88+
"Nested field access in lambda parameters is not yet supported");
89+
}
90+
91+
return lambdaBuilder.newParameterReference(
92+
lambdaParamRef.getStepsOut(),
93+
reference.getDirectReference().getStructField().getField());
94+
}
7895
case ROOTTYPE_NOT_SET:
7996
default:
8097
throw new IllegalArgumentException("Unhandled type: " + reference.getRootTypeCase());
@@ -260,6 +277,18 @@ public Type visit(Type.Struct type) throws RuntimeException {
260277
}
261278
}
262279

280+
case LAMBDA:
281+
{
282+
io.substrait.proto.Expression.Lambda protoLambda = expr.getLambda();
283+
Type.Struct parameters =
284+
(Type.Struct)
285+
protoTypeConverter.from(
286+
io.substrait.proto.Type.newBuilder()
287+
.setStruct(protoLambda.getParameters())
288+
.build());
289+
290+
return lambdaBuilder.lambdaFromStruct(parameters, () -> from(protoLambda.getBody()));
291+
}
263292
// TODO enum.
264293
case ENUM:
265294
throw new UnsupportedOperationException("Unsupported type: " + expr.getRexTypeCase());

0 commit comments

Comments
 (0)