diff --git a/lance-spark-3.4_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala b/lance-spark-3.4_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala index 56b84f859..87a2d78a2 100644 --- a/lance-spark-3.4_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala +++ b/lance-spark-3.4_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala @@ -15,7 +15,7 @@ package org.apache.spark.sql.catalyst.parser.extensions import org.apache.spark.sql.catalyst.analysis.{UnresolvedIdentifier, UnresolvedRelation} import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.plans.logical.{AddColumnsBackfill, AddIndex, LanceDropIndex, LanceNamedArgument, LogicalPlan, Optimize, SetUnenforcedPrimaryKey, ShowIndexes, UpdateColumnsBackfill, Vacuum} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumnsBackfill, AddIndex, CreateBranch, CreateTag, DropBranch, DropTag, LanceDropIndex, LanceNamedArgument, LogicalPlan, Optimize, SetUnenforcedPrimaryKey, ShowBranches, ShowIndexes, ShowTags, UpdateColumnsBackfill, Vacuum} import org.lance.spark.utils.ParserUtils import scala.collection.JavaConverters._ @@ -105,6 +105,42 @@ class LanceSqlExtensionsAstBuilder(delegate: ParserInterface) LanceDropIndex(table, indexName) } + override def visitShowBranches(ctx: LanceSqlExtensionsParser.ShowBranchesContext) + : ShowBranches = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + ShowBranches(table) + } + + override def visitCreateBranch(ctx: LanceSqlExtensionsParser.CreateBranchContext) + : CreateBranch = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + val branchName = cleanIdentifier(ctx.branchName.getText) + CreateBranch(table, branchName) + } + + override def visitDropBranch(ctx: LanceSqlExtensionsParser.DropBranchContext): DropBranch = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + val branchName = cleanIdentifier(ctx.branchName.getText) + DropBranch(table, branchName) + } + + override def visitShowTags(ctx: LanceSqlExtensionsParser.ShowTagsContext): ShowTags = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + ShowTags(table) + } + + override def visitCreateTag(ctx: LanceSqlExtensionsParser.CreateTagContext): CreateTag = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + val tagName = cleanIdentifier(ctx.tagName.getText) + CreateTag(table, tagName) + } + + override def visitDropTag(ctx: LanceSqlExtensionsParser.DropTagContext): DropTag = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + val tagName = cleanIdentifier(ctx.tagName.getText) + DropTag(table, tagName) + } + override def visitSetUnenforcedPrimaryKey( ctx: LanceSqlExtensionsParser.SetUnenforcedPrimaryKeyContext): SetUnenforcedPrimaryKey = { val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) diff --git a/lance-spark-3.4_2.12/src/test/java/org/lance/spark/branch/BranchTest.java b/lance-spark-3.4_2.12/src/test/java/org/lance/spark/branch/BranchTest.java new file mode 100755 index 000000000..aaf2975da --- /dev/null +++ b/lance-spark-3.4_2.12/src/test/java/org/lance/spark/branch/BranchTest.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.branch; + +/** Concrete implementation of BaseBranchTest for Spark 3.4. */ +public class BranchTest extends BaseBranchTest { + // All test methods are inherited from BaseBranchTest +} diff --git a/lance-spark-3.4_2.12/src/test/java/org/lance/spark/branch/TagTest.java b/lance-spark-3.4_2.12/src/test/java/org/lance/spark/branch/TagTest.java new file mode 100755 index 000000000..658d916fb --- /dev/null +++ b/lance-spark-3.4_2.12/src/test/java/org/lance/spark/branch/TagTest.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.branch; + +/** Concrete implementation of BaseTagTest for Spark 3.4. */ +public class TagTest extends BaseTagTest { + // All test methods are inherited from BaseTagTest +} diff --git a/lance-spark-3.5_2.12/src/main/java/org/lance/spark/LancePositionDeltaOperation.java b/lance-spark-3.5_2.12/src/main/java/org/lance/spark/LancePositionDeltaOperation.java index 41b07f44a..0717fc692 100644 --- a/lance-spark-3.5_2.12/src/main/java/org/lance/spark/LancePositionDeltaOperation.java +++ b/lance-spark-3.5_2.12/src/main/java/org/lance/spark/LancePositionDeltaOperation.java @@ -90,7 +90,8 @@ public DeltaWriteBuilder newWriteBuilder(LogicalWriteInfo logicalWriteInfo) { .datasetUri(readOptions.getDatasetUri()) .storageOptions(readOptions.getStorageOptions()) .namespace(readOptions.getNamespace()) - .tableId(readOptions.getTableId()); + .tableId(readOptions.getTableId()) + .branchName(readOptions.getBranchName()); if (fileFormatVersion != null) { writeOptionsBuilder.fileFormatVersion(fileFormatVersion); } diff --git a/lance-spark-3.5_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWrite.java b/lance-spark-3.5_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWrite.java index 530bea9cd..ed394eae0 100644 --- a/lance-spark-3.5_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWrite.java +++ b/lance-spark-3.5_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWrite.java @@ -292,7 +292,7 @@ public DeltaWriter createWriter(int partitionId, long taskId) { try (ArrowArrayStream arrowStream = ArrowArrayStream.allocateNew(LanceRuntime.allocator())) { Data.exportArrayStream(LanceRuntime.allocator(), writeBuffer, arrowStream); - return Fragment.create(writeOptions.getDatasetUri(), arrowStream, params); + return Fragment.create(writeOptions.getActualDatasetUri(), arrowStream, params); } }; FutureTask> fragmentCreationTask = diff --git a/lance-spark-3.5_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala b/lance-spark-3.5_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala index 56b84f859..87a2d78a2 100644 --- a/lance-spark-3.5_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala +++ b/lance-spark-3.5_2.12/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala @@ -15,7 +15,7 @@ package org.apache.spark.sql.catalyst.parser.extensions import org.apache.spark.sql.catalyst.analysis.{UnresolvedIdentifier, UnresolvedRelation} import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.plans.logical.{AddColumnsBackfill, AddIndex, LanceDropIndex, LanceNamedArgument, LogicalPlan, Optimize, SetUnenforcedPrimaryKey, ShowIndexes, UpdateColumnsBackfill, Vacuum} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumnsBackfill, AddIndex, CreateBranch, CreateTag, DropBranch, DropTag, LanceDropIndex, LanceNamedArgument, LogicalPlan, Optimize, SetUnenforcedPrimaryKey, ShowBranches, ShowIndexes, ShowTags, UpdateColumnsBackfill, Vacuum} import org.lance.spark.utils.ParserUtils import scala.collection.JavaConverters._ @@ -105,6 +105,42 @@ class LanceSqlExtensionsAstBuilder(delegate: ParserInterface) LanceDropIndex(table, indexName) } + override def visitShowBranches(ctx: LanceSqlExtensionsParser.ShowBranchesContext) + : ShowBranches = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + ShowBranches(table) + } + + override def visitCreateBranch(ctx: LanceSqlExtensionsParser.CreateBranchContext) + : CreateBranch = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + val branchName = cleanIdentifier(ctx.branchName.getText) + CreateBranch(table, branchName) + } + + override def visitDropBranch(ctx: LanceSqlExtensionsParser.DropBranchContext): DropBranch = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + val branchName = cleanIdentifier(ctx.branchName.getText) + DropBranch(table, branchName) + } + + override def visitShowTags(ctx: LanceSqlExtensionsParser.ShowTagsContext): ShowTags = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + ShowTags(table) + } + + override def visitCreateTag(ctx: LanceSqlExtensionsParser.CreateTagContext): CreateTag = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + val tagName = cleanIdentifier(ctx.tagName.getText) + CreateTag(table, tagName) + } + + override def visitDropTag(ctx: LanceSqlExtensionsParser.DropTagContext): DropTag = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + val tagName = cleanIdentifier(ctx.tagName.getText) + DropTag(table, tagName) + } + override def visitSetUnenforcedPrimaryKey( ctx: LanceSqlExtensionsParser.SetUnenforcedPrimaryKeyContext): SetUnenforcedPrimaryKey = { val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) diff --git a/lance-spark-3.5_2.12/src/test/java/org/lance/spark/branch/BranchTest.java b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/branch/BranchTest.java new file mode 100755 index 000000000..4671b72e5 --- /dev/null +++ b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/branch/BranchTest.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.branch; + +/** Concrete implementation of BaseBranchTest for Spark 3.5. */ +public class BranchTest extends BaseBranchTest { + // All test methods are inherited from BaseBranchTest +} diff --git a/lance-spark-3.5_2.12/src/test/java/org/lance/spark/branch/TagTest.java b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/branch/TagTest.java new file mode 100755 index 000000000..eaee782a6 --- /dev/null +++ b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/branch/TagTest.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.branch; + +/** Concrete implementation of BaseTagTest for Spark 3.5. */ +public class TagTest extends BaseTagTest { + // All test methods are inherited from BaseTagTest +} diff --git a/lance-spark-3.5_2.12/src/test/java/org/lance/spark/branch/UpdateBranchTest.java b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/branch/UpdateBranchTest.java new file mode 100644 index 000000000..9497cec2b --- /dev/null +++ b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/branch/UpdateBranchTest.java @@ -0,0 +1,162 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.branch; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class UpdateBranchTest { + protected String catalogName = "lance_test"; + protected String tableName = "branch_test"; + protected String fullTable = catalogName + ".default." + tableName; + + protected SparkSession spark; + + @TempDir Path tempDir; + + @BeforeEach + public void setup() throws IOException { + Path rootPath = tempDir.resolve(UUID.randomUUID().toString()); + Files.createDirectories(rootPath); + String testRoot = rootPath.toString(); + spark = + SparkSession.builder() + .appName("lance-create-index-test") + .master("local[3]") + .config( + "spark.sql.catalog." + catalogName, "org.lance.spark.LanceNamespaceSparkCatalog") + .config( + "spark.sql.extensions", "org.lance.spark.extensions.LanceSparkSessionExtensions") + .config("spark.sql.catalog." + catalogName + ".impl", "dir") + .config("spark.sql.catalog." + catalogName + ".root", testRoot) + .getOrCreate(); + this.tableName = "branch_test_" + UUID.randomUUID().toString().replace("-", ""); + this.fullTable = this.catalogName + ".default." + this.tableName; + } + + @AfterEach + public void tearDown() throws IOException { + if (spark != null) { + spark.close(); + } + } + + private void prepareDataset() { + spark.sql(String.format("create table %s (id int, text string) using lance;", fullTable)); + // First insert to create initial fragments + spark.sql( + String.format( + "insert into %s (id, text) values %s ;", + fullTable, + IntStream.range(0, 10) + .boxed() + .map(i -> String.format("(%d, 'text_%d')", i, i)) + .collect(Collectors.joining(",")))); + // Second insert to ensure multiple fragments + spark.sql( + String.format( + "insert into %s (id, text) values %s ;", + fullTable, + IntStream.range(10, 20) + .boxed() + .map(i -> String.format("(%d, 'text_%d')", i, i)) + .collect(Collectors.joining(",")))); + } + + @Test + public void testUpdate() { + prepareDataset(); + + spark.sql(String.format("alter table %s create branch branch_0", fullTable)).collectAsList(); + + spark.sql( + String.format( + "insert into %s (id, text) values %s ;", + fullTable, + IntStream.range(20, 30) + .boxed() + .map(i -> String.format("(%d, 'text_%d')", i, i)) + .collect(Collectors.joining(",")))); + + spark.sql( + String.format("update %s__branch__branch_0 set text=concat('new_text_',id);", fullTable)); + + List rows = + spark.sql(String.format("select * from %s__branch__branch_0", fullTable)).collectAsList(); + for (Row row : rows) { + Assertions.assertEquals("new_text_" + row.getInt(0), row.getString(1)); + } + } + + @Test + public void testDelete() { + prepareDataset(); + + spark.sql(String.format("alter table %s create branch branch_0", fullTable)).collectAsList(); + + spark.sql( + String.format( + "insert into %s (id, text) values %s ;", + fullTable, + IntStream.range(20, 30) + .boxed() + .map(i -> String.format("(%d, 'text_%d')", i, i)) + .collect(Collectors.joining(",")))); + + spark.sql(String.format("delete from %s__branch__branch_0 where id >= 10", fullTable)); + + List rows = + spark.sql(String.format("select * from %s__branch__branch_0", fullTable)).collectAsList(); + Assertions.assertEquals(10, rows.size()); + } + + @Test + public void testMergeInto() { + prepareDataset(); + + spark.sql(String.format("alter table %s create branch branch_0", fullTable)).collectAsList(); + + spark.sql( + String.format( + "create temporary view v as select id, concat('new_text_',id) as text from %s", + fullTable)); + + spark.sql( + String.format( + "merge into %s__branch__branch_0 as target " + + "using v as source on target.id = source.id " + + "when matched then update set target.text = source.text;", + fullTable)); + + List rows = + spark.sql(String.format("select * from %s__branch__branch_0", fullTable)).collectAsList(); + Assertions.assertEquals(20, rows.size()); + for (Row row : rows) { + Assertions.assertEquals("new_text_" + row.getInt(0), row.getString(1)); + } + } +} diff --git a/lance-spark-4.0_2.13/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala b/lance-spark-4.0_2.13/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala index 9e5d64780..fdf59ebba 100644 --- a/lance-spark-4.0_2.13/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala +++ b/lance-spark-4.0_2.13/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala @@ -15,7 +15,7 @@ package org.apache.spark.sql.catalyst.parser.extensions import org.apache.spark.sql.catalyst.analysis.{UnresolvedIdentifier, UnresolvedRelation} import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.plans.logical.{AddColumnsBackfill, AddIndex, LanceDropIndex, LanceNamedArgument, LogicalPlan, Optimize, SetUnenforcedPrimaryKey, ShowIndexes, UpdateColumnsBackfill, Vacuum} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumnsBackfill, AddIndex, CreateBranch, CreateTag, DropBranch, DropTag, LanceDropIndex, LanceNamedArgument, LogicalPlan, Optimize, SetUnenforcedPrimaryKey, ShowBranches, ShowIndexes, ShowTags, UpdateColumnsBackfill, Vacuum} import org.lance.spark.utils.ParserUtils import scala.jdk.CollectionConverters._ @@ -105,6 +105,42 @@ class LanceSqlExtensionsAstBuilder(delegate: ParserInterface) LanceDropIndex(table, indexName) } + override def visitShowBranches(ctx: LanceSqlExtensionsParser.ShowBranchesContext) + : ShowBranches = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + ShowBranches(table) + } + + override def visitCreateBranch(ctx: LanceSqlExtensionsParser.CreateBranchContext) + : CreateBranch = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + val branchName = cleanIdentifier(ctx.branchName.getText) + CreateBranch(table, branchName) + } + + override def visitDropBranch(ctx: LanceSqlExtensionsParser.DropBranchContext): DropBranch = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + val branchName = cleanIdentifier(ctx.branchName.getText) + DropBranch(table, branchName) + } + + override def visitShowTags(ctx: LanceSqlExtensionsParser.ShowTagsContext): ShowTags = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + ShowTags(table) + } + + override def visitCreateTag(ctx: LanceSqlExtensionsParser.CreateTagContext): CreateTag = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + val tagName = cleanIdentifier(ctx.tagName.getText) + CreateTag(table, tagName) + } + + override def visitDropTag(ctx: LanceSqlExtensionsParser.DropTagContext): DropTag = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + val tagName = cleanIdentifier(ctx.tagName.getText) + DropTag(table, tagName) + } + override def visitSetUnenforcedPrimaryKey( ctx: LanceSqlExtensionsParser.SetUnenforcedPrimaryKeyContext): SetUnenforcedPrimaryKey = { val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) diff --git a/lance-spark-4.1_2.13/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala b/lance-spark-4.1_2.13/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala index 9e5d64780..fdf59ebba 100644 --- a/lance-spark-4.1_2.13/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala +++ b/lance-spark-4.1_2.13/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensionsAstBuilder.scala @@ -15,7 +15,7 @@ package org.apache.spark.sql.catalyst.parser.extensions import org.apache.spark.sql.catalyst.analysis.{UnresolvedIdentifier, UnresolvedRelation} import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.plans.logical.{AddColumnsBackfill, AddIndex, LanceDropIndex, LanceNamedArgument, LogicalPlan, Optimize, SetUnenforcedPrimaryKey, ShowIndexes, UpdateColumnsBackfill, Vacuum} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumnsBackfill, AddIndex, CreateBranch, CreateTag, DropBranch, DropTag, LanceDropIndex, LanceNamedArgument, LogicalPlan, Optimize, SetUnenforcedPrimaryKey, ShowBranches, ShowIndexes, ShowTags, UpdateColumnsBackfill, Vacuum} import org.lance.spark.utils.ParserUtils import scala.jdk.CollectionConverters._ @@ -105,6 +105,42 @@ class LanceSqlExtensionsAstBuilder(delegate: ParserInterface) LanceDropIndex(table, indexName) } + override def visitShowBranches(ctx: LanceSqlExtensionsParser.ShowBranchesContext) + : ShowBranches = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + ShowBranches(table) + } + + override def visitCreateBranch(ctx: LanceSqlExtensionsParser.CreateBranchContext) + : CreateBranch = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + val branchName = cleanIdentifier(ctx.branchName.getText) + CreateBranch(table, branchName) + } + + override def visitDropBranch(ctx: LanceSqlExtensionsParser.DropBranchContext): DropBranch = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + val branchName = cleanIdentifier(ctx.branchName.getText) + DropBranch(table, branchName) + } + + override def visitShowTags(ctx: LanceSqlExtensionsParser.ShowTagsContext): ShowTags = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + ShowTags(table) + } + + override def visitCreateTag(ctx: LanceSqlExtensionsParser.CreateTagContext): CreateTag = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + val tagName = cleanIdentifier(ctx.tagName.getText) + CreateTag(table, tagName) + } + + override def visitDropTag(ctx: LanceSqlExtensionsParser.DropTagContext): DropTag = { + val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) + val tagName = cleanIdentifier(ctx.tagName.getText) + DropTag(table, tagName) + } + override def visitSetUnenforcedPrimaryKey( ctx: LanceSqlExtensionsParser.SetUnenforcedPrimaryKeyContext): SetUnenforcedPrimaryKey = { val table = UnresolvedIdentifier(visitMultipartIdentifier(ctx.multipartIdentifier())) diff --git a/lance-spark-base_2.12/src/main/antlr4/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensions.g4 b/lance-spark-base_2.12/src/main/antlr4/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensions.g4 index 9fa6299e6..25badfa3a 100644 --- a/lance-spark-base_2.12/src/main/antlr4/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensions.g4 +++ b/lance-spark-base_2.12/src/main/antlr4/org/apache/spark/sql/catalyst/parser/extensions/LanceSqlExtensions.g4 @@ -23,7 +23,13 @@ statement | ALTER TABLE multipartIdentifier UPDATE COLUMNS columnList FROM identifier #updateColumnsBackfill | ALTER TABLE multipartIdentifier CREATE INDEX indexName=identifier USING method=identifier '(' columnList ')' (WITH '(' (namedArgument (',' namedArgument)*)? ')')? #createIndex | ALTER TABLE multipartIdentifier DROP INDEX indexName=identifier #dropIndex + | ALTER TABLE multipartIdentifier CREATE BRANCH branchName=identifier #createBranch + | ALTER TABLE multipartIdentifier DROP BRANCH branchName=identifier #dropBranch + | ALTER TABLE multipartIdentifier CREATE TAG tagName=identifier #createTag + | ALTER TABLE multipartIdentifier DROP TAG tagName=identifier #dropTag | SHOW (INDEXES | INDEX) (FROM | IN) multipartIdentifier #showIndexes + | SHOW BRANCHES (FROM | IN) multipartIdentifier #showBranches + | SHOW TAGS (FROM | IN) multipartIdentifier #showTags | OPTIMIZE multipartIdentifier (WITH '(' (namedArgument (',' namedArgument)*)? ')')? #optimize | VACUUM multipartIdentifier (WITH '(' (namedArgument (',' namedArgument)*)? ')')? #vacuum | ALTER TABLE multipartIdentifier SET UNENFORCED PRIMARY KEY '(' columnList ')' #setUnenforcedPrimaryKey @@ -68,6 +74,8 @@ number ADD: 'ADD'; ALTER: 'ALTER'; +BRANCH: 'BRANCH'; +BRANCHES: 'BRANCHES'; COLUMNS: 'COLUMNS'; CREATE: 'CREATE'; DROP: 'DROP'; @@ -81,6 +89,8 @@ PRIMARY: 'PRIMARY'; SET: 'SET'; SHOW: 'SHOW'; TABLE: 'TABLE'; +TAG: 'TAG'; +TAGS: 'TAGS'; UNENFORCED: 'UNENFORCED'; UPDATE: 'UPDATE'; USING: 'USING'; diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/BaseLanceNamespaceSparkCatalog.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/BaseLanceNamespaceSparkCatalog.java index 73bddc9ac..9425de8e9 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/BaseLanceNamespaceSparkCatalog.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/BaseLanceNamespaceSparkCatalog.java @@ -1454,6 +1454,16 @@ private Table loadTableInternal( return loadTableFromPath(ident, timestamp, version); } + // Handle branch-based access + if (isBranchBasedIdentifier(ident)) { + return loadTableAtBranch(ident, timestamp, version); + } + + // Handle tag-based access + if (isTagBasedIdentifier(ident)) { + return loadTableAtTag(ident, timestamp, version); + } + ResolvedTable resolved = resolveIdentifier(ident); DescribeTableResponse describeResponse = resolved.describeResponse; Map initialStorageOptions = describeResponse.getStorageOptions(); @@ -1508,6 +1518,153 @@ private Table loadTableInternal( null); } + private static final String BRANCH_KEY = "__branch__"; + + private static boolean isBranchBasedIdentifier(Identifier identifier) { + return identifier.name().toLowerCase().indexOf(BRANCH_KEY) > 0; + } + + private Table loadTableAtBranch( + Identifier ident, Optional timestamp, Optional version) + throws NoSuchTableException { + int pos = ident.name().indexOf(BRANCH_KEY); + String actualName = ident.name().substring(0, pos); + String branchName = ident.name().substring(pos + BRANCH_KEY.length()); + if (branchName == null || branchName.isBlank()) { + throw new IllegalArgumentException("No specified branch name in:" + ident.name()); + } + + ident = Identifier.of(ident.namespace(), actualName); + Identifier actualIdent = transformIdentifierForApi(ident); + List tableIdList = buildTableId(actualIdent); + DescribeTableRequest describeRequest = new DescribeTableRequest(); + tableIdList.forEach(describeRequest::addIdItem); + DescribeTableResponse describeResponse = describeTableOrThrow(describeRequest, ident); + String location = describeResponse.getLocation(); + LanceSparkReadOptions readOptions = + createReadOptions( + location, + catalogConfig, + Optional.empty(), + Optional.of(namespace), + Optional.of(tableIdList), + name, + Optional.of(branchName), + Optional.empty()); + + Map initialStorageOptions = describeResponse.getStorageOptions(); + + Optional versionId = Optional.empty(); + if (timestamp.isPresent()) { + try (Dataset dataset = Utils.openDatasetBuilder(readOptions).build()) { + versionId = Optional.of(Utils.findVersion(dataset.listVersions(), timestamp.get())); + } catch (TableNotFoundException e) { + throw new NoSuchTableException(ident); + } + } else if (version.isPresent()) { + versionId = Optional.of(Utils.parseVersion(version.get())); + } + + // If time travel requested, rebuild readOptions with the resolved version + LanceSparkReadOptions branchReadOptions; + if (versionId.isPresent()) { + branchReadOptions = + createReadOptions( + describeResponse.getLocation(), + catalogConfig, + versionId, + Optional.of(namespace), + Optional.of(tableIdList), + name, + Optional.of(branchName), + Optional.empty()); + } else { + branchReadOptions = readOptions; + } + + // Read schema, file format version, and config from the dataset + String fileFormatVersion; + StructType schema; + Map tableProperties; + try (Dataset dataset = Utils.openDatasetBuilder(readOptions).build()) { + schema = LanceArrowUtils.fromArrowSchema(dataset.getSchema()); + fileFormatVersion = dataset.getLanceFileFormatVersion(); + tableProperties = dataset.getConfig(); + } + + // Create read options with namespace support + boolean managedVersioning = Boolean.TRUE.equals(describeResponse.getManagedVersioning()); + return createDataset( + branchReadOptions, + schema, + initialStorageOptions, + namespaceImpl, + namespaceProperties, + managedVersioning, + fileFormatVersion, + tableProperties, + null); + } + + private static final String TAG_KEY = "__tag__"; + + private static boolean isTagBasedIdentifier(Identifier identifier) { + return identifier.name().toLowerCase().indexOf(TAG_KEY) > 0; + } + + private Table loadTableAtTag(Identifier ident, Optional timestamp, Optional version) + throws NoSuchTableException { + int pos = ident.name().indexOf(TAG_KEY); + String actualName = ident.name().substring(0, pos); + String tagName = ident.name().substring(pos + TAG_KEY.length()); + if (tagName == null || tagName.isBlank()) { + throw new IllegalArgumentException("No specified tag name in:" + ident.name()); + } + + ident = Identifier.of(ident.namespace(), actualName); + Identifier actualIdent = transformIdentifierForApi(ident); + List tableIdList = buildTableId(actualIdent); + DescribeTableRequest describeRequest = new DescribeTableRequest(); + tableIdList.forEach(describeRequest::addIdItem); + DescribeTableResponse describeResponse = describeTableOrThrow(describeRequest, ident); + String location = describeResponse.getLocation(); + LanceSparkReadOptions tagReadOptions = + createReadOptions( + location, + catalogConfig, + Optional.empty(), + Optional.of(namespace), + Optional.of(tableIdList), + name, + Optional.empty(), + Optional.of(tagName)); + + Map initialStorageOptions = describeResponse.getStorageOptions(); + + // Read schema, file format version, and config from the dataset + String fileFormatVersion; + StructType schema; + Map tableProperties; + try (Dataset dataset = Utils.openDatasetBuilder(tagReadOptions).build()) { + schema = LanceArrowUtils.fromArrowSchema(dataset.getSchema()); + fileFormatVersion = dataset.getLanceFileFormatVersion(); + tableProperties = dataset.getConfig(); + } + + // Create read options with namespace support + boolean managedVersioning = Boolean.TRUE.equals(describeResponse.getManagedVersioning()); + return createDataset( + tagReadOptions, + schema, + initialStorageOptions, + namespaceImpl, + namespaceProperties, + managedVersioning, + fileFormatVersion, + tableProperties, + null); + } + /** * Calls namespace.describeTable and translates table-not-found errors into Spark's {@link * NoSuchTableException}. diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceDataset.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceDataset.java index c3e817dc9..0379a6120 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceDataset.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceDataset.java @@ -336,6 +336,7 @@ public WriteBuilder newWriteBuilder(LogicalWriteInfo logicalWriteInfo) { .datasetUri(readOptions.getDatasetUri()) .namespace(readOptions.getNamespace()) .tableId(readOptions.getTableId()) + .branchName(readOptions.getBranchName()) .fromOptions(mergedOptions); // Use table's file format version if not explicitly set in write options if (!mergedOptions.containsKey(LanceSparkWriteOptions.CONFIG_FILE_FORMAT_VERSION) diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkReadOptions.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkReadOptions.java index 3cb835d42..2e4074828 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkReadOptions.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkReadOptions.java @@ -16,6 +16,7 @@ import org.lance.ReadOptions; import org.lance.ipc.Query; import org.lance.namespace.LanceNamespace; +import org.lance.spark.utils.Optional; import org.lance.spark.utils.QueryUtils; import com.google.common.base.Preconditions; @@ -105,6 +106,8 @@ public class LanceSparkReadOptions implements Serializable { private final String datasetUri; private final String dbPath; private final String datasetName; + private final Optional branchName; + private final Optional tagName; private final boolean pushDownFilters; private final Integer blockSize; private final Integer version; @@ -135,6 +138,8 @@ private LanceSparkReadOptions(Builder builder) { String[] paths = extractDbPathAndDatasetName(datasetUri); this.dbPath = paths[0]; this.datasetName = paths[1]; + this.branchName = builder.branchName; + this.tagName = builder.tagName; this.pushDownFilters = builder.pushDownFilters; this.blockSize = builder.blockSize; this.version = builder.version; @@ -230,6 +235,14 @@ public String getDatasetName() { return datasetName; } + public Optional getBranchName() { + return branchName; + } + + public Optional getTagName() { + return tagName; + } + public boolean isPushDownFilters() { return pushDownFilters; } @@ -326,6 +339,8 @@ public LanceSparkReadOptions withVersion(int newVersion) { .storageOptions(this.storageOptions) .namespace(this.namespace) .tableId(this.tableId) + .branchName(this.branchName) + .tagName(this.tagName) .catalogName(this.catalogName) .executorCredentialRefresh(this.executorCredentialRefresh) .build(); @@ -408,6 +423,8 @@ public int hashCode() { /** Builder for creating LanceSparkReadOptions instances. */ public static class Builder { private String datasetUri; + private Optional branchName = Optional.empty(); + private Optional tagName = Optional.empty(); private boolean pushDownFilters = DEFAULT_PUSH_DOWN_FILTERS; private Integer blockSize; private Query nearest; @@ -429,6 +446,16 @@ public Builder datasetUri(String datasetUri) { return this; } + public Builder branchName(Optional branchName) { + this.branchName = branchName; + return this; + } + + public Builder tagName(Optional tagName) { + this.tagName = tagName; + return this; + } + public Builder pushDownFilters(boolean pushDownFilters) { this.pushDownFilters = pushDownFilters; return this; diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkWriteOptions.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkWriteOptions.java index 0e0f6db06..d64c95870 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkWriteOptions.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkWriteOptions.java @@ -16,6 +16,7 @@ import org.lance.WriteParams; import org.lance.WriteParams.WriteMode; import org.lance.namespace.LanceNamespace; +import org.lance.spark.utils.Optional; import com.google.common.base.Preconditions; @@ -70,6 +71,7 @@ public class LanceSparkWriteOptions implements Serializable { public static final long DEFAULT_MAX_BATCH_BYTES = 256L * 1024 * 1024; private final String datasetUri; + private final Optional branchName; private final WriteMode writeMode; private final Integer maxRowsPerFile; private final Integer maxRowsPerGroup; @@ -98,6 +100,7 @@ public class LanceSparkWriteOptions implements Serializable { private LanceSparkWriteOptions(Builder builder) { this.datasetUri = builder.datasetUri; + this.branchName = builder.branchName; this.writeMode = builder.writeMode; this.maxRowsPerFile = builder.maxRowsPerFile; this.maxRowsPerGroup = builder.maxRowsPerGroup; @@ -148,6 +151,17 @@ public String getDatasetUri() { return datasetUri; } + public String getActualDatasetUri() { + if (branchName.isEmpty()) { + return getDatasetUri(); + } + return getDatasetUri() + "/tree/" + branchName.get(); + } + + public Optional getBranchName() { + return branchName; + } + public WriteMode getWriteMode() { return writeMode; } @@ -218,6 +232,7 @@ public Long getVersion() { public Builder toBuilder() { return builder() .datasetUri(datasetUri) + .branchName(branchName) .writeMode(writeMode) .maxRowsPerFile(maxRowsPerFile) .maxRowsPerGroup(maxRowsPerGroup) @@ -352,6 +367,7 @@ public int hashCode() { /** Builder for creating LanceSparkWriteOptions instances. */ public static class Builder { private String datasetUri; + private Optional branchName = Optional.empty(); private WriteMode writeMode = DEFAULT_WRITE_MODE; private Integer maxRowsPerFile; private Integer maxRowsPerGroup; @@ -376,6 +392,11 @@ public Builder datasetUri(String datasetUri) { return this; } + public Builder branchName(Optional branchName) { + this.branchName = branchName; + return this; + } + public Builder writeMode(WriteMode writeMode) { this.writeMode = writeMode; return this; diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/Utils.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/Utils.java index eb019c646..37dd9ffa9 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/Utils.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/Utils.java @@ -15,6 +15,7 @@ import org.lance.Dataset; import org.lance.ReadOptions; +import org.lance.Ref; import org.lance.Version; import org.lance.namespace.LanceNamespace; import org.lance.spark.LanceRuntime; @@ -70,6 +71,8 @@ public static class OpenDatasetBuilder { private final String uri; private final LanceNamespace namespace; private final List tableId; + private final Optional branchName; + private final Optional tagName; private final Map storageOptions; private final String catalogName; private final Long version; @@ -89,6 +92,8 @@ private OpenDatasetBuilder(LanceSparkReadOptions opts) { this.catalogName = opts.getCatalogName(); this.namespace = opts.getNamespace(); this.tableId = opts.getTableId(); + this.branchName = opts.getBranchName(); + this.tagName = opts.getTagName(); this.blockSize = opts.getBlockSize(); this.indexCacheSize = opts.getIndexCacheSize(); this.metadataCacheSize = opts.getMetadataCacheSize(); @@ -99,6 +104,8 @@ private OpenDatasetBuilder(LanceSparkWriteOptions opts) { this.storageOptions = opts.getStorageOptions(); this.namespace = opts.getNamespace(); this.tableId = opts.getTableId(); + this.branchName = opts.getBranchName(); + this.tagName = Optional.empty(); this.catalogName = null; this.version = opts.getVersion(); this.blockSize = null; @@ -130,7 +137,7 @@ public Dataset build() { .setStorageOptions(merged) .setSession( catalogName != null ? LanceRuntime.session(catalogName) : LanceRuntime.session()); - if (version != null) { + if (version != null && branchName.isEmpty()) { roBuilder.setVersion(version); } if (blockSize != null) { @@ -143,32 +150,56 @@ public Dataset build() { roBuilder.setMetadataCacheSize(metadataCacheSize); } + Dataset ds = null; if (namespace != null && tableId != null) { - return Dataset.open() - .allocator(LanceRuntime.allocator()) - .namespaceClient(namespace) - .tableId(tableId) - .readOptions(roBuilder.build()) - .build(); - } - if (runtimeNamespaceImpl != null) { + ds = + Dataset.open() + .allocator(LanceRuntime.allocator()) + .namespaceClient(namespace) + .tableId(tableId) + .readOptions(roBuilder.build()) + .build(); + } else if (runtimeNamespaceImpl != null) { LanceNamespace runtimeNamespace = LanceRuntime.getOrCreateNamespace(runtimeNamespaceImpl, runtimeNamespaceProperties); List effectiveTableId = runtimeTableId != null ? runtimeTableId : tableId; if (runtimeNamespace != null && effectiveTableId != null) { - return Dataset.open() - .allocator(LanceRuntime.allocator()) - .namespaceClient(runtimeNamespace) - .tableId(effectiveTableId) - .readOptions(roBuilder.build()) - .build(); + ds = + Dataset.open() + .allocator(LanceRuntime.allocator()) + .namespaceClient(runtimeNamespace) + .tableId(effectiveTableId) + .readOptions(roBuilder.build()) + .build(); + } + } + if (ds == null) { + ds = + Dataset.open() + .allocator(LanceRuntime.allocator()) + .uri(uri) + .readOptions(roBuilder.build()) + .build(); + } + + Ref ref = null; + if (branchName.isPresent()) { + if (version != null) { + ref = Ref.ofBranch(branchName.get(), version); + } else { + ref = Ref.ofBranch(branchName.get()); } + } else if (tagName.isPresent()) { + ref = Ref.ofTag(tagName.get()); } - return Dataset.open() - .allocator(LanceRuntime.allocator()) - .uri(uri) - .readOptions(roBuilder.build()) - .build(); + + if (ref != null) { + Dataset newDs = ds.checkout(ref); + ds.close(); + return newDs; + } + + return ds; } } @@ -209,6 +240,49 @@ public static LanceSparkReadOptions createReadOptions( return builder.build(); } + /** + * Creates LanceSparkReadOptions for this catalog. + * + * @param location the dataset URI + * @param catalogConfig catalog configuration + * @param versionId optional dataset version id + * @param namespace optional namespace for credential vending + * @param tableId optional table identifier + * @param catalogName catalog name for cache isolation + * @param branchName branch name + * @param tagName tag name + * @return a new LanceSparkReadOptions with catalog settings + */ + public static LanceSparkReadOptions createReadOptions( + String location, + LanceSparkCatalogConfig catalogConfig, + Optional versionId, + Optional namespace, + Optional> tableId, + String catalogName, + Optional branchName, + Optional tagName) { + LanceSparkReadOptions.Builder builder = + LanceSparkReadOptions.builder() + .datasetUri(location) + .withCatalogDefaults(catalogConfig) + .catalogName(catalogName); + + if (versionId.isPresent()) { + builder.version(versionId.get().intValue()); + } + if (tableId.isPresent()) { + builder.tableId(tableId.get()); + } + if (namespace.isPresent()) { + builder.namespace(namespace.get()); + } + builder.branchName(branchName); + builder.tagName(tagName); + + return builder.build(); + } + // Determine if the timestamp is in microseconds or nanoseconds and convert to Instant private static Instant instantFromTimestamp(long timestamp) { if (timestamp <= 0) { diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceDataWriter.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceDataWriter.java index d515ef354..9a58fef9b 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceDataWriter.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceDataWriter.java @@ -305,7 +305,7 @@ private BufferAndTask buildBufferAndTask(BlobReferenceResolver resolver) { try (ArrowArrayStream arrowStream = ArrowArrayStream.allocateNew(LanceRuntime.allocator())) { Data.exportArrayStream(LanceRuntime.allocator(), bufferRef, arrowStream); - return Fragment.create(writeOptions.getDatasetUri(), arrowStream, params); + return Fragment.create(writeOptions.getActualDatasetUri(), arrowStream, params); } }; FutureTask> task = writeBuffer.createTrackedTask(fragmentCreator); diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SparkWrite.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SparkWrite.java index 7e5f6bd30..007e561b4 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SparkWrite.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SparkWrite.java @@ -246,6 +246,7 @@ public Write build() { .storageOptions(writeOptions.getStorageOptions()) .namespace(writeOptions.getNamespace()) .tableId(writeOptions.getTableId()) + .branchName(writeOptions.getBranchName()) .batchSize(writeOptions.getBatchSize()) .datasetUri(writeOptions.getDatasetUri()) .fileFormatVersion(writeOptions.getFileFormatVersion()) diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateBranch.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateBranch.scala new file mode 100755 index 000000000..1a71bdd76 --- /dev/null +++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateBranch.scala @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.types.{DataTypes, StructField, StructType} + +case class CreateBranch( + table: LogicalPlan, + branchName: String) extends Command { + + override def children: Seq[LogicalPlan] = Seq(table) + + override def output: Seq[Attribute] = CreateBranchOutputType.SCHEMA + + override def simpleString(maxFields: Int): String = { + s"CreateBranch(${branchName})" + } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]) + : CreateBranch = { + copy(newChildren(0), this.branchName) + } +} + +object CreateBranchOutputType { + val SCHEMA = StructType( + Array( + StructField("name", DataTypes.StringType, nullable = false))) + .map(field => AttributeReference(field.name, field.dataType, field.nullable, field.metadata)()) +} diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateTag.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateTag.scala new file mode 100755 index 000000000..6b05ef386 --- /dev/null +++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateTag.scala @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.types.{DataTypes, StructField, StructType} + +case class CreateTag( + table: LogicalPlan, + tagName: String) extends Command { + + override def children: Seq[LogicalPlan] = Seq(table) + + override def output: Seq[Attribute] = CreateTagOutputType.SCHEMA + + override def simpleString(maxFields: Int): String = { + s"CreateTag(${tagName})" + } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]) + : CreateTag = { + copy(newChildren(0), this.tagName) + } +} + +object CreateTagOutputType { + val SCHEMA = StructType( + Array( + StructField("name", DataTypes.StringType, nullable = false))) + .map(field => AttributeReference(field.name, field.dataType, field.nullable, field.metadata)()) +} diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropBranch.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropBranch.scala new file mode 100755 index 000000000..f9a95707b --- /dev/null +++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropBranch.scala @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.types.{DataTypes, StructField, StructType} + +case class DropBranch( + table: LogicalPlan, + branchName: String) extends Command { + + override def children: Seq[LogicalPlan] = Seq(table) + + override def output: Seq[Attribute] = DropBranchOutputType.SCHEMA + + override def simpleString(maxFields: Int): String = { + s"DropBranch(${branchName})" + } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]) + : DropBranch = { + copy(newChildren(0), this.branchName) + } +} + +object DropBranchOutputType { + val SCHEMA = StructType( + Array( + StructField("name", DataTypes.StringType, nullable = false))) + .map(field => AttributeReference(field.name, field.dataType, field.nullable, field.metadata)()) +} diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropTag.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropTag.scala new file mode 100755 index 000000000..a9b2278f6 --- /dev/null +++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropTag.scala @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.types.{DataTypes, StructField, StructType} + +case class DropTag( + table: LogicalPlan, + tagName: String) extends Command { + + override def children: Seq[LogicalPlan] = Seq(table) + + override def output: Seq[Attribute] = DropTagOutputType.SCHEMA + + override def simpleString(maxFields: Int): String = { + s"DropTag(${tagName})" + } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): DropTag = { + copy(newChildren(0), this.tagName) + } +} + +object DropTagOutputType { + val SCHEMA = StructType( + Array( + StructField("name", DataTypes.StringType, nullable = false))) + .map(field => AttributeReference(field.name, field.dataType, field.nullable, field.metadata)()) +} diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ShowBranches.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ShowBranches.scala new file mode 100755 index 000000000..0580bc751 --- /dev/null +++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ShowBranches.scala @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.types.{DataTypes, StructField, StructType} + +case class ShowBranches( + table: LogicalPlan) extends Command { + + override def children: Seq[LogicalPlan] = Seq(table) + + override def output: Seq[Attribute] = ShowBranchesOutputType.SCHEMA + + override def simpleString(maxFields: Int): String = { + s"ShowBranches" + } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]) + : ShowBranches = { + copy(newChildren(0)) + } +} + +object ShowBranchesOutputType { + val SCHEMA = StructType( + Array( + StructField("name", DataTypes.StringType, nullable = false), + StructField("parent_branch", DataTypes.StringType, nullable = true), + StructField("parent_version", DataTypes.LongType, nullable = false), + StructField("create_at", DataTypes.StringType, nullable = false), + StructField("manifest_size", DataTypes.IntegerType, nullable = false))) + .map(field => AttributeReference(field.name, field.dataType, field.nullable, field.metadata)()) +} diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ShowTags.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ShowTags.scala new file mode 100755 index 000000000..4ba429bc7 --- /dev/null +++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ShowTags.scala @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.types.{DataTypes, StructField, StructType} + +case class ShowTags( + table: LogicalPlan) extends Command { + + override def children: Seq[LogicalPlan] = Seq(table) + + override def output: Seq[Attribute] = ShowTagsOutputType.SCHEMA + + override def simpleString(maxFields: Int): String = { + s"ShowTags" + } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): ShowTags = { + copy(newChildren(0)) + } +} + +object ShowTagsOutputType { + val SCHEMA = StructType( + Array( + StructField("name", DataTypes.StringType, nullable = false), + StructField("version", DataTypes.LongType, nullable = false), + StructField("manifest_size", DataTypes.IntegerType, nullable = false))) + .map(field => AttributeReference(field.name, field.dataType, field.nullable, field.metadata)()) +} diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateBranchExec.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateBranchExec.scala new file mode 100755 index 000000000..e5e2dce04 --- /dev/null +++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateBranchExec.scala @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow} +import org.apache.spark.sql.catalyst.plans.logical.CreateBranchOutputType +import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} +import org.apache.spark.unsafe.types.UTF8String +import org.lance.Ref +import org.lance.spark.LanceDataset +import org.lance.spark.utils.Utils + +case class CreateBranchExec( + catalog: TableCatalog, + ident: Identifier, + branchName: String) extends LeafV2CommandExec { + + override def output: Seq[Attribute] = CreateBranchOutputType.SCHEMA + + override protected def run(): Seq[InternalRow] = { + val lanceDataset = catalog.loadTable(ident) match { + case d: LanceDataset => d + case _ => throw new UnsupportedOperationException("CreateBranch only supports LanceDataset") + } + + val dataset = Utils.openDatasetBuilder(lanceDataset.readOptions()).build() + try { + val branchDs = dataset.createBranch(branchName, Ref.ofMain(dataset.version())) + branchDs.close() + } finally { + dataset.close() + } + + Seq(new GenericInternalRow(Array[Any]( + UTF8String.fromString(branchName)))) + } +} diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTagExec.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTagExec.scala new file mode 100755 index 000000000..aac2b280b --- /dev/null +++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTagExec.scala @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow} +import org.apache.spark.sql.catalyst.plans.logical.CreateTagOutputType +import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} +import org.apache.spark.unsafe.types.UTF8String +import org.lance.spark.LanceDataset +import org.lance.spark.utils.Utils + +case class CreateTagExec( + catalog: TableCatalog, + ident: Identifier, + tagName: String) extends LeafV2CommandExec { + + override def output: Seq[Attribute] = CreateTagOutputType.SCHEMA + + override protected def run(): Seq[InternalRow] = { + val lanceDataset = catalog.loadTable(ident) match { + case d: LanceDataset => d + case _ => throw new UnsupportedOperationException("CreateTag only supports LanceDataset") + } + + val dataset = Utils.openDatasetBuilder(lanceDataset.readOptions()).build() + try { + dataset.tags().create(tagName, dataset.version()) + } finally { + dataset.close() + } + + Seq(new GenericInternalRow(Array[Any]( + UTF8String.fromString(tagName)))) + } +} diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropBranchExec.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropBranchExec.scala new file mode 100755 index 000000000..ebd19985f --- /dev/null +++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropBranchExec.scala @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow} +import org.apache.spark.sql.catalyst.plans.logical.DropBranchOutputType +import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} +import org.apache.spark.unsafe.types.UTF8String +import org.lance.spark.LanceDataset +import org.lance.spark.utils.Utils + +case class DropBranchExec( + catalog: TableCatalog, + ident: Identifier, + branchName: String) extends LeafV2CommandExec { + + override def output: Seq[Attribute] = DropBranchOutputType.SCHEMA + + override protected def run(): Seq[InternalRow] = { + val lanceDataset = catalog.loadTable(ident) match { + case d: LanceDataset => d + case _ => throw new UnsupportedOperationException("DropBranch only supports LanceDataset") + } + + val dataset = Utils.openDatasetBuilder(lanceDataset.readOptions()).build() + try { + dataset.branches().delete(branchName) + } finally { + dataset.close() + } + + Seq(new GenericInternalRow(Array[Any]( + UTF8String.fromString(branchName)))) + } +} diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTagExec.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTagExec.scala new file mode 100755 index 000000000..5ac864f2b --- /dev/null +++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTagExec.scala @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow} +import org.apache.spark.sql.catalyst.plans.logical.DropTagOutputType +import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} +import org.apache.spark.unsafe.types.UTF8String +import org.lance.spark.LanceDataset +import org.lance.spark.utils.Utils + +case class DropTagExec( + catalog: TableCatalog, + ident: Identifier, + tagName: String) extends LeafV2CommandExec { + + override def output: Seq[Attribute] = DropTagOutputType.SCHEMA + + override protected def run(): Seq[InternalRow] = { + val lanceDataset = catalog.loadTable(ident) match { + case d: LanceDataset => d + case _ => throw new UnsupportedOperationException("DropTag only supports LanceDataset") + } + + val dataset = Utils.openDatasetBuilder(lanceDataset.readOptions()).build() + try { + dataset.tags().delete(tagName) + } finally { + dataset.close() + } + + Seq(new GenericInternalRow(Array[Any]( + UTF8String.fromString(tagName)))) + } +} diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/LanceDataSourceV2Strategy.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/LanceDataSourceV2Strategy.scala index fb0556824..33e00d657 100644 --- a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/LanceDataSourceV2Strategy.scala +++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/LanceDataSourceV2Strategy.scala @@ -51,6 +51,24 @@ case class LanceDataSourceV2Strategy(session: SparkSession) extends SparkStrateg case LanceDropIndex(ResolvedIdentifier(catalog, ident), indexName) => LanceDropIndexExec(asTableCatalog(catalog), ident, indexName.toLowerCase) :: Nil + case CreateBranch(ResolvedIdentifier(catalog, ident), branchName) => + CreateBranchExec(asTableCatalog(catalog), ident, branchName) :: Nil + + case DropBranch(ResolvedIdentifier(catalog, ident), branchName) => + DropBranchExec(asTableCatalog(catalog), ident, branchName) :: Nil + + case CreateTag(ResolvedIdentifier(catalog, ident), tagName) => + CreateTagExec(asTableCatalog(catalog), ident, tagName) :: Nil + + case DropTag(ResolvedIdentifier(catalog, ident), tagName) => + DropTagExec(asTableCatalog(catalog), ident, tagName) :: Nil + + case ShowBranches(ResolvedIdentifier(catalog, ident)) => + ShowBranchesExec(asTableCatalog(catalog), ident) :: Nil + + case ShowTags(ResolvedIdentifier(catalog, ident)) => + ShowTagsExec(asTableCatalog(catalog), ident) :: Nil + case SetUnenforcedPrimaryKey(ResolvedIdentifier(catalog, ident), columns) => SetUnenforcedPrimaryKeyExec(asTableCatalog(catalog), ident, columns) :: Nil diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowBranchesExec.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowBranchesExec.scala new file mode 100755 index 000000000..88e8dc2c0 --- /dev/null +++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowBranchesExec.scala @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow} +import org.apache.spark.sql.catalyst.plans.logical.ShowBranchesOutputType +import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} +import org.apache.spark.unsafe.types.UTF8String +import org.lance.spark.LanceDataset +import org.lance.spark.utils.Utils + +import java.time.Instant +import java.time.ZoneId +import java.time.format.DateTimeFormatter + +import scala.collection.JavaConverters._ + +case class ShowBranchesExec( + catalog: TableCatalog, + ident: Identifier) extends LeafV2CommandExec { + + override def output: Seq[Attribute] = ShowBranchesOutputType.SCHEMA + + override protected def run(): Seq[InternalRow] = { + val lanceDataset = catalog.loadTable(ident) match { + case d: LanceDataset => d + case _ => throw new UnsupportedOperationException("ShowBranches only supports LanceDataset") + } + + val formatter = + DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss XXX").withZone(ZoneId.systemDefault()) + + val dataset = Utils.openDatasetBuilder(lanceDataset.readOptions()).build() + try { + dataset.branches().list().asScala + .sortBy(_.getCreateAt) + .map { branch => + new GenericInternalRow(Array[Any]( + UTF8String.fromString(branch.getName), + UTF8String.fromString(branch.getParentBranch.orElse("")), + branch.getParentVersion, + UTF8String.fromString(Instant.ofEpochMilli(branch.getCreateAt * 1000).atZone( + ZoneId.systemDefault()).format(formatter)), + branch.getManifestSize)) + }.toSeq + } finally { + dataset.close() + } + } +} diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTagsExec.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTagsExec.scala new file mode 100755 index 000000000..8c231c480 --- /dev/null +++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTagsExec.scala @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow} +import org.apache.spark.sql.catalyst.plans.logical.ShowTagsOutputType +import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} +import org.apache.spark.unsafe.types.UTF8String +import org.lance.spark.LanceDataset +import org.lance.spark.utils.Utils + +import scala.collection.JavaConverters._ + +case class ShowTagsExec( + catalog: TableCatalog, + ident: Identifier) extends LeafV2CommandExec { + + override def output: Seq[Attribute] = ShowTagsOutputType.SCHEMA + + override protected def run(): Seq[InternalRow] = { + val lanceDataset = catalog.loadTable(ident) match { + case d: LanceDataset => d + case _ => throw new UnsupportedOperationException("ShowTags only supports LanceDataset") + } + + val dataset = Utils.openDatasetBuilder(lanceDataset.readOptions()).build() + try { + dataset.tags().list().asScala + .sortBy(_.getVersion) + .map { tag => + new GenericInternalRow(Array[Any]( + UTF8String.fromString(tag.getName), + tag.getVersion, + tag.getManifestSize)) + }.toSeq + } finally { + dataset.close() + } + } +} diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/branch/BaseBranchTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/branch/BaseBranchTest.java new file mode 100755 index 000000000..ae03156da --- /dev/null +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/branch/BaseBranchTest.java @@ -0,0 +1,239 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.branch; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public abstract class BaseBranchTest { + protected String catalogName = "lance_test"; + protected String tableName = "branch_test"; + protected String fullTable = catalogName + ".default." + tableName; + + protected SparkSession spark; + + @TempDir Path tempDir; + + @BeforeEach + public void setup() throws IOException { + Path rootPath = tempDir.resolve(UUID.randomUUID().toString()); + Files.createDirectories(rootPath); + String testRoot = rootPath.toString(); + spark = + SparkSession.builder() + .appName("lance-branch-test") + .master("local[10]") + .config( + "spark.sql.catalog." + catalogName, "org.lance.spark.LanceNamespaceSparkCatalog") + .config( + "spark.sql.extensions", "org.lance.spark.extensions.LanceSparkSessionExtensions") + .config("spark.sql.catalog." + catalogName + ".impl", "dir") + .config("spark.sql.catalog." + catalogName + ".root", testRoot) + .getOrCreate(); + this.tableName = "branch_test_" + UUID.randomUUID().toString().replace("-", ""); + this.fullTable = this.catalogName + ".default." + this.tableName; + } + + @AfterEach + public void tearDown() throws IOException { + if (spark != null) { + spark.close(); + } + } + + private void prepareDataset() { + spark.sql(String.format("create table %s (id int, text string) using lance;", fullTable)); + // First insert to create initial fragments + spark.sql( + String.format( + "insert into %s (id, text) values %s ;", + fullTable, + IntStream.range(0, 10) + .boxed() + .map(i -> String.format("(%d, 'text_%d')", i, i)) + .collect(Collectors.joining(",")))); + // Second insert to ensure multiple fragments + spark.sql( + String.format( + "insert into %s (id, text) values %s ;", + fullTable, + IntStream.range(10, 20) + .boxed() + .map(i -> String.format("(%d, 'text_%d')", i, i)) + .collect(Collectors.joining(",")))); + } + + @Test + public void testCreateBranch() { + prepareDataset(); + + spark.sql(String.format("alter table %s create branch branch_0", fullTable)).collectAsList(); + for (Row row : spark.sql(String.format("show branches in %s", fullTable)).collectAsList()) { + Assertions.assertEquals("branch_0", row.getString(0)); + Assertions.assertEquals("", row.getString(1)); + Assertions.assertTrue(row.getLong(2) > 0); + Assertions.assertNotNull(row.getString(3)); + Assertions.assertTrue(row.getInt(4) > 0); + } + + spark.sql( + String.format( + "insert into %s (id, text) values %s ;", + fullTable, + IntStream.range(10, 20) + .boxed() + .map(i -> String.format("(%d, 'text_%d')", i, i)) + .collect(Collectors.joining(",")))); + + spark.sql(String.format("alter table %s create branch branch_1", fullTable)).collectAsList(); + + List rows = spark.sql(String.format("show branches from %s", fullTable)).collectAsList(); + Assertions.assertEquals(2, rows.size()); + } + + @Test + public void testQueryFromBranch() { + prepareDataset(); + + spark.sql(String.format("alter table %s create branch branch_0", fullTable)).collectAsList(); + + spark.sql( + String.format( + "insert into %s (id, text) values %s ;", + fullTable, + IntStream.range(20, 30) + .boxed() + .map(i -> String.format("(%d, 'text_%d')", i, i)) + .collect(Collectors.joining(",")))); + + List rows = + spark + .sql(String.format("select * from %s__branch__branch_0 where id >= 10", fullTable)) + .collectAsList(); + Assertions.assertEquals(10, rows.size()); + + spark.sql(String.format("alter table %s create branch branch_1", fullTable)).collectAsList(); + + rows = + spark.sql(String.format("select * from %s__branch__branch_1", fullTable)).collectAsList(); + Assertions.assertEquals(30, rows.size()); + } + + @Test + public void testDropBranch() { + prepareDataset(); + + spark.sql(String.format("alter table %s create branch branch_0", fullTable)); + for (Row row : spark.sql(String.format("show branches from %s", fullTable)).collectAsList()) { + Assertions.assertEquals("branch_0", row.getString(0)); + Assertions.assertEquals("", row.getString(1)); + Assertions.assertTrue(row.getLong(2) > 0); + Assertions.assertNotNull(row.getString(3)); + Assertions.assertTrue(row.getInt(4) > 0); + } + + spark.sql(String.format("alter table %s drop branch branch_0", fullTable)); + Assertions.assertEquals( + 0, spark.sql(String.format("show branches from %s", fullTable)).collectAsList().size()); + + spark.sql( + String.format( + "insert into %s (id, text) values %s ;", + fullTable, + IntStream.range(20, 30) + .boxed() + .map(i -> String.format("(%d, 'text_%d')", i, i)) + .collect(Collectors.joining(",")))); + + spark.sql(String.format("alter table %s create branch branch_1", fullTable)).collectAsList(); + + List rows = spark.sql(String.format("show branches from %s", fullTable)).collectAsList(); + Assertions.assertEquals(1, rows.size()); + + for (Row row : rows) { + Assertions.assertEquals("branch_1", row.getString(0)); + Assertions.assertEquals("", row.getString(1)); + Assertions.assertTrue(row.getLong(2) > 0); + Assertions.assertNotNull(row.getString(3)); + Assertions.assertTrue(row.getInt(4) > 0); + } + } + + @Test + public void testInsert() { + prepareDataset(); + + spark.sql(String.format("alter table %s create branch branch_0", fullTable)); + + spark.sql( + String.format( + "insert into %s__branch__branch_0 (id, text) values %s ;", + fullTable, + IntStream.range(20, 30) + .boxed() + .map(i -> String.format("(%d, 'text_%d')", i, i)) + .collect(Collectors.joining(",")))); + + List rows = + spark + .sql(String.format("select * from %s__branch__branch_0 where id >= 20", fullTable)) + .collectAsList(); + Assertions.assertEquals(10, rows.size()); + + for (Row row : rows) { + Assertions.assertEquals("text_" + row.getInt(0), row.getString(1)); + } + + // main branch does not change + rows = spark.sql(String.format("select * from %s", fullTable)).collectAsList(); + Assertions.assertEquals(20, rows.size()); + } + + @Test + public void testAddColumnsWithBackFill() { + prepareDataset(); + + spark.sql(String.format("alter table %s create branch branch_0", fullTable)).collectAsList(); + + spark.sql( + String.format( + "create temporary view v " + + "as select _rowaddr, _fragid, concat('new_text_',id) as new_text " + + "from %s__branch__branch_0", + fullTable)); + + spark.sql( + String.format("alter table %s__branch__branch_0 add columns new_text from v", fullTable)); + + List rows = + spark.sql(String.format("select * from %s__branch__branch_0", fullTable)).collectAsList(); + Assertions.assertEquals(20, rows.size()); + for (Row row : rows) { + Assertions.assertEquals("new_text_" + row.getInt(0), row.getString(2)); + } + } +} diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/branch/BaseTagTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/branch/BaseTagTest.java new file mode 100755 index 000000000..fbea203e4 --- /dev/null +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/branch/BaseTagTest.java @@ -0,0 +1,178 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.branch; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public abstract class BaseTagTest { + protected String catalogName = "lance_test"; + protected String tableName = "tag_test"; + protected String fullTable = catalogName + ".default." + tableName; + + protected SparkSession spark; + + @TempDir Path tempDir; + + @BeforeEach + public void setup() throws IOException { + Path rootPath = tempDir.resolve(UUID.randomUUID().toString()); + Files.createDirectories(rootPath); + String testRoot = rootPath.toString(); + spark = + SparkSession.builder() + .appName("lance-tag-test") + .master("local[10]") + .config( + "spark.sql.catalog." + catalogName, "org.lance.spark.LanceNamespaceSparkCatalog") + .config( + "spark.sql.extensions", "org.lance.spark.extensions.LanceSparkSessionExtensions") + .config("spark.sql.catalog." + catalogName + ".impl", "dir") + .config("spark.sql.catalog." + catalogName + ".root", testRoot) + .getOrCreate(); + this.tableName = "tag_test_" + UUID.randomUUID().toString().replace("-", ""); + this.fullTable = this.catalogName + ".default." + this.tableName; + } + + @AfterEach + public void tearDown() throws IOException { + if (spark != null) { + spark.close(); + } + } + + private void prepareDataset() { + spark.sql(String.format("create table %s (id int, text string) using lance;", fullTable)); + // First insert to create initial fragments + spark.sql( + String.format( + "insert into %s (id, text) values %s ;", + fullTable, + IntStream.range(0, 10) + .boxed() + .map(i -> String.format("(%d, 'text_%d')", i, i)) + .collect(Collectors.joining(",")))); + // Second insert to ensure multiple fragments + spark.sql( + String.format( + "insert into %s (id, text) values %s ;", + fullTable, + IntStream.range(10, 20) + .boxed() + .map(i -> String.format("(%d, 'text_%d')", i, i)) + .collect(Collectors.joining(",")))); + } + + @Test + public void testCreateTag() { + prepareDataset(); + + spark.sql(String.format("alter table %s create tag tag_0", fullTable)).collectAsList(); + for (Row row : spark.sql(String.format("show tags in %s", fullTable)).collectAsList()) { + Assertions.assertEquals("tag_0", row.getString(0)); + Assertions.assertTrue(row.getLong(1) > 0); + Assertions.assertTrue(row.getInt(2) > 0); + } + + spark.sql( + String.format( + "insert into %s (id, text) values %s ;", + fullTable, + IntStream.range(10, 20) + .boxed() + .map(i -> String.format("(%d, 'text_%d')", i, i)) + .collect(Collectors.joining(",")))); + + spark.sql(String.format("alter table %s create tag tag_1", fullTable)).collectAsList(); + + List rows = spark.sql(String.format("show tags in %s", fullTable)).collectAsList(); + Assertions.assertEquals(2, rows.size()); + } + + @Test + public void testDropTag() { + prepareDataset(); + + spark.sql(String.format("alter table %s create tag tag_0", fullTable)); + for (Row row : spark.sql(String.format("show tags in %s", fullTable)).collectAsList()) { + Assertions.assertEquals("tag_0", row.getString(0)); + Assertions.assertTrue(row.getLong(1) > 0); + Assertions.assertTrue(row.getInt(2) > 0); + } + + spark.sql(String.format("alter table %s drop tag tag_0", fullTable)); + Assertions.assertEquals( + 0, spark.sql(String.format("show tags in %s", fullTable)).collectAsList().size()); + + spark.sql( + String.format( + "insert into %s (id, text) values %s ;", + fullTable, + IntStream.range(10, 20) + .boxed() + .map(i -> String.format("(%d, 'text_%d')", i, i)) + .collect(Collectors.joining(",")))); + + spark.sql(String.format("alter table %s create tag tag_1", fullTable)).collectAsList(); + + List rows = spark.sql(String.format("show tags in %s", fullTable)).collectAsList(); + Assertions.assertEquals(1, rows.size()); + + for (Row row : rows) { + Assertions.assertEquals("tag_1", row.getString(0)); + Assertions.assertTrue(row.getLong(1) > 0); + Assertions.assertTrue(row.getInt(2) > 0); + } + } + + @Test + public void testQueryFromTag() { + prepareDataset(); + + spark.sql(String.format("alter table %s create tag tag_0", fullTable)).collectAsList(); + + spark.sql( + String.format( + "insert into %s (id, text) values %s ;", + fullTable, + IntStream.range(20, 30) + .boxed() + .map(i -> String.format("(%d, 'text_%d')", i, i)) + .collect(Collectors.joining(",")))); + + List rows = + spark + .sql(String.format("select * from %s__tag__tag_0 where id >= 10", fullTable)) + .collectAsList(); + Assertions.assertEquals(10, rows.size()); + + spark.sql(String.format("alter table %s create tag tag_1", fullTable)).collectAsList(); + + rows = spark.sql(String.format("select * from %s__tag__tag_1", fullTable)).collectAsList(); + Assertions.assertEquals(30, rows.size()); + } +}