diff --git a/core/src/main/java/google/registry/bsa/persistence/Queries.java b/core/src/main/java/google/registry/bsa/persistence/Queries.java index e22c13f7115..187d0043c70 100644 --- a/core/src/main/java/google/registry/bsa/persistence/Queries.java +++ b/core/src/main/java/google/registry/bsa/persistence/Queries.java @@ -27,7 +27,6 @@ import java.time.Instant; import java.util.List; import java.util.Optional; -import java.util.stream.Collectors; import java.util.stream.Stream; /** Helpers for querying BSA JPA entities. */ @@ -114,21 +113,33 @@ static ImmutableList batchReadUnblockables( } static ImmutableSet queryUnblockablesByNames(ImmutableSet domains) { - String labelTldParis = - domains.stream() - .map( - domain -> { - List parts = DOMAIN_SPLITTER.splitToList(domain); - verify(parts.size() == 2, "Invalid domain name %s", domain); - return String.format("('%s','%s')", parts.get(0), parts.get(1)); - }) - .collect(Collectors.joining(",")); - String sql = - String.format( - "SELECT CONCAT(d.label, '.', d.tld) FROM \"BsaUnblockableDomain\" d " - + "WHERE (d.label, d.tld) IN (%s)", - labelTldParis); - return ImmutableSet.copyOf(tm().getEntityManager().createNativeQuery(sql).getResultList()); + if (domains.isEmpty()) { + return ImmutableSet.of(); + } + ImmutableList domainList = ImmutableList.copyOf(domains); + StringBuilder sqlBuilder = + new StringBuilder( + "SELECT CONCAT(d.label, '.', d.tld) FROM \"BsaUnblockableDomain\" d WHERE (d.label," + + " d.tld) IN ("); + for (int i = 0; i < domainList.size(); i++) { + if (i > 0) { + sqlBuilder.append(","); + } + sqlBuilder.append("(:label").append(i).append(", :tld").append(i).append(")"); + } + sqlBuilder.append(")"); + + var query = tm().getEntityManager().createNativeQuery(sqlBuilder.toString()); + for (int i = 0; i < domainList.size(); i++) { + List parts = DOMAIN_SPLITTER.splitToList(domainList.get(i)); + verify(parts.size() == 2, "Invalid domain name %s", domainList.get(i)); + query.setParameter("label" + i, parts.get(0)); + query.setParameter("tld" + i, parts.get(1)); + } + + @SuppressWarnings("unchecked") + List resultList = (List) query.getResultList(); + return ImmutableSet.copyOf(resultList); } static ImmutableSet queryNewlyCreatedDomains( diff --git a/core/src/test/java/google/registry/bsa/persistence/QueriesTest.java b/core/src/test/java/google/registry/bsa/persistence/QueriesTest.java index f185aeb7cde..c7b0e3c1d07 100644 --- a/core/src/test/java/google/registry/bsa/persistence/QueriesTest.java +++ b/core/src/test/java/google/registry/bsa/persistence/QueriesTest.java @@ -316,4 +316,21 @@ void batchReadUnblockables_multiBatch() { .domainName()) .isEqualTo("label3.app"); } + + @Test + void testQueryUnblockablesByNames_sqlInjectionSafe() { + setupUnblockableDomains(); + // Verify standard lookup works + assertThat(tm().transact(() -> queryUnblockablesByNames(ImmutableSet.of("a.tld1")))) + .containsExactly("a.tld1"); + + // Attempt SQL Injection payload in the name. Should be treated as literal and return empty. + assertThat(tm().transact(() -> queryUnblockablesByNames(ImmutableSet.of("a' OR '1'='1.tld1")))) + .isEmpty(); + } + + @Test + void testQueryUnblockablesByNames_emptySet() { + assertThat(tm().transact(() -> queryUnblockablesByNames(ImmutableSet.of()))).isEmpty(); + } }