diff --git a/Qq/Match.lean b/Qq/Match.lean index 8ae3267..46ad838 100644 --- a/Qq/Match.lean +++ b/Qq/Match.lean @@ -246,17 +246,6 @@ partial def isIrrefutablePattern : Term → Bool | `(true) => false | `(false) => false -- TODO properly | stx => stx.1.isIdent -scoped elab "_comefrom" n:ident "do" b:doSeq " in " body:term : term <= expectedType => do - let _ ← extractBind expectedType - (← elabTerm (← `(?m)).1.stripPos none).mvarId!.assign expectedType - elabTerm (← `(have $n:ident : ?m := (do $b:doSeq); $body)) expectedType - -scoped syntax "_comefrom" ident "do" doSeq : term -macro_rules | `(assert! (_comefrom $n do $b); $body) => `(_comefrom $n do $b in $body) - -scoped macro "comefrom" n:ident "do" b:doSeq : doElem => - `(doElem| assert! (_comefrom $n do $b)) - def mkLetDoSeqItem [Monad m] [MonadQuotation m] (pat : Term) (rhs : TSyntax `doElem) (alt : TSyntax ``doSeq) : m (List (TSyntax ``doSeqItem)) := do match pat with | `(_) => return [] @@ -356,11 +345,8 @@ macro_rules `(doElem| do $(lifts.push t):doSeqItem*) | _ => - let (pat', auxs) ← floatQMatch (← `(doSeq| alt)) pat [] - let items := - #[← `(doSeqItem| comefrom alt do $alt:doSeq)] ++ - (← mkLetDoSeqItem pat' rhs alt) ++ - auxs + let (pat', auxs) ← floatQMatch (← `(doSeq| $alt)) pat [] + let items := Array.mk <| (← mkLetDoSeqItem pat' rhs alt) ++ auxs `(doElem| do $items:doSeqItem*) | `(match $[$discrs:term],* with $[| $[$patss],* => $rhss]*) => do @@ -373,14 +359,15 @@ macro_rules pure (← `(x), ← `(doSeqItem| let x := $d:term)) let mut items := discrs.map (·.2) let discrs := discrs.map (·.1) - items := items.push (← `(doSeqItem| comefrom alt do throwError "nonexhaustive match")) + let mut alt : TSyntax `doElem ← `(doElem| throwError "nonexhaustive match") for pats in patss.reverse, rhs in rhss.reverse do let mut subItems : Array (TSyntax ``doSeqItem) := #[] for discr in discrs, pat in pats do - subItems := subItems ++ (← mkLetDoSeqItem pat (← `(doElem| pure $discr:term)) (← `(doSeq| alt))) - subItems := subItems.push (← `(doSeqItem| do $rhs)) - items := items.push (← `(doSeqItem| comefrom alt do $subItems:doSeqItem*)) - items := items.push (← `(doSeqItem| alt)) - `(doElem| (do $items:doSeqItem*)) + subItems := + subItems ++ (← mkLetDoSeqItem pat (← `(doElem| pure $discr:term)) (←`(doSeq|$alt:doElem))) + subItems := subItems.push (←`(doSeqItem|do $rhs)) + alt ← `(doElem| do $subItems:doSeqItem*) + items := items.push (←`(doSeqItem|$alt:doElem)) + `(doElem| do $items:doSeqItem*) end diff --git a/examples/matching.lean b/examples/matching.lean index e05d754..0a2b767 100644 --- a/examples/matching.lean +++ b/examples/matching.lean @@ -27,6 +27,7 @@ abbrev square (a : Nat) := #eval summands q(inferInstance) q(k + square (square k)) #eval summands q(⟨(· * ·)⟩) q(k * square (square k)) +set_option pp.macroStack true in def matchProd (e : Nat × Q(Nat)) : MetaM Bool := do let (2, ~q(1)) := e | return false return true @@ -51,3 +52,27 @@ def getNatAdd (e : Expr) : MetaM (Option (Q(Nat) × Q(Nat))) := do #eval do guard <| (← getNatAdd q(1 + 2)) == some (q(1), q(2)) #eval do guard <| (← getNatAdd q((1 + 2 : Int))) == none + + + +section test_return + +def foo1 (T : Q(Type)) : MetaM Nat := do + let x : Nat ← match T with + | ~q(Prop) => return (2 : Nat) + | _ => pure (1 : Nat) + pure (3 + x) + +#eval do guard <| (←foo1 q(Prop)) == 2 +#eval do guard <| (←foo1 q(Nat)) == 3 + 1 + +def foo2 (T : Q(Type)) : MetaM Nat := do + let x : Nat ← match T with + | ~q(Prop) => pure (2 : Nat) + | _ => return (1 : Nat) + pure (3 + x) + +#eval do guard <| (←foo2 q(Prop)) == 3 + 2 +#eval do guard <| (←foo2 q(Nat)) == 1 + +end test_return