Handle primitives overriding generics in KSP

When a kotlin class specifies a non-nullable primitive type for a
super's generic argument, kotlin will duplicate methods that receive the
generic as an argument.
This CL changes the KSP override check to ignore those overrides so that
it also reports two methods. Note that this only happens for paremeters
and not return values. For return values, only the boxed one is generated.

To handle the return type boxing, we now use the overridden desclaration
when wrapping types, which lets us choose between primitive and
non-primitive by looking at the declaration.

Bug: 160258066
Bug: 160322705
Test: XExecutableElementTest

Change-Id: Id2fb76748ff014bcabe4543daed211d82d04bd4c
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspMethodElement.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspMethodElement.kt
index 097c3c1..1303247 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspMethodElement.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspMethodElement.kt
@@ -96,12 +96,16 @@
         env, containing, declaration
     ) {
         override val returnType: XType by lazy {
+            // b/160258066
+            // we may need to box the return type if it is overriding a generic, hence, we should
+            // use the declaration of the overridee if available when deciding nullability
+            val overridee = declaration.findOverridee()
             env.wrap(
                 ksType = declaration.returnTypeAsMemberOf(
                     resolver = env.resolver,
                     ksType = containing.type.ksType
                 ),
-                originatingReference = checkNotNull(declaration.returnType)
+                originatingReference = checkNotNull(overridee?.returnType ?: declaration.returnType)
             )
         }
         override fun isSuspendFunction() = false
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspMethodType.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspMethodType.kt
index 16a9cc5..1b84172 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspMethodType.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspMethodType.kt
@@ -59,8 +59,12 @@
         containing: KspType
     ) : KspMethodType(env, origin, containing) {
         override val returnType: XType by lazy {
+            // b/160258066
+            // we may need to box the return type if it is overriding a generic, hence, we should
+            // use the declaration of the overridee if available when deciding nullability
+            val overridee = origin.declaration.findOverridee()
             env.wrap(
-                originatingReference = origin.declaration.returnType!!,
+                originatingReference = (overridee?.returnType ?: origin.declaration.returnType)!!,
                 ksType = origin.declaration.returnTypeAsMemberOf(
                     resolver = env.resolver,
                     ksType = containing.ksType
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/ResolverExt.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/ResolverExt.kt
index 876eb4a..a285ebf 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/ResolverExt.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/ResolverExt.kt
@@ -25,6 +25,8 @@
 import com.google.devtools.ksp.symbol.KSFunctionDeclaration
 import com.google.devtools.ksp.symbol.KSPropertyAccessor
 import com.google.devtools.ksp.symbol.KSPropertyDeclaration
+import com.google.devtools.ksp.symbol.KSTypeParameter
+import com.google.devtools.ksp.symbol.Nullability
 
 internal fun Resolver.findClass(qName: String) = getClassDeclarationByName(
     getKSNameFromString(qName)
@@ -84,12 +86,44 @@
         // workaround for https://github.com/google/ksp/issues/164
         null
     }
-    if (overridee == other) {
+    // before accepting this override, check if we have a primitive parameter that was a type
+    // reference in overridee. In those cases, kotlin will actually generate two jvm methods.
+    if (overridee == other && this.overridesInJvm(other)) {
         return true
     }
     return overridee?.overrides(other) ?: false
 }
 
+/**
+ * If the overrider specifies a primitive value for a type argument, ignore the override as
+ * kotlin will generate two class methods for them.
+ *
+ * see: b/160258066 for details
+ */
+private fun KSFunctionDeclaration.overridesInJvm(
+    other: KSFunctionDeclaration
+): Boolean {
+    parameters.forEachIndexed { index, myParam ->
+        val myParamType = myParam.type.resolve()
+        if (myParamType.nullability == Nullability.NOT_NULL) {
+            val myParamDecl = myParamType.declaration
+            val paramQName = myParamDecl.qualifiedName?.asString()
+            if (paramQName != null &&
+                KspTypeMapper.getPrimitiveJavaTypeName(paramQName) != null
+            ) {
+                // parameter is a primitive. Check if the parent declared it as a type argument,
+                // in which case, we should ignore the override.
+                val otherParamDeclaration = other.parameters
+                    .getOrNull(index)?.type?.resolve()?.declaration
+                if (otherParamDeclaration is KSTypeParameter) {
+                    return false
+                }
+            }
+        }
+    }
+    return true
+}
+
 private fun KSPropertyDeclaration.overrides(other: KSPropertyDeclaration): Boolean {
     val overridee = findOverridee()
     if (overridee == other) {
diff --git a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/XExecutableElementTest.kt b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/XExecutableElementTest.kt
index 4f4f83c..39fb06c 100644
--- a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/XExecutableElementTest.kt
+++ b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/XExecutableElementTest.kt
@@ -509,4 +509,124 @@
             }
         }
     }
+
+    @Test
+    fun genericToPrimitiveOverrides_methodElement() {
+        genericToPrimitiveOverrides(asMemberOf = false)
+    }
+
+    @Test
+    fun genericToPrimitiveOverrides_asMemberOf() {
+        genericToPrimitiveOverrides(asMemberOf = true)
+    }
+
+    // see b/160258066
+    private fun genericToPrimitiveOverrides(asMemberOf: Boolean) {
+        val source = Source.kotlin(
+            "Foo.kt",
+            """
+            interface Base<Key> {
+                fun getKey(id: Key): Unit
+                fun getKeyOverridden(id: Key): Unit
+                fun returnKey(): Key
+                fun returnKeyOverridden(): Key
+                fun getAndReturnKey(key: Key): Key
+                fun getAndReturnKeyOverridden(key: Key): Key
+            }
+            interface NonNullPrimitiveOverride : Base<Int> {
+                override fun getKeyOverridden(id: Int): Unit
+                override fun returnKeyOverridden(): Int
+                override fun getAndReturnKeyOverridden(key: Int): Int
+            }
+            interface NullablePrimitiveOverride : Base<Int?> {
+                override fun getKeyOverridden(id: Int?): Unit
+                override fun returnKeyOverridden(): Int?
+                override fun getAndReturnKeyOverridden(key: Int?): Int?
+            }
+            class Item
+            interface ClassOverride : Base<Item> {
+                override fun getKeyOverridden(id: Item): Unit
+                override fun returnKeyOverridden(): Item
+                override fun getAndReturnKeyOverridden(key: Item): Item
+            }
+            """.trimIndent()
+        )
+        runProcessorTest(sources = listOf(source)) { invocation ->
+            val objectMethodNames = invocation.processingEnv.requireTypeElement(TypeName.OBJECT)
+                .getAllNonPrivateInstanceMethods().map { it.name }.toSet()
+
+            fun XTypeElement.methodsSignature(): String {
+                return getAllNonPrivateInstanceMethods()
+                    .filterNot { it.name in objectMethodNames }
+                    .sortedBy {
+                        it.name
+                    }.joinToString("\n") { methodElement ->
+                        buildString {
+                            append(methodElement.name)
+                            append("(")
+                            val paramTypes = if (asMemberOf) {
+                                methodElement.asMemberOf(this@methodsSignature.type).parameterTypes
+                            } else {
+                                methodElement.parameters.map { it.type }
+                            }
+                            val paramsSignature = paramTypes.joinToString(",") {
+                                it.typeName.toString()
+                            }
+                            append(paramsSignature)
+                            append("):")
+                            val returnType = if (asMemberOf) {
+                                methodElement.asMemberOf(this@methodsSignature.type).returnType
+                            } else {
+                                methodElement.returnType
+                            }
+                            append(returnType.typeName)
+                        }
+                    }
+            }
+
+            val nonNullOverride =
+                invocation.processingEnv.requireTypeElement("NonNullPrimitiveOverride")
+            assertThat(
+                nonNullOverride.methodsSignature()
+            ).isEqualTo(
+                """
+                getAndReturnKey(java.lang.Integer):java.lang.Integer
+                getAndReturnKeyOverridden(int):java.lang.Integer
+                getAndReturnKeyOverridden(java.lang.Integer):java.lang.Integer
+                getKey(java.lang.Integer):void
+                getKeyOverridden(int):void
+                getKeyOverridden(java.lang.Integer):void
+                returnKey():java.lang.Integer
+                returnKeyOverridden():java.lang.Integer
+                """.trimIndent()
+            )
+            val nullableOverride =
+                invocation.processingEnv.requireTypeElement("NullablePrimitiveOverride")
+            assertThat(
+                nullableOverride.methodsSignature()
+            ).isEqualTo(
+                """
+                getAndReturnKey(java.lang.Integer):java.lang.Integer
+                getAndReturnKeyOverridden(java.lang.Integer):java.lang.Integer
+                getKey(java.lang.Integer):void
+                getKeyOverridden(java.lang.Integer):void
+                returnKey():java.lang.Integer
+                returnKeyOverridden():java.lang.Integer
+                """.trimIndent()
+            )
+            val classOverride = invocation.processingEnv.requireTypeElement("ClassOverride")
+            assertThat(
+                classOverride.methodsSignature()
+            ).isEqualTo(
+                """
+                getAndReturnKey(Item):Item
+                getAndReturnKeyOverridden(Item):Item
+                getKey(Item):void
+                getKeyOverridden(Item):void
+                returnKey():Item
+                returnKeyOverridden():Item
+                """.trimIndent()
+            )
+        }
+    }
 }