Merge "Special case overrides() check for suspend methods in KAPT." into androidx-main
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/javac/ElementExt.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/javac/ElementExt.kt
index aed0202..943644b 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/javac/ElementExt.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/javac/ElementExt.kt
@@ -18,7 +18,15 @@
 
 import androidx.room.compiler.processing.XNullability
 import com.google.auto.common.MoreElements
+import com.google.auto.common.MoreTypes
+import com.squareup.javapoet.ParameterizedTypeName
+import com.squareup.javapoet.TypeName
 import javax.lang.model.element.Element
+import javax.lang.model.element.ExecutableElement
+import javax.lang.model.element.Modifier
+import javax.lang.model.element.TypeElement
+import javax.lang.model.util.Types
+import kotlin.coroutines.Continuation
 
 private val NONNULL_ANNOTATIONS = arrayOf(
     "androidx.annotation.NonNull",
@@ -59,4 +67,85 @@
     } else {
         null
     }
+}
+
+/**
+ * Tests whether one suspend function, as a member of a given types, overrides another suspend
+ * function.
+ *
+ * This method assumes function one and two are suspend methods, i.e. they both return Object,
+ * have at least one parameter and the last parameter is of type Continuation. This method is
+ * similar to MoreElements.overrides() but doesn't check isSubsignature() due to Continuation's
+ * type arg being covariant, instead the equivalent is done by checking each parameter explicitly.
+ */
+internal fun suspendOverrides(
+    overrider: ExecutableElement,
+    overridden: ExecutableElement,
+    owner: TypeElement,
+    typeUtils: Types
+): Boolean {
+    if (overrider.simpleName != overridden.simpleName) {
+        return false
+    }
+    if (overrider.enclosingElement == overridden.enclosingElement) {
+        return false
+    }
+    if (overridden.modifiers.contains(Modifier.STATIC)) {
+        return false
+    }
+    if (overridden.modifiers.contains(Modifier.PRIVATE)) {
+        return false
+    }
+    val overriddenType = overridden.enclosingElement as? TypeElement ?: return false
+    if (!typeUtils.isSubtype(
+            typeUtils.erasure(owner.asType()),
+            typeUtils.erasure(overriddenType.asType()))
+    ) {
+        return false
+    }
+    val ownerType = MoreTypes.asDeclared(owner.asType())
+    val overriderExecutable = MoreTypes.asExecutable(typeUtils.asMemberOf(ownerType, overrider))
+    val overriddenExecutable = MoreTypes.asExecutable(typeUtils.asMemberOf(ownerType, overrider))
+    if (overriderExecutable.parameterTypes.size != overriddenExecutable.parameterTypes.size) {
+        return false
+    }
+    val continuationTypeName = TypeName.get(Continuation::class.java)
+    val overriderLastParamTypeName =
+        (TypeName.get(overriderExecutable.parameterTypes.last()) as? ParameterizedTypeName)
+            ?.rawType
+    check(overriderLastParamTypeName == continuationTypeName) {
+        "Expected $overriderLastParamTypeName to be $continuationTypeName"
+    }
+    val overriddenLastParamTypeName =
+        (TypeName.get(overriddenExecutable.parameterTypes.last()) as? ParameterizedTypeName)
+            ?.rawType
+    check(overriddenLastParamTypeName == continuationTypeName) {
+        "Expected $overriddenLastParamTypeName to be $continuationTypeName"
+    }
+    val overriderContinuationTypeArg =
+        MoreTypes.asDeclared(overriderExecutable.parameterTypes.last())
+            .typeArguments.single().extendsBound()
+    val overriddenContinuationTypeArg =
+        MoreTypes.asDeclared(overriderExecutable.parameterTypes.last())
+            .typeArguments.single().extendsBound()
+    if (!typeUtils.isSameType(
+            typeUtils.erasure(overriderContinuationTypeArg),
+            typeUtils.erasure(overriddenContinuationTypeArg))
+    ) {
+        return false
+    }
+    if (overriddenExecutable.parameterTypes.size >= 2) {
+        overriderExecutable.parameterTypes.zip(overriddenExecutable.parameterTypes)
+            .dropLast(1)
+            .forEach { (overriderParam, overriddenParam) ->
+                if (!typeUtils.isSameType(
+                        typeUtils.erasure(overriderParam),
+                        typeUtils.erasure(overriddenParam)
+                    )
+                ) {
+                    return false
+                }
+            }
+    }
+    return true
 }
\ No newline at end of file
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/javac/JavacMethodElement.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/javac/JavacMethodElement.kt
index baa51d1..167315a 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/javac/JavacMethodElement.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/javac/JavacMethodElement.kt
@@ -18,6 +18,7 @@
 
 import androidx.room.compiler.processing.XMethodElement
 import androidx.room.compiler.processing.XMethodType
+import androidx.room.compiler.processing.XProcessingEnv
 import androidx.room.compiler.processing.XType
 import androidx.room.compiler.processing.XTypeElement
 import androidx.room.compiler.processing.javac.kotlin.KmFunction
@@ -121,6 +122,14 @@
     override fun overrides(other: XMethodElement, owner: XTypeElement): Boolean {
         check(other is JavacMethodElement)
         check(owner is JavacTypeElement)
+        if (
+            env.backend == XProcessingEnv.Backend.JAVAC &&
+            this.isSuspendFunction() &&
+            other.isSuspendFunction()
+        ) {
+            // b/222240938 - Special case suspend functions in KAPT
+            return suspendOverrides(element, other.element, owner.element, env.typeUtils)
+        }
         // Use auto-common's overrides, which provides consistency across javac and ejc (Eclipse).
         return MoreElements.overrides(element, other.element, owner.element, env.typeUtils)
     }
diff --git a/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/XTypeElementTest.kt b/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/XTypeElementTest.kt
index 5970140..cf909f8 100644
--- a/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/XTypeElementTest.kt
+++ b/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/XTypeElementTest.kt
@@ -804,6 +804,64 @@
     }
 
     @Test
+    fun suspendOverride() {
+        val src = Source.kotlin(
+            "Foo.kt",
+            """
+            interface Base<T> {
+                suspend fun get(): T
+                suspend fun getAll(): List<T>
+                suspend fun putAll(input: List<T>)
+                suspend fun getAllWithDefault(): List<T>
+            }
+
+            interface DerivedInterface : Base<String> {
+                override suspend fun get(): String
+                override suspend fun getAll(): List<String>
+                override suspend fun putAll(input: List<String>)
+                override suspend fun getAllWithDefault(): List<String> {
+                    return emptyList()
+                }
+            }
+            """.trimIndent()
+        )
+        runProcessorTest(sources = listOf(src)) { invocation ->
+            val base = invocation.processingEnv.requireTypeElement("DerivedInterface")
+            val methodNames = base.getAllMethods().toList().jvmNames()
+            assertThat(methodNames).containsExactly("get", "getAll", "putAll", "getAllWithDefault")
+        }
+    }
+
+    @Test
+    fun suspendOverride_abstractClass() {
+        val src = Source.kotlin(
+            "Foo.kt",
+            """
+            abstract class Base<T> {
+                abstract suspend fun get(): T
+                abstract suspend fun getAll(): List<T>
+                abstract suspend fun putAll(input: List<T>)
+            }
+
+            abstract class DerivedClass : Base<Int>() {
+                abstract override suspend fun get(): Int
+                abstract override suspend fun getAll(): List<Int>
+                override suspend fun putAll(input: List<Int>) {
+                }
+            }
+            """.trimIndent()
+        )
+        runProcessorTest(sources = listOf(src)) { invocation ->
+            val base = invocation.processingEnv.requireTypeElement("DerivedClass")
+            val methodNamesCount =
+                base.getAllMethods().toList().jvmNames().groupingBy { it }.eachCount()
+            assertThat(methodNamesCount["get"]).isEqualTo(1)
+            assertThat(methodNamesCount["getAll"]).isEqualTo(1)
+            assertThat(methodNamesCount["putAll"]).isEqualTo(1)
+        }
+    }
+
+    @Test
     fun overrideMethodWithCovariantReturnType() {
         val src = Source.kotlin(
             "ParentWithExplicitOverride.kt",