Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@

package androidx.compose.runtime

import androidx.collection.mutableObjectListOf
import androidx.compose.runtime.internal.AtomicInt
import androidx.compose.runtime.platform.makeSynchronizedObject
import androidx.compose.runtime.platform.synchronized
import androidx.compose.runtime.snapshots.fastForEach
import kotlin.coroutines.Continuation
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.coroutines.resumeWithException
import kotlin.jvm.JvmInline
import kotlinx.coroutines.CancellableContinuation
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.suspendCancellableCoroutine

Expand All @@ -38,24 +42,34 @@ import kotlinx.coroutines.suspendCancellableCoroutine
*/
class BroadcastFrameClock(private val onNewAwaiters: (() -> Unit)? = null) : MonotonicFrameClock {

private class FrameAwaiter<R>(val onFrame: (Long) -> R, val continuation: Continuation<R>) {
private class FrameAwaiter<R>(onFrame: (Long) -> R, continuation: CancellableContinuation<R>) {
private var onFrame: ((Long) -> R)? = onFrame
private var continuation: (CancellableContinuation<R>)? = continuation

fun cancel() {
onFrame = null
continuation = null
}

fun resume(timeNanos: Long) {
continuation.resumeWith(runCatching { onFrame(timeNanos) })
val onFrame = onFrame ?: return
continuation?.resumeWith(runCatching { onFrame(timeNanos) })
}

fun resumeWithException(exception: Throwable) {
continuation?.resumeWithException(exception)
}
}

private val lock = makeSynchronizedObject()
private var failureCause: Throwable? = null
private var awaiters = mutableListOf<FrameAwaiter<*>>()
private var spareList = mutableListOf<FrameAwaiter<*>>()

// Uses AtomicInt to avoid adding AtomicBoolean to the Expect/Actual requirements of the
// runtime.
private val hasAwaitersUnlocked = AtomicInt(0)
private val pendingAwaitersCountUnlocked = AtomicAwaitersCount()
private var awaiters = mutableObjectListOf<FrameAwaiter<*>>()
private var spareList = mutableObjectListOf<FrameAwaiter<*>>()

/** `true` if there are any callers of [withFrameNanos] awaiting to run for a pending frame. */
val hasAwaiters: Boolean
get() = hasAwaitersUnlocked.get() != 0
get() = pendingAwaitersCountUnlocked.hasAwaiters()

/**
* Send a frame for time [timeNanos] to all current callers of [withFrameNanos]. The `onFrame`
Expand All @@ -69,7 +83,7 @@ class BroadcastFrameClock(private val onNewAwaiters: (() -> Unit)? = null) : Mon
val toResume = awaiters
awaiters = spareList
spareList = toResume
hasAwaitersUnlocked.set(0)
pendingAwaitersCountUnlocked.incrementVersionAndResetCount()

for (i in 0 until toResume.size) {
toResume[i].resume(timeNanos)
Expand All @@ -81,24 +95,24 @@ class BroadcastFrameClock(private val onNewAwaiters: (() -> Unit)? = null) : Mon
override suspend fun <R> withFrameNanos(onFrame: (Long) -> R): R =
suspendCancellableCoroutine { co ->
val awaiter = FrameAwaiter(onFrame, co)
val hasNewAwaiters =
synchronized(lock) {
val cause = failureCause
if (cause != null) {
co.resumeWithException(cause)
return@suspendCancellableCoroutine
}
val hadAwaiters = awaiters.isNotEmpty()
awaiters.add(awaiter)
if (!hadAwaiters) hasAwaitersUnlocked.set(1)
!hadAwaiters
var hasNewAwaiters = false
var awaitersVersion = -1
synchronized(lock) {
val cause = failureCause
if (cause != null) {
co.resumeWithException(cause)
return@suspendCancellableCoroutine
}
awaitersVersion =
pendingAwaitersCountUnlocked.incrementCountAndGetVersion(
ifFirstAwaiter = { hasNewAwaiters = true }
)
awaiters.add(awaiter)
}

co.invokeOnCancellation {
synchronized(lock) {
awaiters.remove(awaiter)
if (awaiters.isEmpty()) hasAwaitersUnlocked.set(0)
}
awaiter.cancel()
pendingAwaitersCountUnlocked.decrementCount(awaitersVersion)
}

// Wake up anything that was waiting for someone to schedule a frame
Expand All @@ -118,9 +132,9 @@ class BroadcastFrameClock(private val onNewAwaiters: (() -> Unit)? = null) : Mon
synchronized(lock) {
if (failureCause != null) return
failureCause = cause
awaiters.fastForEach { awaiter -> awaiter.continuation.resumeWithException(cause) }
awaiters.forEach { awaiter -> awaiter.resumeWithException(cause) }
awaiters.clear()
hasAwaitersUnlocked.set(0)
pendingAwaitersCountUnlocked.incrementVersionAndResetCount()
}
}

Expand All @@ -133,4 +147,84 @@ class BroadcastFrameClock(private val onNewAwaiters: (() -> Unit)? = null) : Mon
) {
fail(cancellationException)
}

/**
* [BroadcastFrameClock] tracks the number of pending [FrameAwaiter]s using this atomic type.
* This count is made up of two components: The count itself ([COUNT_BITS] bits) and a version
* ([VERSION_BITS] bits).
*
* The count is incremented when a new awaiter is added, and decremented when an awaiter is
* cancelled. When the pending awaiters are processed, this count is reset to zero. To prevent a
* race condition that can cause an inaccurate count when awaiters are removed, cancelled
* awaiters only decrement their count when the version of the counter has not changed. The
* version is incremented every time the awaiters are dispatched and the count resets to zero.
*
* The number of bits required to track the version is very small, and the version is allowed
* and expected to roll over. By allocating 4 bits for the version, cancellation events can be
* correctly counted as long as the cancellation callback completes within 16 [sendFrame]
* invocations. Most cancelled awaiters will invoke their cancellation logic almost immediately,
* so even a narrow version range can be highly effective.
*/
@Suppress("NOTHING_TO_INLINE")
@JvmInline
private value class AtomicAwaitersCount private constructor(private val value: AtomicInt) {
constructor() : this(AtomicInt(0))

inline fun hasAwaiters(): Boolean = value.get().count > 0

inline fun incrementVersionAndResetCount() {
update { pack(version = it.version + 1, count = 0) }
}

@OptIn(ExperimentalContracts::class)
inline fun incrementCountAndGetVersion(ifFirstAwaiter: () -> Unit): Int {
contract { callsInPlace(ifFirstAwaiter, InvocationKind.AT_MOST_ONCE) }
val newValue = update { it + 1 }
if (newValue.count == 1) ifFirstAwaiter()
return newValue.version
}

inline fun decrementCount(version: Int) {
update { value -> if (value.version == version) value - 1 else value }
}

private inline fun update(calculation: (Int) -> Int): Int {
var oldValue: Int
var newValue: Int
do {
oldValue = value.get()
newValue = calculation(oldValue)
} while (!value.compareAndSet(oldValue, newValue))
return newValue
}

/**
* Bitpacks [version] and [count] together. The topmost bit is always 0 to enforce this
* value always being positive. [version] takes the next [VERSION_BITS] topmost bits, and
* [count] takes the remaining [COUNT_BITS] bits.
*
* `| 0 | version | count |`
*/
private fun pack(version: Int, count: Int): Int {
val versionComponent = (version and (-1 shl VERSION_BITS).inv()) shl COUNT_BITS
val countComponent = count and (-1 shl COUNT_BITS).inv()
return versionComponent or countComponent
}

private inline val Int.version: Int
get() = (this ushr COUNT_BITS) and (-1 shl VERSION_BITS).inv()

private inline val Int.count: Int
get() = this and (-1 shl COUNT_BITS).inv()

override fun toString(): String {
val current = value.get()
return "AtomicAwaitersCount(version = ${current.version}, count = ${current.count})"
}

companion object {
private const val VERSION_BITS = 4
private const val COUNT_BITS = Int.SIZE_BITS - VERSION_BITS - 1
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright 2025 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package androidx.compose.runtime.internal

internal actual fun sleep(millis: UInt) =
Thread.sleep(millis.toLong())
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright 2025 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package androidx.compose.runtime.internal

internal actual fun sleep(millis: UInt) {
platform.posix.usleep(millis)
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,36 @@
* limitations under the License.
*/

package androidx.compose.runtime.dispatch
package androidx.compose.runtime

import androidx.compose.runtime.internal.AtomicInt
import androidx.compose.runtime.internal.sleep
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertTrue
import kotlin.time.Duration.Companion.milliseconds
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CoroutineStart.UNDISPATCHED
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.InternalCoroutinesApi
import kotlinx.coroutines.async
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.launch
import kotlinx.coroutines.test.UnconfinedTestDispatcher
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.yield
import kotlinx.test.IgnoreJsTarget
import kotlinx.test.IgnoreWasmTarget

@ExperimentalCoroutinesApi
class BroadcastFrameClockTest {
@Test
fun sendAndReceiveFrames() =
runTest(UnconfinedTestDispatcher()) {
val clock = androidx.compose.runtime.BroadcastFrameClock()
val clock = BroadcastFrameClock()

val frameAwaiter = async { clock.withFrameNanos { it } }

Expand All @@ -49,7 +61,7 @@ class BroadcastFrameClockTest {
@Test
fun cancelClock() =
runTest(UnconfinedTestDispatcher()) {
val clock = androidx.compose.runtime.BroadcastFrameClock()
val clock = BroadcastFrameClock()
val frameAwaiter = async { clock.withFrameNanos { it } }

clock.cancel()
Expand All @@ -66,15 +78,48 @@ class BroadcastFrameClockTest {
@Test
fun failClockWhenNewAwaitersNotified() =
runTest(UnconfinedTestDispatcher()) {
val clock =
androidx.compose.runtime.BroadcastFrameClock {
throw CancellationException("failed frame clock")
}
val clock = BroadcastFrameClock { throw CancellationException("failed frame clock") }

val failingAwaiter = async { clock.withFrameNanos { it } }
assertAwaiterCancelled("failingAwaiter", failingAwaiter)

val lateAwaiter = async { clock.withFrameNanos { it } }
assertAwaiterCancelled("lateAwaiter", lateAwaiter)
}

@OptIn(InternalCoroutinesApi::class)
@IgnoreJsTarget
@IgnoreWasmTarget
@Test
fun locklessCancellation() = runTest(timeout = 5_000.milliseconds) {

val clock = BroadcastFrameClock()
val cancellationGate = AtomicInt(1)

var spin = true
async(start = UNDISPATCHED) {
clock.withFrameNanos {
cancellationGate.add(-1)
while (spin) sleep(100u)
}
}

val cancellingJob = async(start = UNDISPATCHED) { clock.withFrameNanos {} }

launch(Dispatchers.Default) { clock.sendFrame(1) }

// Wait for the spinlock to start
while (cancellationGate.get() != 0) yield()

// Assert that this line doesn't deadlock.
cancellingJob.cancelAndJoin()

// Make sure that we can queue up new jobs for subsequent frames
spin = false
assertFalse(clock.hasAwaiters)
async(start = UNDISPATCHED) { clock.withFrameNanos {} }
assertTrue(clock.hasAwaiters)

clock.cancel()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright 2025 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package androidx.compose.runtime.internal

internal expect fun sleep(millis: UInt)
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright 2025 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package androidx.compose.runtime.internal

internal actual fun sleep(millis: UInt): Unit =
throw UnsupportedOperationException("Sleep is not supported")
Loading