Skip to content

Commit

Permalink
Remove cached plates from AutoGuide.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Feb 14, 2025
1 parent aa829fb commit 66e2b8f
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,24 +97,24 @@ def __init__(
self._prototype_frames = {}
self._prototype_frame_full_sizes = {}

def _create_plates(self, *args, **kwargs):
def _create_plates(self, *args, **kwargs) -> dict[str, numpyro.plate]:
if self.create_plates is None:
self.plates = {}
plates = {}
else:
plates = self.create_plates(*args, **kwargs)
if isinstance(plates, numpyro.plate):
plates = [plates]
assert all(isinstance(p, numpyro.plate) for p in plates), (
"create_plates() returned a non-plate"
)
self.plates = {p.name: p for p in plates}
plates = {p.name: p for p in plates}
for name, frame in sorted(self._prototype_frames.items()):
if name not in self.plates:
if name not in plates:
full_size = self._prototype_frame_full_sizes[name]
self.plates[name] = numpyro.plate(
plates[name] = numpyro.plate(
name, full_size, dim=frame.dim, subsample_size=frame.size
)
return self.plates
return plates

def __getstate__(self):
state = self.__dict__.copy()
Expand Down

0 comments on commit 66e2b8f

Please sign in to comment.