Make sure there is only 1 instance of each type element

When traversing classes, we'll frequently wrap the same type
element in different cases (e.g. finding all methods of a type
element). Previously, we only cached results in findTypeElement
which meant we could wrap the same type element multiple times
if XProcessingEnv.wrap is called.

This CL updates the type element store to also accept the backing
type as key. When invoked, it will first get the qualified name
to query the cache before wrapping a new item.

Note that java does not allow conflicts on qualified names so
there is no uniqueness concerns (e.g. you cannot have foo.Bar.Baz
and foo.Bar class with Baz static inner class)

Bug: 160322705
Test: XProcessingEnvTest.typeElementsAreCached
Change-Id: Ic0fba8ea8ce80946afc642cd49ab48aa18a0e70b
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/javac/JavacProcessingEnv.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/javac/JavacProcessingEnv.kt
index 06c844a..8e6e544 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/javac/JavacProcessingEnv.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/javac/JavacProcessingEnv.kt
@@ -42,10 +42,17 @@
     val typeUtils: Types = delegate.typeUtils
 
     private val typeElementStore =
-        XTypeElementStore { qName ->
-            val result = delegate.elementUtils.getTypeElement(qName)
-            result?.let(this::wrapTypeElement)
-        }
+        XTypeElementStore(
+            findElement = { qName ->
+                delegate.elementUtils.getTypeElement(qName)
+            },
+            wrap = {
+                JavacTypeElement(this, it)
+            },
+            getQName = {
+                it.qualifiedName.toString()
+            }
+        )
 
     override val messager: XMessager by lazy {
         JavacProcessingEnvMessager(delegate)
@@ -113,7 +120,7 @@
     }
 
     // maybe cache here ?
-    fun wrapTypeElement(element: TypeElement) = JavacTypeElement(this, element)
+    fun wrapTypeElement(element: TypeElement) = typeElementStore[element]
 
     /**
      * Wraps the given java processing type into an XType.
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/javac/XTypeElementStore.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/javac/XTypeElementStore.kt
index 229951e..a44e6b4 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/javac/XTypeElementStore.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/javac/XTypeElementStore.kt
@@ -22,8 +22,10 @@
 /**
  * Utility class to cache type element wrappers.
  */
-internal class XTypeElementStore<T : XTypeElement>(
-    val findElement: (qName: String) -> T?
+internal class XTypeElementStore<BackingType, T : XTypeElement>(
+    private val findElement: (qName: String) -> BackingType?,
+    private val getQName: (BackingType) -> String?,
+    private val wrap: (type: BackingType) -> T
 ) {
     // instead of something like a Guava cache, we use a map of weak references here because our
     // main goal is avoiding to re-parse type elements as we go up & down in the hierarchy while
@@ -31,12 +33,30 @@
     // could possibly hold a lot more information than we desire.
     private val typeCache = mutableMapOf<String, WeakReference<T>>()
 
+    operator fun get(backingType: BackingType): T {
+        val qName = getQName(backingType)
+        @Suppress("FoldInitializerAndIfToElvis")
+        if (qName == null) {
+            // just wrap without caching, likely an error or local type in kotlin
+            return wrap(backingType)
+        }
+        get(qName)?.let {
+            return it
+        }
+        val wrapped = wrap(backingType)
+        return cache(qName, wrapped)
+    }
+
     operator fun get(qName: String): T? {
         typeCache[qName]?.get()?.let {
             return it
         }
-        val result = findElement(qName) ?: return null
-        typeCache[qName] = WeakReference(result)
-        return result
+        val result = findElement(qName)?.let(wrap) ?: return null
+        return cache(qName, result)
+    }
+
+    private fun cache(qName: String, element: T): T {
+        typeCache[qName] = WeakReference(element)
+        return element
     }
 }
\ No newline at end of file
diff --git a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspProcessingEnv.kt b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspProcessingEnv.kt
index 5e0b43a..6533f64 100644
--- a/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspProcessingEnv.kt
+++ b/room/compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspProcessingEnv.kt
@@ -43,16 +43,21 @@
     override val backend: XProcessingEnv.Backend = XProcessingEnv.Backend.KSP
 
     private val typeElementStore =
-        XTypeElementStore { qName ->
-            resolver.getClassDeclarationByName(
-                KspTypeMapper.swapWithKotlinType(qName)
-            )?.let {
-                KspTypeElement(
-                    env = this,
-                    declaration = it
+        XTypeElementStore(
+            findElement = {
+                resolver.getClassDeclarationByName(
+                    KspTypeMapper.swapWithKotlinType(it)
                 )
+            },
+            getQName = {
+                // for error types or local types, qualified name is null.
+                // it is best to just not cache them
+                it.qualifiedName?.asString()
+            },
+            wrap = {
+                KspTypeElement(this, it)
             }
-        }
+        )
 
     override val messager: XMessager = KspMessager(logger)
 
@@ -194,10 +199,7 @@
     }
 
     fun wrapClassDeclaration(declaration: KSClassDeclaration): KspTypeElement {
-        return KspTypeElement(
-            env = this,
-            declaration = declaration
-        )
+        return typeElementStore[declaration]
     }
 
     class CommonTypes(resolver: Resolver) {
diff --git a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/XProcessingEnvTest.kt b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/XProcessingEnvTest.kt
index 6083bd4..a2057fa 100644
--- a/room/compiler-processing/src/test/java/androidx/room/compiler/processing/XProcessingEnvTest.kt
+++ b/room/compiler-processing/src/test/java/androidx/room/compiler/processing/XProcessingEnvTest.kt
@@ -239,6 +239,28 @@
         }
     }
 
+    @Test
+    fun typeElementsAreCached() {
+        val src = Source.java(
+            "JavaSubject",
+            """
+            class JavaSubject {
+                NestedClass nestedClass;
+                class NestedClass {
+                    int x;
+                }
+            }
+            """.trimIndent()
+        )
+        runProcessorTest(
+            sources = listOf(src)
+        ) { invocation ->
+            val parent = invocation.processingEnv.requireTypeElement("JavaSubject")
+            val nested = invocation.processingEnv.requireTypeElement("JavaSubject.NestedClass")
+            assertThat(nested.enclosingTypeElement).isSameInstanceAs(parent)
+        }
+    }
+
     companion object {
         val PRIMITIVE_TYPES = listOf(
             TypeName.BOOLEAN,