@@ -215,21 +215,28 @@ def tick(self):
215
215
216
216
class Thread :
217
217
task : Task
218
+ index : int
218
219
future : Optional [asyncio .Future ]
219
220
on_resume : Optional [asyncio .Future ]
220
221
on_suspend_or_exit : Optional [asyncio .Future ]
222
+ context : list [int ]
223
+
224
+ CONTEXT_LENGTH = 2
221
225
222
226
def __init__ (self , task , thread_func ):
223
227
self .task = task
228
+ self .index = task .inst .table .add (self )
224
229
self .future = None
225
230
self .on_resume = asyncio .Future ()
226
231
self .on_suspend_or_exit = None
232
+ self .context = [0 ] * Thread .CONTEXT_LENGTH
227
233
async def thread_start ():
228
234
await self .on_resume
229
235
self .on_resume = None
230
236
await thread_func (task , self )
231
237
self .on_suspend_or_exit .set_result (None )
232
238
self .task .thread = None
239
+ self .task .inst .table .remove (self .index )
233
240
asyncio .create_task (thread_start ())
234
241
235
242
async def resume (self ):
@@ -254,6 +261,29 @@ async def suspend(self, future):
254
261
await self .on_resume
255
262
self .on_resume = None
256
263
264
+ async def switch (self , other : Thread ):
265
+ assert (not self .future and not other .future )
266
+ assert (self .on_suspend_or_exit and not other .on_suspend_or_exit )
267
+ other .on_suspend_or_exit = self .on_suspend_or_exit
268
+ self .on_suspend_or_exit = None
269
+ other .on_resume .set_result (Cancelled .FALSE )
270
+ assert (not self .on_resume )
271
+ self .on_resume = asyncio .Future ()
272
+ await self .on_resume
273
+ self .on_resume = None
274
+
275
+ def yield_to (self , other : Thread ):
276
+ # deterministically switch to other, but leave this thread unblocked
277
+ TODO
278
+
279
+ def block (self ):
280
+ # perform just the first half of switch
281
+ TODO
282
+
283
+ def unblock (self , other : Thread ):
284
+ # unblock other, but deterministically keep running here
285
+ TODO
286
+
257
287
258
288
### Lifting and Lowering Context
259
289
@@ -431,22 +461,6 @@ def write(self, vs):
431
461
assert (all (v == () for v in vs ))
432
462
self .progress += len (vs )
433
463
434
- #### Context-Local Storage
435
-
436
- class ContextLocalStorage :
437
- LENGTH = 1
438
- array : list [int ]
439
-
440
- def __init__ (self ):
441
- self .array = [0 ] * ContextLocalStorage .LENGTH
442
-
443
- def set (self , i , v ):
444
- assert (types_match_values (['i32' ], [v ]))
445
- self .array [i ] = v
446
-
447
- def get (self , i ):
448
- return self .array [i ]
449
-
450
464
#### Waitable State
451
465
452
466
class EventCode (IntEnum ):
@@ -457,6 +471,7 @@ class EventCode(IntEnum):
457
471
FUTURE_READ = 4
458
472
FUTURE_WRITE = 5
459
473
TASK_CANCELLED = 6
474
+ THREAD_RESUMED = 7
460
475
461
476
EventTuple = tuple [EventCode , int , int ]
462
477
@@ -546,7 +561,6 @@ class State(Enum):
546
561
thread : Thread
547
562
cancellable : bool
548
563
num_borrows : int
549
- context : ContextLocalStorage
550
564
551
565
def __init__ (self , opts , inst , ft , supertask , on_resolve , thread_func ):
552
566
self .state = Task .State .INITIAL
@@ -558,7 +572,6 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, thread_func):
558
572
self .thread = Thread (self , thread_func )
559
573
self .cancellable = False
560
574
self .num_borrows = 0
561
- self .context = ContextLocalStorage ()
562
575
563
576
def trap_if_on_the_stack (self , inst ):
564
577
c = self .supertask
@@ -638,6 +651,7 @@ async def wait_for_event(self, thread, waitable_set, cancellable, unlock) -> Eve
638
651
waitable_set .num_waiting += 1
639
652
e = None
640
653
while not e :
654
+ # TODO: somehow get a THREAD_RESUME event...
641
655
maybe_event = waitable_set .maybe_has_pending_event .wait ()
642
656
await self .block_on (thread , maybe_event , cancellable , unlock )
643
657
if self .deliver_cancel ():
@@ -650,6 +664,7 @@ async def yield_(self, thread, cancellable, unlock) -> EventTuple:
650
664
assert (self .thread is thread and self is thread .task )
651
665
if cancellable and self .deliver_cancel ():
652
666
return (EventCode .TASK_CANCELLED , 0 , 0 )
667
+ # TODO: somehow get a THREAD_RESUME event...
653
668
await self .block_on (thread , asyncio .sleep (0 ), cancellable , unlock )
654
669
if cancellable and self .deliver_cancel ():
655
670
return (EventCode .TASK_CANCELLED , 0 , 0 )
@@ -681,7 +696,7 @@ def cancel(self):
681
696
self .state = Task .State .RESOLVED
682
697
683
698
def exit (self ):
684
- trap_if (self .state != Task .State .RESOLVED )
699
+ trap_if (self .state != Task .State .RESOLVED ) # TODO: move this to empty-threads case
685
700
assert (self .num_borrows == 0 )
686
701
if self .needs_lock ():
687
702
self .inst .lock .release ()
@@ -2102,25 +2117,76 @@ async def canon_resource_rep(rt, thread, i):
2102
2117
trap_if (h .rt is not rt )
2103
2118
return [h .rep ]
2104
2119
2120
+ ### 🧵 `canon thread.index`
2121
+
2122
+ async def canon_thread_index (shared , thread ):
2123
+ assert (not shared )
2124
+ return [thread .index ]
2125
+
2126
+ ### 🧵 `canon thread.new_indirect`
2127
+
2128
+ async def canon_thread_new_indirect (ft , ftbl , thread , i , c ):
2129
+ trap_if (not thread .task .inst .may_leave )
2130
+ f = thread .task .inst .ftbl .get (i )
2131
+ trap_if (f .type != ft )
2132
+ thread = Thread (thread .task , f (c ))
2133
+ return [thread .index ]
2134
+
2135
+ ### 🧵 `canon thread.switch`
2136
+
2137
+ async def canon_thread_switch (thread , i ):
2138
+ trap_if (not thread .task .inst .may_leave )
2139
+ other = thread .task .inst .table .get (i )
2140
+ trap_if (not isinstance (other , Thread ))
2141
+ cancelled = await thread .switch (other )
2142
+ return [ 1 if cancelled else 0 ]
2143
+
2144
+ ### 🧵 `canon thread.yield-to`
2145
+
2146
+ async def canon_thread_yield_to (thread , i ):
2147
+ trap_if (not thread .task .inst .may_leave )
2148
+ other = thread .task .inst .table .get (i )
2149
+ trap_if (not isinstance (other , Thread ))
2150
+ other .yield_to (other )
2151
+ return []
2152
+
2153
+ ### 🧵 `canon thread.block`
2154
+
2155
+ async def canon_thread_block (thread , i ):
2156
+ trap_if (not thread .task .inst .may_leave )
2157
+ other = thread .task .inst .table .get (i )
2158
+ trap_if (not isinstance (other , Thread ))
2159
+ cancelled = await thread .block ()
2160
+ return [ 1 if cancelled else 0 ]
2161
+
2162
+ ### 🧵 `canon thread.unblock`
2163
+
2164
+ async def canon_thread_unblock (thread , i ):
2165
+ trap_if (not thread .task .inst .may_leave )
2166
+ other = thread .task .inst .table .get (i )
2167
+ trap_if (not isinstance (other , Thread ))
2168
+ thread .unblock ()
2169
+ return []
2170
+
2105
2171
### 🔀 `canon context.get`
2106
2172
2107
2173
async def canon_context_get (t , i , thread ):
2108
2174
assert (t == 'i32' )
2109
- assert (i < ContextLocalStorage . LENGTH )
2110
- return [thread .task . context . get ( i ) ]
2175
+ assert (i < Thread . CONTEXT_LENGTH )
2176
+ return [thread .context [ i ] ]
2111
2177
2112
2178
### 🔀 `canon context.set`
2113
2179
2114
2180
async def canon_context_set (t , i , thread , v ):
2115
2181
assert (t == 'i32' )
2116
- assert (i < ContextLocalStorage . LENGTH )
2117
- thread .task . context . set ( i , v )
2182
+ assert (i < Thread . CONTEXT_LENGTH )
2183
+ thread .context [ i ] = v
2118
2184
return []
2119
2185
2120
2186
### 🔀 `canon backpressure.set`
2121
2187
2122
2188
async def canon_backpressure_set (thread , flat_args ):
2123
- trap_if (thread .task .opts .sync )
2189
+ # TODO: remove trap_if(thread.task.opts.sync)
2124
2190
assert (len (flat_args ) == 1 )
2125
2191
if flat_args [0 ] == 0 :
2126
2192
thread .task .inst .no_backpressure .set ()
0 commit comments