Skip to content

Commit

Permalink
further speedups
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisHughes24 committed Nov 21, 2022
1 parent 801e4d9 commit 119b68f
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 72 deletions.
2 changes: 1 addition & 1 deletion _target/deps/mathlib
Submodule mathlib updated from aae5e5 to 568eb9
193 changes: 124 additions & 69 deletions src/v3/circuit.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ inductive circuit (α : Type u) : Type u
| or : circuit → circuit → circuit
| not : circuit → circuit
| xor : circuit → circuit → circuit
| imp : circuit → circuit → circuit

namespace circuit
variables {α : Type u} {β : Type v}
Expand All @@ -24,7 +23,6 @@ def reppr [has_repr α] : (circuit α) → string
| (or x y) := "(" ++ reppr x ++ "" ++ reppr y ++ ")"
| (not x) := "¬" ++ reppr x
| (xor x y) := "(" ++ reppr x ++ "" ++ reppr y ++ ")"
| (imp x y) := "(" ++ reppr x ++ "" ++ reppr y ++ ")"

instance [has_repr α] : has_repr (circuit α) := ⟨reppr⟩

Expand All @@ -36,7 +34,6 @@ def vars [decidable_eq α] : circuit α → list α
| (or c₁ c₂) := (vars c₁ ++ vars c₂).dedup
| (not c) := vars c
| (xor c₁ c₂) := (vars c₁ ++ vars c₂).dedup
| (imp c₁ c₂) := (vars c₁ ++ vars c₂).dedup

@[simp] def eval : circuit α → (α → bool) → bool
| true _ := tt
Expand All @@ -46,7 +43,6 @@ def vars [decidable_eq α] : circuit α → list α
| (or c₁ c₂) f := (eval c₁ f) || (eval c₂ f)
| (not c) f := bnot (eval c f)
| (xor c₁ c₂) f := bxor (eval c₁ f) (eval c₂ f)
| (imp c₁ c₂) f := bnot (eval c₁ f) || (eval c₂ f)

@[simp] def evalv [decidable_eq α] : Π (c : circuit α), (Π a ∈ vars c, bool) → bool
| true _ := tt
Expand All @@ -59,8 +55,6 @@ def vars [decidable_eq α] : circuit α → list α
| (not c) f := bnot (evalv c (λ i hi, f i (by simp [hi, vars])))
| (xor c₁ c₂) f := bxor (evalv c₁ (λ i hi, f i (by simp [hi, vars])))
(evalv c₂ (λ i hi, f i (by simp [hi, vars])))
| (imp c₁ c₂) f := bnot (evalv c₁ (λ i hi, f i (by simp [hi, vars]))) ||
(evalv c₂ (λ i hi, f i (by simp [hi, vars])))

lemma eval_eq_evalv [decidable_eq α] : ∀ (c : circuit α) (f : α → bool),
eval c f = evalv c (λ x _, f x)
Expand All @@ -71,7 +65,6 @@ lemma eval_eq_evalv [decidable_eq α] : ∀ (c : circuit α) (f : α → bool),
| (or c₁ c₂) f := by rw [eval, evalv, eval_eq_evalv, eval_eq_evalv]
| (not c) f := by rw [eval, evalv, eval_eq_evalv]
| (xor c₁ c₂) f := by rw [eval, evalv, eval_eq_evalv, eval_eq_evalv]
| (imp c₁ c₂) f := by rw [eval, evalv, eval_eq_evalv, eval_eq_evalv]

@[simp] def of_bool : bool → circuit α
| tt := true
Expand All @@ -85,12 +78,25 @@ instance : preorder (circuit α) :=
le_refl := λ c f h, h,
le_trans := λ c₁ c₂ c₃ h₁₂ h₂₃ f h₁, h₂₃ f (h₁₂ f h₁) }

instance [decidable_eq α] :
decidable_rel ((≤) : circuit α → circuit α → Prop) :=
λ c₁ c₂, decidable_of_iff (∀ (x : Π (i : α), (i ∈ (c₁.vars ++ c₂.vars).dedup) → bool),
x ∈ (c₁.vars ++ c₂.vars).dedup.pi (λ _, [tt, ff]) →
c₁.evalv (λ i hi, x i (by simp [hi])) → c₂.evalv (λ i hi, x i (by simp [hi])))
sorry
lemma le_def : ∀ (c₁ c₂ : circuit α), c₁ ≤ c₂ ↔ ∀ f, eval c₁ f → eval c₂ f :=
λ _ _, iff.rfl

lemma exists_eval_iff_exists_evalv [decidable_eq α] (c : circuit α) :
(∃ x, eval c x) ↔ ∃ x, evalv c x :=
begin
split,
{ rintro ⟨x, hx⟩,
use λ a _, x a,
rw [eval_eq_evalv] at hx,
exact hx },
{ rintro ⟨x, hx⟩,
refine ⟨λ a, dite (a ∈ c.vars) (x a) (λ _, ff), _⟩,
convert hx,
rw [eval_eq_evalv],
congr' 1,
ext i hi,
simp [hi] }
end

def simplify_and : circuit α → circuit α → circuit α
| true c := c
Expand Down Expand Up @@ -150,20 +156,6 @@ begin
cases c₁; cases c₂; simp [*, circuit.simplify_xor, eval, bnot_bxor] at *,
end

def simplify_imp : circuit α → circuit α → circuit α
| false _ := true
| _ true := true
| true c := c
| c false := not c
| c₁ c₂ := imp c₁ c₂

@[simp] lemma eval_simplify_imp : ∀ (c₁ c₂ : circuit α) (f : α → bool),
eval (simplify_imp c₁ c₂) f = bnot (eval c₁ f) || (eval c₂ f) :=
begin
intros c₁ c₂ f,
cases c₁; cases c₂; simp [*, circuit.simplify_imp, eval] at *
end

def map : Π (c : circuit α) (f : α → β), circuit β
| true _ := true
| false _ := false
Expand All @@ -172,7 +164,6 @@ def map : Π (c : circuit α) (f : α → β), circuit β
| (or c₁ c₂) f := simplify_or (map c₁ f) (map c₂ f)
| (not c) f := simplify_not (map c f)
| (xor c₁ c₂) f := simplify_xor (map c₁ f) (map c₂ f)
| (imp c₁ c₂) f := simplify_imp (map c₁ f) (map c₂ f)

lemma eval_map {c : circuit α} {f : α → β} {g : β → bool} :
eval (map c f) g = eval c (λ x, g (f x)) :=
Expand All @@ -188,7 +179,6 @@ def simplify : Π (c : circuit α), circuit α
| (or c₁ c₂) := simplify_or (simplify c₁) (simplify c₂)
| (not c) := simplify_not (simplify c)
| (xor c₁ c₂) := simplify_xor (simplify c₁) (simplify c₂)
| (imp c₁ c₂) := simplify_imp (simplify c₁) (simplify c₂)

@[simp] lemma eval_simplify : Π {c : circuit α} {f : α → bool},
eval (simplify c) f = eval c f
Expand All @@ -199,7 +189,6 @@ def simplify : Π (c : circuit α), circuit α
| (or c₁ c₂) f := by rw [simplify]; simp *
| (not c) f := by rw [simplify]; simp *
| (xor c₁ c₂) f := by rw [simplify]; simp *
| (imp c₁ c₂) f := by rw [simplify]; simp *

def sum_vars_left [decidable_eq α] [decidable_eq β] : circuit (α ⊕ β) → list α
| true := []
Expand All @@ -210,7 +199,6 @@ def sum_vars_left [decidable_eq α] [decidable_eq β] : circuit (α ⊕ β) →
| (or c₁ c₂) := (sum_vars_left c₁ ++ sum_vars_left c₂).dedup
| (not c) := sum_vars_left c
| (xor c₁ c₂) := (sum_vars_left c₁ ++ sum_vars_left c₂).dedup
| (imp c₁ c₂) := (sum_vars_left c₁ ++ sum_vars_left c₂).dedup

def sum_vars_right [decidable_eq α] [decidable_eq β] : circuit (α ⊕ β) → list β
| true := []
Expand All @@ -221,7 +209,6 @@ def sum_vars_right [decidable_eq α] [decidable_eq β] : circuit (α ⊕ β) →
| (or c₁ c₂) := (sum_vars_right c₁ ++ sum_vars_right c₂).dedup
| (not c) := sum_vars_right c
| (xor c₁ c₂) := (sum_vars_right c₁ ++ sum_vars_right c₂).dedup
| (imp c₁ c₂) := (sum_vars_right c₁ ++ sum_vars_right c₂).dedup

lemma eval_eq_of_eq_on_vars [decidable_eq α] : Π {c : circuit α} {f g : α → bool}
(h : ∀ x ∈ c.vars, f x = g x), eval c f = eval c g
Expand Down Expand Up @@ -254,13 +241,6 @@ begin
eval_eq_of_eq_on_vars (λ x hx, h x (or.inl hx)),
eval_eq_of_eq_on_vars (λ x hx, h x (or.inr hx))]
end
| (imp c₁ c₂) f g h :=
begin
simp only [vars, list.mem_append, list.mem_dedup] at h,
rw [eval, eval,
eval_eq_of_eq_on_vars (λ x hx, h x (or.inl hx)),
eval_eq_of_eq_on_vars (λ x hx, h x (or.inr hx))]
end

@[simp] lemma mem_vars_iff_mem_sum_vars_left [decidable_eq α] [decidable_eq β] :
Π {c : circuit (α ⊕ β)} {x : α},
Expand Down Expand Up @@ -289,11 +269,6 @@ end
simp [vars, sum_vars_left],
simp [mem_vars_iff_mem_sum_vars_left]
end
| (imp c₁ c₂) _ :=
begin
simp [vars, sum_vars_left],
simp [mem_vars_iff_mem_sum_vars_left]
end

@[simp] lemma mem_vars_iff_mem_sum_vars_right [decidable_eq α] [decidable_eq β] :
Π {c : circuit (α ⊕ β)} {x : β},
Expand Down Expand Up @@ -322,11 +297,6 @@ end
simp [vars, sum_vars_right],
simp [mem_vars_iff_mem_sum_vars_right]
end
| (imp c₁ c₂) _ :=
begin
simp [vars, sum_vars_right],
simp [mem_vars_iff_mem_sum_vars_right]
end

lemma eval_eq_of_eq_on_sum_vars_left_right
[decidable_eq α] [decidable_eq β] :
Expand Down Expand Up @@ -370,15 +340,6 @@ begin
eval_eq_of_eq_on_sum_vars_left_right
(λ x hx, h₁ x (or.inr hx)) (λ x hx, h₂ x (or.inr hx))]
end
| (imp c₁ c₂) f g h₁ h₂ :=
begin
simp only [sum_vars_left, sum_vars_right, list.mem_append, list.mem_dedup] at h₁ h₂,
rw [eval, eval,
eval_eq_of_eq_on_sum_vars_left_right
(λ x hx, h₁ x (or.inl hx)) (λ x hx, h₂ x (or.inl hx)),
eval_eq_of_eq_on_sum_vars_left_right
(λ x hx, h₁ x (or.inr hx)) (λ x hx, h₂ x (or.inr hx))]
end

def bOr : Π (s : list α) (f : α → circuit β), circuit β
| [] _ := false
Expand Down Expand Up @@ -448,8 +409,6 @@ def assign_vars [decidable_eq α] :
| (not c) f := simplify_not (assign_vars c (λ x hx, f x (by simp [hx, vars])))
| (xor c₁ c₂) f := simplify_xor (assign_vars c₁ (λ x hx, f x (by simp [hx, vars])))
(assign_vars c₂ (λ x hx, f x (by simp [hx, vars])))
| (imp c₁ c₂) f := simplify_imp (assign_vars c₁ (λ x hx, f x (by simp [hx, vars])))
(assign_vars c₂ (λ x hx, f x (by simp [hx, vars])))

lemma eval_assign_vars [decidable_eq α] : ∀ {c : circuit α}
{f : Π (a : α) (ha : a ∈ c.vars), β ⊕ bool} {g : β → bool},
Expand Down Expand Up @@ -479,10 +438,6 @@ end
simp [assign_vars, eval, vars],
rw [eval_assign_vars, eval_assign_vars]
end
| (imp c₁ c₂) f g := begin
simp [assign_vars, eval, vars],
rw [eval_assign_vars, eval_assign_vars]
end

def bind : Π (c : circuit α) (f : α → circuit β), circuit β
| true _ := true
Expand All @@ -492,7 +447,6 @@ def bind : Π (c : circuit α) (f : α → circuit β), circuit β
| (or c₁ c₂) f := simplify_or (bind c₁ f) (bind c₂ f)
| (not c) f := simplify_not (bind c f)
| (xor c₁ c₂) f := simplify_xor (bind c₁ f) (bind c₂ f)
| (imp c₁ c₂) f := simplify_imp (bind c₁ f) (bind c₂ f)

lemma eval_bind : ∀ (c : circuit α) (f : α → circuit β) (g : β → bool),
eval (bind c f) g = eval c (λ a, eval (f a) g)
Expand All @@ -515,10 +469,6 @@ end
simp [bind, eval],
rw [eval_bind, eval_bind]
end
| (imp c₁ c₂) f g := begin
simp [bind, eval],
rw [eval_bind, eval_bind]
end

def single [decidable_eq α] {s : list α} (x : Π a ∈ s, bool) : circuit α :=
bAnd s (λ i, if hi : i ∈ s then cond (x i hi) (var i) (not (var i)) else true)
Expand All @@ -539,5 +489,110 @@ begin
cases x a ha; simp [h] }
end

def nonempty_aux [decidable_eq α] :
Π (c : circuit α), { b : bool // (∃ x, eval c x) = (b : Prop) }
| true := ⟨tt, by simp⟩
| false := ⟨ff, by simp⟩
| (var x) := ⟨tt, by simp⟩
-- | (or c₁ c₂) :=
-- let b₁ := nonempty_aux c₁ in
-- let b₂ := nonempty_aux c₂ in
-- ⟨b₁ || b₂, by
-- { simp only [exists_or_distrib, eval, bor_coe_iff, eq_iff_iff],
-- rw [← b₁.prop, ← b₂.prop] }⟩
-- | (and c₁ (or c₂ c₃)) :=
-- let b₁ := nonempty_aux (and c₁ c₂) in
-- let b₂ := nonempty_aux (and c₁ c₃) in
-- ⟨b₁ || b₂, by
-- { simp only [eval, bor_coe_iff, eq_iff_iff, band_coe_iff, and_or_distrib_left,
-- exists_or_distrib],
-- rw [← b₁.prop, ← b₂.prop],
-- simp }⟩
-- | (and (or c₁ c₂) c₃) :=
-- let b₁ := nonempty_aux (and c₁ c₃) in
-- let b₂ := nonempty_aux (and c₂ c₃) in
-- ⟨b₁ || b₂, by
-- { simp only [eval, bor_coe_iff, eq_iff_iff, band_coe_iff, or_and_distrib_right,
-- exists_or_distrib],
-- rw [← b₁.prop, ← b₂.prop],
-- simp }⟩
-- | (not (and c₁ c₂)) :=
-- let b₁ := nonempty_aux (not c₁) in
-- let b₂ := nonempty_aux (not c₂) in
-- ⟨b₁ || b₂, by
-- { simp only [eval, bor_coe_iff, eq_iff_iff, band_coe_iff, not_and_distrib, exists_or_distrib,
-- bool.bnot_iff_not, eq_ff_eq_not_eq_tt],
-- rw [← b₁.prop, ← b₂.prop],
-- simp }⟩
| c :=
let v := vars c in
have hv : vars c = v := rfl,
match v, hv with
| [], hv := ⟨eval c (λ _, ff), begin
rw [exists_eval_iff_exists_evalv, eq_iff_iff, eval_eq_evalv],
split,
{ rintros ⟨x, hx⟩,
convert hx,
ext i hi,
rw [hv] at hi,
exact false.elim hi },
{ intro h,
use λ i _, ff,
exact h }
end
| (i::l), hv :=
let c₁ := c.assign_vars (λ j _, if i = j then sum.inr tt else sum.inl j) in
let c₂ := c.assign_vars (λ j _, if i = j then sum.inr ff else sum.inl j) in
have h₁ : sizeof c₁ < sizeof c := sorry,
have h₂ : sizeof c₂ < sizeof c := sorry,
let b₁ := nonempty_aux c₁ in
let b₂ := nonempty_aux c₂ in
⟨b₁ || b₂, begin
simp only [bor_coe_iff, eq_iff_iff, eval_eq_evalv],
rw [← b₁.prop, ← b₂.prop],
simp only [eval_assign_vars],
split,
{ rintro ⟨x, hx⟩,
cases hi : x i,
{ right,
use x,
convert hx,
ext j hj,
split_ifs,
{ subst h,
simp [hi] },
{ simp } },
{ left,
use x,
convert hx,
ext j hj,
split_ifs,
{ subst h,
simp [hi] },
{ simp } } },
{ intro h,
rcases h with ⟨x, hx⟩ | ⟨x, hx⟩,
{ refine ⟨_, hx⟩ },
{ refine ⟨_, hx⟩ } }
end
end

def nonempty [decidable_eq α] (c : circuit α) : bool :=
(nonempty_aux c).1

lemma nonempty_iff [decidable_eq α] (c : circuit α) :
nonempty c ↔ ∃ x, eval c x :=
by rw [nonempty, ← (nonempty_aux c).2]

def always_true [decidable_eq α] (c : circuit α) : bool :=
bnot (nonempty (not c))

lemma always_true_iff [decidable_eq α] (c : circuit α) :
always_true c ↔ ∀ x, eval c x :=
by simp [always_true, nonempty_iff, not_not]

instance [decidable_eq α] : decidable_rel ((≤) : circuit α → circuit α → Prop) :=
λ c₁ c₂, decidable_of_iff (always_true ((not c₁).or c₂))
begin simp [always_true_iff, le_def, or_iff_not_imp_left], end

end circuit
8 changes: 6 additions & 2 deletions src/v3/strucs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -360,15 +360,19 @@ set_option profiler true
def x : term := term.var 0
def y : term := term.var 1
def z : term := term.var 2
def a : term := term.var 3
def b : term := term.var 4
def c : term := term.var 5
def d : term := term.var 6

#eval check_eq (x +- y) 0 2
#eval check_eq (x +- x) 0 2
#eval check_eq (x - y) (x + -y) 2
#eval check_eq (x + 1) x.incr 2
#eval check_eq (x - 1) x.decr 2

#eval check_eq (x.xor x) term.zero 1
#eval check_eq (x + y) (y + x) 1
#eval check_eq (x + (y + z)) ((x + y) + z) 2
#eval check_eq ((x + y) + z) (x + (y + z)) 2
#eval check_eq (not (xor x y)) (and x y - or x y - 1) 2

-- #eval (bitwise_struc bxor).nth_output (λ _, (tt, tt)) 0
Expand Down

0 comments on commit 119b68f

Please sign in to comment.