Fix InvalidationTracker.Observer registration on main thread

Ensure that observer resgiration happens on RoomDatabase's query
executor instead of the thread that calls loadFuture which could happen
on main thread

Test: ./gradlew room:room-paging-guava:cC
Fixes: 229875775
Change-Id: Ia299eeacc84887807540a56246f42ea5b8fc0163
diff --git a/room/room-paging-guava/src/androidTest/kotlin/androidx/room/paging/guava/LimitOffsetListenableFuturePagingSourceTest.kt b/room/room-paging-guava/src/androidTest/kotlin/androidx/room/paging/guava/LimitOffsetListenableFuturePagingSourceTest.kt
index a9fd319..cf122e9 100644
--- a/room/room-paging-guava/src/androidTest/kotlin/androidx/room/paging/guava/LimitOffsetListenableFuturePagingSourceTest.kt
+++ b/room/room-paging-guava/src/androidTest/kotlin/androidx/room/paging/guava/LimitOffsetListenableFuturePagingSourceTest.kt
@@ -30,6 +30,7 @@
 import androidx.room.Room
 import androidx.room.RoomDatabase
 import androidx.room.RoomSQLiteQuery
+import androidx.room.paging.util.ThreadSafeInvalidationObserver
 import androidx.room.util.getColumnIndexOrThrow
 import androidx.sqlite.db.SimpleSQLiteQuery
 import androidx.test.core.app.ApplicationProvider
@@ -45,6 +46,7 @@
 import java.util.concurrent.CancellationException
 import java.util.concurrent.Executor
 import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicBoolean
 import kotlin.test.assertFailsWith
 import kotlin.test.assertFalse
 import kotlin.test.assertTrue
@@ -67,8 +69,32 @@
     val countingTaskExecutorRule = CountingTaskExecutorRule()
 
     @Test
+    fun initialLoad_registersInvalidationObserver() =
+        setupAndRunWithTestExecutor { db, queryExecutor, _ ->
+            val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(
+                db = db,
+                isInitialLoad = true
+            )
+
+            val listenableFuture = pagingSource.refresh()
+            assertFalse(pagingSource.privateObserver().privateRegisteredState().get())
+
+            // observer registration is queued up on queryExecutor by refresh() call
+            queryExecutor.executeAll()
+
+            assertTrue(pagingSource.privateObserver().privateRegisteredState().get())
+            // note that listenableFuture is not done yet
+            // The future has been transformed into a ListenableFuture<LoadResult> whose result
+            // is still pending
+            assertFalse(listenableFuture.isDone)
+        }
+
+    @Test
     fun initialEmptyLoad_futureIsDone() = setupAndRun { db ->
-        val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(db)
+        val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(
+            db = db,
+            isInitialLoad = true
+        )
 
         runTest {
             val listenableFuture = pagingSource.refresh()
@@ -81,16 +107,19 @@
 
     @Test
     fun initialLoad_returnsFutureImmediately() =
-        setupAndRunWithTestExecutor { db, _, transactionExecutor ->
-            val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(db)
+        setupAndRunWithTestExecutor { db, queryExecutor, transactionExecutor ->
+            val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(
+                db = db,
+                isInitialLoad = true
+            )
 
             val listenableFuture = pagingSource.refresh()
             // ensure future is returned even as its result is still pending
             assertFalse(listenableFuture.isDone)
             assertThat(pagingSource.itemCount.get()).isEqualTo(-1)
 
-            // now execute db queries
-            transactionExecutor.executeAll() // initial transactional refresh load
+            queryExecutor.executeAll() // run loadFuture
+            transactionExecutor.executeAll() // start initialLoad callable + load data
 
             val page = listenableFuture.await() as LoadResult.Page
             assertThat(page.data).containsExactlyElementsIn(
@@ -103,14 +132,15 @@
     fun append_returnsFutureImmediately() =
         setupAndRunWithTestExecutor { db, queryExecutor, _ ->
             val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(db)
-            pagingSource.itemCount.set(100) // bypass check for initial load
+
+            pagingSource.itemCount.set(100)
 
             val listenableFuture = pagingSource.append(key = 20)
             // ensure future is returned even as its result is still pending
             assertFalse(listenableFuture.isDone)
 
-            // load append
-            queryExecutor.executeNext()
+            // run transformAsync and async function
+            queryExecutor.executeAll()
 
             val page = listenableFuture.await() as LoadResult.Page
             assertThat(page.data).containsExactlyElementsIn(
@@ -129,8 +159,8 @@
             // ensure future is returned even as its result is still pending
             assertFalse(listenableFuture.isDone)
 
-            // load prepend
-            queryExecutor.executeNext()
+            // run transformAsync and async function
+            queryExecutor.executeAll()
 
             val page = listenableFuture.await() as LoadResult.Page
             assertThat(page.data).containsExactlyElementsIn(
@@ -150,8 +180,7 @@
             pagingSource.invalidate() // imitate refreshVersionsAsync invalidating the PagingSource
             assertTrue(pagingSource.invalid)
 
-            // executing the load Callable
-            queryExecutor.executeNext()
+            queryExecutor.executeAll() // run transformAsync and async function
 
             val result = listenableFuture.await()
             assertThat(result).isInstanceOf(LoadResult.Invalid::class.java)
@@ -169,8 +198,7 @@
             pagingSource.invalidate() // imitate refreshVersionsAsync invalidating the PagingSource
             assertTrue(pagingSource.invalid)
 
-            // executing the load Callable
-            queryExecutor.executeNext()
+            queryExecutor.executeAll() // run transformAsync and async function
 
             val result = listenableFuture.await()
             assertThat(result).isInstanceOf(LoadResult.Invalid::class.java)
@@ -180,8 +208,8 @@
     @Test
     fun refresh_consecutively() = setupAndRun { db ->
         db.dao.addAllItems(ITEMS_LIST)
-        val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(db)
-        val pagingSource2 = LimitOffsetListenableFuturePagingSourceImpl(db)
+        val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(db, true)
+        val pagingSource2 = LimitOffsetListenableFuturePagingSourceImpl(db, true)
 
         val listenableFuture1 = pagingSource.refresh(key = 10)
         val listenableFuture2 = pagingSource2.refresh(key = 15)
@@ -210,22 +238,33 @@
             val listenableFuture1 = pagingSource.append(key = 10)
             val listenableFuture2 = pagingSource.append(key = 15)
 
-            // both appends should be queued
+            // both load futures are queued
+            assertThat(queryExecutor.queuedSize()).isEqualTo(2)
+            queryExecutor.executeNext() // first transformAsync
+            queryExecutor.executeNext() // second transformAsync
+
+            // both async functions are queued
+            assertThat(queryExecutor.queuedSize()).isEqualTo(2)
+            queryExecutor.executeNext() // first async function
+            queryExecutor.executeNext() // second async function
+
+            // both nonInitial loads are queued
             assertThat(queryExecutor.queuedSize()).isEqualTo(2)
 
-            // run next append in queue and make sure it is the first append
-            queryExecutor.executeNext()
+            queryExecutor.executeNext() // first db load
             val page1 = listenableFuture1.await() as LoadResult.Page
             assertThat(page1.data).containsExactlyElementsIn(
                 ITEMS_LIST.subList(10, 15)
             )
 
-            // now run the second append
-            queryExecutor.executeNext()
+            queryExecutor.executeNext() // second db load
             val page2 = listenableFuture2.await() as LoadResult.Page
             assertThat(page2.data).containsExactlyElementsIn(
                 ITEMS_LIST.subList(15, 20)
             )
+
+            assertTrue(listenableFuture1.isDone)
+            assertTrue(listenableFuture2.isDone)
         }
 
     @Test
@@ -236,31 +275,41 @@
 
             assertThat(queryExecutor.queuedSize()).isEqualTo(0)
 
-            val listenableFuture1 = pagingSource.prepend(key = 30)
-            val listenableFuture2 = pagingSource.prepend(key = 25)
+            val listenableFuture1 = pagingSource.prepend(key = 25)
+            val listenableFuture2 = pagingSource.prepend(key = 20)
 
-            // both prepends should be queued
+            // both load futures are queued
+            assertThat(queryExecutor.queuedSize()).isEqualTo(2)
+            queryExecutor.executeNext() // first transformAsync
+            queryExecutor.executeNext() // second transformAsync
+
+            // both async functions are queued
+            assertThat(queryExecutor.queuedSize()).isEqualTo(2)
+            queryExecutor.executeNext() // first async function
+            queryExecutor.executeNext() // second async function
+
+            // both nonInitial loads are queued
             assertThat(queryExecutor.queuedSize()).isEqualTo(2)
 
-            // run next prepend in queue and make sure it is the first prepend
-            queryExecutor.executeNext()
+            queryExecutor.executeNext() // first db load
             val page1 = listenableFuture1.await() as LoadResult.Page
             assertThat(page1.data).containsExactlyElementsIn(
-                ITEMS_LIST.subList(25, 30)
-            )
-
-            // now run the second prepend
-            queryExecutor.executeNext()
-            val page2 = listenableFuture2.await() as LoadResult.Page
-            assertThat(page2.data).containsExactlyElementsIn(
                 ITEMS_LIST.subList(20, 25)
             )
-        }
 
+            queryExecutor.executeNext() // second db load
+            val page2 = listenableFuture2.await() as LoadResult.Page
+            assertThat(page2.data).containsExactlyElementsIn(
+                ITEMS_LIST.subList(15, 20)
+            )
+
+            assertTrue(listenableFuture1.isDone)
+            assertTrue(listenableFuture2.isDone)
+        }
     @Test
     fun refresh_onSuccess() = setupAndRun { db ->
         db.dao.addAllItems(ITEMS_LIST)
-        val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(db)
+        val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(db, true)
 
         val listenableFuture = pagingSource.refresh(key = 30)
 
@@ -342,21 +391,91 @@
     }
 
     @Test
-    fun refresh_awaitThrowsCancellationException() =
+    fun refresh_cancelBeforeObserverRegistered_CancellationException() =
         setupAndRunWithTestExecutor { db, queryExecutor, transactionExecutor ->
-            val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(db)
+            val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(db, true)
 
             val listenableFuture = pagingSource.refresh(key = 50)
-            // the initial runInTransaction load
-            assertThat(transactionExecutor.queuedSize()).isEqualTo(1)
+            assertThat(queryExecutor.queuedSize()).isEqualTo(1) // transformAsync
+
+            // cancel before observer has been registered. This queues up another task which is
+            // the cancelled async function
+            listenableFuture.cancel(true)
+
+            // even though future is cancelled, transformAsync was already queued up which means
+            // observer will still get registered
+            assertThat(queryExecutor.queuedSize()).isEqualTo(2)
+            // start async function but doesn't proceed further
+            queryExecutor.executeAll()
+
+            // ensure initial load is not queued up
+            assertThat(transactionExecutor.queuedSize()).isEqualTo(0)
+
+            // await() should throw after cancellation
+            assertFailsWith<CancellationException> {
+                listenableFuture.await()
+            }
+
+            // executors should be idle
+            assertThat(queryExecutor.queuedSize()).isEqualTo(0)
+            assertThat(transactionExecutor.queuedSize()).isEqualTo(0)
+            assertTrue(listenableFuture.isDone)
+            // even though initial refresh load is cancelled, the paging source itself
+            // is NOT invalidated
+            assertFalse(pagingSource.invalid)
+        }
+
+    @Test
+    fun refresh_cancelAfterObserverRegistered_CancellationException() =
+        setupAndRunWithTestExecutor { db, queryExecutor, transactionExecutor ->
+            val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(db, true)
+
+            val listenableFuture = pagingSource.refresh(key = 50)
+
+            // start transformAsync and register observer
+            queryExecutor.executeNext()
+
+            // cancel after observer registration
+            listenableFuture.cancel(true)
+
+            // start the async function but it has been cancelled so it doesn't queue up
+            // initial load
+            queryExecutor.executeNext()
+
+            // initialLoad not queued
+            assertThat(transactionExecutor.queuedSize()).isEqualTo(0)
+
+            // await() should throw after cancellation
+            assertFailsWith<CancellationException> {
+                listenableFuture.await()
+            }
+
+            // executors should be idle
+            assertThat(queryExecutor.queuedSize()).isEqualTo(0)
+            assertThat(transactionExecutor.queuedSize()).isEqualTo(0)
+            assertTrue(listenableFuture.isDone)
+            // even though initial refresh load is cancelled, the paging source itself
+            // is NOT invalidated
+            assertFalse(pagingSource.invalid)
+        }
+
+    @Test
+    fun refresh_cancelAfterLoadIsQueued_CancellationException() =
+        setupAndRunWithTestExecutor { db, queryExecutor, transactionExecutor ->
+            val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(db, true)
+
+            val listenableFuture = pagingSource.refresh(key = 50)
+
+            queryExecutor.executeAll() // run loadFuture and queue up initial load
 
             listenableFuture.cancel(true)
 
-            assertThat(queryExecutor.queuedSize()).isEqualTo(0)
+            // initialLoad has been queued
             assertThat(transactionExecutor.queuedSize()).isEqualTo(1)
+            assertThat(queryExecutor.queuedSize()).isEqualTo(0)
 
-            transactionExecutor.executeNext() // initial load
-            queryExecutor.executeNext() // refreshVersionsAsync from the end runInTransaction
+            transactionExecutor.executeAll() // room starts transaction but doesn't complete load
+            queryExecutor.executeAll() // InvalidationTracker from end of transaction
 
             // await() should throw after cancellation
             assertFailsWith<CancellationException> {
@@ -383,7 +502,7 @@
             assertThat(queryExecutor.queuedSize()).isEqualTo(1)
 
             listenableFuture.cancel(true)
-            queryExecutor.executeNext()
+            queryExecutor.executeAll()
 
             // await() should throw after cancellation
             assertFailsWith<CancellationException> {
@@ -407,7 +526,7 @@
             assertThat(queryExecutor.queuedSize()).isEqualTo(1)
 
             listenableFuture.cancel(true)
-            queryExecutor.executeNext()
+            queryExecutor.executeAll()
 
             // await() should throw after cancellation
             assertFailsWith<CancellationException> {
@@ -422,10 +541,12 @@
 
     @Test
     fun refresh_canceledFutureRunsOnFailureCallback() =
-        setupAndRunWithTestExecutor { db, _, transactionExecutor ->
-            val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(db)
+        setupAndRunWithTestExecutor { db, queryExecutor, transactionExecutor ->
+            val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(db, true)
 
             val listenableFuture = pagingSource.refresh(key = 30)
+
+            queryExecutor.executeAll() // start transformAsync & async function
             assertThat(transactionExecutor.queuedSize()).isEqualTo(1)
 
             val callbackExecutor = TestExecutor()
@@ -437,7 +558,7 @@
 
             // now cancel future and execute the refresh load. The refresh should not complete.
             listenableFuture.cancel(true)
-            transactionExecutor.executeNext()
+            transactionExecutor.executeAll()
             assertThat(transactionExecutor.queuedSize()).isEqualTo(0)
 
             callbackExecutor.executeAll()
@@ -448,12 +569,11 @@
         }
 
     @Test
-    fun append_canceledFutureRunsOnFailureCallback() =
+    fun append_canceledFutureRunsOnFailureCallback2() =
         setupAndRunWithTestExecutor { db, queryExecutor, _ ->
             val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(db)
             pagingSource.itemCount.set(100) // bypass check for initial load
 
-            // queue up the append first
             val listenableFuture = pagingSource.append(key = 20)
             assertThat(queryExecutor.queuedSize()).isEqualTo(1)
 
@@ -466,7 +586,9 @@
 
             // now cancel future and execute the append load. The append should not complete.
             listenableFuture.cancel(true)
-            queryExecutor.executeNext()
+
+            queryExecutor.executeNext() // transformAsync
+            queryExecutor.executeNext() // nonInitialLoad
             // if load was erroneously completed, InvalidationTracker would be queued
             assertThat(queryExecutor.queuedSize()).isEqualTo(0)
 
@@ -475,7 +597,7 @@
             // make sure onFailure callback was executed
             assertTrue(onFailureReceived)
             assertTrue(listenableFuture.isDone)
-    }
+        }
 
     @Test
     fun prepend_canceledFutureRunsOnFailureCallback() =
@@ -496,7 +618,8 @@
 
             // now cancel future and execute the prepend which should not complete.
             listenableFuture.cancel(true)
-            queryExecutor.executeNext()
+            queryExecutor.executeNext() // transformAsync
+            queryExecutor.executeNext() // nonInitialLoad
             // if load was erroneously completed, InvalidationTracker would be queued
             assertThat(queryExecutor.queuedSize()).isEqualTo(0)
 
@@ -510,7 +633,7 @@
     @Test
     fun refresh_AfterCancellation() = setupAndRun { db ->
         db.dao.addAllItems(ITEMS_LIST)
-        val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(db)
+        val pagingSource = LimitOffsetListenableFuturePagingSourceImpl(db, true)
         pagingSource.itemCount.set(100) // bypass check for initial load
 
         val listenableFuture = pagingSource.prepend(key = 50)
@@ -521,7 +644,7 @@
         }
 
         // new gen after query from previous gen was cancelled
-        val pagingSource2 = LimitOffsetListenableFuturePagingSourceImpl(db)
+        val pagingSource2 = LimitOffsetListenableFuturePagingSourceImpl(db, true)
         val listenableFuture2 = pagingSource2.refresh()
         val result = listenableFuture2.await() as LoadResult.Page
 
@@ -685,7 +808,7 @@
 
         runTest {
             db.dao.addAllItems(ITEMS_LIST)
-            queryExecutor.executeNext() // InvalidationTracker from the addAllItems
+            queryExecutor.executeAll() // InvalidationTracker from the addAllItems
           test(db, queryExecutor, transactionExecutor)
         }
         tearDown(db)
@@ -700,6 +823,7 @@
 
 private class LimitOffsetListenableFuturePagingSourceImpl(
     db: RoomDatabase,
+    isInitialLoad: Boolean = false,
     queryString: String = "SELECT * FROM $tableName ORDER BY id ASC",
 ) : LimitOffsetListenableFuturePagingSource<TestItem>(
     sourceQuery = RoomSQLiteQuery.acquire(
@@ -709,6 +833,14 @@
     db = db,
     tables = arrayOf(tableName)
 ) {
+
+   init {
+       // bypass register check and avoid registering observer
+       if (!isInitialLoad) {
+           privateObserver().privateRegisteredState().set(true)
+       }
+   }
+
     override fun convertRows(cursor: Cursor): List<TestItem> {
         return convertRowsHelper(cursor)
     }
@@ -747,6 +879,27 @@
     return tasks.size
 }
 
+@Suppress("UNCHECKED_CAST")
+private fun ThreadSafeInvalidationObserver.privateRegisteredState(): AtomicBoolean {
+    return ThreadSafeInvalidationObserver::class.java
+        .getDeclaredField("registered")
+        .let {
+            it.isAccessible = true
+            it.get(this)
+        } as AtomicBoolean
+}
+
+@Suppress("UNCHECKED_CAST")
+private fun LimitOffsetListenableFuturePagingSource<TestItem>.privateObserver():
+    ThreadSafeInvalidationObserver {
+    return LimitOffsetListenableFuturePagingSource::class.java
+        .getDeclaredField("observer")
+        .let {
+            it.isAccessible = true
+            it.get(this)
+        } as ThreadSafeInvalidationObserver
+}
+
 private fun LimitOffsetListenableFuturePagingSource<TestItem>.refresh(
     key: Int? = null,
 ): ListenableFuture<LoadResult<Int, TestItem>> {
diff --git a/room/room-paging-guava/src/main/java/androidx/room/paging/guava/LimitOffsetListenableFuturePagingSource.kt b/room/room-paging-guava/src/main/java/androidx/room/paging/guava/LimitOffsetListenableFuturePagingSource.kt
index fb2d209..23fbec0 100644
--- a/room/room-paging-guava/src/main/java/androidx/room/paging/guava/LimitOffsetListenableFuturePagingSource.kt
+++ b/room/room-paging-guava/src/main/java/androidx/room/paging/guava/LimitOffsetListenableFuturePagingSource.kt
@@ -19,6 +19,7 @@
 import android.database.Cursor
 import androidx.annotation.NonNull
 import androidx.annotation.RestrictTo
+import androidx.annotation.VisibleForTesting
 import androidx.paging.ListenableFuturePagingSource
 import androidx.paging.PagingState
 import androidx.room.RoomDatabase
@@ -32,6 +33,7 @@
 import androidx.room.paging.util.queryItemCount
 import androidx.room.util.createCancellationSignal
 import androidx.sqlite.db.SupportSQLiteQuery
+import com.google.common.util.concurrent.Futures
 import com.google.common.util.concurrent.ListenableFuture
 import java.util.concurrent.Callable
 import java.util.concurrent.atomic.AtomicInteger
@@ -53,7 +55,7 @@
         tables = tables,
     )
 
-    // internal for testing visibility
+    @VisibleForTesting
     internal val itemCount: AtomicInteger = AtomicInteger(INITIAL_ITEM_COUNT)
     private val observer = ThreadSafeInvalidationObserver(tables = tables, ::invalidate)
 
@@ -65,13 +67,18 @@
     * cancellation of await() will transitively cancel this future as well.
     */
     override fun loadFuture(params: LoadParams<Int>): ListenableFuture<LoadResult<Int, Value>> {
-        observer.registerIfNecessary(db)
-        val tempCount = itemCount.get()
-        return if (tempCount == INITIAL_ITEM_COUNT) {
-            initialLoad(params)
-        } else {
-            nonInitialLoad(params, tempCount)
-        }
+        return Futures.transformAsync(
+            createListenableFuture(db, false) { observer.registerIfNecessary(db) },
+            {
+                val tempCount = itemCount.get()
+                if (tempCount == INITIAL_ITEM_COUNT) {
+                    initialLoad(params)
+                } else {
+                    nonInitialLoad(params, tempCount)
+                }
+            },
+            db.queryExecutor
+        )
     }
 
     /**