Skip to content

Commit f81caa3

Browse files
[NFC][SYCL][Graph] Add iterator_range::to<Container>() helper (#19431)
1 parent 58f4da6 commit f81caa3

File tree

9 files changed

+81
-56
lines changed

9 files changed

+81
-56
lines changed

sycl/source/detail/device_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2295,7 +2295,7 @@ struct devices_deref_impl {
22952295
}
22962296
};
22972297
using devices_iterator =
2298-
variadic_iterator<devices_deref_impl,
2298+
variadic_iterator<devices_deref_impl, device,
22992299
std::vector<std::shared_ptr<device_impl>>::const_iterator,
23002300
std::vector<device>::const_iterator, device_impl *>;
23012301

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,15 +1970,11 @@ void modifiable_command_graph::print_graph(sycl::detail::string_view pathstr,
19701970

19711971
std::vector<node> modifiable_command_graph::get_nodes() const {
19721972
graph_impl::ReadLock Lock(impl->MMutex);
1973-
return createNodesFromImpls(impl->MNodeStorage);
1973+
return impl->nodes().to<std::vector<node>>();
19741974
}
19751975
std::vector<node> modifiable_command_graph::get_root_nodes() const {
19761976
graph_impl::ReadLock Lock(impl->MMutex);
1977-
auto &Roots = impl->MRoots;
1978-
std::vector<node_impl *> Impls{};
1979-
1980-
std::copy(Roots.begin(), Roots.end(), std::back_inserter(Impls));
1981-
return createNodesFromImpls(Impls);
1977+
return impl->roots().to<std::vector<node>>();
19821978
}
19831979

19841980
void modifiable_command_graph::checkNodePropertiesAndThrow(

sycl/source/detail/graph/graph_impl.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
293293
std::vector<std::shared_ptr<node_impl>> MNodeStorage;
294294

295295
nodes_range roots() const { return MRoots; }
296+
nodes_range nodes() const { return MNodeStorage; }
296297

297298
/// Find the last node added to this graph from an in-order queue.
298299
/// @param Queue In-order queue to find the last node added to the graph from.
@@ -665,6 +666,8 @@ class exec_graph_impl {
665666
return MPartitions;
666667
}
667668

669+
nodes_range nodes() const { return MNodeStorage; }
670+
668671
/// Query whether the graph contains any host-task nodes.
669672
/// @return True if the graph contains any host-task nodes. False otherwise.
670673
bool containsHostTask() const { return MContainsHostTask; }

sycl/source/detail/graph/memory_pool.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,8 @@ graph_mem_pool::tryReuseExistingAllocation(size_t Size, usm::alloc AllocType,
114114
// free nodes. We do this in a breadth-first approach because we want to find
115115
// the shortest path to a reusable allocation.
116116

117-
std::queue<node_impl *> NodesToCheck;
118-
119117
// Add all the dependent nodes to the queue, they will be popped first
120-
for (node_impl &Dep : DepNodes) {
121-
NodesToCheck.push(&Dep);
122-
}
118+
auto NodesToCheck = DepNodes.to<std::queue<node_impl *>>();
123119

124120
// Called when traversing over nodes to check if the current node is a free
125121
// node for one of the available allocations. If it is we populate AllocInfo

sycl/source/detail/graph/node_impl.cpp

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,42 +15,14 @@ inline namespace _V1 {
1515
namespace ext {
1616
namespace oneapi {
1717
namespace experimental {
18-
namespace detail {
19-
20-
/// Takes a vector of weak_ptrs to node_impls and returns a vector of node
21-
/// objects created from those impls, in the same order.
22-
std::vector<node> createNodesFromImpls(
23-
const std::vector<std::weak_ptr<detail::node_impl>> &Impls) {
24-
std::vector<node> Nodes{};
25-
Nodes.reserve(Impls.size());
26-
27-
for (std::weak_ptr<detail::node_impl> Impl : Impls) {
28-
Nodes.push_back(sycl::detail::createSyclObjFromImpl<node>(Impl.lock()));
29-
}
30-
31-
return Nodes;
32-
}
33-
34-
std::vector<node> createNodesFromImpls(nodes_range Impls) {
35-
std::vector<node> Nodes{};
36-
Nodes.reserve(Impls.size());
37-
38-
for (detail::node_impl &Impl : Impls) {
39-
Nodes.push_back(sycl::detail::createSyclObjFromImpl<node>(Impl));
40-
}
41-
42-
return Nodes;
43-
}
44-
} // namespace detail
45-
4618
node_type node::get_type() const { return impl->MNodeType; }
4719

4820
std::vector<node> node::get_predecessors() const {
49-
return detail::createNodesFromImpls(impl->MPredecessors);
21+
return impl->predecessors().to<std::vector<node>>();
5022
}
5123

5224
std::vector<node> node::get_successors() const {
53-
return detail::createNodesFromImpls(impl->MSuccessors);
25+
return impl->successors().to<std::vector<node>>();
5426
}
5527

5628
node node::get_node_from_event(event nodeEvent) {

sycl/source/detail/graph/node_impl.hpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,6 @@ class node_impl;
3636
class nodes_range;
3737
class exec_graph_impl;
3838

39-
/// Takes a vector of weak_ptrs to node_impls and returns a vector of node
40-
/// objects created from those impls, in the same order.
41-
std::vector<node>
42-
createNodesFromImpls(const std::vector<std::weak_ptr<node_impl>> &Impls);
43-
44-
std::vector<node> createNodesFromImpls(nodes_range Impls);
45-
4639
inline node_type getNodeTypeFromCG(sycl::detail::CGType CGType) {
4740
using sycl::detail::CG;
4841

@@ -780,7 +773,7 @@ struct nodes_deref_impl {
780773

781774
template <typename... ContainerTy>
782775
using nodes_iterator_impl =
783-
variadic_iterator<nodes_deref_impl,
776+
variadic_iterator<nodes_deref_impl, node,
784777
typename ContainerTy::const_iterator...>;
785778

786779
using nodes_iterator = nodes_iterator_impl<

sycl/source/detail/helpers.hpp

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,16 @@
88

99
#pragma once
1010

11+
#include <sycl/detail/impl_utils.hpp>
1112
#include <sycl/detail/kernel_name_str_t.hpp>
13+
#include <sycl/detail/type_traits.hpp>
1214

1315
#include <ur_api.h>
1416

17+
#include <algorithm>
18+
#include <iterator>
1519
#include <memory>
20+
#include <queue>
1621
#include <tuple>
1722
#include <variant>
1823
#include <vector>
@@ -30,13 +35,26 @@ const RTDeviceBinaryImage *
3035
retrieveKernelBinary(queue_impl &Queue, KernelNameStrRefT KernelName,
3136
CGExecKernel *CGKernel = nullptr);
3237

33-
template <typename DereferenceImpl, typename... Iterators>
38+
template <typename DereferenceImpl, typename SyclTy, typename... Iterators>
3439
class variadic_iterator {
3540
using storage_iter = std::variant<Iterators...>;
3641

3742
storage_iter It;
3843

3944
public:
45+
using iterator_category = std::forward_iterator_tag;
46+
using difference_type = std::ptrdiff_t;
47+
using reference = decltype(DereferenceImpl::dereference(
48+
*std::declval<nth_type_t<0, Iterators...>>()));
49+
using value_type = std::remove_reference_t<reference>;
50+
using sycl_type = SyclTy;
51+
using pointer = value_type *;
52+
static_assert(std::is_same_v<reference, value_type &>);
53+
54+
variadic_iterator(const variadic_iterator &) = default;
55+
variadic_iterator(variadic_iterator &&) = default;
56+
variadic_iterator(variadic_iterator &) = default;
57+
4058
template <typename IterTy>
4159
variadic_iterator(IterTy &&It) : It(std::forward<IterTy>(It)) {}
4260

@@ -52,6 +70,9 @@ class variadic_iterator {
5270
bool operator!=(const variadic_iterator &Other) const {
5371
return It != Other.It;
5472
}
73+
bool operator==(const variadic_iterator &Other) const {
74+
return It == Other.It;
75+
}
5576

5677
decltype(auto) operator*() {
5778
return std::visit(
@@ -64,6 +85,16 @@ class variadic_iterator {
6485

6586
// Non-owning!
6687
template <typename iterator> class iterator_range {
88+
using value_type = typename iterator::value_type;
89+
using sycl_type = typename iterator::sycl_type;
90+
91+
template <typename Container, typename = void>
92+
struct has_reserve : public std::false_type {};
93+
template <typename Container>
94+
struct has_reserve<
95+
Container, std::void_t<decltype(std::declval<Container>().reserve(1))>>
96+
: public std::true_type {};
97+
6798
public:
6899
iterator_range(const iterator_range &Other) = default;
69100

@@ -81,6 +112,40 @@ template <typename iterator> class iterator_range {
81112
bool empty() const { return Size == 0; }
82113
decltype(auto) front() const { return *begin(); }
83114

115+
template <typename Container>
116+
std::enable_if_t<
117+
check_type_in_v<Container, std::vector<sycl_type>,
118+
std::queue<value_type *>, std::vector<value_type *>,
119+
std::vector<std::shared_ptr<value_type>>>,
120+
Container>
121+
to() const {
122+
std::conditional_t<std::is_same_v<Container, std::queue<value_type *>>,
123+
typename std::queue<value_type *>::container_type,
124+
Container>
125+
Result;
126+
if constexpr (has_reserve<decltype(Result)>::value)
127+
Result.reserve(size());
128+
std::transform(
129+
begin(), end(), std::back_inserter(Result), [](value_type &E) {
130+
if constexpr (std::is_same_v<Container, std::vector<sycl_type>>)
131+
return createSyclObjFromImpl<sycl_type>(E);
132+
else if constexpr (std::is_same_v<
133+
Container,
134+
std::vector<std::shared_ptr<value_type>>>)
135+
return E.shared_from_this();
136+
else
137+
return &E;
138+
});
139+
if constexpr (std::is_same_v<Container, decltype(Result)>)
140+
return Result;
141+
else
142+
return Container{std::move(Result)};
143+
}
144+
145+
protected:
146+
template <typename Container>
147+
static constexpr bool has_reserve_v = has_reserve<Container>::value;
148+
84149
private:
85150
iterator Begin;
86151
iterator End;

sycl/source/detail/scheduler/commands.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3802,10 +3802,10 @@ bool ExecCGCommand::readyForCleanup() const {
38023802
UpdateCommandBufferCommand::UpdateCommandBufferCommand(
38033803
queue_impl *Queue,
38043804
ext::oneapi::experimental::detail::exec_graph_impl *Graph,
3805-
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
3806-
Nodes)
3805+
ext::oneapi::experimental::detail::nodes_range Nodes)
38073806
: Command(CommandType::UPDATE_CMD_BUFFER, Queue), MGraph(Graph),
3808-
MNodes(std::move(Nodes)) {}
3807+
MNodes(Nodes.to<std::vector<std::shared_ptr<
3808+
ext::oneapi::experimental::detail::node_impl>>>()) {}
38093809

38103810
ur_result_t UpdateCommandBufferCommand::enqueueImp() {
38113811
waitForPreparedHostEvents();

sycl/source/detail/scheduler/commands.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ inline namespace _V1 {
3333
namespace ext::oneapi::experimental::detail {
3434
class exec_graph_impl;
3535
class node_impl;
36+
class nodes_range;
3637
} // namespace ext::oneapi::experimental::detail
3738
namespace detail {
3839

@@ -723,8 +724,7 @@ class UpdateCommandBufferCommand : public Command {
723724
explicit UpdateCommandBufferCommand(
724725
queue_impl *Queue,
725726
ext::oneapi::experimental::detail::exec_graph_impl *Graph,
726-
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
727-
Nodes);
727+
ext::oneapi::experimental::detail::nodes_range Nodes);
728728

729729
void printDot(std::ostream &Stream) const final;
730730
void emitInstrumentationData() final;

0 commit comments

Comments
 (0)