Cache field access on internal sets and maps in runtime

Most iteration methods in identity maps and sets are accessing internal storage fields very frequently (e.g. on each iteration of removeIf loop). This change caches field accesses to local variables, reducing the cost.

Test: existing collection tests
Change-Id: I842cb2e016adfe68cf2b003e7b421f92a6321b9f
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityArrayIntMap.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityArrayIntMap.kt
index 272450c..5644c62 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityArrayIntMap.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityArrayIntMap.kt
@@ -17,18 +17,14 @@
 package androidx.compose.runtime.collection
 
 import androidx.compose.runtime.identityHashCode
-import kotlin.contracts.ExperimentalContracts
 
-@OptIn(ExperimentalContracts::class)
 internal class IdentityArrayIntMap {
-    @PublishedApi
     internal var size = 0
-
-    @PublishedApi
+        private set
     internal var keys: Array<Any?> = arrayOfNulls(4)
-
-    @PublishedApi
+        private set
     internal var values: IntArray = IntArray(4)
+        private set
 
     operator fun get(key: Any): Int {
         val index = find(key)
@@ -38,6 +34,8 @@
      * Add [value] to the map and return `-1` if it was added or previous value if it already existed.
      */
     fun add(key: Any, value: Int): Int {
+        val values = values
+
         val index: Int
         if (size > 0) {
             index = find(key)
@@ -52,6 +50,8 @@
 
         val insertIndex = -(index + 1)
 
+        val keys = keys
+        val size = size
         if (size == keys.size) {
             val newKeys = arrayOfNulls<Any>(keys.size * 2)
             val newValues = IntArray(keys.size * 2)
@@ -75,8 +75,8 @@
                 destination = newValues,
                 endIndex = insertIndex
             )
-            keys = newKeys
-            values = newValues
+            this.keys = newKeys
+            this.values = newValues
         } else {
             keys.copyInto(
                 destination = keys,
@@ -91,9 +91,9 @@
                 endIndex = size
             )
         }
-        keys[insertIndex] = key
-        values[insertIndex] = value
-        size++
+        this.keys[insertIndex] = key
+        this.values[insertIndex] = value
+        this.size++
 
         return -1
     }
@@ -103,6 +103,10 @@
      */
     fun remove(key: Any): Boolean {
         val index = find(key)
+
+        val keys = keys
+        val values = values
+        val size = size
         if (index >= 0) {
             if (index < size - 1) {
                 keys.copyInto(
@@ -118,8 +122,9 @@
                     endIndex = size
                 )
             }
-            size--
-            keys[size] = null
+            val newSize = size - 1
+            keys[newSize] = null
+            this.size = newSize
             return true
         }
         return false
@@ -129,6 +134,10 @@
      * Removes all values that match [predicate].
      */
     inline fun removeValueIf(predicate: (Any, Int) -> Boolean) {
+        val keys = keys
+        val values = values
+        val size = size
+
         var destinationIndex = 0
         for (i in 0 until size) {
             @Suppress("UNCHECKED_CAST")
@@ -145,10 +154,14 @@
         for (i in destinationIndex until size) {
             keys[i] = null
         }
-        size = destinationIndex
+        this.size = destinationIndex
     }
 
     inline fun any(predicate: (Any, Int) -> Boolean): Boolean {
+        val keys = keys
+        val values = values
+        val size = size
+
         for (i in 0 until size) {
             if (predicate(keys[i] as Any, values[i])) return true
         }
@@ -156,6 +169,10 @@
     }
 
     inline fun forEach(block: (Any, Int) -> Unit) {
+        val keys = keys
+        val values = values
+        val size = size
+
         for (i in 0 until size) {
             block(keys[i] as Any, values[i])
         }
@@ -170,6 +187,7 @@
         var high = size - 1
         val valueIdentity = identityHashCode(key)
 
+        val keys = keys
         while (low <= high) {
             val mid = (low + high).ushr(1)
             val midVal = keys[mid]
@@ -192,6 +210,9 @@
      * be returned, which is always after the last item with the same [identityHashCode].
      */
     private fun findExactIndex(midIndex: Int, value: Any?, valueHash: Int): Int {
+        val keys = keys
+        val size = size
+
         // hunt down first
         for (i in midIndex - 1 downTo 0) {
             val v = keys[i]
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityArrayMap.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityArrayMap.kt
index 13d31df..cb6d619 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityArrayMap.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityArrayMap.kt
@@ -20,8 +20,11 @@
 
 internal class IdentityArrayMap<Key : Any, Value : Any?>(capacity: Int = 16) {
     internal var keys = arrayOfNulls<Any?>(capacity)
+        private set
     internal var values = arrayOfNulls<Any?>(capacity)
+        private set
     internal var size = 0
+        private set
 
     fun isEmpty() = size == 0
     fun isNotEmpty() = size > 0
@@ -35,6 +38,10 @@
     }
 
     operator fun set(key: Key, value: Value) {
+        val keys = keys
+        val values = values
+        val size = size
+
         val index = find(key)
         if (index >= 0) {
             values[index] = value
@@ -57,7 +64,7 @@
                 )
             }
             destKeys[insertIndex] = key
-            keys = destKeys
+            this.keys = destKeys
             val destValues = if (resize) {
                 arrayOfNulls(size * 2)
             } else values
@@ -74,8 +81,8 @@
                 )
             }
             destValues[insertIndex] = value
-            values = destValues
-            size++
+            this.values = destValues
+            this.size++
         }
     }
 
@@ -158,6 +165,7 @@
         var low = 0
         var high = size - 1
 
+        val keys = keys
         while (low <= high) {
             val mid = (low + high).ushr(1)
             val midKey = keys[mid]
@@ -180,6 +188,9 @@
      * be returned, which is always after the last key with the same [identityHashCode].
      */
     private fun findExactIndex(midIndex: Int, key: Any?, keyHash: Int): Int {
+        val keys = keys
+        val size = size
+
         // hunt down first
         for (i in midIndex - 1 downTo 0) {
             val k = keys[i]
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityArraySet.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityArraySet.kt
index 35e9a2b..29721ca 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityArraySet.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityArraySet.kt
@@ -53,6 +53,9 @@
      */
     fun add(value: T): Boolean {
         val index: Int
+        val size = size
+        val values = values
+
         if (size > 0) {
             index = find(value)
 
@@ -77,7 +80,7 @@
                 destination = newSorted,
                 endIndex = insertIndex
             )
-            values = newSorted
+            this.values = newSorted
         } else {
             values.copyInto(
                 destination = values,
@@ -86,8 +89,8 @@
                 endIndex = size
             )
         }
-        values[insertIndex] = value
-        size++
+        this.values[insertIndex] = value
+        this.size++
         return true
     }
 
@@ -96,7 +99,6 @@
      */
     fun clear() {
         values.fill(null)
-
         size = 0
     }
 
@@ -144,7 +146,7 @@
             } else {
                 // slow path, merge this and other values
                 val newArray = if (needsResize) {
-                    arrayOfNulls(combinedSize)
+                    arrayOfNulls(if (thisSize > otherSize) thisSize * 2 else otherSize * 2)
                 } else {
                     thisValues
                 }
@@ -242,6 +244,9 @@
      */
     fun remove(value: T): Boolean {
         val index = find(value)
+        val values = values
+        val size = size
+
         if (index >= 0) {
             if (index < size - 1) {
                 values.copyInto(
@@ -251,8 +256,8 @@
                     endIndex = size
                 )
             }
-            size--
-            values[size] = null
+            values[size - 1] = null
+            this.size--
             return true
         }
         return false
@@ -262,6 +267,9 @@
      * Removes all values that match [predicate].
      */
     inline fun removeValueIf(predicate: (T) -> Boolean) {
+        val values = values
+        val size = size
+
         var destinationIndex = 0
         for (i in 0 until size) {
             @Suppress("UNCHECKED_CAST")
@@ -276,7 +284,7 @@
         for (i in destinationIndex until size) {
             values[i] = null
         }
-        size = destinationIndex
+        this.size = destinationIndex
     }
 
     /**
@@ -287,10 +295,11 @@
         var low = 0
         var high = size - 1
         val valueIdentity = identityHashCode(value)
+        val values = values
 
         while (low <= high) {
             val mid = (low + high).ushr(1)
-            val midVal = get(mid)
+            val midVal = values[mid]
             val midIdentity = identityHashCode(midVal)
             when {
                 midIdentity < valueIdentity -> low = mid + 1
@@ -309,7 +318,14 @@
      * If no match is found, the negative index - 1 of the position in which it would be will
      * be returned, which is always after the last item with the same [identityHashCode].
      */
-    private fun findExactIndex(midIndex: Int, value: Any?, valueHash: Int): Int {
+    private fun findExactIndex(
+        midIndex: Int,
+        value: Any?,
+        valueHash: Int
+    ): Int {
+        val values = values
+        val size = size
+
         // hunt down first
         for (i in midIndex - 1 downTo 0) {
             val v = values[i]
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityScopeMap.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityScopeMap.kt
index cf102a2..f94b85f 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityScopeMap.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityScopeMap.kt
@@ -54,14 +54,6 @@
     internal var size = 0
 
     /**
-     * Returns the value at the given [index] order in the map.
-     */
-    @Suppress("NOTHING_TO_INLINE")
-    private inline fun valueAt(index: Int): Any {
-        return values[valueOrder[index]]!!
-    }
-
-    /**
      * Returns the [IdentityArraySet] for the value at the given [index] order in the map.
      */
     private fun scopeSetAt(index: Int): IdentityArraySet<T> {
@@ -97,6 +89,11 @@
      * and insertes it into the map and returns it.
      */
     private fun getOrCreateIdentitySet(value: Any): IdentityArraySet<T> {
+        val size = size
+        val valueOrder = valueOrder
+        val values = values
+        val scopeSets = scopeSets
+
         val index: Int
         if (size > 0) {
             index = find(value)
@@ -127,18 +124,18 @@
                 )
             }
             valueOrder[insertIndex] = valueIndex
-            size++
+            this.size++
             return scopeSet
         }
 
         // We have to increase the size of all arrays
         val newSize = valueOrder.size * 2
         val valueIndex = size
-        scopeSets = scopeSets.copyOf(newSize)
+        val newScopeSets = scopeSets.copyOf(newSize)
         val scopeSet = IdentityArraySet<T>()
-        scopeSets[valueIndex] = scopeSet
-        values = values.copyOf(newSize)
-        values[valueIndex] = value
+        newScopeSets[valueIndex] = scopeSet
+        val newValues = values.copyOf(newSize)
+        newValues[valueIndex] = value
 
         val newKeyOrder = IntArray(newSize)
         for (i in size + 1 until newSize) {
@@ -160,8 +157,10 @@
                 endIndex = insertIndex
             )
         }
-        valueOrder = newKeyOrder
-        size++
+        this.scopeSets = newScopeSets
+        this.values = newValues
+        this.valueOrder = newKeyOrder
+        this.size++
         return scopeSet
     }
 
@@ -169,7 +168,11 @@
      * Removes all values and scopes from the map
      */
     fun clear() {
-        for (i in 0 until scopeSets.size) {
+        val scopeSets = scopeSets
+        val valueOrder = valueOrder
+        val values = values
+
+        for (i in scopeSets.indices) {
             scopeSets[i]?.clear()
             valueOrder[i] = i
             values[i] = null
@@ -188,6 +191,11 @@
      */
     fun remove(value: Any, scope: T): Boolean {
         val index = find(value)
+
+        val valueOrder = valueOrder
+        val scopeSets = scopeSets
+        val values = values
+        val size = size
         if (index >= 0) {
             val valueOrderIndex = valueOrder[index]
             val set = scopeSets[valueOrderIndex] ?: return false
@@ -203,9 +211,10 @@
                         endIndex = endIndex
                     )
                 }
-                valueOrder[size - 1] = valueOrderIndex
+                val newSize = size - 1
+                valueOrder[newSize] = valueOrderIndex
                 values[valueOrderIndex] = null
-                size--
+                this.size = newSize
             }
             return removed
         }
@@ -233,6 +242,9 @@
     }
 
     private inline fun removingScopes(removalOperation: (IdentityArraySet<T>) -> Unit) {
+        val valueOrder = valueOrder
+        val scopeSets = scopeSets
+        val values = values
         var destinationIndex = 0
         for (i in 0 until size) {
             val valueIndex = valueOrder[i]
@@ -265,9 +277,11 @@
         var low = 0
         var high = size - 1
 
+        val values = values
+        val valueOrder = valueOrder
         while (low <= high) {
             val mid = (low + high).ushr(1)
-            val midValue = valueAt(mid)
+            val midValue = values[valueOrder[mid]]
             val midValHash = identityHashCode(midValue)
             when {
                 midValHash < valueIdentity -> low = mid + 1
@@ -287,9 +301,12 @@
      * be returned, which is always after the last item with the same [identityHashCode].
      */
     private fun findExactIndex(midIndex: Int, value: Any?, valueHash: Int): Int {
+        val values = values
+        val valueOrder = valueOrder
+
         // hunt down first
         for (i in midIndex - 1 downTo 0) {
-            val v = valueAt(i)
+            val v = values[valueOrder[i]]
             if (v === value) {
                 return i
             }
@@ -299,7 +316,7 @@
         }
 
         for (i in midIndex + 1 until size) {
-            val v = valueAt(i)
+            val v = values[valueOrder[i]]
             if (v === value) {
                 return i
             }