Add NonDeclared types for KSP

In KSP, all types are modeled as KSType but that does not match
the model of Java's TypeMirror, DeclaredType and ArrayType.

This CL is an attempt to mimic the java types in KSP.

To do this, I've created sub classes for KspType for declared
and type arguments. The TypeArgument one is specifically necessary
to create the right typeName for them and also distinguish them
from declared types.

This CL is not fully complete, we need to make the KspTypeTests
run with KAPT to ensure we match KAPT as much as possible where it
makes sense (e.g. for nullability, KSP is better than KAPT and we
should keep it).
Similarly, we need to make all of XTypeTests run with KSP which
requires implementing primitive types (somehow) in KSP.
I'll implement those in followups.

Bug: 160322705
Test: XTypeTest, KSTypeExtTest, KspTypeTest

Change-Id: Icd0e91ca14ce9fd34924e5c4cb90b7ac26aaff84
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeExt.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeExt.kt
index defa443..2af4314 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeExt.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeExt.kt
@@ -19,10 +19,12 @@
 import com.squareup.javapoet.ClassName
 import com.squareup.javapoet.ParameterizedTypeName
 import com.squareup.javapoet.TypeName
+import com.squareup.javapoet.TypeVariableName
 import com.squareup.javapoet.WildcardTypeName
 import org.jetbrains.kotlin.ksp.symbol.KSDeclaration
 import org.jetbrains.kotlin.ksp.symbol.KSType
 import org.jetbrains.kotlin.ksp.symbol.KSTypeArgument
+import org.jetbrains.kotlin.ksp.symbol.KSTypeParameter
 import org.jetbrains.kotlin.ksp.symbol.KSTypeReference
 import org.jetbrains.kotlin.ksp.symbol.Variance
 
@@ -42,7 +44,13 @@
     return if (this == null) {
         ERROR_TYPE_NAME
     } else {
-        requireType().typeName()
+        val resolvedType = try {
+            requireType()
+        } catch (illegalState: IllegalStateException) {
+            // workaround for https://github.com/google/ksp/issues/101
+            null
+        }
+        resolvedType?.typeName() ?: ERROR_TYPE_NAME
     }
 }
 
@@ -62,15 +70,32 @@
     return ClassName.get(pkg, shortNames.first(), *(shortNames.drop(1).toTypedArray()))
 }
 
+internal fun KSTypeArgument.typeName(
+    param: KSTypeParameter
+): TypeName {
+    return when (variance) {
+        Variance.CONTRAVARIANT -> WildcardTypeName.supertypeOf(type.typeName())
+        Variance.COVARIANT -> WildcardTypeName.subtypeOf(type.typeName())
+        Variance.STAR -> {
+            // for star projected types, JavaPoet uses the name from the declaration if
+            // * is not given explicitly
+            if (type == null) {
+                // explicit *
+                WildcardTypeName.subtypeOf(TypeName.OBJECT)
+            } else {
+                TypeVariableName.get(param.name.asString(), type.typeName())
+            }
+        }
+        else -> type.typeName()
+    }
+}
+
 internal fun KSType.typeName(): TypeName {
     return if (this.arguments.isNotEmpty()) {
-        val args: Array<TypeName> = this.arguments.map {
-            when (it.variance) {
-                Variance.CONTRAVARIANT -> WildcardTypeName.supertypeOf(it.type.typeName())
-                Variance.COVARIANT -> WildcardTypeName.subtypeOf(it.type.typeName())
-                Variance.STAR -> WildcardTypeName.subtypeOf(TypeName.OBJECT)
-                else -> it.type.typeName()
-            }
+        val args: Array<TypeName> = this.arguments.mapIndexed { index, typeArg ->
+            typeArg.typeName(
+                this.declaration.typeParameters[index]
+            )
         }.toTypedArray()
         val className = declaration.typeName()
         ParameterizedTypeName.get(
@@ -95,6 +120,13 @@
     }
 }
 
+/**
+ * see: https://github.com/google/ksp/issues/101
+ * Wildcard resolution might throw. We are not catching it here as we don't have a good fallback,
+ * instead, catching it in the caller when we have an option to handle. And callers which do not
+ * have a way to handle will just crash for now until the issue is resolved.
+ */
+@Throws(IllegalStateException::class)
 internal fun KSTypeReference.requireType(): KSType {
     return checkNotNull(resolve()) {
         "Resolve in type reference should not have returned null, please file a bug. $this"
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspArrayType.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspArrayType.kt
index df89b7f..8e613b2 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspArrayType.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspArrayType.kt
@@ -25,15 +25,15 @@
 internal class KspArrayType(
     env: KspProcessingEnv,
     ksType: KSType
-) : KspType(
+) : KspDeclaredType( // in kotlin, array types are also declared
     env, ksType
 ),
     XArrayType {
     override val componentType: XType by lazy {
-        typeArguments.first()
+        typeArguments.first().extendsBoundOrSelf()
     }
 
     override val typeName: TypeName by lazy {
-        ArrayTypeName.of(typeArguments.first().typeName)
+        ArrayTypeName.of(componentType.typeName)
     }
-}
+}
\ No newline at end of file
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspDeclaredType.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspDeclaredType.kt
new file mode 100644
index 0000000..1e75aea
--- /dev/null
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspDeclaredType.kt
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.room.compiler.processing.ksp
+
+import androidx.room.compiler.processing.XDeclaredType
+import androidx.room.compiler.processing.XType
+import com.squareup.javapoet.TypeName
+import org.jetbrains.kotlin.ksp.symbol.KSType
+
+internal open class KspDeclaredType(
+    env: KspProcessingEnv,
+    ksType: KSType
+) : KspType(env, ksType), XDeclaredType {
+    override val typeName: TypeName by lazy {
+        ksType.typeName()
+    }
+
+    override val typeArguments: List<XType> by lazy {
+        ksType.arguments.mapIndexed { index, arg ->
+            env.wrap(ksType.declaration.typeParameters[index], arg)
+        }
+    }
+}
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspMethodElement.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspMethodElement.kt
index fae37ac..d0cd59f 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspMethodElement.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspMethodElement.kt
@@ -58,7 +58,7 @@
     }
 
     override fun asMemberOf(other: XDeclaredType): XMethodType {
-        check(other is KspType)
+        check(other is KspDeclaredType)
         return KspMethodType.create(
             env = env,
             origin = this,
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspMethodType.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspMethodType.kt
index 15a1487..cca29f31 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspMethodType.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspMethodType.kt
@@ -24,7 +24,7 @@
 internal sealed class KspMethodType(
     val env: KspProcessingEnv,
     val origin: KspMethodElement,
-    val containing: KspType
+    val containing: KspDeclaredType
 ) : XMethodType {
     override val parameterTypes: List<XType> by lazy {
         origin.parameters.map {
@@ -48,7 +48,7 @@
     private class KspNormalMethodType(
         env: KspProcessingEnv,
         origin: KspMethodElement,
-        containing: KspType
+        containing: KspDeclaredType
     ) : KspMethodType(env, origin, containing) {
         override val returnType: XType by lazy {
             env.wrap(
@@ -63,7 +63,7 @@
     private class KspSuspendMethodType(
         env: KspProcessingEnv,
         origin: KspMethodElement,
-        containing: KspType
+        containing: KspDeclaredType
     ) : KspMethodType(env, origin, containing), XSuspendMethodType {
         override val returnType: XType
             // suspend functions always return Any?, no need to call asMemberOf
@@ -83,7 +83,7 @@
         fun create(
             env: KspProcessingEnv,
             origin: KspMethodElement,
-            containing: KspType
+            containing: KspDeclaredType
         ) = if (origin.isSuspendFunction()) {
             KspSuspendMethodType(env, origin, containing)
         } else {
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspProcessingEnv.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspProcessingEnv.kt
index 9137b71..41043be 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspProcessingEnv.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspProcessingEnv.kt
@@ -29,6 +29,8 @@
 import org.jetbrains.kotlin.ksp.processing.Resolver
 import org.jetbrains.kotlin.ksp.symbol.KSClassDeclaration
 import org.jetbrains.kotlin.ksp.symbol.KSType
+import org.jetbrains.kotlin.ksp.symbol.KSTypeArgument
+import org.jetbrains.kotlin.ksp.symbol.KSTypeParameter
 import org.jetbrains.kotlin.ksp.symbol.KSTypeReference
 import org.jetbrains.kotlin.ksp.symbol.Variance
 
@@ -73,7 +75,21 @@
     }
 
     override fun getDeclaredType(type: XTypeElement, vararg types: XType): XDeclaredType {
-        TODO("Not yet implemented")
+        check(type is KspTypeElement) {
+            "Unexpected type element type: $type"
+        }
+        val typeArguments = types.map { argType ->
+            check(argType is KspType) {
+                "$argType is not an instance of KspType"
+            }
+            resolver.getTypeArgument(
+                argType.ksType.createTypeReference(),
+                variance = Variance.INVARIANT
+            )
+        }
+        return wrap(
+            type.declaration.asType(typeArguments)
+        )
     }
 
     override fun getArrayType(type: XType): XArrayType {
@@ -93,24 +109,29 @@
         )
     }
 
-    fun wrap(ksType: KSType): KspType {
+    fun wrap(ksType: KSType): KspDeclaredType {
         return if (ksType.declaration.qualifiedName?.asString() == KOTLIN_ARRAY_Q_NAME) {
             KspArrayType(
                 env = this,
                 ksType = ksType
             )
         } else {
-            KspType(
-                env = this,
-                ksType = ksType
-            )
+            KspDeclaredType(this, ksType)
         }
     }
 
-    fun wrap(ksTypeReference: KSTypeReference): KspType {
+    fun wrap(ksTypeReference: KSTypeReference): KspDeclaredType {
         return wrap(ksTypeReference.requireType())
     }
 
+    fun wrap(ksTypeParam: KSTypeParameter, ksTypeArgument: KSTypeArgument): KspTypeArgumentType {
+        return KspTypeArgumentType(
+            env = this,
+            typeArg = ksTypeArgument,
+            typeParam = ksTypeParam
+        )
+    }
+
     fun wrapClassDeclaration(declaration: KSClassDeclaration): KspTypeElement {
         return KspTypeElement(
             env = this,
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspRawType.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspRawType.kt
index 7dc0ba6..c5367d0 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspRawType.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspRawType.kt
@@ -42,4 +42,8 @@
     override fun hashCode(): Int {
         return typeName.hashCode()
     }
+
+    override fun toString(): String {
+        return typeName.toString()
+    }
 }
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspType.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspType.kt
index 1fecce7..ed13785 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspType.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspType.kt
@@ -16,12 +16,10 @@
 
 package androidx.room.compiler.processing.ksp
 
-import androidx.room.compiler.processing.XDeclaredType
 import androidx.room.compiler.processing.XEquality
 import androidx.room.compiler.processing.XNullability
 import androidx.room.compiler.processing.XType
 import androidx.room.compiler.processing.XTypeElement
-import com.squareup.javapoet.TypeName
 import org.jetbrains.kotlin.ksp.symbol.KSClassDeclaration
 import org.jetbrains.kotlin.ksp.symbol.KSType
 import org.jetbrains.kotlin.ksp.symbol.KSTypeReference
@@ -36,23 +34,14 @@
  * We don't necessarily have a [KSTypeReference] (e.g. if we are getting it from an element).
  * Similarly, we may not be able to get a [KSType] (e.g. if it resolves to error).
  */
-internal open class KspType(
+internal abstract class KspType(
     private val env: KspProcessingEnv,
     val ksType: KSType
-) : XDeclaredType, XEquality {
+) : XType, XEquality {
     override val rawType by lazy {
         KspRawType(this)
     }
 
-    override val typeArguments: List<XType> by lazy {
-        ksType.arguments.map {
-            env.wrap(it.type!!)
-        }
-    }
-    override val typeName: TypeName by lazy {
-        ksType.typeName()
-    }
-
     override val nullability by lazy {
         when (ksType.nullability) {
             Nullability.NULLABLE -> XNullability.NULLABLE
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspTypeArgumentType.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspTypeArgumentType.kt
new file mode 100644
index 0000000..d3adffa
--- /dev/null
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspTypeArgumentType.kt
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.room.compiler.processing.ksp
+
+import com.squareup.javapoet.TypeName
+import org.jetbrains.kotlin.ksp.symbol.KSTypeArgument
+import org.jetbrains.kotlin.ksp.symbol.KSTypeParameter
+
+/**
+ * The typeName for type arguments requires the type parameter, hence we have a special type
+ * for them when we produce them.
+ */
+internal class KspTypeArgumentType(
+    env: KspProcessingEnv,
+    val typeParam: KSTypeParameter,
+    val typeArg: KSTypeArgument
+) : KspType(
+    env = env,
+    ksType = typeArg.requireType()
+) {
+    override val typeName: TypeName by lazy {
+        typeArg.typeName(typeParam)
+    }
+}
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspTypeElement.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspTypeElement.kt
index e9c5e76..bc86345 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspTypeElement.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspTypeElement.kt
@@ -71,7 +71,7 @@
         }
     }
 
-    override val type: KspType by lazy {
+    override val type: KspDeclaredType by lazy {
         env.wrap(declaration.asStarProjectedType())
     }
 
diff --git a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/XTypeTest.kt b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/XTypeTest.kt
index 9c74318..0258404 100644
--- a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/XTypeTest.kt
+++ b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/XTypeTest.kt
@@ -20,13 +20,17 @@
 import androidx.room.compiler.processing.util.getDeclaredMethod
 import androidx.room.compiler.processing.util.getField
 import androidx.room.compiler.processing.util.getMethod
+import androidx.room.compiler.processing.util.runKspTest
 import androidx.room.compiler.processing.util.runProcessorTest
 import androidx.room.compiler.processing.util.runProcessorTestForFailedCompilation
+import androidx.room.compiler.processing.util.runProcessorTestIncludingKsp
+import com.google.common.truth.Truth
 import com.google.common.truth.Truth.assertThat
 import com.squareup.javapoet.ClassName
 import com.squareup.javapoet.ParameterizedTypeName
 import com.squareup.javapoet.TypeName
 import com.squareup.javapoet.TypeVariableName
+import org.jetbrains.kotlin.ksp.getClassDeclarationByName
 import org.junit.Test
 import org.junit.runner.RunWith
 import org.junit.runners.JUnit4
@@ -46,7 +50,7 @@
             }
             """.trimIndent()
         )
-        runProcessorTest(
+        runProcessorTestIncludingKsp(
             sources = listOf(parent)
         ) {
             val type = it.processingEnv.requireType("foo.bar.Parent") as XDeclaredType
@@ -61,13 +65,21 @@
 
             val typeArguments = type.typeArguments
             assertThat(typeArguments).hasSize(1)
+            val inputStreamClassName = ClassName.get("java.io", "InputStream")
             typeArguments.first().let { firstType ->
                 assertThat(firstType.isDeclared()).isFalse()
                 val expected = TypeVariableName.get(
                     "InputStreamType",
-                    ClassName.get("java.io", "InputStream")
+                    inputStreamClassName
                 )
                 assertThat(firstType.typeName).isEqualTo(expected)
+                // equals in TypeVariableName just checks the string representation but we want
+                // to assert the upper bound as well
+                assertThat(
+                    (firstType.typeName as TypeVariableName).bounds
+                ).containsExactly(
+                    inputStreamClassName
+                )
             }
 
             type.asTypeElement().getMethod("wildcardParam").let { method ->
@@ -75,7 +87,7 @@
                 val extendsBoundOrSelf = wildcardParam.type.extendsBoundOrSelf()
                 assertThat(extendsBoundOrSelf.rawType)
                     .isEqualTo(
-                        it.processingEnv.requireType("java.util.Set").rawType
+                        it.processingEnv.requireType(it.types.mutableSet).rawType
                     )
             }
         }
@@ -95,6 +107,7 @@
                 }
             """.trimIndent()
         )
+        // TODO run with KSP as well once https://github.com/google/ksp/issues/107 is resolved
         runProcessorTestForFailedCompilation(
             sources = listOf(missingTypeRef)
         ) {
@@ -125,7 +138,7 @@
             }
             """.trimIndent()
         )
-        runProcessorTest(
+        runProcessorTestIncludingKsp(
             sources = listOf(subject)
         ) {
             val type = it.processingEnv.requireType("foo.bar.Baz")
@@ -143,6 +156,10 @@
             it.processingEnv.requireType("java.util.List").let { list ->
                 assertThat(list.isCollection()).isTrue()
             }
+            it.processingEnv.requireType("java.util.ArrayList").let { list ->
+                // isCollection is overloaded name, it is actually just checking list or set.
+                assertThat(list.isCollection()).isFalse()
+            }
             it.processingEnv.requireType("java.util.Set").let { list ->
                 assertThat(list.isCollection()).isTrue()
             }
@@ -153,15 +170,33 @@
     }
 
     @Test
-    fun toStringMatchesUnderlyingElement() {
-        runProcessorTest {
-            it.processingEnv.requireType("java.lang.Integer").let { map ->
-                assertThat(map.toString()).isEqualTo("java.lang.Integer")
+    fun isCollection_kotlin() {
+        runKspTest(sources = emptyList(), succeed = true) { invocation ->
+            val subjects = listOf("Map" to false, "List" to true, "Set" to true)
+            subjects.forEach { (subject, expected) ->
+                invocation.processingEnv.requireType("kotlin.collections.$subject").let { type ->
+                    Truth.assertWithMessage(type.typeName.toString())
+                        .that(type.isCollection()).isEqualTo(expected)
+                }
             }
         }
     }
 
     @Test
+    fun toStringMatchesUnderlyingElement() {
+        runProcessorTestIncludingKsp {
+            val subject = "java.lang.String"
+            val expected = if (it.isKsp) {
+                it.kspResolver.getClassDeclarationByName(subject)?.toString()
+            } else {
+                it.javaElementUtils.getTypeElement(subject)?.toString()
+            }
+            val actual = it.processingEnv.requireType(subject).toString()
+            assertThat(actual).isEqualTo(expected)
+        }
+    }
+
+    @Test
     fun errorTypeForSuper() {
         val missingTypeRef = Source.java(
             "foo.bar.Baz",
@@ -174,6 +209,7 @@
                 }
             """.trimIndent()
         )
+        // TODO run with KSP as well once https://github.com/google/ksp/issues/107 is resolved
         runProcessorTestForFailedCompilation(
             sources = listOf(missingTypeRef)
         ) {
@@ -217,7 +253,7 @@
 
     @Test
     fun rawType() {
-        runProcessorTest {
+        runProcessorTestIncludingKsp {
             val subject = it.processingEnv.getDeclaredType(
                 it.processingEnv.requireTypeElement(List::class),
                 it.processingEnv.requireType(String::class)
diff --git a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/ksp/KSTypeExtTest.kt b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/ksp/KSTypeExtTest.kt
index 9797393..d95171c 100644
--- a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/ksp/KSTypeExtTest.kt
+++ b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/ksp/KSTypeExtTest.kt
@@ -208,7 +208,7 @@
         }
         // make sure we grabbed some values to ensure test is working
         assertThat(golden).isNotEmpty()
-        assertThat(golden).containsExactlyEntriesIn(kspResults)
+        assertThat(kspResults).containsExactlyEntriesIn(golden)
     }
 
     private fun runTest(
diff --git a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/ksp/KspTypeTest.kt b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/ksp/KspTypeTest.kt
index 08c9de1..5be5351 100644
--- a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/ksp/KspTypeTest.kt
+++ b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/ksp/KspTypeTest.kt
@@ -23,6 +23,7 @@
 import androidx.room.compiler.processing.util.runKspTest
 import com.google.common.truth.Truth.assertThat
 import com.squareup.javapoet.ClassName
+import com.squareup.javapoet.WildcardTypeName
 import org.jetbrains.kotlin.ksp.getClassDeclarationByName
 import org.jetbrains.kotlin.ksp.getDeclaredFunctions
 import org.jetbrains.kotlin.ksp.symbol.KSPropertyDeclaration
@@ -488,12 +489,16 @@
             val paramType = invocation.wrap(method.parameters.first().type!!)
             val arg1 = paramType.typeArguments.single()
             assertThat(arg1.typeName)
-                .isEqualTo(ClassName.get("kotlin", "Number"))
+                .isEqualTo(
+                    WildcardTypeName.subtypeOf(
+                        ClassName.get("kotlin", "Number")
+                    )
+                )
             assertThat(arg1.extendsBound()).isNull()
         }
     }
 
-    private fun TestInvocation.requirePropertyType(name: String): KspType {
+    private fun TestInvocation.requirePropertyType(name: String): KspDeclaredType {
         (processingEnv as KspProcessingEnv).resolver.getAllFiles().forEach { file ->
             val prop = file.declarations.first {
                 it.simpleName.asString() == name
@@ -509,7 +514,7 @@
         throw IllegalStateException("cannot find any property with name $name")
     }
 
-    private fun TestInvocation.wrap(typeRef: KSTypeReference): KspType {
+    private fun TestInvocation.wrap(typeRef: KSTypeReference): KspDeclaredType {
         return (processingEnv as KspProcessingEnv).wrap(typeRef)
     }
 }
\ No newline at end of file
diff --git a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/util/KotlinTypeNames.kt b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/util/KotlinTypeNames.kt
index 6b4198d..337ef103 100644
--- a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/util/KotlinTypeNames.kt
+++ b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/util/KotlinTypeNames.kt
@@ -26,6 +26,7 @@
     val STRING_CLASS_NAME = ClassName.get("kotlin", "String")
     val LIST_CLASS_NAME = ClassName.get("kotlin.collections", "List")
     val MUTABLELIST_CLASS_NAME = ClassName.get("kotlin.collections", "MutableList")
+    val MUTABLESET_CLASS_NAME = ClassName.get("kotlin.collections", "MutableSet")
     val MAP_CLASS_NAME = ClassName.get("kotlin.collections", "Map")
     val PAIR_CLASS_NAME = ClassName.get(Pair::class.java)
     val CONTINUATION_CLASS_NAME = ClassName.get("kotlin.coroutines", "Continuation")
diff --git a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/util/ProcessorTestExt.kt b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/util/ProcessorTestExt.kt
index 5db3d82..ae79f1d 100644
--- a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/util/ProcessorTestExt.kt
+++ b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/util/ProcessorTestExt.kt
@@ -115,7 +115,7 @@
         sources
     }
     // we can compile w/ javac only if all code is in java
-    if (sources.all { it is Source.JavaSource }) {
+    if (sources.canCompileWithJava()) {
         runJavaProcessorTest(sources = sources, handler = handler, succeed = true)
     }
     runKaptTest(sources = sources, handler = handler, succeed = true)
@@ -139,12 +139,23 @@
     sources: List<Source>,
     handler: (TestInvocation) -> Unit
 ) {
-    // run with java processor
-    runJavaProcessorTest(sources = sources, handler = handler, succeed = false)
+    if (sources.canCompileWithJava()) {
+        // run with java processor
+        runJavaProcessorTest(sources = sources, handler = handler, succeed = false)
+    }
     // now run with kapt
     runKaptTest(sources = sources, handler = handler, succeed = false)
 }
 
+fun runProcessorTestForFailedCompilationIncludingKsp(
+    sources: List<Source>,
+    handler: (TestInvocation) -> Unit
+) {
+    runProcessorTestForFailedCompilation(sources = sources, handler = handler)
+    // now run with ksp
+    runKspTest(sources = sources, handler = handler, succeed = false)
+}
+
 fun runJavaProcessorTest(
     sources: List<Source>,
     succeed: Boolean,
@@ -189,3 +200,5 @@
     }
     kspProcessor.throwIfFailed()
 }
+
+private fun List<Source>.canCompileWithJava() = all { it is Source.JavaSource }
\ No newline at end of file
diff --git a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/util/TestInvocation.kt b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/util/TestInvocation.kt
index d6778c1..2f2d228 100644
--- a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/util/TestInvocation.kt
+++ b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/util/TestInvocation.kt
@@ -17,10 +17,12 @@
 package androidx.room.compiler.processing.util
 
 import androidx.room.compiler.processing.XProcessingEnv
+import androidx.room.compiler.processing.javac.JavacProcessingEnv
 import androidx.room.compiler.processing.ksp.KspProcessingEnv
 import com.squareup.javapoet.ClassName
 import com.squareup.javapoet.TypeName
 import org.jetbrains.kotlin.ksp.processing.Resolver
+import javax.lang.model.util.Elements
 
 class TestInvocation(
     val processingEnv: XProcessingEnv
@@ -30,6 +32,9 @@
     val kspResolver: Resolver
         get() = (processingEnv as KspProcessingEnv).resolver
 
+    val javaElementUtils: Elements
+        get() = (processingEnv as JavacProcessingEnv).elementUtils
+
     val types by lazy {
         if (processingEnv is KspProcessingEnv) {
             Types(
@@ -39,7 +44,8 @@
                 boxedInt = KotlinTypeNames.INT_CLASS_NAME,
                 int = KotlinTypeNames.INT_CLASS_NAME,
                 long = KotlinTypeNames.LONG_CLASS_NAME,
-                list = KotlinTypeNames.LIST_CLASS_NAME
+                list = KotlinTypeNames.LIST_CLASS_NAME,
+                mutableSet = KotlinTypeNames.MUTABLESET_CLASS_NAME
             )
         } else {
             Types(
@@ -49,7 +55,8 @@
                 boxedInt = TypeName.INT.box(),
                 int = TypeName.INT,
                 long = TypeName.LONG,
-                list = ClassName.get("java.util", "List")
+                list = ClassName.get("java.util", "List"),
+                mutableSet = ClassName.get("java.util", "Set")
             )
         }
     }
@@ -65,6 +72,7 @@
         val boxedInt: TypeName,
         val int: TypeName,
         val long: TypeName,
-        val list: ClassName
+        val list: ClassName,
+        val mutableSet: TypeName
     )
 }