Skip to content

Commit cd89ca3

Browse files
committed
Add cooperative threads
1 parent e8fcd93 commit cd89ca3

File tree

2 files changed

+150
-74
lines changed

2 files changed

+150
-74
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 92 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -354,22 +354,6 @@ def write(self, vs):
354354
assert(all(v == () for v in vs))
355355
self.progress += len(vs)
356356

357-
#### Context-Local Storage
358-
359-
class ContextLocalStorage:
360-
LENGTH = 1
361-
array: list[int]
362-
363-
def __init__(self):
364-
self.array = [0] * ContextLocalStorage.LENGTH
365-
366-
def set(self, i, v):
367-
assert(types_match_values(['i32'], [v]))
368-
self.array[i] = v
369-
370-
def get(self, i):
371-
return self.array[i]
372-
373357
#### Waitable State
374358

375359
class EventCode(IntEnum):
@@ -471,7 +455,6 @@ class State(Enum):
471455
supertask: Optional[Task]
472456
on_resolve: Callable[[Optional[list[any]]], None]
473457
num_borrows: int
474-
context: ContextLocalStorage
475458

476459
def __init__(self, opts, inst, ft, supertask, on_resolve):
477460
self.state = Task.State.INITIAL
@@ -481,7 +464,6 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
481464
self.supertask = supertask
482465
self.on_resolve = on_resolve
483466
self.num_borrows = 0
484-
self.context = ContextLocalStorage()
485467

486468
def trap_if_on_the_stack(self, inst):
487469
c = self.supertask
@@ -872,22 +854,29 @@ def drop(self):
872854

873855
class Thread:
874856
task: Task
857+
index: int
875858
future: Optional[asyncio.Future]
876859
on_resume: Optional[asyncio.Future]
877860
on_suspend_or_exit: Optional[asyncio.Future]
878861
returned: bool
862+
context: list[int]
863+
864+
CONTEXT_LENGTH = 1
879865

880866
def __init__(self, task, coro):
881867
self.task = task
868+
self.index = task.inst.table.add(self)
882869
self.future = None
883870
self.on_resume = asyncio.Future()
884871
self.on_suspend_or_exit = None
885872
self.returned = False
873+
self.context = [0] * Thread.CONTEXT_LENGTH
886874
async def async_impl():
887875
assert(await self.on_resume == Cancelled.FALSE)
888876
self.on_resume = None
889877
await coro
890878
self.on_suspend_or_exit.set_result(None)
879+
self.task.inst.table.remove(self.index)
891880
self.returned = True
892881
asyncio.create_task(async_impl())
893882

@@ -915,6 +904,30 @@ async def suspend(self, future) -> Cancelled:
915904
self.on_resume = None
916905
return cancelled
917906

907+
async def switch(self, other: Thread) -> Cancelled:
908+
assert(not self.future and not other.future)
909+
assert(self.on_suspend_or_exit and not other.on_suspend_or_exit)
910+
other.on_suspend_or_exit = self.on_suspend_or_exit
911+
self.on_suspend_or_exit = None
912+
other.on_resume.set_result(Cancelled.FALSE)
913+
assert(not self.on_resume)
914+
self.on_resume = asyncio.Future()
915+
cancelled = await self.on_resume
916+
self.on_resume = None
917+
return cancelled
918+
919+
def yield_to(self, other: Thread) -> Cancelled:
920+
# deterministically switch to other, but leave this thread unblocked
921+
TODO
922+
923+
def unblock(self, other: Thread):
924+
# unblock other, but deterministically keep running here
925+
TODO
926+
927+
def wait(self) -> Cancelled:
928+
# perform just the first half of switch
929+
TODO
930+
918931
#### Store State / Embedding API
919932

920933
class Store:
@@ -2095,19 +2108,76 @@ async def canon_resource_rep(rt, thread, i):
20952108
trap_if(h.rt is not rt)
20962109
return [h.rep]
20972110

2111+
### 🧵 `canon thread.index`
2112+
2113+
async def canon_thread_index(shared, thread):
2114+
assert(not shared)
2115+
return [thread.index]
2116+
2117+
### 🧵 `canon thread.new_indirect`
2118+
2119+
async def canon_thread_new_indirect(shared, ft, ftbl, thread, i, c):
2120+
assert(not shared)
2121+
inst = thread.task.inst
2122+
trap_if(not inst.may_leave)
2123+
f = ftbl.get(i)
2124+
trap_if(f is None)
2125+
trap_if(f.type != ft)
2126+
thread = Thread(thread.task, f(c))
2127+
return [thread.index]
2128+
2129+
### 🧵 `canon thread.switch`
2130+
2131+
async def canon_thread_switch(shared, thread, i):
2132+
assert(not shared)
2133+
trap_if(not thread.task.inst.may_leave)
2134+
other = thread.task.inst.table.get(i)
2135+
trap_if(not isinstance(other, Thread))
2136+
cancelled = await thread.switch(other)
2137+
return [ 1 if cancelled else 0 ]
2138+
2139+
### 🧵 `canon thread.yield-to`
2140+
2141+
async def canon_thread_yield_to(shared, thread, i):
2142+
assert(not shared)
2143+
trap_if(not thread.task.inst.may_leave)
2144+
other = thread.task.inst.table.get(i)
2145+
trap_if(not isinstance(other, Thread))
2146+
other.yield_to(other)
2147+
return []
2148+
2149+
### 🧵 `canon thread.unblock`
2150+
2151+
async def canon_thread_unblock(shared, thread, i):
2152+
trap_if(not thread.task.inst.may_leave)
2153+
other = thread.task.inst.table.get(i)
2154+
trap_if(not isinstance(other, Thread))
2155+
thread.unblock()
2156+
return []
2157+
2158+
### 🧵 `canon thread.wait`
2159+
2160+
async def canon_thread_wait(shared, thread, i):
2161+
assert(not shared)
2162+
trap_if(not thread.task.inst.may_leave)
2163+
other = thread.task.inst.table.get(i)
2164+
trap_if(not isinstance(other, Thread))
2165+
cancelled = await thread.suspend()
2166+
return [ 1 if cancelled else 0 ]
2167+
20982168
### 🔀 `canon context.get`
20992169

21002170
async def canon_context_get(t, i, thread):
21012171
assert(t == 'i32')
2102-
assert(i < ContextLocalStorage.LENGTH)
2103-
return [thread.task.context.get(i)]
2172+
assert(i < Thread.CONTEXT_LENGTH)
2173+
return [thread.context[i]]
21042174

21052175
### 🔀 `canon context.set`
21062176

21072177
async def canon_context_set(t, i, thread, v):
21082178
assert(t == 'i32')
2109-
assert(i < ContextLocalStorage.LENGTH)
2110-
thread.task.context.set(i, v)
2179+
assert(i < Thread.CONTEXT_LENGTH)
2180+
thread.context[i] = v
21112181
return []
21122182

21132183
### 🔀 `canon backpressure.set`

0 commit comments

Comments
 (0)