Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: WF: create unfold theorems eagerly #6898

Draft
wants to merge 19 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions src/Lean/Elab/PreDefinition/Eqns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -340,15 +340,21 @@ private def unfoldLHS (declName : Name) (mvarId : MVarId) : MetaM MVarId := mvar
-- Else use delta reduction
deltaLHS mvarId

private partial def mkEqnProof (declName : Name) (type : Expr) : MetaM Expr := do
private partial def mkEqnProof (declName : Name) (type : Expr) (tryRefl : Bool) : MetaM Expr := do
trace[Elab.definition.eqns] "proving: {type}"
withNewMCtxDepth do
let main ← mkFreshExprSyntheticOpaqueMVar type
let (_, mvarId) ← main.mvarId!.intros

-- Try rfl before deltaLHS to avoid `id` checkpoints in the proof, which would make
-- the lemma ineligible for dsimp
unless ← withAtLeastTransparency .all (tryURefl mvarId) do
go (← unfoldLHS declName mvarId)
-- For well-founded recursion this is disabled: The equation may hold
-- definitionally as written, but not embedded in larger proofs
if tryRefl then
if (← withAtLeastTransparency .all (tryURefl mvarId)) then
return ← instantiateMVars main

go (← unfoldLHS declName mvarId)
instantiateMVars main
where
/--
Expand Down Expand Up @@ -391,9 +397,10 @@ This unfolds the function application on the LHS (using an unfold theorem, if pr
delta-reduction), calculates the types for the equational theorems using `mkEqnTypes`, and then
proves them using `mkEqnProof`.

This is currently used for non-recursive functions and for functions defined by partial_fixpoint.
This is currently used for non-recursive functions, well-founded recursion and partial_fixpoint,
but not for structural recursion.
-/
def mkEqns (declName : Name) : MetaM (Array Name) := do
def mkEqns (declName : Name) (declNames : Array Name) (tryRefl := true): MetaM (Array Name) := do
let info ← getConstInfoDefn declName
let us := info.levelParams.map mkLevelParam
withOptions (tactic.hygienic.set · false) do
Expand All @@ -402,14 +409,14 @@ def mkEqns (declName : Name) : MetaM (Array Name) := do
forallTelescope (cleanupAnnotations := true) target fun xs target => do
let goal ← mkFreshExprSyntheticOpaqueMVar target
withReducible do
mkEqnTypes #[] goal.mvarId!
mkEqnTypes declNames goal.mvarId!
let mut thmNames := #[]
for h : i in [: eqnTypes.size] do
let type := eqnTypes[i]
trace[Elab.definition.eqns] "eqnType[{i}]: {eqnTypes[i]}"
let name := (Name.str declName eqnThmSuffixBase).appendIndexAfter (i+1)
thmNames := thmNames.push name
let value ← mkEqnProof declName type
let value ← mkEqnProof declName type tryRefl
let (type, value) ← removeUnusedEqnHypotheses type value
addDecl <| Declaration.thmDecl {
name, type, value
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/PreDefinition/Nonrec/Eqns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
return none
if (← getEnv).contains declName then
if backward.eqns.nonrecursive.get (← getOptions) then
mkEqns declName
mkEqns declName #[]
else
let o ← mkSimpleEqThm declName
return o.map (#[·])
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Elab/PreDefinition/PartialFixpoint/Eqns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def getUnfoldFor? (declName : Name) : MetaM (Option Name) := do
return some (← mkUnfoldEq declName info)

def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
if let some _ := eqnInfoExt.find? (← getEnv) declName then
mkEqns declName
if let some info := eqnInfoExt.find? (← getEnv) declName then
mkEqns declName info.declNames
else
return none

Expand Down
1 change: 1 addition & 0 deletions src/Lean/Elab/PreDefinition/WF.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.Elab.PreDefinition.WF.Main
import Lean.Elab.PreDefinition.WF.Eqns
113 changes: 14 additions & 99 deletions src/Lean/Elab/PreDefinition/WF/Eqns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Lean.Meta.Tactic.Split
import Lean.Elab.PreDefinition.Basic
import Lean.Elab.PreDefinition.Eqns
import Lean.Meta.ArgsPacker.Basic
import Lean.Elab.PreDefinition.WF.Unfold
import Init.Data.Array.Basic

namespace Lean.Elab.WF
Expand All @@ -22,101 +23,6 @@ structure EqnInfo extends EqnInfoCore where
argsPacker : ArgsPacker
deriving Inhabited

private def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
let target ← mvarId.getType'
let some (_, lhs, rhs) := target.eq? | unreachable!

-- lhs should be an application of the declNameNonrec, which unfolds to an
-- application of fix in one step
let some lhs' ← delta? lhs | throwError "rwFixEq: cannot delta-reduce {lhs}"
let_expr WellFounded.fix _α _C _r _hwf F x := lhs'
| throwTacticEx `rwFixEq mvarId "expected saturated fixed-point application in {lhs'}"
let h := mkAppN (mkConst ``WellFounded.fix_eq lhs'.getAppFn.constLevels!) lhs'.getAppArgs

-- We used to just rewrite with `fix_eq` and continue with whatever RHS that produces, but that
-- would include more copies of `fix` resulting in large and confusing terms.
-- Instead we manually construct the new term in terms of the current functions,
-- which should be headed by the `declNameNonRec`, and should be defeq to the expected type

-- if lhs == e x and lhs' == fix .., then lhsNew := e x = F x (fun y _ => e y)
let ftype := (← inferType (mkApp F x)).bindingDomain!
let f' ← forallBoundedTelescope ftype (some 2) fun ys _ => do
mkLambdaFVars ys (.app lhs.appFn! ys[0]!)
let lhsNew := mkApp2 F x f'
let targetNew ← mkEq lhsNew rhs
let mvarNew ← mkFreshExprSyntheticOpaqueMVar targetNew
mvarId.assign (← mkEqTrans h mvarNew)
return mvarNew.mvarId!

private partial def mkProof (declName declNameNonRec : Name) (type : Expr) : MetaM Expr := do
trace[Elab.definition.wf.eqns] "proving: {type}"
withNewMCtxDepth do
let main ← mkFreshExprSyntheticOpaqueMVar type
let (_, mvarId) ← main.mvarId!.intros
let rec go (mvarId : MVarId) : MetaM Unit := do
trace[Elab.definition.wf.eqns] "step\n{MessageData.ofGoal mvarId}"
if ← withAtLeastTransparency .all (tryURefl mvarId) then
trace[Elab.definition.wf.eqns] "refl!"
return ()
else if (← tryContradiction mvarId) then
trace[Elab.definition.wf.eqns] "contradiction!"
return ()
else if let some mvarId ← simpMatch? mvarId then
trace[Elab.definition.wf.eqns] "simpMatch!"
go mvarId
else if let some mvarId ← simpIf? mvarId then
trace[Elab.definition.wf.eqns] "simpIf!"
go mvarId
else if let some mvarId ← whnfReducibleLHS? mvarId then
trace[Elab.definition.wf.eqns] "whnfReducibleLHS!"
go mvarId
else
let ctx ← Simp.mkContext (config := { dsimp := false, etaStruct := .none })
match (← simpTargetStar mvarId ctx (simprocs := {})).1 with
| TacticResultCNM.closed => return ()
| TacticResultCNM.modified mvarId =>
trace[Elab.definition.wf.eqns] "simp only!"
go mvarId
| TacticResultCNM.noChange =>
if let some mvarIds ← casesOnStuckLHS? mvarId then
trace[Elab.definition.wf.eqns] "case split into {mvarIds.size} goals"
mvarIds.forM go
else if let some mvarIds ← splitTarget? mvarId then
trace[Elab.definition.wf.eqns] "splitTarget into {mvarIds.length} goals"
mvarIds.forM go
else
-- At some point in the past, we looked for occurrences of Wf.fix to fold on the
-- LHS (introduced in 096e4eb), but it seems that code path was never used,
-- so #3133 removed it again (and can be recovered from there if this was premature).
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"

let mvarId ← if declName != declNameNonRec then deltaLHS mvarId else pure mvarId
let mvarId ← rwFixEq mvarId
go mvarId
instantiateMVars main

def mkEqns (declName : Name) (info : EqnInfo) : MetaM (Array Name) :=
withOptions (tactic.hygienic.set · false) do
let baseName := declName
let eqnTypes ← withNewMCtxDepth <| lambdaTelescope (cleanupAnnotations := true) info.value fun xs body => do
let us := info.levelParams.map mkLevelParam
let target ← mkEq (mkAppN (Lean.mkConst declName us) xs) body
let goal ← mkFreshExprSyntheticOpaqueMVar target
withReducible do
mkEqnTypes info.declNames goal.mvarId!
let mut thmNames := #[]
for h : i in [: eqnTypes.size] do
let type := eqnTypes[i]
trace[Elab.definition.wf.eqns] "{eqnTypes[i]}"
let name := (Name.str baseName eqnThmSuffixBase).appendIndexAfter (i+1)
thmNames := thmNames.push name
let value ← mkProof declName info.declNameNonRec type
let (type, value) ← removeUnusedEqnHypotheses type value
addDecl <| Declaration.thmDecl {
name, type, value
levelParams := info.levelParams
}
return thmNames

builtin_initialize eqnInfoExt : MapDeclarationExtension EqnInfo ← mkMapDeclarationExtension

Expand All @@ -138,17 +44,26 @@ def registerEqnsInfo (preDefs : Array PreDefinition) (declNameNonRec : Name) (fi

def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
if let some info := eqnInfoExt.find? (← getEnv) declName then
mkEqns declName info
mkEqns declName info.declNames (tryRefl := false)
else
return none

builtin_initialize
registerGetEqnsFn getEqnsFor?


-- Remove the rest of this file after the next stage update,
-- as we generate these eagerly now.
def getUnfoldFor? (declName : Name) : MetaM (Option Name) := do
let name := Name.str declName unfoldThmSuffix
let env ← getEnv
Eqns.getUnfoldFor? declName fun _ => eqnInfoExt.find? env declName |>.map (·.toEqnInfoCore)
if env.contains name then return name
let some info := eqnInfoExt.find? env declName | return none
mkUnfoldEq info.toEqnInfoCore info.declNameNonRec
return some name

builtin_initialize
registerGetEqnsFn getEqnsFor?
registerGetUnfoldEqnFn getUnfoldFor?
registerTraceClass `Elab.definition.wf.eqns


end Lean.Elab.WF
6 changes: 5 additions & 1 deletion src/Lean/Elab/PreDefinition/WF/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Lean.Elab.PreDefinition.WF.PackMutual
import Lean.Elab.PreDefinition.WF.Preprocess
import Lean.Elab.PreDefinition.WF.Rel
import Lean.Elab.PreDefinition.WF.Fix
import Lean.Elab.PreDefinition.WF.Eqns
import Lean.Elab.PreDefinition.WF.Unfold
import Lean.Elab.PreDefinition.WF.Ite
import Lean.Elab.PreDefinition.WF.GuessLex

Expand Down Expand Up @@ -61,6 +61,10 @@ def wfRecursion (preDefs : Array PreDefinition) (termMeasure?s : Array (Option T
Mutual.addPreDefsFromUnary preDefs preDefsNonrec preDefNonRec
let preDefs ← Mutual.cleanPreDefs preDefs
registerEqnsInfo preDefs preDefNonRec.declName fixedPrefixSize argsPacker
for preDef in preDefs do
unless preDef.kind.isTheorem do
unless (← isProp preDef.type) do
WF.mkUnfoldEq { preDef with } preDefNonRec.declName
Mutual.addPreDefAttributes preDefs

builtin_initialize registerTraceClass `Elab.definition.wf
Expand Down
104 changes: 104 additions & 0 deletions src/Lean/Elab/PreDefinition/WF/Unfold.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Elab.PreDefinition.Basic
import Lean.Elab.PreDefinition.Eqns

namespace Lean.Elab.WF
open Meta
open Eqns

private def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
let target ← mvarId.getType'
let some (_, lhs, rhs) := target.eq? | unreachable!

-- lhs should be an application of the declNameNonrec, which unfolds to an
-- application of fix in one step
let some lhs' ← delta? lhs | throwError "rwFixEq: cannot delta-reduce {lhs}"
let_expr WellFounded.fix _α _C _r _hwf F x := lhs'
| throwTacticEx `rwFixEq mvarId "expected saturated fixed-point application in {lhs'}"
let h := mkAppN (mkConst ``WellFounded.fix_eq lhs'.getAppFn.constLevels!) lhs'.getAppArgs

-- We used to just rewrite with `fix_eq` and continue with whatever RHS that produces, but that
-- would include more copies of `fix` resulting in large and confusing terms.
-- Instead we manually construct the new term in terms of the current functions,
-- which should be headed by the `declNameNonRec`, and should be defeq to the expected type

-- if lhs == e x and lhs' == fix .., then lhsNew := e x = F x (fun y _ => e y)
let ftype := (← inferType (mkApp F x)).bindingDomain!
let f' ← forallBoundedTelescope ftype (some 2) fun ys _ => do
mkLambdaFVars ys (.app lhs.appFn! ys[0]!)
let lhsNew := mkApp2 F x f'
let targetNew ← mkEq lhsNew rhs
let mvarNew ← mkFreshExprSyntheticOpaqueMVar targetNew
mvarId.assign (← mkEqTrans h mvarNew)
return mvarNew.mvarId!

private partial def mkUnfoldProof (declName declNameNonRec : Name) (type : Expr) : MetaM Expr := do
trace[Elab.definition.wf.eqns] "proving: {type}"
withNewMCtxDepth do
let main ← mkFreshExprSyntheticOpaqueMVar type
let (_, mvarId) ← main.mvarId!.intros
let rec go (mvarId : MVarId) : MetaM Unit := do
trace[Elab.definition.wf.eqns] "step\n{MessageData.ofGoal mvarId}"
if ← withAtLeastTransparency .all (tryURefl mvarId) then
trace[Elab.definition.wf.eqns] "refl!"
return ()
else if (← tryContradiction mvarId) then
trace[Elab.definition.wf.eqns] "contradiction!"
return ()
else if let some mvarId ← simpMatch? mvarId then
trace[Elab.definition.wf.eqns] "simpMatch!"
go mvarId
else if let some mvarId ← simpIf? mvarId then
trace[Elab.definition.wf.eqns] "simpIf!"
go mvarId
else if let some mvarId ← whnfReducibleLHS? mvarId then
trace[Elab.definition.wf.eqns] "whnfReducibleLHS!"
go mvarId
else
let ctx ← Simp.mkContext (config := { dsimp := false, etaStruct := .none })
match (← simpTargetStar mvarId ctx (simprocs := {})).1 with
| TacticResultCNM.closed => return ()
| TacticResultCNM.modified mvarId =>
trace[Elab.definition.wf.eqns] "simp only!"
go mvarId
| TacticResultCNM.noChange =>
if let some mvarIds ← casesOnStuckLHS? mvarId then
trace[Elab.definition.wf.eqns] "case split into {mvarIds.size} goals"
mvarIds.forM go
else if let some mvarIds ← splitTarget? mvarId then
trace[Elab.definition.wf.eqns] "splitTarget into {mvarIds.length} goals"
mvarIds.forM go
else
-- At some point in the past, we looked for occurrences of Wf.fix to fold on the
-- LHS (introduced in 096e4eb), but it seems that code path was never used,
-- so #3133 removed it again (and can be recovered from there if this was premature).
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"

let mvarId ← if declName != declNameNonRec then deltaLHS mvarId else pure mvarId
let mvarId ← rwFixEq mvarId
go mvarId
instantiateMVars main

-- TODO: Afer the next stage0 update, change the type to PreDefinition
def mkUnfoldEq (preDef : EqnInfoCore) (unaryPreDefName : Name) : MetaM Unit := do
withOptions (tactic.hygienic.set · false) do
let baseName := preDef.declName
lambdaTelescope preDef.value fun xs body => do
let us := preDef.levelParams.map mkLevelParam
let type ← mkEq (mkAppN (Lean.mkConst preDef.declName us) xs) body
let value ← mkUnfoldProof preDef.declName unaryPreDefName type
let type ← mkForallFVars xs type
let value ← mkLambdaFVars xs value
let name := Name.str baseName unfoldThmSuffix
addDecl <| Declaration.thmDecl {
name, type, value
levelParams := preDef.levelParams
}
trace[Elab.definition.wf] "mkUnfoldEq defined {.ofConstName name}"

end Lean.Elab.WF
6 changes: 0 additions & 6 deletions tests/lean/run/simpDiag.lean
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ info: [simp] Diagnostics
[simp] ack.eq_1 ↦ 768, succeeded: 768
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
---
info: [diag] Diagnostics
[kernel] unfolded declarations (max: 29, num: 2):
[kernel] Nat.casesOn ↦ 29
[kernel] Nat.rec ↦ 29
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
---
error: tactic 'simp' failed, nested error:
maximum recursion depth has been reached
use `set_option maxRecDepth <num>` to increase limit
Expand Down
Loading
Loading