Propagate and enforce DisallowComposableCall

Change-Id: I65374fe2aa61b1d7ee9d4565f50372a994d7a11d
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/analysis/ComposableCheckerTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/analysis/ComposableCheckerTests.kt
index bc1a8ea..65fe6c1 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/analysis/ComposableCheckerTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/analysis/ComposableCheckerTests.kt
@@ -1075,4 +1075,20 @@
         """
         )
     }
+
+    fun testDisallowComposableCallPropagation() = check(
+        """
+        import androidx.compose.runtime.*
+        class Foo
+        @Composable inline fun a(block1: @DisallowComposableCalls () -> Foo): Foo {
+            return block1()
+        }
+        @Composable inline fun b(<!MISSING_DISALLOW_COMPOSABLE_CALLS_ANNOTATION!>block2: () -> Foo<!>): Foo {
+          return a { block2() }
+        }
+        @Composable inline fun c(block2: @DisallowComposableCalls () -> Foo): Foo {
+          return a { block2() }
+        }
+    """
+    )
 }
diff --git a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposableCallChecker.kt b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposableCallChecker.kt
index 6bd17a8..14360b3 100644
--- a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposableCallChecker.kt
+++ b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposableCallChecker.kt
@@ -83,12 +83,60 @@
         container.useInstance(this)
     }
 
+    fun checkInlineLambdaCall(
+        resolvedCall: ResolvedCall<*>,
+        reportOn: PsiElement,
+        context: CallCheckerContext
+    ) {
+        if (resolvedCall !is VariableAsFunctionResolvedCall) return
+        val descriptor = resolvedCall.variableCall.resultingDescriptor
+        if (descriptor !is ValueParameterDescriptor) return
+        if (descriptor.type.composablePreventCaptureContract() == true) return
+        val function = descriptor.containingDeclaration
+        if (
+            function is FunctionDescriptor &&
+            function.isInline &&
+            function.isMarkedAsComposable()
+        ) {
+            val bindingContext = context.trace.bindingContext
+            var node: PsiElement? = reportOn
+            loop@while (node != null) {
+                when (node) {
+                    is KtLambdaExpression -> {
+                        val arg = getArgumentDescriptor(node.functionLiteral, bindingContext)
+                        if (arg?.type?.composablePreventCaptureContract() == true) {
+                            val parameterSrc = descriptor.findPsi()
+                            if (parameterSrc != null) {
+                                missingDisallowedComposableCallPropagation(
+                                    context,
+                                    parameterSrc,
+                                    descriptor,
+                                    arg
+                                )
+                            }
+                        }
+                    }
+                    is KtFunction -> {
+                        val fn = bindingContext[BindingContext.FUNCTION, node]
+                        if (fn == function) {
+                            return
+                        }
+                    }
+                }
+                node = node.parent as? KtElement
+            }
+        }
+    }
+
     override fun check(
         resolvedCall: ResolvedCall<*>,
         reportOn: PsiElement,
         context: CallCheckerContext
     ) {
-        if (!resolvedCall.isComposableInvocation()) return
+        if (!resolvedCall.isComposableInvocation()) {
+            checkInlineLambdaCall(resolvedCall, reportOn, context)
+            return
+        }
         val bindingContext = context.trace.bindingContext
         var node: PsiElement? = reportOn
         loop@while (node != null) {
@@ -217,6 +265,22 @@
         }
     }
 
+    private fun missingDisallowedComposableCallPropagation(
+        context: CallCheckerContext,
+        unmarkedParamEl: PsiElement,
+        unmarkedParamDescriptor: ValueParameterDescriptor,
+        markedParamDescriptor: ValueParameterDescriptor
+    ) {
+        context.trace.report(
+            ComposeErrors.MISSING_DISALLOW_COMPOSABLE_CALLS_ANNOTATION.on(
+                unmarkedParamEl,
+                unmarkedParamDescriptor,
+                markedParamDescriptor,
+                markedParamDescriptor.containingDeclaration
+            )
+        )
+    }
+
     private fun illegalCall(
         context: CallCheckerContext,
         callEl: PsiElement,
diff --git a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeErrorMessages.kt b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeErrorMessages.kt
index e0e887a..fee8a4b7 100644
--- a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeErrorMessages.kt
+++ b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeErrorMessages.kt
@@ -53,6 +53,15 @@
         )
 
         MAP.put(
+            ComposeErrors.MISSING_DISALLOW_COMPOSABLE_CALLS_ANNOTATION,
+            "Parameter {0} cannot be inlined inside of lambda argument {1} of {2} " +
+                "without also being annotated with @DisallowComposableCalls",
+            Renderers.NAME,
+            Renderers.NAME,
+            Renderers.NAME
+        )
+
+        MAP.put(
             ComposeErrors.COMPOSABLE_PROPERTY_BACKING_FIELD,
             "Composable properties are not able to have backing fields"
         )
diff --git a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeErrors.kt b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeErrors.kt
index 8ef2b83..0b76433 100644
--- a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeErrors.kt
+++ b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeErrors.kt
@@ -17,10 +17,13 @@
 package androidx.compose.compiler.plugins.kotlin
 
 import com.intellij.psi.PsiElement
+import org.jetbrains.kotlin.descriptors.CallableDescriptor
 import org.jetbrains.kotlin.descriptors.DeclarationDescriptor
+import org.jetbrains.kotlin.descriptors.ValueParameterDescriptor
 import org.jetbrains.kotlin.diagnostics.DiagnosticFactory0
 import org.jetbrains.kotlin.diagnostics.DiagnosticFactory1
 import org.jetbrains.kotlin.diagnostics.DiagnosticFactory2
+import org.jetbrains.kotlin.diagnostics.DiagnosticFactory3
 import org.jetbrains.kotlin.diagnostics.Errors
 import org.jetbrains.kotlin.diagnostics.PositioningStrategies.DECLARATION_SIGNATURE_OR_DEFAULT
 import org.jetbrains.kotlin.diagnostics.Severity
@@ -74,6 +77,17 @@
             Severity.ERROR
         )
 
+    @JvmField
+    val MISSING_DISALLOW_COMPOSABLE_CALLS_ANNOTATION =
+        DiagnosticFactory3.create<
+            PsiElement,
+            ValueParameterDescriptor, // unmarked
+            ValueParameterDescriptor, // marked
+            CallableDescriptor
+            >(
+            Severity.ERROR
+        )
+
     // This error matches Kotlin's CONFLICTING_OVERLOADS error, except that it renders the
     // annotations with the descriptor. This is important to use for errors where the
     // only difference is whether or not it is annotated with @Composable or not.
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/SuspendingEffects.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/SuspendingEffects.kt
index ff15ec9..4f4ba72 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/SuspendingEffects.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/SuspendingEffects.kt
@@ -87,7 +87,7 @@
  */
 @Composable
 inline fun rememberCoroutineScope(
-    getContext: () -> CoroutineContext = { EmptyCoroutineContext }
+    getContext: @DisallowComposableCalls () -> CoroutineContext = { EmptyCoroutineContext }
 ): CoroutineScope {
     val composer = currentComposer
     val wrapper = remember {