[GH] Change JavacBasicAnnotationProcessor to wrap each step individually.

## Proposed Changes

This commit changes the behavior of JavacBasicAnnotationProcessor to wrap each XProcessingStep individually in its own Step rather than have a single Step that wraps all XProcessingSteps. The issue with wrapping all XProcessingSteps in a single Step is that the underlying BasicAnnotationProcessor triggers the Step, and thus all wrapped XProcessingSteps, when any annotation is found rather than only triggering the Step that handles the annotation.

In order to share the XProcessingEnv for all XProcessingSteps in a round, the first step to process will initialize the XProcessingEnv for other steps. In addition, the postRound call also now shares the XProcessingEnv too. At the end of each postRound I've set the XProcessingEnv to null to allow GC.

## Testing

Test: I've added a test that adds 2 XProcessingSteps to JavacBasicAnnotationProcessor that each handle a different annotation, and then verify that if only one annotation is present in the sources then only the step that handles that annotation is triggered.

## Issues Fixed

Fixes: 192658371

This is an imported pull request from https://github.com/androidx/androidx/pull/208.

Resolves #208
Github-Pr-Head-Sha: a517295af65047cdf5e762defd7b7b8df28dffed
GitOrigin-RevId: 61aa2f99ffc4dc53aca7e1fe487871afd90faac1
Change-Id: If5700b7626b933c52963802f48728548b9c7439e
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/javac/JavacBasicAnnotationProcessor.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/javac/JavacBasicAnnotationProcessor.kt
index f8926ad..f41a368 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/javac/JavacBasicAnnotationProcessor.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/javac/JavacBasicAnnotationProcessor.kt
@@ -18,6 +18,7 @@
 
 import androidx.room.compiler.processing.XBasicAnnotationProcessor
 import androidx.room.compiler.processing.XElement
+import androidx.room.compiler.processing.XProcessingStep
 import androidx.room.compiler.processing.XRoundEnv
 import com.google.auto.common.BasicAnnotationProcessor
 import com.google.common.collect.ImmutableSetMultimap
@@ -31,43 +32,45 @@
 abstract class JavacBasicAnnotationProcessor :
     BasicAnnotationProcessor(), XBasicAnnotationProcessor {
 
-    final override fun steps(): Iterable<Step> {
-        // Execute all processing steps in a single auto-common Step. This is done to share the
-        // XProcessingEnv and its cached across steps in the same round.
-        val steps = processingSteps()
-        val parentStep = object : Step {
-            override fun annotations() = steps.flatMap { it.annotations() }.toSet()
+    // This state is cached here so that it can be shared by all steps in a given processing round.
+    // The state is initialized at beginning of each round using the InitializingStep, and
+    // the state is cleared at the end of each round in BasicAnnotationProcessor#postRound()
+    private var cachedXEnv: JavacProcessingEnv? = null
 
-            override fun process(
-                elementsByAnnotation: ImmutableSetMultimap<String, Element>
-            ): Set<Element> {
-                val xEnv = JavacProcessingEnv(processingEnv)
-                val convertedElementsByAnnotation = mutableMapOf<String, Set<XElement>>()
-                annotations().forEach { annotation ->
-                    convertedElementsByAnnotation[annotation] =
-                        elementsByAnnotation[annotation].mapNotNull { element ->
-                            xEnv.wrapAnnotatedElement(element, annotation)
-                        }.toSet()
-                }
-                val results = steps.flatMap { step ->
-                    step.process(
-                        env = xEnv,
-                        elementsByAnnotation = step.annotations().associateWith {
-                            convertedElementsByAnnotation[it] ?: emptySet()
-                        }
-                    )
-                }
-                return results.map { (it as JavacElement).element }.toSet()
+    final override fun steps(): Iterable<Step> {
+        return processingSteps().map { DelegatingStep(it) }
+    }
+
+    /** A [Step] that delegates to an [XProcessingStep]. */
+    private inner class DelegatingStep(val xStep: XProcessingStep) : Step {
+        override fun annotations() = xStep.annotations()
+
+        override fun process(
+            elementsByAnnotation: ImmutableSetMultimap<String, Element>
+        ): Set<Element> {
+            // The first step in a round initializes the cachedXEnv. Note: the "first" step can
+            // change each round depending on which annotations are present in the current round and
+            // which elements were deferred in the previous round.
+            val xEnv = cachedXEnv ?: JavacProcessingEnv(processingEnv).also { cachedXEnv = it }
+            val xElementsByAnnotation = mutableMapOf<String, Set<XElement>>()
+            xStep.annotations().forEach { annotation ->
+                xElementsByAnnotation[annotation] =
+                    elementsByAnnotation[annotation].mapNotNull { element ->
+                        xEnv.wrapAnnotatedElement(element, annotation)
+                    }.toSet()
             }
+            return xStep.process(xEnv, xElementsByAnnotation).map {
+                (it as JavacElement).element
+            }.toSet()
         }
-        return listOf(parentStep)
     }
 
     final override fun postRound(roundEnv: RoundEnvironment) {
-        // Due to BasicAnnotationProcessor taking over AbstractProcessor#process() we can't
-        // share the same XProcessingEnv from the steps, but that might be ok...
-        val xEnv = JavacProcessingEnv(processingEnv)
+        // The cachedXEnv can be null if none of the steps were processed in the round.
+        // In this case, we just create a new one since there is no cached one to share.
+        val xEnv = cachedXEnv ?: JavacProcessingEnv(processingEnv)
         val xRound = XRoundEnv.create(xEnv, roundEnv)
         postRound(xEnv, xRound)
+        cachedXEnv = null // Reset after every round to allow GC
     }
 }
\ No newline at end of file
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspBasicAnnotationProcessor.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspBasicAnnotationProcessor.kt
index 1fc7b4d..43ea091 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspBasicAnnotationProcessor.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KspBasicAnnotationProcessor.kt
@@ -46,14 +46,24 @@
         val round = XRoundEnv.create(processingEnv)
         val deferredElements = processingSteps().flatMap { step ->
             val invalidElements = mutableSetOf<XElement>()
-            val elementsByAnnotation = step.annotations().associateWith { annotation ->
+            val elementsByAnnotation = step.annotations().mapNotNull { annotation ->
                 val annotatedElements = round.getElementsAnnotatedWith(annotation)
-                annotatedElements
+                val validElements = annotatedElements
                     .filter { (it as KspElement).declaration.validateExceptLocals() }
-                    .also { invalidElements.addAll(annotatedElements - it) }
                     .toSet()
+                invalidElements.addAll(annotatedElements - validElements)
+                if (validElements.isNotEmpty()) {
+                    annotation to validElements
+                } else {
+                    null
+                }
+            }.toMap()
+            // Only process the step if there are annotated elements found for this step.
+            if (elementsByAnnotation.isNotEmpty()) {
+                invalidElements + step.process(processingEnv, elementsByAnnotation)
+            } else {
+                invalidElements
             }
-            invalidElements + step.process(processingEnv, elementsByAnnotation)
         }
         postRound(processingEnv, round)
         return deferredElements.map { (it as KspElement).declaration }
diff --git a/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/XProcessingStepTest.kt b/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/XProcessingStepTest.kt
index b0af4d3..fbd36fc 100644
--- a/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/XProcessingStepTest.kt
+++ b/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/XProcessingStepTest.kt
@@ -268,6 +268,68 @@
     }
 
     @Test
+    fun cachingBetweenSteps() {
+        val main = JavaFileObjects.forSourceString(
+            "foo.bar.Main",
+            """
+            package foo.bar;
+            import androidx.room.compiler.processing.testcode.*;
+            @MainAnnotation(
+                typeList = {},
+                singleType = Object.class,
+                intMethod = 3,
+                singleOtherAnnotation = @OtherAnnotation("y")
+            )
+            class Main {}
+            """.trimIndent()
+        )
+        val other = JavaFileObjects.forSourceString(
+            "foo.bar.Other",
+            """
+            package foo.bar;
+            import androidx.room.compiler.processing.testcode.*;
+            @OtherAnnotation("x")
+            class Other {
+            }
+            """.trimIndent()
+        )
+        val elementsByStep = mutableMapOf<XProcessingStep, XTypeElement>()
+        // create a scenario where we can test caching between steps
+        val mainStep = object : XProcessingStep {
+            override fun annotations(): Set<String> = setOf(MainAnnotation::class.qualifiedName!!)
+            override fun process(
+                env: XProcessingEnv,
+                elementsByAnnotation: Map<String, Set<XElement>>
+            ): Set<XTypeElement> {
+                elementsByStep[this] = env.requireTypeElement("foo.bar.Main")
+                return emptySet()
+            }
+        }
+        val otherStep = object : XProcessingStep {
+            override fun annotations(): Set<String> = setOf(OtherAnnotation::class.qualifiedName!!)
+            override fun process(
+                env: XProcessingEnv,
+                elementsByAnnotation: Map<String, Set<XElement>>
+            ): Set<XTypeElement> {
+                elementsByStep[this] = env.requireTypeElement("foo.bar.Main")
+                return emptySet()
+            }
+        }
+        assertAbout(
+            JavaSourcesSubjectFactory.javaSources()
+        ).that(
+            listOf(main, other)
+        ).processedWith(
+            object : JavacBasicAnnotationProcessor() {
+                override fun processingSteps() = listOf(mainStep, otherStep)
+            }
+        ).compilesWithoutError()
+        assertThat(elementsByStep.keys).containsExactly(mainStep, otherStep)
+        // make sure elements between steps are the same instances
+        assertThat(elementsByStep[mainStep]).isSameInstanceAs(elementsByStep[otherStep])
+    }
+
+    @Test
     fun kspReturnsUnprocessed() {
         CompilationTestCapabilities.assumeKspIsEnabled()
         var returned: Set<XElement>? = null
@@ -388,6 +450,112 @@
     }
 
     @Test
+    fun javacDeferredStep() {
+        // create a scenario where we defer the first round of processing
+        val main = JavaFileObjects.forSourceString(
+            "foo.bar.Main",
+            """
+            package foo.bar;
+            import androidx.room.compiler.processing.testcode.*;
+            @MainAnnotation(
+                typeList = {},
+                singleType = Object.class,
+                intMethod = 3,
+                singleOtherAnnotation = @OtherAnnotation("y")
+            )
+            class Main {}
+            """.trimIndent()
+        )
+        val stepsProcessed = mutableListOf<XProcessingStep>()
+        val mainStep = object : XProcessingStep {
+            var round = 0
+            override fun annotations() = setOf(MainAnnotation::class.qualifiedName!!)
+            override fun process(
+                env: XProcessingEnv,
+                elementsByAnnotation: Map<String, Set<XElement>>
+            ): Set<XElement> {
+                stepsProcessed.add(this)
+                val deferredElements = if (round++ == 0) {
+                    // Generate a random class to trigger another processing round
+                    val className = ClassName.get("foo.bar", "Main_Impl")
+                    val spec = TypeSpec.classBuilder(className).build()
+                    JavaFile.builder(className.packageName(), spec)
+                        .build()
+                        .writeTo(env.filer)
+
+                    // Defer all processing to the next round
+                    elementsByAnnotation.values.flatten().toSet()
+                } else {
+                    emptySet()
+                }
+                return deferredElements
+            }
+        }
+        assertAbout(
+            JavaSourcesSubjectFactory.javaSources()
+        ).that(
+            listOf(main)
+        ).processedWith(
+            object : JavacBasicAnnotationProcessor() {
+                override fun processingSteps() = listOf(mainStep)
+            }
+        ).compilesWithoutError()
+
+        // Assert that mainStep was processed twice due to deferring
+        assertThat(stepsProcessed).containsExactly(mainStep, mainStep)
+    }
+
+    @Test
+    fun javacStepOnlyCalledIfElementsToProcess() {
+        val main = JavaFileObjects.forSourceString(
+            "foo.bar.Main",
+            """
+            package foo.bar;
+            import androidx.room.compiler.processing.testcode.*;
+            @MainAnnotation(
+                typeList = {},
+                singleType = Object.class,
+                intMethod = 3,
+                singleOtherAnnotation = @OtherAnnotation("y")
+            )
+            class Main {
+            }
+            """.trimIndent()
+        )
+        val stepsProcessed = mutableListOf<XProcessingStep>()
+        val mainStep = object : XProcessingStep {
+            override fun annotations() = setOf(MainAnnotation::class.qualifiedName!!)
+            override fun process(
+                env: XProcessingEnv,
+                elementsByAnnotation: Map<String, Set<XElement>>
+            ): Set<XElement> {
+                stepsProcessed.add(this)
+                return emptySet()
+            }
+        }
+        val otherStep = object : XProcessingStep {
+            override fun annotations() = setOf(OtherAnnotation::class.qualifiedName!!)
+            override fun process(
+                env: XProcessingEnv,
+                elementsByAnnotation: Map<String, Set<XElement>>
+            ): Set<XElement> {
+                stepsProcessed.add(this)
+                return emptySet()
+            }
+        }
+        assertAbout(
+            JavaSourcesSubjectFactory.javaSources()
+        ).that(
+            listOf(main)
+        ).processedWith(
+            object : JavacBasicAnnotationProcessor() {
+                override fun processingSteps() = listOf(mainStep, otherStep)
+            }
+        ).compilesWithoutError()
+        assertThat(stepsProcessed).containsExactly(mainStep)
+    }
+
+    @Test
     fun kspAnnotatedElementsByStep() {
         val main = SourceFile.kotlin(
             "Classes.kt",
@@ -449,4 +617,121 @@
         assertThat(elementsByStep[otherStep])
             .containsExactly("foo.bar.Other")
     }
+
+    @Test
+    fun kspDeferredStep() {
+        // create a scenario where we defer the first round of processing
+        val main = SourceFile.kotlin(
+            "Classes.kt",
+            """
+            package foo.bar
+            import androidx.room.compiler.processing.testcode.*
+            @MainAnnotation(
+                typeList = [],
+                singleType = Any::class,
+                intMethod = 3,
+                singleOtherAnnotation = OtherAnnotation("y")
+            )
+            class Main {}
+            """.trimIndent()
+        )
+        val stepsProcessed = mutableListOf<XProcessingStep>()
+        val mainStep = object : XProcessingStep {
+            var round = 0
+            override fun annotations() = setOf(MainAnnotation::class.qualifiedName!!)
+            override fun process(
+                env: XProcessingEnv,
+                elementsByAnnotation: Map<String, Set<XElement>>
+            ): Set<XElement> {
+                stepsProcessed.add(this)
+                val deferredElements = if (round++ == 0) {
+                    // Generate a random class to trigger another processing round
+                    val className = ClassName.get("foo.bar", "Main_Impl")
+                    val spec = TypeSpec.classBuilder(className).build()
+                    JavaFile.builder(className.packageName(), spec)
+                        .build()
+                        .writeTo(env.filer)
+
+                    // Defer all processing to the next round
+                    elementsByAnnotation.values.flatten().toSet()
+                } else {
+                    emptySet()
+                }
+                return deferredElements
+            }
+        }
+
+        val processorProvider = object : SymbolProcessorProvider {
+            override fun create(environment: SymbolProcessorEnvironment): SymbolProcessor {
+                return object : KspBasicAnnotationProcessor(environment) {
+                    override fun processingSteps() = listOf(mainStep)
+                }
+            }
+        }
+        KotlinCompilation().apply {
+            workingDir = temporaryFolder.root
+            inheritClassPath = true
+            symbolProcessorProviders = listOf(processorProvider)
+            sources = listOf(main)
+            verbose = false
+        }.compile()
+
+        // Assert that mainStep was processed twice due to deferring
+        assertThat(stepsProcessed).containsExactly(mainStep, mainStep)
+    }
+
+    @Test
+    fun kspStepOnlyCalledIfElementsToProcess() {
+        val main = SourceFile.kotlin(
+            "Classes.kt",
+            """
+            package foo.bar
+            import androidx.room.compiler.processing.testcode.*
+            @MainAnnotation(
+                typeList = [],
+                singleType = Any::class,
+                intMethod = 3,
+                singleOtherAnnotation = OtherAnnotation("y")
+            )
+            class Main {
+            }
+            """.trimIndent()
+        )
+        val stepsProcessed = mutableListOf<XProcessingStep>()
+        val mainStep = object : XProcessingStep {
+            override fun annotations() = setOf(MainAnnotation::class.qualifiedName!!)
+            override fun process(
+                env: XProcessingEnv,
+                elementsByAnnotation: Map<String, Set<XElement>>
+            ): Set<XElement> {
+                stepsProcessed.add(this)
+                return emptySet()
+            }
+        }
+        val otherStep = object : XProcessingStep {
+            override fun annotations() = setOf(OtherAnnotation::class.qualifiedName!!)
+            override fun process(
+                env: XProcessingEnv,
+                elementsByAnnotation: Map<String, Set<XElement>>
+            ): Set<XElement> {
+                stepsProcessed.add(this)
+                return emptySet()
+            }
+        }
+        val processorProvider = object : SymbolProcessorProvider {
+            override fun create(environment: SymbolProcessorEnvironment): SymbolProcessor {
+                return object : KspBasicAnnotationProcessor(environment) {
+                    override fun processingSteps() = listOf(mainStep, otherStep)
+                }
+            }
+        }
+        KotlinCompilation().apply {
+            workingDir = temporaryFolder.root
+            inheritClassPath = true
+            symbolProcessorProviders = listOf(processorProvider)
+            sources = listOf(main)
+            verbose = false
+        }.compile()
+        assertThat(stepsProcessed).containsExactly(mainStep)
+    }
 }
\ No newline at end of file