diff --git a/src/Init/Data/Int.lean b/src/Init/Data/Int.lean index f909e7e18976..48c60e78ce15 100644 --- a/src/Init/Data/Int.lean +++ b/src/Init/Data/Int.lean @@ -14,3 +14,4 @@ import Init.Data.Int.LemmasAux import Init.Data.Int.Order import Init.Data.Int.Pow import Init.Data.Int.Cooper +import Init.Data.Int.Linear diff --git a/src/Init/Data/Int/Linear.lean b/src/Init/Data/Int/Linear.lean new file mode 100644 index 000000000000..32a20e8600db --- /dev/null +++ b/src/Init/Data/Int/Linear.lean @@ -0,0 +1,212 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Init.ByCases +import Init.Data.Prod +import Init.Data.Int.Lemmas +import Init.Data.Int.LemmasAux +import Init.Data.RArray + +namespace Int.Linear + +/-! Helper definitions and theorems for constructing linear arithmetic proofs. -/ + +abbrev Var := Nat +abbrev Context := Lean.RArray Int + +def Var.denote (ctx : Context) (v : Var) : Int := + ctx.get v + +inductive Expr where + | num (v : Int) + | var (i : Var) + | add (a b : Expr) + | sub (a b : Expr) + | mulL (k : Int) (a : Expr) + | mulR (a : Expr) (k : Int) + deriving Inhabited + +def Expr.denote (ctx : Context) : Expr → Int + | .add a b => Int.add (denote ctx a) (denote ctx b) + | .sub a b => Int.sub (denote ctx a) (denote ctx b) + | .num k => k + | .var v => v.denote ctx + | .mulL k e => Int.mul k (denote ctx e) + | .mulR e k => Int.mul (denote ctx e) k + +inductive Poly where + | num (k : Int) + | add (k : Int) (v : Var) (p : Poly) + deriving BEq, Repr + +def Poly.denote (ctx : Context) (p : Poly) : Int := + match p with + | .num k => k + | .add k v p => Int.add (Int.mul k (v.denote ctx)) (denote ctx p) + +def Poly.addConst (p : Poly) (k : Int) : Poly := + match p with + | .num k' => .num (k+k') + | .add k' v' p => .add k' v' (addConst p k) + +def Poly.insert (k : Int) (v : Var) (p : Poly) : Poly := + match p with + | .num k' => .add k v (.num k') + | .add k' v' p => + bif Nat.blt v v' then + .add k v <| .add k' v' p + else bif Nat.beq v v' then + if Int.add k k' == 0 then + p + else + .add (Int.add k k') v' p + else + .add k' v' (insert k v p) + +def Poly.norm (p : Poly) : Poly := + match p with + | .num k => .num k + | .add k v p => (norm p).insert k v + +def Expr.toPoly' (e : Expr) := + go 1 e (.num 0) +where + go (coeff : Int) : Expr → (Poly → Poly) + | .num k => bif k == 0 then id else (Poly.addConst · (Int.mul coeff k)) + | .var v => (.add coeff v ·) + | .add a b => go coeff a ∘ go coeff b + | .sub a b => go coeff a ∘ go (-coeff) b + | .mulL k a + | .mulR a k => bif k == 0 then id else go (Int.mul coeff k) a + +def Expr.toPoly (e : Expr) : Poly := + e.toPoly'.norm + +inductive PolyCnstr where + | eq (p : Poly) + | le (p : Poly) + deriving BEq, Repr + +def PolyCnstr.denote (ctx : Context) : PolyCnstr → Prop + | .eq p => p.denote ctx = 0 + | .le p => p.denote ctx ≤ 0 + +def PolyCnstr.norm : PolyCnstr → PolyCnstr + | .eq p => .eq p.norm + | .le p => .le p.norm + +inductive ExprCnstr where + | eq (p₁ p₂ : Expr) + | le (p₁ p₂ : Expr) + deriving Inhabited + +def ExprCnstr.denote (ctx : Context) : ExprCnstr → Prop + | .eq e₁ e₂ => e₁.denote ctx = e₂.denote ctx + | .le e₁ e₂ => e₁.denote ctx ≤ e₂.denote ctx + +def ExprCnstr.toPoly : ExprCnstr → PolyCnstr + | .eq e₁ e₂ => .eq (e₁.sub e₂).toPoly.norm + | .le e₁ e₂ => .le (e₁.sub e₂).toPoly.norm + +attribute [local simp] Int.add_comm Int.add_assoc Int.add_left_comm Int.add_mul Int.mul_add +attribute [local simp] Poly.insert Poly.denote Poly.norm Poly.addConst + +theorem Poly.denote_addConst (ctx : Context) (p : Poly) (k : Int) : (p.addConst k).denote ctx = p.denote ctx + k := by + induction p <;> simp [*] + +attribute [local simp] Poly.denote_addConst + +theorem Poly.denote_insert (ctx : Context) (k : Int) (v : Var) (p : Poly) : + (p.insert k v).denote ctx = p.denote ctx + k * v.denote ctx := by + induction p <;> simp [*] + next k' v' p' ih => + by_cases h₁ : Nat.blt v v' <;> simp [*] + by_cases h₂ : Nat.beq v v' <;> simp [*] + by_cases h₃ : k + k' = 0 <;> simp [*, Nat.eq_of_beq_eq_true h₂] + rw [← Int.add_mul] + simp [*] + +attribute [local simp] Poly.denote_insert + +theorem Poly.denote_norm (ctx : Context) (p : Poly) : p.norm.denote ctx = p.denote ctx := by + induction p <;> simp [*] + +attribute [local simp] Poly.denote_norm + +private theorem sub_fold (a b : Int) : a.sub b = a - b := rfl + +attribute [local simp] sub_fold +attribute [local simp] ExprCnstr.denote ExprCnstr.toPoly PolyCnstr.denote Expr.denote + +theorem Expr.denote_toPoly'_go (ctx : Context) (e : Expr) : + (toPoly'.go k e p).denote ctx = k * e.denote ctx + p.denote ctx := by + induction k, e using Expr.toPoly'.go.induct generalizing p with + | case1 k k' => + simp only [toPoly'.go] + by_cases h : k' == 0 + · simp [h, eq_of_beq h] + · simp [h, Var.denote] + | case2 k i => simp [toPoly'.go] + | case3 k a b iha ihb => simp [toPoly'.go, iha, ihb] + | case4 k a b iha ihb => + simp [toPoly'.go, iha, ihb, Int.mul_sub] + rw [Int.sub_eq_add_neg, ←Int.neg_mul, Int.add_assoc] + | case5 k k' a ih + | case6 k a k' ih => + simp only [toPoly'.go] + by_cases h : k' == 0 + · simp [h, eq_of_beq h] + · simp [h, cond_false, Int.mul_assoc] + simp at ih + rw [ih] + rw [Int.mul_assoc, Int.mul_comm k'] + +theorem Expr.denote_toPoly (ctx : Context) (e : Expr) : e.toPoly.denote ctx = e.denote ctx := by + simp [toPoly, toPoly', Expr.denote_toPoly'_go] + +attribute [local simp] Expr.denote_toPoly + +theorem ExprCnstr.denote_toPoly (ctx : Context) (c : ExprCnstr) : c.toPoly.denote ctx = c.denote ctx := by + cases c <;> simp + · rw [Int.sub_eq_zero] + · constructor + · exact Int.le_of_sub_nonpos + · exact Int.sub_nonpos_of_le + +instance : LawfulBEq Poly where + eq_of_beq {a} := by + induction a <;> intro b <;> cases b <;> simp_all! [BEq.beq] + · rename_i k₁ v₁ p₁ k₂ v₂ p₂ ih + intro _ _ h + exact ih h + rfl := by + intro a + induction a <;> simp! [BEq.beq] + · rename_i k v p ih + exact ih + +instance : LawfulBEq PolyCnstr where + eq_of_beq {a b} := by + cases a <;> cases b <;> rename_i p₁ p₂ <;> simp_all! [BEq.beq] + · show (p₁ == p₂) = true → _ + simp + · show (p₁ == p₂) = true → _ + simp + rfl {a} := by + cases a <;> rename_i p <;> show (p == p) = true + <;> simp + +theorem Expr.eq_of_toPoly_eq (ctx : Context) (e e' : Expr) (h : e.toPoly == e'.toPoly) : e.denote ctx = e'.denote ctx := by + have h := congrArg (Poly.denote ctx) (eq_of_beq h) + simp [Poly.norm] at h + assumption + +theorem ExprCnstr.eq_of_toPoly_eq (ctx : Context) (c c' : ExprCnstr) (h : c.toPoly == c'.toPoly) : c.denote ctx = c'.denote ctx := by + have h := congrArg (PolyCnstr.denote ctx) (eq_of_beq h) + rw [denote_toPoly, denote_toPoly] at h + assumption + +end Int.Linear diff --git a/tests/lean/run/liaByRefl.lean b/tests/lean/run/liaByRefl.lean new file mode 100644 index 000000000000..4143ec02d4af --- /dev/null +++ b/tests/lean/run/liaByRefl.lean @@ -0,0 +1,82 @@ +import Lean + +open Int.Linear + +-- Convenient RArray literals +elab tk:"#R[" ts:term,* "]" : term => do + let ts : Array Lean.Syntax := ts + let es ← ts.mapM fun stx => Lean.Elab.Term.elabTerm stx none + if h : 0 < es.size then + return (Lean.RArray.toExpr (← Lean.Meta.inferType es[0]!) id (Lean.RArray.ofArray es h)) + else + throwErrorAt tk "RArray cannot be empty" + +example (x₁ x₂ : Int) : + Expr.denote #R[x₁, x₂] (.add (.add (.var 0) (.var 1)) (.num 3)) + = + x₁ + x₂ + 3 := + rfl + + +example (x₁ x₂ : Int) : + Expr.denote #R[x₁, x₂] (.sub (.add (.mulR (.var 0) 4) (.mulL 2 (.var 1))) (.num 3)) + = + (x₁*4) + 2*x₂ - 3 := + rfl + +example : + Expr.toPoly (.add (.add (.var 1) (.var 1)) (.num 3)) + = + Expr.toPoly (.add (.num 3) (.mulL 2 (.var 1))) := + rfl + +example : + Expr.toPoly (.add (.add (.add (.var 1) (.var 1)) (.num 3)) (.var 2)) + = + Expr.toPoly (.add (.add (.num 3) (.var 2)) (.mulL 2 (.var 1))) := + rfl + +example (x₁ x₂ x₃ : Int) : + ExprCnstr.denote #R[x₁, x₂, x₃] (.eq (.sub (.add (.mulR (.var 0) 4) (.mulL 2 (.var 1))) (.num 3)) (.sub (.var 1) (.var 2))) + = + ((x₁*4) + 2*x₂ - 3 = x₂ - x₃) := + rfl + +example : + ExprCnstr.toPoly (.eq (.sub (.add (.mulR (.var 0) 4) (.mulL 2 (.var 1))) (.num 3)) (.sub (.var 1) (.var 2))) + = + ExprCnstr.toPoly (.eq (.add (.var 2) (.add (.var 1) (.add (.mulL 4 (.var 0)) (.num (-3))))) (.num 0)) := + rfl + +example (x₁ x₂ x₃ : Int) : (x₁ + x₂) + (x₂ + x₃) = x₃ + 2*x₂ + x₁ := + Expr.eq_of_toPoly_eq #R[x₁, x₂, x₃] + (Expr.add (Expr.add (Expr.var 0) (Expr.var 1)) (Expr.add (Expr.var 1) (Expr.var 2))) + (Expr.add (Expr.add (Expr.var 2) (Expr.mulL 2 (Expr.var 1))) (Expr.var 0)) + rfl + +example : + ExprCnstr.toPoly + (.eq (Expr.add (Expr.add (Expr.var 0) (Expr.var 1)) (Expr.add (Expr.var 1) (Expr.var 2))) + (Expr.add (Expr.var 2) (Expr.var 1))) + = + ExprCnstr.toPoly + (.eq (Expr.add (Expr.var 0) (Expr.var 1)) + (Expr.num 0)) + := + rfl + +example (x₁ x₂ x₃ : Int) : ((x₁ + x₂) + (x₂ + x₃) = x₃ + x₂) = (x₁ + x₂ = 0) := + ExprCnstr.eq_of_toPoly_eq #R[x₁, x₂, x₃] + (.eq (Expr.add (Expr.add (Expr.var 0) (Expr.var 1)) (Expr.add (Expr.var 1) (Expr.var 2))) + (Expr.add (Expr.var 2) (Expr.var 1))) + (.eq (Expr.add (Expr.var 0) (Expr.var 1)) + (Expr.num 0)) + rfl + +example (x₁ x₂ x₃ : Int) : ((x₁ + x₂) + (x₂ + x₃) ≤ x₃ + x₂) = (x₁ + x₂ ≤ 0) := + ExprCnstr.eq_of_toPoly_eq #R[x₁, x₂, x₃] + (.le (Expr.add (Expr.add (Expr.var 0) (Expr.var 1)) (Expr.add (Expr.var 1) (Expr.var 2))) + (Expr.add (Expr.var 2) (Expr.var 1))) + (.le (Expr.add (Expr.var 0) (Expr.var 1)) + (Expr.num 0)) + rfl