diff --git a/build.mill b/build.mill index 83484c3460..17bbedff6d 100644 --- a/build.mill +++ b/build.mill @@ -112,6 +112,8 @@ object runner extends Cross[Runner](Scala.runnerScalaVersions) with CrossScalaDefaultToRunner object `test-runner` extends Cross[TestRunner](Scala.runnerScalaVersions) with CrossScalaDefaultToRunner +object `java-test-runner` extends JavaTestRunner + with LocatedInModules object `tasty-lib` extends Cross[TastyLib](Scala.scala3MainVersions) with CrossScalaDefaultToInternal @@ -452,12 +454,18 @@ trait Core extends ScalaCliCrossSbtModule val runnerMainClass = build.runner(crossScalaVersion) .mainClass() .getOrElse(sys.error("No main class defined for runner")) + val javaTestRunnerMainClass = `java-test-runner` + .mainClass() + .getOrElse(sys.error("No main class defined for java-test-runner")) val detailedVersionValue = if (`local-repo`.developingOnStubModules) s"""Some("${vcsState()}")""" else "None" val testRunnerOrganization = `test-runner`(crossScalaVersion) .pomSettings() .organization + val javaTestRunnerOrganization = `java-test-runner` + .pomSettings() + .organization val code = s"""package scala.build.internal | @@ -479,6 +487,11 @@ trait Core extends ScalaCliCrossSbtModule | def testRunnerVersion = "${`test-runner`(crossScalaVersion).publishVersion()}" | def testRunnerMainClass = "$testRunnerMainClass" | + | def javaTestRunnerOrganization = "$javaTestRunnerOrganization" + | def javaTestRunnerModuleName = "${`java-test-runner`.artifactName()}" + | def javaTestRunnerVersion = "${`java-test-runner`.publishVersion()}" + | def javaTestRunnerMainClass = "$javaTestRunnerMainClass" + | | def runnerOrganization = "${build.runner(crossScalaVersion).pomSettings().organization}" | def runnerModuleName = "${build.runner(crossScalaVersion).artifactName()}" | def runnerVersion = "${build.runner(crossScalaVersion).publishVersion()}" @@ -1323,6 +1336,16 @@ trait TestRunner extends CrossSbtModule override def mainClass: T[Option[String]] = Some("scala.build.testrunner.DynamicTestRunner") } +trait JavaTestRunner extends JavaModule + with ScalaCliPublishModule + with LocatedInModules { + override def mvnDeps: T[Seq[Dep]] = super.mvnDeps() ++ Seq( + Deps.asm, + Deps.testInterface + ) + override def mainClass: T[Option[String]] = Some("scala.build.testrunner.JavaDynamicTestRunner") +} + trait TastyLib extends ScalaCliCrossSbtModule with ScalaCliPublishModule with ScalaCliScalafixModule @@ -1357,7 +1380,7 @@ object `local-repo` extends LocalRepo { def developingOnStubModules = false override def stubsModules: Seq[PublishLocalNoFluff] = - Seq(runner(Scala.runnerScala3), `test-runner`(Scala.runnerScala3)) + Seq(runner(Scala.runnerScala3), `test-runner`(Scala.runnerScala3), `java-test-runner`) override def version: T[String] = runner(Scala.runnerScala3).publishVersion() } diff --git a/modules/build/src/main/scala/scala/build/Build.scala b/modules/build/src/main/scala/scala/build/Build.scala index 04816c8299..1a52b9bc6d 100644 --- a/modules/build/src/main/scala/scala/build/Build.scala +++ b/modules/build/src/main/scala/scala/build/Build.scala @@ -1105,8 +1105,7 @@ object Build { either { val options0 = - // FIXME: don't add Scala to pure Java test builds (need to add pure Java test runner) - if sources.hasJava && !sources.hasScala && scope != Scope.Test + if sources.hasJava && !sources.hasScala then options.copy( scalaOptions = options.scalaOptions.copy( diff --git a/modules/build/src/test/scala/scala/build/tests/JavaTestRunnerTests.scala b/modules/build/src/test/scala/scala/build/tests/JavaTestRunnerTests.scala new file mode 100644 index 0000000000..3cfc8f155e --- /dev/null +++ b/modules/build/src/test/scala/scala/build/tests/JavaTestRunnerTests.scala @@ -0,0 +1,51 @@ +package scala.build.tests + +import com.eed3si9n.expecty.Expecty.assert as expect + +import scala.build.options.* + +class JavaTestRunnerTests extends TestUtil.ScalaCliBuildSuite { + + private def makeOptions( + scalaVersionOpt: Option[MaybeScalaVersion], + addTestRunner: Boolean + ): BuildOptions = + BuildOptions( + scalaOptions = ScalaOptions( + scalaVersion = scalaVersionOpt + ), + internalDependencies = InternalDependenciesOptions( + addTestRunnerDependencyOpt = Some(addTestRunner) + ) + ) + + test("pure Java build has no scalaParams") { + val opts = makeOptions(Some(MaybeScalaVersion.none), addTestRunner = false) + val params = opts.scalaParams.toOption.flatten + expect(params.isEmpty, "Pure Java build should have no scalaParams") + } + + test("Scala build has scalaParams") { + val opts = makeOptions(None, addTestRunner = false) + val params = opts.scalaParams.toOption.flatten + expect(params.isDefined, "Scala build should have scalaParams") + } + + test("pure Java test build gets addJvmJavaTestRunner=true in Artifacts params") { + val opts = makeOptions(Some(MaybeScalaVersion.none), addTestRunner = true) + val isJava = opts.scalaParams.toOption.flatten.isEmpty + expect(isJava, "Expected pure Java build to have no scalaParams") + } + + test("Scala test build gets addJvmTestRunner=true in Artifacts params") { + val opts = makeOptions(None, addTestRunner = true) + val isJava = opts.scalaParams.toOption.flatten.isEmpty + expect(!isJava, "Expected Scala build to have scalaParams") + } + + test("mixed Scala+Java build still gets Scala test runner") { + val opts = makeOptions(None, addTestRunner = true) + val isJava = opts.scalaParams.toOption.flatten.isEmpty + expect(!isJava, "Mixed Scala+Java build should still use Scala test runner") + } +} diff --git a/modules/cli/src/main/scala/scala/cli/commands/test/Test.scala b/modules/cli/src/main/scala/scala/cli/commands/test/Test.scala index a236a5a93c..80b36b2b54 100644 --- a/modules/cli/src/main/scala/scala/cli/commands/test/Test.scala +++ b/modules/cli/src/main/scala/scala/cli/commands/test/Test.scala @@ -256,11 +256,16 @@ object Test extends ScalaCommand[TestOptions] { testOnly.map(to => s"--test-only=$to").toSeq ++ Seq("--") ++ args + val testRunnerMainClass = + if build.artifacts.hasJavaTestRunner + then Constants.javaTestRunnerMainClass + else Constants.testRunnerMainClass + Runner.runJvm( build.options.javaHome().value.javaCommand, build.options.javaOptions.javaOpts.toSeq.map(_.value.value), classPath, - Constants.testRunnerMainClass, + testRunnerMainClass, extraArgs, logger, allowExecve = allowExecve diff --git a/modules/integration/src/test/scala/scala/cli/integration/RunTestDefinitions.scala b/modules/integration/src/test/scala/scala/cli/integration/RunTestDefinitions.scala index 1d3517e2d6..ec29db80a4 100755 --- a/modules/integration/src/test/scala/scala/cli/integration/RunTestDefinitions.scala +++ b/modules/integration/src/test/scala/scala/cli/integration/RunTestDefinitions.scala @@ -2510,4 +2510,24 @@ abstract class RunTestDefinitions processes.foreach { case (p, _) => expect(p.exitCode() == 0) } } } + + test("pure Java run has no Scala on classpath") { + TestInputs( + os.rel / "Main.java" -> + """public class Main { + | public static void main(String[] args) { + | try { + | Class.forName("scala.Predef"); + | throw new RuntimeException("Scala should not be on the classpath"); + | } catch (ClassNotFoundException e) { + | System.out.println("No Scala on classpath!"); + | } + | } + |} + |""".stripMargin + ).fromRoot { root => + val res = os.proc(TestUtil.cli, "run", extraOptions, ".").call(cwd = root) + expect(res.out.text().contains("No Scala on classpath!")) + } + } } diff --git a/modules/integration/src/test/scala/scala/cli/integration/TestTestDefinitions.scala b/modules/integration/src/test/scala/scala/cli/integration/TestTestDefinitions.scala index 363fd3ed1a..51358bd772 100644 --- a/modules/integration/src/test/scala/scala/cli/integration/TestTestDefinitions.scala +++ b/modules/integration/src/test/scala/scala/cli/integration/TestTestDefinitions.scala @@ -857,6 +857,63 @@ abstract class TestTestDefinitions extends ScalaCliSuite with TestScalaVersionAr } } + test("pure Java test with JUnit has no Scala on classpath") { + TestInputs( + os.rel / "test" / "MyTests.java" -> + """//> using test.dep junit:junit:4.13.2 + |//> using test.dep com.novocode:junit-interface:0.11 + |import org.junit.Test; + |import static org.junit.Assert.assertEquals; + | + |public class MyTests { + | @Test + | public void foo() { + | try { + | Class.forName("scala.Predef"); + | throw new AssertionError("Scala should not be on the classpath"); + | } catch (ClassNotFoundException e) { + | // expected + | } + | assertEquals(4, 2 + 2); + | System.out.println("No Scala on classpath!"); + | } + |} + |""".stripMargin + ).fromRoot { root => + val res = os.proc(TestUtil.cli, "test", extraOptions, ".").call(cwd = root) + expect(res.out.text().contains("No Scala on classpath!")) + } + } + + test("pure Java test with JUnit and --server=false has no Scala on classpath") { + TestInputs( + os.rel / "test" / "MyTests.java" -> + """//> using test.dep junit:junit:4.13.2 + |//> using test.dep com.novocode:junit-interface:0.11 + |import org.junit.Test; + |import static org.junit.Assert.assertEquals; + | + |public class MyTests { + | @Test + | public void foo() { + | try { + | Class.forName("scala.Predef"); + | throw new AssertionError("Scala should not be on the classpath"); + | } catch (ClassNotFoundException e) { + | // expected + | } + | assertEquals(4, 2 + 2); + | System.out.println("No Scala on classpath (no server)!"); + | } + |} + |""".stripMargin + ).fromRoot { root => + val res = + os.proc(TestUtil.cli, "test", "--server=false", extraOptions, ".").call(cwd = root) + expect(res.out.text().contains("No Scala on classpath (no server)!")) + } + } + test(s"zio-test warning when zio-test-sbt was not passed") { TestUtil.retryOnCi() { val expectedMessage = "Hello from zio" diff --git a/modules/java-test-runner/src/main/java/scala/build/testrunner/JavaAsmTestRunner.java b/modules/java-test-runner/src/main/java/scala/build/testrunner/JavaAsmTestRunner.java new file mode 100644 index 0000000000..4649b15235 --- /dev/null +++ b/modules/java-test-runner/src/main/java/scala/build/testrunner/JavaAsmTestRunner.java @@ -0,0 +1,315 @@ +package scala.build.testrunner; + +import org.objectweb.asm.*; +import sbt.testing.*; + +import java.io.*; +import java.nio.charset.StandardCharsets; +import java.nio.file.*; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Stream; +import java.util.zip.*; + +public class JavaAsmTestRunner { + + public static class ParentInspector { + private final List classPath; + private final ConcurrentHashMap> cache = new ConcurrentHashMap<>(); + + public ParentInspector(List classPath) { + this.classPath = classPath; + } + + private List parents(String className) { + return cache.computeIfAbsent(className, name -> { + byte[] byteCode = findInClassPath(classPath, name + ".class"); + if (byteCode == null) return Collections.emptyList(); + TestClassChecker checker = new TestClassChecker(); + ClassReader reader = new ClassReader(byteCode); + reader.accept(checker, 0); + return checker.getImplements(); + }); + } + + public List allParents(String className) { + List result = new ArrayList<>(); + Set done = new HashSet<>(); + Deque todo = new ArrayDeque<>(); + todo.add(className); + while (!todo.isEmpty()) { + String current = todo.poll(); + if (!done.add(current)) continue; + result.add(current); + todo.addAll(parents(current)); + } + return result; + } + } + + public static Optional matchFingerprints( + String className, + InputStream byteCodeStream, + List fingerprints, + ParentInspector parentInspector, + ClassLoader loader + ) throws IOException { + TestClassChecker checker = new TestClassChecker(); + ClassReader reader = new ClassReader(byteCodeStream); + reader.accept(checker, 0); + + boolean isModule = className.endsWith("$"); + boolean hasPublicConstructors = checker.getPublicConstructorCount() > 0; + boolean definitelyNoTests = checker.isAbstract() || + checker.isInterface() || + checker.getPublicConstructorCount() > 1 || + isModule == hasPublicConstructors; + + if (definitelyNoTests) return Optional.empty(); + + for (Fingerprint fp : fingerprints) { + if (fp instanceof SubclassFingerprint) { + SubclassFingerprint sf = (SubclassFingerprint) fp; + if (sf.isModule() != isModule) continue; + String superName = sf.superclassName().replace('.', '/'); + if (parentInspector.allParents(checker.getName()).contains(superName)) { + return Optional.of(fp); + } + } else if (fp instanceof AnnotatedFingerprint) { + AnnotatedFingerprint af = (AnnotatedFingerprint) fp; + if (af.isModule() != isModule) continue; + // Use classloader-based reflection for annotation matching (proven approach) + if (loader != null) { + try { + String rawName = className.replace('/', '.').replace('\\', '.'); + String clsNameForLoad = rawName.endsWith("$") ? rawName.substring(0, rawName.length() - 1) : rawName; + Class cls = loader.loadClass(clsNameForLoad); + Optional result = + JavaFrameworkUtils.matchFingerprints(loader, cls, new Fingerprint[]{fp}); + if (result.isPresent()) return Optional.of(fp); + } catch (ClassNotFoundException | NoClassDefFoundError | + UnsupportedClassVersionError | IncompatibleClassChangeError e) { + // fall through + } + } + } + } + return Optional.empty(); + } + + public static List findFrameworkServices(List classPath) { + List result = new ArrayList<>(); + byte[] content = findInClassPath(classPath, "META-INF/services/sbt.testing.Framework"); + if (content != null) { + parseServiceFileContent(new String(content, StandardCharsets.UTF_8), result); + } + return result; + } + + private static void parseServiceFileContent(String content, List result) { + for (String line : content.split("[\r\n]+")) { + String trimmed = line.trim(); + if (!trimmed.isEmpty() && !trimmed.startsWith("#")) { + result.add(trimmed); + } + } + } + + public static List findFrameworks( + List classPath, + List preferredClasses, + ParentInspector parentInspector + ) { + List result = new ArrayList<>(); + // first check preferred classes + for (String preferred : preferredClasses) { + String resourceName = preferred.replace('.', '/') + ".class"; + byte[] bytes = findInClassPath(classPath, resourceName); + if (bytes != null) { + TestClassChecker checker = new TestClassChecker(); + new ClassReader(bytes).accept(checker, 0); + if (!checker.isAbstract() && checker.getPublicConstructorCount() == 1) { + String internalName = preferred.replace('.', '/'); + if (parentInspector.allParents(internalName).contains("sbt/testing/Framework")) { + result.add(internalName); + } + } + } + } + if (!result.isEmpty()) return result; + + // scan all classes in classpath + for (Map.Entry entry : listClassesByteCode(classPath, true).entrySet()) { + String name = entry.getKey(); + if (name.contains("module-info")) continue; + TestClassChecker checker = new TestClassChecker(); + new ClassReader(entry.getValue()).accept(checker, 0); + if (!checker.isAbstract() && checker.getPublicConstructorCount() == 1) { + if (parentInspector.allParents(name).contains("sbt/testing/Framework")) { + result.add(name); + } + } + } + return result; + } + + public static List taskDefs( + List classPath, + boolean keepJars, + List fingerprints, + ParentInspector parentInspector, + ClassLoader loader + ) { + List result = new ArrayList<>(); + for (Map.Entry entry : listClassesByteCode(classPath, keepJars).entrySet()) { + String name = entry.getKey(); + if (name.contains("module-info")) continue; + try { + Optional fp = matchFingerprints( + name, + new ByteArrayInputStream(entry.getValue()), + fingerprints, + parentInspector, + loader + ); + if (fp.isPresent()) { + String stripped = name.endsWith("$") ? name.substring(0, name.length() - 1) : name; + String clsName = stripped.replace('/', '.').replace('\\', '.'); + result.add(new TaskDef(clsName, fp.get(), false, new Selector[]{new SuiteSelector()})); + } + } catch (IOException e) { + // skip + } + } + return result; + } + + private static Map listClassesByteCode(List classPath, boolean keepJars) { + Map result = new LinkedHashMap<>(); + for (Path entry : classPath) { + result.putAll(listClassesByteCode(entry, keepJars)); + } + return result; + } + + private static Map listClassesByteCode(Path entry, boolean keepJars) { + Map result = new LinkedHashMap<>(); + if (Files.isDirectory(entry)) { + try (Stream stream = Files.walk(entry, Integer.MAX_VALUE)) { + stream.filter(p -> p.getFileName().toString().endsWith(".class")) + .forEach(p -> { + String rel = entry.relativize(p).toString().replace('\\', '/'); + String name = rel.endsWith(".class") ? rel.substring(0, rel.length() - 6) : rel; + try { + result.put(name, Files.readAllBytes(p)); + } catch (IOException e) { + // skip + } + }); + } catch (IOException e) { + // skip + } + } else if (keepJars && Files.isRegularFile(entry)) { + byte[] buf = new byte[16384]; + try (ZipFile zf = new ZipFile(entry.toFile())) { + Enumeration entries = zf.entries(); + while (entries.hasMoreElements()) { + ZipEntry ze = entries.nextElement(); + if (!ze.getName().endsWith(".class")) continue; + String name = ze.getName(); + name = name.substring(0, name.length() - 6); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (InputStream is = zf.getInputStream(ze)) { + int read; + while ((read = is.read(buf)) >= 0) { + baos.write(buf, 0, read); + } + } + result.put(name, baos.toByteArray()); + } + } catch (IOException e) { + // skip + } + } + return result; + } + + private static byte[] findInClassPath(List classPath, String name) { + for (Path entry : classPath) { + byte[] found = findInClassPathEntry(entry, name); + if (found != null) return found; + } + return null; + } + + private static byte[] findInClassPathEntry(Path entry, String name) { + if (Files.isDirectory(entry)) { + Path p = entry.resolve(name); + if (Files.isRegularFile(p)) { + try { + return Files.readAllBytes(p); + } catch (IOException e) { + return null; + } + } + } else if (Files.isRegularFile(entry)) { + byte[] buf = new byte[16384]; + try (ZipFile zf = new ZipFile(entry.toFile())) { + ZipEntry ze = zf.getEntry(name); + if (ze == null) return null; + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (InputStream is = zf.getInputStream(ze)) { + int read; + while ((read = is.read(buf)) >= 0) { + baos.write(buf, 0, read); + } + } + return baos.toByteArray(); + } catch (IOException e) { + return null; + } + } + return null; + } + + public static class TestClassChecker extends ClassVisitor { + private String name; + private int publicConstructorCount = 0; + private boolean isInterface = false; + private boolean isAbstract = false; + private List implementsList = new ArrayList<>(); + + public TestClassChecker() { + super(Opcodes.ASM9); + } + + @Override + public void visit(int version, int access, String name, String signature, + String superName, String[] interfaces) { + this.name = name; + this.isInterface = (access & Opcodes.ACC_INTERFACE) != 0; + this.isAbstract = (access & Opcodes.ACC_ABSTRACT) != 0; + if (superName != null) implementsList.add(superName); + if (interfaces != null) { + for (String iface : interfaces) { + implementsList.add(iface); + } + } + } + + @Override + public MethodVisitor visitMethod(int access, String name, String descriptor, + String signature, String[] exceptions) { + if ("".equals(name) && (access & Opcodes.ACC_PUBLIC) != 0) { + publicConstructorCount++; + } + return null; + } + + public String getName() { return name; } + public int getPublicConstructorCount() { return publicConstructorCount; } + public boolean isInterface() { return isInterface; } + public boolean isAbstract() { return isAbstract; } + public List getImplements() { return implementsList; } + } +} diff --git a/modules/java-test-runner/src/main/java/scala/build/testrunner/JavaDynamicTestRunner.java b/modules/java-test-runner/src/main/java/scala/build/testrunner/JavaDynamicTestRunner.java new file mode 100644 index 0000000000..68b89ed2f7 --- /dev/null +++ b/modules/java-test-runner/src/main/java/scala/build/testrunner/JavaDynamicTestRunner.java @@ -0,0 +1,161 @@ +package scala.build.testrunner; + +import sbt.testing.*; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.regex.Pattern; + +public class JavaDynamicTestRunner { + + /** + * Based on junit-interface GlobFilter.compileGlobPattern: + * https://github.com/sbt/junit-interface/blob/f8c6372ed01ce86f15393b890323d96afbe6d594/src/main/java/com/novocode/junit/GlobFilter.java#L37 + * + * Converts a glob expression (only * supported) into a regex Pattern. + */ + private static Pattern globPattern(String expr) { + String[] parts = expr.split("\\*", -1); + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < parts.length; i++) { + if (i != 0) sb.append(".*"); + if (!parts[i].isEmpty()) sb.append(Pattern.quote(parts[i].replace("\n", "\\n"))); + } + return Pattern.compile(sb.toString()); + } + + public static void main(String[] args) { + List testFrameworks = new ArrayList<>(); + List remainingArgs = new ArrayList<>(); + boolean requireTests = false; + int verbosity = 0; + Optional testOnly = Optional.empty(); + + boolean pastDashDash = false; + for (String arg : args) { + if (pastDashDash) { + remainingArgs.add(arg); + } else if ("--".equals(arg)) { + pastDashDash = true; + } else if (arg.startsWith("--test-framework=")) { + testFrameworks.add(arg.substring("--test-framework=".length())); + } else if (arg.startsWith("--test-only=")) { + testOnly = Optional.of(arg.substring("--test-only=".length())); + } else if (arg.startsWith("--verbosity=")) { + try { + verbosity = Integer.parseInt(arg.substring("--verbosity=".length())); + } catch (NumberFormatException e) { + // ignore malformed + } + } else if ("--require-tests".equals(arg)) { + requireTests = true; + } else { + remainingArgs.add(arg); + } + } + + JavaTestLogger logger = new JavaTestLogger(verbosity, System.err); + + if (!testFrameworks.isEmpty()) { + logger.debug("Directly passed " + testFrameworks.size() + " test frameworks:\n - " + + String.join("\n - ", testFrameworks)); + } + + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); + java.util.List classPath0 = JavaTestRunner.classPath(classLoader); + + List frameworks; + if (!testFrameworks.isEmpty()) { + frameworks = new ArrayList<>(); + for (String fw : testFrameworks) { + try { + frameworks.add(JavaFrameworkUtils.loadFramework(classLoader, fw)); + } catch (Exception e) { + System.err.println("Could not load test framework: " + fw); + System.err.println(e.getMessage()); + System.exit(1); + } + } + } else { + List frameworkServices = JavaFrameworkUtils.findFrameworkServices(classLoader); + List scannedFrameworks = JavaFrameworkUtils.findFrameworks( + classPath0, classLoader, JavaTestRunner.commonTestFrameworks() + ); + List toRun = JavaFrameworkUtils.getFrameworksToRun( + frameworkServices, scannedFrameworks, logger + ); + if (toRun.isEmpty()) { + if (verbosity >= 2) { + throw new RuntimeException("No test framework found"); + } else { + System.err.println("No test framework found"); + System.exit(1); + } + } + frameworks = toRun; + } + + String[] runnerArgs = remainingArgs.toArray(new String[0]); + final Optional testOnlyFinal = testOnly; + final boolean requireTestsFinal = requireTests; + + boolean anyFailed = false; + for (Framework framework : frameworks) { + logger.log("Running test framework: " + framework.name()); + Fingerprint[] fingerprints = framework.fingerprints(); + Runner runner = framework.runner(runnerArgs, new String[0], classLoader); + + List> classes = new ArrayList<>(); + for (String name : JavaFrameworkUtils.listClasses(classPath0, false)) { + try { + classes.add(classLoader.loadClass(name)); + } catch (ClassNotFoundException | NoClassDefFoundError | + UnsupportedClassVersionError | IncompatibleClassChangeError e) { + // skip + } + } + + List taskDefs = new ArrayList<>(); + for (Class cls : classes) { + Optional fp = JavaFrameworkUtils.matchFingerprints( + classLoader, cls, fingerprints + ); + if (!fp.isPresent()) continue; + String clsName = cls.getName().endsWith("$") + ? cls.getName().substring(0, cls.getName().length() - 1) + : cls.getName(); + if (testOnlyFinal.isPresent()) { + Pattern pat = globPattern(testOnlyFinal.get()); + if (!pat.matcher(clsName).matches()) continue; + } + taskDefs.add(new TaskDef(clsName, fp.get(), false, new Selector[]{new SuiteSelector()})); + } + + Task[] initialTasks = runner.tasks(taskDefs.toArray(new TaskDef[0])); + List events = JavaTestRunner.runTasks(Arrays.asList(initialTasks), System.out); + + boolean failed = events.stream().anyMatch(ev -> + ev.status() == Status.Error || + ev.status() == Status.Failure || + ev.status() == Status.Canceled + ); + + String doneMsg = runner.done(); + if (doneMsg != null && !doneMsg.isEmpty()) System.out.println(doneMsg); + + if (requireTestsFinal && events.isEmpty()) { + logger.error("Error: no tests were run for " + framework.name() + "."); + anyFailed = true; + } else if (failed) { + logger.error("Error: " + framework.name() + " tests failed."); + anyFailed = true; + } else { + logger.log(framework.name() + " tests ran successfully."); + } + } + + System.exit(anyFailed ? 1 : 0); + } +} diff --git a/modules/java-test-runner/src/main/java/scala/build/testrunner/JavaFrameworkUtils.java b/modules/java-test-runner/src/main/java/scala/build/testrunner/JavaFrameworkUtils.java new file mode 100644 index 0000000000..fc921e4521 --- /dev/null +++ b/modules/java-test-runner/src/main/java/scala/build/testrunner/JavaFrameworkUtils.java @@ -0,0 +1,196 @@ +package scala.build.testrunner; + +import sbt.testing.*; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Modifier; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.*; +import java.util.zip.ZipEntry; +import java.util.zip.ZipFile; +import java.util.stream.Stream; + +public class JavaFrameworkUtils { + + public static List findFrameworkServices(ClassLoader loader) { + List result = new ArrayList<>(); + ServiceLoader serviceLoader = ServiceLoader.load(Framework.class, loader); + for (Framework f : serviceLoader) { + result.add(f); + } + return result; + } + + public static Framework loadFramework(ClassLoader loader, String className) throws Exception { + Class cls = loader.loadClass(className); + return (Framework) cls.getConstructor().newInstance(); + } + + public static List findFrameworks( + List classPath, + ClassLoader loader, + List preferredClasses + ) { + Class frameworkCls = Framework.class; + List result = new ArrayList<>(); + Set seen = new LinkedHashSet<>(); + + // first try preferred classes, then scan classpath + List candidates = new ArrayList<>(preferredClasses); + for (String name : listClasses(classPath, true)) { + if (!seen.contains(name)) { + candidates.add(name); + } + } + + for (String name : candidates) { + if (!seen.add(name)) continue; + Class cls; + try { + cls = loader.loadClass(name); + } catch (ClassNotFoundException | UnsupportedClassVersionError | + NoClassDefFoundError | IncompatibleClassChangeError e) { + continue; + } + if (!frameworkCls.isAssignableFrom(cls)) continue; + if (Modifier.isAbstract(cls.getModifiers())) continue; + long publicNoArgCtors = Arrays.stream(cls.getConstructors()) + .filter(c -> Modifier.isPublic(c.getModifiers()) && c.getParameterCount() == 0) + .count(); + if (publicNoArgCtors != 1) continue; + try { + Framework instance = (Framework) cls.getConstructor().newInstance(); + result.add(instance); + } catch (Exception e) { + // skip + } + } + return result; + } + + public static Optional matchFingerprints( + ClassLoader loader, + Class cls, + Fingerprint[] fingerprints + ) { + boolean isModule = cls.getName().endsWith("$"); + long publicCtorCount = Arrays.stream(cls.getConstructors()) + .filter(c -> Modifier.isPublic(c.getModifiers())) + .count(); + boolean noPublicConstructors = publicCtorCount == 0; + boolean definitelyNoTests = Modifier.isAbstract(cls.getModifiers()) || + cls.isInterface() || + publicCtorCount > 1 || + isModule != noPublicConstructors; + if (definitelyNoTests) return Optional.empty(); + + for (Fingerprint fp : fingerprints) { + if (fp instanceof SubclassFingerprint) { + SubclassFingerprint sf = (SubclassFingerprint) fp; + if (sf.isModule() != isModule) continue; + try { + Class superCls = loader.loadClass(sf.superclassName()); + if (superCls.isAssignableFrom(cls)) return Optional.of(fp); + } catch (ClassNotFoundException e) { + // skip + } + } else if (fp instanceof AnnotatedFingerprint) { + AnnotatedFingerprint af = (AnnotatedFingerprint) fp; + if (af.isModule() != isModule) continue; + try { + @SuppressWarnings("unchecked") + Class annotationCls = + (Class) loader.loadClass(af.annotationName()); + boolean matches = + cls.isAnnotationPresent(annotationCls) || + Arrays.stream(cls.getDeclaredMethods()) + .anyMatch(m -> m.isAnnotationPresent(annotationCls)) || + Arrays.stream(cls.getMethods()) + .anyMatch(m -> m.isAnnotationPresent(annotationCls) && + Modifier.isPublic(m.getModifiers())); + if (matches) return Optional.of(fp); + } catch (ClassNotFoundException e) { + // skip + } + } + } + return Optional.empty(); + } + + public static List getFrameworksToRun( + List frameworkServices, + List frameworks, + JavaTestLogger logger + ) { + List all = new ArrayList<>(frameworkServices); + all.addAll(frameworks); + return getFrameworksToRun(all, logger); + } + + public static List getFrameworksToRun( + List allFrameworks, + JavaTestLogger logger + ) { + // dedup by name + Map byName = new LinkedHashMap<>(); + for (Framework f : allFrameworks) { + byName.putIfAbsent(f.name(), f); + } + List distinct = new ArrayList<>(byName.values()); + + // filter out frameworks that are superclasses of another framework in the list + List finalFrameworks = new ArrayList<>(); + for (Framework f1 : distinct) { + boolean isInherited = distinct.stream() + .filter(f2 -> f2 != f1) + .anyMatch(f2 -> f1.getClass().isAssignableFrom(f2.getClass())); + if (!isInherited) finalFrameworks.add(f1); + } + return finalFrameworks; + } + + public static List listClasses(List classPath, boolean keepJars) { + List result = new ArrayList<>(); + for (Path entry : classPath) { + result.addAll(listClasses(entry, keepJars)); + } + return result; + } + + public static List listClasses(Path entry, boolean keepJars) { + List result = new ArrayList<>(); + if (Files.isDirectory(entry)) { + try (Stream stream = Files.walk(entry, Integer.MAX_VALUE)) { + stream.filter(p -> p.getFileName().toString().endsWith(".class")) + .map(entry::relativize) + .map(p -> { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < p.getNameCount(); i++) { + if (i > 0) sb.append("."); + sb.append(p.getName(i).toString()); + } + String name = sb.toString(); + return name.endsWith(".class") ? name.substring(0, name.length() - 6) : name; + }) + .forEach(result::add); + } catch (Exception e) { + // skip + } + } else if (keepJars && Files.isRegularFile(entry)) { + try (ZipFile zf = new ZipFile(entry.toFile())) { + Enumeration entries = zf.entries(); + while (entries.hasMoreElements()) { + ZipEntry ze = entries.nextElement(); + String name = ze.getName(); + if (name.endsWith(".class")) { + result.add(name.substring(0, name.length() - 6).replace("/", ".")); + } + } + } catch (Exception e) { + // skip + } + } + return result; + } +} diff --git a/modules/java-test-runner/src/main/java/scala/build/testrunner/JavaTestLogger.java b/modules/java-test-runner/src/main/java/scala/build/testrunner/JavaTestLogger.java new file mode 100644 index 0000000000..a02166de3b --- /dev/null +++ b/modules/java-test-runner/src/main/java/scala/build/testrunner/JavaTestLogger.java @@ -0,0 +1,29 @@ +package scala.build.testrunner; + +import java.io.PrintStream; + +public class JavaTestLogger { + private final int verbosity; + private final PrintStream out; + + public JavaTestLogger(int verbosity, PrintStream out) { + this.verbosity = verbosity; + this.out = out; + } + + public void error(String message) { + out.println(message); + } + + public void message(String message) { + if (verbosity >= 0) out.println(message); + } + + public void log(String message) { + if (verbosity >= 1) out.println(message); + } + + public void debug(String message) { + if (verbosity >= 2) out.println(message); + } +} diff --git a/modules/java-test-runner/src/main/java/scala/build/testrunner/JavaTestRunner.java b/modules/java-test-runner/src/main/java/scala/build/testrunner/JavaTestRunner.java new file mode 100644 index 0000000000..cbd46f1539 --- /dev/null +++ b/modules/java-test-runner/src/main/java/scala/build/testrunner/JavaTestRunner.java @@ -0,0 +1,87 @@ +package scala.build.testrunner; + +import sbt.testing.*; + +import java.io.File; +import java.io.PrintStream; +import java.net.URLClassLoader; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Deque; +import java.util.List; + +public class JavaTestRunner { + + public static List commonTestFrameworks() { + return Arrays.asList( + "munit.Framework", + "utest.runner.Framework", + "org.scalacheck.ScalaCheckFramework", + "zio.test.sbt.ZTestFramework", + "org.scalatest.tools.Framework", + "com.novocode.junit.JUnitFramework", + "org.scalajs.junit.JUnitFramework", + "weaver.framework.CatsEffect" + ); + } + + public static List classPath(ClassLoader loader) { + List result = new ArrayList<>(); + collectClassPath(loader, result); + return result; + } + + private static void collectClassPath(ClassLoader loader, List result) { + if (loader == null) return; + if (loader instanceof URLClassLoader) { + URLClassLoader urlLoader = (URLClassLoader) loader; + for (java.net.URL url : urlLoader.getURLs()) { + if ("file".equals(url.getProtocol())) { + try { + result.add(Paths.get(url.toURI()).toAbsolutePath()); + } catch (Exception e) { + // skip + } + } + } + } else if (loader.getClass().getName().equals("jdk.internal.loader.ClassLoaders$AppClassLoader")) { + String cp = System.getProperty("java.class.path", ""); + for (String entry : cp.split(File.pathSeparator)) { + if (!entry.isEmpty()) { + result.add(Paths.get(entry)); + } + } + } + collectClassPath(loader.getParent(), result); + } + + public static List runTasks(List initialTasks, PrintStream out) { + Deque tasks = new ArrayDeque<>(initialTasks); + List events = new ArrayList<>(); + + sbt.testing.Logger logger = new sbt.testing.Logger() { + public boolean ansiCodesSupported() { return true; } + public void error(String msg) { out.println(msg); } + public void warn(String msg) { out.println(msg); } + public void info(String msg) { out.println(msg); } + public void debug(String msg) { out.println(msg); } + public void trace(Throwable t) { t.printStackTrace(out); } + }; + + EventHandler eventHandler = event -> events.add(event); + sbt.testing.Logger[] loggers = new sbt.testing.Logger[]{logger}; + + while (!tasks.isEmpty()) { + Task task = tasks.poll(); + Task[] newTasks = task.execute(eventHandler, loggers); + for (Task t : newTasks) { + tasks.add(t); + } + } + + return events; + } +} diff --git a/modules/options/src/main/scala/scala/build/Artifacts.scala b/modules/options/src/main/scala/scala/build/Artifacts.scala index bb430a5d81..c3cfa4c8ea 100644 --- a/modules/options/src/main/scala/scala/build/Artifacts.scala +++ b/modules/options/src/main/scala/scala/build/Artifacts.scala @@ -49,6 +49,7 @@ final case class Artifacts( extraSourceJars: Seq[os.Path], scalaOpt: Option[ScalaArtifacts], hasJvmRunner: Boolean, + hasJavaTestRunner: Boolean, resolution: Option[Resolution] ) { @@ -131,6 +132,7 @@ object Artifacts { jvmVersion: Int, addJvmRunner: Option[Boolean], addJvmTestRunner: Boolean, + addJvmJavaTestRunner: Boolean, addJmhDependencies: Option[String], extraRepositories: Seq[Repository], keepResolution: Boolean, @@ -189,11 +191,19 @@ object Artifacts { } else Nil + val jvmJavaTestRunnerDependencies = + if addJvmJavaTestRunner then + Seq( + dep"${Constants.javaTestRunnerOrganization}:${Constants.javaTestRunnerModuleName}:${Constants.javaTestRunnerVersion}" + ) + else Nil + val jmhDependencies = addJmhDependencies.toSeq .map(version => dep"${Constants.jmhOrg}:${Constants.jmhGeneratorBytecodeModule}:$version") val maybeSnapshotRepo = { val hasSnapshots = jvmTestRunnerDependencies.exists(_.version.endsWith("SNAPSHOT")) || + jvmJavaTestRunnerDependencies.exists(_.version.endsWith("SNAPSHOT")) || scalaArtifactsParamsOpt.flatMap(_.scalaNativeCliVersion).exists(_.endsWith("SNAPSHOT")) val hasNightlies = scalaArtifactsParamsOpt.exists(a => a.params.scalaVersion.endsWith("-NIGHTLY") || @@ -409,6 +419,7 @@ object Artifacts { val internalDependencies = jvmTestRunnerDependencies.map(Positioned.none) ++ + jvmJavaTestRunnerDependencies.map(Positioned.none) ++ scalaOpt.toSeq.flatMap(_.internalDependencies).map(Positioned.none) ++ jmhDependencies.map(Positioned.none) val updatedDependencies = dependencies ++ @@ -582,6 +593,7 @@ object Artifacts { extraSourceJars, scalaOpt, hasRunner, + addJvmJavaTestRunner, if (keepResolution) Some(fetchRes.resolution) else None ) } diff --git a/modules/options/src/main/scala/scala/build/options/BuildOptions.scala b/modules/options/src/main/scala/scala/build/options/BuildOptions.scala index 49b9a4cf5c..aa96da1d41 100644 --- a/modules/options/src/main/scala/scala/build/options/BuildOptions.scala +++ b/modules/options/src/main/scala/scala/build/options/BuildOptions.scala @@ -222,6 +222,10 @@ final case class BuildOptions( private def addJvmTestRunner: Boolean = platform.value == Platform.JVM && internalDependencies.addTestRunnerDependency + + private def addJvmJavaTestRunner: Boolean = + platform.value == Platform.JVM && + internalDependencies.addTestRunnerDependency private def addJsTestBridge: Option[String] = if (platform.value == Platform.JS && internalDependencies.addTestRunnerDependency) Some(scalaJsOptions.finalVersion) @@ -476,6 +480,7 @@ final case class BuildOptions( if (scalaArtifactsParamsOpt.isDefined) None else Some(false) // no runner in pure Java mode } + val isJavaBuild = scalaArtifactsParamsOpt.isEmpty val extraRepositories: Seq[Repository] = value(finalRepositories) val maybeArtifacts = Artifacts( scalaArtifactsParamsOpt = scalaArtifactsParamsOpt, @@ -490,7 +495,8 @@ final case class BuildOptions( fetchSources = classPathOptions.fetchSources.getOrElse(false), jvmVersion = javaHome().value.version, addJvmRunner = addRunnerDependency0, - addJvmTestRunner = isTests && addJvmTestRunner, + addJvmTestRunner = isTests && addJvmTestRunner && !isJavaBuild, + addJvmJavaTestRunner = isTests && addJvmJavaTestRunner && isJavaBuild, addJmhDependencies = jmhOptions.finalJmhVersion, extraRepositories = extraRepositories, keepResolution = internal.keepResolution,