Merge "Synchronize adding and removing table trackers." into androidx-main
diff --git a/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/test/SyncTriggersConcurrencyTest.kt b/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/test/SyncTriggersConcurrencyTest.kt
new file mode 100644
index 0000000..43da129
--- /dev/null
+++ b/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/test/SyncTriggersConcurrencyTest.kt
@@ -0,0 +1,177 @@
+/*
+ * Copyright 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.room.androidx.room.integration.kotlintestapp.test
+
+import androidx.arch.core.executor.testing.CountingTaskExecutorRule
+import androidx.room.Dao
+import androidx.room.Database
+import androidx.room.Delete
+import androidx.room.Entity
+import androidx.room.Insert
+import androidx.room.InvalidationTracker
+import androidx.room.OnConflictStrategy
+import androidx.room.PrimaryKey
+import androidx.room.Room
+import androidx.room.RoomDatabase
+import androidx.test.ext.junit.runners.AndroidJUnit4
+import androidx.test.filters.LargeTest
+import androidx.test.platform.app.InstrumentationRegistry
+import java.util.UUID
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.ExecutorService
+import java.util.concurrent.Executors
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.atomic.AtomicInteger
+import kotlin.test.assertTrue
+import org.junit.After
+import org.junit.Before
+import org.junit.Rule
+import org.junit.Test
+import org.junit.runner.RunWith
+
+// Verifies b/215583326
+@LargeTest
+@RunWith(AndroidJUnit4::class)
+class SyncTriggersConcurrencyTest {
+
+    @Rule
+    @JvmField
+    val countingTaskExecutorRule = CountingTaskExecutorRule()
+
+    private lateinit var executor: ExecutorService
+    private lateinit var database: SampleDatabase
+    private lateinit var terminationSignal: AtomicBoolean
+
+    @Before
+    fun setup() {
+        val applicationContext = InstrumentationRegistry.getInstrumentation().targetContext
+        val threadId = AtomicInteger()
+        executor = Executors.newCachedThreadPool { runnable ->
+            Thread(runnable).apply {
+                name = "invalidation_tracker_test_worker_${threadId.getAndIncrement()}"
+            }
+        }
+        database = Room
+            .databaseBuilder(applicationContext, SampleDatabase::class.java, DB_NAME)
+            .build()
+        terminationSignal = AtomicBoolean()
+    }
+
+    @After
+    fun tearDown() {
+        terminationSignal.set(true)
+        executor.shutdown()
+        val terminated = executor.awaitTermination(1L, TimeUnit.SECONDS)
+        countingTaskExecutorRule.drainTasks(500, TimeUnit.MILLISECONDS)
+        database.close()
+        InstrumentationRegistry.getInstrumentation().targetContext.deleteDatabase(DB_NAME)
+        check(terminated)
+        check(countingTaskExecutorRule.isIdle)
+    }
+
+    @Test
+    fun test() {
+        val invalidationTracker = database.invalidationTracker
+
+        // Launch CONCURRENCY number of tasks which stress the InvalidationTracker by repeatedly
+        // registering and unregistering observers.
+        repeat(CONCURRENCY) {
+            executor.execute(StressRunnable(invalidationTracker, terminationSignal))
+        }
+
+        // Repeatedly, CHECK_ITERATIONS number of times:
+        // 1. Add an observer
+        // 2. Insert an entity
+        // 4. Remove the observer
+        // 5. Assert that the observer received an invalidation call.
+        val dao = database.sampleDao
+        repeat(CHECK_ITERATIONS) { iteration ->
+            val checkObserver = TestObserver(
+                expectedInvalidationCount = 1
+            )
+            invalidationTracker.addObserver(checkObserver)
+            try {
+                val entity = SampleEntity(UUID.randomUUID().toString())
+                dao.insert(entity)
+                val countedDown = checkObserver.latch.await(10L, TimeUnit.SECONDS)
+                assertTrue(countedDown, "iteration $iteration timed out")
+            } finally {
+                invalidationTracker.removeObserver(checkObserver)
+            }
+        }
+    }
+
+    /**
+     * Stresses the invalidation tracker by repeatedly adding and removing an observer.
+     * @property invalidationTracker the invalidation tracker
+     * @property terminationSignal when set to true, signals the loop to terminate
+     */
+    private class StressRunnable(
+        private val invalidationTracker: InvalidationTracker,
+        private val terminationSignal: AtomicBoolean,
+    ) : Runnable {
+
+        val observer = TestObserver()
+
+        override fun run() {
+            while (!terminationSignal.get()) {
+                invalidationTracker.addObserver(observer)
+                invalidationTracker.removeObserver(observer)
+            }
+        }
+    }
+
+    private class TestObserver(
+        expectedInvalidationCount: Int = 0
+    ) : InvalidationTracker.Observer(SampleEntity::class.java.simpleName) {
+
+        val latch = CountDownLatch(expectedInvalidationCount)
+
+        override fun onInvalidated(tables: MutableSet<String>) {
+            latch.countDown()
+        }
+    }
+
+    @Database(entities = [SampleEntity::class], version = 1, exportSchema = false)
+    abstract class SampleDatabase : RoomDatabase() {
+        abstract val sampleDao: SampleDao
+    }
+
+    @Dao
+    interface SampleDao {
+
+        @Insert(onConflict = OnConflictStrategy.REPLACE)
+        fun insert(count: SampleEntity)
+
+        @Delete
+        fun delete(count: SampleEntity)
+    }
+
+    @Entity
+    class SampleEntity(
+        @PrimaryKey val id: String,
+    )
+
+    companion object {
+
+        private const val DB_NAME = "sample.db"
+
+        private const val CONCURRENCY = 4
+        private const val CHECK_ITERATIONS = 200
+    }
+}
\ No newline at end of file
diff --git a/room/room-runtime/src/main/java/androidx/room/InvalidationTracker.java b/room/room-runtime/src/main/java/androidx/room/InvalidationTracker.java
index fe0ae9a..3a0c5ca 100644
--- a/room/room-runtime/src/main/java/androidx/room/InvalidationTracker.java
+++ b/room/room-runtime/src/main/java/androidx/room/InvalidationTracker.java
@@ -104,7 +104,7 @@
     @SuppressWarnings("WeakerAccess") /* synthetic access */
     volatile SupportSQLiteStatement mCleanupStatement;
 
-    private ObservedTableTracker mObservedTableTracker;
+    private final ObservedTableTracker mObservedTableTracker;
 
     private final InvalidationLiveDataContainer mInvalidationLiveDataContainer;
 
@@ -115,6 +115,8 @@
 
     private MultiInstanceInvalidationClient mMultiInstanceInvalidationClient;
 
+    private final Object mSyncTriggersLock = new Object();
+
     /**
      * Used by the generated code.
      *
@@ -537,14 +539,13 @@
             return;
         }
         try {
-            // This method runs in a while loop because while changes are synced to db, another
-            // runnable may be skipped. If we cause it to skip, we need to do its work.
-            while (true) {
-                Lock closeLock = mDatabase.getCloseLock();
-                closeLock.lock();
-                try {
-                    // there is a potential race condition where another mSyncTriggers runnable
-                    // can start running right after we get the tables list to sync.
+            Lock closeLock = mDatabase.getCloseLock();
+            closeLock.lock();
+            try {
+                // Serialize adding and removing table trackers, this is specifically important
+                // to avoid missing invalidation before a transaction starts but there are
+                // pending (possibly concurrent) observer changes.
+                synchronized (mSyncTriggersLock) {
                     final int[] tablesToSync = mObservedTableTracker.getTablesToSync();
                     if (tablesToSync == null) {
                         return;
@@ -566,10 +567,9 @@
                     } finally {
                         database.endTransaction();
                     }
-                    mObservedTableTracker.onSyncCompleted();
-                } finally {
-                    closeLock.unlock();
                 }
+            } finally {
+                closeLock.unlock();
             }
         } catch (IllegalStateException | SQLiteException exception) {
             // may happen if db is closed. just log.
@@ -789,13 +789,6 @@
 
         boolean mNeedsSync;
 
-        /**
-         * After we return non-null value from getTablesToSync, we expect a onSyncCompleted before
-         * returning any non-null value from getTablesToSync.
-         * This allows us to workaround any multi-threaded state syncing issues.
-         */
-        boolean mPendingSync;
-
         ObservedTableTracker(int tableCount) {
             mTableObservers = new long[tableCount];
             mTriggerStates = new boolean[tableCount];
@@ -852,7 +845,7 @@
         }
 
         /**
-         * If this returns non-null, you must call onSyncCompleted.
+         * If this returns non-null there are no pending sync operations.
          *
          * @return int[] An int array where the index for each tableId has the action for that
          * table.
@@ -860,7 +853,7 @@
         @Nullable
         int[] getTablesToSync() {
             synchronized (this) {
-                if (!mNeedsSync || mPendingSync) {
+                if (!mNeedsSync) {
                     return null;
                 }
                 final int tableCount = mTableObservers.length;
@@ -873,19 +866,8 @@
                     }
                     mTriggerStates[i] = newState;
                 }
-                mPendingSync = true;
                 mNeedsSync = false;
-                return mTriggerStateChanges;
-            }
-        }
-
-        /**
-         * if getTablesToSync returned non-null, the called should call onSyncCompleted once it
-         * is done.
-         */
-        void onSyncCompleted() {
-            synchronized (this) {
-                mPendingSync = false;
+                return mTriggerStateChanges.clone();
             }
         }
     }
diff --git a/room/room-runtime/src/test/java/androidx/room/ObservedTableTrackerTest.java b/room/room-runtime/src/test/java/androidx/room/ObservedTableTrackerTest.java
index 224815f..2106fdb 100644
--- a/room/room-runtime/src/test/java/androidx/room/ObservedTableTrackerTest.java
+++ b/room/room-runtime/src/test/java/androidx/room/ObservedTableTrackerTest.java
@@ -64,17 +64,6 @@
     }
 
     @Test
-    public void returnNullUntilSync() {
-        initState(1, 3);
-        mTracker.onAdded(4);
-        assertThat(mTracker.getTablesToSync(), is(createResponse(4, ADD)));
-        mTracker.onAdded(0);
-        assertThat(mTracker.getTablesToSync(), is(nullValue()));
-        mTracker.onSyncCompleted();
-        assertThat(mTracker.getTablesToSync(), is(createResponse(0, ADD)));
-    }
-
-    @Test
     public void multipleAdditionsDeletions() {
         initState(2, 4);
         mTracker.onAdded(2);
@@ -94,7 +83,6 @@
     private void initState(int... tableIds) {
         mTracker.onAdded(tableIds);
         mTracker.getTablesToSync();
-        mTracker.onSyncCompleted();
     }
 
     private static int[] createResponse(int... tuples) {