Merge "Fix KSTypeVarianceResolver for type aliases with type parameters." into androidx-main
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/DefaultKspType.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/DefaultKspType.kt
index 77606d3..a24eeb0 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/DefaultKspType.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/DefaultKspType.kt
@@ -16,7 +16,6 @@
 
 package androidx.room.compiler.processing.ksp
 
-import androidx.room.compiler.processing.XNullability
 import androidx.room.compiler.processing.tryBox
 import com.google.devtools.ksp.symbol.KSType
 import com.squareup.kotlinpoet.javapoet.JTypeName
@@ -25,8 +24,9 @@
 internal class DefaultKspType(
     env: KspProcessingEnv,
     ksType: KSType,
-    scope: KSTypeVarianceResolverScope?
-) : KspType(env, ksType, scope) {
+    scope: KSTypeVarianceResolverScope? = null,
+    typeAlias: KSType? = null,
+) : KspType(env, ksType, scope, typeAlias) {
 
     override fun resolveJTypeName(): JTypeName {
         // always box these. For primitives, typeName might return the primitive type but if we
@@ -42,19 +42,10 @@
         return this
     }
 
-    override fun copyWithNullability(nullability: XNullability): KspType {
-        return DefaultKspType(
-            env = env,
-            ksType = ksType.withNullability(nullability),
-            scope = scope
-        )
-    }
-
-    override fun copyWithScope(scope: KSTypeVarianceResolverScope): KspType {
-        return DefaultKspType(
-            env = env,
-            ksType = ksType,
-            scope = scope
-        )
-    }
+    override fun copy(
+        env: KspProcessingEnv,
+        ksType: KSType,
+        scope: KSTypeVarianceResolverScope?,
+        typeAlias: KSType?
+    ) = DefaultKspType(env, ksType, scope, typeAlias)
 }
\ No newline at end of file
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeExt.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeExt.kt
index 1b31225..83e3f36 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeExt.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeExt.kt
@@ -17,6 +17,8 @@
 package androidx.room.compiler.processing.ksp
 
 import androidx.room.compiler.processing.XNullability
+import androidx.room.compiler.processing.rawTypeName
+import com.google.devtools.ksp.processing.Resolver
 import com.google.devtools.ksp.symbol.KSAnnotated
 import com.google.devtools.ksp.symbol.KSAnnotation
 import com.google.devtools.ksp.symbol.KSDeclaration
@@ -26,6 +28,86 @@
 import com.google.devtools.ksp.symbol.KSTypeArgument
 import com.google.devtools.ksp.symbol.KSTypeParameter
 import com.google.devtools.ksp.symbol.KSTypeReference
+import com.google.devtools.ksp.symbol.Nullability
+import com.google.devtools.ksp.symbol.Variance
+import com.squareup.kotlinpoet.javapoet.JClassName
+
+internal fun KSType.replaceSuspendFunctionTypes(resolver: Resolver): KSType {
+    return if (!isSuspendFunctionType) {
+        this
+    } else {
+        // Find the JVM FunctionN type that will replace the suspend function and use that.
+        val functionN = resolver.requireType(
+            (declaration.asJTypeName(resolver).rawTypeName() as JClassName).canonicalName()
+        )
+        functionN.replace(
+            buildList {
+                addAll(arguments.dropLast(1))
+                val continuationTypeRef = resolver.requireType("kotlin.coroutines.Continuation")
+                    .replace(arguments.takeLast(1))
+                    .createTypeReference()
+                add(resolver.getTypeArgument(continuationTypeRef, Variance.INVARIANT))
+                val objTypeRef = resolver.requireType("java.lang.Object").createTypeReference()
+                add(resolver.getTypeArgument(objTypeRef, Variance.INVARIANT))
+            }
+        )
+    }
+}
+
+internal fun KSType.replaceTypeAliases(resolver: Resolver): KSType {
+    return if (declaration is KSTypeAlias) {
+        // Note: KSP only gives us access to the typealias through the declaration. This means
+        // that any type arguments on the typealias won't be resolved so we have to do this
+        // manually by creating a map from type parameter to type argument and manually
+        // substituting the type parameters as we find them.
+        val typeParamNameToTypeArgs = declaration.typeParameters.indices.associate { i ->
+            declaration.typeParameters[i].name.asString() to arguments[i]
+        }
+        (declaration as KSTypeAlias).type.resolve()
+            .replaceTypeArgs(resolver, typeParamNameToTypeArgs)
+    } else {
+        this
+    }.let {
+        it.replace(it.arguments.map { typeArg -> typeArg.replaceTypeAliases(resolver) })
+    }.let {
+        // if this type is nullable, carry it over
+        if (nullability == Nullability.NULLABLE) {
+            it.makeNullable()
+        } else {
+            it
+        }
+    }
+}
+
+private fun KSTypeArgument.replaceTypeAliases(resolver: Resolver): KSTypeArgument {
+    val type = type?.resolve() ?: return this
+    return resolver.getTypeArgument(
+        type.replaceTypeAliases(resolver).createTypeReference(),
+        variance
+    )
+}
+
+private fun KSType.replaceTypeArgs(
+    resolver: Resolver,
+    typeArgsMap: Map<String, KSTypeArgument>
+): KSType = replace(arguments.map { it.replaceTypeArgs(resolver, typeArgsMap) })
+
+private fun KSTypeArgument.replaceTypeArgs(
+    resolver: Resolver,
+    typeArgsMap: Map<String, KSTypeArgument>
+): KSTypeArgument {
+    val type = type?.resolve() ?: return this
+    if (type.isTypeParameter()) {
+        val name = (type.declaration as KSTypeParameter).name.asString()
+        if (typeArgsMap.containsKey(name)) {
+            return typeArgsMap[name]!!
+        }
+    }
+    return resolver.getTypeArgument(
+        type.replaceTypeArgs(resolver, typeArgsMap).createTypeReference(),
+        variance
+    )
+}
 
 /**
  * Root package comes as <root> instead of "" so we work around it here.
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeJavaPoetExt.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeJavaPoetExt.kt
index a40fe41..1aaaee0 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeJavaPoetExt.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeJavaPoetExt.kt
@@ -168,7 +168,9 @@
     resolver: Resolver,
     typeArgumentTypeLookup: JTypeArgumentTypeLookup
 ): JTypeName {
-    return if (this.arguments.isNotEmpty() && !resolver.isJavaRawType(this)) {
+    return if (declaration is KSTypeAlias) {
+        replaceTypeAliases(resolver).asJTypeName(resolver, typeArgumentTypeLookup)
+    } else if (this.arguments.isNotEmpty() && !resolver.isJavaRawType(this)) {
         val args: Array<JTypeName> = this.arguments
             .map { typeArg -> typeArg.asJTypeName(resolver, typeArgumentTypeLookup) }
             .map { it.tryBox() }
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeKotlinPoetExt.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeKotlinPoetExt.kt
index affa165..bf9c0d6 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeKotlinPoetExt.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeKotlinPoetExt.kt
@@ -140,7 +140,9 @@
     resolver: Resolver,
     typeArgumentTypeLookup: KTypeArgumentTypeLookup
 ): KTypeName {
-    return if (this.arguments.isNotEmpty() && !resolver.isJavaRawType(this)) {
+    return if (declaration is KSTypeAlias) {
+        replaceTypeAliases(resolver).asKTypeName(resolver, typeArgumentTypeLookup)
+    } else if (this.arguments.isNotEmpty() && !resolver.isJavaRawType(this)) {
         val args: List<KTypeName> = this.arguments
             .map { typeArg ->
                 typeArg.asKTypeName(
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeVarianceResolver.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeVarianceResolver.kt
index 56c3455..2687c5e 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeVarianceResolver.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeVarianceResolver.kt
@@ -16,7 +16,6 @@
 
 package androidx.room.compiler.processing.ksp
 
-import androidx.room.compiler.processing.rawTypeName
 import com.google.devtools.ksp.KspExperimental
 import com.google.devtools.ksp.isOpen
 import com.google.devtools.ksp.processing.Resolver
@@ -28,7 +27,6 @@
 import com.google.devtools.ksp.symbol.KSTypeParameter
 import com.google.devtools.ksp.symbol.Modifier
 import com.google.devtools.ksp.symbol.Variance
-import com.squareup.kotlinpoet.javapoet.JClassName
 
 /**
  * When kotlin generates java code, it has some interesting rules on how variance is handled.
@@ -53,10 +51,7 @@
      */
     @OptIn(KspExperimental::class)
     fun applyTypeVariance(type: KSType, scope: KSTypeVarianceResolverScope?): KSType {
-        if (type.isError ||
-            type.arguments.isEmpty() ||
-            resolver.isJavaRawType(type) ||
-            scope?.needsWildcardResolution == false) {
+        if (type.isError || scope?.needsWildcardResolution == false) {
             // There's nothing to resolve in this case, so just return the original type.
             return type
         }
@@ -64,6 +59,10 @@
         // First wrap types/arguments in our own wrappers so that we can keep track of the original
         // type, which is needed to get annotations.
         return KSTypeWrapper(resolver, type)
+            // Next, replace all type aliases with their resolved types
+            .replaceTypeAliases()
+            // Next, replace all suspend functions with their JVM types.
+            .replaceSuspendFunctionTypes()
             // Next, resolve wildcards based on the scope of the type
             .resolveWildcards(scope)
             // Next, apply any additional variance changes based on the @JvmSuppressWildcards or
@@ -76,6 +75,73 @@
             .unwrap()
     }
 
+    private fun KSTypeWrapper.replaceTypeAliases(): KSTypeWrapper {
+        return if (declaration is KSTypeAlias) {
+            // Note: KSP only gives us access to the type alias through the declaration. This means
+            // that any type arguments on the type alias won't be resolved.
+            // For example, if we have a type alias,  MyAlias<T> = Foo<Bar<T>>, and a property,
+            // MyAlias<Baz>, then calling KSTypeAlias#type on the property will give Foo<Bar<T>>
+            // rather than Foo<Bar<Baz>>.
+            val typeParamNameToTypeArgs = declaration.typeParameters.indices.associate { i ->
+                declaration.typeParameters[i].name.asString() to arguments[i]
+            }
+            replaceType(declaration.type.resolve()).replaceTypeArgs(typeParamNameToTypeArgs)
+        } else {
+            this
+        }.let {
+            it.replace(it.arguments.map { typeArg -> typeArg.replaceTypeAliases() })
+        }
+    }
+
+    private fun KSTypeArgumentWrapper.replaceTypeAliases(): KSTypeArgumentWrapper {
+        val type = type ?: return this
+        return replace(type.replaceTypeAliases(), variance)
+    }
+
+    private fun KSTypeWrapper.replaceTypeArgs(
+        typeArgsMap: Map<String, KSTypeArgumentWrapper>
+    ): KSTypeWrapper = replace(arguments.map { it.replaceTypeArgs(typeArgsMap) })
+
+    private fun KSTypeArgumentWrapper.replaceTypeArgs(
+        typeArgsMap: Map<String, KSTypeArgumentWrapper>
+    ): KSTypeArgumentWrapper {
+        val type = type ?: return this
+        if (type.isTypeParameter()) {
+            val name = (type.declaration as KSTypeParameter).name.asString()
+            if (typeArgsMap.containsKey(name)) {
+                return replace(typeArgsMap[name]?.type!!, variance)
+            }
+        }
+        return replace(type.replaceTypeArgs(typeArgsMap), variance)
+    }
+
+    private fun KSTypeWrapper.replaceSuspendFunctionTypes(): KSTypeWrapper {
+        return if (!newType.isSuspendFunctionType) {
+            this
+        } else {
+            val newKSType = newType.replaceSuspendFunctionTypes(resolver)
+            val newType = KSTypeWrapper(resolver, newKSType)
+            replaceType(newKSType).replace(
+                buildList {
+                    addAll(arguments.dropLast(1))
+                    val originalArg = arguments.last()
+                    val continuationArg = newType.arguments[newType.arguments.lastIndex - 1]
+                    add(
+                        continuationArg.replace(
+                            continuationArg.type!!.replace(
+                                continuationArg.type!!.arguments.map {
+                                    it.replace(originalArg.type!!, originalArg.variance)
+                                }
+                            ),
+                            continuationArg.variance
+                        )
+                    )
+                    add(newType.arguments.last())
+                }
+            )
+        }
+    }
+
     private fun KSTypeWrapper.resolveWildcards(
         scope: KSTypeVarianceResolverScope?
     ) = if (scope == null) {
@@ -249,8 +315,7 @@
 
     private fun KSTypeWrapper.applyJvmWildcardAnnotations(
         scope: KSTypeVarianceResolverScope?
-    ) =
-        replace(arguments.map { it.applyJvmWildcardAnnotations(scope) })
+    ) = replace(arguments.map { it.applyJvmWildcardAnnotations(scope) })
 
     private fun KSTypeArgumentWrapper.applyJvmWildcardAnnotations(
         scope: KSTypeVarianceResolverScope?
@@ -282,23 +347,26 @@
  * [IllegalStateException] since KSP tries to cast to its own implementation of [KSTypeArgument].
  */
 private class KSTypeWrapper constructor(
-    private val resolver: Resolver,
-    private val originalType: KSType,
-    private val newType: KSType =
-        originalType.replaceTypeAliases().replaceSuspendFunctionTypes(resolver),
-    newTypeArguments: List<KSTypeArgumentWrapper>? = null,
-    private val typeStack: List<KSTypeWrapper> = emptyList(),
-    private val typeArgStack: List<KSTypeArgumentWrapper> = emptyList(),
-    private val typeParamStack: List<KSTypeParameter> = emptyList(),
+    val resolver: Resolver,
+    val originalType: KSType,
+    val newType: KSType = originalType,
+    val newTypeArguments: List<KSTypeArgumentWrapper>? = null,
+    val typeStack: List<KSTypeWrapper> = emptyList(),
+    val typeArgStack: List<KSTypeArgumentWrapper> = emptyList(),
+    val typeParamStack: List<KSTypeParameter> = emptyList(),
 ) {
-    val declaration = originalType.declaration
+    val declaration = newType.declaration
 
     val arguments: List<KSTypeArgumentWrapper> by lazy {
-        newTypeArguments ?: newType.arguments.indices.map { i ->
+        val arguments = newTypeArguments ?: newType.arguments.indices.map { i ->
             KSTypeArgumentWrapper(
                 originalTypeArg = newType.arguments[i],
                 typeParam = newType.declaration.typeParameters[i],
                 resolver = resolver,
+            )
+        }
+        arguments.map { newTypeArg ->
+            newTypeArg.copy(
                 typeStack = typeStack + this,
                 typeArgStack = typeArgStack,
                 typeParamStack = typeParamStack,
@@ -306,11 +374,23 @@
         }
     }
 
-    fun replace(newTypeArguments: List<KSTypeArgumentWrapper>) = KSTypeWrapper(
+    fun replaceType(newType: KSType): KSTypeWrapper = copy(newType = newType)
+
+    fun replace(newTypeArguments: List<KSTypeArgumentWrapper>) =
+        copy(newTypeArguments = newTypeArguments)
+
+    fun copy(
+        originalType: KSType = this.originalType,
+        newType: KSType = this.newType,
+        newTypeArguments: List<KSTypeArgumentWrapper>? = this.newTypeArguments,
+        typeStack: List<KSTypeWrapper> = this.typeStack,
+        typeArgStack: List<KSTypeArgumentWrapper> = this.typeArgStack,
+        typeParamStack: List<KSTypeParameter> = this.typeParamStack,
+    ) = KSTypeWrapper(
+        resolver = resolver,
         originalType = originalType,
         newType = newType,
         newTypeArguments = newTypeArguments,
-        resolver = resolver,
         typeStack = typeStack,
         typeArgStack = typeArgStack,
         typeParamStack = typeParamStack,
@@ -328,32 +408,7 @@
         }
         append(newType.declaration.simpleName.asString())
         if (arguments.isNotEmpty()) {
-            append("$arguments")
-        }
-    }
-
-    private companion object {
-        fun KSType.replaceTypeAliases() = (declaration as? KSTypeAlias)?.type?.resolve() ?: this
-
-        fun KSType.replaceSuspendFunctionTypes(resolver: Resolver) = if (!isSuspendFunctionType) {
-            this
-        } else {
-            // Find the JVM FunctionN type that will replace the suspend function and use that.
-            val functionN = resolver.requireType(
-                (declaration.asJTypeName(resolver).rawTypeName() as JClassName).canonicalName()
-            )
-            functionN.replace(
-                buildList {
-                    addAll(arguments.dropLast(1))
-                    val continuationArgs = arguments.takeLast(1)
-                    val continuationTypeRef = resolver.requireType("kotlin.coroutines.Continuation")
-                        .replace(continuationArgs)
-                        .createTypeReference()
-                    val objTypeRef = resolver.requireType("java.lang.Object").createTypeReference()
-                    add(resolver.getTypeArgument(continuationTypeRef, Variance.INVARIANT))
-                    add(resolver.getTypeArgument(objTypeRef, Variance.INVARIANT))
-                }
-            )
+            append("<${arguments.joinToString(", ")}>")
         }
     }
 }
@@ -368,36 +423,47 @@
  * type argument.
  */
 private class KSTypeArgumentWrapper constructor(
-    private val originalTypeArg: KSTypeArgument,
-    private val newType: KSTypeWrapper? = null,
-    private val resolver: Resolver,
+    val originalTypeArg: KSTypeArgument,
+    val newType: KSTypeWrapper? = null,
+    val resolver: Resolver,
     val typeParam: KSTypeParameter,
     val variance: Variance = originalTypeArg.variance,
-    val typeStack: List<KSTypeWrapper>,
-    val typeArgStack: List<KSTypeArgumentWrapper>,
-    val typeParamStack: List<KSTypeParameter>,
+    val typeStack: List<KSTypeWrapper> = emptyList(),
+    val typeArgStack: List<KSTypeArgumentWrapper> = emptyList(),
+    val typeParamStack: List<KSTypeParameter> = emptyList(),
 ) {
     val type: KSTypeWrapper? by lazy {
         if (variance == Variance.STAR || originalTypeArg.type == null) {
             // Return null for star projections, otherwise we'll end up in an infinite loop.
-            null
-        } else {
-            newType ?: KSTypeWrapper(
-                originalType = originalTypeArg.type!!.resolve(),
-                resolver = resolver,
-                typeStack = typeStack,
-                typeArgStack = typeArgStack + this,
-                typeParamStack = typeParamStack + typeParam,
-            )
+            return@lazy null
         }
+        val type = newType ?: KSTypeWrapper(resolver, originalTypeArg.type!!.resolve())
+        type.copy(
+            typeStack = typeStack,
+            typeArgStack = typeArgStack + this,
+            typeParamStack = typeParamStack + typeParam,
+        )
     }
 
-    fun replace(newType: KSTypeWrapper, newVariance: Variance) = KSTypeArgumentWrapper(
-        originalTypeArg = originalTypeArg,
-        typeParam = typeParam,
+    fun replace(newType: KSTypeWrapper, newVariance: Variance) = copy(
         newType = newType,
         variance = newVariance,
+    )
+
+    fun copy(
+        originalTypeArg: KSTypeArgument = this.originalTypeArg,
+        newType: KSTypeWrapper? = this.newType,
+        typeParam: KSTypeParameter = this.typeParam,
+        variance: Variance = this.variance,
+        typeStack: List<KSTypeWrapper> = this.typeStack,
+        typeArgStack: List<KSTypeArgumentWrapper> = this.typeArgStack,
+        typeParamStack: List<KSTypeParameter> = this.typeParamStack,
+    ) = KSTypeArgumentWrapper(
         resolver = resolver,
+        originalTypeArg = originalTypeArg,
+        newType = newType,
+        variance = variance,
+        typeParam = typeParam,
         typeStack = typeStack,
         typeArgStack = typeArgStack,
         typeParamStack = typeParamStack,
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspArrayType.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspArrayType.kt
index 21a2942..74a142a 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspArrayType.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspArrayType.kt
@@ -28,8 +28,9 @@
 internal sealed class KspArrayType(
     env: KspProcessingEnv,
     ksType: KSType,
-    scope: KSTypeVarianceResolverScope?
-) : KspType(env, ksType, scope), XArrayType {
+    scope: KSTypeVarianceResolverScope? = null,
+    typeAlias: KSType? = null,
+) : KspType(env, ksType, scope, typeAlias), XArrayType {
 
     abstract override val componentType: KspType
 
@@ -52,8 +53,9 @@
     private class BoxedArray(
         env: KspProcessingEnv,
         ksType: KSType,
-        scope: KSTypeVarianceResolverScope?
-    ) : KspArrayType(env, ksType, scope) {
+        scope: KSTypeVarianceResolverScope? = null,
+        typeAlias: KSType? = null,
+    ) : KspArrayType(env, ksType, scope, typeAlias) {
         override fun resolveJTypeName(): JTypeName {
             return JArrayTypeName.of(componentType.asTypeName().java.box())
         }
@@ -72,21 +74,12 @@
             )
         }
 
-        override fun copyWithNullability(nullability: XNullability): BoxedArray {
-            return BoxedArray(
-                env = env,
-                ksType = ksType.withNullability(nullability),
-                scope = scope,
-            )
-        }
-
-        override fun copyWithScope(scope: KSTypeVarianceResolverScope): KspType {
-            return BoxedArray(
-                env = env,
-                ksType = ksType,
-                scope = scope
-            )
-        }
+        override fun copy(
+            env: KspProcessingEnv,
+            ksType: KSType,
+            scope: KSTypeVarianceResolverScope?,
+            typeAlias: KSType?
+        ) = BoxedArray(env, ksType, scope, typeAlias)
     }
 
     /**
@@ -95,9 +88,10 @@
     private class PrimitiveArray(
         env: KspProcessingEnv,
         ksType: KSType,
-        scope: KSTypeVarianceResolverScope?,
-        override val componentType: KspType
-    ) : KspArrayType(env, ksType, scope) {
+        scope: KSTypeVarianceResolverScope? = null,
+        typeAlias: KSType? = null,
+        override val componentType: KspType,
+    ) : KspArrayType(env, ksType, scope, typeAlias) {
         override fun resolveJTypeName(): JTypeName {
             return JArrayTypeName.of(componentType.asTypeName().java.unbox())
         }
@@ -106,23 +100,12 @@
             return ksType.asKTypeName(env.resolver)
         }
 
-        override fun copyWithNullability(nullability: XNullability): PrimitiveArray {
-            return PrimitiveArray(
-                env = env,
-                ksType = ksType.withNullability(nullability),
-                componentType = componentType,
-                scope = scope
-            )
-        }
-
-        override fun copyWithScope(scope: KSTypeVarianceResolverScope): KspType {
-            return PrimitiveArray(
-                env = env,
-                ksType = ksType,
-                componentType = componentType,
-                scope = scope
-            )
-        }
+        override fun copy(
+            env: KspProcessingEnv,
+            ksType: KSType,
+            scope: KSTypeVarianceResolverScope?,
+            typeAlias: KSType?
+        ) = PrimitiveArray(env, ksType, scope, typeAlias, componentType)
     }
 
     /**
@@ -131,14 +114,14 @@
     internal class Factory(private val env: KspProcessingEnv) {
         // map of built in array type to its component type
         private val builtInArrays = mapOf(
-            "kotlin.BooleanArray" to KspPrimitiveType(env, env.resolver.builtIns.booleanType, null),
-            "kotlin.ByteArray" to KspPrimitiveType(env, env.resolver.builtIns.byteType, null),
-            "kotlin.CharArray" to KspPrimitiveType(env, env.resolver.builtIns.charType, null),
-            "kotlin.DoubleArray" to KspPrimitiveType(env, env.resolver.builtIns.doubleType, null),
-            "kotlin.FloatArray" to KspPrimitiveType(env, env.resolver.builtIns.floatType, null),
-            "kotlin.IntArray" to KspPrimitiveType(env, env.resolver.builtIns.intType, null),
-            "kotlin.LongArray" to KspPrimitiveType(env, env.resolver.builtIns.longType, null),
-            "kotlin.ShortArray" to KspPrimitiveType(env, env.resolver.builtIns.shortType, null),
+            "kotlin.BooleanArray" to KspPrimitiveType(env, env.resolver.builtIns.booleanType),
+            "kotlin.ByteArray" to KspPrimitiveType(env, env.resolver.builtIns.byteType),
+            "kotlin.CharArray" to KspPrimitiveType(env, env.resolver.builtIns.charType),
+            "kotlin.DoubleArray" to KspPrimitiveType(env, env.resolver.builtIns.doubleType),
+            "kotlin.FloatArray" to KspPrimitiveType(env, env.resolver.builtIns.floatType),
+            "kotlin.IntArray" to KspPrimitiveType(env, env.resolver.builtIns.intType),
+            "kotlin.LongArray" to KspPrimitiveType(env, env.resolver.builtIns.longType),
+            "kotlin.ShortArray" to KspPrimitiveType(env, env.resolver.builtIns.shortType),
         )
 
         // map from the primitive to its array
@@ -156,7 +139,6 @@
                             primitiveArrayEntry.key
                         ),
                         componentType = primitiveArrayEntry.value,
-                        scope = null
                     )
                 }
             }
@@ -171,7 +153,6 @@
                         )
                     )
                 ),
-                scope = null
             )
         }
 
@@ -185,7 +166,6 @@
                 return BoxedArray(
                     env = env,
                     ksType = ksType,
-                    scope = null
                 )
             }
             builtInArrays[qName]?.let { primitiveType ->
@@ -193,7 +173,6 @@
                     env = env,
                     ksType = ksType,
                     componentType = primitiveType,
-                    scope = null
                 )
             }
             return null
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspPrimitiveType.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspPrimitiveType.kt
index 73d675c..2035e3a 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspPrimitiveType.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspPrimitiveType.kt
@@ -16,7 +16,6 @@
 
 package androidx.room.compiler.processing.ksp
 
-import androidx.room.compiler.processing.XNullability
 import androidx.room.compiler.processing.tryUnbox
 import com.google.devtools.ksp.symbol.KSType
 import com.squareup.kotlinpoet.javapoet.JTypeName
@@ -32,8 +31,8 @@
 internal class KspPrimitiveType(
     env: KspProcessingEnv,
     ksType: KSType,
-    scope: KSTypeVarianceResolverScope?
-) : KspType(env, ksType, scope) {
+    typeAlias: KSType? = null,
+) : KspType(env, ksType, null, typeAlias) {
     override fun resolveJTypeName(): JTypeName {
         return ksType.asJTypeName(env.resolver).tryUnbox()
     }
@@ -49,28 +48,10 @@
         )
     }
 
-    override fun copyWithNullability(nullability: XNullability): KspType {
-        return when (nullability) {
-            XNullability.NONNULL -> {
-                this
-            }
-            XNullability.NULLABLE -> {
-                // primitive types cannot be nullable hence we box them.
-                boxed().makeNullable()
-            }
-            else -> {
-                // this should actually never happens as the only time this is called is from
-                // make nullable-make nonnull but we have this error here for completeness.
-                error("cannot set nullability to unknown in KSP")
-            }
-        }
-    }
-
-    override fun copyWithScope(scope: KSTypeVarianceResolverScope): KspType {
-        return KspPrimitiveType(
-            env = env,
-            ksType = ksType,
-            scope = scope
-        )
-    }
+    override fun copy(
+        env: KspProcessingEnv,
+        ksType: KSType,
+        scope: KSTypeVarianceResolverScope?,
+        typeAlias: KSType?
+    ) = KspPrimitiveType(env, ksType, typeAlias)
 }
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspProcessingEnv.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspProcessingEnv.kt
index ae5a43cb..1081550 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspProcessingEnv.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspProcessingEnv.kt
@@ -105,7 +105,6 @@
             env = this,
             ksType = resolver.builtIns.unitType,
             boxed = false,
-            scope = null
         )
 
     override fun findTypeElement(qName: String): KspTypeElement? {
@@ -227,7 +226,6 @@
         return KspTypeArgumentType(
             env = this,
             typeArg = ksTypeArgument,
-            scope = null
         )
     }
 
@@ -242,41 +240,27 @@
     fun wrap(ksType: KSType, allowPrimitives: Boolean): KspType {
         val declaration = ksType.declaration
         if (declaration is KSTypeAlias) {
-            val actual = wrap(
-                ksType = declaration.type.resolve().replace(ksType.arguments),
+            return wrap(
+                ksType = ksType.replaceTypeAliases(resolver),
                 allowPrimitives = allowPrimitives && ksType.nullability == Nullability.NOT_NULL
-            )
-            // if this type is nullable, carry it over
-            return if (ksType.nullability == Nullability.NULLABLE) {
-                actual.makeNullable()
-            } else {
-                actual
-            }
+            ).copyWithTypeAlias(ksType)
         }
         val qName = ksType.declaration.qualifiedName?.asString()
         if (declaration is KSTypeParameter) {
-            return KspTypeVariableType(
-                env = this,
-                ksType = ksType,
-                scope = null
-            )
+            return KspTypeVariableType(this, ksType)
         }
         if (allowPrimitives && qName != null && ksType.nullability == Nullability.NOT_NULL) {
             // check for primitives
             val javaPrimitive = KspTypeMapper.getPrimitiveJavaTypeName(qName)
             if (javaPrimitive != null) {
-                return KspPrimitiveType(this, ksType, scope = null)
+                return KspPrimitiveType(this, ksType)
             }
             // special case for void
             if (qName == "kotlin.Unit") {
                 return voidType
             }
         }
-        return arrayTypeFactory.createIfArray(ksType) ?: DefaultKspType(
-            this,
-            ksType,
-            scope = null
-        )
+        return arrayTypeFactory.createIfArray(ksType) ?: DefaultKspType(this, ksType)
     }
 
     fun wrapClassDeclaration(declaration: KSClassDeclaration): KspTypeElement {
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspType.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspType.kt
index dc5cf7c..b347008 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspType.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspType.kt
@@ -49,10 +49,10 @@
 internal abstract class KspType(
     env: KspProcessingEnv,
     val ksType: KSType,
-    /**
-     * Type resolver to convert KSType into its JVM representation.
-     */
-    val scope: KSTypeVarianceResolverScope?
+    /** Type resolver to convert KSType into its JVM representation. */
+    val scope: KSTypeVarianceResolverScope?,
+    /** The `typealias` that was resolved to get the [ksType], or null if none exists. */
+    val typeAlias: KSType?,
 ) : KspAnnotated(env), XType, XEquality {
     override val rawType by lazy {
         KspRawType(this)
@@ -69,7 +69,7 @@
      * The [XTypeName] represents those differences as [JTypeName] and [KTypeName], respectively.
      */
     private val xTypeName: XTypeName by lazy {
-        val jvmWildcardType = env.resolveWildcards(ksType, scope).let {
+        val jvmWildcardType = env.resolveWildcards(typeAlias ?: ksType, scope).let {
             if (it == ksType) {
                 this
             } else {
@@ -266,14 +266,23 @@
 
     abstract override fun boxed(): KspType
 
-    abstract fun copyWithScope(scope: KSTypeVarianceResolverScope): KspType
+    abstract fun copy(
+        env: KspProcessingEnv,
+        ksType: KSType,
+        scope: KSTypeVarianceResolverScope?,
+        typeAlias: KSType?,
+    ): KspType
 
-    /**
-     * Create a copy of this type with the given nullability.
-     * This method is not called if the nullability of the type is already equal to the given
-     * nullability.
-     */
-    protected abstract fun copyWithNullability(nullability: XNullability): KspType
+    fun copyWithScope(scope: KSTypeVarianceResolverScope) = copy(env, ksType, scope, typeAlias)
+
+    fun copyWithTypeAlias(typeAlias: KSType) = copy(env, ksType, scope, typeAlias)
+
+    private fun copyWithNullability(nullability: XNullability): KspType = boxed().copy(
+        env = env,
+        ksType = ksType.withNullability(nullability),
+        scope = scope,
+        typeAlias = typeAlias,
+    )
 
     final override fun makeNullable(): KspType {
         if (nullability == XNullability.NULLABLE) {
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspTypeArgumentType.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspTypeArgumentType.kt
index 589244d..bd262c6 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspTypeArgumentType.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspTypeArgumentType.kt
@@ -16,8 +16,8 @@
 
 package androidx.room.compiler.processing.ksp
 
-import androidx.room.compiler.processing.XNullability
 import androidx.room.compiler.processing.XType
+import com.google.devtools.ksp.symbol.KSType
 import com.google.devtools.ksp.symbol.KSTypeArgument
 import com.google.devtools.ksp.symbol.KSTypeParameter
 import com.google.devtools.ksp.symbol.KSTypeReference
@@ -32,11 +32,13 @@
 internal class KspTypeArgumentType(
     env: KspProcessingEnv,
     val typeArg: KSTypeArgument,
-    scope: KSTypeVarianceResolverScope?
+    scope: KSTypeVarianceResolverScope? = null,
+    typeAlias: KSType? = null,
 ) : KspType(
     env = env,
     ksType = typeArg.requireType(),
-    scope = scope
+    scope = scope,
+    typeAlias = typeAlias,
 ) {
     /**
      * When KSP resolves classes, it always resolves to the upper bound. Hence, the ksType we
@@ -72,24 +74,17 @@
         return _extendsBound
     }
 
-    override fun copyWithNullability(nullability: XNullability): KspTypeArgumentType {
-        return KspTypeArgumentType(
-            env = env,
-            typeArg = DelegatingTypeArg(
-                original = typeArg,
-                type = ksType.withNullability(nullability).createTypeReference()
-            ),
-            scope = scope
-        )
-    }
-
-    override fun copyWithScope(scope: KSTypeVarianceResolverScope): KspType {
-        return KspTypeArgumentType(
-            env = env,
-            typeArg = typeArg,
-            scope = scope
-        )
-    }
+    override fun copy(
+        env: KspProcessingEnv,
+        ksType: KSType,
+        scope: KSTypeVarianceResolverScope?,
+        typeAlias: KSType?
+    ) = KspTypeArgumentType(
+        env = env,
+        typeArg = DelegatingTypeArg(typeArg, type = ksType.createTypeReference()),
+        scope = scope,
+        typeAlias = typeAlias
+    )
 
     private class DelegatingTypeArg(
         val original: KSTypeArgument,
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspTypeVariableType.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspTypeVariableType.kt
index cd2b9e6..43d1d17 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspTypeVariableType.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspTypeVariableType.kt
@@ -16,7 +16,6 @@
 
 package androidx.room.compiler.processing.ksp
 
-import androidx.room.compiler.processing.XNullability
 import androidx.room.compiler.processing.XType
 import androidx.room.compiler.processing.XTypeVariableType
 import com.google.devtools.ksp.symbol.KSType
@@ -27,8 +26,8 @@
 internal class KspTypeVariableType(
     env: KspProcessingEnv,
     ksType: KSType,
-    scope: KSTypeVarianceResolverScope?
-) : KspType(env, ksType, scope), XTypeVariableType {
+    scope: KSTypeVarianceResolverScope? = null,
+) : KspType(env, ksType, scope, null), XTypeVariableType {
     private val typeVariable: KSTypeParameter by lazy {
         // Note: This is a workaround for a bug in KSP where we may get ERROR_TYPE in the bounds
         // (https://github.com/google/ksp/issues/1250). To work around it we get the matching
@@ -52,19 +51,10 @@
         return this
     }
 
-    override fun copyWithNullability(nullability: XNullability): KspTypeVariableType {
-        return KspTypeVariableType(
-            env = env,
-            ksType = ksType,
-            scope = scope
-        )
-    }
-
-    override fun copyWithScope(scope: KSTypeVarianceResolverScope): KspType {
-        return KspTypeVariableType(
-            env = env,
-            ksType = ksType,
-            scope = scope
-        )
-    }
+    override fun copy(
+        env: KspProcessingEnv,
+        ksType: KSType,
+        scope: KSTypeVarianceResolverScope?,
+        typeAlias: KSType?
+    ) = KspTypeVariableType(env, ksType, scope)
 }
\ No newline at end of file
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspVoidType.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspVoidType.kt
index a68b298..9af6ba0 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspVoidType.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspVoidType.kt
@@ -32,8 +32,9 @@
     env: KspProcessingEnv,
     ksType: KSType,
     val boxed: Boolean,
-    scope: KSTypeVarianceResolverScope?
-) : KspType(env, ksType, scope) {
+    scope: KSTypeVarianceResolverScope? = null,
+    typeAlias: KSType? = null,
+) : KspType(env, ksType, scope, typeAlias) {
     override fun resolveJTypeName(): JTypeName {
         return if (boxed || nullability == XNullability.NULLABLE) {
             JTypeName.VOID.box()
@@ -54,26 +55,16 @@
                 env = env,
                 ksType = ksType,
                 boxed = true,
-                scope = scope
+                scope = scope,
+                typeAlias = typeAlias,
             )
         }
     }
 
-    override fun copyWithNullability(nullability: XNullability): KspType {
-        return KspVoidType(
-            env = env,
-            ksType = ksType.withNullability(nullability),
-            boxed = boxed || nullability == XNullability.NULLABLE,
-            scope = scope
-        )
-    }
-
-    override fun copyWithScope(scope: KSTypeVarianceResolverScope): KspType {
-        return KspVoidType(
-            env = env,
-            ksType = ksType,
-            boxed = boxed,
-            scope = scope
-        )
-    }
+    override fun copy(
+        env: KspProcessingEnv,
+        ksType: KSType,
+        scope: KSTypeVarianceResolverScope?,
+        typeAlias: KSType?,
+    ) = KspVoidType(env, ksType, boxed, scope, typeAlias)
 }
diff --git a/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/ksp/KspTypeNamesGoldenTest.kt b/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/ksp/KspTypeNamesGoldenTest.kt
index 35f0b23..b0c73fd 100644
--- a/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/ksp/KspTypeNamesGoldenTest.kt
+++ b/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/ksp/KspTypeNamesGoldenTest.kt
@@ -158,10 +158,15 @@
                 class MyGeneric<T>
                 class MyGenericIn<in T>
                 class MyGenericOut<out T>
+                class MyGeneric2Parameters<in T1, out T2>
                 class MyGenericMultipleParameters<T1: MyGeneric<*>, T2: MyGeneric<T1>>
                 interface MyInterface
                 typealias MyInterfaceAlias = MyInterface
                 typealias MyGenericAlias = MyGenericIn<MyGenericOut<MyGenericOut<MyType>>>
+                typealias MyGenericOutAlias<T> = MyGenericOut<T>
+                typealias MyGenericOutAliasWithJSW<T> = MyGenericOut<@JSW T>
+                typealias MyGeneric2ParametersAlias<T1, T2> = MyGeneric2Parameters<MyGenericOut<T1>, MyGeneric<MyGenericOut<T2>>>
+                typealias MyGeneric1ParameterAlias<T> = MyGeneric2Parameters<MyGenericOut<T>, MyGeneric<MyGenericOut<T>>>
                 typealias MyLambdaAlias1 = (List<MyGenericIn<MyGenericOut<MyGenericOut<MyType>>>>) -> List<MyGenericIn<MyGenericOut<MyGenericOut<MyType>>>>
                 typealias MyLambdaAlias2 = @JSW (List<MyGenericIn<MyGenericOut<MyGenericOut<MyType>>>>) -> List<MyGenericIn<MyGenericOut<MyGenericOut<MyType>>>>
                 typealias MyLambdaAlias3 = (@JSW List<MyGenericIn<MyGenericOut<MyGenericOut<MyType>>>>) -> @JSW List<MyGenericIn<MyGenericOut<MyGenericOut<MyType>>>>
@@ -296,6 +301,69 @@
                         fun method31(
                             param: MyGenericOut<MyGeneric<out MyGenericOut<MyGeneric<MyType>>>>
                         ): MyGenericOut<MyGeneric<out MyGenericOut<MyGeneric<MyType>>>> = TODO()
+                        fun method32(
+                            param: MyGenericOutAlias<MyInterface>
+                        ): MyGenericOutAlias<MyInterface> = TODO()
+                        fun method33(
+                            param: MyGenericOut<MyGenericOutAlias<MyInterface>>
+                        ): MyGenericOut<MyGenericOutAlias<MyInterface>> = TODO()
+                        fun method34(
+                            param: MyGenericIn<MyGenericOutAlias<MyInterface>>
+                        ): MyGenericIn<MyGenericOutAlias<MyInterface>> = TODO()
+                        fun method35(
+                            param: MyGenericOutAlias<MyGenericOut<MyInterface>>
+                        ): MyGenericOutAlias<MyGenericOut<MyInterface>> = TODO()
+                        fun method36(
+                            param: MyGenericIn<MyGenericOutAlias<MyGenericOut<MyInterface>>>
+                        ): MyGenericIn<MyGenericOutAlias<MyGenericOut<MyInterface>>> = TODO()
+                        fun method37(
+                            param: MyGenericIn<MyGenericOutAlias<MyGenericIn<MyInterface>>>
+                        ): MyGenericIn<MyGenericOutAlias<MyGenericIn<MyInterface>>> = TODO()
+                        fun method38(
+                            param: MyGenericOutAliasWithJSW<MyInterface>
+                        ): MyGenericOutAliasWithJSW<MyInterface> = TODO()
+                        fun method39(
+                            param: MyGenericOut<MyGenericOutAliasWithJSW<MyInterface>>
+                        ): MyGenericOut<MyGenericOutAliasWithJSW<MyInterface>> = TODO()
+                        fun method40(
+                            param: MyGenericIn<MyGenericOutAliasWithJSW<MyInterface>>
+                        ): MyGenericIn<MyGenericOutAliasWithJSW<MyInterface>> = TODO()
+                        fun method41(
+                            param: MyGenericOutAliasWithJSW<MyGenericOut<MyInterface>>
+                        ): MyGenericOutAliasWithJSW<MyGenericOut<MyInterface>> = TODO()
+                        fun method42(
+                            param: MyGenericIn<MyGenericOutAliasWithJSW<MyGenericOut<MyInterface>>>
+                        ): MyGenericIn<MyGenericOutAliasWithJSW<MyGenericOut<MyInterface>>> = TODO()
+                        fun method43(
+                            param: MyGenericIn<MyGenericOutAliasWithJSW<MyGenericIn<MyInterface>>>
+                        ): MyGenericIn<MyGenericOutAliasWithJSW<MyGenericIn<MyInterface>>> = TODO()
+                        fun method44(
+                            param: MyGenericOutAliasWithJSW<MyType>
+                        ): MyGenericOutAliasWithJSW<MyType> = TODO()
+                        fun method45(
+                            param: MyGenericOut<MyGenericOutAliasWithJSW<MyType>>
+                        ): MyGenericOut<MyGenericOutAliasWithJSW<MyType>> = TODO()
+                        fun method46(
+                            param: MyGenericIn<MyGenericOutAliasWithJSW<MyType>>
+                        ): MyGenericIn<MyGenericOutAliasWithJSW<MyType>> = TODO()
+                        fun method47(
+                            param: MyGeneric2ParametersAlias<MyType, MyType>
+                        ): MyGeneric2ParametersAlias<MyType, MyType> = TODO()
+                        fun method48(
+                            param: MyGenericOut<MyGeneric2ParametersAlias<MyType, MyType>>
+                        ): MyGenericOut<MyGeneric2ParametersAlias<MyType, MyType>> = TODO()
+                        fun method49(
+                            param: MyGenericIn<MyGeneric2ParametersAlias<MyType, MyType>>
+                        ): MyGenericIn<MyGeneric2ParametersAlias<MyType, MyType>> = TODO()
+                        fun method50(
+                            param: MyGeneric2ParametersAlias<MyGenericOut<MyType>, MyGenericIn<MyType>>
+                        ): MyGeneric2ParametersAlias<MyGenericOut<MyType>, MyGenericIn<MyType>> = TODO()
+                        fun method51(
+                            param: MyGenericOut<MyGeneric2ParametersAlias<MyGenericOut<MyType>, MyGenericIn<MyType>>>
+                        ): MyGenericOut<MyGeneric2ParametersAlias<MyGenericOut<MyType>, MyGenericIn<MyType>>> = TODO()
+                        fun method52(
+                            param: MyGenericIn<MyGeneric2ParametersAlias<MyGenericOut<MyType>, MyGenericIn<MyType>>>
+                        ): MyGenericIn<MyGeneric2ParametersAlias<MyGenericOut<MyType>, MyGenericIn<MyType>>> = TODO()
                     }
                 """.trimIndent()
             ), listOf("Subject")