Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 209 additions & 13 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import io.substrait.expression.Expression.SwitchClause;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.FunctionOption;
import io.substrait.expression.WindowBound;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
Expand Down Expand Up @@ -39,6 +40,7 @@
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
Expand Down Expand Up @@ -73,6 +75,17 @@ public class SubstraitBuilder {

private final SimpleExtension.ExtensionCollection extensions;

/**
* Constructs a new SubstraitBuilder with the default extension collection.
*
* <p>The builder is initialized with {@link DefaultExtensionCatalog#DEFAULT_COLLECTION}, which
* includes standard Substrait functions for strings, arithmetic, comparison, datetime, and other
* operations.
*/
public SubstraitBuilder() {
this(DefaultExtensionCatalog.DEFAULT_COLLECTION);
}

/**
* Constructs a new SubstraitBuilder with the specified extension collection.
*
Expand All @@ -83,6 +96,34 @@ public SubstraitBuilder(SimpleExtension.ExtensionCollection extensions) {
this.extensions = extensions;
}

/**
* Gets the default extension collection used by this builder.
*
* <p>This collection includes standard Substrait functions for strings, arithmetic, comparison,
* datetime, and other operations from {@link DefaultExtensionCatalog#DEFAULT_COLLECTION}.
*
* @return the ExtensionCollection containing standard Substrait functions
*/
public SimpleExtension.ExtensionCollection getExtensions() {
return extensions;
}

/**
* Creates a Plan.Root, which is the top-level container for a Substrait query plan.
*
* <p>The Plan.Root wraps a relational expression tree and associates output column names with the
* plan. This is the final step in building a complete Substrait plan that can be serialized and
* executed by a Substrait consumer.
*
* @param input the root relational expression of the query plan
* @param names the ordered list of output column names corresponding to the input relation's
* output schema
* @return a Plan.Root containing the query plan and output column names
*/
public Plan.Root root(final Rel input, final List<String> names) {
return Plan.Root.builder().input(input).names(names).build();
}

// Relations

/**
Expand Down Expand Up @@ -142,13 +183,32 @@ public Aggregate aggregate(
return aggregate(groupingsFn, measuresFn, Optional.of(remap), input);
}

private Aggregate aggregate(
Function<Rel, List<Aggregate.Grouping>> groupingsFn,
Function<Rel, List<Aggregate.Measure>> measuresFn,
Optional<Rel.Remap> remap,
Rel input) {
List<Aggregate.Grouping> groupings = groupingsFn.apply(input);
List<Aggregate.Measure> measures = measuresFn.apply(input);
/**
* Creates an aggregate relation that groups and aggregates data from an input relation.
*
* <p>This method constructs a Substrait aggregate operation by applying grouping and measure
* functions to the input relation. The grouping function defines how rows are grouped together,
* while the measure function defines the aggregate computations (e.g., SUM, COUNT, AVG) to
* perform on each group.
*
* <p>The optional remap parameter allows reordering or filtering of output columns, which is
* useful for controlling the final schema of the aggregate result.
*
* @param groupingsFn a function that takes the input relation and returns a list of grouping
* expressions defining how to partition the data
* @param measuresFn a function that takes the input relation and returns a list of aggregate
* measures to compute for each group
* @param remap an optional remapping specification to reorder or filter output columns
* @param input the input relation to aggregate
* @return an Aggregate relation representing the grouping and aggregation operation
*/
public Aggregate aggregate(
final Function<Rel, List<Aggregate.Grouping>> groupingsFn,
final Function<Rel, List<Aggregate.Measure>> measuresFn,
final Optional<Rel.Remap> remap,
final Rel input) {
final List<Aggregate.Grouping> groupings = groupingsFn.apply(input);
final List<Aggregate.Measure> measures = measuresFn.apply(input);
return Aggregate.builder()
.groupings(groupings)
.measures(measures)
Expand Down Expand Up @@ -853,24 +913,64 @@ public Expression.BoolLiteral bool(boolean v) {
return Expression.BoolLiteral.builder().value(v).build();
}

/**
* Create i16 literal.
*
* @param value value to create
* @return i16 instance
*/
public Expression.I8Literal i8(final int value) {
return Expression.I8Literal.builder().value(value).build();
}

/**
* Create i16 literal.
*
* @param value value to create
* @return i16 instance
*/
public Expression.I16Literal i16(final int value) {
return Expression.I16Literal.builder().value(value).build();
}

/**
* Creates a 32-bit integer literal expression.
*
* @param v the integer value
* @param value the integer value
* @return a new {@link Expression.I32Literal}
*/
public Expression.I32Literal i32(int v) {
return Expression.I32Literal.builder().value(v).build();
public Expression.I32Literal i32(final int value) {
return Expression.I32Literal.builder().value(value).build();
}

/**
* Createi64 literal.
*
* @param value value to create
* @return i64 instance
*/
public Expression.I64Literal i64(final long value) {
return Expression.I64Literal.builder().value(value).build();
}

/**
* Creates a 32-bit floating point literal expression.
*
* @param value the float value
* @return a new {@link Expression.FP32Literal}
*/
public Expression.FP32Literal fp32(final float value) {
return Expression.FP32Literal.builder().value(value).build();
}

/**
* Creates a 64-bit floating point literal expression.
*
* @param v the double value
* @param value the double value
* @return a new {@link Expression.FP64Literal}
*/
public Expression.FP64Literal fp64(double v) {
return Expression.FP64Literal.builder().value(v).build();
public Expression.FP64Literal fp64(final double value) {
return Expression.FP64Literal.builder().value(value).build();
}

/**
Expand Down Expand Up @@ -1439,6 +1539,79 @@ public Expression.ScalarFunctionInvocation or(Expression... args) {
return scalarFn(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, "or:bool", outputType, args);
}

/**
* Creates a logical NOT expression that negates a boolean expression.
*
* <p>This is a convenience method that wraps the boolean NOT function from the Substrait standard
* library. The result is nullable to handle NULL input values according to three-valued logic.
*
* @param expression the boolean expression to negate
* @return a scalar function invocation representing the logical NOT of the input expression
*/
public Expression not(final Expression expression) {
return this.scalarFn(
DefaultExtensionCatalog.FUNCTIONS_BOOLEAN,
"not:bool",
TypeCreator.NULLABLE.BOOLEAN,
expression);
}

/**
* Creates a null-check expression that tests whether an expression is null.
*
* <p>This is a convenience method that wraps the is_null function from the Substrait comparison
* function library. The function evaluates the input expression and returns true if it is null,
* false otherwise. This is commonly used in conditional logic and filtering operations.
*
* <p>The return type is always a required (non-nullable) boolean, as the null check itself always
* produces a definite true/false result.
*
* @param expression the expression to test for null
* @return a scalar function invocation that returns true if the expression is null, false
* otherwise
*/
public Expression isNull(final Expression expression) {

final List<Expression> args = new ArrayList<>();
args.add(expression);

return this.scalarFn(
DefaultExtensionCatalog.FUNCTIONS_COMPARISON,
"is_null:any",
TypeCreator.REQUIRED.BOOLEAN,
args,
new ArrayList<FunctionOption>());
}

/**
* Creates a scalar function invocation with function options.
*
* <p>This method extends the base builder's functionality by supporting function options, which
* control function behavior (e.g., rounding modes, overflow handling).
*
* @param urn the extension URI (e.g., {@link DefaultExtensionCatalog#FUNCTIONS_STRING})
* @param key the function signature (e.g., "substring:str_i32_i32")
* @param returnType the return type of the function
* @param args the function arguments
* @param optionsList the function options controlling behavior
* @return a scalar function invocation expression
*/
public Expression scalarFn(
final String urn,
final String key,
final Type returnType,
final List<? extends FunctionArg> args,
final List<FunctionOption> optionsList) {
final SimpleExtension.ScalarFunctionVariant declaration =
extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(urn, key));
return Expression.ScalarFunctionInvocation.builder()
.declaration(declaration)
.options(optionsList)
.outputType(returnType)
.arguments(args)
.build();
}

/**
* Creates a scalar function invocation with specified arguments.
*
Expand All @@ -1459,6 +1632,29 @@ public Expression.ScalarFunctionInvocation scalarFn(
.build();
}

/**
* Creates a scalar function invocation with function options.
*
* @param urn the extension URI (e.g., {@link DefaultExtensionCatalog#FUNCTIONS_STRING})
* @param key the function signature (e.g., "substring:str_i32_i32")
* @param returnType the return type of the function
* @param args the function arguments
* @return a scalar function invocation expression
*/
public Expression scalarFn(
final String urn,
final String key,
final Type returnType,
final List<? extends FunctionArg> args) {
final SimpleExtension.ScalarFunctionVariant declaration =
extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(urn, key));
return Expression.ScalarFunctionInvocation.builder()
.declaration(declaration)
.outputType(returnType)
.arguments(args)
.build();
}

/**
* Creates a window function invocation with specified arguments and window bounds.
*
Expand Down
Loading
Loading