Handle nullability in asMemberOf

This CL fixes a bug in asMemberOf where if a type is declared
as T? but swapped with a non-null type, we would consider it
non-null even though it should still be nullable.

Bug: 160322705
Test: KspAsMemberOfTest

Change-Id: I4e34adfeb0c16efe9d43fb302b5e590404bd64aa
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSPropertyDeclarationExt.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSPropertyDeclarationExt.kt
index 15dfd80..8d60651 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSPropertyDeclarationExt.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSPropertyDeclarationExt.kt
@@ -16,13 +16,18 @@
 
 package androidx.room.compiler.processing.ksp
 
+import org.jetbrains.kotlin.ksp.closestClassDeclaration
 import org.jetbrains.kotlin.ksp.getAllSuperTypes
+import org.jetbrains.kotlin.ksp.processing.Resolver
 import org.jetbrains.kotlin.ksp.symbol.KSClassDeclaration
+import org.jetbrains.kotlin.ksp.symbol.KSDeclaration
 import org.jetbrains.kotlin.ksp.symbol.KSName
 import org.jetbrains.kotlin.ksp.symbol.KSPropertyDeclaration
 import org.jetbrains.kotlin.ksp.symbol.KSType
 import org.jetbrains.kotlin.ksp.symbol.KSTypeArgument
 import org.jetbrains.kotlin.ksp.symbol.KSTypeParameter
+import org.jetbrains.kotlin.ksp.symbol.KSTypeReference
+import org.jetbrains.kotlin.ksp.symbol.Nullability
 
 /**
  * Returns the type of a property as if it is member of the given [ksType].
@@ -31,44 +36,105 @@
  * handle inner classes properly.
  * TODO: remove once https://github.com/android/kotlin/issues/26 is implemented
  */
-internal fun KSPropertyDeclaration.typeAsMemberOf(ksType: KSType): KSType {
+internal fun KSPropertyDeclaration.typeAsMemberOf(resolver: Resolver, ksType: KSType): KSType {
     val myType: KSType = checkNotNull(type?.requireType()) {
         "Cannot find type of Kotlin property: $this"
     }
-    val parent = checkNotNull(findEnclosingAncestorClassDeclaration()) {
-        "Cannot find containing class for property. $this"
-    }
-    // TODO traverse grandparents if parent is an inner class as TypeArguments might be declared
-    //  there as well.
-    val matchingParentType: KSType = (ksType.declaration as? KSClassDeclaration)
+    return myType.asMemberOf(resolver, this, ksType)
+}
+
+/**
+ * Returns `this` type as member of the [other] type.
+ *
+ * @param resolver The KSP resolver instance
+ * @param declaration The KSDeclaration where the owner of this type is defined. Note that this can
+ * be different from [KSType.declaration]. For instance, if you have a class `Foo<T>` with property
+ * `x : List<T>`, `x`'s type declaration is `kotlin.List` whereas the declaration that
+ * should be passed here is `x` (from which the implementation will find `Foo`). On the other hand,
+ * `T` of `List<T>`'s declaration is already in `Foo`.
+ * @param other The new owner for this type. For instance, if you want to resolve `x` in
+ * `Bar<String>`, this would be the star projected type of `Bar`.
+ */
+internal fun KSType.asMemberOf(
+    resolver: Resolver,
+    declaration: KSDeclaration,
+    other: KSType
+): KSType {
+    val parent = declaration.closestClassDeclaration() ?: return this
+    val parentQName = parent.qualifiedName ?: return this
+    val matchingParentType: KSType = (other.declaration as? KSClassDeclaration)
         ?.getAllSuperTypes()
         ?.firstOrNull {
-            it.starProjection().declaration == parentDeclaration
-        } ?: return myType
+            it.starProjection().declaration.qualifiedName == parentQName
+        } ?: return this
     // create a map of replacements.
     val replacements = parent.typeParameters.mapIndexed { index, ksTypeParameter ->
         ksTypeParameter.name to matchingParentType.arguments.getOrNull(index)
     }.toMap()
-    return myType.replaceFromMap(replacements)
+    return replaceFromMap(resolver, replacements)
 }
 
-private fun KSTypeArgument.replaceFromMap(arguments: Map<KSName, KSTypeArgument?>): KSTypeArgument {
-    val myTypeDeclaration = type?.resolve()?.declaration
+private fun KSTypeArgument.replaceFromMap(
+    resolver: Resolver,
+    arguments: Map<KSName, KSTypeArgument?>
+): KSTypeArgument {
+    val resolvedType = type?.resolve()
+    val myTypeDeclaration = resolvedType?.declaration
     if (myTypeDeclaration is KSTypeParameter) {
-        return arguments[myTypeDeclaration.name] ?: this
+        val match = arguments[myTypeDeclaration.name] ?: return this
+        // workaround for https://github.com/google/ksp/issues/82
+        val explicitNullable = resolvedType.makeNullable() == resolvedType
+        return if (explicitNullable) {
+            match.makeNullable(resolver)
+        } else {
+            match
+        }
     }
     return this
 }
 
-private fun KSType.replaceFromMap(arguments: Map<KSName, KSTypeArgument?>): KSType {
+private fun KSType.replaceFromMap(
+    resolver: Resolver,
+    arguments: Map<KSName, KSTypeArgument?>
+): KSType {
     val myDeclaration = this.declaration
     if (myDeclaration is KSTypeParameter) {
-        return arguments[myDeclaration.name]?.type?.resolve() ?: this
+        val match = arguments[myDeclaration.name]?.type?.resolve() ?: return this
+        // workaround for https://github.com/google/ksp/issues/82
+        val explicitNullable = this.makeNullable() == this
+        return if (explicitNullable) {
+            match.makeNullable()
+        } else {
+            match
+        }
     }
     if (this.arguments.isEmpty()) {
         return this
     }
     return replace(this.arguments.map {
-        it.replaceFromMap(arguments)
+        it.replaceFromMap(resolver, arguments)
     })
 }
+
+private fun KSTypeArgument.makeNullable(resolver: Resolver): KSTypeArgument {
+    val myType = type
+    val resolved = myType?.resolve() ?: return this
+    if (resolved.nullability == Nullability.NULLABLE) {
+        return this
+    }
+    return resolver.getTypeArgument(myType.swapResolvedType(resolved.makeNullable()), variance)
+}
+
+private fun KSTypeReference.swapResolvedType(replacement: KSType): KSTypeReference {
+    return DelegatingTypeReference(
+        original = this,
+        resolved = replacement
+    )
+}
+
+private class DelegatingTypeReference(
+    val original: KSTypeReference,
+    val resolved: KSType
+) : KSTypeReference by original {
+    override fun resolve(): KSType? = resolved
+}
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspFieldElement.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspFieldElement.kt
index d2634bb..ca35efb 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspFieldElement.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspFieldElement.kt
@@ -42,7 +42,7 @@
     }
 
     override val type: XType by lazy {
-        env.wrap(declaration.typeAsMemberOf(containing.type.ksType))
+        env.wrap(declaration.typeAsMemberOf(env.resolver, containing.type.ksType))
     }
 
     override fun asMemberOf(other: XDeclaredType): XType {
@@ -50,7 +50,7 @@
             return type
         }
         check(other is KspType)
-        val asMember = declaration.typeAsMemberOf(other.ksType)
+        val asMember = declaration.typeAsMemberOf(env.resolver, other.ksType)
         return env.wrap(asMember)
     }
 }
diff --git a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/ksp/KSProertyDeclarationExtTest.kt b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/ksp/KSProertyDeclarationExtTest.kt
index 287c7b7..a4d59c3 100644
--- a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/ksp/KSProertyDeclarationExtTest.kt
+++ b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/ksp/KSProertyDeclarationExtTest.kt
@@ -29,6 +29,7 @@
 import org.jetbrains.kotlin.ksp.getDeclaredProperties
 import org.jetbrains.kotlin.ksp.symbol.KSClassDeclaration
 import org.jetbrains.kotlin.ksp.symbol.KSPropertyDeclaration
+import org.jetbrains.kotlin.ksp.symbol.Nullability
 import org.junit.Test
 
 class KSProertyDeclarationExtTest {
@@ -51,17 +52,18 @@
             val base = invocation.requireClass("BaseClass")
             val sub = invocation.requireClass("SubClass").asStarProjectedType()
             base.requireProperty("genericProp").let { prop ->
-                assertThat(prop.typeAsMemberOf(sub).typeName()).isEqualTo(INT_CLASS_NAME)
+                assertThat(prop.typeAsMemberOf(invocation.kspResolver, sub).typeName())
+                    .isEqualTo(INT_CLASS_NAME)
             }
             base.requireProperty("listOfGeneric").let { prop ->
-                assertThat(prop.typeAsMemberOf(sub).typeName())
+                assertThat(prop.typeAsMemberOf(invocation.kspResolver, sub).typeName())
                     .isEqualTo(ParameterizedTypeName.get(LIST_CLASS_NAME, INT_CLASS_NAME))
             }
 
             val listOfStringsTypeName =
                 ParameterizedTypeName.get(LIST_CLASS_NAME, STRING_CLASS_NAME)
             base.requireProperty("mapOfStringToGeneric2").let { prop ->
-                assertThat(prop.typeAsMemberOf(sub).typeName())
+                assertThat(prop.typeAsMemberOf(invocation.kspResolver, sub).typeName())
                     .isEqualTo(
                         ParameterizedTypeName.get(
                             MAP_CLASS_NAME, STRING_CLASS_NAME, listOfStringsTypeName
@@ -70,7 +72,7 @@
             }
 
             base.requireProperty("pairOfGenerics").let { prop ->
-                assertThat(prop.typeAsMemberOf(sub).typeName())
+                assertThat(prop.typeAsMemberOf(invocation.kspResolver, sub).typeName())
                     .isEqualTo(
                         ParameterizedTypeName.get(
                             PAIR_CLASS_NAME, INT_CLASS_NAME, listOfStringsTypeName
@@ -80,6 +82,70 @@
         }
     }
 
+    @Test
+    fun asMemberOfNullabilityResolution() {
+        val src = Source.kotlin(
+            "Foo.kt", """
+            open class MyInterface<T> {
+                val inheritedProp: T = TODO()
+                var nullableProp: T? = TODO()
+                val inheritedGenericProp: List<T> = TODO()
+                val nullableGenericProp: List<T?> = TODO()
+            }
+            abstract class NonNullSubject : MyInterface<String>()
+            abstract class NullableSubject: MyInterface<String?>()
+        """.trimIndent()
+        )
+        runKspTest(sources = listOf(src), succeed = true) { invocation ->
+            val myInterface = invocation.requireClass("MyInterface")
+            val nonNullSubject = invocation.requireClass("NonNullSubject").asStarProjectedType()
+            val nullableSubject = invocation.requireClass("NullableSubject").asStarProjectedType()
+            val inheritedProp = myInterface.requireProperty("inheritedProp")
+            assertThat(
+                inheritedProp.typeAsMemberOf(invocation.kspResolver, nonNullSubject).nullability
+            ).isEqualTo(Nullability.NOT_NULL)
+            assertThat(
+                inheritedProp.typeAsMemberOf(invocation.kspResolver, nullableSubject).nullability
+            ).isEqualTo(Nullability.NULLABLE)
+
+            val nullableProp = myInterface.requireProperty("nullableProp")
+            assertThat(
+                nullableProp.typeAsMemberOf(invocation.kspResolver, nonNullSubject).nullability
+            ).isEqualTo(Nullability.NULLABLE)
+            assertThat(
+                nullableProp.typeAsMemberOf(invocation.kspResolver, nullableSubject).nullability
+            ).isEqualTo(Nullability.NULLABLE)
+
+            val inheritedGenericProp = myInterface.requireProperty("inheritedGenericProp")
+            inheritedGenericProp.typeAsMemberOf(invocation.kspResolver, nonNullSubject).let {
+                assertThat(it.nullability).isEqualTo(Nullability.NOT_NULL)
+                assertThat(
+                    it.arguments.first().type?.resolve()?.nullability
+                ).isEqualTo(Nullability.NOT_NULL)
+            }
+            inheritedGenericProp.typeAsMemberOf(invocation.kspResolver, nullableSubject).let {
+                assertThat(it.nullability).isEqualTo(Nullability.NOT_NULL)
+                assertThat(
+                    it.arguments.first().type?.resolve()?.nullability
+                ).isEqualTo(Nullability.NULLABLE)
+            }
+
+            val nullableGenericProp = myInterface.requireProperty("nullableGenericProp")
+            nullableGenericProp.typeAsMemberOf(invocation.kspResolver, nonNullSubject).let {
+                assertThat(it.nullability).isEqualTo(Nullability.NOT_NULL)
+                assertThat(
+                    it.arguments.first().type?.resolve()?.nullability
+                ).isEqualTo(Nullability.NULLABLE)
+            }
+            nullableGenericProp.typeAsMemberOf(invocation.kspResolver, nullableSubject).let {
+                assertThat(it.nullability).isEqualTo(Nullability.NOT_NULL)
+                assertThat(
+                    it.arguments.first().type?.resolve()?.nullability
+                ).isEqualTo(Nullability.NULLABLE)
+            }
+        }
+    }
+
     private fun TestInvocation.requireClass(name: String): KSClassDeclaration {
         val resolver = (processingEnv as KspProcessingEnv).resolver
         return resolver.requireClass(name)
diff --git a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/util/TestInvocation.kt b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/util/TestInvocation.kt
index 95e4ead..e3ab7d4 100644
--- a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/util/TestInvocation.kt
+++ b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/util/TestInvocation.kt
@@ -17,7 +17,12 @@
 package androidx.room.compiler.processing.util
 
 import androidx.room.compiler.processing.XProcessingEnv
+import androidx.room.compiler.processing.ksp.KspProcessingEnv
+import org.jetbrains.kotlin.ksp.processing.Resolver
 
 class TestInvocation(
     val processingEnv: XProcessingEnv
-)
+) {
+    val kspResolver: Resolver
+        get() = (processingEnv as KspProcessingEnv).resolver
+}