diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 59417e19..fdfda10e 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -4,13 +4,16 @@ aspectj = "1.9.22.1" assertj = "3.27.3" cron-utils = "9.2.1" hikaricp = "7.0.2" -kryo = "5.6.2" jackson = "2.21.2" java-websocket = "1.6.0" +jaxb-api = "4.0.2" +jdbi = "3.47.0" +jooq = "3.19.15" jspecify = "1.0.0" junit = "6.0.3" junit-pioneer = "2.3.0" kotlin = "2.3.10" +kryo = "5.6.2" logback = "1.5.32" maven-artifact = "3.9.13" maven-publish = "0.36.0" @@ -20,10 +23,10 @@ postgresql = "42.7.10" rest-assured = "6.0.0" shadow = "9.4.1" slf4j = "2.0.17" -sqlite-jdbc = "3.49.1.0" spotless = "8.4.0" spring-boot = "3.4.4" spring-framework = "6.2.5" +sqlite-jdbc = "3.49.1.0" system-stubs = "2.1.8" testcontainers = "2.0.4" versions = "0.53.0" @@ -37,6 +40,9 @@ hikaricp = { module = "com.zaxxer:HikariCP", version.ref = "hikaricp" } jackson-databind = { module = "com.fasterxml.jackson.core:jackson-databind", version.ref = "jackson" } jackson-jsr310 = { module = "com.fasterxml.jackson.datatype:jackson-datatype-jsr310", version.ref = "jackson" } java-websocket = { module = "org.java-websocket:Java-WebSocket", version.ref = "java-websocket" } +jaxb-api = { module = "jakarta.xml.bind:jakarta.xml.bind-api", version.ref = "jaxb-api" } +jdbi-core = { module = "org.jdbi:jdbi3-core", version.ref = "jdbi" } +jooq = { module = "org.jooq:jooq", version.ref = "jooq" } jspecify = { module = "org.jspecify:jspecify", version.ref = "jspecify" } junit-bom = { module = "org.junit:junit-bom", version.ref = "junit" } junit-jupiter = { module = "org.junit.jupiter:junit-jupiter" } @@ -51,11 +57,11 @@ postgresql = { module = "org.postgresql:postgresql", version.ref = "postgresql" rest-assured = { module = "io.rest-assured:rest-assured", version.ref = "rest-assured" } slf4j-api = { module = "org.slf4j:slf4j-api", version.ref = "slf4j" } slf4j-simple = { module = "org.slf4j:slf4j-simple", version.ref = "slf4j" } -sqlite-jdbc = { module = "org.xerial:sqlite-jdbc", version.ref = "sqlite-jdbc" } spring-aop = { module = "org.springframework:spring-aop", version.ref = "spring-framework" } spring-boot-autoconfigure = { module = "org.springframework.boot:spring-boot-autoconfigure", version.ref = "spring-boot" } spring-boot-configuration-processor = { module = "org.springframework.boot:spring-boot-configuration-processor", version.ref = "spring-boot" } spring-boot-test = { module = "org.springframework.boot:spring-boot-test", version.ref = "spring-boot" } +sqlite-jdbc = { module = "org.xerial:sqlite-jdbc", version.ref = "sqlite-jdbc" } system-stubs-jupiter = { module = "uk.org.webcompere:system-stubs-jupiter", version.ref = "system-stubs" } testcontainers-postgresql = { module = "org.testcontainers:testcontainers-postgresql", version.ref = "testcontainers" } diff --git a/settings.gradle.kts b/settings.gradle.kts index ae7fee16..af6f2a87 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -1,6 +1,12 @@ rootProject.name = "dbos-transact-java" -include("transact", "transact-cli", "transact-spring-boot-starter") +include( + "transact", + "transact-cli", + "transact-spring-boot-starter", + "transact-jdbi-step-factory", + "transact-jooq-step-factory", +) plugins { id("org.gradle.toolchains.foojay-resolver") version "1.0.0" } diff --git a/transact-jdbi-step-factory/build.gradle.kts b/transact-jdbi-step-factory/build.gradle.kts new file mode 100644 index 00000000..35e623b2 --- /dev/null +++ b/transact-jdbi-step-factory/build.gradle.kts @@ -0,0 +1,22 @@ +plugins { id("java-library") } + +tasks.withType { + options.compilerArgs.add("-Xlint:unchecked") + options.compilerArgs.add("-Xlint:deprecation") + options.compilerArgs.add("-Xlint:rawtypes") + options.compilerArgs.add("-Werror") +} + +dependencies { + api(project(":transact")) + api(libs.jdbi.core) + + testImplementation(platform(libs.junit.bom)) + testImplementation(libs.junit.jupiter) + testRuntimeOnly(libs.junit.platform.launcher) + + testRuntimeOnly(libs.logback.classic) + testImplementation(libs.testcontainers.postgresql) + testImplementation(libs.postgresql) + testImplementation(libs.hikaricp) +} diff --git a/transact-jdbi-step-factory/src/main/java/dev/dbos/transact/jdbi/JdbiStepFactory.java b/transact-jdbi-step-factory/src/main/java/dev/dbos/transact/jdbi/JdbiStepFactory.java new file mode 100644 index 00000000..779f737c --- /dev/null +++ b/transact-jdbi-step-factory/src/main/java/dev/dbos/transact/jdbi/JdbiStepFactory.java @@ -0,0 +1,193 @@ +package dev.dbos.transact.jdbi; + +import dev.dbos.transact.DBOS; +import dev.dbos.transact.json.DBOSSerializer; +import dev.dbos.transact.json.SerializationUtil; +import dev.dbos.transact.txstep.PostgresStepFactory; +import dev.dbos.transact.workflow.internal.StepResult; + +import java.util.Objects; +import java.util.Optional; + +import org.jdbi.v3.core.Handle; +import org.jdbi.v3.core.HandleCallback; +import org.jdbi.v3.core.HandleConsumer; +import org.jdbi.v3.core.Jdbi; + +/** + * Runs idempotent transactional steps inside DBOS workflows using Jdbi3 {@link Handle} objects. + * + *

Construct one with a {@link Jdbi} instance pointing at a PostgreSQL database. The constructor + * verifies the datasource is PostgreSQL and creates the {@code tx_step_outputs} table if needed. + * Lambdas passed to {@link #inStep} or {@link #useStep} receive a {@link Handle} with a transaction + * already open; they must not call {@code commit} or {@code close} themselves. + * + *

{@code
+ * JdbiStepFactory factory = new JdbiStepFactory(dbos, Jdbi.create(dataSource));
+ *
+ * // inside a @Workflow method:
+ * int count = factory.inStep(handle -> {
+ *     return handle.createUpdate("INSERT INTO ...").execute();
+ * }, "myStep");
+ * }
+ */ +public class JdbiStepFactory extends PostgresStepFactory { + + private final Jdbi jdbi; + + /** + * Creates a factory using the schema and serializer from {@code dbos} configuration. + * + * @param dbos the DBOS runtime instance + * @param jdbi a Jdbi instance connected to a PostgreSQL database + */ + public JdbiStepFactory(DBOS dbos, Jdbi jdbi) { + this(dbos, jdbi, null, null); + } + + /** + * Creates a factory using the given schema and the serializer from {@code dbos} configuration. + * + * @param dbos the DBOS runtime instance + * @param jdbi a Jdbi instance connected to a PostgreSQL database + * @param schema the PostgreSQL schema to use for {@code tx_step_outputs}; {@code null} uses the + * schema from {@code dbos} configuration + */ + public JdbiStepFactory(DBOS dbos, Jdbi jdbi, String schema) { + this(dbos, jdbi, schema, null); + } + + /** + * Creates a factory using the given serializer and the schema from {@code dbos} configuration. + * + * @param dbos the DBOS runtime instance + * @param jdbi a Jdbi instance connected to a PostgreSQL database + * @param serializer the serializer to use for step outputs; {@code null} uses the serializer from + * {@code dbos} configuration + */ + public JdbiStepFactory(DBOS dbos, Jdbi jdbi, DBOSSerializer serializer) { + this(dbos, jdbi, null, serializer); + } + + /** + * Creates a factory with explicit schema and serializer overrides. + * + *

Connects to the database immediately to verify it is PostgreSQL and to create the {@code + * tx_step_outputs} table in the given schema if it does not already exist. + * + * @param dbos the DBOS runtime instance + * @param jdbi a Jdbi instance connected to a PostgreSQL database + * @param schema the PostgreSQL schema to use for {@code tx_step_outputs}; {@code null} uses the + * schema from {@code dbos} configuration + * @param serializer the serializer to use for step outputs; {@code null} uses the serializer from + * {@code dbos} configuration + * @throws RuntimeException if the datasource is not PostgreSQL or the schema setup fails + */ + public JdbiStepFactory(DBOS dbos, Jdbi jdbi, String schema, DBOSSerializer serializer) { + super(dbos, schema, serializer, () -> jdbi.open().getConnection()); + this.jdbi = Objects.requireNonNull(jdbi); + } + + /** + * Executes {@code callback} as an idempotent DBOS step inside a Jdbi transaction. + * + *

If a result for this step is already recorded (e.g. on workflow retry), the callback is + * skipped and the cached result is returned. Otherwise the callback runs inside an open + * transaction; the output is recorded atomically with the database work so the step is + * exactly-once on success. + * + * @param the return type of the callback + * @param the checked exception type the callback may throw + * @param callback the database work to perform; receives an open {@link Handle} and must not + * commit or close it + * @param stepName a stable name that identifies this step within the workflow + * @return the value returned by {@code callback} + * @throws X if the callback throws + */ + public R inStep(final HandleCallback callback, String stepName) + throws X { + return runTxStep( + (wfId, stepId) -> + jdbi.inTransaction( + h -> { + var result = callback.withHandle(h); + recordOutput(h, wfId, stepId, result); + return result; + }), + stepName); + } + + /** + * Executes {@code callback} as an idempotent DBOS step inside a Jdbi transaction, with no return + * value. + * + *

Behaves identically to {@link #inStep} but accepts a {@link HandleConsumer} for callers that + * do not need to return a result. + * + * @param the checked exception type the callback may throw + * @param callback the database work to perform; receives an open {@link Handle} and must not + * commit or close it + * @param stepName a stable name that identifies this step within the workflow + * @throws X if the callback throws + */ + public void useStep(final HandleConsumer callback, String stepName) + throws X { + inStep( + handle -> { + callback.useHandle(handle); + return null; + }, + stepName); + } + + @Override + protected Optional checkExecution(String workflowId, int stepId, String stepName) { + return jdbi.withHandle( + h -> + h.createQuery(checkSql()) + .bind(0, workflowId) + .bind(1, stepId) + .map( + (rs, ctx) -> + new StepResult( + workflowId, + stepId, + stepName, + rs.getString("output"), + rs.getString("error"), + null, + rs.getString("serialization"))) + .findOne()); + } + + private void recordOutput(Handle handle, String workflowId, int stepId, R result) { + var value = SerializationUtil.serializeValue(result, null, serializer); + recordResult(handle, workflowId, stepId, value.serializedValue(), null, value.serialization()); + } + + @Override + protected void recordError(String workflowId, int stepId, Exception exception) { + var value = SerializationUtil.serializeError(exception, null, serializer); + jdbi.useTransaction( + h -> + recordResult( + h, workflowId, stepId, null, value.serializedValue(), value.serialization())); + } + + private void recordResult( + Handle handle, + String workflowId, + int stepId, + String output, + String error, + String serialization) { + handle + .createUpdate(upsertSql()) + .bind(0, workflowId) + .bind(1, stepId) + .bind(2, output) + .bind(3, error) + .bind(4, serialization) + .execute(); + } +} diff --git a/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/jdbi/JdbiStepFactoryTest.java b/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/jdbi/JdbiStepFactoryTest.java new file mode 100644 index 00000000..2e101468 --- /dev/null +++ b/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/jdbi/JdbiStepFactoryTest.java @@ -0,0 +1,428 @@ +package dev.dbos.transact.jdbi; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import dev.dbos.transact.DBOS; +import dev.dbos.transact.config.DBOSConfig; +import dev.dbos.transact.context.WorkflowOptions; +import dev.dbos.transact.json.SerializationUtil; +import dev.dbos.transact.utils.DBUtils; +import dev.dbos.transact.utils.PgContainer; +import dev.dbos.transact.workflow.Workflow; +import dev.dbos.transact.workflow.WorkflowHandle; + +import java.sql.SQLException; +import java.util.Objects; + +import com.zaxxer.hikari.HikariDataSource; +import org.jdbi.v3.core.Handle; +import org.jdbi.v3.core.Jdbi; +import org.junit.jupiter.api.AutoClose; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +interface FactoryTestService { + record TestResult(String user, int greetCount) {} + + TestResult insertWorkflow(String user); + + TestResult errorWorkflow(String user); + + TestResult readWorkflow(String user); + + TestResult insertThenReadWorkflow(String user); +} + +class FactoryTestServiceImpl implements FactoryTestService { + + private final JdbiStepFactory stepFactory; + + public FactoryTestServiceImpl(JdbiStepFactory stepFactory) { + this.stepFactory = stepFactory; + } + + FactoryTestService.TestResult insertGreeting(Handle handle, String user) { + var sql = + """ + INSERT INTO greetings(name, greet_count) + VALUES (?, 1) + ON CONFLICT(name) + DO UPDATE SET greet_count = greetings.greet_count + 1 + RETURNING greet_count + """; + return handle + .createQuery(sql) + .bind(0, Objects.requireNonNull(user)) + .map((rs, ctx) -> new FactoryTestService.TestResult(user, rs.getInt("greet_count"))) + .findFirst() + .orElse(new FactoryTestService.TestResult(user, 0)); + } + + FactoryTestService.TestResult errorGreeting(Handle handle, String user) { + insertGreeting(handle, user); + throw new RuntimeException("Test Exception %d".formatted(System.currentTimeMillis())); + } + + FactoryTestService.TestResult readGreeting(Handle handle, String user) { + var sql = + """ + SELECT greet_count + FROM greetings + WHERE name = ? + """; + return handle + .createQuery(sql) + .bind(0, Objects.requireNonNull(user)) + .map((rs, ctx) -> new FactoryTestService.TestResult(user, rs.getInt("greet_count"))) + .findFirst() + .orElse(new FactoryTestService.TestResult(user, 0)); + } + + @Override + @Workflow + public FactoryTestService.TestResult insertWorkflow(String user) { + return stepFactory.inStep((Handle h) -> insertGreeting(h, user), "insertGreeting"); + } + + @Override + @Workflow + public FactoryTestService.TestResult errorWorkflow(String user) { + return stepFactory.inStep((Handle h) -> errorGreeting(h, user), "errorGreeting"); + } + + @Override + @Workflow + public FactoryTestService.TestResult readWorkflow(String user) { + return stepFactory.inStep((Handle h) -> readGreeting(h, user), "readGreeting"); + } + + @Override + @Workflow + public FactoryTestService.TestResult insertThenReadWorkflow(String user) { + stepFactory.useStep((Handle h) -> insertGreeting(h, user), "insertGreeting"); + return stepFactory.inStep((Handle h) -> readGreeting(h, user), "readGreeting"); + } +} + +public class JdbiStepFactoryTest { + @AutoClose final PgContainer pgContainer = new PgContainer(); + + DBOSConfig dbosConfig; + @AutoClose DBOS dbos; + @AutoClose HikariDataSource dataSource; + JdbiStepFactory stepFactory; + FactoryTestService proxy; + FactoryTestServiceImpl impl; + + @BeforeEach + void beforeEach() throws SQLException { + + dbosConfig = pgContainer.dbosConfig(); + dataSource = pgContainer.dataSource(); + + try (var conn = dataSource.getConnection(); + var stmt = conn.createStatement()) { + stmt.execute( + "CREATE TABLE greetings(name text NOT NULL, greet_count integer DEFAULT 0, PRIMARY KEY(name))"); + } + + dbos = new DBOS(dbosConfig); + Jdbi jdbi = Jdbi.create(dataSource); + stepFactory = new JdbiStepFactory(dbos, jdbi); + + impl = new FactoryTestServiceImpl(stepFactory); + proxy = dbos.registerProxy(FactoryTestService.class, impl); + + dbos.launch(); + } + + private int getGreetCount(String user) throws SQLException { + var sql = "SELECT greet_count FROM greetings WHERE name = ?"; + try (var conn = dataSource.getConnection(); + var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, user); + try (var rs = stmt.executeQuery()) { + return rs.next() ? rs.getInt("greet_count") : 0; + } + } + } + + @Test + public void testInsert() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + + var rows = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(1, rows.size()); + var row = rows.get(0); + assertEquals(wfid, row.workflowId()); + assertEquals(0, row.stepId()); + assertNotNull(row.output()); + assertNull(row.error()); + assertEquals(SerializationUtil.NATIVE, row.serialization()); + var output = SerializationUtil.deserializeValue(row.output(), row.serialization(), null); + assertEquals(new FactoryTestService.TestResult(user, 1), output); + + assertEquals(1, getGreetCount(user)); + } + + @Test + public void testError() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + try (var _o = new WorkflowOptions(wfid).setContext()) { + assertThrows(RuntimeException.class, () -> proxy.errorWorkflow(user)); + } + + // Transaction rolled back — no greeting inserted + var rows = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(1, rows.size()); + var row = rows.get(0); + assertEquals(wfid, row.workflowId()); + assertEquals(0, row.stepId()); + assertNull(row.output()); + assertNotNull(row.error()); + + assertEquals(0, getGreetCount(user)); + } + + @Test + public void testRead() throws Exception { + var insertWfid = "wf1"; + var readWfid = "wf2"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(insertWfid).setContext()) { + proxy.insertWorkflow(user); + } + + try (var _o = new WorkflowOptions(readWfid).setContext()) { + var result = proxy.readWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + + var rows = DBUtils.getTxStepRows(dataSource, readWfid); + assertEquals(1, rows.size()); + var row = rows.get(0); + assertEquals(readWfid, row.workflowId()); + assertEquals(0, row.stepId()); + assertNotNull(row.output()); + assertNull(row.error()); + assertEquals(SerializationUtil.NATIVE, row.serialization()); + var output = SerializationUtil.deserializeValue(row.output(), row.serialization(), null); + assertEquals(new FactoryTestService.TestResult(user, 1), output); + + assertEquals(1, getGreetCount(user)); + } + + @Test + public void testIdempotency() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + + // Second call with same wfid — txStep output is cached, insert not re-executed + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + + assertEquals(1, getGreetCount(user)); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid).size()); + } + + @Test + public void testRetryError() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(wfid).setContext()) { + assertThrows(RuntimeException.class, () -> proxy.errorWorkflow(user)); + } + assertEquals(0, getGreetCount(user)); + dbos.close(); + + try (var conn = dataSource.getConnection(); + var stmt = + conn.prepareStatement("DELETE FROM dbos.operation_outputs WHERE workflow_uuid = ?")) { + stmt.setString(1, wfid); + stmt.executeUpdate(); + } + DBUtils.setWorkflowState(dataSource, wfid, "PENDING"); + + assertEquals(0, DBUtils.getStepRows(dataSource, wfid).size()); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid).size()); + + dbos.launch(); + WorkflowHandle handle = + dbos.retrieveWorkflow(wfid); + assertThrows(RuntimeException.class, handle::getResult); + + // Cached error replayed — insert still not committed + assertEquals(0, getGreetCount(user)); + var txSteps = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(1, txSteps.size()); + assertNull(txSteps.get(0).output()); + assertNotNull(txSteps.get(0).error()); + } + + @Test + public void testMultipleTxSteps() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertThenReadWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + + assertEquals(1, getGreetCount(user)); + + var rows = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(2, rows.size()); + assertEquals(0, rows.get(0).stepId()); + assertNotNull(rows.get(0).output()); + assertNull(rows.get(0).error()); + assertEquals(1, rows.get(1).stepId()); + assertNotNull(rows.get(1).output()); + assertNull(rows.get(1).error()); + } + + @Test + public void testDistinctWorkflows() throws Exception { + var wfid1 = "wf1"; + var wfid2 = "wf2"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(wfid1).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(1, result.greetCount()); + } + + try (var _o = new WorkflowOptions(wfid2).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(2, result.greetCount()); + } + + assertEquals(2, getGreetCount(user)); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid1).size()); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid2).size()); + } + + @Test + public void testRetryPartialMultipleSteps() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertThenReadWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + assertEquals(1, getGreetCount(user)); + dbos.close(); + + // Simulate crash after step 0 wrote tx_step_outputs but before step 1 ran: + // both operation_outputs rows are gone, and step 1 has no tx_step_outputs entry + try (var conn = dataSource.getConnection(); + var stmt = + conn.prepareStatement("DELETE FROM dbos.operation_outputs WHERE workflow_uuid = ?")) { + stmt.setString(1, wfid); + stmt.executeUpdate(); + } + try (var conn = dataSource.getConnection(); + var stmt = + conn.prepareStatement( + "DELETE FROM dbos.tx_step_outputs WHERE workflow_id = ? AND step_id = 1")) { + stmt.setString(1, wfid); + stmt.executeUpdate(); + } + DBUtils.setWorkflowState(dataSource, wfid, "PENDING"); + + assertEquals(0, DBUtils.getStepRows(dataSource, wfid).size()); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid).size()); + + var relaunchTimestamp = System.currentTimeMillis(); + dbos.launch(); + WorkflowHandle handle = + dbos.retrieveWorkflow(wfid); + var result = (FactoryTestService.TestResult) handle.getResult(); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + + // Step 0 cache hit — insert not re-executed + assertEquals(1, getGreetCount(user)); + + var txSteps = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(2, txSteps.size()); + assertTrue(txSteps.get(0).createdAt() < relaunchTimestamp); // step 0: original run + assertTrue(txSteps.get(1).createdAt() >= relaunchTimestamp); // step 1: re-executed on retry + } + + @Test + public void testRetryInsert() throws Exception { + var timestamp = System.currentTimeMillis(); + + var wfid = "wf1"; + var user = "testUser"; + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + dbos.close(); + + try (var conn = dataSource.getConnection(); + var stmt = + conn.prepareStatement("DELETE FROM dbos.operation_outputs WHERE workflow_uuid = ?")) { + stmt.setString(1, wfid); + stmt.executeUpdate(); + } + DBUtils.setWorkflowState(dataSource, wfid, "PENDING"); + + assertEquals(0, DBUtils.getStepRows(dataSource, wfid).size()); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid).size()); + + var relaunchTimestamp = System.currentTimeMillis(); + dbos.launch(); + var handle = dbos.retrieveWorkflow(wfid); + var result = (FactoryTestService.TestResult) handle.getResult(); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + + var steps = DBUtils.getStepRows(dataSource, wfid); + var txSteps = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(1, steps.size()); + assertEquals(1, txSteps.size()); + + var step = steps.get(0); + var txStep = txSteps.get(0); + assertEquals(step.output(), txStep.output()); + assertEquals(step.error(), txStep.error()); + + assertTrue(txStep.createdAt() < step.startedAt()); + assertTrue(timestamp < txStep.createdAt()); + assertTrue(txStep.createdAt() < relaunchTimestamp); + assertTrue(relaunchTimestamp < step.startedAt()); + + // Retry reads from tx_step_outputs cache — insert not re-executed + assertEquals(1, getGreetCount(user)); + } +} diff --git a/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/DBUtils.java b/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/DBUtils.java new file mode 100644 index 00000000..7f16faa1 --- /dev/null +++ b/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/DBUtils.java @@ -0,0 +1,69 @@ +package dev.dbos.transact.utils; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import dev.dbos.transact.database.SystemDatabase; + +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import javax.sql.DataSource; + +public class DBUtils { + + public static List getTxStepRows(DataSource ds, String workflowId) + throws SQLException { + var sql = + "SELECT * FROM \"%s\".tx_step_outputs WHERE workflow_id = ? ORDER BY step_id" + .formatted(SystemDatabase.sanitizeSchema(null)); + try (var conn = ds.getConnection(); + var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, Objects.requireNonNull(workflowId)); + try (var rs = stmt.executeQuery()) { + List rows = new ArrayList<>(); + while (rs.next()) { + rows.add(new TxStepOutputRow(rs)); + } + return rows; + } + } + } + + public static List getStepRows(DataSource ds, String workflowId) + throws SQLException { + var sql = + "SELECT * FROM \"%s\".operation_outputs WHERE workflow_uuid = ? ORDER BY function_id" + .formatted(SystemDatabase.sanitizeSchema(null)); + try (var conn = ds.getConnection(); + var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, workflowId); + try (var rs = stmt.executeQuery()) { + List rows = new ArrayList<>(); + while (rs.next()) { + rows.add(new OperationOutputRow(rs)); + } + return rows; + } + } + } + + public static void setWorkflowState(DataSource ds, String workflowId, String newState) + throws SQLException { + String sql = + "UPDATE dbos.workflow_status SET status = ?, updated_at = ? WHERE workflow_uuid = ?"; + + try (var connection = ds.getConnection(); + PreparedStatement pstmt = connection.prepareStatement(sql)) { + pstmt.setString(1, newState); + pstmt.setLong(2, Instant.now().toEpochMilli()); + pstmt.setString(3, workflowId); + + int rowsAffected = pstmt.executeUpdate(); + assertEquals(1, rowsAffected); + } + } +} diff --git a/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/OperationOutputRow.java b/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/OperationOutputRow.java new file mode 100644 index 00000000..c1ac2938 --- /dev/null +++ b/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/OperationOutputRow.java @@ -0,0 +1,27 @@ +package dev.dbos.transact.utils; + +import java.sql.ResultSet; +import java.sql.SQLException; + +public record OperationOutputRow( + String workflowId, + int functionId, + String output, + String error, + String functionName, + String childWorkflowId, + Long startedAt, + Long completedAt) { + + public OperationOutputRow(ResultSet rs) throws SQLException { + this( + rs.getString("workflow_uuid"), + rs.getInt("function_id"), + rs.getString("output"), + rs.getString("error"), + rs.getString("function_name"), + rs.getString("child_workflow_id"), + rs.getObject("started_at_epoch_ms", Long.class), + rs.getObject("completed_at_epoch_ms", Long.class)); + } +} diff --git a/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java b/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java new file mode 100644 index 00000000..2c1431dd --- /dev/null +++ b/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java @@ -0,0 +1,117 @@ +package dev.dbos.transact.utils; + +import dev.dbos.transact.DBOSClient; +import dev.dbos.transact.config.DBOSConfig; +import dev.dbos.transact.database.SystemDatabase; + +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Objects; +import java.util.UUID; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Semaphore; + +import com.zaxxer.hikari.HikariDataSource; +import org.testcontainers.postgresql.PostgreSQLContainer; + +public class PgContainer implements AutoCloseable { + + private static final int SIZE = Runtime.getRuntime().availableProcessors(); + private static final BlockingQueue POOL = new ArrayBlockingQueue<>(SIZE); + private static final Semaphore PERMITS = new Semaphore(SIZE); + + static { + Runtime.getRuntime() + .addShutdownHook( + new Thread( + () -> { + var containers = new ArrayList(); + POOL.drainTo(containers); + containers.forEach(PostgreSQLContainer::stop); + })); + } + + static PostgreSQLContainer acquire() { + try { + PERMITS.acquire(); + var container = POOL.poll(); + if (container == null) { + container = new PostgreSQLContainer("postgres:18"); + container.start(); + } + return container; + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + static void release(PostgreSQLContainer c) { + POOL.offer(c); + PERMITS.release(); + } + + private final PostgreSQLContainer pgContainer; + private final String jdbcUrl; + private final String dbName; + + public PgContainer() { + // take a container from the pool and create a new database for it + pgContainer = acquire(); + dbName = "test_" + UUID.randomUUID().toString().replace("-", ""); + jdbcUrl = pgContainer.getJdbcUrl().replaceFirst("/[^/]+$", "/" + dbName); + + try (var conn = + DriverManager.getConnection( + pgContainer.getJdbcUrl(), pgContainer.getUsername(), pgContainer.getPassword()); + var stmt = conn.createStatement()) { + stmt.execute("CREATE DATABASE " + dbName); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + @Override + public void close() throws Exception { + // drop the database we created and return the container too the pool + var _jdbcUrl = pgContainer.getJdbcUrl(); + try (var conn = DriverManager.getConnection(_jdbcUrl, username(), password()); + var stmt = conn.createStatement()) { + var sql = "DROP DATABASE IF EXISTS %s WITH (FORCE)".formatted(dbName); + stmt.execute(sql); + } + release(pgContainer); + } + + public String jdbcUrl() { + return jdbcUrl; + } + + public String username() { + return pgContainer.getUsername(); + } + + public String password() { + return pgContainer.getPassword(); + } + + public DBOSConfig dbosConfig() { + return dbosConfig(null); + } + + public DBOSConfig dbosConfig(String appName) { + return DBOSConfig.defaults(Objects.requireNonNullElse(appName, "transact-java-test")) + .withDatabaseUrl(jdbcUrl()) + .withDbUser(username()) + .withDbPassword(password()); + } + + public HikariDataSource dataSource() { + return SystemDatabase.createDataSource(jdbcUrl(), username(), password()); + } + + public DBOSClient dbosClient() { + return new DBOSClient(jdbcUrl(), username(), password()); + } +} diff --git a/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/TxStepOutputRow.java b/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/TxStepOutputRow.java new file mode 100644 index 00000000..7472fe31 --- /dev/null +++ b/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/TxStepOutputRow.java @@ -0,0 +1,23 @@ +package dev.dbos.transact.utils; + +import java.sql.ResultSet; +import java.sql.SQLException; + +public record TxStepOutputRow( + String workflowId, + int stepId, + String output, + String error, + String serialization, + Long createdAt) { + + public TxStepOutputRow(ResultSet rs) throws SQLException { + this( + rs.getString("workflow_id"), + rs.getInt("step_id"), + rs.getString("output"), + rs.getString("error"), + rs.getString("serialization"), + rs.getObject("created_at", Long.class)); + } +} diff --git a/transact-jooq-step-factory/build.gradle.kts b/transact-jooq-step-factory/build.gradle.kts new file mode 100644 index 00000000..e8377772 --- /dev/null +++ b/transact-jooq-step-factory/build.gradle.kts @@ -0,0 +1,23 @@ +plugins { id("java-library") } + +tasks.withType { + options.compilerArgs.add("-Xlint:unchecked") + options.compilerArgs.add("-Xlint:deprecation") + options.compilerArgs.add("-Xlint:rawtypes") + options.compilerArgs.add("-Werror") +} + +dependencies { + api(project(":transact")) + api(libs.jooq) + compileOnly(libs.jaxb.api) + + testImplementation(platform(libs.junit.bom)) + testImplementation(libs.junit.jupiter) + testRuntimeOnly(libs.junit.platform.launcher) + + testRuntimeOnly(libs.logback.classic) + testImplementation(libs.testcontainers.postgresql) + testImplementation(libs.postgresql) + testImplementation(libs.hikaricp) +} diff --git a/transact-jooq-step-factory/src/main/java/dev/dbos/transact/jooq/JooqStepFactory.java b/transact-jooq-step-factory/src/main/java/dev/dbos/transact/jooq/JooqStepFactory.java new file mode 100644 index 00000000..225fb851 --- /dev/null +++ b/transact-jooq-step-factory/src/main/java/dev/dbos/transact/jooq/JooqStepFactory.java @@ -0,0 +1,164 @@ +package dev.dbos.transact.jooq; + +import dev.dbos.transact.DBOS; +import dev.dbos.transact.json.DBOSSerializer; +import dev.dbos.transact.json.SerializationUtil; +import dev.dbos.transact.txstep.PostgresStepFactory; +import dev.dbos.transact.workflow.internal.StepResult; + +import java.util.Objects; +import java.util.Optional; + +import org.jooq.Configuration; +import org.jooq.DSLContext; +import org.jooq.TransactionalCallable; +import org.jooq.TransactionalRunnable; + +/** + * Runs idempotent transactional steps inside DBOS workflows using jOOQ {@link DSLContext} objects. + * + *

Construct one with a {@link DSLContext} connected to a PostgreSQL database. The constructor + * verifies the datasource is PostgreSQL and creates the {@code tx_step_outputs} table if needed. + * Lambdas passed to {@link #txStepResult} or {@link #txStep} receive a jOOQ {@link + * org.jooq.Configuration} with a transaction already open; they must not commit or close the + * underlying connection themselves. + * + *

{@code
+ * JooqStepFactory factory = new JooqStepFactory(dbos, dslContext);
+ *
+ * // inside a @Workflow method:
+ * int count = factory.txStepResult(trx -> {
+ *     return trx.dsl().insertInto(...).execute();
+ * }, "myStep");
+ * }
+ */ +public class JooqStepFactory extends PostgresStepFactory { + + private final DSLContext dsl; + + /** Creates a factory using the schema from the DBOS config. */ + public JooqStepFactory(DBOS dbos, DSLContext dsl) { + this(dbos, dsl, null, null); + } + + /** Creates a factory using a custom schema for {@code tx_step_outputs}. */ + public JooqStepFactory(DBOS dbos, DSLContext dsl, String schema) { + this(dbos, dsl, schema, null); + } + + /** Creates a factory using a custom serializer. */ + public JooqStepFactory(DBOS dbos, DSLContext dsl, DBOSSerializer serializer) { + this(dbos, dsl, null, serializer); + } + + /** + * Creates a factory with a custom schema and serializer. + * + *

Connects to the database immediately to verify it is PostgreSQL and to create the {@code + * tx_step_outputs} table in the given schema if it does not already exist. + * + * @param dbos the DBOS runtime instance + * @param dsl a DSLContext connected to a PostgreSQL database + * @param schema the PostgreSQL schema to use for {@code tx_step_outputs}; {@code null} uses the + * schema from {@code dbos} configuration + * @param serializer the serializer to use for step outputs; {@code null} uses the serializer from + * {@code dbos} configuration + * @throws RuntimeException if the datasource is not PostgreSQL or the schema setup fails + */ + public JooqStepFactory(DBOS dbos, DSLContext dsl, String schema, DBOSSerializer serializer) { + super(dbos, schema, serializer, () -> dsl.configuration().connectionProvider().acquire()); + this.dsl = Objects.requireNonNull(dsl); + } + + /** + * Executes {@code callback} as an idempotent DBOS step inside a jOOQ transaction. + * + *

If a result for this step is already recorded (e.g. on workflow retry), the callback is + * skipped and the cached result is returned. Otherwise the callback runs inside an open + * transaction; the output is recorded atomically with the database work so the step is + * exactly-once on success. + * + * @param the return type of the callback + * @param callback the database work to perform; receives a jOOQ {@link org.jooq.Configuration} + * with an open transaction and must not commit or close the underlying connection + * @param stepName a stable name that identifies this step within the workflow + * @return the value returned by {@code callback} + */ + public T txStepResult(TransactionalCallable callback, String stepName) { + return runTxStep( + (wfId, stepId) -> + dsl.transactionResult( + trx -> { + var result = callback.run(trx); + recordOutput(trx, wfId, stepId, result); + return result; + }), + stepName); + } + + /** + * Executes {@code transactional} as an idempotent DBOS step inside a jOOQ transaction, with no + * return value. + * + *

Behaves identically to {@link #txStepResult} but accepts a {@link TransactionalRunnable} for + * callers that do not need to return a result. + * + * @param transactional the database work to perform; receives a jOOQ {@link + * org.jooq.Configuration} with an open transaction and must not commit or close the + * underlying connection + * @param stepName a stable name that identifies this step within the workflow + */ + public void txStep(TransactionalRunnable transactional, String stepName) { + txStepResult( + c -> { + transactional.run(c); + return null; + }, + stepName); + } + + @Override + protected Optional checkExecution(String workflowId, int stepId, String stepName) { + return dsl.fetchOptional(checkSql(), workflowId, stepId) + .map( + r -> + new StepResult( + workflowId, + stepId, + stepName, + r.get("output", String.class), + r.get("error", String.class), + null, + r.get("serialization", String.class))); + } + + private void recordOutput(Configuration trx, String workflowId, int stepId, R result) { + var value = SerializationUtil.serializeValue(result, null, serializer); + recordResult( + trx.dsl(), workflowId, stepId, value.serializedValue(), null, value.serialization()); + } + + @Override + protected void recordError(String workflowId, int stepId, Exception exception) { + var value = SerializationUtil.serializeError(exception, null, serializer); + dsl.transaction( + trx -> + recordResult( + trx.dsl(), + workflowId, + stepId, + null, + value.serializedValue(), + value.serialization())); + } + + private void recordResult( + DSLContext ctx, + String workflowId, + int stepId, + String output, + String error, + String serialization) { + ctx.execute(upsertSql(), workflowId, stepId, output, error, serialization); + } +} diff --git a/transact-jooq-step-factory/src/test/java/dev/dbos/transact/jooq/JooqStepFactoryTest.java b/transact-jooq-step-factory/src/test/java/dev/dbos/transact/jooq/JooqStepFactoryTest.java new file mode 100644 index 00000000..2f5846ff --- /dev/null +++ b/transact-jooq-step-factory/src/test/java/dev/dbos/transact/jooq/JooqStepFactoryTest.java @@ -0,0 +1,417 @@ +package dev.dbos.transact.jooq; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import dev.dbos.transact.DBOS; +import dev.dbos.transact.config.DBOSConfig; +import dev.dbos.transact.context.WorkflowOptions; +import dev.dbos.transact.json.SerializationUtil; +import dev.dbos.transact.utils.DBUtils; +import dev.dbos.transact.utils.PgContainer; +import dev.dbos.transact.workflow.Workflow; +import dev.dbos.transact.workflow.WorkflowHandle; + +import java.sql.SQLException; +import java.util.Objects; + +import com.zaxxer.hikari.HikariDataSource; +import org.jooq.DSLContext; +import org.jooq.SQLDialect; +import org.jooq.impl.DSL; +import org.junit.jupiter.api.AutoClose; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +interface FactoryTestService { + record TestResult(String user, int greetCount) {} + + TestResult insertWorkflow(String user); + + TestResult errorWorkflow(String user); + + TestResult readWorkflow(String user); + + TestResult insertThenReadWorkflow(String user); +} + +class FactoryTestServiceImpl implements FactoryTestService { + + private final JooqStepFactory stepFactory; + + public FactoryTestServiceImpl(JooqStepFactory stepFactory) { + this.stepFactory = stepFactory; + } + + FactoryTestService.TestResult insertGreeting(DSLContext ctx, String user) { + var sql = + """ + INSERT INTO greetings(name, greet_count) + VALUES (?, 1) + ON CONFLICT(name) + DO UPDATE SET greet_count = greetings.greet_count + 1 + RETURNING greet_count + """; + var record = ctx.fetchOne(sql, Objects.requireNonNull(user)); + int greetCount = record != null ? record.get("greet_count", Integer.class) : 0; + return new FactoryTestService.TestResult(user, greetCount); + } + + FactoryTestService.TestResult errorGreeting(DSLContext ctx, String user) { + insertGreeting(ctx, user); + throw new RuntimeException("Test Exception %d".formatted(System.currentTimeMillis())); + } + + FactoryTestService.TestResult readGreeting(DSLContext ctx, String user) { + var sql = "SELECT greet_count FROM greetings WHERE name = ?"; + var record = ctx.fetchOne(sql, Objects.requireNonNull(user)); + int greetCount = record != null ? record.get("greet_count", Integer.class) : 0; + return new FactoryTestService.TestResult(user, greetCount); + } + + @Override + @Workflow + public FactoryTestService.TestResult insertWorkflow(String user) { + return stepFactory.txStepResult(ctx -> insertGreeting(ctx.dsl(), user), "insertGreeting"); + } + + @Override + @Workflow + public FactoryTestService.TestResult errorWorkflow(String user) { + return stepFactory.txStepResult(ctx -> errorGreeting(ctx.dsl(), user), "errorGreeting"); + } + + @Override + @Workflow + public FactoryTestService.TestResult readWorkflow(String user) { + return stepFactory.txStepResult(ctx -> readGreeting(ctx.dsl(), user), "readGreeting"); + } + + @Override + @Workflow + public FactoryTestService.TestResult insertThenReadWorkflow(String user) { + stepFactory.txStep(ctx -> insertGreeting(ctx.dsl(), user), "insertGreeting"); + return stepFactory.txStepResult(ctx -> readGreeting(ctx.dsl(), user), "readGreeting"); + } +} + +public class JooqStepFactoryTest { + @AutoClose final PgContainer pgContainer = new PgContainer(); + + DBOSConfig dbosConfig; + @AutoClose DBOS dbos; + @AutoClose HikariDataSource dataSource; + JooqStepFactory stepFactory; + FactoryTestService proxy; + FactoryTestServiceImpl impl; + + @BeforeEach + void beforeEach() throws SQLException { + dbosConfig = pgContainer.dbosConfig(); + dataSource = pgContainer.dataSource(); + + try (var conn = dataSource.getConnection(); + var stmt = conn.createStatement()) { + stmt.execute( + "CREATE TABLE greetings(name text NOT NULL, greet_count integer DEFAULT 0, PRIMARY KEY(name))"); + } + + dbos = new DBOS(dbosConfig); + DSLContext dsl = DSL.using(dataSource, SQLDialect.POSTGRES); + stepFactory = new JooqStepFactory(dbos, dsl); + + impl = new FactoryTestServiceImpl(stepFactory); + proxy = dbos.registerProxy(FactoryTestService.class, impl); + + dbos.launch(); + } + + private int getGreetCount(String user) throws SQLException { + var sql = "SELECT greet_count FROM greetings WHERE name = ?"; + try (var conn = dataSource.getConnection(); + var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, user); + try (var rs = stmt.executeQuery()) { + return rs.next() ? rs.getInt("greet_count") : 0; + } + } + } + + @Test + public void testInsert() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + + var rows = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(1, rows.size()); + var row = rows.get(0); + assertEquals(wfid, row.workflowId()); + assertEquals(0, row.stepId()); + assertNotNull(row.output()); + assertNull(row.error()); + assertEquals(SerializationUtil.NATIVE, row.serialization()); + var output = SerializationUtil.deserializeValue(row.output(), row.serialization(), null); + assertEquals(new FactoryTestService.TestResult(user, 1), output); + + assertEquals(1, getGreetCount(user)); + } + + @Test + public void testError() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + try (var _o = new WorkflowOptions(wfid).setContext()) { + assertThrows(RuntimeException.class, () -> proxy.errorWorkflow(user)); + } + + // Transaction rolled back — no greeting inserted + var rows = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(1, rows.size()); + var row = rows.get(0); + assertEquals(wfid, row.workflowId()); + assertEquals(0, row.stepId()); + assertNull(row.output()); + assertNotNull(row.error()); + + assertEquals(0, getGreetCount(user)); + } + + @Test + public void testRead() throws Exception { + var insertWfid = "wf1"; + var readWfid = "wf2"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(insertWfid).setContext()) { + proxy.insertWorkflow(user); + } + + try (var _o = new WorkflowOptions(readWfid).setContext()) { + var result = proxy.readWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + + var rows = DBUtils.getTxStepRows(dataSource, readWfid); + assertEquals(1, rows.size()); + var row = rows.get(0); + assertEquals(readWfid, row.workflowId()); + assertEquals(0, row.stepId()); + assertNotNull(row.output()); + assertNull(row.error()); + assertEquals(SerializationUtil.NATIVE, row.serialization()); + var output = SerializationUtil.deserializeValue(row.output(), row.serialization(), null); + assertEquals(new FactoryTestService.TestResult(user, 1), output); + + assertEquals(1, getGreetCount(user)); + } + + @Test + public void testIdempotency() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + + // Second call with same wfid — txStep output is cached, insert not re-executed + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + + assertEquals(1, getGreetCount(user)); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid).size()); + } + + @Test + public void testRetryError() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(wfid).setContext()) { + assertThrows(RuntimeException.class, () -> proxy.errorWorkflow(user)); + } + assertEquals(0, getGreetCount(user)); + dbos.close(); + + try (var conn = dataSource.getConnection(); + var stmt = + conn.prepareStatement("DELETE FROM dbos.operation_outputs WHERE workflow_uuid = ?")) { + stmt.setString(1, wfid); + stmt.executeUpdate(); + } + DBUtils.setWorkflowState(dataSource, wfid, "PENDING"); + + assertEquals(0, DBUtils.getStepRows(dataSource, wfid).size()); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid).size()); + + dbos.launch(); + WorkflowHandle handle = + dbos.retrieveWorkflow(wfid); + assertThrows(RuntimeException.class, handle::getResult); + + // Cached error replayed — insert still not committed + assertEquals(0, getGreetCount(user)); + var txSteps = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(1, txSteps.size()); + assertNull(txSteps.get(0).output()); + assertNotNull(txSteps.get(0).error()); + } + + @Test + public void testMultipleTxSteps() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertThenReadWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + + assertEquals(1, getGreetCount(user)); + + var rows = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(2, rows.size()); + assertEquals(0, rows.get(0).stepId()); + assertNotNull(rows.get(0).output()); + assertNull(rows.get(0).error()); + assertEquals(1, rows.get(1).stepId()); + assertNotNull(rows.get(1).output()); + assertNull(rows.get(1).error()); + } + + @Test + public void testDistinctWorkflows() throws Exception { + var wfid1 = "wf1"; + var wfid2 = "wf2"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(wfid1).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(1, result.greetCount()); + } + + try (var _o = new WorkflowOptions(wfid2).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(2, result.greetCount()); + } + + assertEquals(2, getGreetCount(user)); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid1).size()); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid2).size()); + } + + @Test + public void testRetryPartialMultipleSteps() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertThenReadWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + assertEquals(1, getGreetCount(user)); + dbos.close(); + + // Simulate crash after step 0 wrote tx_step_outputs but before step 1 ran: + // both operation_outputs rows are gone, and step 1 has no tx_step_outputs entry + try (var conn = dataSource.getConnection(); + var stmt = + conn.prepareStatement("DELETE FROM dbos.operation_outputs WHERE workflow_uuid = ?")) { + stmt.setString(1, wfid); + stmt.executeUpdate(); + } + try (var conn = dataSource.getConnection(); + var stmt = + conn.prepareStatement( + "DELETE FROM dbos.tx_step_outputs WHERE workflow_id = ? AND step_id = 1")) { + stmt.setString(1, wfid); + stmt.executeUpdate(); + } + DBUtils.setWorkflowState(dataSource, wfid, "PENDING"); + + assertEquals(0, DBUtils.getStepRows(dataSource, wfid).size()); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid).size()); + + var relaunchTimestamp = System.currentTimeMillis(); + dbos.launch(); + WorkflowHandle handle = + dbos.retrieveWorkflow(wfid); + var result = (FactoryTestService.TestResult) handle.getResult(); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + + // Step 0 cache hit — insert not re-executed + assertEquals(1, getGreetCount(user)); + + var txSteps = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(2, txSteps.size()); + assertTrue(txSteps.get(0).createdAt() < relaunchTimestamp); // step 0: original run + assertTrue(txSteps.get(1).createdAt() >= relaunchTimestamp); // step 1: re-executed on retry + } + + @Test + public void testRetryInsert() throws Exception { + var timestamp = System.currentTimeMillis(); + + var wfid = "wf1"; + var user = "testUser"; + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + dbos.close(); + + try (var conn = dataSource.getConnection(); + var stmt = + conn.prepareStatement("DELETE FROM dbos.operation_outputs WHERE workflow_uuid = ?")) { + stmt.setString(1, wfid); + stmt.executeUpdate(); + } + DBUtils.setWorkflowState(dataSource, wfid, "PENDING"); + + assertEquals(0, DBUtils.getStepRows(dataSource, wfid).size()); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid).size()); + + var relaunchTimestamp = System.currentTimeMillis(); + dbos.launch(); + var handle = dbos.retrieveWorkflow(wfid); + var result = (FactoryTestService.TestResult) handle.getResult(); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + + var steps = DBUtils.getStepRows(dataSource, wfid); + var txSteps = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(1, steps.size()); + assertEquals(1, txSteps.size()); + + var step = steps.get(0); + var txStep = txSteps.get(0); + assertEquals(step.output(), txStep.output()); + assertEquals(step.error(), txStep.error()); + + assertTrue(txStep.createdAt() < step.startedAt()); + assertTrue(timestamp < txStep.createdAt()); + assertTrue(txStep.createdAt() < relaunchTimestamp); + assertTrue(relaunchTimestamp < step.startedAt()); + + // Retry reads from tx_step_outputs cache — insert not re-executed + assertEquals(1, getGreetCount(user)); + } +} diff --git a/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/DBUtils.java b/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/DBUtils.java new file mode 100644 index 00000000..7f16faa1 --- /dev/null +++ b/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/DBUtils.java @@ -0,0 +1,69 @@ +package dev.dbos.transact.utils; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import dev.dbos.transact.database.SystemDatabase; + +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import javax.sql.DataSource; + +public class DBUtils { + + public static List getTxStepRows(DataSource ds, String workflowId) + throws SQLException { + var sql = + "SELECT * FROM \"%s\".tx_step_outputs WHERE workflow_id = ? ORDER BY step_id" + .formatted(SystemDatabase.sanitizeSchema(null)); + try (var conn = ds.getConnection(); + var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, Objects.requireNonNull(workflowId)); + try (var rs = stmt.executeQuery()) { + List rows = new ArrayList<>(); + while (rs.next()) { + rows.add(new TxStepOutputRow(rs)); + } + return rows; + } + } + } + + public static List getStepRows(DataSource ds, String workflowId) + throws SQLException { + var sql = + "SELECT * FROM \"%s\".operation_outputs WHERE workflow_uuid = ? ORDER BY function_id" + .formatted(SystemDatabase.sanitizeSchema(null)); + try (var conn = ds.getConnection(); + var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, workflowId); + try (var rs = stmt.executeQuery()) { + List rows = new ArrayList<>(); + while (rs.next()) { + rows.add(new OperationOutputRow(rs)); + } + return rows; + } + } + } + + public static void setWorkflowState(DataSource ds, String workflowId, String newState) + throws SQLException { + String sql = + "UPDATE dbos.workflow_status SET status = ?, updated_at = ? WHERE workflow_uuid = ?"; + + try (var connection = ds.getConnection(); + PreparedStatement pstmt = connection.prepareStatement(sql)) { + pstmt.setString(1, newState); + pstmt.setLong(2, Instant.now().toEpochMilli()); + pstmt.setString(3, workflowId); + + int rowsAffected = pstmt.executeUpdate(); + assertEquals(1, rowsAffected); + } + } +} diff --git a/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/OperationOutputRow.java b/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/OperationOutputRow.java new file mode 100644 index 00000000..c1ac2938 --- /dev/null +++ b/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/OperationOutputRow.java @@ -0,0 +1,27 @@ +package dev.dbos.transact.utils; + +import java.sql.ResultSet; +import java.sql.SQLException; + +public record OperationOutputRow( + String workflowId, + int functionId, + String output, + String error, + String functionName, + String childWorkflowId, + Long startedAt, + Long completedAt) { + + public OperationOutputRow(ResultSet rs) throws SQLException { + this( + rs.getString("workflow_uuid"), + rs.getInt("function_id"), + rs.getString("output"), + rs.getString("error"), + rs.getString("function_name"), + rs.getString("child_workflow_id"), + rs.getObject("started_at_epoch_ms", Long.class), + rs.getObject("completed_at_epoch_ms", Long.class)); + } +} diff --git a/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java b/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java new file mode 100644 index 00000000..2c1431dd --- /dev/null +++ b/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java @@ -0,0 +1,117 @@ +package dev.dbos.transact.utils; + +import dev.dbos.transact.DBOSClient; +import dev.dbos.transact.config.DBOSConfig; +import dev.dbos.transact.database.SystemDatabase; + +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Objects; +import java.util.UUID; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Semaphore; + +import com.zaxxer.hikari.HikariDataSource; +import org.testcontainers.postgresql.PostgreSQLContainer; + +public class PgContainer implements AutoCloseable { + + private static final int SIZE = Runtime.getRuntime().availableProcessors(); + private static final BlockingQueue POOL = new ArrayBlockingQueue<>(SIZE); + private static final Semaphore PERMITS = new Semaphore(SIZE); + + static { + Runtime.getRuntime() + .addShutdownHook( + new Thread( + () -> { + var containers = new ArrayList(); + POOL.drainTo(containers); + containers.forEach(PostgreSQLContainer::stop); + })); + } + + static PostgreSQLContainer acquire() { + try { + PERMITS.acquire(); + var container = POOL.poll(); + if (container == null) { + container = new PostgreSQLContainer("postgres:18"); + container.start(); + } + return container; + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + static void release(PostgreSQLContainer c) { + POOL.offer(c); + PERMITS.release(); + } + + private final PostgreSQLContainer pgContainer; + private final String jdbcUrl; + private final String dbName; + + public PgContainer() { + // take a container from the pool and create a new database for it + pgContainer = acquire(); + dbName = "test_" + UUID.randomUUID().toString().replace("-", ""); + jdbcUrl = pgContainer.getJdbcUrl().replaceFirst("/[^/]+$", "/" + dbName); + + try (var conn = + DriverManager.getConnection( + pgContainer.getJdbcUrl(), pgContainer.getUsername(), pgContainer.getPassword()); + var stmt = conn.createStatement()) { + stmt.execute("CREATE DATABASE " + dbName); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + @Override + public void close() throws Exception { + // drop the database we created and return the container too the pool + var _jdbcUrl = pgContainer.getJdbcUrl(); + try (var conn = DriverManager.getConnection(_jdbcUrl, username(), password()); + var stmt = conn.createStatement()) { + var sql = "DROP DATABASE IF EXISTS %s WITH (FORCE)".formatted(dbName); + stmt.execute(sql); + } + release(pgContainer); + } + + public String jdbcUrl() { + return jdbcUrl; + } + + public String username() { + return pgContainer.getUsername(); + } + + public String password() { + return pgContainer.getPassword(); + } + + public DBOSConfig dbosConfig() { + return dbosConfig(null); + } + + public DBOSConfig dbosConfig(String appName) { + return DBOSConfig.defaults(Objects.requireNonNullElse(appName, "transact-java-test")) + .withDatabaseUrl(jdbcUrl()) + .withDbUser(username()) + .withDbPassword(password()); + } + + public HikariDataSource dataSource() { + return SystemDatabase.createDataSource(jdbcUrl(), username(), password()); + } + + public DBOSClient dbosClient() { + return new DBOSClient(jdbcUrl(), username(), password()); + } +} diff --git a/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/TxStepOutputRow.java b/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/TxStepOutputRow.java new file mode 100644 index 00000000..7472fe31 --- /dev/null +++ b/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/TxStepOutputRow.java @@ -0,0 +1,23 @@ +package dev.dbos.transact.utils; + +import java.sql.ResultSet; +import java.sql.SQLException; + +public record TxStepOutputRow( + String workflowId, + int stepId, + String output, + String error, + String serialization, + Long createdAt) { + + public TxStepOutputRow(ResultSet rs) throws SQLException { + this( + rs.getString("workflow_id"), + rs.getInt("step_id"), + rs.getString("output"), + rs.getString("error"), + rs.getString("serialization"), + rs.getObject("created_at", Long.class)); + } +} diff --git a/transact/src/main/java/dev/dbos/transact/DBOS.java b/transact/src/main/java/dev/dbos/transact/DBOS.java index a6d1a4ca..7278a881 100644 --- a/transact/src/main/java/dev/dbos/transact/DBOS.java +++ b/transact/src/main/java/dev/dbos/transact/DBOS.java @@ -62,9 +62,7 @@ public class DBOS implements AutoCloseable { private final Set lifecycleRegistry = ConcurrentHashMap.newKeySet(); private final DBOSConfig config; private final AtomicReference dbosExecutor = new AtomicReference<>(); - private final DBOSIntegration integration = - new DBOSIntegration( - dbosExecutor::get, this::registerLifecycleListener, this::registerWorkflow); + private final DBOSIntegration integration; private AlertHandler alertHandler; @@ -83,7 +81,13 @@ public DBOS(@NonNull DBOSConfig config) { Objects.requireNonNull(config.dbPassword(), "DBOSConfig.dbPassword must not be null"); } - this.config = config; + this.config = new DBOSConfig(config); + this.integration = + new DBOSIntegration( + this.config, + dbosExecutor::get, + this::registerLifecycleListener, + this::registerWorkflow); } /** diff --git a/transact/src/main/java/dev/dbos/transact/config/DBOSConfig.java b/transact/src/main/java/dev/dbos/transact/config/DBOSConfig.java index ac00eb2e..79b06be5 100644 --- a/transact/src/main/java/dev/dbos/transact/config/DBOSConfig.java +++ b/transact/src/main/java/dev/dbos/transact/config/DBOSConfig.java @@ -58,6 +58,31 @@ public record DBOSConfig( listenQueues = (listenQueues == null) ? Set.of() : Set.copyOf(listenQueues); } + // Copy constructor + public DBOSConfig(DBOSConfig other) { + this( + other.appName, + other.databaseUrl, + other.dbUser, + other.dbPassword, + other.dataSource, + other.adminServer, + other.adminServerPort, + other.migrate, + other.conductorKey, + other.conductorDomain, + (other.conductorExecutorMetadata == null + ? null + : Map.copyOf(other.conductorExecutorMetadata)), + other.appVersion, + other.executorId, + other.databaseSchema, + other.enablePatching, + (other.listenQueues == null ? null : Set.copyOf(other.listenQueues)), + other.serializer, + other.schedulerPollingInterval); + } + public static @NonNull DBOSConfig defaults(@NonNull String appName) { return new DBOSConfig( appName, null, null, null, null, false, // adminServer diff --git a/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java b/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java index 66af3d26..f5cf67eb 100644 --- a/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java @@ -69,7 +69,7 @@ static void recordStepResultTxn( try (PreparedStatement pstmt = connection.prepareStatement(sql)) { pstmt.setString(1, result.workflowId()); pstmt.setInt(2, result.stepId()); - pstmt.setString(3, result.functionName()); + pstmt.setString(3, result.stepName()); if (result.output() != null) { pstmt.setString(4, result.output()); @@ -99,7 +99,7 @@ static void recordStepResultTxn( logger.warn( String.format( "Step output for %s:%d-%s was already recorded", - result.workflowId(), result.stepId(), result.functionName())); + result.workflowId(), result.stepId(), result.stepName())); throw new DBOSWorkflowExecutionConflictException(result.workflowId()); } } diff --git a/transact/src/main/java/dev/dbos/transact/internal/DBOSIntegration.java b/transact/src/main/java/dev/dbos/transact/internal/DBOSIntegration.java index 32e708cc..5338bedf 100644 --- a/transact/src/main/java/dev/dbos/transact/internal/DBOSIntegration.java +++ b/transact/src/main/java/dev/dbos/transact/internal/DBOSIntegration.java @@ -1,6 +1,7 @@ package dev.dbos.transact.internal; import dev.dbos.transact.StartWorkflowOptions; +import dev.dbos.transact.config.DBOSConfig; import dev.dbos.transact.database.ExternalState; import dev.dbos.transact.execution.DBOSExecutor; import dev.dbos.transact.execution.DBOSLifecycleListener; @@ -35,14 +36,17 @@ public interface RegisteredWorkflowConsumer { void register(Workflow wfTag, Object target, Method method, String instanceName); } + private final DBOSConfig config; private final Supplier executorSupplier; private final Consumer listenerConsumer; private final RegisteredWorkflowConsumer workflowConsumer; public DBOSIntegration( + @NonNull DBOSConfig config, @NonNull Supplier executorSupplier, @NonNull Consumer lifecycleConsumer, @NonNull RegisteredWorkflowConsumer workflowConsumer) { + this.config = Objects.requireNonNull(config); this.executorSupplier = Objects.requireNonNull(executorSupplier); this.listenerConsumer = Objects.requireNonNull(lifecycleConsumer); this.workflowConsumer = Objects.requireNonNull(workflowConsumer); @@ -57,6 +61,10 @@ private DBOSExecutor executor(String caller) { return exec; } + public DBOSConfig config() { + return this.config; + } + /** * Register a lifecycle listener that receives callbacks when DBOS is launched or shut down * diff --git a/transact/src/main/java/dev/dbos/transact/txstep/JdbcStepFactory.java b/transact/src/main/java/dev/dbos/transact/txstep/JdbcStepFactory.java new file mode 100644 index 00000000..da17e8ef --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/txstep/JdbcStepFactory.java @@ -0,0 +1,234 @@ +package dev.dbos.transact.txstep; + +import dev.dbos.transact.DBOS; +import dev.dbos.transact.execution.ThrowingRunnable; +import dev.dbos.transact.execution.ThrowingSupplier; +import dev.dbos.transact.json.DBOSSerializer; +import dev.dbos.transact.json.SerializationUtil; +import dev.dbos.transact.workflow.internal.StepResult; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Objects; +import java.util.Optional; + +import javax.sql.DataSource; + +/** + * A StepFactory implementation backed by plain JDBC {@link Connection} objects. + * + *

Construct one with a {@link DataSource} pointing at a PostgreSQL database. The constructor + * verifies the datasource is PostgreSQL and creates the {@code tx_step_outputs} table if needed. + * User lambdas passed to {@code txStep} receive a {@link Connection} with a transaction already + * started; they should not call {@code commit} or {@code close} themselves. + * + *

{@code
+ * JdbcStepFactory factory = new JdbcStepFactory(dbos, dataSource);
+ *
+ * // inside a @Workflow method:
+ * int count = factory.txStep(conn -> {
+ *     try (var stmt = conn.prepareStatement("INSERT INTO ...")) { ... }
+ *     return rowCount;
+ * }, "myStep");
+ * }
+ */ +public class JdbcStepFactory extends PostgresStepFactory { + + private final DataSource dataSource; + + /** Creates a factory using the schema from the DBOS config. */ + public JdbcStepFactory(DBOS dbos, DataSource dataSource) { + this(dbos, dataSource, null, null); + } + + /** Creates a factory using a custom schema for {@code tx_step_outputs}. */ + public JdbcStepFactory(DBOS dbos, DataSource dataSource, String schema) { + this(dbos, dataSource, schema, null); + } + + /** Creates a factory using a custom serializer. */ + public JdbcStepFactory(DBOS dbos, DataSource dataSource, DBOSSerializer serializer) { + this(dbos, dataSource, null, serializer); + } + + /** + * Creates a factory with a custom schema and serializer. + * + *

Connects to the database immediately to verify it is PostgreSQL and to create the {@code + * tx_step_outputs} table in the given schema if it does not already exist. + * + * @param dbos the DBOS runtime instance + * @param dataSource a DataSource connected to a PostgreSQL database + * @param schema the PostgreSQL schema to use for {@code tx_step_outputs}; {@code null} uses the + * schema from {@code dbos} configuration + * @param serializer the serializer to use for step outputs; {@code null} uses the serializer from + * {@code dbos} configuration + * @throws RuntimeException if the datasource is not PostgreSQL or the schema setup fails + */ + public JdbcStepFactory( + DBOS dbos, DataSource dataSource, String schema, DBOSSerializer serializer) { + super(dbos, schema, serializer, dataSource::getConnection); + this.dataSource = Objects.requireNonNull(dataSource); + } + + @Override + protected Optional checkExecution(String workflowId, int stepId, String stepName) { + try (var conn = dataSource.getConnection(); + var stmt = conn.prepareStatement(checkSql())) { + stmt.setString(1, workflowId); + stmt.setInt(2, stepId); + try (var rs = stmt.executeQuery()) { + if (!rs.next()) return Optional.empty(); + return Optional.of( + new StepResult( + workflowId, + stepId, + stepName, + rs.getString("output"), + rs.getString("error"), + null, + rs.getString("serialization"))); + } + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + /** Database work that runs inside a JDBC transaction and returns a result. */ + @FunctionalInterface + public interface TransactionalFunction { + R execute(Connection conn) throws X; + } + + /** + * Executes {@code callback} as an idempotent DBOS step inside a JDBC transaction. + * + *

If a result for this step is already recorded (e.g. on workflow retry), the callback is + * skipped and the cached result is returned. Otherwise the callback runs inside an open + * transaction; the output is recorded atomically with the database work so the step is + * exactly-once on success. + * + * @param the return type of the callback + * @param the checked exception type the callback may throw + * @param callback the database work to perform; receives an open {@link Connection} and must not + * commit or close it + * @param stepName a stable name that identifies this step within the workflow + * @return the value returned by {@code callback} + * @throws X if the callback throws + */ + public R txStep( + final TransactionalFunction callback, String stepName) throws X { + return runTxStep( + (wfId, stepId) -> + executeTransaction( + dataSource, + c -> { + var result = callback.execute(c); + recordOutput(c, wfId, stepId, result); + return result; + }), + stepName); + } + + /** Database work that runs inside a JDBC transaction without returning a result. */ + @FunctionalInterface + public interface TransactionalRunnable { + void execute(Connection conn) throws X; + } + + /** + * Executes {@code callback} as an idempotent DBOS step inside a JDBC transaction, with no return + * value. + * + *

Behaves identically to {@link #txStep(TransactionalFunction, String)} but accepts a {@link + * TransactionalRunnable} for callers that do not need to return a result. + * + * @param the checked exception type the callback may throw + * @param callback the database work to perform; receives an open {@link Connection} and must not + * commit or close it + * @param stepName a stable name that identifies this step within the workflow + * @throws X if the callback throws + */ + public void txStep(final TransactionalRunnable callback, String stepName) + throws X { + txStep( + c -> { + callback.execute(c); + return null; + }, + stepName); + } + + private static R executeTransaction( + final DataSource ds, TransactionalFunction func) throws X { + var conn = + safeGet( + () -> { + var c = ds.getConnection(); + c.setAutoCommit(false); + return c; + }); + try { + var result = func.execute(conn); + safely(conn::commit); + return result; + } catch (Exception e) { + safely(conn::rollback); + throw e; + } finally { + safely(conn::close); + } + } + + private static void safely(ThrowingRunnable op) { + try { + op.execute(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + private static T safeGet(ThrowingSupplier supplier) { + try { + return supplier.execute(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + private void recordOutput(Connection conn, String workflowId, int stepId, R result) { + var value = SerializationUtil.serializeValue(result, null, serializer); + recordResult(conn, workflowId, stepId, value.serializedValue(), null, value.serialization()); + } + + @Override + protected void recordError(String workflowId, int stepId, Exception exception) { + var value = SerializationUtil.serializeError(exception, null, serializer); + executeTransaction( + dataSource, + (Connection conn) -> { + recordResult( + conn, workflowId, stepId, null, value.serializedValue(), value.serialization()); + return null; + }); + } + + private void recordResult( + Connection conn, + String workflowId, + int stepId, + String output, + String error, + String serialization) { + try (var stmt = conn.prepareStatement(upsertSql())) { + stmt.setString(1, workflowId); + stmt.setInt(2, stepId); + stmt.setString(3, output); + stmt.setString(4, error); + stmt.setString(5, serialization); + stmt.executeUpdate(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } +} diff --git a/transact/src/main/java/dev/dbos/transact/txstep/PostgresStepFactory.java b/transact/src/main/java/dev/dbos/transact/txstep/PostgresStepFactory.java new file mode 100644 index 00000000..8ba79b09 --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/txstep/PostgresStepFactory.java @@ -0,0 +1,123 @@ +package dev.dbos.transact.txstep; + +import dev.dbos.transact.DBOS; +import dev.dbos.transact.database.SystemDatabase; +import dev.dbos.transact.json.DBOSSerializer; +import dev.dbos.transact.workflow.internal.StepResult; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Objects; +import java.util.Optional; + +/** + * Abstract base for transactional step factories backed by a PostgreSQL database. + * + *

Subclasses provide a database-library-specific public API (e.g. plain JDBC {@link Connection}, + * JDBI {@code Handle}, jOOQ {@code DSLContext}) while this class owns the shared step lifecycle: + * idempotency checking, error recording, and the {@link #runTxStep} template method that integrates + * with the DBOS runtime. + * + *

The constructor verifies that the datasource is PostgreSQL and creates the {@code + * tx_step_outputs} table (and its enclosing schema) if they do not already exist. + */ +public abstract class PostgresStepFactory { + + protected final DBOS dbos; + protected final String schema; + protected final DBOSSerializer serializer; + + @FunctionalInterface + protected interface ConnectionOpener { + Connection open() throws SQLException; + } + + protected PostgresStepFactory( + DBOS dbos, String schema, DBOSSerializer serializer, ConnectionOpener opener) { + this.dbos = Objects.requireNonNull(dbos); + var config = dbos.integration().config(); + this.schema = SystemDatabase.sanitizeSchema(schema == null ? config.databaseSchema() : schema); + this.serializer = serializer == null ? config.serializer() : serializer; + + try (var conn = opener.open()) { + // ensure we're running on Postgres + var productName = conn.getMetaData().getDatabaseProductName(); + if (!productName.equalsIgnoreCase("PostgreSQL")) { + throw new IllegalArgumentException( + "TxStepFactory requires a PostgreSQL datasource, got: " + productName); + } + + // ensure provided schema and tx_step_outputs table exist + try (var stmt = conn.createStatement()) { + stmt.addBatch("CREATE SCHEMA IF NOT EXISTS \"%s\"".formatted(this.schema)); + stmt.addBatch( + """ + CREATE TABLE IF NOT EXISTS "%1$s".tx_step_outputs ( + workflow_id TEXT NOT NULL, + step_id INT NOT NULL, + output TEXT, + error TEXT, + serialization TEXT, + created_at BIGINT NOT NULL DEFAULT (EXTRACT(EPOCH FROM now())*1000)::bigint, + PRIMARY KEY (workflow_id, step_id) + )""" + .formatted(this.schema)); + stmt.executeBatch(); + } + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + protected String checkSql() { + return """ + SELECT output, error, serialization + FROM "%s".tx_step_outputs + WHERE workflow_id = ? AND step_id = ? + """ + .formatted(schema); + } + + protected abstract Optional checkExecution( + String workflowId, int stepId, String stepName); + + protected String upsertSql() { + return """ + INSERT INTO "%s".tx_step_outputs + (workflow_id, step_id, output, error, serialization) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT DO NOTHING + """ + .formatted(schema); + } + + protected abstract void recordError(String workflowId, int stepId, Exception exception); + + @FunctionalInterface + protected interface TxStepFunction { + R execute(String workflowId, int stepId) throws X; + } + + @SuppressWarnings("unchecked") + protected R runTxStep(TxStepFunction execute, String stepName) + throws X { + return dbos.runStep( + () -> { + var workflowId = Objects.requireNonNull(DBOS.workflowId()); + int stepId = Objects.requireNonNull(DBOS.stepId()); + + var prev = checkExecution(workflowId, stepId, stepName); + if (prev.isPresent()) { + return prev.get().toResult(serializer); + } + + try { + return execute.execute(workflowId, stepId); + } catch (Exception e) { + recordError(workflowId, stepId, e); + throw (X) e; + } + }, + stepName); + } +} diff --git a/transact/src/main/java/dev/dbos/transact/workflow/internal/StepResult.java b/transact/src/main/java/dev/dbos/transact/workflow/internal/StepResult.java index f3e6f3e4..c706e912 100644 --- a/transact/src/main/java/dev/dbos/transact/workflow/internal/StepResult.java +++ b/transact/src/main/java/dev/dbos/transact/workflow/internal/StepResult.java @@ -1,9 +1,12 @@ package dev.dbos.transact.workflow.internal; +import dev.dbos.transact.json.DBOSSerializer; +import dev.dbos.transact.json.SerializationUtil; + public record StepResult( String workflowId, int stepId, - String functionName, + String stepName, String output, String error, String childWorkflowId, @@ -14,20 +17,38 @@ public StepResult(String workflowId, int stepId, String functionName) { } public StepResult withOutput(String v) { - return new StepResult( - workflowId, stepId, functionName, v, error, childWorkflowId, serialization); + return new StepResult(workflowId, stepId, stepName, v, error, childWorkflowId, serialization); } public StepResult withError(String v) { - return new StepResult( - workflowId, stepId, functionName, output, v, childWorkflowId, serialization); + return new StepResult(workflowId, stepId, stepName, output, v, childWorkflowId, serialization); } public StepResult withChildWorkflowId(String v) { - return new StepResult(workflowId, stepId, functionName, output, error, v, serialization); + return new StepResult(workflowId, stepId, stepName, output, error, v, serialization); } public StepResult withSerialization(String v) { - return new StepResult(workflowId, stepId, functionName, output, error, childWorkflowId, v); + return new StepResult(workflowId, stepId, stepName, output, error, childWorkflowId, v); + } + + @SuppressWarnings("unchecked") + public R toResult(DBOSSerializer serializer) throws E { + if (error != null) { + var t = SerializationUtil.deserializeError(error, serialization, serializer); + if (t instanceof Exception) { + throw (E) t; + } else { + throw new RuntimeException(t.getMessage(), t); + } + } + + if (output != null) { + return (R) SerializationUtil.deserializeValue(output, serialization, serializer); + } + + throw new IllegalStateException( + "Recorded output and error are both null for workflow %s step %d (%s)" + .formatted(workflowId, stepId, stepName)); } } diff --git a/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryInitTest.java b/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryInitTest.java new file mode 100644 index 00000000..7458e8df --- /dev/null +++ b/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryInitTest.java @@ -0,0 +1,173 @@ +package dev.dbos.transact.txstep; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import dev.dbos.transact.Constants; +import dev.dbos.transact.DBOS; +import dev.dbos.transact.database.SystemDatabase; +import dev.dbos.transact.utils.DBUtils; +import dev.dbos.transact.utils.PgContainer; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Objects; + +import org.junit.jupiter.api.AutoClose; +import org.junit.jupiter.api.Test; + +public class JdbcStepFactoryInitTest { + @AutoClose final PgContainer pgContainer = new PgContainer(); + + static boolean validateSchema(Connection conn, String schema) throws SQLException { + Objects.requireNonNull(schema); + try (var rs = conn.getMetaData().getSchemas()) { + while (rs.next()) { + if (schema.equalsIgnoreCase(rs.getString("TABLE_SCHEM"))) { + return true; + } + } + } + return false; + } + + static boolean validateTxStepOutputsTable(Connection conn, String schema) throws SQLException { + try (var rs = + conn.getMetaData() + .getTables(null, Objects.requireNonNull(schema), "tx_step_outputs", null)) { + if (rs.next()) { + if (schema.equals(rs.getString("TABLE_SCHEM")) + && "tx_step_outputs".equals(rs.getString("TABLE_NAME")) + && "TABLE".equalsIgnoreCase(rs.getString("TABLE_TYPE"))) { + return true; + } + } + } + return false; + } + + @Test + public void sameDbDefaultSchema() throws Exception { + var config = pgContainer.dbosConfig(); + try (var dbos = new DBOS(config); + var dataSource = pgContainer.dataSource()) { + var schema = Constants.DB_SCHEMA; + + // ensure step factory schema/table do not exist + try (var conn = dataSource.getConnection()) { + assertFalse(validateSchema(conn, schema)); + assertFalse(validateTxStepOutputsTable(conn, schema)); + } + + // create step factory to initialize the app db tables + new JdbcStepFactory(dbos, dataSource); + + // ensure step factory schema/table do exist + try (var conn = dataSource.getConnection()) { + assertTrue(validateSchema(conn, schema)); + assertTrue(validateTxStepOutputsTable(conn, schema)); + } + dbos.launch(); + } + } + + @Test + public void sameDbCustomDbosSchema() throws Exception { + var schema = "custom"; + var config = pgContainer.dbosConfig().withDatabaseSchema(schema); + try (var dbos = new DBOS(config); + var dataSource = pgContainer.dataSource()) { + new JdbcStepFactory(dbos, dataSource); + try (var conn = dataSource.getConnection()) { + assertTrue(validateSchema(conn, schema)); + assertTrue(validateTxStepOutputsTable(conn, schema)); + } + dbos.launch(); + } + } + + @Test + public void sameDbCustomFactorySchema() throws Exception { + var schema = "custom"; + var config = pgContainer.dbosConfig(); + try (var dbos = new DBOS(config); + var dataSource = pgContainer.dataSource()) { + new JdbcStepFactory(dbos, dataSource, schema); + try (var conn = dataSource.getConnection()) { + assertFalse(validateSchema(conn, Constants.DB_SCHEMA)); + assertTrue(validateSchema(conn, schema)); + assertTrue(validateTxStepOutputsTable(conn, schema)); + } + dbos.launch(); + try (var conn = dataSource.getConnection()) { + assertTrue(validateSchema(conn, Constants.DB_SCHEMA)); + } + } + } + + @Test + public void sameDbCustomDbosAndFactorySchema() throws Exception { + var dbosSchema = "custom_a"; + var factorySchema = "custom_b"; + var config = pgContainer.dbosConfig().withDatabaseSchema(dbosSchema); + try (var dbos = new DBOS(config); + var dataSource = pgContainer.dataSource()) { + new JdbcStepFactory(dbos, dataSource, factorySchema); + try (var conn = dataSource.getConnection()) { + assertFalse(validateSchema(conn, dbosSchema)); + assertTrue(validateSchema(conn, factorySchema)); + assertTrue(validateTxStepOutputsTable(conn, factorySchema)); + } + dbos.launch(); + try (var conn = dataSource.getConnection()) { + assertTrue(validateSchema(conn, dbosSchema)); + } + } + } + + @Test + public void nonPostgresDataSource() throws Exception { + var config = pgContainer.dbosConfig(); + try (var dbos = new DBOS(config)) { + var sqliteDs = new org.sqlite.SQLiteDataSource(); + sqliteDs.setUrl("jdbc:sqlite::memory:"); + assertThrows(IllegalArgumentException.class, () -> new JdbcStepFactory(dbos, sqliteDs)); + } + } + + @Test + public void separateDBs() throws Exception { + // create a 2nd database in the container's PG instance + var appDbName = "factory_test_db"; + try (var conn = + DriverManager.getConnection( + pgContainer.jdbcUrl(), pgContainer.username(), pgContainer.password()); + var stmt = conn.createStatement()) { + stmt.execute("CREATE DATABASE " + appDbName); + } + var appDbJdbcUrl = pgContainer.jdbcUrl().replaceFirst("/[^/]+$", "/" + appDbName); + + var config = pgContainer.dbosConfig(); + try (var dbos = new DBOS(config); + var dataSource = + SystemDatabase.createDataSource( + appDbJdbcUrl, pgContainer.username(), pgContainer.password())) { + new JdbcStepFactory(dbos, dataSource); + dbos.launch(); + + var appDbTables = DBUtils.getTables(dataSource, "dbos"); + assertEquals(1, appDbTables.size()); + assertTrue(appDbTables.contains("tx_step_outputs")); + + var sysDbTables = DBUtils.getTables(pgContainer, "dbos"); + assertTrue(sysDbTables.size() >= 10); + assertFalse(sysDbTables.contains("tx_step_outputs")); + assertTrue(sysDbTables.contains("dbos_migrations")); + assertTrue(sysDbTables.contains("workflow_status")); + assertTrue(sysDbTables.contains("operation_outputs")); + } + } +} diff --git a/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryTest.java b/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryTest.java new file mode 100644 index 00000000..61db3688 --- /dev/null +++ b/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryTest.java @@ -0,0 +1,429 @@ +package dev.dbos.transact.txstep; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import dev.dbos.transact.DBOS; +import dev.dbos.transact.config.DBOSConfig; +import dev.dbos.transact.context.WorkflowOptions; +import dev.dbos.transact.json.SerializationUtil; +import dev.dbos.transact.utils.DBUtils; +import dev.dbos.transact.utils.PgContainer; +import dev.dbos.transact.workflow.Workflow; +import dev.dbos.transact.workflow.WorkflowHandle; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Objects; + +import com.zaxxer.hikari.HikariDataSource; +import org.junit.jupiter.api.AutoClose; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +interface FactoryTestService { + record TestResult(String user, int greetCount) {} + + TestResult insertWorkflow(String user) throws SQLException; + + TestResult errorWorkflow(String user) throws SQLException; + + TestResult readWorkflow(String user) throws SQLException; + + TestResult insertThenReadWorkflow(String user) throws SQLException; +} + +class FactoryTestServiceImpl implements FactoryTestService { + + private final JdbcStepFactory stepFactory; + + public FactoryTestServiceImpl(JdbcStepFactory stepFactory) { + this.stepFactory = stepFactory; + } + + TestResult insertGreeting(Connection conn, String user) throws SQLException { + var sql = + """ + INSERT INTO greetings(name, greet_count) + VALUES (?, 1) + ON CONFLICT(name) + DO UPDATE SET greet_count = greetings.greet_count + 1 + RETURNING greet_count + """; + + try (var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, Objects.requireNonNull(user)); + try (var rs = stmt.executeQuery()) { + var greetCount = rs.next() ? rs.getInt("greet_count") : 0; + return new TestResult(user, greetCount); + } + } + } + + TestResult errorGreeting(Connection conn, String user) throws SQLException { + insertGreeting(conn, user); + throw new RuntimeException("Test Exception %d".formatted(System.currentTimeMillis())); + } + + TestResult readGreeting(Connection conn, String user) throws SQLException { + var sql = + """ + SELECT greet_count + FROM greetings + WHERE name = ? + """; + try (var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, Objects.requireNonNull(user)); + try (var rs = stmt.executeQuery()) { + var greetCount = rs.next() ? rs.getInt("greet_count") : 0; + return new TestResult(user, greetCount); + } + } + } + + @Override + @Workflow + public TestResult insertWorkflow(String user) throws SQLException { + return stepFactory.txStep((Connection c) -> insertGreeting(c, user), "insertGreeting"); + } + + @Override + @Workflow + public TestResult errorWorkflow(String user) throws SQLException { + return stepFactory.txStep((Connection c) -> errorGreeting(c, user), "errorGreeting"); + } + + @Override + @Workflow + public TestResult readWorkflow(String user) throws SQLException { + return stepFactory.txStep((Connection c) -> readGreeting(c, user), "readGreeting"); + } + + @Override + @Workflow + public TestResult insertThenReadWorkflow(String user) throws SQLException { + stepFactory.txStep((Connection c) -> insertGreeting(c, user), "insertGreeting"); + return stepFactory.txStep((Connection c) -> readGreeting(c, user), "readGreeting"); + } +} + +public class JdbcStepFactoryTest { + @AutoClose final PgContainer pgContainer = new PgContainer(); + + DBOSConfig dbosConfig; + @AutoClose DBOS dbos; + @AutoClose HikariDataSource dataSource; + JdbcStepFactory stepFactory; + FactoryTestService proxy; + FactoryTestServiceImpl impl; + + @BeforeEach + void beforeEach() throws SQLException { + + dbosConfig = pgContainer.dbosConfig(); + dataSource = pgContainer.dataSource(); + + try (var conn = dataSource.getConnection(); + var stmt = conn.createStatement()) { + stmt.execute( + "CREATE TABLE greetings(name text NOT NULL, greet_count integer DEFAULT 0, PRIMARY KEY(name))"); + } + + dbos = new DBOS(dbosConfig); + stepFactory = new JdbcStepFactory(dbos, dataSource); + + impl = new FactoryTestServiceImpl(stepFactory); + proxy = dbos.registerProxy(FactoryTestService.class, impl); + + dbos.launch(); + } + + private int getGreetCount(String user) throws SQLException { + var sql = "SELECT greet_count FROM greetings WHERE name = ?"; + try (var conn = dataSource.getConnection(); + var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, user); + try (var rs = stmt.executeQuery()) { + return rs.next() ? rs.getInt("greet_count") : 0; + } + } + } + + @Test + public void testInsert() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + + var rows = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(1, rows.size()); + var row = rows.get(0); + assertEquals(wfid, row.workflowId()); + assertEquals(0, row.stepId()); + assertNotNull(row.output()); + assertNull(row.error()); + assertEquals(SerializationUtil.NATIVE, row.serialization()); + var output = SerializationUtil.deserializeValue(row.output(), row.serialization(), null); + assertEquals(new FactoryTestService.TestResult(user, 1), output); + + assertEquals(1, getGreetCount(user)); + } + + @Test + public void testError() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + try (var _o = new WorkflowOptions(wfid).setContext()) { + assertThrows(RuntimeException.class, () -> proxy.errorWorkflow(user)); + } + + // Transaction rolled back — no greeting inserted + var rows = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(1, rows.size()); + var row = rows.get(0); + assertEquals(wfid, row.workflowId()); + assertEquals(0, row.stepId()); + assertNull(row.output()); + assertNotNull(row.error()); + + assertEquals(0, getGreetCount(user)); + } + + @Test + public void testRead() throws Exception { + var insertWfid = "wf1"; + var readWfid = "wf2"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(insertWfid).setContext()) { + proxy.insertWorkflow(user); + } + + try (var _o = new WorkflowOptions(readWfid).setContext()) { + var result = proxy.readWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + + var rows = DBUtils.getTxStepRows(dataSource, readWfid); + assertEquals(1, rows.size()); + var row = rows.get(0); + assertEquals(readWfid, row.workflowId()); + assertEquals(0, row.stepId()); + assertNotNull(row.output()); + assertNull(row.error()); + assertEquals(SerializationUtil.NATIVE, row.serialization()); + var output = SerializationUtil.deserializeValue(row.output(), row.serialization(), null); + assertEquals(new FactoryTestService.TestResult(user, 1), output); + + assertEquals(1, getGreetCount(user)); + } + + @Test + public void testIdempotency() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + + // Second call with same wfid — txStep output is cached, insert not re-executed + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + + assertEquals(1, getGreetCount(user)); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid).size()); + } + + @Test + public void testRetryError() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(wfid).setContext()) { + assertThrows(RuntimeException.class, () -> proxy.errorWorkflow(user)); + } + assertEquals(0, getGreetCount(user)); + dbos.close(); + + try (var conn = dataSource.getConnection(); + var stmt = + conn.prepareStatement("DELETE FROM dbos.operation_outputs WHERE workflow_uuid = ?")) { + stmt.setString(1, wfid); + stmt.executeUpdate(); + } + DBUtils.setWorkflowState(dataSource, wfid, "PENDING"); + + assertEquals(0, DBUtils.getStepRows(dataSource, wfid).size()); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid).size()); + + dbos.launch(); + WorkflowHandle handle = + dbos.retrieveWorkflow(wfid); + assertThrows(RuntimeException.class, handle::getResult); + + // Cached error replayed — insert still not committed + assertEquals(0, getGreetCount(user)); + var txSteps = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(1, txSteps.size()); + assertNull(txSteps.get(0).output()); + assertNotNull(txSteps.get(0).error()); + } + + @Test + public void testMultipleTxSteps() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertThenReadWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + + assertEquals(1, getGreetCount(user)); + + var rows = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(2, rows.size()); + assertEquals(0, rows.get(0).stepId()); + assertNotNull(rows.get(0).output()); + assertNull(rows.get(0).error()); + assertEquals(1, rows.get(1).stepId()); + assertNotNull(rows.get(1).output()); + assertNull(rows.get(1).error()); + } + + @Test + public void testDistinctWorkflows() throws Exception { + var wfid1 = "wf1"; + var wfid2 = "wf2"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(wfid1).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(1, result.greetCount()); + } + + try (var _o = new WorkflowOptions(wfid2).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(2, result.greetCount()); + } + + assertEquals(2, getGreetCount(user)); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid1).size()); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid2).size()); + } + + @Test + public void testRetryPartialMultipleSteps() throws Exception { + var wfid = "wf1"; + var user = "testUser"; + + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertThenReadWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + assertEquals(1, getGreetCount(user)); + dbos.close(); + + // Simulate crash after step 0 wrote tx_step_outputs but before step 1 ran: + // both operation_outputs rows are gone, and step 1 has no tx_step_outputs entry + try (var conn = dataSource.getConnection(); + var stmt = + conn.prepareStatement("DELETE FROM dbos.operation_outputs WHERE workflow_uuid = ?")) { + stmt.setString(1, wfid); + stmt.executeUpdate(); + } + try (var conn = dataSource.getConnection(); + var stmt = + conn.prepareStatement( + "DELETE FROM dbos.tx_step_outputs WHERE workflow_id = ? AND step_id = 1")) { + stmt.setString(1, wfid); + stmt.executeUpdate(); + } + DBUtils.setWorkflowState(dataSource, wfid, "PENDING"); + + assertEquals(0, DBUtils.getStepRows(dataSource, wfid).size()); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid).size()); + + var relaunchTimestamp = System.currentTimeMillis(); + dbos.launch(); + WorkflowHandle handle = + dbos.retrieveWorkflow(wfid); + var result = (FactoryTestService.TestResult) handle.getResult(); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + + // Step 0 cache hit — insert not re-executed + assertEquals(1, getGreetCount(user)); + + var txSteps = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(2, txSteps.size()); + assertTrue(txSteps.get(0).createdAt() < relaunchTimestamp); // step 0: original run + assertTrue(txSteps.get(1).createdAt() >= relaunchTimestamp); // step 1: re-executed on retry + } + + @Test + public void testRetryInsert() throws Exception { + var timestamp = System.currentTimeMillis(); + + var wfid = "wf1"; + var user = "testUser"; + try (var _o = new WorkflowOptions(wfid).setContext()) { + var result = proxy.insertWorkflow(user); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + } + dbos.close(); + + try (var conn = dataSource.getConnection(); + var stmt = + conn.prepareStatement("DELETE FROM dbos.operation_outputs WHERE workflow_uuid = ?")) { + stmt.setString(1, wfid); + stmt.executeUpdate(); + } + DBUtils.setWorkflowState(dataSource, wfid, "PENDING"); + + assertEquals(0, DBUtils.getStepRows(dataSource, wfid).size()); + assertEquals(1, DBUtils.getTxStepRows(dataSource, wfid).size()); + + var relaunchTimestamp = System.currentTimeMillis(); + dbos.launch(); + var handle = dbos.retrieveWorkflow(wfid); + var result = (FactoryTestService.TestResult) handle.getResult(); + assertEquals(1, result.greetCount()); + assertEquals(user, result.user()); + + var steps = DBUtils.getStepRows(dataSource, wfid); + var txSteps = DBUtils.getTxStepRows(dataSource, wfid); + assertEquals(1, steps.size()); + assertEquals(1, txSteps.size()); + + var step = steps.get(0); + var txStep = txSteps.get(0); + assertEquals(step.output(), txStep.output()); + assertEquals(step.error(), txStep.error()); + + assertTrue(txStep.createdAt() < step.startedAt()); + assertTrue(timestamp < txStep.createdAt()); + assertTrue(txStep.createdAt() < relaunchTimestamp); + assertTrue(relaunchTimestamp < step.startedAt()); + + // Retry reads from tx_step_outputs cache — insert not re-executed + assertEquals(1, getGreetCount(user)); + } +} diff --git a/transact/src/test/java/dev/dbos/transact/utils/DBUtils.java b/transact/src/test/java/dev/dbos/transact/utils/DBUtils.java index 450b6455..b84eb80b 100644 --- a/transact/src/test/java/dev/dbos/transact/utils/DBUtils.java +++ b/transact/src/test/java/dev/dbos/transact/utils/DBUtils.java @@ -15,7 +15,11 @@ import java.sql.Statement; import java.time.Instant; import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.Objects; import javax.sql.DataSource; @@ -444,4 +448,87 @@ public static List getStreamEntries(DataSource ds, String workflowId, } } } + + public static List getAllTxStepRows(DataSource ds) throws SQLException { + return getAllTxStepRows(ds, null); + } + + public static List getAllTxStepRows(DataSource ds, String schema) + throws SQLException { + var sql = + "SELECT * FROM \"%s\".tx_step_outputs ORDER BY created_at" + .formatted(SystemDatabase.sanitizeSchema(schema)); + try (var conn = ds.getConnection(); + var stmt = conn.createStatement(); + var rs = stmt.executeQuery(sql)) { + List rows = new ArrayList<>(); + while (rs.next()) { + rows.add(new TxStepOutputRow(rs)); + } + return rows; + } + } + + public static List getTxStepRows(DataSource ds, String workflowId) + throws SQLException { + return getTxStepRows(ds, workflowId, null); + } + + public static List getTxStepRows(DataSource ds, String workflowId, String schema) + throws SQLException { + var sql = + "SELECT * FROM \"%s\".tx_step_outputs WHERE workflow_id = ? ORDER BY step_id" + .formatted(SystemDatabase.sanitizeSchema(schema)); + try (var conn = ds.getConnection(); + var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, Objects.requireNonNull(workflowId)); + try (var rs = stmt.executeQuery()) { + List rows = new ArrayList<>(); + while (rs.next()) { + rows.add(new TxStepOutputRow(rs)); + } + return rows; + } + } + } + + public static List> dumpResultSet(ResultSet rs) throws SQLException { + List> results = new ArrayList<>(); + var metaData = rs.getMetaData(); + var columnCount = metaData.getColumnCount(); + while (rs.next()) { + Map map = new HashMap<>(); + for (var i = 1; i <= columnCount; i++) { + map.put(metaData.getColumnLabel(i), rs.getObject(i)); + } + results.add(map); + } + return results; + } + + public static Collection getTables(PgContainer pg, String schema) throws SQLException { + try (var ds = pg.dataSource()) { + return getTables(ds, schema); + } + } + + public static Collection getTables(DataSource ds, String schema) throws SQLException { + try (var conn = ds.getConnection()) { + return getTables(conn, schema); + } + } + + public static Collection getTables(Connection conn, String schema) throws SQLException { + List tables = new ArrayList<>(); + try (var rs = conn.getMetaData().getTables(null, schema, null, null)) { + while (rs.next()) { + var name = rs.getString("TABLE_NAME"); + var type = rs.getString("TABLE_TYPE"); + if ("TABLE".equals(type)) { + tables.add(name); + } + } + } + return tables; + } } diff --git a/transact/src/test/java/dev/dbos/transact/utils/TxStepOutputRow.java b/transact/src/test/java/dev/dbos/transact/utils/TxStepOutputRow.java new file mode 100644 index 00000000..7472fe31 --- /dev/null +++ b/transact/src/test/java/dev/dbos/transact/utils/TxStepOutputRow.java @@ -0,0 +1,23 @@ +package dev.dbos.transact.utils; + +import java.sql.ResultSet; +import java.sql.SQLException; + +public record TxStepOutputRow( + String workflowId, + int stepId, + String output, + String error, + String serialization, + Long createdAt) { + + public TxStepOutputRow(ResultSet rs) throws SQLException { + this( + rs.getString("workflow_id"), + rs.getInt("step_id"), + rs.getString("output"), + rs.getString("error"), + rs.getString("serialization"), + rs.getObject("created_at", Long.class)); + } +}