Skip to content

Commit

Permalink
Make seed handler stateless by default
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Feb 19, 2025
1 parent f7746d5 commit 31ba110
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
11 changes: 11 additions & 0 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,8 @@ class seed(Messenger):
>>> assert x == y
"""

stateful = False

def __init__(
self,
fn: Optional[Callable] = None,
Expand Down Expand Up @@ -835,6 +837,15 @@ def process_message(self, msg: Message) -> None:
self.rng_key, rng_key_sample = random.split(self.rng_key)
msg["kwargs"]["rng_key"] = rng_key_sample

def __call__(self, *args, **kwargs):
if self.fn is not None and not self.stateful:
cloned_seeded_fn = seed(
self.fn, rng_seed=self.rng_key, hide_types=self.hide_types
)
cloned_seeded_fn.stateful = True
return cloned_seeded_fn.__call__(*args, **kwargs)
return super().__call__(*args, **kwargs)


class substitute(Messenger):
"""
Expand Down
7 changes: 5 additions & 2 deletions test/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,12 @@ def model_subsample_2():
model_subsample_2,
],
)
def test_plate(model):
def test_trace_jit(model):
trace = handlers.trace(handlers.seed(model, random.PRNGKey(1))).get_trace()
jit_trace = handlers.trace(jit(handlers.seed(model, random.PRNGKey(1)))).get_trace()
with jax.check_tracer_leaks(False):
jit_trace = handlers.trace(
jit(handlers.seed(model, random.PRNGKey(1)))
).get_trace()
assert "z" in trace
for name, site in trace.items():
if site["type"] == "sample":
Expand Down

0 comments on commit 31ba110

Please sign in to comment.