Skip to content

Commit f1ffda1

Browse files
sethaxenyebaisunxd3penelopeysmgithub-actions[bot]
authored
Add StatsBase.predict to the interface (#81)
* Add StatsBase as a dependency * Implement StatsBase.predict * use `fix` and fix some errors * Format Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Bump StatsBase compat * slim down implementations --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: Penelope Yong <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 903e0c6 commit f1ffda1

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
1111
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
1212
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
14+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1415

1516
[compat]
1617
AbstractMCMC = "2, 3, 4, 5"
1718
Accessors = "0.1"
1819
DensityInterface = "0.4"
1920
JSON = "0.19 - 0.21"
2021
Random = "1.6"
22+
StatsBase = "0.32, 0.33, 0.34"
2123
julia = "~1.6.6, 1.7.3"

src/abstractprobprog.jl

+14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using AbstractMCMC
22
using DensityInterface
33
using Random
4+
using StatsBase
45

56
"""
67
AbstractProbabilisticProgram
@@ -116,3 +117,16 @@ end
116117
function Base.rand(model::AbstractProbabilisticProgram)
117118
return rand(Random.default_rng(), NamedTuple, model)
118119
end
120+
121+
"""
122+
predict(
123+
[rng::AbstractRNG=Random.default_rng(),]
124+
model::AbstractProbabilisticProgram,
125+
params,
126+
)
127+
128+
Draw a sample from the predictive distribution specified by `model` with its parameters fixed to `params`.
129+
"""
130+
function StatsBase.predict(model::AbstractProbabilisticProgram, params)
131+
return predict(Random.default_rng(), NamedTuple, model, params)
132+
end

0 commit comments

Comments
 (0)