From b95e71f218541794da71d5a6bc7e89aedf402342 Mon Sep 17 00:00:00 2001 From: ZetaMap Date: Thu, 26 Mar 2026 20:18:59 +0100 Subject: [PATCH 1/3] added syntax support for java 17-25 Additionnal changes: * @Desugar annotation removed as it's completely useless and annoying * ByteBuddy updated to support Java 25 * Fixed bugs with record retrofitting * Added lot of example of usecase (need help include them in README, and to add them properly to tests) * Updated description * Updated java version in build.gradle files * Plugin tested on JDK 8, 9, 12, 16, 17, 19, 21, 25 --- .github/workflows/ci.yml | 2 +- .github/workflows/release.yml | 4 +- .gitignore | 3 +- README.md | 6 +- example/build.gradle | 5 +- .../com/example/Java10FeaturesExample.java | 44 ++ .../com/example/Java11FeaturesExample.java | 29 + .../com/example/Java14FeaturesExample.java | 123 ++++ .../com/example/Java15FeaturesExample.java | 48 ++ .../com/example/Java16FeaturesExample.java | 112 +++ .../com/example/Java17FeaturesExample.java | 59 ++ .../com/example/Java21FeaturesExample.java | 256 +++++++ .../com/example/Java22FeaturesExample.java | 48 ++ .../com/example/Java25FeaturesExample.java | 100 +++ .../com/example/Java25FeaturesExample2.java | 50 ++ .../com/example/Java9FeaturesExample.java | 77 ++ example/src/main/java/com/example/Main.java | 76 ++ .../main/java/com/example/RecordExample.java | 2 - .../java/com/example/JavaFeaturesTest.java | 14 + .../java/com/example/RecordExampleTest.java | 4 +- gradle/publishing.gradle | 2 +- gradle/wrapper/gradle-wrapper.properties | 2 +- jabel-javac-plugin/build.gradle | 10 +- .../com/github/bsideup/jabel/Desugar.java | 14 - .../FlexibleMainRetrofittingTaskListener.java | 212 ++++++ .../ImplicitClassesFixerTaskListener.java | 80 ++ .../InstanceofRetrofittingTaskListener.java | 144 ++++ .../bsideup/jabel/JabelCompilerPlugin.java | 157 ++-- .../bsideup/jabel/RecordPatternHelper.java | 240 ++++++ .../RecordsRetrofittingTaskListener.java | 678 ++++++++--------- .../jabel/SwitchRetrofittingTaskListener.java | 683 ++++++++++++++++++ 31 files changed, 2794 insertions(+), 490 deletions(-) create mode 100644 example/src/main/java/com/example/Java10FeaturesExample.java create mode 100644 example/src/main/java/com/example/Java11FeaturesExample.java create mode 100644 example/src/main/java/com/example/Java14FeaturesExample.java create mode 100644 example/src/main/java/com/example/Java15FeaturesExample.java create mode 100644 example/src/main/java/com/example/Java16FeaturesExample.java create mode 100644 example/src/main/java/com/example/Java17FeaturesExample.java create mode 100644 example/src/main/java/com/example/Java21FeaturesExample.java create mode 100644 example/src/main/java/com/example/Java22FeaturesExample.java create mode 100644 example/src/main/java/com/example/Java25FeaturesExample.java create mode 100644 example/src/main/java/com/example/Java25FeaturesExample2.java create mode 100644 example/src/main/java/com/example/Java9FeaturesExample.java create mode 100644 example/src/main/java/com/example/Main.java create mode 100644 example/src/test/java/com/example/JavaFeaturesTest.java delete mode 100644 jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/Desugar.java create mode 100644 jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/FlexibleMainRetrofittingTaskListener.java create mode 100644 jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/ImplicitClassesFixerTaskListener.java create mode 100644 jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/InstanceofRetrofittingTaskListener.java create mode 100644 jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordPatternHelper.java create mode 100644 jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/SwitchRetrofittingTaskListener.java diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b53339c..4e25066 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,6 +11,6 @@ jobs: - name: Set up JDK uses: actions/setup-java@v1 with: - java-version: 16 + java-version: 25 - name: Build with Gradle run: ./gradlew build diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1308358..204cefd 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,10 +14,10 @@ jobs: - name: Set up JDK uses: actions/setup-java@v1 with: - java-version: 14 + java-version: 25 - name: Deploy with Gradle env: GRADLE_PUBLISH_REPO_URL: ${{ secrets.GRADLE_PUBLISH_REPO_URL }} GRADLE_PUBLISH_MAVEN_USER: ${{ secrets.GRADLE_PUBLISH_MAVEN_USER }} GRADLE_PUBLISH_MAVEN_PASSWORD: ${{ secrets.GRADLE_PUBLISH_MAVEN_PASSWORD }} - run: ./gradlew --no-daemon -Pversion=$(git tag --points-at HEAD) publish \ No newline at end of file + run: ./gradlew --no-daemon -Pversion=$(git tag --points-at HEAD) publish diff --git a/.gitignore b/.gitignore index fd0c699..704574d 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,5 @@ out/ .project .classpath -.settings/ \ No newline at end of file +.settings/ +/bin/ diff --git a/README.md b/README.md index 398073e..e6e35be 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Jabel - use modern Java 9-14 syntax when targeting Java 8 +# Jabel - use modern Java 9-25 syntax when targeting Java 8 > Because life is too short to wait for your users to upgrade their Java! @@ -74,8 +74,8 @@ Jabel has to be enabled as a Javac plugin in your maven-compiler-plugin: 8 - 14 - 14 + 25 + 25 -Xplugin:jabel diff --git a/example/build.gradle b/example/build.gradle index 73d4c1d..a361ac5 100644 --- a/example/build.gradle +++ b/example/build.gradle @@ -3,11 +3,11 @@ plugins { } configure([tasks.compileJava]) { - sourceCompatibility = 16 + sourceCompatibility = 25 options.release = 8 javaCompiler = javaToolchains.compilerFor { - languageVersion = JavaLanguageVersion.of(21) + languageVersion = JavaLanguageVersion.of(25) } } @@ -23,7 +23,6 @@ test { dependencies { annotationProcessor project(":jabel-javac-plugin") - compileOnly project(":jabel-javac-plugin") testImplementation 'junit:junit:4.13.2' } diff --git a/example/src/main/java/com/example/Java10FeaturesExample.java b/example/src/main/java/com/example/Java10FeaturesExample.java new file mode 100644 index 0000000..d749463 --- /dev/null +++ b/example/src/main/java/com/example/Java10FeaturesExample.java @@ -0,0 +1,44 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +import java.util.ArrayList; +import java.util.HashMap; + +/** + * Examples of Java 10 features desugared by the compiler:
+ * + * LOCAL_VARIABLE_TYPE_INFERENCE (var) + *
+ * // Source (Java 10+):
+ * var str = "hello";
+ * var list = new ArrayList<String>();
+ * for (var i = 0; i < 10; i++) { }
+ * for (var item : list) { }
+ *
+ * // Decompiled (Java 8):
+ * String str = "hello";
+ * ArrayList<String> list = new ArrayList<String>();
+ * for (int i = 0; i < 10; i++) { }
+ * for (String item : list) { }
+ * 
+ */ +public class Java10FeaturesExample { + + void varExamples() { + var str = "hello"; + var num = 42; + var list = new ArrayList(); + var map = new HashMap(); + + list.add("item"); + map.put("key", 1); + + for (var i = 0; i < 3; i++) { + list.add("item" + i); + } + for (var item : list) { + System.out.println(item); + } + } +} diff --git a/example/src/main/java/com/example/Java11FeaturesExample.java b/example/src/main/java/com/example/Java11FeaturesExample.java new file mode 100644 index 0000000..7cf0ffb --- /dev/null +++ b/example/src/main/java/com/example/Java11FeaturesExample.java @@ -0,0 +1,29 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +import java.util.function.BiFunction; + +/** + * Examples of Java 11 features desugared by the compiler:
+ * + * VAR_SYNTAX_IMPLICIT_LAMBDAS + *
+ * // Source (Java 11+):
+ * BiFunction<String, String, String> f = (var a, var b) -> a + b;
+ * // Allows annotations:
+ * Consumer<String> c = (@NonNull var s) -> System.out.println(s);
+ *
+ * // Decompiled (Java 8):
+ * BiFunction<String, String, String> f = (a, b) -> a + b;
+ * // or with explicit types:
+ * BiFunction<String, String, String> f = (String a, String b) -> a + b;
+ * 
+ */ +public class Java11FeaturesExample { + + BiFunction concat = (var a, var b) -> a + b; + + BiFunction add = + (var x, var y) -> x + y; +} diff --git a/example/src/main/java/com/example/Java14FeaturesExample.java b/example/src/main/java/com/example/Java14FeaturesExample.java new file mode 100644 index 0000000..c477b73 --- /dev/null +++ b/example/src/main/java/com/example/Java14FeaturesExample.java @@ -0,0 +1,123 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +/** + * Examples of Java 14 features desugared by the compiler:
+ * + * SWITCH_MULTIPLE_CASE_LABELS + *
+ * // Source (Java 14+):
+ * case SAT, SUN -> "weekend";
+ *
+ * // Decompiled (Java 8):
+ * case SAT:
+ * case SUN:
+ *     return "weekend";
+ * 
+ *

+ * SWITCH_RULE (arrow syntax) + *

+ * // Source (Java 14+):
+ * case MON -> "monday";
+ *
+ * // Decompiled (Java 8):
+ * case MON:
+ *     return "monday";
+ * 
+ *

+ * SWITCH_EXPRESSION + *

+ * // Source (Java 14+):
+ * String result = switch (day) {
+ *     case MON -> "monday";
+ *     default -> { yield "other"; }
+ * };
+ *
+ * // Decompiled (Java 8):
+ * String result;
+ * switch (day) {
+ *     case MON:
+ *         result = "monday";
+ *         break;
+ *     default:
+ *         result = "other";
+ * }
+ * 
+ */ +public class Java14FeaturesExample { + + enum Day { MON, TUE, WED, THU, FRI, SAT, SUN } + + class Builder{ + public int n; + public Builder build(Object o){ + n++; + return this; + } + } + + String switchMultipleLabels(Day day) { + return switch (day) { + case SAT, SUN -> "weekend"; + case MON, TUE, WED, THU, FRI -> "weekday"; + }; + } + + String switchRule(Day day) { + return (switch (day) { + case MON -> "monday"; + case TUE -> "tuesday"; + default -> "other"; + }).toLowerCase(); + } + + int switchExpressionWithYield(Day day) { + return switch (day) { + case MON -> 1; + case TUE -> 2; + default -> { + int result = day.ordinal(); + yield result; + } + }; + } + + void statementSwitchWithArrow(Day day) { + switch (day) { + case SAT -> System.out.println("Saturday"); + case SUN -> System.out.println("Sunday"); + default -> System.out.println("Weekday"); + } + } + + int[] asArrayElements(int a, int b) { + return new int[]{ + switch (a) { case 1 -> 10; default -> 0; }, + switch (b) { case 2 -> 20; default -> 0; } + }; + } + + String chained(String str) { + return switch ( + switch (str) { + case "test": { + str = "pok"; + yield "A"; + } + default: yield "B"; + } + ) { + case "A" -> str.isEmpty() ? "got-A" : "err"; + default -> "got-B"; + }; + } + + int twiceOf(int n) { + return new Builder().build(1).build(switch (n) { + case 1 -> 10; + case 2 -> 20; + default -> n * n; + }).build(4).n; + } +} diff --git a/example/src/main/java/com/example/Java15FeaturesExample.java b/example/src/main/java/com/example/Java15FeaturesExample.java new file mode 100644 index 0000000..46be666 --- /dev/null +++ b/example/src/main/java/com/example/Java15FeaturesExample.java @@ -0,0 +1,48 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +/** + * Examples of Java 15 features desugared by the compiler:
+ * + * TEXT_BLOCKS + *
+ * // Source (Java 15+):
+ * String s = """
+ *     Hello
+ *     World
+ *     """;
+ * String cont = """
+ *     Single \
+ *     line""";
+ *
+ * // Decompiled (Java 8):
+ * String s = "Hello\nWorld\n";
+ * String cont = "Single line";
+ * 
+ */ +public class Java15FeaturesExample { + + String basic = """ + Hello + World + """; + + String json = """ + {"name": "test", "value": 42} + """; + + String withTrailingSpace = """ + Line with space \s + Next line + """; + + String lineContinuation = """ + Single \ + line"""; + + String formatted = String.format(""" + Name: %s + Age: %d + """, "Alice", 25); +} diff --git a/example/src/main/java/com/example/Java16FeaturesExample.java b/example/src/main/java/com/example/Java16FeaturesExample.java new file mode 100644 index 0000000..4ec031d --- /dev/null +++ b/example/src/main/java/com/example/Java16FeaturesExample.java @@ -0,0 +1,112 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +import java.util.List; +import java.util.Objects; + +/** + * Examples of Java 16 features desugared by the compiler:
+ * + * PATTERN_MATCHING_IN_INSTANCEOF + *
+ * // Source (Java 16+):
+ * if (obj instanceof String s) {
+ *     System.out.println(s.length());
+ * }
+ *
+ * // Decompiled (Java 8):
+ * if (obj instanceof String) {
+ *     String s = (String) obj;
+ *     System.out.println(s.length());
+ * }
+ * 
+ *

+ * RECORDS + *

+ * // Source (Java 16+):
+ * record Point(int x, int y) { }
+ *
+ * // Decompiled (Java 8):
+ * class Point {
+ *     private final int x;
+ *     private final int y;
+ *     Point(int x, int y) { this.x = x; this.y = y; }
+ *     int x() { return x; }
+ *     int y() { return y; }
+ *     public boolean equals(Object o) { ... }
+ *     public int hashCode() { ... }
+ *     public String toString() { ... }
+ * }
+ * 
+ */ +public class Java16FeaturesExample { + record Pair(A first, B second) {} + + record Point(int x, int y) { + public Point(int x, int y) { + if (x > Short.MAX_VALUE) throw new IllegalArgumentException(); + this.x = x; + if (y > Short.MAX_VALUE) throw new IllegalArgumentException(); + this.y = y; + } + } + + record Person(String name, int age) { + static Object planet; + + Person { + Objects.requireNonNull(name); + if (age < 0) throw new IllegalArgumentException(); + } + } + + record MultipleTypes(boolean bool, byte b, short s, int i, long l, + float f, double d, String str, Point p) { + @Override + public boolean equals(java.lang.Object o) { + return this == o; + } + + @Override + public String str() { + return "nothing"; + } + } + + void patternMatchingInstanceof(Object obj) { + if (obj instanceof String s) { + System.out.println(s.length()); + } + + if (obj instanceof String s && s.length() > 3) { + System.out.println(s.toUpperCase()); + } + + int len = obj instanceof String s ? s.length() : -1; + Integer.valueOf(len); // use variable to get expected byte code + } + + int inlinePatternMatching(Object obj) { + return obj instanceof String s ? s.length() : obj instanceof Integer i ? i : -1; + } + + void reifiableTypesInstanceof(Object obj) { + if (!(obj instanceof List list)) { + System.out.println("Not a list"); + return; + } + System.out.println(list.size()); + } + + void recordsExample() { + var p1 = new Point(3, 4); + var p2 = new Point(3, 4); + + int x = p1.x(); + int y = p1.y(); + boolean equal = p1.equals(p2); + int hash = p1.hashCode(); + String str = p1.toString(); + } +} diff --git a/example/src/main/java/com/example/Java17FeaturesExample.java b/example/src/main/java/com/example/Java17FeaturesExample.java new file mode 100644 index 0000000..1c1c55b --- /dev/null +++ b/example/src/main/java/com/example/Java17FeaturesExample.java @@ -0,0 +1,59 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +/** + * Examples of Java 17 features desugared by the compiler:
+ * + * SEALED_CLASSES + *
+ * // Source (Java 17+):
+ * public sealed class Shape permits Circle, Square { }
+ * public final class Circle extends Shape { }
+ * public non-sealed class Square extends Shape { }
+ *
+ * // Decompiled (Java 8):
+ * public abstract class Shape { }  // sealed, permits removed
+ * public final class Circle extends Shape { }
+ * public class Square extends Shape { }  // non-sealed removed
+ * // Note: PermittedSubclasses bytecode attribute must be removed
+ * 
+ *

+ * REDUNDANT_STRICTFP + *

+ * // Source (Java 17+):
+ * public strictfp class Math { }
+ *
+ * // Decompiled (Java 8): same code (strictfp kept but redundant since Java 17)
+ * 
+ */ +public class Java17FeaturesExample { + + sealed class Shape permits Circle, Square, Rectangle { + private final String name; + Shape(String name) { this.name = name; } + String getName() { return name; } + } + + final class Circle extends Shape { + Circle() { super("Circle"); } + void test(){ + Square.test(); + } + } + + final class Square extends Shape { + Square() { super("Square"); } + static void test() {} + } + + non-sealed class Rectangle extends Shape { + Rectangle() { super("Rectangle"); } + } + + class SpecialRectangle extends Rectangle {} + + strictfp double strictMethod(double a, double b) { + return a * b + a / b; + } +} diff --git a/example/src/main/java/com/example/Java21FeaturesExample.java b/example/src/main/java/com/example/Java21FeaturesExample.java new file mode 100644 index 0000000..cc29c8a --- /dev/null +++ b/example/src/main/java/com/example/Java21FeaturesExample.java @@ -0,0 +1,256 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + + +/** + * Examples of Java 21 features with manual desugaring:
+ * + * CASE_NULL + *
+ * // Source (Java 21+):
+ * switch (str) {
+ *     case null -> "null";
+ *     case "a" -> "A";
+ *     default -> "other";
+ * }
+ *
+ * // Decompiled (Java 8):
+ * if (str == null) {
+ *     return "null";
+ * }
+ * switch (str) {
+ *     case "a": return "A";
+ *     default: return "other";
+ * }
+ * 
+ *

+ * PATTERN_SWITCH + *

+ * // Source (Java 21+):
+ * switch (obj) {
+ *     case String s -> s.length();
+ *     case Integer i -> i;
+ *     default -> 0;
+ * }
+ *
+ * // Decompiled (Java 8):
+ * if (obj instanceof String) {
+ *     String s = (String) obj;
+ *     return s.length();
+ * } else if (obj instanceof Integer) {
+ *     Integer i = (Integer) obj;
+ *     return i;
+ * } else {
+ *     return 0;
+ * }
+ * 
+ *

+ * RECORD_PATTERNS + *

+ * // Source (Java 21+):
+ * if (obj instanceof Point(int x, int y)) {
+ *     return x + y;
+ * }
+ *
+ * // Decompiled (Java 8):
+ * if (obj instanceof Point) {
+ *     Point p = (Point) obj;
+ *     int x = p.x();
+ *     int y = p.y();
+ *     return x + y;
+ * }
+ * 
+ *

+ * UNCONDITIONAL_PATTERN_IN_INSTANCEOF + *

+ * // Source (Java 21+):
+ * if (str instanceof CharSequence cs) { }  // always true for non-null
+ *
+ * // Decompiled (Java 8):
+ * if (str != null) {
+ *     CharSequence cs = str;
+ * }
+ * 
+ */ +public class Java21FeaturesExample { + + sealed interface Geometry permits Point, Triangle, Shape{} + record Point(int x, int y) implements Geometry {} + record Triangle(Point a, Point b, Point c) implements Geometry {} + record Shape(String name, Triangle triangle) implements Geometry {} + + class Builder{ + public int n; + public Builder build(Object o){ + n++; + return this; + } + } + + String caseNull(String input) { + return switch (input) { + case "hello" -> "hello"; + case null -> "null"; + default -> "other"; + }; + } + + String caseNullWithDefault(String input) { + return switch (input.toString()) { + case "specific" -> "specific"; + case null, default -> "null or default"; + }; + } + + String patternSwitch(Object obj) { + return switch (obj) { + case String s -> "String: " + s; + case Integer i -> "Integer: " + i; + default -> throw new IllegalArgumentException(obj.toString()); + }; + } + + String patternSwitchVariable(Object obj) { + String var = "ignored part:" + switch (obj) { + case String s -> var = "String: " + s; + case null -> var = "null"; + case Integer i -> var = "Integer: " + i; + default -> var = "Other"; + }; + var.toString(); // fake a use to avoid optimizations + return var; + } + + String patternSwitchWithGuard(Object obj) { + return (switch (obj) { + case String s when s.isEmpty() -> "empty"; + case String s when s.length() < 5 -> "short"; + case Integer i -> "int"; + case String s when s.length() > 3 && s.startsWith("a") -> "long-a"; + case String s -> "long"; + default -> "not a string"; + }).trim(); + } + + String patternSwitchExpressionWithGuard(Object obj) { + String var; + switch (obj) { + case String s when s.isEmpty() -> var = "empty"; + case String s when s.length() < 5 -> var = "short"; + case Integer i -> var = "int"; + case String s -> var = "long"; + default -> var = "not a string"; + } + var.toString(); // fake a use to avoid optimizations + return var; + } + + String patternSwitchStatement(Object obj) { + String result; + switch (obj) { + case String s: + result = "String: " + s; + break; + case Integer i: + result = "Int: " + i; + break; + default: + result = "Other"; + break; + } + return result; + } + + int switchYield(Object obj) { + return switch (obj) { + case String s -> { + int len = s.length(); + yield len * 2; + } + case Integer i -> { + yield i + 1; + } + default -> 0; + }; + } + + String nestedSwitch(Object outer, Object inner) { + return switch (outer) { + case String s -> switch (inner) { + case Integer i -> "String+Int: " + s + i; + default -> switch (inner) { + case Float f -> "String+Float: " + s + f; + default -> "String+Other"; + }; + }; + case Integer i -> "int"; + default -> "Other"; + }; + } + + int inMethod(Object obj) { + return new Builder().build(1).build(switch (obj) { + case null -> -1; + case String s when s.length() > 5 -> s.length()-5; + case String s -> s.length(); + case Integer i -> i; + default -> obj.hashCode(); + }).build(4).n; + } + + void switchInIf(Object obj) { + if (switch (obj) { + case null -> -1; + case String s when s.length() > 5 -> s.length()-5; + case String s -> s.length(); + case Integer i -> i; + default -> obj.hashCode(); + } > 0) { + System.out.println("Working"); + } + } + + // Record patterns in instanceof + void recordPatternInstanceof(Object obj) { + if (obj instanceof Point(int x, int y)) { + System.out.println(x + y); + } + } + + void nestedRecordPattern(Object obj) { + if (obj instanceof Shape(var name, Triangle(var p, Point(int bx, int by), Point(int cx, int cy)))) { + System.out.println("Shape " + name + " has triangle with vertices: " + + "(" + p.x() + "," + p.y() + "), (" + bx + "," + by + "), (" + cx + "," + cy + ")"); + } + } + + int recordPatternSwitchStatement(Geometry obj) { + return switch (obj) { + case Point(int x, int y): yield x + y; + case Triangle(Point p, Point(int bx, int by), Point(int cx, int cy)): { + int sum = p.x() * p.y(); + sum += bx * by; + sum += cx * cy; + yield sum; + } + case Shape s: yield 0; + // No default case since class is sealed. + }; + } + + void unconditionalPattern(String str) { + // Unconditional pattern - String is always a CharSequence (for non-null) + if (str instanceof CharSequence cs) { + System.out.println(cs.length()); + } + } + + // Unconditional with inferred type + void unconditionalPatternObject(T str) { + if (str instanceof Shape o) { + System.out.println(o.hashCode()); + } + } +} + diff --git a/example/src/main/java/com/example/Java22FeaturesExample.java b/example/src/main/java/com/example/Java22FeaturesExample.java new file mode 100644 index 0000000..ff11c7c --- /dev/null +++ b/example/src/main/java/com/example/Java22FeaturesExample.java @@ -0,0 +1,48 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +import java.util.Arrays; +import java.util.List; + +/** + * Examples of Java 22 features desugared by the compiler:
+ * + * UNNAMED_VARIABLES + *
+ * // Source (Java 22+):
+ * for (var _ : list) { count++; }
+ * try { } catch (Exception _) { }
+ * list.forEach(_ -> count++);
+ *
+ * // Decompiled (Java 8):
+ * for (Object $unused : list) { count++; }
+ * try { } catch (Exception $unused) { }
+ * list.forEach($unused -> count++);
+ * 
+ */ +public class Java22FeaturesExample { + + void unnamedInForEach() { + var list = Arrays.asList("a", "b", "c"); + int count = 0; + + for (var _ : list) { + count++; + } + } + + void unnamedInCatch() { + try { + throw new RuntimeException(); + } catch (RuntimeException _) { + System.out.println("caught"); + } + } + + void unnamedInLambda() { + var list = Arrays.asList("a", "b", "c"); + var count = new int[]{0}; + list.forEach(_ -> count[0]++); + } +} diff --git a/example/src/main/java/com/example/Java25FeaturesExample.java b/example/src/main/java/com/example/Java25FeaturesExample.java new file mode 100644 index 0000000..3c64375 --- /dev/null +++ b/example/src/main/java/com/example/Java25FeaturesExample.java @@ -0,0 +1,100 @@ +// Examples made by Claude Opus 4.6 + +package com.example; + +/** + * Examples of Java 25 flexible constructors with manual desugaring:
+ * + * FLEXIBLE_CONSTRUCTORS + *
+ * // Source (Java 25+):
+ * class Child extends Parent {
+ *     Child(int value) {
+ *         if (value < 0) throw new IllegalArgumentException();
+ *         int processed = value * 2;
+ *         super(processed);
+ *     }
+ * }
+ *
+ * // Decompiled (Java 8):
+ * class Child extends Parent {
+ *     Child(int value) {
+ *         super($prologue$0(value));
+ *     }
+ *     private static int $prologue$0(int value) {
+ *         if (value < 0) throw new IllegalArgumentException();
+ *         return value * 2;
+ *     }
+ * }
+ * 
+ */ +public class Java25FeaturesExample { + String str; + + void main() { + System.out.println("Bridged entry point"); + str = ""; + } + + static class Parent { + final int value; + Parent(int value) { this.value = value; } + } + + // Complex loop in prologue + static class ComplexPrologue extends Parent { + ComplexPrologue(int n) { + if (n < 0) throw new IllegalArgumentException(); + int sum = 0; + for (int i = 0; i < n; i++) { + if (i % 2 == 0) { + sum += i; + } + } + super(sum > 100 ? 100 : sum); + } + } + + // Epilogue reuses multiple prologue variables + static class ManyLocals extends Parent { + final String label; + final double ratio; + ManyLocals(int a, int b) { + String lbl = a + "/" + b; + double r = b != 0 ? (double) a / b : 0.0; + int sum = a + b; + super(sum); + this.label = lbl; + this.ratio = r; + } + } + + // Existing Object[] / (Object[], Void) constructors must not clash + static class SpoofChild extends Parent { + boolean spoofed; + + SpoofChild(int value) { + super(value); + } + + // Spoof: (Object[]) + SpoofChild(Object[] args) { + boolean s = false; + if (args.length > 0) s = true; + this(0); + spoofed = s; + } + + // Spoof: (Object[], Void) + SpoofChild(Object[] args, Void v) { + this((int)(Integer) args[0]); + this.spoofed = true; + } + + // Another flex in the same spoofed class + SpoofChild(String s) { + int parsed = Integer.parseInt(s); + super(parsed); + } + } +} \ No newline at end of file diff --git a/example/src/main/java/com/example/Java25FeaturesExample2.java b/example/src/main/java/com/example/Java25FeaturesExample2.java new file mode 100644 index 0000000..a419644 --- /dev/null +++ b/example/src/main/java/com/example/Java25FeaturesExample2.java @@ -0,0 +1,50 @@ +/** + * Examples of Java 25 features with manual desugaring:
+ + * IMPLICIT_CLASSES + *
+ * // Source (Java 25+):
+ * // File: Main.java (no class declaration)
+ * void main() {
+ *     System.out.println("Hello");
+ * }
+ *
+ * // Decompiled (Java 8):
+ * public class Main {
+ *     public static void main(String[] args) {
+ *         new Main().main();
+ *     }
+ *     void main() {
+ *         System.out.println("Hello");
+ *     }
+ * }
+ * 
+ */ + +int variable = 1; + +static void main() { + System.out/*IO*/.println("Implicit classes work!"); + new ArrayList<>(); // Implicitly imported +} + +// Cannot be adapted as a standard entry point +void main(String... args) { + System.out/*IO*/.println("instance main with args."); + new Test(); +} + +private void privateMethod() { + main(); +} + + +interface TestInterface { + +} + +class Test implements TestInterface { + public Test() { + Integer.toString(variable); + } +} diff --git a/example/src/main/java/com/example/Java9FeaturesExample.java b/example/src/main/java/com/example/Java9FeaturesExample.java new file mode 100644 index 0000000..1cbf816 --- /dev/null +++ b/example/src/main/java/com/example/Java9FeaturesExample.java @@ -0,0 +1,77 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +import java.io.Closeable; +import java.util.ArrayList; +import java.util.List; + +/** + * Examples of Java 9 features desugared by the compiler:
+ * + * PRIVATE_SAFE_VARARGS + *
+ * // Source (Java 9+):
+ * @SafeVarargs
+ * private void method(List<String>... lists) { }
+ *
+ * // Decompiled (Java 8): same code, annotation preserved
+ * 
+ *

+ * DIAMOND_WITH_ANONYMOUS_CLASS_CREATION + *

+ * // Source (Java 9+):
+ * List<String> list = new ArrayList<>() { };
+ *
+ * // Decompiled (Java 8):
+ * List<String> list = new ArrayList<String>() { };
+ * 
+ *

+ * EFFECTIVELY_FINAL_VARIABLES_IN_TRY_WITH_RESOURCES + *

+ * // Source (Java 9+):
+ * Closeable c = ...;
+ * try (c) { }
+ *
+ * // Decompiled (Java 8):
+ * Closeable c = ...;
+ * try (Closeable c2 = c) { }
+ * 
+ *

+ * PRIVATE_INTERFACE_METHODS + *

+ * // Source (Java 9+):
+ * interface I {
+ *     private String helper() { return "!"; }
+ *     default String greet() { return "Hi" + helper(); }
+ * }
+ *
+ * // Decompiled (Java 8): same code (private methods in interfaces supported in bytecode)
+ * 
+ */ +public class Java9FeaturesExample { + + @SafeVarargs + private final void safeVarargsMethod(List... lists) { + for (var list : lists) System.out.println(list); + } + + List diamondWithAnonymous = new ArrayList<>() { + @Override + public boolean add(String s) { + return super.add(s.toUpperCase()); + } + }; + + void effectivelyFinalTryWithResources() throws Exception { + Closeable resource = () -> System.out.println("closed"); + try (resource) { + System.out.println("using resource"); + } + } + + interface WithPrivateMethods { + private String helper() { return "!"; } + default String greet(String name) { return "Hello " + name + helper(); } + } +} diff --git a/example/src/main/java/com/example/Main.java b/example/src/main/java/com/example/Main.java new file mode 100644 index 0000000..400fe80 --- /dev/null +++ b/example/src/main/java/com/example/Main.java @@ -0,0 +1,76 @@ +// Examples made by Claude Opus 4.5 + +package com.example; + +/** Main class to demonstrate all Java 9-25 features desugared by Jabel. */ +class Main { + public static void main(String[] args) { + System.out.println("=== Jabel Feature Examples ==="); + + System.out.println("\n--- Java 9 Features ---"); + Java9FeaturesExample java9 = new Java9FeaturesExample(); + java9.diamondWithAnonymous.add("test"); + System.out.println("Diamond with anonymous: " + java9.diamondWithAnonymous); + Java9FeaturesExample.WithPrivateMethods impl = new Java9FeaturesExample.WithPrivateMethods() {}; + System.out.println("Private interface method: " + impl.greet("World")); + + System.out.println("\n--- Java 10 Features ---"); + Java10FeaturesExample java10 = new Java10FeaturesExample(); + java10.varExamples(); + System.out.println("var keyword works!"); + + System.out.println("\n--- Java 11 Features ---"); + Java11FeaturesExample java11 = new Java11FeaturesExample(); + System.out.println("var in lambda: " + java11.concat.apply("hello", "world")); + + System.out.println("\n--- Java 14 Features ---"); + var java14 = new Java14FeaturesExample(); + System.out.println("Switch expression: " + java14.switchMultipleLabels(Java14FeaturesExample.Day.SAT)); + System.out.println("Switch with yield: " + java14.switchExpressionWithYield(Java14FeaturesExample.Day.WED)); + + System.out.println("\n--- Java 15 Features ---"); + var java15 = new Java15FeaturesExample(); + System.out.println("Text block: " + java15.basic.replace("\n", "\\n")); + + System.out.println("\n--- Java 16 Features ---"); + var java16 = new Java16FeaturesExample(); + java16.patternMatchingInstanceof("Hello"); + var point = new Java16FeaturesExample.Point(3, 4); + System.out.println("Record: " + point); + + System.out.println("\n--- Java 17 Features ---"); + var java17 = new Java17FeaturesExample(); + System.out.println("sealed classes works!"); + System.out.println("strictfp method: " + java17.strictMethod(3.14, 2.71)); + + System.out.println("\n--- Java 21 Features ---"); + var java21 = new Java21FeaturesExample(); + System.out.println("case null: " + java21.caseNull(null)); + System.out.println("pattern switch: " + java21.patternSwitch("test")); + System.out.println("record pattern: " + java21.recordPatternSwitchStatement(new Java21FeaturesExample.Point(5, 7))); + System.out.println("switch yield: " + java21.switchYield("hello")); + + System.out.println("\n--- Java 22 Features ---"); + var java22 = new Java22FeaturesExample(); + java22.unnamedInForEach(); + System.out.println("Unnamed variables work!"); + + System.out.println("\n--- Java 25 Features ---"); + // Inline: complex prologue + var dp = new Java25FeaturesExample.ComplexPrologue(10); + assert dp.value == 20 : "ComplexPrologue"; // 0+2+4+6+8 = 20 + System.out.println("ComplexPrologue: value=" + dp.value); + // General: multiple shared vars in epilogue + var ms = new Java25FeaturesExample.ManyLocals(10, 4); + assert ms.value == 14 && "10/4".equals(ms.label) && ms.ratio == 2.5 : "ManyLocals"; + System.out.println("ManyLocals: value=" + ms.value + " label=" + ms.label + " ratio=" + ms.ratio); + // Edge: tests constructor collisions + var sp = new Java25FeaturesExample.SpoofChild(99); + assert sp.value == 99 && !sp.spoofed : "SpoofChild"; + System.out.println("SpoofChild: value=" + sp.value + " spoofed=" + sp.spoofed); + //Implicit classes + //Java25FeaturesExample2.main(); // in theory we cannot reference the class + + System.out.println("\n=== All features work! ==="); + } +} diff --git a/example/src/main/java/com/example/RecordExample.java b/example/src/main/java/com/example/RecordExample.java index 3c36f36..f8be6dc 100644 --- a/example/src/main/java/com/example/RecordExample.java +++ b/example/src/main/java/com/example/RecordExample.java @@ -1,8 +1,6 @@ package com.example; -import com.github.bsideup.jabel.Desugar; -@Desugar public record RecordExample(int i, String s, long l, float f, double d, String[] arr, boolean b) { public static RecordExample DUMMY = new RecordExample(0, null, 0, 0, 0, null, true); diff --git a/example/src/test/java/com/example/JavaFeaturesTest.java b/example/src/test/java/com/example/JavaFeaturesTest.java new file mode 100644 index 0000000..f4bf334 --- /dev/null +++ b/example/src/test/java/com/example/JavaFeaturesTest.java @@ -0,0 +1,14 @@ +package com.example; + +import org.junit.Test; + +import static org.junit.Assert.*; + +public class JavaFeaturesTest { + + @Test + public void shouldWork() { + Main.main(new String[0]); + //TODO: Java25FeaturesExample2.main(new String[0]); + } +} \ No newline at end of file diff --git a/example/src/test/java/com/example/RecordExampleTest.java b/example/src/test/java/com/example/RecordExampleTest.java index 6e01cfb..5d2f290 100644 --- a/example/src/test/java/com/example/RecordExampleTest.java +++ b/example/src/test/java/com/example/RecordExampleTest.java @@ -21,7 +21,7 @@ public void testToString() { RecordExample r = new RecordExample(42, "yeah", 100500, 0.5f, 5d, new String[]{"Hello", "World!"}, true); assertEquals( - "RecordExample[i=42,s=yeah,l=100500,f=0.5,d=5.0,arr=[Ljava.lang.String;@hash,b=true]", + "RecordExample[i=42, s=yeah, l=100500, f=0.5, d=5.0, arr=[Ljava.lang.String;@hash, b=true]", Objects.toString(r).replaceAll(";@[a-f0-9]+", ";@hash") ); } @@ -31,7 +31,7 @@ public void testToStringWithNulls() { RecordExample r = new RecordExample(42, null, 100500, 0.5f, 5d, null, true); assertEquals( - "RecordExample[i=42,s=null,l=100500,f=0.5,d=5.0,arr=null,b=true]", + "RecordExample[i=42, s=null, l=100500, f=0.5, d=5.0, arr=null, b=true]", Objects.toString(r) ); } diff --git a/gradle/publishing.gradle b/gradle/publishing.gradle index ffc3b3b..cb1e056 100644 --- a/gradle/publishing.gradle +++ b/gradle/publishing.gradle @@ -3,7 +3,7 @@ plugins.withType(MavenPublishPlugin) { publications { mavenJava(MavenPublication) { publication -> pom { - description = 'Jabel - use modern Java 9-14 syntax when targeting Java 8.' + description = 'Jabel - use modern Java 9-25 syntax when targeting Java 8.' name = project.description ?: description url = 'https://github.com/bsideup/jabel' licenses { diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 8049c68..03b32a2 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.5-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.4-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/jabel-javac-plugin/build.gradle b/jabel-javac-plugin/build.gradle index 1725fc5..30de699 100644 --- a/jabel-javac-plugin/build.gradle +++ b/jabel-javac-plugin/build.gradle @@ -6,15 +6,15 @@ plugins { sourceCompatibility = targetCompatibility = 8 dependencies { - implementation platform('net.bytebuddy:byte-buddy-parent:1.14.9') + implementation platform('net.bytebuddy:byte-buddy-parent:1.18.7') implementation 'net.bytebuddy:byte-buddy' implementation 'net.bytebuddy:byte-buddy-agent' - implementation 'net.java.dev.jna:jna:5.13.0' + implementation 'net.java.dev.jna:jna:5.18.0' + implementation 'net.java.dev.jna:jna-platform:5.18.0' } - task sourcesJar(type: Jar) { - classifier 'sources' + archiveClassifier = 'sources' from sourceSets.main.allJava } @@ -24,7 +24,7 @@ javadoc { task javadocJar(type: Jar) { from javadoc - classifier = 'javadoc' + archiveClassifier = 'javadoc' } publishing { diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/Desugar.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/Desugar.java deleted file mode 100644 index 1a11531..0000000 --- a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/Desugar.java +++ /dev/null @@ -1,14 +0,0 @@ -package com.github.bsideup.jabel; - -import java.lang.annotation.Documented; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; - -import static java.lang.annotation.ElementType.TYPE; - -@Documented -@Retention(RetentionPolicy.SOURCE) -@Target(value=TYPE) -public @interface Desugar { -} diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/FlexibleMainRetrofittingTaskListener.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/FlexibleMainRetrofittingTaskListener.java new file mode 100644 index 0000000..a57e47d --- /dev/null +++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/FlexibleMainRetrofittingTaskListener.java @@ -0,0 +1,212 @@ +package com.github.bsideup.jabel; + +import java.util.*; + +import javax.tools.*; + +import com.sun.source.util.*; +import com.sun.tools.javac.code.*; +import com.sun.tools.javac.tree.*; +import com.sun.tools.javac.tree.JCTree.*; +import com.sun.tools.javac.util.*; +import com.sun.tools.javac.util.List; + + +/** + * Will adapts flexible entry points + * ({@code void main();, void main(String[]);, static void main();, + * static void main(String[])}) by attempting to create a bridge. + *

+ * This is pretty limited as non-static main entry points with arguments cannot + * be adapted, due to a signature duplication with the static one. + *

+ * Moreover, classes containing multiple entry points with ones that are + * private, will not work properly on older JVMs.
+ * For example this code: + * + *

{@code
+ * class Main {
+ *     static void main() {
+ *     }
+ *
+ *     private static void main(String[] args) {
+ *     }
+ * }
+ * }
+ * + * The first main method will be selected by Java 25's JVM, because the standard + * one is private, whereas on an older JVM, an error will be displayed.
+ * And Jabel cannot create a bridge for the first main, because there would be a + * signature duplication with the second declared main. + */ +public class FlexibleMainRetrofittingTaskListener implements TaskListener { + final TreeMaker make; + final Names names; + final Symtab syms; + final Log log; + final JCDiagnostic.Factory diagFactory; + final Name mainName; + + FlexibleMainRetrofittingTaskListener(Context context) { + make = TreeMaker.instance(context); + names = Names.instance(context); + syms = Symtab.instance(context); + log = Log.instance(context); + diagFactory = JCDiagnostic.Factory.instance(context); + mainName = names.fromString("main"); // syms.main; + + // Proper way to make warnings + JavacMessages.instance(context).add(locale -> new ResourceBundle() { + final Map keys = new HashMap<>(2); + { + // Act like it's linting + keys.put( + "jabel.warn.possible.signature.duplication", + "[jabel] possible entry point cannot be adapted " + + "due to a signature duplication. " + + "''{0}'' cannot therefore be used as an entry point in a JVM bellow Java25." + ); + keys.put( + "jabel.warn.no.default.constructor.found", + "[jabel] possible entry point cannot be adapted " + + "because no instanciable default constructor was found. " + + "''{0}'' cannot therefore be used as an entry point in a JVM bellow Java25." + ); + } + + @Override + protected Object handleGetObject(String key) { + return keys.get(key); + } + + @Override + public Enumeration getKeys() { + return Collections.enumeration(keys.keySet()); + } + }); + } + + @Override + public void started(TaskEvent e) { + if (e.getKind() != TaskEvent.Kind.ENTER) return; + if (!(e.getCompilationUnit() instanceof JCCompilationUnit)) return; + JCCompilationUnit jcu = (JCCompilationUnit) e.getCompilationUnit(); + JavaFileObject old = log.useSource(jcu.sourcefile); + + for (JCTree def : jcu.defs) { + if (!(def instanceof JCClassDecl)) continue; + transformClass((JCClassDecl) def); + } + + log.useSource(old); + } + + @Override + public void finished(TaskEvent e) { + } + + public void transformClass(JCClassDecl classDecl) { + make.at(classDecl.pos); + + // Slots: [0]=static+args, [1]=static, [2]=instance+args, [3]=instance + JCMethodDecl[] mains = new JCMethodDecl[4]; + boolean canInstanciate = false, hasConstructor = false; + int ep = mains.length; + + for (JCTree def : classDecl.defs) { + if (!(def instanceof JCMethodDecl)) continue; + JCMethodDecl method = (JCMethodDecl) def; + + if (!isMain(method)) { + if (method.name == names.init) { + hasConstructor = true; + if (!isPrivate(method) && method.params.isEmpty()) { + canInstanciate = true; + } + } + continue; + } + + int slot = (isStatic(method) ? 0 : 2) | (method.params.isEmpty() ? 1 : 0); + if (mains[slot] != null) continue; + mains[slot] = method; + if (slot < ep && !isPrivate(method)) ep = slot; + } + + // To avoid a signature duplication + if (mains[0] == null && mains[2] != null) ep = 2; + + switch (ep) { + case 1: + addMainBridge(classDecl, false); + break; + case 2: + warn(mains[ep], "possible.signature.duplication", classDecl.name); + break; + case 3: + if (canInstanciate || !hasConstructor) addMainBridge(classDecl, true); + else warn(mains[ep], "no.default.constructor.found", classDecl.name); + } + } + + public boolean isPrivate(JCMethodDecl m) { + return (m.mods.flags & Flags.PRIVATE) != 0; + } + + public boolean isStatic(JCMethodDecl m) { + return (m.mods.flags & Flags.STATIC) != 0; + } + + private void warn(JCMethodDecl method, String key, Object arg) { + log.report(diagFactory.create(log.currentSource(), method, new JCDiagnostic.Warning("jabel", key, arg))); + } + + public boolean isMain(JCMethodDecl method) { + if (mainName != method.name) return false; + if (!(method.restype instanceof JCPrimitiveTypeTree)) return false; + if (((JCPrimitiveTypeTree) method.restype).typetag != TypeTag.VOID) return false; + if (method.params.isEmpty()) return true; + if (method.params.size() != 1) return false; + JCTree vartype = method.params.get(0).vartype; + if (!(vartype instanceof JCArrayTypeTree)) return false; + String elem = ((JCArrayTypeTree) vartype).elemtype.toString(); + return elem.equals("String") || elem.endsWith(".String"); + } + + /** If {@code toLocalMain} is {@code true}, a zero-arg constructor must be present. */ + public JCMethodDecl makeMainBridge(JCClassDecl classDecl, boolean toLocalMain) { + return make.MethodDef( + make.Modifiers(Flags.PUBLIC | Flags.STATIC), + mainName, + make.TypeIdent(TypeTag.VOID), + List.nil(), + List.of(make.VarDef( + make.Modifiers(Flags.PARAMETER), + names.fromString("args"), + make.TypeArray(make.Type(syms.stringType)), + null + )), + List.nil(), + make.Block(0, List.of(make.Exec(make.Apply( + List.nil(), + make.Select( + toLocalMain ? make.NewClass( + null, + List.nil(), + make.Ident(classDecl.name), + List.nil(), + null + ) : make.Ident(classDecl.name), + mainName + ), + List.nil() + )))), + null + ); + } + + /** Creates a main bridge in the specified class. */ + public void addMainBridge(JCClassDecl classDecl, boolean toLocalMain) { + classDecl.defs = classDecl.defs.append(makeMainBridge(classDecl, toLocalMain)); + } +} \ No newline at end of file diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/ImplicitClassesFixerTaskListener.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/ImplicitClassesFixerTaskListener.java new file mode 100644 index 0000000..41eee1b --- /dev/null +++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/ImplicitClassesFixerTaskListener.java @@ -0,0 +1,80 @@ +package com.github.bsideup.jabel; + +import com.sun.source.util.*; +import com.sun.tools.javac.code.*; +import com.sun.tools.javac.tree.*; +import com.sun.tools.javac.tree.JCTree.*; +import com.sun.tools.javac.util.*; + + +/** + * Because implicit classes import the {@code java.base} module, and modules are + * not yet supported, we need to convert it by adding star import of all + * exported and existing packages. + */ +public class ImplicitClassesFixerTaskListener implements TaskListener { + final TreeMaker make; + final Names names; + final Symtab syms; + List javaBaseImports; + + ImplicitClassesFixerTaskListener(Context context) { + make = TreeMaker.instance(context); + names = Names.instance(context); + syms = Symtab.instance(context); + } + + @Override + public void started(TaskEvent e) { + if (e.getKind() != TaskEvent.Kind.ENTER) return; + if (!(e.getCompilationUnit() instanceof JCCompilationUnit)) return; + JCCompilationUnit jcu = (JCCompilationUnit) e.getCompilationUnit(); + + for (JCTree def : jcu.defs) { + if (!(def instanceof JCClassDecl)) continue; + JCClassDecl clazz = (JCClassDecl) def; + if ((clazz.mods.flags & Flags.IMPLICIT_CLASS) == 0) continue; + clazz.mods.flags &= ~Flags.IMPLICIT_CLASS; // Avoid implicit importation of java.base + injectJavaBaseImports(jcu); + } + } + + @Override + public void finished(TaskEvent e) { + } + + @SuppressWarnings("unchecked") + public void injectJavaBaseImports(JCCompilationUnit jcu) { + if (javaBaseImports == null) javaBaseImports = makeStarImports(getJavaBasePackages()); + // Duplicate imports are not important + jcu.defs = ((List) (List) javaBaseImports).appendList(jcu.defs); + } + + /** @return the list of packages that are exported by the {@code java.base} modules. */ + public List getJavaBasePackages() { + // Since modules are not enabled and initialized, {@code syms.java_base.exports} is not populated + return ModuleLayer.boot().findModule("java.base").map(mod -> + mod.getPackages().stream() + .filter(mod::isExported) + .map(p -> syms.enterPackage(syms.java_base, names.fromString(p))) + .filter(p -> { + try { + p.complete(); + } catch(Exception ignored) {} + return p.exists(); + }) + .collect(List.collector()) + ).orElse(List.nil()); + } + + public List makeStarImports(List packages) { + List imports = List.nil(); + for (Symbol.PackageSymbol pkg : packages) { + imports = imports.append(make.Import(make.Select( + make.QualIdent(pkg), + names.asterisk + ), false)); + } + return imports; + } +} \ No newline at end of file diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/InstanceofRetrofittingTaskListener.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/InstanceofRetrofittingTaskListener.java new file mode 100644 index 0000000..44403e1 --- /dev/null +++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/InstanceofRetrofittingTaskListener.java @@ -0,0 +1,144 @@ +package com.github.bsideup.jabel; + +import com.sun.source.util.*; +import com.sun.tools.javac.tree.*; +import com.sun.tools.javac.tree.JCTree.*; +import com.sun.tools.javac.util.*; + +import static com.github.bsideup.jabel.RecordPatternHelper.*; + + +/** + * Transforms record patterns and binding patterns in instanceof (Java 16+/21+) + * into Java 8 compatible bytecode. + *

+ * Handles: + *

    + *
  • Binding patterns (if obj instanceof String s).
  • + *
  • Record patterns (if obj instanceof Point(int x, int y)).
  • + *
  • Nested record patterns (if obj instanceof Line(Point(int x1, int y1), + * Point end)).
  • + *
  • Unconditional patterns (if str instanceof CharSequence cs).
  • + *
+ */ +public class InstanceofRetrofittingTaskListener implements TaskListener { + final RecordPatternHelper helper; + final TreeMaker make; + + public InstanceofRetrofittingTaskListener(Context context) { + helper = new RecordPatternHelper(context); + make = TreeMaker.instance(context); + } + + @Override + public void started(TaskEvent e) { + if (e.getKind() != TaskEvent.Kind.ENTER) return; + if (!(e.getCompilationUnit() instanceof JCCompilationUnit)) return; + new InstanceofTranslator().translate((JCCompilationUnit) e.getCompilationUnit()); + } + + @Override + public void finished(TaskEvent e) { + } + + public class InstanceofTranslator extends TreeTranslator { + @Override + public T translate(T tree) { + if (tree == null) return null; + helper.collectRecord(tree); + + if (tree instanceof JCIf) { + JCIf ifStmt = (JCIf) tree; + JCExpression cond = unwrapParenthesis(ifStmt.cond); + if (!(cond instanceof JCInstanceOf)) { + return super.translate(tree); + } + + JCInstanceOf instanceOf = (JCInstanceOf) cond; + JCTree pattern = instanceOf.pattern; + if (pattern == null) { + return super.translate(tree); + } + + if (isRecordPattern(pattern)) { + transformRecordPattern(ifStmt, instanceOf, pattern); + } else if (isBindingPattern(pattern)) { + transformBindingPattern(ifStmt, instanceOf, pattern); + } + } + + return super.translate(tree); + } + } + + public JCExpression unwrapParenthesis(JCExpression expr) { + while (expr instanceof JCParens) { + expr = ((JCParens) expr).expr; + } + return expr; + } + + /** + * Transform: {@code if (obj instanceof Point(int x, int y)) { body } }
+ * Into: + * {@code if (obj instanceof Point) { Point $rec = (Point)obj; int x = $rec.x(); ... body } } + */ + public void transformRecordPattern(JCIf ifStmt, JCInstanceOf instanceOf, JCTree pattern) { + JCExpression recordType = getRecordType(pattern); + if (recordType == null) return; + + make.at(ifStmt.pos); + Name tempVar = helper.tempName(); + + ListBuffer declarations = new ListBuffer<>(); + declarations.append(helper.makeVarDef( + tempVar, + recordType, + helper.makeCast(recordType, instanceOf.expr) + )); + helper.extractRecordBindings( + getRecordNested(pattern), + make.Ident(tempVar), + helper.getRecordComponentNames(pattern), + declarations + ); + + instanceOf.pattern = recordType; + ifStmt.thenpart = buildBlock(declarations.toList(), ifStmt.thenpart); + } + + /** + * Transform: {@code if (obj instanceof String s) { body } }
+ * Into: {@code if (obj instanceof String) { String s = (String)obj; body } } + */ + public void transformBindingPattern(JCIf ifStmt, JCInstanceOf instanceOf, JCTree pattern) { + JCVariableDecl var = ((JCBindingPattern) pattern).var; + if (var == null || var.vartype == null) return; + + make.at(ifStmt.pos); + instanceOf.pattern = helper.copy(var.vartype); + + ListBuffer declarations = new ListBuffer<>(); + declarations.append(helper.makeVarDef( + var.name, + var.vartype, + helper.makeCast(var.vartype, instanceOf.expr) + )); + ifStmt.thenpart = buildBlock(declarations.toList(), ifStmt.thenpart); + } + + public JCBlock buildBlock(List declarations, JCStatement body) { + ListBuffer stmts = new ListBuffer<>(); + for (JCVariableDecl decl : declarations) { + stmts.append(decl); + } + if (body instanceof JCBlock) { + for (JCStatement stmt : ((JCBlock) body).stats) { + stmts.append(stmt); + } + } else if (body != null) { + stmts.append(body); + } + return make.Block(0, stmts.toList()); + } +} \ No newline at end of file diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/JabelCompilerPlugin.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/JabelCompilerPlugin.java index 0a5aeb1..7bf1be1 100644 --- a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/JabelCompilerPlugin.java +++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/JabelCompilerPlugin.java @@ -4,8 +4,8 @@ import com.sun.source.util.Plugin; import com.sun.tools.javac.api.BasicJavacTask; import com.sun.tools.javac.code.Source; -import com.sun.tools.javac.util.Context; -import com.sun.tools.javac.util.JavacMessages; +import com.sun.tools.javac.util.*; + import net.bytebuddy.ByteBuddy; import net.bytebuddy.agent.ByteBuddyAgent; import net.bytebuddy.asm.Advice; @@ -29,34 +29,42 @@ import static net.bytebuddy.matcher.ElementMatchers.*; + public class JabelCompilerPlugin implements Plugin { static { - Map visitors = new HashMap() {{ - // Disable the preview feature check - AsmVisitorWrapper checkSourceLevelAdvice = Advice.to(CheckSourceLevelAdvice.class) - .on(named("checkSourceLevel").and(takesArguments(2))); - - // Allow features that were introduced together with Records (local enums, static inner members, ...) - AsmVisitorWrapper allowRecordsEraFeaturesAdvice = new FieldAccessStub("allowRecords", true); - - put("com.sun.tools.javac.parser.JavacParser", - new AsmVisitorWrapper.Compound( - checkSourceLevelAdvice, - allowRecordsEraFeaturesAdvice - ) - ); - put("com.sun.tools.javac.parser.JavaTokenizer", checkSourceLevelAdvice); - - put("com.sun.tools.javac.comp.Check", allowRecordsEraFeaturesAdvice); - put("com.sun.tools.javac.comp.Attr", allowRecordsEraFeaturesAdvice); - put("com.sun.tools.javac.comp.Resolve", allowRecordsEraFeaturesAdvice); + boolean c = false; + try { + Class.forName("com.sun.tools.javac.code.Source$Feature"); + c = true; + } catch (Exception e) {} + final boolean canPatchSources = c; - // Lower the source requirement for supported features - put( - "com.sun.tools.javac.code.Source$Feature", - Advice.to(AllowedInSourceAdvice.class) - .on(named("allowedInSource").and(takesArguments(1))) - ); + Map visitors = new HashMap() {{ + if (canPatchSources) { + // Disable the preview feature check + AsmVisitorWrapper checkSourceLevelAdvice = Advice.to(CheckSourceLevelAdvice.class) + .on(named("checkSourceLevel").and(takesArguments(2))); + + // Allow features that were introduced together with Records (local enums, static inner members, ...) + AsmVisitorWrapper allowRecordsEraFeaturesAdvice = new FieldAccessStub("allowRecords", true); + + put("com.sun.tools.javac.parser.JavacParser", + new AsmVisitorWrapper.Compound( + checkSourceLevelAdvice, + allowRecordsEraFeaturesAdvice + ) + ); + put("com.sun.tools.javac.parser.JavaTokenizer", checkSourceLevelAdvice); + + put("com.sun.tools.javac.comp.Check", allowRecordsEraFeaturesAdvice); + put("com.sun.tools.javac.comp.Attr", allowRecordsEraFeaturesAdvice); + put("com.sun.tools.javac.comp.Resolve", allowRecordsEraFeaturesAdvice); + + // Lower the source requirement for supported features + AsmVisitorWrapper allowedInSourceAdvice = Advice.to(AllowedInSourceAdvice.class) + .on(named("allowedInSource").and(takesArguments(1))); + put("com.sun.tools.javac.code.Source$Feature", allowedInSourceAdvice); + } }}; try { @@ -92,41 +100,40 @@ public class JabelCompilerPlugin implements Plugin { .load(classLoader, ClassReloadingStrategy.fromInstalledAgent()); }); - JavaModule jabelModule = JavaModule.ofType(JabelCompilerPlugin.class); - ClassInjector.UsingInstrumentation.redefineModule( - ByteBuddyAgent.getInstrumentation(), - JavaModule.ofType(JavacTask.class), - Collections.emptySet(), - Collections.emptyMap(), - new HashMap>() {{ - put("com.sun.tools.javac.api", Collections.singleton(jabelModule)); - put("com.sun.tools.javac.tree", Collections.singleton(jabelModule)); - put("com.sun.tools.javac.code", Collections.singleton(jabelModule)); - put("com.sun.tools.javac.util", Collections.singleton(jabelModule)); - }}, - Collections.emptySet(), - Collections.emptyMap() - ); + try { + JavaModule jabelModule = JavaModule.ofType(JabelCompilerPlugin.class); + ClassInjector.UsingInstrumentation.redefineModule( + ByteBuddyAgent.getInstrumentation(), + JavaModule.ofType(JavacTask.class), + Collections.emptySet(), + Collections.emptyMap(), + new HashMap>() {{ + put("com.sun.tools.javac.api", Collections.singleton(jabelModule)); + put("com.sun.tools.javac.tree", Collections.singleton(jabelModule)); + put("com.sun.tools.javac.code", Collections.singleton(jabelModule)); + put("com.sun.tools.javac.comp", Collections.singleton(jabelModule)); + put("com.sun.tools.javac.util", Collections.singleton(jabelModule)); + }}, + Collections.emptySet(), + Collections.emptyMap() + ); + // In case of we are running on Java 8 + } catch (NullPointerException ignored) {} } @Override public void init(JavacTask task, String... args) { Context context = ((BasicJavacTask) task).getContext(); - JavacMessages.instance(context).add(locale -> new ResourceBundle() { - @Override - protected Object handleGetObject(String key) { - return "{0}"; - } - - @Override - public Enumeration getKeys() { - return Collections.enumeration(Arrays.asList("missing.desugar.on.record")); - } - }); + removeUnderscoreWarnings(context); task.addTaskListener(new RecordsRetrofittingTaskListener(context)); - - System.out.println("Jabel: initialized"); + task.addTaskListener(new InstanceofRetrofittingTaskListener(context)); + task.addTaskListener(new SwitchRetrofittingTaskListener(context)); + try { + task.addTaskListener(new FlexibleMainRetrofittingTaskListener(context)); + // Because JCDiagnostic.Warning doesn't exists on Java 8. But we don't care at this point + } catch (NoClassDefFoundError ignored) {} + task.addTaskListener(new ImplicitClassesFixerTaskListener(context)); } @Override @@ -134,40 +141,46 @@ public String getName() { return "jabel"; } - // Make it auto start on Java 14+ + /** Make it auto starts on Java 14+. */ public boolean autoStart() { return true; } - static class AllowedInSourceAdvice { + /** Removes warnings about {@code '_'}. */ + private static void removeUnderscoreWarnings(Context context) { + // Need to inherit a class instead. + // This is due to DeferredDiagnosticHandler(Predicate) being DeferredDiagnosticHandler(Filter) on Java 16- + Log.instance(context).new DiscardDiagnosticHandler() { + @Override + public void report(JCDiagnostic diag) { + String code = diag.getCode(); + if (code.contains("underscore.as.identifier") || + code.contains("use.of.underscore.not.allowed")) return; + prev.report(diag); + } + }; + } + static class AllowedInSourceAdvice { @Advice.OnMethodEnter static void allowedInSource( @Advice.This Source.Feature feature, @Advice.Argument(value = 0, readOnly = false) Source source ) { switch (feature.name()) { - case "PRIVATE_SAFE_VARARGS": - case "SWITCH_EXPRESSION": - case "SWITCH_RULE": - case "SWITCH_MULTIPLE_CASE_LABELS": - case "LOCAL_VARIABLE_TYPE_INFERENCE": - case "VAR_SYNTAX_IMPLICIT_LAMBDAS": - case "DIAMOND_WITH_ANONYMOUS_CLASS_CREATION": - case "EFFECTIVELY_FINAL_VARIABLES_IN_TRY_WITH_RESOURCES": - case "TEXT_BLOCKS": - case "PATTERN_MATCHING_IN_INSTANCEOF": - case "REIFIABLE_TYPES_INSTANCEOF": - case "RECORDS": + case "MODULES": // Extremely difficult as initialization is done very early + case "STRING_TEMPLATES": // Appeared on Java 21 and removed on Java 23 because of a confusing design + case "MODULE_IMPORTS": // Needs the modules system + case "JAVA_BASE_TRANSITIVE": // Needs the modules system + break; + default: //noinspection UnusedAssignment source = Source.DEFAULT; - break; } } } static class CheckSourceLevelAdvice { - @Advice.OnMethodEnter static void checkSourceLevel( @Advice.Argument(value = 1, readOnly = false) Source.Feature feature @@ -181,9 +194,7 @@ static void checkSourceLevel( } private static class FieldAccessStub extends AsmVisitorWrapper.AbstractBase { - final String fieldName; - final Object value; public FieldAccessStub(String fieldName, Object value) { diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordPatternHelper.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordPatternHelper.java new file mode 100644 index 0000000..35df9fb --- /dev/null +++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordPatternHelper.java @@ -0,0 +1,240 @@ +package com.github.bsideup.jabel; + +import java.util.*; + +import com.sun.tools.javac.code.*; +import com.sun.tools.javac.tree.*; +import com.sun.tools.javac.tree.JCTree.*; +import com.sun.tools.javac.util.*; +import com.sun.tools.javac.util.List; + + +/** Utilities for record pattern extraction and tree manipulation. */ +class RecordPatternHelper { + static String getClassName(Object obj) { + return obj == null ? "" : obj.getClass().getSimpleName(); + } + + /** Set the arrow-style body of a case tree. No-op if the field doesn't exist. */ + static void setCaseBody(JCCase caseTree, JCTree body) { + try { + caseTree.body = body; + } catch (NoSuchFieldError ignored) {} + } + + static boolean isBindingPattern(JCTree pattern) { + return pattern != null && getClassName(pattern).equals("JCBindingPattern"); + } + + static boolean isRecordPattern(JCTree pattern) { + return pattern != null && getClassName(pattern).equals("JCRecordPattern"); + } + + /** Check if this tree node is a pattern (binding, record, case label, or any). */ + static boolean isPattern(JCTree label) { + if (label == null) return false; + String name = getClassName(label); + return name.contains("Pattern") || name.contains("Binding"); + } + + /** + * Because {@link JCPatternCaseLabel#pat} type is {@link JCPattern}, we need to + * delay access to an inner class.
+ * Like that, the class initialization error can be catched easily. + */ + private static final class PatternCaseLabelAccess { + static JCTree pat(JCTree label) { + return ((JCPatternCaseLabel) label).pat; + } + } + + /** Get the binding variable from a pattern (binding pattern or pattern case label). */ + static JCVariableDecl getPatternVar(JCTree label) { + if (label == null) return null; + switch (getClassName(label)) { + case "JCBindingPattern": + return ((JCBindingPattern) label).var; + case "JCPatternCaseLabel": + return getPatternVar(PatternCaseLabelAccess.pat(label)); + default: + return null; + } + } + + /** + * Get the type from a pattern label (for instanceof check). + * Works for binding patterns, pattern case labels, and unnamed patterns. + */ + static JCExpression getPatternType(JCTree label) { + if (label == null) return null; + switch (getClassName(label)) { + case "JCBindingPattern": + JCVariableDecl var = ((JCBindingPattern) label).var; + return var != null ? var.vartype : null; + case "JCPatternCaseLabel": + return getPatternType(PatternCaseLabelAccess.pat(label)); + case "JCAnyPattern": + return getAnyPatternType(label); + default: + return null; + } + } + + /** + * Get the type from a {@code JCAnyPattern}. + *

+ * Reflection is needed because APIs varies across JDK versions: + * {@link JCTree#type} and {@link JCTree#getType()} differ between Java 22 and 23+. + */ + static JCExpression getAnyPatternType(JCTree label) { + // TODO: remake + try { + Object type = label.getClass().getField("type").get(label); + if (type instanceof JCExpression) { + return (JCExpression) type; + } + } catch (Exception ignored) {} + try { + Object type = label.getClass().getMethod("getType").invoke(label); + if (type instanceof JCExpression) { + return (JCExpression) type; + } + } catch (Exception ignored) {} + return null; + } + + /** Get the record type (deconstructor) from a record pattern or pattern case label. */ + static JCExpression getRecordType(JCTree label) { + if (label == null) return null; + switch (getClassName(label)) { + case "JCRecordPattern": + return ((JCRecordPattern) label).deconstructor; + case "JCPatternCaseLabel": + return getRecordType(PatternCaseLabelAccess.pat(label)); + default: + return null; + } + } + + /** Get nested patterns from a record pattern or pattern case label. */ + static List getRecordNested(JCTree label) { + if (label == null) return null; + switch (getClassName(label)) { + case "JCRecordPattern": + return ((JCRecordPattern) label).nested; + case "JCPatternCaseLabel": + return getRecordNested(PatternCaseLabelAccess.pat(label)); + default: + return null; + } + } + + /// + + final TreeMaker make; + final Names names; + /** Cache of record declarations for component name lookup. */ + final Map records = new HashMap<>(); + private static int tempVarCounter = 0; + + RecordPatternHelper(Context context) { + make = TreeMaker.instance(context); + names = Names.instance(context); + } + + /** + * Collect a record declaration for later component name lookup. + * Call during tree traversal. + */ + void collectRecord(JCTree tree) { + if (!(tree instanceof JCClassDecl)) return; + JCClassDecl classDecl = (JCClassDecl) tree; + if (!"RECORD".equals(classDecl.getKind().toString())) return; + records.put(classDecl.name.toString(), classDecl); + } + + /** Get record component names from cached record declarations or fallback to binding names. */ + List getRecordComponentNames(JCTree pattern) { + JCExpression deconstructor = getRecordType(pattern); + if (deconstructor == null) return List.nil(); + + // Try to find record declaration + JCClassDecl recordDecl = records.get(deconstructor.toString()); + if (recordDecl != null) { + List result = List.nil(); + for (JCTree def : recordDecl.defs) { + if (!(def instanceof JCVariableDecl)) continue; + JCVariableDecl varDecl = (JCVariableDecl) def; + if ((varDecl.mods.flags & Flags.RECORD) == 0) continue; + result = result.append(varDecl.name); + } + if (!result.isEmpty()) return result; + } + + // Fallback: use binding pattern names + List nested = getRecordNested(pattern); + if (nested == null) return List.nil(); + + List result = List.nil(); + for (JCTree p : nested) { + JCVariableDecl var = getPatternVar(p); + result = result.append(var != null ? var.name : null); + } + return result; + } + + /** + * Recursively extract bindings from nested record pattern components. + *

+ * Generates variable declarations like: + * {@code final Point $tmp = parent.b(); final int bx = $tmp.x(); final int by = $tmp.y();} + */ + void extractRecordBindings( + List nested, JCExpression baseAccessor, + List componentNames, ListBuffer out + ) { + int index = 0; + for (JCTree nestedPattern : nested) { + Name componentName = index < componentNames.size() ? componentNames.get(index++) : null; + + if (isBindingPattern(nestedPattern)) { + JCVariableDecl var = ((JCBindingPattern) nestedPattern).var; + if (var == null) continue; + + Name accessorName = componentName != null ? componentName : var.name; + out.append(makeVarDef(var.name, var.vartype, makeMethodCall(baseAccessor, accessorName))); + continue; + } + + if (!isRecordPattern(nestedPattern)) continue; + JCExpression nestedRecordType = getRecordType(nestedPattern); + List deepNested = getRecordNested(nestedPattern); + if (nestedRecordType == null || deepNested == null || componentName == null) continue; + + Name tempVar = tempName(); + JCExpression accessor = makeMethodCall(baseAccessor, componentName); + out.append(makeVarDef(tempVar, nestedRecordType, accessor)); + extractRecordBindings(deepNested, make.Ident(tempVar), getRecordComponentNames(nestedPattern), out); + } + } + + Name tempName() { + return names.fromString("$record$" + (tempVarCounter++)); + } + + JCVariableDecl makeVarDef(Name name, JCExpression type, JCExpression init) { + return make.VarDef(make.Modifiers(Flags.FINAL), name, copy(type), init); + } + + JCTypeCast makeCast(JCExpression type, JCExpression expr) { + return make.TypeCast(copy(type), copy(expr)); + } + + JCMethodInvocation makeMethodCall(JCExpression receiver, Name method) { + return make.Apply(List.nil(), make.Select(copy(receiver), method), List.nil()); + } + + JCExpression copy(JCExpression expr) { + return new TreeCopier(make).copy(expr); + } +} diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordsRetrofittingTaskListener.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordsRetrofittingTaskListener.java index cbadd4a..1fa0e9f 100644 --- a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordsRetrofittingTaskListener.java +++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordsRetrofittingTaskListener.java @@ -1,191 +1,163 @@ package com.github.bsideup.jabel; -import com.sun.source.tree.ClassTree; -import com.sun.source.tree.CompilationUnitTree; -import com.sun.source.util.TaskEvent; -import com.sun.source.util.TaskListener; +import java.util.Iterator; +import java.util.stream.*; + +import javax.lang.model.element.Modifier; + +import com.sun.source.tree.*; +import com.sun.source.util.*; import com.sun.source.util.TreeScanner; import com.sun.tools.javac.code.*; -import com.sun.tools.javac.tree.JCTree; -import com.sun.tools.javac.tree.TreeMaker; +import com.sun.tools.javac.tree.*; +import com.sun.tools.javac.tree.JCTree.*; import com.sun.tools.javac.util.*; -import javax.lang.model.element.Modifier; -import javax.tools.JavaFileObject; -import java.util.Iterator; -import java.util.stream.Stream; - -class RecordsRetrofittingTaskListener implements TaskListener { +/** + * Will generate {@code hashCode()}, {@code equals()} and {@code toString()} + * methods, and remove {@link Flags#RECORD}. + */ +public class RecordsRetrofittingTaskListener implements TaskListener { final TreeMaker make; - final Symtab syms; - + final Types types; final Names names; - final Log log; - - TreeScanner recordsScanner = new TreeScanner() { - @Override - public Void visitClass(ClassTree node, Void aVoid) { - if (!"RECORD".equals(node.getKind().toString())) { - return super.visitClass(node, aVoid); - } + public RecordsRetrofittingTaskListener(Context context) { + make = TreeMaker.instance(context); + syms = Symtab.instance(context); + types = Types.instance(context); + names = Names.instance(context); + } - JCTree.JCClassDecl classDecl = (JCTree.JCClassDecl) node; + @Override + public void started(TaskEvent e) { + if (e.getKind() != TaskEvent.Kind.ENTER) return; + new RecordsScanner().scan(e.getCompilationUnit(), false); + } - if (classDecl.extending == null) { - // Prevent implicit "extends java.lang.Record" - classDecl.extending = make.Type(syms.objectType); - } + /** Remove {@link Flags#RECORD} to avoid invalid ASM reading. */ + @Override + public void finished(TaskEvent e) { + if (e.getKind() != TaskEvent.Kind.ANALYZE) return; + new RecordsScanner().scan(e.getCompilationUnit(), true); + } - { - Name methodName = names.toString; - List argTypes = List.nil(); - if (!containsMethod(classDecl, methodName)) { - JCTree.JCMethodDecl methodDecl = make.MethodDef( - new Symbol.MethodSymbol( - Flags.PUBLIC, - methodName, - new Type.MethodType( - argTypes, - syms.stringType, - List.nil(), - syms.methodClass - ), - syms.objectType.tsym - ), - make.Block(0, generateToString(classDecl)) - ); - classDecl.defs = classDecl.defs.append(methodDecl); - } - } + public class RecordsScanner extends TreeScanner { + @Override + public Void visitClass(ClassTree node, Boolean endPhase) { + if ("RECORD".equals(node.getKind().toString())) { + JCClassDecl classDecl = (JCClassDecl) node; - { - Name methodName = names.hashCode; - List argTypes = List.nil(); - if (!containsMethod(classDecl, methodName)) { - classDecl.defs = classDecl.defs.append(make.MethodDef( - new Symbol.MethodSymbol( - Flags.PUBLIC, - methodName, - new Type.MethodType( - argTypes, - syms.intType, - List.nil(), - syms.methodClass - ), - syms.objectType.tsym - ), - make.Block(0, generateHashCode(classDecl)) - )); - } - } + if (endPhase != null && endPhase) { + if (classDecl.sym != null) { + classDecl.sym.flags_field &= ~Flags.RECORD; + } - { - Name methodName = names.equals; - List argTypes = List.of(syms.objectType); - if (!containsMethod(classDecl, methodName)) { - Symbol.MethodSymbol methodSymbol = new Symbol.MethodSymbol( - Flags.PUBLIC | Flags.FINAL, - methodName, - new Type.MethodType( - argTypes, - syms.booleanType, - List.nil(), - syms.methodClass - ), - syms.objectType.tsym - ); - Symbol.VarSymbol firstParameter = methodSymbol.params().head; - - JCTree.JCMethodDecl methodDecl = make.MethodDef( - methodSymbol, - make.Block(0, generateEquals(classDecl, firstParameter.name)) - ); - // THIS ONE IS IMPORTANT! Otherwise, Flow.AssignAnalyzer#visitVarDef will have track=false - methodDecl.params.head.pos = classDecl.pos; - classDecl.defs = classDecl.defs.append(methodDecl); + } else { + if (classDecl.extending == null) { + // Prevent implicit "extends java.lang.Record" + classDecl.extending = make.Type(syms.objectType); + } + generateToStringIfNeeded(classDecl); + generateHashcodeIfNeeded(classDecl); + generateEqualsIfNeeded(classDecl); } } - return super.visitClass(node, aVoid); + return super.visitClass(node, endPhase); } + } - private boolean containsMethod(JCTree.JCClassDecl classDecl, Name name) { - return classDecl.defs.stream() - .filter(JCTree.JCMethodDecl.class::isInstance) - .map(JCTree.JCMethodDecl.class::cast) - .anyMatch(def -> { - if (def.getName() != name) { - return false; - } - - if (name == names.equals) { - if (def.params.size() != 1) { - return false; - } - - // TODO find a better way? - JCTree.JCVariableDecl param = def.params.get(0); - switch (param.getType().toString()) { - case "java.lang.Object": - case "Object": - return true; - default: - return false; - } - } - - return true; - }); - } - }; + public void generateToStringIfNeeded(JCClassDecl classDecl) { + if (containsMethod(classDecl, names.toString)) return; + classDecl.defs = classDecl.defs.append(make.MethodDef( + new Symbol.MethodSymbol( + Flags.PUBLIC | Flags.FINAL, + names.toString, + new Type.MethodType( + List.nil(), + syms.stringType, + List.nil(), + syms.methodClass + ), + syms.objectType.tsym + ), + make.Block(0, generateToString(classDecl)) + )); + } - public RecordsRetrofittingTaskListener(Context context) { - make = TreeMaker.instance(context); - syms = Symtab.instance(context); - names = Names.instance(context); - log = Log.instance(context); + public void generateHashcodeIfNeeded(JCClassDecl classDecl) { + if (containsMethod(classDecl, names.hashCode)) return; + classDecl.defs = classDecl.defs.append(make.MethodDef( + new Symbol.MethodSymbol( + Flags.PUBLIC | Flags.FINAL, + names.hashCode, + new Type.MethodType( + List.nil(), + syms.intType, + List.nil(), + syms.methodClass + ), + syms.objectType.tsym + ), + make.Block(0, generateHashCode(classDecl)) + )); } - @Override - public void started(TaskEvent e) { - switch (e.getKind()) { - case ENTER: - recordsScanner.scan(e.getCompilationUnit(), null); - new TreeScanner() { - @Override - public Void visitClass(ClassTree node, Void aVoid) { - if ("RECORD".equals(node.getKind().toString())) { - JCTree.JCClassDecl classDecl = (JCTree.JCClassDecl) node; - - if (classDecl.extending == null) { - // Prevent implicit "extends java.lang.Record" - classDecl.extending = make.Type(syms.objectType); - } - } - return super.visitClass(node, aVoid); - } - }.scan(e.getCompilationUnit(), null); - break; - case ANALYZE: - new MandatoryDesugarAnnotationTreeScanner(log, e.getCompilationUnit()).scan(e.getCompilationUnit(), null); - } + public void generateEqualsIfNeeded(JCClassDecl classDecl) { + if (containsMethod(classDecl, names.equals)) return; + Symbol.MethodSymbol methodSymbol = new Symbol.MethodSymbol( + Flags.PUBLIC | Flags.FINAL, + names.equals, + new Type.MethodType( + List.of(syms.objectType), + syms.booleanType, + List.nil(), + syms.methodClass + ), + syms.objectType.tsym + ); + JCTree.JCMethodDecl methodDecl = make.MethodDef( + methodSymbol, + make.Block(0, generateEquals( + classDecl, + methodSymbol.params().head.name + )) + ); + + // THIS ONE IS IMPORTANT! Otherwise, Flow.AssignAnalyzer#visitVarDef will have track=false + methodDecl.params.head.pos = classDecl.pos; + classDecl.defs = classDecl.defs.append(methodDecl); } - @Override - public void finished(TaskEvent e) { + /** Can only search for a method with no or one argument. */ + private boolean containsMethod(JCClassDecl classDecl, Name name) { + for (JCTree next : classDecl.defs) { + if (!(next instanceof JCMethodDecl)) continue; + JCMethodDecl def = (JCMethodDecl) next; + if (def.getName() != name) continue; + if (name != names.equals) return true; + if (def.params.size() != 1) continue; + // TODO find a better way? + switch(def.params.get(0).getType().toString()){ + case "java.lang.Object": + case "Object": + return true; + } + } + return false; } - private Stream getRecordComponents(JCTree.JCClassDecl classDecl) { + public Stream getRecordComponents(JCClassDecl classDecl) { return classDecl.getMembers().stream() - .filter(JCTree.JCVariableDecl.class::isInstance) - .map(JCTree.JCVariableDecl.class::cast) - .filter(it -> !it.getModifiers().getFlags().contains(Modifier.STATIC)); + .filter(JCVariableDecl.class::isInstance) + .map(JCVariableDecl.class::cast) + .filter(it -> !it.getModifiers().getFlags().contains(Modifier.STATIC)); } - private List generateToString(JCTree.JCClassDecl classDecl) { - JCTree.JCExpression stringBuilder = make.NewClass( + public List generateToString(JCClassDecl classDecl) { + JCExpression stringBuilder = make.NewClass( null, null, make.QualIdent(syms.stringBuilderType.tsym), @@ -194,302 +166,244 @@ private List generateToString(JCTree.JCClassDecl classDecl) ); for ( - Iterator iterator = getRecordComponents(classDecl).iterator(); + Iterator iterator = getRecordComponents(classDecl).iterator(); iterator.hasNext(); ) { - JCTree.JCVariableDecl fieldDecl = iterator.next(); + JCVariableDecl fieldDecl = iterator.next(); Name fieldName = fieldDecl.name; stringBuilder = make.App( - make.Select(stringBuilder, names.append).setType(syms.stringBuilderType), + make.Select(stringBuilder, names.append) + .setType(syms.stringBuilderType), List.of(make.Literal(fieldName + "=")) ); stringBuilder = make.App( - make.Select(stringBuilder, names.append).setType(syms.stringBuilderType), - List.of( - make.Select( - make.This(Type.noType), - fieldName - ) - ) + make.Select(stringBuilder, names.append) + .setType(syms.stringBuilderType), + List.of(make.Select(make.This(Type.noType), fieldName)) ); - if (iterator.hasNext()) { - stringBuilder = make.App( - make.Select(stringBuilder, names.append).setType(syms.stringBuilderType), - List.of(make.Literal(",")) - ); - } + if (!iterator.hasNext()) break; + stringBuilder = make.App( + make.Select(stringBuilder, names.append) + .setType(syms.stringBuilderType), + List.of(make.Literal(", ")) + ); } stringBuilder = make.App( - make.Select(stringBuilder, names.append).setType(syms.stringBuilderType), + make.Select(stringBuilder, names.append) + .setType(syms.stringBuilderType), List.of(make.Literal("]")) ); - return List.of(make.Return( - make.App( - make.Select(stringBuilder, names.toString).setType(syms.stringType) - ) - )); + return List.of(make.Return(make.App( + make.Select(stringBuilder, names.toString) + .setType(syms.stringType) + ))); } - private List generateEquals(JCTree.JCClassDecl classDecl, Name otherName) { - ListBuffer statements = new ListBuffer<>(); + public List generateEquals(JCClassDecl classDecl, Name otherName) { + ListBuffer statements = new ListBuffer<>(); // if (o == this) return true; - { - statements.add(make.If( - make.Binary( - JCTree.Tag.EQ, - make.This(Type.noType), - make.Ident(otherName) - ), - make.Return(make.Literal(true)), - null - )); - } + statements.add(make.If( + make.Binary( + Tag.EQ, + make.This(Type.noType), + make.Ident(otherName) + ), + make.Return(make.Literal(true)), + null + )); // if (o == null) return false; - { - statements.add(make.If( - make.Binary( - JCTree.Tag.EQ, - make.Ident(otherName), - make.Literal(TypeTag.BOT, null) - ), - make.Return(make.Literal(false)), - null - )); - } + statements.add(make.If( + make.Binary( + Tag.EQ, + make.Ident(otherName), + make.Literal(TypeTag.BOT, null) + ), + make.Return(make.Literal(false)), + null + )); // if (o.getClass() != getClass()) return false; - { + statements.add(make.If( + make.Binary( + Tag.EQ, + make.App(make.Select( + make.Ident(otherName), + names.getClass + ).setType(syms.classType)), + make.App(make.Select( + make.This(Type.noType), + names.getClass + ).setType(syms.classType)) + ), + make.Block(0, List.nil()), + make.Return(make.Literal(false)) + )); + + // Create casted variable: ClassName other = (ClassName)o; + Name thatName = names.fromString("other"); + statements.add(make.VarDef( + make.Modifiers(0L), + thatName, + make.Ident(classDecl.name), + make.TypeCast(make.Ident(classDecl.name), make.Ident(otherName)) + )); + + // fields - use the casted variable + for ( + Iterator iterator = getRecordComponents(classDecl).iterator(); + iterator.hasNext(); + ) { + JCVariableDecl fieldDecl = iterator.next(); + JCExpression myFieldAccess = make.Select(make.This(Type.noType), fieldDecl.name); + JCExpression otherFieldAccess = make.Select(make.Ident(thatName), fieldDecl.name); + + final JCExpression condition; + if (fieldDecl.getType() instanceof JCPrimitiveTypeTree) { + condition = make.Binary(Tag.EQ, otherFieldAccess, myFieldAccess); + } else { + condition = make.App( + // call Objects.equals + make.Select( + make.QualIdent(syms.objectsType.tsym), + names.equals + ).setType(syms.objectsType), + List.of(otherFieldAccess, myFieldAccess) + ); + } statements.add(make.If( - make.Binary( - JCTree.Tag.EQ, - make.App(make.Select(make.Ident(otherName), names.getClass).setType(syms.classType)), - make.App(make.Select(make.This(Type.noType), names.getClass).setType(syms.classType)) - ), + condition, make.Block(0, List.nil()), make.Return(make.Literal(false)) )); } - // fields - { - for ( - Iterator iterator = getRecordComponents(classDecl).iterator(); - iterator.hasNext(); - ) { - JCTree.JCVariableDecl fieldDecl = iterator.next(); - - JCTree.JCExpression myFieldAccess = make.Select(make.This(Type.noType), fieldDecl.name); - JCTree.JCExpression otherFieldAccess = make.Select( - make.TypeCast(make.Ident(classDecl.name), make.Ident(otherName)), - fieldDecl.name - ); - - final JCTree.JCExpression condition; - if (fieldDecl.getType() instanceof JCTree.JCPrimitiveTypeTree) { - condition = make.Binary(JCTree.Tag.EQ, otherFieldAccess, myFieldAccess); - } else { - condition = make.App( - // call Objects.equals - make.Select( - make.QualIdent(syms.objectsType.tsym), - names.equals - ).setType(syms.objectsType), - List.of(otherFieldAccess, myFieldAccess) - ); - } - statements.add(make.If( - condition, - make.Block(0, List.nil()), - make.Return(make.Literal(false)) - )); - } - } - + // return true; statements.add(make.Return(make.Literal(true))); + return statements.toList(); } - private List generateHashCode(JCTree.JCClassDecl classDecl) { - ListBuffer expressions = new ListBuffer<>(); + public List generateHashCode(JCClassDecl classDecl) { + ListBuffer expressions = new ListBuffer<>(); for ( - Iterator iterator = getRecordComponents(classDecl).iterator(); + Iterator iterator = getRecordComponents(classDecl).iterator(); iterator.hasNext(); ) { - JCTree.JCVariableDecl fieldDecl = iterator.next(); + JCVariableDecl fieldDecl = iterator.next(); JCTree fType = fieldDecl.getType(); + JCExpression myFieldAccess = make.Select(make.This(Type.noType), fieldDecl.name); - JCTree.JCExpression myFieldAccess = make.Select(make.This(Type.noType), fieldDecl.name); - - if (fType instanceof JCTree.JCPrimitiveTypeTree) { - switch (((JCTree.JCPrimitiveTypeTree) fType).getPrimitiveTypeKind()) { + if (fType instanceof JCPrimitiveTypeTree) { + switch (((JCPrimitiveTypeTree) fType).getPrimitiveTypeKind()) { case BOOLEAN: /* this.fieldName ? 1 : 0 */ - expressions.append( - make.Conditional( - myFieldAccess, - make.Literal(TypeTag.INT, 1), - make.Literal(TypeTag.INT, 0) - ) - ); + expressions.append(make.Conditional( + myFieldAccess, + make.Literal(TypeTag.INT, 1), + make.Literal(TypeTag.INT, 0) + )); break; + case LONG: - expressions.append(longToIntForHashCode(myFieldAccess)); + expressions.append(make.TypeCast( + make.TypeIdent(syms.intType.getTag()), + make.Parens(make.Binary( + Tag.BITXOR, + myFieldAccess, + make.Parens(make.Binary( + Tag.USR, + myFieldAccess, + make.Literal(32) + )) + )) + )); break; + case FLOAT: /* this.fieldName != 0f ? Float.floatToIntBits(this.fieldName) : 0 */ - expressions.append( - make.Conditional( - make.Binary(JCTree.Tag.NE, myFieldAccess, make.Literal(0f)), - make.App( - make.Select( - make.Ident(names.fromString("Float")), - names.fromString("floatToIntBits")).setType(syms.intType), - List.of(myFieldAccess) - ), - make.Literal(TypeTag.INT, 0) - ) - ); + expressions.append(make.Conditional( + make.Binary(Tag.NE, myFieldAccess, make.Literal(0f)), + make.App( + make.Select( + make.QualIdent(types.boxedClass(syms.floatType)), + names.fromString("floatToIntBits") + ).setType(syms.intType), + List.of(myFieldAccess) + ), + make.Literal(TypeTag.INT, 0) + )); break; + case DOUBLE: - /* longToIntForHashCode(Double.doubleToLongBits(this.fieldName)) */ - expressions.append( - longToIntForHashCode( - make.App( - make.Select( - make.Ident(names.fromString("Double")), - names.fromString("doubleToLongBits")).setType(syms.intType), - List.of(myFieldAccess) - ) - ) - ); + /* Double.hashCode(this.fieldName) */ + expressions.append(make.App( + make.Select( + make.QualIdent(types.boxedClass(syms.doubleType)), + names.hashCode + ).setType(syms.intType), + List.of(myFieldAccess) + )); break; - default: + case BYTE: case SHORT: case INT: case CHAR: + default: /* just the field */ expressions.append(myFieldAccess); break; } - } else if (fType instanceof JCTree.JCArrayTypeTree) { - expressions.append( - make.App( - make.Select( - make.Select( - make.Select( - make.Ident(names.fromString("java")), - names.fromString("util") - ), - names.fromString("Arrays") - ), - names.fromString("hashCode") - ).setType(syms.intType), - List.of(myFieldAccess) - ) - ); + + } else if (fType instanceof JCArrayTypeTree) { + expressions.append(make.App( + make.Select( + make.QualIdent(syms.arraysType.tsym), + names.hashCode + ).setType(syms.intType), + List.of(myFieldAccess) + )); } else { /* (this.fieldName != null ? this.fieldName.hashCode() : 0) */ - expressions.append( - make.Conditional( - make.Binary(JCTree.Tag.NE, myFieldAccess, make.Literal(TypeTag.BOT, null)), - make.App(make.Select(myFieldAccess, names.hashCode).setType(syms.intType)), - make.Literal(0) - ) - ); + expressions.append(make.Conditional( + make.Binary(Tag.NE, myFieldAccess, make.Literal(TypeTag.BOT, null)), + make.App(make.Select(myFieldAccess, names.hashCode).setType(syms.intType)), + make.Literal(0) + )); } } - ListBuffer statements = new ListBuffer<>(); - + ListBuffer statements = new ListBuffer<>(); Name resultName = names.fromString("result"); + statements.append(make.VarDef( + make.Modifiers(0L), + resultName, + make.TypeIdent(syms.intType.getTag()), + make.Literal(0) + )); - statements.append( - make.VarDef( - make.Modifiers(0L), - resultName, - make.TypeIdent(syms.intType.getTag()), - make.Literal(0) - ) - ); - for (JCTree.JCExpression expression : expressions) { + for (JCExpression expression : expressions) { // result = 31 * result + ${expr} - statements.append(make.Exec( - make.Assign( - make.Ident(resultName), - make.Binary( - JCTree.Tag.PLUS, - make.Binary(JCTree.Tag.MUL, make.Literal(TypeTag.INT, 31), make.Ident(resultName)), - expression - ) + statements.append(make.Exec(make.Assign( + make.Ident(resultName), + make.Binary( + Tag.PLUS, + make.Binary(Tag.MUL, make.Literal(TypeTag.INT, 31), make.Ident(resultName)), + expression ) - )); + ))); } statements.append(make.Return(make.Ident(resultName))); return statements.toList(); } - - public JCTree.JCExpression longToIntForHashCode(JCTree.JCExpression ref) { - /* (int) (ref ^ ref >>> 32) */ - return make.TypeCast( - make.TypeIdent(syms.intType.getTag()), - make.Parens( - make.Binary( - JCTree.Tag.BITXOR, - ref, - make.Parens(make.Binary(JCTree.Tag.USR, ref, make.Literal(32))) - ) - ) - ); - } - - private static class MandatoryDesugarAnnotationTreeScanner extends TreeScanner { - - private final Log log; - - private final CompilationUnitTree compilationUnit; - - public MandatoryDesugarAnnotationTreeScanner(Log log, CompilationUnitTree compilationUnit) { - this.log = log; - this.compilationUnit = compilationUnit; - } - - @Override - public Void visitClass(ClassTree node, Void aVoid) { - if ("RECORD".equals(node.getKind().toString())) { - if ( - node.getModifiers().getAnnotations().stream() - .noneMatch(annotation -> { - Type type = ((JCTree.JCAnnotation) annotation).type; - return Desugar.class.getName().equals(type.toString()); - }) - ) { - JavaFileObject oldSource = log.useSource(compilationUnit.getSourceFile()); - try { - log.error( - (JCTree.JCClassDecl) node, - new JCDiagnostic.Error( - "jabel", - "missing.desugar.on.record", - "Must be annotated with @Desugar" - ) - ); - } finally { - log.useSource(oldSource); - } - } - } - return super.visitClass(node, aVoid); - } - } } diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/SwitchRetrofittingTaskListener.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/SwitchRetrofittingTaskListener.java new file mode 100644 index 0000000..18c60a0 --- /dev/null +++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/SwitchRetrofittingTaskListener.java @@ -0,0 +1,683 @@ +package com.github.bsideup.jabel; + +import java.lang.reflect.*; +import java.util.*; + +import com.sun.source.tree.*; +import com.sun.source.util.*; +import com.sun.source.util.TreeScanner; +import com.sun.tools.javac.code.*; +import com.sun.tools.javac.tree.*; +import com.sun.tools.javac.tree.JCTree.*; +import com.sun.tools.javac.util.*; +import com.sun.tools.javac.util.List; + +import static com.github.bsideup.jabel.RecordPatternHelper.*; + + +//TODO: the class need some reworking since JCSwitchExpression doesn't exists in Java 12- +/** + * Transforms modern switch constructs (Java 17+) into Java 8 compatible code. + *

+ * This will extract {@code null} case to a {@code if(sel==null)} statement. + *
+ * And will move (record) pattern/guard switch labels, to the switch condition. + * Letting the compiler optimize it. + */ +public class SwitchRetrofittingTaskListener implements TaskListener { + // region compiler compatibility + + // Because we're compiling with Java25, the old method (without guards) doens't exists. + private static Method LEGACY_MAKE_CASE; + // Some internal states to avoid getting errors everytimes we trying to use a + // not found feature. + private static Boolean GUARDS, LABELS, BODIES, DEFAULT_CASES, CONSTANT_CASES; + + /** Get the guard expression from a case, or null if guards are unsupported. */ + private static JCExpression getGuard(JCCase caseTree) { + if (GUARDS != null) { + return GUARDS ? caseTree.getGuard() : null; + } + try { + JCExpression guard = caseTree.getGuard(); + GUARDS = true; + return guard; + } catch (NoSuchMethodError ignored) { + GUARDS = false; + return null; + } + } + + /** Get the labels from a case, handling both old and new compiler APIs. */ + @SuppressWarnings("unchecked") + private static List getLabels(JCCase caseTree) { + if (LABELS == null) { + try { + caseTree.getLabels(); + LABELS = true; + } catch (NoSuchMethodError ignored) { + LABELS = false; + } + } + if (LABELS) { + return (List) (List) caseTree.getLabels(); + } + + List labels = caseTree.getExpressions(); + if (labels == null) return List.nil(); + List result = List.nil(); + for (Object label : labels) { + if (!(label instanceof JCTree)) continue; + result = result.append((JCTree) label); + } + return result; + } + + /** Get the arrow-style body of a case, or null if unsupported. */ + private static JCTree getBody(JCCase caseTree) { + if (BODIES != null) { + return BODIES ? caseTree.getBody() : null; + } + try { + JCTree body = caseTree.getBody(); + BODIES = true; + return body; + } catch (NoSuchMethodError ignored) { + BODIES = false; + return null; + } + } + + // TODO: CaseTree.CaseKind doesn't exists in Java 12- + /** + * Create a new case tree, handling different JDK signatures.
+ * JDK 21+: Case(CaseKind, List labels, JCExpression guard, List stats, JCTree body)
+ * JDK 17-20: Case(CaseKind, List labels, List stats, JCTree body) + */ + @SuppressWarnings("unchecked") + private JCCase makeCase( + CaseTree.CaseKind kind, List labels, JCExpression guard, + List stats, JCTree body + ) { + if (GUARDS != false) { + try { + JCCase c = make.Case(kind, (List) labels, guard, stats, body); + GUARDS = true; + return c; + } catch (NoSuchMethodError ignored) {} + } + + try { + GUARDS = false; + if (LEGACY_MAKE_CASE == null) { + LEGACY_MAKE_CASE = TreeMaker.class.getMethod( + "Case", + CaseTree.CaseKind.class, + List.class, + List.class, + JCTree.class + ); + } + return (JCCase) LEGACY_MAKE_CASE.invoke(make, kind, labels, stats, body); + } catch (Exception ignored) {} // Hope this never happen... + return null; + } + + // Because the return type of these two methods may not exist, we need delay the + // call to an inner class. + // Like that, the class initialization error can be catched easily. + private static final class DefaultCaseLabelFactory { + static JCTree make(TreeMaker m) { + return m.DefaultCaseLabel(); + } + } + + private static final class ConstantCaseLabelFactory { + static JCTree make(TreeMaker m, JCExpression lit) { + return m.ConstantCaseLabel(lit); + } + } + + /** Create a default case label. Returns null if unsupported (JDK < 17). */ + private JCTree makeDefaultCaseLabel() { + if (DEFAULT_CASES != null) { + return DEFAULT_CASES ? DefaultCaseLabelFactory.make(make) : null; + } + try { + JCTree c = DefaultCaseLabelFactory.make(make); + DEFAULT_CASES = true; + return c; + } catch (NoSuchMethodError | NoClassDefFoundError ignored) { + DEFAULT_CASES = false; + return null; + } + } + + /** Creates a case label for an {@code Integer}, handling JDK 17-20 and 21+ APIs. */ + private JCTree makeLabel(int i) { + if (CONSTANT_CASES != null) { + return CONSTANT_CASES ? ConstantCaseLabelFactory.make(make, make.Literal(i)) : make.Literal(i); + } + try { + JCTree tree = ConstantCaseLabelFactory.make(make, make.Literal(i)); // JDK 21+ + CONSTANT_CASES = true; + return tree; + } catch (NoSuchMethodError | NoClassDefFoundError ignored) { + CONSTANT_CASES = false; + return make.Literal(i); // JDK 17-20 + } + } + + private static boolean isSwitchExpression(Tree tree) { + return tree != null && getClassName(tree).equals("JCSwitchExpression"); + } + + private static boolean isDefault(JCTree label) { + return label != null && getClassName(label).equals("JCDefaultCaseLabel"); + } + + private static boolean isConstant(JCTree label) { + return label != null && getClassName(label).equals("JCConstantCaseLabel"); + } + + private static boolean isNull(JCTree label) { + if (label instanceof JCLiteral) { + return ((JCLiteral) label).typetag == TypeTag.BOT; + } + if (!isConstant(label)) return false; + JCExpression expr = ((JCConstantCaseLabel) label).expr; + return expr instanceof JCLiteral && ((JCLiteral) expr).typetag == TypeTag.BOT; + } + + private static boolean hasDefault(JCCase caseTree) { + List labels = getLabels(caseTree); + if (labels.isEmpty()) return true; + for (JCTree label : labels) { + if (isDefault(label)) return true; + } + return false; + } + + private static JCExpression getLabelExpression(JCTree label) { + if (label instanceof JCExpression) { + return (JCExpression) label; + } + if (isConstant(label)) { + return ((JCConstantCaseLabel) label).expr; + } + return null; + } + + private static boolean hasPatterns(List cases) { + if (cases == null) return false; + for (JCCase caseTree : cases) { + if (caseTree == null) continue; + if (getGuard(caseTree) != null) return true; + for (JCTree label : getLabels(caseTree)) { + if (isPattern(label)) return true; + } + } + return false; + } + + private static boolean hasNull(List cases) { + if (cases == null) return false; + for (JCCase caseTree : cases) { + if (caseTree == null) continue; + for (JCTree label : getLabels(caseTree)) { + if (isNull(label)) return true; + } + } + return false; + } + + // end region + // region task listener + + final RecordPatternHelper helper; + final TreeMaker make; + final Symtab syms; + final Names names; + private int tempVarCounter = 0; + + public SwitchRetrofittingTaskListener(Context context) { + helper = new RecordPatternHelper(context); + make = TreeMaker.instance(context); + syms = Symtab.instance(context); + names = Names.instance(context); + } + + @Override + public void started(TaskEvent e) { + if (e.getKind() != TaskEvent.Kind.ENTER) return; + if (!(e.getCompilationUnit() instanceof JCCompilationUnit)) return; + new SwitchTranslator().translate((JCCompilationUnit) e.getCompilationUnit()); + } + + @Override + public void finished(TaskEvent e) { + } + + + public class SwitchTranslator extends TreeTranslator { + private final Map captures = new HashMap<>(); + + @Override + public void visitSwitch(JCSwitch tree) { + super.visitSwitch(tree); + if (!needsTransform(tree.cases)) { + tree.cases = injectDefault(tree.cases); + return; + } + + make.at(tree.pos); + if (isComplex(tree.selector)) { + Name sv = names.fromString("$switch$" + (tempVarCounter++)); + result = make.Block(0, List.of( + makeFinalVar(sv, make.Type(syms.objectType), tree.selector), + transformSwitch(make.Ident(sv), tree.cases, false, null) + )); + } else { + result = transformSwitch(tree.selector, tree.cases, false, null); + } + } + + @Override + public void visitBlock(JCBlock tree) { + ListBuffer buffer = null; + for (JCStatement stmt : tree.stats) { + List found = findPatternSwitches(stmt); + + for (JCSwitchExpression sw : found) { + if (!isComplex(sw.selector)) continue; + if (buffer == null) { + ListBuffer buf = new ListBuffer<>(); + for (JCStatement s : tree.stats) { + if (s == stmt) break; + buf.append(s); + } + buffer = buf; + } + + Name sv = names.fromString("$switch$" + (tempVarCounter++)); + buffer.append(make.VarDef(make.Modifiers(0), sv, make.Type(syms.objectType), null)); + captures.put(sw, sw.selector); + sw.selector = make.Ident(sv); + } + if (buffer != null) buffer.append(stmt); + } + + if (buffer != null) { + tree.stats = buffer.toList(); + } + super.visitBlock(tree); + } + + @Override + public T translate(T tree) { + if (tree == null) return null; + helper.collectRecord(tree); + if (!isSwitchExpression(tree)) { + return super.translate(tree); + } + + JCSwitchExpression sw = (JCSwitchExpression) tree; + make.at(sw.pos); + JCExpression rawSel = captures.remove(sw); // consume capture if registered + sw.selector = translate(sw.selector); + sw.cases = translate(sw.cases); + + if (needsTransform(sw.cases)) { + JCSwitch ns = transformSwitch(sw.selector, sw.cases, true, rawSel); + sw.selector = ns.selector; + sw.cases = ns.cases; + } else { + sw.cases = injectDefault(sw.cases); + } + return tree; + } + + /** TreeTranslator skips the arrow body field, translate it manually. */ + @Override + public void visitCase(JCCase tree) { + super.visitCase(tree); + JCTree body = getBody(tree); + if (body == null) return; + JCTree saved = result; + setCaseBody(tree, translate(body)); + result = saved; + } + + /** Recursively collects all JCSwitchExpression nodes inside a statement. */ + private List findPatternSwitches(JCTree node) { + ListBuffer out = new ListBuffer<>(); + new TreeScanner() { + @Override + public Void scan(Tree t, Void v) { + if (t == null) return null; + if (isSwitchExpression(t)) { + JCSwitchExpression se = (JCSwitchExpression) t; + if (needsTransform(se.cases)) out.append(se); + // Don't recurse into the switch itself, nested ones handled separately. + return null; + } + return super.scan(t, v); + } + }.scan(node, null); + return out.toList(); + } + } + + /** + * Replaces a pattern/null switch to a standard switch using a ternary-chain + * dispatcher in the condition. + *

+ * The ternary chain maps each case condition to an index which becomes the new + * selector.
+ * Each case label is rewritten with this index, plus the record pattern, if + * any, is lowered to the case.
+ * + * @param sel the selector expression + * @param cases already-translated case list + * @param expression whether to build a switch expression or statement + * @param rawSel original selector to inject as a capture, or {@code null} + */ + public JCSwitch transformSwitch(JCExpression sel, List cases, boolean expression, JCExpression rawSel) { + java.util.List nonDefs = new ArrayList<>(); + JCCase defCase = null; + for (JCCase c : cases) { + if (hasDefault(c)) defCase = c; + else nonDefs.add(c); + } + int n = nonDefs.size(); + + // Ternary chain + final Name selName = sel instanceof JCIdent ? ((JCIdent) sel).name : null; + JCExpression ternary = make.Literal(n); + for (int i = n - 1; i >= 0; i--) { + JCCase c = nonDefs.get(i); + JCExpression cond = buildCondition(getLabels(c), sel, c); + if (cond == null) continue; + + if (i == 0 && rawSel != null && selName != null) { + // Replace the first reference to the selector ident with (sv = rawSel). + cond = new TreeTranslator() { + boolean a = true; + @Override + public void visitIdent(JCIdent id) { + if (a && id.name == selName) { + a = false; + result = make.Parens(make.Assign(make.Ident(selName), rawSel)); + } else { + result = id; + } + } + }.translate(cond); + } + ternary = make.Conditional(cond, make.Literal(i), ternary); + } + // In case of + if (n == 0 && rawSel != null && selName != null) { + ternary = make.Parens(make.Assign(make.Ident(selName), rawSel)); + } + + // Rebuild cases with int label + ListBuffer newCases = new ListBuffer<>(); + for (int i = 0; i < n; i++) { + JCCase nc = makeCase( + CaseTree.CaseKind.STATEMENT, + List.of(makeLabel(i)), + null, + List.of(make.Block(0, buildCaseBody(nonDefs.get(i), sel, expression))), + null + ); + if (nc != null) newCases.append(nc); + } + + // Default case + JCTree dl = makeDefaultCaseLabel(); + if (dl != null) { + JCCase dc = makeCase( + CaseTree.CaseKind.STATEMENT, + List.of(dl), + null, + List.of(make.Block(0, defCase != null ? + buildCaseBody(defCase, sel, expression) : + List.of(makeMatchExceptionThrow()) + )), + null + ); + if (dc != null) newCases.append(dc); + } + + return make.Switch(ternary, newCases.toList()); + } + + public List buildCaseBody(JCCase c, JCExpression sel, boolean expression) { + ListBuffer out = new ListBuffer<>(); + addBindings(c, sel, out); + JCTree body = getBody(c); + + if (body instanceof JCBlock) { + for (JCStatement s : ((JCBlock) body).stats) { + out.append(s); + } + } else if (body != null) { + JCExpression expr = extractExpression(body); + if (expr != null) { + if (expression) { + out.append(make.Yield(expr)); + } else { + out.append(make.Exec(expr)); + out.append(make.Break(null)); + } + } else if (body instanceof JCStatement) { + out.append((JCStatement) body); + } + } else if (c.stats != null) { + for (JCStatement s : c.stats) { + out.append(s); + } + } + + return out.toList(); + } + + /** + * Injects + * {@code default: throw new UnsupportedOperationException("MatchException");} + * for exhaustive switches that have no explicit default. + */ + public List injectDefault(List cases) { + if (cases == null) return cases; + for (JCCase c : cases) { + if (c != null && hasDefault(c)) { + return cases; + } + } + + CaseTree.CaseKind kind = CaseTree.CaseKind.STATEMENT; + for (JCCase c : cases) { + if (c != null) { + kind = c.getCaseKind(); + break; + } + } + + JCTree defaultLabel = makeDefaultCaseLabel(); + if (defaultLabel == null) return cases; + + JCStatement throwStmt = makeMatchExceptionThrow(); + JCCase defaultCase = makeCase( + kind, + List.of(defaultLabel), + null, + List.of(throwStmt), + kind == CaseTree.CaseKind.RULE ? throwStmt : null + ); + return defaultCase != null ? cases.append(defaultCase) : cases; + } + + public JCExpression buildCondition(List labels, JCExpression sel, JCCase caseTree) { + JCExpression condition = null; + for (JCTree label : labels) { + JCExpression lc = buildLabelCondition(label, sel); + if (lc == null) continue; + condition = condition == null ? lc : make.Binary(JCTree.Tag.OR, condition, lc); + } + + JCExpression guard = getGuard(caseTree); + if (guard == null) return condition; + + Map bindingMap = collectBindings(labels, sel); + if (!bindingMap.isEmpty()) { + guard = new TreeTranslator() { + @Override + public void visitIdent(JCIdent ident) { + JCExpression replacement = bindingMap.get(ident.name); + result = replacement != null ? replacement : ident; + } + }.translate(guard); + } + + return condition == null ? guard : make.Binary(JCTree.Tag.AND, condition, guard); + } + + public JCExpression buildLabelCondition(JCTree label, JCExpression sel) { + if (isNull(label)) { + return make.Binary(JCTree.Tag.EQ, helper.copy(sel), make.Literal(TypeTag.BOT, null)); + } else if (isPattern(label)) { + JCVariableDecl pv = getPatternVar(label); + if (pv != null) { + return make.TypeTest(helper.copy(sel), pv.vartype); + } + JCExpression rt = getRecordType(label); + if (rt != null) { + return make.TypeTest(helper.copy(sel), rt); + } + JCExpression pt = getPatternType(label); + if (pt != null) { + return make.TypeTest(helper.copy(sel), pt); + } + } else if (!isDefault(label)) { + JCExpression expr = getLabelExpression(label); + if (expr != null) { + return make.Apply( + List.nil(), + make.Select(expr, names.equals), + List.of(helper.copy(sel)) + ); + } + } + return null; + } + + public Map collectBindings(List labels, JCExpression sel) { + Map map = new HashMap<>(); + for (JCTree label : labels) { + JCVariableDecl pv = getPatternVar(label); + if (pv != null) { + map.put(pv.name, make.TypeCast(pv.vartype, helper.copy(sel))); + continue; + } + + JCExpression rt = getRecordType(label); + List nested = getRecordNested(label); + if (rt == null || nested == null) continue; + + List componentNames = helper.getRecordComponentNames(label); + int i = 0; + for (JCTree np : nested) { + Name cn = i < componentNames.size() ? componentNames.get(i++) : null; + JCVariableDecl nv = getPatternVar(np); + if (nv == null) continue; + + map.put(nv.name, make.Apply( + List.nil(), + make.Select( + make.TypeCast(rt, helper.copy(sel)), + cn != null ? cn : nv.name + ), + List.nil() + )); + } + } + return map; + } + + private void addBindings(JCCase caseTree, JCExpression sel, ListBuffer out) { + for (JCTree label : getLabels(caseTree)) { + if (!isPattern(label)) continue; + + JCVariableDecl pv = getPatternVar(label); + if (pv != null) { + out.append(make.VarDef( + make.Modifiers(Flags.FINAL), + pv.name, + pv.vartype, + make.TypeCast(pv.vartype, helper.copy(sel)) + )); + continue; + } + + JCExpression rt = getRecordType(label); + List nested = getRecordNested(label); + if (rt == null || nested == null) continue; + + Name cv = helper.tempName(); + out.append(helper.makeVarDef(cv, rt, helper.makeCast(rt, helper.copy(sel)))); + ListBuffer bindings = new ListBuffer<>(); + helper.extractRecordBindings( + nested, + make.Ident(cv), + helper.getRecordComponentNames(label), + bindings + ); + for (JCVariableDecl decl : bindings) { + out.append(decl); + } + } + } + + private boolean isComplex(JCExpression expr) { + if (expr instanceof JCIdent || expr instanceof JCLiteral) return false; + if (expr instanceof JCFieldAccess) { + return isComplex(((JCFieldAccess) expr).selected); + } + if (expr instanceof JCParens) { + return isComplex(((JCParens) expr).expr); + } + return true; // method calls, new, binary ops, etc. + } + + private static boolean needsTransform(List cases) { + return hasPatterns(cases) || hasNull(cases); + } + + private JCExpression extractExpression(JCTree body) { + if (body instanceof JCExpressionStatement) { + return ((JCExpressionStatement) body).expr; + } + if (body instanceof JCExpression) { + return (JCExpression) body; + } + return null; + } + + private JCStatement makeFinalVar(Name name, JCExpression type, JCExpression init) { + return make.VarDef(make.Modifiers(Flags.FINAL), name, type, init); + } + + // TODO: explain expected cases? + private JCStatement makeMatchExceptionThrow() { + return make.Throw(make.NewClass( + null, + List.nil(), + make.Ident(names.fromString("UnsupportedOperationException")), + List.of(make.Literal("MatchException")), + null + )); + } + + // end region +} \ No newline at end of file From a83439e83fead6d0b41f6e3a305849dae55556f9 Mon Sep 17 00:00:00 2001 From: ZetaMap Date: Thu, 9 Apr 2026 16:44:05 +0200 Subject: [PATCH 2/3] updated tests --- example/src/main/java/com/example/Main.java | 10 +++++++++- .../src/test/java/com/example/JavaFeaturesTest.java | 1 - 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/example/src/main/java/com/example/Main.java b/example/src/main/java/com/example/Main.java index 400fe80..f4246e1 100644 --- a/example/src/main/java/com/example/Main.java +++ b/example/src/main/java/com/example/Main.java @@ -69,7 +69,15 @@ public static void main(String[] args) { assert sp.value == 99 && !sp.spoofed : "SpoofChild"; System.out.println("SpoofChild: value=" + sp.value + " spoofed=" + sp.spoofed); //Implicit classes - //Java25FeaturesExample2.main(); // in theory we cannot reference the class + // Need reflection since implicit classes cannot be referenced + try { + java.lang.reflect.Method main = Class.forName("Java25FeaturesExample2").getDeclaredMethod("main"); + main.setAccessible(true); + main.invoke(null); + } catch (Exception e) { + System.err.println("Implicit classes error: " + e.toString()); + return; + } System.out.println("\n=== All features work! ==="); } diff --git a/example/src/test/java/com/example/JavaFeaturesTest.java b/example/src/test/java/com/example/JavaFeaturesTest.java index f4bf334..0303a0b 100644 --- a/example/src/test/java/com/example/JavaFeaturesTest.java +++ b/example/src/test/java/com/example/JavaFeaturesTest.java @@ -9,6 +9,5 @@ public class JavaFeaturesTest { @Test public void shouldWork() { Main.main(new String[0]); - //TODO: Java25FeaturesExample2.main(new String[0]); } } \ No newline at end of file From 1923c07211f59e4cbd48a4f7f097502ba8860c7e Mon Sep 17 00:00:00 2001 From: ZetaMap Date: Sun, 26 Apr 2026 21:35:45 +0200 Subject: [PATCH 3/3] added more examples, fixed bugs, modified plugin initialization, remade java 21 switch retrofitting --- example/build.gradle | 5 +- .../com/example/Java14FeaturesExample.java | 12 +- .../com/example/Java16FeaturesExample.java | 17 +- .../com/example/Java17FeaturesExample.java | 11 +- .../com/example/Java21FeaturesExample.java | 178 +- .../com/example/Java25FeaturesExample2.java | 1 - .../com/example/Java9FeaturesExample.java | 2 +- example/src/main/java/com/example/Main.java | 4 +- jabel-javac-plugin/build.gradle | 2 +- .../FlexibleMainRetrofittingTaskListener.java | 12 +- .../InstanceofRetrofittingTaskListener.java | 144 -- .../bsideup/jabel/JabelCompilerPlugin.java | 225 +-- .../bsideup/jabel/RecordPatternHelper.java | 240 --- ...RecordPatternRetrofittingTaskListener.java | 155 ++ .../RecordsRetrofittingTaskListener.java | 59 +- .../jabel/SwitchRetrofittingTaskListener.java | 1611 +++++++++++++---- 16 files changed, 1682 insertions(+), 996 deletions(-) delete mode 100644 jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/InstanceofRetrofittingTaskListener.java delete mode 100644 jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordPatternHelper.java create mode 100644 jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordPatternRetrofittingTaskListener.java diff --git a/example/build.gradle b/example/build.gradle index a361ac5..24cf60f 100644 --- a/example/build.gradle +++ b/example/build.gradle @@ -2,8 +2,9 @@ plugins { id "java" } -configure([tasks.compileJava]) { +compileJava { sourceCompatibility = 25 + targetCompatibility = 8 options.release = 8 javaCompiler = javaToolchains.compilerFor { @@ -22,7 +23,7 @@ test { } dependencies { - annotationProcessor project(":jabel-javac-plugin") + annotationProcessor project(":jabel-javac-plugin")//.files("build/libs/jabel-javac-plugin.jar") testImplementation 'junit:junit:4.13.2' } diff --git a/example/src/main/java/com/example/Java14FeaturesExample.java b/example/src/main/java/com/example/Java14FeaturesExample.java index c477b73..e9b82fa 100644 --- a/example/src/main/java/com/example/Java14FeaturesExample.java +++ b/example/src/main/java/com/example/Java14FeaturesExample.java @@ -87,8 +87,18 @@ void statementSwitchWithArrow(Day day) { switch (day) { case SAT -> System.out.println("Saturday"); case SUN -> System.out.println("Sunday"); - default -> System.out.println("Weekday"); } + System.out.println("Weekday"); + } + + void noDefaultInjectedCase(int a) { + switch (a) { + case 1 -> { + System.out.println("Wrong value!!"); + return; + } + } + System.out.println("Passed!"); } int[] asArrayElements(int a, int b) { diff --git a/example/src/main/java/com/example/Java16FeaturesExample.java b/example/src/main/java/com/example/Java16FeaturesExample.java index 4ec031d..e5ce86c 100644 --- a/example/src/main/java/com/example/Java16FeaturesExample.java +++ b/example/src/main/java/com/example/Java16FeaturesExample.java @@ -59,19 +59,20 @@ record Person(String name, int age) { Objects.requireNonNull(name); if (age < 0) throw new IllegalArgumentException(); } + } record MultipleTypes(boolean bool, byte b, short s, int i, long l, float f, double d, String str, Point p) { - @Override - public boolean equals(java.lang.Object o) { - return this == o; - } + @Override + public boolean equals(Object o) { + return this == o; + } - @Override - public String str() { - return "nothing"; - } + @Override + public String str() { + return "different string: " + str; + } } void patternMatchingInstanceof(Object obj) { diff --git a/example/src/main/java/com/example/Java17FeaturesExample.java b/example/src/main/java/com/example/Java17FeaturesExample.java index 1c1c55b..05846e2 100644 --- a/example/src/main/java/com/example/Java17FeaturesExample.java +++ b/example/src/main/java/com/example/Java17FeaturesExample.java @@ -16,7 +16,6 @@ * public abstract class Shape { } // sealed, permits removed * public final class Circle extends Shape { } * public class Square extends Shape { } // non-sealed removed - * // Note: PermittedSubclasses bytecode attribute must be removed * *

* REDUNDANT_STRICTFP @@ -26,6 +25,16 @@ * * // Decompiled (Java 8): same code (strictfp kept but redundant since Java 17) * + *

+ * PATTERN_SWITCH (preview since JDK 17) + *

+ * // Source (Java 17+ preview):
+ * switch (obj) {
+ *     case String s -> s.length();
+ *     case Integer i -> i;
+ *     default -> 0;
+ * }
+ * 
*/ public class Java17FeaturesExample { diff --git a/example/src/main/java/com/example/Java21FeaturesExample.java b/example/src/main/java/com/example/Java21FeaturesExample.java index cc29c8a..f531130 100644 --- a/example/src/main/java/com/example/Java21FeaturesExample.java +++ b/example/src/main/java/com/example/Java21FeaturesExample.java @@ -2,11 +2,11 @@ package com.example; - /** * Examples of Java 21 features with manual desugaring:
* * CASE_NULL + * *
  * // Source (Java 21+):
  * switch (str) {
@@ -16,16 +16,15 @@
  * }
  *
  * // Decompiled (Java 8):
- * if (str == null) {
- *     return "null";
- * }
- * switch (str) {
- *     case "a": return "A";
+ * switch (str == null ? -1 : str.equals("a") ? 0 : 1) {
+ *     case -1: return "null";
+ *     case 0: return "A";
  *     default: return "other";
  * }
  * 
*

* PATTERN_SWITCH + * *

  * // Source (Java 21+):
  * switch (obj) {
@@ -35,18 +34,21 @@
  * }
  *
  * // Decompiled (Java 8):
- * if (obj instanceof String) {
- *     String s = (String) obj;
- *     return s.length();
- * } else if (obj instanceof Integer) {
- *     Integer i = (Integer) obj;
- *     return i;
- * } else {
- *     return 0;
+ * switch (Objects.requireNonNull(obj) instanceof String ? 0 : obj instanceof Integer ? 1 : 2) {
+ *     case 0: {
+ *         String s = (String) obj;
+ *         return s.length();
+ *     }
+ *     case 1: {
+ *         Integer i = (Integer) obj;
+ *         return i;
+ *     }
+ *     default: return 0;
  * }
  * 
*

* RECORD_PATTERNS + * *

  * // Source (Java 21+):
  * if (obj instanceof Point(int x, int y)) {
@@ -63,9 +65,11 @@
  * 
*

* UNCONDITIONAL_PATTERN_IN_INSTANCEOF + * *

  * // Source (Java 21+):
- * if (str instanceof CharSequence cs) { }  // always true for non-null
+ * if (str instanceof CharSequence cs) {
+ * } // always true for non-null
  *
  * // Decompiled (Java 8):
  * if (str != null) {
@@ -74,20 +78,89 @@
  * 
*/ public class Java21FeaturesExample { + public enum InitDeadlockTest { + START, END; + + static final String MSG = describe(START); + + static String describe(InitDeadlockTest dz) { + return switch (dz) { + case START -> "starting"; + case END -> "ending"; + case null -> "unknown"; + }; + } + } + + enum Day { + MON, TUE, WED, THU, FRI, SAT, SUN + } - sealed interface Geometry permits Point, Triangle, Shape{} - record Point(int x, int y) implements Geometry {} - record Triangle(Point a, Point b, Point c) implements Geometry {} - record Shape(String name, Triangle triangle) implements Geometry {} + sealed interface Geometry permits Point, Triangle, Shape { + } + + record Point(int x, int y) implements Geometry { + } + + record Triangle(Point a, Point b, Point c) implements Geometry { + } + + record Shape(String name, Triangle triangle) implements Geometry { + } - class Builder{ - public int n; - public Builder build(Object o){ + class Builder { + public int n = InitDeadlockTest.START.ordinal(); + + public Builder build(Object o) { n++; return this; } } + String enumWithNull(Day day) { + return switch (day) { + case SAT, Day.SUN -> "weekend"; + case null -> ""; + case MON, TUE, WED, THU, FRI -> "weekday"; + }; + } + + String enumWithNullAndPattern(Day day) { + return switch (day) { + case SAT, Day.SUN -> "weekend"; + case null -> ""; + case MON, TUE, WED, THU, FRI -> "weekday"; + case Day d -> "not possible: " + d; + }; + } + + String numberWithNull(Float day) { + return switch (day) { + case 6f, 7f -> "weekend"; + case null -> ""; + case 1f, 2f, 3f, 4f, 5f -> "weekday"; + case Float f when f < 0 -> (day = 10f) + " "; + default -> ""; + }; + } + + String number(Float day) { + return switch (day) { + case 6f, 7f -> "weekend"; + case 1f, 2f, 3f, 4f, 5f -> "weekday"; + default -> ""; + }; + } + + String classicWithNull(Integer day) { + return switch (day) { + case 6, 7 -> "weekend"; + case null -> ""; + case 1, 2, 3, 4, 5 -> "weekday"; + default -> ""; + }; + } + String caseNull(String input) { return switch (input) { case "hello" -> "hello"; @@ -103,6 +176,24 @@ String caseNullWithDefault(String input) { }; } + String mixedCaseType(Object msg) { + String v; + switch (msg) { + case Day.MON: + v = "Monday?"; + break; + case Point(int x, int y) when x > 100: + v = "Far away at " + x + "," + y; + case null: + throw new RuntimeException(); + case String s: + v = "Message: " + s; + default: + v = "Other"; + } + return v; + } + String patternSwitch(Object obj) { return switch (obj) { case String s -> "String: " + s; @@ -111,6 +202,14 @@ String patternSwitch(Object obj) { }; } + String patternSwitchClass(Object obj) { + return switch (obj) { + case Class c when c == Integer.class -> "Int"; + case Class c -> "Class: " + c.getName(); + default -> throw new IllegalArgumentException(obj.toString()); + }; + } + String patternSwitchVariable(Object obj) { String var = "ignored part:" + switch (obj) { case String s -> var = "String: " + s; @@ -128,7 +227,9 @@ String patternSwitchWithGuard(Object obj) { case String s when s.length() < 5 -> "short"; case Integer i -> "int"; case String s when s.length() > 3 && s.startsWith("a") -> "long-a"; - case String s -> "long"; + case String[] s -> "string list"; + case String s -> s; + case int[] i -> "int list"; default -> "not a string"; }).trim(); } @@ -191,8 +292,7 @@ String nestedSwitch(Object outer, Object inner) { int inMethod(Object obj) { return new Builder().build(1).build(switch (obj) { - case null -> -1; - case String s when s.length() > 5 -> s.length()-5; + case String s when s.length() > 5 -> s.length() - 5; case String s -> s.length(); case Integer i -> i; default -> obj.hashCode(); @@ -201,13 +301,13 @@ int inMethod(Object obj) { void switchInIf(Object obj) { if (switch (obj) { - case null -> -1; - case String s when s.length() > 5 -> s.length()-5; - case String s -> s.length(); - case Integer i -> i; - default -> obj.hashCode(); + case null -> -1; + case String s when s.length() > 5 -> s.length() - 5; + case String s -> s.length(); + case Integer i -> i; + default -> obj.hashCode(); } > 0) { - System.out.println("Working"); + System.out.println("Working"); } } @@ -215,26 +315,33 @@ void switchInIf(Object obj) { void recordPatternInstanceof(Object obj) { if (obj instanceof Point(int x, int y)) { System.out.println(x + y); + } else if (obj instanceof Shape(String name, Triangle triangle) && name.equals("line")) { + System.out.println("A line cannot be a triangle!"); } } void nestedRecordPattern(Object obj) { if (obj instanceof Shape(var name, Triangle(var p, Point(int bx, int by), Point(int cx, int cy)))) { - System.out.println("Shape " + name + " has triangle with vertices: " + - "(" + p.x() + "," + p.y() + "), (" + bx + "," + by + "), (" + cx + "," + cy + ")"); + System.out.println("Shape " + name + ": " + + "(" + p.x() + "," + p.y() + "), (" + bx + "," + by + "), (" + cx + "," + cy + ")"); } } int recordPatternSwitchStatement(Geometry obj) { return switch (obj) { - case Point(int x, int y): yield x + y; + case Point(int x, int y) when x == 0 && y == 0: { + yield 0; + } + case Point(int x, int y): + yield x + y; case Triangle(Point p, Point(int bx, int by), Point(int cx, int cy)): { int sum = p.x() * p.y(); sum += bx * by; sum += cx * cy; yield sum; } - case Shape s: yield 0; + case Shape s: + yield 0; // No default case since class is sealed. }; } @@ -253,4 +360,3 @@ void unconditionalPatternObject(T str) { } } } - diff --git a/example/src/main/java/com/example/Java25FeaturesExample2.java b/example/src/main/java/com/example/Java25FeaturesExample2.java index a419644..1d8678f 100644 --- a/example/src/main/java/com/example/Java25FeaturesExample2.java +++ b/example/src/main/java/com/example/Java25FeaturesExample2.java @@ -1,6 +1,5 @@ /** * Examples of Java 25 features with manual desugaring:
- * IMPLICIT_CLASSES *
  * // Source (Java 25+):
diff --git a/example/src/main/java/com/example/Java9FeaturesExample.java b/example/src/main/java/com/example/Java9FeaturesExample.java
index 1cbf816..e563600 100644
--- a/example/src/main/java/com/example/Java9FeaturesExample.java
+++ b/example/src/main/java/com/example/Java9FeaturesExample.java
@@ -53,7 +53,7 @@ public class Java9FeaturesExample {
 
     @SafeVarargs
     private final void safeVarargsMethod(List... lists) {
-        for (var list : lists) System.out.println(list);
+        for (List list : lists) System.out.println(list);
     }
 
     List diamondWithAnonymous = new ArrayList<>() {
diff --git a/example/src/main/java/com/example/Main.java b/example/src/main/java/com/example/Main.java
index f4246e1..ce1134b 100644
--- a/example/src/main/java/com/example/Main.java
+++ b/example/src/main/java/com/example/Main.java
@@ -27,6 +27,7 @@ public static void main(String[] args) {
         var java14 = new Java14FeaturesExample();
         System.out.println("Switch expression: " + java14.switchMultipleLabels(Java14FeaturesExample.Day.SAT));
         System.out.println("Switch with yield: " + java14.switchExpressionWithYield(Java14FeaturesExample.Day.WED));
+        java14.noDefaultInjectedCase(16);
 
         System.out.println("\n--- Java 15 Features ---");
         var java15 = new Java15FeaturesExample();
@@ -36,7 +37,7 @@ public static void main(String[] args) {
         var java16 = new Java16FeaturesExample();
         java16.patternMatchingInstanceof("Hello");
         var point = new Java16FeaturesExample.Point(3, 4);
-        System.out.println("Record: " + point);
+        System.out.println("Record: " + point + " (" + point.x() + ", " + point.y() + ")");
 
         System.out.println("\n--- Java 17 Features ---");
         var java17 = new Java17FeaturesExample();
@@ -47,6 +48,7 @@ public static void main(String[] args) {
         var java21 = new Java21FeaturesExample();
         System.out.println("case null: " + java21.caseNull(null));
         System.out.println("pattern switch: " + java21.patternSwitch("test"));
+        System.out.println("pattern switch statement: " + java21.patternSwitchVariable("test"));
         System.out.println("record pattern: " + java21.recordPatternSwitchStatement(new Java21FeaturesExample.Point(5, 7)));
         System.out.println("switch yield: " + java21.switchYield("hello"));
 
diff --git a/jabel-javac-plugin/build.gradle b/jabel-javac-plugin/build.gradle
index 30de699..860c3e9 100644
--- a/jabel-javac-plugin/build.gradle
+++ b/jabel-javac-plugin/build.gradle
@@ -14,8 +14,8 @@ dependencies {
 }
 
 task sourcesJar(type: Jar) {
-    archiveClassifier = 'sources'
     from sourceSets.main.allJava
+    archiveClassifier = 'sources'
 }
 
 javadoc {
diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/FlexibleMainRetrofittingTaskListener.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/FlexibleMainRetrofittingTaskListener.java
index a57e47d..826f2e1 100644
--- a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/FlexibleMainRetrofittingTaskListener.java
+++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/FlexibleMainRetrofittingTaskListener.java
@@ -168,9 +168,15 @@ public boolean isMain(JCMethodDecl method) {
         if (method.params.isEmpty()) return true;
         if (method.params.size() != 1) return false;
         JCTree vartype = method.params.get(0).vartype;
-        if (!(vartype instanceof JCArrayTypeTree)) return false;
-        String elem = ((JCArrayTypeTree) vartype).elemtype.toString();
-        return elem.equals("String") || elem.endsWith(".String");
+        if(!(vartype instanceof JCArrayTypeTree)) return false;
+        // TODO find a better way?
+        switch(((JCArrayTypeTree)vartype).getType().toString()){
+            case "java.lang.String":
+            case "String":
+                return true;
+            default:
+                return false;
+        }
     }
 
     /** If {@code toLocalMain} is {@code true}, a zero-arg constructor must be present. */
diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/InstanceofRetrofittingTaskListener.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/InstanceofRetrofittingTaskListener.java
deleted file mode 100644
index 44403e1..0000000
--- a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/InstanceofRetrofittingTaskListener.java
+++ /dev/null
@@ -1,144 +0,0 @@
-package com.github.bsideup.jabel;
-
-import com.sun.source.util.*;
-import com.sun.tools.javac.tree.*;
-import com.sun.tools.javac.tree.JCTree.*;
-import com.sun.tools.javac.util.*;
-
-import static com.github.bsideup.jabel.RecordPatternHelper.*;
-
-
-/**
- * Transforms record patterns and binding patterns in instanceof (Java 16+/21+)
- * into Java 8 compatible bytecode.
- * 

- * Handles: - *

    - *
  • Binding patterns (if obj instanceof String s).
  • - *
  • Record patterns (if obj instanceof Point(int x, int y)).
  • - *
  • Nested record patterns (if obj instanceof Line(Point(int x1, int y1), - * Point end)).
  • - *
  • Unconditional patterns (if str instanceof CharSequence cs).
  • - *
- */ -public class InstanceofRetrofittingTaskListener implements TaskListener { - final RecordPatternHelper helper; - final TreeMaker make; - - public InstanceofRetrofittingTaskListener(Context context) { - helper = new RecordPatternHelper(context); - make = TreeMaker.instance(context); - } - - @Override - public void started(TaskEvent e) { - if (e.getKind() != TaskEvent.Kind.ENTER) return; - if (!(e.getCompilationUnit() instanceof JCCompilationUnit)) return; - new InstanceofTranslator().translate((JCCompilationUnit) e.getCompilationUnit()); - } - - @Override - public void finished(TaskEvent e) { - } - - public class InstanceofTranslator extends TreeTranslator { - @Override - public T translate(T tree) { - if (tree == null) return null; - helper.collectRecord(tree); - - if (tree instanceof JCIf) { - JCIf ifStmt = (JCIf) tree; - JCExpression cond = unwrapParenthesis(ifStmt.cond); - if (!(cond instanceof JCInstanceOf)) { - return super.translate(tree); - } - - JCInstanceOf instanceOf = (JCInstanceOf) cond; - JCTree pattern = instanceOf.pattern; - if (pattern == null) { - return super.translate(tree); - } - - if (isRecordPattern(pattern)) { - transformRecordPattern(ifStmt, instanceOf, pattern); - } else if (isBindingPattern(pattern)) { - transformBindingPattern(ifStmt, instanceOf, pattern); - } - } - - return super.translate(tree); - } - } - - public JCExpression unwrapParenthesis(JCExpression expr) { - while (expr instanceof JCParens) { - expr = ((JCParens) expr).expr; - } - return expr; - } - - /** - * Transform: {@code if (obj instanceof Point(int x, int y)) { body } }
- * Into: - * {@code if (obj instanceof Point) { Point $rec = (Point)obj; int x = $rec.x(); ... body } } - */ - public void transformRecordPattern(JCIf ifStmt, JCInstanceOf instanceOf, JCTree pattern) { - JCExpression recordType = getRecordType(pattern); - if (recordType == null) return; - - make.at(ifStmt.pos); - Name tempVar = helper.tempName(); - - ListBuffer declarations = new ListBuffer<>(); - declarations.append(helper.makeVarDef( - tempVar, - recordType, - helper.makeCast(recordType, instanceOf.expr) - )); - helper.extractRecordBindings( - getRecordNested(pattern), - make.Ident(tempVar), - helper.getRecordComponentNames(pattern), - declarations - ); - - instanceOf.pattern = recordType; - ifStmt.thenpart = buildBlock(declarations.toList(), ifStmt.thenpart); - } - - /** - * Transform: {@code if (obj instanceof String s) { body } }
- * Into: {@code if (obj instanceof String) { String s = (String)obj; body } } - */ - public void transformBindingPattern(JCIf ifStmt, JCInstanceOf instanceOf, JCTree pattern) { - JCVariableDecl var = ((JCBindingPattern) pattern).var; - if (var == null || var.vartype == null) return; - - make.at(ifStmt.pos); - instanceOf.pattern = helper.copy(var.vartype); - - ListBuffer declarations = new ListBuffer<>(); - declarations.append(helper.makeVarDef( - var.name, - var.vartype, - helper.makeCast(var.vartype, instanceOf.expr) - )); - ifStmt.thenpart = buildBlock(declarations.toList(), ifStmt.thenpart); - } - - public JCBlock buildBlock(List declarations, JCStatement body) { - ListBuffer stmts = new ListBuffer<>(); - for (JCVariableDecl decl : declarations) { - stmts.append(decl); - } - if (body instanceof JCBlock) { - for (JCStatement stmt : ((JCBlock) body).stats) { - stmts.append(stmt); - } - } else if (body != null) { - stmts.append(body); - } - return make.Block(0, stmts.toList()); - } -} \ No newline at end of file diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/JabelCompilerPlugin.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/JabelCompilerPlugin.java index 7bf1be1..c3f86ae 100644 --- a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/JabelCompilerPlugin.java +++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/JabelCompilerPlugin.java @@ -1,72 +1,38 @@ package com.github.bsideup.jabel; -import com.sun.source.util.JavacTask; -import com.sun.source.util.Plugin; -import com.sun.tools.javac.api.BasicJavacTask; -import com.sun.tools.javac.code.Source; -import com.sun.tools.javac.util.*; +import java.util.*; -import net.bytebuddy.ByteBuddy; -import net.bytebuddy.agent.ByteBuddyAgent; -import net.bytebuddy.asm.Advice; -import net.bytebuddy.asm.AsmVisitorWrapper; -import net.bytebuddy.description.field.FieldDescription; -import net.bytebuddy.description.field.FieldList; -import net.bytebuddy.description.method.MethodList; -import net.bytebuddy.description.type.TypeDescription; -import net.bytebuddy.dynamic.ClassFileLocator; -import net.bytebuddy.dynamic.loading.ClassInjector; -import net.bytebuddy.dynamic.loading.ClassReloadingStrategy; -import net.bytebuddy.dynamic.scaffold.MethodGraph; -import net.bytebuddy.implementation.Implementation; -import net.bytebuddy.jar.asm.ClassVisitor; -import net.bytebuddy.jar.asm.MethodVisitor; -import net.bytebuddy.jar.asm.Opcodes; -import net.bytebuddy.pool.TypePool; -import net.bytebuddy.utility.JavaModule; +import com.sun.source.util.*; +import com.sun.tools.javac.api.*; +import com.sun.tools.javac.code.*; +import com.sun.tools.javac.util.*; -import java.util.*; +import net.bytebuddy.*; +import net.bytebuddy.agent.*; +import net.bytebuddy.asm.*; +import net.bytebuddy.description.type.*; +import net.bytebuddy.dynamic.*; +import net.bytebuddy.dynamic.loading.*; +import net.bytebuddy.dynamic.scaffold.*; +import net.bytebuddy.pool.*; +import net.bytebuddy.utility.*; import static net.bytebuddy.matcher.ElementMatchers.*; public class JabelCompilerPlugin implements Plugin { - static { - boolean c = false; + static final boolean JABEL_INITIALIZED = initJabel(); + + @SuppressWarnings("resource") + private static boolean initJabel() { + // We cannot easily force features bellow Java 10.35 try { Class.forName("com.sun.tools.javac.code.Source$Feature"); - c = true; - } catch (Exception e) {} - final boolean canPatchSources = c; - - Map visitors = new HashMap() {{ - if (canPatchSources) { - // Disable the preview feature check - AsmVisitorWrapper checkSourceLevelAdvice = Advice.to(CheckSourceLevelAdvice.class) - .on(named("checkSourceLevel").and(takesArguments(2))); - - // Allow features that were introduced together with Records (local enums, static inner members, ...) - AsmVisitorWrapper allowRecordsEraFeaturesAdvice = new FieldAccessStub("allowRecords", true); - - put("com.sun.tools.javac.parser.JavacParser", - new AsmVisitorWrapper.Compound( - checkSourceLevelAdvice, - allowRecordsEraFeaturesAdvice - ) - ); - put("com.sun.tools.javac.parser.JavaTokenizer", checkSourceLevelAdvice); - - put("com.sun.tools.javac.comp.Check", allowRecordsEraFeaturesAdvice); - put("com.sun.tools.javac.comp.Attr", allowRecordsEraFeaturesAdvice); - put("com.sun.tools.javac.comp.Resolve", allowRecordsEraFeaturesAdvice); - - // Lower the source requirement for supported features - AsmVisitorWrapper allowedInSourceAdvice = Advice.to(AllowedInSourceAdvice.class) - .on(named("allowedInSource").and(takesArguments(1))); - put("com.sun.tools.javac.code.Source$Feature", allowedInSourceAdvice); - } - }}; + } catch (Exception e) { + return false; + } + // Install ByteBuddy try { ByteBuddyAgent.install(); } catch (Exception e) { @@ -81,58 +47,65 @@ public class JabelCompilerPlugin implements Plugin { ) ); } + ByteBuddy byteBuddy = new ByteBuddy().with(MethodGraph.Compiler.ForDeclaredMethods.INSTANCE); - ByteBuddy byteBuddy = new ByteBuddy() - .with(MethodGraph.Compiler.ForDeclaredMethods.INSTANCE); + // Hook classes ClassLoader classLoader = JavacTask.class.getClassLoader(); ClassFileLocator classFileLocator = ClassFileLocator.ForClassLoader.of(classLoader); TypePool typePool = TypePool.ClassLoading.of(classLoader); + TypeDescription clazz; + + // Lower features source level + clazz = typePool.describe("com.sun.tools.javac.code.Source$Feature").resolve(); + byteBuddy.decorate(clazz, classFileLocator) + .visit(Advice.to(AllowedInSourceAdvice.class).on(named("allowedInSource").and(takesArguments(1)))) + .make() + .load(classLoader, ClassReloadingStrategy.fromInstalledAgent()); + + // Force enable preview features and suppress its warnings + clazz = typePool.describe("com.sun.tools.javac.code.Preview").resolve(); + byteBuddy.decorate(clazz, classFileLocator) + .visit(Advice.to(IsEnabledAdvice.class).on(named("isEnabled").and(takesArguments(0)))) + .visit(Advice.to(IsPreviewAdvice.class).on(named("isPreview").and(takesArguments(1)))) + .visit(Advice.to(WarnPreviewAdvice.class).on(named("warnPreview"))) + .make() + .load(classLoader, ClassReloadingStrategy.fromInstalledAgent()); + + + // Open internal compiler packages + Set jabelModule = Collections.singleton(JavaModule.ofType(JabelCompilerPlugin.class)); + ClassInjector.UsingInstrumentation.redefineModule( + ByteBuddyAgent.getInstrumentation(), + JavaModule.ofType(JavacTask.class), + Collections.emptySet(), + Collections.emptyMap(), + new HashMap>() {{ + put("com.sun.tools.javac.api", jabelModule); + put("com.sun.tools.javac.tree", jabelModule); + put("com.sun.tools.javac.code", jabelModule); + put("com.sun.tools.javac.comp", jabelModule); + put("com.sun.tools.javac.util", jabelModule); + }}, + Collections.emptySet(), + Collections.emptyMap() + ); - visitors.forEach((className, visitor) -> { - byteBuddy - .decorate( - typePool.describe(className).resolve(), - classFileLocator - ) - .visit(visitor) - .make() - .load(classLoader, ClassReloadingStrategy.fromInstalledAgent()); - }); - - try { - JavaModule jabelModule = JavaModule.ofType(JabelCompilerPlugin.class); - ClassInjector.UsingInstrumentation.redefineModule( - ByteBuddyAgent.getInstrumentation(), - JavaModule.ofType(JavacTask.class), - Collections.emptySet(), - Collections.emptyMap(), - new HashMap>() {{ - put("com.sun.tools.javac.api", Collections.singleton(jabelModule)); - put("com.sun.tools.javac.tree", Collections.singleton(jabelModule)); - put("com.sun.tools.javac.code", Collections.singleton(jabelModule)); - put("com.sun.tools.javac.comp", Collections.singleton(jabelModule)); - put("com.sun.tools.javac.util", Collections.singleton(jabelModule)); - }}, - Collections.emptySet(), - Collections.emptyMap() - ); - // In case of we are running on Java 8 - } catch (NullPointerException ignored) {} + return true; } @Override public void init(JavacTask task, String... args) { + // Useless to continue if Jabel was not initialized correctly + if (!JABEL_INITIALIZED) return; + Context context = ((BasicJavacTask) task).getContext(); removeUnderscoreWarnings(context); task.addTaskListener(new RecordsRetrofittingTaskListener(context)); - task.addTaskListener(new InstanceofRetrofittingTaskListener(context)); + task.addTaskListener(new RecordPatternRetrofittingTaskListener(context)); task.addTaskListener(new SwitchRetrofittingTaskListener(context)); - try { - task.addTaskListener(new FlexibleMainRetrofittingTaskListener(context)); - // Because JCDiagnostic.Warning doesn't exists on Java 8. But we don't care at this point - } catch (NoClassDefFoundError ignored) {} + task.addTaskListener(new FlexibleMainRetrofittingTaskListener(context)); task.addTaskListener(new ImplicitClassesFixerTaskListener(context)); } @@ -142,17 +115,16 @@ public String getName() { } /** Make it auto starts on Java 14+. */ + @Override public boolean autoStart() { return true; } /** Removes warnings about {@code '_'}. */ - private static void removeUnderscoreWarnings(Context context) { - // Need to inherit a class instead. - // This is due to DeferredDiagnosticHandler(Predicate) being DeferredDiagnosticHandler(Filter) on Java 16- - Log.instance(context).new DiscardDiagnosticHandler() { + static void removeUnderscoreWarnings(Context context){ + Log.instance(context).new DiscardDiagnosticHandler(){ @Override - public void report(JCDiagnostic diag) { + public void report(JCDiagnostic diag){ String code = diag.getCode(); if (code.contains("underscore.as.identifier") || code.contains("use.of.underscore.not.allowed")) return; @@ -161,6 +133,7 @@ public void report(JCDiagnostic diag) { }; } + /** Makes all {@link Source.Feature} available in all source levels, except few ones. */ static class AllowedInSourceAdvice { @Advice.OnMethodEnter static void allowedInSource( @@ -180,47 +153,29 @@ static void allowedInSource( } } - static class CheckSourceLevelAdvice { - @Advice.OnMethodEnter - static void checkSourceLevel( - @Advice.Argument(value = 1, readOnly = false) Source.Feature feature - ) { - if (feature.allowedInSource(Source.JDK8)) { - // This must be one of the cases from "AllowedInSourceAdvice" - //noinspection UnusedAssignment - feature = Source.Feature.PRIVATE_SAFE_VARARGS; - } + /** Makes {@link Preview#isEnabled()} always return {@code true}. */ + static class IsEnabledAdvice { + @Advice.OnMethodExit + static void isEnabled(@Advice.Return(readOnly = false) boolean result) { + //noinspection UnusedAssignment + result = true; } } - private static class FieldAccessStub extends AsmVisitorWrapper.AbstractBase { - final String fieldName; - final Object value; - - public FieldAccessStub(String fieldName, Object value) { - this.fieldName = fieldName; - this.value = value; + /** Makes {@link Preview#isPreview(Feature)} always return {@code false}. */ + static class IsPreviewAdvice { + @Advice.OnMethodExit + static void isPreview(@Advice.Return(readOnly = false) boolean result) { + //noinspection UnusedAssignment + result = false; } + } - @Override - public ClassVisitor wrap(TypeDescription instrumentedType, ClassVisitor classVisitor, Implementation.Context implementationContext, TypePool typePool, FieldList fields, MethodList methods, int writerFlags, int readerFlags) { - return new ClassVisitor(Opcodes.ASM9, classVisitor) { - @Override - public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) { - MethodVisitor methodVisitor = super.visitMethod(access, name, descriptor, signature, exceptions); - return new MethodVisitor(Opcodes.ASM9, methodVisitor) { - @Override - public void visitFieldInsn(int opcode, String owner, String name, String descriptor) { - if (opcode == Opcodes.GETFIELD && fieldName.equalsIgnoreCase(name)) { - super.visitInsn(Opcodes.POP); - super.visitLdcInsn(value); - } else { - super.visitFieldInsn(opcode, owner, name, descriptor); - } - } - }; - } - }; + /** Makes {@link Preview#warnPreview(DiagnosticPosition, Feature)} a no-op. */ + static class WarnPreviewAdvice { + @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class) + static boolean warnPreview() { + return true; // skip method body } } } diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordPatternHelper.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordPatternHelper.java deleted file mode 100644 index 35df9fb..0000000 --- a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordPatternHelper.java +++ /dev/null @@ -1,240 +0,0 @@ -package com.github.bsideup.jabel; - -import java.util.*; - -import com.sun.tools.javac.code.*; -import com.sun.tools.javac.tree.*; -import com.sun.tools.javac.tree.JCTree.*; -import com.sun.tools.javac.util.*; -import com.sun.tools.javac.util.List; - - -/** Utilities for record pattern extraction and tree manipulation. */ -class RecordPatternHelper { - static String getClassName(Object obj) { - return obj == null ? "" : obj.getClass().getSimpleName(); - } - - /** Set the arrow-style body of a case tree. No-op if the field doesn't exist. */ - static void setCaseBody(JCCase caseTree, JCTree body) { - try { - caseTree.body = body; - } catch (NoSuchFieldError ignored) {} - } - - static boolean isBindingPattern(JCTree pattern) { - return pattern != null && getClassName(pattern).equals("JCBindingPattern"); - } - - static boolean isRecordPattern(JCTree pattern) { - return pattern != null && getClassName(pattern).equals("JCRecordPattern"); - } - - /** Check if this tree node is a pattern (binding, record, case label, or any). */ - static boolean isPattern(JCTree label) { - if (label == null) return false; - String name = getClassName(label); - return name.contains("Pattern") || name.contains("Binding"); - } - - /** - * Because {@link JCPatternCaseLabel#pat} type is {@link JCPattern}, we need to - * delay access to an inner class.
- * Like that, the class initialization error can be catched easily. - */ - private static final class PatternCaseLabelAccess { - static JCTree pat(JCTree label) { - return ((JCPatternCaseLabel) label).pat; - } - } - - /** Get the binding variable from a pattern (binding pattern or pattern case label). */ - static JCVariableDecl getPatternVar(JCTree label) { - if (label == null) return null; - switch (getClassName(label)) { - case "JCBindingPattern": - return ((JCBindingPattern) label).var; - case "JCPatternCaseLabel": - return getPatternVar(PatternCaseLabelAccess.pat(label)); - default: - return null; - } - } - - /** - * Get the type from a pattern label (for instanceof check). - * Works for binding patterns, pattern case labels, and unnamed patterns. - */ - static JCExpression getPatternType(JCTree label) { - if (label == null) return null; - switch (getClassName(label)) { - case "JCBindingPattern": - JCVariableDecl var = ((JCBindingPattern) label).var; - return var != null ? var.vartype : null; - case "JCPatternCaseLabel": - return getPatternType(PatternCaseLabelAccess.pat(label)); - case "JCAnyPattern": - return getAnyPatternType(label); - default: - return null; - } - } - - /** - * Get the type from a {@code JCAnyPattern}. - *

- * Reflection is needed because APIs varies across JDK versions: - * {@link JCTree#type} and {@link JCTree#getType()} differ between Java 22 and 23+. - */ - static JCExpression getAnyPatternType(JCTree label) { - // TODO: remake - try { - Object type = label.getClass().getField("type").get(label); - if (type instanceof JCExpression) { - return (JCExpression) type; - } - } catch (Exception ignored) {} - try { - Object type = label.getClass().getMethod("getType").invoke(label); - if (type instanceof JCExpression) { - return (JCExpression) type; - } - } catch (Exception ignored) {} - return null; - } - - /** Get the record type (deconstructor) from a record pattern or pattern case label. */ - static JCExpression getRecordType(JCTree label) { - if (label == null) return null; - switch (getClassName(label)) { - case "JCRecordPattern": - return ((JCRecordPattern) label).deconstructor; - case "JCPatternCaseLabel": - return getRecordType(PatternCaseLabelAccess.pat(label)); - default: - return null; - } - } - - /** Get nested patterns from a record pattern or pattern case label. */ - static List getRecordNested(JCTree label) { - if (label == null) return null; - switch (getClassName(label)) { - case "JCRecordPattern": - return ((JCRecordPattern) label).nested; - case "JCPatternCaseLabel": - return getRecordNested(PatternCaseLabelAccess.pat(label)); - default: - return null; - } - } - - /// - - final TreeMaker make; - final Names names; - /** Cache of record declarations for component name lookup. */ - final Map records = new HashMap<>(); - private static int tempVarCounter = 0; - - RecordPatternHelper(Context context) { - make = TreeMaker.instance(context); - names = Names.instance(context); - } - - /** - * Collect a record declaration for later component name lookup. - * Call during tree traversal. - */ - void collectRecord(JCTree tree) { - if (!(tree instanceof JCClassDecl)) return; - JCClassDecl classDecl = (JCClassDecl) tree; - if (!"RECORD".equals(classDecl.getKind().toString())) return; - records.put(classDecl.name.toString(), classDecl); - } - - /** Get record component names from cached record declarations or fallback to binding names. */ - List getRecordComponentNames(JCTree pattern) { - JCExpression deconstructor = getRecordType(pattern); - if (deconstructor == null) return List.nil(); - - // Try to find record declaration - JCClassDecl recordDecl = records.get(deconstructor.toString()); - if (recordDecl != null) { - List result = List.nil(); - for (JCTree def : recordDecl.defs) { - if (!(def instanceof JCVariableDecl)) continue; - JCVariableDecl varDecl = (JCVariableDecl) def; - if ((varDecl.mods.flags & Flags.RECORD) == 0) continue; - result = result.append(varDecl.name); - } - if (!result.isEmpty()) return result; - } - - // Fallback: use binding pattern names - List nested = getRecordNested(pattern); - if (nested == null) return List.nil(); - - List result = List.nil(); - for (JCTree p : nested) { - JCVariableDecl var = getPatternVar(p); - result = result.append(var != null ? var.name : null); - } - return result; - } - - /** - * Recursively extract bindings from nested record pattern components. - *

- * Generates variable declarations like: - * {@code final Point $tmp = parent.b(); final int bx = $tmp.x(); final int by = $tmp.y();} - */ - void extractRecordBindings( - List nested, JCExpression baseAccessor, - List componentNames, ListBuffer out - ) { - int index = 0; - for (JCTree nestedPattern : nested) { - Name componentName = index < componentNames.size() ? componentNames.get(index++) : null; - - if (isBindingPattern(nestedPattern)) { - JCVariableDecl var = ((JCBindingPattern) nestedPattern).var; - if (var == null) continue; - - Name accessorName = componentName != null ? componentName : var.name; - out.append(makeVarDef(var.name, var.vartype, makeMethodCall(baseAccessor, accessorName))); - continue; - } - - if (!isRecordPattern(nestedPattern)) continue; - JCExpression nestedRecordType = getRecordType(nestedPattern); - List deepNested = getRecordNested(nestedPattern); - if (nestedRecordType == null || deepNested == null || componentName == null) continue; - - Name tempVar = tempName(); - JCExpression accessor = makeMethodCall(baseAccessor, componentName); - out.append(makeVarDef(tempVar, nestedRecordType, accessor)); - extractRecordBindings(deepNested, make.Ident(tempVar), getRecordComponentNames(nestedPattern), out); - } - } - - Name tempName() { - return names.fromString("$record$" + (tempVarCounter++)); - } - - JCVariableDecl makeVarDef(Name name, JCExpression type, JCExpression init) { - return make.VarDef(make.Modifiers(Flags.FINAL), name, copy(type), init); - } - - JCTypeCast makeCast(JCExpression type, JCExpression expr) { - return make.TypeCast(copy(type), copy(expr)); - } - - JCMethodInvocation makeMethodCall(JCExpression receiver, Name method) { - return make.Apply(List.nil(), make.Select(copy(receiver), method), List.nil()); - } - - JCExpression copy(JCExpression expr) { - return new TreeCopier(make).copy(expr); - } -} diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordPatternRetrofittingTaskListener.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordPatternRetrofittingTaskListener.java new file mode 100644 index 0000000..d3cd568 --- /dev/null +++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordPatternRetrofittingTaskListener.java @@ -0,0 +1,155 @@ +package com.github.bsideup.jabel; + +import com.sun.source.util.*; +import com.sun.tools.javac.code.*; +import com.sun.tools.javac.code.Symbol.*; +import com.sun.tools.javac.code.Type.*; +import com.sun.tools.javac.comp.Operators; +import com.sun.tools.javac.tree.*; +import com.sun.tools.javac.tree.JCTree.*; +import com.sun.tools.javac.util.*; + + +/** + * Replaces {@link MatchException} catch clauses, generated by the compiler for + * record patterns unpacking, by {@link RuntimeException}. + */ +public class RecordPatternRetrofittingTaskListener implements TaskListener { + private boolean MATCH_EXCEPTION_PRESENT = SwitchRetrofittingTaskListener.MATCH_EXCEPTION_PRESENT; + + final TreeMaker make; + final Symtab syms; + final Names names; + final MethodSymbol runtimeExceptionConstructor; + final OperatorSymbol opCONCAT; + + public RecordPatternRetrofittingTaskListener(Context context) { + make = TreeMaker.instance(context); + syms = Symtab.instance(context); + names = Names.instance(context); + + runtimeExceptionConstructor = getRuntimeExceptionConstructor(); + opCONCAT = SwitchRetrofittingTaskListener.resolveBinary( + Operators.instance(context), + make.Literal(0), + Tag.PLUS, + syms.stringType, + syms.stringType + ); + } + + @Override + public void started(TaskEvent e) { + if (!MATCH_EXCEPTION_PRESENT) return; + if (e.getKind() != TaskEvent.Kind.GENERATE) return; + if (!(e.getCompilationUnit() instanceof JCCompilationUnit)) return; + new MatchExceptionFixer().translate((JCCompilationUnit) e.getCompilationUnit()); + } + + @Override + public void finished(TaskEvent e) { + if (!MATCH_EXCEPTION_PRESENT) return; + if (e.getKind() != TaskEvent.Kind.ANALYZE) return; + // Creates the virtual symbol after analyze, so the user-code cannot reference it + injectMatchException(); + } + + public class MatchExceptionFixer extends TreeTranslator { + @Override + public void visitBlock(JCBlock tree) { + JCCatch c = SwitchRetrofittingTaskListener.getPatternMatchingCatchHandler(tree); + if (c != null) { + if (isMatchException(c.param.vartype)) { + c.param.vartype = make.QualIdent(syms.runtimeExceptionType.tsym) + .setType(syms.runtimeExceptionType); + if (c.param.sym != null) c.param.sym.type = syms.runtimeExceptionType; + } + JCTree last = result; + c.body = translate(c.body); + result = last; + } + super.visitBlock(tree); + } + + @Override + public void visitNewClass(JCNewClass tree) { + super.visitNewClass(tree); + if (!isMatchException(tree.clazz)) return; + + if (runtimeExceptionConstructor != null) { + tree.clazz = make.QualIdent(syms.runtimeExceptionType.tsym).setType(syms.runtimeExceptionType); + tree.type = syms.runtimeExceptionType; + tree.constructor = runtimeExceptionConstructor; + } + if (opCONCAT != null) { + tree.args.head = make.Binary( + Tag.PLUS, + make.Literal("MatchException: Record-pattern unpacking failed: "), + tree.args.head + ).setType(syms.stringType); + ((JCBinary) tree.args.head).operator = opCONCAT; + } else { + tree.args.head = make.Literal("MatchException: record-pattern unpacking failed"); + } + } + } + + public boolean isMatchException(JCTree t) { + if (t == null) return false; + // TODO find a better way? + switch (t.toString()) { + case "java.lang.MatchException": + case "MatchException": + return true; + default: + return false; + } + } + + /** + * Registers a virtual {@link MatchException} symbol to make the compiler happy. + */ + public void injectMatchException() { + ClassSymbol cs = (ClassSymbol) syms.matchExceptionType.tsym; + if (cs == null) return; + if (cs.classfile != null) { + // Do nothing if already exists in the classpath to avoid changing the user code + MATCH_EXCEPTION_PRESENT = false; + return; + } + if (cs.members_field != null && !cs.members_field.isEmpty()) return; + + // Fill the current symbol + cs.flags_field = Flags.PUBLIC; + ClassType ct = new ClassType(Type.noType, List.nil(), cs); + ct.supertype_field = syms.runtimeExceptionType; + cs.type = ct; + cs.members_field = Scope.WriteableScope.create(cs); + + // Only create the constructor used by TransPatterns + cs.members_field.enter(new MethodSymbol( + Flags.PUBLIC, + names.init, + new MethodType( + List.of(syms.stringType, syms.throwableType), + syms.voidType, + List.nil(), + syms.methodClass + ), + cs + )); + cs.completer = Completer.NULL_COMPLETER; + } + + public MethodSymbol getRuntimeExceptionConstructor() { + for (Symbol sym : syms.runtimeExceptionType.tsym.members().getSymbols()) { + if (sym.kind != Kinds.Kind.MTH) continue; + if (sym.name != names.init) continue; + MethodSymbol m = (MethodSymbol) sym; + if (m.params().size() != 2) continue; + // TODO: check for actual types? + return m; + } + return null; + } +} diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordsRetrofittingTaskListener.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordsRetrofittingTaskListener.java index 1fa0e9f..0fcd6b2 100644 --- a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordsRetrofittingTaskListener.java +++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/RecordsRetrofittingTaskListener.java @@ -156,7 +156,7 @@ public Stream getRecordComponents(JCClassDecl classDecl) { .filter(it -> !it.getModifiers().getFlags().contains(Modifier.STATIC)); } - public List generateToString(JCClassDecl classDecl) { + public List generateToString(JCClassDecl classDecl){ JCExpression stringBuilder = make.NewClass( null, null, @@ -165,43 +165,23 @@ public List generateToString(JCClassDecl classDecl) { null ); - for ( - Iterator iterator = getRecordComponents(classDecl).iterator(); - iterator.hasNext(); - ) { + for(Iterator iterator = getRecordComponents(classDecl).iterator(); iterator.hasNext();){ JCVariableDecl fieldDecl = iterator.next(); Name fieldName = fieldDecl.name; - stringBuilder = make.App( - make.Select(stringBuilder, names.append) - .setType(syms.stringBuilderType), - List.of(make.Literal(fieldName + "=")) - ); - - stringBuilder = make.App( - make.Select(stringBuilder, names.append) - .setType(syms.stringBuilderType), - List.of(make.Select(make.This(Type.noType), fieldName)) - ); - - if (!iterator.hasNext()) break; - stringBuilder = make.App( - make.Select(stringBuilder, names.append) - .setType(syms.stringBuilderType), - List.of(make.Literal(", ")) - ); + stringBuilder = stringAppend(stringBuilder, make.Literal(fieldName + "=")); + stringBuilder = stringAppend(stringBuilder, make.Select(make.This(Type.noType), fieldName)); + if(iterator.hasNext()){ + stringBuilder = stringAppend(stringBuilder, make.Literal(", ")); + } } + stringBuilder = stringAppend(stringBuilder, make.Literal("]")); - stringBuilder = make.App( - make.Select(stringBuilder, names.append) - .setType(syms.stringBuilderType), - List.of(make.Literal("]")) - ); + return List.of(make.Return(make.App(make.Select(stringBuilder, names.toString).setType(syms.stringType)))); + } - return List.of(make.Return(make.App( - make.Select(stringBuilder, names.toString) - .setType(syms.stringType) - ))); + private JCMethodInvocation stringAppend(JCExpression builder, JCExpression arg) { + return make.App(make.Select(builder, names.append).setType(syms.stringBuilderType), List.of(arg)); } public List generateEquals(JCClassDecl classDecl, Name otherName) { @@ -233,14 +213,8 @@ public List generateEquals(JCClassDecl classDecl, Name otherName) { statements.add(make.If( make.Binary( Tag.EQ, - make.App(make.Select( - make.Ident(otherName), - names.getClass - ).setType(syms.classType)), - make.App(make.Select( - make.This(Type.noType), - names.getClass - ).setType(syms.classType)) + make.App(make.Select(make.Ident(otherName), names.getClass).setType(syms.classType)), + make.App(make.Select(make.This(Type.noType), names.getClass).setType(syms.classType)) ), make.Block(0, List.nil()), make.Return(make.Literal(false)) @@ -249,7 +223,7 @@ public List generateEquals(JCClassDecl classDecl, Name otherName) { // Create casted variable: ClassName other = (ClassName)o; Name thatName = names.fromString("other"); statements.add(make.VarDef( - make.Modifiers(0L), + make.Modifiers(0), thatName, make.Ident(classDecl.name), make.TypeCast(make.Ident(classDecl.name), make.Ident(otherName)) @@ -303,6 +277,7 @@ public List generateHashCode(JCClassDecl classDecl) { JCExpression myFieldAccess = make.Select(make.This(Type.noType), fieldDecl.name); if (fType instanceof JCPrimitiveTypeTree) { + //TODO simplify that? switch (((JCPrimitiveTypeTree) fType).getPrimitiveTypeKind()) { case BOOLEAN: /* this.fieldName ? 1 : 0 */ @@ -385,7 +360,7 @@ public List generateHashCode(JCClassDecl classDecl) { ListBuffer statements = new ListBuffer<>(); Name resultName = names.fromString("result"); statements.append(make.VarDef( - make.Modifiers(0L), + make.Modifiers(0), resultName, make.TypeIdent(syms.intType.getTag()), make.Literal(0) diff --git a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/SwitchRetrofittingTaskListener.java b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/SwitchRetrofittingTaskListener.java index 18c60a0..1ee71d0 100644 --- a/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/SwitchRetrofittingTaskListener.java +++ b/jabel-javac-plugin/src/main/java/com/github/bsideup/jabel/SwitchRetrofittingTaskListener.java @@ -1,49 +1,63 @@ package com.github.bsideup.jabel; -import java.lang.reflect.*; +import java.lang.reflect.Method; import java.util.*; import com.sun.source.tree.*; import com.sun.source.util.*; import com.sun.source.util.TreeScanner; import com.sun.tools.javac.code.*; +import com.sun.tools.javac.code.Symbol.*; +import com.sun.tools.javac.comp.*; import com.sun.tools.javac.tree.*; import com.sun.tools.javac.tree.JCTree.*; import com.sun.tools.javac.util.*; +import com.sun.tools.javac.util.JCDiagnostic.*; import com.sun.tools.javac.util.List; -import static com.github.bsideup.jabel.RecordPatternHelper.*; - -//TODO: the class need some reworking since JCSwitchExpression doesn't exists in Java 12- /** * Transforms modern switch constructs (Java 17+) into Java 8 compatible code. + *


+ * This listener will rewrite entirely the switch for a better one. + *

+ * Instead of placing the guard in the case body, it will be evaluated in the switch condition.
+ * Avoiding restarting the switch everytimes a guard fail, + * and storing an index to not reevaluate previous cases. + *

+ * So this method will convert cases labels into a ternary-chain + * that will be placed in the switch condition.
+ * And case labels will be changed to match their position in the switch. *

- * This will extract {@code null} case to a {@code if(sel==null)} statement. - *
- * And will move (record) pattern/guard switch labels, to the switch condition. - * Letting the compiler optimize it. + * Also, to avoid double method calls, if the switch condition was a method call, + * an uninitialized variable is inserted in the previous line, + * and then initialized in the ternary-chain.
+ * Same way goes for cases using a record-pattern with a guard using a component from that record. */ public class SwitchRetrofittingTaskListener implements TaskListener { // region compiler compatibility - // Because we're compiling with Java25, the old method (without guards) doens't exists. + // Because we're compiling with JDK 25, the old method (without guards) doens't exists. private static Method LEGACY_MAKE_CASE; - // Some internal states to avoid getting errors everytimes we trying to use a - // not found feature. - private static Boolean GUARDS, LABELS, BODIES, DEFAULT_CASES, CONSTANT_CASES; + // Some internal states to avoid getting errors everytimes we trying to use a found feature. + // true means that the feature is not present, by default assuming it is. + private static boolean GUARD, LABELS, BODY, DEFAULT_CASE, CONSTANT_CASE, PATTERN_MATCHING_CATCH, SWITCH_PATTERN; + /* private */ static boolean MATCH_EXCEPTION_PRESENT; + + static { + try { + Symtab.class.getDeclaredField("matchExceptionType"); + MATCH_EXCEPTION_PRESENT = true; + } catch (Exception ignored) {} + } /** Get the guard expression from a case, or null if guards are unsupported. */ private static JCExpression getGuard(JCCase caseTree) { - if (GUARDS != null) { - return GUARDS ? caseTree.getGuard() : null; - } + if (GUARD) return null; try { - JCExpression guard = caseTree.getGuard(); - GUARDS = true; - return guard; + return caseTree.getGuard(); } catch (NoSuchMethodError ignored) { - GUARDS = false; + GUARD = true; return null; } } @@ -51,17 +65,13 @@ private static JCExpression getGuard(JCCase caseTree) { /** Get the labels from a case, handling both old and new compiler APIs. */ @SuppressWarnings("unchecked") private static List getLabels(JCCase caseTree) { - if (LABELS == null) { + if (!LABELS) { try { - caseTree.getLabels(); - LABELS = true; + return (List) (List) caseTree.getLabels(); } catch (NoSuchMethodError ignored) { - LABELS = false; + LABELS = true; } } - if (LABELS) { - return (List) (List) caseTree.getLabels(); - } List labels = caseTree.getExpressions(); if (labels == null) return List.nil(); @@ -75,20 +85,17 @@ private static List getLabels(JCCase caseTree) { /** Get the arrow-style body of a case, or null if unsupported. */ private static JCTree getBody(JCCase caseTree) { - if (BODIES != null) { - return BODIES ? caseTree.getBody() : null; - } + if (BODY) return null; try { - JCTree body = caseTree.getBody(); - BODIES = true; - return body; + return caseTree.getBody(); } catch (NoSuchMethodError ignored) { - BODIES = false; + BODY = true; return null; } } - // TODO: CaseTree.CaseKind doesn't exists in Java 12- + // TODO: CaseTree.CaseKind doesn't exists on Java 12- + // EDIT: Seems to not be a real problem... /** * Create a new case tree, handling different JDK signatures.
* JDK 21+: Case(CaseKind, List labels, JCExpression guard, List stats, JCTree body)
@@ -96,35 +103,33 @@ private static JCTree getBody(JCCase caseTree) { */ @SuppressWarnings("unchecked") private JCCase makeCase( - CaseTree.CaseKind kind, List labels, JCExpression guard, - List stats, JCTree body + CaseTree.CaseKind kind, List labels, List stats, JCTree body ) { - if (GUARDS != false) { + // A default case is one that have no labels on JDK < 17 + if (!labels.isEmpty() && labels.head == null) labels = List.nil(); + + if (!GUARD) { try { - JCCase c = make.Case(kind, (List) labels, guard, stats, body); - GUARDS = true; - return c; - } catch (NoSuchMethodError ignored) {} + return make.Case(kind, (List) labels, null, stats, body); + } catch (NoSuchMethodError ignored) { + GUARD = true; + } } try { - GUARDS = false; if (LEGACY_MAKE_CASE == null) { LEGACY_MAKE_CASE = TreeMaker.class.getMethod( - "Case", - CaseTree.CaseKind.class, - List.class, - List.class, - JCTree.class + "Case", CaseTree.CaseKind.class, List.class, List.class, JCTree.class ); } return (JCCase) LEGACY_MAKE_CASE.invoke(make, kind, labels, stats, body); - } catch (Exception ignored) {} // Hope this never happen... + } catch (Exception ignored) {// Should never fail, hope this never happen... + System.err.println("[jabel] Failed to make case with labels: " + labels); + } return null; } - // Because the return type of these two methods may not exist, we need delay the - // call to an inner class. + // Because the return type of these two methods may not exist, we need delay the call to an inner class. // Like that, the class initialization error can be catched easily. private static final class DefaultCaseLabelFactory { static JCTree make(TreeMaker m) { @@ -140,35 +145,94 @@ static JCTree make(TreeMaker m, JCExpression lit) { /** Create a default case label. Returns null if unsupported (JDK < 17). */ private JCTree makeDefaultCaseLabel() { - if (DEFAULT_CASES != null) { - return DEFAULT_CASES ? DefaultCaseLabelFactory.make(make) : null; - } + if (DEFAULT_CASE) return null; try { - JCTree c = DefaultCaseLabelFactory.make(make); - DEFAULT_CASES = true; - return c; + return DefaultCaseLabelFactory.make(make); } catch (NoSuchMethodError | NoClassDefFoundError ignored) { - DEFAULT_CASES = false; + DEFAULT_CASE = true; return null; } } /** Creates a case label for an {@code Integer}, handling JDK 17-20 and 21+ APIs. */ private JCTree makeLabel(int i) { - if (CONSTANT_CASES != null) { - return CONSTANT_CASES ? ConstantCaseLabelFactory.make(make, make.Literal(i)) : make.Literal(i); - } + JCLiteral lit = make.Literal(i); + if (CONSTANT_CASE) return lit; try { - JCTree tree = ConstantCaseLabelFactory.make(make, make.Literal(i)); // JDK 21+ - CONSTANT_CASES = true; - return tree; + return ConstantCaseLabelFactory.make(make, lit); // JDK 21+ } catch (NoSuchMethodError | NoClassDefFoundError ignored) { - CONSTANT_CASES = false; - return make.Literal(i); // JDK 17-20 + CONSTANT_CASE = true; + return lit; // JDK 17-20 + } + } + + private static final class PatternMatchingCatchAccess { + static void attach(JCBlock block, JCCatch handler, Set calls) { + block.patternMatchingCatch = new JCBlock.PatternMatchingCatch(handler, calls); + } + + static JCCatch handler(JCBlock block) { + if (block.patternMatchingCatch == null) return null; + return block.patternMatchingCatch.handler(); + } + } + + /* private */ static JCCatch getPatternMatchingCatchHandler(JCBlock block) { + if (PATTERN_MATCHING_CATCH) return null; + try { + return PatternMatchingCatchAccess.handler(block); + } catch (NoSuchFieldError | NoClassDefFoundError ignored) { + PATTERN_MATCHING_CATCH = true; + return null; + } + } + + private static void attachPatternMatchingCatch(JCBlock block, JCCatch body, Set calls) { + if (PATTERN_MATCHING_CATCH) return; + try { + PatternMatchingCatchAccess.attach(block, body, calls); + } catch (NoSuchFieldError | NoClassDefFoundError ignored) { + PATTERN_MATCHING_CATCH = true; + } + } + + /** Sets {@code patternSwitch = false} on a switch, silently ignoring JDK < 17. */ + private static void clearPatternSwitchStatus(JCTree tree) { + if (SWITCH_PATTERN) return; + try { + switch (getClassName(tree)) { + case "JCSwitch": + ((JCSwitch) tree).patternSwitch = false; + break; + case "JCSwitchExpression": + ((JCSwitchExpression) tree).patternSwitch = false; + break; + } + } catch (NoSuchFieldError | NoClassDefFoundError ignored) { + SWITCH_PATTERN = true; } } - private static boolean isSwitchExpression(Tree tree) { + private static String getClassName(Object obj) { + return obj == null ? "" : obj.getClass().getSimpleName(); + } + + private static boolean isBindingPattern(JCTree pattern) { + return pattern != null && getClassName(pattern).equals("JCBindingPattern"); + } + + private static boolean isRecordPattern(JCTree pattern) { + return pattern != null && getClassName(pattern).equals("JCRecordPattern"); + } + + /** Check if this tree node is a pattern (binding, record, case label, or any). */ + private static boolean isPattern(JCTree label) { + if (label == null) return false; + String name = getClassName(label); + return name.contains("Pattern") || name.contains("Binding"); + } + + private static boolean isSwitchExpression(JCTree tree) { return tree != null && getClassName(tree).equals("JCSwitchExpression"); } @@ -181,321 +245,828 @@ private static boolean isConstant(JCTree label) { } private static boolean isNull(JCTree label) { - if (label instanceof JCLiteral) { - return ((JCLiteral) label).typetag == TypeTag.BOT; - } + if (label instanceof JCLiteral) return ((JCLiteral) label).typetag == TypeTag.BOT; if (!isConstant(label)) return false; JCExpression expr = ((JCConstantCaseLabel) label).expr; return expr instanceof JCLiteral && ((JCLiteral) expr).typetag == TypeTag.BOT; } - private static boolean hasDefault(JCCase caseTree) { - List labels = getLabels(caseTree); - if (labels.isEmpty()) return true; - for (JCTree label : labels) { - if (isDefault(label)) return true; - } - return false; + private boolean isComplex(JCExpression e) { + if (e instanceof JCIdent || e instanceof JCLiteral) return false; + if (e instanceof JCFieldAccess) return isComplex(((JCFieldAccess) e).selected); + if (e instanceof JCParens) return isComplex(((JCParens) e).expr); + return true; + } + + /** @return {@code true} if {@code expr} is structurally an enum constant reference. */ + private static boolean isEnumConstant(JCExpression expr) { + return (expr instanceof JCIdent || expr instanceof JCFieldAccess) + && !(expr instanceof JCLiteral); + } + + private JCExpression getExpression(JCTree body) { + if (body instanceof JCExpressionStatement) return ((JCExpressionStatement) body).expr; + if (body instanceof JCExpression) return (JCExpression) body; + return null; } private static JCExpression getLabelExpression(JCTree label) { - if (label instanceof JCExpression) { - return (JCExpression) label; + if (label instanceof JCExpression) return (JCExpression) label; + if (isConstant(label)) return ((JCConstantCaseLabel) label).expr; + return null; + } + + /** + * Because {@link JCPatternCaseLabel#pat} type is {@link JCPattern}, + * we need to delay access to an inner class.
+ * Like that, the class initialization error can be catched easily. + */ + private static final class PatternCaseLabelAccess { + static JCTree pat(JCTree label) { + return ((JCPatternCaseLabel) label).pat; + } + } + + /** Get the binding variable from a pattern (binding pattern or pattern case label). */ + private static JCVariableDecl getPatternVar(JCTree label) { + if (label == null) return null; + switch (getClassName(label)) { + case "JCBindingPattern": + return ((JCBindingPattern) label).var; + case "JCPatternCaseLabel": + return getPatternVar(PatternCaseLabelAccess.pat(label)); + default: + return null; } - if (isConstant(label)) { - return ((JCConstantCaseLabel) label).expr; + } + + /** + * Get the type from a pattern label (for instanceof check).
+ * Works for binding patterns, pattern case labels, and unnamed patterns. + */ + private static JCExpression getPatternType(JCTree label) { + if (label == null) return null; + switch (getClassName(label)) { + case "JCBindingPattern": + JCVariableDecl var = ((JCBindingPattern) label).var; + return var != null ? var.vartype : null; + case "JCPatternCaseLabel": + return getPatternType(PatternCaseLabelAccess.pat(label)); + // case "JCAnyPattern": + // return getAnyPatternType(label); + default: + return null; + } + } + + /** Get the record type from a record pattern or pattern case label. */ + private static JCExpression getRecordType(JCTree label) { + if (label == null) return null; + switch (getClassName(label)) { + case "JCRecordPattern": + return ((JCRecordPattern) label).deconstructor; + case "JCPatternCaseLabel": + return getRecordType(PatternCaseLabelAccess.pat(label)); + default: + return null; + } + } + + /** Get nested patterns from a record pattern or pattern case label. */ + private static List getRecordNested(JCTree label) { + if (label == null) return null; + switch (getClassName(label)) { + case "JCRecordPattern": + return ((JCRecordPattern) label).nested; + case "JCPatternCaseLabel": + return getRecordNested(PatternCaseLabelAccess.pat(label)); + default: + return null; } - return null; } - private static boolean hasPatterns(List cases) { + private static boolean hasDefault(List cases) { + if (cases == null) return true; + for (JCCase c : cases) { + if (c == null) continue; + List labels = getLabels(c); + if (labels.isEmpty()) return true; + for (JCTree label : labels) { + if (isDefault(label)) return true; + } + } + return false; + } + + private static boolean hasRecordPatterns(List cases) { + if (cases == null) return false; + for (JCCase c : cases) { + if (c == null) continue; + for (JCTree l : getLabels(c)) { + if (getRecordType(l) != null) return true; + } + } + return false; + } + + private static boolean hasGuardRecordPatterns(List cases) { if (cases == null) return false; - for (JCCase caseTree : cases) { - if (caseTree == null) continue; - if (getGuard(caseTree) != null) return true; - for (JCTree label : getLabels(caseTree)) { - if (isPattern(label)) return true; + for (JCCase c : cases) { + if (c == null || getGuard(c) == null) continue; + for (JCTree l : getLabels(c)) { + if (getRecordType(l) != null) return true; } } return false; } - private static boolean hasNull(List cases) { + /** @return whether case labels contains things not handled by a standard switch. */ + private static boolean needsTransform(List cases) { if (cases == null) return false; - for (JCCase caseTree : cases) { - if (caseTree == null) continue; - for (JCTree label : getLabels(caseTree)) { - if (isNull(label)) return true; + for (JCCase c : cases) { + if (c == null) continue; + if (getGuard(c) != null) return true; + for (JCTree label : getLabels(c)) { + if (isPattern(label) || isNull(label)) return true; + JCExpression expr = getLabelExpression(label); + if (!(expr instanceof JCLiteral)) continue; + switch (((JCLiteral) expr).typetag) { + case FLOAT: + case DOUBLE: + case LONG: + return true; + default: + } } } return false; } + /** + * @return whether all cases are just enum constants with a {@code null} + * and/or an unconditional pattern. + */ + private static boolean isEnumSwitch(List cases) { + if (cases == null) return false; + boolean hasEnumConstant = false, hasUnguardedPattern = false; + for (JCCase c : cases) { + if (c == null) continue; + if (getGuard(c) != null) return false; + for (JCTree label : getLabels(c)) { + if (isDefault(label) || isNull(label)) continue; + if (isPattern(label)) { + if (hasUnguardedPattern) return false; + hasUnguardedPattern = true; + continue; + } + JCExpression expr = getLabelExpression(label); + if (expr == null || !isEnumConstant(expr)) return false; + hasEnumConstant = true; + } + } + return hasEnumConstant; + } + + // Forced to use reflection here, methods are package-private. + static Method resolveBinary; + static { + try { + resolveBinary = Operators.class.getDeclaredMethod( + "resolveBinary", DiagnosticPosition.class, Tag.class, Type.class, Type.class + ); + resolveBinary.setAccessible(true); + } catch (Exception ignored) {} + } + + /* private */ static OperatorSymbol resolveBinary( + Operators ops, DiagnosticPosition pos, Tag tag, Type op1, Type op2 + ) { + if (resolveBinary != null) { + try { + return (OperatorSymbol) resolveBinary.invoke(ops, pos, tag, op1, op2); + } catch (Exception e) { + //TODO use Log instead? but this should never fail... + System.err.println("[jabel] Failed to find operator " + tag + " with " + op1 + "," + op2); + } + } + return null; + } + // end region // region task listener - final RecordPatternHelper helper; final TreeMaker make; final Symtab syms; final Names names; - private int tempVarCounter = 0; + final Types types; + final Operators ops; + final Attr attr; + final SymTreeCopier copier; + + private int tempVarCounter; + private Set pendingAccessorCalls = null; + private Map> guardPreVars = null; + private MethodSymbol currentMethodSym; + private boolean cacheResolved, useMatchException; + private MethodSymbol object_equals, object_toString, object_getClass, enum_ordinal, class_getName, + float_floatToIntBits, double_doubleToLongBits, icce_ctor, me_ctor; public SwitchRetrofittingTaskListener(Context context) { - helper = new RecordPatternHelper(context); make = TreeMaker.instance(context); syms = Symtab.instance(context); names = Names.instance(context); + types = Types.instance(context); + ops = Operators.instance(context); + attr = Attr.instance(context); + copier = new SymTreeCopier(make); } @Override public void started(TaskEvent e) { - if (e.getKind() != TaskEvent.Kind.ENTER) return; - if (!(e.getCompilationUnit() instanceof JCCompilationUnit)) return; - new SwitchTranslator().translate((JCCompilationUnit) e.getCompilationUnit()); } @Override public void finished(TaskEvent e) { + if (e.getKind() != TaskEvent.Kind.ANALYZE) return; + if (!(e.getCompilationUnit() instanceof JCCompilationUnit)) return; + new SwitchTranslator().translate((JCCompilationUnit) e.getCompilationUnit()); } - public class SwitchTranslator extends TreeTranslator { - private final Map captures = new HashMap<>(); + /** Captures the original selector of a switch expression before replacement by a temp var. */ + private final Map captures = new HashMap<>(); + private Set blockAccessorCalls = null; + private ClassSymbol currentClass; + // Symbol tracking @Override - public void visitSwitch(JCSwitch tree) { - super.visitSwitch(tree); - if (!needsTransform(tree.cases)) { - tree.cases = injectDefault(tree.cases); - return; + public void visitClassDef(JCClassDecl tree) { + ClassSymbol prevClass = currentClass; + MethodSymbol prevMethod = currentMethodSym; + try { + currentClass = tree.sym; + currentMethodSym = null; + super.visitClassDef(tree); + } finally { + currentClass = prevClass; + currentMethodSym = prevMethod; } + } - make.at(tree.pos); - if (isComplex(tree.selector)) { - Name sv = names.fromString("$switch$" + (tempVarCounter++)); - result = make.Block(0, List.of( - makeFinalVar(sv, make.Type(syms.objectType), tree.selector), - transformSwitch(make.Ident(sv), tree.cases, false, null) - )); - } else { - result = transformSwitch(tree.selector, tree.cases, false, null); + @Override + public void visitMethodDef(JCMethodDecl tree) { + MethodSymbol prev = currentMethodSym; + try { + currentMethodSym = tree.sym; + super.visitMethodDef(tree); + } finally { + currentMethodSym = prev; } } @Override public void visitBlock(JCBlock tree) { - ListBuffer buffer = null; - for (JCStatement stmt : tree.stats) { - List found = findPatternSwitches(stmt); - - for (JCSwitchExpression sw : found) { - if (!isComplex(sw.selector)) continue; - if (buffer == null) { - ListBuffer buf = new ListBuffer<>(); - for (JCStatement s : tree.stats) { - if (s == stmt) break; - buf.append(s); - } - buffer = buf; - } + MethodSymbol prev = currentMethodSym; + // Save and reset blockAccessorCalls for this block scope + Set prevCalls = blockAccessorCalls; + blockAccessorCalls = null; + try { + if (currentMethodSym == null && currentClass != null) { + currentMethodSym = new MethodSymbol( + tree.flags | Flags.BLOCK, + names.empty, + null, + currentClass + ); + } + ListBuffer buffer = null; + for (JCStatement stmt : tree.stats) { + buffer = prepareExprSwitchPrefix(stmt, buffer, tree.stats); + if (buffer != null) buffer.append(stmt); + } + if (buffer != null) tree.stats = buffer.toList(); + super.visitBlock(tree); + + // Attach patternMatchingCatch to this block if accessor calls were collected + if (blockAccessorCalls != null) { + pendingAccessorCalls = blockAccessorCalls; + blockAccessorCalls = null; + attachPatternMatchingCatch(tree); + } + } finally { + // Propagate remaining calls to the parent scope + if (blockAccessorCalls != null) { + if (prevCalls == null) prevCalls = blockAccessorCalls; + else prevCalls.addAll(blockAccessorCalls); + } + blockAccessorCalls = prevCalls; + currentMethodSym = prev; + } + } + + @Override + public void visitVarDef(JCVariableDecl tree) { + MethodSymbol prev = currentMethodSym; + try { + if (currentMethodSym == null && currentClass != null) { + currentMethodSym = new MethodSymbol( + (tree.mods.flags & Flags.STATIC) | Flags.BLOCK, + names.empty, + null, + currentClass + ); + } + super.visitVarDef(tree); + } finally { + currentMethodSym = prev; + } + } + + @Override + public void visitSwitch(JCSwitch tree) { + super.visitSwitch(tree); + if (!needsTransform(tree.cases)) return; + clearPatternSwitchStatus(tree); // TransPatterns must not re-process our output + + make.at(tree.pos); + ListBuffer prefix = new ListBuffer<>(); + JCExpression sel = captureSelector(tree.selector, prefix, true); + buildGuardPreDecls(tree.cases, prefix); + JCSwitch ns = transformSwitch(tree, sel, tree.cases, false, null); + tree.selector = ns.selector; + tree.cases = ns.cases; + + collectAccessorCalls(); + if (prefix.isEmpty()) return; + prefix.append(tree); + result = make.Block(0, prefix.toList()); + } - Name sv = names.fromString("$switch$" + (tempVarCounter++)); - buffer.append(make.VarDef(make.Modifiers(0), sv, make.Type(syms.objectType), null)); - captures.put(sw, sw.selector); - sw.selector = make.Ident(sv); + private ListBuffer prepareExprSwitchPrefix( + JCStatement stmt, ListBuffer buffer, List allStats + ) { + List found = findPatternSwitches(stmt); + if (found.isEmpty()) return buffer; + + boolean needsBuffer = false; + for (JCTree tree : found) { + JCSwitchExpression sw = (JCSwitchExpression) tree; + if (isComplex(sw.selector) || hasGuardRecordPatterns(sw.cases)) { + needsBuffer = true; + break; } - if (buffer != null) buffer.append(stmt); } + if (!needsBuffer) return buffer; - if (buffer != null) { - tree.stats = buffer.toList(); + if (buffer == null) { + ListBuffer buf = new ListBuffer<>(); + for (JCStatement s : allStats) { + if (s == stmt) break; + buf.append(s); + } + buffer = buf; + } + + for (JCTree tree : found) { + JCSwitchExpression sw = (JCSwitchExpression) tree; + JCExpression sel = captureSelector(sw.selector, buffer, false); + if (sel != sw.selector) { + captures.put(tree, sw.selector); + sw.selector = sel; + } + if (needsTransform(sw.cases)) buildGuardPreDecls(sw.cases, buffer); } - super.visitBlock(tree); + return buffer; } @Override public T translate(T tree) { if (tree == null) return null; - helper.collectRecord(tree); - if (!isSwitchExpression(tree)) { - return super.translate(tree); - } + if (!isSwitchExpression(tree)) return super.translate(tree); + clearPatternSwitchStatus(tree); JCSwitchExpression sw = (JCSwitchExpression) tree; make.at(sw.pos); - JCExpression rawSel = captures.remove(sw); // consume capture if registered + JCExpression rawSel = captures.remove(sw); sw.selector = translate(sw.selector); sw.cases = translate(sw.cases); if (needsTransform(sw.cases)) { - JCSwitch ns = transformSwitch(sw.selector, sw.cases, true, rawSel); + JCSwitch ns = transformSwitch(tree, sw.selector, sw.cases, true, rawSel); sw.selector = ns.selector; sw.cases = ns.cases; + collectAccessorCalls(); } else { - sw.cases = injectDefault(sw.cases); + sw.cases = injectDefault(sw.selector, sw.cases); } return tree; } - /** TreeTranslator skips the arrow body field, translate it manually. */ + /** TreeTranslator skips the arrow body field. */ @Override public void visitCase(JCCase tree) { super.visitCase(tree); JCTree body = getBody(tree); if (body == null) return; JCTree saved = result; - setCaseBody(tree, translate(body)); + tree.body = translate(body); result = saved; } - /** Recursively collects all JCSwitchExpression nodes inside a statement. */ - private List findPatternSwitches(JCTree node) { - ListBuffer out = new ListBuffer<>(); - new TreeScanner() { - @Override - public Void scan(Tree t, Void v) { - if (t == null) return null; - if (isSwitchExpression(t)) { - JCSwitchExpression se = (JCSwitchExpression) t; - if (needsTransform(se.cases)) out.append(se); - // Don't recurse into the switch itself, nested ones handled separately. - return null; - } - return super.scan(t, v); + /** Moves pendingAccessorCalls into blockAccessorCalls. */ + private void collectAccessorCalls() { + if (pendingAccessorCalls == null) return; + if (blockAccessorCalls == null) blockAccessorCalls = new HashSet<>(); + blockAccessorCalls.addAll(pendingAccessorCalls); + pendingAccessorCalls = null; + } + } + + @SuppressWarnings("unchecked") + public List findPatternSwitches(JCTree node) { + ListBuffer out = new ListBuffer<>(); + new TreeScanner() { + @Override + public Void scan(Tree t, Void v) { + if (t == null) return null; + if (!isSwitchExpression((JCTree) t)) return super.scan(t, v); + JCSwitchExpression se = (JCSwitchExpression) t; + if (needsTransform(se.cases)) out.append(se); + return null; + } + }.scan(node, null); + return (List) (List) out.toList(); + } + + /** Builds the standard switch from a pattern/null switch. */ + public JCSwitch transformSwitch( + JCTree currentSwitch, JCExpression sel, List cases, boolean expression, + JCExpression rawSel + ) { + pendingAccessorCalls = hasRecordPatterns(cases) ? new HashSet<>() : null; + + boolean isEnum = isEnumSwitch(cases), hasNull = false; + List labels = List.nil(); + List slotCases = List.nil(); + + // Analyze case and labels + for (JCCase c : cases) { + JCTree last = null; + for (JCTree label : getLabels(c)) { + if (isNull(label)) { + hasNull = true; + continue; } - }.scan(node, null); - return out.toList(); + if (isDefault(label) || (isEnum && isPattern(label))) continue; + if (last != null) { + labels = labels.append(last); + slotCases = slotCases.append(null); + } + last = label; + } + if (last != null) { + labels = labels.append(last); + slotCases = slotCases.append(c); + } } + + return make.Switch( + isEnum + ? buildEnumSwitchCondition(sel, hasNull, rawSel) + : buildTypeSwitchCondition(sel, hasNull, labels, slotCases, rawSel), + prepareSwichCases(currentSwitch, sel, expression, isEnum, cases, labels) + ); } - /** - * Replaces a pattern/null switch to a standard switch using a ternary-chain - * dispatcher in the condition. - *

- * The ternary chain maps each case condition to an index which becomes the new - * selector.
- * Each case label is rewritten with this index, plus the record pattern, if - * any, is lowered to the case.
- * - * @param sel the selector expression - * @param cases already-translated case list - * @param expression whether to build a switch expression or statement - * @param rawSel original selector to inject as a capture, or {@code null} - */ - public JCSwitch transformSwitch(JCExpression sel, List cases, boolean expression, JCExpression rawSel) { - java.util.List nonDefs = new ArrayList<>(); - JCCase defCase = null; + public List prepareSwichCases( + JCTree currentSwitch, JCExpression sel, boolean isExpression, boolean isEnum, + List cases, List labels + ) { + ListBuffer newCases = new ListBuffer<>(); + JCCase nc; + JCStatement body; + boolean defaultEmitted = false; + int i = 0; + for (JCCase c : cases) { - if (hasDefault(c)) defCase = c; - else nonDefs.add(c); + JCTree lastNormal = null; + boolean seenNull = false, seenDefault = false; + List caseLabels = getLabels(c); + + // Scan for "null" and "default" labels + for (JCTree label : caseLabels) { + if (isNull(label)) seenNull = true; + else if (isDefault(label) || (isEnum && isPattern(label))) seenDefault = true; + else lastNormal = label; + } + + // Emit fall-through slots + for (JCTree label : caseLabels) { + if (isNull(label) || isDefault(label) || (isEnum && isPattern(label))) continue; + if (label == lastNormal) break; + int lv = isEnum ? getEnumOrdinal(getLabelExpression(label)) : i++; + if (lv < 0) continue; + nc = makeStatementCase(makeLabel(lv), List.nil()); + if (nc != null) newCases.append(nc); + } + + // Emit the last normal label + if (lastNormal != null) { + int lv = isEnum ? getEnumOrdinal(getLabelExpression(lastNormal)) : i++; + if (lv >= 0) { + body = make.Block(0, buildCaseBody(currentSwitch, c, sel, isExpression)); + nc = makeStatementCase(makeLabel(lv), List.of(body)); + if (nc != null) newCases.append(nc); + } + } + + // Emit null slot at its original position + if (seenNull) { + // In case of null and default are on the same case + body = seenDefault + ? null + : make.Block(0, buildCaseBody(currentSwitch, c, sel, isExpression)); + nc = makeStatementCase(makeLabel(-1), body == null ? List.nil() : List.of(body)); + if (nc != null) newCases.append(nc); + } + + // Emit default slot at its original position + if (seenDefault && !defaultEmitted) { + defaultEmitted = true; + body = make.Block(0, buildCaseBody(currentSwitch, c, sel, isExpression)); + nc = makeStatementCase(makeDefaultCaseLabel(), List.of(body)); + if (nc != null) newCases.append(nc); + } + } + + // Emit default case for exhaustive switches + if (!defaultEmitted) { + body = makeMatchExceptionThrow(sel, !isEnum); + nc = makeStatementCase(makeDefaultCaseLabel(), List.of(body)); + if (nc != null) newCases.append(nc); } - int n = nonDefs.size(); - // Ternary chain - final Name selName = sel instanceof JCIdent ? ((JCIdent) sel).name : null; + return newCases.toList(); + } + + /** Builds a simple {@code sel == null ? -1 : sel.ordinal()} condition. */ + public JCExpression buildEnumSwitchCondition(JCExpression sel, boolean hasNull, JCExpression rawSel) { + resolveMethodsCache(); + VarSymbol selSym = (sel instanceof JCIdent && ((JCIdent) sel).sym instanceof VarSymbol) + ? (VarSymbol) ((JCIdent) sel).sym + : null; + JCExpression base = makeMethodCall(sel, enum_ordinal); + if (hasNull) base = makeConditional(makeBinary(Tag.EQ, sel, makeNull()), make.Literal(-1), base); + return rawSel != null && selSym != null ? assignSwitchSelector(base, selSym, rawSel) : base; + } + + /** Builds a ternary chain and groups the same type-pattern. */ + public JCExpression buildTypeSwitchCondition( + JCExpression sel, boolean hasNull, List labels, List slotCases, + JCExpression rawSel + ) { + int n = labels.size(); + Map group = new HashMap<>(); + JCExpression[] anchor = new JCExpression[n]; + + for (int i = 0; i < n; i++, labels = labels.tail, slotCases = slotCases.tail) { + JCExpression pt = getPatternTypeForGrouping(labels.head); + Symbol tsym = pt != null && pt.type != null ? types.erasure(pt.type).tsym : null; + + if (tsym == null) { + anchor[i] = buildLabelCondition(labels.head, sel); + continue; + } + + JCExpression guard = buildGuard(labels.head, sel, slotCases.head); + JCConditional open = group.get(tsym); + + if (open == null) { + JCExpression thenpart; + if (guard != null) { + thenpart = makeConditional(guard, make.Literal(i), make.Literal(n)); + group.put(tsym, (JCConditional) thenpart); + } else { + thenpart = make.Literal(i); + } + anchor[i] = makeConditional(makeTypeTest(sel, pt), thenpart, make.Literal(n)); + } else { + anchor[i] = null; + if (guard != null) { + JCConditional next = makeConditional(guard, make.Literal(i), make.Literal(n)); + open.falsepart = next; + open.type = syms.intType; + group.put(tsym, next); + } else { + open.falsepart = make.Literal(i); + open.type = syms.intType; + group.remove(tsym); + } + } + } + + // Assemble ternary chain + sel = TreeInfo.skipParens(sel); + VarSymbol selSym = (sel instanceof JCIdent && ((JCIdent) sel).sym instanceof VarSymbol) + ? (VarSymbol) ((JCIdent) sel).sym + : null; + boolean inlineInNull = hasNull && rawSel != null && selSym != null; JCExpression ternary = make.Literal(n); + + // Null check + if (!hasNull && sel.type != null && !sel.type.isPrimitive() && selSym != null) { + rawSel = attr.makeNullCheck(rawSel != null ? rawSel : sel); + } + for (int i = n - 1; i >= 0; i--) { - JCCase c = nonDefs.get(i); - JCExpression cond = buildCondition(getLabels(c), sel, c); + JCExpression cond = anchor[i]; if (cond == null) continue; - - if (i == 0 && rawSel != null && selName != null) { - // Replace the first reference to the selector ident with (sv = rawSel). - cond = new TreeTranslator() { - boolean a = true; - @Override - public void visitIdent(JCIdent id) { - if (a && id.name == selName) { - a = false; - result = make.Parens(make.Assign(make.Ident(selName), rawSel)); - } else { - result = id; - } - } - }.translate(cond); + if (i == 0 && !inlineInNull && rawSel != null && selSym != null) { + cond = assignSwitchSelector(cond, selSym, rawSel); + } + if (cond instanceof JCConditional) { + ((JCConditional) cond).falsepart = ternary; + ternary = cond; + } else { + ternary = makeConditional(cond, make.Literal(i), ternary); } - ternary = make.Conditional(cond, make.Literal(i), ternary); } // In case of - if (n == 0 && rawSel != null && selName != null) { - ternary = make.Parens(make.Assign(make.Ident(selName), rawSel)); + if (rawSel != null && selSym != null && !inlineInNull && n > 0 && anchor[0] == null) { + ternary = assignSwitchSelector(ternary, selSym, rawSel); } - // Rebuild cases with int label - ListBuffer newCases = new ListBuffer<>(); - for (int i = 0; i < n; i++) { - JCCase nc = makeCase( - CaseTree.CaseKind.STATEMENT, - List.of(makeLabel(i)), - null, - List.of(make.Block(0, buildCaseBody(nonDefs.get(i), sel, expression))), - null - ); - if (nc != null) newCases.append(nc); + if (hasNull) { + JCExpression nullSel = inlineInNull ? makeAssignParens(selSym, rawSel) : sel; + ternary = makeConditional(makeBinary(Tag.EQ, nullSel, makeNull()), make.Literal(-1), ternary); } - // Default case - JCTree dl = makeDefaultCaseLabel(); - if (dl != null) { - JCCase dc = makeCase( - CaseTree.CaseKind.STATEMENT, - List.of(dl), - null, - List.of(make.Block(0, defCase != null ? - buildCaseBody(defCase, sel, expression) : - List.of(makeMatchExceptionThrow()) - )), - null - ); - if (dc != null) newCases.append(dc); + return (n == 0 && rawSel != null && selSym != null) ? makeAssignParens(selSym, rawSel) : ternary; + } + + private JCExpression getPatternTypeForGrouping(JCTree label) { + if (!isPattern(label)) return null; + JCVariableDecl pv = getPatternVar(label); + if (pv != null) return pv.vartype; + return getRecordType(label); + } + + /** Cache enum values. */ + // TODO: maybe find a better way to get Enum#ordinal()? + private final Map enumOrdinalCache = new HashMap<>(); + public int getEnumOrdinal(JCExpression expr) { + Symbol sym = null; + if (expr instanceof JCIdent) sym = ((JCIdent) expr).sym; + else if (expr instanceof JCFieldAccess) sym = ((JCFieldAccess) expr).sym; + if (!(sym instanceof VarSymbol) || (sym.flags() & Flags.ENUM) == 0) return -1; + + Integer cached = enumOrdinalCache.get(sym); + if (cached != null) return cached; + + int ordinal = 0; + for (Symbol s : sym.owner.getEnclosedElements()) { + if ((s.flags() & Flags.ENUM) == 0) continue; + enumOrdinalCache.put(s, ordinal); + ordinal++; } + cached = enumOrdinalCache.get(sym); + return cached != null ? cached : -1; + } - return make.Switch(ternary, newCases.toList()); + public JCExpression captureSelector(JCExpression sel, ListBuffer prefix, boolean init) { + if (!isComplex(sel)) return sel; + JCVariableDecl v = makeVarDef(switchTempName(), sel.type, init ? sel : null); + prefix.append(v); + return make.QualIdent(v.sym); } - public List buildCaseBody(JCCase c, JCExpression sel, boolean expression) { + private void attachPatternMatchingCatch(JCBlock block) { + if (pendingAccessorCalls == null || !MATCH_EXCEPTION_PRESENT) return; + // TODO: JDK 19 generates proxy methods, but it was also the first preview of record-pattern-matching. + // Should we also generate proxy methods or leave as is, without error handling? + resolveMethodsCache(); + JCVariableDecl ctch = makeVarDef(recordTempName(), syms.throwableType, null, Flags.PARAMETER); + JCCatch handler = make.Catch( + ctch, + make.Block(0, List.of(make.Throw(makeNewClass( + me_ctor, + List.of(makeMethodCall(make.Ident(ctch.sym), object_toString), make.Ident(ctch.sym)) + )))) + ); + attachPatternMatchingCatch(block, handler, pendingAccessorCalls); + pendingAccessorCalls = null; + } + + public List buildCaseBody(JCTree cSwitch, JCCase c, JCExpression sel, boolean expression) { ListBuffer out = new ListBuffer<>(); addBindings(c, sel, out); JCTree body = getBody(c); - if (body instanceof JCBlock) { - for (JCStatement s : ((JCBlock) body).stats) { - out.append(s); - } + for (JCStatement s : ((JCBlock) body).stats) out.append(s); } else if (body != null) { - JCExpression expr = extractExpression(body); + JCExpression expr = getExpression(body); if (expr != null) { if (expression) { - out.append(make.Yield(expr)); + JCYield y = make.Yield(expr); + y.target = cSwitch; + out.append(y); } else { out.append(make.Exec(expr)); - out.append(make.Break(null)); + JCBreak b = make.Break(null); + b.target = cSwitch; + out.append(b); } } else if (body instanceof JCStatement) { out.append((JCStatement) body); } } else if (c.stats != null) { - for (JCStatement s : c.stats) { - out.append(s); - } + for (JCStatement s : c.stats) out.append(s); } - return out.toList(); } - /** - * Injects - * {@code default: throw new UnsupportedOperationException("MatchException");} - * for exhaustive switches that have no explicit default. - */ - public List injectDefault(List cases) { - if (cases == null) return cases; - for (JCCase c : cases) { - if (c != null && hasDefault(c)) { - return cases; + /** Returns the guard expression for a pattern slot (with binding substitution), or null if none. */ + public JCExpression buildGuard(JCTree label, JCExpression sel, JCCase src) { + if (src == null) return null; + JCExpression guard = getGuard(src); + if (guard == null) return null; + Map bindings = collectBindings(label, sel, src); + if (bindings.isEmpty()) return guard; + return new TreeTranslator() { + @Override + public void visitIdent(JCIdent id) { + JCExpression r = bindings.get(id.name); + result = r != null ? r : id; + } + }.translate(guard); + } + + public JCExpression buildLabelCondition(JCTree label, JCExpression sel) { + if (isDefault(label)) return null; + + if (isPattern(label)) { + // sel instanceof Type + JCVariableDecl pv = getPatternVar(label); + if (pv != null) return makeTypeTest(sel, pv.vartype); + JCExpression rt = getRecordType(label); + if (rt != null) return makeTypeTest(sel, rt); + JCExpression pt = getPatternType(label); + if (pt != null) return makeTypeTest(sel, pt); + } + + JCExpression expr = getLabelExpression(label); + if (expr == null) return null; + if (expr instanceof JCLiteral) { + JCLiteral lit = (JCLiteral) expr; + switch (lit.typetag) { + case FLOAT: + // TODO: use Float#floatToIntBits() result as a selector? + case DOUBLE: { + // Boxed.xxxxToxxxBits((Boxed)sel) == Boxed.xxxxToxxxBits(expr) + // TODO: cache sel result + resolveMethodsCache(); + boolean isFloat = lit.typetag == TypeTag.FLOAT; + ClassSymbol type = types.boxedClass(isFloat ? syms.floatType : syms.doubleType); + MethodSymbol method = isFloat ? float_floatToIntBits : double_doubleToLongBits; + return makeBinary( + Tag.EQ, + makeMethodCall(make.QualIdent(type), method, List.of(makeCast(sel, type.type))), + makeMethodCall(make.QualIdent(type), method, List.of(expr)) + ); + } + case CLASS: //$FALL-THROUGH$ + break; + case INT: + // TODO: we lost O(1) if there is only a null case. Let the compiler decide? + // TODO2: process the same ways as JDK 17? (sel == null ? (sel != n ? sel : n+1) : n) + // where the null case is a value not used in labels, instead of -1. + default: + // sel == expr + return makeBinary(Tag.EQ, sel, expr); + } + + } else if (isEnumConstant(expr)) { + // sel == Enum.expr + if (expr instanceof JCIdent) { + JCIdent id = (JCIdent) expr; + if (id.sym != null && id.sym.owner != null && id.sym.owner.type != null) { + JCFieldAccess fa = make.Select(make.Type(id.sym.owner.type), id.name); + fa.sym = id.sym; + fa.type = id.type; + expr = fa; + } } + return makeBinary(Tag.EQ, sel, expr); } + // ((Object)sel).equals(expr) (Implicit null check) + resolveMethodsCache(); + return makeMethodCall(sel, object_equals, List.of(expr)); + } + + /** Injects {@code default: throw new IncompatibleClassChangeError(...)} if no default exists. */ + public List injectDefault(JCExpression sel, List cases) { + if (hasDefault(cases)) return cases; CaseTree.CaseKind kind = CaseTree.CaseKind.STATEMENT; for (JCCase c : cases) { if (c != null) { @@ -504,80 +1075,70 @@ public List injectDefault(List cases) { } } - JCTree defaultLabel = makeDefaultCaseLabel(); - if (defaultLabel == null) return cases; - - JCStatement throwStmt = makeMatchExceptionThrow(); - JCCase defaultCase = makeCase( - kind, - List.of(defaultLabel), - null, - List.of(throwStmt), - kind == CaseTree.CaseKind.RULE ? throwStmt : null - ); - return defaultCase != null ? cases.append(defaultCase) : cases; + JCStatement throwStmt = makeMatchExceptionThrow(sel, false); + JCStatement body = kind == CaseTree.CaseKind.RULE ? throwStmt : null; + JCCase dc = makeCase(kind, List.of(makeDefaultCaseLabel()), List.of(throwStmt), body); + return dc != null ? cases.append(dc) : cases; } - public JCExpression buildCondition(List labels, JCExpression sel, JCCase caseTree) { - JCExpression condition = null; - for (JCTree label : labels) { - JCExpression lc = buildLabelCondition(label, sel); - if (lc == null) continue; - condition = condition == null ? lc : make.Binary(JCTree.Tag.OR, condition, lc); + public List getRecordComponentNames(JCTree pattern) { + JCExpression deconstructor = getRecordType(pattern); + if (deconstructor == null) return List.nil(); + Type recordType = deconstructor.type; + if (recordType == null || !(recordType.tsym instanceof ClassSymbol)) return List.nil(); + List result = List.nil(); + for (RecordComponent rc : ((ClassSymbol) recordType.tsym).getRecordComponents()) { + result = result.append(rc.name); } + return result; + } - JCExpression guard = getGuard(caseTree); - if (guard == null) return condition; + public Map collectBindings(JCTree label, JCExpression sel, JCCase caseTree) { + Map preVars = guardPreVars != null ? guardPreVars.get(caseTree) : null; + Map map = new HashMap<>(); - Map bindingMap = collectBindings(labels, sel); - if (!bindingMap.isEmpty()) { - guard = new TreeTranslator() { - @Override - public void visitIdent(JCIdent ident) { - JCExpression replacement = bindingMap.get(ident.name); - result = replacement != null ? replacement : ident; - } - }.translate(guard); + JCVariableDecl pv = getPatternVar(label); + if (pv != null) { + map.put(pv.name, makeCast(sel, symTypeOf(pv))); + return map; } - return condition == null ? guard : make.Binary(JCTree.Tag.AND, condition, guard); - } + JCExpression rt = getRecordType(label); + List nested = getRecordNested(label); + if (rt == null || nested == null) return map; - public JCExpression buildLabelCondition(JCTree label, JCExpression sel) { - if (isNull(label)) { - return make.Binary(JCTree.Tag.EQ, helper.copy(sel), make.Literal(TypeTag.BOT, null)); - } else if (isPattern(label)) { - JCVariableDecl pv = getPatternVar(label); - if (pv != null) { - return make.TypeTest(helper.copy(sel), pv.vartype); - } - JCExpression rt = getRecordType(label); - if (rt != null) { - return make.TypeTest(helper.copy(sel), rt); - } - JCExpression pt = getPatternType(label); - if (pt != null) { - return make.TypeTest(helper.copy(sel), pt); - } - } else if (!isDefault(label)) { - JCExpression expr = getLabelExpression(label); - if (expr != null) { - return make.Apply( - List.nil(), - make.Select(expr, names.equals), - List.of(helper.copy(sel)) - ); - } + Type recType = rt.type != null ? rt.type : syms.objectType; + List componentNames = getRecordComponentNames(label); + + for (JCTree np : nested) { + Name cn = componentNames.head; + componentNames = componentNames.tail; + JCVariableDecl nv = getPatternVar(np); + if (nv == null) continue; + + cn = cn != null ? cn : nv.name; + JCMethodInvocation accessorCall = makeMethodCall(makeCast(sel, recType), cn); + if (pendingAccessorCalls != null) pendingAccessorCalls.add(accessorCall); + + JCExpression val = preVars != null && preVars.containsKey(nv.name) + ? makeAssignParens(preVars.get(nv.name), accessorCall) + : accessorCall; + map.put(nv.name, val); } - return null; + return map; } - public Map collectBindings(List labels, JCExpression sel) { - Map map = new HashMap<>(); - for (JCTree label : labels) { + /** Emits binding variable declarations at the top of a case body. */ + public void addBindings(JCCase caseTree, JCExpression sel, ListBuffer out) { + Map preVars = guardPreVars != null ? guardPreVars.get(caseTree) : null; + + for (JCTree label : getLabels(caseTree)) { + if (!isPattern(label)) continue; + JCVariableDecl pv = getPatternVar(label); if (pv != null) { - map.put(pv.name, make.TypeCast(pv.vartype, helper.copy(sel))); + Type castTy = symTypeOf(pv); + out.append(make.VarDef(pv.sym, makeCast(sel, castTy))); continue; } @@ -585,99 +1146,389 @@ public Map collectBindings(List labels, JCExpression List nested = getRecordNested(label); if (rt == null || nested == null) continue; - List componentNames = helper.getRecordComponentNames(label); - int i = 0; - for (JCTree np : nested) { - Name cn = i < componentNames.size() ? componentNames.get(i++) : null; - JCVariableDecl nv = getPatternVar(np); - if (nv == null) continue; + // RecordType $record$N = (RecordType) sel; + JCVariableDecl v = makeVarDef(recordTempName(), rt.type, makeCast(sel, rt.type)); + out.append(v); + + if (preVars == null || preVars.isEmpty()) { + ListBuffer bindings = new ListBuffer<>(); + extractRecordBindings(nested, make.Ident(v.sym), getRecordComponentNames(label), bindings); + for (JCVariableDecl decl : bindings) { + if (pendingAccessorCalls != null) { + collectMethodInvocations(decl.init, pendingAccessorCalls); + } + out.append(decl); + } + return; + } - map.put(nv.name, make.Apply( - List.nil(), - make.Select( - make.TypeCast(rt, helper.copy(sel)), - cn != null ? cn : nv.name - ), - List.nil() - )); + List componentNames = getRecordComponentNames(label); + for (JCTree np : nested) { + Name name = componentNames.head; + componentNames = componentNames.tail; + if (!isBindingPattern(np)) continue; + JCVariableDecl npv = ((JCBindingPattern) np).var; + if (npv == null) continue; + + VarSymbol preVar = preVars.get(npv.name); + JCExpression init; + if (preVar != null) { + init = make.Ident(preVar); // alias the pre-var assigned in the guard + } else { + name = name != null ? name : npv.name; + JCMethodInvocation call = makeMethodCall(make.Ident(v.sym), name); + if (pendingAccessorCalls != null) pendingAccessorCalls.add(call); + init = call; + } + out.append(make.VarDef(npv.sym, init)); } } - return map; } - private void addBindings(JCCase caseTree, JCExpression sel, ListBuffer out) { - for (JCTree label : getLabels(caseTree)) { - if (!isPattern(label)) continue; - - JCVariableDecl pv = getPatternVar(label); - if (pv != null) { - out.append(make.VarDef( - make.Modifiers(Flags.FINAL), - pv.name, - pv.vartype, - make.TypeCast(pv.vartype, helper.copy(sel)) - )); + /** + * Recursively extract bindings from nested record pattern components. + *

+ * Generates variable declarations like: + * {@code final Point $tmp = parent.b(); final int bx = $tmp.x(); final int by = $tmp.y();} + */ + public void extractRecordBindings( + List nested, JCExpression baseAccessor, List componentNames, + ListBuffer out + ) { + for (JCTree pattern : nested) { + Name name = componentNames.head; + componentNames = componentNames.tail; + + if (isBindingPattern(pattern)) { + JCVariableDecl var = ((JCBindingPattern) pattern).var; + if (var == null) continue; + name = name != null ? name : var.name; + out.append(make.VarDef(var.sym, makeMethodCall(baseAccessor, name))); continue; } - JCExpression rt = getRecordType(label); - List nested = getRecordNested(label); - if (rt == null || nested == null) continue; + if (!isRecordPattern(pattern)) continue; + JCExpression nestedRecordType = getRecordType(pattern); + List deepNested = getRecordNested(pattern); + if (nestedRecordType == null || deepNested == null || name == null) continue; - Name cv = helper.tempName(); - out.append(helper.makeVarDef(cv, rt, helper.makeCast(rt, helper.copy(sel)))); - ListBuffer bindings = new ListBuffer<>(); - helper.extractRecordBindings( - nested, - make.Ident(cv), - helper.getRecordComponentNames(label), - bindings + JCVariableDecl tmpDecl = makeVarDef( + recordTempName(), + nestedRecordType.type, + makeMethodCall(baseAccessor, name) + ); + out.append(tmpDecl); + extractRecordBindings( + deepNested, + make.Ident(tmpDecl.sym), + getRecordComponentNames(pattern), + out ); - for (JCVariableDecl decl : bindings) { - out.append(decl); - } } } - private boolean isComplex(JCExpression expr) { - if (expr instanceof JCIdent || expr instanceof JCLiteral) return false; - if (expr instanceof JCFieldAccess) { - return isComplex(((JCFieldAccess) expr).selected); - } - if (expr instanceof JCParens) { - return isComplex(((JCParens) expr).expr); + /** + * Pre-declares variables for record components referenced in a guard.
+ * These are assigned inline in the ternary-chain via {@link #collectBindings}, + * and aliased in the case body via {@link #addBindings}. + */ + public void buildGuardPreDecls(List cases, ListBuffer out) { + guardPreVars = null; + + for (JCCase c : cases) { + JCExpression guard = getGuard(c); + if (guard == null) continue; + + for (JCTree label : getLabels(c)) { + if (getRecordType(label) == null) continue; + List nested = getRecordNested(label); + if (nested == null) continue; + + // Collect names referenced in the guard + Set guardNames = new HashSet<>(); + new TreeScanner() { + @Override + public Void visitIdentifier(IdentifierTree node, Void v) { + if (node instanceof JCIdent) guardNames.add(((JCIdent) node).name); + return null; + } + }.scan(guard, null); + + Map varMap = new HashMap<>(); + if (guardPreVars == null) guardPreVars = new HashMap<>(); + guardPreVars.put(c, varMap); + + for (JCTree np : nested) { + JCVariableDecl pv = getPatternVar(np); + if (pv == null || pv.name == null || pv.vartype == null) continue; + if (!guardNames.contains(pv.name)) continue; + + JCVariableDecl v = makeVarDef( + recordTempName(), + symTypeOf(pv), + makeDefaultValue(pv.vartype), + 0 + ); + varMap.put(pv.name, v.sym); + out.append(v); + } + } } - return true; // method calls, new, binary ops, etc. } - private static boolean needsTransform(List cases) { - return hasPatterns(cases) || hasNull(cases); + // end region + // region making + + private JCMethodInvocation makeMethodCall(JCExpression receiver, Name methodName) { + return makeMethodCall(receiver, findMethod(receiver.type, methodName, List.nil()), List.nil()); + } + + private JCMethodInvocation makeMethodCall(JCExpression receiver, MethodSymbol meth) { + return makeMethodCall(receiver, meth, List.nil()); + } + + private JCMethodInvocation makeMethodCall(JCExpression receiver, MethodSymbol meth, List args) { + JCFieldAccess fa = make.Select(copier.copy(receiver), meth.name); + fa.sym = meth; + fa.type = meth.erasure(types); + JCMethodInvocation call = make.Apply(List.nil(), fa, args); + call.type = types.erasure(meth.getReturnType()); + return call; + } + + private JCNewClass makeNewClass(MethodSymbol ctor, List args) { + JCNewClass tree = make.NewClass(null, null, make.QualIdent(ctor.owner), args, null); + tree.constructor = ctor; + tree.constructorType = ctor.erasure(types); + tree.type = ctor.owner.type; + return tree; + } + + private JCInstanceOf makeTypeTest(JCExpression lhs, JCExpression type) { + JCInstanceOf tree = make.TypeTest(copier.copy(lhs), type); + tree.type = syms.booleanType; + return tree; + } + + /** Only with int types. */ + private JCConditional makeConditional(JCExpression cond, JCExpression thenpart, JCExpression elsepart) { + JCConditional c = make.Conditional(cond, thenpart, elsepart); + c.type = syms.intType; + return c; + } + + private JCBinary makeBinary(JCTree.Tag optag, JCExpression lhs, JCExpression rhs) { + JCBinary tree = make.Binary(optag, copier.copy(lhs), rhs); // only copy left side + tree.operator = resolveBinary(ops, tree, optag, lhs.type, rhs.type); + if (tree.operator != null) tree.type = types.erasure(tree.operator.type.getReturnType()); + return tree; + } + + private JCVariableDecl makeVarDef(Name name, Type type, JCExpression init) { + return makeVarDef(name, type, init, Flags.FINAL); + } + + private JCVariableDecl makeVarDef(Name name, Type type, JCExpression init, long addFlags) { + return make.VarDef(new VarSymbol(Flags.SYNTHETIC | addFlags, name, type, currentMethodSym), init); + } + + private JCExpression makeCast(JCExpression expr, Type type) { + if (expr.type != null && types.isSubtype(expr.type, type)) return expr; + return make.at(expr.pos()).TypeCast(make.Type(type), copier.copy(expr)).setType(type); } - private JCExpression extractExpression(JCTree body) { - if (body instanceof JCExpressionStatement) { - return ((JCExpressionStatement) body).expr; + private JCParens makeAssignParens(VarSymbol sym, JCExpression expr) { + JCExpression assign = make.Assign(make.Ident(sym), expr).setType(sym.type); + JCExpression parens = make.Parens(assign).setType(assign.type); + return (JCParens) parens; + } + + private JCExpression makeDefaultValue(JCExpression type) { + if (!(type instanceof JCPrimitiveTypeTree)) return makeNull(); + switch (((JCPrimitiveTypeTree) type).getPrimitiveTypeKind()) { + case BOOLEAN: + return make.Literal(false); + case LONG: + return make.Literal(0L); + case FLOAT: + return make.Literal(0.0f); + case DOUBLE: + return make.Literal(0.0d); + default: + return make.Literal(0); } - if (body instanceof JCExpression) { - return (JCExpression) body; + } + + private JCLiteral makeNull() { + return make.Literal(TypeTag.BOT, null).setType(syms.botType); + } + + private JCStatement makeMatchExceptionThrow(JCExpression sel, boolean useType) { + resolveMethodsCache(); + return make.Throw(useMatchException + ? makeNewClass(me_ctor, List.of(makeNull(), makeNull())) + : makeNewClass( + icce_ctor, + List.of(makeBinary( + Tag.PLUS, + make.Literal("MatchException: Unhandled case: "), + useType + ? makeMethodCall(makeMethodCall(sel, object_getClass), class_getName) + : sel + )) + ) + ); + } + + private JCCase makeStatementCase(JCTree label, List body) { + // Use an empty label list for JDK < 17 + List labels = label != null ? List.of(label) : List.nil(); + return makeCase(CaseTree.CaseKind.STATEMENT, labels, body, null); + } + + // end region + // region other + + // TODO: need to find a way to get an Env, to use Resolve instead + private MethodSymbol findMethod(Type type, Name name, List paramTypes) { + if (type == null || type.tsym == null) return null; + int pSize = paramTypes.size(); + outer: + for (Symbol s : type.tsym.members().getSymbolsByName(name)) { + if (!(s instanceof MethodSymbol)) continue; + MethodSymbol ms = (MethodSymbol) s; + if (ms.params().size() != pSize) continue; + List n = paramTypes; + for (VarSymbol p : ms.params()) { + if (!types.isSameType(p.type, n.head)) continue outer; + n = n.tail; + } + return ms; } + //TODO use Log instead? But this should never fail as it's only used to resolve things from Java API, + // and record components, which have already been resolved by the compiler. + System.err.println("[jabel] Failed to resolve " + type + "." + name + "(" + paramTypes + ")"); return null; } - private JCStatement makeFinalVar(Name name, JCExpression type, JCExpression init) { - return make.VarDef(make.Modifiers(Flags.FINAL), name, type, init); + private void resolveMethodsCache() { + if (cacheResolved) return; + cacheResolved = true; + + object_equals = findMethod(syms.objectType, names.equals, List.of(syms.objectType)); + object_toString = findMethod(syms.objectType, names.toString, List.nil()); + object_getClass = findMethod(syms.objectType, names.getClass, List.nil()); + class_getName = findMethod(syms.classType, names.fromString("getName"), List.nil()); + enum_ordinal = findMethod(syms.enumSym.type, names.ordinal, List.nil()); + float_floatToIntBits = findMethod( + types.boxedClass(syms.floatType).type, + names.fromString("floatToIntBits"), + List.of(syms.floatType) + ); + double_doubleToLongBits = findMethod( + types.boxedClass(syms.doubleType).type, + names.fromString("doubleToLongBits"), + List.of(syms.doubleType) + ); + icce_ctor = findMethod(syms.incompatibleClassChangeErrorType, names.init, List.of(syms.stringType)); + if (!MATCH_EXCEPTION_PRESENT) return; + me_ctor = findMethod(syms.matchExceptionType, names.init, List.of(syms.stringType, syms.throwableType)); + // Use MatchException if available in classpath + useMatchException = ((ClassSymbol) syms.matchExceptionType.tsym).classfile != null; + } + + private Type symTypeOf(JCVariableDecl pv) { + if (pv.sym != null && pv.sym.type != null) return pv.sym.type; + if (pv.vartype != null && pv.vartype.type != null) return pv.vartype.type; + return syms.objectType; + } + + private JCExpression assignSwitchSelector(JCExpression cond, VarSymbol selSym, JCExpression rawSel) { + return new TreeTranslator() { + boolean first = true; + @Override + public void visitIdent(JCIdent id) { + result = first && id.sym == selSym ? makeAssignParens(selSym, rawSel) : id; + first = false; + } + }.translate(cond); + } + + /** Recursively collects {@link JCMethodInvocation} nodes from an expression init. */ + private static void collectMethodInvocations(JCTree tree, Set out) { + if (tree instanceof JCMethodInvocation) { + out.add((JCMethodInvocation) tree); + } else if (tree instanceof JCTypeCast) { + collectMethodInvocations(((JCTypeCast) tree).expr, out); + } + } + + private Name switchTempName() { + return names.fromString("$switch$" + tempVarCounter++); } - // TODO: explain expected cases? - private JCStatement makeMatchExceptionThrow() { - return make.Throw(make.NewClass( - null, - List.nil(), - make.Ident(names.fromString("UnsupportedOperationException")), - List.of(make.Literal("MatchException")), - null - )); + private Name recordTempName() { + return names.fromString("$record$" + tempVarCounter++); } // end region -} \ No newline at end of file + + /** {@link TreeCopier} that preserves types and symbols. This does not includes declarations. */ + public static class SymTreeCopier extends TreeCopier { + public SymTreeCopier(TreeMaker M) { + super(M); + } + + @Override + public Z copy(Z tree, T p) { + Z fresh = super.copy(tree, p); + if (fresh != null && fresh != tree && tree instanceof JCExpression && fresh instanceof JCExpression) { + ((JCExpression) fresh).type = ((JCExpression) tree).type; + } + return fresh; + } + + @Override + public JCTree visitBinary(BinaryTree node, T p) { + JCTree fresh = super.visitBinary(node, p); + ((JCBinary) fresh).operator = ((JCBinary) node).operator; + return fresh; + } + + @Override + public JCTree visitUnary(UnaryTree node, T p) { + JCTree fresh = super.visitUnary(node, p); + ((JCUnary) fresh).operator = ((JCUnary) node).operator; + return fresh; + } + + @Override + public JCTree visitIdentifier(IdentifierTree node, T p) { + JCTree fresh = super.visitIdentifier(node, p); + ((JCIdent) fresh).sym = ((JCIdent) node).sym; + return fresh; + } + + @Override + public JCTree visitNewClass(NewClassTree node, T p) { + JCTree fresh = super.visitNewClass(node, p); + ((JCNewClass) fresh).constructor = ((JCNewClass) node).constructor; + ((JCNewClass) fresh).constructorType = ((JCNewClass) node).constructorType; + return fresh; + } + + @Override + public JCTree visitMemberSelect(MemberSelectTree node, T p) { + JCTree fresh = super.visitMemberSelect(node, p); + ((JCFieldAccess) fresh).sym = ((JCFieldAccess) node).sym; + return fresh; + } + + @Override + public JCTree visitMemberReference(MemberReferenceTree node, T p) { + JCTree fresh = super.visitMemberReference(node, p); + ((JCMemberReference) fresh).sym = ((JCMemberReference) node).sym; + return fresh; + } + } +}