Skip to content
This repository has been archived by the owner on Jul 24, 2024. It is now read-only.

Commit

Permalink
feat(data/fin/tuple): rename fin.append to matrix.vec_append, int…
Browse files Browse the repository at this point in the history
…roduce a new `fin.append` and `fin.repeat`. (#10346)

We already had `fin.append v w h`, which combines the append operation with casting.

This commit removes the `h` argument, allowing it to be defeq to `fin.add_cases`, and moves the previous definition to the name `matrix.vec_append` matching `matrix.vec_cons` and similar. Another advantage of dropping `h` is that it is clearer how to state lemmas like `append_assoc`, as we only have one proof argument to keep track of instead of four.
As it turns out, to implement a `gmonoid` structure on tuples (under concatenation), `fin.append` without the `h` argument is all that's needed.

We implement `matrix.vec_append` in terms of `fin.append`, but provide a lemma that unfolds it to the old definition so as to avoid having to rewrite all the other proofs.

Removing `matrix.vec_append` entirely is left to investigate in some future PR.
  • Loading branch information
eric-wieser committed Jan 16, 2023
1 parent ec2dfca commit 59505c3
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 73 deletions.
117 changes: 106 additions & 11 deletions src/data/fin/tuple/basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ We define the following operations:
* `fin.insert_nth` : insert an element to a tuple at a given position.
* `fin.find p` : returns the first index `n` where `p n` is satisfied, and `none` if it is never
satisfied.
* `fin.append a b` : append two tuples.
* `fin.repeat n a` : repeat a tuple `n` times.
-/
universes u v
Expand Down Expand Up @@ -246,17 +248,110 @@ set.ext $ λ y, exists_fin_succ.trans $ eq_comm.or iff.rfl
set.range (fin.cons x b : fin n.succ → α) = insert x (set.range b) :=
by rw [range_fin_succ, cons_zero, tail_cons]

/-- `fin.append ho u v` appends two vectors of lengths `m` and `n` to produce
one of length `o = m + n`. `ho` provides control of definitional equality
for the vector length. -/
def append {α : Type*} {o : ℕ} (ho : o = m + n) (u : fin m → α) (v : fin n → α) : fin o → α :=
λ i, if h : (i : ℕ) < m
then u ⟨i, h⟩
else v ⟨(i : ℕ) - m, (tsub_lt_iff_left (le_of_not_lt h)).2 (ho ▸ i.property)⟩

@[simp] lemma fin_append_apply_zero {α : Type*} {o : ℕ} (ho : (o + 1) = (m + 1) + n)
(u : fin (m + 1) → α) (v : fin n → α) :
fin.append ho u v 0 = u 0 := rfl
section append

/-- Append a tuple of length `m` to a tuple of length `n` to get a tuple of length `m + n`.
This is a non-dependent version of `fin.add_cases`. -/
def append {α : Type*} (a : fin m → α) (b : fin n → α) : fin (m + n) → α :=
@fin.add_cases _ _ (λ _, α) a b

@[simp] lemma append_left {α : Type*} (u : fin m → α) (v : fin n → α) (i : fin m) :
append u v (fin.cast_add n i) = u i :=
add_cases_left _ _ _

@[simp] lemma append_right {α : Type*} (u : fin m → α) (v : fin n → α) (i : fin n) :
append u v (nat_add m i) = v i :=
add_cases_right _ _ _

lemma append_right_nil {α : Type*} (u : fin m → α) (v : fin n → α) (hv : n = 0) :
append u v = u ∘ fin.cast (by rw [hv, add_zero]) :=
begin
refine funext (fin.add_cases (λ l, _) (λ r, _)),
{ rw [append_left, function.comp_apply],
refine congr_arg u (fin.ext _),
simp },
{ exact (fin.cast hv r).elim0' }
end

@[simp] lemma append_elim0' {α : Type*} (u : fin m → α) :
append u fin.elim0' = u ∘ fin.cast (add_zero _) :=
append_right_nil _ _ rfl

lemma append_left_nil {α : Type*} (u : fin m → α) (v : fin n → α) (hu : m = 0) :
append u v = v ∘ fin.cast (by rw [hu, zero_add]) :=
begin
refine funext (fin.add_cases (λ l, _) (λ r, _)),
{ exact (fin.cast hu l).elim0' },
{ rw [append_right, function.comp_apply],
refine congr_arg v (fin.ext _),
simp [hu] },
end

@[simp] lemma elim0'_append {α : Type*} (v : fin n → α) :
append fin.elim0' v = v ∘ fin.cast (zero_add _) :=
append_left_nil _ _ rfl

lemma append_assoc {p : ℕ} {α : Type*} (a : fin m → α) (b : fin n → α) (c : fin p → α) :
append (append a b) c = append a (append b c) ∘ fin.cast (add_assoc _ _ _) :=
begin
ext i,
rw function.comp_apply,
apply fin.add_cases (λ l, _) (λ r, _) i,
{ rw append_left,
apply fin.add_cases (λ ll, _) (λ lr, _) l,
{ rw append_left,
simp [cast_add_cast_add] },
{ rw append_right,
simp [cast_add_nat_add], }, },
{ rw append_right,
simp [←nat_add_nat_add] },
end

end append

section repeat

/-- Repeat `a` `m` times. For example `fin.repeat 2 ![0, 3, 7] = ![0, 3, 7, 0, 3, 7]`. -/
@[simp] def repeat {α : Type*} (m : ℕ) (a : fin n → α) : fin (m * n) → α
| i := a i.mod_nat

@[simp] lemma repeat_zero {α : Type*} (a : fin n → α) :
repeat 0 a = fin.elim0' ∘ cast (zero_mul _) :=
funext $ λ x, (cast (zero_mul _) x).elim0'

@[simp] lemma repeat_one {α : Type*} (a : fin n → α) :
repeat 1 a = a ∘ cast (one_mul _) :=
begin
generalize_proofs h,
apply funext,
rw (fin.cast h.symm).surjective.forall,
intro i,
simp [mod_nat, nat.mod_eq_of_lt i.is_lt],
end

lemma repeat_succ {α : Type*} (a : fin n → α) (m : ℕ) :
repeat m.succ a = append a (repeat m a) ∘ cast ((nat.succ_mul _ _).trans (add_comm _ _)) :=
begin
generalize_proofs h,
apply funext,
rw (fin.cast h.symm).surjective.forall,
refine fin.add_cases (λ l, _) (λ r, _),
{ simp [mod_nat, nat.mod_eq_of_lt l.is_lt], },
{ simp [mod_nat] }
end

@[simp] lemma repeat_add {α : Type*} (a : fin n → α) (m₁ m₂ : ℕ) :
repeat (m₁ + m₂) a = append (repeat m₁ a) (repeat m₂ a) ∘ cast (add_mul _ _ _) :=
begin
generalize_proofs h,
apply funext,
rw (fin.cast h.symm).surjective.forall,
refine fin.add_cases (λ l, _) (λ r, _),
{ simp [mod_nat, nat.mod_eq_of_lt l.is_lt], },
{ simp [mod_nat, nat.add_mod] }
end

end repeat

end tuple

Expand Down
58 changes: 42 additions & 16 deletions src/data/fin/vec_notation.lean
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ variables {α : Type u}
section matrix_notation

/-- `![]` is the vector with no entries. -/
def vec_empty : fin 0 → α :=
fin_zero_elim
def vec_empty : fin 0 → α := fin.elim0'

/-- `vec_cons h t` prepends an entry `h` to a vector `t`.
Expand Down Expand Up @@ -174,16 +173,43 @@ of elements by virtue of the semantics of `bit0` and `bit1` and of
addition on `fin n`).
-/

@[simp] lemma empty_append (v : fin n → α) : fin.append (zero_add _).symm ![] v = v :=
by { ext, simp [fin.append] }
/-- `vec_append ho u v` appends two vectors of lengths `m` and `n` to produce
one of length `o = m + n`. This is a variant of `fin.append` with an additional `ho` argument,
which provides control of definitional equality for the vector length.
This turns out to be helpful when providing simp lemmas to reduce `![a, b, c] n`, and also means
that `vec_append ho u v 0` is valid. `fin.append u v 0` is not valid in this case because there is
no `has_zero (fin (m + n))` instance. -/
def vec_append {α : Type*} {o : ℕ} (ho : o = m + n) (u : fin m → α) (v : fin n → α) : fin o → α :=
fin.append u v ∘ fin.cast ho

lemma vec_append_eq_ite {α : Type*} {o : ℕ} (ho : o = m + n) (u : fin m → α) (v : fin n → α) :
vec_append ho u v = λ i,
if h : (i : ℕ) < m
then u ⟨i, h⟩
else v ⟨(i : ℕ) - m, (tsub_lt_iff_left (le_of_not_lt h)).2 (ho ▸ i.property)⟩ :=
begin
ext i,
rw [vec_append, fin.append, function.comp_apply, fin.add_cases],
congr' with hi,
simp only [eq_rec_constant],
refl,
end

@[simp] lemma vec_append_apply_zero {α : Type*} {o : ℕ} (ho : (o + 1) = (m + 1) + n)
(u : fin (m + 1) → α) (v : fin n → α) :
vec_append ho u v 0 = u 0 := rfl

@[simp] lemma empty_vec_append (v : fin n → α) : vec_append (zero_add _).symm ![] v = v :=
by { ext, simp [vec_append_eq_ite] }

@[simp] lemma cons_append (ho : o + 1 = m + 1 + n) (x : α) (u : fin m → α) (v : fin n → α) :
fin.append ho (vec_cons x u) v =
vec_cons x (fin.append (by rwa [add_assoc, add_comm 1, ←add_assoc,
@[simp] lemma cons_vec_append (ho : o + 1 = m + 1 + n) (x : α) (u : fin m → α) (v : fin n → α) :
vec_append ho (vec_cons x u) v =
vec_cons x (vec_append (by rwa [add_assoc, add_comm 1, ←add_assoc,
add_right_cancel_iff] at ho) u v) :=
begin
ext i,
simp_rw [fin.append],
simp_rw [vec_append_eq_ite],
split_ifs with h,
{ rcases i with ⟨⟨⟩ | i, hi⟩,
{ simp },
Expand All @@ -205,10 +231,10 @@ only alternate elements (odd-numbered). -/
def vec_alt1 (hm : m = n + n) (v : fin m → α) (k : fin n) : α :=
v ⟨(k : ℕ) + k + 1, hm.symm ▸ nat.add_succ_lt_add k.property k.property⟩

lemma vec_alt0_append (v : fin n → α) : vec_alt0 rfl (fin.append rfl v v) = v ∘ bit0 :=
lemma vec_alt0_vec_append (v : fin n → α) : vec_alt0 rfl (vec_append rfl v v) = v ∘ bit0 :=
begin
ext i,
simp_rw [function.comp, bit0, vec_alt0, fin.append],
simp_rw [function.comp, bit0, vec_alt0, vec_append_eq_ite],
split_ifs with h; congr,
{ rw fin.coe_mk at h,
simp only [fin.ext_iff, fin.coe_add, fin.coe_mk],
Expand All @@ -220,10 +246,10 @@ begin
exact add_lt_add i.property i.property }
end

lemma vec_alt1_append (v : fin (n + 1) → α) : vec_alt1 rfl (fin.append rfl v v) = v ∘ bit1 :=
lemma vec_alt1_vec_append (v : fin (n + 1) → α) : vec_alt1 rfl (vec_append rfl v v) = v ∘ bit1 :=
begin
ext i,
simp_rw [function.comp, vec_alt1, fin.append],
simp_rw [function.comp, vec_alt1, vec_append_eq_ite],
cases n,
{ simp, congr },
{ split_ifs with h; simp_rw [bit1, bit0]; congr,
Expand All @@ -248,12 +274,12 @@ end
by simp [vec_head, vec_alt1]

@[simp] lemma cons_vec_bit0_eq_alt0 (x : α) (u : fin n → α) (i : fin (n + 1)) :
vec_cons x u (bit0 i) = vec_alt0 rfl (fin.append rfl (vec_cons x u) (vec_cons x u)) i :=
by rw vec_alt0_append
vec_cons x u (bit0 i) = vec_alt0 rfl (vec_append rfl (vec_cons x u) (vec_cons x u)) i :=
by rw vec_alt0_vec_append

@[simp] lemma cons_vec_bit1_eq_alt1 (x : α) (u : fin n → α) (i : fin (n + 1)) :
vec_cons x u (bit1 i) = vec_alt1 rfl (fin.append rfl (vec_cons x u) (vec_cons x u)) i :=
by rw vec_alt1_append
vec_cons x u (bit1 i) = vec_alt1 rfl (vec_append rfl (vec_cons x u) (vec_cons x u)) i :=
by rw vec_alt1_vec_append

@[simp] lemma cons_vec_alt0 (h : m + 1 + 1 = (n + 1) + (n + 1)) (x y : α) (u : fin m → α) :
vec_alt0 h (vec_cons x (vec_cons y u)) = vec_cons x (vec_alt0
Expand Down
12 changes: 6 additions & 6 deletions src/linear_algebra/cross_product.lean
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ lemma triple_product_permutation (u v w : fin 3 → R) :
u ⬝ᵥ (v ×₃ w) = v ⬝ᵥ (w ×₃ u) :=
begin
simp only [cross_apply, vec3_dot_product,
matrix.head_cons, matrix.cons_vec_bit0_eq_alt0, matrix.empty_append, matrix.cons_val_one,
matrix.cons_vec_alt0, matrix.cons_append, matrix.cons_val_zero],
matrix.head_cons, matrix.cons_vec_bit0_eq_alt0, matrix.empty_vec_append, matrix.cons_val_one,
matrix.cons_vec_alt0, matrix.cons_vec_append, matrix.cons_val_zero],
ring,
end

Expand All @@ -108,17 +108,17 @@ theorem triple_product_eq_det (u v w : fin 3 → R) :
begin
simp only [vec3_dot_product, cross_apply, matrix.det_fin_three,
matrix.head_cons, matrix.cons_vec_bit0_eq_alt0, matrix.empty_vec_alt0, matrix.cons_vec_alt0,
matrix.vec_head_vec_alt0, fin.fin_append_apply_zero, matrix.empty_append, matrix.cons_append,
matrix.cons_val', matrix.cons_val_one, matrix.cons_val_zero],
matrix.vec_head_vec_alt0, matrix.vec_append_apply_zero, matrix.empty_vec_append,
matrix.cons_vec_append, matrix.cons_val', matrix.cons_val_one, matrix.cons_val_zero],
ring,
end

/-- The scalar quadruple product identity, related to the Binet-Cauchy identity. -/
theorem cross_dot_cross (u v w x : fin 3 → R) :
(u ×₃ v) ⬝ᵥ (w ×₃ x) = (u ⬝ᵥ w) * (v ⬝ᵥ x) - (u ⬝ᵥ x) * (v ⬝ᵥ w) :=
begin
simp only [vec3_dot_product, cross_apply, cons_append, cons_vec_bit0_eq_alt0,
cons_val_one, cons_vec_alt0, linear_map.mk₂_apply, cons_val_zero, head_cons, empty_append],
simp only [vec3_dot_product, cross_apply, cons_vec_append, cons_vec_bit0_eq_alt0,
cons_val_one, cons_vec_alt0, linear_map.mk₂_apply, cons_val_zero, head_cons, empty_vec_append],
ring_nf,
end

Expand Down
4 changes: 2 additions & 2 deletions src/model_theory/encoding.lean
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ begin
(fin_range n).map (option.some ∘ ts) ++ list_decode l,
{ induction (fin_range n) with i l' l'ih,
{ refl },
{ rw [cons_bind, append_assoc, ih, map_cons, l'ih, cons_append] } },
{ rw [cons_bind, list.append_assoc, ih, map_cons, l'ih, cons_append] } },
have h' : ∀ i, (list_decode ((fin_range n).bind (λ (i : fin n), (ts i).list_encode) ++ l)).nth
↑i = some (some (ts i)),
{ intro i,
Expand Down Expand Up @@ -268,7 +268,7 @@ begin
rw [list.drop_append_eq_append_drop, length_map, length_fin_range, nat.sub_self, drop,
drop_eq_nil_of_le, nil_append],
rw [length_map, length_fin_range], }, },
{ rw [list_encode, append_assoc, cons_append, list_decode],
{ rw [list_encode, list.append_assoc, cons_append, list_decode],
simp only [subtype.val_eq_coe] at *,
rw [(ih1 _).1, (ih1 _).2, (ih2 _).1, (ih2 _).2, sigma_imp, dif_pos rfl],
exact ⟨rfl, rfl⟩, },
Expand Down
2 changes: 1 addition & 1 deletion src/number_theory/legendre_symbol/gauss_sum.lean
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ begin
{ ext, congr, apply pow_one },
convert_to (0 + 1 * τ ^ 1 + 0 + (-1) * τ ^ 3 + 0 + (-1) * τ ^ 5 + 0 + 1 * τ ^ 7) ^ 2 = _,
{ simp only [χ₈_apply, matrix.cons_val_zero, matrix.cons_val_one, matrix.head_cons,
matrix.cons_vec_bit0_eq_alt0, matrix.cons_vec_bit1_eq_alt1, matrix.cons_append,
matrix.cons_vec_bit0_eq_alt0, matrix.cons_vec_bit1_eq_alt1, matrix.cons_vec_append,
matrix.cons_vec_alt0, matrix.cons_vec_alt1, int.cast_zero, int.cast_one, int.cast_neg,
zero_mul], refl },
convert_to 8 + (τ ^ 4 + 1) * (τ ^ 10 - 2 * τ ^ 8 - 2 * τ ^ 6 + 6 * τ ^ 4 + τ ^ 2 - 8) = _,
Expand Down
Loading

0 comments on commit 59505c3

Please sign in to comment.