Skip to content

Commit accef4e

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Inference Gym: Move the Stan stuff to the spinoffs directory.
PiperOrigin-RevId: 328403234
1 parent 18d6af3 commit accef4e

17 files changed

+94
-102
lines changed

spinoffs/inference_gym/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ models.
109109
Currently we have a little tooling to help use `cmdstanpy` to generate ground
110110
truth values (in the correct format) for models without analytic ground truth.
111111
Using this requires adding a model implementation inside the
112-
[`tools/inference_gym_ground_truth`][ground_truth_dir]
112+
[`inference_gym/tools/stan`][ground_truth_dir]
113113
directory.
114114

115115
### Adding a new real dataset
@@ -124,11 +124,11 @@ Follow the example of the [`SyntheticItemResponseTheory`][irt] model.
124124

125125
### Generating ground truth files.
126126

127-
See [`tools/inference_gym_ground_truth/get_ground_truth.py`][get_ground_truth].
127+
See [`inference_gym/tools/get_ground_truth.py`][get_ground_truth].
128128

129129
[model]: https://github.com/tensorflow/probability/tree/master/spinoffs/inference_gym/targets/model.py
130-
[get_ground_truth]: https://github.com/tensorflow/probability/tree/master/tools/inference_gym_ground_truth/get_ground_truth.py
131-
[ground_truth_dir]: https://github.com/tensorflow/probability/tree/master/tools/inference_gym_ground_truth
130+
[get_ground_truth]: https://github.com/tensorflow/probability/tree/master/spinoffs/inference_gym/tools/get_ground_truth.py
131+
[ground_truth_dir]: https://github.com/tensorflow/probability/tree/master/spinoffs/inference_gym/tools/stan
132132
[bayesian_model]: https://github.com/tensorflow/probability/tree/master/spinoffs/inference_gym/targets/bayesian_model.py
133133
[sparse_logistic_regression]: https://github.com/tensorflow/probability/tree/master/spinoffs/inference_gym/targets/sparse_logistic_regression.py
134134
[logistic_regression]: https://github.com/tensorflow/probability/tree/master/spinoffs/inference_gym/targets/logistic_regression.py

spinoffs/inference_gym/internal/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
package(
2222
default_visibility = [
2323
"//tensorflow_probability:__subpackages__",
24-
"//tools/inference_gym_ground_truth:__subpackages__",
24+
"//tensorflow_probability/opensource/tools/inference_gym_ground_truth:__subpackages__",
2525
"//spinoffs/inference_gym:__subpackages__",
2626
],
2727
)

spinoffs/inference_gym/tools/BUILD

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2020 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
# Ground truth computation.
16+
17+
package(
18+
default_visibility = [
19+
"//tensorflow_probability:__subpackages__",
20+
"//spinoffs/inference_gym:__subpackages__",
21+
],
22+
)
23+
24+
licenses(["notice"])
25+
26+
exports_files(["LICENSE"])
27+
28+
# We can't use strict/pytype because `cmdstanpy` is not available internally.
29+
py_binary(
30+
name = "get_ground_truth",
31+
srcs = ["get_ground_truth.py"],
32+
python_version = "PY3",
33+
deps = [
34+
"//tensorflow_probability",
35+
"//spinoffs/inference_gym/internal:ground_truth_encoding",
36+
"//spinoffs/inference_gym/tools/stan:targets",
37+
],
38+
)

tools/inference_gym_ground_truth/get_ground_truth.py renamed to spinoffs/inference_gym/tools/get_ground_truth.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Lint as: python2, python3
1+
# Lint as: python3
22
# Copyright 2020 The TensorFlow Probability Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -43,10 +43,6 @@
4343
- tfds-nightly
4444
"""
4545

46-
from __future__ import absolute_import
47-
from __future__ import division
48-
from __future__ import print_function
49-
5046
import functools
5147
import os
5248
import sys
@@ -57,8 +53,8 @@
5753
import pandas as pd
5854
import tensorflow.compat.v2 as tf
5955
import tensorflow_probability as tfp
60-
from tools.inference_gym_ground_truth import targets
6156
from spinoffs.inference_gym.internal import ground_truth_encoding
57+
from spinoffs.inference_gym.tools.stan import targets
6258
# Direct import for flatten_with_tuple_paths.
6359
from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import
6460

@@ -72,7 +68,7 @@
7268
flags.DEFINE_string('output_directory', None,
7369
'Where to save the ground truth values. By default, this '
7470
'places it in the appropriate directory in the '
75-
'TensorFlow Probability source directory.')
71+
'Inference Gym source directory.')
7672

7773
FLAGS = flags.FLAGS
7874

@@ -154,7 +150,7 @@ def main(argv):
154150

155151
argv_str = '\n'.join([' {} \\'.format(arg) for arg in sys.argv[1:]])
156152
command_str = (
157-
"""bazel run //tools/inference_gym_ground_truth:get_ground_truth -- \
153+
"""bazel run //spinoffs/inference_gym/tools:get_ground_truth -- \
158154
{argv_str}""".format(argv_str=argv_str))
159155

160156
file_str = ground_truth_encoding.get_ground_truth_module_source(
@@ -163,7 +159,7 @@ def main(argv):
163159
if FLAGS.output_directory is None:
164160
file_basedir = os.path.dirname(os.path.realpath(__file__))
165161
output_directory = os.path.join(
166-
file_basedir, '../../spinoffs/inference_gym/targets/ground_truth')
162+
file_basedir, '../targets/ground_truth')
167163
else:
168164
output_directory = FLAGS.output_directory
169165
file_path = os.path.join(output_directory, '{}.py'.format(FLAGS.target))

tools/inference_gym_ground_truth/BUILD renamed to spinoffs/inference_gym/tools/stan/BUILD

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,19 @@
1313
# limitations under the License.
1414
# ============================================================================
1515
# Ground truth computation using Stan.
16-
licenses(["notice"])
17-
18-
package(default_visibility = ["//visibility:public"])
1916

20-
py_binary(
21-
name = "get_ground_truth",
22-
srcs = ["get_ground_truth.py"],
23-
python_version = "PY3",
24-
tags = ["notap"],
25-
deps = [
26-
":targets",
27-
"//tensorflow_probability",
28-
"//spinoffs/inference_gym/internal:ground_truth_encoding",
17+
package(
18+
default_visibility = [
19+
"//tensorflow_probability:__subpackages__",
20+
"//spinoffs/inference_gym:__subpackages__",
2921
],
3022
)
3123

24+
licenses(["notice"])
25+
26+
exports_files(["LICENSE"])
27+
28+
# We can't use strict/pytype because `cmdstanpy` is not available internally.
3229
py_library(
3330
name = "brownian_motion",
3431
srcs = ["brownian_motion.py"],

tools/inference_gym_ground_truth/brownian_motion.py renamed to spinoffs/inference_gym/tools/stan/brownian_motion.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Lint as: python2, python3
1+
# Lint as: python3
22
# Copyright 2020 The TensorFlow Probability Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,14 +15,10 @@
1515
# ============================================================================
1616
"""Brownian Motion model, implemented in Stan."""
1717

18-
from __future__ import absolute_import
19-
from __future__ import division
20-
from __future__ import print_function
21-
2218
import numpy as np
2319

24-
from tools.inference_gym_ground_truth import stan_model
25-
from tools.inference_gym_ground_truth import util
20+
from spinoffs.inference_gym.tools.stan import stan_model
21+
from spinoffs.inference_gym.tools.stan import util
2622

2723
__all__ = [
2824
'brownian_motion',

tools/inference_gym_ground_truth/item_response_theory.py renamed to spinoffs/inference_gym/tools/stan/item_response_theory.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Lint as: python2, python3
1+
# Lint as: python3
22
# Copyright 2020 The TensorFlow Probability Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,14 +15,10 @@
1515
# ============================================================================
1616
"""1PL item-response theory model, implemented in Stan."""
1717

18-
from __future__ import absolute_import
19-
from __future__ import division
20-
from __future__ import print_function
21-
2218
import collections
2319

24-
from tools.inference_gym_ground_truth import stan_model
25-
from tools.inference_gym_ground_truth import util
20+
from spinoffs.inference_gym.tools.stan import stan_model
21+
from spinoffs.inference_gym.tools.stan import util
2622

2723
__all__ = [
2824
'item_response_theory',

tools/inference_gym_ground_truth/log_gaussian_cox_process.py renamed to spinoffs/inference_gym/tools/stan/log_gaussian_cox_process.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Lint as: python2, python3
1+
# Lint as: python3
22
# Copyright 2020 The TensorFlow Probability Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,14 +15,10 @@
1515
# ============================================================================
1616
"""Log-gaussian Cox Process, implemented in Stan."""
1717

18-
from __future__ import absolute_import
19-
from __future__ import division
20-
from __future__ import print_function
21-
2218
import collections
2319

24-
from tools.inference_gym_ground_truth import stan_model
25-
from tools.inference_gym_ground_truth import util
20+
from spinoffs.inference_gym.tools.stan import stan_model
21+
from spinoffs.inference_gym.tools.stan import util
2622

2723
__all__ = [
2824
'log_gaussian_cox_process',

tools/inference_gym_ground_truth/logistic_regression.py renamed to spinoffs/inference_gym/tools/stan/logistic_regression.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Lint as: python2, python3
1+
# Lint as: python3
22
# Copyright 2020 The TensorFlow Probability Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,14 +15,10 @@
1515
# ============================================================================
1616
"""Logistic regression, implemented in Stan."""
1717

18-
from __future__ import absolute_import
19-
from __future__ import division
20-
from __future__ import print_function
21-
2218
import numpy as np
2319

24-
from tools.inference_gym_ground_truth import stan_model
25-
from tools.inference_gym_ground_truth import util
20+
from spinoffs.inference_gym.tools.stan import stan_model
21+
from spinoffs.inference_gym.tools.stan import util
2622

2723
__all__ = [
2824
'logistic_regression',

tools/inference_gym_ground_truth/probit_regression.py renamed to spinoffs/inference_gym/tools/stan/probit_regression.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Lint as: python2, python3
1+
# Lint as: python3
22
# Copyright 2020 The TensorFlow Probability Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,14 +15,10 @@
1515
# ============================================================================
1616
"""Probit regression, implemented in Stan."""
1717

18-
from __future__ import absolute_import
19-
from __future__ import division
20-
from __future__ import print_function
21-
2218
import numpy as np
2319

24-
from tools.inference_gym_ground_truth import stan_model
25-
from tools.inference_gym_ground_truth import util
20+
from spinoffs.inference_gym.tools.stan import stan_model
21+
from spinoffs.inference_gym.tools.stan import util
2622

2723
__all__ = [
2824
'probit_regression',

tools/inference_gym_ground_truth/radon_contextual_effects.py renamed to spinoffs/inference_gym/tools/stan/radon_contextual_effects.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Lint as: python2, python3
1+
# Lint as: python3
22
# Copyright 2020 The TensorFlow Probability Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,16 +15,12 @@
1515
# ============================================================================
1616
"""Radon model, implemented in Stan."""
1717

18-
from __future__ import absolute_import
19-
from __future__ import division
20-
from __future__ import print_function
21-
2218
import collections
2319

2420
import numpy as np
2521

26-
from tools.inference_gym_ground_truth import stan_model
27-
from tools.inference_gym_ground_truth import util
22+
from spinoffs.inference_gym.tools.stan import stan_model
23+
from spinoffs.inference_gym.tools.stan import util
2824

2925
__all__ = [
3026
'radon_contextual_effects',
@@ -115,7 +111,8 @@ def _ext_identity(samples):
115111
samples, r'^county_effect_scale$')[:, 0]
116112
res['county_effect'] = util.get_columns(samples, r'^county_effect\.\d+$')
117113
res['weight'] = util.get_columns(samples, r'^weight\.\d+$')
118-
res['log_radon_scale'] = util.get_columns(samples, r'^log_radon_scale$')[:, 0]
114+
res['log_radon_scale'] = (
115+
util.get_columns(samples, r'^log_radon_scale$')[:, 0])
119116
return res
120117

121118
extract_fns = {'identity': _ext_identity}

tools/inference_gym_ground_truth/sparse_logistic_regression.py renamed to spinoffs/inference_gym/tools/stan/sparse_logistic_regression.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Lint as: python2, python3
1+
# Lint as: python3
22
# Copyright 2020 The TensorFlow Probability Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,15 +15,11 @@
1515
# ============================================================================
1616
"""Sparse logistic regression, implemented in Stan."""
1717

18-
from __future__ import absolute_import
19-
from __future__ import division
20-
from __future__ import print_function
21-
2218
import collections
2319
import numpy as np
2420

25-
from tools.inference_gym_ground_truth import stan_model
26-
from tools.inference_gym_ground_truth import util
21+
from spinoffs.inference_gym.tools.stan import stan_model
22+
from spinoffs.inference_gym.tools.stan import util
2723

2824
__all__ = [
2925
'sparse_logistic_regression',

tools/inference_gym_ground_truth/stan_model.py renamed to spinoffs/inference_gym/tools/stan/stan_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Lint as: python2, python3
1+
# Lint as: python3
22
# Copyright 2020 The TensorFlow Probability Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");

tools/inference_gym_ground_truth/stochastic_volatility.py renamed to spinoffs/inference_gym/tools/stan/stochastic_volatility.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Lint as: python2, python3
1+
# Lint as: python3
22
# Copyright 2020 The TensorFlow Probability Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,14 +15,10 @@
1515
# ============================================================================
1616
"""Stochastic volatility model, implemented in Stan."""
1717

18-
from __future__ import absolute_import
19-
from __future__ import division
20-
from __future__ import print_function
21-
2218
import collections
2319

24-
from tools.inference_gym_ground_truth import stan_model
25-
from tools.inference_gym_ground_truth import util
20+
from spinoffs.inference_gym.tools.stan import stan_model
21+
from spinoffs.inference_gym.tools.stan import util
2622

2723
__all__ = [
2824
'stochastic_volatility',

tools/inference_gym_ground_truth/targets.py renamed to spinoffs/inference_gym/tools/stan/targets.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Lint as: python2, python3
1+
# Lint as: python3
22
# Copyright 2020 The TensorFlow Probability Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,19 +15,15 @@
1515
# ============================================================================
1616
"""Stan models, used as a source of ground truth."""
1717

18-
from __future__ import absolute_import
19-
from __future__ import division
20-
from __future__ import print_function
21-
22-
from tools.inference_gym_ground_truth import brownian_motion
23-
from tools.inference_gym_ground_truth import item_response_theory
24-
from tools.inference_gym_ground_truth import log_gaussian_cox_process
25-
from tools.inference_gym_ground_truth import logistic_regression
26-
from tools.inference_gym_ground_truth import probit_regression
27-
from tools.inference_gym_ground_truth import radon_contextual_effects
28-
from tools.inference_gym_ground_truth import sparse_logistic_regression
29-
from tools.inference_gym_ground_truth import stochastic_volatility
3018
from spinoffs.inference_gym.internal import data
19+
from spinoffs.inference_gym.tools.stan import brownian_motion
20+
from spinoffs.inference_gym.tools.stan import item_response_theory
21+
from spinoffs.inference_gym.tools.stan import log_gaussian_cox_process
22+
from spinoffs.inference_gym.tools.stan import logistic_regression
23+
from spinoffs.inference_gym.tools.stan import probit_regression
24+
from spinoffs.inference_gym.tools.stan import radon_contextual_effects
25+
from spinoffs.inference_gym.tools.stan import sparse_logistic_regression
26+
from spinoffs.inference_gym.tools.stan import stochastic_volatility
3127

3228
__all__ = [
3329
'brownian_motion_missing_middle_observations',

0 commit comments

Comments
 (0)