The stochy
crate is a collection of stochastic approximation algorithms:
RSPSA
(Resilient Simultaneous Perturbation Stochastic Approximation)SPSA
(Simultaneous Perturbation Stochastic Approximation)
You can use stochy
to:
- Minimize functions with multiple parameters, without requiring a gradient.
- Optimize parameters in game-playing programs using relative difference functions.
stochy
is compatible with both the stepwise algorithm API and
the argmin solver API (enable via the argmin
feature flag). Difference functions are only supported under stepwise
.
Example Cargo.toml
:
[dependencies]
stochy = "0.0.3"
# if using argmin, replace the above with:
# stochy = { version = "0.0.3", features = ["argmin"] }
use stepwise::{Driver as _, fixed_iters, assert_approx_eq};
use stochy::{SpsaAlgo, SpsaParams};
let f = |x: &[f64]| (x[0] - 1.5).powi(2) + x[1].powi(2);
let hyperparams = SpsaParams::default();
let initial_guess = vec![1.0, 1.0];
let spsa = SpsaAlgo::from_fn(hyperparams, initial_guess, f).expect("bad hyperparams!");
let (solved, final_step) = fixed_iters(spsa, 20_000)
.on_step(|algo, step| println!("{} {:?}", step.iteration(), algo.x()))
.solve()
.expect("solving failed!");
assert_approx_eq!(solved.x(), &[1.5, 0.0]);
println!("Solved in {} iterations.", final_step.iteration());
This example is equivalent to Example 1, but uses the argmin
crate to manage the SPSA algorithm.
use stepwise::assert_approx_eq;
struct MySimpleCost;
# #[cfg(feature = "argmin")]
impl argmin::core::CostFunction for MySimpleCost {
type Param = Vec<f64>;
type Output = f64;
fn cost(&self, x: &Self::Param) -> Result<Self::Output, argmin::core::Error> {
Ok((x[0] - 1.5).powi(2) + x[1].powi(2))
}
}
let hyperparams = stochy::SpsaParams::default();
let algo = stochy::SpsaSolverArgmin::new(hyperparams);
let exec = argmin::core::Executor::new(MySimpleCost, algo);
let initial_param = vec![1.0, 1.0];
let result = exec
.configure(|step| step.param(initial_param).max_iters(20_000))
.run()
.unwrap();
let best_param = result.state.best_param.unwrap();
assert_approx_eq!(best_param.as_slice(), &[1.5, 0.0]);
println!("Solved in {} iterations.", result.state.iter);
Table 1: Feature comparison of the algorithms contrasted with the more familiar Gradient Descent algorithm.