|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, software |
| 13 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 14 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 15 | + * See the License for the specific language governing permissions and |
| 16 | + * limitations under the License. |
| 17 | + */ |
| 18 | + |
| 19 | +package org.apache.spark.sql.execution.datasources.v2.odps |
| 20 | + |
| 21 | +import com.aliyun.odps.`type`.TypeInfoFactory |
| 22 | +import com.aliyun.odps.account.AliyunAccount |
| 23 | +import com.aliyun.odps.task.SQLTask |
| 24 | +import com.aliyun.odps.{Column, Odps, OdpsType, TableSchema} |
| 25 | +import org.apache.spark.internal.Logging |
| 26 | +import org.apache.spark.sql.{SparkSession, sources} |
| 27 | +import org.apache.spark.sql.catalyst.InternalRow |
| 28 | +import org.apache.spark.sql.connector.catalog.{Identifier, SupportsPartitionManagement} |
| 29 | +import org.apache.spark.sql.connector.expressions.{LogicalExpressions, NamedReference, Transform} |
| 30 | +import org.apache.spark.sql.sources.Filter |
| 31 | +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} |
| 32 | +import org.apache.spark.sql.util.CaseInsensitiveStringMap |
| 33 | +import org.apache.spark.unsafe.types.UTF8String |
| 34 | +import org.scalatest.funsuite.AnyFunSuite |
| 35 | + |
| 36 | +import java.util |
| 37 | + |
| 38 | +class OdpsTableSuite extends AnyFunSuite with Logging { |
| 39 | + |
| 40 | + private val project: String = "" |
| 41 | + private val accessId: String = "" |
| 42 | + private val accessKey: String = "" |
| 43 | + private val endPoint: String = "" |
| 44 | + |
| 45 | + private val partTable: Identifier = Identifier.of(Array(project), "partTable") |
| 46 | + private val multiPartTable: Identifier = Identifier.of(Array(project), "multiPartTable") |
| 47 | + private val allTypePartTable: Identifier = Identifier.of(Array(project), "allTypePartTable") |
| 48 | + |
| 49 | + private def sparkSession: SparkSession = SparkSession.builder() |
| 50 | + .master("local[2]") |
| 51 | + .config("spark.hadoop.odps.access.id", accessId) |
| 52 | + .config("spark.hadoop.odps.access.key", accessKey) |
| 53 | + .config("spark.hadoop.odps.end.point", endPoint) |
| 54 | + .config("spark.hadoop.odps.project.name", project) |
| 55 | + .getOrCreate() |
| 56 | + |
| 57 | + private def odps: Odps = { |
| 58 | + val odps = new Odps(new AliyunAccount(accessId, accessKey)) |
| 59 | + odps.setDefaultProject(project) |
| 60 | + odps.setEndpoint(endPoint) |
| 61 | + odps |
| 62 | + } |
| 63 | + |
| 64 | + private val catalog: OdpsTableCatalog = { |
| 65 | + SparkSession.setDefaultSession(sparkSession) |
| 66 | + |
| 67 | + val newCatalog = new OdpsTableCatalog |
| 68 | + newCatalog.initialize("odps", CaseInsensitiveStringMap.empty()) |
| 69 | + |
| 70 | + newCatalog |
| 71 | + } |
| 72 | + |
| 73 | + test("listPartitionIdentifiers") { |
| 74 | + val partTable = createPartTable() |
| 75 | + |
| 76 | + assert(!hasPartitions(partTable)) |
| 77 | + |
| 78 | + val partIdent = InternalRow.apply(UTF8String.fromString("3")) |
| 79 | + partTable.createPartition(partIdent, new util.HashMap[String, String]()) |
| 80 | + assert(partTable.listPartitionIdentifiers(Array.empty, InternalRow.empty).length == 1) |
| 81 | + |
| 82 | + val partIdent1 = InternalRow.apply(UTF8String.fromString("4")) |
| 83 | + partTable.createPartition(partIdent1, new util.HashMap[String, String]()) |
| 84 | + assert(partTable.listPartitionIdentifiers(Array.empty, InternalRow.empty).length == 2) |
| 85 | + assert(partTable.listPartitionIdentifiers(Array("dt"), partIdent1).length == 1) |
| 86 | + |
| 87 | + partTable.dropPartition(partIdent) |
| 88 | + assert(partTable.listPartitionIdentifiers(Array.empty, InternalRow.empty).length == 1) |
| 89 | + partTable.dropPartition(partIdent1) |
| 90 | + assert(!hasPartitions(partTable)) |
| 91 | + } |
| 92 | + |
| 93 | + test("listPartitionsByFilter") { |
| 94 | + createMultiPartTable() |
| 95 | + |
| 96 | + Seq( |
| 97 | + Array[Filter]() -> Set( |
| 98 | + Map("part0" -> "0", "part1" -> "abc"), |
| 99 | + Map("part0" -> "0", "part1" -> "def"), |
| 100 | + Map("part0" -> "1", "part1" -> "abc")), |
| 101 | + |
| 102 | + Array[Filter](sources.EqualTo("part0", 0)) -> Set( |
| 103 | + Map("part0" -> "0", "part1" -> "abc"), |
| 104 | + Map("part0" -> "0", "part1" -> "def")), |
| 105 | + |
| 106 | + Array[Filter](sources.EqualTo("part1", "abc")) -> Set( |
| 107 | + Map("part0" -> "0", "part1" -> "abc"), |
| 108 | + Map("part0" -> "1", "part1" -> "abc")), |
| 109 | + |
| 110 | + Array[Filter](sources.And( |
| 111 | + sources.LessThanOrEqual("part0", 1), |
| 112 | + sources.GreaterThanOrEqual("part0", 0))) -> Set( |
| 113 | + Map("part0" -> "0", "part1" -> "abc"), |
| 114 | + Map("part0" -> "0", "part1" -> "def"), |
| 115 | + Map("part0" -> "1", "part1" -> "abc")), |
| 116 | + |
| 117 | + Array[Filter](sources.GreaterThanOrEqual("part0", 1)) -> Set( |
| 118 | + Map("part0" -> "1", "part1" -> "abc")), |
| 119 | + |
| 120 | + Array[Filter](sources.GreaterThanOrEqual("part0", 2)) -> Set(), |
| 121 | + |
| 122 | + ).foreach { case (filters, expected) => |
| 123 | + assert(catalog.listPartitionsByFilter(multiPartTable, filters).toSet === expected) |
| 124 | + } |
| 125 | + |
| 126 | + createPartTableAllType() |
| 127 | + Seq( |
| 128 | + Array[Filter]() -> Set( |
| 129 | + Map("p1" -> "11", "p2" -> "22", "p3" -> "3", "p4" -> "4", "p5" -> "5", "p6" -> "6"), |
| 130 | + Map("p1" -> "11 ", "p2" -> "22 ", "p3" -> "3", "p4" -> "4", "p5" -> "5", "p6" -> "6"), |
| 131 | + Map("p1" -> "111 ", "p2" -> "222 ","p3" -> "3", "p4" -> "4", "p5" -> "5", "p6" -> "6")), |
| 132 | + |
| 133 | + Array[Filter](sources.StringStartsWith("p1", "11")) -> Set( |
| 134 | + Map("p1" -> "11", "p2" -> "22", "p3" -> "3", "p4" -> "4", "p5" -> "5", "p6" -> "6"), |
| 135 | + Map("p1" -> "11 ", "p2" -> "22 ", "p3" -> "3", "p4" -> "4", "p5" -> "5", "p6" -> "6"), |
| 136 | + Map("p1" -> "111 ", "p2" -> "222 ","p3" -> "3", "p4" -> "4", "p5" -> "5", "p6" -> "6")), |
| 137 | + |
| 138 | + Array[Filter](sources.StringEndsWith("p1", " ")) -> Set( |
| 139 | + Map("p1" -> "11 ", "p2" -> "22 ", "p3" -> "3", "p4" -> "4", "p5" -> "5", "p6" -> "6"), |
| 140 | + Map("p1" -> "111 ", "p2" -> "222 ","p3" -> "3", "p4" -> "4", "p5" -> "5", "p6" -> "6")), |
| 141 | + ).foreach { case (filters, expected) => |
| 142 | + assert(catalog.listPartitionsByFilter(allTypePartTable, filters).toSet === expected) |
| 143 | + } |
| 144 | + } |
| 145 | + |
| 146 | + test("partitionExists") { |
| 147 | + val partTable = createMultiPartTable() |
| 148 | + |
| 149 | + assert(partTable.partitionExists(InternalRow(0, UTF8String.fromString("def")))) |
| 150 | + assert(!partTable.partitionExists(InternalRow(-1, UTF8String.fromString("def")))) |
| 151 | + |
| 152 | + val errMsg = intercept[ClassCastException] { |
| 153 | + partTable.partitionExists(InternalRow(UTF8String.fromString("abc"), UTF8String.fromString("def"))) |
| 154 | + }.getMessage |
| 155 | + assert(errMsg.contains("org.apache.spark.unsafe.types.UTF8String cannot be cast to java.lang.Integer")) |
| 156 | + |
| 157 | + val errMsg2 = intercept[IllegalArgumentException] { |
| 158 | + partTable.partitionExists(InternalRow(0)) |
| 159 | + }.getMessage |
| 160 | + assert(errMsg2.contains("The identifier might not refer to one partition")) |
| 161 | + } |
| 162 | + |
| 163 | + test("listPartitionByNames") { |
| 164 | + val partTable = createMultiPartTable() |
| 165 | + |
| 166 | + Seq( |
| 167 | + (Array("part0", "part1"), InternalRow(0, UTF8String.fromString("abc"))) -> Set(InternalRow(0, UTF8String.fromString("abc"))), |
| 168 | + |
| 169 | + (Array("part0"), InternalRow(0)) -> Set(InternalRow(0, UTF8String.fromString("abc")), |
| 170 | + InternalRow(0, UTF8String.fromString("def"))), |
| 171 | + |
| 172 | + (Array("part1"), InternalRow(UTF8String.fromString("abc"))) -> Set(InternalRow(0, UTF8String.fromString("abc")), |
| 173 | + InternalRow(1, UTF8String.fromString("abc"))), |
| 174 | + |
| 175 | + (Array.empty[String], InternalRow.empty) -> Set(InternalRow(0, UTF8String.fromString("abc")), |
| 176 | + InternalRow(0, UTF8String.fromString("def")), InternalRow(1, UTF8String.fromString("abc"))), |
| 177 | + |
| 178 | + (Array("part0", "part1"), InternalRow(3, UTF8String.fromString("xyz"))) -> Set(), |
| 179 | + |
| 180 | + (Array("part1"), InternalRow(UTF8String.fromString("xyz"))) -> Set() |
| 181 | + ).foreach { case ((names, idents), expected) => |
| 182 | + assert(partTable.listPartitionIdentifiers(names, idents).toSet === expected) |
| 183 | + } |
| 184 | + |
| 185 | + // Check invalid parameters |
| 186 | + Seq( |
| 187 | + (Array("part0", "part1"), InternalRow(0)) |
| 188 | + ).foreach { case (names, idents) => |
| 189 | + intercept[ArrayIndexOutOfBoundsException](partTable.listPartitionIdentifiers(names, idents)) |
| 190 | + } |
| 191 | + |
| 192 | + Seq( |
| 193 | + (Array("col0", "part1"), InternalRow(0, 1)), |
| 194 | + (Array("wrong"), InternalRow("invalid")) |
| 195 | + ).foreach { case (names, idents) => |
| 196 | + intercept[IllegalArgumentException](partTable.listPartitionIdentifiers(names, idents)) |
| 197 | + } |
| 198 | + } |
| 199 | + |
| 200 | + test("listPartitionForAllType") { |
| 201 | + val partTable = createPartTableAllType() |
| 202 | + |
| 203 | + Seq( |
| 204 | + (Array.empty[String], InternalRow.empty) -> Set( |
| 205 | + InternalRow(UTF8String.fromString("11 "), UTF8String.fromString("22 "), 3,4,5,6), |
| 206 | + InternalRow(UTF8String.fromString("111 "), UTF8String.fromString("222 "), 3,4,5,6), |
| 207 | + InternalRow(UTF8String.fromString("11"), UTF8String.fromString("22"), 3,4,5,6)), |
| 208 | + |
| 209 | + (Array("p1"), InternalRow(UTF8String.fromString("11"))) -> Set( |
| 210 | + InternalRow(UTF8String.fromString("11"), UTF8String.fromString("22"), 3,4,5,6)), |
| 211 | + |
| 212 | + (Array("p1"), InternalRow(UTF8String.fromString("11 "))) -> Set(), |
| 213 | + |
| 214 | + (Array("p1"), InternalRow(UTF8String.fromString("111 "))) -> Set( |
| 215 | + InternalRow(UTF8String.fromString("111 "), UTF8String.fromString("222 "), 3,4,5,6) |
| 216 | + ), |
| 217 | + |
| 218 | + (Array("p2"), InternalRow(UTF8String.fromString("22 "))) -> Set( |
| 219 | + InternalRow(UTF8String.fromString("11 "), UTF8String.fromString("22 "), 3,4,5,6), |
| 220 | + ), |
| 221 | + |
| 222 | + ).foreach { case ((names, idents), expected) => |
| 223 | + assert(partTable.listPartitionIdentifiers(names, idents).toSet === expected) |
| 224 | + } |
| 225 | + } |
| 226 | + |
| 227 | + private def createPartTable(): OdpsTable = { |
| 228 | + if (catalog.tableExists(partTable)) { |
| 229 | + catalog.dropTable(partTable) |
| 230 | + } |
| 231 | + |
| 232 | + val odpsTable = catalog.createTable( |
| 233 | + partTable, |
| 234 | + new StructType() |
| 235 | + .add("id", IntegerType) |
| 236 | + .add("data", StringType) |
| 237 | + .add("dt", StringType), |
| 238 | + Array[Transform](LogicalExpressions.identity(ref("dt"))), |
| 239 | + util.Collections.emptyMap[String, String]).asInstanceOf[OdpsTable] |
| 240 | + odpsTable |
| 241 | + } |
| 242 | + |
| 243 | + private def createMultiPartTable(): OdpsTable = { |
| 244 | + if (catalog.tableExists(multiPartTable)) { |
| 245 | + catalog.dropTable(multiPartTable) |
| 246 | + } |
| 247 | + |
| 248 | + val odpsTable = catalog.createTable( |
| 249 | + multiPartTable, |
| 250 | + new StructType() |
| 251 | + .add("col0", IntegerType) |
| 252 | + .add("part0", IntegerType) |
| 253 | + .add("part1", StringType), |
| 254 | + Array(LogicalExpressions.identity(ref("part0")), LogicalExpressions.identity(ref("part1"))), |
| 255 | + util.Collections.emptyMap[String, String]).asInstanceOf[OdpsTable] |
| 256 | + |
| 257 | + Seq( |
| 258 | + InternalRow(0, UTF8String.fromString("abc")), |
| 259 | + InternalRow(0, UTF8String.fromString("def")), |
| 260 | + InternalRow(1, UTF8String.fromString("abc"))).foreach { partIdent => |
| 261 | + odpsTable.createPartition(partIdent, new util.HashMap[String, String]()) |
| 262 | + } |
| 263 | + odpsTable |
| 264 | + } |
| 265 | + |
| 266 | + private def createPartTableAllType(): OdpsTable = { |
| 267 | + val tableSchema = new TableSchema |
| 268 | + val columns = new util.ArrayList[Column] |
| 269 | + columns.add(new Column("c0", OdpsType.STRING)) |
| 270 | + columns.add(new Column("c1", OdpsType.BIGINT)) |
| 271 | + tableSchema.setColumns(columns) |
| 272 | + |
| 273 | + val partitionColumns = new util.ArrayList[Column] |
| 274 | + partitionColumns.add(new Column("p1", TypeInfoFactory.getCharTypeInfo(5))) |
| 275 | + partitionColumns.add(new Column("p2", TypeInfoFactory.getVarcharTypeInfo(5))) |
| 276 | + partitionColumns.add(new Column("p3", OdpsType.SMALLINT)) |
| 277 | + partitionColumns.add(new Column("p4", OdpsType.TINYINT)) |
| 278 | + partitionColumns.add(new Column("p5", OdpsType.BIGINT)) |
| 279 | + partitionColumns.add(new Column("p6", OdpsType.INT)) |
| 280 | + tableSchema.setPartitionColumns(partitionColumns) |
| 281 | + |
| 282 | + if (odps.tables().exists(allTypePartTable.name)) { |
| 283 | + odps.tables().delete(allTypePartTable.name) |
| 284 | + } |
| 285 | + odps.tables().create(allTypePartTable.name, tableSchema) |
| 286 | + |
| 287 | + createPartitionUseOdpsSQL(allTypePartTable.name, "p1='11',p2='22',p3=3,p4=4,p5=5,p6=6") |
| 288 | + createPartitionUseOdpsSQL(allTypePartTable.name, "p1='11 ',p2='22 ',p3=3,p4=4,p5=5,p6=6") |
| 289 | + createPartitionUseOdpsSQL(allTypePartTable.name, "p1='111 ',p2='222 ',p3=3,p4=4,p5=5,p6=6") |
| 290 | + |
| 291 | + catalog.loadTable(allTypePartTable).asInstanceOf[OdpsTable] |
| 292 | + } |
| 293 | + |
| 294 | + private def ref(name: String): NamedReference = LogicalExpressions.parseReference(name) |
| 295 | + |
| 296 | + private def hasPartitions(table: SupportsPartitionManagement): Boolean = { |
| 297 | + !table.listPartitionIdentifiers(Array.empty, InternalRow.empty).isEmpty |
| 298 | + } |
| 299 | + |
| 300 | + private def createPartitionUseOdpsSQL(tableName: String, partition: String): Unit = { |
| 301 | + val i = SQLTask.run(odps, odps.getDefaultProject, |
| 302 | + s"ALTER TABLE $project.$tableName ADD PARTITION ($partition);", |
| 303 | + "SQLAddPartitionTask", |
| 304 | + new util.HashMap[String, String](), |
| 305 | + null) |
| 306 | + i.waitForSuccess() |
| 307 | + } |
| 308 | +} |
0 commit comments