Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Add to your `Cargo.toml`:

```toml
[dependencies]
rust-lstm = "0.6"
rust-lstm = "0.8"
```

### Basic Usage
Expand All @@ -78,6 +78,7 @@ fn main() {
### Training Example

```rust
use ndarray::Array2;
use rust_lstm::{LSTMNetwork, create_basic_trainer, TrainingConfig};

fn main() {
Expand All @@ -94,15 +95,23 @@ fn main() {
..Default::default()
});

// Train (train_data is slice of (input_sequence, target_sequence) tuples)
// Each input_sequence and target_sequence is Vec<Array2<f64>>
// Train data is a slice of (input_sequence, target_sequence) tuples.
// Each input_sequence and target_sequence is Vec<Array2<f64>>.
let train_data = vec![(
vec![Array2::from_shape_vec((1, 1), vec![0.0]).unwrap()],
vec![Array2::from_shape_vec((10, 1), vec![0.0; 10]).unwrap()],
)];
// Keep validation data separate from training data in real applications.
let validation_data = train_data.clone();

trainer.train(&train_data, Some(&validation_data));
}
```

### Early Stopping

```rust
use ndarray::Array2;
use rust_lstm::{
LSTMNetwork, create_basic_trainer, TrainingConfig,
EarlyStoppingConfig, EarlyStoppingMetric
Expand All @@ -128,6 +137,13 @@ fn main() {
let mut trainer = create_basic_trainer(network, 0.001)
.with_config(config);

let train_data = vec![(
vec![Array2::from_shape_vec((1, 1), vec![0.0]).unwrap()],
vec![Array2::from_shape_vec((10, 1), vec![0.0; 10]).unwrap()],
)];
// Keep validation data separate from training data in real applications.
let validation_data = train_data.clone();

// Training will stop early if validation loss stops improving
trainer.train(&train_data, Some(&validation_data));
}
Expand All @@ -136,7 +152,16 @@ fn main() {
### Bidirectional LSTM

```rust
use rust_lstm::layers::bilstm_network::{BiLSTMNetwork, CombineMode};
use ndarray::Array2;
use rust_lstm::layers::bilstm_network::BiLSTMNetwork;

let input_size = 3;
let hidden_size = 5;
let num_layers = 1;
let sequence = vec![
Array2::from_shape_vec((input_size, 1), vec![0.5, 0.1, -0.3]).unwrap(),
Array2::from_shape_vec((input_size, 1), vec![0.2, -0.4, 0.7]).unwrap(),
];

// BiLSTM with concatenated outputs (output_size = 2 * hidden_size)
let mut bilstm = BiLSTMNetwork::new_concat(input_size, hidden_size, num_layers);
Expand Down Expand Up @@ -170,23 +195,38 @@ graph TD
### GRU Networks

```rust
use ndarray::Array2;
use rust_lstm::models::gru_network::GRUNetwork;

let input_size = 3;
let hidden_size = 5;
let num_layers = 2;

// Create GRU network (alternative to LSTM)
let mut gru = GRUNetwork::new(input_size, hidden_size, num_layers)
.with_input_dropout(0.2, true)
.with_recurrent_dropout(0.3, true);

// Forward pass
let (output, _) = gru.forward(&input, &hidden_state);
let input = Array2::from_shape_vec((input_size, 1), vec![0.5, 0.1, -0.3]).unwrap();
let hidden_states = vec![Array2::zeros((hidden_size, 1)); num_layers];

// Forward pass returns one hidden state per layer
let outputs = gru.forward(&input, &hidden_states);
let output = outputs.last().unwrap();
```

### Linear Layer

```rust
use ndarray::Array2;
use rust_lstm::layers::linear::LinearLayer;
use rust_lstm::optimizers::Adam;

let hidden_size = 4;
let num_classes = 3;
let lstm_output = Array2::ones((hidden_size, 1));
let grad_output = Array2::ones((num_classes, 1));

// Create linear layer for classification: hidden_size -> num_classes
let mut classifier = LinearLayer::new(hidden_size, num_classes);
let mut optimizer = Adam::new(0.001);
Expand Down Expand Up @@ -251,7 +291,7 @@ use rust_lstm::{
let network = LSTMNetwork::new(1, 10, 2);

// Step decay: reduce LR by 50% every 10 epochs
let mut trainer = create_step_lr_trainer(network, 0.01, 10, 0.5);
let mut trainer = create_step_lr_trainer(network.clone(), 0.01, 10, 0.5);

// OneCycle policy for modern deep learning
let mut trainer = create_one_cycle_trainer(network.clone(), 0.1, 100);
Expand Down
139 changes: 136 additions & 3 deletions tests/readme_examples_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,20 @@

use ndarray::Array2;
use rust_lstm::{
layers::bilstm_network::BiLSTMNetwork,
layers::dropout::{Dropout, Zoneout},
layers::linear::LinearLayer,
layers::peephole_lstm_cell::PeepholeLSTMCell,
loss::{CrossEntropyLoss, MAELoss, MSELoss},
optimizers::{Adam, RMSprop, SGD},
training::create_basic_trainer,
LSTMNetwork, LSTMTrainer, LayerDropoutConfig, TrainingConfig,
models::gru_network::GRUNetwork,
optimizers::{Adam, RMSprop, ScheduledOptimizer, SGD},
schedulers::{CyclicalLR, LRScheduleVisualizer, PolynomialLR, WarmupScheduler},
training::{
create_basic_trainer, create_cosine_annealing_trainer, create_one_cycle_trainer,
create_step_lr_trainer,
},
EarlyStoppingConfig, EarlyStoppingMetric, LSTMNetwork, LSTMTrainer, LayerDropoutConfig,
TrainingConfig,
};

#[test]
Expand Down Expand Up @@ -119,6 +127,131 @@ fn test_training_example() {
assert_eq!(predictions[0].shape(), &[4, 1]);
}

#[test]
fn test_readme_early_stopping_example() {
let network = LSTMNetwork::new(1, 4, 1);

// Configure early stopping
let early_stopping = EarlyStoppingConfig {
patience: 2,
min_delta: 1e-4,
restore_best_weights: true,
monitor: EarlyStoppingMetric::ValidationLoss,
};

let config = TrainingConfig {
epochs: 2,
early_stopping: Some(early_stopping),
..Default::default()
};

let mut trainer = create_basic_trainer(network, 0.001).with_config(config);
let train_data = generate_test_data();
let validation_data = generate_test_data();

trainer.train(&train_data, Some(&validation_data));

assert_eq!(trainer.config.early_stopping.as_ref().unwrap().patience, 2);
}

#[test]
fn test_readme_bilstm_example() {
let input_size = 3;
let hidden_size = 5;
let num_layers = 1;

// BiLSTM with concatenated outputs (output_size = 2 * hidden_size)
let mut bilstm = BiLSTMNetwork::new_concat(input_size, hidden_size, num_layers);

// Process sequence with both past and future context
let sequence = vec![
Array2::from_shape_vec((input_size, 1), vec![0.5, 0.1, -0.3]).unwrap(),
Array2::from_shape_vec((input_size, 1), vec![0.2, -0.4, 0.7]).unwrap(),
];
let outputs = bilstm.forward_sequence(&sequence);

assert_eq!(outputs.len(), sequence.len());
for output in outputs {
assert_eq!(output.shape(), &[2 * hidden_size, 1]);
}
}

#[test]
fn test_readme_gru_example() {
let input_size = 3;
let hidden_size = 5;
let num_layers = 2;

// Create GRU network (alternative to LSTM)
let mut gru = GRUNetwork::new(input_size, hidden_size, num_layers)
.with_input_dropout(0.2, true)
.with_recurrent_dropout(0.3, true);

let input = Array2::from_shape_vec((input_size, 1), vec![0.5, 0.1, -0.3]).unwrap();
let hidden_states = vec![Array2::zeros((hidden_size, 1)); num_layers];

// Forward pass returns one hidden state per layer
let outputs = gru.forward(&input, &hidden_states);
let output = outputs.last().unwrap();

assert_eq!(outputs.len(), num_layers);
assert_eq!(output.shape(), &[hidden_size, 1]);
}

#[test]
fn test_readme_linear_layer_example() {
let hidden_size = 4;
let num_classes = 3;

// Create linear layer for classification: hidden_size -> num_classes
let mut classifier = LinearLayer::new(hidden_size, num_classes);
let mut optimizer = Adam::new(0.001);

// Forward pass
let lstm_output = Array2::ones((hidden_size, 1));
let logits = classifier.forward(&lstm_output);

// Backward pass
let grad_output = Array2::ones((num_classes, 1));
let (gradients, input_grad) = classifier.backward(&grad_output);
classifier.update_parameters(&gradients, &mut optimizer, "classifier");

assert_eq!(logits.shape(), &[num_classes, 1]);
assert_eq!(input_grad.shape(), &[hidden_size, 1]);
}

#[test]
fn test_readme_advanced_learning_rate_scheduling_example() {
// Create a network
let network = LSTMNetwork::new(1, 4, 1);

// Step decay: reduce LR by 50% every 10 epochs
let mut step_trainer = create_step_lr_trainer(network.clone(), 0.01, 10, 0.5);

// OneCycle policy for modern deep learning
let mut one_cycle_trainer = create_one_cycle_trainer(network.clone(), 0.1, 100);

// Cosine annealing with warm restarts
let mut cosine_trainer = create_cosine_annealing_trainer(network.clone(), 0.01, 20, 1e-6);

// Advanced combinations - Warmup + Cyclical scheduling
let base_scheduler = CyclicalLR::new(0.001, 0.01, 10);
let warmup_scheduler = WarmupScheduler::new(5, base_scheduler, 0.0001);
let mut optimizer = ScheduledOptimizer::new(Adam::new(0.01), warmup_scheduler, 0.01);

// Polynomial decay with visualization
let poly_scheduler = PolynomialLR::new(100, 2.0, 0.001);
let schedule = LRScheduleVisualizer::generate_schedule(poly_scheduler, 0.01, 100);

step_trainer.optimizer.step();
one_cycle_trainer.optimizer.step();
cosine_trainer.optimizer.step();
optimizer.step();

assert_eq!(schedule.len(), 100);
assert!(optimizer.get_current_lr() > 0.0);
}

#[test]
fn test_dropout_types_example() {
// Standard dropout
Expand Down
Loading