Skip to content
1 change: 1 addition & 0 deletions arrow-reflect-annotations/src/main/kotlin/MetaModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ interface MetaModule: Module {
val decorator: Decorator
val pure: Pure
val immutable: Immutable
val disallowLambdaCapture: DisallowLambdaCapture
}

10 changes: 5 additions & 5 deletions arrow-reflect-annotations/src/main/kotlin/arrow/meta/Meta.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package arrow.meta

import org.jetbrains.kotlin.fir.FirAnnotationContainer
import org.jetbrains.kotlin.fir.FirLabel
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.contracts.*
import org.jetbrains.kotlin.fir.declarations.*
import org.jetbrains.kotlin.fir.expressions.*
Expand All @@ -28,7 +29,7 @@ annotation class Meta {
fun <In, Out> intercept(args: List<In>, func: (List<In>) -> Out): Out

override fun FirMetaCheckerContext.functionCall(functionCall: FirFunctionCall): FirStatement {
val newCall = if (isDecorated(functionCall, session)) {
val newCall = if (session.isDecorated(functionCall)) {
//language=kotlin
val call: FirCall = decoratedCall(functionCall)
call
Expand All @@ -37,15 +38,14 @@ annotation class Meta {
}

@OptIn(SymbolInternals::class)
private fun isDecorated(newElement: FirFunctionCall, session: FirSession): Boolean =
private fun FirSession.isDecorated(newElement: FirFunctionCall): Boolean =
newElement.toResolvedCallableSymbol()?.fir?.annotations?.hasAnnotation(
classId = ClassId.topLevel(
FqName(
annotation.java.canonicalName
)
),
session = session
) == true
)
, this) == true

private fun FirMetaContext.decoratedCall(
newElement: FirFunctionCall
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package arrow.meta.samples

import arrow.meta.Diagnostics
import arrow.meta.FirMetaCheckerContext
import arrow.meta.Meta
import arrow.meta.samples.DisallowLambdaCaptureErrors.UnsafeCaptureDetected
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.declarations.FirAnonymousFunction
import org.jetbrains.kotlin.fir.declarations.InlineStatus
import org.jetbrains.kotlin.fir.declarations.findArgumentByName
import org.jetbrains.kotlin.fir.declarations.getAnnotationByClassId
import org.jetbrains.kotlin.fir.expressions.FirAnnotation
import org.jetbrains.kotlin.fir.expressions.FirConstExpression
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
import org.jetbrains.kotlin.fir.expressions.toResolvedCallableSymbol
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.types.ConstantValueKind

object DisallowLambdaCaptureErrors : Diagnostics.Error {
val UnsafeCaptureDetected by error1()
}

@Meta
@Target(AnnotationTarget.FUNCTION)
annotation class DisallowLambdaCapture(val msg: String = "") {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this not be more precise as msg: String??

companion object : Meta.Checker.Expression<FirFunctionCall>,
Diagnostics(UnsafeCaptureDetected) {

val annotation = DisallowLambdaCapture::class.java

override fun FirMetaCheckerContext.check(expression: FirFunctionCall) {
val nameArg = expression
.disallowLambdaCaptureAnnotation(session)?.findArgumentByName(Name.identifier(DisallowLambdaCapture::msg.name))
val userMsg =
if (nameArg is FirConstExpression<*> && nameArg.kind == ConstantValueKind.String) nameArg.value as? String
else null
scopeDeclarations.filterIsInstance<FirAnonymousFunction>().forEach { scope ->
if (scope.inlineStatus != InlineStatus.Inline) {
expression.report(
UnsafeCaptureDetected,
userMsg
?: "detected call to member @DisallowLambdaCapture `${+expression}` in non-inline anonymous function"
)
}
}
}

private fun FirFunctionCall.disallowLambdaCaptureAnnotation(session: FirSession): FirAnnotation? =
toResolvedCallableSymbol()?.fir?.getAnnotationByClassId(
ClassId(
FqName(annotation.`package`.name),
Name.identifier(annotation.simpleName)
),
session
)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,42 @@ class FirMetaAdditionalCheckersExtension(
}

private inline fun <reified E : FirElement> invokeChecker(
superType: KClass<*>,
element: E,
session: FirSession,
context: CheckerContext,
reporter: DiagnosticReporter
superType: KClass<*>,
element: E,
session: FirSession,
context: CheckerContext,
reporter: DiagnosticReporter
) {
if (element is FirAnnotationContainer && element.isMetaAnnotated(session)) {
val annotations = element.metaAnnotations(session)
val metaContext = FirMetaCheckerContext(templateCompiler, session, context, reporter)
invokeMeta<E, Unit>(
false,
metaContext,
annotations,
superType = superType,
methodName = "check",
element
)
}
if ((element is FirAnnotationContainer && element.isMetaAnnotated(session)) || (element is FirFunctionCall && element.isCallToAnnotatedFunction(
session
))
) {
val annotations =
when (element) {
is FirFunctionCall ->
element.metaAnnotations(session) + element.toResolvedCallableSymbol()?.fir?.metaAnnotations(
session
).orEmpty()
is FirAnnotationContainer -> element.metaAnnotations(session)
else -> emptyList()
}
val metaContext = FirMetaCheckerContext(templateCompiler, session, context, reporter)
invokeMeta<E, Unit>(
false,
metaContext,
annotations,
superType = superType,
methodName = "check",
element
)
}
}

private inline fun <reified E : FirFunctionCall> E.isCallToAnnotatedFunction(
session: FirSession
): Boolean {
return toResolvedCallableSymbol()?.fir?.isMetaAnnotated(session) == true
}

override val typeCheckers: TypeCheckers
get() = super.typeCheckers
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
FILE: capture_test.kt
package foo.bar

public abstract interface Raise<in E> : R|kotlin/Any| {
@R|arrow/meta/samples/DisallowLambdaCapture|(msg = String(It's unsafe to capture `raise` inside non-inline anonymous functions)) public abstract fun raise(e: R|E|): R|kotlin/Nothing|

}
context(R|foo/bar/Raise<kotlin/String>|)
public final fun shouldNotCapture(): R|() -> kotlin/Unit| {
^shouldNotCapture fun <anonymous>(): R|kotlin/Unit| <inline=Unknown> {
this@R|foo/bar/shouldNotCapture|.R|SubstitutionOverride<foo/bar/Raise.raise: R|kotlin/Nothing|>|(String(boom))
}

}
context(R|foo/bar/Raise<kotlin/String>|)
public final fun inlineCaptureOk(): R|kotlin/Unit| {
R|kotlin/collections/listOf|<R|kotlin/Int|>(vararg(Int(1), Int(2), Int(3))).R|kotlin/collections/map|<R|kotlin/Int|, R|kotlin/Nothing|>(<L> = map@fun <anonymous>(it: R|kotlin/Int|): R|kotlin/Nothing| <inline=Inline, kind=UNKNOWN> {
this@R|foo/bar/inlineCaptureOk|.R|SubstitutionOverride<foo/bar/Raise.raise: R|kotlin/Nothing|>|(String(boom))
}
)
}
context(R|foo/bar/Raise<kotlin/String>|)
public final fun leakedNotOk(): R|() -> kotlin/Unit| {
^leakedNotOk fun <anonymous>(): R|kotlin/Unit| <inline=Unknown> {
R|kotlin/collections/listOf|<R|kotlin/Int|>(vararg(Int(1), Int(2), Int(3))).R|kotlin/collections/map|<R|kotlin/Int|, R|kotlin/Nothing|>(<L> = map@fun <anonymous>(it: R|kotlin/Int|): R|kotlin/Nothing| <inline=Inline, kind=UNKNOWN> {
this@R|foo/bar/leakedNotOk|.R|SubstitutionOverride<foo/bar/Raise.raise: R|kotlin/Nothing|>|(String(boom))
}
)
}

}
context(R|foo/bar/Raise<kotlin/String>|)
public final fun ok(): R|kotlin/Unit| {
this@R|foo/bar/ok|.R|SubstitutionOverride<foo/bar/Raise.raise: R|kotlin/Nothing|>|(String(boom))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package foo.bar

import arrow.meta.samples.DisallowLambdaCapture

interface Raise<in E> {
@DisallowLambdaCapture("It's unsafe to capture `raise` inside non-inline anonymous functions") fun raise(e: E): Nothing
}

context(Raise<String>)
fun shouldNotCapture(): () -> Unit {
return { <!UnsafeCaptureDetected!>raise("boom")<!> }
}

context(Raise<String>)
fun inlineCaptureOk(): Unit {
listOf(1, 2, 3).map { raise("boom") }
}

context(Raise<String>)
fun leakedNotOk(): () -> Unit = {
listOf(1, 2, 3).map { <!UnsafeCaptureDetected!>raise("boom")<!> }
}

context(Raise<String>)
fun ok(): Unit {
raise("boom")
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ public void testAllFilesPresentInDiagnostics() throws Exception {
KtTestUtil.assertAllTestsPresentByMetadataWithExcluded(this.getClass(), new File("src/testData/diagnostics"), Pattern.compile("^(.+)\\.kt$"), null, true);
}

@Test
@TestMetadata("capture_test.kt")
public void testCapture_test() throws Exception {
runTest("src/testData/diagnostics/capture_test.kt");
}

@Test
@TestMetadata("immutable_test.kt")
public void testImmutable_test() throws Exception {
Expand Down
40 changes: 33 additions & 7 deletions sandbox/src/main/kotlin/Sample.kt
Original file line number Diff line number Diff line change
@@ -1,12 +1,38 @@
package example
package foo.bar

import arrow.meta.samples.Product
import arrow.meta.samples.DisallowLambdaCapture
import kotlin.contracts.*

@Product
data class Sample(val name: String, val age: Int)
interface Raise<in E> {
@DisallowLambdaCapture("It's unsafe to capture `raise` inside non-inline anonymous functions")
fun raise(e: E): Nothing
}

context(Raise<String>)
fun shouldNotCapture(): () -> Unit {
return { raise("boom") }
}

context(Raise<String>)
fun inlineCaptureOk(): Unit {
listOf(1, 2, 3).map { raise("boom") }
}

@OptIn(ExperimentalContracts::class)
fun exactlyOne(f: () -> Unit): Unit {
contract {
callsInPlace(f, InvocationKind.EXACTLY_ONCE)
}
}

@OptIn(ExperimentalContracts::class)
fun exactlyOnce(f: () -> Unit): Unit {
contract {
callsInPlace(f, InvocationKind.EXACTLY_ONCE)
}
Comment on lines +30 to +32
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure where this is checked in DisallowLambdaCapture.kt? 🤔

}

fun main() {
val properties = Sample("j", 12).product()
println(properties)
context(Raise<String>)
fun ok(): () -> Unit = {
exactlyOnce { raise("boom") }
}