@@ -215,21 +215,28 @@ def tick(self):
215
215
216
216
class Thread :
217
217
task : Task
218
+ index : int
218
219
future : Optional [asyncio .Future ]
219
220
on_resume : Optional [asyncio .Future ]
220
221
on_suspend_or_exit : Optional [asyncio .Future ]
222
+ context : list [int ]
223
+
224
+ CONTEXT_LENGTH = 2
221
225
222
226
def __init__ (self , task , thread_func ):
223
227
self .task = task
228
+ self .index = task .inst .table .add (self )
224
229
self .future = None
225
230
self .on_resume = asyncio .Future ()
226
231
self .on_suspend_or_exit = None
232
+ self .context = [0 ] * Thread .CONTEXT_LENGTH
227
233
async def thread_start ():
228
234
await self .on_resume
229
235
self .on_resume = None
230
236
await thread_func (task , self )
231
237
self .on_suspend_or_exit .set_result (None )
232
- self .task .thread = None
238
+ self .task .inst .table .remove (self .index )
239
+ self .task .thread_return (self )
233
240
asyncio .create_task (thread_start ())
234
241
235
242
async def resume (self ):
@@ -254,6 +261,29 @@ async def suspend(self, future):
254
261
await self .on_resume
255
262
self .on_resume = None
256
263
264
+ async def switch (self , other : Thread ):
265
+ assert (not self .future and not other .future )
266
+ assert (self .on_suspend_or_exit and not other .on_suspend_or_exit )
267
+ other .on_suspend_or_exit = self .on_suspend_or_exit
268
+ self .on_suspend_or_exit = None
269
+ other .on_resume .set_result (Cancelled .FALSE )
270
+ assert (not self .on_resume )
271
+ self .on_resume = asyncio .Future ()
272
+ await self .on_resume
273
+ self .on_resume = None
274
+
275
+ def yield_to (self , other : Thread ):
276
+ # deterministically switch to other, but leave this thread unblocked
277
+ TODO
278
+
279
+ def block (self ):
280
+ # perform just the first half of switch
281
+ TODO
282
+
283
+ def unblock (self , other : Thread ):
284
+ # unblock other, but deterministically keep running here
285
+ TODO
286
+
257
287
258
288
### Lifting and Lowering Context
259
289
@@ -431,22 +461,6 @@ def write(self, vs):
431
461
assert (all (v == () for v in vs ))
432
462
self .progress += len (vs )
433
463
434
- #### Context-Local Storage
435
-
436
- class ContextLocalStorage :
437
- LENGTH = 1
438
- array : list [int ]
439
-
440
- def __init__ (self ):
441
- self .array = [0 ] * ContextLocalStorage .LENGTH
442
-
443
- def set (self , i , v ):
444
- assert (types_match_values (['i32' ], [v ]))
445
- self .array [i ] = v
446
-
447
- def get (self , i ):
448
- return self .array [i ]
449
-
450
464
#### Waitable State
451
465
452
466
class EventCode (IntEnum ):
@@ -457,6 +471,7 @@ class EventCode(IntEnum):
457
471
FUTURE_READ = 4
458
472
FUTURE_WRITE = 5
459
473
TASK_CANCELLED = 6
474
+ THREAD_RESUMED = 7
460
475
461
476
EventTuple = tuple [EventCode , int , int ]
462
477
@@ -543,10 +558,9 @@ class State(Enum):
543
558
ft : FuncType
544
559
supertask : Optional [Task ]
545
560
on_resolve : OnResolve
546
- thread : Thread
561
+ threads : list [ Thread ]
547
562
cancellable : bool
548
563
num_borrows : int
549
- context : ContextLocalStorage
550
564
551
565
def __init__ (self , opts , inst , ft , supertask , on_resolve , thread_func ):
552
566
self .state = Task .State .INITIAL
@@ -555,10 +569,13 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, thread_func):
555
569
self .ft = ft
556
570
self .supertask = supertask
557
571
self .on_resolve = on_resolve
558
- self .thread = Thread (self , thread_func )
572
+ self .threads = [ Thread (self , thread_func )]
559
573
self .cancellable = False
560
574
self .num_borrows = 0
561
- self .context = ContextLocalStorage ()
575
+
576
+ async def start (self ):
577
+ assert (len (self .threads ) == 1 )
578
+ await self .threads [0 ].resume ()
562
579
563
580
def trap_if_on_the_stack (self , inst ):
564
581
c = self .supertask
@@ -569,8 +586,10 @@ def trap_if_on_the_stack(self, inst):
569
586
async def request_cancellation (self ):
570
587
assert (self .state == Task .State .INITIAL )
571
588
self .state = Task .State .PENDING_CANCEL
589
+ # TODO: move cancellability to the Thread and then search
590
+ # for a cancellable one here...
572
591
if self .cancellable :
573
- await self .thread .resume ()
592
+ await self .threads [ 0 ] .resume ()
574
593
575
594
def deliver_cancel (self ) -> bool :
576
595
if self .state == Task .State .PENDING_CANCEL :
@@ -582,8 +601,9 @@ def deliver_cancel(self) -> bool:
582
601
def needs_lock (self ):
583
602
return self .opts .sync or self .opts .callback
584
603
604
+ # TODO: somehow break this up...
585
605
async def enter (self , thread ):
586
- assert (thread is self .thread and thread .task is self )
606
+ assert (thread in self .threads and thread .task is self )
587
607
if (self .inst .no_backpressure .is_set () and
588
608
self .inst .num_pending == 0 and
589
609
(not self .needs_lock () or not self .inst .lock .locked ())):
@@ -599,6 +619,7 @@ async def enter(self, thread):
599
619
self .inst .num_pending -= 1
600
620
if self .deliver_cancel ():
601
621
self .on_resolve (None )
622
+ self .state = Task .State .RESOLVED
602
623
return False
603
624
if not self .inst .no_backpressure .is_set ():
604
625
continue
@@ -611,6 +632,7 @@ async def enter(self, thread):
611
632
else :
612
633
acquired .cancel ()
613
634
self .on_resolve (None )
635
+ self .state = Task .State .RESOLVED
614
636
return False
615
637
if not self .inst .no_backpressure .is_set ():
616
638
self .inst .lock .release ()
@@ -619,7 +641,7 @@ async def enter(self, thread):
619
641
return True
620
642
621
643
async def block_on (self , thread , awaitable , cancellable = False , unlock = False ) -> Cancelled :
622
- assert (thread is self .thread and thread .task is self )
644
+ assert (thread in self .threads and thread .task is self )
623
645
f = asyncio .ensure_future (awaitable )
624
646
if f .done () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
625
647
return
@@ -633,12 +655,13 @@ async def block_on(self, thread, awaitable, cancellable = False, unlock = False)
633
655
await thread .suspend (asyncio .create_task (self .inst .lock .acquire ()))
634
656
635
657
async def wait_for_event (self , thread , waitable_set , cancellable , unlock ) -> EventTuple :
636
- assert (thread is self .thread and thread .task is self )
658
+ assert (thread in self .threads and thread .task is self )
637
659
if cancellable and self .deliver_cancel ():
638
660
return (EventCode .TASK_CANCELLED , 0 , 0 )
639
661
waitable_set .num_waiting += 1
640
662
e = None
641
663
while not e :
664
+ # TODO: somehow get a THREAD_RESUME event...
642
665
maybe_event = waitable_set .maybe_has_pending_event .wait ()
643
666
await self .block_on (thread , maybe_event , cancellable , unlock )
644
667
if self .deliver_cancel ():
@@ -648,16 +671,17 @@ async def wait_for_event(self, thread, waitable_set, cancellable, unlock) -> Eve
648
671
return e
649
672
650
673
async def yield_ (self , thread , cancellable , unlock ) -> EventTuple :
651
- assert (thread is self .thread and thread .task is self )
674
+ assert (thread in self .threads and thread .task is self )
652
675
if cancellable and self .deliver_cancel ():
653
676
return (EventCode .TASK_CANCELLED , 0 , 0 )
677
+ # TODO: somehow get a THREAD_RESUME event...
654
678
await self .block_on (thread , asyncio .sleep (0 ), cancellable , unlock )
655
679
if cancellable and self .deliver_cancel ():
656
680
return (EventCode .TASK_CANCELLED , 0 , 0 )
657
681
return (EventCode .NONE , 0 , 0 )
658
682
659
683
async def poll_for_event (self , thread , waitable_set , cancellable , unlock ) -> Optional [EventTuple ]:
660
- assert (thread is self .thread and thread .task is self )
684
+ assert (thread in self .threads and thread .task is self )
661
685
waitable_set .num_waiting += 1
662
686
event_code ,_ ,_ = e = await self .yield_ (thread , cancellable , unlock )
663
687
waitable_set .num_waiting -= 1
@@ -682,11 +706,16 @@ def cancel(self):
682
706
self .state = Task .State .RESOLVED
683
707
684
708
def exit (self ):
685
- trap_if (self .state != Task .State .RESOLVED )
686
- assert (self .num_borrows == 0 )
687
709
if self .needs_lock ():
688
710
self .inst .lock .release ()
689
711
712
+ def thread_return (self , thread ):
713
+ assert (thread in self .threads and thread .task is self )
714
+ self .threads .remove (thread )
715
+ if len (self .threads ) == 0 :
716
+ trap_if (self .state != Task .State .RESOLVED )
717
+ assert (self .num_borrows == 0 )
718
+
690
719
#### Subtask State
691
720
692
721
class Subtask (Waitable ):
@@ -1965,7 +1994,7 @@ async def thread_func(task, thread):
1965
1994
[packed ] = await call_and_trap_on_throw (opts .callback , thread , [event_code , p1 , p2 ])
1966
1995
1967
1996
task = Task (opts , inst , ft , caller , on_resolve , thread_func )
1968
- await task .thread . resume ()
1997
+ await task .start ()
1969
1998
return task
1970
1999
1971
2000
class CallbackCode (IntEnum ):
@@ -2103,25 +2132,76 @@ async def canon_resource_rep(rt, thread, i):
2103
2132
trap_if (h .rt is not rt )
2104
2133
return [h .rep ]
2105
2134
2135
+ ### 🧵 `canon thread.index`
2136
+
2137
+ async def canon_thread_index (shared , thread ):
2138
+ assert (not shared )
2139
+ return [thread .index ]
2140
+
2141
+ ### 🧵 `canon thread.new_indirect`
2142
+
2143
+ async def canon_thread_new_indirect (ft , ftbl , thread , i , c ):
2144
+ trap_if (not thread .task .inst .may_leave )
2145
+ f = thread .task .inst .ftbl .get (i )
2146
+ trap_if (f .type != ft )
2147
+ thread = Thread (thread .task , f (c ))
2148
+ return [thread .index ]
2149
+
2150
+ ### 🧵 `canon thread.switch`
2151
+
2152
+ async def canon_thread_switch (thread , i ):
2153
+ trap_if (not thread .task .inst .may_leave )
2154
+ other = thread .task .inst .table .get (i )
2155
+ trap_if (not isinstance (other , Thread ))
2156
+ cancelled = await thread .switch (other )
2157
+ return [ 1 if cancelled else 0 ]
2158
+
2159
+ ### 🧵 `canon thread.yield-to`
2160
+
2161
+ async def canon_thread_yield_to (thread , i ):
2162
+ trap_if (not thread .task .inst .may_leave )
2163
+ other = thread .task .inst .table .get (i )
2164
+ trap_if (not isinstance (other , Thread ))
2165
+ other .yield_to (other )
2166
+ return []
2167
+
2168
+ ### 🧵 `canon thread.block`
2169
+
2170
+ async def canon_thread_block (thread , i ):
2171
+ trap_if (not thread .task .inst .may_leave )
2172
+ other = thread .task .inst .table .get (i )
2173
+ trap_if (not isinstance (other , Thread ))
2174
+ cancelled = await thread .block ()
2175
+ return [ 1 if cancelled else 0 ]
2176
+
2177
+ ### 🧵 `canon thread.unblock`
2178
+
2179
+ async def canon_thread_unblock (thread , i ):
2180
+ trap_if (not thread .task .inst .may_leave )
2181
+ other = thread .task .inst .table .get (i )
2182
+ trap_if (not isinstance (other , Thread ))
2183
+ thread .unblock ()
2184
+ return []
2185
+
2106
2186
### 🔀 `canon context.get`
2107
2187
2108
2188
async def canon_context_get (t , i , thread ):
2109
2189
assert (t == 'i32' )
2110
- assert (i < ContextLocalStorage . LENGTH )
2111
- return [thread .task . context . get ( i ) ]
2190
+ assert (i < Thread . CONTEXT_LENGTH )
2191
+ return [thread .context [ i ] ]
2112
2192
2113
2193
### 🔀 `canon context.set`
2114
2194
2115
2195
async def canon_context_set (t , i , thread , v ):
2116
2196
assert (t == 'i32' )
2117
- assert (i < ContextLocalStorage . LENGTH )
2118
- thread .task . context . set ( i , v )
2197
+ assert (i < Thread . CONTEXT_LENGTH )
2198
+ thread .context [ i ] = v
2119
2199
return []
2120
2200
2121
2201
### 🔀 `canon backpressure.set`
2122
2202
2123
2203
async def canon_backpressure_set (thread , flat_args ):
2124
- trap_if (thread .task .opts .sync )
2204
+ # TODO: remove trap_if(thread.task.opts.sync)
2125
2205
assert (len (flat_args ) == 1 )
2126
2206
if flat_args [0 ] == 0 :
2127
2207
thread .task .inst .no_backpressure .set ()
0 commit comments