Add transform for `SuspendFunction`s as parameters
Currently, `KSType.typeName` doesn't consider suspend functional types
a special case. This means that `suspend () -> Unit` would get resolved
as `Function1<Unit>` and not as `Function1<Continuation<Unit>, ?>`.
This CL adds a check if the `KSType` is a suspend functional type
and adds a continuation parameter with the correct return type.
Test: MethodSpecHelperTest#kotlinParametersAsFunction, SuspendingQueryTest
Bug: b/201674894
Change-Id: I02c97b4ee87888c1494605d8fce2442e9e0ce017
diff --git a/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/dao/BooksDao.kt b/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/dao/BooksDao.kt
index 7842512..2f097a2 100644
--- a/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/dao/BooksDao.kt
+++ b/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/dao/BooksDao.kt
@@ -46,6 +46,7 @@
import io.reactivex.Maybe
import io.reactivex.Single
import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.runBlocking
import java.util.Date
@Dao
@@ -434,6 +435,18 @@
suspend fun concreteSuspendFunctionWithParams(num: Int, text: String) = "$num - $text"
+ @Transaction
+ fun functionWithSuspendFunctionalParam(
+ input: Book,
+ action: suspend (input: Book) -> Book
+ ): Book = runBlocking { action(input) }
+
+ @Transaction
+ suspend fun suspendFunctionWithSuspendFunctionalParam(
+ input: Book,
+ action: suspend (input: Book) -> Book
+ ): Book = action(input)
+
// This is a private method to validate b/194706278
private fun getNullAuthor(): Author? = null
}
diff --git a/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/test/SuspendingQueryTest.kt b/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/test/SuspendingQueryTest.kt
index 1bae59f..4ac4f3a 100644
--- a/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/test/SuspendingQueryTest.kt
+++ b/room/integration-tests/kotlintestapp/src/androidTest/java/androidx/room/integration/kotlintestapp/test/SuspendingQueryTest.kt
@@ -31,6 +31,7 @@
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.filters.LargeTest
import com.google.common.truth.Truth.assertThat
+import com.google.common.truth.Truth.assertWithMessage
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ObsoleteCoroutinesApi
import kotlinx.coroutines.async
@@ -875,4 +876,100 @@
}
}
}
-}
\ No newline at end of file
+
+ @Test
+ fun transactionFunctionWithSuspendFunctionalParamCommits() = runBlocking {
+ // GIVEN a database with a book
+ val bookPublisher = TestUtil.PUBLISHER
+ val addedBook = TestUtil.BOOK_1.copy(bookPublisherId = bookPublisher.publisherId)
+ booksDao.addPublishers(bookPublisher)
+ booksDao.addBooks(addedBook)
+
+ // WHEN a transaction is run
+ val output = kotlin.runCatching {
+ booksDao.functionWithSuspendFunctionalParam(addedBook) { book ->
+ booksDao.deleteBookSuspend(book)
+ return@functionWithSuspendFunctionalParam book
+ }
+ }
+
+ // THEN the change has been committed
+ assertWithMessage("The higher-order fun ran successfully")
+ .that(output.isSuccess)
+ .isEqualTo(true)
+ assertThat(booksDao.getBooksSuspend())
+ .doesNotContain(addedBook)
+ }
+
+ @Test
+ fun transactionFunctionWithSuspendFunctionalParamDoesntCommitWhenError() = runBlocking {
+ // GIVEN a database with a book
+ val bookPublisher = TestUtil.PUBLISHER
+ val addedBook = TestUtil.BOOK_1.copy(bookPublisherId = bookPublisher.publisherId)
+ booksDao.addPublishers(bookPublisher)
+ booksDao.addBooks(addedBook)
+
+ // WHEN a transaction is started and then fails before completing
+ val output = kotlin.runCatching {
+ booksDao.functionWithSuspendFunctionalParam(addedBook) { book ->
+ booksDao.deleteBookSuspend(book)
+ error("Fake error in transaction")
+ }
+ }
+
+ // THEN the change hasn't been committed
+ assertWithMessage("RunCatching caught the thrown error")
+ .that(output.isFailure)
+ .isEqualTo(true)
+ assertThat(booksDao.getBooksSuspend())
+ .contains(addedBook)
+ }
+
+ @Test
+ fun suspendTransactionFunctionWithSuspendFunctionalParamCommits() = runBlocking {
+ // GIVEN a database with a book
+ val bookPublisher = TestUtil.PUBLISHER
+ val addedBook = TestUtil.BOOK_1.copy(bookPublisherId = bookPublisher.publisherId)
+ booksDao.addPublishers(bookPublisher)
+ booksDao.addBooks(addedBook)
+
+ // WHEN a transaction is run
+ val output = kotlin.runCatching {
+ booksDao.functionWithSuspendFunctionalParam(addedBook) { book ->
+ booksDao.deleteBookSuspend(book)
+ return@functionWithSuspendFunctionalParam book
+ }
+ }
+
+ // THEN the change has been committed
+ assertWithMessage("The higher-order fun ran successfully")
+ .that(output.isSuccess)
+ .isEqualTo(true)
+ assertThat(booksDao.getBooksSuspend())
+ .doesNotContain(addedBook)
+ }
+
+ @Test
+ fun suspendTransactionFunctionWithSuspendFunctionalParamDoesntCommitWhenError() = runBlocking {
+ // GIVEN a database with a book
+ val bookPublisher = TestUtil.PUBLISHER
+ val addedBook = TestUtil.BOOK_1.copy(bookPublisherId = bookPublisher.publisherId)
+ booksDao.addPublishers(bookPublisher)
+ booksDao.addBooks(addedBook)
+
+ // WHEN a transaction is started and then fails before completing
+ val output = runCatching {
+ booksDao.suspendFunctionWithSuspendFunctionalParam(addedBook) { book ->
+ booksDao.deleteBookSuspend(book)
+ error("Fake error in transaction")
+ }
+ }
+
+ // THEN the change hasn't been committed
+ assertWithMessage("RunCatching caught the thrown error")
+ .that(output.isFailure)
+ .isEqualTo(true)
+ assertThat(booksDao.getBooksSuspend())
+ .contains(addedBook)
+ }
+}
diff --git a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeExt.kt b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeExt.kt
index f8c7623..8e89adfe 100644
--- a/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeExt.kt
+++ b/room/room-compiler-processing/src/main/java/androidx/room/compiler/processing/ksp/KSTypeExt.kt
@@ -37,7 +37,7 @@
import com.squareup.javapoet.TypeName
import com.squareup.javapoet.TypeVariableName
import com.squareup.javapoet.WildcardTypeName
-import kotlin.IllegalStateException
+import kotlin.coroutines.Continuation
// Catch-all type name when we cannot resolve to anything. This is what KAPT uses as error type
// and we use the same type in KSP for consistency.
@@ -184,15 +184,21 @@
typeArgumentTypeLookup: TypeArgumentTypeLookup
): TypeName {
return if (this.arguments.isNotEmpty()) {
- val args: Array<TypeName> = this.arguments.mapIndexed { index, typeArg ->
- typeArg.typeName(
- param = this.declaration.typeParameters[index],
- resolver = resolver,
- typeArgumentTypeLookup = typeArgumentTypeLookup
- )
- }.map {
- it.tryBox()
- }.toTypedArray()
+ val args: Array<TypeName> = this.arguments
+ .mapIndexed { index, typeArg ->
+ typeArg.typeName(
+ param = this.declaration.typeParameters[index],
+ resolver = resolver,
+ typeArgumentTypeLookup = typeArgumentTypeLookup
+ )
+ }
+ .map { it.tryBox() }
+ .let { args ->
+ if (this.isSuspendFunctionType) args.convertToSuspendSignature()
+ else args
+ }
+ .toTypedArray()
+
when (
val typeName = declaration
.typeName(resolver, typeArgumentTypeLookup).tryBox()
@@ -210,6 +216,29 @@
}
/**
+ * Transforms [this] list of arguments to a suspend signature. For a [suspend] functional type, we
+ * need to transform it to be a FunctionX with a [Continuation] with the correct return type. A
+ * transformed SuspendFunction looks like this:
+ *
+ * FunctionX<[? super $params], ? super Continuation<? super $ReturnType>, ?>
+ */
+private fun List<TypeName>.convertToSuspendSignature(): List<TypeName> {
+ val args = this
+
+ // The last arg is the return type, so take everything except the last arg
+ val actualArgs = args.subList(0, args.size - 1)
+ val continuationReturnType = WildcardTypeName.supertypeOf(args.last())
+ val continuationType = ParameterizedTypeName.get(
+ ClassName.get(Continuation::class.java),
+ continuationReturnType
+ )
+ return actualArgs + listOf(
+ WildcardTypeName.supertypeOf(continuationType),
+ WildcardTypeName.subtypeOf(TypeName.OBJECT)
+ )
+}
+
+/**
* Root package comes as <root> instead of "" so we work around it here.
*/
internal fun KSDeclaration.getNormalizedPackageName(): String {
@@ -275,4 +304,4 @@
): TypeVariableName = typeVarNameConstructor.newInstance(
name,
bounds
-) as TypeVariableName
\ No newline at end of file
+) as TypeVariableName
diff --git a/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/MethodSpecHelperTest.kt b/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/MethodSpecHelperTest.kt
index f76ad64..43a0b27 100644
--- a/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/MethodSpecHelperTest.kt
+++ b/room/room-compiler-processing/src/test/java/androidx/room/compiler/processing/MethodSpecHelperTest.kt
@@ -190,6 +190,22 @@
}
fun singleArg_returnsInterface(operation: (Int) -> MyInterface) {
}
+
+ fun noArg_suspend_returnsUnit(operation: suspend () -> Unit) {
+ }
+
+ suspend fun suspend_noArg_suspend_returnsUnit(operation: suspend () -> Unit) {
+ }
+ suspend fun suspend_no_arg_suspend_returnsString(operation: suspend () -> String) {
+ }
+ suspend fun suspend_singleArg_suspend_returnsUnit(operation: suspend (arg: String) -> Unit) {
+ }
+ suspend fun suspend_threeArgs_suspend_returnsUnit(operation: suspend (one: String, two: Int, three: Boolean) -> Unit) {
+ }
+ suspend fun suspend_singleArg_suspend_returnsString(operation: suspend (arg: String) -> String) {
+ }
+ suspend fun suspend_threeArgs_suspend_returnsString(operation: suspend (one: String, two: Int, three: Boolean) -> String) {
+ }
}
""".trimIndent()
)