Merge "Make sure there is only 1 instance of each type element" into androidx-main
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/javac/JavacProcessingEnv.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/javac/JavacProcessingEnv.kt
index 06c844a..8e6e544 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/javac/JavacProcessingEnv.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/javac/JavacProcessingEnv.kt
@@ -42,10 +42,17 @@
     val typeUtils: Types = delegate.typeUtils
 
     private val typeElementStore =
-        XTypeElementStore { qName ->
-            val result = delegate.elementUtils.getTypeElement(qName)
-            result?.let(this::wrapTypeElement)
-        }
+        XTypeElementStore(
+            findElement = { qName ->
+                delegate.elementUtils.getTypeElement(qName)
+            },
+            wrap = {
+                JavacTypeElement(this, it)
+            },
+            getQName = {
+                it.qualifiedName.toString()
+            }
+        )
 
     override val messager: XMessager by lazy {
         JavacProcessingEnvMessager(delegate)
@@ -113,7 +120,7 @@
     }
 
     // maybe cache here ?
-    fun wrapTypeElement(element: TypeElement) = JavacTypeElement(this, element)
+    fun wrapTypeElement(element: TypeElement) = typeElementStore[element]
 
     /**
      * Wraps the given java processing type into an XType.
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/javac/XTypeElementStore.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/javac/XTypeElementStore.kt
index 229951e..a44e6b4 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/javac/XTypeElementStore.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/javac/XTypeElementStore.kt
@@ -22,8 +22,10 @@
 /**
  * Utility class to cache type element wrappers.
  */
-internal class XTypeElementStore<T : XTypeElement>(
-    val findElement: (qName: String) -> T?
+internal class XTypeElementStore<BackingType, T : XTypeElement>(
+    private val findElement: (qName: String) -> BackingType?,
+    private val getQName: (BackingType) -> String?,
+    private val wrap: (type: BackingType) -> T
 ) {
     // instead of something like a Guava cache, we use a map of weak references here because our
     // main goal is avoiding to re-parse type elements as we go up & down in the hierarchy while
@@ -31,12 +33,30 @@
     // could possibly hold a lot more information than we desire.
     private val typeCache = mutableMapOf<String, WeakReference<T>>()
 
+    operator fun get(backingType: BackingType): T {
+        val qName = getQName(backingType)
+        @Suppress("FoldInitializerAndIfToElvis")
+        if (qName == null) {
+            // just wrap without caching, likely an error or local type in kotlin
+            return wrap(backingType)
+        }
+        get(qName)?.let {
+            return it
+        }
+        val wrapped = wrap(backingType)
+        return cache(qName, wrapped)
+    }
+
     operator fun get(qName: String): T? {
         typeCache[qName]?.get()?.let {
             return it
         }
-        val result = findElement(qName) ?: return null
-        typeCache[qName] = WeakReference(result)
-        return result
+        val result = findElement(qName)?.let(wrap) ?: return null
+        return cache(qName, result)
+    }
+
+    private fun cache(qName: String, element: T): T {
+        typeCache[qName] = WeakReference(element)
+        return element
     }
 }
\ No newline at end of file
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspProcessingEnv.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspProcessingEnv.kt
index 5e0b43a..6533f64 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspProcessingEnv.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspProcessingEnv.kt
@@ -43,16 +43,21 @@
     override val backend: XProcessingEnv.Backend = XProcessingEnv.Backend.KSP
 
     private val typeElementStore =
-        XTypeElementStore { qName ->
-            resolver.getClassDeclarationByName(
-                KspTypeMapper.swapWithKotlinType(qName)
-            )?.let {
-                KspTypeElement(
-                    env = this,
-                    declaration = it
+        XTypeElementStore(
+            findElement = {
+                resolver.getClassDeclarationByName(
+                    KspTypeMapper.swapWithKotlinType(it)
                 )
+            },
+            getQName = {
+                // for error types or local types, qualified name is null.
+                // it is best to just not cache them
+                it.qualifiedName?.asString()
+            },
+            wrap = {
+                KspTypeElement(this, it)
             }
-        }
+        )
 
     override val messager: XMessager = KspMessager(logger)
 
@@ -194,10 +199,7 @@
     }
 
     fun wrapClassDeclaration(declaration: KSClassDeclaration): KspTypeElement {
-        return KspTypeElement(
-            env = this,
-            declaration = declaration
-        )
+        return typeElementStore[declaration]
     }
 
     class CommonTypes(resolver: Resolver) {
diff --git a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/XProcessingEnvTest.kt b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/XProcessingEnvTest.kt
index 6083bd4..a2057fa 100644
--- a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/XProcessingEnvTest.kt
+++ b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/XProcessingEnvTest.kt
@@ -239,6 +239,28 @@
         }
     }
 
+    @Test
+    fun typeElementsAreCached() {
+        val src = Source.java(
+            "JavaSubject",
+            """
+            class JavaSubject {
+                NestedClass nestedClass;
+                class NestedClass {
+                    int x;
+                }
+            }
+            """.trimIndent()
+        )
+        runProcessorTest(
+            sources = listOf(src)
+        ) { invocation ->
+            val parent = invocation.processingEnv.requireTypeElement("JavaSubject")
+            val nested = invocation.processingEnv.requireTypeElement("JavaSubject.NestedClass")
+            assertThat(nested.enclosingTypeElement).isSameInstanceAs(parent)
+        }
+    }
+
     companion object {
         val PRIMITIVE_TYPES = listOf(
             TypeName.BOOLEAN,