Mark destination navigating away from as transitioning

When navigating forward with transitions, although we don't hold the
lifecycle of the entry being kept on the back stack above CREATED during
the transition, we should still mark it as transitioning until its
transition is actually complete.

RelNote: "When navigating away from a `NavBackStackEntry` and using the
`pushWithTransition` API, the `NavigatorState` will now properly mark
the previous entry as transitioning."
Test: modified test
Bug: 172112072
Bug: 194301889
Fixes: 191870023

Change-Id: If0543dd1c20e7338078115e98b5585623f9b8f1c
diff --git a/navigation/navigation-common/api/current.txt b/navigation/navigation-common/api/current.txt
index faa0cd4..5ccc622 100644
--- a/navigation/navigation-common/api/current.txt
+++ b/navigation/navigation-common/api/current.txt
@@ -487,17 +487,14 @@
     ctor public NavigatorState();
     method public abstract androidx.navigation.NavBackStackEntry createBackStackEntry(androidx.navigation.NavDestination destination, android.os.Bundle? arguments);
     method public final kotlinx.coroutines.flow.StateFlow<java.util.List<androidx.navigation.NavBackStackEntry>> getBackStack();
-    method public final kotlinx.coroutines.flow.StateFlow<java.util.Map<androidx.navigation.NavBackStackEntry,androidx.navigation.NavigatorState.OnTransitionCompleteListener>> getTransitionsInProgress();
+    method public final kotlinx.coroutines.flow.StateFlow<java.util.Set<androidx.navigation.NavBackStackEntry>> getTransitionsInProgress();
+    method public void markTransitionComplete(androidx.navigation.NavBackStackEntry entry);
     method public void pop(androidx.navigation.NavBackStackEntry popUpTo, boolean saveState);
-    method public androidx.navigation.NavigatorState.OnTransitionCompleteListener popWithTransition(androidx.navigation.NavBackStackEntry popUpTo, boolean saveState);
+    method public void popWithTransition(androidx.navigation.NavBackStackEntry popUpTo, boolean saveState);
     method public void push(androidx.navigation.NavBackStackEntry backStackEntry);
-    method public androidx.navigation.NavigatorState.OnTransitionCompleteListener pushWithTransition(androidx.navigation.NavBackStackEntry backStackEntry);
+    method public void pushWithTransition(androidx.navigation.NavBackStackEntry backStackEntry);
     property public final kotlinx.coroutines.flow.StateFlow<java.util.List<androidx.navigation.NavBackStackEntry>> backStack;
-    property public final kotlinx.coroutines.flow.StateFlow<java.util.Map<androidx.navigation.NavBackStackEntry,androidx.navigation.NavigatorState.OnTransitionCompleteListener>> transitionsInProgress;
-  }
-
-  public static fun interface NavigatorState.OnTransitionCompleteListener {
-    method public void onTransitionComplete();
+    property public final kotlinx.coroutines.flow.StateFlow<java.util.Set<androidx.navigation.NavBackStackEntry>> transitionsInProgress;
   }
 
   @androidx.navigation.NavOptionsDsl public final class PopUpToBuilder {
diff --git a/navigation/navigation-common/api/public_plus_experimental_current.txt b/navigation/navigation-common/api/public_plus_experimental_current.txt
index faa0cd4..5ccc622 100644
--- a/navigation/navigation-common/api/public_plus_experimental_current.txt
+++ b/navigation/navigation-common/api/public_plus_experimental_current.txt
@@ -487,17 +487,14 @@
     ctor public NavigatorState();
     method public abstract androidx.navigation.NavBackStackEntry createBackStackEntry(androidx.navigation.NavDestination destination, android.os.Bundle? arguments);
     method public final kotlinx.coroutines.flow.StateFlow<java.util.List<androidx.navigation.NavBackStackEntry>> getBackStack();
-    method public final kotlinx.coroutines.flow.StateFlow<java.util.Map<androidx.navigation.NavBackStackEntry,androidx.navigation.NavigatorState.OnTransitionCompleteListener>> getTransitionsInProgress();
+    method public final kotlinx.coroutines.flow.StateFlow<java.util.Set<androidx.navigation.NavBackStackEntry>> getTransitionsInProgress();
+    method public void markTransitionComplete(androidx.navigation.NavBackStackEntry entry);
     method public void pop(androidx.navigation.NavBackStackEntry popUpTo, boolean saveState);
-    method public androidx.navigation.NavigatorState.OnTransitionCompleteListener popWithTransition(androidx.navigation.NavBackStackEntry popUpTo, boolean saveState);
+    method public void popWithTransition(androidx.navigation.NavBackStackEntry popUpTo, boolean saveState);
     method public void push(androidx.navigation.NavBackStackEntry backStackEntry);
-    method public androidx.navigation.NavigatorState.OnTransitionCompleteListener pushWithTransition(androidx.navigation.NavBackStackEntry backStackEntry);
+    method public void pushWithTransition(androidx.navigation.NavBackStackEntry backStackEntry);
     property public final kotlinx.coroutines.flow.StateFlow<java.util.List<androidx.navigation.NavBackStackEntry>> backStack;
-    property public final kotlinx.coroutines.flow.StateFlow<java.util.Map<androidx.navigation.NavBackStackEntry,androidx.navigation.NavigatorState.OnTransitionCompleteListener>> transitionsInProgress;
-  }
-
-  public static fun interface NavigatorState.OnTransitionCompleteListener {
-    method public void onTransitionComplete();
+    property public final kotlinx.coroutines.flow.StateFlow<java.util.Set<androidx.navigation.NavBackStackEntry>> transitionsInProgress;
   }
 
   @androidx.navigation.NavOptionsDsl public final class PopUpToBuilder {
diff --git a/navigation/navigation-common/api/restricted_current.txt b/navigation/navigation-common/api/restricted_current.txt
index faa0cd4..5ccc622 100644
--- a/navigation/navigation-common/api/restricted_current.txt
+++ b/navigation/navigation-common/api/restricted_current.txt
@@ -487,17 +487,14 @@
     ctor public NavigatorState();
     method public abstract androidx.navigation.NavBackStackEntry createBackStackEntry(androidx.navigation.NavDestination destination, android.os.Bundle? arguments);
     method public final kotlinx.coroutines.flow.StateFlow<java.util.List<androidx.navigation.NavBackStackEntry>> getBackStack();
-    method public final kotlinx.coroutines.flow.StateFlow<java.util.Map<androidx.navigation.NavBackStackEntry,androidx.navigation.NavigatorState.OnTransitionCompleteListener>> getTransitionsInProgress();
+    method public final kotlinx.coroutines.flow.StateFlow<java.util.Set<androidx.navigation.NavBackStackEntry>> getTransitionsInProgress();
+    method public void markTransitionComplete(androidx.navigation.NavBackStackEntry entry);
     method public void pop(androidx.navigation.NavBackStackEntry popUpTo, boolean saveState);
-    method public androidx.navigation.NavigatorState.OnTransitionCompleteListener popWithTransition(androidx.navigation.NavBackStackEntry popUpTo, boolean saveState);
+    method public void popWithTransition(androidx.navigation.NavBackStackEntry popUpTo, boolean saveState);
     method public void push(androidx.navigation.NavBackStackEntry backStackEntry);
-    method public androidx.navigation.NavigatorState.OnTransitionCompleteListener pushWithTransition(androidx.navigation.NavBackStackEntry backStackEntry);
+    method public void pushWithTransition(androidx.navigation.NavBackStackEntry backStackEntry);
     property public final kotlinx.coroutines.flow.StateFlow<java.util.List<androidx.navigation.NavBackStackEntry>> backStack;
-    property public final kotlinx.coroutines.flow.StateFlow<java.util.Map<androidx.navigation.NavBackStackEntry,androidx.navigation.NavigatorState.OnTransitionCompleteListener>> transitionsInProgress;
-  }
-
-  public static fun interface NavigatorState.OnTransitionCompleteListener {
-    method public void onTransitionComplete();
+    property public final kotlinx.coroutines.flow.StateFlow<java.util.Set<androidx.navigation.NavBackStackEntry>> transitionsInProgress;
   }
 
   @androidx.navigation.NavOptionsDsl public final class PopUpToBuilder {
diff --git a/navigation/navigation-common/src/main/java/androidx/navigation/NavigatorState.kt b/navigation/navigation-common/src/main/java/androidx/navigation/NavigatorState.kt
index eef5128..c031139 100644
--- a/navigation/navigation-common/src/main/java/androidx/navigation/NavigatorState.kt
+++ b/navigation/navigation-common/src/main/java/androidx/navigation/NavigatorState.kt
@@ -18,7 +18,6 @@
 
 import android.os.Bundle
 import androidx.annotation.RestrictTo
-import androidx.navigation.NavigatorState.OnTransitionCompleteListener
 import kotlinx.coroutines.flow.MutableStateFlow
 import kotlinx.coroutines.flow.StateFlow
 import kotlinx.coroutines.flow.asStateFlow
@@ -32,9 +31,8 @@
 public abstract class NavigatorState {
     private val backStackLock = ReentrantLock(true)
     private val _backStack: MutableStateFlow<List<NavBackStackEntry>> = MutableStateFlow(listOf())
-    private val _transitionsInProgress:
-        MutableStateFlow<Map<NavBackStackEntry, OnTransitionCompleteListener>> =
-            MutableStateFlow(mapOf())
+    private val _transitionsInProgress: MutableStateFlow<Set<NavBackStackEntry>> =
+        MutableStateFlow(setOf())
 
     /**
      * @hide
@@ -52,13 +50,11 @@
     public val backStack: StateFlow<List<NavBackStackEntry>> = _backStack.asStateFlow()
 
     /**
-     * This is the map of currently running transitions to their individual
-     * [OnTransitionCompleteListener]s. Use this map to retrieve the listener and execute the
-     * callback once the transition is complete.
+     * This is the set of currently running transitions. Use this set to retrieve the entry and call
+     * [markTransitionComplete] once the transition is complete.
      */
-    public val transitionsInProgress:
-        StateFlow<Map<NavBackStackEntry, OnTransitionCompleteListener>> =
-            _transitionsInProgress.asStateFlow()
+    public val transitionsInProgress: StateFlow<Set<NavBackStackEntry>> =
+        _transitionsInProgress.asStateFlow()
 
     /**
      * Adds the given [backStackEntry] to the [backStack].
@@ -70,15 +66,25 @@
     }
 
     /**
-     * Provides listener that once activated, adds the given [backStackEntry] to the [backStack].
+     * Adds the given [backStackEntry] to the [backStack]. This also adds the given and
+     * previous entry to the [set of in progress transitions][transitionsInProgress].
+     * Added entries have their [Lifecycle] capped at [Lifecycle.State.STARTED] until an entry is
+     * passed into the [markTransitionComplete] callback, when they are allowed to go to
+     * [Lifecycle.State.RESUMED].
+     *
+     * @see transitionsInProgress
+     * @see markTransitionComplete
+     * @see popWithTransition
      */
-    public open fun pushWithTransition(
-        backStackEntry: NavBackStackEntry
-    ): OnTransitionCompleteListener {
-        push(backStackEntry)
-        return OnTransitionCompleteListener {
-            removeInProgressTransition(backStackEntry)
+    public open fun pushWithTransition(backStackEntry: NavBackStackEntry) {
+        val previousEntry = backStack.value.lastOrNull()
+        // When navigating, we need to mark the outgoing entry as transitioning until it
+        // finishes its outgoing animation.
+        if (previousEntry != null) {
+            _transitionsInProgress.value = _transitionsInProgress.value + previousEntry
         }
+        _transitionsInProgress.value = _transitionsInProgress.value + backStackEntry
+        push(backStackEntry)
     }
 
     /**
@@ -100,50 +106,46 @@
     }
 
     /**
-     * Provides listener that once activated, Pops all destinations up to and including [popUpTo].
+     * Pops all destinations up to and including [popUpTo]. This also adds the given and
+     * incoming entry to the [set of in progress transitions][transitionsInProgress]. Added
+     * entries have their [Lifecycle] held at [Lifecycle.State.CREATED] until an entry is
+     * passed into the [markTransitionComplete] callback, when they are allowed to go to
+     * [Lifecycle.State.DESTROYED] and have their state cleared.
      *
      * This will remove those destinations from the [backStack], saving their state if
      * [saveState] is `true`.
+     *
+     * @see transitionsInProgress
+     * @see markTransitionComplete
+     * @see pushWithTransition
      */
-    public open fun popWithTransition(
-        popUpTo: NavBackStackEntry,
-        saveState: Boolean
-    ): OnTransitionCompleteListener {
-        val listener = OnTransitionCompleteListener {
-            removeInProgressTransition(popUpTo)
+    public open fun popWithTransition(popUpTo: NavBackStackEntry, saveState: Boolean) {
+        _transitionsInProgress.value = _transitionsInProgress.value + popUpTo
+        val incomingEntry = backStack.value.lastOrNull { entry ->
+            entry != popUpTo &&
+                backStack.value.lastIndexOf(entry) < backStack.value.lastIndexOf(popUpTo)
+        }
+        // When popping, we need to mark the incoming entry as transitioning so we keep it
+        // STARTED until the transition completes at which point we can move it to RESUMED
+        if (incomingEntry != null) {
+            _transitionsInProgress.value = _transitionsInProgress.value + incomingEntry
         }
         pop(popUpTo, saveState)
-        return listener
     }
 
     /**
-     * Adds a transition listener to the group of in progress transitions.
+     * This removes the given [NavBackStackEntry] from the [set of the transitions in
+     * progress][transitionsInProgress]. This should be called in conjunction with
+     * [pushWithTransition] and [popWithTransition] as those call are responsible for adding
+     * entries to [transitionsInProgress].
      *
-     * @hide
+     * Failing to call this method could result in entries being prevented from reaching their
+     * final [Lifecycle.State]}.
+     *
+     * @see pushWithTransition
+     * @see popWithTransition
      */
-    @RestrictTo(RestrictTo.Scope.LIBRARY_GROUP)
-    public fun addInProgressTransition(
-        entry: NavBackStackEntry,
-        listener: OnTransitionCompleteListener
-    ) {
-        _transitionsInProgress.value = _transitionsInProgress.value + (entry to listener)
-    }
-
-    /**
-     * @hide
-     */
-    @RestrictTo(RestrictTo.Scope.LIBRARY_GROUP)
-    public fun removeInProgressTransition(entry: NavBackStackEntry) {
+    public open fun markTransitionComplete(entry: NavBackStackEntry) {
         _transitionsInProgress.value = _transitionsInProgress.value - entry
     }
-
-    /**
-     * OnTransitionCompleteListener receives a callback when a destination transition is complete.
-     */
-    public fun interface OnTransitionCompleteListener {
-        /**
-         * Callback for when the transition has completed.
-         */
-        public fun onTransitionComplete()
-    }
 }
diff --git a/navigation/navigation-compose/src/main/java/androidx/navigation/compose/ComposeNavigator.kt b/navigation/navigation-compose/src/main/java/androidx/navigation/compose/ComposeNavigator.kt
index a35f2d0..84a8da4 100644
--- a/navigation/navigation-compose/src/main/java/androidx/navigation/compose/ComposeNavigator.kt
+++ b/navigation/navigation-compose/src/main/java/androidx/navigation/compose/ComposeNavigator.kt
@@ -60,6 +60,18 @@
     }
 
     /**
+     * Callback that removes the given [NavBackStackEntry] from the [map of the transitions in
+     * progress][transitionsInProgress]. This should be called in conjunction with [navigate] and
+     * [popBackStack] as those call are responsible for adding entries to [transitionsInProgress].
+     *
+     * Failing to call this method could result in entries being prevented from reaching their
+     * final [Lifecycle.State]}.
+     */
+    internal fun onTransitionComplete(entry: NavBackStackEntry) {
+        state.markTransitionComplete(entry)
+    }
+
+    /**
      * NavDestination specific to [ComposeNavigator]
      */
     @NavDestination.ClassType(Composable::class)
diff --git a/navigation/navigation-compose/src/main/java/androidx/navigation/compose/NavHost.kt b/navigation/navigation-compose/src/main/java/androidx/navigation/compose/NavHost.kt
index 62dc370..986422c 100644
--- a/navigation/navigation-compose/src/main/java/androidx/navigation/compose/NavHost.kt
+++ b/navigation/navigation-compose/src/main/java/androidx/navigation/compose/NavHost.kt
@@ -118,7 +118,7 @@
     val backStack by composeNavigator.backStack.collectAsState()
     val transitionsInProgress by composeNavigator.transitionsInProgress.collectAsState()
 
-    val backStackEntry = transitionsInProgress.keys.lastOrNull { entry ->
+    val backStackEntry = transitionsInProgress.lastOrNull { entry ->
         entry.lifecycle.currentState.isAtLeast(Lifecycle.State.STARTED)
     } ?: backStack.lastOrNull { entry ->
         entry.lifecycle.currentState.isAtLeast(Lifecycle.State.STARTED)
@@ -137,13 +137,13 @@
                     // There's no animation for the initial crossfade,
                     // so we can instantly mark the transition as complete
                     transitionsInProgress.forEach { entry ->
-                        entry.value.onTransitionComplete()
+                        composeNavigator.onTransitionComplete(entry)
                     }
                     initialCrossfade = false
                 }
                 onDispose {
                     transitionsInProgress.forEach { entry ->
-                        entry.value.onTransitionComplete()
+                        composeNavigator.onTransitionComplete(entry)
                     }
                 }
             }
diff --git a/navigation/navigation-runtime/src/androidTest/java/androidx/navigation/NavBackStackEntryTest.kt b/navigation/navigation-runtime/src/androidTest/java/androidx/navigation/NavBackStackEntryTest.kt
index b061ee6..63c95d2 100644
--- a/navigation/navigation-runtime/src/androidTest/java/androidx/navigation/NavBackStackEntryTest.kt
+++ b/navigation/navigation-runtime/src/androidTest/java/androidx/navigation/NavBackStackEntryTest.kt
@@ -388,10 +388,55 @@
             .isTrue()
     }
 
-    private fun createNavController(): NavController {
+    @Suppress("DEPRECATION")
+    @UiThreadTest
+    @Test
+    fun testOnClearedWhenHostClearedAfterSaveStateWithTransitions() {
+        val hostStore = ViewModelStore()
+        val navController = createNavController(true)
+        navController.setViewModelStore(hostStore)
+        val navGraph = navController.navigatorProvider.navigation(
+            id = 1,
+            startDestination = R.id.start_test
+        ) {
+            test(R.id.start_test)
+        }
+        navController.setGraph(navGraph, null)
+
+        val owner = navController.getBackStackEntry(R.id.start_test)
+        assertThat(owner).isNotNull()
+        val viewModel: TestAndroidViewModel = ViewModelProvider(owner).get()
+        assertThat(viewModel.isCleared).isFalse()
+
+        // Navigate to a new instance of start_test, popping the previous one and saving state
+        navController.navigate(
+            R.id.start_test,
+            null,
+            navOptions {
+                popUpTo(R.id.start_test) {
+                    inclusive = true
+                    saveState = true
+                }
+            }
+        )
+        val newEntry = navController.getBackStackEntry(R.id.start_test)
+        navController.navigatorProvider[TestNavigator::class].onTransitionComplete(newEntry)
+
+        assertWithMessage("ViewModel should be saved when the destination is saved")
+            .that(viewModel.isCleared)
+            .isFalse()
+
+        hostStore.clear()
+
+        assertWithMessage("ViewModel should be cleared when the host is cleared")
+            .that(viewModel.isCleared)
+            .isTrue()
+    }
+
+    private fun createNavController(withTransitions: Boolean = false): NavController {
         val navController = NavHostController(ApplicationProvider.getApplicationContext())
         navController.setLifecycleOwner(TestLifecycleOwner())
-        val navigator = TestNavigator()
+        val navigator = TestNavigator(withTransitions)
         navController.navigatorProvider.addNavigator(navigator)
         return navController
     }
diff --git a/navigation/navigation-runtime/src/main/java/androidx/navigation/NavController.kt b/navigation/navigation-runtime/src/main/java/androidx/navigation/NavController.kt
index de23adb..b80a225 100644
--- a/navigation/navigation-runtime/src/main/java/androidx/navigation/NavController.kt
+++ b/navigation/navigation-runtime/src/main/java/androidx/navigation/NavController.kt
@@ -181,6 +181,7 @@
         mutableMapOf<Navigator<out NavDestination>, NavControllerNavigatorState>()
     private var addToBackStackHandler: ((backStackEntry: NavBackStackEntry) -> Unit)? = null
     private var popFromBackStackHandler: ((popUpTo: NavBackStackEntry) -> Unit)? = null
+    private val entrySavedState = mutableMapOf<NavBackStackEntry, Boolean>()
 
     /**
      * Call [Navigator.navigate] while setting up a [handler] that receives callbacks
@@ -269,43 +270,28 @@
             }
         }
 
-        override fun pushWithTransition(
-            backStackEntry: NavBackStackEntry
-        ): OnTransitionCompleteListener {
-            val innerListener = super.pushWithTransition(backStackEntry)
-            val listener = OnTransitionCompleteListener {
-                innerListener.onTransitionComplete()
-                if (!this@NavControllerNavigatorState.isNavigating) {
-                    updateBackStackLifecycle()
-                }
-            }
-            addInProgressTransition(backStackEntry, listener)
-            return listener
+        override fun popWithTransition(popUpTo: NavBackStackEntry, saveState: Boolean) {
+            super.popWithTransition(popUpTo, saveState)
+            entrySavedState[popUpTo] = saveState
         }
 
-        override fun popWithTransition(
-            popUpTo: NavBackStackEntry,
-            saveState: Boolean
-        ): OnTransitionCompleteListener {
-            // we need to mark the entry as transitioning before making the super call to pop so
-            // we don't move its lifecycle to DESTROYED.
-            addInProgressTransition(popUpTo) { }
-            val innerListener = super.popWithTransition(popUpTo, saveState)
-            val listener = OnTransitionCompleteListener {
-                innerListener.onTransitionComplete()
-                if (backQueue.contains(popUpTo)) {
-                    updateBackStackLifecycle()
-                } else {
-                    // If the entry is no longer part of the backStack, we need to manually move
-                    // it to DESTROYED, and clear its view model
-                    popUpTo.maxLifecycle = Lifecycle.State.DESTROYED
-                    if (!saveState) {
-                        viewModel?.clear(popUpTo.id)
-                    }
+        override fun markTransitionComplete(entry: NavBackStackEntry) {
+            val savedState = entrySavedState[entry] == true
+            super.markTransitionComplete(entry)
+            entrySavedState.remove(entry)
+            if (!backQueue.contains(entry)) {
+                // If the entry is no longer part of the backStack, we need to manually move
+                // it to DESTROYED, and clear its view model
+                entry.maxLifecycle = Lifecycle.State.DESTROYED
+                if (!savedState) {
+                    viewModel?.clear(entry.id)
                 }
+                updateBackStackLifecycle()
+            } else if (!this@NavControllerNavigatorState.isNavigating) {
+                updateBackStackLifecycle()
             }
-            addInProgressTransition(popUpTo, listener)
-            return listener
+            // else, updateBackStackLifecycle() will be called after any ongoing navigate() call
+            // completes
         }
     }
 
@@ -604,17 +590,7 @@
         val navigator = navigatorProvider
             .getNavigator<Navigator<NavDestination>>(entry.destination.navigatorName)
         val state = navigatorState[navigator]
-        val transitioning = state?.transitionsInProgress?.value?.containsKey(entry)
-        // When popping, we need to mark the incoming entry as transitioning so we keep it
-        // STARTED until the transition completes at which point we can move it to RESUMED
-        if (backQueue.isNotEmpty() && transitioning == true) {
-            state.addInProgressTransition(backQueue.last()) {
-                state.removeInProgressTransition(backQueue.last())
-                if (!state.isNavigating) {
-                    updateBackStackLifecycle()
-                }
-            }
-        }
+        val transitioning = state?.transitionsInProgress?.value?.contains(entry)
         if (entry.lifecycle.currentState.isAtLeast(Lifecycle.State.CREATED)) {
             if (saveState) {
                 // Move the state through STOPPED
@@ -862,7 +838,7 @@
                     val navigator = navigatorProvider
                         .getNavigator<Navigator<*>>(entry.destination.navigatorName)
                     val state = navigatorState[navigator]
-                    val transitioning = state?.transitionsInProgress?.value?.containsKey(entry)
+                    val transitioning = state?.transitionsInProgress?.value?.contains(entry)
                     if (transitioning != true) {
                         upwardStateTransitions[entry] = Lifecycle.State.RESUMED
                     } else {
diff --git a/navigation/navigation-testing/src/androidTest/java/androidx/navigation/testing/TestNavigatorStateTest.kt b/navigation/navigation-testing/src/androidTest/java/androidx/navigation/testing/TestNavigatorStateTest.kt
index 0685776..7be84d4 100644
--- a/navigation/navigation-testing/src/androidTest/java/androidx/navigation/testing/TestNavigatorStateTest.kt
+++ b/navigation/navigation-testing/src/androidTest/java/androidx/navigation/testing/TestNavigatorStateTest.kt
@@ -103,7 +103,7 @@
         assertThat(firstEntry.lifecycle.currentState)
             .isEqualTo(Lifecycle.State.STARTED)
 
-        state.transitionsInProgress.value[firstEntry]?.onTransitionComplete()
+        state.markTransitionComplete(firstEntry)
         assertThat(firstEntry.lifecycle.currentState)
             .isEqualTo(Lifecycle.State.RESUMED)
 
@@ -113,23 +113,38 @@
             .isEqualTo(Lifecycle.State.CREATED)
         assertThat(secondEntry.lifecycle.currentState)
             .isEqualTo(Lifecycle.State.STARTED)
+        assertThat(state.transitionsInProgress.value.contains(firstEntry)).isTrue()
 
-        state.transitionsInProgress.value[secondEntry]?.onTransitionComplete()
+        state.markTransitionComplete(firstEntry)
+        state.markTransitionComplete(secondEntry)
         assertThat(secondEntry.lifecycle.currentState)
             .isEqualTo(Lifecycle.State.RESUMED)
 
-        navigator.popBackStack(secondEntry, false)
+        navigator.popBackStack(secondEntry, true)
         assertThat(secondEntry.lifecycle.currentState)
             .isEqualTo(Lifecycle.State.CREATED)
         assertThat(firstEntry.lifecycle.currentState)
             .isEqualTo(Lifecycle.State.STARTED)
 
-        state.transitionsInProgress.value[firstEntry]?.onTransitionComplete()
+        state.markTransitionComplete(firstEntry)
         assertThat(firstEntry.lifecycle.currentState)
             .isEqualTo(Lifecycle.State.RESUMED)
-        state.transitionsInProgress.value[secondEntry]?.onTransitionComplete()
+        state.markTransitionComplete(secondEntry)
         assertThat(secondEntry.lifecycle.currentState)
             .isEqualTo(Lifecycle.State.DESTROYED)
+
+        val restoredSecondEntry = state.restoreBackStackEntry(secondEntry)
+        navigator.navigate(listOf(restoredSecondEntry), null, null)
+        assertThat(firstEntry.lifecycle.currentState)
+            .isEqualTo(Lifecycle.State.CREATED)
+        assertThat(restoredSecondEntry.lifecycle.currentState)
+            .isEqualTo(Lifecycle.State.STARTED)
+        assertThat(state.transitionsInProgress.value.contains(firstEntry)).isTrue()
+
+        state.markTransitionComplete(firstEntry)
+        state.markTransitionComplete(restoredSecondEntry)
+        assertThat(restoredSecondEntry.lifecycle.currentState)
+            .isEqualTo(Lifecycle.State.RESUMED)
     }
 
     @Navigator.Name("test")
diff --git a/navigation/navigation-testing/src/main/java/androidx/navigation/testing/TestNavigatorState.kt b/navigation/navigation-testing/src/main/java/androidx/navigation/testing/TestNavigatorState.kt
index eef9c77..c146ce31 100644
--- a/navigation/navigation-testing/src/main/java/androidx/navigation/testing/TestNavigatorState.kt
+++ b/navigation/navigation-testing/src/main/java/androidx/navigation/testing/TestNavigatorState.kt
@@ -27,7 +27,6 @@
 import androidx.navigation.NavDestination
 import androidx.navigation.NavViewModelStoreProvider
 import androidx.navigation.NavigatorState
-import androidx.navigation.NavigatorState.OnTransitionCompleteListener
 import kotlinx.coroutines.CoroutineDispatcher
 import kotlinx.coroutines.Dispatchers
 import kotlinx.coroutines.runBlocking
@@ -66,6 +65,7 @@
     }
 
     private val savedStates = mutableMapOf<String, Bundle>()
+    private val entrySavedState = mutableMapOf<NavBackStackEntry, Boolean>()
 
     override fun createBackStackEntry(
         destination: NavDestination,
@@ -96,19 +96,6 @@
         updateMaxLifecycle()
     }
 
-    override fun pushWithTransition(
-        backStackEntry: NavBackStackEntry
-    ): OnTransitionCompleteListener {
-        val innerListener = super.pushWithTransition(backStackEntry)
-        val listener = OnTransitionCompleteListener {
-            innerListener.onTransitionComplete()
-            updateMaxLifecycle()
-        }
-        addInProgressTransition(backStackEntry, listener)
-        updateMaxLifecycle()
-        return listener
-    }
-
     override fun pop(popUpTo: NavBackStackEntry, saveState: Boolean) {
         val beforePopList = backStack.value
         val poppedList = beforePopList.subList(beforePopList.indexOf(popUpTo), beforePopList.size)
@@ -116,35 +103,20 @@
         updateMaxLifecycle(poppedList, saveState)
     }
 
-    override fun popWithTransition(
-        popUpTo: NavBackStackEntry,
-        saveState: Boolean
-    ): OnTransitionCompleteListener {
-        // Get the entry that will be incoming after we have popped all the way up to the desired
-        // entry.
-        // We need to do this before we call popWithTransition because for the TestNavigatorState
-        // pop is called immediately, which would cause the entry to immediately go to RESUMED.
-        val incomingEntry = backStack.value.lastOrNull { entry ->
-            entry != popUpTo &&
-                backStack.value.lastIndexOf(entry) < backStack.value.lastIndexOf(popUpTo)
+    override fun popWithTransition(popUpTo: NavBackStackEntry, saveState: Boolean) {
+        super.popWithTransition(popUpTo, saveState)
+        entrySavedState[popUpTo] = saveState
+    }
+
+    override fun markTransitionComplete(entry: NavBackStackEntry) {
+        val savedState = entrySavedState[entry] == true
+        super.markTransitionComplete(entry)
+        entrySavedState.remove(entry)
+        if (!backStack.value.contains(entry)) {
+            updateMaxLifecycle(listOf(entry), savedState)
+        } else {
+            updateMaxLifecycle()
         }
-        // When popping, we need to mark the incoming entry as transitioning so we keep it
-        // STARTED until the transition completes at which point we can move it to RESUMED
-        if (incomingEntry != null) {
-            addInProgressTransition(incomingEntry) {
-                removeInProgressTransition(incomingEntry)
-                updateMaxLifecycle()
-            }
-        }
-        addInProgressTransition(popUpTo) { }
-        val innerListener = super.popWithTransition(popUpTo, saveState)
-        val listener = OnTransitionCompleteListener {
-            innerListener.onTransitionComplete()
-            updateMaxLifecycle(listOf(popUpTo))
-        }
-        addInProgressTransition(popUpTo, listener)
-        updateMaxLifecycle()
-        return listener
     }
 
     private fun updateMaxLifecycle(
@@ -158,16 +130,22 @@
             withContext(Dispatchers.Main.immediate) {
                 // Mark all removed NavBackStackEntries as DESTROYED
                 for (entry in poppedList.reversed()) {
-                    if (saveState) {
+                    if (
+                        saveState &&
+                        entry.lifecycle.currentState.isAtLeast(Lifecycle.State.STARTED)
+                    ) {
                         // Move the NavBackStackEntry to the stopped state, then save its state
                         entry.maxLifecycle = Lifecycle.State.CREATED
                         val savedState = Bundle()
                         entry.saveState(savedState)
                         savedStates[entry.id] = savedState
                     }
-                    val transitioning = transitionsInProgress.value.containsKey(entry)
+                    val transitioning = transitionsInProgress.value.contains(entry)
                     if (!transitioning) {
                         entry.maxLifecycle = Lifecycle.State.DESTROYED
+                        if (!saveState) {
+                            savedStates.remove(entry.id)
+                        }
                     } else {
                         entry.maxLifecycle = Lifecycle.State.CREATED
                     }
@@ -176,7 +154,7 @@
                 val currentList = backStack.value
                 var previousEntry: NavBackStackEntry? = null
                 for (entry in currentList.reversed()) {
-                    val transitioning = transitionsInProgress.value.containsKey(entry)
+                    val transitioning = transitionsInProgress.value.contains(entry)
                     entry.maxLifecycle = when {
                         previousEntry == null ->
                             if (!transitioning) {
diff --git a/testutils/testutils-navigation/src/main/java/androidx/testutils/TestNavigator.kt b/testutils/testutils-navigation/src/main/java/androidx/testutils/TestNavigator.kt
index f4b4933..6adce96 100644
--- a/testutils/testutils-navigation/src/main/java/androidx/testutils/TestNavigator.kt
+++ b/testutils/testutils-navigation/src/main/java/androidx/testutils/TestNavigator.kt
@@ -18,13 +18,15 @@
 
 import androidx.navigation.NavBackStackEntry
 import androidx.navigation.NavDestination
+import androidx.navigation.NavOptions
 import androidx.navigation.Navigator
 
 /**
  * A simple Navigator that doesn't actually navigate anywhere, but does dispatch correctly
  */
 @Navigator.Name("test")
-open class TestNavigator : Navigator<TestNavigator.Destination>() {
+open class TestNavigator(private val hasTransitions: Boolean = false) :
+    Navigator<TestNavigator.Destination>() {
 
     val backStack: List<NavBackStackEntry>
         get() = state.backStack.value
@@ -41,6 +43,32 @@
         return Destination(this)
     }
 
+    override fun navigate(
+        entries: List<NavBackStackEntry>,
+        navOptions: NavOptions?,
+        navigatorExtras: Extras?
+    ) {
+        entries.forEach { entry ->
+            if (hasTransitions) {
+                state.pushWithTransition(entry)
+            } else {
+                state.push(entry)
+            }
+        }
+    }
+
+    override fun popBackStack(popUpTo: NavBackStackEntry, savedState: Boolean) {
+        if (hasTransitions) {
+            state.popWithTransition(popUpTo, savedState)
+        } else {
+            super.popBackStack(popUpTo, savedState)
+        }
+    }
+
+    public fun onTransitionComplete(entry: NavBackStackEntry) {
+        state.markTransitionComplete(entry)
+    }
+
     /**
      * A simple Test destination
      */