Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 41 additions & 27 deletions examples/real_data_example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,35 @@
#![allow(unused_comparisons)]

use ndarray::{arr2, Array2};
use rand::{rngs::StdRng, Rng, SeedableRng};
use rust_lstm::loss::MSELoss;
use rust_lstm::models::lstm_network::LSTMNetwork;
use rust_lstm::optimizers::Adam;
use rust_lstm::training::LSTMTrainer;
use rust_lstm::training::{LSTMTrainer, TrainingConfig};
use std::fs::File;
use std::io::{BufRead, BufReader};

pub const DEMO_SENSOR_DAYS: usize = 7;
pub const DEMO_SEQUENCE_LENGTH: usize = 12;
pub const DEMO_HIDDEN_SIZE: usize = 32;
pub const DEMO_RECENT_WINDOW_HOURS: usize = 48;
pub const DEMO_PREDICTION_START_HOUR: usize = 24;
pub const DEMO_NUM_PREDICTIONS: usize = 5;
pub const DEMO_EPOCHS: usize = 5;
pub const DEMO_PRINT_EVERY: usize = 2;
pub const DEMO_RANDOM_SEED: u64 = 42;

/// Generic data point for time series
#[derive(Debug, Clone)]
struct DataPoint {
timestamp: String,
values: Vec<f64>,
#[derive(Debug, Clone, PartialEq)]
pub struct DataPoint {
pub timestamp: String,
pub values: Vec<f64>,
}

/// Data loader for CSV files
struct CSVDataLoader {
data: Vec<DataPoint>,
feature_names: Vec<String>,
pub struct CSVDataLoader {
pub data: Vec<DataPoint>,
pub feature_names: Vec<String>,
normalizers: Vec<(f64, f64)>, // (mean, std) for each feature
}

Expand Down Expand Up @@ -98,7 +109,8 @@ impl CSVDataLoader {
}

/// Generate synthetic CSV-like data for demonstration
fn generate_synthetic_sensor_data(days: usize) -> Self {
pub fn generate_synthetic_sensor_data(days: usize) -> Self {
let mut rng = StdRng::seed_from_u64(DEMO_RANDOM_SEED);
let mut data = Vec::new();

// Simulate IoT sensor data: temperature, humidity, pressure, light
Expand All @@ -112,22 +124,22 @@ impl CSVDataLoader {
let seasonal_temp_cycle =
15.0 * (2.0 * std::f64::consts::PI * day_of_year / 365.0).sin();
let temperature =
20.0 + daily_temp_cycle + seasonal_temp_cycle + (rand::random::<f64>() - 0.5) * 3.0;
20.0 + daily_temp_cycle + seasonal_temp_cycle + (rng.gen::<f64>() - 0.5) * 3.0;

// Humidity inversely related to temperature
let humidity = 70.0 - (temperature - 20.0) * 1.5 + (rand::random::<f64>() - 0.5) * 15.0;
let humidity = 70.0 - (temperature - 20.0) * 1.5 + (rng.gen::<f64>() - 0.5) * 15.0;
let humidity = humidity.clamp(20.0, 95.0);

// Pressure with weather patterns
let pressure =
1013.25 + 10.0 * (day_of_year / 30.0).sin() + (rand::random::<f64>() - 0.5) * 20.0;
1013.25 + 10.0 * (day_of_year / 30.0).sin() + (rng.gen::<f64>() - 0.5) * 20.0;

// Light with daily cycle
let light = if (6.0..=18.0).contains(&hour_of_day) {
1000.0 * (std::f64::consts::PI * (hour_of_day - 6.0) / 12.0).sin()
+ (rand::random::<f64>() - 0.5) * 200.0
+ (rng.gen::<f64>() - 0.5) * 200.0
} else {
(rand::random::<f64>() * 50.0).max(0.0)
(rng.gen::<f64>() * 50.0).max(0.0)
};

let timestamp = format!(
Expand Down Expand Up @@ -279,12 +291,7 @@ impl TimeSeriesPredictor {
let optimizer = Adam::new(0.001);
let mut trainer = LSTMTrainer::new(self.network.clone(), loss_function, optimizer);

// Configure for quick demo
let mut config = rust_lstm::training::TrainingConfig::default();
config.epochs = 5; // Very reduced for quick demo
config.print_every = 2; // Print every 2 epochs

trainer = trainer.with_config(config);
trainer = trainer.with_config(real_data_training_config());

trainer.train(train_data, Some(val_data));

Expand Down Expand Up @@ -321,13 +328,21 @@ impl TimeSeriesPredictor {
}
}

pub fn real_data_training_config() -> TrainingConfig {
TrainingConfig {
epochs: DEMO_EPOCHS,
print_every: DEMO_PRINT_EVERY,
..TrainingConfig::default()
}
}

fn main() {
println!("📈 Real Data Time Series Prediction with LSTM");
println!("===============================================\n");

// Generate synthetic sensor data (in practice, load from real CSV)
println!("📡 Generating synthetic IoT sensor data...");
let mut data_loader = CSVDataLoader::generate_synthetic_sensor_data(7); // 7 days for quick demo
let mut data_loader = CSVDataLoader::generate_synthetic_sensor_data(DEMO_SENSOR_DAYS);

println!(
"📊 Data loaded: {} data points with {} features",
Expand Down Expand Up @@ -360,8 +375,8 @@ fn main() {
// Create predictor to predict temperature (feature 0)
let mut predictor = TimeSeriesPredictor::new(
data_loader.feature_names.len(), // All features as input
12, // 12-hour sequences (reduced for speed)
32, // 32 hidden units (reduced for speed)
DEMO_SEQUENCE_LENGTH, // Reduced for speed
DEMO_HIDDEN_SIZE, // Reduced for speed
0, // Predict temperature (index 0)
);

Expand All @@ -370,11 +385,10 @@ fn main() {

// Make predictions on recent data
println!("\n🔮 Making temperature predictions:");
let recent_data = &data_loader.data[data_loader.data.len() - 48..]; // Last 48 hours
let recent_data = &data_loader.data[data_loader.data.len() - DEMO_RECENT_WINDOW_HOURS..];

for i in 24..29 {
// Predict for hours 25-29
let input_data = &recent_data[i - 24..i];
for i in DEMO_PREDICTION_START_HOUR..DEMO_PREDICTION_START_HOUR + DEMO_NUM_PREDICTIONS {
let input_data = &recent_data[i - DEMO_PREDICTION_START_HOUR..i];
if let Some(predicted_temp) = predictor.predict_next(&data_loader, input_data) {
let actual_temp = recent_data[i].values[0];
let error = (predicted_temp - actual_temp).abs();
Expand Down
121 changes: 121 additions & 0 deletions tests/example_training_bounds_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ mod dropout_example;
#[path = "../examples/weather_prediction.rs"]
mod weather_prediction;

#[allow(dead_code)]
#[path = "../examples/real_data_example.rs"]
mod real_data_example;

use std::hint::black_box;

#[test]
Expand Down Expand Up @@ -462,3 +466,120 @@ fn weather_prediction_demo_data_is_reproducible() {
"generated weather fixture should keep bounded meteorological values"
);
}

#[test]
fn real_data_example_applies_bounded_training_config() {
let config = real_data_example::real_data_training_config();

assert!(
black_box(config.epochs) <= 5,
"real_data_example should avoid long default training runs"
);
assert!(
black_box(config.print_every) > 0,
"real_data_example progress logging should stay enabled"
);
assert!(
black_box(config.print_every) <= black_box(config.epochs),
"real_data_example progress logging should not exceed the epoch budget"
);
assert!(
config.early_stopping.is_none(),
"real_data_example should avoid hidden early-stopping work"
);
}

#[test]
fn real_data_example_uses_bounded_demo_sizes() {
assert!(
black_box(real_data_example::DEMO_SENSOR_DAYS) <= 7,
"real_data_example should keep its synthetic sensor dataset bounded"
);
assert!(
black_box(real_data_example::DEMO_SEQUENCE_LENGTH) <= 12,
"real_data_example should keep each sequence bounded"
);
assert!(
black_box(real_data_example::DEMO_HIDDEN_SIZE) <= 32,
"real_data_example should keep hidden size bounded"
);
assert!(
black_box(real_data_example::DEMO_RECENT_WINDOW_HOURS) <= 48,
"real_data_example should keep preview prediction data bounded"
);
assert!(
black_box(real_data_example::DEMO_NUM_PREDICTIONS) <= 5,
"real_data_example should keep preview prediction count bounded"
);
assert!(
black_box(real_data_example::DEMO_PREDICTION_START_HOUR)
>= black_box(real_data_example::DEMO_SEQUENCE_LENGTH),
"real_data_example should feed enough history into each preview prediction"
);
assert!(
black_box(real_data_example::DEMO_RECENT_WINDOW_HOURS)
>= black_box(
real_data_example::DEMO_PREDICTION_START_HOUR
+ real_data_example::DEMO_NUM_PREDICTIONS,
),
"real_data_example should keep enough recent data for every preview prediction"
);
assert!(
black_box(real_data_example::DEMO_SENSOR_DAYS * 24)
>= black_box(real_data_example::DEMO_RECENT_WINDOW_HOURS),
"real_data_example dataset should cover the preview prediction window"
);
}

#[test]
fn real_data_example_synthetic_sensor_data_is_reproducible() {
let first = real_data_example::CSVDataLoader::generate_synthetic_sensor_data(black_box(2));
let second = real_data_example::CSVDataLoader::generate_synthetic_sensor_data(black_box(2));

assert_eq!(
first.data, second.data,
"real_data_example synthetic sensor fixture should be deterministic"
);
assert_eq!(
first.feature_names, second.feature_names,
"real_data_example synthetic sensor features should be deterministic"
);
assert_eq!(
first.data.len(),
black_box(2 * 24),
"real_data_example should generate hourly sensor readings"
);
assert!(
first.data.windows(2).any(|window| window[0] != window[1]),
"real_data_example synthetic sensor fixture should vary over time"
);
}

#[test]
fn real_data_example_synthetic_sensor_data_has_sane_value_bounds() {
let loader = real_data_example::CSVDataLoader::generate_synthetic_sensor_data(black_box(12));

assert_eq!(
loader.feature_names,
["temperature", "humidity", "pressure", "light"],
"real_data_example should keep the expected sensor feature layout"
);
assert!(
loader.data.iter().all(|point| point.values.len() == 4),
"real_data_example should generate four sensor values per reading"
);
assert!(
loader.data.iter().all(|point| {
let temperature = point.values[0];
let humidity = point.values[1];
let pressure = point.values[2];
let light = point.values[3];

(-5.0..=45.0).contains(&temperature)
&& (20.0..=95.0).contains(&humidity)
&& (990.0..=1040.0).contains(&pressure)
&& (-100.0..=1100.0).contains(&light)
}),
"real_data_example synthetic sensor values should stay in sane ranges"
);
}
Loading