Skip to content

Commit b169a61

Browse files
committed
Add cooperative threads
1 parent e574cae commit b169a61

File tree

3 files changed

+150
-76
lines changed

3 files changed

+150
-76
lines changed

design/mvp/Async.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ as a parameter.
240240

241241
### Context-Local Storage
242242

243+
TODO: update (also there are 2 now)
244+
243245
Each task contains a distinct mutable **context-local storage** array. The
244246
current task's context-local storage can be read and written from core wasm
245247
code by calling the [`context.get`] and [`context.set`] built-ins.

design/mvp/canonical-abi/definitions.py

Lines changed: 90 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -215,21 +215,28 @@ def tick(self):
215215

216216
class Thread:
217217
task: Task
218+
index: int
218219
future: Optional[asyncio.Future]
219220
on_resume: Optional[asyncio.Future]
220221
on_suspend_or_exit: Optional[asyncio.Future]
222+
context: list[int]
223+
224+
CONTEXT_LENGTH = 2
221225

222226
def __init__(self, task, thread_func):
223227
self.task = task
228+
self.index = task.inst.table.add(self)
224229
self.future = None
225230
self.on_resume = asyncio.Future()
226231
self.on_suspend_or_exit = None
232+
self.context = [0] * Thread.CONTEXT_LENGTH
227233
async def thread_start():
228234
await self.on_resume
229235
self.on_resume = None
230236
await thread_func(task, self)
231237
self.on_suspend_or_exit.set_result(None)
232238
self.task.thread = None
239+
self.task.inst.table.remove(self.index)
233240
asyncio.create_task(thread_start())
234241

235242
async def resume(self):
@@ -254,6 +261,29 @@ async def suspend(self, future):
254261
await self.on_resume
255262
self.on_resume = None
256263

264+
async def switch(self, other: Thread):
265+
assert(not self.future and not other.future)
266+
assert(self.on_suspend_or_exit and not other.on_suspend_or_exit)
267+
other.on_suspend_or_exit = self.on_suspend_or_exit
268+
self.on_suspend_or_exit = None
269+
other.on_resume.set_result(Cancelled.FALSE)
270+
assert(not self.on_resume)
271+
self.on_resume = asyncio.Future()
272+
await self.on_resume
273+
self.on_resume = None
274+
275+
def yield_to(self, other: Thread):
276+
# deterministically switch to other, but leave this thread unblocked
277+
TODO
278+
279+
def block(self):
280+
# perform just the first half of switch
281+
TODO
282+
283+
def unblock(self, other: Thread):
284+
# unblock other, but deterministically keep running here
285+
TODO
286+
257287

258288
### Lifting and Lowering Context
259289

@@ -431,22 +461,6 @@ def write(self, vs):
431461
assert(all(v == () for v in vs))
432462
self.progress += len(vs)
433463

434-
#### Context-Local Storage
435-
436-
class ContextLocalStorage:
437-
LENGTH = 1
438-
array: list[int]
439-
440-
def __init__(self):
441-
self.array = [0] * ContextLocalStorage.LENGTH
442-
443-
def set(self, i, v):
444-
assert(types_match_values(['i32'], [v]))
445-
self.array[i] = v
446-
447-
def get(self, i):
448-
return self.array[i]
449-
450464
#### Waitable State
451465

452466
class EventCode(IntEnum):
@@ -457,6 +471,7 @@ class EventCode(IntEnum):
457471
FUTURE_READ = 4
458472
FUTURE_WRITE = 5
459473
TASK_CANCELLED = 6
474+
THREAD_RESUMED = 7
460475

461476
EventTuple = tuple[EventCode, int, int]
462477

@@ -546,7 +561,6 @@ class State(Enum):
546561
thread: Thread
547562
cancellable: bool
548563
num_borrows: int
549-
context: ContextLocalStorage
550564

551565
def __init__(self, opts, inst, ft, supertask, on_resolve, thread_func):
552566
self.state = Task.State.INITIAL
@@ -558,7 +572,6 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, thread_func):
558572
self.thread = Thread(self, thread_func)
559573
self.cancellable = False
560574
self.num_borrows = 0
561-
self.context = ContextLocalStorage()
562575

563576
def trap_if_on_the_stack(self, inst):
564577
c = self.supertask
@@ -638,6 +651,7 @@ async def wait_for_event(self, thread, waitable_set, cancellable, unlock) -> Eve
638651
waitable_set.num_waiting += 1
639652
e = None
640653
while not e:
654+
# TODO: somehow get a THREAD_RESUME event...
641655
maybe_event = waitable_set.maybe_has_pending_event.wait()
642656
await self.block_on(thread, maybe_event, cancellable, unlock)
643657
if self.deliver_cancel():
@@ -650,6 +664,7 @@ async def yield_(self, thread, cancellable, unlock) -> EventTuple:
650664
assert(self.thread is thread and self is thread.task)
651665
if cancellable and self.deliver_cancel():
652666
return (EventCode.TASK_CANCELLED, 0, 0)
667+
# TODO: somehow get a THREAD_RESUME event...
653668
await self.block_on(thread, asyncio.sleep(0), cancellable, unlock)
654669
if cancellable and self.deliver_cancel():
655670
return (EventCode.TASK_CANCELLED, 0, 0)
@@ -681,7 +696,7 @@ def cancel(self):
681696
self.state = Task.State.RESOLVED
682697

683698
def exit(self):
684-
trap_if(self.state != Task.State.RESOLVED)
699+
trap_if(self.state != Task.State.RESOLVED) # TODO: move this to empty-threads case
685700
assert(self.num_borrows == 0)
686701
if self.needs_lock():
687702
self.inst.lock.release()
@@ -2102,25 +2117,76 @@ async def canon_resource_rep(rt, thread, i):
21022117
trap_if(h.rt is not rt)
21032118
return [h.rep]
21042119

2120+
### 🧵 `canon thread.index`
2121+
2122+
async def canon_thread_index(shared, thread):
2123+
assert(not shared)
2124+
return [thread.index]
2125+
2126+
### 🧵 `canon thread.new_indirect`
2127+
2128+
async def canon_thread_new_indirect(ft, ftbl, thread, i, c):
2129+
trap_if(not thread.task.inst.may_leave)
2130+
f = thread.task.inst.ftbl.get(i)
2131+
trap_if(f.type != ft)
2132+
thread = Thread(thread.task, f(c))
2133+
return [thread.index]
2134+
2135+
### 🧵 `canon thread.switch`
2136+
2137+
async def canon_thread_switch(thread, i):
2138+
trap_if(not thread.task.inst.may_leave)
2139+
other = thread.task.inst.table.get(i)
2140+
trap_if(not isinstance(other, Thread))
2141+
cancelled = await thread.switch(other)
2142+
return [ 1 if cancelled else 0 ]
2143+
2144+
### 🧵 `canon thread.yield-to`
2145+
2146+
async def canon_thread_yield_to(thread, i):
2147+
trap_if(not thread.task.inst.may_leave)
2148+
other = thread.task.inst.table.get(i)
2149+
trap_if(not isinstance(other, Thread))
2150+
other.yield_to(other)
2151+
return []
2152+
2153+
### 🧵 `canon thread.block`
2154+
2155+
async def canon_thread_block(thread, i):
2156+
trap_if(not thread.task.inst.may_leave)
2157+
other = thread.task.inst.table.get(i)
2158+
trap_if(not isinstance(other, Thread))
2159+
cancelled = await thread.block()
2160+
return [ 1 if cancelled else 0 ]
2161+
2162+
### 🧵 `canon thread.unblock`
2163+
2164+
async def canon_thread_unblock(thread, i):
2165+
trap_if(not thread.task.inst.may_leave)
2166+
other = thread.task.inst.table.get(i)
2167+
trap_if(not isinstance(other, Thread))
2168+
thread.unblock()
2169+
return []
2170+
21052171
### 🔀 `canon context.get`
21062172

21072173
async def canon_context_get(t, i, thread):
21082174
assert(t == 'i32')
2109-
assert(i < ContextLocalStorage.LENGTH)
2110-
return [thread.task.context.get(i)]
2175+
assert(i < Thread.CONTEXT_LENGTH)
2176+
return [thread.context[i]]
21112177

21122178
### 🔀 `canon context.set`
21132179

21142180
async def canon_context_set(t, i, thread, v):
21152181
assert(t == 'i32')
2116-
assert(i < ContextLocalStorage.LENGTH)
2117-
thread.task.context.set(i, v)
2182+
assert(i < Thread.CONTEXT_LENGTH)
2183+
thread.context[i] = v
21182184
return []
21192185

21202186
### 🔀 `canon backpressure.set`
21212187

21222188
async def canon_backpressure_set(thread, flat_args):
2123-
trap_if(thread.task.opts.sync)
2189+
# TODO: remove trap_if(thread.task.opts.sync)
21242190
assert(len(flat_args) == 1)
21252191
if flat_args[0] == 0:
21262192
thread.task.inst.no_backpressure.set()

0 commit comments

Comments
 (0)