diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b53339c..4e25066 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,6 +11,6 @@ jobs: - name: Set up JDK uses: actions/setup-java@v1 with: - java-version: 16 + java-version: 25 - name: Build with Gradle run: ./gradlew build diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1308358..204cefd 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,10 +14,10 @@ jobs: - name: Set up JDK uses: actions/setup-java@v1 with: - java-version: 14 + java-version: 25 - name: Deploy with Gradle env: GRADLE_PUBLISH_REPO_URL: ${{ secrets.GRADLE_PUBLISH_REPO_URL }} GRADLE_PUBLISH_MAVEN_USER: ${{ secrets.GRADLE_PUBLISH_MAVEN_USER }} GRADLE_PUBLISH_MAVEN_PASSWORD: ${{ secrets.GRADLE_PUBLISH_MAVEN_PASSWORD }} - run: ./gradlew --no-daemon -Pversion=$(git tag --points-at HEAD) publish \ No newline at end of file + run: ./gradlew --no-daemon -Pversion=$(git tag --points-at HEAD) publish diff --git a/.gitignore b/.gitignore index fd0c699..704574d 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,5 @@ out/ .project .classpath -.settings/ \ No newline at end of file +.settings/ +/bin/ diff --git a/README.md b/README.md index 398073e..e6e35be 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Jabel - use modern Java 9-14 syntax when targeting Java 8 +# Jabel - use modern Java 9-25 syntax when targeting Java 8 > Because life is too short to wait for your users to upgrade their Java! @@ -74,8 +74,8 @@ Jabel has to be enabled as a Javac plugin in your maven-compiler-plugin: 8 - 14 - 14 + 25 + 25 -Xplugin:jabel diff --git a/example/build.gradle b/example/build.gradle index 73d4c1d..24cf60f 100644 --- a/example/build.gradle +++ b/example/build.gradle @@ -2,12 +2,13 @@ plugins { id "java" } -configure([tasks.compileJava]) { - sourceCompatibility = 16 +compileJava { + sourceCompatibility = 25 + targetCompatibility = 8 options.release = 8 javaCompiler = javaToolchains.compilerFor { - languageVersion = JavaLanguageVersion.of(21) + languageVersion = JavaLanguageVersion.of(25) } } @@ -22,8 +23,7 @@ test { } dependencies { - annotationProcessor project(":jabel-javac-plugin") - compileOnly project(":jabel-javac-plugin") + annotationProcessor project(":jabel-javac-plugin")//.files("build/libs/jabel-javac-plugin.jar") testImplementation 'junit:junit:4.13.2' } diff --git a/example/src/main/java/com/example/Java10FeaturesExample.java b/example/src/main/java/com/example/Java10FeaturesExample.java new file mode 100644 index 0000000..d749463 --- /dev/null +++ b/example/src/main/java/com/example/Java10FeaturesExample.java @@ -0,0 +1,44 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +import java.util.ArrayList; +import java.util.HashMap; + +/** + * Examples of Java 10 features desugared by the compiler:
+ * + * LOCAL_VARIABLE_TYPE_INFERENCE (var) + *
+ * // Source (Java 10+):
+ * var str = "hello";
+ * var list = new ArrayList<String>();
+ * for (var i = 0; i < 10; i++) { }
+ * for (var item : list) { }
+ *
+ * // Decompiled (Java 8):
+ * String str = "hello";
+ * ArrayList<String> list = new ArrayList<String>();
+ * for (int i = 0; i < 10; i++) { }
+ * for (String item : list) { }
+ * 
+ */ +public class Java10FeaturesExample { + + void varExamples() { + var str = "hello"; + var num = 42; + var list = new ArrayList(); + var map = new HashMap(); + + list.add("item"); + map.put("key", 1); + + for (var i = 0; i < 3; i++) { + list.add("item" + i); + } + for (var item : list) { + System.out.println(item); + } + } +} diff --git a/example/src/main/java/com/example/Java11FeaturesExample.java b/example/src/main/java/com/example/Java11FeaturesExample.java new file mode 100644 index 0000000..7cf0ffb --- /dev/null +++ b/example/src/main/java/com/example/Java11FeaturesExample.java @@ -0,0 +1,29 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +import java.util.function.BiFunction; + +/** + * Examples of Java 11 features desugared by the compiler:
+ * + * VAR_SYNTAX_IMPLICIT_LAMBDAS + *
+ * // Source (Java 11+):
+ * BiFunction<String, String, String> f = (var a, var b) -> a + b;
+ * // Allows annotations:
+ * Consumer<String> c = (@NonNull var s) -> System.out.println(s);
+ *
+ * // Decompiled (Java 8):
+ * BiFunction<String, String, String> f = (a, b) -> a + b;
+ * // or with explicit types:
+ * BiFunction<String, String, String> f = (String a, String b) -> a + b;
+ * 
+ */ +public class Java11FeaturesExample { + + BiFunction concat = (var a, var b) -> a + b; + + BiFunction add = + (var x, var y) -> x + y; +} diff --git a/example/src/main/java/com/example/Java14FeaturesExample.java b/example/src/main/java/com/example/Java14FeaturesExample.java new file mode 100644 index 0000000..e9b82fa --- /dev/null +++ b/example/src/main/java/com/example/Java14FeaturesExample.java @@ -0,0 +1,133 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +/** + * Examples of Java 14 features desugared by the compiler:
+ * + * SWITCH_MULTIPLE_CASE_LABELS + *
+ * // Source (Java 14+):
+ * case SAT, SUN -> "weekend";
+ *
+ * // Decompiled (Java 8):
+ * case SAT:
+ * case SUN:
+ *     return "weekend";
+ * 
+ *

+ * SWITCH_RULE (arrow syntax) + *

+ * // Source (Java 14+):
+ * case MON -> "monday";
+ *
+ * // Decompiled (Java 8):
+ * case MON:
+ *     return "monday";
+ * 
+ *

+ * SWITCH_EXPRESSION + *

+ * // Source (Java 14+):
+ * String result = switch (day) {
+ *     case MON -> "monday";
+ *     default -> { yield "other"; }
+ * };
+ *
+ * // Decompiled (Java 8):
+ * String result;
+ * switch (day) {
+ *     case MON:
+ *         result = "monday";
+ *         break;
+ *     default:
+ *         result = "other";
+ * }
+ * 
+ */ +public class Java14FeaturesExample { + + enum Day { MON, TUE, WED, THU, FRI, SAT, SUN } + + class Builder{ + public int n; + public Builder build(Object o){ + n++; + return this; + } + } + + String switchMultipleLabels(Day day) { + return switch (day) { + case SAT, SUN -> "weekend"; + case MON, TUE, WED, THU, FRI -> "weekday"; + }; + } + + String switchRule(Day day) { + return (switch (day) { + case MON -> "monday"; + case TUE -> "tuesday"; + default -> "other"; + }).toLowerCase(); + } + + int switchExpressionWithYield(Day day) { + return switch (day) { + case MON -> 1; + case TUE -> 2; + default -> { + int result = day.ordinal(); + yield result; + } + }; + } + + void statementSwitchWithArrow(Day day) { + switch (day) { + case SAT -> System.out.println("Saturday"); + case SUN -> System.out.println("Sunday"); + } + System.out.println("Weekday"); + } + + void noDefaultInjectedCase(int a) { + switch (a) { + case 1 -> { + System.out.println("Wrong value!!"); + return; + } + } + System.out.println("Passed!"); + } + + int[] asArrayElements(int a, int b) { + return new int[]{ + switch (a) { case 1 -> 10; default -> 0; }, + switch (b) { case 2 -> 20; default -> 0; } + }; + } + + String chained(String str) { + return switch ( + switch (str) { + case "test": { + str = "pok"; + yield "A"; + } + default: yield "B"; + } + ) { + case "A" -> str.isEmpty() ? "got-A" : "err"; + default -> "got-B"; + }; + } + + int twiceOf(int n) { + return new Builder().build(1).build(switch (n) { + case 1 -> 10; + case 2 -> 20; + default -> n * n; + }).build(4).n; + } +} diff --git a/example/src/main/java/com/example/Java15FeaturesExample.java b/example/src/main/java/com/example/Java15FeaturesExample.java new file mode 100644 index 0000000..46be666 --- /dev/null +++ b/example/src/main/java/com/example/Java15FeaturesExample.java @@ -0,0 +1,48 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +/** + * Examples of Java 15 features desugared by the compiler:
+ * + * TEXT_BLOCKS + *
+ * // Source (Java 15+):
+ * String s = """
+ *     Hello
+ *     World
+ *     """;
+ * String cont = """
+ *     Single \
+ *     line""";
+ *
+ * // Decompiled (Java 8):
+ * String s = "Hello\nWorld\n";
+ * String cont = "Single line";
+ * 
+ */ +public class Java15FeaturesExample { + + String basic = """ + Hello + World + """; + + String json = """ + {"name": "test", "value": 42} + """; + + String withTrailingSpace = """ + Line with space \s + Next line + """; + + String lineContinuation = """ + Single \ + line"""; + + String formatted = String.format(""" + Name: %s + Age: %d + """, "Alice", 25); +} diff --git a/example/src/main/java/com/example/Java16FeaturesExample.java b/example/src/main/java/com/example/Java16FeaturesExample.java new file mode 100644 index 0000000..e5ce86c --- /dev/null +++ b/example/src/main/java/com/example/Java16FeaturesExample.java @@ -0,0 +1,113 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +import java.util.List; +import java.util.Objects; + +/** + * Examples of Java 16 features desugared by the compiler:
+ * + * PATTERN_MATCHING_IN_INSTANCEOF + *
+ * // Source (Java 16+):
+ * if (obj instanceof String s) {
+ *     System.out.println(s.length());
+ * }
+ *
+ * // Decompiled (Java 8):
+ * if (obj instanceof String) {
+ *     String s = (String) obj;
+ *     System.out.println(s.length());
+ * }
+ * 
+ *

+ * RECORDS + *

+ * // Source (Java 16+):
+ * record Point(int x, int y) { }
+ *
+ * // Decompiled (Java 8):
+ * class Point {
+ *     private final int x;
+ *     private final int y;
+ *     Point(int x, int y) { this.x = x; this.y = y; }
+ *     int x() { return x; }
+ *     int y() { return y; }
+ *     public boolean equals(Object o) { ... }
+ *     public int hashCode() { ... }
+ *     public String toString() { ... }
+ * }
+ * 
+ */ +public class Java16FeaturesExample { + record Pair(A first, B second) {} + + record Point(int x, int y) { + public Point(int x, int y) { + if (x > Short.MAX_VALUE) throw new IllegalArgumentException(); + this.x = x; + if (y > Short.MAX_VALUE) throw new IllegalArgumentException(); + this.y = y; + } + } + + record Person(String name, int age) { + static Object planet; + + Person { + Objects.requireNonNull(name); + if (age < 0) throw new IllegalArgumentException(); + } + + } + + record MultipleTypes(boolean bool, byte b, short s, int i, long l, + float f, double d, String str, Point p) { + @Override + public boolean equals(Object o) { + return this == o; + } + + @Override + public String str() { + return "different string: " + str; + } + } + + void patternMatchingInstanceof(Object obj) { + if (obj instanceof String s) { + System.out.println(s.length()); + } + + if (obj instanceof String s && s.length() > 3) { + System.out.println(s.toUpperCase()); + } + + int len = obj instanceof String s ? s.length() : -1; + Integer.valueOf(len); // use variable to get expected byte code + } + + int inlinePatternMatching(Object obj) { + return obj instanceof String s ? s.length() : obj instanceof Integer i ? i : -1; + } + + void reifiableTypesInstanceof(Object obj) { + if (!(obj instanceof List list)) { + System.out.println("Not a list"); + return; + } + System.out.println(list.size()); + } + + void recordsExample() { + var p1 = new Point(3, 4); + var p2 = new Point(3, 4); + + int x = p1.x(); + int y = p1.y(); + boolean equal = p1.equals(p2); + int hash = p1.hashCode(); + String str = p1.toString(); + } +} diff --git a/example/src/main/java/com/example/Java17FeaturesExample.java b/example/src/main/java/com/example/Java17FeaturesExample.java new file mode 100644 index 0000000..05846e2 --- /dev/null +++ b/example/src/main/java/com/example/Java17FeaturesExample.java @@ -0,0 +1,68 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +/** + * Examples of Java 17 features desugared by the compiler:
+ * + * SEALED_CLASSES + *
+ * // Source (Java 17+):
+ * public sealed class Shape permits Circle, Square { }
+ * public final class Circle extends Shape { }
+ * public non-sealed class Square extends Shape { }
+ *
+ * // Decompiled (Java 8):
+ * public abstract class Shape { }  // sealed, permits removed
+ * public final class Circle extends Shape { }
+ * public class Square extends Shape { }  // non-sealed removed
+ * 
+ *

+ * REDUNDANT_STRICTFP + *

+ * // Source (Java 17+):
+ * public strictfp class Math { }
+ *
+ * // Decompiled (Java 8): same code (strictfp kept but redundant since Java 17)
+ * 
+ *

+ * PATTERN_SWITCH (preview since JDK 17) + *

+ * // Source (Java 17+ preview):
+ * switch (obj) {
+ *     case String s -> s.length();
+ *     case Integer i -> i;
+ *     default -> 0;
+ * }
+ * 
+ */ +public class Java17FeaturesExample { + + sealed class Shape permits Circle, Square, Rectangle { + private final String name; + Shape(String name) { this.name = name; } + String getName() { return name; } + } + + final class Circle extends Shape { + Circle() { super("Circle"); } + void test(){ + Square.test(); + } + } + + final class Square extends Shape { + Square() { super("Square"); } + static void test() {} + } + + non-sealed class Rectangle extends Shape { + Rectangle() { super("Rectangle"); } + } + + class SpecialRectangle extends Rectangle {} + + strictfp double strictMethod(double a, double b) { + return a * b + a / b; + } +} diff --git a/example/src/main/java/com/example/Java21FeaturesExample.java b/example/src/main/java/com/example/Java21FeaturesExample.java new file mode 100644 index 0000000..f531130 --- /dev/null +++ b/example/src/main/java/com/example/Java21FeaturesExample.java @@ -0,0 +1,362 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +/** + * Examples of Java 21 features with manual desugaring:
+ * + * CASE_NULL + * + *
+ * // Source (Java 21+):
+ * switch (str) {
+ *     case null -> "null";
+ *     case "a" -> "A";
+ *     default -> "other";
+ * }
+ *
+ * // Decompiled (Java 8):
+ * switch (str == null ? -1 : str.equals("a") ? 0 : 1) {
+ *     case -1: return "null";
+ *     case 0: return "A";
+ *     default: return "other";
+ * }
+ * 
+ *

+ * PATTERN_SWITCH + * + *

+ * // Source (Java 21+):
+ * switch (obj) {
+ *     case String s -> s.length();
+ *     case Integer i -> i;
+ *     default -> 0;
+ * }
+ *
+ * // Decompiled (Java 8):
+ * switch (Objects.requireNonNull(obj) instanceof String ? 0 : obj instanceof Integer ? 1 : 2) {
+ *     case 0: {
+ *         String s = (String) obj;
+ *         return s.length();
+ *     }
+ *     case 1: {
+ *         Integer i = (Integer) obj;
+ *         return i;
+ *     }
+ *     default: return 0;
+ * }
+ * 
+ *

+ * RECORD_PATTERNS + * + *

+ * // Source (Java 21+):
+ * if (obj instanceof Point(int x, int y)) {
+ *     return x + y;
+ * }
+ *
+ * // Decompiled (Java 8):
+ * if (obj instanceof Point) {
+ *     Point p = (Point) obj;
+ *     int x = p.x();
+ *     int y = p.y();
+ *     return x + y;
+ * }
+ * 
+ *

+ * UNCONDITIONAL_PATTERN_IN_INSTANCEOF + * + *

+ * // Source (Java 21+):
+ * if (str instanceof CharSequence cs) {
+ * } // always true for non-null
+ *
+ * // Decompiled (Java 8):
+ * if (str != null) {
+ *     CharSequence cs = str;
+ * }
+ * 
+ */ +public class Java21FeaturesExample { + public enum InitDeadlockTest { + START, END; + + static final String MSG = describe(START); + + static String describe(InitDeadlockTest dz) { + return switch (dz) { + case START -> "starting"; + case END -> "ending"; + case null -> "unknown"; + }; + } + } + + enum Day { + MON, TUE, WED, THU, FRI, SAT, SUN + } + + sealed interface Geometry permits Point, Triangle, Shape { + } + + record Point(int x, int y) implements Geometry { + } + + record Triangle(Point a, Point b, Point c) implements Geometry { + } + + record Shape(String name, Triangle triangle) implements Geometry { + } + + class Builder { + public int n = InitDeadlockTest.START.ordinal(); + + public Builder build(Object o) { + n++; + return this; + } + } + + String enumWithNull(Day day) { + return switch (day) { + case SAT, Day.SUN -> "weekend"; + case null -> ""; + case MON, TUE, WED, THU, FRI -> "weekday"; + }; + } + + String enumWithNullAndPattern(Day day) { + return switch (day) { + case SAT, Day.SUN -> "weekend"; + case null -> ""; + case MON, TUE, WED, THU, FRI -> "weekday"; + case Day d -> "not possible: " + d; + }; + } + + String numberWithNull(Float day) { + return switch (day) { + case 6f, 7f -> "weekend"; + case null -> ""; + case 1f, 2f, 3f, 4f, 5f -> "weekday"; + case Float f when f < 0 -> (day = 10f) + " "; + default -> ""; + }; + } + + String number(Float day) { + return switch (day) { + case 6f, 7f -> "weekend"; + case 1f, 2f, 3f, 4f, 5f -> "weekday"; + default -> ""; + }; + } + + String classicWithNull(Integer day) { + return switch (day) { + case 6, 7 -> "weekend"; + case null -> ""; + case 1, 2, 3, 4, 5 -> "weekday"; + default -> ""; + }; + } + + String caseNull(String input) { + return switch (input) { + case "hello" -> "hello"; + case null -> "null"; + default -> "other"; + }; + } + + String caseNullWithDefault(String input) { + return switch (input.toString()) { + case "specific" -> "specific"; + case null, default -> "null or default"; + }; + } + + String mixedCaseType(Object msg) { + String v; + switch (msg) { + case Day.MON: + v = "Monday?"; + break; + case Point(int x, int y) when x > 100: + v = "Far away at " + x + "," + y; + case null: + throw new RuntimeException(); + case String s: + v = "Message: " + s; + default: + v = "Other"; + } + return v; + } + + String patternSwitch(Object obj) { + return switch (obj) { + case String s -> "String: " + s; + case Integer i -> "Integer: " + i; + default -> throw new IllegalArgumentException(obj.toString()); + }; + } + + String patternSwitchClass(Object obj) { + return switch (obj) { + case Class c when c == Integer.class -> "Int"; + case Class c -> "Class: " + c.getName(); + default -> throw new IllegalArgumentException(obj.toString()); + }; + } + + String patternSwitchVariable(Object obj) { + String var = "ignored part:" + switch (obj) { + case String s -> var = "String: " + s; + case null -> var = "null"; + case Integer i -> var = "Integer: " + i; + default -> var = "Other"; + }; + var.toString(); // fake a use to avoid optimizations + return var; + } + + String patternSwitchWithGuard(Object obj) { + return (switch (obj) { + case String s when s.isEmpty() -> "empty"; + case String s when s.length() < 5 -> "short"; + case Integer i -> "int"; + case String s when s.length() > 3 && s.startsWith("a") -> "long-a"; + case String[] s -> "string list"; + case String s -> s; + case int[] i -> "int list"; + default -> "not a string"; + }).trim(); + } + + String patternSwitchExpressionWithGuard(Object obj) { + String var; + switch (obj) { + case String s when s.isEmpty() -> var = "empty"; + case String s when s.length() < 5 -> var = "short"; + case Integer i -> var = "int"; + case String s -> var = "long"; + default -> var = "not a string"; + } + var.toString(); // fake a use to avoid optimizations + return var; + } + + String patternSwitchStatement(Object obj) { + String result; + switch (obj) { + case String s: + result = "String: " + s; + break; + case Integer i: + result = "Int: " + i; + break; + default: + result = "Other"; + break; + } + return result; + } + + int switchYield(Object obj) { + return switch (obj) { + case String s -> { + int len = s.length(); + yield len * 2; + } + case Integer i -> { + yield i + 1; + } + default -> 0; + }; + } + + String nestedSwitch(Object outer, Object inner) { + return switch (outer) { + case String s -> switch (inner) { + case Integer i -> "String+Int: " + s + i; + default -> switch (inner) { + case Float f -> "String+Float: " + s + f; + default -> "String+Other"; + }; + }; + case Integer i -> "int"; + default -> "Other"; + }; + } + + int inMethod(Object obj) { + return new Builder().build(1).build(switch (obj) { + case String s when s.length() > 5 -> s.length() - 5; + case String s -> s.length(); + case Integer i -> i; + default -> obj.hashCode(); + }).build(4).n; + } + + void switchInIf(Object obj) { + if (switch (obj) { + case null -> -1; + case String s when s.length() > 5 -> s.length() - 5; + case String s -> s.length(); + case Integer i -> i; + default -> obj.hashCode(); + } > 0) { + System.out.println("Working"); + } + } + + // Record patterns in instanceof + void recordPatternInstanceof(Object obj) { + if (obj instanceof Point(int x, int y)) { + System.out.println(x + y); + } else if (obj instanceof Shape(String name, Triangle triangle) && name.equals("line")) { + System.out.println("A line cannot be a triangle!"); + } + } + + void nestedRecordPattern(Object obj) { + if (obj instanceof Shape(var name, Triangle(var p, Point(int bx, int by), Point(int cx, int cy)))) { + System.out.println("Shape " + name + ": " + + "(" + p.x() + "," + p.y() + "), (" + bx + "," + by + "), (" + cx + "," + cy + ")"); + } + } + + int recordPatternSwitchStatement(Geometry obj) { + return switch (obj) { + case Point(int x, int y) when x == 0 && y == 0: { + yield 0; + } + case Point(int x, int y): + yield x + y; + case Triangle(Point p, Point(int bx, int by), Point(int cx, int cy)): { + int sum = p.x() * p.y(); + sum += bx * by; + sum += cx * cy; + yield sum; + } + case Shape s: + yield 0; + // No default case since class is sealed. + }; + } + + void unconditionalPattern(String str) { + // Unconditional pattern - String is always a CharSequence (for non-null) + if (str instanceof CharSequence cs) { + System.out.println(cs.length()); + } + } + + // Unconditional with inferred type + void unconditionalPatternObject(T str) { + if (str instanceof Shape o) { + System.out.println(o.hashCode()); + } + } +} diff --git a/example/src/main/java/com/example/Java22FeaturesExample.java b/example/src/main/java/com/example/Java22FeaturesExample.java new file mode 100644 index 0000000..ff11c7c --- /dev/null +++ b/example/src/main/java/com/example/Java22FeaturesExample.java @@ -0,0 +1,48 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +import java.util.Arrays; +import java.util.List; + +/** + * Examples of Java 22 features desugared by the compiler:
+ * + * UNNAMED_VARIABLES + *
+ * // Source (Java 22+):
+ * for (var _ : list) { count++; }
+ * try { } catch (Exception _) { }
+ * list.forEach(_ -> count++);
+ *
+ * // Decompiled (Java 8):
+ * for (Object $unused : list) { count++; }
+ * try { } catch (Exception $unused) { }
+ * list.forEach($unused -> count++);
+ * 
+ */ +public class Java22FeaturesExample { + + void unnamedInForEach() { + var list = Arrays.asList("a", "b", "c"); + int count = 0; + + for (var _ : list) { + count++; + } + } + + void unnamedInCatch() { + try { + throw new RuntimeException(); + } catch (RuntimeException _) { + System.out.println("caught"); + } + } + + void unnamedInLambda() { + var list = Arrays.asList("a", "b", "c"); + var count = new int[]{0}; + list.forEach(_ -> count[0]++); + } +} diff --git a/example/src/main/java/com/example/Java25FeaturesExample.java b/example/src/main/java/com/example/Java25FeaturesExample.java new file mode 100644 index 0000000..3c64375 --- /dev/null +++ b/example/src/main/java/com/example/Java25FeaturesExample.java @@ -0,0 +1,100 @@ +// Examples made by Claude Opus 4.6 + +package com.example; + +/** + * Examples of Java 25 flexible constructors with manual desugaring:
+ * + * FLEXIBLE_CONSTRUCTORS + *
+ * // Source (Java 25+):
+ * class Child extends Parent {
+ *     Child(int value) {
+ *         if (value < 0) throw new IllegalArgumentException();
+ *         int processed = value * 2;
+ *         super(processed);
+ *     }
+ * }
+ *
+ * // Decompiled (Java 8):
+ * class Child extends Parent {
+ *     Child(int value) {
+ *         super($prologue$0(value));
+ *     }
+ *     private static int $prologue$0(int value) {
+ *         if (value < 0) throw new IllegalArgumentException();
+ *         return value * 2;
+ *     }
+ * }
+ * 
+ */ +public class Java25FeaturesExample { + String str; + + void main() { + System.out.println("Bridged entry point"); + str = ""; + } + + static class Parent { + final int value; + Parent(int value) { this.value = value; } + } + + // Complex loop in prologue + static class ComplexPrologue extends Parent { + ComplexPrologue(int n) { + if (n < 0) throw new IllegalArgumentException(); + int sum = 0; + for (int i = 0; i < n; i++) { + if (i % 2 == 0) { + sum += i; + } + } + super(sum > 100 ? 100 : sum); + } + } + + // Epilogue reuses multiple prologue variables + static class ManyLocals extends Parent { + final String label; + final double ratio; + ManyLocals(int a, int b) { + String lbl = a + "/" + b; + double r = b != 0 ? (double) a / b : 0.0; + int sum = a + b; + super(sum); + this.label = lbl; + this.ratio = r; + } + } + + // Existing Object[] / (Object[], Void) constructors must not clash + static class SpoofChild extends Parent { + boolean spoofed; + + SpoofChild(int value) { + super(value); + } + + // Spoof: (Object[]) + SpoofChild(Object[] args) { + boolean s = false; + if (args.length > 0) s = true; + this(0); + spoofed = s; + } + + // Spoof: (Object[], Void) + SpoofChild(Object[] args, Void v) { + this((int)(Integer) args[0]); + this.spoofed = true; + } + + // Another flex in the same spoofed class + SpoofChild(String s) { + int parsed = Integer.parseInt(s); + super(parsed); + } + } +} \ No newline at end of file diff --git a/example/src/main/java/com/example/Java25FeaturesExample2.java b/example/src/main/java/com/example/Java25FeaturesExample2.java new file mode 100644 index 0000000..1d8678f --- /dev/null +++ b/example/src/main/java/com/example/Java25FeaturesExample2.java @@ -0,0 +1,49 @@ +/** + * Examples of Java 25 features with manual desugaring:
+ * IMPLICIT_CLASSES + *
+ * // Source (Java 25+):
+ * // File: Main.java (no class declaration)
+ * void main() {
+ *     System.out.println("Hello");
+ * }
+ *
+ * // Decompiled (Java 8):
+ * public class Main {
+ *     public static void main(String[] args) {
+ *         new Main().main();
+ *     }
+ *     void main() {
+ *         System.out.println("Hello");
+ *     }
+ * }
+ * 
+ */ + +int variable = 1; + +static void main() { + System.out/*IO*/.println("Implicit classes work!"); + new ArrayList<>(); // Implicitly imported +} + +// Cannot be adapted as a standard entry point +void main(String... args) { + System.out/*IO*/.println("instance main with args."); + new Test(); +} + +private void privateMethod() { + main(); +} + + +interface TestInterface { + +} + +class Test implements TestInterface { + public Test() { + Integer.toString(variable); + } +} diff --git a/example/src/main/java/com/example/Java9FeaturesExample.java b/example/src/main/java/com/example/Java9FeaturesExample.java new file mode 100644 index 0000000..e563600 --- /dev/null +++ b/example/src/main/java/com/example/Java9FeaturesExample.java @@ -0,0 +1,77 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +import java.io.Closeable; +import java.util.ArrayList; +import java.util.List; + +/** + * Examples of Java 9 features desugared by the compiler:
+ * + * PRIVATE_SAFE_VARARGS + *
+ * // Source (Java 9+):
+ * @SafeVarargs
+ * private void method(List<String>... lists) { }
+ *
+ * // Decompiled (Java 8): same code, annotation preserved
+ * 
+ *

+ * DIAMOND_WITH_ANONYMOUS_CLASS_CREATION + *

+ * // Source (Java 9+):
+ * List<String> list = new ArrayList<>() { };
+ *
+ * // Decompiled (Java 8):
+ * List<String> list = new ArrayList<String>() { };
+ * 
+ *

+ * EFFECTIVELY_FINAL_VARIABLES_IN_TRY_WITH_RESOURCES + *

+ * // Source (Java 9+):
+ * Closeable c = ...;
+ * try (c) { }
+ *
+ * // Decompiled (Java 8):
+ * Closeable c = ...;
+ * try (Closeable c2 = c) { }
+ * 
+ *

+ * PRIVATE_INTERFACE_METHODS + *

+ * // Source (Java 9+):
+ * interface I {
+ *     private String helper() { return "!"; }
+ *     default String greet() { return "Hi" + helper(); }
+ * }
+ *
+ * // Decompiled (Java 8): same code (private methods in interfaces supported in bytecode)
+ * 
+ */ +public class Java9FeaturesExample { + + @SafeVarargs + private final void safeVarargsMethod(List... lists) { + for (List list : lists) System.out.println(list); + } + + List diamondWithAnonymous = new ArrayList<>() { + @Override + public boolean add(String s) { + return super.add(s.toUpperCase()); + } + }; + + void effectivelyFinalTryWithResources() throws Exception { + Closeable resource = () -> System.out.println("closed"); + try (resource) { + System.out.println("using resource"); + } + } + + interface WithPrivateMethods { + private String helper() { return "!"; } + default String greet(String name) { return "Hello " + name + helper(); } + } +} diff --git a/example/src/main/java/com/example/Main.java b/example/src/main/java/com/example/Main.java new file mode 100644 index 0000000..ce1134b --- /dev/null +++ b/example/src/main/java/com/example/Main.java @@ -0,0 +1,86 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +/** Main class to demonstrate all Java 9-25 features desugared by Jabel. */ +class Main { + public static void main(String[] args) { + System.out.println("=== Jabel Feature Examples ==="); + + System.out.println("\n--- Java 9 Features ---"); + Java9FeaturesExample java9 = new Java9FeaturesExample(); + java9.diamondWithAnonymous.add("test"); + System.out.println("Diamond with anonymous: " + java9.diamondWithAnonymous); + Java9FeaturesExample.WithPrivateMethods impl = new Java9FeaturesExample.WithPrivateMethods() {}; + System.out.println("Private interface method: " + impl.greet("World")); + + System.out.println("\n--- Java 10 Features ---"); + Java10FeaturesExample java10 = new Java10FeaturesExample(); + java10.varExamples(); + System.out.println("var keyword works!"); + + System.out.println("\n--- Java 11 Features ---"); + Java11FeaturesExample java11 = new Java11FeaturesExample(); + System.out.println("var in lambda: " + java11.concat.apply("hello", "world")); + + System.out.println("\n--- Java 14 Features ---"); + var java14 = new Java14FeaturesExample(); + System.out.println("Switch expression: " + java14.switchMultipleLabels(Java14FeaturesExample.Day.SAT)); + System.out.println("Switch with yield: " + java14.switchExpressionWithYield(Java14FeaturesExample.Day.WED)); + java14.noDefaultInjectedCase(16); + + System.out.println("\n--- Java 15 Features ---"); + var java15 = new Java15FeaturesExample(); + System.out.println("Text block: " + java15.basic.replace("\n", "\\n")); + + System.out.println("\n--- Java 16 Features ---"); + var java16 = new Java16FeaturesExample(); + java16.patternMatchingInstanceof("Hello"); + var point = new Java16FeaturesExample.Point(3, 4); + System.out.println("Record: " + point + " (" + point.x() + ", " + point.y() + ")"); + + System.out.println("\n--- Java 17 Features ---"); + var java17 = new Java17FeaturesExample(); + System.out.println("sealed classes works!"); + System.out.println("strictfp method: " + java17.strictMethod(3.14, 2.71)); + + System.out.println("\n--- Java 21 Features ---"); + var java21 = new Java21FeaturesExample(); + System.out.println("case null: " + java21.caseNull(null)); + System.out.println("pattern switch: " + java21.patternSwitch("test")); + System.out.println("pattern switch statement: " + java21.patternSwitchVariable("test")); + System.out.println("record pattern: " + java21.recordPatternSwitchStatement(new Java21FeaturesExample.Point(5, 7))); + System.out.println("switch yield: " + java21.switchYield("hello")); + + System.out.println("\n--- Java 22 Features ---"); + var java22 = new Java22FeaturesExample(); + java22.unnamedInForEach(); + System.out.println("Unnamed variables work!"); + + System.out.println("\n--- Java 25 Features ---"); + // Inline: complex prologue + var dp = new Java25FeaturesExample.ComplexPrologue(10); + assert dp.value == 20 : "ComplexPrologue"; // 0+2+4+6+8 = 20 + System.out.println("ComplexPrologue: value=" + dp.value); + // General: multiple shared vars in epilogue + var ms = new Java25FeaturesExample.ManyLocals(10, 4); + assert ms.value == 14 && "10/4".equals(ms.label) && ms.ratio == 2.5 : "ManyLocals"; + System.out.println("ManyLocals: value=" + ms.value + " label=" + ms.label + " ratio=" + ms.ratio); + // Edge: tests constructor collisions + var sp = new Java25FeaturesExample.SpoofChild(99); + assert sp.value == 99 && !sp.spoofed : "SpoofChild"; + System.out.println("SpoofChild: value=" + sp.value + " spoofed=" + sp.spoofed); + //Implicit classes + // Need reflection since implicit classes cannot be referenced + try { + java.lang.reflect.Method main = Class.forName("Java25FeaturesExample2").getDeclaredMethod("main"); + main.setAccessible(true); + main.invoke(null); + } catch (Exception e) { + System.err.println("Implicit classes error: " + e.toString()); + return; + } + + System.out.println("\n=== All features work! ==="); + } +} diff --git a/example/src/main/java/com/example/RecordExample.java b/example/src/main/java/com/example/RecordExample.java index 3c36f36..f8be6dc 100644 --- a/example/src/main/java/com/example/RecordExample.java +++ b/example/src/main/java/com/example/RecordExample.java @@ -1,8 +1,6 @@ package com.example; -import com.github.bsideup.jabel.Desugar; -@Desugar public record RecordExample(int i, String s, long l, float f, double d, String[] arr, boolean b) { public static RecordExample DUMMY = new RecordExample(0, null, 0, 0, 0, null, true); diff --git a/example/src/test/java/com/example/JavaFeaturesTest.java b/example/src/test/java/com/example/JavaFeaturesTest.java new file mode 100644 index 0000000..0303a0b --- /dev/null +++ b/example/src/test/java/com/example/JavaFeaturesTest.java @@ -0,0 +1,13 @@ +package com.example; + +import org.junit.Test; + +import static org.junit.Assert.*; + +public class JavaFeaturesTest { + + @Test + public void shouldWork() { + Main.main(new String[0]); + } +} \ No newline at end of file diff --git a/example/src/test/java/com/example/RecordExampleTest.java b/example/src/test/java/com/example/RecordExampleTest.java index 6e01cfb..5d2f290 100644 --- a/example/src/test/java/com/example/RecordExampleTest.java +++ b/example/src/test/java/com/example/RecordExampleTest.java @@ -21,7 +21,7 @@ public void testToString() { RecordExample r = new RecordExample(42, "yeah", 100500, 0.5f, 5d, new String[]{"Hello", "World!"}, true); assertEquals( - "RecordExample[i=42,s=yeah,l=100500,f=0.5,d=5.0,arr=[Ljava.lang.String;@hash,b=true]", + "RecordExample[i=42, s=yeah, l=100500, f=0.5, d=5.0, arr=[Ljava.lang.String;@hash, b=true]", Objects.toString(r).replaceAll(";@[a-f0-9]+", ";@hash") ); } @@ -31,7 +31,7 @@ public void testToStringWithNulls() { RecordExample r = new RecordExample(42, null, 100500, 0.5f, 5d, null, true); assertEquals( - "RecordExample[i=42,s=null,l=100500,f=0.5,d=5.0,arr=null,b=true]", + "RecordExample[i=42, s=null, l=100500, f=0.5, d=5.0, arr=null, b=true]", Objects.toString(r) ); } diff --git a/gradle/publishing.gradle b/gradle/publishing.gradle index ffc3b3b..cb1e056 100644 --- a/gradle/publishing.gradle +++ b/gradle/publishing.gradle @@ -3,7 +3,7 @@ plugins.withType(MavenPublishPlugin) { publications { mavenJava(MavenPublication) { publication -> pom { - description = 'Jabel - use modern Java 9-14 syntax when targeting Java 8.' + description = 'Jabel - use modern Java 9-25 syntax when targeting Java 8.' name = project.description ?: description url = 'https://github.com/bsideup/jabel' licenses { diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 8049c68..03b32a2 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.5-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.4-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/jabel-javac-plugin/build.gradle b/jabel-javac-plugin/build.gradle index 1725fc5..860c3e9 100644 --- a/jabel-javac-plugin/build.gradle +++ b/jabel-javac-plugin/build.gradle @@ -6,16 +6,16 @@ plugins { sourceCompatibility = targetCompatibility = 8 dependencies { - implementation platform('net.bytebuddy:byte-buddy-parent:1.14.9') + implementation platform('net.bytebuddy:byte-buddy-parent:1.18.7') implementation 'net.bytebuddy:byte-buddy' implementation 'net.bytebuddy:byte-buddy-agent' - implementation 'net.java.dev.jna:jna:5.13.0' + implementation 'net.java.dev.jna:jna:5.18.0' + implementation 'net.java.dev.jna:jna-platform:5.18.0' } - task sourcesJar(type: Jar) { - classifier 'sources' from sourceSets.main.allJava + archiveClassifier = 'sources' } javadoc { @@ -24,7 +24,7 @@ javadoc { task javadocJar(type: Jar) { from javadoc - classifier = 'javadoc' + archiveClassifier = 'javadoc' } publishing { diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/Desugar.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/Desugar.java deleted file mode 100644 index 1a11531..0000000 --- a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/Desugar.java +++ /dev/null @@ -1,14 +0,0 @@ -package com.github.bsideup.jabel; - -import java.lang.annotation.Documented; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; - -import static java.lang.annotation.ElementType.TYPE; - -@Documented -@Retention(RetentionPolicy.SOURCE) -@Target(value=TYPE) -public @interface Desugar { -} diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/FlexibleMainRetrofittingTaskListener.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/FlexibleMainRetrofittingTaskListener.java new file mode 100644 index 0000000..826f2e1 --- /dev/null +++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/FlexibleMainRetrofittingTaskListener.java @@ -0,0 +1,218 @@ +package com.github.bsideup.jabel; + +import java.util.*; + +import javax.tools.*; + +import com.sun.source.util.*; +import com.sun.tools.javac.code.*; +import com.sun.tools.javac.tree.*; +import com.sun.tools.javac.tree.JCTree.*; +import com.sun.tools.javac.util.*; +import com.sun.tools.javac.util.List; + + +/** + * Will adapts flexible entry points + * ({@code void main();, void main(String[]);, static void main();, + * static void main(String[])}) by attempting to create a bridge. + *

+ * This is pretty limited as non-static main entry points with arguments cannot + * be adapted, due to a signature duplication with the static one. + *

+ * Moreover, classes containing multiple entry points with ones that are + * private, will not work properly on older JVMs.
+ * For example this code: + * + *

{@code
+ * class Main {
+ *     static void main() {
+ *     }
+ *
+ *     private static void main(String[] args) {
+ *     }
+ * }
+ * }
+ * + * The first main method will be selected by Java 25's JVM, because the standard + * one is private, whereas on an older JVM, an error will be displayed.
+ * And Jabel cannot create a bridge for the first main, because there would be a + * signature duplication with the second declared main. + */ +public class FlexibleMainRetrofittingTaskListener implements TaskListener { + final TreeMaker make; + final Names names; + final Symtab syms; + final Log log; + final JCDiagnostic.Factory diagFactory; + final Name mainName; + + FlexibleMainRetrofittingTaskListener(Context context) { + make = TreeMaker.instance(context); + names = Names.instance(context); + syms = Symtab.instance(context); + log = Log.instance(context); + diagFactory = JCDiagnostic.Factory.instance(context); + mainName = names.fromString("main"); // syms.main; + + // Proper way to make warnings + JavacMessages.instance(context).add(locale -> new ResourceBundle() { + final Map keys = new HashMap<>(2); + { + // Act like it's linting + keys.put( + "jabel.warn.possible.signature.duplication", + "[jabel] possible entry point cannot be adapted " + + "due to a signature duplication. " + + "''{0}'' cannot therefore be used as an entry point in a JVM bellow Java25." + ); + keys.put( + "jabel.warn.no.default.constructor.found", + "[jabel] possible entry point cannot be adapted " + + "because no instanciable default constructor was found. " + + "''{0}'' cannot therefore be used as an entry point in a JVM bellow Java25." + ); + } + + @Override + protected Object handleGetObject(String key) { + return keys.get(key); + } + + @Override + public Enumeration getKeys() { + return Collections.enumeration(keys.keySet()); + } + }); + } + + @Override + public void started(TaskEvent e) { + if (e.getKind() != TaskEvent.Kind.ENTER) return; + if (!(e.getCompilationUnit() instanceof JCCompilationUnit)) return; + JCCompilationUnit jcu = (JCCompilationUnit) e.getCompilationUnit(); + JavaFileObject old = log.useSource(jcu.sourcefile); + + for (JCTree def : jcu.defs) { + if (!(def instanceof JCClassDecl)) continue; + transformClass((JCClassDecl) def); + } + + log.useSource(old); + } + + @Override + public void finished(TaskEvent e) { + } + + public void transformClass(JCClassDecl classDecl) { + make.at(classDecl.pos); + + // Slots: [0]=static+args, [1]=static, [2]=instance+args, [3]=instance + JCMethodDecl[] mains = new JCMethodDecl[4]; + boolean canInstanciate = false, hasConstructor = false; + int ep = mains.length; + + for (JCTree def : classDecl.defs) { + if (!(def instanceof JCMethodDecl)) continue; + JCMethodDecl method = (JCMethodDecl) def; + + if (!isMain(method)) { + if (method.name == names.init) { + hasConstructor = true; + if (!isPrivate(method) && method.params.isEmpty()) { + canInstanciate = true; + } + } + continue; + } + + int slot = (isStatic(method) ? 0 : 2) | (method.params.isEmpty() ? 1 : 0); + if (mains[slot] != null) continue; + mains[slot] = method; + if (slot < ep && !isPrivate(method)) ep = slot; + } + + // To avoid a signature duplication + if (mains[0] == null && mains[2] != null) ep = 2; + + switch (ep) { + case 1: + addMainBridge(classDecl, false); + break; + case 2: + warn(mains[ep], "possible.signature.duplication", classDecl.name); + break; + case 3: + if (canInstanciate || !hasConstructor) addMainBridge(classDecl, true); + else warn(mains[ep], "no.default.constructor.found", classDecl.name); + } + } + + public boolean isPrivate(JCMethodDecl m) { + return (m.mods.flags & Flags.PRIVATE) != 0; + } + + public boolean isStatic(JCMethodDecl m) { + return (m.mods.flags & Flags.STATIC) != 0; + } + + private void warn(JCMethodDecl method, String key, Object arg) { + log.report(diagFactory.create(log.currentSource(), method, new JCDiagnostic.Warning("jabel", key, arg))); + } + + public boolean isMain(JCMethodDecl method) { + if (mainName != method.name) return false; + if (!(method.restype instanceof JCPrimitiveTypeTree)) return false; + if (((JCPrimitiveTypeTree) method.restype).typetag != TypeTag.VOID) return false; + if (method.params.isEmpty()) return true; + if (method.params.size() != 1) return false; + JCTree vartype = method.params.get(0).vartype; + if(!(vartype instanceof JCArrayTypeTree)) return false; + // TODO find a better way? + switch(((JCArrayTypeTree)vartype).getType().toString()){ + case "java.lang.String": + case "String": + return true; + default: + return false; + } + } + + /** If {@code toLocalMain} is {@code true}, a zero-arg constructor must be present. */ + public JCMethodDecl makeMainBridge(JCClassDecl classDecl, boolean toLocalMain) { + return make.MethodDef( + make.Modifiers(Flags.PUBLIC | Flags.STATIC), + mainName, + make.TypeIdent(TypeTag.VOID), + List.nil(), + List.of(make.VarDef( + make.Modifiers(Flags.PARAMETER), + names.fromString("args"), + make.TypeArray(make.Type(syms.stringType)), + null + )), + List.nil(), + make.Block(0, List.of(make.Exec(make.Apply( + List.nil(), + make.Select( + toLocalMain ? make.NewClass( + null, + List.nil(), + make.Ident(classDecl.name), + List.nil(), + null + ) : make.Ident(classDecl.name), + mainName + ), + List.nil() + )))), + null + ); + } + + /** Creates a main bridge in the specified class. */ + public void addMainBridge(JCClassDecl classDecl, boolean toLocalMain) { + classDecl.defs = classDecl.defs.append(makeMainBridge(classDecl, toLocalMain)); + } +} \ No newline at end of file diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/ImplicitClassesFixerTaskListener.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/ImplicitClassesFixerTaskListener.java new file mode 100644 index 0000000..41eee1b --- /dev/null +++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/ImplicitClassesFixerTaskListener.java @@ -0,0 +1,80 @@ +package com.github.bsideup.jabel; + +import com.sun.source.util.*; +import com.sun.tools.javac.code.*; +import com.sun.tools.javac.tree.*; +import com.sun.tools.javac.tree.JCTree.*; +import com.sun.tools.javac.util.*; + + +/** + * Because implicit classes import the {@code java.base} module, and modules are + * not yet supported, we need to convert it by adding star import of all + * exported and existing packages. + */ +public class ImplicitClassesFixerTaskListener implements TaskListener { + final TreeMaker make; + final Names names; + final Symtab syms; + List javaBaseImports; + + ImplicitClassesFixerTaskListener(Context context) { + make = TreeMaker.instance(context); + names = Names.instance(context); + syms = Symtab.instance(context); + } + + @Override + public void started(TaskEvent e) { + if (e.getKind() != TaskEvent.Kind.ENTER) return; + if (!(e.getCompilationUnit() instanceof JCCompilationUnit)) return; + JCCompilationUnit jcu = (JCCompilationUnit) e.getCompilationUnit(); + + for (JCTree def : jcu.defs) { + if (!(def instanceof JCClassDecl)) continue; + JCClassDecl clazz = (JCClassDecl) def; + if ((clazz.mods.flags & Flags.IMPLICIT_CLASS) == 0) continue; + clazz.mods.flags &= ~Flags.IMPLICIT_CLASS; // Avoid implicit importation of java.base + injectJavaBaseImports(jcu); + } + } + + @Override + public void finished(TaskEvent e) { + } + + @SuppressWarnings("unchecked") + public void injectJavaBaseImports(JCCompilationUnit jcu) { + if (javaBaseImports == null) javaBaseImports = makeStarImports(getJavaBasePackages()); + // Duplicate imports are not important + jcu.defs = ((List) (List) javaBaseImports).appendList(jcu.defs); + } + + /** @return the list of packages that are exported by the {@code java.base} modules. */ + public List getJavaBasePackages() { + // Since modules are not enabled and initialized, {@code syms.java_base.exports} is not populated + return ModuleLayer.boot().findModule("java.base").map(mod -> + mod.getPackages().stream() + .filter(mod::isExported) + .map(p -> syms.enterPackage(syms.java_base, names.fromString(p))) + .filter(p -> { + try { + p.complete(); + } catch(Exception ignored) {} + return p.exists(); + }) + .collect(List.collector()) + ).orElse(List.nil()); + } + + public List makeStarImports(List packages) { + List imports = List.nil(); + for (Symbol.PackageSymbol pkg : packages) { + imports = imports.append(make.Import(make.Select( + make.QualIdent(pkg), + names.asterisk + ), false)); + } + return imports; + } +} \ No newline at end of file diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/JabelCompilerPlugin.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/JabelCompilerPlugin.java index 0a5aeb1..c3f86ae 100644 --- a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/JabelCompilerPlugin.java +++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/JabelCompilerPlugin.java @@ -1,64 +1,38 @@ package com.github.bsideup.jabel; -import com.sun.source.util.JavacTask; -import com.sun.source.util.Plugin; -import com.sun.tools.javac.api.BasicJavacTask; -import com.sun.tools.javac.code.Source; -import com.sun.tools.javac.util.Context; -import com.sun.tools.javac.util.JavacMessages; -import net.bytebuddy.ByteBuddy; -import net.bytebuddy.agent.ByteBuddyAgent; -import net.bytebuddy.asm.Advice; -import net.bytebuddy.asm.AsmVisitorWrapper; -import net.bytebuddy.description.field.FieldDescription; -import net.bytebuddy.description.field.FieldList; -import net.bytebuddy.description.method.MethodList; -import net.bytebuddy.description.type.TypeDescription; -import net.bytebuddy.dynamic.ClassFileLocator; -import net.bytebuddy.dynamic.loading.ClassInjector; -import net.bytebuddy.dynamic.loading.ClassReloadingStrategy; -import net.bytebuddy.dynamic.scaffold.MethodGraph; -import net.bytebuddy.implementation.Implementation; -import net.bytebuddy.jar.asm.ClassVisitor; -import net.bytebuddy.jar.asm.MethodVisitor; -import net.bytebuddy.jar.asm.Opcodes; -import net.bytebuddy.pool.TypePool; -import net.bytebuddy.utility.JavaModule; - import java.util.*; +import com.sun.source.util.*; +import com.sun.tools.javac.api.*; +import com.sun.tools.javac.code.*; +import com.sun.tools.javac.util.*; + +import net.bytebuddy.*; +import net.bytebuddy.agent.*; +import net.bytebuddy.asm.*; +import net.bytebuddy.description.type.*; +import net.bytebuddy.dynamic.*; +import net.bytebuddy.dynamic.loading.*; +import net.bytebuddy.dynamic.scaffold.*; +import net.bytebuddy.pool.*; +import net.bytebuddy.utility.*; + import static net.bytebuddy.matcher.ElementMatchers.*; -public class JabelCompilerPlugin implements Plugin { - static { - Map visitors = new HashMap() {{ - // Disable the preview feature check - AsmVisitorWrapper checkSourceLevelAdvice = Advice.to(CheckSourceLevelAdvice.class) - .on(named("checkSourceLevel").and(takesArguments(2))); - - // Allow features that were introduced together with Records (local enums, static inner members, ...) - AsmVisitorWrapper allowRecordsEraFeaturesAdvice = new FieldAccessStub("allowRecords", true); - - put("com.sun.tools.javac.parser.JavacParser", - new AsmVisitorWrapper.Compound( - checkSourceLevelAdvice, - allowRecordsEraFeaturesAdvice - ) - ); - put("com.sun.tools.javac.parser.JavaTokenizer", checkSourceLevelAdvice); - put("com.sun.tools.javac.comp.Check", allowRecordsEraFeaturesAdvice); - put("com.sun.tools.javac.comp.Attr", allowRecordsEraFeaturesAdvice); - put("com.sun.tools.javac.comp.Resolve", allowRecordsEraFeaturesAdvice); +public class JabelCompilerPlugin implements Plugin { + static final boolean JABEL_INITIALIZED = initJabel(); - // Lower the source requirement for supported features - put( - "com.sun.tools.javac.code.Source$Feature", - Advice.to(AllowedInSourceAdvice.class) - .on(named("allowedInSource").and(takesArguments(1))) - ); - }}; + @SuppressWarnings("resource") + private static boolean initJabel() { + // We cannot easily force features bellow Java 10.35 + try { + Class.forName("com.sun.tools.javac.code.Source$Feature"); + } catch (Exception e) { + return false; + } + // Install ByteBuddy try { ByteBuddyAgent.install(); } catch (Exception e) { @@ -73,60 +47,66 @@ public class JabelCompilerPlugin implements Plugin { ) ); } + ByteBuddy byteBuddy = new ByteBuddy().with(MethodGraph.Compiler.ForDeclaredMethods.INSTANCE); - ByteBuddy byteBuddy = new ByteBuddy() - .with(MethodGraph.Compiler.ForDeclaredMethods.INSTANCE); + // Hook classes ClassLoader classLoader = JavacTask.class.getClassLoader(); ClassFileLocator classFileLocator = ClassFileLocator.ForClassLoader.of(classLoader); TypePool typePool = TypePool.ClassLoading.of(classLoader); - - visitors.forEach((className, visitor) -> { - byteBuddy - .decorate( - typePool.describe(className).resolve(), - classFileLocator - ) - .visit(visitor) - .make() - .load(classLoader, ClassReloadingStrategy.fromInstalledAgent()); - }); - - JavaModule jabelModule = JavaModule.ofType(JabelCompilerPlugin.class); + TypeDescription clazz; + + // Lower features source level + clazz = typePool.describe("com.sun.tools.javac.code.Source$Feature").resolve(); + byteBuddy.decorate(clazz, classFileLocator) + .visit(Advice.to(AllowedInSourceAdvice.class).on(named("allowedInSource").and(takesArguments(1)))) + .make() + .load(classLoader, ClassReloadingStrategy.fromInstalledAgent()); + + // Force enable preview features and suppress its warnings + clazz = typePool.describe("com.sun.tools.javac.code.Preview").resolve(); + byteBuddy.decorate(clazz, classFileLocator) + .visit(Advice.to(IsEnabledAdvice.class).on(named("isEnabled").and(takesArguments(0)))) + .visit(Advice.to(IsPreviewAdvice.class).on(named("isPreview").and(takesArguments(1)))) + .visit(Advice.to(WarnPreviewAdvice.class).on(named("warnPreview"))) + .make() + .load(classLoader, ClassReloadingStrategy.fromInstalledAgent()); + + + // Open internal compiler packages + Set jabelModule = Collections.singleton(JavaModule.ofType(JabelCompilerPlugin.class)); ClassInjector.UsingInstrumentation.redefineModule( ByteBuddyAgent.getInstrumentation(), JavaModule.ofType(JavacTask.class), Collections.emptySet(), Collections.emptyMap(), - new HashMap>() {{ - put("com.sun.tools.javac.api", Collections.singleton(jabelModule)); - put("com.sun.tools.javac.tree", Collections.singleton(jabelModule)); - put("com.sun.tools.javac.code", Collections.singleton(jabelModule)); - put("com.sun.tools.javac.util", Collections.singleton(jabelModule)); + new HashMap>() {{ + put("com.sun.tools.javac.api", jabelModule); + put("com.sun.tools.javac.tree", jabelModule); + put("com.sun.tools.javac.code", jabelModule); + put("com.sun.tools.javac.comp", jabelModule); + put("com.sun.tools.javac.util", jabelModule); }}, Collections.emptySet(), Collections.emptyMap() ); + + return true; } @Override public void init(JavacTask task, String... args) { - Context context = ((BasicJavacTask) task).getContext(); - JavacMessages.instance(context).add(locale -> new ResourceBundle() { - @Override - protected Object handleGetObject(String key) { - return "{0}"; - } + // Useless to continue if Jabel was not initialized correctly + if (!JABEL_INITIALIZED) return; - @Override - public Enumeration getKeys() { - return Collections.enumeration(Arrays.asList("missing.desugar.on.record")); - } - }); + Context context = ((BasicJavacTask) task).getContext(); + removeUnderscoreWarnings(context); task.addTaskListener(new RecordsRetrofittingTaskListener(context)); - - System.out.println("Jabel: initialized"); + task.addTaskListener(new RecordPatternRetrofittingTaskListener(context)); + task.addTaskListener(new SwitchRetrofittingTaskListener(context)); + task.addTaskListener(new FlexibleMainRetrofittingTaskListener(context)); + task.addTaskListener(new ImplicitClassesFixerTaskListener(context)); } @Override @@ -134,82 +114,68 @@ public String getName() { return "jabel"; } - // Make it auto start on Java 14+ + /** Make it auto starts on Java 14+. */ + @Override public boolean autoStart() { return true; } - static class AllowedInSourceAdvice { + /** Removes warnings about {@code '_'}. */ + static void removeUnderscoreWarnings(Context context){ + Log.instance(context).new DiscardDiagnosticHandler(){ + @Override + public void report(JCDiagnostic diag){ + String code = diag.getCode(); + if (code.contains("underscore.as.identifier") || + code.contains("use.of.underscore.not.allowed")) return; + prev.report(diag); + } + }; + } + /** Makes all {@link Source.Feature} available in all source levels, except few ones. */ + static class AllowedInSourceAdvice { @Advice.OnMethodEnter static void allowedInSource( @Advice.This Source.Feature feature, @Advice.Argument(value = 0, readOnly = false) Source source ) { switch (feature.name()) { - case "PRIVATE_SAFE_VARARGS": - case "SWITCH_EXPRESSION": - case "SWITCH_RULE": - case "SWITCH_MULTIPLE_CASE_LABELS": - case "LOCAL_VARIABLE_TYPE_INFERENCE": - case "VAR_SYNTAX_IMPLICIT_LAMBDAS": - case "DIAMOND_WITH_ANONYMOUS_CLASS_CREATION": - case "EFFECTIVELY_FINAL_VARIABLES_IN_TRY_WITH_RESOURCES": - case "TEXT_BLOCKS": - case "PATTERN_MATCHING_IN_INSTANCEOF": - case "REIFIABLE_TYPES_INSTANCEOF": - case "RECORDS": + case "MODULES": // Extremely difficult as initialization is done very early + case "STRING_TEMPLATES": // Appeared on Java 21 and removed on Java 23 because of a confusing design + case "MODULE_IMPORTS": // Needs the modules system + case "JAVA_BASE_TRANSITIVE": // Needs the modules system + break; + default: //noinspection UnusedAssignment source = Source.DEFAULT; - break; } } } - static class CheckSourceLevelAdvice { - - @Advice.OnMethodEnter - static void checkSourceLevel( - @Advice.Argument(value = 1, readOnly = false) Source.Feature feature - ) { - if (feature.allowedInSource(Source.JDK8)) { - // This must be one of the cases from "AllowedInSourceAdvice" - //noinspection UnusedAssignment - feature = Source.Feature.PRIVATE_SAFE_VARARGS; - } + /** Makes {@link Preview#isEnabled()} always return {@code true}. */ + static class IsEnabledAdvice { + @Advice.OnMethodExit + static void isEnabled(@Advice.Return(readOnly = false) boolean result) { + //noinspection UnusedAssignment + result = true; } } - private static class FieldAccessStub extends AsmVisitorWrapper.AbstractBase { - - final String fieldName; - - final Object value; - - public FieldAccessStub(String fieldName, Object value) { - this.fieldName = fieldName; - this.value = value; + /** Makes {@link Preview#isPreview(Feature)} always return {@code false}. */ + static class IsPreviewAdvice { + @Advice.OnMethodExit + static void isPreview(@Advice.Return(readOnly = false) boolean result) { + //noinspection UnusedAssignment + result = false; } + } - @Override - public ClassVisitor wrap(TypeDescription instrumentedType, ClassVisitor classVisitor, Implementation.Context implementationContext, TypePool typePool, FieldList fields, MethodList methods, int writerFlags, int readerFlags) { - return new ClassVisitor(Opcodes.ASM9, classVisitor) { - @Override - public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) { - MethodVisitor methodVisitor = super.visitMethod(access, name, descriptor, signature, exceptions); - return new MethodVisitor(Opcodes.ASM9, methodVisitor) { - @Override - public void visitFieldInsn(int opcode, String owner, String name, String descriptor) { - if (opcode == Opcodes.GETFIELD && fieldName.equalsIgnoreCase(name)) { - super.visitInsn(Opcodes.POP); - super.visitLdcInsn(value); - } else { - super.visitFieldInsn(opcode, owner, name, descriptor); - } - } - }; - } - }; + /** Makes {@link Preview#warnPreview(DiagnosticPosition, Feature)} a no-op. */ + static class WarnPreviewAdvice { + @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class) + static boolean warnPreview() { + return true; // skip method body } } } diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordPatternRetrofittingTaskListener.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordPatternRetrofittingTaskListener.java new file mode 100644 index 0000000..d3cd568 --- /dev/null +++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordPatternRetrofittingTaskListener.java @@ -0,0 +1,155 @@ +package com.github.bsideup.jabel; + +import com.sun.source.util.*; +import com.sun.tools.javac.code.*; +import com.sun.tools.javac.code.Symbol.*; +import com.sun.tools.javac.code.Type.*; +import com.sun.tools.javac.comp.Operators; +import com.sun.tools.javac.tree.*; +import com.sun.tools.javac.tree.JCTree.*; +import com.sun.tools.javac.util.*; + + +/** + * Replaces {@link MatchException} catch clauses, generated by the compiler for + * record patterns unpacking, by {@link RuntimeException}. + */ +public class RecordPatternRetrofittingTaskListener implements TaskListener { + private boolean MATCH_EXCEPTION_PRESENT = SwitchRetrofittingTaskListener.MATCH_EXCEPTION_PRESENT; + + final TreeMaker make; + final Symtab syms; + final Names names; + final MethodSymbol runtimeExceptionConstructor; + final OperatorSymbol opCONCAT; + + public RecordPatternRetrofittingTaskListener(Context context) { + make = TreeMaker.instance(context); + syms = Symtab.instance(context); + names = Names.instance(context); + + runtimeExceptionConstructor = getRuntimeExceptionConstructor(); + opCONCAT = SwitchRetrofittingTaskListener.resolveBinary( + Operators.instance(context), + make.Literal(0), + Tag.PLUS, + syms.stringType, + syms.stringType + ); + } + + @Override + public void started(TaskEvent e) { + if (!MATCH_EXCEPTION_PRESENT) return; + if (e.getKind() != TaskEvent.Kind.GENERATE) return; + if (!(e.getCompilationUnit() instanceof JCCompilationUnit)) return; + new MatchExceptionFixer().translate((JCCompilationUnit) e.getCompilationUnit()); + } + + @Override + public void finished(TaskEvent e) { + if (!MATCH_EXCEPTION_PRESENT) return; + if (e.getKind() != TaskEvent.Kind.ANALYZE) return; + // Creates the virtual symbol after analyze, so the user-code cannot reference it + injectMatchException(); + } + + public class MatchExceptionFixer extends TreeTranslator { + @Override + public void visitBlock(JCBlock tree) { + JCCatch c = SwitchRetrofittingTaskListener.getPatternMatchingCatchHandler(tree); + if (c != null) { + if (isMatchException(c.param.vartype)) { + c.param.vartype = make.QualIdent(syms.runtimeExceptionType.tsym) + .setType(syms.runtimeExceptionType); + if (c.param.sym != null) c.param.sym.type = syms.runtimeExceptionType; + } + JCTree last = result; + c.body = translate(c.body); + result = last; + } + super.visitBlock(tree); + } + + @Override + public void visitNewClass(JCNewClass tree) { + super.visitNewClass(tree); + if (!isMatchException(tree.clazz)) return; + + if (runtimeExceptionConstructor != null) { + tree.clazz = make.QualIdent(syms.runtimeExceptionType.tsym).setType(syms.runtimeExceptionType); + tree.type = syms.runtimeExceptionType; + tree.constructor = runtimeExceptionConstructor; + } + if (opCONCAT != null) { + tree.args.head = make.Binary( + Tag.PLUS, + make.Literal("MatchException: Record-pattern unpacking failed: "), + tree.args.head + ).setType(syms.stringType); + ((JCBinary) tree.args.head).operator = opCONCAT; + } else { + tree.args.head = make.Literal("MatchException: record-pattern unpacking failed"); + } + } + } + + public boolean isMatchException(JCTree t) { + if (t == null) return false; + // TODO find a better way? + switch (t.toString()) { + case "java.lang.MatchException": + case "MatchException": + return true; + default: + return false; + } + } + + /** + * Registers a virtual {@link MatchException} symbol to make the compiler happy. + */ + public void injectMatchException() { + ClassSymbol cs = (ClassSymbol) syms.matchExceptionType.tsym; + if (cs == null) return; + if (cs.classfile != null) { + // Do nothing if already exists in the classpath to avoid changing the user code + MATCH_EXCEPTION_PRESENT = false; + return; + } + if (cs.members_field != null && !cs.members_field.isEmpty()) return; + + // Fill the current symbol + cs.flags_field = Flags.PUBLIC; + ClassType ct = new ClassType(Type.noType, List.nil(), cs); + ct.supertype_field = syms.runtimeExceptionType; + cs.type = ct; + cs.members_field = Scope.WriteableScope.create(cs); + + // Only create the constructor used by TransPatterns + cs.members_field.enter(new MethodSymbol( + Flags.PUBLIC, + names.init, + new MethodType( + List.of(syms.stringType, syms.throwableType), + syms.voidType, + List.nil(), + syms.methodClass + ), + cs + )); + cs.completer = Completer.NULL_COMPLETER; + } + + public MethodSymbol getRuntimeExceptionConstructor() { + for (Symbol sym : syms.runtimeExceptionType.tsym.members().getSymbols()) { + if (sym.kind != Kinds.Kind.MTH) continue; + if (sym.name != names.init) continue; + MethodSymbol m = (MethodSymbol) sym; + if (m.params().size() != 2) continue; + // TODO: check for actual types? + return m; + } + return null; + } +} diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordsRetrofittingTaskListener.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordsRetrofittingTaskListener.java index cbadd4a..0fcd6b2 100644 --- a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordsRetrofittingTaskListener.java +++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordsRetrofittingTaskListener.java @@ -1,191 +1,163 @@ package com.github.bsideup.jabel; -import com.sun.source.tree.ClassTree; -import com.sun.source.tree.CompilationUnitTree; -import com.sun.source.util.TaskEvent; -import com.sun.source.util.TaskListener; +import java.util.Iterator; +import java.util.stream.*; + +import javax.lang.model.element.Modifier; + +import com.sun.source.tree.*; +import com.sun.source.util.*; import com.sun.source.util.TreeScanner; import com.sun.tools.javac.code.*; -import com.sun.tools.javac.tree.JCTree; -import com.sun.tools.javac.tree.TreeMaker; +import com.sun.tools.javac.tree.*; +import com.sun.tools.javac.tree.JCTree.*; import com.sun.tools.javac.util.*; -import javax.lang.model.element.Modifier; -import javax.tools.JavaFileObject; -import java.util.Iterator; -import java.util.stream.Stream; - -class RecordsRetrofittingTaskListener implements TaskListener { +/** + * Will generate {@code hashCode()}, {@code equals()} and {@code toString()} + * methods, and remove {@link Flags#RECORD}. + */ +public class RecordsRetrofittingTaskListener implements TaskListener { final TreeMaker make; - final Symtab syms; - + final Types types; final Names names; - final Log log; - - TreeScanner recordsScanner = new TreeScanner() { - @Override - public Void visitClass(ClassTree node, Void aVoid) { - if (!"RECORD".equals(node.getKind().toString())) { - return super.visitClass(node, aVoid); - } + public RecordsRetrofittingTaskListener(Context context) { + make = TreeMaker.instance(context); + syms = Symtab.instance(context); + types = Types.instance(context); + names = Names.instance(context); + } - JCTree.JCClassDecl classDecl = (JCTree.JCClassDecl) node; + @Override + public void started(TaskEvent e) { + if (e.getKind() != TaskEvent.Kind.ENTER) return; + new RecordsScanner().scan(e.getCompilationUnit(), false); + } - if (classDecl.extending == null) { - // Prevent implicit "extends java.lang.Record" - classDecl.extending = make.Type(syms.objectType); - } + /** Remove {@link Flags#RECORD} to avoid invalid ASM reading. */ + @Override + public void finished(TaskEvent e) { + if (e.getKind() != TaskEvent.Kind.ANALYZE) return; + new RecordsScanner().scan(e.getCompilationUnit(), true); + } - { - Name methodName = names.toString; - List argTypes = List.nil(); - if (!containsMethod(classDecl, methodName)) { - JCTree.JCMethodDecl methodDecl = make.MethodDef( - new Symbol.MethodSymbol( - Flags.PUBLIC, - methodName, - new Type.MethodType( - argTypes, - syms.stringType, - List.nil(), - syms.methodClass - ), - syms.objectType.tsym - ), - make.Block(0, generateToString(classDecl)) - ); - classDecl.defs = classDecl.defs.append(methodDecl); - } - } + public class RecordsScanner extends TreeScanner { + @Override + public Void visitClass(ClassTree node, Boolean endPhase) { + if ("RECORD".equals(node.getKind().toString())) { + JCClassDecl classDecl = (JCClassDecl) node; - { - Name methodName = names.hashCode; - List argTypes = List.nil(); - if (!containsMethod(classDecl, methodName)) { - classDecl.defs = classDecl.defs.append(make.MethodDef( - new Symbol.MethodSymbol( - Flags.PUBLIC, - methodName, - new Type.MethodType( - argTypes, - syms.intType, - List.nil(), - syms.methodClass - ), - syms.objectType.tsym - ), - make.Block(0, generateHashCode(classDecl)) - )); - } - } + if (endPhase != null && endPhase) { + if (classDecl.sym != null) { + classDecl.sym.flags_field &= ~Flags.RECORD; + } - { - Name methodName = names.equals; - List argTypes = List.of(syms.objectType); - if (!containsMethod(classDecl, methodName)) { - Symbol.MethodSymbol methodSymbol = new Symbol.MethodSymbol( - Flags.PUBLIC | Flags.FINAL, - methodName, - new Type.MethodType( - argTypes, - syms.booleanType, - List.nil(), - syms.methodClass - ), - syms.objectType.tsym - ); - Symbol.VarSymbol firstParameter = methodSymbol.params().head; - - JCTree.JCMethodDecl methodDecl = make.MethodDef( - methodSymbol, - make.Block(0, generateEquals(classDecl, firstParameter.name)) - ); - // THIS ONE IS IMPORTANT! Otherwise, Flow.AssignAnalyzer#visitVarDef will have track=false - methodDecl.params.head.pos = classDecl.pos; - classDecl.defs = classDecl.defs.append(methodDecl); + } else { + if (classDecl.extending == null) { + // Prevent implicit "extends java.lang.Record" + classDecl.extending = make.Type(syms.objectType); + } + generateToStringIfNeeded(classDecl); + generateHashcodeIfNeeded(classDecl); + generateEqualsIfNeeded(classDecl); } } - return super.visitClass(node, aVoid); + return super.visitClass(node, endPhase); } + } - private boolean containsMethod(JCTree.JCClassDecl classDecl, Name name) { - return classDecl.defs.stream() - .filter(JCTree.JCMethodDecl.class::isInstance) - .map(JCTree.JCMethodDecl.class::cast) - .anyMatch(def -> { - if (def.getName() != name) { - return false; - } - - if (name == names.equals) { - if (def.params.size() != 1) { - return false; - } - - // TODO find a better way? - JCTree.JCVariableDecl param = def.params.get(0); - switch (param.getType().toString()) { - case "java.lang.Object": - case "Object": - return true; - default: - return false; - } - } - - return true; - }); - } - }; + public void generateToStringIfNeeded(JCClassDecl classDecl) { + if (containsMethod(classDecl, names.toString)) return; + classDecl.defs = classDecl.defs.append(make.MethodDef( + new Symbol.MethodSymbol( + Flags.PUBLIC | Flags.FINAL, + names.toString, + new Type.MethodType( + List.nil(), + syms.stringType, + List.nil(), + syms.methodClass + ), + syms.objectType.tsym + ), + make.Block(0, generateToString(classDecl)) + )); + } - public RecordsRetrofittingTaskListener(Context context) { - make = TreeMaker.instance(context); - syms = Symtab.instance(context); - names = Names.instance(context); - log = Log.instance(context); + public void generateHashcodeIfNeeded(JCClassDecl classDecl) { + if (containsMethod(classDecl, names.hashCode)) return; + classDecl.defs = classDecl.defs.append(make.MethodDef( + new Symbol.MethodSymbol( + Flags.PUBLIC | Flags.FINAL, + names.hashCode, + new Type.MethodType( + List.nil(), + syms.intType, + List.nil(), + syms.methodClass + ), + syms.objectType.tsym + ), + make.Block(0, generateHashCode(classDecl)) + )); } - @Override - public void started(TaskEvent e) { - switch (e.getKind()) { - case ENTER: - recordsScanner.scan(e.getCompilationUnit(), null); - new TreeScanner() { - @Override - public Void visitClass(ClassTree node, Void aVoid) { - if ("RECORD".equals(node.getKind().toString())) { - JCTree.JCClassDecl classDecl = (JCTree.JCClassDecl) node; - - if (classDecl.extending == null) { - // Prevent implicit "extends java.lang.Record" - classDecl.extending = make.Type(syms.objectType); - } - } - return super.visitClass(node, aVoid); - } - }.scan(e.getCompilationUnit(), null); - break; - case ANALYZE: - new MandatoryDesugarAnnotationTreeScanner(log, e.getCompilationUnit()).scan(e.getCompilationUnit(), null); - } + public void generateEqualsIfNeeded(JCClassDecl classDecl) { + if (containsMethod(classDecl, names.equals)) return; + Symbol.MethodSymbol methodSymbol = new Symbol.MethodSymbol( + Flags.PUBLIC | Flags.FINAL, + names.equals, + new Type.MethodType( + List.of(syms.objectType), + syms.booleanType, + List.nil(), + syms.methodClass + ), + syms.objectType.tsym + ); + JCTree.JCMethodDecl methodDecl = make.MethodDef( + methodSymbol, + make.Block(0, generateEquals( + classDecl, + methodSymbol.params().head.name + )) + ); + + // THIS ONE IS IMPORTANT! Otherwise, Flow.AssignAnalyzer#visitVarDef will have track=false + methodDecl.params.head.pos = classDecl.pos; + classDecl.defs = classDecl.defs.append(methodDecl); } - @Override - public void finished(TaskEvent e) { + /** Can only search for a method with no or one argument. */ + private boolean containsMethod(JCClassDecl classDecl, Name name) { + for (JCTree next : classDecl.defs) { + if (!(next instanceof JCMethodDecl)) continue; + JCMethodDecl def = (JCMethodDecl) next; + if (def.getName() != name) continue; + if (name != names.equals) return true; + if (def.params.size() != 1) continue; + // TODO find a better way? + switch(def.params.get(0).getType().toString()){ + case "java.lang.Object": + case "Object": + return true; + } + } + return false; } - private Stream getRecordComponents(JCTree.JCClassDecl classDecl) { + public Stream getRecordComponents(JCClassDecl classDecl) { return classDecl.getMembers().stream() - .filter(JCTree.JCVariableDecl.class::isInstance) - .map(JCTree.JCVariableDecl.class::cast) - .filter(it -> !it.getModifiers().getFlags().contains(Modifier.STATIC)); + .filter(JCVariableDecl.class::isInstance) + .map(JCVariableDecl.class::cast) + .filter(it -> !it.getModifiers().getFlags().contains(Modifier.STATIC)); } - private List generateToString(JCTree.JCClassDecl classDecl) { - JCTree.JCExpression stringBuilder = make.NewClass( + public List generateToString(JCClassDecl classDecl){ + JCExpression stringBuilder = make.NewClass( null, null, make.QualIdent(syms.stringBuilderType.tsym), @@ -193,303 +165,220 @@ private List generateToString(JCTree.JCClassDecl classDecl) null ); - for ( - Iterator iterator = getRecordComponents(classDecl).iterator(); - iterator.hasNext(); - ) { - JCTree.JCVariableDecl fieldDecl = iterator.next(); + for(Iterator iterator = getRecordComponents(classDecl).iterator(); iterator.hasNext();){ + JCVariableDecl fieldDecl = iterator.next(); Name fieldName = fieldDecl.name; - stringBuilder = make.App( - make.Select(stringBuilder, names.append).setType(syms.stringBuilderType), - List.of(make.Literal(fieldName + "=")) - ); - - stringBuilder = make.App( - make.Select(stringBuilder, names.append).setType(syms.stringBuilderType), - List.of( - make.Select( - make.This(Type.noType), - fieldName - ) - ) - ); - - if (iterator.hasNext()) { - stringBuilder = make.App( - make.Select(stringBuilder, names.append).setType(syms.stringBuilderType), - List.of(make.Literal(",")) - ); + stringBuilder = stringAppend(stringBuilder, make.Literal(fieldName + "=")); + stringBuilder = stringAppend(stringBuilder, make.Select(make.This(Type.noType), fieldName)); + if(iterator.hasNext()){ + stringBuilder = stringAppend(stringBuilder, make.Literal(", ")); } } + stringBuilder = stringAppend(stringBuilder, make.Literal("]")); - stringBuilder = make.App( - make.Select(stringBuilder, names.append).setType(syms.stringBuilderType), - List.of(make.Literal("]")) - ); + return List.of(make.Return(make.App(make.Select(stringBuilder, names.toString).setType(syms.stringType)))); + } - return List.of(make.Return( - make.App( - make.Select(stringBuilder, names.toString).setType(syms.stringType) - ) - )); + private JCMethodInvocation stringAppend(JCExpression builder, JCExpression arg) { + return make.App(make.Select(builder, names.append).setType(syms.stringBuilderType), List.of(arg)); } - private List generateEquals(JCTree.JCClassDecl classDecl, Name otherName) { - ListBuffer statements = new ListBuffer<>(); + public List generateEquals(JCClassDecl classDecl, Name otherName) { + ListBuffer statements = new ListBuffer<>(); // if (o == this) return true; - { - statements.add(make.If( - make.Binary( - JCTree.Tag.EQ, - make.This(Type.noType), - make.Ident(otherName) - ), - make.Return(make.Literal(true)), - null - )); - } + statements.add(make.If( + make.Binary( + Tag.EQ, + make.This(Type.noType), + make.Ident(otherName) + ), + make.Return(make.Literal(true)), + null + )); // if (o == null) return false; - { - statements.add(make.If( - make.Binary( - JCTree.Tag.EQ, - make.Ident(otherName), - make.Literal(TypeTag.BOT, null) - ), - make.Return(make.Literal(false)), - null - )); - } + statements.add(make.If( + make.Binary( + Tag.EQ, + make.Ident(otherName), + make.Literal(TypeTag.BOT, null) + ), + make.Return(make.Literal(false)), + null + )); // if (o.getClass() != getClass()) return false; - { + statements.add(make.If( + make.Binary( + Tag.EQ, + make.App(make.Select(make.Ident(otherName), names.getClass).setType(syms.classType)), + make.App(make.Select(make.This(Type.noType), names.getClass).setType(syms.classType)) + ), + make.Block(0, List.nil()), + make.Return(make.Literal(false)) + )); + + // Create casted variable: ClassName other = (ClassName)o; + Name thatName = names.fromString("other"); + statements.add(make.VarDef( + make.Modifiers(0), + thatName, + make.Ident(classDecl.name), + make.TypeCast(make.Ident(classDecl.name), make.Ident(otherName)) + )); + + // fields - use the casted variable + for ( + Iterator iterator = getRecordComponents(classDecl).iterator(); + iterator.hasNext(); + ) { + JCVariableDecl fieldDecl = iterator.next(); + JCExpression myFieldAccess = make.Select(make.This(Type.noType), fieldDecl.name); + JCExpression otherFieldAccess = make.Select(make.Ident(thatName), fieldDecl.name); + + final JCExpression condition; + if (fieldDecl.getType() instanceof JCPrimitiveTypeTree) { + condition = make.Binary(Tag.EQ, otherFieldAccess, myFieldAccess); + } else { + condition = make.App( + // call Objects.equals + make.Select( + make.QualIdent(syms.objectsType.tsym), + names.equals + ).setType(syms.objectsType), + List.of(otherFieldAccess, myFieldAccess) + ); + } statements.add(make.If( - make.Binary( - JCTree.Tag.EQ, - make.App(make.Select(make.Ident(otherName), names.getClass).setType(syms.classType)), - make.App(make.Select(make.This(Type.noType), names.getClass).setType(syms.classType)) - ), + condition, make.Block(0, List.nil()), make.Return(make.Literal(false)) )); } - // fields - { - for ( - Iterator iterator = getRecordComponents(classDecl).iterator(); - iterator.hasNext(); - ) { - JCTree.JCVariableDecl fieldDecl = iterator.next(); - - JCTree.JCExpression myFieldAccess = make.Select(make.This(Type.noType), fieldDecl.name); - JCTree.JCExpression otherFieldAccess = make.Select( - make.TypeCast(make.Ident(classDecl.name), make.Ident(otherName)), - fieldDecl.name - ); - - final JCTree.JCExpression condition; - if (fieldDecl.getType() instanceof JCTree.JCPrimitiveTypeTree) { - condition = make.Binary(JCTree.Tag.EQ, otherFieldAccess, myFieldAccess); - } else { - condition = make.App( - // call Objects.equals - make.Select( - make.QualIdent(syms.objectsType.tsym), - names.equals - ).setType(syms.objectsType), - List.of(otherFieldAccess, myFieldAccess) - ); - } - statements.add(make.If( - condition, - make.Block(0, List.nil()), - make.Return(make.Literal(false)) - )); - } - } - + // return true; statements.add(make.Return(make.Literal(true))); + return statements.toList(); } - private List generateHashCode(JCTree.JCClassDecl classDecl) { - ListBuffer expressions = new ListBuffer<>(); + public List generateHashCode(JCClassDecl classDecl) { + ListBuffer expressions = new ListBuffer<>(); for ( - Iterator iterator = getRecordComponents(classDecl).iterator(); + Iterator iterator = getRecordComponents(classDecl).iterator(); iterator.hasNext(); ) { - JCTree.JCVariableDecl fieldDecl = iterator.next(); + JCVariableDecl fieldDecl = iterator.next(); JCTree fType = fieldDecl.getType(); + JCExpression myFieldAccess = make.Select(make.This(Type.noType), fieldDecl.name); - JCTree.JCExpression myFieldAccess = make.Select(make.This(Type.noType), fieldDecl.name); - - if (fType instanceof JCTree.JCPrimitiveTypeTree) { - switch (((JCTree.JCPrimitiveTypeTree) fType).getPrimitiveTypeKind()) { + if (fType instanceof JCPrimitiveTypeTree) { + //TODO simplify that? + switch (((JCPrimitiveTypeTree) fType).getPrimitiveTypeKind()) { case BOOLEAN: /* this.fieldName ? 1 : 0 */ - expressions.append( - make.Conditional( - myFieldAccess, - make.Literal(TypeTag.INT, 1), - make.Literal(TypeTag.INT, 0) - ) - ); + expressions.append(make.Conditional( + myFieldAccess, + make.Literal(TypeTag.INT, 1), + make.Literal(TypeTag.INT, 0) + )); break; + case LONG: - expressions.append(longToIntForHashCode(myFieldAccess)); + expressions.append(make.TypeCast( + make.TypeIdent(syms.intType.getTag()), + make.Parens(make.Binary( + Tag.BITXOR, + myFieldAccess, + make.Parens(make.Binary( + Tag.USR, + myFieldAccess, + make.Literal(32) + )) + )) + )); break; + case FLOAT: /* this.fieldName != 0f ? Float.floatToIntBits(this.fieldName) : 0 */ - expressions.append( - make.Conditional( - make.Binary(JCTree.Tag.NE, myFieldAccess, make.Literal(0f)), - make.App( - make.Select( - make.Ident(names.fromString("Float")), - names.fromString("floatToIntBits")).setType(syms.intType), - List.of(myFieldAccess) - ), - make.Literal(TypeTag.INT, 0) - ) - ); + expressions.append(make.Conditional( + make.Binary(Tag.NE, myFieldAccess, make.Literal(0f)), + make.App( + make.Select( + make.QualIdent(types.boxedClass(syms.floatType)), + names.fromString("floatToIntBits") + ).setType(syms.intType), + List.of(myFieldAccess) + ), + make.Literal(TypeTag.INT, 0) + )); break; + case DOUBLE: - /* longToIntForHashCode(Double.doubleToLongBits(this.fieldName)) */ - expressions.append( - longToIntForHashCode( - make.App( - make.Select( - make.Ident(names.fromString("Double")), - names.fromString("doubleToLongBits")).setType(syms.intType), - List.of(myFieldAccess) - ) - ) - ); + /* Double.hashCode(this.fieldName) */ + expressions.append(make.App( + make.Select( + make.QualIdent(types.boxedClass(syms.doubleType)), + names.hashCode + ).setType(syms.intType), + List.of(myFieldAccess) + )); break; - default: + case BYTE: case SHORT: case INT: case CHAR: + default: /* just the field */ expressions.append(myFieldAccess); break; } - } else if (fType instanceof JCTree.JCArrayTypeTree) { - expressions.append( - make.App( - make.Select( - make.Select( - make.Select( - make.Ident(names.fromString("java")), - names.fromString("util") - ), - names.fromString("Arrays") - ), - names.fromString("hashCode") - ).setType(syms.intType), - List.of(myFieldAccess) - ) - ); + + } else if (fType instanceof JCArrayTypeTree) { + expressions.append(make.App( + make.Select( + make.QualIdent(syms.arraysType.tsym), + names.hashCode + ).setType(syms.intType), + List.of(myFieldAccess) + )); } else { /* (this.fieldName != null ? this.fieldName.hashCode() : 0) */ - expressions.append( - make.Conditional( - make.Binary(JCTree.Tag.NE, myFieldAccess, make.Literal(TypeTag.BOT, null)), - make.App(make.Select(myFieldAccess, names.hashCode).setType(syms.intType)), - make.Literal(0) - ) - ); + expressions.append(make.Conditional( + make.Binary(Tag.NE, myFieldAccess, make.Literal(TypeTag.BOT, null)), + make.App(make.Select(myFieldAccess, names.hashCode).setType(syms.intType)), + make.Literal(0) + )); } } - ListBuffer statements = new ListBuffer<>(); - + ListBuffer statements = new ListBuffer<>(); Name resultName = names.fromString("result"); + statements.append(make.VarDef( + make.Modifiers(0), + resultName, + make.TypeIdent(syms.intType.getTag()), + make.Literal(0) + )); - statements.append( - make.VarDef( - make.Modifiers(0L), - resultName, - make.TypeIdent(syms.intType.getTag()), - make.Literal(0) - ) - ); - for (JCTree.JCExpression expression : expressions) { + for (JCExpression expression : expressions) { // result = 31 * result + ${expr} - statements.append(make.Exec( - make.Assign( - make.Ident(resultName), - make.Binary( - JCTree.Tag.PLUS, - make.Binary(JCTree.Tag.MUL, make.Literal(TypeTag.INT, 31), make.Ident(resultName)), - expression - ) + statements.append(make.Exec(make.Assign( + make.Ident(resultName), + make.Binary( + Tag.PLUS, + make.Binary(Tag.MUL, make.Literal(TypeTag.INT, 31), make.Ident(resultName)), + expression ) - )); + ))); } statements.append(make.Return(make.Ident(resultName))); return statements.toList(); } - - public JCTree.JCExpression longToIntForHashCode(JCTree.JCExpression ref) { - /* (int) (ref ^ ref >>> 32) */ - return make.TypeCast( - make.TypeIdent(syms.intType.getTag()), - make.Parens( - make.Binary( - JCTree.Tag.BITXOR, - ref, - make.Parens(make.Binary(JCTree.Tag.USR, ref, make.Literal(32))) - ) - ) - ); - } - - private static class MandatoryDesugarAnnotationTreeScanner extends TreeScanner { - - private final Log log; - - private final CompilationUnitTree compilationUnit; - - public MandatoryDesugarAnnotationTreeScanner(Log log, CompilationUnitTree compilationUnit) { - this.log = log; - this.compilationUnit = compilationUnit; - } - - @Override - public Void visitClass(ClassTree node, Void aVoid) { - if ("RECORD".equals(node.getKind().toString())) { - if ( - node.getModifiers().getAnnotations().stream() - .noneMatch(annotation -> { - Type type = ((JCTree.JCAnnotation) annotation).type; - return Desugar.class.getName().equals(type.toString()); - }) - ) { - JavaFileObject oldSource = log.useSource(compilationUnit.getSourceFile()); - try { - log.error( - (JCTree.JCClassDecl) node, - new JCDiagnostic.Error( - "jabel", - "missing.desugar.on.record", - "Must be annotated with @Desugar" - ) - ); - } finally { - log.useSource(oldSource); - } - } - } - return super.visitClass(node, aVoid); - } - } } diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/SwitchRetrofittingTaskListener.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/SwitchRetrofittingTaskListener.java new file mode 100644 index 0000000..1ee71d0 --- /dev/null +++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/SwitchRetrofittingTaskListener.java @@ -0,0 +1,1534 @@ +package com.github.bsideup.jabel; + +import java.lang.reflect.Method; +import java.util.*; + +import com.sun.source.tree.*; +import com.sun.source.util.*; +import com.sun.source.util.TreeScanner; +import com.sun.tools.javac.code.*; +import com.sun.tools.javac.code.Symbol.*; +import com.sun.tools.javac.comp.*; +import com.sun.tools.javac.tree.*; +import com.sun.tools.javac.tree.JCTree.*; +import com.sun.tools.javac.util.*; +import com.sun.tools.javac.util.JCDiagnostic.*; +import com.sun.tools.javac.util.List; + + +/** + * Transforms modern switch constructs (Java 17+) into Java 8 compatible code. + *


+ * This listener will rewrite entirely the switch for a better one. + *

+ * Instead of placing the guard in the case body, it will be evaluated in the switch condition.
+ * Avoiding restarting the switch everytimes a guard fail, + * and storing an index to not reevaluate previous cases. + *

+ * So this method will convert cases labels into a ternary-chain + * that will be placed in the switch condition.
+ * And case labels will be changed to match their position in the switch. + *

+ * Also, to avoid double method calls, if the switch condition was a method call, + * an uninitialized variable is inserted in the previous line, + * and then initialized in the ternary-chain.
+ * Same way goes for cases using a record-pattern with a guard using a component from that record. + */ +public class SwitchRetrofittingTaskListener implements TaskListener { + // region compiler compatibility + + // Because we're compiling with JDK 25, the old method (without guards) doens't exists. + private static Method LEGACY_MAKE_CASE; + // Some internal states to avoid getting errors everytimes we trying to use a found feature. + // true means that the feature is not present, by default assuming it is. + private static boolean GUARD, LABELS, BODY, DEFAULT_CASE, CONSTANT_CASE, PATTERN_MATCHING_CATCH, SWITCH_PATTERN; + /* private */ static boolean MATCH_EXCEPTION_PRESENT; + + static { + try { + Symtab.class.getDeclaredField("matchExceptionType"); + MATCH_EXCEPTION_PRESENT = true; + } catch (Exception ignored) {} + } + + /** Get the guard expression from a case, or null if guards are unsupported. */ + private static JCExpression getGuard(JCCase caseTree) { + if (GUARD) return null; + try { + return caseTree.getGuard(); + } catch (NoSuchMethodError ignored) { + GUARD = true; + return null; + } + } + + /** Get the labels from a case, handling both old and new compiler APIs. */ + @SuppressWarnings("unchecked") + private static List getLabels(JCCase caseTree) { + if (!LABELS) { + try { + return (List) (List) caseTree.getLabels(); + } catch (NoSuchMethodError ignored) { + LABELS = true; + } + } + + List labels = caseTree.getExpressions(); + if (labels == null) return List.nil(); + List result = List.nil(); + for (Object label : labels) { + if (!(label instanceof JCTree)) continue; + result = result.append((JCTree) label); + } + return result; + } + + /** Get the arrow-style body of a case, or null if unsupported. */ + private static JCTree getBody(JCCase caseTree) { + if (BODY) return null; + try { + return caseTree.getBody(); + } catch (NoSuchMethodError ignored) { + BODY = true; + return null; + } + } + + // TODO: CaseTree.CaseKind doesn't exists on Java 12- + // EDIT: Seems to not be a real problem... + /** + * Create a new case tree, handling different JDK signatures.
+ * JDK 21+: Case(CaseKind, List labels, JCExpression guard, List stats, JCTree body)
+ * JDK 17-20: Case(CaseKind, List labels, List stats, JCTree body) + */ + @SuppressWarnings("unchecked") + private JCCase makeCase( + CaseTree.CaseKind kind, List labels, List stats, JCTree body + ) { + // A default case is one that have no labels on JDK < 17 + if (!labels.isEmpty() && labels.head == null) labels = List.nil(); + + if (!GUARD) { + try { + return make.Case(kind, (List) labels, null, stats, body); + } catch (NoSuchMethodError ignored) { + GUARD = true; + } + } + + try { + if (LEGACY_MAKE_CASE == null) { + LEGACY_MAKE_CASE = TreeMaker.class.getMethod( + "Case", CaseTree.CaseKind.class, List.class, List.class, JCTree.class + ); + } + return (JCCase) LEGACY_MAKE_CASE.invoke(make, kind, labels, stats, body); + } catch (Exception ignored) {// Should never fail, hope this never happen... + System.err.println("[jabel] Failed to make case with labels: " + labels); + } + return null; + } + + // Because the return type of these two methods may not exist, we need delay the call to an inner class. + // Like that, the class initialization error can be catched easily. + private static final class DefaultCaseLabelFactory { + static JCTree make(TreeMaker m) { + return m.DefaultCaseLabel(); + } + } + + private static final class ConstantCaseLabelFactory { + static JCTree make(TreeMaker m, JCExpression lit) { + return m.ConstantCaseLabel(lit); + } + } + + /** Create a default case label. Returns null if unsupported (JDK < 17). */ + private JCTree makeDefaultCaseLabel() { + if (DEFAULT_CASE) return null; + try { + return DefaultCaseLabelFactory.make(make); + } catch (NoSuchMethodError | NoClassDefFoundError ignored) { + DEFAULT_CASE = true; + return null; + } + } + + /** Creates a case label for an {@code Integer}, handling JDK 17-20 and 21+ APIs. */ + private JCTree makeLabel(int i) { + JCLiteral lit = make.Literal(i); + if (CONSTANT_CASE) return lit; + try { + return ConstantCaseLabelFactory.make(make, lit); // JDK 21+ + } catch (NoSuchMethodError | NoClassDefFoundError ignored) { + CONSTANT_CASE = true; + return lit; // JDK 17-20 + } + } + + private static final class PatternMatchingCatchAccess { + static void attach(JCBlock block, JCCatch handler, Set calls) { + block.patternMatchingCatch = new JCBlock.PatternMatchingCatch(handler, calls); + } + + static JCCatch handler(JCBlock block) { + if (block.patternMatchingCatch == null) return null; + return block.patternMatchingCatch.handler(); + } + } + + /* private */ static JCCatch getPatternMatchingCatchHandler(JCBlock block) { + if (PATTERN_MATCHING_CATCH) return null; + try { + return PatternMatchingCatchAccess.handler(block); + } catch (NoSuchFieldError | NoClassDefFoundError ignored) { + PATTERN_MATCHING_CATCH = true; + return null; + } + } + + private static void attachPatternMatchingCatch(JCBlock block, JCCatch body, Set calls) { + if (PATTERN_MATCHING_CATCH) return; + try { + PatternMatchingCatchAccess.attach(block, body, calls); + } catch (NoSuchFieldError | NoClassDefFoundError ignored) { + PATTERN_MATCHING_CATCH = true; + } + } + + /** Sets {@code patternSwitch = false} on a switch, silently ignoring JDK < 17. */ + private static void clearPatternSwitchStatus(JCTree tree) { + if (SWITCH_PATTERN) return; + try { + switch (getClassName(tree)) { + case "JCSwitch": + ((JCSwitch) tree).patternSwitch = false; + break; + case "JCSwitchExpression": + ((JCSwitchExpression) tree).patternSwitch = false; + break; + } + } catch (NoSuchFieldError | NoClassDefFoundError ignored) { + SWITCH_PATTERN = true; + } + } + + private static String getClassName(Object obj) { + return obj == null ? "" : obj.getClass().getSimpleName(); + } + + private static boolean isBindingPattern(JCTree pattern) { + return pattern != null && getClassName(pattern).equals("JCBindingPattern"); + } + + private static boolean isRecordPattern(JCTree pattern) { + return pattern != null && getClassName(pattern).equals("JCRecordPattern"); + } + + /** Check if this tree node is a pattern (binding, record, case label, or any). */ + private static boolean isPattern(JCTree label) { + if (label == null) return false; + String name = getClassName(label); + return name.contains("Pattern") || name.contains("Binding"); + } + + private static boolean isSwitchExpression(JCTree tree) { + return tree != null && getClassName(tree).equals("JCSwitchExpression"); + } + + private static boolean isDefault(JCTree label) { + return label != null && getClassName(label).equals("JCDefaultCaseLabel"); + } + + private static boolean isConstant(JCTree label) { + return label != null && getClassName(label).equals("JCConstantCaseLabel"); + } + + private static boolean isNull(JCTree label) { + if (label instanceof JCLiteral) return ((JCLiteral) label).typetag == TypeTag.BOT; + if (!isConstant(label)) return false; + JCExpression expr = ((JCConstantCaseLabel) label).expr; + return expr instanceof JCLiteral && ((JCLiteral) expr).typetag == TypeTag.BOT; + } + + private boolean isComplex(JCExpression e) { + if (e instanceof JCIdent || e instanceof JCLiteral) return false; + if (e instanceof JCFieldAccess) return isComplex(((JCFieldAccess) e).selected); + if (e instanceof JCParens) return isComplex(((JCParens) e).expr); + return true; + } + + /** @return {@code true} if {@code expr} is structurally an enum constant reference. */ + private static boolean isEnumConstant(JCExpression expr) { + return (expr instanceof JCIdent || expr instanceof JCFieldAccess) + && !(expr instanceof JCLiteral); + } + + private JCExpression getExpression(JCTree body) { + if (body instanceof JCExpressionStatement) return ((JCExpressionStatement) body).expr; + if (body instanceof JCExpression) return (JCExpression) body; + return null; + } + + private static JCExpression getLabelExpression(JCTree label) { + if (label instanceof JCExpression) return (JCExpression) label; + if (isConstant(label)) return ((JCConstantCaseLabel) label).expr; + return null; + } + + /** + * Because {@link JCPatternCaseLabel#pat} type is {@link JCPattern}, + * we need to delay access to an inner class.
+ * Like that, the class initialization error can be catched easily. + */ + private static final class PatternCaseLabelAccess { + static JCTree pat(JCTree label) { + return ((JCPatternCaseLabel) label).pat; + } + } + + /** Get the binding variable from a pattern (binding pattern or pattern case label). */ + private static JCVariableDecl getPatternVar(JCTree label) { + if (label == null) return null; + switch (getClassName(label)) { + case "JCBindingPattern": + return ((JCBindingPattern) label).var; + case "JCPatternCaseLabel": + return getPatternVar(PatternCaseLabelAccess.pat(label)); + default: + return null; + } + } + + /** + * Get the type from a pattern label (for instanceof check).
+ * Works for binding patterns, pattern case labels, and unnamed patterns. + */ + private static JCExpression getPatternType(JCTree label) { + if (label == null) return null; + switch (getClassName(label)) { + case "JCBindingPattern": + JCVariableDecl var = ((JCBindingPattern) label).var; + return var != null ? var.vartype : null; + case "JCPatternCaseLabel": + return getPatternType(PatternCaseLabelAccess.pat(label)); + // case "JCAnyPattern": + // return getAnyPatternType(label); + default: + return null; + } + } + + /** Get the record type from a record pattern or pattern case label. */ + private static JCExpression getRecordType(JCTree label) { + if (label == null) return null; + switch (getClassName(label)) { + case "JCRecordPattern": + return ((JCRecordPattern) label).deconstructor; + case "JCPatternCaseLabel": + return getRecordType(PatternCaseLabelAccess.pat(label)); + default: + return null; + } + } + + /** Get nested patterns from a record pattern or pattern case label. */ + private static List getRecordNested(JCTree label) { + if (label == null) return null; + switch (getClassName(label)) { + case "JCRecordPattern": + return ((JCRecordPattern) label).nested; + case "JCPatternCaseLabel": + return getRecordNested(PatternCaseLabelAccess.pat(label)); + default: + return null; + } + } + + private static boolean hasDefault(List cases) { + if (cases == null) return true; + for (JCCase c : cases) { + if (c == null) continue; + List labels = getLabels(c); + if (labels.isEmpty()) return true; + for (JCTree label : labels) { + if (isDefault(label)) return true; + } + } + return false; + } + + private static boolean hasRecordPatterns(List cases) { + if (cases == null) return false; + for (JCCase c : cases) { + if (c == null) continue; + for (JCTree l : getLabels(c)) { + if (getRecordType(l) != null) return true; + } + } + return false; + } + + private static boolean hasGuardRecordPatterns(List cases) { + if (cases == null) return false; + for (JCCase c : cases) { + if (c == null || getGuard(c) == null) continue; + for (JCTree l : getLabels(c)) { + if (getRecordType(l) != null) return true; + } + } + return false; + } + + /** @return whether case labels contains things not handled by a standard switch. */ + private static boolean needsTransform(List cases) { + if (cases == null) return false; + for (JCCase c : cases) { + if (c == null) continue; + if (getGuard(c) != null) return true; + for (JCTree label : getLabels(c)) { + if (isPattern(label) || isNull(label)) return true; + JCExpression expr = getLabelExpression(label); + if (!(expr instanceof JCLiteral)) continue; + switch (((JCLiteral) expr).typetag) { + case FLOAT: + case DOUBLE: + case LONG: + return true; + default: + } + } + } + return false; + } + + /** + * @return whether all cases are just enum constants with a {@code null} + * and/or an unconditional pattern. + */ + private static boolean isEnumSwitch(List cases) { + if (cases == null) return false; + boolean hasEnumConstant = false, hasUnguardedPattern = false; + for (JCCase c : cases) { + if (c == null) continue; + if (getGuard(c) != null) return false; + for (JCTree label : getLabels(c)) { + if (isDefault(label) || isNull(label)) continue; + if (isPattern(label)) { + if (hasUnguardedPattern) return false; + hasUnguardedPattern = true; + continue; + } + JCExpression expr = getLabelExpression(label); + if (expr == null || !isEnumConstant(expr)) return false; + hasEnumConstant = true; + } + } + return hasEnumConstant; + } + + // Forced to use reflection here, methods are package-private. + static Method resolveBinary; + static { + try { + resolveBinary = Operators.class.getDeclaredMethod( + "resolveBinary", DiagnosticPosition.class, Tag.class, Type.class, Type.class + ); + resolveBinary.setAccessible(true); + } catch (Exception ignored) {} + } + + /* private */ static OperatorSymbol resolveBinary( + Operators ops, DiagnosticPosition pos, Tag tag, Type op1, Type op2 + ) { + if (resolveBinary != null) { + try { + return (OperatorSymbol) resolveBinary.invoke(ops, pos, tag, op1, op2); + } catch (Exception e) { + //TODO use Log instead? but this should never fail... + System.err.println("[jabel] Failed to find operator " + tag + " with " + op1 + "," + op2); + } + } + return null; + } + + // end region + // region task listener + + final TreeMaker make; + final Symtab syms; + final Names names; + final Types types; + final Operators ops; + final Attr attr; + final SymTreeCopier copier; + + private int tempVarCounter; + private Set pendingAccessorCalls = null; + private Map> guardPreVars = null; + private MethodSymbol currentMethodSym; + private boolean cacheResolved, useMatchException; + private MethodSymbol object_equals, object_toString, object_getClass, enum_ordinal, class_getName, + float_floatToIntBits, double_doubleToLongBits, icce_ctor, me_ctor; + + public SwitchRetrofittingTaskListener(Context context) { + make = TreeMaker.instance(context); + syms = Symtab.instance(context); + names = Names.instance(context); + types = Types.instance(context); + ops = Operators.instance(context); + attr = Attr.instance(context); + copier = new SymTreeCopier(make); + } + + @Override + public void started(TaskEvent e) { + } + + @Override + public void finished(TaskEvent e) { + if (e.getKind() != TaskEvent.Kind.ANALYZE) return; + if (!(e.getCompilationUnit() instanceof JCCompilationUnit)) return; + new SwitchTranslator().translate((JCCompilationUnit) e.getCompilationUnit()); + } + + public class SwitchTranslator extends TreeTranslator { + /** Captures the original selector of a switch expression before replacement by a temp var. */ + private final Map captures = new HashMap<>(); + private Set blockAccessorCalls = null; + private ClassSymbol currentClass; + + // Symbol tracking + @Override + public void visitClassDef(JCClassDecl tree) { + ClassSymbol prevClass = currentClass; + MethodSymbol prevMethod = currentMethodSym; + try { + currentClass = tree.sym; + currentMethodSym = null; + super.visitClassDef(tree); + } finally { + currentClass = prevClass; + currentMethodSym = prevMethod; + } + } + + @Override + public void visitMethodDef(JCMethodDecl tree) { + MethodSymbol prev = currentMethodSym; + try { + currentMethodSym = tree.sym; + super.visitMethodDef(tree); + } finally { + currentMethodSym = prev; + } + } + + @Override + public void visitBlock(JCBlock tree) { + MethodSymbol prev = currentMethodSym; + // Save and reset blockAccessorCalls for this block scope + Set prevCalls = blockAccessorCalls; + blockAccessorCalls = null; + try { + if (currentMethodSym == null && currentClass != null) { + currentMethodSym = new MethodSymbol( + tree.flags | Flags.BLOCK, + names.empty, + null, + currentClass + ); + } + ListBuffer buffer = null; + for (JCStatement stmt : tree.stats) { + buffer = prepareExprSwitchPrefix(stmt, buffer, tree.stats); + if (buffer != null) buffer.append(stmt); + } + if (buffer != null) tree.stats = buffer.toList(); + super.visitBlock(tree); + + // Attach patternMatchingCatch to this block if accessor calls were collected + if (blockAccessorCalls != null) { + pendingAccessorCalls = blockAccessorCalls; + blockAccessorCalls = null; + attachPatternMatchingCatch(tree); + } + } finally { + // Propagate remaining calls to the parent scope + if (blockAccessorCalls != null) { + if (prevCalls == null) prevCalls = blockAccessorCalls; + else prevCalls.addAll(blockAccessorCalls); + } + blockAccessorCalls = prevCalls; + currentMethodSym = prev; + } + } + + @Override + public void visitVarDef(JCVariableDecl tree) { + MethodSymbol prev = currentMethodSym; + try { + if (currentMethodSym == null && currentClass != null) { + currentMethodSym = new MethodSymbol( + (tree.mods.flags & Flags.STATIC) | Flags.BLOCK, + names.empty, + null, + currentClass + ); + } + super.visitVarDef(tree); + } finally { + currentMethodSym = prev; + } + } + + @Override + public void visitSwitch(JCSwitch tree) { + super.visitSwitch(tree); + if (!needsTransform(tree.cases)) return; + clearPatternSwitchStatus(tree); // TransPatterns must not re-process our output + + make.at(tree.pos); + ListBuffer prefix = new ListBuffer<>(); + JCExpression sel = captureSelector(tree.selector, prefix, true); + buildGuardPreDecls(tree.cases, prefix); + JCSwitch ns = transformSwitch(tree, sel, tree.cases, false, null); + tree.selector = ns.selector; + tree.cases = ns.cases; + + collectAccessorCalls(); + if (prefix.isEmpty()) return; + prefix.append(tree); + result = make.Block(0, prefix.toList()); + } + + private ListBuffer prepareExprSwitchPrefix( + JCStatement stmt, ListBuffer buffer, List allStats + ) { + List found = findPatternSwitches(stmt); + if (found.isEmpty()) return buffer; + + boolean needsBuffer = false; + for (JCTree tree : found) { + JCSwitchExpression sw = (JCSwitchExpression) tree; + if (isComplex(sw.selector) || hasGuardRecordPatterns(sw.cases)) { + needsBuffer = true; + break; + } + } + if (!needsBuffer) return buffer; + + if (buffer == null) { + ListBuffer buf = new ListBuffer<>(); + for (JCStatement s : allStats) { + if (s == stmt) break; + buf.append(s); + } + buffer = buf; + } + + for (JCTree tree : found) { + JCSwitchExpression sw = (JCSwitchExpression) tree; + JCExpression sel = captureSelector(sw.selector, buffer, false); + if (sel != sw.selector) { + captures.put(tree, sw.selector); + sw.selector = sel; + } + if (needsTransform(sw.cases)) buildGuardPreDecls(sw.cases, buffer); + } + return buffer; + } + + @Override + public T translate(T tree) { + if (tree == null) return null; + if (!isSwitchExpression(tree)) return super.translate(tree); + clearPatternSwitchStatus(tree); + + JCSwitchExpression sw = (JCSwitchExpression) tree; + make.at(sw.pos); + JCExpression rawSel = captures.remove(sw); + sw.selector = translate(sw.selector); + sw.cases = translate(sw.cases); + + if (needsTransform(sw.cases)) { + JCSwitch ns = transformSwitch(tree, sw.selector, sw.cases, true, rawSel); + sw.selector = ns.selector; + sw.cases = ns.cases; + collectAccessorCalls(); + } else { + sw.cases = injectDefault(sw.selector, sw.cases); + } + return tree; + } + + /** TreeTranslator skips the arrow body field. */ + @Override + public void visitCase(JCCase tree) { + super.visitCase(tree); + JCTree body = getBody(tree); + if (body == null) return; + JCTree saved = result; + tree.body = translate(body); + result = saved; + } + + /** Moves pendingAccessorCalls into blockAccessorCalls. */ + private void collectAccessorCalls() { + if (pendingAccessorCalls == null) return; + if (blockAccessorCalls == null) blockAccessorCalls = new HashSet<>(); + blockAccessorCalls.addAll(pendingAccessorCalls); + pendingAccessorCalls = null; + } + } + + @SuppressWarnings("unchecked") + public List findPatternSwitches(JCTree node) { + ListBuffer out = new ListBuffer<>(); + new TreeScanner() { + @Override + public Void scan(Tree t, Void v) { + if (t == null) return null; + if (!isSwitchExpression((JCTree) t)) return super.scan(t, v); + JCSwitchExpression se = (JCSwitchExpression) t; + if (needsTransform(se.cases)) out.append(se); + return null; + } + }.scan(node, null); + return (List) (List) out.toList(); + } + + /** Builds the standard switch from a pattern/null switch. */ + public JCSwitch transformSwitch( + JCTree currentSwitch, JCExpression sel, List cases, boolean expression, + JCExpression rawSel + ) { + pendingAccessorCalls = hasRecordPatterns(cases) ? new HashSet<>() : null; + + boolean isEnum = isEnumSwitch(cases), hasNull = false; + List labels = List.nil(); + List slotCases = List.nil(); + + // Analyze case and labels + for (JCCase c : cases) { + JCTree last = null; + for (JCTree label : getLabels(c)) { + if (isNull(label)) { + hasNull = true; + continue; + } + if (isDefault(label) || (isEnum && isPattern(label))) continue; + if (last != null) { + labels = labels.append(last); + slotCases = slotCases.append(null); + } + last = label; + } + if (last != null) { + labels = labels.append(last); + slotCases = slotCases.append(c); + } + } + + return make.Switch( + isEnum + ? buildEnumSwitchCondition(sel, hasNull, rawSel) + : buildTypeSwitchCondition(sel, hasNull, labels, slotCases, rawSel), + prepareSwichCases(currentSwitch, sel, expression, isEnum, cases, labels) + ); + } + + public List prepareSwichCases( + JCTree currentSwitch, JCExpression sel, boolean isExpression, boolean isEnum, + List cases, List labels + ) { + ListBuffer newCases = new ListBuffer<>(); + JCCase nc; + JCStatement body; + boolean defaultEmitted = false; + int i = 0; + + for (JCCase c : cases) { + JCTree lastNormal = null; + boolean seenNull = false, seenDefault = false; + List caseLabels = getLabels(c); + + // Scan for "null" and "default" labels + for (JCTree label : caseLabels) { + if (isNull(label)) seenNull = true; + else if (isDefault(label) || (isEnum && isPattern(label))) seenDefault = true; + else lastNormal = label; + } + + // Emit fall-through slots + for (JCTree label : caseLabels) { + if (isNull(label) || isDefault(label) || (isEnum && isPattern(label))) continue; + if (label == lastNormal) break; + int lv = isEnum ? getEnumOrdinal(getLabelExpression(label)) : i++; + if (lv < 0) continue; + nc = makeStatementCase(makeLabel(lv), List.nil()); + if (nc != null) newCases.append(nc); + } + + // Emit the last normal label + if (lastNormal != null) { + int lv = isEnum ? getEnumOrdinal(getLabelExpression(lastNormal)) : i++; + if (lv >= 0) { + body = make.Block(0, buildCaseBody(currentSwitch, c, sel, isExpression)); + nc = makeStatementCase(makeLabel(lv), List.of(body)); + if (nc != null) newCases.append(nc); + } + } + + // Emit null slot at its original position + if (seenNull) { + // In case of null and default are on the same case + body = seenDefault + ? null + : make.Block(0, buildCaseBody(currentSwitch, c, sel, isExpression)); + nc = makeStatementCase(makeLabel(-1), body == null ? List.nil() : List.of(body)); + if (nc != null) newCases.append(nc); + } + + // Emit default slot at its original position + if (seenDefault && !defaultEmitted) { + defaultEmitted = true; + body = make.Block(0, buildCaseBody(currentSwitch, c, sel, isExpression)); + nc = makeStatementCase(makeDefaultCaseLabel(), List.of(body)); + if (nc != null) newCases.append(nc); + } + } + + // Emit default case for exhaustive switches + if (!defaultEmitted) { + body = makeMatchExceptionThrow(sel, !isEnum); + nc = makeStatementCase(makeDefaultCaseLabel(), List.of(body)); + if (nc != null) newCases.append(nc); + } + + return newCases.toList(); + } + + /** Builds a simple {@code sel == null ? -1 : sel.ordinal()} condition. */ + public JCExpression buildEnumSwitchCondition(JCExpression sel, boolean hasNull, JCExpression rawSel) { + resolveMethodsCache(); + VarSymbol selSym = (sel instanceof JCIdent && ((JCIdent) sel).sym instanceof VarSymbol) + ? (VarSymbol) ((JCIdent) sel).sym + : null; + JCExpression base = makeMethodCall(sel, enum_ordinal); + if (hasNull) base = makeConditional(makeBinary(Tag.EQ, sel, makeNull()), make.Literal(-1), base); + return rawSel != null && selSym != null ? assignSwitchSelector(base, selSym, rawSel) : base; + } + + /** Builds a ternary chain and groups the same type-pattern. */ + public JCExpression buildTypeSwitchCondition( + JCExpression sel, boolean hasNull, List labels, List slotCases, + JCExpression rawSel + ) { + int n = labels.size(); + Map group = new HashMap<>(); + JCExpression[] anchor = new JCExpression[n]; + + for (int i = 0; i < n; i++, labels = labels.tail, slotCases = slotCases.tail) { + JCExpression pt = getPatternTypeForGrouping(labels.head); + Symbol tsym = pt != null && pt.type != null ? types.erasure(pt.type).tsym : null; + + if (tsym == null) { + anchor[i] = buildLabelCondition(labels.head, sel); + continue; + } + + JCExpression guard = buildGuard(labels.head, sel, slotCases.head); + JCConditional open = group.get(tsym); + + if (open == null) { + JCExpression thenpart; + if (guard != null) { + thenpart = makeConditional(guard, make.Literal(i), make.Literal(n)); + group.put(tsym, (JCConditional) thenpart); + } else { + thenpart = make.Literal(i); + } + anchor[i] = makeConditional(makeTypeTest(sel, pt), thenpart, make.Literal(n)); + } else { + anchor[i] = null; + if (guard != null) { + JCConditional next = makeConditional(guard, make.Literal(i), make.Literal(n)); + open.falsepart = next; + open.type = syms.intType; + group.put(tsym, next); + } else { + open.falsepart = make.Literal(i); + open.type = syms.intType; + group.remove(tsym); + } + } + } + + // Assemble ternary chain + sel = TreeInfo.skipParens(sel); + VarSymbol selSym = (sel instanceof JCIdent && ((JCIdent) sel).sym instanceof VarSymbol) + ? (VarSymbol) ((JCIdent) sel).sym + : null; + boolean inlineInNull = hasNull && rawSel != null && selSym != null; + JCExpression ternary = make.Literal(n); + + // Null check + if (!hasNull && sel.type != null && !sel.type.isPrimitive() && selSym != null) { + rawSel = attr.makeNullCheck(rawSel != null ? rawSel : sel); + } + + for (int i = n - 1; i >= 0; i--) { + JCExpression cond = anchor[i]; + if (cond == null) continue; + if (i == 0 && !inlineInNull && rawSel != null && selSym != null) { + cond = assignSwitchSelector(cond, selSym, rawSel); + } + if (cond instanceof JCConditional) { + ((JCConditional) cond).falsepart = ternary; + ternary = cond; + } else { + ternary = makeConditional(cond, make.Literal(i), ternary); + } + } + // In case of + if (rawSel != null && selSym != null && !inlineInNull && n > 0 && anchor[0] == null) { + ternary = assignSwitchSelector(ternary, selSym, rawSel); + } + + if (hasNull) { + JCExpression nullSel = inlineInNull ? makeAssignParens(selSym, rawSel) : sel; + ternary = makeConditional(makeBinary(Tag.EQ, nullSel, makeNull()), make.Literal(-1), ternary); + } + + return (n == 0 && rawSel != null && selSym != null) ? makeAssignParens(selSym, rawSel) : ternary; + } + + private JCExpression getPatternTypeForGrouping(JCTree label) { + if (!isPattern(label)) return null; + JCVariableDecl pv = getPatternVar(label); + if (pv != null) return pv.vartype; + return getRecordType(label); + } + + /** Cache enum values. */ + // TODO: maybe find a better way to get Enum#ordinal()? + private final Map enumOrdinalCache = new HashMap<>(); + public int getEnumOrdinal(JCExpression expr) { + Symbol sym = null; + if (expr instanceof JCIdent) sym = ((JCIdent) expr).sym; + else if (expr instanceof JCFieldAccess) sym = ((JCFieldAccess) expr).sym; + if (!(sym instanceof VarSymbol) || (sym.flags() & Flags.ENUM) == 0) return -1; + + Integer cached = enumOrdinalCache.get(sym); + if (cached != null) return cached; + + int ordinal = 0; + for (Symbol s : sym.owner.getEnclosedElements()) { + if ((s.flags() & Flags.ENUM) == 0) continue; + enumOrdinalCache.put(s, ordinal); + ordinal++; + } + cached = enumOrdinalCache.get(sym); + return cached != null ? cached : -1; + } + + public JCExpression captureSelector(JCExpression sel, ListBuffer prefix, boolean init) { + if (!isComplex(sel)) return sel; + JCVariableDecl v = makeVarDef(switchTempName(), sel.type, init ? sel : null); + prefix.append(v); + return make.QualIdent(v.sym); + } + + private void attachPatternMatchingCatch(JCBlock block) { + if (pendingAccessorCalls == null || !MATCH_EXCEPTION_PRESENT) return; + // TODO: JDK 19 generates proxy methods, but it was also the first preview of record-pattern-matching. + // Should we also generate proxy methods or leave as is, without error handling? + resolveMethodsCache(); + JCVariableDecl ctch = makeVarDef(recordTempName(), syms.throwableType, null, Flags.PARAMETER); + JCCatch handler = make.Catch( + ctch, + make.Block(0, List.of(make.Throw(makeNewClass( + me_ctor, + List.of(makeMethodCall(make.Ident(ctch.sym), object_toString), make.Ident(ctch.sym)) + )))) + ); + attachPatternMatchingCatch(block, handler, pendingAccessorCalls); + pendingAccessorCalls = null; + } + + public List buildCaseBody(JCTree cSwitch, JCCase c, JCExpression sel, boolean expression) { + ListBuffer out = new ListBuffer<>(); + addBindings(c, sel, out); + JCTree body = getBody(c); + if (body instanceof JCBlock) { + for (JCStatement s : ((JCBlock) body).stats) out.append(s); + } else if (body != null) { + JCExpression expr = getExpression(body); + if (expr != null) { + if (expression) { + JCYield y = make.Yield(expr); + y.target = cSwitch; + out.append(y); + } else { + out.append(make.Exec(expr)); + JCBreak b = make.Break(null); + b.target = cSwitch; + out.append(b); + } + } else if (body instanceof JCStatement) { + out.append((JCStatement) body); + } + } else if (c.stats != null) { + for (JCStatement s : c.stats) out.append(s); + } + return out.toList(); + } + + /** Returns the guard expression for a pattern slot (with binding substitution), or null if none. */ + public JCExpression buildGuard(JCTree label, JCExpression sel, JCCase src) { + if (src == null) return null; + JCExpression guard = getGuard(src); + if (guard == null) return null; + Map bindings = collectBindings(label, sel, src); + if (bindings.isEmpty()) return guard; + return new TreeTranslator() { + @Override + public void visitIdent(JCIdent id) { + JCExpression r = bindings.get(id.name); + result = r != null ? r : id; + } + }.translate(guard); + } + + public JCExpression buildLabelCondition(JCTree label, JCExpression sel) { + if (isDefault(label)) return null; + + if (isPattern(label)) { + // sel instanceof Type + JCVariableDecl pv = getPatternVar(label); + if (pv != null) return makeTypeTest(sel, pv.vartype); + JCExpression rt = getRecordType(label); + if (rt != null) return makeTypeTest(sel, rt); + JCExpression pt = getPatternType(label); + if (pt != null) return makeTypeTest(sel, pt); + } + + JCExpression expr = getLabelExpression(label); + if (expr == null) return null; + if (expr instanceof JCLiteral) { + JCLiteral lit = (JCLiteral) expr; + switch (lit.typetag) { + case FLOAT: + // TODO: use Float#floatToIntBits() result as a selector? + case DOUBLE: { + // Boxed.xxxxToxxxBits((Boxed)sel) == Boxed.xxxxToxxxBits(expr) + // TODO: cache sel result + resolveMethodsCache(); + boolean isFloat = lit.typetag == TypeTag.FLOAT; + ClassSymbol type = types.boxedClass(isFloat ? syms.floatType : syms.doubleType); + MethodSymbol method = isFloat ? float_floatToIntBits : double_doubleToLongBits; + return makeBinary( + Tag.EQ, + makeMethodCall(make.QualIdent(type), method, List.of(makeCast(sel, type.type))), + makeMethodCall(make.QualIdent(type), method, List.of(expr)) + ); + } + case CLASS: //$FALL-THROUGH$ + break; + case INT: + // TODO: we lost O(1) if there is only a null case. Let the compiler decide? + // TODO2: process the same ways as JDK 17? (sel == null ? (sel != n ? sel : n+1) : n) + // where the null case is a value not used in labels, instead of -1. + default: + // sel == expr + return makeBinary(Tag.EQ, sel, expr); + } + + } else if (isEnumConstant(expr)) { + // sel == Enum.expr + if (expr instanceof JCIdent) { + JCIdent id = (JCIdent) expr; + if (id.sym != null && id.sym.owner != null && id.sym.owner.type != null) { + JCFieldAccess fa = make.Select(make.Type(id.sym.owner.type), id.name); + fa.sym = id.sym; + fa.type = id.type; + expr = fa; + } + } + return makeBinary(Tag.EQ, sel, expr); + } + + // ((Object)sel).equals(expr) (Implicit null check) + resolveMethodsCache(); + return makeMethodCall(sel, object_equals, List.of(expr)); + } + + /** Injects {@code default: throw new IncompatibleClassChangeError(...)} if no default exists. */ + public List injectDefault(JCExpression sel, List cases) { + if (hasDefault(cases)) return cases; + CaseTree.CaseKind kind = CaseTree.CaseKind.STATEMENT; + for (JCCase c : cases) { + if (c != null) { + kind = c.getCaseKind(); + break; + } + } + + JCStatement throwStmt = makeMatchExceptionThrow(sel, false); + JCStatement body = kind == CaseTree.CaseKind.RULE ? throwStmt : null; + JCCase dc = makeCase(kind, List.of(makeDefaultCaseLabel()), List.of(throwStmt), body); + return dc != null ? cases.append(dc) : cases; + } + + public List getRecordComponentNames(JCTree pattern) { + JCExpression deconstructor = getRecordType(pattern); + if (deconstructor == null) return List.nil(); + Type recordType = deconstructor.type; + if (recordType == null || !(recordType.tsym instanceof ClassSymbol)) return List.nil(); + List result = List.nil(); + for (RecordComponent rc : ((ClassSymbol) recordType.tsym).getRecordComponents()) { + result = result.append(rc.name); + } + return result; + } + + public Map collectBindings(JCTree label, JCExpression sel, JCCase caseTree) { + Map preVars = guardPreVars != null ? guardPreVars.get(caseTree) : null; + Map map = new HashMap<>(); + + JCVariableDecl pv = getPatternVar(label); + if (pv != null) { + map.put(pv.name, makeCast(sel, symTypeOf(pv))); + return map; + } + + JCExpression rt = getRecordType(label); + List nested = getRecordNested(label); + if (rt == null || nested == null) return map; + + Type recType = rt.type != null ? rt.type : syms.objectType; + List componentNames = getRecordComponentNames(label); + + for (JCTree np : nested) { + Name cn = componentNames.head; + componentNames = componentNames.tail; + JCVariableDecl nv = getPatternVar(np); + if (nv == null) continue; + + cn = cn != null ? cn : nv.name; + JCMethodInvocation accessorCall = makeMethodCall(makeCast(sel, recType), cn); + if (pendingAccessorCalls != null) pendingAccessorCalls.add(accessorCall); + + JCExpression val = preVars != null && preVars.containsKey(nv.name) + ? makeAssignParens(preVars.get(nv.name), accessorCall) + : accessorCall; + map.put(nv.name, val); + } + return map; + } + + /** Emits binding variable declarations at the top of a case body. */ + public void addBindings(JCCase caseTree, JCExpression sel, ListBuffer out) { + Map preVars = guardPreVars != null ? guardPreVars.get(caseTree) : null; + + for (JCTree label : getLabels(caseTree)) { + if (!isPattern(label)) continue; + + JCVariableDecl pv = getPatternVar(label); + if (pv != null) { + Type castTy = symTypeOf(pv); + out.append(make.VarDef(pv.sym, makeCast(sel, castTy))); + continue; + } + + JCExpression rt = getRecordType(label); + List nested = getRecordNested(label); + if (rt == null || nested == null) continue; + + // RecordType $record$N = (RecordType) sel; + JCVariableDecl v = makeVarDef(recordTempName(), rt.type, makeCast(sel, rt.type)); + out.append(v); + + if (preVars == null || preVars.isEmpty()) { + ListBuffer bindings = new ListBuffer<>(); + extractRecordBindings(nested, make.Ident(v.sym), getRecordComponentNames(label), bindings); + for (JCVariableDecl decl : bindings) { + if (pendingAccessorCalls != null) { + collectMethodInvocations(decl.init, pendingAccessorCalls); + } + out.append(decl); + } + return; + } + + List componentNames = getRecordComponentNames(label); + for (JCTree np : nested) { + Name name = componentNames.head; + componentNames = componentNames.tail; + if (!isBindingPattern(np)) continue; + JCVariableDecl npv = ((JCBindingPattern) np).var; + if (npv == null) continue; + + VarSymbol preVar = preVars.get(npv.name); + JCExpression init; + if (preVar != null) { + init = make.Ident(preVar); // alias the pre-var assigned in the guard + } else { + name = name != null ? name : npv.name; + JCMethodInvocation call = makeMethodCall(make.Ident(v.sym), name); + if (pendingAccessorCalls != null) pendingAccessorCalls.add(call); + init = call; + } + out.append(make.VarDef(npv.sym, init)); + } + } + } + + /** + * Recursively extract bindings from nested record pattern components. + *

+ * Generates variable declarations like: + * {@code final Point $tmp = parent.b(); final int bx = $tmp.x(); final int by = $tmp.y();} + */ + public void extractRecordBindings( + List nested, JCExpression baseAccessor, List componentNames, + ListBuffer out + ) { + for (JCTree pattern : nested) { + Name name = componentNames.head; + componentNames = componentNames.tail; + + if (isBindingPattern(pattern)) { + JCVariableDecl var = ((JCBindingPattern) pattern).var; + if (var == null) continue; + name = name != null ? name : var.name; + out.append(make.VarDef(var.sym, makeMethodCall(baseAccessor, name))); + continue; + } + + if (!isRecordPattern(pattern)) continue; + JCExpression nestedRecordType = getRecordType(pattern); + List deepNested = getRecordNested(pattern); + if (nestedRecordType == null || deepNested == null || name == null) continue; + + JCVariableDecl tmpDecl = makeVarDef( + recordTempName(), + nestedRecordType.type, + makeMethodCall(baseAccessor, name) + ); + out.append(tmpDecl); + extractRecordBindings( + deepNested, + make.Ident(tmpDecl.sym), + getRecordComponentNames(pattern), + out + ); + } + } + + /** + * Pre-declares variables for record components referenced in a guard.
+ * These are assigned inline in the ternary-chain via {@link #collectBindings}, + * and aliased in the case body via {@link #addBindings}. + */ + public void buildGuardPreDecls(List cases, ListBuffer out) { + guardPreVars = null; + + for (JCCase c : cases) { + JCExpression guard = getGuard(c); + if (guard == null) continue; + + for (JCTree label : getLabels(c)) { + if (getRecordType(label) == null) continue; + List nested = getRecordNested(label); + if (nested == null) continue; + + // Collect names referenced in the guard + Set guardNames = new HashSet<>(); + new TreeScanner() { + @Override + public Void visitIdentifier(IdentifierTree node, Void v) { + if (node instanceof JCIdent) guardNames.add(((JCIdent) node).name); + return null; + } + }.scan(guard, null); + + Map varMap = new HashMap<>(); + if (guardPreVars == null) guardPreVars = new HashMap<>(); + guardPreVars.put(c, varMap); + + for (JCTree np : nested) { + JCVariableDecl pv = getPatternVar(np); + if (pv == null || pv.name == null || pv.vartype == null) continue; + if (!guardNames.contains(pv.name)) continue; + + JCVariableDecl v = makeVarDef( + recordTempName(), + symTypeOf(pv), + makeDefaultValue(pv.vartype), + 0 + ); + varMap.put(pv.name, v.sym); + out.append(v); + } + } + } + } + + // end region + // region making + + private JCMethodInvocation makeMethodCall(JCExpression receiver, Name methodName) { + return makeMethodCall(receiver, findMethod(receiver.type, methodName, List.nil()), List.nil()); + } + + private JCMethodInvocation makeMethodCall(JCExpression receiver, MethodSymbol meth) { + return makeMethodCall(receiver, meth, List.nil()); + } + + private JCMethodInvocation makeMethodCall(JCExpression receiver, MethodSymbol meth, List args) { + JCFieldAccess fa = make.Select(copier.copy(receiver), meth.name); + fa.sym = meth; + fa.type = meth.erasure(types); + JCMethodInvocation call = make.Apply(List.nil(), fa, args); + call.type = types.erasure(meth.getReturnType()); + return call; + } + + private JCNewClass makeNewClass(MethodSymbol ctor, List args) { + JCNewClass tree = make.NewClass(null, null, make.QualIdent(ctor.owner), args, null); + tree.constructor = ctor; + tree.constructorType = ctor.erasure(types); + tree.type = ctor.owner.type; + return tree; + } + + private JCInstanceOf makeTypeTest(JCExpression lhs, JCExpression type) { + JCInstanceOf tree = make.TypeTest(copier.copy(lhs), type); + tree.type = syms.booleanType; + return tree; + } + + /** Only with int types. */ + private JCConditional makeConditional(JCExpression cond, JCExpression thenpart, JCExpression elsepart) { + JCConditional c = make.Conditional(cond, thenpart, elsepart); + c.type = syms.intType; + return c; + } + + private JCBinary makeBinary(JCTree.Tag optag, JCExpression lhs, JCExpression rhs) { + JCBinary tree = make.Binary(optag, copier.copy(lhs), rhs); // only copy left side + tree.operator = resolveBinary(ops, tree, optag, lhs.type, rhs.type); + if (tree.operator != null) tree.type = types.erasure(tree.operator.type.getReturnType()); + return tree; + } + + private JCVariableDecl makeVarDef(Name name, Type type, JCExpression init) { + return makeVarDef(name, type, init, Flags.FINAL); + } + + private JCVariableDecl makeVarDef(Name name, Type type, JCExpression init, long addFlags) { + return make.VarDef(new VarSymbol(Flags.SYNTHETIC | addFlags, name, type, currentMethodSym), init); + } + + private JCExpression makeCast(JCExpression expr, Type type) { + if (expr.type != null && types.isSubtype(expr.type, type)) return expr; + return make.at(expr.pos()).TypeCast(make.Type(type), copier.copy(expr)).setType(type); + } + + private JCParens makeAssignParens(VarSymbol sym, JCExpression expr) { + JCExpression assign = make.Assign(make.Ident(sym), expr).setType(sym.type); + JCExpression parens = make.Parens(assign).setType(assign.type); + return (JCParens) parens; + } + + private JCExpression makeDefaultValue(JCExpression type) { + if (!(type instanceof JCPrimitiveTypeTree)) return makeNull(); + switch (((JCPrimitiveTypeTree) type).getPrimitiveTypeKind()) { + case BOOLEAN: + return make.Literal(false); + case LONG: + return make.Literal(0L); + case FLOAT: + return make.Literal(0.0f); + case DOUBLE: + return make.Literal(0.0d); + default: + return make.Literal(0); + } + } + + private JCLiteral makeNull() { + return make.Literal(TypeTag.BOT, null).setType(syms.botType); + } + + private JCStatement makeMatchExceptionThrow(JCExpression sel, boolean useType) { + resolveMethodsCache(); + return make.Throw(useMatchException + ? makeNewClass(me_ctor, List.of(makeNull(), makeNull())) + : makeNewClass( + icce_ctor, + List.of(makeBinary( + Tag.PLUS, + make.Literal("MatchException: Unhandled case: "), + useType + ? makeMethodCall(makeMethodCall(sel, object_getClass), class_getName) + : sel + )) + ) + ); + } + + private JCCase makeStatementCase(JCTree label, List body) { + // Use an empty label list for JDK < 17 + List labels = label != null ? List.of(label) : List.nil(); + return makeCase(CaseTree.CaseKind.STATEMENT, labels, body, null); + } + + // end region + // region other + + // TODO: need to find a way to get an Env, to use Resolve instead + private MethodSymbol findMethod(Type type, Name name, List paramTypes) { + if (type == null || type.tsym == null) return null; + int pSize = paramTypes.size(); + outer: + for (Symbol s : type.tsym.members().getSymbolsByName(name)) { + if (!(s instanceof MethodSymbol)) continue; + MethodSymbol ms = (MethodSymbol) s; + if (ms.params().size() != pSize) continue; + List n = paramTypes; + for (VarSymbol p : ms.params()) { + if (!types.isSameType(p.type, n.head)) continue outer; + n = n.tail; + } + return ms; + } + //TODO use Log instead? But this should never fail as it's only used to resolve things from Java API, + // and record components, which have already been resolved by the compiler. + System.err.println("[jabel] Failed to resolve " + type + "." + name + "(" + paramTypes + ")"); + return null; + } + + private void resolveMethodsCache() { + if (cacheResolved) return; + cacheResolved = true; + + object_equals = findMethod(syms.objectType, names.equals, List.of(syms.objectType)); + object_toString = findMethod(syms.objectType, names.toString, List.nil()); + object_getClass = findMethod(syms.objectType, names.getClass, List.nil()); + class_getName = findMethod(syms.classType, names.fromString("getName"), List.nil()); + enum_ordinal = findMethod(syms.enumSym.type, names.ordinal, List.nil()); + float_floatToIntBits = findMethod( + types.boxedClass(syms.floatType).type, + names.fromString("floatToIntBits"), + List.of(syms.floatType) + ); + double_doubleToLongBits = findMethod( + types.boxedClass(syms.doubleType).type, + names.fromString("doubleToLongBits"), + List.of(syms.doubleType) + ); + icce_ctor = findMethod(syms.incompatibleClassChangeErrorType, names.init, List.of(syms.stringType)); + if (!MATCH_EXCEPTION_PRESENT) return; + me_ctor = findMethod(syms.matchExceptionType, names.init, List.of(syms.stringType, syms.throwableType)); + // Use MatchException if available in classpath + useMatchException = ((ClassSymbol) syms.matchExceptionType.tsym).classfile != null; + } + + private Type symTypeOf(JCVariableDecl pv) { + if (pv.sym != null && pv.sym.type != null) return pv.sym.type; + if (pv.vartype != null && pv.vartype.type != null) return pv.vartype.type; + return syms.objectType; + } + + private JCExpression assignSwitchSelector(JCExpression cond, VarSymbol selSym, JCExpression rawSel) { + return new TreeTranslator() { + boolean first = true; + @Override + public void visitIdent(JCIdent id) { + result = first && id.sym == selSym ? makeAssignParens(selSym, rawSel) : id; + first = false; + } + }.translate(cond); + } + + /** Recursively collects {@link JCMethodInvocation} nodes from an expression init. */ + private static void collectMethodInvocations(JCTree tree, Set out) { + if (tree instanceof JCMethodInvocation) { + out.add((JCMethodInvocation) tree); + } else if (tree instanceof JCTypeCast) { + collectMethodInvocations(((JCTypeCast) tree).expr, out); + } + } + + private Name switchTempName() { + return names.fromString("$switch$" + tempVarCounter++); + } + + private Name recordTempName() { + return names.fromString("$record$" + tempVarCounter++); + } + + // end region + + /** {@link TreeCopier} that preserves types and symbols. This does not includes declarations. */ + public static class SymTreeCopier extends TreeCopier { + public SymTreeCopier(TreeMaker M) { + super(M); + } + + @Override + public Z copy(Z tree, T p) { + Z fresh = super.copy(tree, p); + if (fresh != null && fresh != tree && tree instanceof JCExpression && fresh instanceof JCExpression) { + ((JCExpression) fresh).type = ((JCExpression) tree).type; + } + return fresh; + } + + @Override + public JCTree visitBinary(BinaryTree node, T p) { + JCTree fresh = super.visitBinary(node, p); + ((JCBinary) fresh).operator = ((JCBinary) node).operator; + return fresh; + } + + @Override + public JCTree visitUnary(UnaryTree node, T p) { + JCTree fresh = super.visitUnary(node, p); + ((JCUnary) fresh).operator = ((JCUnary) node).operator; + return fresh; + } + + @Override + public JCTree visitIdentifier(IdentifierTree node, T p) { + JCTree fresh = super.visitIdentifier(node, p); + ((JCIdent) fresh).sym = ((JCIdent) node).sym; + return fresh; + } + + @Override + public JCTree visitNewClass(NewClassTree node, T p) { + JCTree fresh = super.visitNewClass(node, p); + ((JCNewClass) fresh).constructor = ((JCNewClass) node).constructor; + ((JCNewClass) fresh).constructorType = ((JCNewClass) node).constructorType; + return fresh; + } + + @Override + public JCTree visitMemberSelect(MemberSelectTree node, T p) { + JCTree fresh = super.visitMemberSelect(node, p); + ((JCFieldAccess) fresh).sym = ((JCFieldAccess) node).sym; + return fresh; + } + + @Override + public JCTree visitMemberReference(MemberReferenceTree node, T p) { + JCTree fresh = super.visitMemberReference(node, p); + ((JCMemberReference) fresh).sym = ((JCMemberReference) node).sym; + return fresh; + } + } +}