Skip to content

Commit 707c3e0

Browse files
authored
Merge pull request #54 from HyperCodec/activation-example
create custom activations example
2 parents 1ca896b + 7c31f30 commit 707c3e0

File tree

2 files changed

+86
-2
lines changed

2 files changed

+86
-2
lines changed

examples/custom_activation.rs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
//! An example implementation of a custom activation function.
2+
3+
use neat::*;
4+
use rand::prelude::*;
5+
6+
#[derive(DivisionReproduction, RandomlyMutable, Clone)]
7+
struct AgentDNA {
8+
network: NeuralNetworkTopology<2, 2>,
9+
}
10+
11+
impl Prunable for AgentDNA {}
12+
13+
impl GenerateRandom for AgentDNA {
14+
fn gen_random(rng: &mut impl Rng) -> Self {
15+
Self {
16+
network: NeuralNetworkTopology::new(0.01, 3, rng),
17+
}
18+
}
19+
}
20+
21+
fn fitness(g: &AgentDNA) -> f32 {
22+
let network: NeuralNetwork<2, 2> = NeuralNetwork::from(&g.network);
23+
let mut fitness = 0.;
24+
let mut rng = rand::thread_rng();
25+
26+
for _ in 0..50 {
27+
let n = rng.gen::<f32>();
28+
let n2 = rng.gen::<f32>();
29+
30+
let expected = if (n + n2) / 2. >= 0.5 { 0 } else { 1 };
31+
32+
let result = network.predict([n, n2]);
33+
network.flush_state();
34+
35+
// partial_cmp chance of returning None in this smh
36+
let result = result.iter().max_index();
37+
38+
if result == expected {
39+
fitness += 1.;
40+
} else {
41+
fitness -= 1.;
42+
}
43+
}
44+
45+
fitness
46+
}
47+
48+
#[cfg(feature = "serde")]
49+
fn serde_nextgen(rewards: Vec<(AgentDNA, f32)>) -> Vec<AgentDNA> {
50+
let max = rewards
51+
.iter()
52+
.max_by(|(_, ra), (_, rb)| ra.total_cmp(rb))
53+
.unwrap();
54+
55+
let ser = NNTSerde::from(&max.0.network);
56+
let data = serde_json::to_string_pretty(&ser).unwrap();
57+
std::fs::write("best-agent.json", data).expect("Failed to write to file");
58+
59+
division_pruning_nextgen(rewards)
60+
}
61+
62+
fn main() {
63+
let log_activation = activation_fn!(f32::log10);
64+
register_activation(log_activation);
65+
66+
#[cfg(not(feature = "rayon"))]
67+
let mut rng = rand::thread_rng();
68+
69+
let mut sim = GeneticSim::new(
70+
#[cfg(not(feature = "rayon"))]
71+
Vec::gen_random(&mut rng, 100),
72+
#[cfg(feature = "rayon")]
73+
Vec::gen_random(100),
74+
fitness,
75+
#[cfg(not(feature = "serde"))]
76+
division_pruning_nextgen,
77+
#[cfg(feature = "serde")]
78+
serde_nextgen,
79+
);
80+
81+
for _ in 0..200 {
82+
sim.next_generation();
83+
}
84+
}

src/topology/activation.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ use crate::NeuronLocation;
1515
#[macro_export]
1616
macro_rules! activation_fn {
1717
($F: path) => {
18-
ActivationFn::new(Arc::new($F), ActivationScope::default(), stringify!($F).into())
18+
ActivationFn::new(std::sync::Arc::new($F), ActivationScope::default(), stringify!($F).into())
1919
};
2020

2121
($F: path, $S: expr) => {
22-
ActivationFn::new(Arc::new($F), $S, stringify!($F).into())
22+
ActivationFn::new(std::sync::Arc::new($F), $S, stringify!($F).into())
2323
};
2424

2525
{$($F: path),*} => {

0 commit comments

Comments
 (0)