Skip to content

Commit

Permalink
add version asserts for various examples
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Feb 6, 2025
1 parent 9ec98e5 commit 5cf00b6
Show file tree
Hide file tree
Showing 10 changed files with 12 additions and 0 deletions.
1 change: 1 addition & 0 deletions examples/capture_recapture.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.17.0")
parser = argparse.ArgumentParser(
description="CJS capture-recapture model for ecological data"
)
Expand Down
2 changes: 2 additions & 0 deletions examples/cvae-flax/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from train_baseline import train_baseline
from train_cvae import train_cvae

import numpyro
from numpyro.examples.datasets import MNIST

from models import BaselineNet, Decoder, Encoder, cvae_guide, cvae_model # isort:skip
Expand Down Expand Up @@ -78,6 +79,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.17.0")
parser = argparse.ArgumentParser(
description="Conditional Variational Autoencoder on MNIST using Flax"
)
Expand Down
1 change: 1 addition & 0 deletions examples/dais_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def add_fig(samples, title, ax):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.17.0")
parser = argparse.ArgumentParser("Usage example for AutoDAIS guide.")
parser.add_argument("--num-svi-steps", type=int, default=80 * 1000)
parser.add_argument("--num-warmup", type=int, default=2000)
Expand Down
1 change: 1 addition & 0 deletions examples/hmcecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def summary_plot(losses, hmc_samples, hmcecs_samples, hmc_runtime, hmcecs_runtim


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.17.0")
parser = argparse.ArgumentParser(
"Hamiltonian Monte Carlo with Energy Conserving Subsampling"
)
Expand Down
1 change: 1 addition & 0 deletions examples/hmm_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.17.0")
parser = argparse.ArgumentParser(description="HMC for HMMs")
parser.add_argument(
"-m",
Expand Down
1 change: 1 addition & 0 deletions examples/hsgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.17.0")
args = parse_arguments()
numpyro.enable_x64(args.x64)
numpyro.set_platform(args.device)
Expand Down
1 change: 1 addition & 0 deletions examples/ssbvm_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.17.0")
parser = argparse.ArgumentParser(
description="Sine-skewed sine (bivariate von mises) mixture model example"
)
Expand Down
2 changes: 2 additions & 0 deletions examples/stein_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from jax import config, nn, numpy as jnp, random

import numpyro
from numpyro import deterministic, plate, sample, set_platform, subsample
from numpyro.contrib.einstein import MixtureGuidePredictive, RBFKernel, SteinVI
from numpyro.distributions import Gamma, Normal
Expand Down Expand Up @@ -186,6 +187,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.17.0")
config.update("jax_debug_nans", True)

parser = argparse.ArgumentParser()
Expand Down
1 change: 1 addition & 0 deletions examples/var2.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.17.0")
parser = argparse.ArgumentParser(description="VAR(2) example")
parser.add_argument("--num-data", nargs="?", default=100, type=int)
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
Expand Down
1 change: 1 addition & 0 deletions examples/zero_inflated_poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def add_fig(var_name, title, ax):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.17.0")
parser = argparse.ArgumentParser("Zero-Inflated Poisson Regression")
parser.add_argument("--seed", nargs="?", default=42, type=int)
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
Expand Down

0 comments on commit 5cf00b6

Please sign in to comment.