Add support for testing multi round processors

This CL adds new APIs to run multi-round processors.

Most compilation tests need to run 1 round hence, by default, we invoke
the callback once. If a test wants another round, it needs to register
another callback, which is very similar to an existing api we have in
TestProcessor.nextRunHandler.

Also added utility methods to assert partial diagnostic messages.

Bug: 160322705
Test: MultiRoundProcessingTest, TestRunnerTest

Change-Id: Ic3fbef82d88be7beaa88bd1dff8bf4516c62fd27
diff --git a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/SyntheticJavacProcessor.kt b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/SyntheticJavacProcessor.kt
index 1eb4ef8..d80bd31 100644
--- a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/SyntheticJavacProcessor.kt
+++ b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/SyntheticJavacProcessor.kt
@@ -16,32 +16,27 @@
 
 package androidx.room.compiler.processing
 
-import androidx.room.compiler.processing.util.RecordingXMessager
 import androidx.room.compiler.processing.util.XTestInvocation
 import javax.lang.model.SourceVersion
 
 @Suppress("VisibleForTests")
-class SyntheticJavacProcessor(
-    val handler: (XTestInvocation) -> Unit,
-) : JavacTestProcessor(), SyntheticProcessor {
-    override val invocationInstances = mutableListOf<XTestInvocation>()
-    private var result: Result<Unit>? = null
-    override val messageWatcher = RecordingXMessager()
-
+class SyntheticJavacProcessor private constructor(
+    private val impl: SyntheticProcessorImpl
+) : JavacTestProcessor(), SyntheticProcessor by impl {
+    constructor(handlers: List<(XTestInvocation) -> Unit>) : this(
+        SyntheticProcessorImpl(handlers)
+    )
     override fun doProcess(annotations: Set<XTypeElement>, roundEnv: XRoundEnv): Boolean {
-        val xEnv = XProcessingEnv.create(processingEnv)
-        xEnv.messager.addMessageWatcher(messageWatcher)
-        result = kotlin.runCatching {
-            handler(
-                XTestInvocation(
-                    processingEnv = xEnv,
-                    roundEnv = roundEnv
-                ).also {
-                    invocationInstances.add(it)
-                }
-            )
+        if (!impl.canRunAnotherRound()) {
+            return true
         }
-        return true
+        val xEnv = XProcessingEnv.create(processingEnv)
+        val testInvocation = XTestInvocation(
+            processingEnv = xEnv,
+            roundEnv = roundEnv
+        )
+        impl.runNextRound(testInvocation)
+        return impl.expectsAnotherRound()
     }
 
     override fun getSupportedSourceVersion(): SourceVersion {
@@ -49,15 +44,4 @@
     }
 
     override fun getSupportedAnnotationTypes() = setOf("*")
-
-    override fun getProcessingException(): Throwable? {
-        val result = this.result ?: return AssertionError("processor didn't run")
-        result.exceptionOrNull()?.let {
-            return it
-        }
-        if (result.isFailure) {
-            return AssertionError("processor failed but no exception is reported")
-        }
-        return null
-    }
 }
\ No newline at end of file
diff --git a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/SyntheticKspProcessor.kt b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/SyntheticKspProcessor.kt
index 598e983..d1bceda 100644
--- a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/SyntheticKspProcessor.kt
+++ b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/SyntheticKspProcessor.kt
@@ -16,7 +16,6 @@
 
 package androidx.room.compiler.processing
 
-import androidx.room.compiler.processing.util.RecordingXMessager
 import androidx.room.compiler.processing.util.XTestInvocation
 import com.google.devtools.ksp.processing.CodeGenerator
 import com.google.devtools.ksp.processing.KSPLogger
@@ -24,15 +23,15 @@
 import com.google.devtools.ksp.processing.SymbolProcessor
 import com.google.devtools.ksp.symbol.KSAnnotated
 
-class SyntheticKspProcessor(
-    private val handler: (XTestInvocation) -> Unit
-) : SymbolProcessor, SyntheticProcessor {
-    override val invocationInstances = mutableListOf<XTestInvocation>()
-    private var result: Result<Unit>? = null
+class SyntheticKspProcessor private constructor(
+    private val impl: SyntheticProcessorImpl
+) : SymbolProcessor, SyntheticProcessor by impl {
+    constructor(handlers: List<(XTestInvocation) -> Unit>) : this(
+        SyntheticProcessorImpl(handlers)
+    )
     private lateinit var options: Map<String, String>
     private lateinit var codeGenerator: CodeGenerator
     private lateinit var logger: KSPLogger
-    override val messageWatcher = RecordingXMessager()
 
     override fun finish() {
     }
@@ -49,34 +48,20 @@
     }
 
     override fun process(resolver: Resolver): List<KSAnnotated> {
+        if (!impl.canRunAnotherRound()) {
+            return emptyList()
+        }
         val xEnv = XProcessingEnv.create(
             options,
             resolver,
             codeGenerator,
             logger
         )
-        xEnv.messager.addMessageWatcher(messageWatcher)
-        result = kotlin.runCatching {
-            handler(
-                XTestInvocation(
-                    processingEnv = xEnv,
-                    roundEnv = XRoundEnv.create(xEnv)
-                ).also {
-                    invocationInstances.add(it)
-                }
-            )
-        }
+        val testInvocation = XTestInvocation(
+            processingEnv = xEnv,
+            roundEnv = XRoundEnv.create(xEnv)
+        )
+        impl.runNextRound(testInvocation)
         return emptyList()
     }
-
-    override fun getProcessingException(): Throwable? {
-        val result = this.result ?: return AssertionError("processor didn't run")
-        result.exceptionOrNull()?.let {
-            return it
-        }
-        if (result.isFailure) {
-            return AssertionError("processor failed but no exception is reported")
-        }
-        return null
-    }
 }
diff --git a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/SyntheticProcessor.kt b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/SyntheticProcessor.kt
index a9872df9a..8cd0076 100644
--- a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/SyntheticProcessor.kt
+++ b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/SyntheticProcessor.kt
@@ -43,4 +43,67 @@
      * dispatch them afterwards.
      */
     fun getProcessingException(): Throwable?
+
+    /**
+     * Returns true if the processor expected to run another round.
+     */
+    fun expectsAnotherRound(): Boolean
+}
+
+/**
+ * Helper class to implement [SyntheticProcessor] processor that handles the communication with
+ * the testing infrastructure.
+ */
+internal class SyntheticProcessorImpl(
+    handlers: List<(XTestInvocation) -> Unit>
+) : SyntheticProcessor {
+    private var result: Result<Unit>? = null
+    override val invocationInstances = mutableListOf<XTestInvocation>()
+    private val nextRunHandlers = handlers.toMutableList()
+    override val messageWatcher = RecordingXMessager()
+
+    override fun expectsAnotherRound(): Boolean {
+        return nextRunHandlers.isNotEmpty()
+    }
+
+    /**
+     * Returns true if this can run another round, which means previous round didn't throw an
+     * exception and there is another handler in the queue.
+     */
+    fun canRunAnotherRound(): Boolean {
+        if (result?.exceptionOrNull() != null) {
+            // if there is an existing failure from a previous run, stop
+            return false
+        }
+        return expectsAnotherRound()
+    }
+
+    override fun getProcessingException(): Throwable? {
+        val result = this.result ?: return AssertionError("processor didn't run")
+        result.exceptionOrNull()?.let {
+            return it
+        }
+        if (result.isFailure) {
+            return AssertionError("processor failed but no exception is reported")
+        }
+        return null
+    }
+
+    /**
+     * Runs the next handler with the given test invocation.
+     */
+    fun runNextRound(
+        invocation: XTestInvocation
+    ) {
+        check(nextRunHandlers.isNotEmpty()) {
+            "Called run next round w/o a runner to handle it. Looks like a testing infra bug"
+        }
+        val handler = nextRunHandlers.removeAt(0)
+        invocationInstances.add(invocation)
+        invocation.processingEnv.messager.addMessageWatcher(messageWatcher)
+        result = kotlin.runCatching {
+            handler(invocation)
+            invocation.dispose()
+        }
+    }
 }
diff --git a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/CompilationResultSubject.kt b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/CompilationResultSubject.kt
index b3ed7a3..4c94140 100644
--- a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/CompilationResultSubject.kt
+++ b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/CompilationResultSubject.kt
@@ -112,26 +112,95 @@
      * Asserts that compilation has a warning with the given text.
      *
      * @see hasError
+     * @see hasNote
      */
     fun hasWarning(expected: String) = chain {
         hasDiagnosticWithMessage(
             kind = Diagnostic.Kind.WARNING,
-            expected = expected
+            expected = expected,
+            acceptPartialMatch = false
         ) {
             "expected warning: $expected"
         }
     }
 
     /**
+     * Asserts that compilation has a warning that contains the given text.
+     *
+     * @see hasErrorContaining
+     * @see hasNoteContaining
+     */
+    fun hasWarningContaining(expected: String) = chain {
+        hasDiagnosticWithMessage(
+            kind = Diagnostic.Kind.WARNING,
+            expected = expected,
+            acceptPartialMatch = true
+        ) {
+            "expected warning: $expected"
+        }
+    }
+
+    /**
+     * Asserts that compilation has a note with the given text.
+     *
+     * @see hasError
+     * @see hasWarning
+     */
+    fun hasNote(expected: String) = chain {
+        hasDiagnosticWithMessage(
+            kind = Diagnostic.Kind.NOTE,
+            expected = expected,
+            acceptPartialMatch = false
+        ) {
+            "expected note: $expected"
+        }
+    }
+
+    /**
+     * Asserts that compilation has a note that contains the given text.
+     *
+     * @see hasErrorContaining
+     * @see hasWarningContaining
+     */
+    fun hasNoteContaining(expected: String) = chain {
+        hasDiagnosticWithMessage(
+            kind = Diagnostic.Kind.NOTE,
+            expected = expected,
+            acceptPartialMatch = true
+        ) {
+            "expected note: $expected"
+        }
+    }
+
+    /**
      * Asserts that compilation has an error with the given text.
      *
      * @see hasWarning
+     * @see hasNote
      */
     fun hasError(expected: String) = chain {
         shouldSucceed = false
         hasDiagnosticWithMessage(
             kind = Diagnostic.Kind.ERROR,
-            expected = expected
+            expected = expected,
+            acceptPartialMatch = false
+        ) {
+            "expected error: $expected"
+        }
+    }
+
+    /**
+     * Asserts that compilation has an error that contains the given text.
+     *
+     * @see hasWarningContaining
+     * @see hasNoteContaining
+     */
+    fun hasErrorContaining(expected: String) = chain {
+        shouldSucceed = false
+        hasDiagnosticWithMessage(
+            kind = Diagnostic.Kind.ERROR,
+            expected = expected,
+            acceptPartialMatch = true
         ) {
             "expected error: $expected"
         }
@@ -194,6 +263,18 @@
         }
     }
 
+    /**
+     * Checks if the processor has any remaining rounds that did not run which would possibly
+     * mean it didn't run assertions it wanted to run.
+     */
+    internal fun assertAllExpectedRoundsAreCompleted() {
+        if (compilationResult.processor.expectsAnotherRound()) {
+            failWithActual(
+                simpleFact("Test runner requested another round but that didn't happen")
+            )
+        }
+    }
+
     internal fun assertNoProcessorAssertionErrors() {
         val processingException = compilationResult.processor.getProcessingException()
         if (processingException != null) {
@@ -209,12 +290,16 @@
     private fun hasDiagnosticWithMessage(
         kind: Diagnostic.Kind,
         expected: String,
+        acceptPartialMatch: Boolean,
         buildErrorMessage: () -> String
     ) {
         val diagnostics = compilationResult.diagnosticsOfKind(kind)
         if (diagnostics.any { it.msg == expected }) {
             return
         }
+        if (acceptPartialMatch && diagnostics.any { it.msg.contains(expected) }) {
+            return
+        }
         failWithActual(simpleFact(buildErrorMessage()))
     }
 
diff --git a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/ProcessorTestExt.kt b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/ProcessorTestExt.kt
index e74f1b9..b6ca3bc 100644
--- a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/ProcessorTestExt.kt
+++ b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/ProcessorTestExt.kt
@@ -46,6 +46,7 @@
             ).isNotEmpty()
 
             subject.assertCompilationResult()
+            subject.assertAllExpectedRoundsAreCompleted()
             true
         } else {
             false
@@ -75,7 +76,7 @@
         params = TestCompilationParameters(
             sources = sources,
             classpath = classpath,
-            handler = handler
+            handlers = listOf(handler)
         ),
         JavacCompilationTestRunner,
         KaptCompilationTestRunner
@@ -86,7 +87,8 @@
  * Runs the compilation test with all 3 backends (javac, kapt, ksp) if possible (e.g. javac
  * cannot test kotlin sources).
  *
- * The [handler] will be invoked for each compilation hence it should be repeatable.
+ * The [handler] will be invoked only for the first round. If you need to test multi round
+ * processing, use `handlers = listOf(..., ...)`.
  *
  * To assert on the compilation results, [handler] can call
  * [XTestInvocation.assertCompilationResult] where it will receive a subject for post compilation
@@ -100,12 +102,21 @@
     sources: List<Source> = emptyList(),
     classpath: List<File> = emptyList(),
     handler: (XTestInvocation) -> Unit
+) = runProcessorTest(sources = sources, classpath = classpath, handlers = listOf(handler))
+
+/**
+ * @see runProcessorTest
+ */
+fun runProcessorTest(
+    sources: List<Source> = emptyList(),
+    classpath: List<File> = emptyList(),
+    handlers: List<(XTestInvocation) -> Unit>
 ) {
     runTests(
         params = TestCompilationParameters(
             sources = sources,
             classpath = classpath,
-            handler = handler
+            handlers = handlers
         ),
         JavacCompilationTestRunner,
         KaptCompilationTestRunner,
@@ -122,12 +133,25 @@
     sources: List<Source>,
     classpath: List<File> = emptyList(),
     handler: (XTestInvocation) -> Unit
+) = runJavaProcessorTest(
+    sources = sources,
+    classpath = classpath,
+    handlers = listOf(handler)
+)
+
+/**
+ * @see runJavaProcessorTest
+ */
+fun runJavaProcessorTest(
+    sources: List<Source>,
+    classpath: List<File> = emptyList(),
+    handlers: List<(XTestInvocation) -> Unit>
 ) {
     runTests(
         params = TestCompilationParameters(
             sources = sources,
             classpath = classpath,
-            handler = handler
+            handlers = handlers
         ),
         JavacCompilationTestRunner
     )
@@ -140,12 +164,25 @@
     sources: List<Source>,
     classpath: List<File> = emptyList(),
     handler: (XTestInvocation) -> Unit
+) = runKaptTest(
+    sources = sources,
+    classpath = classpath,
+    handlers = listOf(handler)
+)
+
+/**
+ * @see runKaptTest
+ */
+fun runKaptTest(
+    sources: List<Source>,
+    classpath: List<File> = emptyList(),
+    handlers: List<(XTestInvocation) -> Unit>
 ) {
     runTests(
         params = TestCompilationParameters(
             sources = sources,
             classpath = classpath,
-            handler = handler
+            handlers = handlers
         ),
         KaptCompilationTestRunner
     )
@@ -158,12 +195,25 @@
     sources: List<Source>,
     classpath: List<File> = emptyList(),
     handler: (XTestInvocation) -> Unit
+) = runKspTest(
+    sources = sources,
+    classpath = classpath,
+    handlers = listOf(handler)
+)
+
+/**
+ * @see runKspTest
+ */
+fun runKspTest(
+    sources: List<Source>,
+    classpath: List<File> = emptyList(),
+    handlers: List<(XTestInvocation) -> Unit>
 ) {
     runTests(
         params = TestCompilationParameters(
             sources = sources,
             classpath = classpath,
-            handler = handler
+            handlers = handlers
         ),
         KspCompilationTestRunner
     )
diff --git a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/XTestInvocation.kt b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/XTestInvocation.kt
index 8f927343..25faf77 100644
--- a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/XTestInvocation.kt
+++ b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/XTestInvocation.kt
@@ -18,21 +18,40 @@
 
 import androidx.room.compiler.processing.XProcessingEnv
 import androidx.room.compiler.processing.XRoundEnv
+import com.google.common.truth.Truth
 import kotlin.reflect.KClass
 
 /**
  * Data holder for XProcessing tests to access the processing environment.
  */
 class XTestInvocation(
-    val processingEnv: XProcessingEnv,
-    val roundEnv: XRoundEnv
+    processingEnv: XProcessingEnv,
+    roundEnv: XRoundEnv
 ) {
+    val processingEnv: XProcessingEnv = processingEnv
+        get() {
+            assertNotDisposed()
+            return field
+        }
+    val roundEnv: XRoundEnv = roundEnv
+        get() {
+            assertNotDisposed()
+            return field
+        }
+
+    /**
+     * Set to true after callback is called to ensure the test does not re-use an invocation that
+     * is no longer usable (no longer in the process method of the processor)
+     */
+    private var disposed = false
+
     /**
      * Extension mechanism to allow putting objects into invocation that can be retrieved later.
      */
     private val userData = mutableMapOf<KClass<*>, Any>()
 
     private val postCompilationAssertions = mutableListOf<CompilationResultSubject.() -> Unit>()
+
     val isKsp: Boolean
         get() = processingEnv.backend == XProcessingEnv.Backend.KSP
 
@@ -43,6 +62,7 @@
      * Note that it is not safe to access the environment in this block.
      */
     fun assertCompilationResult(block: CompilationResultSubject.() -> Unit) {
+        assertNotDisposed()
         postCompilationAssertions.add(block)
     }
 
@@ -55,11 +75,13 @@
     }
 
     fun <T : Any> getUserData(key: KClass<T>): T? {
+        assertNotDisposed()
         @Suppress("UNCHECKED_CAST")
         return userData[key] as T?
     }
 
     fun <T : Any> putUserData(key: KClass<T>, value: T) {
+        assertNotDisposed()
         userData[key] = value
     }
 
@@ -71,4 +93,14 @@
             putUserData(key, it)
         }
     }
+
+    fun dispose() {
+        disposed = true
+    }
+
+    private fun assertNotDisposed() {
+        Truth.assertWithMessage("Cannot use a test invocation after it is disposed.")
+            .that(disposed)
+            .isFalse()
+    }
 }
\ No newline at end of file
diff --git a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/runner/CompilationTestRunner.kt b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/runner/CompilationTestRunner.kt
index 257e58a..43bca34 100644
--- a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/runner/CompilationTestRunner.kt
+++ b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/runner/CompilationTestRunner.kt
@@ -36,5 +36,5 @@
 internal data class TestCompilationParameters(
     val sources: List<Source> = emptyList(),
     val classpath: List<File> = emptyList(),
-    val handler: (XTestInvocation) -> Unit
+    val handlers: List<(XTestInvocation) -> Unit>
 )
diff --git a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/runner/JavacCompilationTestRunner.kt b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/runner/JavacCompilationTestRunner.kt
index 5fdba61..a70f146 100644
--- a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/runner/JavacCompilationTestRunner.kt
+++ b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/runner/JavacCompilationTestRunner.kt
@@ -31,7 +31,7 @@
     }
 
     override fun compile(params: TestCompilationParameters): CompilationResult {
-        val syntheticJavacProcessor = SyntheticJavacProcessor(params.handler)
+        val syntheticJavacProcessor = SyntheticJavacProcessor(params.handlers)
         val sources = if (params.sources.isEmpty()) {
             // synthesize a source to trigger compilation
             listOf(
diff --git a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/runner/KaptCompilationTestRunner.kt b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/runner/KaptCompilationTestRunner.kt
index 438ee44..b4082c2 100644
--- a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/runner/KaptCompilationTestRunner.kt
+++ b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/runner/KaptCompilationTestRunner.kt
@@ -31,7 +31,7 @@
     }
 
     override fun compile(params: TestCompilationParameters): CompilationResult {
-        val syntheticJavacProcessor = SyntheticJavacProcessor(params.handler)
+        val syntheticJavacProcessor = SyntheticJavacProcessor(params.handlers)
         val compilation = KotlinCompilationUtil.prepareCompilation(
             sources = params.sources,
             classpaths = params.classpath
diff --git a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/runner/KspCompilationTestRunner.kt b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/runner/KspCompilationTestRunner.kt
index c562e7d..85916fa 100644
--- a/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/runner/KspCompilationTestRunner.kt
+++ b/room/compiler-processing-testing/src/main/java/androidx/room/compiler/processing/util/runner/KspCompilationTestRunner.kt
@@ -46,7 +46,7 @@
         } else {
             params.sources
         }
-        val syntheticKspProcessor = SyntheticKspProcessor(params.handler)
+        val syntheticKspProcessor = SyntheticKspProcessor(params.handlers)
 
         val kspCompilation = KotlinCompilationUtil.prepareCompilation(
             sources,
diff --git a/room/compiler-processing-testing/src/test/java/androidx/room/compiler/processing/util/GeneratedCodeMatchTest.kt b/room/compiler-processing-testing/src/test/java/androidx/room/compiler/processing/util/GeneratedCodeMatchTest.kt
index 84adf25..832937d 100644
--- a/room/compiler-processing-testing/src/test/java/androidx/room/compiler/processing/util/GeneratedCodeMatchTest.kt
+++ b/room/compiler-processing-testing/src/test/java/androidx/room/compiler/processing/util/GeneratedCodeMatchTest.kt
@@ -24,12 +24,10 @@
 import org.junit.runner.RunWith
 import org.junit.runners.Parameterized
 
-typealias TestRunner = (block: (XTestInvocation) -> Unit) -> Unit
-
 @RunWith(Parameterized::class)
 class GeneratedCodeMatchTest internal constructor(
     private val runTest: TestRunner
-) {
+) : MultiBackendTest() {
     @Test
     fun successfulGeneratedCodeMatch() {
         val file = JavaFile.builder(
@@ -113,24 +111,4 @@
         )
         assertThat(result.exceptionOrNull()).hasMessageThat().contains(mismatch.toString())
     }
-
-    companion object {
-        @JvmStatic
-        @Parameterized.Parameters
-        fun runners(): List<TestRunner> = listOfNotNull(
-            { block: (XTestInvocation) -> Unit ->
-                runJavaProcessorTest(sources = emptyList(), handler = block)
-            },
-            { block: (XTestInvocation) -> Unit ->
-                runKaptTest(sources = emptyList(), handler = block)
-            },
-            if (CompilationTestCapabilities.canTestWithKsp) {
-                { block: (XTestInvocation) -> Unit ->
-                    runKspTest(sources = emptyList(), handler = block)
-                }
-            } else {
-                null
-            }
-        )
-    }
 }
diff --git a/room/compiler-processing-testing/src/test/java/androidx/room/compiler/processing/util/MultiBackendTest.kt b/room/compiler-processing-testing/src/test/java/androidx/room/compiler/processing/util/MultiBackendTest.kt
new file mode 100644
index 0000000..80e2465
--- /dev/null
+++ b/room/compiler-processing-testing/src/test/java/androidx/room/compiler/processing/util/MultiBackendTest.kt
@@ -0,0 +1,53 @@
+/*
+ * Copyright 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.room.compiler.processing.util
+
+import org.junit.runners.Parameterized
+
+class TestRunner(
+    private val name: String,
+    private val runner: (List<(XTestInvocation) -> Unit>) -> Unit
+) {
+    operator fun invoke(handlers: List<(XTestInvocation) -> Unit>) = runner(handlers)
+    operator fun invoke(handler: (XTestInvocation) -> Unit) = runner(listOf(handler))
+    override fun toString() = name
+}
+
+/**
+ * Helper test runner class to run tests for each backend in isolation
+ */
+abstract class MultiBackendTest {
+    companion object {
+        @JvmStatic
+        @Parameterized.Parameters(name = "{0}")
+        fun runners(): List<TestRunner> = listOfNotNull(
+            TestRunner("java") {
+                runJavaProcessorTest(sources = emptyList(), handlers = it)
+            },
+            TestRunner("kapt") {
+                runKaptTest(sources = emptyList(), handlers = it)
+            },
+            if (CompilationTestCapabilities.canTestWithKsp) {
+                TestRunner("ksp") {
+                    runKspTest(sources = emptyList(), handlers = it)
+                }
+            } else {
+                null
+            }
+        )
+    }
+}
diff --git a/room/compiler-processing-testing/src/test/java/androidx/room/compiler/processing/util/MultiRoundProcessingTest.kt b/room/compiler-processing-testing/src/test/java/androidx/room/compiler/processing/util/MultiRoundProcessingTest.kt
new file mode 100644
index 0000000..0313f36
--- /dev/null
+++ b/room/compiler-processing-testing/src/test/java/androidx/room/compiler/processing/util/MultiRoundProcessingTest.kt
@@ -0,0 +1,200 @@
+/*
+ * Copyright 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.room.compiler.processing.util
+
+import com.google.common.truth.Truth.assertThat
+import com.squareup.javapoet.ClassName
+import com.squareup.javapoet.JavaFile
+import com.squareup.javapoet.TypeSpec
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+import javax.tools.Diagnostic
+
+@RunWith(Parameterized::class)
+class MultiRoundProcessingTest(
+    private val testRunner: TestRunner
+) : MultiBackendTest() {
+    private fun generateCode(index: Int): JavaFile {
+        val typeSpec = TypeSpec.classBuilder(
+            ClassName.bestGuess("foo.bar.Baz$index")
+        ).build()
+        return JavaFile.builder("foo.bar", typeSpec).build()
+    }
+
+    @Test
+    fun dontRequestAnotherRound() {
+        var runCnt = 0
+        testRunner {
+            runCnt++
+        }
+        // only run 1 if a second round is not explicitly requested
+        assertThat(runCnt).isEqualTo(1)
+    }
+
+    @Suppress("NAME_SHADOWING") // intentional to avoid accessing the wrong one
+    @Test
+    fun multipleRounds() {
+        var runCnt = 0
+        fun checkAndIncrementRunCount(expected: Int) {
+            assertThat(runCnt).isEqualTo(expected)
+            runCnt++
+        }
+        testRunner(
+            handlers = listOf(
+                { invocation ->
+                    checkAndIncrementRunCount(0)
+                    invocation.processingEnv.filer.write(generateCode(0))
+                },
+                { invocation ->
+                    checkAndIncrementRunCount(1)
+                    invocation.processingEnv.filer.write(generateCode(1))
+                },
+                {
+                    checkAndIncrementRunCount(2)
+                }
+            )
+        )
+        checkAndIncrementRunCount(3)
+    }
+
+    @Suppress("NAME_SHADOWING") // intentional to avoid accessing the wrong one
+    @Test
+    fun validateMessagesFromDifferentRounds() {
+        var didRunFirstRoundAssertions = false
+        var didRunSecondRoundAssertions = false
+        testRunner(
+            handlers = listOf(
+                { invocation ->
+                    invocation.processingEnv.messager.printMessage(
+                        Diagnostic.Kind.NOTE,
+                        "note from 1"
+                    )
+                    invocation.processingEnv.messager.printMessage(
+                        Diagnostic.Kind.WARNING,
+                        "warning from 1"
+                    )
+                    invocation.processingEnv.filer.write(generateCode(0))
+                    invocation.assertCompilationResult {
+                        // can assert diagnostics from followup rounds
+                        hasWarning("warning from 1")
+                        hasWarning("warning from 2")
+                        hasError("error from 2")
+                        hasNote("note from 1")
+                        hasNote("note from 2")
+                        didRunFirstRoundAssertions = true
+                    }
+                },
+                { invocation ->
+                    check(!didRunFirstRoundAssertions) {
+                        "shouldn't run assertions before all runs are completed"
+                    }
+                    invocation.processingEnv.messager.printMessage(
+                        Diagnostic.Kind.NOTE,
+                        "note from 2"
+                    )
+                    invocation.processingEnv.messager.printMessage(
+                        Diagnostic.Kind.WARNING,
+                        "warning from 2"
+                    )
+                    invocation.processingEnv.messager.printMessage(
+                        Diagnostic.Kind.ERROR,
+                        "error from 2"
+                    )
+                    invocation.assertCompilationResult {
+                        hasWarning("warning from 1")
+                        hasWarning("warning from 2")
+                        hasError("error from 2")
+                        hasNote("note from 1")
+                        hasNote("note from 2")
+                        didRunSecondRoundAssertions = true
+                    }
+                }
+            )
+        )
+        // just to make sure test didn't pass by failing to run assertions.
+        assertThat(didRunFirstRoundAssertions).isTrue()
+        assertThat(didRunSecondRoundAssertions).isTrue()
+    }
+
+    @Test
+    fun validateIfRequestedRoundIsRun() {
+        val result = runCatching {
+            testRunner(
+                handlers = listOf(
+                    {},
+                    {
+                        // this won't happen because no code is generated in the first run
+                    }
+                )
+            )
+        }
+        assertThat(
+            result.isFailure
+        ).isTrue()
+        assertThat(
+            result.exceptionOrNull()
+        ).hasMessageThat()
+            .contains(
+                "Test runner requested another round but that didn't happen"
+            )
+    }
+
+    @Test
+    fun accessingDisposedHandlerIsNotAllowed() {
+        val result = runCatching {
+            lateinit var previousInvocation: XTestInvocation
+            testRunner(
+                handlers = listOf(
+                    { invocation1 ->
+                        invocation1.processingEnv.filer.write(generateCode(0))
+                        previousInvocation = invocation1
+                    },
+                    {
+                        previousInvocation.processingEnv.filer.write(generateCode(1))
+                    }
+                )
+            )
+        }
+        assertThat(
+            result.exceptionOrNull()?.cause
+        ).hasMessageThat()
+            .contains("Cannot use a test invocation after it is disposed")
+    }
+
+    @Test
+    fun checkFailureFromAPreviousRoundIsNotMissed() {
+        val result = runCatching {
+            testRunner(
+                handlers = listOf(
+                    { invocation1 ->
+                        invocation1.processingEnv.filer.write(generateCode(0))
+                        // this will fail
+                        throw AssertionError("i failed")
+                    },
+                    {
+                        // this won't run
+                        throw AssertionError("this shouldn't run as prev one failed")
+                    }
+                )
+            )
+        }
+        assertThat(
+            result.exceptionOrNull()?.cause
+        ).hasMessageThat().contains("i failed")
+    }
+}
\ No newline at end of file
diff --git a/room/compiler-processing-testing/src/test/java/androidx/room/compiler/processing/util/TestRunnerTest.kt b/room/compiler-processing-testing/src/test/java/androidx/room/compiler/processing/util/TestRunnerTest.kt
index 051dac1..7dcc46f 100644
--- a/room/compiler-processing-testing/src/test/java/androidx/room/compiler/processing/util/TestRunnerTest.kt
+++ b/room/compiler-processing-testing/src/test/java/androidx/room/compiler/processing/util/TestRunnerTest.kt
@@ -57,6 +57,44 @@
     @Test(expected = AssertionError::class)
     fun reportedError_unexpected() = reportedError(assertFailure = false)
 
+    @Test
+    fun diagnosticsMessages() {
+        runProcessorTest { invocation ->
+            invocation.processingEnv.messager.run {
+                printMessage(Diagnostic.Kind.NOTE, "note 1")
+                printMessage(Diagnostic.Kind.WARNING, "warn 1")
+                printMessage(Diagnostic.Kind.ERROR, "error 1")
+            }
+            invocation.assertCompilationResult {
+                hasNote("note 1")
+                hasWarning("warn 1")
+                hasError("error 1")
+                hasNoteContaining("ote")
+                hasWarningContaining("arn")
+                hasErrorContaining("rror")
+                // these should fail:
+                assertThat(
+                    runCatching { hasNote("note") }.isFailure
+                ).isTrue()
+                assertThat(
+                    runCatching { hasWarning("warn") }.isFailure
+                ).isTrue()
+                assertThat(
+                    runCatching { hasError("error") }.isFailure
+                ).isTrue()
+                assertThat(
+                    runCatching { hasNoteContaining("error") }.isFailure
+                ).isTrue()
+                assertThat(
+                    runCatching { hasWarningContaining("note") }.isFailure
+                ).isTrue()
+                assertThat(
+                    runCatching { hasErrorContaining("warning") }.isFailure
+                ).isTrue()
+            }
+        }
+    }
+
     private fun reportedError(assertFailure: Boolean) {
         runProcessorTest {
             it.processingEnv.messager.printMessage(