-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathMyTrackFindingAlgorithm.hpp
More file actions
224 lines (196 loc) · 9.77 KB
/
MyTrackFindingAlgorithm.hpp
File metadata and controls
224 lines (196 loc) · 9.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
#pragma once
#include "Acts/Utilities/Logger.hpp"
#include "Acts/MagneticField/MagneticFieldProvider.hpp"
#include "Acts/Geometry/TrackingGeometry.hpp"
#include "Acts/TrackFitting/KalmanFitter.hpp"
#include "Acts/Utilities/CalibrationContext.hpp"
#include "ActsExamples/Framework/DataHandle.hpp"
#include "ActsExamples/Framework/IAlgorithm.hpp"
#include "ActsExamples/Framework/ProcessCode.hpp"
#include "ActsExamples/Framework/AlgorithmContext.hpp"
#include "ActsExamples/EventData/Track.hpp"
#include "ActsExamples/EventData/Measurement.hpp"
#include "ActsExamples/EventData/Index.hpp"
//#include "ActsExamples/EventData/IndexSourceLink.hpp"
namespace My{
using ActsExamples::ReadDataHandle;
using ActsExamples::WriteDataHandle;
using ActsExamples::IAlgorithm;
using ActsExamples::ProcessCode;
using ActsExamples::AlgorithmContext;
using ActsExamples::ConstTrackContainer;
using ActsExamples::TrackParametersContainer;
using ActsExamples::MeasurementContainer;
using ActsExamples::IndexSourceLinkAccessor;
using ActsExamples::TrackContainer;
using ActsExamples::IndexMultimap;
using ActsExamples::Index;
using ActsExamples::MeasurementParticlesMap;
using TrackProxy = TrackContainer::TrackProxy;
using TrackStateProxy = TrackContainer::TrackStateProxy;
using TrackStateContainerBackend = TrackContainer::TrackStateContainerBackend;
using Updater = typename Acts::KalmanFitterExtensions<TrackStateContainerBackend>::Updater;
using Calibrator = typename Acts::KalmanFitterExtensions<TrackStateContainerBackend>::Calibrator;
using PM = Acts::TrackStatePropMask;
struct MyTrackFindingResult {
std::vector<TrackStateProxy> trackStateCandidates;
std::vector<TrackProxy> activeBranches;
std::vector<TrackProxy> collectedTracks;
};
struct MyTrackFindingActor {
using result_type = MyTrackFindingResult;
IndexSourceLinkAccessor slAccessor;
const MeasurementContainer* measurements{nullptr};
TrackStateContainerBackend* trackStates{nullptr};
const Acts::CalibrationContext* calibrationContext{nullptr};
const std::vector<int>* particleIds;
Calibrator calibrator{Acts::DelegateFuncTag<Acts::detail::voidFitterCalibrator<TrackStateContainerBackend>>{}};
Updater updater{Acts::DelegateFuncTag<Acts::detail::voidFitterUpdater<TrackStateContainerBackend>>{}};
template <typename propagator_state_t, typename stepper_t, typename navigator_t>
void act(propagator_state_t& state, const stepper_t& stepper, const navigator_t& navigator,
result_type& result, const Acts::Logger& logger) const {
if (result.activeBranches.empty()) return;
auto surface = navigator.currentSurface(state.navigation);
if (surface != nullptr) {
// ACTS_DEBUG("On surface: " << surface->geometryId());
filter(surface, state, stepper, navigator, result, logger);
}
if (navigator.endOfWorldReached(state.navigation)){
// store current branch and remove it from active branches
auto currentBranch = result.activeBranches.back();
result.collectedTracks.push_back(currentBranch);
result.activeBranches.pop_back();
if (result.activeBranches.empty()) return;
// set stepper state to the next branch
currentBranch = result.activeBranches.back();
ACTS_VERBOSE("Switching navigator to track " << currentBranch.tipIndex());
auto cs = currentBranch.outermostTrackState(); // current state
stepper.initialize(state.stepping, cs.filtered(), cs.filteredCovariance(), stepper.particleHypothesis(state.stepping), cs.referenceSurface());
state.navigation.options.startSurface = &cs.referenceSurface();
state.navigation.options.targetSurface = nullptr;
auto navInitRes = navigator.initialize(state.navigation, stepper.position(state.stepping), stepper.direction(state.stepping), state.options.direction);
}
}
template <typename propagator_state_t, typename stepper_t, typename navigator_t>
void filter(const Acts::Surface* surface, propagator_state_t& state, const stepper_t& stepper,
const navigator_t& navigator, result_type& result, const Acts::Logger& logger) const {
if (surface->associatedDetectorElement() == nullptr) return;
int layer = surface->geometryId().layer();
// TODO skip pixel-like surfaces using geo class
if (layer== 2) return;
if (layer== 6) return;
if (layer==13) return;
if (layer==20) return;
if (layer==27) return;
if (layer==34) return;
if (layer==38) return;
auto boundStateRes = stepper.boundState(state.stepping, *surface);
auto& boundState = *boundStateRes;
auto& [boundParams, jacobian, pathLength] = boundState;
// TODO add material
// create trackState candidates
result.trackStateCandidates.clear();
auto [slBegin, slEnd] = slAccessor.range(*surface);
int passedCandidates = 0;
double chi2max = 10;
auto currentBranch = result.activeBranches.back();
auto tipIndex = currentBranch.tipIndex();
if (slBegin - slEnd>0) result.trackStateCandidates.reserve(slBegin - slEnd);
char buffer[200]; // Define a fixed-size buffer
for (auto it = slBegin; it != slEnd; ++it) {
// if (it != slBegin) continue; // only one hit per surface allowed
const auto sl = *it; // source link
auto meas = measurements->getMeasurement(sl.get<ActsExamples::IndexSourceLink>().index());
auto particleId = particleIds->at(meas.index());
PM mask = PM::Predicted | PM::Jacobian | PM::Calibrated;
auto ts = trackStates->makeTrackState(mask, tipIndex);
ts.predicted() = boundParams.parameters();
ts.predictedCovariance() = *boundParams.covariance();
ts.jacobian() = jacobian;
ts.pathLength() = pathLength;
ts.setReferenceSurface(boundParams.referenceSurface().getSharedPtr());
calibrator(state.geoContext, *calibrationContext, sl, ts);
if (ts.calibratedSize()!=1) continue;
// chi2 for 1-D measurements
double calib = ts.effectiveCalibrated().data()[0];
double pred = ts.predicted().data()[0];
double calibCov = ts.effectiveCalibratedCovariance().data()[0];
double predCov = ts.predictedCovariance().data()[0];
double res = calib - pred;
double chi2 = res*res/(calibCov + predCov);
std::sprintf(buffer, "track=%2d layer=%2d partId=%2d calib=%4.1f pred=%4.1f cov=%6.4f chi2=%9.5f",
tipIndex, surface->geometryId().layer(), particleId, calib, pred, predCov, chi2);
ACTS_VERBOSE(buffer);
// ACTS_DEBUG("track=" << tipIndex << "layer=" << surface->geometryId().layer() <<
// "particleId=" << particleId << " calib=" << calib <<" pred=" << pred << " predCov=" << predCov << " chi2=" << chi2);
ts.chi2() = chi2;
result.trackStateCandidates.push_back(ts);
if (ts.chi2() > chi2max) continue;
passedCandidates++;
}
// ACTS_INFO("Track " << tipIndex << " on surface: " << surface->geometryId() << " candidates=" << result.trackStateCandidates.size());
if (passedCandidates==0) return;
bool isOutlier = (passedCandidates==0);
// copy new track state indices from trackStateCandidates to the trackStates container
// TODO sort track candidates according to chi2
std::vector<int> trackStateList;
for (size_t i=0; i<result.trackStateCandidates.size(); i++){
auto& ts = result.trackStateCandidates[i];
if (ts.chi2() > chi2max) continue;
PM mask = PM::Predicted | PM::Filtered | PM::Jacobian | PM::Calibrated;
auto trackState = trackStates->makeTrackState(mask, ts.previous());
trackState.copyFrom(ts, mask, false);
trackState.typeFlags().set(Acts::TrackStateFlag::ParameterFlag);
trackState.typeFlags().set(Acts::TrackStateFlag::MeasurementFlag);
trackStateList.push_back(trackState.index());
}
auto rootBranch = result.activeBranches.back();
std::vector<TrackProxy> newBranches;
for (int i=0;i<trackStateList.size();i++){
auto shallowCopy = [&] {
auto sc = rootBranch.container().makeTrack();
sc.copyFromShallow(rootBranch);
return sc;
};
auto newBranch = (i==0) ? rootBranch : shallowCopy();
newBranch.tipIndex() = trackStateList[i];
newBranches.push_back(newBranch);
}
result.activeBranches.pop_back();
for (TrackProxy newBranch : newBranches) {
auto trackState = newBranch.outermostTrackState();
if (trackState.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
auto updateRes = updater(state.geoContext, trackState, logger);
// TODO: add branch stopper
}
result.activeBranches.push_back(newBranch);
}
currentBranch = result.activeBranches.back();
auto cs = currentBranch.outermostTrackState(); // current state
auto freePar = Acts::MultiTrajectoryHelpers::freeFiltered(state.geoContext, cs);
stepper.update(state.stepping, freePar, cs.filtered(), cs.filteredCovariance(), *surface);
// TODO add material effects
}
};
class MyTrackFindingAlgorithm final : public IAlgorithm {
public:
class Config {
public:
std::string inputMeasurements;
std::string inputMeasurementParticlesMap;
std::string inputInitialTrackParameters;
std::string outputTracks;
std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry;
std::shared_ptr<const Acts::MagneticFieldProvider> magneticField;
};
MyTrackFindingAlgorithm(Config config, Acts::Logging::Level level);
ProcessCode execute(const AlgorithmContext& ctx) const override;
const Config& config() const { return m_cfg; }
private:
Config m_cfg;
ReadDataHandle<MeasurementContainer> m_inputMeasurements{this, "InputMeasurements"};
ReadDataHandle<MeasurementParticlesMap> m_inputMeasurementParticlesMap{this, "InputMeasurementParticlesMap"};
ReadDataHandle<TrackParametersContainer> m_inputInitialTrackParameters{this, "InputInitialTrackParameters"};
WriteDataHandle<ConstTrackContainer> m_outputTracks{this, "OutputTracks"};
};
} // namespace My