Skip to content

Commit d9b1076

Browse files
committed
Add inference meta => value
1 parent 6bfc104 commit d9b1076

28 files changed

+95
-96
lines changed

arbor/include/arbor/cable_cell.hpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <arbor/morph/primitives.hpp>
2222
#include <arbor/util/hash_def.hpp>
2323
#include <arbor/util/typed_map.hpp>
24+
#include <arbor/util/extra_traits.hpp>
2425

2526
namespace arb {
2627

@@ -34,12 +35,26 @@ struct ARB_SYMBOL_VISIBLE cable_probe_point_info {
3435

3536
// Cable cell type definitions
3637
using cable_sample_type = const double;
37-
using cable_sample_range = std::pair<cable_sample_type*, cable_sample_type*>;
3838

3939
using cable_state_meta_type = const mlocation;
4040
using cable_state_cell_meta_type = const mcable;
4141
using cable_point_meta_type = const cable_probe_point_info;
4242

43+
template <>
44+
struct probe_value_type_of<cable_state_meta_type> {
45+
using type = cable_sample_type;
46+
};
47+
48+
template <>
49+
struct probe_value_type_of<cable_state_cell_meta_type> {
50+
using type = cable_sample_type;
51+
};
52+
53+
template <>
54+
struct probe_value_type_of<cable_point_meta_type> {
55+
using type = cable_sample_type;
56+
};
57+
4358
// Each kind of probe has its own type for representing its address, as below.
4459
// The metadata associated with a probe is also passed to a sampler via an `any_ptr`;
4560
// the underlying pointer will be a const pointer to the associated metadata type.

arbor/include/arbor/lif_cell.hpp

+9
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <arbor/export.hpp>
66
#include <arbor/units.hpp>
77

8+
#include <arbor/util/extra_traits.hpp>
9+
810
namespace arb {
911

1012
namespace U = arb::units;
@@ -29,6 +31,7 @@ struct ARB_SYMBOL_VISIBLE lif_cell {
2931
struct ARB_SYMBOL_VISIBLE lif_probe_metadata {};
3032

3133
using lif_sample_type = const double;
34+
using lif_meta_type = const lif_probe_metadata;
3235

3336
// Voltage estimate [mV].
3437
// Sample value type: `double`
@@ -37,4 +40,10 @@ struct ARB_SYMBOL_VISIBLE lif_probe_voltage {
3740
using meta_type = const lif_probe_metadata;
3841
};
3942

43+
template <>
44+
struct probe_value_type_of<const lif_meta_type> {
45+
using type = lif_sample_type;
46+
};
47+
48+
4049
} // namespace arb

arbor/include/arbor/sampling.hpp

+10-19
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
#include <cstddef>
44
#include <format>
55
#include <functional>
6-
#include <iostream>
76

87
#include <arbor/assert.hpp>
98
#include <arbor/common_types.hpp>
109
#include <arbor/util/any_ptr.hpp>
10+
#include <arbor/util/extra_traits.hpp>
1111

1212
namespace arb {
1313

@@ -52,18 +52,9 @@ struct sample_records {
5252
std::any values; // resolves to pointer of probe-specific payload data D of layout D[n_sample][width]
5353
};
5454

55-
// Helper class, to be specialized in each cell header, mapping from Metadata to Value types
56-
template <typename M>
57-
struct probe_value_type_of {
58-
using meta_type = M;
59-
using type = void;
60-
};
61-
62-
template<typename M,
63-
typename V// = probe_value_type_of<M>::type
64-
>
55+
template<typename M>
6556
struct sample_reader {
66-
using value_type = V;
57+
using value_type = probe_value_type_of_t<M>;
6758
using meta_type = M;
6859

6960
std::size_t width = 0;
@@ -92,15 +83,15 @@ struct sample_reader {
9283
}
9384
};
9485

95-
// TODO M is enough to know V!
96-
template<typename M, typename V>
86+
template<typename M>
9787
auto make_sample_reader(util::any_ptr apm, const sample_records& sr) {
9888
using util::any_cast;
9989
auto pm = any_cast<M*>(apm);
10090
if (!pm) {
10191
throw std::runtime_error{std::format("Sample reader: could not cast to metadata type; expected {}, got {}.",
10292
typeid((M*)nullptr).name(), apm.type().name())};
10393
}
94+
using V = sample_reader<M>::value_type;
10495
V* val = nullptr;
10596
try {
10697
val = any_cast<V*>(sr.values);
@@ -109,11 +100,11 @@ auto make_sample_reader(util::any_ptr apm, const sample_records& sr) {
109100
throw std::runtime_error{std::format("Sample reader: could not cast to value type; expected {}, got {}.",
110101
typeid((V*)nullptr).name(), sr.values.type().name())};
111102
}
112-
return sample_reader<M, V> { .width=sr.width,
113-
.n_sample=sr.n_sample,
114-
.time=sr.time,
115-
.values=val,
116-
.metadata=pm, };
103+
return sample_reader<M> { .width=sr.width,
104+
.n_sample=sr.n_sample,
105+
.time=sr.time,
106+
.values=val,
107+
.metadata=pm, };
117108
}
118109

119110
using sampler_function = std::function<void(const probe_metadata&, const sample_records&)>;

arbor/include/arbor/simple_sampler.hpp

+7-6
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@ namespace arb {
1313
// Simple(st?) implementation of a recorder of scalar trace data from a cell
1414
// probe, with some metadata.
1515

16-
template<typename M, typename V>
16+
template<typename M>
1717
struct simple_sampler_result {
18+
using value_type = probe_value_type_of_t<M>;
1819
std::size_t n_sample = 0;
1920
std::size_t width = 0;
2021
std::vector<time_type> time;
21-
std::vector<std::vector<std::remove_const_t<V>>> values;
22+
std::vector<std::vector<std::remove_const_t<value_type>>> values;
2223
std::vector<std::remove_const_t<M>> metadata;
2324

24-
void from_reader(const sample_reader<M, V>& reader) {
25+
void from_reader(const sample_reader<M>& reader) {
2526
n_sample = reader.n_sample;
2627
width = reader.width;
2728
values.resize(width);
@@ -38,10 +39,10 @@ struct simple_sampler_result {
3839
}
3940
};
4041

41-
template <typename M, typename V>
42-
auto make_simple_sampler(simple_sampler_result<M, V>& trace) {
42+
template <typename M>
43+
auto make_simple_sampler(simple_sampler_result<M>& trace) {
4344
return [&trace](const probe_metadata& pm, const sample_records& recs) {
44-
auto reader = make_sample_reader<M, V>(pm.meta, recs);
45+
auto reader = make_sample_reader<M>(pm.meta, recs);
4546
trace.from_reader(reader);
4647
};
4748
}
+7-9
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
#pragma once
22

33
namespace arb {
4-
namespace util {
54

6-
// TODO: C++20 replace with std::remove_cvref, std::remove_cvref_t
7-
8-
template <typename T>
9-
struct remove_cvref {
10-
typedef std::remove_cv_t<std::remove_reference_t<T>> type;
5+
// Helper class, to be specialized in each cell header, mapping from Metadata to Value types
6+
template <typename M>
7+
struct probe_value_type_of {
8+
using meta_type = M;
9+
using type = void;
1110
};
1211

13-
template <typename T>
14-
using remove_cvref_t = typename remove_cvref<T>::type;
12+
template <typename M>
13+
using probe_value_type_of_t = probe_value_type_of<M>::type;
1514

16-
} // namespace util
1715
} // namespace arb

example/busyring/ring.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ using namespace arborio::literals;
4141
namespace U = arb::units;
4242

4343
// result of simple sampler for probe type
44-
using sample_result = arb::simple_sampler_result<arb::cable_state_meta_type, arb::cable_sample_type>;
44+
using sample_result = arb::simple_sampler_result<arb::cable_state_meta_type>;
4545

4646
// Writes voltage trace as a json file.
4747
void write_trace_json(const std::string& path, const sample_result&);

example/diffusion/diffusion.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ std::ofstream out;
5959

6060
void sampler(probe_metadata pm, const sample_records& samples) {
6161
out << "time,prox,dist,Xd\n" << std::fixed << std::setprecision(4);
62-
auto reader = arb::make_sample_reader<arb::cable_state_cell_meta_type, arb::cable_sample_type>(pm.meta, samples);
62+
auto reader = arb::make_sample_reader<arb::cable_state_cell_meta_type>(pm.meta, samples);
6363
for (std::size_t ix= 0; ix < reader.n_sample; ++ix) {
6464
auto time = reader.get_time(ix);
6565
for (std::size_t iy = 0; iy < reader.width; ++iy) {

example/dryrun/dryrun.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct run_params {
4747

4848

4949
// result of simple sampler for probe type
50-
using sample_result = arb::simple_sampler_result<arb::cable_state_meta_type, arb::cable_sample_type>;
50+
using sample_result = arb::simple_sampler_result<arb::cable_state_meta_type>;
5151

5252
// Writes voltage trace as a json file.
5353
void write_trace_json(const sample_result&);

example/gap_junctions/gap_junctions.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ using arb::cell_kind;
5959
using arb::time_type;
6060

6161
using probe_t = arb::cable_probe_membrane_voltage;
62-
using sample_results = std::vector<arb::simple_sampler_result<probe_t::meta_type, probe_t::value_type>>;
62+
using sample_results = std::vector<arb::simple_sampler_result<probe_t::meta_type>>;
6363

6464
// Writes voltage trace as a json file.
6565
void write_trace_json(const sample_results& traces, unsigned rank);

example/generators/generators.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ using arb::time_type;
3535
using namespace arborio::literals;
3636

3737
// result of simple sampler for probe type
38-
using sample_result = arb::simple_sampler_result<arb::cable_state_meta_type, arb::cable_sample_type>;
38+
using sample_result = arb::simple_sampler_result<arb::cable_state_meta_type>;
3939

4040
// Writes voltage trace as a json file.
4141
void write_trace_json(const sample_result&);

example/lfp/lfp.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,13 @@ int main(int argc, char** argv) {
180180
auto sample_schedule = arb::regular_schedule(sample_dt*U::ms);
181181
sim.add_sampler(arb::one_probe({0, "Itotal"}), sample_schedule, lfp.callback());
182182

183-
arb::simple_sampler_result<arb::cable_probe_membrane_voltage::meta_type, arb::cable_probe_membrane_voltage::value_type> membrane_voltage;
183+
arb::simple_sampler_result<arb::cable_probe_membrane_voltage::meta_type> membrane_voltage;
184184
sim.add_sampler(arb::one_probe({0, "Um"}), sample_schedule, arb::make_simple_sampler(membrane_voltage));
185185

186-
arb::simple_sampler_result<arb::cable_probe_total_ion_current_density::meta_type, arb::cable_probe_total_ion_current_density::value_type> ionic_current_density;
186+
arb::simple_sampler_result<arb::cable_probe_total_ion_current_density::meta_type> ionic_current_density;
187187
sim.add_sampler(arb::one_probe({0, "Iion"}), sample_schedule, arb::make_simple_sampler(ionic_current_density));
188188

189-
arb::simple_sampler_result<arb::cable_probe_point_state::meta_type, arb::cable_probe_point_state::value_type> synapse_g;
189+
arb::simple_sampler_result<arb::cable_probe_point_state::meta_type> synapse_g;
190190
sim.add_sampler(arb::one_probe({0, "expsyn-g"}), sample_schedule, arb::make_simple_sampler(synapse_g));
191191

192192
sim.run(t_stop*U::ms, dt*U::ms);

example/network_description/network_description.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,11 @@ ring_params read_options(int argc, char** argv);
5858
using arb::cell_gid_type;
5959
using arb::cell_kind;
6060
using arb::cell_lid_type;
61-
using arb::cell_member_type;
6261
using arb::cell_size_type;
6362
using arb::time_type;
6463

6564
// result of simple sampler for probe type
66-
using sample_result = arb::simple_sampler_result<arb::cable_state_meta_type, arb::cable_sample_type>;
65+
using sample_result = arb::simple_sampler_result<arb::cable_state_meta_type>;
6766

6867
// Writes voltage trace as a json file.
6968
void write_trace_json(const sample_result&);

example/ornstein_uhlenbeck/ou.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ struct sampler {
6161
}
6262

6363
void operator()(arb::probe_metadata pm, const arb::sample_records& samples) {
64-
auto reader = arb::make_sample_reader<arb::cable_state_cell_meta_type, arb::cable_sample_type>(pm.meta, samples);
64+
auto reader = arb::make_sample_reader<arb::cable_state_cell_meta_type>(pm.meta, samples);
6565
for (std::size_t ix = 0; ix < reader.n_sample; ++ix) {
6666
for (std::size_t iy = 0; iy < reader.width; ++iy) {
6767
data_[ix*reader.width + iy] = reader.get_value(ix, iy);

example/plasticity/plasticity.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include <any>
22
#include <iostream>
3-
#include <iomanip>
43
#include <unordered_map>
54

65
#include <arborio/label_parse.hpp>
@@ -79,7 +78,7 @@ std::mutex mtx;
7978

8079
void sampler(arb::probe_metadata pm, const arb::sample_records& samples) {
8180
using probe_t = arb::cable_probe_membrane_voltage;
82-
auto reader = arb::make_sample_reader<probe_t::meta_type, probe_t::value_type>(pm.meta, samples);
81+
auto reader = arb::make_sample_reader<probe_t::meta_type>(pm.meta, samples);
8382
std::lock_guard<std::mutex> lock{mtx};
8483
for (std::size_t ix = 0; ix < reader.n_sample; ++ix) {
8584
auto time = reader.get_time(ix);

example/probe-demo/probe-demo.cpp

+10-14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include <any>
22
#include <functional>
3-
#include <iomanip>
43
#include <iostream>
54
#include <string>
65
#include <tuple>
@@ -20,9 +19,6 @@
2019

2120
// Simulate a cell modelled as a simple cable with HH dynamics,
2221
// emitting the results of a user specified probe over time.
23-
24-
using std::any;
25-
using arb::util::any_cast;
2622
namespace U = arb::units;
2723
using namespace arb::units::literals;
2824

@@ -79,7 +75,7 @@ struct options {
7975
double sim_dt = 0.025; // [ms]
8076
double sample_dt = 1.0; // [ms]
8177
unsigned n_cv = 10;
82-
any probe_addr;
78+
std::any probe_addr;
8379
std::string value_name;
8480
probe_kind kind = probe_kind::invalid;
8581
};
@@ -105,9 +101,9 @@ std::string show_location(const M& where) {
105101
// Do this once
106102
static std::atomic<int> printed_header = 0;
107103

108-
template<typename M, typename V>
104+
template<typename M>
109105
void sampler(arb::probe_metadata pm, const arb::sample_records& samples) {
110-
auto reader = arb::make_sample_reader<M, V>(pm.meta, samples);
106+
auto reader = arb::make_sample_reader<M>(pm.meta, samples);
111107
// Print CSV header for sample output
112108
if (0 == printed_header.fetch_add(1)) {
113109
std::cout << std::format("t", "");
@@ -127,9 +123,9 @@ void sampler(arb::probe_metadata pm, const arb::sample_records& samples) {
127123

128124
struct cable_recipe: public arb::recipe {
129125
arb::cable_cell_global_properties gprop;
130-
any probe_addr;
126+
std::any probe_addr;
131127

132-
explicit cable_recipe(any probe_addr, unsigned n_cv):
128+
explicit cable_recipe(std::any probe_addr, unsigned n_cv):
133129
probe_addr(std::move(probe_addr)) {
134130
gprop.default_parameters = arb::neuron_parameter_defaults;
135131
gprop.default_parameters.discretization = arb::cv_policy_fixed_per_branch(n_cv);
@@ -138,7 +134,7 @@ struct cable_recipe: public arb::recipe {
138134
arb::cell_size_type num_cells() const override { return 1; }
139135
std::vector<arb::probe_info> get_probes(arb::cell_gid_type) const override { return {{probe_addr, "probe"}}; }
140136
arb::cell_kind get_cell_kind(arb::cell_gid_type) const override { return arb::cell_kind::cable; }
141-
any get_global_properties(arb::cell_kind) const override { return gprop; }
137+
std::any get_global_properties(arb::cell_kind) const override { return gprop; }
142138

143139
arb::util::unique_any get_cell_description(arb::cell_gid_type) const override {
144140
const double length = 1000; // [µm]
@@ -174,17 +170,17 @@ int main(int argc, char** argv) {
174170
case probe_kind::cell:
175171
sim.add_sampler(arb::all_probes,
176172
arb::regular_schedule(opt.sample_dt*U::ms),
177-
sampler<arb::cable_state_cell_meta_type, arb::cable_sample_type>);
173+
sampler<arb::cable_state_cell_meta_type>);
178174
break;
179175
case probe_kind::state:
180176
sim.add_sampler(arb::all_probes,
181177
arb::regular_schedule(opt.sample_dt*U::ms),
182-
sampler<arb::cable_state_meta_type, arb::cable_sample_type>);
178+
sampler<arb::cable_state_meta_type>);
183179
break;
184180
case probe_kind::point:
185181
sim.add_sampler(arb::all_probes,
186182
arb::regular_schedule(opt.sample_dt*U::ms),
187-
sampler<arb::cable_point_meta_type, arb::cable_sample_type>);
183+
sampler<arb::cable_point_meta_type>);
188184
break;
189185
default:
190186
std::cerr << "Invalid probe kind\n";
@@ -235,7 +231,7 @@ bool parse_options(options& opt, int& argc, char** argv) {
235231
auto do_help = [&]() { usage(argv[0], help_msg); };
236232

237233
// Map probe argument to output variable name and a lambda that makes specific probe address from a location.
238-
using probe_spec_t = std::tuple<std::string, probe_kind, std::function<any(std::any)>>;
234+
using probe_spec_t = std::tuple<std::string, probe_kind, std::function<std::any(std::any)>>;
239235
std::pair<const char*, probe_spec_t> probe_tbl[] {
240236
// located probes
241237
{"v", {"v", probe_kind::state, [](std::any a) -> std::any { return arb::cable_probe_membrane_voltage{any2loc(a)}; }}},

example/remote/remote.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ mpi_handle setup_mpi() {
150150
void sampler(arb::probe_metadata pm, const arb::sample_records& samples) {
151151
if (pm.id.gid != 0) return;
152152
if (pm.id.tag != "Um") return;
153-
auto reader = arb::make_sample_reader<arb::lif_probe_metadata, arb::lif_sample_type>(pm.meta, samples);
153+
auto reader = arb::make_sample_reader<arb::lif_meta_type>(pm.meta, samples);
154154
for (std::size_t ix = 0; ix < reader.n_sample; ++ix) {
155155
double time = reader.get_time(ix);
156156
double value = reader.get_value(ix);

example/ring/ring.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ using arb::cell_kind;
5353
using arb::time_type;
5454

5555
// result of simple sampler for probe type
56-
using sample_result = arb::simple_sampler_result<arb::cable_state_meta_type, arb::cable_sample_type>;
56+
using sample_result = arb::simple_sampler_result<arb::cable_state_meta_type>;
5757

5858
// Writes voltage trace as a json file.
5959
void write_trace_json(const sample_result&);

0 commit comments

Comments
 (0)