Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: improve/extend thread management and ctrl-c #276

Draft
wants to merge 4 commits into
base: ctrlc
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ camino = "1.1.6"
glob = "0.3.1"
rustworkx-core = "0.14.2"
ctrlc = "3.4.2"
crossbeam-channel = "0.5.12"

[dev-dependencies]
assert_cmd = "2.0.14"
Expand Down
31 changes: 19 additions & 12 deletions src/multisearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::sync::Arc;
use std::sync::Mutex;

use crate::utils::{
csvwriter_thread, load_collection, load_sketches, MultiSearchResult, ReportType, ThreadManager,
csvwriter_thread, load_collection, load_sketches, MultiSearchResult, ReportType, ThreadManager, WriterType
};
use sourmash::ani_utils::ani_from_containment;

Expand Down Expand Up @@ -46,19 +46,26 @@ pub fn multisearch(
)?;
let against = load_sketches(against_collection, selection, ReportType::Against).unwrap();

// set up a multi-producer, single-consumer channel.
let (send, recv) =
std::sync::mpsc::sync_channel::<MultiSearchResult>(rayon::current_num_threads());
// // set up a multi-producer, single-consumer channel.
// let (send, recv) =
// std::sync::mpsc::sync_channel::<MultiSearchResult>(rayon::current_num_threads());

// // & spawn a thread that is dedicated to printing to a buffered output
let thrd = csvwriter_thread(recv, output);
// // // & spawn a thread that is dedicated to printing to a buffered output
// let thrd = csvwriter_thread(recv, output);

// set up manager to allow for ctrl-c handling
let manager = ThreadManager::new(send, thrd);
// // set up manager to allow for ctrl-c handling
// let manager = ThreadManager::new(send, thrd);

// Wrap ThreadManager in Arc<Mutex> for safe sharing across threads
let manager_shared = Arc::new(Mutex::new(manager));
//
// // Wrap ThreadManager in Arc<Mutex> for safe sharing across threads
// let manager_shared = Arc::new(Mutex::new(manager));
// set up manager to allow for ctrl-c handling
let mut manager = ThreadManager::new();
// start writer thread
manager.add_writer_thread(WriterType::MultiSearch, output)?;
// // Wrap ThreadManager in Arc<Mutex> for safe sharing across threads
let manager_shared = Arc::new(Mutex::new(manager));

// //
// Main loop: iterate (in parallel) over all search signature paths,
// loading them individually and searching them. Stuff results into
// the writer thread above.
Expand Down Expand Up @@ -138,7 +145,7 @@ pub fn multisearch(
})
.flatten()
.try_for_each_with(manager_shared.clone(), |manager, result| {
manager.lock().unwrap().send(result)
manager.lock().unwrap().send(result, WriterType::MultiSearch)
})?;

// do some cleanup and error handling -
Expand Down
139 changes: 88 additions & 51 deletions src/pairwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::sync::Arc;
use std::sync::Mutex;

use crate::utils::{
csvwriter_thread, load_collection, load_sketches, MultiSearchResult, ReportType, ThreadManager,
load_collection, load_sketches, MultiSearchResult, ReportType, ThreadManager, WriterType,
};
use sourmash::ani_utils::ani_from_containment;
use sourmash::selection::Selection;
Expand Down Expand Up @@ -43,18 +43,42 @@ pub fn pairwise(
let sketches = load_sketches(collection, selection, ReportType::General).unwrap();

// set up a multi-producer, single-consumer channel.
let (send, recv) =
std::sync::mpsc::sync_channel::<MultiSearchResult>(rayon::current_num_threads());
// let (send, recv) =
// std::sync::mpsc::sync_channel::<MultiSearchResult>(rayon::current_num_threads());

// // & spawn a thread that is dedicated to printing to a buffered output
let thrd = csvwriter_thread(recv, output);
// // // & spawn a thread that is dedicated to printing to a buffered output
// let thrd = csvwriter_thread(recv, output);

// let manager = ThreadManager::new(send, thrd);
// set up manager to allow for ctrl-c handling
let manager = ThreadManager::new(send, thrd);

// Wrap ThreadManager in Arc<Mutex> for safe sharing across threads
let mut manager = ThreadManager::new();
// start writer thread
manager.add_writer_thread(WriterType::MultiSearch, output)?;
// // Wrap ThreadManager in Arc<Mutex> for safe sharing across threads
let manager_shared = Arc::new(Mutex::new(manager));

// Create a new ThreadManager instance
// let thread_manager = ThreadManager::new();

// Wrap the ThreadManager instance in a Mutex to make it thread-safe
// let mutex_thread_manager = Mutex::new(thread_manager);

// Wrap the Mutex in an Arc to make it shareable across threads
// let arc_mutex_thread_manager = Arc::new(mutex_thread_manager);

// Create a new instance of ThreadManager wrapped in an Arc and Mutex
// let manager = Arc::new(Mutex::new(ThreadManager::new()));
// let manager_clone = Arc::clone(&manager);


// Lock the Mutex to acquire a guard, then unwrap
// let manager_shared = arc_mutex_thread_manager.lock().unwrap();
// let manager_shared = manager.lock().unwrap();

// let thread_manager = Arc::new(Mutex::new(ThreadManager::new()));
// let mut manager_shared = thread_manager.lock().unwrap().unwrap();
// manager_shared.add_writer_thread(WriterType::MultiSearch, output);

//
// Main loop: iterate (in parallel) over all signature,
// Results written to the writer thread above.
Expand All @@ -64,18 +88,16 @@ pub fn pairwise(

sketches.par_iter().enumerate().for_each(|(idx, query)| {
// Clone the Arc to get a new reference for this thread
let manager_clone = manager_shared.clone();
// let manager_clone = manager_shared.clone();
// let manager_clone = Arc::clone(&arc_mutex_thread_manager);
let manager_clone = Arc::clone(&manager_shared);
let mut has_written_comparison = false;
for against in sketches.iter().skip(idx + 1) {
if manager_shared
.lock()
.unwrap()
.interrupted
.load(atomic::Ordering::SeqCst)
{
println!("Ctrl-C received, signaling shutdown...");
return; // Early return to stop processing further
// don't need to acquire lock to check for interrupt
if manager.check_for_interrupt() {
return; // Early return to stop processing further. This should end the loop and move to cleanup.
}

let overlap = query.minhash.count_common(&against.minhash, false).unwrap() as f64;
let query1_size = query.minhash.size() as f64;
let query2_size = against.minhash.size() as f64;
Expand All @@ -101,24 +123,32 @@ pub fn pairwise(
average_containment_ani = Some((qani + mani) / 2.);
max_containment_ani = Some(f64::max(qani, mani));
}
manager_clone
let multisearch_result = MultiSearchResult {
query_name: query.name.clone(),
query_md5: query.md5sum.clone(),
match_name: against.name.clone(),
match_md5: against.md5sum.clone(),
containment: containment_q1_in_q2,
max_containment,
jaccard,
intersect_hashes: overlap,
query_containment_ani,
match_containment_ani,
average_containment_ani,
max_containment_ani,
};

match manager_clone
.lock()
.unwrap()
.send(MultiSearchResult {
query_name: query.name.clone(),
query_md5: query.md5sum.clone(),
match_name: against.name.clone(),
match_md5: against.md5sum.clone(),
containment: containment_q1_in_q2,
max_containment,
jaccard,
intersect_hashes: overlap,
query_containment_ani,
match_containment_ani,
average_containment_ani,
max_containment_ani,
})
.unwrap()
.send(WriterType::MultiSearch, multisearch_result)
{
Ok(()) => {}
Err(send_error) => {
eprintln!("Error sending data: {:?}", send_error);
return;
}
}
}

let i = processed_cmp.fetch_add(1, atomic::Ordering::SeqCst);
Expand All @@ -138,30 +168,37 @@ pub fn pairwise(
average_containment_ani = Some(1.0);
max_containment_ani = Some(1.0);
}

manager_clone
let multisearch_result = MultiSearchResult {
query_name: query.name.clone(),
query_md5: query.md5sum.clone(),
match_name: query.name.clone(),
match_md5: query.md5sum.clone(),
containment: 1.0,
max_containment: 1.0,
jaccard: 1.0,
intersect_hashes: query.minhash.size() as f64,
query_containment_ani,
match_containment_ani,
average_containment_ani,
max_containment_ani,
};

match manager_clone
.lock()
.unwrap()
.send(MultiSearchResult {
query_name: query.name.clone(),
query_md5: query.md5sum.clone(),
match_name: query.name.clone(),
match_md5: query.md5sum.clone(),
containment: 1.0,
max_containment: 1.0,
jaccard: 1.0,
intersect_hashes: query.minhash.size() as f64,
query_containment_ani,
match_containment_ani,
average_containment_ani,
max_containment_ani,
})
.unwrap()
.send(WriterType::MultiSearch, multisearch_result)
{
Ok(()) => {}
Err(send_error) => {
eprintln!("Error sending data: {:?}", send_error);
return;
}
}
}
});

// do some cleanup and error handling -
manager_shared.lock().unwrap().perform_cleanup();
manager.perform_cleanup();

// done!
let i: usize = processed_cmp.load(atomic::Ordering::SeqCst);
Expand Down
Loading
Loading