Skip to content

Commit b20509d

Browse files
committed
feat: API to avoid deadlocks from dropped promises
1 parent b65715e commit b20509d

23 files changed

+194
-104
lines changed

Diff for: src/Init/System/IO.lean

+6-1
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,12 @@ protected def TaskState.toString : TaskState → String
238238

239239
instance : ToString TaskState := ⟨TaskState.toString⟩
240240

241-
/-- Returns current state of the `Task` in the Lean runtime's task manager. -/
241+
/--
242+
Returns current state of the `Task` in the Lean runtime's task manager.
243+
244+
Note that for tasks derived from `Promise`s, `waiting` and `running` should be considered
245+
equivalent.
246+
-/
242247
@[extern "lean_io_get_task_state"] opaque getTaskState : @& Task α → BaseIO TaskState
243248

244249
/-- Check if the task has finished execution, at which point calling `Task.get` will return immediately. -/

Diff for: src/Init/System/Promise.lean

+30-11
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@ private structure PromiseImpl (α : Type) : Type where
2121
2222
Typical usage is as follows:
2323
1. `let promise ← Promise.new` creates a promise
24-
2. `promise.result : Task α` can now be passed around
25-
3. `promise.result.get` blocks until the promise is resolved
24+
2. `promise.result? : Task (Option α)` can now be passed around
25+
3. `promise.result?.get` blocks until the promise is resolved
2626
4. `promise.resolve a` resolves the promise
27-
5. `promise.result.get` now returns `a`
27+
5. `promise.result?.get` now returns `some a`
2828
29-
Every promise must eventually be resolved.
30-
Otherwise the memory used for the promise will be leaked,
31-
and any tasks depending on the promise's result will wait forever.
29+
If the promise is dropped without ever being resolved, `promise.result.get` will panic and return
30+
`default : α`. See `Promise.result?/resultD` for other ways to handle this case.
3231
-/
3332
def Promise (α : Type) : Type := PromiseImpl α
3433

@@ -47,12 +46,32 @@ Only the first call to this function has an effect.
4746
@[extern "lean_io_promise_resolve"]
4847
opaque Promise.resolve (value : α) (promise : @& Promise α) : BaseIO Unit
4948

49+
/--
50+
Like `Promise.result`, but resolves to `none` if the promise is dropped without ever being resolved.
51+
-/
52+
@[extern "lean_io_promise_result_opt"]
53+
opaque Promise.result? (promise : @& Promise α) : Task (Option α)
54+
55+
-- TODO: publicize?
56+
@[extern "lean_option_get_or_block"]
57+
private opaque Option.getOrBlock! [Nonempty α] : Option α → α
58+
5059
/--
5160
The result task of a `Promise`.
5261
53-
The task blocks until `Promise.resolve` is called.
62+
The task blocks until `Promise.resolve` is called. If the promise is dropped without ever being
63+
resolved, evaluating the task will panic and, when not using fatal panics, block forever. Use
64+
`Promise.result?` to handle this case explicitly.
65+
-/
66+
def Promise.result! (promise : @& Promise α) : Task α :=
67+
let _ : Nonempty α := promise.h
68+
promise.result?.map (sync := true) Option.getOrBlock!
69+
70+
@[inherit_doc Promise.result!, deprecated Promise.result! (since := "2025-02-05")]
71+
def Promise.result := @Promise.result!
72+
73+
/--
74+
Like `Promise.result`, but resolves to `dflt` if the promise is dropped without ever being resolved.
5475
-/
55-
@[extern "lean_io_promise_result"]
56-
opaque Promise.result (promise : Promise α) : Task α :=
57-
have : Nonempty α := promise.h
58-
Classical.choice inferInstance
76+
def Promise.resultD (promise : Promise α) (dflt : α): Task α :=
77+
promise.result?.map (sync := true) (·.getD dflt)

Diff for: src/Lean/AddDecl.lean

+3-6
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,9 @@ def addDecl (decl : Declaration) : CoreM Unit := do
7777
async.commitConst async.asyncEnv (some info)
7878
setEnv async.mainEnv
7979
let checkAct ← Core.wrapAsyncAsSnapshot fun _ => do
80-
try
81-
setEnv async.asyncEnv
82-
doAdd
83-
async.commitCheckEnv (← getEnv)
84-
finally
85-
async.commitFailure
80+
setEnv async.asyncEnv
81+
doAdd
82+
async.commitCheckEnv (← getEnv)
8683
let t ← BaseIO.mapTask (fun _ => checkAct) env.checked
8784
let endRange? := (← getRef).getTailPos?.map fun pos => ⟨pos, pos⟩
8885
Core.logSnapshotTask { range? := endRange?, task := t }

Diff for: src/Lean/Elab/Command.lean

+1-1
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ partial def elabCommand (stx : Syntax) : CommandElabM Unit := do
497497
newNextMacroScope := nextMacroScope
498498
hasTraces
499499
next := Array.zipWith (fun cmdPromise cmd =>
500-
{ range? := cmd.getRange?, task := cmdPromise.result }) cmdPromises cmds
500+
{ range? := cmd.getRange?, task := cmdPromise.resultD default }) cmdPromises cmds
501501
: MacroExpandedSnapshot
502502
}
503503
-- After the first command whose syntax tree changed, we must disable

Diff for: src/Lean/Elab/MutualDef.lean

+3-3
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ private def elabHeaders (views : Array DefView)
222222
view.ref.getPos?.map fun pos => ⟨pos, pos⟩
223223
else
224224
getBodyTerm? view.value |>.getD view.value |>.getRange?
225-
task := bodyPromise.result }
225+
task := bodyPromise.resultD default }
226226
snap.new.resolve <| some {
227227
diagnostics :=
228228
(← Language.Snapshot.Diagnostics.ofMessageLog (← Core.getAndEmptyMessageLog))
@@ -263,7 +263,7 @@ where
263263
:= do
264264
if let some e := getBodyTerm? body then
265265
if let `(by $tacs*) := e then
266-
return (e, some { range? := mkNullNode tacs |>.getRange?, task := tacPromise.result })
266+
return (e, some { range? := mkNullNode tacs |>.getRange?, task := tacPromise.resultD default })
267267
tacPromise.resolve default
268268
return (none, none)
269269

@@ -1077,7 +1077,7 @@ def elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
10771077
} }
10781078
defs := defs.push {
10791079
fullHeaderRef
1080-
headerProcessedSnap := { range? := d.getRange?, task := headerPromise.result }
1080+
headerProcessedSnap := { range? := d.getRange?, task := headerPromise.resultD default }
10811081
}
10821082
reusedAllHeaders := reusedAllHeaders && view.headerSnap?.any (·.old?.isSome)
10831083
views := views.push view

Diff for: src/Lean/Elab/Tactic/Basic.lean

+1-1
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ where
234234
diagnostics := .empty
235235
state? := (← Tactic.saveState)
236236
}
237-
next := #[{ range? := stx'.getRange?, task := promise.result }]
237+
next := #[{ range? := stx'.getRange?, task := promise.resultD default }]
238238
}
239239
-- Update `tacSnap?` to old unfolding
240240
withTheReader Term.Context ({ · with tacSnap? := some {

Diff for: src/Lean/Elab/Tactic/BuiltinTactic.lean

+3-3
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ where
9393
desc := tac.getKind.toString
9494
diagnostics := .empty
9595
stx := tac
96-
inner? := some { range?, task := inner.result }
97-
finished := { range?, task := finished.result }
98-
next := #[{ range? := stxs.getRange?, task := next.result }]
96+
inner? := some { range?, task := inner.resultD default }
97+
finished := { range?, task := finished.resultD default }
98+
next := #[{ range? := stxs.getRange?, task := next.resultD default }]
9999
}
100100
-- Run `tac` in a fresh info tree state and store resulting state in snapshot for
101101
-- incremental reporting, then add back saved trees. Here we rely on `evalTactic`

Diff for: src/Lean/Elab/Tactic/Induction.lean

+2-2
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,9 @@ where
285285
stx := mkNullNode altStxs
286286
diagnostics := .empty
287287
inner? := none
288-
finished := { range? := none, task := finished.result }
288+
finished := { range? := none, task := finished.resultD default }
289289
next := Array.zipWith
290-
(fun stx prom => { range? := stx.getRange?, task := prom.result })
290+
(fun stx prom => { range? := stx.getRange?, task := prom.resultD default })
291291
altStxs altPromises
292292
}
293293
goWithIncremental <| altPromises.mapIdx fun i prom => {

Diff for: src/Lean/Environment.lean

+32-37
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ private def setCheckedSync (env : Environment) (newChecked : Kernel.Environment)
463463

464464
def promiseChecked (env : Environment) : BaseIO (Environment × IO.Promise Environment) := do
465465
let prom ← IO.Promise.new
466-
return ({ env with checked := prom.result.bind (sync := true) (·.checked) }, prom)
466+
return ({ env with checked := prom.result?.bind (sync := true) (·.getD env |>.checked) }, prom)
467467

468468
/--
469469
Checks whether the given declaration name may potentially added, or have been added, to the current
@@ -607,8 +607,9 @@ structure AddConstAsyncResult where
607607
/--
608608
Resulting "async branch" environment which should be used to add the desired declaration in a new
609609
task and then call `AddConstAsyncResult.commit*` to commit results back to the main environment.
610-
One of `commitCheckEnv` or `commitFailure` must be called eventually to prevent deadlocks on main
611-
branch accesses.
610+
`commitCheckEnv` completes the addition; if it is not called and the `AddConstAsyncResult` object
611+
is dropped, `sorry`ed default values will be reported instead and the kernel environment will be
612+
left unchanged.
612613
-/
613614
asyncEnv : Environment
614615
private constName : Name
@@ -632,20 +633,43 @@ def addConstAsync (env : Environment) (constName : Name) (kind : ConstantKind) (
632633
let infoPromise ← IO.Promise.new
633634
let extensionsPromise ← IO.Promise.new
634635
let checkedEnvPromise ← IO.Promise.new
636+
637+
-- fallback info in case promises are dropped unfulfilled
638+
let fallbackVal := {
639+
name := constName
640+
levelParams := []
641+
type := mkApp2 (mkConst ``sorryAx [0]) (mkSort 0) (mkConst ``true)
642+
}
643+
let fallbackInfo := match kind with
644+
| .defn => .defnInfo { fallbackVal with
645+
value := mkApp2 (mkConst ``sorryAx [0]) fallbackVal.type (mkConst ``true)
646+
hints := .abbrev
647+
safety := .safe
648+
}
649+
| .thm => .thmInfo { fallbackVal with
650+
value := mkApp2 (mkConst ``sorryAx [0]) fallbackVal.type (mkConst ``true)
651+
}
652+
| .axiom => .axiomInfo { fallbackVal with
653+
isUnsafe := false
654+
}
655+
| k => panic! s!"AddConstAsyncResult.addConstAsync: unsupported constant kind {repr k}"
656+
635657
let asyncConst := {
636658
constInfo := {
637659
name := constName
638660
kind
639-
sig := sigPromise.result
640-
constInfo := infoPromise.result
661+
sig := sigPromise.resultD fallbackVal
662+
constInfo := infoPromise.resultD fallbackInfo
641663
}
642-
exts? := guard reportExts *> some extensionsPromise.result
664+
exts? := guard reportExts *> some (extensionsPromise.resultD #[])
643665
}
644666
return {
645667
constName, kind
646668
mainEnv := { env with
647669
asyncConsts := env.asyncConsts.add asyncConst
648-
checked := checkedEnvPromise.result }
670+
checked := checkedEnvPromise.result?.bind (sync := true) fun
671+
| some kenv => .pure kenv
672+
| none => env.checked }
649673
asyncEnv := { env with
650674
asyncCtx? := some { declPrefix := privateToUserName constName.eraseMacroScopes }
651675
}
@@ -679,43 +703,14 @@ def AddConstAsyncResult.commitConst (res : AddConstAsyncResult) (env : Environme
679703
let kind' := .ofConstantInfo info
680704
if res.kind != kind' then
681705
throw <| .userError s!"AddConstAsyncResult.commitConst: constant has kind {repr kind'} but expected {repr res.kind}"
682-
let sig := res.sigPromise.result.get
706+
let sig := res.sigPromise.result!.get
683707
if sig.levelParams != info.levelParams then
684708
throw <| .userError s!"AddConstAsyncResult.commitConst: constant has level params {info.levelParams} but expected {sig.levelParams}"
685709
if sig.type != info.type then
686710
throw <| .userError s!"AddConstAsyncResult.commitConst: constant has type {info.type} but expected {sig.type}"
687711
res.infoPromise.resolve info
688712
res.extensionsPromise.resolve env.checkedWithoutAsync.extensions
689713

690-
/--
691-
Aborts async addition, filling in missing information with default values/sorries and leaving the
692-
kernel environment unchanged.
693-
-/
694-
def AddConstAsyncResult.commitFailure (res : AddConstAsyncResult) : BaseIO Unit := do
695-
let val := if (← IO.hasFinished res.sigPromise.result) then
696-
res.sigPromise.result.get
697-
else {
698-
name := res.constName
699-
levelParams := []
700-
type := mkApp2 (mkConst ``sorryAx [0]) (mkSort 0) (mkConst ``true)
701-
}
702-
res.sigPromise.resolve val
703-
res.infoPromise.resolve <| match res.kind with
704-
| .defn => .defnInfo { val with
705-
value := mkApp2 (mkConst ``sorryAx [0]) val.type (mkConst ``true)
706-
hints := .abbrev
707-
safety := .safe
708-
}
709-
| .thm => .thmInfo { val with
710-
value := mkApp2 (mkConst ``sorryAx [0]) val.type (mkConst ``true)
711-
}
712-
| .axiom => .axiomInfo { val with
713-
isUnsafe := false
714-
}
715-
| k => panic! s!"AddConstAsyncResult.commitFailure: unsupported constant kind {repr k}"
716-
res.extensionsPromise.resolve #[]
717-
let _ ← BaseIO.mapTask (t := res.asyncEnv.checked) (sync := true) res.checkedEnvPromise.resolve
718-
719714
/--
720715
Assuming `Lean.addDecl` has been run for the constant to be added on the async environment branch,
721716
commits the full constant info from that call to the main environment, waits for the final kernel

Diff for: src/Lean/Language/Lean.lean

+9-9
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ where
467467
infoTree? := cmdState.infoState.trees[0]!
468468
result? := some {
469469
cmdState
470-
firstCmdSnap := { range? := none, task := prom.result }
470+
firstCmdSnap := { range? := none, task := prom.result! }
471471
}
472472
}
473473

@@ -489,7 +489,7 @@ where
489489
oldNext.bindIO (sync := true) fun oldNext => do
490490
parseCmd oldNext newParserState oldFinished.cmdState newProm sync ctx
491491
return .pure ()
492-
prom.resolve <| { old with nextCmdSnap? := some { range? := none, task := newProm.result } }
492+
prom.resolve <| { old with nextCmdSnap? := some { range? := none, task := newProm.result! } }
493493
else prom.resolve old -- terminal command, we're done!
494494

495495
-- fast path, do not even start new task for this snapshot (see [Incremental Parsing])
@@ -548,15 +548,15 @@ where
548548
-- report terminal tasks on first line of decl such as not to hide incremental tactics'
549549
-- progress
550550
let initRange? := getNiceCommandStartPos? stx |>.map fun pos => ⟨pos, pos⟩
551-
let finishedSnap := { range? := initRange?, task := finishedPromise.result }
551+
let finishedSnap := { range? := initRange?, task := finishedPromise.result! }
552552
let tacticCache ← old?.map (·.tacticCache) |>.getDM (IO.mkRef {})
553553

554554
let minimalSnapshots := internal.cmdlineSnapshots.get cmdState.scopes.head!.opts
555555
let next? ← if Parser.isTerminalCommand stx then pure none
556556
-- for now, wait on "command finished" snapshot before parsing next command
557557
else some <$> IO.Promise.new
558558
let nextCmdSnap? := next?.map
559-
({ range? := some ⟨parserState.pos, ctx.input.endPos⟩, task := ·.result })
559+
({ range? := some ⟨parserState.pos, ctx.input.endPos⟩, task := ·.result! })
560560
let diagnostics ← Snapshot.Diagnostics.ofMessageLog msgLog
561561
let (stx', parserState') := if minimalSnapshots && !Parser.isTerminalCommand stx then
562562
(default, default)
@@ -565,8 +565,8 @@ where
565565
prom.resolve {
566566
diagnostics, finishedSnap, tacticCache, nextCmdSnap?
567567
stx := stx', parserState := parserState'
568-
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
569-
reportSnap := { range? := initRange?, task := reportPromise.result }
568+
elabSnap := { range? := stx.getRange?, task := elabPromise.result! }
569+
reportSnap := { range? := initRange?, task := reportPromise.result! }
570570
}
571571
let cmdState ← doElab stx cmdState beginPos
572572
{ old? := old?.map fun old => ⟨old.stx, old.elabSnap⟩, new := elabPromise }
@@ -576,8 +576,8 @@ where
576576
-- We want to trace all of `CommandParsedSnapshot` but `traceTask` is part of it, so let's
577577
-- create a temporary snapshot tree containing all tasks but it
578578
let snaps := #[
579-
{ range? := none, task := elabPromise.result.map (sync := true) toSnapshotTree },
580-
{ range? := none, task := finishedPromise.result.map (sync := true) toSnapshotTree }] ++
579+
{ range? := none, task := elabPromise.result!.map (sync := true) toSnapshotTree },
580+
{ range? := none, task := finishedPromise.result!.map (sync := true) toSnapshotTree }] ++
581581
cmdState.snapshotTasks
582582
let tree := SnapshotTree.mk { diagnostics := .empty } snaps
583583
BaseIO.bindTask (← tree.waitAll) fun _ => do
@@ -672,7 +672,7 @@ def processCommands (inputCtx : Parser.InputContext) (parserState : Parser.Modul
672672
process.parseCmd (old?.map (·.2)) parserState commandState prom (sync := true)
673673
|>.run (old?.map (·.1))
674674
|>.run { inputCtx with }
675-
return prom.result
675+
return prom.result!
676676

677677
/-- Waits for and returns final command state, if importing was successful. -/
678678
partial def waitForFinalCmdState? (snap : InitialSnapshot) : Option Command.State := do

Diff for: src/Lean/Server/FileWorker.lean

+1-1
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ section Initialization
379379
let processor := Language.Lean.process (setupImports meta opts chanOut srcSearchPathPromise)
380380
let processor ← Language.mkIncrementalProcessor processor
381381
let initSnap ← processor meta.mkInputContext
382-
let _ ← IO.mapTask (t := srcSearchPathPromise.result) fun srcSearchPath => do
382+
let _ ← IO.mapTask (t := srcSearchPathPromise.result!) fun srcSearchPath => do
383383
let importClosure := getImportClosure? initSnap
384384
let importClosure ← importClosure.filterMapM (documentUriFromModule srcSearchPath ·)
385385
chanOut.send <| mkImportClosureNotification importClosure

Diff for: src/Std/Internal/Async/Basic.lean

+2-2
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,14 @@ Create an `AsyncTask` that resolves to the value of `x`.
9191
-/
9292
@[inline]
9393
def ofPromise (x : IO.Promise (Except IO.Error α)) : AsyncTask α :=
94-
x.result
94+
x.result!
9595

9696
/--
9797
Create an `AsyncTask` that resolves to the value of `x`.
9898
-/
9999
@[inline]
100100
def ofPurePromise (x : IO.Promise α) : AsyncTask α :=
101-
x.result.map pure
101+
x.result!.map pure
102102

103103
/--
104104
Obtain the `IO.TaskState` of `x`.

Diff for: src/Std/Sync/Channel.lean

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def Channel.recv? (ch : Channel α) : BaseIO (Task (Option α)) :=
7676
else if !st.closed then
7777
let promise ← IO.Promise.new
7878
set { st with consumers := st.consumers.enqueue promise }
79-
return promise.result
79+
return promise.result?.map (sync := true) (·.bind id)
8080
else
8181
return .pure none
8282

Diff for: src/bin/lean-gdb.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ class LeanObjectPrinter:
4949
"""Print a lean_object object."""
5050

5151
kinds = [
52-
# 244, ...
52+
# 243, ...
5353
('ctor', []),
54+
('promise', ['m_result']),
5455
('closure', ['m_arity', 'm_fun', 'm_num_fixed']),
5556
('array', ['m_size', 'm_capacity']),
5657
('sarray', ['m_size', 'm_capacity']),
@@ -62,7 +63,7 @@ class LeanObjectPrinter:
6263
('ref', ['m_value']),
6364
('external', ['m_class', 'm_data']),
6465
]
65-
lean_max_ctor_tag = 244
66+
lean_max_ctor_tag = 243
6667

6768
def __init__(self, val):
6869
self.val = val.address

0 commit comments

Comments
 (0)