@@ -76,40 +76,78 @@ class ThreadContextElementTest : TestBase() {
7676 assertNull(threadContextElementThreadLocal.get())
7777 }
7878
79+ class JobCaptor (val capturees : MutableList <String > = mutableListOf()) : ThreadContextElement<Unit> {
80+
81+ companion object Key : CoroutineContext.Key<MyElement>
82+
83+ override val key: CoroutineContext .Key <* > get() = Key
84+
85+ override fun updateThreadContext (context : CoroutineContext ) {
86+ capturees.add(" Update: ${context.job} " )
87+ }
88+
89+ override fun restoreThreadContext (context : CoroutineContext , oldState : Unit ) {
90+ capturees.add(" Restore: ${context.job} " )
91+ }
92+ }
93+
94+ /* *
95+ * For stability of the test, it is important to make sure that
96+ * the parent job actually suspends when calling
97+ * `withContext(dispatcher2 + CoroutineName("dispatched"))`.
98+ *
99+ * Here this requirement is fulfilled by forcing execution on a single thread.
100+ * However, dispatching is performed with two non-equal dispatchers to force dispatching.
101+ *
102+ * Suspend of the parent coroutine [kotlinx.coroutines.DispatchedCoroutine.trySuspend] is out of the control of the test,
103+ * while being executed concurrently with resume of the child coroutine [kotlinx.coroutines.DispatchedCoroutine.tryResume].
104+ */
79105 @Test
80106 fun testWithContextJobAccess () = runTest {
107+ // Emulate non-equal dispatchers
108+ val dispatcher = Dispatchers .Default .limitedParallelism(1 )
109+ val dispatcher1 = dispatcher.limitedParallelism(1 , " dispatcher1" )
110+ val dispatcher2 = dispatcher.limitedParallelism(1 , " dispatcher2" )
81111 val captor = JobCaptor ()
82- val manuallyCaptured = ArrayList <Job >()
83- withContext(captor) {
84- manuallyCaptured + = coroutineContext.job
112+ val manuallyCaptured = mutableListOf<String >()
113+
114+ fun registerUpdate (job : Job ? ) = manuallyCaptured.add(" Update: $job " )
115+ fun registerRestore (job : Job ? ) = manuallyCaptured.add(" Restore: $job " )
116+
117+ var rootJob: Job ? = null
118+ withContext(captor + dispatcher1) {
119+ rootJob = coroutineContext.job
120+ registerUpdate(rootJob)
121+ var undispatchedJob: Job ? = null
85122 withContext(CoroutineName (" undispatched" )) {
86- manuallyCaptured + = coroutineContext.job
87- withContext(Dispatchers .Default ) {
88- manuallyCaptured + = coroutineContext.job
123+ undispatchedJob = coroutineContext.job
124+ registerUpdate(undispatchedJob)
125+ // These 2 restores and the corresponding next 2 updates happen only if the following `withContext`
126+ // call actually suspends.
127+ registerRestore(undispatchedJob)
128+ registerRestore(rootJob)
129+ // Without forcing of single backing thread the code inside `withContext`
130+ // may already complete at the moment when the parent coroutine decides
131+ // whether it needs to suspend or not.
132+ var dispatchedJob: Job ? = null
133+ withContext(dispatcher2 + CoroutineName (" dispatched" )) {
134+ dispatchedJob = coroutineContext.job
135+ registerUpdate(dispatchedJob)
89136 }
137+ registerRestore(dispatchedJob)
90138 // Context restored, captured again
91- manuallyCaptured + = coroutineContext.job
139+ registerUpdate(undispatchedJob)
92140 }
141+ registerRestore(undispatchedJob)
93142 // Context restored, captured again
94- manuallyCaptured + = coroutineContext.job
143+ registerUpdate(rootJob)
95144 }
96- assertEquals(manuallyCaptured, captor.capturees)
97- }
98- }
99-
100- private class JobCaptor () : ThreadContextElement<Unit> {
101-
102- val capturees: MutableList <Job > = mutableListOf ()
103-
104- companion object Key : CoroutineContext.Key<MyElement>
105-
106- override val key: CoroutineContext .Key <* > get() = Key
107-
108- override fun updateThreadContext (context : CoroutineContext ) {
109- capturees.add(context.job)
110- }
145+ registerRestore(rootJob)
111146
112- override fun restoreThreadContext (context : CoroutineContext , oldState : Unit ) {
147+ // Restores may be called concurrently to the update calls in other threads, so their order is not checked.
148+ val expected = manuallyCaptured.filter { it.startsWith(" Update: " ) }.joinToString(separator = " \n " )
149+ val actual = captor.capturees.filter { it.startsWith(" Update: " ) }.joinToString(separator = " \n " )
150+ assertEquals(expected, actual)
113151 }
114152}
115153
0 commit comments