diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
index 0f58da6af..d820031cd 100644
--- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
+++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
@@ -12,6 +12,7 @@
import io.substrait.expression.Expression.SwitchClause;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
+import io.substrait.expression.FunctionOption;
import io.substrait.expression.WindowBound;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
@@ -39,6 +40,7 @@
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
@@ -73,6 +75,17 @@ public class SubstraitBuilder {
private final SimpleExtension.ExtensionCollection extensions;
+ /**
+ * Constructs a new SubstraitBuilder with the default extension collection.
+ *
+ *
The builder is initialized with {@link DefaultExtensionCatalog#DEFAULT_COLLECTION}, which
+ * includes standard Substrait functions for strings, arithmetic, comparison, datetime, and other
+ * operations.
+ */
+ public SubstraitBuilder() {
+ this(DefaultExtensionCatalog.DEFAULT_COLLECTION);
+ }
+
/**
* Constructs a new SubstraitBuilder with the specified extension collection.
*
@@ -83,6 +96,18 @@ public SubstraitBuilder(SimpleExtension.ExtensionCollection extensions) {
this.extensions = extensions;
}
+ /**
+ * Gets the default extension collection used by this builder.
+ *
+ *
This collection includes standard Substrait functions for strings, arithmetic, comparison,
+ * datetime, and other operations from {@link DefaultExtensionCatalog#DEFAULT_COLLECTION}.
+ *
+ * @return the ExtensionCollection containing standard Substrait functions
+ */
+ public SimpleExtension.ExtensionCollection getExtensions() {
+ return extensions;
+ }
+
// Relations
/**
@@ -142,13 +167,32 @@ public Aggregate aggregate(
return aggregate(groupingsFn, measuresFn, Optional.of(remap), input);
}
- private Aggregate aggregate(
- Function> groupingsFn,
- Function> measuresFn,
- Optional remap,
- Rel input) {
- List groupings = groupingsFn.apply(input);
- List measures = measuresFn.apply(input);
+ /**
+ * Creates an aggregate relation that groups and aggregates data from an input relation.
+ *
+ * This method constructs a Substrait aggregate operation by applying grouping and measure
+ * functions to the input relation. The grouping function defines how rows are grouped together,
+ * while the measure function defines the aggregate computations (e.g., SUM, COUNT, AVG) to
+ * perform on each group.
+ *
+ *
The optional remap parameter allows reordering or filtering of output columns, which is
+ * useful for controlling the final schema of the aggregate result.
+ *
+ * @param groupingsFn a function that takes the input relation and returns a list of grouping
+ * expressions defining how to partition the data
+ * @param measuresFn a function that takes the input relation and returns a list of aggregate
+ * measures to compute for each group
+ * @param remap an optional remapping specification to reorder or filter output columns
+ * @param input the input relation to aggregate
+ * @return an Aggregate relation representing the grouping and aggregation operation
+ */
+ public Aggregate aggregate(
+ final Function> groupingsFn,
+ final Function> measuresFn,
+ final Optional remap,
+ final Rel input) {
+ final List groupings = groupingsFn.apply(input);
+ final List measures = measuresFn.apply(input);
return Aggregate.builder()
.groupings(groupings)
.measures(measures)
@@ -853,6 +897,26 @@ public Expression.BoolLiteral bool(boolean v) {
return Expression.BoolLiteral.builder().value(v).build();
}
+ /**
+ * Create i16 literal.
+ *
+ * @param value value to create
+ * @return i16 instance
+ */
+ public Expression.I8Literal i8(final int value) {
+ return Expression.I8Literal.builder().value(value).build();
+ }
+
+ /**
+ * Create i16 literal.
+ *
+ * @param value value to create
+ * @return i16 instance
+ */
+ public Expression.I16Literal i16(final int value) {
+ return Expression.I16Literal.builder().value(value).build();
+ }
+
/**
* Creates a 32-bit integer literal expression.
*
@@ -863,6 +927,26 @@ public Expression.I32Literal i32(int v) {
return Expression.I32Literal.builder().value(v).build();
}
+ /**
+ * Creates a 64-bit integer literal expression.
+ *
+ * @param value value to create
+ * @return i64 instance
+ */
+ public Expression.I64Literal i64(final long value) {
+ return Expression.I64Literal.builder().value(value).build();
+ }
+
+ /**
+ * Creates a 32-bit floating point literal expression.
+ *
+ * @param value the float value
+ * @return a new {@link Expression.FP32Literal}
+ */
+ public Expression.FP32Literal fp32(final float value) {
+ return Expression.FP32Literal.builder().value(value).build();
+ }
+
/**
* Creates a 64-bit floating point literal expression.
*
@@ -1439,6 +1523,79 @@ public Expression.ScalarFunctionInvocation or(Expression... args) {
return scalarFn(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, "or:bool", outputType, args);
}
+ /**
+ * Creates a logical NOT expression that negates a boolean expression.
+ *
+ * This is a convenience method that wraps the boolean NOT function from the Substrait standard
+ * library. The result is nullable to handle NULL input values according to three-valued logic.
+ *
+ * @param expression the boolean expression to negate
+ * @return a scalar function invocation representing the logical NOT of the input expression
+ */
+ public Expression not(final Expression expression) {
+ return this.scalarFn(
+ DefaultExtensionCatalog.FUNCTIONS_BOOLEAN,
+ "not:bool",
+ TypeCreator.NULLABLE.BOOLEAN,
+ expression);
+ }
+
+ /**
+ * Creates a null-check expression that tests whether an expression is null.
+ *
+ *
This is a convenience method that wraps the is_null function from the Substrait comparison
+ * function library. The function evaluates the input expression and returns true if it is null,
+ * false otherwise. This is commonly used in conditional logic and filtering operations.
+ *
+ *
The return type is always a required (non-nullable) boolean, as the null check itself always
+ * produces a definite true/false result.
+ *
+ * @param expression the expression to test for null
+ * @return a scalar function invocation that returns true if the expression is null, false
+ * otherwise
+ */
+ public Expression isNull(final Expression expression) {
+
+ final List args = new ArrayList<>();
+ args.add(expression);
+
+ return this.scalarFn(
+ DefaultExtensionCatalog.FUNCTIONS_COMPARISON,
+ "is_null:any",
+ TypeCreator.REQUIRED.BOOLEAN,
+ args,
+ new ArrayList());
+ }
+
+ /**
+ * Creates a scalar function invocation with function options.
+ *
+ * This method extends the base builder's functionality by supporting function options, which
+ * control function behavior (e.g., rounding modes, overflow handling).
+ *
+ * @param urn the extension URI (e.g., {@link DefaultExtensionCatalog#FUNCTIONS_STRING})
+ * @param key the function signature (e.g., "substring:str_i32_i32")
+ * @param returnType the return type of the function
+ * @param args the function arguments
+ * @param optionsList the function options controlling behavior
+ * @return a scalar function invocation expression
+ */
+ public Expression scalarFn(
+ final String urn,
+ final String key,
+ final Type returnType,
+ final List extends FunctionArg> args,
+ final List optionsList) {
+ final SimpleExtension.ScalarFunctionVariant declaration =
+ extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(urn, key));
+ return Expression.ScalarFunctionInvocation.builder()
+ .declaration(declaration)
+ .options(optionsList)
+ .outputType(returnType)
+ .arguments(args)
+ .build();
+ }
+
/**
* Creates a scalar function invocation with specified arguments.
*
@@ -1459,6 +1616,29 @@ public Expression.ScalarFunctionInvocation scalarFn(
.build();
}
+ /**
+ * Creates a scalar function invocation with function options.
+ *
+ * @param urn the extension URI (e.g., {@link DefaultExtensionCatalog#FUNCTIONS_STRING})
+ * @param key the function signature (e.g., "substring:str_i32_i32")
+ * @param returnType the return type of the function
+ * @param args the function arguments
+ * @return a scalar function invocation expression
+ */
+ public Expression scalarFn(
+ final String urn,
+ final String key,
+ final Type returnType,
+ final List extends FunctionArg> args) {
+ final SimpleExtension.ScalarFunctionVariant declaration =
+ extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(urn, key));
+ return Expression.ScalarFunctionInvocation.builder()
+ .declaration(declaration)
+ .outputType(returnType)
+ .arguments(args)
+ .build();
+ }
+
/**
* Creates a window function invocation with specified arguments and window bounds.
*
@@ -1532,6 +1712,22 @@ public Plan plan(Plan.Root root) {
return Plan.builder().addRoots(root).build();
}
+ /**
+ * Creates a Plan.Root, which is the top-level container for a Substrait query plan.
+ *
+ * The {@link Plan} wraps a relational expression tree and associates output column names with
+ * the plan. This is the final step in building a complete Substrait plan that can be serialized
+ * and executed by a Substrait consumer.
+ *
+ * @param input the root relational expression of the query plan
+ * @param names the ordered list of output column names corresponding to the input relation's
+ * output schema
+ * @return a new {@link Plan}
+ */
+ public Plan.Root root(final Rel input, final List names) {
+ return Plan.Root.builder().input(input).names(names).build();
+ }
+
/**
* Creates a field remapping specification from field indexes.
*
diff --git a/core/src/test/java/io/substrait/dsl/SubstraitBuilderTest.java b/core/src/test/java/io/substrait/dsl/SubstraitBuilderTest.java
new file mode 100644
index 000000000..2a199de03
--- /dev/null
+++ b/core/src/test/java/io/substrait/dsl/SubstraitBuilderTest.java
@@ -0,0 +1,265 @@
+package io.substrait.dsl;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertInstanceOf;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import io.substrait.TestBase;
+import io.substrait.expression.AggregateFunctionInvocation;
+import io.substrait.expression.Expression;
+import io.substrait.expression.FieldReference;
+import io.substrait.extension.SimpleExtension;
+import io.substrait.plan.Plan;
+import io.substrait.relation.AbstractWriteRel;
+import io.substrait.relation.Aggregate;
+import io.substrait.relation.Cross;
+import io.substrait.relation.Fetch;
+import io.substrait.relation.Filter;
+import io.substrait.relation.Join;
+import io.substrait.relation.NamedScan;
+import io.substrait.relation.NamedWrite;
+import io.substrait.relation.Project;
+import io.substrait.relation.Rel.Remap;
+import io.substrait.relation.Sort;
+import io.substrait.type.Type;
+import java.util.Arrays;
+import java.util.List;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Nested;
+import org.junit.jupiter.api.Test;
+
+class SubstraitBuilderTest extends TestBase {
+
+ private SubstraitBuilder builder;
+
+ @BeforeEach
+ void newbuilder() {
+ this.builder = new SubstraitBuilder();
+ }
+
+ @Test
+ void basicCreation() {
+ assertNotNull(this.builder);
+ assertInstanceOf(SimpleExtension.ExtensionCollection.class, this.builder.getExtensions());
+ }
+
+ @Nested
+ @DisplayName("Literal and Expression Tests")
+ class ExpressionTests {
+
+ @Test
+ void testLiterals() {
+ assertEquals(true, builder.bool(true).value());
+ assertEquals(10, builder.i8(10).value());
+ assertEquals(100, builder.i16(100).value());
+ assertEquals(1000, builder.i32(1000).value());
+ assertEquals(10_000L, builder.i64(10_000L).value());
+ assertEquals(1.5f, builder.fp32(1.5f).value());
+ assertEquals(2.5, builder.fp64(2.5).value());
+ assertEquals("foo", builder.str("foo").value());
+ }
+
+ @Test
+ void testFieldReferences() {
+ final NamedScan scan = createSimpleScan();
+ final FieldReference ref = builder.fieldReference(scan, 0);
+ assertNotNull(ref);
+
+ final List refs = builder.fieldReferences(scan, 0, 1);
+ assertEquals(2, refs.size());
+ }
+
+ @Test
+ void testCast() {
+ final Expression.I32Literal input = builder.i32(1);
+ final Expression cast = builder.cast(input, R.I64);
+ assertNotNull(cast);
+ assertTrue(cast instanceof Expression.Cast);
+ }
+ }
+
+ @Nested
+ @DisplayName("Relation Building Tests")
+ class RelationTests {
+
+ @Test
+ void testNamedScan() {
+ final NamedScan scan =
+ builder.namedScan(List.of("table"), List.of("c1", "c2"), List.of(R.I32, R.STRING));
+ assertNotNull(scan);
+ assertEquals(List.of("table"), scan.getNames());
+ }
+
+ @Test
+ void testProject() {
+ final NamedScan scan = createSimpleScan();
+ final Project project =
+ builder.project(rel -> List.of(builder.i32(1), builder.fieldReference(rel, 0)), scan);
+ assertNotNull(project);
+ assertEquals(scan, project.getInput());
+ }
+
+ @Test
+ void testFilter() {
+ final NamedScan scan = createSimpleScan();
+ final Filter filter =
+ builder.filter(
+ rel -> builder.equal(builder.fieldReference(rel, 0), builder.i32(10)), scan);
+ assertNotNull(filter);
+ assertNotNull(filter.getCondition());
+ }
+
+ @Test
+ void testInnerJoin() {
+ final NamedScan left = createSimpleScan();
+ final NamedScan right = createSimpleScan();
+ final Join join =
+ builder.innerJoin(
+ inputs ->
+ builder.equal(
+ builder.fieldReference(inputs.left(), 0),
+ builder.fieldReference(inputs.right(), 0)),
+ left,
+ right);
+ assertNotNull(join);
+ assertEquals(Join.JoinType.INNER, join.getJoinType());
+ }
+
+ @Test
+ void testFetchAndLimit() {
+ final NamedScan scan = createSimpleScan();
+ Fetch limit = builder.limit(10, scan);
+ assertEquals(10, limit.getCount().getAsLong());
+ assertEquals(0, limit.getOffset());
+ limit = builder.limit(10, Remap.of(Arrays.asList(new Integer[] {0, 1})), scan);
+ assertNotNull(limit);
+
+ Fetch offset = builder.offset(5, scan);
+ assertEquals(5, offset.getOffset());
+
+ offset = builder.offset(5, Remap.of(Arrays.asList(new Integer[] {0, 1})), scan);
+ assertEquals(5, offset.getOffset());
+ }
+
+ @Test
+ void testSort() {
+ final NamedScan scan = createSimpleScan();
+ Sort sort = builder.sort(rel -> builder.sortFields(rel, 0), scan);
+ assertNotNull(sort);
+ sort =
+ builder.sort(
+ rel -> builder.sortFields(rel, 0),
+ Remap.of(Arrays.asList(new Integer[] {0, 1})),
+ scan);
+ assertNotNull(sort);
+ }
+
+ @Test
+ void testCross() {
+ final NamedScan left = createSimpleScan();
+ final NamedScan right = createSimpleScan();
+
+ Cross cross = builder.cross(left, right);
+ assertNotNull(cross);
+
+ cross = builder.cross(left, right, Remap.of(Arrays.asList(new Integer[] {0, 1})));
+ assertNotNull(cross);
+ }
+
+ @Test
+ void testFetch() {
+ final NamedScan left = createSimpleScan();
+ Fetch fetch = builder.fetch(0, 1, left);
+ assertNotNull(fetch);
+
+ fetch = builder.fetch(0, 1, Remap.of(Arrays.asList(new Integer[] {0, 1})), left);
+ assertNotNull(fetch);
+ }
+ }
+
+ @Nested
+ @DisplayName("Aggregate and Scalar Function Tests")
+ class FunctionTests {
+
+ @Test
+ void testAggregateMeasures() {
+ final NamedScan scan = createSimpleScan();
+ final Aggregate.Measure count = builder.count(scan, 0);
+ final Aggregate.Measure sum = builder.sum(builder.fieldReference(scan, 0));
+
+ assertNotNull(count);
+ assertNotNull(sum);
+
+ final Aggregate.Grouping grouping = builder.grouping(scan, 1);
+ assertNotNull(grouping);
+ assertEquals(1, grouping.getExpressions().size());
+
+ final AggregateFunctionInvocation afi1 =
+ builder.aggregateFn(
+ "extension:io.substrait:functions_aggregate_generic",
+ "count:any",
+ Type.I32.builder().nullable(false).build(),
+ builder.fieldReference(scan, 0));
+ assertNotNull(afi1);
+ }
+
+ @Test
+ void testArithmeticFunctions() {
+ final Expression left = builder.i32(10);
+ final Expression right = builder.i32(20);
+
+ assertNotNull(builder.add(left, right));
+ assertNotNull(builder.subtract(left, right));
+ assertNotNull(builder.multiply(left, right));
+ assertNotNull(builder.divide(left, right));
+ assertNotNull(builder.negate(left));
+ }
+
+ @Test
+ void testBooleanLogic() {
+ final Expression b1 = builder.bool(true);
+ final Expression b2 = builder.bool(false);
+
+ assertNotNull(builder.and(b1, b2));
+ assertNotNull(builder.or(b1, b2));
+ assertNotNull(builder.not(b1));
+ assertNotNull(builder.isNull(b1));
+ }
+ }
+
+ @Nested
+ @DisplayName("Write and Plan Tests")
+ class PlanTests {
+
+ @Test
+ void testNamedWrite() {
+ final NamedScan scan = createSimpleScan();
+ final NamedWrite write =
+ builder.namedWrite(
+ List.of("target_table"),
+ List.of("c1", "c2"),
+ AbstractWriteRel.WriteOp.INSERT,
+ AbstractWriteRel.CreateMode.APPEND_IF_EXISTS,
+ AbstractWriteRel.OutputMode.MODIFIED_RECORDS,
+ scan);
+ assertNotNull(write);
+ }
+
+ @Test
+ void testPlanCreation() {
+ final NamedScan scan = createSimpleScan();
+ final Plan.Root root = builder.root(scan, List.of("out1", "out2"));
+ final Plan plan = builder.plan(root);
+
+ assertNotNull(plan);
+ assertEquals(1, plan.getRoots().size());
+ }
+ }
+
+ // Helper method to create a base relation for testing
+ private NamedScan createSimpleScan() {
+ return builder.namedScan(List.of("t"), List.of("a", "b"), List.of(R.I32, R.STRING));
+ }
+}