diff --git a/spring-test-dbunit/src/main/java/com/github/springtestdbunit/DbUnitRunner.java b/spring-test-dbunit/src/main/java/com/github/springtestdbunit/DbUnitRunner.java index 6d7ebed1..577218f5 100644 --- a/spring-test-dbunit/src/main/java/com/github/springtestdbunit/DbUnitRunner.java +++ b/spring-test-dbunit/src/main/java/com/github/springtestdbunit/DbUnitRunner.java @@ -138,7 +138,13 @@ private void verifyExpected(DbUnitTestContext testContext, DatabaseConnections c if (logger.isDebugEnabled()) { logger.debug("Veriftying @DatabaseTest expectation using " + annotation.value()); } - DatabaseAssertion assertion = annotation.assertionMode().getDatabaseAssertion(); + + DatabaseAssertion assertion; + if (StringUtils.hasLength(annotation.assertionBean())) { + assertion = testContext.getDatabaseAssertion(annotation.assertionBean()); + } else { + assertion = annotation.assertionMode().getDatabaseAssertion(); + } List columnFilters = getColumnFilters(annotation); if (StringUtils.hasLength(query)) { Assert.hasLength(table, "The table name must be specified when using a SQL query"); diff --git a/spring-test-dbunit/src/main/java/com/github/springtestdbunit/DbUnitTestContext.java b/spring-test-dbunit/src/main/java/com/github/springtestdbunit/DbUnitTestContext.java index fca0e15f..92462381 100644 --- a/spring-test-dbunit/src/main/java/com/github/springtestdbunit/DbUnitTestContext.java +++ b/spring-test-dbunit/src/main/java/com/github/springtestdbunit/DbUnitTestContext.java @@ -18,6 +18,7 @@ import java.lang.reflect.Method; +import com.github.springtestdbunit.assertion.DatabaseAssertion; import org.dbunit.database.IDatabaseConnection; import org.dbunit.dataset.IDataSet; @@ -49,6 +50,13 @@ public interface DbUnitTestContext { */ DatabaseOperationLookup getDatbaseOperationLookup(); + /** + * Returns the {@link DatabaseAssertion} implemented by the bean with the given name. + * @param databaseAssertionBeanName name of the database assertion bean + * @return database assertion + */ + DatabaseAssertion getDatabaseAssertion(String databaseAssertionBeanName); + /** * Returns the class that is under test. * @return The class under test diff --git a/spring-test-dbunit/src/main/java/com/github/springtestdbunit/DbUnitTestExecutionListener.java b/spring-test-dbunit/src/main/java/com/github/springtestdbunit/DbUnitTestExecutionListener.java index 185974cf..cac49ba6 100644 --- a/spring-test-dbunit/src/main/java/com/github/springtestdbunit/DbUnitTestExecutionListener.java +++ b/spring-test-dbunit/src/main/java/com/github/springtestdbunit/DbUnitTestExecutionListener.java @@ -21,6 +21,7 @@ import javax.sql.DataSource; +import com.github.springtestdbunit.assertion.DatabaseAssertion; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.dbunit.database.IDatabaseConnection; @@ -236,6 +237,10 @@ public DatabaseOperationLookup getDatbaseOperationLookup() { return (DatabaseOperationLookup) getAttribute(DATABASE_OPERATION_LOOKUP_ATTRIBUTE); } + public DatabaseAssertion getDatabaseAssertion(String databaseAssertionBeanName) { + return testContext.getApplicationContext().getBean(databaseAssertionBeanName, DatabaseAssertion.class); + } + public Class getTestClass() { return (Class) ReflectionUtils.invokeMethod(GET_TEST_CLASS, this.testContext); } diff --git a/spring-test-dbunit/src/main/java/com/github/springtestdbunit/annotation/ExpectedDatabase.java b/spring-test-dbunit/src/main/java/com/github/springtestdbunit/annotation/ExpectedDatabase.java index 14811fce..08563d10 100644 --- a/spring-test-dbunit/src/main/java/com/github/springtestdbunit/annotation/ExpectedDatabase.java +++ b/spring-test-dbunit/src/main/java/com/github/springtestdbunit/annotation/ExpectedDatabase.java @@ -24,6 +24,7 @@ import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import com.github.springtestdbunit.assertion.DatabaseAssertion; import org.dbunit.dataset.IDataSet; import org.dbunit.dataset.filter.IColumnFilter; @@ -66,6 +67,14 @@ */ DatabaseAssertionMode assertionMode() default DatabaseAssertionMode.DEFAULT; + /** + * The name of the database assertion bean to use for comparing datasets. + * This bean must implement the {@link DatabaseAssertion} interface. + * If defined, this supersedes the {@link ExpectedDatabase#assertionMode()} element. + * @return Database assertion bean to use + */ + String assertionBean() default ""; + /** * Optional table name that can be used to limit the comparison to a specific table. * @return the table name