Merge "Add transform for `SuspendFunction`s as parameters" into androidx-main
diff --git a/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/dao/BooksDao.kt b/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/dao/BooksDao.kt
index 7842512..2f097a2 100644
--- a/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/dao/BooksDao.kt
+++ b/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/dao/BooksDao.kt
@@ -46,6 +46,7 @@
 import io.reactivex.Maybe
 import io.reactivex.Single
 import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.runBlocking
 import java.util.Date
 
 @Dao
@@ -434,6 +435,18 @@
 
     suspend fun concreteSuspendFunctionWithParams(num: Int, text: String) = "$num - $text"
 
+    @Transaction
+    fun functionWithSuspendFunctionalParam(
+        input: Book,
+        action: suspend (input: Book) -> Book
+    ): Book = runBlocking { action(input) }
+
+    @Transaction
+    suspend fun suspendFunctionWithSuspendFunctionalParam(
+        input: Book,
+        action: suspend (input: Book) -> Book
+    ): Book = action(input)
+
     // This is a private method to validate b/194706278
     private fun getNullAuthor(): Author? = null
 }
diff --git a/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/test/SuspendingQueryTest.kt b/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/test/SuspendingQueryTest.kt
index 1bae59f..4ac4f3a 100644
--- a/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/test/SuspendingQueryTest.kt
+++ b/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/test/SuspendingQueryTest.kt
@@ -31,6 +31,7 @@
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import androidx.test.filters.LargeTest
 import com.google.common.truth.Truth.assertThat
+import com.google.common.truth.Truth.assertWithMessage
 import kotlinx.coroutines.Dispatchers
 import kotlinx.coroutines.ObsoleteCoroutinesApi
 import kotlinx.coroutines.async
@@ -875,4 +876,100 @@
             }
         }
     }
-}
\ No newline at end of file
+
+    @Test
+    fun transactionFunctionWithSuspendFunctionalParamCommits() = runBlocking {
+        // GIVEN a database with a book
+        val bookPublisher = TestUtil.PUBLISHER
+        val addedBook = TestUtil.BOOK_1.copy(bookPublisherId = bookPublisher.publisherId)
+        booksDao.addPublishers(bookPublisher)
+        booksDao.addBooks(addedBook)
+
+        // WHEN a transaction is run
+        val output = kotlin.runCatching {
+            booksDao.functionWithSuspendFunctionalParam(addedBook) { book ->
+                booksDao.deleteBookSuspend(book)
+                return@functionWithSuspendFunctionalParam book
+            }
+        }
+
+        // THEN the change has been committed
+        assertWithMessage("The higher-order fun ran successfully")
+            .that(output.isSuccess)
+            .isEqualTo(true)
+        assertThat(booksDao.getBooksSuspend())
+            .doesNotContain(addedBook)
+    }
+
+    @Test
+    fun transactionFunctionWithSuspendFunctionalParamDoesntCommitWhenError() = runBlocking {
+        // GIVEN a database with a book
+        val bookPublisher = TestUtil.PUBLISHER
+        val addedBook = TestUtil.BOOK_1.copy(bookPublisherId = bookPublisher.publisherId)
+        booksDao.addPublishers(bookPublisher)
+        booksDao.addBooks(addedBook)
+
+        // WHEN a transaction is started and then fails before completing
+        val output = kotlin.runCatching {
+            booksDao.functionWithSuspendFunctionalParam(addedBook) { book ->
+                booksDao.deleteBookSuspend(book)
+                error("Fake error in transaction")
+            }
+        }
+
+        // THEN the change hasn't been committed
+        assertWithMessage("RunCatching caught the thrown error")
+            .that(output.isFailure)
+            .isEqualTo(true)
+        assertThat(booksDao.getBooksSuspend())
+            .contains(addedBook)
+    }
+
+    @Test
+    fun suspendTransactionFunctionWithSuspendFunctionalParamCommits() = runBlocking {
+        // GIVEN a database with a book
+        val bookPublisher = TestUtil.PUBLISHER
+        val addedBook = TestUtil.BOOK_1.copy(bookPublisherId = bookPublisher.publisherId)
+        booksDao.addPublishers(bookPublisher)
+        booksDao.addBooks(addedBook)
+
+        // WHEN a transaction is run
+        val output = kotlin.runCatching {
+            booksDao.functionWithSuspendFunctionalParam(addedBook) { book ->
+                booksDao.deleteBookSuspend(book)
+                return@functionWithSuspendFunctionalParam book
+            }
+        }
+
+        // THEN the change has been committed
+        assertWithMessage("The higher-order fun ran successfully")
+            .that(output.isSuccess)
+            .isEqualTo(true)
+        assertThat(booksDao.getBooksSuspend())
+            .doesNotContain(addedBook)
+    }
+
+    @Test
+    fun suspendTransactionFunctionWithSuspendFunctionalParamDoesntCommitWhenError() = runBlocking {
+        // GIVEN a database with a book
+        val bookPublisher = TestUtil.PUBLISHER
+        val addedBook = TestUtil.BOOK_1.copy(bookPublisherId = bookPublisher.publisherId)
+        booksDao.addPublishers(bookPublisher)
+        booksDao.addBooks(addedBook)
+
+        // WHEN a transaction is started and then fails before completing
+        val output = runCatching {
+            booksDao.suspendFunctionWithSuspendFunctionalParam(addedBook) { book ->
+                booksDao.deleteBookSuspend(book)
+                error("Fake error in transaction")
+            }
+        }
+
+        // THEN the change hasn't been committed
+        assertWithMessage("RunCatching caught the thrown error")
+            .that(output.isFailure)
+            .isEqualTo(true)
+        assertThat(booksDao.getBooksSuspend())
+            .contains(addedBook)
+    }
+}
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeExt.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeExt.kt
index f8c7623..8e89adfe 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeExt.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeExt.kt
@@ -37,7 +37,7 @@
 import com.squareup.javapoet.TypeName
 import com.squareup.javapoet.TypeVariableName
 import com.squareup.javapoet.WildcardTypeName
-import kotlin.IllegalStateException
+import kotlin.coroutines.Continuation
 
 // Catch-all type name when we cannot resolve to anything. This is what KAPT uses as error type
 // and we use the same type in KSP for consistency.
@@ -184,15 +184,21 @@
     typeArgumentTypeLookup: TypeArgumentTypeLookup
 ): TypeName {
     return if (this.arguments.isNotEmpty()) {
-        val args: Array<TypeName> = this.arguments.mapIndexed { index, typeArg ->
-            typeArg.typeName(
-                param = this.declaration.typeParameters[index],
-                resolver = resolver,
-                typeArgumentTypeLookup = typeArgumentTypeLookup
-            )
-        }.map {
-            it.tryBox()
-        }.toTypedArray()
+        val args: Array<TypeName> = this.arguments
+            .mapIndexed { index, typeArg ->
+                typeArg.typeName(
+                    param = this.declaration.typeParameters[index],
+                    resolver = resolver,
+                    typeArgumentTypeLookup = typeArgumentTypeLookup
+                )
+            }
+            .map { it.tryBox() }
+            .let { args ->
+                if (this.isSuspendFunctionType) args.convertToSuspendSignature()
+                else args
+            }
+            .toTypedArray()
+
         when (
             val typeName = declaration
                 .typeName(resolver, typeArgumentTypeLookup).tryBox()
@@ -210,6 +216,29 @@
 }
 
 /**
+ * Transforms [this] list of arguments to a suspend signature. For a [suspend] functional type, we
+ * need to transform it to be a FunctionX with a [Continuation] with the correct return type. A
+ * transformed SuspendFunction looks like this:
+ *
+ * FunctionX<[? super $params], ? super Continuation<? super $ReturnType>, ?>
+ */
+private fun List<TypeName>.convertToSuspendSignature(): List<TypeName> {
+    val args = this
+
+    // The last arg is the return type, so take everything except the last arg
+    val actualArgs = args.subList(0, args.size - 1)
+    val continuationReturnType = WildcardTypeName.supertypeOf(args.last())
+    val continuationType = ParameterizedTypeName.get(
+        ClassName.get(Continuation::class.java),
+        continuationReturnType
+    )
+    return actualArgs + listOf(
+        WildcardTypeName.supertypeOf(continuationType),
+        WildcardTypeName.subtypeOf(TypeName.OBJECT)
+    )
+}
+
+/**
  * Root package comes as <root> instead of "" so we work around it here.
  */
 internal fun KSDeclaration.getNormalizedPackageName(): String {
@@ -275,4 +304,4 @@
 ): TypeVariableName = typeVarNameConstructor.newInstance(
     name,
     bounds
-) as TypeVariableName
\ No newline at end of file
+) as TypeVariableName
diff --git a/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/MethodSpecHelperTest.kt b/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/MethodSpecHelperTest.kt
index f76ad64..43a0b27 100644
--- a/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/MethodSpecHelperTest.kt
+++ b/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/MethodSpecHelperTest.kt
@@ -190,6 +190,22 @@
                 }
                 fun singleArg_returnsInterface(operation: (Int) -> MyInterface) {
                 }
+
+                fun noArg_suspend_returnsUnit(operation: suspend () -> Unit) {
+                }
+
+                suspend fun suspend_noArg_suspend_returnsUnit(operation: suspend () -> Unit) {
+                }
+                suspend fun suspend_no_arg_suspend_returnsString(operation: suspend () -> String) {
+                }
+                suspend fun suspend_singleArg_suspend_returnsUnit(operation: suspend (arg: String) -> Unit) {
+                }
+                suspend fun suspend_threeArgs_suspend_returnsUnit(operation: suspend (one: String, two: Int, three: Boolean) -> Unit) {
+                }
+                suspend fun suspend_singleArg_suspend_returnsString(operation: suspend (arg: String) -> String) {
+                }
+                suspend fun suspend_threeArgs_suspend_returnsString(operation: suspend (one: String, two: Int, three: Boolean) -> String) {
+                }
             }
             """.trimIndent()
         )