Skip to content

feat(deep_causality): Added Programmatic Verification of Model Assumptions #277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 31, 2025
Merged
15 changes: 12 additions & 3 deletions deep_causality/benches/benchmarks/bench_collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ fn small_causality_collection_benchmark(criterion: &mut Criterion) {
let evidence = PropagatingEffect::Numerical(0.99);

criterion.bench_function("small_causality_collection_propagation", |bencher| {
bencher.iter(|| coll.evaluate_deterministic_propagation(&evidence).unwrap())
bencher.iter(|| {
coll.evaluate_deterministic_propagation(&evidence, &AggregateLogic::All)
.unwrap()
})
});
}

Expand All @@ -28,7 +31,10 @@ fn medium_causality_collection_benchmark(criterion: &mut Criterion) {
let evidence = PropagatingEffect::Numerical(0.99);

criterion.bench_function("medium_causality_collection_propagation", |bencher| {
bencher.iter(|| coll.evaluate_deterministic_propagation(&evidence).unwrap())
bencher.iter(|| {
coll.evaluate_deterministic_propagation(&evidence, &AggregateLogic::All)
.unwrap()
})
});
}

Expand All @@ -37,7 +43,10 @@ fn large_causality_collection_benchmark(criterion: &mut Criterion) {
let evidence = PropagatingEffect::Numerical(0.99);

criterion.bench_function("large_causality_collection_propagation", |bencher| {
bencher.iter(|| coll.evaluate_deterministic_propagation(&evidence).unwrap())
bencher.iter(|| {
coll.evaluate_deterministic_propagation(&evidence, &AggregateLogic::All)
.unwrap()
})
});
}

Expand Down
43 changes: 43 additions & 0 deletions deep_causality/src/errors/assumption_error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* SPDX-License-Identifier: MIT
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
*/

//!
//! Error type for assumption checking.
//!
use std::error::Error;
use std::fmt;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum AssumptionError {
/// Error returned when verification is attempted on a model with no assumptions.
NoAssumptionsDefined,
///Error returned when verification is attempted without data i.e. empty collection.
NoDataToTestDefined,
///Error to capture the specific failed assumption
AssumptionFailed(String),
/// Wraps an error that occurred during the execution of an assumption function.
EvaluationFailed(String),
}

impl Error for AssumptionError {}

impl fmt::Display for AssumptionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AssumptionError::NoAssumptionsDefined => {
write!(f, "Model has no assumptions to verify")
}
AssumptionError::NoDataToTestDefined => {
write!(f, "No Data to test provided")
}
AssumptionError::AssumptionFailed(a) => {
write!(f, "Assumption failed: {a}")
}
AssumptionError::EvaluationFailed(msg) => {
write!(f, "Failed to evaluate assumption: {msg}")
}
}
}
}
4 changes: 3 additions & 1 deletion deep_causality/src/errors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

mod action_error;
mod adjustment_error;
mod assumption_error;
mod build_error;
mod causal_graph_index_error;
mod causality_error;
Expand All @@ -13,11 +14,12 @@ mod context_index_error;
mod index_error;
mod model_build_error;
mod model_generation_error;
mod model_validation_error;
pub mod model_validation_error;
mod update_error;

pub use action_error::*;
pub use adjustment_error::*;
pub use assumption_error::*;
pub use build_error::*;
pub use causal_graph_index_error::*;
pub use causality_error::*;
Expand Down
45 changes: 19 additions & 26 deletions deep_causality/src/extensions/causable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,24 @@ use std::hash::Hash;

// Extension trait http://xion.io/post/code/rust-extension-traits.html
use deep_causality_macros::{
make_array_to_vec, make_get_all_items, make_get_all_map_items, make_is_empty, make_len,
make_map_to_vec, make_vec_to_vec,
make_array_to_vec, make_find_from_iter_values, make_find_from_map_values, make_get_all_items,
make_get_all_map_items, make_is_empty, make_len, make_map_to_vec, make_vec_deq_to_vec,
make_vec_to_vec,
};

use crate::Causable;
use crate::traits::causable::causable_reasoning::CausableReasoning;
use crate::{Causable, IdentificationValue};

impl<T> CausableReasoning<T> for [T]
impl<K, V> CausableReasoning<V> for HashMap<K, V>
where
T: Causable + Clone,
K: Eq + Hash,
V: Causable + Clone,
{
make_len!();
make_is_empty!();
make_get_all_items!();
make_array_to_vec!();
make_map_to_vec!();
make_get_all_map_items!();
make_find_from_map_values!();
}

impl<K, V> CausableReasoning<V> for BTreeMap<K, V>
Expand All @@ -34,17 +37,18 @@ where
make_is_empty!();
make_map_to_vec!();
make_get_all_map_items!();
make_find_from_map_values!();
}

impl<K, V> CausableReasoning<V> for HashMap<K, V>
impl<T> CausableReasoning<T> for [T]
where
K: Eq + Hash,
V: Causable + Clone,
T: Causable + Clone,
{
make_len!();
make_is_empty!();
make_map_to_vec!();
make_get_all_map_items!();
make_get_all_items!();
make_array_to_vec!();
make_find_from_iter_values!();
}

impl<T> CausableReasoning<T> for Vec<T>
Expand All @@ -55,6 +59,7 @@ where
make_is_empty!();
make_vec_to_vec!();
make_get_all_items!();
make_find_from_iter_values!();
}

impl<T> CausableReasoning<T> for VecDeque<T>
Expand All @@ -64,18 +69,6 @@ where
make_len!();
make_is_empty!();
make_get_all_items!();
// VecDeque can't be turned into a vector hence the custom implementation
// https://github.com/rust-lang/rust/issues/23308
// Also, make_contiguous requires self to be mutable, which would violate the API, hence the clone.
// https://doc.rust-lang.org/std/collections/struct.VecDeque.html#method.make_contiguous
fn to_vec(&self) -> Vec<T> {
let mut v = Vec::with_capacity(self.len());
let mut deque = self.clone(); // clone to avoid mutating the original

for item in deque.make_contiguous().iter() {
v.push(item.clone());
}

v
}
make_vec_deq_to_vec!();
make_find_from_iter_values!();
}
3 changes: 3 additions & 0 deletions deep_causality/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ pub use crate::traits::observable::ObservableReasoning;
// Scalar Traits
pub use crate::traits::scalar::scalar_projector::ScalarProjector;
pub use crate::traits::scalar::scalar_value::ScalarValue;
// Transferable Trait
pub use crate::traits::transferable::Transferable;
//
// Types
//
Expand Down Expand Up @@ -139,6 +141,7 @@ pub use crate::types::model_types::inference::Inference;
pub use crate::types::model_types::model::Model;
pub use crate::types::model_types::observation::Observation;
// Reasoning types
pub use crate::types::reasoning_types::aggregate_logic::AggregateLogic;
pub use crate::types::reasoning_types::propagating_effect::PropagatingEffect;
//
//Symbolic types
Expand Down
31 changes: 23 additions & 8 deletions deep_causality/src/traits/assumable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
*/

use crate::{DescriptionValue, EvalFn, Identifiable, NumericalValue};
use crate::{AssumptionError, DescriptionValue, Identifiable, NumericalValue, PropagatingEffect};

/// The Assumable trait defines the interface for objects that represent
/// assumptions that can be tested and verified. Assumable types must also
Expand All @@ -25,10 +25,9 @@ use crate::{DescriptionValue, EvalFn, Identifiable, NumericalValue};
///
pub trait Assumable: Identifiable {
fn description(&self) -> DescriptionValue;
fn assumption_fn(&self) -> EvalFn;
fn assumption_tested(&self) -> bool;
fn assumption_valid(&self) -> bool;
fn verify_assumption(&self, data: &[NumericalValue]) -> bool;
fn verify_assumption(&self, data: &[PropagatingEffect]) -> Result<bool, AssumptionError>;
}

/// The AssumableReasoning trait provides default implementations for common
Expand Down Expand Up @@ -119,10 +118,19 @@ where
/// (from `number_assumption_valid()`) by the total number of assumptions
/// (from `len()`) and multiplying by 100.
///
/// Returns the percentage as a NumericalValue.
/// # Errors
///
fn percent_assumption_valid(&self) -> NumericalValue {
(self.number_assumption_valid() / self.len() as NumericalValue) * 100.0
/// Returns `AssumptionError::EvaluationFailed` if the number of assumptions is zero,
/// as percentage calculation would lead to a division by zero.
///
fn percent_assumption_valid(&self) -> Result<NumericalValue, AssumptionError> {
if self.is_empty() {
return Err(AssumptionError::EvaluationFailed(
"Cannot calculate percentage with zero assumptions".to_string(),
));
}
let percentage = (self.number_assumption_valid() / self.len() as NumericalValue) * 100.0;
Ok(percentage)
}

/// Verifies all assumptions in the collection against the provided data.
Expand All @@ -133,10 +141,17 @@ where
/// This will test each assumption against the data and update the
/// `assumption_valid` and `assumption_tested` flags accordingly.
///
fn verify_all_assumptions(&self, data: &[NumericalValue]) {
/// # Errors
///
/// Returns an `AssumptionError` if any of the assumption functions fail during execution.
///
fn verify_all_assumptions(&self, data: &[PropagatingEffect]) -> Result<(), AssumptionError> {
for a in self.get_all_items() {
a.verify_assumption(data);
// We are interested in the side effect of updating the assumption state,
// but we must handle the potential error.
let _ = a.verify_assumption(data)?;
}
Ok(())
}

/// Returns a vector containing references to all invalid assumptions.
Expand Down
12 changes: 11 additions & 1 deletion deep_causality/src/traits/causable/causable_reasoning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
*/

use crate::{Causable, CausalityError, NumericalValue, PropagatingEffect};
use crate::{
AggregateLogic, Causable, CausalityError, IdentificationValue, NumericalValue,
PropagatingEffect,
};

/// Provides default implementations for reasoning over collections of `Causable` items.
///
Expand All @@ -16,6 +19,7 @@ where
{
//
// These methods must be implemented by the collection type.
// See deep_causality/src/extensions/causable/mod.rs
//

/// Returns the total number of `Causable` items in the collection.
Expand All @@ -31,6 +35,9 @@ where
/// This is the primary accessor used by the trait's default methods.
fn get_all_items(&self) -> Vec<&T>;

/// Returns a reference to a `Causable` item by its ID, if found.
fn get_item_by_id(&self, id: IdentificationValue) -> Option<&T>;

//
// Default implementations for all other methods are provided below.
//
Expand All @@ -49,6 +56,7 @@ where
fn evaluate_deterministic_propagation(
&self,
effect: &PropagatingEffect,
_logic: &AggregateLogic,
) -> Result<PropagatingEffect, CausalityError> {
for cause in self.get_all_items() {
let effect = cause.evaluate(effect)?;
Expand Down Expand Up @@ -90,6 +98,7 @@ where
fn evaluate_probabilistic_propagation(
&self,
effect: &PropagatingEffect,
_logic: &AggregateLogic,
) -> Result<PropagatingEffect, CausalityError> {
let mut cumulative_prob: NumericalValue = 1.0;

Expand Down Expand Up @@ -140,6 +149,7 @@ where
fn evaluate_mixed_propagation(
&self,
effect: &PropagatingEffect,
_logic: &AggregateLogic,
) -> Result<PropagatingEffect, CausalityError> {
// The chain starts as deterministically true. It can transition to probabilistic.
let mut aggregated_effect = PropagatingEffect::Deterministic(true);
Expand Down
1 change: 1 addition & 0 deletions deep_causality/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ pub mod indexable;
pub mod inferable;
pub mod observable;
pub mod scalar;
pub mod transferable;
56 changes: 56 additions & 0 deletions deep_causality/src/traits/transferable/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* SPDX-License-Identifier: MIT
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
*/

//!
//! Trait for assumption verification used in the Model type.
//!
use crate::{Assumable, Assumption, AssumptionError, PropagatingEffect};
use std::sync::Arc;

pub trait Transferable {
fn get_assumptions(&self) -> &Option<Arc<Vec<Assumption>>>;

/// Verifies the model's assumptions against a given PropagatingEffect.
///
/// The function iterates through all defined assumptions and checks them against
/// the provided data. It short-circuits and returns immediately on the first
/// failure or error.
///
/// Overwrite the default implementation if you need customization.
///
/// # Arguments
/// * `effect` - Sample data to be tested. Details on sampling should be documented in each assumption.
///
/// # Returns
/// * `Ok(())` if all assumptions hold true.
/// * `Err(AssumptionError::AssumptionFailed(String))` if an assumption is not met.
/// * `Err(AssumptionError::NoAssumptionsDefined)` if the model has no assumptions.
/// * `Err(AssumptionError::NoDataToTestDefined)` if the effect slice is empty.
/// * `Err(AssumptionError::EvaluationError(...))` if an error occurs during evaluation.
///
fn verify_assumptions(&self, effect: &[PropagatingEffect]) -> Result<(), AssumptionError> {
if effect.is_empty() {
return Err(AssumptionError::NoDataToTestDefined);
}

if self.get_assumptions().is_none() {
return Err(AssumptionError::NoAssumptionsDefined);
}

let assumptions = self.get_assumptions().as_ref().unwrap();

for assumption in assumptions.iter() {
// The `?` operator propagates any evaluation errors.
if !assumption.verify_assumption(effect)? {
// If an assumption returns `Ok(false)`, the check has failed.
// We now return an error containing the specific assumption that failed.
return Err(AssumptionError::AssumptionFailed(assumption.to_string()));
}
}

// If the loop completes, all assumptions passed.
Ok(())
}
}
4 changes: 2 additions & 2 deletions deep_causality/src/types/alias_types/alias_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
* SPDX-License-Identifier: MIT
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
*/
use crate::{CausalityError, Context, NumericalValue, PropagatingEffect};
use crate::{AssumptionError, CausalityError, Context, PropagatingEffect};
use std::sync::Arc;

// Fn aliases for assumable, assumption, & assumption collection
/// Function type for evaluating numerical values and returning a boolean result.
/// This remains unchanged as it serves a different purpose outside the core causal reasoning.
pub type EvalFn = fn(&[NumericalValue]) -> bool;
pub type EvalFn = fn(&[PropagatingEffect]) -> Result<bool, AssumptionError>;

/// The unified function signature for all singleton causaloids that do not require an external context.
///
Expand Down
Loading
Loading