diff --git a/pramen/core/src/main/resources/reference.conf b/pramen/core/src/main/resources/reference.conf index e14a5968..a9431025 100644 --- a/pramen/core/src/main/resources/reference.conf +++ b/pramen/core/src/main/resources/reference.conf @@ -61,7 +61,7 @@ pramen { # It is not always possible. When a table is initially created, MSCK REPAIR is always used to pick up all partitions. # Also ADD PARTTITION is only for Parquet format. # This option can be overridden per metatable. - hive.prefer.add.partition = false + hive.prefer.add.partition = true # If enabled, the job will wait for the output table to become available before running a job # If the number of seconds <=0 the waiting will be infinite diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/runner/task/ThreadClosableRegistry.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/runner/task/ThreadClosableRegistry.scala new file mode 100644 index 00000000..b611082e --- /dev/null +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/runner/task/ThreadClosableRegistry.scala @@ -0,0 +1,84 @@ +/* + * Copyright 2022 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.pramen.core.runner.task + +import org.slf4j.LoggerFactory + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +object ThreadClosableRegistry { + private val log = LoggerFactory.getLogger(this.getClass) + private val closeables = new java.util.LinkedList[(Long, AutoCloseable)] + + /** + * Registers a closeable resource for the current thread. + * The resource will be automatically closed when [[cleanupThread]] is called for this thread. + * + * @param closeable The AutoCloseable resource to register + */ + def registerCloseable(closeable: AutoCloseable): Unit = synchronized { + val threadId = Thread.currentThread().getId + + val iterator = closeables.iterator() + var alreadyRegistered = false + while (iterator.hasNext && !alreadyRegistered) { + alreadyRegistered = iterator.next()._2 == closeable + } + if (!alreadyRegistered) { + closeables.add((threadId, closeable)) + } + } + + /** + * Unregisters a closeable resource from the registry. + * This removes the resource regardless of which thread it was registered from. + * + * @param closeable The AutoCloseable resource to unregister + */ + def unregisterCloseable(closeable: AutoCloseable): Unit = synchronized { + val iterator = closeables.iterator() + while (iterator.hasNext) { + val (_, c) = iterator.next() + if (c == closeable) { + iterator.remove() + return + } + } + } + + /** + * Closes all registered resources for the specified thread in LIFO (Last-In-First-Out) order. + * This method is typically called when a thread times out or completes execution. + * Any exceptions during closing are logged but do not prevent other resources from being closed. + * + * @param threadId The ID of the thread whose resources should be cleaned up + */ + def cleanupThread(threadId: Long): Unit = synchronized { + val threadCloseables = closeables.asScala.filter(_._1 == threadId).map(_._2).toList + threadCloseables.reverse.foreach { closeable => + try { + closeable.close() + } catch { + case NonFatal(ex) => + log.warn(s"Error closing resource for thread $threadId.", ex) + } finally { + unregisterCloseable(closeable) + } + } + } +} diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/ThreadUtils.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/ThreadUtils.scala index b8da1f8a..2173efac 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/ThreadUtils.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/ThreadUtils.scala @@ -16,13 +16,17 @@ package za.co.absa.pramen.core.utils +import org.slf4j.LoggerFactory import za.co.absa.pramen.core.exceptions.TimeoutException +import za.co.absa.pramen.core.runner.task.ThreadClosableRegistry import za.co.absa.pramen.core.utils.impl.ThreadWithException import java.lang.Thread.UncaughtExceptionHandler import scala.concurrent.duration.Duration object ThreadUtils { + private val log = LoggerFactory.getLogger(this.getClass) + /** * Executes an action with a timeout. If the timeout is breached the task is killed (using Thread.interrupt()) * @@ -30,8 +34,8 @@ object ThreadUtils { * * Any exception is passed to the caller. * - * @param timeout The task timeout. - * @param action An action to execute. + * @param timeout The task timeout. + * @param action An action to execute. */ @throws[TimeoutException] def runWithTimeout(timeout: Duration)(action: => Unit): Unit = { @@ -54,6 +58,15 @@ object ThreadUtils { if (thread.isAlive) { val stackTrace = thread.getStackTrace + + try { + // Execute cleanup BEFORE interrupt - e.g. close the JDBC connection/statement + ThreadClosableRegistry.cleanupThread(thread.getId) + } catch { + case ex: Throwable => + log.warn(s"Exception during timeout cleanup: ${ex.getMessage}") + } + thread.interrupt() val prettyTimeout = TimeUtils.prettyPrintElapsedTimeShort(timeout.toMillis) diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/QueryExecutorJdbc.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/QueryExecutorJdbc.scala index 62906fd8..b7288e9a 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/QueryExecutorJdbc.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/QueryExecutorJdbc.scala @@ -19,6 +19,7 @@ package za.co.absa.pramen.core.utils.hive import org.slf4j.LoggerFactory import za.co.absa.pramen.core.reader.JdbcUrlSelector import za.co.absa.pramen.core.reader.model.JdbcConfig +import za.co.absa.pramen.core.runner.task.ThreadClosableRegistry import java.sql._ import scala.util.control.NonFatal @@ -66,15 +67,43 @@ class QueryExecutorJdbc(jdbcUrlSelector: JdbcUrlSelector, optimizedExistQuery: B executeActionOnConnection { conn => val statement = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) + val autoCloseStatement: AutoCloseable = new AutoCloseable { + val statementClosed = new java.util.concurrent.atomic.AtomicBoolean(false) + + override def close(): Unit = { + if (statementClosed.compareAndSet(false, true)) { + try { + log.info(s"Cancelling SQL statement: $query") + statement.cancel() + } finally { + log.info(s"Closing the SQL statement...") + statement.close() + } + } + } + } + + ThreadClosableRegistry.registerCloseable(autoCloseStatement) + try { statement.execute(query) } finally { - statement.close() + ThreadClosableRegistry.unregisterCloseable(autoCloseStatement) + autoCloseStatement.close() } } } - override def close(): Unit = if (connection != null) connection.close() + override def close(): Unit = { + if (connection != null) { + ThreadClosableRegistry.unregisterCloseable(connection) + try { + connection.close() + } catch { + case NonFatal(ex) => log.warn("Failed to close JDBC connection", ex) + } + } + } private[core] def executeActionOnConnection(action: Connection => Boolean): Boolean = { val currentConnection = getConnection(forceReconnect = false) @@ -97,7 +126,9 @@ class QueryExecutorJdbc(jdbcUrlSelector: JdbcUrlSelector, optimizedExistQuery: B if (connection == null || forceReconnect) { val (newConnection, url) = jdbcUrlSelector.getWorkingConnection(retries) log.info(s"Selected query executor connection: $url") + close() connection = newConnection + ThreadClosableRegistry.registerCloseable(connection) } connection } diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/metastore/MetastoreSuite.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/metastore/MetastoreSuite.scala index 2c6d13a8..4a845bed 100644 --- a/pramen/core/src/test/scala/za/co/absa/pramen/core/metastore/MetastoreSuite.scala +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/metastore/MetastoreSuite.scala @@ -306,7 +306,7 @@ class MetastoreSuite extends AnyWordSpec with SparkTestBase with TextComparisonF m.repairOrCreateHiveTable("table_hive_parquet", infoDate, Option(schema), hh, recreate = false) assert(qe.queries.length == 1) - assert(qe.queries.exists(_.contains("REPAIR"))) + assert(qe.queries.exists(_.contains("ALTER TABLE"))) } "do nothing for a delta since it does not need repairing" in { diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/runner/task/ThreadClosableRegistrySuite.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/runner/task/ThreadClosableRegistrySuite.scala new file mode 100644 index 00000000..aa497bea --- /dev/null +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/runner/task/ThreadClosableRegistrySuite.scala @@ -0,0 +1,189 @@ +/* + * Copyright 2022 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.pramen.core.tests.runner.task + +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec +import za.co.absa.pramen.core.runner.task.ThreadClosableRegistry + +import scala.collection.mutable + +class ThreadClosableRegistrySuite extends AnyWordSpec with Matchers { + + private def createCountingCloseable(): (AutoCloseable, () => Int) = { + var count = 0 + val closeable = new AutoCloseable { + override def close(): Unit = count += 1 + } + (closeable, () => count) + } + + private def currentThreadId: Long = Thread.currentThread().getId + + "ThreadClosableRegistry" should { + "register and cleanup single or multiple closeable resources" in { + val (closeable1, getCount1) = createCountingCloseable() + val (closeable2, getCount2) = createCountingCloseable() + val (closeable3, getCount3) = createCountingCloseable() + + ThreadClosableRegistry.registerCloseable(closeable1) + ThreadClosableRegistry.registerCloseable(closeable2) + ThreadClosableRegistry.registerCloseable(closeable3) + + ThreadClosableRegistry.cleanupThread(currentThreadId) + + getCount1() shouldBe 1 + getCount2() shouldBe 1 + getCount3() shouldBe 1 + } + + "not affect other threads when cleaning up a specific thread" in { + val (closeableThread1, getCount1) = createCountingCloseable() + val (closeableThread2, getCount2) = createCountingCloseable() + + var thread1Id: Long = 0 + var thread2Id: Long = 0 + + val thread1 = new Thread { + override def run(): Unit = { + thread1Id = Thread.currentThread().getId + ThreadClosableRegistry.registerCloseable(closeableThread1) + } + } + + val thread2 = new Thread { + override def run(): Unit = { + thread2Id = Thread.currentThread().getId + ThreadClosableRegistry.registerCloseable(closeableThread2) + } + } + + thread1.start() + thread2.start() + thread1.join() + thread2.join() + + // Cleanup only thread1 + ThreadClosableRegistry.cleanupThread(thread1Id) + getCount1() shouldBe 1 + getCount2() shouldBe 0 + + // Cleanup thread2 + ThreadClosableRegistry.cleanupThread(thread2Id) + getCount2() shouldBe 1 + } + + "handle exceptions during resource cleanup gracefully" in { + val (closeable1, getCount1) = createCountingCloseable() + val (closeable3, getCount3) = createCountingCloseable() + + var closeCalled2 = 0 + val closeable2 = new AutoCloseable { + override def close(): Unit = { + closeCalled2 += 1 + throw new RuntimeException("Close failed") + } + } + + ThreadClosableRegistry.registerCloseable(closeable1) + ThreadClosableRegistry.registerCloseable(closeable2) + ThreadClosableRegistry.registerCloseable(closeable3) + + // Should not throw exception, should attempt to close all resources + noException should be thrownBy ThreadClosableRegistry.cleanupThread(currentThreadId) + + getCount1() shouldBe 1 + closeCalled2 shouldBe 1 + getCount3() shouldBe 1 + } + + "do nothing when cleaning up a thread with no registered resources" in { + val nonExistentThreadId = 999999L + + noException should be thrownBy ThreadClosableRegistry.cleanupThread(nonExistentThreadId) + } + + "do nothing when cleaning up an already cleaned thread" in { + val (closeable, getCount) = createCountingCloseable() + + ThreadClosableRegistry.registerCloseable(closeable) + ThreadClosableRegistry.cleanupThread(currentThreadId) + getCount() shouldBe 1 + + // Cleanup again - should not call close again + ThreadClosableRegistry.cleanupThread(currentThreadId) + getCount() shouldBe 1 + } + + "close resources in LIFO order (last registered, first closed)" in { + val closeOrder = mutable.ArrayBuffer[Int]() + + val closeables = (1 to 3).map { id => + new AutoCloseable { + override def close(): Unit = closeOrder += id + } + } + + closeables.foreach(ThreadClosableRegistry.registerCloseable) + ThreadClosableRegistry.cleanupThread(currentThreadId) + + closeOrder should contain theSameElementsInOrderAs Seq(3, 2, 1) + } + + "handle concurrent registrations from multiple threads" in { + val numThreads = 10 + val closeablesPerThread = 5 + val threadData = mutable.Map[Long, Seq[() => Int]]() + + val threads = (1 to numThreads).map { _ => + new Thread { + override def run(): Unit = { + val threadId = Thread.currentThread().getId + val closeablesWithCounters = (1 to closeablesPerThread).map(_ => createCountingCloseable()) + + threadData.synchronized { + threadData(threadId) = closeablesWithCounters.map(_._2) + } + + closeablesWithCounters.foreach { case (closeable, _) => + ThreadClosableRegistry.registerCloseable(closeable) + } + } + } + } + + threads.foreach(_.start()) + threads.foreach(_.join()) + + threadData.foreach { case (threadId, getCounters) => + ThreadClosableRegistry.cleanupThread(threadId) + getCounters.foreach(getCount => getCount() shouldBe 1) + } + } + + "not close an explicitly unregistered resource during cleanup" in { + val (closeable, getCount) = createCountingCloseable() + + ThreadClosableRegistry.registerCloseable(closeable) + ThreadClosableRegistry.unregisterCloseable(closeable) + + // cleanupThread should not close the unregistered resource + ThreadClosableRegistry.cleanupThread(currentThreadId) + getCount() shouldBe 0 + } + } +}