@@ -115,6 +115,57 @@ def Environment.transitivelyRequiredModules (env : Environment) (module : Name)
115
115
|>.filter (env.getModuleFor? · = some module)
116
116
(NameSet.ofList constants).transitivelyRequiredModules env
117
117
118
+ /--
119
+ Computes all the modules transitively required by the specified modules.
120
+ Should be equivalent to calling `transitivelyRequiredModules` on each module, but shares more of the work.
121
+ -/
122
+ partial def Environment.transitivelyRequiredModules' (env : Environment) (modules : List Name) (verbose : Bool := false ) :
123
+ CoreM (NameMap NameSet) := do
124
+ let N := env.header.moduleNames.size
125
+ let mut c2m : NameMap (BitVec N) := {}
126
+ let mut pushed : NameSet := {}
127
+ let mut result : NameMap NameSet := {}
128
+ for m in modules do
129
+ if verbose then
130
+ IO.println s! "Processing module { m} "
131
+ let mut r : BitVec N := 0
132
+ for n in env.header.moduleData[(env.header.moduleNames.getIdx? m).getD 0 ]!.constNames do
133
+ if ! n.isInternal then
134
+ -- This is messy: Mathlib is big enough that writing a recursive function causes a stack overflow.
135
+ -- So we use an explicit stack instead. We visit each constant twice:
136
+ -- once to record the constants transitively used by it,
137
+ -- and again to record the modules which defined those constants.
138
+ let mut stack : List (Name × Option NameSet) := [⟨n, none⟩]
139
+ pushed := pushed.insert n
140
+ while !stack.isEmpty do
141
+ match stack with
142
+ | [] => panic! "Stack is empty"
143
+ | (c, used?) :: tail =>
144
+ stack := tail
145
+ match used? with
146
+ | none =>
147
+ if !c2m.contains c then
148
+ let used := (← getConstInfo c).getUsedConstantsAsSet
149
+ stack := ⟨c, some used⟩ :: stack
150
+ for u in used do
151
+ if !pushed.contains u then
152
+ stack := ⟨u, none⟩ :: stack
153
+ pushed := pushed.insert u
154
+ | some used =>
155
+ let usedModules : NameSet :=
156
+ used.fold (init := {}) (fun s u => if let some m := env.getModuleFor? u then s.insert m else s)
157
+ let transitivelyUsed : BitVec N :=
158
+ used.fold (init := toBitVec usedModules) (fun s u => s ||| ((c2m.find? u).getD 0 ))
159
+ c2m := c2m.insert c transitivelyUsed
160
+ r := r ||| ((c2m.find? n).getD 0 )
161
+ result := result.insert m (toNameSet r)
162
+ return result
163
+ where
164
+ toBitVec {N : Nat} (s : NameSet) : BitVec N :=
165
+ s.fold (init := 0 ) (fun b n => b ||| BitVec.twoPow _ ((env.header.moduleNames.getIdx? n).getD 0 ))
166
+ toNameSet {N : Nat} (b : BitVec N) : NameSet :=
167
+ env.header.moduleNames.zipWithIndex.foldl (init := {}) (fun s (n, i) => if b.getLsbD i then s.insert n else s)
168
+
118
169
/--
119
170
Return the names of the modules in which constants used in the current file were defined.
120
171
0 commit comments