Skip to content

Commit

Permalink
feat: show how to use model explorer (#1228)
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal authored Jan 30, 2025
1 parent 94ff497 commit 2242f21
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 1 deletion.
4 changes: 4 additions & 0 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ export default defineConfig({
text: "Initializing Weights",
link: "/manual/weight_initializers",
},
{
text: "Visualizing Lux Models",
link: "/manual/visualize_lux_models",
},
],
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/src/manual/exporting_to_jax.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Exporting Lux Models to Jax (via EnzymeJAX & Reactant)
# [Exporting Lux Models to Jax (via EnzymeJAX & Reactant)](@ref exporting_to_stablehlo)

In this manual, we will go over how to export Lux models to StableHLO and use
[EnzymeJAX](https://github.com/EnzymeAD/Enzyme-JAX) to run integrate Lux models with
Expand Down
58 changes: 58 additions & 0 deletions docs/src/manual/visualize_lux_models.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Visualizing Lux Models using Model Explorer

We can use [model explorer](https://ai.google.dev/edge/model-explorer) to visualize both Lux
models and the corresponding gradient expressions. To do this we just need to compile our
model [using Reactant](@ref reactant-compilation) and save the resulting `mlir` file.

```@example visualize_lux_models
using Lux, Reactant, Enzyme, Random
dev = reactant_device(; force=true)
model = Chain(
Chain(
Conv((3, 3), 3 => 32, relu; pad=SamePad()),
BatchNorm(32),
),
FlattenLayer(),
Dense(32 * 32 * 32 => 32, tanh),
BatchNorm(32),
Dense(32 => 10)
)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
x = randn(Float32, 32, 32, 3, 4) |> dev
nothing #hide
```

Following instructions from [exporting lux models to stablehlo](@ref exporting_to_stablehlo)
we can save the `mlir` file.

```@example visualize_lux_models
hlo = @code_hlo model(x, ps, Lux.testmode(st))
open("exported_lux_model.mlir", "w") do io
write(io, string(hlo))
end
```

![model-explorer-screenshot](../public/model_explorer_graph_forward_pass.png)

We can also visualize the gradients of the model using the same method.

```@example visualize_lux_models
function ∇sumabs2_enzyme(model, x, ps, st)
return Enzyme.gradient(Enzyme.Reverse, sum ∘ first ∘ Lux.apply, Const(model),
x, ps, Const(st))
end
hlo = @code_hlo ∇sumabs2_enzyme(model, x, ps, st)
open("exported_lux_model_gradients.mlir", "w") do io
write(io, string(hlo))
end
```

This is going to be hard to read, but you get the idea.

![model-explorer-screenshot](../public/model_explorer_graph_backward_pass.png)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 2242f21

Please sign in to comment.