Skip to content

Commit

Permalink
feat: add Vector.mapM, ForIn/ToStream instances (#6835)
Browse files Browse the repository at this point in the history
This PR fills some gaps in the `Vector` API, adding `mapM`, `zip`, and
`ForIn'` and `ToStream` instances.
  • Loading branch information
kim-em authored Jan 29, 2025
1 parent aa65107 commit c93012f
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
27 changes: 27 additions & 0 deletions src/Init/Data/Vector/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ prelude
import Init.Data.Array.Lemmas
import Init.Data.Array.MapIdx
import Init.Data.Range
import Init.Data.Stream

/-!
# Vectors
Expand Down Expand Up @@ -178,6 +179,16 @@ which also receives the index of the element, and the fact that the index is les
@[inline] def mapFinIdx (v : Vector α n) (f : (i : Nat) → α → (h : i < n) → β) : Vector β n :=
⟨v.toArray.mapFinIdx (fun i a h => f i a (by simpa [v.size_toArray] using h)), by simp⟩

/-- Map a monadic function over a vector. -/
def mapM [Monad m] (f : α → m β) (v : Vector α n) : m (Vector β n) := do
go 0 (Nat.zero_le n) #v[]
where
go (i : Nat) (h : i ≤ n) (r : Vector β i) : m (Vector β n) := do
if h' : i < n then
go (i+1) (by omega) (r.push (← f v[i]))
else
return r.cast (by omega)

@[inline] def flatten (v : Vector (Vector α n) m) : Vector α (m * n) :=
⟨(v.toArray.map Vector.toArray).flatten,
by rcases v; simp_all [Function.comp_def, Array.map_const']⟩
Expand All @@ -191,6 +202,9 @@ which also receives the index of the element, and the fact that the index is les
@[deprecated zipIdx (since := "2025-01-21")]
abbrev zipWithIndex := @zipIdx

@[inline] def zip (v : Vector α n) (w : Vector β n) : Vector (α × β) n :=
⟨v.toArray.zip w.toArray, by simp⟩

/-- Maps corresponding elements of two vectors of equal size using the function `f`. -/
@[inline] def zipWith (a : Vector α n) (b : Vector β n) (f : α → β → φ) : Vector φ n :=
⟨Array.zipWith a.toArray b.toArray f, by simp⟩
Expand Down Expand Up @@ -323,6 +337,19 @@ no element of the index matches the given value.
@[inline] def count [BEq α] (a : α) (v : Vector α n) : Nat :=
v.toArray.count a

/-! ### ForIn instance -/

@[simp] theorem mem_toArray_iff (a : α) (v : Vector α n) : a ∈ v.toArray ↔ a ∈ v :=
fun h => ⟨h⟩, fun ⟨h⟩ => h⟩

instance : ForIn' m (Vector α n) α inferInstance where
forIn' v b f := Array.forIn' v.toArray b (fun a h b => f a (by simpa using h) b)

/-! ### ToStream instance -/

instance : ToStream (Vector α n) (Subarray α) where
toStream v := v.toArray[:n]

/-! ### Lexicographic ordering -/

instance instLT [LT α] : LT (Vector α n) := ⟨fun v w => v.toArray < w.toArray⟩
Expand Down
3 changes: 0 additions & 3 deletions src/Init/Data/Vector/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,6 @@ protected theorem ext {a b : Vector α n} (h : (i : Nat) → (_ : i < n) → a[i
rcases v with ⟨v, h⟩
exact ⟨by rintro rfl; simp_all, by rintro rfl; simpa using h⟩

@[simp] theorem mem_toArray_iff (a : α) (v : Vector α n) : a ∈ v.toArray ↔ a ∈ v :=
fun h => ⟨h⟩, fun ⟨h⟩ => h⟩

/-! ### toList -/

theorem toArray_toList (a : Vector α n) : a.toArray.toList = a.toList := rfl
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Meta/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1657,7 +1657,7 @@ def withLocalDeclsD [Inhabited α] (declInfos : Array (Name × (Array Expr → n
(declInfos.map (fun (name, typeCtor) => (name, BinderInfo.default, typeCtor))) k

/--
Simpler variant of `withLocalDeclsD` for brining variables into scope whose types do not depend
Simpler variant of `withLocalDeclsD` for bringing variables into scope whose types do not depend
on each other.
-/
def withLocalDeclsDND [Inhabited α] (declInfos : Array (Name × Expr)) (k : (xs : Array Expr) → n α) : n α :=
Expand Down

0 comments on commit c93012f

Please sign in to comment.