Skip to content

Commit a49de7b

Browse files
authored
Merge pull request #214 from beomki-yeo/fix-cuda-example
Make CUDA algorithms identical to SYCL algorithms
2 parents 79cb8ea + 77deb52 commit a49de7b

27 files changed

+871
-544
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ cmake --build <build_directory> <options>
211211
### cpu reconstruction chain
212212

213213
```sh
214-
<build_directory>/bin/traccc_seq_example --detector_file=tml_detector/trackml-detector.csv --digitization_config_file=tml_detector/default-geometric-config-generic.json --cell_directory=tml_pixels/ --events=10
214+
<build_directory>/bin/traccc_seq_example --detector_file=tml_detector/trackml-detector.csv --digitization_config_file=tml_detector/default-geometric-config-generic.json --input_directory=tml_pixels/ --events=10
215215
```
216216

217217
### cuda reconstruction chain

device/cuda/include/traccc/cuda/clusterization/clusterization_algorithm.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "traccc/edm/measurement.hpp"
1414
#include "traccc/edm/spacepoint.hpp"
1515
#include "traccc/utils/algorithm.hpp"
16+
#include "traccc/utils/memory_resource.hpp"
1617

1718
// VecMem include(s).
1819
#include <vecmem/memory/memory_resource.hpp>
@@ -31,7 +32,7 @@ class clusterization_algorithm
3132
/// Constructor for clusterization algorithm
3233
///
3334
/// @param mr is a memory resource (device)
34-
clusterization_algorithm(vecmem::memory_resource& mr);
35+
clusterization_algorithm(const traccc::memory_resource& mr);
3536

3637
/// Callable operator for clusterization algorithm
3738
///
@@ -43,7 +44,8 @@ class clusterization_algorithm
4344
const cell_container_types::host& cells_per_event) const override;
4445

4546
private:
46-
std::reference_wrapper<vecmem::memory_resource> m_mr;
47+
traccc::memory_resource m_mr;
48+
std::unique_ptr<vecmem::copy> m_copy;
4749
};
4850

4951
} // namespace traccc::cuda

device/cuda/include/traccc/cuda/seeding/seed_finding.hpp

+35-8
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
#include "traccc/seeding/detail/seeding_config.hpp"
1414
#include "traccc/seeding/detail/spacepoint_grid.hpp"
1515
#include "traccc/utils/algorithm.hpp"
16+
#include "traccc/utils/memory_resource.hpp"
1617

1718
// VecMem include(s).
1819
#include <vecmem/containers/data/vector_buffer.hpp>
19-
#include <vecmem/memory/memory_resource.hpp>
20+
#include <vecmem/utils/copy.hpp>
2021

2122
// System include(s).
2223
#include <functional>
@@ -25,27 +26,53 @@ namespace traccc::cuda {
2526

2627
/// Seed finding for cuda
2728
class seed_finding
28-
: public algorithm<host_seed_collection(
29-
const spacepoint_container_types::view&, const sp_grid_const_view&)> {
29+
: public algorithm<vecmem::data::vector_buffer<seed>(
30+
const spacepoint_container_types::const_view&,
31+
const sp_grid_const_view&)>,
32+
public algorithm<vecmem::data::vector_buffer<seed>(
33+
const spacepoint_container_types::buffer&, const sp_grid_buffer&)> {
3034

3135
public:
3236
/// Constructor for the cuda seed finding
3337
///
3438
/// @param config is seed finder configuration parameters
3539
/// @param sp_grid spacepoint grid
3640
/// @param mr vecmem memory resource
37-
seed_finding(const seedfinder_config& config, vecmem::memory_resource& mr);
41+
seed_finding(const seedfinder_config& config,
42+
const traccc::memory_resource& mr);
3843

3944
/// Callable operator for the seed finding
4045
///
41-
/// @return seed_collection is the vector of seeds per event
42-
output_type operator()(const spacepoint_container_types::view& spacepoints,
43-
const sp_grid_const_view& g2_view) const override;
46+
/// @param spacepoints_view is a view of all spacepoints in the event
47+
/// @param g2_view is a view of the spacepoint grid
48+
/// @return a vector buffer of seeds
49+
///
50+
vecmem::data::vector_buffer<seed> operator()(
51+
const spacepoint_container_types::const_view& spacepoints_view,
52+
const sp_grid_const_view& g2_view) const override;
53+
54+
/// Callable operator for the seed finding
55+
///
56+
/// @param spacepoints_buffer is a buffer of all spacepoints in the event
57+
/// @param g2_buffer is a buffer of the spacepoint grid
58+
/// @return a vector buffer of seeds
59+
///
60+
vecmem::data::vector_buffer<seed> operator()(
61+
const spacepoint_container_types::buffer& spacepoints_buffer,
62+
const sp_grid_buffer& g2_buffer) const override;
63+
64+
private:
65+
/// Implementation for the public seed finding operators.
66+
vecmem::data::vector_buffer<seed> operator()(
67+
const spacepoint_container_types::const_view& spacepoints_view,
68+
const sp_grid_const_view& g2_view,
69+
const std::vector<unsigned int>& grid_sizes) const;
4470

4571
private:
4672
seedfinder_config m_seedfinder_config;
4773
seedfilter_config m_seedfilter_config;
48-
std::reference_wrapper<vecmem::memory_resource> m_mr;
74+
traccc::memory_resource m_mr;
75+
std::unique_ptr<vecmem::copy> m_copy;
4976
};
5077

5178
} // namespace traccc::cuda

device/cuda/include/traccc/cuda/seeding/seed_selecting.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace cuda {
3333
void seed_selecting(
3434
const seedfilter_config& filter_config,
3535
const vecmem::vector<device::doublet_counter_header>& dcc_headers,
36-
const spacepoint_container_types::view& spacepoints,
36+
const spacepoint_container_types::const_view& spacepoints,
3737
sp_grid_const_view internal_sp_view,
3838
device::doublet_counter_container_types::const_view dcc_view,
3939
triplet_counter_container_view tcc_view, triplet_container_view tc_view,

device/cuda/include/traccc/cuda/seeding/seeding_algorithm.hpp

+22-7
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,41 @@
1919
// VecMem include(s).
2020
#include <vecmem/memory/memory_resource.hpp>
2121

22+
// traccc library include(s).
23+
#include "traccc/utils/memory_resource.hpp"
24+
2225
namespace traccc::cuda {
2326

2427
/// Main algorithm for performing the track seeding on an NVIDIA GPU
25-
class seeding_algorithm : public algorithm<host_seed_collection(
26-
const spacepoint_container_types::view&)> {
28+
class seeding_algorithm : public algorithm<vecmem::data::vector_buffer<seed>(
29+
const spacepoint_container_types::const_view&)>,
30+
public algorithm<vecmem::data::vector_buffer<seed>(
31+
const spacepoint_container_types::buffer&)> {
2732

2833
public:
2934
/// Constructor for the seed finding algorithm
3035
///
3136
/// @param mr The memory resource to use
3237
///
33-
seeding_algorithm(vecmem::memory_resource& mr);
38+
seeding_algorithm(const traccc::memory_resource& mr);
39+
40+
/// Operator executing the algorithm.
41+
///
42+
/// @param spacepoints_view is a view of all spacepoints in the event
43+
/// @return the buffer of track seeds reconstructed from the spacepoints
44+
///
45+
vecmem::data::vector_buffer<seed> operator()(
46+
const spacepoint_container_types::const_view& spacepoints_view)
47+
const override;
3448

3549
/// Operator executing the algorithm.
3650
///
37-
/// @param spacepoint All spacepoints in the event
38-
/// @return The track seeds reconstructed from the spacepoints
51+
/// @param spacepoints_buffer is a buffer of all spacepoints in the event
52+
/// @return the buffer of track seeds reconstructed from the spacepoints
3953
///
40-
output_type operator()(
41-
const spacepoint_container_types::view& spacepoints) const override;
54+
vecmem::data::vector_buffer<seed> operator()(
55+
const spacepoint_container_types::buffer& spacepoints_buffer)
56+
const override;
4257

4358
private:
4459
/// Sub-algorithm performing the spacepoint binning

device/cuda/include/traccc/cuda/seeding/spacepoint_binning.hpp

+21-7
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
#include "traccc/seeding/detail/seeding_config.hpp"
1313
#include "traccc/seeding/detail/spacepoint_grid.hpp"
1414
#include "traccc/utils/algorithm.hpp"
15+
#include "traccc/utils/memory_resource.hpp"
1516

1617
// VecMem include(s).
17-
#include <vecmem/memory/memory_resource.hpp>
18+
#include <vecmem/utils/copy.hpp>
1819

1920
// System include(s).
2021
#include <functional>
@@ -24,22 +25,35 @@ namespace traccc::cuda {
2425

2526
/// Spacepoing binning executed on a CUDA device
2627
class spacepoint_binning : public algorithm<sp_grid_buffer(
27-
const spacepoint_container_types::view&)> {
28+
const spacepoint_container_types::const_view&)>,
29+
public algorithm<sp_grid_buffer(
30+
const spacepoint_container_types::buffer&)> {
2831

2932
public:
3033
/// Constructor for the algorithm
3134
spacepoint_binning(const seedfinder_config& config,
3235
const spacepoint_grid_config& grid_config,
33-
vecmem::memory_resource& mr);
36+
const traccc::memory_resource& mr);
3437

35-
/// Function executing the algorithm
36-
sp_grid_buffer operator()(
37-
const spacepoint_container_types::view& sp_data) const override;
38+
/// Function executing the algorithm with a a view of spacepoints
39+
sp_grid_buffer operator()(const spacepoint_container_types::const_view&
40+
spacepoints_view) const override;
41+
42+
/// Function executing the algorithm with spacepoint buffer
43+
sp_grid_buffer operator()(const spacepoint_container_types::buffer&
44+
spacepoints_buffer) const override;
3845

3946
private:
47+
/// Implementation for the public spacepoint binning operators
48+
sp_grid_buffer operator()(
49+
const spacepoint_container_types::const_view& spacepoints_view,
50+
const std::vector<unsigned int>& sp_sizes) const;
51+
52+
/// Member variables
4053
seedfinder_config m_config;
4154
std::pair<sp_grid::axis_p0_type, sp_grid::axis_p1_type> m_axes;
42-
std::reference_wrapper<vecmem::memory_resource> m_mr;
55+
traccc::memory_resource m_mr;
56+
std::unique_ptr<vecmem::copy> m_copy;
4357

4458
}; // class spacepoint_binning
4559

device/cuda/include/traccc/cuda/seeding/track_params_estimation.hpp

+36-8
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,61 @@
77

88
#pragma once
99

10+
// Project include(s)
1011
#include "traccc/seeding/track_params_estimation_helper.hpp"
1112
#include "traccc/utils/algorithm.hpp"
13+
#include "traccc/utils/memory_resource.hpp"
14+
15+
// VecMem include(s).
16+
#include <vecmem/utils/copy.hpp>
1217

1318
namespace traccc {
1419
namespace cuda {
1520

1621
/// track parameter estimation for cuda
1722
struct track_params_estimation
1823
: public algorithm<host_bound_track_parameters_collection(
19-
const spacepoint_container_types::view&, host_seed_collection&&)> {
24+
const spacepoint_container_types::const_view&,
25+
const vecmem::data::vector_view<const seed>&)>,
26+
public algorithm<host_bound_track_parameters_collection(
27+
const spacepoint_container_types::buffer&,
28+
const vecmem::data::vector_buffer<seed>&)> {
29+
2030
public:
2131
/// Constructor for track_params_estimation
2232
///
2333
/// @param mr is the memory resource
24-
track_params_estimation(vecmem::memory_resource& mr) : m_mr(mr) {}
34+
track_params_estimation(const traccc::memory_resource& mr);
2535

2636
/// Callable operator for track_params_esitmation
2737
///
28-
/// @param input_type is the seed container
38+
/// @param spaepoints_view is the view of the spacepoint container
39+
/// @param seeds_view is the view of the seed container
40+
/// @return vector of bound track parameters
2941
///
30-
/// @return vector of bound track parameters
31-
output_type operator()(
32-
const spacepoint_container_types::view& spacepoints_view,
33-
host_seed_collection&& seeds) const override;
42+
host_bound_track_parameters_collection operator()(
43+
const spacepoint_container_types::const_view& spacepoints_view,
44+
const vecmem::data::vector_view<const seed>& seeds_view) const override;
45+
46+
/// Callable operator for track_params_esitmation
47+
///
48+
/// @param spaepoints_buffer is the buffer of the spacepoint container
49+
/// @param seeds_buffer is the buffer of the seed container
50+
/// @return vector of bound track parameters
51+
///
52+
host_bound_track_parameters_collection operator()(
53+
const spacepoint_container_types::buffer& spacepoints_buffer,
54+
const vecmem::data::vector_buffer<seed>& seeds_buffer) const override;
3455

3556
private:
36-
std::reference_wrapper<vecmem::memory_resource> m_mr;
57+
/// Implementation for the public track params estimation operators
58+
host_bound_track_parameters_collection operator()(
59+
const spacepoint_container_types::const_view& spacepoints_view,
60+
const vecmem::data::vector_view<const seed>& seeds_view,
61+
std::size_t seeds_size) const;
62+
63+
traccc::memory_resource m_mr;
64+
std::unique_ptr<vecmem::copy> m_copy;
3765
};
3866

3967
} // namespace cuda

0 commit comments

Comments
 (0)