Skip to content

Commit a18ab3d

Browse files
committed
Add cooperative threads
1 parent 868cd9b commit a18ab3d

File tree

3 files changed

+207
-118
lines changed

3 files changed

+207
-118
lines changed

design/mvp/Async.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ as a parameter.
240240

241241
### Context-Local Storage
242242

243+
TODO: update (also there are 2 now)
244+
243245
Each task contains a distinct mutable **context-local storage** array. The
244246
current task's context-local storage can be read and written from core wasm
245247
code by calling the [`context.get`] and [`context.set`] built-ins.

design/mvp/canonical-abi/definitions.py

Lines changed: 115 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -215,21 +215,28 @@ def tick(self):
215215

216216
class Thread:
217217
task: Task
218+
index: int
218219
future: Optional[asyncio.Future]
219220
on_resume: Optional[asyncio.Future]
220221
on_suspend_or_exit: Optional[asyncio.Future]
222+
context: list[int]
223+
224+
CONTEXT_LENGTH = 2
221225

222226
def __init__(self, task, thread_func):
223227
self.task = task
228+
self.index = task.inst.table.add(self)
224229
self.future = None
225230
self.on_resume = asyncio.Future()
226231
self.on_suspend_or_exit = None
232+
self.context = [0] * Thread.CONTEXT_LENGTH
227233
async def thread_start():
228234
await self.on_resume
229235
self.on_resume = None
230236
await thread_func(task, self)
231237
self.on_suspend_or_exit.set_result(None)
232-
self.task.thread = None
238+
self.task.inst.table.remove(self.index)
239+
self.task.thread_return(self)
233240
asyncio.create_task(thread_start())
234241

235242
async def resume(self):
@@ -254,6 +261,29 @@ async def suspend(self, future):
254261
await self.on_resume
255262
self.on_resume = None
256263

264+
async def switch(self, other: Thread):
265+
assert(not self.future and not other.future)
266+
assert(self.on_suspend_or_exit and not other.on_suspend_or_exit)
267+
other.on_suspend_or_exit = self.on_suspend_or_exit
268+
self.on_suspend_or_exit = None
269+
other.on_resume.set_result(Cancelled.FALSE)
270+
assert(not self.on_resume)
271+
self.on_resume = asyncio.Future()
272+
await self.on_resume
273+
self.on_resume = None
274+
275+
def yield_to(self, other: Thread):
276+
# deterministically switch to other, but leave this thread unblocked
277+
TODO
278+
279+
def block(self):
280+
# perform just the first half of switch
281+
TODO
282+
283+
def unblock(self, other: Thread):
284+
# unblock other, but deterministically keep running here
285+
TODO
286+
257287

258288
### Lifting and Lowering Context
259289

@@ -431,22 +461,6 @@ def write(self, vs):
431461
assert(all(v == () for v in vs))
432462
self.progress += len(vs)
433463

434-
#### Context-Local Storage
435-
436-
class ContextLocalStorage:
437-
LENGTH = 1
438-
array: list[int]
439-
440-
def __init__(self):
441-
self.array = [0] * ContextLocalStorage.LENGTH
442-
443-
def set(self, i, v):
444-
assert(types_match_values(['i32'], [v]))
445-
self.array[i] = v
446-
447-
def get(self, i):
448-
return self.array[i]
449-
450464
#### Waitable State
451465

452466
class EventCode(IntEnum):
@@ -457,6 +471,7 @@ class EventCode(IntEnum):
457471
FUTURE_READ = 4
458472
FUTURE_WRITE = 5
459473
TASK_CANCELLED = 6
474+
THREAD_RESUMED = 7
460475

461476
EventTuple = tuple[EventCode, int, int]
462477

@@ -543,10 +558,9 @@ class State(Enum):
543558
ft: FuncType
544559
supertask: Optional[Task]
545560
on_resolve: OnResolve
546-
thread: Thread
561+
threads: list[Thread]
547562
cancellable: bool
548563
num_borrows: int
549-
context: ContextLocalStorage
550564

551565
def __init__(self, opts, inst, ft, supertask, on_resolve, thread_func):
552566
self.state = Task.State.INITIAL
@@ -555,10 +569,13 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, thread_func):
555569
self.ft = ft
556570
self.supertask = supertask
557571
self.on_resolve = on_resolve
558-
self.thread = Thread(self, thread_func)
572+
self.threads = [Thread(self, thread_func)]
559573
self.cancellable = False
560574
self.num_borrows = 0
561-
self.context = ContextLocalStorage()
575+
576+
async def start(self):
577+
assert(len(self.threads) == 1)
578+
await self.threads[0].resume()
562579

563580
def trap_if_on_the_stack(self, inst):
564581
c = self.supertask
@@ -569,8 +586,10 @@ def trap_if_on_the_stack(self, inst):
569586
async def request_cancellation(self):
570587
assert(self.state == Task.State.INITIAL)
571588
self.state = Task.State.PENDING_CANCEL
589+
# TODO: move cancellability to the Thread and then search
590+
# for a cancellable one here...
572591
if self.cancellable:
573-
await self.thread.resume()
592+
await self.threads[0].resume()
574593

575594
def deliver_cancel(self) -> bool:
576595
if self.state == Task.State.PENDING_CANCEL:
@@ -582,8 +601,9 @@ def deliver_cancel(self) -> bool:
582601
def needs_lock(self):
583602
return self.opts.sync or self.opts.callback
584603

604+
# TODO: somehow break this up...
585605
async def enter(self, thread):
586-
assert(thread is self.thread and thread.task is self)
606+
assert(thread in self.threads and thread.task is self)
587607
if (self.inst.no_backpressure.is_set() and
588608
self.inst.num_pending == 0 and
589609
(not self.needs_lock() or not self.inst.lock.locked())):
@@ -599,6 +619,7 @@ async def enter(self, thread):
599619
self.inst.num_pending -= 1
600620
if self.deliver_cancel():
601621
self.on_resolve(None)
622+
self.state = Task.State.RESOLVED
602623
return False
603624
if not self.inst.no_backpressure.is_set():
604625
continue
@@ -611,6 +632,7 @@ async def enter(self, thread):
611632
else:
612633
acquired.cancel()
613634
self.on_resolve(None)
635+
self.state = Task.State.RESOLVED
614636
return False
615637
if not self.inst.no_backpressure.is_set():
616638
self.inst.lock.release()
@@ -619,7 +641,7 @@ async def enter(self, thread):
619641
return True
620642

621643
async def block_on(self, thread, awaitable, cancellable = False, unlock = False) -> Cancelled:
622-
assert(thread is self.thread and thread.task is self)
644+
assert(thread in self.threads and thread.task is self)
623645
f = asyncio.ensure_future(awaitable)
624646
if f.done() and not DETERMINISTIC_PROFILE and random.randint(0,1):
625647
return
@@ -633,12 +655,13 @@ async def block_on(self, thread, awaitable, cancellable = False, unlock = False)
633655
await thread.suspend(asyncio.create_task(self.inst.lock.acquire()))
634656

635657
async def wait_for_event(self, thread, waitable_set, cancellable, unlock) -> EventTuple:
636-
assert(thread is self.thread and thread.task is self)
658+
assert(thread in self.threads and thread.task is self)
637659
if cancellable and self.deliver_cancel():
638660
return (EventCode.TASK_CANCELLED, 0, 0)
639661
waitable_set.num_waiting += 1
640662
e = None
641663
while not e:
664+
# TODO: somehow get a THREAD_RESUME event...
642665
maybe_event = waitable_set.maybe_has_pending_event.wait()
643666
await self.block_on(thread, maybe_event, cancellable, unlock)
644667
if self.deliver_cancel():
@@ -648,16 +671,17 @@ async def wait_for_event(self, thread, waitable_set, cancellable, unlock) -> Eve
648671
return e
649672

650673
async def yield_(self, thread, cancellable, unlock) -> EventTuple:
651-
assert(thread is self.thread and thread.task is self)
674+
assert(thread in self.threads and thread.task is self)
652675
if cancellable and self.deliver_cancel():
653676
return (EventCode.TASK_CANCELLED, 0, 0)
677+
# TODO: somehow get a THREAD_RESUME event...
654678
await self.block_on(thread, asyncio.sleep(0), cancellable, unlock)
655679
if cancellable and self.deliver_cancel():
656680
return (EventCode.TASK_CANCELLED, 0, 0)
657681
return (EventCode.NONE, 0, 0)
658682

659683
async def poll_for_event(self, thread, waitable_set, cancellable, unlock) -> Optional[EventTuple]:
660-
assert(thread is self.thread and thread.task is self)
684+
assert(thread in self.threads and thread.task is self)
661685
waitable_set.num_waiting += 1
662686
event_code,_,_ = e = await self.yield_(thread, cancellable, unlock)
663687
waitable_set.num_waiting -= 1
@@ -682,11 +706,16 @@ def cancel(self):
682706
self.state = Task.State.RESOLVED
683707

684708
def exit(self):
685-
trap_if(self.state != Task.State.RESOLVED)
686-
assert(self.num_borrows == 0)
687709
if self.needs_lock():
688710
self.inst.lock.release()
689711

712+
def thread_return(self, thread):
713+
assert(thread in self.threads and thread.task is self)
714+
self.threads.remove(thread)
715+
if len(self.threads) == 0:
716+
trap_if(self.state != Task.State.RESOLVED)
717+
assert(self.num_borrows == 0)
718+
690719
#### Subtask State
691720

692721
class Subtask(Waitable):
@@ -1965,7 +1994,7 @@ async def thread_func(task, thread):
19651994
[packed] = await call_and_trap_on_throw(opts.callback, thread, [event_code, p1, p2])
19661995

19671996
task = Task(opts, inst, ft, caller, on_resolve, thread_func)
1968-
await task.thread.resume()
1997+
await task.start()
19691998
return task
19701999

19712000
class CallbackCode(IntEnum):
@@ -2103,25 +2132,76 @@ async def canon_resource_rep(rt, thread, i):
21032132
trap_if(h.rt is not rt)
21042133
return [h.rep]
21052134

2135+
### 🧵 `canon thread.index`
2136+
2137+
async def canon_thread_index(shared, thread):
2138+
assert(not shared)
2139+
return [thread.index]
2140+
2141+
### 🧵 `canon thread.new_indirect`
2142+
2143+
async def canon_thread_new_indirect(ft, ftbl, thread, i, c):
2144+
trap_if(not thread.task.inst.may_leave)
2145+
f = thread.task.inst.ftbl.get(i)
2146+
trap_if(f.type != ft)
2147+
thread = Thread(thread.task, f(c))
2148+
return [thread.index]
2149+
2150+
### 🧵 `canon thread.switch`
2151+
2152+
async def canon_thread_switch(thread, i):
2153+
trap_if(not thread.task.inst.may_leave)
2154+
other = thread.task.inst.table.get(i)
2155+
trap_if(not isinstance(other, Thread))
2156+
cancelled = await thread.switch(other)
2157+
return [ 1 if cancelled else 0 ]
2158+
2159+
### 🧵 `canon thread.yield-to`
2160+
2161+
async def canon_thread_yield_to(thread, i):
2162+
trap_if(not thread.task.inst.may_leave)
2163+
other = thread.task.inst.table.get(i)
2164+
trap_if(not isinstance(other, Thread))
2165+
other.yield_to(other)
2166+
return []
2167+
2168+
### 🧵 `canon thread.block`
2169+
2170+
async def canon_thread_block(thread, i):
2171+
trap_if(not thread.task.inst.may_leave)
2172+
other = thread.task.inst.table.get(i)
2173+
trap_if(not isinstance(other, Thread))
2174+
cancelled = await thread.block()
2175+
return [ 1 if cancelled else 0 ]
2176+
2177+
### 🧵 `canon thread.unblock`
2178+
2179+
async def canon_thread_unblock(thread, i):
2180+
trap_if(not thread.task.inst.may_leave)
2181+
other = thread.task.inst.table.get(i)
2182+
trap_if(not isinstance(other, Thread))
2183+
thread.unblock()
2184+
return []
2185+
21062186
### 🔀 `canon context.get`
21072187

21082188
async def canon_context_get(t, i, thread):
21092189
assert(t == 'i32')
2110-
assert(i < ContextLocalStorage.LENGTH)
2111-
return [thread.task.context.get(i)]
2190+
assert(i < Thread.CONTEXT_LENGTH)
2191+
return [thread.context[i]]
21122192

21132193
### 🔀 `canon context.set`
21142194

21152195
async def canon_context_set(t, i, thread, v):
21162196
assert(t == 'i32')
2117-
assert(i < ContextLocalStorage.LENGTH)
2118-
thread.task.context.set(i, v)
2197+
assert(i < Thread.CONTEXT_LENGTH)
2198+
thread.context[i] = v
21192199
return []
21202200

21212201
### 🔀 `canon backpressure.set`
21222202

21232203
async def canon_backpressure_set(thread, flat_args):
2124-
trap_if(thread.task.opts.sync)
2204+
# TODO: remove trap_if(thread.task.opts.sync)
21252205
assert(len(flat_args) == 1)
21262206
if flat_args[0] == 0:
21272207
thread.task.inst.no_backpressure.set()

0 commit comments

Comments
 (0)