Synchronize adding and removing table trackers.

Replace the usage of a 'pending sync' flag, onSyncCompleted() callback and a while(true) loop with a sync block to protect adding and removing table triggers from race conditions. Specifically the strategy of the flag along with the loop made sure only one thread would proceed to update triggers and would release other threads from the responsibility. However, in certain conditions, a trigger sync might be already occurring and a new write transaction is started, but does not wait for the triggers sync to finish, possibly causing invalidation of newly added observers to be missed.

Using a sync block is simple but it does introduce contingency, as now every add / remove observers call along with beginTransaction() might need to wait if there is another thread within the block.

Bug: 215583326
Test: SyncTriggersConcurrencyTest
Change-Id: I3f18daf2a61eca2c872668c818f75b3b03204988
diff --git a/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/test/SyncTriggersConcurrencyTest.kt b/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/test/SyncTriggersConcurrencyTest.kt
new file mode 100644
index 0000000..43da129
--- /dev/null
+++ b/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/test/SyncTriggersConcurrencyTest.kt
@@ -0,0 +1,177 @@
+/*
+ * Copyright 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.room.androidx.room.integration.kotlintestapp.test
+
+import androidx.arch.core.executor.testing.CountingTaskExecutorRule
+import androidx.room.Dao
+import androidx.room.Database
+import androidx.room.Delete
+import androidx.room.Entity
+import androidx.room.Insert
+import androidx.room.InvalidationTracker
+import androidx.room.OnConflictStrategy
+import androidx.room.PrimaryKey
+import androidx.room.Room
+import androidx.room.RoomDatabase
+import androidx.test.ext.junit.runners.AndroidJUnit4
+import androidx.test.filters.LargeTest
+import androidx.test.platform.app.InstrumentationRegistry
+import java.util.UUID
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.ExecutorService
+import java.util.concurrent.Executors
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.atomic.AtomicInteger
+import kotlin.test.assertTrue
+import org.junit.After
+import org.junit.Before
+import org.junit.Rule
+import org.junit.Test
+import org.junit.runner.RunWith
+
+// Verifies b/215583326
+@LargeTest
+@RunWith(AndroidJUnit4::class)
+class SyncTriggersConcurrencyTest {
+
+    @Rule
+    @JvmField
+    val countingTaskExecutorRule = CountingTaskExecutorRule()
+
+    private lateinit var executor: ExecutorService
+    private lateinit var database: SampleDatabase
+    private lateinit var terminationSignal: AtomicBoolean
+
+    @Before
+    fun setup() {
+        val applicationContext = InstrumentationRegistry.getInstrumentation().targetContext
+        val threadId = AtomicInteger()
+        executor = Executors.newCachedThreadPool { runnable ->
+            Thread(runnable).apply {
+                name = "invalidation_tracker_test_worker_${threadId.getAndIncrement()}"
+            }
+        }
+        database = Room
+            .databaseBuilder(applicationContext, SampleDatabase::class.java, DB_NAME)
+            .build()
+        terminationSignal = AtomicBoolean()
+    }
+
+    @After
+    fun tearDown() {
+        terminationSignal.set(true)
+        executor.shutdown()
+        val terminated = executor.awaitTermination(1L, TimeUnit.SECONDS)
+        countingTaskExecutorRule.drainTasks(500, TimeUnit.MILLISECONDS)
+        database.close()
+        InstrumentationRegistry.getInstrumentation().targetContext.deleteDatabase(DB_NAME)
+        check(terminated)
+        check(countingTaskExecutorRule.isIdle)
+    }
+
+    @Test
+    fun test() {
+        val invalidationTracker = database.invalidationTracker
+
+        // Launch CONCURRENCY number of tasks which stress the InvalidationTracker by repeatedly
+        // registering and unregistering observers.
+        repeat(CONCURRENCY) {
+            executor.execute(StressRunnable(invalidationTracker, terminationSignal))
+        }
+
+        // Repeatedly, CHECK_ITERATIONS number of times:
+        // 1. Add an observer
+        // 2. Insert an entity
+        // 4. Remove the observer
+        // 5. Assert that the observer received an invalidation call.
+        val dao = database.sampleDao
+        repeat(CHECK_ITERATIONS) { iteration ->
+            val checkObserver = TestObserver(
+                expectedInvalidationCount = 1
+            )
+            invalidationTracker.addObserver(checkObserver)
+            try {
+                val entity = SampleEntity(UUID.randomUUID().toString())
+                dao.insert(entity)
+                val countedDown = checkObserver.latch.await(10L, TimeUnit.SECONDS)
+                assertTrue(countedDown, "iteration $iteration timed out")
+            } finally {
+                invalidationTracker.removeObserver(checkObserver)
+            }
+        }
+    }
+
+    /**
+     * Stresses the invalidation tracker by repeatedly adding and removing an observer.
+     * @property invalidationTracker the invalidation tracker
+     * @property terminationSignal when set to true, signals the loop to terminate
+     */
+    private class StressRunnable(
+        private val invalidationTracker: InvalidationTracker,
+        private val terminationSignal: AtomicBoolean,
+    ) : Runnable {
+
+        val observer = TestObserver()
+
+        override fun run() {
+            while (!terminationSignal.get()) {
+                invalidationTracker.addObserver(observer)
+                invalidationTracker.removeObserver(observer)
+            }
+        }
+    }
+
+    private class TestObserver(
+        expectedInvalidationCount: Int = 0
+    ) : InvalidationTracker.Observer(SampleEntity::class.java.simpleName) {
+
+        val latch = CountDownLatch(expectedInvalidationCount)
+
+        override fun onInvalidated(tables: MutableSet<String>) {
+            latch.countDown()
+        }
+    }
+
+    @Database(entities = [SampleEntity::class], version = 1, exportSchema = false)
+    abstract class SampleDatabase : RoomDatabase() {
+        abstract val sampleDao: SampleDao
+    }
+
+    @Dao
+    interface SampleDao {
+
+        @Insert(onConflict = OnConflictStrategy.REPLACE)
+        fun insert(count: SampleEntity)
+
+        @Delete
+        fun delete(count: SampleEntity)
+    }
+
+    @Entity
+    class SampleEntity(
+        @PrimaryKey val id: String,
+    )
+
+    companion object {
+
+        private const val DB_NAME = "sample.db"
+
+        private const val CONCURRENCY = 4
+        private const val CHECK_ITERATIONS = 200
+    }
+}
\ No newline at end of file
diff --git a/room/room-runtime/src/main/java/androidx/room/InvalidationTracker.java b/room/room-runtime/src/main/java/androidx/room/InvalidationTracker.java
index fe0ae9a..3a0c5ca 100644
--- a/room/room-runtime/src/main/java/androidx/room/InvalidationTracker.java
+++ b/room/room-runtime/src/main/java/androidx/room/InvalidationTracker.java
@@ -104,7 +104,7 @@
     @SuppressWarnings("WeakerAccess") /* synthetic access */
     volatile SupportSQLiteStatement mCleanupStatement;
 
-    private ObservedTableTracker mObservedTableTracker;
+    private final ObservedTableTracker mObservedTableTracker;
 
     private final InvalidationLiveDataContainer mInvalidationLiveDataContainer;
 
@@ -115,6 +115,8 @@
 
     private MultiInstanceInvalidationClient mMultiInstanceInvalidationClient;
 
+    private final Object mSyncTriggersLock = new Object();
+
     /**
      * Used by the generated code.
      *
@@ -537,14 +539,13 @@
             return;
         }
         try {
-            // This method runs in a while loop because while changes are synced to db, another
-            // runnable may be skipped. If we cause it to skip, we need to do its work.
-            while (true) {
-                Lock closeLock = mDatabase.getCloseLock();
-                closeLock.lock();
-                try {
-                    // there is a potential race condition where another mSyncTriggers runnable
-                    // can start running right after we get the tables list to sync.
+            Lock closeLock = mDatabase.getCloseLock();
+            closeLock.lock();
+            try {
+                // Serialize adding and removing table trackers, this is specifically important
+                // to avoid missing invalidation before a transaction starts but there are
+                // pending (possibly concurrent) observer changes.
+                synchronized (mSyncTriggersLock) {
                     final int[] tablesToSync = mObservedTableTracker.getTablesToSync();
                     if (tablesToSync == null) {
                         return;
@@ -566,10 +567,9 @@
                     } finally {
                         database.endTransaction();
                     }
-                    mObservedTableTracker.onSyncCompleted();
-                } finally {
-                    closeLock.unlock();
                 }
+            } finally {
+                closeLock.unlock();
             }
         } catch (IllegalStateException | SQLiteException exception) {
             // may happen if db is closed. just log.
@@ -789,13 +789,6 @@
 
         boolean mNeedsSync;
 
-        /**
-         * After we return non-null value from getTablesToSync, we expect a onSyncCompleted before
-         * returning any non-null value from getTablesToSync.
-         * This allows us to workaround any multi-threaded state syncing issues.
-         */
-        boolean mPendingSync;
-
         ObservedTableTracker(int tableCount) {
             mTableObservers = new long[tableCount];
             mTriggerStates = new boolean[tableCount];
@@ -852,7 +845,7 @@
         }
 
         /**
-         * If this returns non-null, you must call onSyncCompleted.
+         * If this returns non-null there are no pending sync operations.
          *
          * @return int[] An int array where the index for each tableId has the action for that
          * table.
@@ -860,7 +853,7 @@
         @Nullable
         int[] getTablesToSync() {
             synchronized (this) {
-                if (!mNeedsSync || mPendingSync) {
+                if (!mNeedsSync) {
                     return null;
                 }
                 final int tableCount = mTableObservers.length;
@@ -873,19 +866,8 @@
                     }
                     mTriggerStates[i] = newState;
                 }
-                mPendingSync = true;
                 mNeedsSync = false;
-                return mTriggerStateChanges;
-            }
-        }
-
-        /**
-         * if getTablesToSync returned non-null, the called should call onSyncCompleted once it
-         * is done.
-         */
-        void onSyncCompleted() {
-            synchronized (this) {
-                mPendingSync = false;
+                return mTriggerStateChanges.clone();
             }
         }
     }
diff --git a/room/room-runtime/src/test/java/androidx/room/ObservedTableTrackerTest.java b/room/room-runtime/src/test/java/androidx/room/ObservedTableTrackerTest.java
index 224815f..2106fdb 100644
--- a/room/room-runtime/src/test/java/androidx/room/ObservedTableTrackerTest.java
+++ b/room/room-runtime/src/test/java/androidx/room/ObservedTableTrackerTest.java
@@ -64,17 +64,6 @@
     }
 
     @Test
-    public void returnNullUntilSync() {
-        initState(1, 3);
-        mTracker.onAdded(4);
-        assertThat(mTracker.getTablesToSync(), is(createResponse(4, ADD)));
-        mTracker.onAdded(0);
-        assertThat(mTracker.getTablesToSync(), is(nullValue()));
-        mTracker.onSyncCompleted();
-        assertThat(mTracker.getTablesToSync(), is(createResponse(0, ADD)));
-    }
-
-    @Test
     public void multipleAdditionsDeletions() {
         initState(2, 4);
         mTracker.onAdded(2);
@@ -94,7 +83,6 @@
     private void initState(int... tableIds) {
         mTracker.onAdded(tableIds);
         mTracker.getTablesToSync();
-        mTracker.onSyncCompleted();
     }
 
     private static int[] createResponse(int... tuples) {