Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thread-safe EGraph struct #517

Merged
merged 2 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/ast/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,16 +217,19 @@ fn map_fallible<T>(
.collect::<Result<_, _>>()
}

pub trait Macro<T> {
Alex-Fischman marked this conversation as resolved.
Show resolved Hide resolved
pub trait Macro<T>: Send + Sync {
fn name(&self) -> Symbol;
fn parse(&self, args: &[Sexp], span: Span, parser: &mut Parser) -> Result<T, ParseError>;
}

pub struct SimpleMacro<T, F: Fn(&[Sexp], Span, &mut Parser) -> Result<T, ParseError>>(Symbol, F);
pub struct SimpleMacro<T, F: Fn(&[Sexp], Span, &mut Parser) -> Result<T, ParseError> + Send + Sync>(
Symbol,
F,
);

impl<T, F> SimpleMacro<T, F>
where
F: Fn(&[Sexp], Span, &mut Parser) -> Result<T, ParseError>,
F: Fn(&[Sexp], Span, &mut Parser) -> Result<T, ParseError> + Send + Sync,
{
pub fn new(head: &str, f: F) -> Self {
Self(head.into(), f)
Expand All @@ -235,7 +238,7 @@ where

impl<T, F> Macro<T> for SimpleMacro<T, F>
where
F: Fn(&[Sexp], Span, &mut Parser) -> Result<T, ParseError>,
F: Fn(&[Sexp], Span, &mut Parser) -> Result<T, ParseError> + Send + Sync,
{
fn name(&self) -> Symbol {
self.0
Expand Down
16 changes: 8 additions & 8 deletions src/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub struct Function {
pub merge: MergeFn,
pub(crate) nodes: table::Table,
sorts: HashSet<Symbol>,
pub(crate) indexes: Vec<Rc<ColumnIndex>>,
pub(crate) indexes: Vec<Arc<ColumnIndex>>,
pub(crate) rebuild_indexes: Vec<Option<CompositeColumnIndex>>,
index_updated_through: usize,
updates: usize,
Expand All @@ -30,7 +30,7 @@ pub enum MergeFn {
Union,
// the rc is make sure it's cheaply clonable, since calling the merge fn
// requires a clone
Expr(Rc<Program>),
Expr(Arc<Program>),
}

/// All information we know determined by the input.
Expand Down Expand Up @@ -125,7 +125,7 @@ impl Function {
let program = egraph
.compile_expr(&binding, &actions, &target)
.map_err(Error::TypeErrors)?;
MergeFn::Expr(Rc::new(program))
MergeFn::Expr(Arc::new(program))
} else if decl.subtype == FunctionSubtype::Constructor {
MergeFn::Union
} else {
Expand All @@ -136,7 +136,7 @@ impl Function {
input
.iter()
.chain(once(&output))
.map(|x| Rc::new(ColumnIndex::new(x.name()))),
.map(|x| Arc::new(ColumnIndex::new(x.name()))),
);

let rebuild_indexes = Vec::from_iter(input.iter().chain(once(&output)).map(|x| {
Expand Down Expand Up @@ -179,7 +179,7 @@ impl Function {
self.nodes.clear();
self.indexes
.iter_mut()
.for_each(|x| Rc::make_mut(x).clear());
.for_each(|x| Arc::make_mut(x).clear());
self.rebuild_indexes.iter_mut().for_each(|x| {
if let Some(x) = x {
x.clear()
Expand Down Expand Up @@ -224,7 +224,7 @@ impl Function {
&self,
col: usize,
timestamps: &Range<u32>,
) -> Option<Rc<ColumnIndex>> {
) -> Option<Arc<ColumnIndex>> {
let range = self.nodes.transform_range(timestamps);
if range.end > self.index_updated_through {
return None;
Expand Down Expand Up @@ -255,7 +255,7 @@ impl Function {
.zip(self.rebuild_indexes.iter_mut())
.enumerate()
{
let as_mut = Rc::make_mut(index);
let as_mut = Arc::make_mut(index);
if col == self.schema.input.len() {
for (slot, _, out) in self.nodes.iter_range(offsets.clone(), true) {
as_mut.add(out.value, slot)
Expand Down Expand Up @@ -300,7 +300,7 @@ impl Function {
for index in &mut self.indexes {
// Everything works if we don't have a unique copy of the indexes,
// but we ought to be able to avoid this copy.
Rc::make_mut(index).clear();
Arc::make_mut(index).clear();
}
for rebuild_index in self.rebuild_indexes.iter_mut().flatten() {
rebuild_index.clear();
Expand Down
4 changes: 2 additions & 2 deletions src/gj.rs
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ type RowIdx = u32;
#[derive(Debug)]
enum LazyTrieInner {
Borrowed {
index: Rc<ColumnIndex>,
index: Arc<ColumnIndex>,
map: HashMap<Value, LazyTrie>,
},
Delayed(SmallVec<[RowIdx; 4]>),
Expand All @@ -822,7 +822,7 @@ impl LazyTrie {
LazyTrieInner::Borrowed { index, .. } => index.len(),
}
}
fn from_column_index(index: Rc<ColumnIndex>) -> LazyTrie {
fn from_column_index(index: Arc<ColumnIndex>) -> LazyTrie {
LazyTrie(UnsafeCell::new(LazyTrieInner::Borrowed {
index,
map: Default::default(),
Expand Down
16 changes: 11 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ use indexmap::map::Entry;
use instant::{Duration, Instant};
pub use serialize::{SerializeConfig, SerializedNode};
use sort::*;
use std::fmt::Debug;
use std::fmt::{Display, Formatter};
use std::fs::File;
use std::hash::Hash;
use std::io::Read;
use std::iter::once;
use std::ops::{Deref, Range};
use std::path::PathBuf;
use std::rc::Rc;
use std::str::FromStr;
use std::{fmt::Debug, sync::Arc};
use std::sync::Arc;
pub use termdag::{Term, TermDag, TermId};
use thiserror::Error;
pub use typechecking::TypeInfo;
Expand Down Expand Up @@ -292,7 +292,7 @@ impl RunReport {
}

#[derive(Clone)]
pub struct Primitive(Arc<dyn PrimitiveLike>);
pub struct Primitive(Arc<dyn PrimitiveLike + Send + Sync>);
impl Primitive {
// Takes the full signature of a primitive (including input and output types)
// Returns whether the primitive is compatible with this signature
Expand Down Expand Up @@ -344,7 +344,7 @@ impl Debug for Primitive {
}
}

impl<T: PrimitiveLike + 'static> From<T> for Primitive {
impl<T: PrimitiveLike + 'static + Send + Sync> From<T> for Primitive {
fn from(p: T) -> Self {
Self(Arc::new(p))
}
Expand Down Expand Up @@ -1585,7 +1585,9 @@ pub enum Error {

#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::{Arc, Mutex};

use lazy_static::lazy_static;

use crate::constraint::SimpleTypeConstraint;
use crate::sort::*;
Expand Down Expand Up @@ -1656,4 +1658,8 @@ mod tests {
)
.unwrap();
}

lazy_static! {
pub static ref RT: Mutex<EGraph> = Mutex::new(EGraph::default());
}
}