Merge "Additional tests for refresh key behavior" into androidx-main
diff --git a/room/room-paging/src/androidTest/kotlin/androidx/room/paging/LimitOffsetPagingSourceTest.kt b/room/room-paging/src/androidTest/kotlin/androidx/room/paging/LimitOffsetPagingSourceTest.kt
index 9cdc433..e96da43 100644
--- a/room/room-paging/src/androidTest/kotlin/androidx/room/paging/LimitOffsetPagingSourceTest.kt
+++ b/room/room-paging/src/androidTest/kotlin/androidx/room/paging/LimitOffsetPagingSourceTest.kt
@@ -78,9 +78,8 @@
         val pagingSource = LimitOffsetPagingSourceImpl(database)
         runBlocking {
             // count query is executed on first load
-            pagingSource.load(
-                createLoadParam(LoadType.REFRESH)
-            )
+            pagingSource.refresh()
+
             assertThat(pagingSource.itemCount.get()).isEqualTo(100)
         }
     }
@@ -94,9 +93,7 @@
         )
         runBlocking {
             // count query is executed on first load
-            pagingSource.load(
-                createLoadParam(LoadType.REFRESH)
-            )
+            pagingSource.refresh()
             // should be 60 instead of 100
             assertThat(pagingSource.itemCount.get()).isEqualTo(60)
         }
@@ -108,7 +105,7 @@
         val pagingSource = LimitOffsetPagingSourceImpl(database)
         runBlocking {
             // load once to register db observers
-            pagingSource.load(createLoadParam(LoadType.REFRESH))
+            pagingSource.refresh()
             assertThat(pagingSource.invalid).isFalse()
             // paging source should be invalidated when insert into db
             val result = dao.addTestItem(TestItem(101))
@@ -124,7 +121,7 @@
         val pagingSource = LimitOffsetPagingSourceImpl(database)
         runBlocking {
             // load once to register db observers
-            pagingSource.load(createLoadParam(LoadType.REFRESH))
+            pagingSource.refresh()
             assertThat(pagingSource.invalid).isFalse()
             // paging source should be invalidated when delete from db
             dao.deleteTestItem(TestItem(50))
@@ -139,7 +136,7 @@
         val pagingSource = LimitOffsetPagingSourceImpl(database)
         runBlocking {
             // load once to register db observers
-            pagingSource.load(createLoadParam(LoadType.REFRESH))
+            pagingSource.refresh()
             assertThat(pagingSource.invalid).isFalse()
 
             val result = dao.deleteTestItem(TestItem(1000))
@@ -155,18 +152,14 @@
         val pagingSource = LimitOffsetPagingSourceImpl(database)
         runBlocking {
             // test empty load
-            var result = pagingSource.load(
-                createLoadParam(LoadType.REFRESH)
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.refresh()
 
             assertTrue(result.data.isEmpty())
             // now add data
             dao.addAllItems(itemsList)
-            result = pagingSource.load(
-                createLoadParam(LoadType.REFRESH)
-            ) as PagingSource.LoadResult.Page
+            val result2 = pagingSource.refresh()
 
-            assertThat(result.data).containsExactlyElementsIn(
+            assertThat(result2.data).containsExactlyElementsIn(
                 itemsList.subList(0, 15)
             )
         }
@@ -178,12 +171,8 @@
         val pagingSource = LimitOffsetPagingSourceImpl(database)
         // refresh with initial key = 20
         runBlocking {
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.REFRESH,
-                    key = 20,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.refresh(key = 20)
+
             // item in pos 21-35 (TestItemId 20-34) loaded
             assertThat(result.data).containsExactlyElementsIn(
                 itemsList.subList(20, 35)
@@ -199,12 +188,8 @@
             queryString = "SELECT * FROM $tableName ORDER BY id ASC LIMIT 10 OFFSET 30",
         )
         runBlocking {
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.REFRESH,
-                    key = null,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.refresh()
+
             // default initial loadSize = 15 starting from index 0.
             // user supplied limit offset should cause initial loadSize = 10, starting from index 30
             assertThat(result.data).containsExactlyElementsIn(
@@ -228,12 +213,8 @@
         )
         // refresh with initial key = 40
         runBlocking {
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.REFRESH,
-                    key = 40,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.refresh(key = 40)
+
             // initial loadSize = 15, but limited by id < 50, should only load items 40 - 50
             assertThat(result.data).containsExactlyElementsIn(
                 itemsList.subList(40, 50)
@@ -255,12 +236,8 @@
                     "ORDER BY id ASC",
         )
         runBlocking {
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.REFRESH,
-                    key = null,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.refresh()
+
             assertThat(result.data).containsExactly(itemsList[90])
             assertThat(pagingSource.itemCount.get()).isEqualTo(1)
         }
@@ -274,12 +251,8 @@
             queryString = "SELECT * FROM $tableName ORDER BY id ASC LIMIT 10 OFFSET 500",
         )
         runBlocking {
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.REFRESH,
-                    key = null,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.refresh()
+
             // invalid OFFSET = 500 should return empty data
             assertThat(result.data).isEmpty()
 
@@ -300,12 +273,8 @@
             queryString = "SELECT * FROM $tableName ORDER BY id ASC LIMIT -1",
         )
         runBlocking {
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.REFRESH,
-                    key = null,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.refresh()
+
             // ensure that it respects SQLite's default behavior for negative LIMIT
             assertThat(result.data).containsExactlyElementsIn(
                 itemsList.subList(0, 15)
@@ -323,12 +292,8 @@
     fun invalidInitialKey_dbEmpty_returnsEmpty() {
         val pagingSource = LimitOffsetPagingSourceImpl(database)
         runBlocking {
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.REFRESH,
-                    key = 101,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.refresh(key = 101)
+
             assertThat(result.data).isEmpty()
         }
     }
@@ -338,12 +303,8 @@
         val pagingSource = LimitOffsetPagingSourceImpl(database)
         dao.addAllItems(itemsList)
         runBlocking {
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.REFRESH,
-                    key = 101,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.refresh(key = 101)
+
             // should load the last page
             assertThat(result.data).containsExactlyElementsIn(
                 itemsList.subList(85, 100)
@@ -358,12 +319,7 @@
         runBlocking {
             // should throw error when initial key is negative
             val expectedException = assertFailsWith<IllegalArgumentException> {
-                pagingSource.load(
-                    createLoadParam(
-                        LoadType.REFRESH,
-                        key = -1,
-                    )
-                )
+                pagingSource.refresh(key = -1)
             }
             // default message from Paging 3 for negative initial key
             assertThat(expectedException.message).isEqualTo(
@@ -379,12 +335,8 @@
         // to bypass check for initial load and run as non-initial load
         pagingSource.itemCount.set(100)
         runBlocking {
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.APPEND,
-                    key = 20,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.append(key = 20)
+
             // item in pos 21-25 (TestItemId 20-24) loaded
             assertThat(result.data).containsExactlyElementsIn(
                 itemsList.subList(20, 25)
@@ -401,12 +353,8 @@
         // to bypass check for initial load and run as non-initial load
         pagingSource.itemCount.set(100)
         runBlocking {
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.APPEND,
-                    key = 97,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.append(key = 97)
+
             // item in pos 98-100 (TestItemId 97-99) loaded
             assertThat(result.data).containsExactlyElementsIn(
                 itemsList.subList(97, 100)
@@ -424,23 +372,15 @@
         pagingSource.itemCount.set(100)
         runBlocking {
             // first prepend
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.APPEND,
-                    key = 30,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.append(key = 30)
+
             // TestItemId 30-34 loaded
             assertThat(result.data).containsExactlyElementsIn(
                 itemsList.subList(30, 35)
             )
             // second prepend using nextKey from previous load
-            val result2 = pagingSource.load(
-                createLoadParam(
-                    LoadType.APPEND,
-                    key = result.nextKey,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result2 = pagingSource.append(key = result.nextKey)
+
             // TestItemId 35 - 39 loaded
             assertThat(result2.data).containsExactlyElementsIn(
                 itemsList.subList(35, 40)
@@ -455,12 +395,7 @@
         // to bypass check for initial load and run as non-initial load
         pagingSource.itemCount.set(100)
         runBlocking {
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.PREPEND,
-                    key = 30,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.prepend(key = 30)
 
             assertThat(result.data).containsExactlyElementsIn(
                 itemsList.subList(25, 30)
@@ -477,12 +412,8 @@
         // to bypass check for initial load and run as non-initial load
         pagingSource.itemCount.set(100)
         runBlocking {
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.PREPEND,
-                    key = 3,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.prepend(key = 3)
+
             // items in pos 0 - 2 (TestItemId 0 - 2) loaded
             assertThat(result.data).containsExactlyElementsIn(
                 itemsList.subList(0, 3)
@@ -500,23 +431,15 @@
         pagingSource.itemCount.set(100)
         runBlocking {
             // first prepend
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.PREPEND,
-                    key = 20,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.prepend(key = 20)
+
             // items pos 16-20 (TestItemId 15-19) loaded
             assertThat(result.data).containsExactlyElementsIn(
                 itemsList.subList(15, 20)
             )
             // second prepend using prevKey from previous load
-            val result2 = pagingSource.load(
-                createLoadParam(
-                    LoadType.PREPEND,
-                    key = result.prevKey,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result2 = pagingSource.prepend(key = result.prevKey)
+
             // items pos 11-15 (TestItemId 10 - 14) loaded
             assertThat(result2.data).containsExactlyElementsIn(
                 itemsList.subList(10, 15)
@@ -530,32 +453,20 @@
         dao.addAllItems(itemsList)
         runBlocking {
             // for initial load
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.REFRESH,
-                    key = 50,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.refresh(key = 50)
+
             // initial loads items in pos 51 - 65, should have 50 items before
             assertThat(result.itemsBefore).isEqualTo(50)
 
             // prepend from initial load
-            val result2 = pagingSource.load(
-                createLoadParam(
-                    LoadType.PREPEND,
-                    key = result.prevKey,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result2 = pagingSource.prepend(key = result.prevKey)
+
             // prepend loads items in pos 46 - 50, should have 45 item before
             assertThat(result2.itemsBefore).isEqualTo(45)
 
             // append from initial load
-            val result3 = pagingSource.load(
-                createLoadParam(
-                    LoadType.APPEND,
-                    key = result.nextKey,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result3 = pagingSource.append(key = result.nextKey)
+
             // append loads items in position 66 - 70 , should have 65 item before
             assertThat(result3.itemsBefore).isEqualTo(65)
         }
@@ -567,32 +478,20 @@
         dao.addAllItems(itemsList)
         runBlocking {
             // for initial load
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.REFRESH,
-                    key = 30,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.refresh(key = 30)
+
             // initial loads items in position 31 - 45, should have 55 items after
             assertThat(result.itemsAfter).isEqualTo(55)
 
             // prepend from initial load
-            val result2 = pagingSource.load(
-                createLoadParam(
-                    LoadType.PREPEND,
-                    key = result.prevKey,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result2 = pagingSource.prepend(key = result.prevKey)
+
             // prepend loads items in position 26 - 30, should have 70 item after
             assertThat(result2.itemsAfter).isEqualTo(70)
 
             // append from initial load
-            val result3 = pagingSource.load(
-                createLoadParam(
-                    LoadType.APPEND,
-                    key = result.nextKey,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result3 = pagingSource.append(result.nextKey)
+
             // append loads items in position 46 - 50 , should have 50 item after
             assertThat(result3.itemsAfter).isEqualTo(50)
         }
@@ -604,12 +503,7 @@
         dao.addAllItems(itemsList)
         runBlocking {
             // initial load
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.REFRESH,
-                    key = null,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result = pagingSource.refresh()
             // 15 items loaded, assuming anchorPosition = 14 as the last item loaded
             var refreshKey = pagingSource.getRefreshKey(
                 PagingState(
@@ -625,12 +519,7 @@
             assertThat(refreshKey).isEqualTo(7)
 
             // append after refresh
-            val result2 = pagingSource.load(
-                createLoadParam(
-                    LoadType.APPEND,
-                    key = result.nextKey,
-                )
-            ) as PagingSource.LoadResult.Page
+            val result2 = pagingSource.append(key = result.nextKey)
 
             assertThat(result2.data).isEqualTo(
                 itemsList.subList(15, 20)
@@ -651,23 +540,39 @@
     }
 
     @Test
-    fun refreshKey_largerThanItemCount() {
+    fun load_refreshKeyGreaterThanItemCount_lastPage() {
         val pagingSource = LimitOffsetPagingSourceImpl(database)
         dao.addAllItems(itemsList)
         runBlocking {
-            // initial load, assume getRefreshKey returned invalid large key due to large number of
-            // items dropped
-            val result = pagingSource.load(
-                createLoadParam(
-                    LoadType.REFRESH,
-                    key = 250,
-                )
-            ) as PagingSource.LoadResult.Page
 
-            // should load last page
-            assertThat(result.data).containsExactlyElementsIn(
-                itemsList.subList(85, 100)
+            pagingSource.refresh(key = 70)
+
+            dao.deleteTestItems(40, 100)
+
+            // assume user was viewing last item of the refresh load with anchorPosition = 85,
+            // initialLoadSize = 15. This mimics how getRefreshKey() calculates refresh key.
+            val refreshKey = 85 - (15 / 2)
+            assertThat(refreshKey).isEqualTo(78)
+
+            val pagingSource2 = LimitOffsetPagingSourceImpl(database)
+            val result2 = pagingSource2.refresh(key = refreshKey)
+
+            // database should only have 40 items left. Refresh key is invalid at this point
+            // (greater than item count after deletion)
+            assertThat(pagingSource2.itemCount.get()).isEqualTo(40)
+            // ensure that paging source can handle invalid refresh key properly
+            // should load last page with items 25 - 40
+            assertThat(result2.data).containsExactlyElementsIn(
+                itemsList.subList(25, 40)
             )
+
+            // should account for updated item count to return correct itemsBefore, itemsAfter,
+            // prevKey, nextKey
+            assertThat(result2.itemsBefore).isEqualTo(25)
+            assertThat(result2.itemsAfter).isEqualTo(0)
+            // no append can be triggered
+            assertThat(result2.prevKey).isEqualTo(25)
+            assertThat(result2.nextKey).isEqualTo(null)
         }
     }
 
@@ -683,38 +588,23 @@
      * Ideally, in the future Paging will be able to handle this case better.
      */
     @Test
-    fun refreshKey_topItemsDeleted_loadFromBeginning() {
+    fun load_refreshKeyGreaterThanItemCount_firstPage() {
         val pagingSource = LimitOffsetPagingSourceImpl(database)
         dao.addAllItems(itemsList)
         runBlocking {
-            val result1 = pagingSource.load(
-                createLoadParam(
-                    LoadType.REFRESH,
-                    key = null,
-                )
-            ) as PagingSource.LoadResult.Page
+            pagingSource.refresh()
 
             assertThat(pagingSource.itemCount.get()).isEqualTo(100)
 
             // items id 0 - 29 deleted (30 items removed)
-            dao.deleteTestItems()
+            dao.deleteTestItems(0, 29)
 
             val pagingSource2 = LimitOffsetPagingSourceImpl(database)
-            // assume user was viewing first few items with anchorPosition = 0
-            val refreshKey = pagingSource.getRefreshKey(
-                PagingState(
-                    pages = listOf(result1),
-                    anchorPosition = 0,
-                    config = CONFIG,
-                    leadingPlaceholderCount = 0,
-                )
-            )
-            val result2 = pagingSource2.load(
-                createLoadParam(
-                    LoadType.REFRESH,
-                    key = refreshKey,
-                )
-            ) as PagingSource.LoadResult.Page
+            // assume user was viewing first few items with anchorPosition = 0 and refresh key
+            // clips to 0
+            val refreshKey = 0
+
+            val result2 = pagingSource2.refresh(key = refreshKey)
 
             // database should only have 70 items left
             assertThat(pagingSource2.itemCount.get()).isEqualTo(70)
@@ -722,6 +612,48 @@
             assertThat(result2.data).containsExactlyElementsIn(
                 itemsList.subList(30, 45)
             )
+
+            // should account for updated item count to return correct itemsBefore, itemsAfter,
+            // prevKey, nextKey
+            assertThat(result2.itemsBefore).isEqualTo(0)
+            assertThat(result2.itemsAfter).isEqualTo(55)
+            // no prepend can be triggered
+            assertThat(result2.prevKey).isEqualTo(null)
+            assertThat(result2.nextKey).isEqualTo(15)
+        }
+    }
+
+    @Test
+    fun load_loadSizeAndRefreshKeyGreaterThanItemCount() {
+        val pagingSource = LimitOffsetPagingSourceImpl(database)
+        dao.addAllItems(itemsList)
+        runBlocking {
+
+            pagingSource.refresh(key = 30)
+
+            assertThat(pagingSource.itemCount.get()).isEqualTo(100)
+            // items id 0 - 94 deleted (95 items removed)
+            dao.deleteTestItems(0, 94)
+
+            val pagingSource2 = LimitOffsetPagingSourceImpl(database)
+            // assume user was viewing first few items with anchorPosition = 0 and refresh key
+            // clips to 0
+            val refreshKey = 0
+
+            val result2 = pagingSource2.refresh(key = refreshKey)
+
+            // database should only have 5 items left
+            assertThat(pagingSource2.itemCount.get()).isEqualTo(5)
+            // only 5 items should be loaded with offset = 0
+            assertThat(result2.data).containsExactlyElementsIn(
+                itemsList.subList(95, 100)
+            )
+
+            // should recognize that this is a terminal load
+            assertThat(result2.itemsBefore).isEqualTo(0)
+            assertThat(result2.itemsAfter).isEqualTo(0)
+            assertThat(result2.prevKey).isEqualTo(null)
+            assertThat(result2.nextKey).isEqualTo(null)
         }
     }
 
@@ -771,6 +703,39 @@
         }
     }
 
+    private suspend fun PagingSource<Int, TestItem>.refresh(
+        key: Int? = null,
+    ): PagingSource.LoadResult.Page<Int, TestItem> {
+        return this.load(
+            createLoadParam(
+                loadType = LoadType.REFRESH,
+                key = key,
+            )
+        ) as PagingSource.LoadResult.Page
+    }
+
+    private suspend fun PagingSource<Int, TestItem>.append(
+        key: Int? = -1,
+    ): PagingSource.LoadResult.Page<Int, TestItem> {
+        return this.load(
+            createLoadParam(
+                loadType = LoadType.APPEND,
+                key = key,
+            )
+        ) as PagingSource.LoadResult.Page
+    }
+
+    private suspend fun PagingSource<Int, TestItem>.prepend(
+        key: Int? = -1,
+    ): PagingSource.LoadResult.Page<Int, TestItem> {
+        return this.load(
+            createLoadParam(
+                loadType = LoadType.PREPEND,
+                key = key,
+            )
+        ) as PagingSource.LoadResult.Page
+    }
+
     companion object {
         val CONFIG = PagingConfig(
             pageSize = 5,
diff --git a/room/room-paging/src/androidTest/kotlin/androidx/room/paging/TestItemDao.kt b/room/room-paging/src/androidTest/kotlin/androidx/room/paging/TestItemDao.kt
index 6b50fcda..8863f6f 100644
--- a/room/room-paging/src/androidTest/kotlin/androidx/room/paging/TestItemDao.kt
+++ b/room/room-paging/src/androidTest/kotlin/androidx/room/paging/TestItemDao.kt
@@ -32,6 +32,6 @@
     @Delete
     fun deleteTestItem(testItem: TestItem): Int
 
-    @Query("DELETE FROM TestItem WHERE id < 30")
-    fun deleteTestItems(): Int
+    @Query("DELETE FROM TestItem WHERE id >= :start AND id <= :end")
+    fun deleteTestItems(start: Int, end: Int): Int
 }
\ No newline at end of file