Make SnapshotMutableStateImpl parcelable

And remove the custom serialization logic from DisposableSaveableStateRegistry.

Fixes: 182453625
Test: SaveableStateHolderTest, ActivityRecreationTest ParcelableMutableStateTests
Change-Id: I66a669cb3c9117f968d0b71464b12252370973e2
diff --git a/compose/runtime/runtime-saveable/src/androidAndroidTest/AndroidManifest.xml b/compose/runtime/runtime-saveable/src/androidAndroidTest/AndroidManifest.xml
index 391101b..4148f68 100644
--- a/compose/runtime/runtime-saveable/src/androidAndroidTest/AndroidManifest.xml
+++ b/compose/runtime/runtime-saveable/src/androidAndroidTest/AndroidManifest.xml
@@ -30,5 +30,6 @@
         <activity android:name="androidx.compose.runtime.saveable.RecreationTest6Activity" />
         <activity android:name="androidx.compose.runtime.saveable.RecreationTest7Activity" />
         <activity android:name="androidx.compose.runtime.saveable.RecreationTest8Activity" />
+        <activity android:name="androidx.compose.runtime.saveable.SaveableStateHolderTest$Activity" />
     </application>
 </manifest>
diff --git a/compose/runtime/runtime-saveable/src/androidAndroidTest/kotlin/androidx/compose/runtime/saveable/SaveableStateHolderTest.kt b/compose/runtime/runtime-saveable/src/androidAndroidTest/kotlin/androidx/compose/runtime/saveable/SaveableStateHolderTest.kt
index 86f5647..7170fd85 100644
--- a/compose/runtime/runtime-saveable/src/androidAndroidTest/kotlin/androidx/compose/runtime/saveable/SaveableStateHolderTest.kt
+++ b/compose/runtime/runtime-saveable/src/androidAndroidTest/kotlin/androidx/compose/runtime/saveable/SaveableStateHolderTest.kt
@@ -16,12 +16,15 @@
 
 package androidx.compose.runtime.saveable
 
+import android.os.Bundle
+import androidx.activity.ComponentActivity
+import androidx.compose.runtime.MutableState
 import androidx.compose.runtime.getValue
 import androidx.compose.runtime.mutableStateOf
 import androidx.compose.runtime.remember
 import androidx.compose.runtime.setValue
 import androidx.compose.ui.test.junit4.StateRestorationTester
-import androidx.compose.ui.test.junit4.createComposeRule
+import androidx.compose.ui.test.junit4.createAndroidComposeRule
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import androidx.test.filters.MediumTest
 import com.google.common.truth.Truth.assertThat
@@ -34,7 +37,7 @@
 class SaveableStateHolderTest {
 
     @get:Rule
-    val rule = createComposeRule()
+    val rule = createAndroidComposeRule<Activity>()
 
     private val restorationTester = StateRestorationTester(rule)
 
@@ -256,6 +259,44 @@
             assertThat(restorableNumberOnScreen1).isEqualTo(1)
         }
     }
+
+    @Test
+    fun restoringStateOfThePreviousPageAfterCreatingBundle() {
+        var showFirstPage by mutableStateOf(true)
+        var firstPageState: MutableState<Int>? = null
+
+        rule.setContent {
+            val holder = rememberSaveableStateHolder()
+            holder.SaveableStateProvider(showFirstPage) {
+                if (showFirstPage) {
+                    firstPageState = rememberSaveable { mutableStateOf(0) }
+                }
+            }
+        }
+
+        rule.runOnIdle {
+            assertThat(firstPageState!!.value).isEqualTo(0)
+            // change the value, so we can assert this change will be restored
+            firstPageState!!.value = 1
+            firstPageState = null
+            showFirstPage = false
+        }
+
+        rule.runOnIdle {
+            rule.activity.doFakeSave()
+            showFirstPage = true
+        }
+
+        rule.runOnIdle {
+            assertThat(firstPageState!!.value).isEqualTo(1)
+        }
+    }
+
+    class Activity : ComponentActivity() {
+        fun doFakeSave() {
+            onSaveInstanceState(Bundle())
+        }
+    }
 }
 
 enum class Screens {
diff --git a/compose/runtime/runtime/src/androidAndroidTest/kotlin/androidx/compose/runtime/snapshots/ParcelableMutableStateTests.kt b/compose/runtime/runtime/src/androidAndroidTest/kotlin/androidx/compose/runtime/snapshots/ParcelableMutableStateTests.kt
new file mode 100644
index 0000000..0cb895a
--- /dev/null
+++ b/compose/runtime/runtime/src/androidAndroidTest/kotlin/androidx/compose/runtime/snapshots/ParcelableMutableStateTests.kt
@@ -0,0 +1,61 @@
+/*
+ * Copyright 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.compose.runtime.snapshots
+
+import android.os.Parcel
+import android.os.Parcelable
+import androidx.compose.runtime.SnapshotMutationPolicy
+import androidx.compose.runtime.mutableStateOf
+import androidx.compose.runtime.neverEqualPolicy
+import androidx.compose.runtime.referentialEqualityPolicy
+import androidx.compose.runtime.structuralEqualityPolicy
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+import kotlin.test.assertEquals
+
+@RunWith(Parameterized::class)
+class ParcelableMutableStateTests(
+    private val policy: SnapshotMutationPolicy<Int>
+) {
+    @Test
+    fun saveAndRestoreTheMutableStateOf() {
+        val a = mutableStateOf(0, policy)
+        a.value = 1
+
+        val parcel = Parcel.obtain()
+        parcel.writeParcelable(a as Parcelable, 0)
+        parcel.setDataPosition(0)
+        @Suppress("UNCHECKED_CAST")
+        val restored =
+            parcel.readParcelable<Parcelable>(javaClass.classLoader) as SnapshotMutableState<Int>
+
+        assertEquals(1, restored.value)
+        assertEquals(policy, restored.policy)
+    }
+
+    companion object {
+        @JvmStatic
+        @Parameterized.Parameters(name = "{0}")
+        fun initParameters(): Array<SnapshotMutationPolicy<Int>> =
+            arrayOf(
+                structuralEqualityPolicy(),
+                referentialEqualityPolicy(),
+                neverEqualPolicy()
+            )
+    }
+}
diff --git a/compose/runtime/runtime/src/androidMain/kotlin/androidx/compose/runtime/ActualAndroid.android.kt b/compose/runtime/runtime/src/androidMain/kotlin/androidx/compose/runtime/ActualAndroid.android.kt
index 0043c50..7b49ff7 100644
--- a/compose/runtime/runtime/src/androidMain/kotlin/androidx/compose/runtime/ActualAndroid.android.kt
+++ b/compose/runtime/runtime/src/androidMain/kotlin/androidx/compose/runtime/ActualAndroid.android.kt
@@ -18,6 +18,7 @@
 
 import android.os.Looper
 import android.view.Choreographer
+import androidx.compose.runtime.snapshots.SnapshotMutableState
 import kotlinx.coroutines.Dispatchers
 import kotlinx.coroutines.delay
 import kotlinx.coroutines.runBlocking
@@ -80,3 +81,8 @@
     if (Looper.getMainLooper() != null) DefaultChoreographerFrameClock
     else SdkStubsFallbackFrameClock
 }
+
+internal actual fun <T> createSnapshotMutableState(
+    value: T,
+    policy: SnapshotMutationPolicy<T>
+): SnapshotMutableState<T> = ParcelableSnapshotMutableState(value, policy)
diff --git a/compose/runtime/runtime/src/androidMain/kotlin/androidx/compose/runtime/ParcelableSnapshotMutableState.kt b/compose/runtime/runtime/src/androidMain/kotlin/androidx/compose/runtime/ParcelableSnapshotMutableState.kt
new file mode 100644
index 0000000..05ccc29
--- /dev/null
+++ b/compose/runtime/runtime/src/androidMain/kotlin/androidx/compose/runtime/ParcelableSnapshotMutableState.kt
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.compose.runtime
+
+import android.annotation.SuppressLint
+import android.os.Parcel
+import android.os.Parcelable
+
+@SuppressLint("BanParcelableUsage")
+internal class ParcelableSnapshotMutableState<T>(
+    value: T,
+    policy: SnapshotMutationPolicy<T>
+) : SnapshotMutableStateImpl<T>(value, policy), Parcelable {
+
+    override fun writeToParcel(parcel: Parcel, flags: Int) {
+        parcel.writeValue(value)
+        parcel.writeInt(
+            when (policy) {
+                neverEqualPolicy<Any?>() -> PolicyNeverEquals
+                structuralEqualityPolicy<Any?>() -> PolicyStructuralEquality
+                referentialEqualityPolicy<Any?>() -> PolicyReferentialEquality
+                else -> throw IllegalStateException(
+                    "Only known types of MutableState's SnapshotMutationPolicy are supported"
+                )
+            }
+        )
+    }
+
+    override fun describeContents(): Int {
+        return 0
+    }
+
+    companion object {
+        private const val PolicyNeverEquals = 0
+        private const val PolicyStructuralEquality = 1
+        private const val PolicyReferentialEquality = 2
+
+        @Suppress("unused")
+        @JvmField
+        val CREATOR: Parcelable.Creator<ParcelableSnapshotMutableState<Any?>> =
+            object : Parcelable.ClassLoaderCreator<ParcelableSnapshotMutableState<Any?>> {
+                override fun createFromParcel(
+                    parcel: Parcel,
+                    loader: ClassLoader?
+                ): ParcelableSnapshotMutableState<Any?> {
+                    val value = parcel.readValue(loader ?: javaClass.classLoader)
+                    val policyIndex = parcel.readInt()
+                    return ParcelableSnapshotMutableState(
+                        value,
+                        when (policyIndex) {
+                            PolicyNeverEquals -> neverEqualPolicy()
+                            PolicyStructuralEquality -> structuralEqualityPolicy()
+                            PolicyReferentialEquality -> referentialEqualityPolicy()
+                            else -> throw IllegalStateException(
+                                "Unsupported MutableState policy $policyIndex was restored"
+                            )
+                        }
+                    )
+                }
+
+                override fun createFromParcel(parcel: Parcel) = createFromParcel(parcel, null)
+
+                override fun newArray(size: Int) =
+                    arrayOfNulls<ParcelableSnapshotMutableState<Any?>?>(size)
+            }
+    }
+}
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/SnapshotState.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/SnapshotState.kt
index 0622ea6..3e8396f 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/SnapshotState.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/SnapshotState.kt
@@ -63,7 +63,7 @@
 fun <T> mutableStateOf(
     value: T,
     policy: SnapshotMutationPolicy<T> = structuralEqualityPolicy()
-): MutableState<T> = SnapshotMutableStateImpl(value, policy)
+): MutableState<T> = createSnapshotMutableState(value, policy)
 
 /**
  * A value holder where reads to the [value] property during the execution of a [Composable]
@@ -113,17 +113,25 @@
 }
 
 /**
+ * Returns platform specific implementation based on [SnapshotMutableStateImpl].
+ */
+internal expect fun <T> createSnapshotMutableState(
+    value: T,
+    policy: SnapshotMutationPolicy<T>
+): SnapshotMutableState<T>
+
+/**
  * A single value holder whose reads and writes are observed by Compose.
  *
  * Additionally, writes to it are transacted as part of the [Snapshot] system.
  *
- * @property value the wrapped value
- * @property policy a policy to control how changes are handled in a mutable snapshot.
+ * @param value the wrapped value
+ * @param policy a policy to control how changes are handled in a mutable snapshot.
  *
  * @see mutableStateOf
  * @see SnapshotMutationPolicy
  */
-private class SnapshotMutableStateImpl<T>(
+internal open class SnapshotMutableStateImpl<T>(
     value: T,
     override val policy: SnapshotMutationPolicy<T>
 ) : StateObject, SnapshotMutableState<T> {
@@ -243,6 +251,8 @@
 
 private object ReferentialEqualityPolicy : SnapshotMutationPolicy<Any?> {
     override fun equivalent(a: Any?, b: Any?) = a === b
+
+    override fun toString() = "ReferentialEqualityPolicy"
 }
 
 /**
@@ -258,6 +268,8 @@
 
 private object StructuralEqualityPolicy : SnapshotMutationPolicy<Any?> {
     override fun equivalent(a: Any?, b: Any?) = a == b
+
+    override fun toString() = "StructuralEqualityPolicy"
 }
 
 /**
@@ -273,6 +285,8 @@
 
 private object NeverEqualPolicy : SnapshotMutationPolicy<Any?> {
     override fun equivalent(a: Any?, b: Any?) = false
+
+    override fun toString() = "NeverEqualPolicy"
 }
 
 /**
diff --git a/compose/runtime/runtime/src/desktopMain/kotlin/androidx/compose/runtime/ActualDesktop.desktop.kt b/compose/runtime/runtime/src/desktopMain/kotlin/androidx/compose/runtime/ActualDesktop.desktop.kt
index 06f7e35..b704879 100644
--- a/compose/runtime/runtime/src/desktopMain/kotlin/androidx/compose/runtime/ActualDesktop.desktop.kt
+++ b/compose/runtime/runtime/src/desktopMain/kotlin/androidx/compose/runtime/ActualDesktop.desktop.kt
@@ -16,6 +16,7 @@
 
 package androidx.compose.runtime
 
+import androidx.compose.runtime.snapshots.SnapshotMutableState
 import kotlinx.coroutines.delay
 
 internal actual object Trace {
@@ -71,3 +72,8 @@
         return onFrame(System.nanoTime())
     }
 }
+
+internal actual fun <T> createSnapshotMutableState(
+    value: T,
+    policy: SnapshotMutationPolicy<T>
+): SnapshotMutableState<T> = SnapshotMutableStateImpl(value, policy)
diff --git a/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt b/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt
index 1515b5d..04c94f8 100644
--- a/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt
+++ b/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt
@@ -18,22 +18,18 @@
 
 package androidx.compose.ui.platform
 
-import android.annotation.SuppressLint
 import android.os.Binder
 import android.os.Bundle
-import android.os.Parcel
 import android.os.Parcelable
 import android.util.Size
 import android.util.SizeF
 import android.util.SparseArray
 import android.view.View
-import androidx.compose.runtime.mutableStateOf
 import androidx.compose.runtime.neverEqualPolicy
 import androidx.compose.runtime.referentialEqualityPolicy
 import androidx.compose.runtime.saveable.SaveableStateRegistry
 import androidx.compose.runtime.snapshots.SnapshotMutableState
 import androidx.compose.runtime.structuralEqualityPolicy
-import androidx.compose.ui.util.fastForEachIndexed
 import androidx.savedstate.SavedStateRegistry
 import androidx.savedstate.SavedStateRegistryOwner
 import java.io.Serializable
@@ -118,11 +114,7 @@
  * Checks that [value] can be stored inside [Bundle].
  */
 private fun canBeSavedToBundle(value: Any): Boolean {
-    for (cl in AcceptableClasses) {
-        if (cl.isInstance(value)) {
-            return true
-        }
-    }
+    // SnapshotMutableStateImpl is Parcelable, but we do extra checks
     if (value is SnapshotMutableState<*>) {
         if (value.policy === neverEqualPolicy<Any?>() ||
             value.policy === structuralEqualityPolicy<Any?>() ||
@@ -130,6 +122,13 @@
         ) {
             val stateValue = value.value
             return if (stateValue == null) true else canBeSavedToBundle(stateValue)
+        } else {
+            return false
+        }
+    }
+    for (cl in AcceptableClasses) {
+        if (cl.isInstance(value)) {
+            return true
         }
     }
     return false
@@ -165,7 +164,6 @@
     val map = mutableMapOf<String, List<Any?>>()
     this.keySet().forEach { key ->
         val list = getParcelableArrayList<Parcelable?>(key) as ArrayList<Any?>
-        unwrapMutableStatesIn(list)
         map[key] = list
     }
     return map
@@ -175,7 +173,6 @@
     val bundle = Bundle()
     forEach { (key, list) ->
         val arrayList = if (list is ArrayList<Any?>) list else ArrayList(list)
-        wrapMutableStatesIn(arrayList)
         bundle.putParcelableArrayList(
             key,
             arrayList as ArrayList<Parcelable?>
@@ -183,142 +180,3 @@
     }
     return bundle
 }
-
-private fun wrapMutableStatesIn(list: MutableList<Any?>) {
-    list.fastForEachIndexed { index, value ->
-        if (value is SnapshotMutableState<*>) {
-            list[index] = ParcelableMutableStateHolder(value)
-        } else {
-            wrapMutableStatesInListOrMap(value)
-        }
-    }
-}
-
-private fun wrapMutableStatesIn(map: MutableMap<Any?, Any?>) {
-    map.forEach { (key, value) ->
-        if (value is SnapshotMutableState<*>) {
-            map[key] = ParcelableMutableStateHolder(value)
-        } else {
-            wrapMutableStatesInListOrMap(value)
-        }
-    }
-}
-
-private fun wrapMutableStatesInListOrMap(value: Any?) {
-    when (value) {
-        is MutableList<*> -> {
-            wrapMutableStatesIn(value as MutableList<Any?>)
-        }
-        is List<*> -> {
-            value.forEach {
-                check(it !is SnapshotMutableState<*>) {
-                    "Unexpected immutable list containing MutableState!"
-                }
-            }
-        }
-        is MutableMap<*, *> -> {
-            wrapMutableStatesIn(value as MutableMap<Any?, Any?>)
-        }
-        is Map<*, *> -> {
-            value.forEach {
-                check(it.value !is SnapshotMutableState<*>) {
-                    "Unexpected immutable map containing MutableState!"
-                }
-            }
-        }
-    }
-}
-
-private fun unwrapMutableStatesIn(list: MutableList<Any?>) {
-    list.fastForEachIndexed { index, value ->
-        if (value is ParcelableMutableStateHolder) {
-            list[index] = value.state
-        } else {
-            unwrapMutableStatesInListOrMap(value)
-        }
-    }
-}
-
-private fun unwrapMutableStatesIn(map: MutableMap<Any?, Any?>) {
-    map.forEach { (key, value) ->
-        if (value is ParcelableMutableStateHolder) {
-            map[key] = value.state
-        } else {
-            unwrapMutableStatesInListOrMap(value)
-        }
-    }
-}
-
-private fun unwrapMutableStatesInListOrMap(value: Any?) {
-    when (value) {
-        is MutableList<*> -> {
-            unwrapMutableStatesIn(value as MutableList<Any?>)
-        }
-        is MutableMap<*, *> -> {
-            unwrapMutableStatesIn(value as MutableMap<Any?, Any?>)
-        }
-    }
-}
-
-@SuppressLint("BanParcelableUsage")
-private class ParcelableMutableStateHolder : Parcelable {
-
-    val state: SnapshotMutableState<*>
-
-    constructor(state: SnapshotMutableState<*>) {
-        this.state = state
-    }
-
-    private constructor(parcel: Parcel, loader: ClassLoader?) {
-        val value = parcel.readValue(loader ?: javaClass.classLoader)
-        val policyIndex = parcel.readInt()
-        state = mutableStateOf(
-            value,
-            when (policyIndex) {
-                PolicyNeverEquals -> neverEqualPolicy()
-                PolicyStructuralEquality -> structuralEqualityPolicy()
-                PolicyReferentialEquality -> referentialEqualityPolicy()
-                else -> throw IllegalStateException(
-                    "Restored an incorrect MutableState policy $policyIndex"
-                )
-            }
-        ) as SnapshotMutableState
-    }
-
-    override fun writeToParcel(parcel: Parcel, flags: Int) {
-        parcel.writeValue(state.value)
-        parcel.writeInt(
-            when (state.policy) {
-                neverEqualPolicy<Any?>() -> PolicyNeverEquals
-                structuralEqualityPolicy<Any?>() -> PolicyStructuralEquality
-                referentialEqualityPolicy<Any?>() -> PolicyReferentialEquality
-                else -> throw IllegalStateException(
-                    "Only known types of MutableState's SnapshotMutationPolicy are supported"
-                )
-            }
-        )
-    }
-
-    override fun describeContents(): Int {
-        return 0
-    }
-
-    companion object {
-        private const val PolicyNeverEquals = 0
-        private const val PolicyStructuralEquality = 1
-        private const val PolicyReferentialEquality = 2
-
-        @Suppress("unused")
-        @JvmField
-        val CREATOR: Parcelable.Creator<ParcelableMutableStateHolder> =
-            object : Parcelable.ClassLoaderCreator<ParcelableMutableStateHolder> {
-                override fun createFromParcel(parcel: Parcel, loader: ClassLoader) =
-                    ParcelableMutableStateHolder(parcel, loader)
-
-                override fun createFromParcel(parcel: Parcel) =
-                    ParcelableMutableStateHolder(parcel, null)
-
-                override fun newArray(size: Int) = arrayOfNulls<ParcelableMutableStateHolder?>(size)
-            }
-    }
-}