Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions core/codesig/src/mill/codesig/CodeSig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,63 @@ package mill.codesig
import mill.codesig.JvmModel.*

object CodeSig {
def compute(
classFiles: Seq[os.Path],
upstreamClasspath: Seq[os.Path],
ignoreCall: (Option[MethodDef], MethodSig) => Boolean,
logger: Logger,
prevTransitiveCallGraphHashesOpt: () => Option[Map[String, Int]]
): CallGraphAnalysis = {
implicit val st: SymbolTable = new SymbolTable()

private def callGraphAnalysis(
classFiles: Seq[os.Path],
upstreamClasspath: Seq[os.Path],
ignoreCall: (Option[MethodDef], MethodSig) => Boolean
)(implicit st: SymbolTable): CallGraphAnalysis = {
val localSummary = LocalSummary.apply(classFiles.iterator.map(os.read.inputStream(_)))
logger.log(localSummary)

val externalSummary = ExternalSummary.apply(localSummary, upstreamClasspath)
logger.log(externalSummary)

val resolvedMethodCalls = ResolvedCalls.apply(localSummary, externalSummary)
logger.log(resolvedMethodCalls)

new CallGraphAnalysis(
localSummary,
resolvedMethodCalls,
externalSummary,
ignoreCall,
logger,
prevTransitiveCallGraphHashesOpt
ignoreCall
)
}

def getCallGraphAnalysis(
classFiles: Seq[os.Path],
upstreamClasspath: Seq[os.Path],
ignoreCall: (Option[MethodDef], MethodSig) => Boolean
): CallGraphAnalysis = {
implicit val st: SymbolTable = new SymbolTable()

callGraphAnalysis(classFiles, upstreamClasspath, ignoreCall)
}

def compute(
classFiles: Seq[os.Path],
upstreamClasspath: Seq[os.Path],
ignoreCall: (Option[MethodDef], MethodSig) => Boolean,
logger: Logger,
prevTransitiveCallGraphHashesOpt: () => Option[Map[String, Int]]
): CallGraphAnalysis = {
implicit val st: SymbolTable = new SymbolTable()

val callAnalysis = callGraphAnalysis(classFiles, upstreamClasspath, ignoreCall)

logger.log(callAnalysis.localSummary)
logger.log(callAnalysis.externalSummary)
logger.log(callAnalysis.resolved)

logger.mandatoryLog(callAnalysis.methodCodeHashes)
logger.mandatoryLog(callAnalysis.prettyCallGraph)
logger.mandatoryLog(callAnalysis.transitiveCallGraphHashes0)

logger.log(callAnalysis.transitiveCallGraphHashes)

val spanningInvalidationTree = callAnalysis.calculateSpanningInvalidationTree {
prevTransitiveCallGraphHashesOpt()
}

logger.mandatoryLog(spanningInvalidationTree)

callAnalysis
}
}
6 changes: 4 additions & 2 deletions core/codesig/src/mill/codesig/ExternalSummary.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import mill.codesig.JvmModel.*
import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Opcodes}

import java.net.URLClassLoader
import scala.util.Try

case class ExternalSummary(
directMethods: Map[JCls, Map[MethodSig, Boolean]],
Expand Down Expand Up @@ -47,7 +48,8 @@ object ExternalSummary {

def load(cls: JCls): Unit = methodsPerCls.getOrElse(cls, load0(cls))

def load0(cls: JCls): Unit = {
// Some macros implementations will fail the ClassReader, we can skip them
def load0(cls: JCls): Unit = Try {
val visitor = new MyClassVisitor()
val resourcePath =
os.resource(upstreamClassloader) / os.SubPath(cls.name.replace('.', '/') + ".class")
Expand All @@ -61,7 +63,7 @@ object ExternalSummary {
methodsPerCls(cls) = visitor.methods
ancestorsPerCls(cls) = visitor.ancestors
ancestorsPerCls(cls).foreach(load)
}
}.getOrElse(())

(allDirectAncestors ++ allMethodCallParamClasses)
.filter(!localSummary.contains(_))
Expand Down
128 changes: 88 additions & 40 deletions core/codesig/src/mill/codesig/ReachabilityAnalysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@ package mill.codesig

import mill.codesig.JvmModel.*
import mill.internal.{SpanningForest, Tarjans}
import ujson.Obj
import ujson.{Obj, Arr}
import upickle.default.{Writer, writer}

import scala.collection.immutable.SortedMap
import scala.collection.mutable

class CallGraphAnalysis(
localSummary: LocalSummary,
resolved: ResolvedCalls,
externalSummary: ExternalSummary,
ignoreCall: (Option[MethodDef], MethodSig) => Boolean,
logger: Logger,
prevTransitiveCallGraphHashesOpt: () => Option[Map[String, Int]]
val localSummary: LocalSummary,
val resolved: ResolvedCalls,
val externalSummary: ExternalSummary,
ignoreCall: (Option[MethodDef], MethodSig) => Boolean
)(implicit st: SymbolTable) {

val methods: Map[MethodDef, LocalSummary.MethodInfo] = for {
Expand All @@ -40,17 +39,13 @@ class CallGraphAnalysis(
lazy val methodCodeHashes: SortedMap[String, Int] =
methods.map { case (k, vs) => (k.toString, vs.codeHash) }.to(SortedMap)

logger.mandatoryLog(methodCodeHashes)

lazy val prettyCallGraph: SortedMap[String, Array[CallGraphAnalysis.Node]] = {
indexGraphEdges.zip(indexToNodes).map { case (vs, k) =>
(k.toString, vs.map(indexToNodes))
}
.to(SortedMap)
}

logger.mandatoryLog(prettyCallGraph)

def transitiveCallGraphValues[V: scala.reflect.ClassTag](
nodeValues: Array[V],
reduce: (V, V) => V,
Expand Down Expand Up @@ -78,44 +73,45 @@ class CallGraphAnalysis(
.collect { case (CallGraphAnalysis.LocalDef(d), v) => (d.toString, v) }
.to(SortedMap)

logger.mandatoryLog(transitiveCallGraphHashes0)
logger.log(transitiveCallGraphHashes)

lazy val spanningInvalidationTree: Obj = prevTransitiveCallGraphHashesOpt() match {
case Some(prevTransitiveCallGraphHashes) =>
CallGraphAnalysis.spanningInvalidationTree(
prevTransitiveCallGraphHashes,
transitiveCallGraphHashes0,
indexToNodes,
indexGraphEdges
)
case None => ujson.Obj()
def calculateSpanningInvalidationTree(
prevTransitiveCallGraphHashesOpt: => Option[Map[String, Int]]
): Obj = {
prevTransitiveCallGraphHashesOpt match {
case Some(prevTransitiveCallGraphHashes) =>
CallGraphAnalysis.spanningInvalidationTree(
prevTransitiveCallGraphHashes,
transitiveCallGraphHashes0,
indexToNodes,
indexGraphEdges
)
case None => ujson.Obj()
}
}

logger.mandatoryLog(spanningInvalidationTree)
def calculateInvalidatedClassNames(
prevTransitiveCallGraphHashesOpt: => Option[Map[String, Int]]
): Set[String] = {
prevTransitiveCallGraphHashesOpt match {
case Some(prevTransitiveCallGraphHashes) =>
CallGraphAnalysis.invalidatedClassNames(
prevTransitiveCallGraphHashes,
transitiveCallGraphHashes0,
indexToNodes,
indexGraphEdges
)
case None => Set.empty
}
}
}

object CallGraphAnalysis {

/**
* Computes the minimal spanning forest of the that covers the nodes in the
* call graph whose transitive call graph hashes has changed since the last
* run, rendered as a JSON dictionary tree. This provides a great "debug
* view" that lets you easily Cmd-F to find a particular node and then trace
* it up the JSON hierarchy to figure out what upstream node was the root
* cause of the change in the callgraph.
*
* There are typically multiple possible spanning forests for a given graph;
* one is chosen arbitrarily. This is usually fine, since when debugging you
* typically are investigating why there's a path to a node at all where none
* should exist, rather than trying to fully analyse all possible paths
*/
def spanningInvalidationTree(
private def getSpanningForest(
prevTransitiveCallGraphHashes: Map[String, Int],
transitiveCallGraphHashes0: Array[(CallGraphAnalysis.Node, Int)],
indexToNodes: Array[Node],
indexGraphEdges: Array[Array[Int]]
): ujson.Obj = {
) = {
val transitiveCallGraphHashes0Map = transitiveCallGraphHashes0.toMap

val nodesWithChangedHashes = indexGraphEdges
Expand All @@ -135,12 +131,64 @@ object CallGraphAnalysis {
val reverseGraphEdges =
indexGraphEdges.indices.map(reverseGraphMap.getOrElse(_, Array[Int]())).toArray

SpanningForest.apply(reverseGraphEdges, nodesWithChangedHashes, false)
}

/**
* Computes the minimal spanning forest of the that covers the nodes in the
* call graph whose transitive call graph hashes has changed since the last
* run, rendered as a JSON dictionary tree. This provides a great "debug
* view" that lets you easily Cmd-F to find a particular node and then trace
* it up the JSON hierarchy to figure out what upstream node was the root
* cause of the change in the callgraph.
*
* There are typically multiple possible spanning forests for a given graph;
* one is chosen arbitrarily. This is usually fine, since when debugging you
* typically are investigating why there's a path to a node at all where none
* should exist, rather than trying to fully analyse all possible paths
*/
def spanningInvalidationTree(
prevTransitiveCallGraphHashes: Map[String, Int],
transitiveCallGraphHashes0: Array[(CallGraphAnalysis.Node, Int)],
indexToNodes: Array[Node],
indexGraphEdges: Array[Array[Int]]
): ujson.Obj = {
SpanningForest.spanningTreeToJsonTree(
SpanningForest.apply(reverseGraphEdges, nodesWithChangedHashes, false),
getSpanningForest(prevTransitiveCallGraphHashes, transitiveCallGraphHashes0, indexToNodes, indexGraphEdges),
k => indexToNodes(k).toString
)
}

/**
* Get all class names that have their hashcode changed compared to prevTransitiveCallGraphHashes
*/
def invalidatedClassNames(
prevTransitiveCallGraphHashes: Map[String, Int],
transitiveCallGraphHashes0: Array[(CallGraphAnalysis.Node, Int)],
indexToNodes: Array[Node],
indexGraphEdges: Array[Array[Int]]
): Set[String] = {
val rootNode = getSpanningForest(prevTransitiveCallGraphHashes, transitiveCallGraphHashes0, indexToNodes, indexGraphEdges)

val jsonValueQueue = mutable.ArrayDeque[(Int, SpanningForest.Node)]()
jsonValueQueue.appendAll(rootNode.values.toSeq)
val builder = Set.newBuilder[String]

while (jsonValueQueue.nonEmpty) {
val (nodeIndex, node) = jsonValueQueue.removeHead()
node.values.foreach { case (childIndex, childNode) =>
jsonValueQueue.append((childIndex, childNode))
}
indexToNodes(nodeIndex) match {
case CallGraphAnalysis.LocalDef(methodDef) => builder.addOne(methodDef.cls.name)
case CallGraphAnalysis.Call(methodCall) => builder.addOne(methodCall.cls.name)
case CallGraphAnalysis.ExternalClsCall(externalCls) => builder.addOne(externalCls.name)
}
}

builder.result()
}

def indexGraphEdges(
indexToNodes: Array[Node],
methods: Map[MethodDef, LocalSummary.MethodInfo],
Expand Down
21 changes: 16 additions & 5 deletions core/define/src/mill/define/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,16 @@ object Task extends TaskBase {
inline def Command[T](inline t: Result[T])(implicit
inline w: W[T],
inline ctx: mill.define.ModuleCtx
): Command[T] = ${ TaskMacros.commandImpl[T]('t)('w, 'ctx, exclusive = '{ false }) }
): Command[T] = ${ TaskMacros.commandImpl[T]('t)('w, 'ctx, exclusive = '{ false }, persistent = '{ false }) }


/**
* This version allow [[Command]] to be persistent
*/
inline def Command[T](inline persistent: Boolean)(inline t: Result[T])(implicit
inline w: W[T],
inline ctx: mill.define.ModuleCtx
): Command[T] = ${ TaskMacros.commandImpl[T]('t)('w, 'ctx, exclusive = '{ false }, persistent = '{ persistent }) }

/**
* @param exclusive Exclusive commands run serially at the end of an evaluation,
Expand All @@ -142,7 +151,7 @@ object Task extends TaskBase {
inline def apply[T](inline t: Result[T])(implicit
inline w: W[T],
inline ctx: mill.define.ModuleCtx
): Command[T] = ${ TaskMacros.commandImpl[T]('t)('w, 'ctx, '{ this.exclusive }) }
): Command[T] = ${ TaskMacros.commandImpl[T]('t)('w, 'ctx, '{ this.exclusive }, '{ false }) }
}

/**
Expand Down Expand Up @@ -396,7 +405,8 @@ class Command[+T](
val ctx0: mill.define.ModuleCtx,
val writer: W[?],
val isPrivate: Option[Boolean],
val exclusive: Boolean
val exclusive: Boolean,
override val persistent: Boolean
) extends NamedTask[T] {

override def asCommand: Some[Command[T]] = Some(this)
Expand Down Expand Up @@ -543,12 +553,13 @@ private object TaskMacros {
)(t: Expr[Result[T]])(
w: Expr[W[T]],
ctx: Expr[mill.define.ModuleCtx],
exclusive: Expr[Boolean]
exclusive: Expr[Boolean],
persistent: Expr[Boolean]
): Expr[Command[T]] = {
appImpl[Command, T](
(in, ev) =>
'{
new Command[T]($in, $ev, $ctx, $w, ${ taskIsPrivate() }, exclusive = $exclusive)
new Command[T]($in, $ev, $ctx, $w, ${ taskIsPrivate() }, exclusive = $exclusive, persistent = $persistent)
},
t
)
Expand Down
33 changes: 33 additions & 0 deletions example/javalib/testing/7-test-quick/build.mill
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//// SNIPPET:BUILD1
package build
import mill._, javalib._
import os._

object foo extends JavaModule {
object test extends JavaTests {
def testFramework = "com.novocode.junit.JUnitFramework" // Use JUnit 4 framework interface
def mvnDeps = Seq(
mvn"junit:junit:4.13.2", // JUnit 4 itself
mvn"com.novocode:junit-interface:0.11" // sbt-compatible JUnit interface
)
}
// Ultilities for replacing text in files
def replaceBar(args: String*) = Task.Command {
val relativePath = os.RelPath("../../../foo/src/Bar.java")
val filePath = Task.dest() / relativePath
os.write.over(filePath, os.read(filePath).replace(
"""return String.format("Hi, %s!", name);""",
"""return String.format("Ciao, %s!", name);"""
))
}

def replaceFooTest2(args: String*) = Task.Command {
val relativePath = os.RelPath("../../../foo/test/src/FooTest2.java")
val filePath = Task.dest() / relativePath
os.write.over(filePath, os.read(filePath).replace(
"""assertEquals("Hi, " + name + "!", greeted);""",
"""assertEquals("Ciao, " + name + "!", greeted);""",
))
}
}
//// SNIPPET:END
11 changes: 11 additions & 0 deletions example/javalib/testing/7-test-quick/foo/src/Bar.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package foo;

public class Bar {
public static String greet(String name) {
return String.format("Hello, %s!", name);
}

public static String greet2(String name) {
return String.format("Hi, %s!", name);
}
}
Loading