Commit a017f4e
committed
Optimize tree_sum compile time using tree_reduce_associative
Changed tree_sum implementation to use jax.tree_util.tree_reduce_associative
when available (JAX >= 0.6.0). Since addition is an associative operation,
tree_reduce_associative can provide better compilation performance.
Testing shows runtime is very close but compile time is significantly
lower (18s vs 23s in reported cases).
For compatibility with older JAX versions (< 0.6.0), the implementation
falls back to the original jax.tree.reduce when tree_reduce_associative
is not available.1 parent dcff838 commit a017f4e
1 file changed
+7
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
20 | 20 | | |
21 | 21 | | |
22 | 22 | | |
| 23 | + | |
23 | 24 | | |
24 | 25 | | |
25 | 26 | | |
| |||
176 | 177 | | |
177 | 178 | | |
178 | 179 | | |
179 | | - | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
180 | 186 | | |
181 | 187 | | |
182 | 188 | | |
| |||
0 commit comments