Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 97 additions & 73 deletions examples/early_stopping_example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,89 @@ use rust_lstm::{
create_basic_trainer, EarlyStoppingConfig, EarlyStoppingMetric, LSTMNetwork, TrainingConfig,
};

pub(crate) const DEMO_EPOCHS: usize = 6;
pub(crate) const DEMO_PRINT_EVERY: usize = 1;
pub(crate) const MAX_DEMO_PATIENCE: usize = 4;
pub(crate) const DEMO_TRAIN_SEQUENCES: usize = 8;
pub(crate) const DEMO_VALIDATION_SEQUENCES: usize = 3;
pub(crate) const DEMO_SEQUENCE_LENGTH: usize = 5;

type SequencePair = (Vec<Array2<f64>>, Vec<Array2<f64>>);

fn demo_early_stopping_config(
patience: usize,
min_delta: f64,
restore_best_weights: bool,
monitor: EarlyStoppingMetric,
) -> EarlyStoppingConfig {
EarlyStoppingConfig {
patience,
min_delta,
restore_best_weights,
monitor,
}
}

pub(crate) fn validation_early_stopping_training_config() -> TrainingConfig {
TrainingConfig {
epochs: DEMO_EPOCHS,
print_every: DEMO_PRINT_EVERY,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: Some(demo_early_stopping_config(
3,
1e-4,
true,
EarlyStoppingMetric::ValidationLoss,
)),
}
}

pub(crate) fn train_loss_early_stopping_training_config() -> TrainingConfig {
TrainingConfig {
epochs: DEMO_EPOCHS,
print_every: DEMO_PRINT_EVERY,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: Some(demo_early_stopping_config(
3,
0.1,
true,
EarlyStoppingMetric::TrainLoss,
)),
}
}

pub(crate) fn no_weight_restoration_training_config() -> TrainingConfig {
TrainingConfig {
epochs: DEMO_EPOCHS,
print_every: DEMO_PRINT_EVERY,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: Some(demo_early_stopping_config(
3,
1e-4,
false,
EarlyStoppingMetric::ValidationLoss,
)),
}
}

pub(crate) fn custom_patience_training_config() -> TrainingConfig {
TrainingConfig {
epochs: DEMO_EPOCHS,
print_every: DEMO_PRINT_EVERY,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: Some(demo_early_stopping_config(
MAX_DEMO_PATIENCE,
1e-6,
true,
EarlyStoppingMetric::ValidationLoss,
)),
}
}

fn main() {
println!("Early Stopping Demonstration");
println!("================================\n");
Expand Down Expand Up @@ -41,25 +124,11 @@ fn demonstrate_validation_early_stopping(

let network = LSTMNetwork::new(1, 8, 1);

// Configure early stopping with default settings (validation loss monitoring)
let early_stopping_config = EarlyStoppingConfig {
patience: 5,
min_delta: 1e-4,
restore_best_weights: true,
monitor: EarlyStoppingMetric::ValidationLoss,
};

let training_config = TrainingConfig {
epochs: 100, // Will likely stop early
print_every: 1,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: Some(early_stopping_config),
};
let training_config = validation_early_stopping_training_config();

let mut trainer = create_basic_trainer(network, 0.01).with_config(training_config);

println!("Training with validation loss monitoring (patience=5)...");
println!("Training with validation loss monitoring (patience=3)...");
trainer.train(train_data, Some(val_data));

// Show final metrics
Expand All @@ -83,25 +152,11 @@ fn demonstrate_train_loss_early_stopping(

let network = LSTMNetwork::new(1, 8, 1);

// Configure early stopping to monitor training loss
let early_stopping_config = EarlyStoppingConfig {
patience: 8,
min_delta: 1e-5,
restore_best_weights: true,
monitor: EarlyStoppingMetric::TrainLoss,
};

let training_config = TrainingConfig {
epochs: 100,
print_every: 1,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: Some(early_stopping_config),
};
let training_config = train_loss_early_stopping_training_config();

let mut trainer = create_basic_trainer(network, 0.01).with_config(training_config);

println!("Training with training loss monitoring (patience=8)...");
println!("Training with training loss monitoring (patience=3)...");
trainer.train(train_data, Some(val_data));

if let Some(final_metrics) = trainer.get_latest_metrics() {
Expand All @@ -124,21 +179,7 @@ fn demonstrate_no_weight_restoration(

let network = LSTMNetwork::new(1, 8, 1);

// Configure early stopping without restoring best weights
let early_stopping_config = EarlyStoppingConfig {
patience: 5,
min_delta: 1e-4,
restore_best_weights: false, // Don't restore best weights
monitor: EarlyStoppingMetric::ValidationLoss,
};

let training_config = TrainingConfig {
epochs: 100,
print_every: 1,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: Some(early_stopping_config),
};
let training_config = no_weight_restoration_training_config();

let mut trainer = create_basic_trainer(network, 0.01).with_config(training_config);

Expand All @@ -161,30 +202,16 @@ fn demonstrate_custom_patience(
train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
val_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
) {
println!("4. EARLY STOPPING WITH HIGH PATIENCE");
println!("====================================");
println!("4. EARLY STOPPING WITH BOUNDED HIGHER PATIENCE");
println!("================================================");

let network = LSTMNetwork::new(1, 8, 1);

// Configure early stopping with higher patience
let early_stopping_config = EarlyStoppingConfig {
patience: 15, // More patient
min_delta: 1e-6, // Smaller improvement threshold
restore_best_weights: true,
monitor: EarlyStoppingMetric::ValidationLoss,
};

let training_config = TrainingConfig {
epochs: 100,
print_every: 2,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: Some(early_stopping_config),
};
let training_config = custom_patience_training_config();

let mut trainer = create_basic_trainer(network, 0.01).with_config(training_config);

println!("Training with high patience (patience=15)...");
println!("Training with bounded higher patience (patience=4)...");
trainer.train(train_data, Some(val_data));

if let Some(final_metrics) = trainer.get_latest_metrics() {
Expand All @@ -199,20 +226,17 @@ fn demonstrate_custom_patience(

/// Generate synthetic data that will cause overfitting
/// This creates a simple pattern that's easy to memorize but doesn't generalize well
fn generate_overfitting_data() -> (
Vec<(Vec<Array2<f64>>, Vec<Array2<f64>>)>,
Vec<(Vec<Array2<f64>>, Vec<Array2<f64>>)>,
) {
pub(crate) fn generate_overfitting_data() -> (Vec<SequencePair>, Vec<SequencePair>) {
let mut train_data = Vec::new();
let mut val_data = Vec::new();

// Create training data - simple sine wave with noise
for i in 0..20 {
for i in 0..DEMO_TRAIN_SEQUENCES {
let mut inputs = Vec::new();
let mut targets = Vec::new();

let phase = i as f64 * 0.1;
for t in 0..10 {
for t in 0..DEMO_SEQUENCE_LENGTH {
let x = (t as f64 * 0.3 + phase).sin();
let y = ((t + 1) as f64 * 0.3 + phase).sin(); // Next value

Expand All @@ -224,12 +248,12 @@ fn generate_overfitting_data() -> (
}

// Create validation data - different phase to test generalization
for i in 0..5 {
for i in 0..DEMO_VALIDATION_SEQUENCES {
let mut inputs = Vec::new();
let mut targets = Vec::new();

let phase = (i as f64 + 100.0) * 0.1; // Different phase
for t in 0..10 {
for t in 0..DEMO_SEQUENCE_LENGTH {
let x = (t as f64 * 0.3 + phase).sin();
let y = ((t + 1) as f64 * 0.3 + phase).sin();

Expand Down
62 changes: 62 additions & 0 deletions tests/example_training_bounds_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ mod training_example;
#[path = "../examples/stock_prediction.rs"]
mod stock_prediction;

#[allow(dead_code)]
#[path = "../examples/early_stopping_example.rs"]
mod early_stopping_example;

use std::hint::black_box;

#[test]
Expand Down Expand Up @@ -66,3 +70,61 @@ fn stock_prediction_demo_data_is_reproducible() {
"demo data should still vary across generated days"
);
}

#[test]
fn early_stopping_example_applies_bounded_configs_to_all_demo_trainers() {
let configs = [
early_stopping_example::validation_early_stopping_training_config(),
early_stopping_example::train_loss_early_stopping_training_config(),
early_stopping_example::no_weight_restoration_training_config(),
early_stopping_example::custom_patience_training_config(),
];

for config in configs {
assert!(
black_box(config.epochs) <= 6,
"early_stopping_example should keep every demo training path bounded"
);
assert!(
black_box(config.print_every) <= black_box(config.epochs),
"early_stopping_example progress logging should not exceed the epoch budget"
);
let early_stopping = config
.early_stopping
.expect("early_stopping_example demos should enable early stopping");
assert!(
black_box(early_stopping.patience) <= 4,
"early_stopping_example patience should stay bounded for deterministic CI runs"
);
}
}

#[test]
fn early_stopping_example_uses_small_deterministic_fixture() {
let (first_train, first_val) = early_stopping_example::generate_overfitting_data();
let (second_train, second_val) = early_stopping_example::generate_overfitting_data();

assert_eq!(
first_train, second_train,
"demo training fixture should be deterministic"
);
assert_eq!(
first_val, second_val,
"demo validation fixture should be deterministic"
);
assert!(
black_box(first_train.len()) <= 8,
"early_stopping_example should keep training sequence count bounded"
);
assert!(
black_box(first_val.len()) <= 3,
"early_stopping_example should keep validation sequence count bounded"
);
assert!(
first_train
.iter()
.chain(first_val.iter())
.all(|(inputs, targets)| inputs.len() <= 5 && targets.len() <= 5),
"early_stopping_example should keep each fixture sequence bounded"
);
}
Loading