Skip to content

Commit dcb3ae9

Browse files
authored
Merge pull request #2401 from jltsiren/master
Cache GBWT nodes in gapless extension
2 parents 3ccb6f9 + a8575ab commit dcb3ae9

6 files changed

+156
-63
lines changed

scripts/giraffe-wrangler.sh

+32-13
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,30 @@ shift
4646

4747
# Define the Giraffe parameters
4848
GIRAFFE_OPTS=(-s75 -u 0.1 -v 1 -w 5 -C 600)
49-
# And the thread count for everyone
50-
THREAD_COUNT=32
49+
50+
# And the thread count for everyone.
51+
# Should fit on a NUMA node
52+
THREAD_COUNT=24
5153

5254
# Define a work directory
5355
# TODO: this requires GNU mptemp
5456
WORK="$(mktemp -d)"
5557

58+
# Check for NUMA. If we have NUMA and no numactl results may be unreliable
59+
NUMA_COUNT="$(lscpu | grep "NUMA node(s)" | cut -f3- -d' ' | tr -d ' ')"
60+
NUMA_PREFIX=""
61+
NUMA_WARNING=0
62+
63+
if [[ "${NUMA_COUNT}" -gt "1" ]] ; then
64+
if which numactl >/dev/null 2>&1 ; then
65+
# Run everything on one NUMA node
66+
NUMA_PREFIX="numactl --cpunodebind=0 --membind=0"
67+
else
68+
# We should warn in the report that NUMA may confound the results
69+
NUMA_WARNING=1
70+
fi
71+
fi
72+
5673
if which perf >/dev/null 2>&1 ; then
5774
# Record profile.
5875
# Do this first because perf is likely to be misconfigured and we want to fail fast.
@@ -61,17 +78,17 @@ if which perf >/dev/null 2>&1 ; then
6178
# script makes take forever because the binary is huge
6279
strip bin/vg
6380

64-
perf record -F 100 --call-graph dwarf -o "${WORK}/perf.data" vg gaffe -x "${XG_INDEX}" -m "${MINIMIZER_INDEX}" -H "${GBWT_INDEX}" -d "${DISTANCE_INDEX}" -f "${REAL_FASTQ}" -t "${THREAD_COUNT}" "${GIRAFFE_OPTS[@]}" >"${WORK}/perf.gam"
81+
${NUMA_PREFIX} perf record -F 100 --call-graph dwarf -o "${WORK}/perf.data" vg gaffe -x "${XG_INDEX}" -m "${MINIMIZER_INDEX}" -H "${GBWT_INDEX}" -d "${DISTANCE_INDEX}" -f "${REAL_FASTQ}" -t "${THREAD_COUNT}" "${GIRAFFE_OPTS[@]}" >"${WORK}/perf.gam"
6582
perf script -i "${WORK}/perf.data" >"${WORK}/out.perf"
6683
deps/FlameGraph/stackcollapse-perf.pl "${WORK}/out.perf" >"${WORK}/out.folded"
6784
deps/FlameGraph/flamegraph.pl "${WORK}/out.folded" > "${WORK}/profile.svg"
6885
fi
6986

7087
# Run simulated reads, with stats
71-
vg gaffe --track-correctness -x "${XG_INDEX}" -m "${MINIMIZER_INDEX}" -H "${GBWT_INDEX}" -d "${DISTANCE_INDEX}" -G "${SIM_GAM}" -t "${THREAD_COUNT}" "${GIRAFFE_OPTS[@]}" >"${WORK}/mapped.gam"
88+
${NUMA_PREFIX} vg gaffe --track-correctness -x "${XG_INDEX}" -m "${MINIMIZER_INDEX}" -H "${GBWT_INDEX}" -d "${DISTANCE_INDEX}" -G "${SIM_GAM}" -t "${THREAD_COUNT}" "${GIRAFFE_OPTS[@]}" >"${WORK}/mapped.gam"
7289

7390
# And map to compare with them
74-
vg map -x "${XG_INDEX}" -g "${GCSA_INDEX}" -G "${SIM_GAM}" -t "${THREAD_COUNT}" >"${WORK}/mapped-map.gam"
91+
${NUMA_PREFIX} vg map -x "${XG_INDEX}" -g "${GCSA_INDEX}" -G "${SIM_GAM}" -t "${THREAD_COUNT}" >"${WORK}/mapped-map.gam"
7592

7693
# Annotate and compare against truth
7794
vg annotate -p -x "${XG_INDEX}" -a "${WORK}/mapped.gam" >"${WORK}/annotated.gam"
@@ -91,26 +108,28 @@ vg view -aj "${WORK}/mapped.gam" | scripts/giraffe-facts.py "${WORK}/facts" >"${
91108
# Now do the real reads
92109

93110
# Get RPS
94-
vg gaffe -p -x "${XG_INDEX}" -m "${MINIMIZER_INDEX}" -H "${GBWT_INDEX}" -d "${DISTANCE_INDEX}" -f "${REAL_FASTQ}" -t "${THREAD_COUNT}" "${GIRAFFE_OPTS[@]}" >"${WORK}/real.gam" 2>"${WORK}/log.txt"
111+
${NUMA_PREFIX} vg gaffe -p -x "${XG_INDEX}" -m "${MINIMIZER_INDEX}" -H "${GBWT_INDEX}" -d "${DISTANCE_INDEX}" -f "${REAL_FASTQ}" -t "${THREAD_COUNT}" "${GIRAFFE_OPTS[@]}" >"${WORK}/real.gam" 2>"${WORK}/log.txt"
95112

96113
GIRAFFE_RPS="$(cat "${WORK}/log.txt" | grep "reads per second" | sed 's/[^0-9.]//g')"
97114

98115
# Get RPS for bwa-mem
99116
REAL_READ_COUNT="$(cat "${REAL_FASTQ}" | wc -l)"
100117
((REAL_READ_COUNT /= 4))
101118

102-
bwa mem -t "${THREAD_COUNT}" "${FASTA}" "${REAL_FASTQ}" >"${WORK}/mapped.bam" 2>"${WORK}/bwa-log.txt"
103-
cat "${REAL_FASTQ}" "${REAL_FASTQ}" >"${WORK}/double.fq"
104-
bwa mem -t "${THREAD_COUNT}" "${FASTA}" "${WORK}/double.fq" >"${WORK}/mapped-double.bam" 2>"${WORK}/bwa-log-double.txt"
119+
${NUMA_PREFIX} bwa mem -t "${THREAD_COUNT}" "${FASTA}" "${REAL_FASTQ}" >"${WORK}/mapped.bam" 2>"${WORK}/bwa-log.txt"
105120

106-
BWA_TIME="$(cat "${WORK}/bwa-log.txt" | grep "Real time:" | sed 's/[^0-9.]*\([0-9.]*\).*/\1/')"
107-
BWA_DOUBLE_TIME="$(cat "${WORK}/bwa-log-double.txt" | grep "Real time:" | sed 's/[^0-9.]*\([0-9.]*\).*/\1/')"
108-
109-
BWA_RPS="$(echo "${REAL_READ_COUNT} / (${BWA_DOUBLE_TIME} - ${BWA_TIME}) / ${THREAD_COUNT}" | bc -l)"
121+
# Now we get all the batch times from BWA and use those to compute RPS values.
122+
# This is optimistic but hopefully consistent.
123+
BWA_RPS_ALL_THREADS="$(cat "${WORK}/bwa-log.txt" | grep "Processed" | sed 's/[^0-9]*\([0-9]*\) reads in .* CPU sec, \([0-9]*\.[0-9]*\) real sec/\1 \2/g' | tr ' ' '\t' | awk '{sum1+=$1; sum2+=$2} END {print sum1/sum2}')"
110124

125+
BWA_RPS="$(echo "${BWA_RPS_ALL_THREADS} / ${THREAD_COUNT}" | bc -l)"
111126

112127
echo "==== Giraffe Wrangler Report for vg $(vg version -s) ===="
113128

129+
if [[ "${NUMA_WARNING}" == "1" ]] ; then
130+
echo "WARNING! Unable to restrict to a single NUMA node! Results may have high variance!"
131+
fi
132+
114133
if which perf >/dev/null 2>&1 ; then
115134
# Output perf stuff
116135
mv "${WORK}/perf.data" ./perf.data

src/gapless_extender.cpp

+24-8
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ std::vector<GaplessExtension> GaplessExtender::extend(cluster_type& cluster, con
238238
return result;
239239
}
240240

241+
// Allocate a GBWT record cache.
242+
gbwt::CachedGBWT cache = this->graph->get_cache();
243+
241244
// Find either the best extension for each seed or the best full-length alignment
242245
// for the entire cluster. If we have found a full-length alignment with
243246
// at most max_mismatches mismatches, we are no longer interested in extensions with
@@ -260,7 +263,7 @@ std::vector<GaplessExtension> GaplessExtender::extend(cluster_type& cluster, con
260263
size_t read_offset = get_read_offset(seed);
261264
size_t node_offset = get_node_offset(seed);
262265
GaplessExtension match {
263-
{ seed.first }, node_offset, this->graph->get_bd_state(seed.first),
266+
{ seed.first }, node_offset, this->graph->get_bd_state(cache, seed.first),
264267
{ read_offset, read_offset }, { },
265268
static_cast<int32_t>(0), false, false,
266269
false, false, static_cast<uint32_t>(0), static_cast<uint32_t>(0)
@@ -299,7 +302,7 @@ std::vector<GaplessExtension> GaplessExtender::extend(cluster_type& cluster, con
299302

300303
// Case 1: Extend to the right.
301304
if (!curr.right_maximal) {
302-
this->graph->follow_paths(curr.state, false, [&](const gbwt::BidirectionalState& next_state) -> bool {
305+
this->graph->follow_paths(cache, curr.state, false, [&](const gbwt::BidirectionalState& next_state) -> bool {
303306
if (next_state.empty()) {
304307
return true;
305308
}
@@ -340,7 +343,7 @@ std::vector<GaplessExtension> GaplessExtender::extend(cluster_type& cluster, con
340343

341344
// Case 2: Extend to the left.
342345
else if (!curr.left_maximal) {
343-
this->graph->follow_paths(curr.state, true, [&](const gbwt::BidirectionalState& next_state) -> bool {
346+
this->graph->follow_paths(cache, curr.state, true, [&](const gbwt::BidirectionalState& next_state) -> bool {
344347
if (next_state.empty()) {
345348
return true;
346349
}
@@ -409,7 +412,7 @@ std::vector<GaplessExtension> GaplessExtender::extend(cluster_type& cluster, con
409412
remove_duplicates(result);
410413
find_mismatches(sequence, *(this->graph), result);
411414
if (trim_extensions) {
412-
this->trim(result, max_mismatches);
415+
this->trim(result, max_mismatches, &cache);
413416
}
414417

415418
return result;
@@ -419,7 +422,7 @@ std::vector<GaplessExtension> GaplessExtender::extend(cluster_type& cluster, con
419422

420423
// Trim mismatches from the extension to maximize the score. Returns true if the
421424
// extension was trimmed.
422-
bool trim_mismatches(GaplessExtension& extension, const GBWTGraph& graph, const Aligner& aligner) {
425+
bool trim_mismatches(GaplessExtension& extension, const GBWTGraph& graph, const gbwt::CachedGBWT& cache, const Aligner& aligner) {
423426

424427
if (extension.exact()) {
425428
return false;
@@ -512,7 +515,7 @@ bool trim_mismatches(GaplessExtension& extension, const GBWTGraph& graph, const
512515
}
513516
if (head > 0 || tail < extension.path.size()) {
514517
in_place_subvector(extension.path, head, tail);
515-
extension.state = graph.bd_find(extension.path);
518+
extension.state = graph.bd_find(cache, extension.path);
516519
}
517520

518521
// Trim the mismatches.
@@ -529,16 +532,29 @@ bool trim_mismatches(GaplessExtension& extension, const GBWTGraph& graph, const
529532
return true;
530533
}
531534

532-
void GaplessExtender::trim(std::vector<GaplessExtension>& extensions, size_t max_mismatches) const {
535+
void GaplessExtender::trim(std::vector<GaplessExtension>& extensions, size_t max_mismatches, const gbwt::CachedGBWT* cache) const {
536+
537+
// Allocate a cache if we were not provided with one.
538+
bool free_cache = (cache == nullptr);
539+
if (free_cache) {
540+
cache = new gbwt::CachedGBWT(this->graph->get_cache());
541+
}
542+
533543
bool trimmed = false;
534544
for (GaplessExtension& extension : extensions) {
535545
if (!extension.full() || extension.mismatches() > max_mismatches) {
536-
trimmed |= trim_mismatches(extension, *(this->graph), *(this->aligner));
546+
trimmed |= trim_mismatches(extension, *(this->graph), *cache, *(this->aligner));
537547
}
538548
}
539549
if (trimmed) {
540550
remove_duplicates(extensions);
541551
}
552+
553+
// Free the cache if we allocated it.
554+
if (free_cache) {
555+
delete cache;
556+
cache = nullptr;
557+
}
542558
}
543559

544560
//------------------------------------------------------------------------------

src/gapless_extender.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,10 @@ class GaplessExtender {
141141
/**
142142
* Try to improve the score of each extension by trimming mismatches from the flanks.
143143
* Do not trim full-length alignments with <= max_mismatches mismatches.
144+
* Use the provided CachedGBWT or allocate a new one.
144145
* Note that extend() already calls this by default.
145146
*/
146-
void trim(std::vector<GaplessExtension>& extensions, size_t max_mismatches = MAX_MISMATCHES) const;
147+
void trim(std::vector<GaplessExtension>& extensions, size_t max_mismatches = MAX_MISMATCHES, const gbwt::CachedGBWT* cache = nullptr) const;
147148

148149
const GBWTGraph* graph;
149150
const Aligner* aligner;

src/gbwt_helper.cpp

+61-40
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ id_t GBWTGraph::max_node_id() const {
211211
return next_id - 1;
212212
}
213213

214-
// Using undocumented parts of the GBWT interface. --Jouni
215214
bool GBWTGraph::follow_edges_impl(const handle_t& handle, bool go_left, const std::function<bool(const handle_t&)>& iteratee) const {
216215

217216
// Incoming edges correspond to the outgoing edges of the reverse node.
@@ -220,9 +219,12 @@ bool GBWTGraph::follow_edges_impl(const handle_t& handle, bool go_left, const st
220219
curr = gbwt::Node::reverse(curr);
221220
}
222221

223-
gbwt::CompressedRecord record = this->index->record(curr);
224-
for (gbwt::rank_type outrank = 0; outrank < record.outdegree(); outrank++) {
225-
gbwt::node_type next = record.successor(outrank);
222+
// Cache the node.
223+
gbwt::CachedGBWT cache(*(this->index), true);
224+
gbwt::size_type cache_index = cache.findRecord(curr);
225+
226+
for (gbwt::rank_type outrank = 0; outrank < cache.outdegree(cache_index); outrank++) {
227+
gbwt::node_type next = cache.successor(cache_index, outrank);
226228
if (next == gbwt::ENDMARKER) {
227229
continue;
228230
}
@@ -361,15 +363,47 @@ gbwt::BidirectionalState GBWTGraph::bd_find(const std::vector<handle_t>& path) c
361363
return result;
362364
}
363365

364-
// Using undocumented parts of the GBWT interface. --Jouni
365366
bool GBWTGraph::follow_paths(gbwt::SearchState state, const std::function<bool(const gbwt::SearchState&)>& iteratee) const {
366-
gbwt::CompressedRecord record = this->index->record(state.node);
367-
for (gbwt::rank_type outrank = 0; outrank < record.outdegree(); outrank++) {
368-
gbwt::node_type next_node = record.successor(outrank);
369-
if (next_node == gbwt::ENDMARKER) {
367+
gbwt::CachedGBWT cache(*(this->index), true);
368+
return this->follow_paths(cache, state, iteratee);
369+
}
370+
371+
bool GBWTGraph::follow_paths(gbwt::BidirectionalState state, bool backward, const std::function<bool(const gbwt::BidirectionalState&)>& iteratee) const {
372+
gbwt::CachedGBWT cache(*(this->index), true);
373+
return this->follow_paths(cache, state, backward, iteratee);
374+
}
375+
376+
//------------------------------------------------------------------------------
377+
378+
gbwt::SearchState GBWTGraph::find(const gbwt::CachedGBWT& cache, const std::vector<handle_t>& path) const {
379+
if (path.empty()) {
380+
return gbwt::SearchState();
381+
}
382+
gbwt::SearchState result = this->get_state(cache, path[0]);
383+
for (size_t i = 1; i < path.size() && !result.empty(); i++) {
384+
result = cache.extend(result, handle_to_node(path[i]));
385+
}
386+
return result;
387+
}
388+
389+
gbwt::BidirectionalState GBWTGraph::bd_find(const gbwt::CachedGBWT& cache, const std::vector<handle_t>& path) const {
390+
if (path.empty()) {
391+
return gbwt::BidirectionalState();
392+
}
393+
gbwt::BidirectionalState result = this->get_bd_state(cache, path[0]);
394+
for (size_t i = 1; i < path.size() && !result.empty(); i++) {
395+
result = cache.bdExtendForward(result, handle_to_node(path[i]));
396+
}
397+
return result;
398+
}
399+
400+
bool GBWTGraph::follow_paths(const gbwt::CachedGBWT& cache, gbwt::SearchState state, const std::function<bool(const gbwt::SearchState&)>& iteratee) const {
401+
gbwt::size_type cache_index = cache.findRecord(state.node);
402+
for (gbwt::rank_type outrank = 0; outrank < cache.outdegree(cache_index); outrank++) {
403+
if (cache.successor(cache_index, outrank) == gbwt::ENDMARKER) {
370404
continue;
371405
}
372-
gbwt::SearchState next_state(next_node, record.LF(state.range, next_node));
406+
gbwt::SearchState next_state = cache.cachedExtend(state, cache_index, outrank);
373407
if (!iteratee(next_state)) {
374408
return false;
375409
}
@@ -378,27 +412,13 @@ bool GBWTGraph::follow_paths(gbwt::SearchState state, const std::function<bool(c
378412
return true;
379413
}
380414

381-
// Using undocumented parts of the GBWT interface. --Jouni
382-
bool GBWTGraph::follow_paths(gbwt::BidirectionalState state, bool backward, const std::function<bool(const gbwt::BidirectionalState&)>& iteratee) const {
383-
if (backward) {
384-
state.flip();
385-
}
386-
387-
gbwt::CompressedRecord record = this->index->record(state.forward.node);
388-
for (gbwt::rank_type outrank = 0; outrank < record.outdegree(); outrank++) {
389-
gbwt::node_type next_node = record.successor(outrank);
390-
if (next_node == gbwt::ENDMARKER) {
415+
bool GBWTGraph::follow_paths(const gbwt::CachedGBWT& cache, gbwt::BidirectionalState state, bool backward, const std::function<bool(const gbwt::BidirectionalState&)>& iteratee) const {
416+
gbwt::size_type cache_index = cache.findRecord(backward ? state.backward.node : state.forward.node);
417+
for (gbwt::rank_type outrank = 0; outrank < cache.outdegree(cache_index); outrank++) {
418+
if (cache.successor(cache_index, outrank) == gbwt::ENDMARKER) {
391419
continue;
392420
}
393-
gbwt::size_type reverse_offset = 0;
394-
gbwt::BidirectionalState next_state = state;
395-
next_state.forward.node = next_node;
396-
next_state.forward.range = record.bdLF(state.forward.range, next_node, reverse_offset);
397-
next_state.backward.range.first += reverse_offset;
398-
next_state.backward.range.second = next_state.backward.range.first + next_state.forward.size() - 1;
399-
if (backward) {
400-
next_state.flip();
401-
}
421+
gbwt::BidirectionalState next_state = (backward ? cache.cachedExtendBackward(state, cache_index, outrank) : cache.cachedExtendForward(state, cache_index, outrank));
402422
if (!iteratee(next_state)) {
403423
return false;
404424
}
@@ -443,12 +463,16 @@ void for_each_haplotype_window(const GBWTGraph& graph, size_t window_size,
443463

444464
// Traverse all starting nodes in parallel.
445465
graph.for_each_handle([&](const handle_t& h) -> bool {
466+
467+
// Get a GBWT cache.
468+
gbwt::CachedGBWT cache = graph.get_cache();
469+
446470
// Initialize the stack with both orientations.
447471
std::stack<GBWTTraversal> windows;
448472
size_t node_length = graph.get_length(h);
449473
for (bool is_reverse : { false, true }) {
450474
handle_t handle = (is_reverse ? graph.flip(h) : h);
451-
gbwt::SearchState state = graph.get_state(handle);
475+
gbwt::SearchState state = graph.get_state(cache, handle);
452476
if (state.empty()) {
453477
continue;
454478
}
@@ -467,24 +491,21 @@ void for_each_haplotype_window(const GBWTGraph& graph, size_t window_size,
467491
}
468492

469493
// Try to extend the window to all successor nodes.
470-
// We are using undocumented parts of the GBWT interface. --Jouni
471494
bool extend_success = false;
472-
gbwt::CompressedRecord record = graph.index->record(window.state.node);
473-
for (gbwt::rank_type outrank = 0; outrank < record.outdegree(); outrank++) {
474-
gbwt::node_type next_node = record.successor(outrank);
475-
if (next_node == gbwt::ENDMARKER) {
495+
gbwt::size_type cache_index = cache.findRecord(window.state.node);
496+
for (gbwt::rank_type outrank = 0; outrank < cache.outdegree(cache_index); outrank++) {
497+
if (cache.successor(cache_index, outrank) == gbwt::ENDMARKER) {
476498
continue;
477499
}
478-
gbwt::range_type next_range = record.LF(window.state.range, next_node);
479-
if (gbwt::Range::empty(next_range)) {
500+
gbwt::SearchState next_state = cache.cachedExtend(window.state, cache_index, outrank);
501+
if (next_state.empty()) {
480502
continue;
481503
}
482-
handle_t next_handle = GBWTGraph::node_to_handle(next_node);
504+
handle_t next_handle = GBWTGraph::node_to_handle(next_state.node);
483505
GBWTTraversal next_window = window;
484506
next_window.traversal.push_back(next_handle);
485507
next_window.length += std::min(graph.get_length(next_handle), target_length - window.length);
486-
next_window.state.node = next_node;
487-
next_window.state.range = next_range;
508+
next_window.state = next_state;
488509
windows.push(next_window);
489510
extend_success = true;
490511
}

0 commit comments

Comments
 (0)