Use IdentityArraySet to store snapshot invalidations

Updates snapshots to use IdentityArraySet for invalidations with faster allocation-less iteration.

Updates iteration in Composition and SnapshotStateObserver to use IdentityArraySet iteration if instance matches.

Partially addresses b/271109624.

Test: IdentityArraySetTest
Change-Id: I476e3c43cabd165ad8edce928aa3b46c06952779
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Composition.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Composition.kt
index bc7d310..ddbb91e 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Composition.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Composition.kt
@@ -20,6 +20,7 @@
 import androidx.compose.runtime.collection.IdentityArrayMap
 import androidx.compose.runtime.collection.IdentityArraySet
 import androidx.compose.runtime.collection.IdentityScopeMap
+import androidx.compose.runtime.collection.fastForEach
 import androidx.compose.runtime.snapshots.fastAll
 import androidx.compose.runtime.snapshots.fastAny
 import androidx.compose.runtime.snapshots.fastForEach
@@ -693,7 +694,7 @@
             }
         }
 
-        for (value in values) {
+        values.fastForEach { value ->
             if (value is RecomposeScopeImpl) {
                 value.invalidateForResult(null)
             } else {
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Recomposer.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Recomposer.kt
index b9d3905..8569cbf 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Recomposer.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Recomposer.kt
@@ -199,7 +199,7 @@
     private var runnerJob: Job? = null
     private var closeCause: Throwable? = null
     private val knownCompositions = mutableListOf<ControlledComposition>()
-    private var snapshotInvalidations = mutableSetOf<Any>()
+    private var snapshotInvalidations = IdentityArraySet<Any>()
     private val compositionInvalidations = mutableListOf<ControlledComposition>()
     private val compositionsAwaitingApply = mutableListOf<ControlledComposition>()
     private val compositionValuesAwaitingInsert = mutableListOf<MovableContentStateReference>()
@@ -280,7 +280,7 @@
     private fun deriveStateLocked(): CancellableContinuation<Unit>? {
         if (_state.value <= State.ShuttingDown) {
             knownCompositions.clear()
-            snapshotInvalidations = mutableSetOf()
+            snapshotInvalidations = IdentityArraySet()
             compositionInvalidations.clear()
             compositionsAwaitingApply.clear()
             compositionValuesAwaitingInsert.clear()
@@ -296,7 +296,7 @@
                 State.Inactive
             }
             runnerJob == null -> {
-                snapshotInvalidations = mutableSetOf()
+                snapshotInvalidations = IdentityArraySet()
                 compositionInvalidations.clear()
                 if (broadcastFrameClock.hasAwaiters) State.InactivePendingWork else State.Inactive
             }
@@ -420,7 +420,7 @@
                     if (_state.value <= State.ShuttingDown) return@run
                 }
             }
-            snapshotInvalidations = mutableSetOf()
+            snapshotInvalidations = IdentityArraySet()
             if (deriveStateLocked() != null) {
                 error("called outside of runRecomposeAndApplyChanges")
             }
@@ -435,7 +435,7 @@
             knownCompositions.fastForEach { composition ->
                 composition.recordModificationsOf(changes)
             }
-            snapshotInvalidations = mutableSetOf()
+            snapshotInvalidations = IdentityArraySet()
         }
         compositionInvalidations.fastForEach(onEachInvalidComposition)
         compositionInvalidations.clear()
@@ -657,7 +657,7 @@
 
                 compositionsAwaitingApply.clear()
                 compositionInvalidations.clear()
-                snapshotInvalidations = mutableSetOf()
+                snapshotInvalidations = IdentityArraySet()
 
                 compositionValuesAwaitingInsert.clear()
                 compositionValuesRemoved.clear()
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityArraySet.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityArraySet.kt
index 1f2404b..35e9a2b 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityArraySet.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityArraySet.kt
@@ -27,9 +27,11 @@
 @OptIn(ExperimentalContracts::class)
 internal class IdentityArraySet<T : Any> : Set<T> {
     override var size = 0
+        private set
 
     @PublishedApi
     internal var values: Array<Any?> = arrayOfNulls(16)
+        private set
 
     /**
      * Returns true if the set contains [element]
@@ -103,8 +105,125 @@
      */
     inline fun fastForEach(block: (T) -> Unit) {
         contract { callsInPlace(block) }
+        val values = values
         for (i in 0 until size) {
-            block(this[i])
+            @Suppress("UNCHECKED_CAST")
+            block(values[i] as T)
+        }
+    }
+
+    fun addAll(collection: Collection<T>) {
+        if (collection.isEmpty()) return
+
+        if (collection !is IdentityArraySet<T>) {
+            // Unknown collection, just add repeatedly
+            for (value in collection) {
+                add(value)
+            }
+        } else {
+            // Identity set, merge sorted arrays
+            val thisValues = values
+            val otherValues = collection.values
+            val thisSize = size
+            val otherSize = collection.size
+            val combinedSize = thisSize + otherSize
+
+            val needsResize = values.size < combinedSize
+            val elementsInOrder = thisSize == 0 ||
+                identityHashCode(thisValues[thisSize - 1]) < identityHashCode(otherValues[0])
+
+            if (!needsResize && elementsInOrder) {
+                // fast path, just copy target values
+                otherValues.copyInto(
+                    destination = thisValues,
+                    destinationOffset = thisSize,
+                    startIndex = 0,
+                    endIndex = otherSize
+                )
+                size += otherSize
+            } else {
+                // slow path, merge this and other values
+                val newArray = if (needsResize) {
+                    arrayOfNulls(combinedSize)
+                } else {
+                    thisValues
+                }
+                var thisIndex = thisSize - 1
+                var otherIndex = otherSize - 1
+                var nextInsertIndex = combinedSize - 1
+                while (thisIndex >= 0 || otherIndex >= 0) {
+                    val nextValue = when {
+                        thisIndex < 0 -> otherValues[otherIndex--]
+                        otherIndex < 0 -> thisValues[thisIndex--]
+                        else -> {
+                            val thisValue = thisValues[thisIndex]
+                            val otherValue = otherValues[otherIndex]
+
+                            val thisHash = identityHashCode(thisValue)
+                            val otherHash = identityHashCode(otherValue)
+                            when {
+                                thisHash > otherHash -> {
+                                    thisIndex--
+                                    thisValue
+                                }
+                                thisHash < otherHash -> {
+                                    otherIndex--
+                                    otherValue
+                                }
+                                thisValue === otherValue -> {
+                                    // hash and the value are the same, advance both pointers
+                                    thisIndex--
+                                    otherIndex--
+                                    thisValue
+                                }
+                                else -> {
+                                    // collision, lookup if the same item is in the array
+                                    var i = thisIndex - 1
+                                    var foundDuplicate = false
+                                    while (i >= 0) {
+                                        val value = thisValues[i--]
+                                        if (identityHashCode(value) != otherHash) break
+                                        if (otherValue === value) {
+                                            foundDuplicate = true
+                                            break
+                                        }
+                                    }
+
+                                    if (foundDuplicate) {
+                                        // advance pointer and continue next iteration of outer
+                                        // merge loop.
+                                        otherIndex--
+                                        continue
+                                    } else {
+                                        // didn't find the duplicate, put other item in array.
+                                        otherIndex--
+                                        otherValue
+                                    }
+                                }
+                            }
+                        }
+                    }
+
+                    // insert value and continue
+                    newArray[nextInsertIndex--] = nextValue
+                }
+
+                if (nextInsertIndex >= 0) {
+                    // some values were duplicated, copy the merged part
+                    newArray.copyInto(
+                        newArray,
+                        destinationOffset = 0,
+                        startIndex = nextInsertIndex + 1,
+                        endIndex = combinedSize
+                    )
+                }
+                // newSize = endOffset - startOffset of copy above
+                val newSize = combinedSize - (nextInsertIndex + 1)
+                newArray.fill(null, fromIndex = newSize, toIndex = combinedSize)
+
+                values = newArray
+                size = newSize
+            }
         }
     }
 
@@ -240,4 +359,16 @@
         override fun hasNext(): Boolean = index < size
         override fun next(): T = this@IdentityArraySet.values[index++] as T
     }
+
+    override fun toString(): String {
+        return joinToString(prefix = "[", postfix = "]") { it.toString() }
+    }
 }
+
+internal inline fun <T : Any> Set<T>.fastForEach(block: (T) -> Unit) {
+    if (this is IdentityArraySet<T>) {
+        fastForEach(block)
+    } else {
+        forEach(block)
+    }
+}
\ No newline at end of file
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
index 84f7d21..d66cff4 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
@@ -24,6 +24,7 @@
 import androidx.compose.runtime.ExperimentalComposeApi
 import androidx.compose.runtime.InternalComposeApi
 import androidx.compose.runtime.SnapshotThreadLocal
+import androidx.compose.runtime.collection.IdentityArraySet
 import androidx.compose.runtime.synchronized
 import kotlin.contracts.ExperimentalContracts
 import kotlin.contracts.InvocationKind
@@ -211,7 +212,7 @@
     /**
      * The set of state objects that have been modified in this snapshot.
      */
-    internal abstract val modified: MutableSet<StateObject>?
+    internal abstract val modified: IdentityArraySet<StateObject>?
 
     /**
      * Notify the snapshot that all objects created in this snapshot to this point should be
@@ -905,21 +906,21 @@
                             mergedRecords ?: mutableListOf<Pair<StateObject, StateRecord>>().also {
                                 mergedRecords = it
                             }
-                            ).add(state to current.create())
+                        ).add(state to current.create())
 
                         // If we revert to current then the state is no longer modified.
                         (
                             statesToRemove ?: mutableListOf<StateObject>().also {
                                 statesToRemove = it
                             }
-                            ).add(state)
+                        ).add(state)
                     }
                     else -> {
                         (
                             mergedRecords ?: mutableListOf<Pair<StateObject, StateRecord>>().also {
                                 mergedRecords = it
                             }
-                            ).add(
+                        ).add(
                             if (merged != previous) state to merged
                             else state to previous.create()
                         )
@@ -943,9 +944,9 @@
             }
         }
 
-        statesToRemove?.let {
+        statesToRemove?.fastForEach {
             // Remove from modified any state objects that have reverted to the parent value.
-            modified.removeAll(it)
+            modified.remove(it)
         }
 
         return SnapshotApplyResult.Success
@@ -1003,10 +1004,10 @@
     }
 
     override fun recordModified(state: StateObject) {
-        (modified ?: HashSet<StateObject>().also { modified = it }).add(state)
+        (modified ?: IdentityArraySet<StateObject>().also { modified = it }).add(state)
     }
 
-    override var modified: MutableSet<StateObject>? = null
+    override var modified: IdentityArraySet<StateObject>? = null
 
     /**
      * A set of the id's previously associated with this snapshot. When this snapshot closes
@@ -1206,7 +1207,7 @@
     override fun hasPendingChanges(): Boolean = false
     override val writeObserver: ((Any) -> Unit)? get() = null
 
-    override var modified: HashSet<StateObject>?
+    override var modified: IdentityArraySet<StateObject>?
         get() = null
         @Suppress("UNUSED_PARAMETER")
         set(value) = unsupported()
@@ -1277,7 +1278,7 @@
         }
     }
 
-    override val modified: HashSet<StateObject>? get() = null
+    override val modified: IdentityArraySet<StateObject>? get() = null
     override val writeObserver: ((Any) -> Unit)? get() = null
     override fun recordModified(state: StateObject) = reportReadonlySnapshotWrite()
 
@@ -1401,10 +1402,10 @@
 
                 // Add all modified objects in this set to the parent
                 (
-                    parent.modified ?: HashSet<StateObject>().also {
+                    parent.modified ?: IdentityArraySet<StateObject>().also {
                         parent.modified = it
                     }
-                    ).addAll(modified)
+                ).addAll(modified)
             }
 
             // Ensure the parent is newer than the current snapshot
@@ -1479,7 +1480,7 @@
 
     override fun hasPendingChanges(): Boolean = currentSnapshot.hasPendingChanges()
 
-    override var modified: MutableSet<StateObject>?
+    override var modified: IdentityArraySet<StateObject>?
         get() = currentSnapshot.modified
         @Suppress("UNUSED_PARAMETER")
         set(value) = unsupported()
@@ -1583,7 +1584,7 @@
 
     override fun hasPendingChanges(): Boolean = currentSnapshot.hasPendingChanges()
 
-    override var modified: MutableSet<StateObject>?
+    override var modified: IdentityArraySet<StateObject>?
         get() = currentSnapshot.modified
         @Suppress("UNUSED_PARAMETER")
         set(value) = unsupported()
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
index d8a6081..ad78a63 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
@@ -24,6 +24,7 @@
 import androidx.compose.runtime.collection.IdentityArrayMap
 import androidx.compose.runtime.collection.IdentityArraySet
 import androidx.compose.runtime.collection.IdentityScopeMap
+import androidx.compose.runtime.collection.fastForEach
 import androidx.compose.runtime.collection.mutableVectorOf
 import androidx.compose.runtime.composeRuntimeError
 import androidx.compose.runtime.observeDerivedStateRecalculations
@@ -509,7 +510,7 @@
          */
         fun recordInvalidation(changes: Set<Any>): Boolean {
             var hasValues = false
-            for (value in changes) {
+            changes.fastForEach { value ->
                 if (value in dependencyToDerivedStates) {
                     // Find derived state that is invalidated by this change
                     dependencyToDerivedStates.forEachScopeOf(value) { derivedState ->
diff --git a/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/collection/IdentityArraySetTest.kt b/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/collection/IdentityArraySetTest.kt
index cad7e12..617e518 100644
--- a/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/collection/IdentityArraySetTest.kt
+++ b/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/collection/IdentityArraySetTest.kt
@@ -182,6 +182,54 @@
         assertTrue(setOfT.containsAll(listOf(stuff[0], stuff[1], stuff[2])))
     }
 
+    @Test
+    fun addAll_Collection() {
+        set.addAll(list)
+
+        assertEquals(list.size, set.size)
+        for (value in list) {
+            assertTrue(value in set)
+        }
+    }
+
+    @Test
+    fun addAll_IdentityArraySet() {
+        val anotherSet = IdentityArraySet<Stuff>()
+        anotherSet.addAll(list)
+
+        set.addAll(anotherSet)
+
+        for (value in list) {
+            assertTrue(value in set)
+        }
+
+        set.addAll(anotherSet)
+
+        assertEquals(anotherSet.size, set.size)
+        for (value in list) {
+            assertTrue(value in set)
+        }
+
+        val stuff = Array(100) { Stuff(it) }
+        for (i in 0 until 100 step 2) {
+            anotherSet.add(stuff[i])
+        }
+        set.addAll(anotherSet)
+
+        for (i in stuff.indices) {
+            val value = stuff[i]
+            if (i % 2 == 0) {
+                assertTrue(value in set, "Expected to have element $i in $set")
+            } else {
+                assertFalse(value in set, "Didn't expect to have element $i in $set")
+            }
+        }
+
+        for (value in list) {
+            assertTrue(value in set)
+        }
+    }
+
     private fun testRemoveValueAtIndex(index: Int) {
         val value = set[index]
         val initialSize = set.size
diff --git a/external/paparazzi/paparazzi/src/main/java/app/cash/paparazzi/Paparazzi.kt b/external/paparazzi/paparazzi/src/main/java/app/cash/paparazzi/Paparazzi.kt
index 0a96bcf..40ce39e 100644
--- a/external/paparazzi/paparazzi/src/main/java/app/cash/paparazzi/Paparazzi.kt
+++ b/external/paparazzi/paparazzi/src/main/java/app/cash/paparazzi/Paparazzi.kt
@@ -572,10 +572,19 @@
       val snapshotInvalidations = recomposer.javaClass
         .getDeclaredField("snapshotInvalidations")
         .apply { isAccessible = true }
-        .get(recomposer) as MutableCollection<*>
+        .get(recomposer)
       compositionInvalidations.clear()
-      snapshotInvalidations.clear()
       applyObservers.clear()
+
+      if (snapshotInvalidations is MutableCollection<*>) {
+        snapshotInvalidations.clear()
+      } else {
+        // backed by IdentityArraySet
+        snapshotInvalidations.javaClass
+          .getDeclaredMethod("clear")
+          .apply { isAccessible = true }
+          .invoke(snapshotInvalidations)
+      }
     }
 
     val dispatcher =