Reduce the overhead of find the current snapshot

Now uses a custom implementaiton of a thread local variable
to reduce the overhead of access the current snapshot of a
thread.

Fixes: b/182168191
Test: ./gradlew :compose:r:r:tDUT
Relnote: N/A
Change-Id: Iab53ce4ae2c4396e63581f36014c3318036368e0
diff --git a/compose/runtime/runtime/api/1.0.0-beta02.txt b/compose/runtime/runtime/api/1.0.0-beta02.txt
index a1c11df..d73a573 100644
--- a/compose/runtime/runtime/api/1.0.0-beta02.txt
+++ b/compose/runtime/runtime/api/1.0.0-beta02.txt
@@ -456,6 +456,9 @@
     method public static boolean isLiveLiteralsEnabled();
   }
 
+  public final class ThreadMapKt {
+  }
+
 }
 
 package androidx.compose.runtime.snapshots {
diff --git a/compose/runtime/runtime/api/current.txt b/compose/runtime/runtime/api/current.txt
index a1c11df..d73a573 100644
--- a/compose/runtime/runtime/api/current.txt
+++ b/compose/runtime/runtime/api/current.txt
@@ -456,6 +456,9 @@
     method public static boolean isLiveLiteralsEnabled();
   }
 
+  public final class ThreadMapKt {
+  }
+
 }
 
 package androidx.compose.runtime.snapshots {
diff --git a/compose/runtime/runtime/api/public_plus_experimental_1.0.0-beta02.txt b/compose/runtime/runtime/api/public_plus_experimental_1.0.0-beta02.txt
index f900999d..c803652 100644
--- a/compose/runtime/runtime/api/public_plus_experimental_1.0.0-beta02.txt
+++ b/compose/runtime/runtime/api/public_plus_experimental_1.0.0-beta02.txt
@@ -544,6 +544,9 @@
     property public abstract int parameters;
   }
 
+  public final class ThreadMapKt {
+  }
+
 }
 
 package androidx.compose.runtime.snapshots {
diff --git a/compose/runtime/runtime/api/public_plus_experimental_current.txt b/compose/runtime/runtime/api/public_plus_experimental_current.txt
index f900999d..c803652 100644
--- a/compose/runtime/runtime/api/public_plus_experimental_current.txt
+++ b/compose/runtime/runtime/api/public_plus_experimental_current.txt
@@ -544,6 +544,9 @@
     property public abstract int parameters;
   }
 
+  public final class ThreadMapKt {
+  }
+
 }
 
 package androidx.compose.runtime.snapshots {
diff --git a/compose/runtime/runtime/api/restricted_1.0.0-beta02.txt b/compose/runtime/runtime/api/restricted_1.0.0-beta02.txt
index a334aa4..d80e38d 100644
--- a/compose/runtime/runtime/api/restricted_1.0.0-beta02.txt
+++ b/compose/runtime/runtime/api/restricted_1.0.0-beta02.txt
@@ -483,6 +483,9 @@
     method public static boolean isLiveLiteralsEnabled();
   }
 
+  public final class ThreadMapKt {
+  }
+
 }
 
 package androidx.compose.runtime.snapshots {
diff --git a/compose/runtime/runtime/api/restricted_current.txt b/compose/runtime/runtime/api/restricted_current.txt
index a334aa4..d80e38d 100644
--- a/compose/runtime/runtime/api/restricted_current.txt
+++ b/compose/runtime/runtime/api/restricted_current.txt
@@ -483,6 +483,9 @@
     method public static boolean isLiveLiteralsEnabled();
   }
 
+  public final class ThreadMapKt {
+  }
+
 }
 
 package androidx.compose.runtime.snapshots {
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Expect.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Expect.kt
index 2189a52..91193ee 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Expect.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Expect.kt
@@ -26,6 +26,20 @@
 
 internal fun <T> ThreadLocal() = ThreadLocal<T?> { null }
 
+/**
+ * This is similar to a [ThreadLocal] but has lower overhead because it avoids a weak reference.
+ * This should only be used when the writes are delimited by a try...finally call that will clean
+ * up the reference such as [androidx.compose.runtime.snapshots.Snapshot.enter] else the reference
+ * could get pinned by the thread local causing a leak.
+ *
+ * [ThreadLocal] can be used to implement the actual for platforms that do not exhibit the same
+ * overhead for thread locals as the JVM and ART.
+ */
+internal expect class SnapshotThreadLocal<T>() {
+    fun get(): T?
+    fun set(value: T?)
+}
+
 internal expect fun identityHashCode(instance: Any?): Int
 
 @PublishedApi
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
index 2f73c00..2927d96 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
@@ -20,7 +20,7 @@
 
 import androidx.compose.runtime.Composable
 import androidx.compose.runtime.InternalComposeApi
-import androidx.compose.runtime.ThreadLocal
+import androidx.compose.runtime.SnapshotThreadLocal
 import androidx.compose.runtime.synchronized
 
 /**
@@ -1355,7 +1355,10 @@
  */
 private const val INVALID_SNAPSHOT = 0
 
-private val threadSnapshot = ThreadLocal<Snapshot>()
+/**
+ * Current thread snapshot
+ */
+private val threadSnapshot = SnapshotThreadLocal<Snapshot>()
 
 // A global synchronization object. This synchronization object should be taken before modifying any
 // of the fields below.
diff --git a/compose/runtime/runtime/src/jvmMain/kotlin/androidx/compose/runtime/ActualJvm.jvm.kt b/compose/runtime/runtime/src/jvmMain/kotlin/androidx/compose/runtime/ActualJvm.jvm.kt
index 7aa0d96..63dbcb1 100644
--- a/compose/runtime/runtime/src/jvmMain/kotlin/androidx/compose/runtime/ActualJvm.jvm.kt
+++ b/compose/runtime/runtime/src/jvmMain/kotlin/androidx/compose/runtime/ActualJvm.jvm.kt
@@ -16,6 +16,9 @@
 
 package androidx.compose.runtime
 
+import androidx.compose.runtime.internal.ThreadMap
+import androidx.compose.runtime.internal.emptyThreadMap
+
 internal actual typealias AtomicReference<V> = java.util.concurrent.atomic.AtomicReference<V>
 
 internal actual open class ThreadLocal<T> actual constructor(
@@ -35,6 +38,23 @@
     }
 }
 
+internal actual class SnapshotThreadLocal<T> {
+    private val map = AtomicReference<ThreadMap>(emptyThreadMap)
+    private val writeMutex = Any()
+
+    @Suppress("UNCHECKED_CAST")
+    actual fun get(): T? = map.get().get(Thread.currentThread().id) as T?
+
+    actual fun set(value: T?) {
+        val key = Thread.currentThread().id
+        synchronized(writeMutex) {
+            val current = map.get()
+            if (current.trySet(key, value)) return
+            map.set(current.newWith(key, value))
+        }
+    }
+}
+
 internal actual fun identityHashCode(instance: Any?): Int = System.identityHashCode(instance)
 
 @PublishedApi
diff --git a/compose/runtime/runtime/src/jvmMain/kotlin/androidx/compose/runtime/internal/ThreadMap.kt b/compose/runtime/runtime/src/jvmMain/kotlin/androidx/compose/runtime/internal/ThreadMap.kt
new file mode 100644
index 0000000..d4645eb
--- /dev/null
+++ b/compose/runtime/runtime/src/jvmMain/kotlin/androidx/compose/runtime/internal/ThreadMap.kt
@@ -0,0 +1,111 @@
+/*
+ * Copyright 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.compose.runtime.internal
+
+internal class ThreadMap(
+    private val size: Int,
+    private val keys: LongArray,
+    private val values: Array<Any?>
+) {
+    fun get(key: Long): Any? {
+        val index = find(key)
+        return if (index >= 0) values[index] else null
+    }
+
+    /**
+     * Set the value if it is already in the map. Otherwise a new map must be allocated to contain
+     * the new entry.
+     */
+    fun trySet(key: Long, value: Any?): Boolean {
+        val index = find(key)
+        if (index < 0) return false
+        values[index] = value
+        return true
+    }
+
+    fun newWith(key: Long, value: Any?): ThreadMap {
+        val size = size
+        val newSize = values.count { it != null } + 1
+        val newKeys = LongArray(newSize)
+        val newValues = arrayOfNulls<Any?>(newSize)
+        if (newSize > 1) {
+            var dest = 0
+            var source = 0
+            while (dest < newSize && source < size) {
+                val oldKey = keys[source]
+                val oldValue = values[source]
+                if (oldKey > key) {
+                    newKeys[dest] = key
+                    newValues[dest] = value
+                    dest++
+                    // Continue with a loop without this check
+                    break
+                }
+                if (oldValue != null) {
+                    newKeys[dest] = oldKey
+                    newValues[dest] = oldValue
+                    dest++
+                }
+                source++
+            }
+            if (source == size) {
+                // Appending a value to the end.
+                newKeys[newSize - 1] = key
+                newValues[newSize - 1] = value
+            } else {
+                while (dest < newSize) {
+                    val oldKey = keys[source]
+                    val oldValue = values[source]
+                    if (oldValue != null) {
+                        newKeys[dest] = oldKey
+                        newValues[dest] = oldValue
+                        dest++
+                    }
+                    source++
+                }
+            }
+        } else {
+            // The only element
+            newKeys[0] = key
+            newValues[0] = value
+        }
+        return ThreadMap(newSize, newKeys, newValues)
+    }
+
+    private fun find(key: Long): Int {
+        var high = size - 1
+        when (high) {
+            -1 -> return -1
+            0 -> return if (keys[0] == key) 0 else if (keys[0] > key) -2 else -1
+        }
+        var low = 0
+
+        while (low <= high) {
+            val mid = (low + high).ushr(1)
+            val midVal = keys[mid]
+            val comparison = midVal - key
+            when {
+                comparison < 0 -> low = mid + 1
+                comparison > 0 -> high = mid - 1
+                else -> return mid
+            }
+        }
+        return -(low + 1)
+    }
+}
+
+internal val emptyThreadMap = ThreadMap(0, LongArray(0), emptyArray())
\ No newline at end of file
diff --git a/compose/runtime/runtime/src/test/kotlin/androidx/compose/runtime/snapshots/SnapshotThreadMapTests.kt b/compose/runtime/runtime/src/test/kotlin/androidx/compose/runtime/snapshots/SnapshotThreadMapTests.kt
new file mode 100644
index 0000000..a50c9a8
--- /dev/null
+++ b/compose/runtime/runtime/src/test/kotlin/androidx/compose/runtime/snapshots/SnapshotThreadMapTests.kt
@@ -0,0 +1,162 @@
+/*
+ * Copyright 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.compose.runtime.snapshots
+
+import androidx.compose.runtime.SnapshotThreadLocal
+import androidx.compose.runtime.internal.ThreadMap
+import kotlin.random.Random
+import kotlin.test.Test
+import kotlin.test.assertEquals
+import kotlin.test.assertFalse
+import kotlin.test.assertNotEquals
+import kotlin.test.assertNotNull
+import kotlin.test.assertNull
+import kotlin.test.assertTrue
+
+/**
+ * Test the internal ThreadMap
+ */
+class SnapshotThreadMapTests {
+    @Test
+    fun canCreateAMap() {
+        val map = emptyThreadMap()
+        assertNotNull(map)
+    }
+
+    @Test
+    fun setOfEmptyFails() {
+        val map = emptyThreadMap()
+        val added = map.trySet(1, 1)
+        assertFalse(added)
+    }
+
+    @Test
+    fun canAddOneToEmpty() {
+        val map = emptyThreadMap()
+        val newMap = map.newWith(1, 1)
+        assertNotEquals(map, newMap)
+        assertEquals(1, newMap.get(1))
+    }
+
+    @Test
+    fun canCreateForward() {
+        val map = testMap(0 until 100)
+        assertNotNull(map)
+        for (i in 0 until 100) {
+            assertEquals(i, map.get(i.toLong()))
+        }
+        for (i in -100 until 0) {
+            assertNull(map.get(i.toLong()))
+        }
+        for (i in 100 until 200) {
+            assertNull(map.get(i.toLong()))
+        }
+    }
+
+    @Test
+    fun canCreateBackward() {
+        val map = testMap((0 until 100).reversed())
+        assertNotNull(map)
+        for (i in 0 until 100) {
+            assertEquals(i, map.get(i.toLong()))
+        }
+        for (i in -100 until 0) {
+            assertNull(map.get(i.toLong()))
+        }
+        for (i in 100 until 200) {
+            assertNull(map.get(i.toLong()))
+        }
+    }
+
+    @Test
+    fun canCreateRandom() {
+        val list = Array<Long>(100) { it.toLong() }
+        val rand = Random(1337)
+        list.shuffle(rand)
+        var map = emptyThreadMap()
+        for (item in list) {
+            map = map.newWith(item, item)
+        }
+        for (i in 0 until 100) {
+            assertEquals(i.toLong(), map.get(i.toLong()))
+        }
+        for (i in -100 until 0) {
+            assertNull(map.get(i.toLong()))
+        }
+        for (i in 100 until 200) {
+            assertNull(map.get(i.toLong()))
+        }
+    }
+
+    @Test
+    fun canRemoveOne() {
+        val map = testMap(1..10)
+        val set = map.trySet(5, null)
+        assertTrue(set)
+        for (i in 1..10) {
+            if (i == 5) {
+                assertNull(map.get(i.toLong()))
+            } else {
+                assertEquals(i, map.get(i.toLong()))
+            }
+        }
+    }
+
+    @Test
+    fun canRemoveOneThenAddOne() {
+        val map = testMap(1..10)
+        val set = map.trySet(5, null)
+        assertTrue(set)
+        val newMap = map.newWith(11, 11)
+        assertNull(newMap.get(5))
+        assertEquals(11, newMap.get(11))
+    }
+
+    private fun emptyThreadMap() = ThreadMap(0, LongArray(0), arrayOfNulls(0))
+
+    private fun testMap(intProgression: IntProgression): ThreadMap {
+        var result = emptyThreadMap()
+        for (i in intProgression) {
+            result = result.newWith(i.toLong(), i)
+        }
+        return result
+    }
+}
+
+/**
+ * Test the thread lcoal variable
+ */
+class SnapshotThreadLocalTests {
+    @Test
+    fun canCreate() {
+        val local = SnapshotThreadLocal<Int>()
+        assertNotNull(local)
+    }
+
+    @Test
+    fun initalValueIsNull() {
+        val local = SnapshotThreadLocal<Int>()
+        assertNull(local.get())
+    }
+
+    @Test
+    fun canSetAndGetTheValue() {
+        val local = SnapshotThreadLocal<Int>()
+        local.set(100)
+        assertEquals(100, local.get())
+    }
+}
\ No newline at end of file