-
Notifications
You must be signed in to change notification settings - Fork 106
Turing 0.41. #661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Turing 0.41. #661
Changes from 10 commits
43fb5ef
47d9d6b
3eb49c6
b80e488
ff71c5b
179caf6
68e7ead
2db4ffb
66aa564
1a2decf
6dfce67
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,9 +41,8 @@ using StatsBase | |
| # Functionality for working with scaled identity matrices. | ||
| using LinearAlgebra | ||
|
|
||
| # Set a seed for reproducibility. | ||
| using Random | ||
| Random.seed!(0); | ||
| # For ensuring reproducibility. | ||
| using StableRNGs: StableRNG | ||
| ``` | ||
|
|
||
| ```{julia} | ||
|
|
@@ -76,7 +75,7 @@ The next step is to get our data ready for testing. We'll split the `mtcars` dat | |
| select!(data, Not(:Model)) | ||
|
|
||
| # Split our dataset 70%/30% into training/test sets. | ||
| trainset, testset = map(DataFrame, splitobs(data; at=0.7, shuffle=true)) | ||
| trainset, testset = map(DataFrame, splitobs(StableRNG(468), data; at=0.7, shuffle=true)) | ||
|
|
||
| # Turing requires data in matrix form. | ||
| target = :MPG | ||
|
|
@@ -143,7 +142,7 @@ With our model specified, we can call the sampler. We will use the No U-Turn Sam | |
|
|
||
| ```{julia} | ||
| model = linear_regression(train, train_target) | ||
| chain = sample(model, NUTS(), 5_000) | ||
| chain = sample(StableRNG(468), model, NUTS(), 20_000) | ||
| ``` | ||
|
|
||
| We can also check the densities and traces of the parameters visually using the `plot` functionality. | ||
|
|
@@ -158,7 +157,7 @@ It looks like all parameters have converged. | |
| #| echo: false | ||
| let | ||
| ess_df = ess(chain) | ||
| @assert minimum(ess_df[:, :ess]) > 500 "Minimum ESS: $(minimum(ess_df[:, :ess])) - not > 700" | ||
| @assert minimum(ess_df[:, :ess]) > 500 "Minimum ESS: $(minimum(ess_df[:, :ess])) - not > 500" | ||
| @assert mean(ess_df[:, :ess]) > 2_000 "Mean ESS: $(mean(ess_df[:, :ess])) - not > 2000" | ||
| @assert maximum(ess_df[:, :ess]) > 3_500 "Maximum ESS: $(maximum(ess_df[:, :ess])) - not > 3500" | ||
| end | ||
|
|
@@ -243,9 +242,11 @@ let | |
| ols_test_loss = msd(test_prediction_ols, testset[!, target]) | ||
| @assert bayes_train_loss < bayes_test_loss "Bayesian training loss ($bayes_train_loss) >= Bayesian test loss ($bayes_test_loss)" | ||
| @assert ols_train_loss < ols_test_loss "OLS training loss ($ols_train_loss) >= OLS test loss ($ols_test_loss)" | ||
| @assert isapprox(bayes_train_loss, ols_train_loss; rtol=0.01) "Difference between Bayesian training loss ($bayes_train_loss) and OLS training loss ($ols_train_loss) unexpectedly large!" | ||
| @assert isapprox(bayes_test_loss, ols_test_loss; rtol=0.05) "Difference between Bayesian test loss ($bayes_test_loss) and OLS test loss ($ols_test_loss) unexpectedly large!" | ||
| @assert bayes_train_loss > ols_train_loss "Bayesian training loss ($bayes_train_loss) <= OLS training loss ($bayes_train_loss)" | ||
| @assert bayes_test_loss < ols_test_loss "Bayesian test loss ($bayes_test_loss) >= OLS test loss ($ols_test_loss)" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you know if it was just happenstance that these actually pretty much matched each other before? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Err, not sure to be honest. I think the atols were quite fragile, even changing the way the test/train split was done led to things failing there. So I'd say yes. |
||
| end | ||
| ``` | ||
|
|
||
| As we can see above, OLS and our Bayesian model fit our training and test data set about the same. | ||
| We can see from this that both linear regression techniques perform fairly similarly. | ||
| The Bayesian linear regression approach performs worse on the training set, but better on the test set. | ||
| This indicates that the Bayesian approach is more able to generalise to unseen data, i.e., it is not overfitting the training data as much. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,11 +38,11 @@ Random.seed!(3) | |
|
|
||
| # Define Gaussian mixture model. | ||
| w = [0.5, 0.5] | ||
| μ = [-3.5, 0.5] | ||
| mixturemodel = MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ], w) | ||
| μ = [-2.0, 2.0] | ||
| mixturemodel = MixtureModel([MvNormal(Fill(μₖ, 2), 0.2 * I) for μₖ in μ], w) | ||
|
|
||
| # We draw the data points. | ||
| N = 60 | ||
| N = 30 | ||
| x = rand(mixturemodel, N); | ||
|
Comment on lines
-40
to
46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having 30 observations just makes life easier because it triggers fewer resampling steps with PG. To compensate for that I made the data more tightly clustered to make sure that the results are still somewhat meaningful. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. locally, this cuts the time taken to run this tutorial to about 12 minutes. I'm not sure what it was before but I think it was certainly over 20 minutes. |
||
| ``` | ||
|
|
||
|
|
@@ -112,7 +112,7 @@ model = gaussian_mixture_model(x); | |
| ``` | ||
|
|
||
| We run a MCMC simulation to obtain an approximation of the posterior distribution of the parameters $\mu$ and $w$ and assignments $k$. | ||
| We use a `Gibbs` sampler that combines a [particle Gibbs](https://www.stats.ox.ac.uk/%7Edoucet/andrieu_doucet_holenstein_PMCMC.pdf) sampler for the discrete parameters (assignments $k$) and a Hamiltonian Monte Carlo sampler for the continuous parameters ($\mu$ and $w$). | ||
| We use a `Gibbs` sampler that combines a [particle Gibbs](https://www.stats.ox.ac.uk/%8Edoucet/andrieu_doucet_holenstein_PMCMC.pdf) sampler for the discrete parameters (assignments $k$) and a Hamiltonian Monte Carlo sampler for the continuous parameters ($\mu$ and $w$). | ||
| We generate multiple chains in parallel using multi-threading. | ||
|
|
||
| ```{julia} | ||
|
|
@@ -145,7 +145,7 @@ let | |
| # μ[1] and μ[2] can switch places, so we sort the values first. | ||
| chain = Array(chains[:, ["μ[1]", "μ[2]"], i]) | ||
| μ_mean = vec(mean(chain; dims=1)) | ||
| @assert isapprox(sort(μ_mean), μ; rtol=0.1) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!" | ||
| @assert isapprox(sort(μ_mean), μ; atol=0.5) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!" | ||
| end | ||
| end | ||
| ``` | ||
|
|
@@ -212,7 +212,7 @@ let | |
| # μ[1] and μ[2] can no longer switch places. Check that they've found the mean | ||
| chain = Array(chains[:, ["μ[1]", "μ[2]"], i]) | ||
| μ_mean = vec(mean(chain; dims=1)) | ||
| @assert isapprox(sort(μ_mean), μ; rtol=0.4) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!" | ||
| @assert isapprox(sort(μ_mean), μ; atol=0.5) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!" | ||
| end | ||
| end | ||
| ``` | ||
|
|
@@ -350,7 +350,7 @@ let | |
| # μ[1] and μ[2] can no longer switch places. Check that they've found the mean | ||
| chain = Array(chains[:, ["μ[1]", "μ[2]"], i]) | ||
| μ_mean = vec(mean(chain; dims=1)) | ||
| @assert isapprox(sort(μ_mean), μ; rtol=0.4) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!" | ||
| @assert isapprox(sort(μ_mean), μ; atol=0.5) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!" | ||
| end | ||
| end | ||
| ``` | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't actually know what went wrong here, i struggled to get this to work without bumping the samples. It might just be that the 5000 samples one just happened to be a very lucky seed. It takes very little time because NUTS is quite fast anyway.