Skip to content

Commit 7c79f05

Browse files
authored
feat: API to avoid deadlocks from dropped promises (#6958)
This PR improves the `Promise` API by considering how dropped promises can lead to never-finished tasks.
1 parent 1248a55 commit 7c79f05

26 files changed

+397
-320
lines changed

Diff for: src/Init/System/IO.lean

+6-1
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,12 @@ protected def TaskState.toString : TaskState → String
238238

239239
instance : ToString TaskState := ⟨TaskState.toString⟩
240240

241-
/-- Returns current state of the `Task` in the Lean runtime's task manager. -/
241+
/--
242+
Returns current state of the `Task` in the Lean runtime's task manager.
243+
244+
Note that for tasks derived from `Promise`s, `waiting` and `running` should be considered
245+
equivalent.
246+
-/
242247
@[extern "lean_io_get_task_state"] opaque getTaskState : @& Task α → BaseIO TaskState
243248

244249
/-- Check if the task has finished execution, at which point calling `Task.get` will return immediately. -/

Diff for: src/Init/System/Promise.lean

+30-11
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@ private structure PromiseImpl (α : Type) : Type where
2121
2222
Typical usage is as follows:
2323
1. `let promise ← Promise.new` creates a promise
24-
2. `promise.result : Task α` can now be passed around
25-
3. `promise.result.get` blocks until the promise is resolved
24+
2. `promise.result? : Task (Option α)` can now be passed around
25+
3. `promise.result?.get` blocks until the promise is resolved
2626
4. `promise.resolve a` resolves the promise
27-
5. `promise.result.get` now returns `a`
27+
5. `promise.result?.get` now returns `some a`
2828
29-
Every promise must eventually be resolved.
30-
Otherwise the memory used for the promise will be leaked,
31-
and any tasks depending on the promise's result will wait forever.
29+
If the promise is dropped without ever being resolved, `promise.result?.get` will return `none`.
30+
See `Promise.result!/resultD` for other ways to handle this case.
3231
-/
3332
def Promise (α : Type) : Type := PromiseImpl α
3433

@@ -47,12 +46,32 @@ Only the first call to this function has an effect.
4746
@[extern "lean_io_promise_resolve"]
4847
opaque Promise.resolve (value : α) (promise : @& Promise α) : BaseIO Unit
4948

49+
/--
50+
Like `Promise.result`, but resolves to `none` if the promise is dropped without ever being resolved.
51+
-/
52+
@[extern "lean_io_promise_result_opt"]
53+
opaque Promise.result? (promise : @& Promise α) : Task (Option α)
54+
55+
-- SU: not planning to make this public without a lot more thought and motivation
56+
@[extern "lean_option_get_or_block"]
57+
private opaque Option.getOrBlock! [Nonempty α] : Option α → α
58+
5059
/--
5160
The result task of a `Promise`.
5261
53-
The task blocks until `Promise.resolve` is called.
62+
The task blocks until `Promise.resolve` is called. If the promise is dropped without ever being
63+
resolved, evaluating the task will panic and, when not using fatal panics, block forever. Use
64+
`Promise.result?` to handle this case explicitly.
65+
-/
66+
def Promise.result! (promise : @& Promise α) : Task α :=
67+
let _ : Nonempty α := promise.h
68+
promise.result?.map (sync := true) Option.getOrBlock!
69+
70+
@[inherit_doc Promise.result!, deprecated Promise.result! (since := "2025-02-05")]
71+
def Promise.result := @Promise.result!
72+
73+
/--
74+
Like `Promise.result`, but resolves to `dflt` if the promise is dropped without ever being resolved.
5475
-/
55-
@[extern "lean_io_promise_result"]
56-
opaque Promise.result (promise : Promise α) : Task α :=
57-
have : Nonempty α := promise.h
58-
Classical.choice inferInstance
76+
def Promise.resultD (promise : Promise α) (dflt : α): Task α :=
77+
promise.result?.map (sync := true) (·.getD dflt)

Diff for: src/Lean/AddDecl.lean

+3-6
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,9 @@ def addDecl (decl : Declaration) : CoreM Unit := do
7777
async.commitConst async.asyncEnv (some info)
7878
setEnv async.mainEnv
7979
let checkAct ← Core.wrapAsyncAsSnapshot fun _ => do
80-
try
81-
setEnv async.asyncEnv
82-
doAdd
83-
async.commitCheckEnv (← getEnv)
84-
finally
85-
async.commitFailure
80+
setEnv async.asyncEnv
81+
doAdd
82+
async.commitCheckEnv (← getEnv)
8683
let t ← BaseIO.mapTask (fun _ => checkAct) env.checked
8784
let endRange? := (← getRef).getTailPos?.map fun pos => ⟨pos, pos⟩
8885
Core.logSnapshotTask { range? := endRange?, task := t }

Diff for: src/Lean/Elab/Command.lean

+32-32
Original file line numberDiff line numberDiff line change
@@ -489,38 +489,38 @@ partial def elabCommand (stx : Syntax) : CommandElabM Unit := do
489489
return oldSnap
490490
let oldCmds? := oldSnap?.map fun old =>
491491
if old.newStx.isOfKind nullKind then old.newStx.getArgs else #[old.newStx]
492-
Language.withAlwaysResolvedPromises cmds.size fun cmdPromises => do
493-
snap.new.resolve <| .ofTyped {
494-
diagnostics := .empty
495-
macroDecl := decl
496-
newStx := stxNew
497-
newNextMacroScope := nextMacroScope
498-
hasTraces
499-
next := Array.zipWith (fun cmdPromise cmd =>
500-
{ range? := cmd.getRange?, task := cmdPromise.result }) cmdPromises cmds
501-
: MacroExpandedSnapshot
502-
}
503-
-- After the first command whose syntax tree changed, we must disable
504-
-- incremental reuse
505-
let mut reusedCmds := true
506-
let opts ← getOptions
507-
-- For each command, associate it with new promise and old snapshot, if any, and
508-
-- elaborate recursively
509-
for cmd in cmds, cmdPromise in cmdPromises, i in [0:cmds.size] do
510-
let oldCmd? := oldCmds?.bind (·[i]?)
511-
withReader ({ · with snap? := some {
512-
new := cmdPromise
513-
old? := do
514-
guard reusedCmds
515-
let old ← oldSnap?
516-
return { stx := (← oldCmd?), val := (← old.next[i]?) }
517-
} }) do
518-
elabCommand cmd
519-
-- Resolve promise for commands not supporting incrementality; waiting for
520-
-- `withAlwaysResolvedPromises` to do this could block reporting by later
521-
-- commands
522-
cmdPromise.resolve default
523-
reusedCmds := reusedCmds && oldCmd?.any (·.eqWithInfoAndTraceReuse opts cmd)
492+
let cmdPromises ← cmds.mapM fun _ => IO.Promise.new
493+
snap.new.resolve <| .ofTyped {
494+
diagnostics := .empty
495+
macroDecl := decl
496+
newStx := stxNew
497+
newNextMacroScope := nextMacroScope
498+
hasTraces
499+
next := Array.zipWith (fun cmdPromise cmd =>
500+
{ range? := cmd.getRange?, task := cmdPromise.resultD default }) cmdPromises cmds
501+
: MacroExpandedSnapshot
502+
}
503+
-- After the first command whose syntax tree changed, we must disable
504+
-- incremental reuse
505+
let mut reusedCmds := true
506+
let opts ← getOptions
507+
-- For each command, associate it with new promise and old snapshot, if any, and
508+
-- elaborate recursively
509+
for cmd in cmds, cmdPromise in cmdPromises, i in [0:cmds.size] do
510+
let oldCmd? := oldCmds?.bind (·[i]?)
511+
withReader ({ · with snap? := some {
512+
new := cmdPromise
513+
old? := do
514+
guard reusedCmds
515+
let old ← oldSnap?
516+
return { stx := (← oldCmd?), val := (← old.next[i]?) }
517+
} }) do
518+
elabCommand cmd
519+
-- Resolve promise for commands not supporting incrementality; waiting for
520+
-- `withAlwaysResolvedPromises` to do this could block reporting by later
521+
-- commands
522+
cmdPromise.resolve default
523+
reusedCmds := reusedCmds && oldCmd?.any (·.eqWithInfoAndTraceReuse opts cmd)
524524
else
525525
elabCommand stxNew
526526
| _ =>

Diff for: src/Lean/Elab/MutualDef.lean

+80-80
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ where
249249
mkBodyTask (body : Syntax) (new : IO.Promise (Option BodyProcessedSnapshot)) :
250250
Language.SnapshotTask (Option BodyProcessedSnapshot) :=
251251
let rangeStx := getBodyTerm? body |>.getD body
252-
{ range? := rangeStx.getRange?, task := new.result }
252+
{ range? := rangeStx.getRange?, task := new.resultD default }
253253

254254
/--
255255
If `body` allows for incremental tactic reporting and reuse, creates a snapshot task out of the
@@ -261,7 +261,7 @@ where
261261
:= do
262262
if let some e := getBodyTerm? body then
263263
if let `(by $tacs*) := e then
264-
return (e, some { range? := mkNullNode tacs |>.getRange?, task := tacPromise.result })
264+
return (e, some { range? := mkNullNode tacs |>.getRange?, task := tacPromise.resultD default })
265265
tacPromise.resolve default
266266
return (none, none)
267267

@@ -1005,45 +1005,45 @@ def elabMutualDef (vars : Array Expr) (sc : Command.Scope) (views : Array DefVie
10051005
else
10061006
go
10071007
where
1008-
go :=
1009-
withAlwaysResolvedPromises views.size fun bodyPromises =>
1010-
withAlwaysResolvedPromises views.size fun tacPromises => do
1011-
let scopeLevelNames ← getLevelNames
1012-
let headers ← elabHeaders views bodyPromises tacPromises
1013-
let headers ← levelMVarToParamHeaders views headers
1014-
let allUserLevelNames := getAllUserLevelNames headers
1015-
withFunLocalDecls headers fun funFVars => do
1016-
for view in views, funFVar in funFVars do
1017-
addLocalVarInfo view.declId funFVar
1018-
let values ←
1019-
try
1020-
let values ← elabFunValues headers vars sc
1021-
Term.synthesizeSyntheticMVarsNoPostponing
1022-
values.mapM (instantiateMVarsProfiling ·)
1023-
catch ex =>
1024-
logException ex
1025-
headers.mapM fun header => withRef header.declId <| mkLabeledSorry header.type (synthetic := true) (unique := true)
1026-
let headers ← headers.mapM instantiateMVarsAtHeader
1027-
let letRecsToLift ← getLetRecsToLift
1028-
let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift
1029-
checkLetRecsToLiftTypes funFVars letRecsToLift
1030-
(if headers.all (·.kind.isTheorem) && !deprecated.oldSectionVars.get (← getOptions) then withHeaderSecVars vars sc headers else withUsed vars headers values letRecsToLift) fun vars => do
1031-
let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift
1032-
checkAllDeclNamesDistinct preDefs
1033-
for preDef in preDefs do
1034-
trace[Elab.definition] "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
1035-
let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamTypesPreDecls preDefs
1036-
let preDefs ← instantiateMVarsAtPreDecls preDefs
1037-
let preDefs ← shareCommonPreDefs preDefs
1038-
let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames
1039-
for preDef in preDefs do
1040-
trace[Elab.definition] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}"
1041-
addPreDefinitions preDefs
1042-
processDeriving headers
1043-
for view in views, header in headers do
1044-
-- NOTE: this should be the full `ref`, and thus needs to be done after any snapshotting
1045-
-- that depends only on a part of the ref
1046-
addDeclarationRangesForBuiltin header.declName view.modifiers.stx view.ref
1008+
go := do
1009+
let bodyPromises ← views.mapM fun _ => IO.Promise.new
1010+
let tacPromises ← views.mapM fun _ => IO.Promise.new
1011+
let scopeLevelNames ← getLevelNames
1012+
let headers ← elabHeaders views bodyPromises tacPromises
1013+
let headers ← levelMVarToParamHeaders views headers
1014+
let allUserLevelNames := getAllUserLevelNames headers
1015+
withFunLocalDecls headers fun funFVars => do
1016+
for view in views, funFVar in funFVars do
1017+
addLocalVarInfo view.declId funFVar
1018+
let values ←
1019+
try
1020+
let values ← elabFunValues headers vars sc
1021+
Term.synthesizeSyntheticMVarsNoPostponing
1022+
values.mapM (instantiateMVarsProfiling ·)
1023+
catch ex =>
1024+
logException ex
1025+
headers.mapM fun header => withRef header.declId <| mkLabeledSorry header.type (synthetic := true) (unique := true)
1026+
let headers ← headers.mapM instantiateMVarsAtHeader
1027+
let letRecsToLift ← getLetRecsToLift
1028+
let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift
1029+
checkLetRecsToLiftTypes funFVars letRecsToLift
1030+
(if headers.all (·.kind.isTheorem) && !deprecated.oldSectionVars.get (← getOptions) then withHeaderSecVars vars sc headers else withUsed vars headers values letRecsToLift) fun vars => do
1031+
let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift
1032+
checkAllDeclNamesDistinct preDefs
1033+
for preDef in preDefs do
1034+
trace[Elab.definition] "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
1035+
let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamTypesPreDecls preDefs
1036+
let preDefs ← instantiateMVarsAtPreDecls preDefs
1037+
let preDefs ← shareCommonPreDefs preDefs
1038+
let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames
1039+
for preDef in preDefs do
1040+
trace[Elab.definition] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}"
1041+
addPreDefinitions preDefs
1042+
processDeriving headers
1043+
for view in views, header in headers do
1044+
-- NOTE: this should be the full `ref`, and thus needs to be done after any snapshotting
1045+
-- that depends only on a part of the ref
1046+
addDeclarationRangesForBuiltin header.declName view.modifiers.stx view.ref
10471047

10481048

10491049
processDeriving (headers : Array DefViewElabHeader) := do
@@ -1060,46 +1060,46 @@ namespace Command
10601060

10611061
def elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
10621062
let opts ← getOptions
1063-
withAlwaysResolvedPromises ds.size fun headerPromises => do
1064-
let snap? := (← read).snap?
1065-
let mut views := #[]
1066-
let mut defs := #[]
1067-
let mut reusedAllHeaders := true
1068-
for h : i in [0:ds.size], headerPromise in headerPromises do
1069-
let d := ds[i]
1070-
let modifiers ← elabModifiers ⟨d[0]⟩
1071-
if ds.size > 1 && modifiers.isNonrec then
1072-
throwErrorAt d "invalid use of 'nonrec' modifier in 'mutual' block"
1073-
let mut view ← mkDefView modifiers d[1]
1074-
let fullHeaderRef := mkNullNode #[d[0], view.headerRef]
1075-
if let some snap := snap? then
1076-
view := { view with headerSnap? := some {
1077-
old? := do
1078-
-- transitioning from `Context.snap?` to `DefView.headerSnap?` invariant: if the
1079-
-- elaboration context and state are unchanged, and the syntax of this as well as all
1080-
-- previous headers is unchanged, then the elaboration result for this header (which
1081-
-- includes state from elaboration of previous headers!) should be unchanged.
1082-
guard reusedAllHeaders
1083-
let old ← snap.old?
1084-
-- blocking wait, `HeadersParsedSnapshot` (and hopefully others) should be quick
1085-
let old ← old.val.get.toTyped? DefsParsedSnapshot
1086-
let oldParsed ← old.defs[i]?
1087-
guard <| fullHeaderRef.eqWithInfoAndTraceReuse opts oldParsed.fullHeaderRef
1088-
-- no syntax guard to store, we already did the necessary checks
1089-
return ⟨.missing, oldParsed.headerProcessedSnap⟩
1090-
new := headerPromise
1091-
} }
1092-
defs := defs.push {
1093-
fullHeaderRef
1094-
headerProcessedSnap := { range? := d.getRange?, task := headerPromise.result }
1095-
}
1096-
reusedAllHeaders := reusedAllHeaders && view.headerSnap?.any (·.old?.isSome)
1097-
views := views.push view
1063+
let headerPromises ← ds.mapM fun _ => IO.Promise.new
1064+
let snap? := (← read).snap?
1065+
let mut views := #[]
1066+
let mut defs := #[]
1067+
let mut reusedAllHeaders := true
1068+
for h : i in [0:ds.size], headerPromise in headerPromises do
1069+
let d := ds[i]
1070+
let modifiers ← elabModifiers ⟨d[0]⟩
1071+
if ds.size > 1 && modifiers.isNonrec then
1072+
throwErrorAt d "invalid use of 'nonrec' modifier in 'mutual' block"
1073+
let mut view ← mkDefView modifiers d[1]
1074+
let fullHeaderRef := mkNullNode #[d[0], view.headerRef]
10981075
if let some snap := snap? then
1099-
-- no non-fatal diagnostics at this point
1100-
snap.new.resolve <| .ofTyped { defs, diagnostics := .empty : DefsParsedSnapshot }
1101-
let sc ← getScope
1102-
runTermElabM fun vars => Term.elabMutualDef vars sc views
1076+
view := { view with headerSnap? := some {
1077+
old? := do
1078+
-- transitioning from `Context.snap?` to `DefView.headerSnap?` invariant: if the
1079+
-- elaboration context and state are unchanged, and the syntax of this as well as all
1080+
-- previous headers is unchanged, then the elaboration result for this header (which
1081+
-- includes state from elaboration of previous headers!) should be unchanged.
1082+
guard reusedAllHeaders
1083+
let old ← snap.old?
1084+
-- blocking wait, `HeadersParsedSnapshot` (and hopefully others) should be quick
1085+
let old ← old.val.get.toTyped? DefsParsedSnapshot
1086+
let oldParsed ← old.defs[i]?
1087+
guard <| fullHeaderRef.eqWithInfoAndTraceReuse opts oldParsed.fullHeaderRef
1088+
-- no syntax guard to store, we already did the necessary checks
1089+
return ⟨.missing, oldParsed.headerProcessedSnap⟩
1090+
new := headerPromise
1091+
} }
1092+
defs := defs.push {
1093+
fullHeaderRef
1094+
headerProcessedSnap := { range? := d.getRange?, task := headerPromise.resultD default }
1095+
}
1096+
reusedAllHeaders := reusedAllHeaders && view.headerSnap?.any (·.old?.isSome)
1097+
views := views.push view
1098+
if let some snap := snap? then
1099+
-- no non-fatal diagnostics at this point
1100+
snap.new.resolve <| .ofTyped { defs, diagnostics := .empty : DefsParsedSnapshot }
1101+
let sc ← getScope
1102+
runTermElabM fun vars => Term.elabMutualDef vars sc views
11031103

11041104
builtin_initialize
11051105
registerTraceClass `Elab.definition.mkClosure

Diff for: src/Lean/Elab/Tactic/Basic.lean

+18-18
Original file line numberDiff line numberDiff line change
@@ -224,26 +224,26 @@ where
224224
guard <| state.term.meta.core.traceState.traces.size == 0
225225
guard <| traceState.traces.size == 0
226226
return old.val.get
227-
Language.withAlwaysResolvedPromise fun promise => do
228-
-- Store new unfolding in the snapshot tree
229-
snap.new.resolve {
230-
stx := stx'
227+
let promise ← IO.Promise.new
228+
-- Store new unfolding in the snapshot tree
229+
snap.new.resolve {
230+
stx := stx'
231+
diagnostics := .empty
232+
inner? := none
233+
finished := .pure {
231234
diagnostics := .empty
232-
inner? := none
233-
finished := .pure {
234-
diagnostics := .empty
235-
state? := (← Tactic.saveState)
236-
}
237-
next := #[{ range? := stx'.getRange?, task := promise.result }]
235+
state? := (← Tactic.saveState)
238236
}
239-
-- Update `tacSnap?` to old unfolding
240-
withTheReader Term.Context ({ · with tacSnap? := some {
241-
new := promise
242-
old? := do
243-
let old ← old?
244-
return ⟨old.stx, (← old.next.get? 0)⟩
245-
} }) do
246-
evalTactic stx'
237+
next := #[{ range? := stx'.getRange?, task := promise.resultD default }]
238+
}
239+
-- Update `tacSnap?` to old unfolding
240+
withTheReader Term.Context ({ · with tacSnap? := some {
241+
new := promise
242+
old? := do
243+
let old ← old?
244+
return ⟨old.stx, (← old.next.get? 0)⟩
245+
} }) do
246+
evalTactic stx'
247247
return
248248
evalTactic stx'
249249
catch ex => handleEx s failures ex (expandEval s ms evalFns)

0 commit comments

Comments
 (0)