Skip to content

Commit 3eb817a

Browse files
committed
Add cooperative threads
1 parent 868cd9b commit 3eb817a

File tree

3 files changed

+206
-118
lines changed

3 files changed

+206
-118
lines changed

design/mvp/Async.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ as a parameter.
240240

241241
### Context-Local Storage
242242

243+
TODO: update (also there are 2 now)
244+
243245
Each task contains a distinct mutable **context-local storage** array. The
244246
current task's context-local storage can be read and written from core wasm
245247
code by calling the [`context.get`] and [`context.set`] built-ins.

design/mvp/canonical-abi/definitions.py

Lines changed: 114 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -215,21 +215,28 @@ def tick(self):
215215

216216
class Thread:
217217
task: Task
218+
index: int
218219
future: Optional[asyncio.Future]
219220
on_resume: Optional[asyncio.Future]
220221
on_suspend_or_exit: Optional[asyncio.Future]
222+
context: list[int]
223+
224+
CONTEXT_LENGTH = 2
221225

222226
def __init__(self, task, thread_func):
223227
self.task = task
228+
self.index = task.inst.table.add(self)
224229
self.future = None
225230
self.on_resume = asyncio.Future()
226231
self.on_suspend_or_exit = None
232+
self.context = [0] * Thread.CONTEXT_LENGTH
227233
async def thread_start():
228234
await self.on_resume
229235
self.on_resume = None
230236
await thread_func(task, self)
231237
self.on_suspend_or_exit.set_result(None)
232-
self.task.thread = None
238+
self.task.inst.table.remove(self.index)
239+
self.task.thread_return(self)
233240
asyncio.create_task(thread_start())
234241

235242
async def resume(self):
@@ -254,6 +261,29 @@ async def suspend(self, future):
254261
await self.on_resume
255262
self.on_resume = None
256263

264+
async def switch(self, other: Thread):
265+
assert(not self.future and not other.future)
266+
assert(self.on_suspend_or_exit and not other.on_suspend_or_exit)
267+
other.on_suspend_or_exit = self.on_suspend_or_exit
268+
self.on_suspend_or_exit = None
269+
other.on_resume.set_result(Cancelled.FALSE)
270+
assert(not self.on_resume)
271+
self.on_resume = asyncio.Future()
272+
await self.on_resume
273+
self.on_resume = None
274+
275+
def yield_to(self, other: Thread):
276+
# deterministically switch to other, but leave this thread unblocked
277+
TODO
278+
279+
def block(self):
280+
# perform just the first half of switch
281+
TODO
282+
283+
def unblock(self, other: Thread):
284+
# unblock other, but deterministically keep running here
285+
TODO
286+
257287

258288
### Lifting and Lowering Context
259289

@@ -431,22 +461,6 @@ def write(self, vs):
431461
assert(all(v == () for v in vs))
432462
self.progress += len(vs)
433463

434-
#### Context-Local Storage
435-
436-
class ContextLocalStorage:
437-
LENGTH = 1
438-
array: list[int]
439-
440-
def __init__(self):
441-
self.array = [0] * ContextLocalStorage.LENGTH
442-
443-
def set(self, i, v):
444-
assert(types_match_values(['i32'], [v]))
445-
self.array[i] = v
446-
447-
def get(self, i):
448-
return self.array[i]
449-
450464
#### Waitable State
451465

452466
class EventCode(IntEnum):
@@ -457,6 +471,7 @@ class EventCode(IntEnum):
457471
FUTURE_READ = 4
458472
FUTURE_WRITE = 5
459473
TASK_CANCELLED = 6
474+
THREAD_RESUMED = 7
460475

461476
EventTuple = tuple[EventCode, int, int]
462477

@@ -543,10 +558,9 @@ class State(Enum):
543558
ft: FuncType
544559
supertask: Optional[Task]
545560
on_resolve: OnResolve
546-
thread: Thread
561+
threads: list[Thread]
547562
cancellable: bool
548563
num_borrows: int
549-
context: ContextLocalStorage
550564

551565
def __init__(self, opts, inst, ft, supertask, on_resolve, thread_func):
552566
self.state = Task.State.INITIAL
@@ -555,10 +569,13 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, thread_func):
555569
self.ft = ft
556570
self.supertask = supertask
557571
self.on_resolve = on_resolve
558-
self.thread = Thread(self, thread_func)
572+
self.threads = [Thread(self, thread_func)]
559573
self.cancellable = False
560574
self.num_borrows = 0
561-
self.context = ContextLocalStorage()
575+
576+
async def start(self):
577+
assert(len(self.threads) == 1)
578+
await self.threads[0].resume()
562579

563580
def trap_if_on_the_stack(self, inst):
564581
c = self.supertask
@@ -569,8 +586,10 @@ def trap_if_on_the_stack(self, inst):
569586
async def request_cancellation(self):
570587
assert(self.state == Task.State.INITIAL)
571588
self.state = Task.State.PENDING_CANCEL
589+
# TODO: move cancellability to the Thread and then search
590+
# for a cancellable one here...
572591
if self.cancellable:
573-
await self.thread.resume()
592+
await self.threads[0].resume()
574593

575594
def deliver_cancel(self) -> bool:
576595
if self.state == Task.State.PENDING_CANCEL:
@@ -583,7 +602,7 @@ def needs_lock(self):
583602
return self.opts.sync or self.opts.callback
584603

585604
async def enter(self, thread):
586-
assert(thread is self.thread and thread.task is self)
605+
assert(thread in self.threads and thread.task is self)
587606
if (self.inst.no_backpressure.is_set() and
588607
self.inst.num_pending == 0 and
589608
(not self.needs_lock() or not self.inst.lock.locked())):
@@ -599,6 +618,7 @@ async def enter(self, thread):
599618
self.inst.num_pending -= 1
600619
if self.deliver_cancel():
601620
self.on_resolve(None)
621+
self.state = Task.State.RESOLVED
602622
return False
603623
if not self.inst.no_backpressure.is_set():
604624
continue
@@ -611,6 +631,7 @@ async def enter(self, thread):
611631
else:
612632
acquired.cancel()
613633
self.on_resolve(None)
634+
self.state = Task.State.RESOLVED
614635
return False
615636
if not self.inst.no_backpressure.is_set():
616637
self.inst.lock.release()
@@ -619,7 +640,7 @@ async def enter(self, thread):
619640
return True
620641

621642
async def block_on(self, thread, awaitable, cancellable = False, unlock = False) -> Cancelled:
622-
assert(thread is self.thread and thread.task is self)
643+
assert(thread in self.threads and thread.task is self)
623644
f = asyncio.ensure_future(awaitable)
624645
if f.done() and not DETERMINISTIC_PROFILE and random.randint(0,1):
625646
return
@@ -633,12 +654,13 @@ async def block_on(self, thread, awaitable, cancellable = False, unlock = False)
633654
await thread.suspend(asyncio.create_task(self.inst.lock.acquire()))
634655

635656
async def wait_for_event(self, thread, waitable_set, cancellable, unlock) -> EventTuple:
636-
assert(thread is self.thread and thread.task is self)
657+
assert(thread in self.threads and thread.task is self)
637658
if cancellable and self.deliver_cancel():
638659
return (EventCode.TASK_CANCELLED, 0, 0)
639660
waitable_set.num_waiting += 1
640661
e = None
641662
while not e:
663+
# TODO: somehow get a THREAD_RESUME event...
642664
maybe_event = waitable_set.maybe_has_pending_event.wait()
643665
await self.block_on(thread, maybe_event, cancellable, unlock)
644666
if self.deliver_cancel():
@@ -648,16 +670,17 @@ async def wait_for_event(self, thread, waitable_set, cancellable, unlock) -> Eve
648670
return e
649671

650672
async def yield_(self, thread, cancellable, unlock) -> EventTuple:
651-
assert(thread is self.thread and thread.task is self)
673+
assert(thread in self.threads and thread.task is self)
652674
if cancellable and self.deliver_cancel():
653675
return (EventCode.TASK_CANCELLED, 0, 0)
676+
# TODO: somehow get a THREAD_RESUME event...
654677
await self.block_on(thread, asyncio.sleep(0), cancellable, unlock)
655678
if cancellable and self.deliver_cancel():
656679
return (EventCode.TASK_CANCELLED, 0, 0)
657680
return (EventCode.NONE, 0, 0)
658681

659682
async def poll_for_event(self, thread, waitable_set, cancellable, unlock) -> Optional[EventTuple]:
660-
assert(thread is self.thread and thread.task is self)
683+
assert(thread in self.threads and thread.task is self)
661684
waitable_set.num_waiting += 1
662685
event_code,_,_ = e = await self.yield_(thread, cancellable, unlock)
663686
waitable_set.num_waiting -= 1
@@ -682,11 +705,16 @@ def cancel(self):
682705
self.state = Task.State.RESOLVED
683706

684707
def exit(self):
685-
trap_if(self.state != Task.State.RESOLVED)
686-
assert(self.num_borrows == 0)
687708
if self.needs_lock():
688709
self.inst.lock.release()
689710

711+
def thread_return(self, thread):
712+
assert(thread in self.threads and thread.task is self)
713+
self.threads.remove(thread)
714+
if len(self.threads) == 0:
715+
trap_if(self.state != Task.State.RESOLVED)
716+
assert(self.num_borrows == 0)
717+
690718
#### Subtask State
691719

692720
class Subtask(Waitable):
@@ -1965,7 +1993,7 @@ async def thread_func(task, thread):
19651993
[packed] = await call_and_trap_on_throw(opts.callback, thread, [event_code, p1, p2])
19661994

19671995
task = Task(opts, inst, ft, caller, on_resolve, thread_func)
1968-
await task.thread.resume()
1996+
await task.start()
19691997
return task
19701998

19711999
class CallbackCode(IntEnum):
@@ -2103,25 +2131,76 @@ async def canon_resource_rep(rt, thread, i):
21032131
trap_if(h.rt is not rt)
21042132
return [h.rep]
21052133

2134+
### 🧵 `canon thread.index`
2135+
2136+
async def canon_thread_index(shared, thread):
2137+
assert(not shared)
2138+
return [thread.index]
2139+
2140+
### 🧵 `canon thread.new_indirect`
2141+
2142+
async def canon_thread_new_indirect(ft, ftbl, thread, i, c):
2143+
trap_if(not thread.task.inst.may_leave)
2144+
f = thread.task.inst.ftbl.get(i)
2145+
trap_if(f.type != ft)
2146+
thread = Thread(thread.task, f(c))
2147+
return [thread.index]
2148+
2149+
### 🧵 `canon thread.switch`
2150+
2151+
async def canon_thread_switch(thread, i):
2152+
trap_if(not thread.task.inst.may_leave)
2153+
other = thread.task.inst.table.get(i)
2154+
trap_if(not isinstance(other, Thread))
2155+
cancelled = await thread.switch(other)
2156+
return [ 1 if cancelled else 0 ]
2157+
2158+
### 🧵 `canon thread.yield-to`
2159+
2160+
async def canon_thread_yield_to(thread, i):
2161+
trap_if(not thread.task.inst.may_leave)
2162+
other = thread.task.inst.table.get(i)
2163+
trap_if(not isinstance(other, Thread))
2164+
other.yield_to(other)
2165+
return []
2166+
2167+
### 🧵 `canon thread.block`
2168+
2169+
async def canon_thread_block(thread, i):
2170+
trap_if(not thread.task.inst.may_leave)
2171+
other = thread.task.inst.table.get(i)
2172+
trap_if(not isinstance(other, Thread))
2173+
cancelled = await thread.block()
2174+
return [ 1 if cancelled else 0 ]
2175+
2176+
### 🧵 `canon thread.unblock`
2177+
2178+
async def canon_thread_unblock(thread, i):
2179+
trap_if(not thread.task.inst.may_leave)
2180+
other = thread.task.inst.table.get(i)
2181+
trap_if(not isinstance(other, Thread))
2182+
thread.unblock()
2183+
return []
2184+
21062185
### 🔀 `canon context.get`
21072186

21082187
async def canon_context_get(t, i, thread):
21092188
assert(t == 'i32')
2110-
assert(i < ContextLocalStorage.LENGTH)
2111-
return [thread.task.context.get(i)]
2189+
assert(i < Thread.CONTEXT_LENGTH)
2190+
return [thread.context[i]]
21122191

21132192
### 🔀 `canon context.set`
21142193

21152194
async def canon_context_set(t, i, thread, v):
21162195
assert(t == 'i32')
2117-
assert(i < ContextLocalStorage.LENGTH)
2118-
thread.task.context.set(i, v)
2196+
assert(i < Thread.CONTEXT_LENGTH)
2197+
thread.context[i] = v
21192198
return []
21202199

21212200
### 🔀 `canon backpressure.set`
21222201

21232202
async def canon_backpressure_set(thread, flat_args):
2124-
trap_if(thread.task.opts.sync)
2203+
# TODO: remove trap_if(thread.task.opts.sync)
21252204
assert(len(flat_args) == 1)
21262205
if flat_args[0] == 0:
21272206
thread.task.inst.no_backpressure.set()

0 commit comments

Comments
 (0)