Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No more locks in sampler add/del #2441

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
34 changes: 15 additions & 19 deletions arbor/cable_cell_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,7 @@ void cable_cell_group::advance(epoch ep, time_type dt, const event_lane_subrange
sample_size_type n_samples = 0;
sample_size_type max_samples_per_call = 0;

if (!sampler_map_.empty()) { // NOTE: We avoid the lock here as often as possible
// SAFETY: We need the lock here, as _schedule_ is not reentrant.
std::lock_guard<std::mutex> guard(sampler_mex_);
if (!sampler_map_.empty()) { // NOTE: We avoid work here as often as possible
for (auto& [sk, sa]: sampler_map_) {
if (sa.probeset_ids.empty()) continue; // No need to make any schedule
auto sample_times = util::make_range(sa.sched.events(tstart, ep.t1));
Expand All @@ -401,18 +399,20 @@ void cable_cell_group::advance(epoch ep, time_type dt, const event_lane_subrange
for (const auto& pid: sa.probeset_ids) {
unsigned index = 0;
for (const auto& pdata: probe_map_.data_on(pid)) {
call_info.push_back({sa.sampler,
pid,
index,
pdata,
n_samples,
n_samples + n_times*pdata->n_raw()});
call_info.push_back({.sampler=sa.sampler,
.probeset_id=pid,
.index=index,
// SAFETY: this is ok as no additions to the probe map can happen during
// advance.
.pdata_ptr=&pdata,
.begin_offset=n_samples,
.end_offset=n_samples + n_times*pdata.n_raw()});
index++;
for (auto t: sample_times) {
auto it = timesteps_.find(t);
arb_assert(it != timesteps_.end());
const auto timestep_index = it - timesteps_.begin();
for (probe_handle h: pdata->raw_handle_range()) {
for (const auto& h: pdata.raw_handle_range()) {
sample_event ev{t, {h, n_samples++}};
sample_events_[timestep_index].push_back(ev);
}
Expand Down Expand Up @@ -447,7 +447,6 @@ void cable_cell_group::advance(epoch ep, time_type dt, const event_lane_subrange
// generate spikes with global spike source ids. The threshold crossings
// record the local spike source index, which must be converted to a
// global index for spike communication.

for (auto c: result.crossings) {
spikes_.emplace_back(spike_sources_[c.index], time_type(c.time));
}
Expand All @@ -457,37 +456,34 @@ void cable_cell_group::add_sampler(sampler_association_handle h,
cell_member_predicate probeset_ids,
schedule sched,
sampler_function fn) {
// SAFETY? Both probe_map and sampler must be protected by this lock?!
std::lock_guard<std::mutex> guard(sampler_mex_);
auto probeset = probe_map_.keys(probeset_ids);
if (!probeset.empty()) {
auto result = sampler_map_.insert({h, sampler_association{std::move(sched),
auto result = sampler_map_.emplace(h, sampler_association{std::move(sched),
std::move(fn),
std::move(probeset)}});
std::move(probeset)});
arb_assert(result.second);
}
}

void cable_cell_group::remove_sampler(sampler_association_handle h) {
std::lock_guard<std::mutex> guard(sampler_mex_);
sampler_map_.erase(h);
}

void cable_cell_group::remove_all_samplers() {
std::lock_guard<std::mutex> guard(sampler_mex_);
sampler_map_.clear();
}

std::vector<probe_metadata> cable_cell_group::get_probe_metadata(const cell_address_type& probeset_id) const {
// SAFETY: Probe associations are fixed after construction, so we do not
// need to grab the mutex.
auto data = probe_map_.data_on(probeset_id);
const auto& data = probe_map_.data_on(probeset_id);

std::vector<probe_metadata> result;
result.reserve(data.size());
unsigned index = 0;
for (const auto& info: data) {
result.push_back({probeset_id, index++, info->get_metadata_ptr()});
result.push_back({probeset_id, index, info.get_metadata_ptr()});
index++;
}
return result;
}
Expand Down
3 changes: 0 additions & 3 deletions arbor/cable_cell_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ struct ARB_ARBOR_API cable_cell_group: public cell_group {

// Collection of samplers to be run against probes in this group.
sampler_association_map sampler_map_;

// Mutex for thread-safe access to sampler associations.
std::mutex sampler_mex_;
};

} // namespace arb
44 changes: 26 additions & 18 deletions arbor/fvm_lowered_cell.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,40 +173,48 @@ struct fvm_probe_data {
// map to multiple probe representations within the cable_cell_group.

struct probe_association_map {
// unique keys from multimap
std::vector<cell_address_type> keys(cell_member_predicate pred=all_probes) const {
std::vector<cell_address_type> res;
std::unordered_set<cell_address_type> seen;
for (const auto& [k, v]: data) {
if (!seen.count(k)) {
if (pred(k)) res.push_back(k);
seen.insert(k);
if (std::holds_alternative<predicate_function>(pred)) {
auto fun = std::get<predicate_function>(pred);
for (const auto& [k, v]: data_) {
if (fun(k)) res.push_back(k);
}
}
else if (std::holds_alternative<one_probe>(pred)) {
auto pid = std::get<one_probe>(pred).pid;
if (data_.contains(pid)) res.push_back(pid);
}
else if (std::holds_alternative<all_probes_t>(pred)) {
for (const auto& [k, v]: data_) {
res.push_back(k);
}
}
return res;
}

auto count(const cell_address_type& k) const { return data.count(k); }
auto count(const cell_address_type& key) const {
return data_.contains(key) ? data_.at(key).size() : 0;
}

// Return range of fvm_probe_data values associated with probeset_id.
std::vector<const fvm_probe_data*> data_on(const cell_address_type& probeset_id) const {
std::vector<const fvm_probe_data*> res;
const auto& [beg, end] = data.equal_range(probeset_id);
for (auto it = beg; it != end; ++it) {
res.push_back(&it->second);
}
return res;
const std::vector<fvm_probe_data>& data_on(const cell_address_type& key) const {
if (!data_.contains(key)) return nil;
return data_.at(key);
}

probe_association_map& insert(const cell_address_type& k, fvm_probe_data v) {
data.insert({k, std::move(v)});
probe_association_map& insert(const cell_address_type& key, fvm_probe_data val) {
data_[key].emplace_back(std::move(val));
size_ += 1;
return *this;
}

std::size_t size() const { return data.size(); }
std::size_t size() const { return size_; }

private:
std::unordered_multimap<cell_address_type, fvm_probe_data> data;
std::unordered_map<cell_address_type, std::vector<fvm_probe_data>> data_;
constexpr static std::vector<fvm_probe_data> nil = {};
std::size_t size_ = 0;
};

struct fvm_initialization_data {
Expand Down
16 changes: 12 additions & 4 deletions arbor/include/arbor/sampling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,35 @@

#include <cstddef>
#include <functional>
#include <variant>

#include <arbor/common_types.hpp>
#include <arbor/util/any_ptr.hpp>

namespace arb {

using cell_member_predicate = std::function<bool (const cell_address_type&)>;

static cell_member_predicate all_probes = [](const cell_address_type&) { return true; };
struct all_probes_t {};

struct one_probe {
one_probe(cell_address_type p): pid{std::move(p)} {}
cell_address_type pid;
bool operator()(const cell_address_type& x) { return x == pid; }
};

using predicate_function = std::function<bool (const cell_address_type&)>;

using cell_member_predicate = std::variant<all_probes_t,
one_probe,
predicate_function>;

static cell_member_predicate all_probes = all_probes_t{};


struct one_gid {
one_gid(cell_gid_type p): gid{std::move(p)} {}
cell_gid_type gid;
bool operator()(const cell_address_type& x) { return x.gid == gid; }
};

struct one_tag {
one_tag(cell_tag_type p): tag{std::move(p)} {}
cell_tag_type tag;
Expand Down
45 changes: 19 additions & 26 deletions arbor/include/arbor/util/handle_set.hpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
#pragma once

#include <mutex>
#include <atomic>
#include <stdexcept>
#include <utility>
#include <vector>

/*
* Manage a set of integer-valued handles.
Expand All @@ -17,43 +15,38 @@
namespace arb {
namespace util {

template <typename Handle>
class handle_set {
public:
using value_type = Handle;
using value_type = std::size_t;

value_type acquire() {
lock_guard lock(mex_);

if (top_==std::numeric_limits<Handle>::max()) {
auto nxt = top_.fetch_add(1);
// We would run into UB _next_ time, so die now.
// SAFETY:
// 1. cannot check if we have already overrun, since another thread
// might already be using a wrapped value
// 2. we also cannot check if top == max since it might already have
// been wrapped by another thread
if (nxt + 1 == std::numeric_limits<value_type>::max()) {
throw std::out_of_range("no more handles");
}
return top_++;
return nxt;
}

// Pre-requisite: h is a handle returned by
// `acquire`, which has not been subject
// to a subsequent `release`.
// Pre-requisite: h is a handle returned by `acquire`, which has not been
// subject to a subsequent `release`.
void release(value_type h) {
lock_guard lock(mex_);

if (h+1==top_) {
--top_;
}
// _if_ this was the last handle to be acquire, release it.
value_type ex = h + 1;
top_.compare_exchange_strong(ex, ex - 1);
// if not, continue as if nothing happened.
}

// Release all handles.
void clear() {
lock_guard lock(mex_);

top_ = 0;
}
void clear() { top_.store(0); }

private:
value_type top_ = 0;

using lock_guard = std::lock_guard<std::mutex>;
std::mutex mex_;
std::atomic<value_type> top_ = 0;
};

} // namespace util
Expand Down
24 changes: 16 additions & 8 deletions arbor/lif_cell_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,26 @@ void lif_cell_group::clear_spikes() {
}

void lif_cell_group::add_sampler(sampler_association_handle h,
cell_member_predicate probeset_ids,
cell_member_predicate pred,
schedule sched,
sampler_function fn) {
std::lock_guard<std::mutex> guard(sampler_mex_);
std::vector<cell_address_type> probeset;
for (const auto& [k, v]: probes_) {
if (probeset_ids(k)) probeset.push_back(k);
if (std::holds_alternative<predicate_function>(pred)) {
auto fun = std::get<predicate_function>(pred);
for (const auto& [k, v]: probes_) {
if (fun(k)) probeset.push_back(k);
}
}
else if (std::holds_alternative<one_probe>(pred)) {
auto pid = std::get<one_probe>(pred).pid;
if (probes_.contains(pid)) probeset.push_back(pid);
}
else if (std::holds_alternative<all_probes_t>(pred)) {
for (const auto& [k, v]: probes_) {
probeset.push_back(k);
}
}

auto assoc = arb::sampler_association{std::move(sched),
std::move(fn),
std::move(probeset)};
Expand All @@ -79,11 +91,9 @@ void lif_cell_group::add_sampler(sampler_association_handle h,
}

void lif_cell_group::remove_sampler(sampler_association_handle h) {
std::lock_guard<std::mutex> guard(sampler_mex_);
samplers_.erase(h);
}
void lif_cell_group::remove_all_samplers() {
std::lock_guard<std::mutex> guard(sampler_mex_);
samplers_.clear();
}

Expand Down Expand Up @@ -121,7 +131,6 @@ void lif_cell_group::advance_cell(time_type tfinal,
std::vector<std::pair<time_type, sampler_association_handle>> samples;
if (!samplers_.empty()) {
auto tlast = last_time_sampled_[lid];
std::lock_guard<std::mutex> guard(sampler_mex_);
for (auto& [hdl, assoc]: samplers_) {
// No need to generate events
if (assoc.probeset_ids.empty()) continue;
Expand Down Expand Up @@ -229,7 +238,6 @@ void lif_cell_group::advance_cell(time_type tfinal,
arb_assert (sampled_voltages.size() <= n_values);
// Now we need to call all sampler callbacks with the data we have collected
{
std::lock_guard<std::mutex> guard(sampler_mex_);
for (auto& [k, vs]: sampled) {
const auto& fun = samplers_[k].sampler;
for (auto& [id, us]: vs) {
Expand Down
4 changes: 1 addition & 3 deletions arbor/lif_cell_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,7 @@ struct ARB_ARBOR_API lif_cell_group: public cell_group {
// Time when the cell can _next_ be updated;
std::vector<time_type> next_time_updatable_;

// SAFETY: We need to access samplers_ through a mutex since
// simulation::add_sampler might be called concurrently.
std::mutex sampler_mex_;
// sampler id -> (schedule, callback, probe_ids)
sampler_association_map samplers_;

// LIF probe metadata, precalculated to pass to callbacks
Expand Down
2 changes: 1 addition & 1 deletion arbor/simulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class simulation_state {
std::array<thread_private_spike_store, 2> local_spikes_;

// Sampler associations handles are managed by a helper class.
util::handle_set<sampler_association_handle> sassoc_handles_;
util::handle_set sassoc_handles_;

// Accessors to events
std::vector<pse_vector>& event_lanes(std::ptrdiff_t epoch_id) { return event_lanes_[epoch_id&1]; }
Expand Down
16 changes: 7 additions & 9 deletions example/brunel/brunel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <iomanip>
#include <iostream>
#include <optional>
#include <set>
#include <vector>

#include <tinyopt/tinyopt.h>
Expand Down Expand Up @@ -353,19 +352,18 @@ void add_subset(cell_gid_type gid,
auto gid_in_range = int(gid >= start && gid < end);
if (m + start + gid_in_range >= end) throw std::runtime_error("Requested too many connections from the given range of gids.");
// Exclude ourself
std::set<cell_gid_type> seen{gid};
std::vector<bool> seen(end - start + 1, false);
if (gid >= start && gid < end) seen[gid - start] = true;
std::mt19937 gen(gid + 42);
while(m > 0) {
while(m) {
cell_gid_type val = rand_range(gen, start, end);
if (!seen.count(val)) {
conns.push_back({{val, src}, {tgt}, weight, delay*U::ms});
seen.insert(val);
m--;
}
if (seen[val - start]) continue;
conns.push_back({{val, src}, {tgt}, weight, delay*U::ms});
seen[val - start] = true;
m--;
}
}


// Read options from (optional) json file and command line arguments.
std::optional<cl_options> read_options(int argc, char** argv) {
using namespace to;
Expand Down
Loading
Loading