diff --git a/docs/src/manual/exporting_to_jax.md b/docs/src/manual/exporting_to_jax.md index 50014b99e..dc23905a1 100644 --- a/docs/src/manual/exporting_to_jax.md +++ b/docs/src/manual/exporting_to_jax.md @@ -59,7 +59,7 @@ end Now we define a python script to run the model using EnzymeJAX. ```python -from enzyme_ad.jax import primitives +from enzyme_ad.jax import hlo_call import jax import jax.numpy as jnp @@ -81,7 +81,7 @@ def run_lux_model( weight6_3, bias6_3, ): - return primitives.ffi_call( + return hlo_call( x, weight1, bias1, @@ -93,13 +93,7 @@ def run_lux_model( bias6_2, weight6_3, bias6_3, - out_shapes=[ - jax.core.ShapedArray([4, 10], jnp.float32), - ], - fn="main", source=code, - lang=primitives.LANG_MHLO, - pipeline_options=primitives.JaXPipeline(""), )