Skip to content

Commit e7edd49

Browse files
[NFC][SYCL][Graph] Switch misc shared_ptr<node_impl> to raw ptr/ref (#19487)
1 parent 2e736fb commit e7edd49

File tree

8 files changed

+96
-116
lines changed

8 files changed

+96
-116
lines changed

sycl/source/detail/graph/dynamic_impl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ void dynamic_local_accessor_base::updateLocalAccessor(
9090
->updateLocalAccessor(NewAllocationSize);
9191
}
9292

93+
void dynamic_parameter_impl::registerNode(node_impl &NodeImpl, int ArgIndex) {
94+
MNodes.emplace_back(NodeImpl.weak_from_this(), ArgIndex);
95+
}
96+
9397
void dynamic_parameter_impl::updateValue(const raw_kernel_arg *NewRawValue,
9498
size_t Size) {
9599
// Number of bytes is taken from member of raw_kernel_arg object rather

sycl/source/detail/graph/dynamic_impl.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,7 @@ class dynamic_parameter_impl {
115115
/// @param NodeImpl The node to be registered
116116
/// @param ArgIndex The arg index for the kernel arg associated with this
117117
/// dynamic_parameter in NodeImpl
118-
void registerNode(std::shared_ptr<node_impl> NodeImpl, int ArgIndex) {
119-
MNodes.emplace_back(NodeImpl, ArgIndex);
120-
}
118+
void registerNode(node_impl &NodeImpl, int ArgIndex);
121119

122120
/// Struct detailing an instance of the usage of the dynamic parameter in a
123121
/// dynamic CG.

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 69 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,11 @@ void propagatePartitionUp(node_impl &Node, int PartitionNum) {
142142
/// @param PartitionNum Number to propagate.
143143
/// @param HostTaskList List of host tasks that have already been processed and
144144
/// are encountered as successors to the node Node.
145-
void propagatePartitionDown(
146-
node_impl &Node, int PartitionNum,
147-
std::list<std::shared_ptr<node_impl>> &HostTaskList) {
145+
void propagatePartitionDown(node_impl &Node, int PartitionNum,
146+
std::list<node_impl *> &HostTaskList) {
148147
if (Node.MCGType == sycl::detail::CGType::CodeplayHostTask) {
149148
if (Node.MPartitionNum != -1) {
150-
HostTaskList.push_front(Node.shared_from_this());
149+
HostTaskList.push_front(&Node);
151150
}
152151
return;
153152
}
@@ -181,11 +180,11 @@ void partition::updateSchedule() {
181180

182181
void exec_graph_impl::makePartitions() {
183182
int CurrentPartition = -1;
184-
std::list<std::shared_ptr<node_impl>> HostTaskList;
183+
std::list<node_impl *> HostTaskList;
185184
// find all the host-tasks in the graph
186-
for (auto &Node : MNodeStorage) {
187-
if (Node->MCGType == sycl::detail::CGType::CodeplayHostTask) {
188-
HostTaskList.push_back(Node);
185+
for (node_impl &Node : nodes()) {
186+
if (Node.MCGType == sycl::detail::CGType::CodeplayHostTask) {
187+
HostTaskList.push_back(&Node);
189188
}
190189
}
191190

@@ -215,29 +214,29 @@ void exec_graph_impl::makePartitions() {
215214
// group that includes the predecessor of `B` can be merged with the group of
216215
// the predecessors of the node `A`.
217216
while (HostTaskList.size() > 0) {
218-
auto Node = HostTaskList.front();
217+
node_impl &Node = *HostTaskList.front();
219218
HostTaskList.pop_front();
220219
CurrentPartition++;
221-
for (node_impl &Predecessor : Node->predecessors()) {
220+
for (node_impl &Predecessor : Node.predecessors()) {
222221
propagatePartitionUp(Predecessor, CurrentPartition);
223222
}
224223
CurrentPartition++;
225-
Node->MPartitionNum = CurrentPartition;
224+
Node.MPartitionNum = CurrentPartition;
226225
CurrentPartition++;
227226
auto TmpSize = HostTaskList.size();
228-
for (node_impl &Successor : Node->successors()) {
227+
for (node_impl &Successor : Node.successors()) {
229228
propagatePartitionDown(Successor, CurrentPartition, HostTaskList);
230229
}
231230
if (HostTaskList.size() > TmpSize) {
232231
// At least one HostTask has been re-numbered so group merge opportunities
233-
for (const auto &HT : HostTaskList) {
232+
for (node_impl *HT : HostTaskList) {
234233
auto HTPartitionNum = HT->MPartitionNum;
235234
if (HTPartitionNum != -1) {
236235
// can merge predecessors of node `Node` with predecessors of node
237236
// `HT` (HTPartitionNum-1) since HT must be reprocessed
238-
for (const auto &NodeImpl : MNodeStorage) {
239-
if (NodeImpl->MPartitionNum == Node->MPartitionNum - 1) {
240-
NodeImpl->MPartitionNum = HTPartitionNum - 1;
237+
for (node_impl &NodeImpl : nodes()) {
238+
if (NodeImpl.MPartitionNum == Node.MPartitionNum - 1) {
239+
NodeImpl.MPartitionNum = HTPartitionNum - 1;
241240
}
242241
}
243242
} else {
@@ -251,12 +250,12 @@ void exec_graph_impl::makePartitions() {
251250
int PartitionFinalNum = 0;
252251
for (int i = -1; i <= CurrentPartition; i++) {
253252
const std::shared_ptr<partition> &Partition = std::make_shared<partition>();
254-
for (auto &Node : MNodeStorage) {
255-
if (Node->MPartitionNum == i) {
256-
MPartitionNodes[Node.get()] = PartitionFinalNum;
257-
if (isPartitionRoot(*Node)) {
258-
Partition->MRoots.insert(Node.get());
259-
if (Node->MCGType == CGType::CodeplayHostTask) {
253+
for (node_impl &Node : nodes()) {
254+
if (Node.MPartitionNum == i) {
255+
MPartitionNodes[&Node] = PartitionFinalNum;
256+
if (isPartitionRoot(Node)) {
257+
Partition->MRoots.insert(&Node);
258+
if (Node.MCGType == CGType::CodeplayHostTask) {
260259
Partition->MIsHostTask = true;
261260
}
262261
}
@@ -295,8 +294,8 @@ void exec_graph_impl::makePartitions() {
295294
}
296295

297296
// Reset node groups (if node have to be re-processed - e.g. subgraph)
298-
for (auto &Node : MNodeStorage) {
299-
Node->MPartitionNum = -1;
297+
for (node_impl &Node : nodes()) {
298+
Node.MPartitionNum = -1;
300299
}
301300
}
302301

@@ -376,19 +375,19 @@ std::set<node_impl *> graph_impl::getCGEdges(
376375
// A unique set of dependencies obtained by checking requirements and events
377376
for (auto &Req : Requirements) {
378377
// Look through the graph for nodes which share this requirement
379-
for (auto &Node : MNodeStorage) {
380-
if (Node->hasRequirementDependency(Req)) {
378+
for (node_impl &Node : nodes()) {
379+
if (Node.hasRequirementDependency(Req)) {
381380
bool ShouldAddDep = true;
382381
// If any of this node's successors have this requirement then we skip
383382
// adding the current node as a dependency.
384-
for (node_impl &Succ : Node->successors()) {
383+
for (node_impl &Succ : Node.successors()) {
385384
if (Succ.hasRequirementDependency(Req)) {
386385
ShouldAddDep = false;
387386
break;
388387
}
389388
}
390389
if (ShouldAddDep) {
391-
UniqueDeps.insert(Node.get());
390+
UniqueDeps.insert(&Node);
392391
}
393392
}
394393
}
@@ -487,7 +486,7 @@ node_impl &graph_impl::add(std::function<void(handler &)> CGF,
487486
}
488487

489488
for (auto &[DynamicParam, ArgIndex] : DynamicParams) {
490-
DynamicParam->registerNode(NodeImpl.shared_from_this(), ArgIndex);
489+
DynamicParam->registerNode(NodeImpl, ArgIndex);
491490
}
492491

493492
return NodeImpl;
@@ -611,10 +610,9 @@ void graph_impl::setLastInorderNode(sycl::detail::queue_impl &Queue,
611610
MInorderQueueMap[Queue.weak_from_this()] = &Node;
612611
}
613612

614-
void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
615-
std::shared_ptr<node_impl> Dest) {
613+
void graph_impl::makeEdge(node_impl &Src, node_impl &Dest) {
616614
throwIfGraphRecordingQueue("make_edge()");
617-
if (Src == Dest) {
615+
if (&Src == &Dest) {
618616
throw sycl::exception(
619617
make_error_code(sycl::errc::invalid),
620618
"make_edge() cannot be called when Src and Dest are the same.");
@@ -624,8 +622,8 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
624622
bool DestFound = false;
625623
for (const auto &Node : MNodeStorage) {
626624

627-
SrcFound |= Node == Src;
628-
DestFound |= Node == Dest;
625+
SrcFound |= Node.get() == &Src;
626+
DestFound |= Node.get() == &Dest;
629627

630628
if (SrcFound && DestFound) {
631629
break;
@@ -641,49 +639,49 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
641639
"Dest must be a node inside the graph.");
642640
}
643641

644-
bool DestWasGraphRoot = Dest->MPredecessors.size() == 0;
642+
bool DestWasGraphRoot = Dest.MPredecessors.size() == 0;
645643

646644
// We need to add the edges first before checking for cycles
647-
Src->registerSuccessor(*Dest);
645+
Src.registerSuccessor(Dest);
648646

649-
bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors.size() == 1;
647+
bool DestLostRootStatus = DestWasGraphRoot && Dest.MPredecessors.size() == 1;
650648
if (DestLostRootStatus) {
651649
// Dest is no longer a Root node, so we need to remove it from MRoots.
652-
MRoots.erase(Dest.get());
650+
MRoots.erase(&Dest);
653651
}
654652

655653
// We can skip cycle checks if either Dest has no successors (cycle not
656654
// possible) or cycle checks have been disabled with the no_cycle_check
657655
// property;
658-
if (Dest->MSuccessors.empty() || !MSkipCycleChecks) {
656+
if (Dest.MSuccessors.empty() || !MSkipCycleChecks) {
659657
bool CycleFound = checkForCycles();
660658

661659
if (CycleFound) {
662660
// Remove the added successor and predecessor.
663-
Src->MSuccessors.pop_back();
664-
Dest->MPredecessors.pop_back();
661+
Src.MSuccessors.pop_back();
662+
Dest.MPredecessors.pop_back();
665663
if (DestLostRootStatus) {
666664
// Add Dest back into MRoots.
667-
MRoots.insert(Dest.get());
665+
MRoots.insert(&Dest);
668666
}
669667

670668
throw sycl::exception(make_error_code(sycl::errc::invalid),
671669
"Command graphs cannot contain cycles.");
672670
}
673671
}
674-
removeRoot(*Dest); // remove receiver from root node list
672+
removeRoot(Dest); // remove receiver from root node list
675673
}
676674

677675
std::vector<sycl::detail::EventImplPtr> graph_impl::getExitNodesEvents(
678676
std::weak_ptr<sycl::detail::queue_impl> RecordedQueue) {
679677
std::vector<sycl::detail::EventImplPtr> Events;
680678

681679
auto RecordedQueueSP = RecordedQueue.lock();
682-
for (auto &Node : MNodeStorage) {
683-
if (Node->MSuccessors.empty()) {
684-
auto EventForNode = getEventForNode(*Node);
680+
for (node_impl &Node : nodes()) {
681+
if (Node.MSuccessors.empty()) {
682+
auto EventForNode = getEventForNode(Node);
685683
if (EventForNode->getSubmittedQueue() == RecordedQueueSP) {
686-
Events.push_back(getEventForNode(*Node));
684+
Events.push_back(getEventForNode(Node));
687685
}
688686
}
689687
}
@@ -1433,15 +1431,14 @@ void exec_graph_impl::update(std::shared_ptr<graph_impl> GraphImpl) {
14331431
std::make_pair(GraphImpl->MNodeStorage[i]->MID, MNodeStorage[i].get()));
14341432
}
14351433

1436-
update(GraphImpl->MNodeStorage);
1434+
update(GraphImpl->nodes());
14371435
}
14381436

1439-
void exec_graph_impl::update(std::shared_ptr<node_impl> Node) {
1440-
this->update(std::vector<std::shared_ptr<node_impl>>{Node});
1437+
void exec_graph_impl::update(node_impl &Node) {
1438+
this->update(std::vector<node_impl *>{&Node});
14411439
}
14421440

1443-
void exec_graph_impl::update(
1444-
const std::vector<std::shared_ptr<node_impl>> &Nodes) {
1441+
void exec_graph_impl::update(nodes_range Nodes) {
14451442
if (!MIsUpdatable) {
14461443
throw sycl::exception(sycl::make_error_code(errc::invalid),
14471444
"update() cannot be called on a executable graph "
@@ -1502,7 +1499,7 @@ void exec_graph_impl::update(
15021499
}
15031500

15041501
bool exec_graph_impl::needsScheduledUpdate(
1505-
const std::vector<std::shared_ptr<node_impl>> &Nodes,
1502+
nodes_range Nodes,
15061503
std::vector<sycl::detail::AccessorImplHost *> &UpdateRequirements) {
15071504
// If there are any accessor requirements, we have to update through the
15081505
// scheduler to ensure that any allocations have taken place before trying to
@@ -1511,30 +1508,30 @@ bool exec_graph_impl::needsScheduledUpdate(
15111508
// At worst we may have as many requirements as there are for the entire graph
15121509
// for updating.
15131510
UpdateRequirements.reserve(MRequirements.size());
1514-
for (auto &Node : Nodes) {
1511+
for (node_impl &Node : Nodes) {
15151512
// Check if node(s) derived from this modifiable node exists in this graph
1516-
if (MIDCache.count(Node->getID()) == 0) {
1513+
if (MIDCache.count(Node.getID()) == 0) {
15171514
throw sycl::exception(
15181515
sycl::make_error_code(errc::invalid),
15191516
"Node passed to update() is not part of the graph.");
15201517
}
15211518

1522-
if (!Node->isUpdatable()) {
1519+
if (!Node.isUpdatable()) {
15231520
std::string ErrorString = "node_type::";
1524-
ErrorString += nodeTypeToString(Node->MNodeType);
1521+
ErrorString += nodeTypeToString(Node.MNodeType);
15251522
ErrorString +=
15261523
" nodes are not supported for update. Only kernel, host_task, "
15271524
"barrier and empty nodes are supported.";
15281525
throw sycl::exception(errc::invalid, ErrorString);
15291526
}
15301527

1531-
if (const auto &CG = Node->MCommandGroup;
1528+
if (const auto &CG = Node.MCommandGroup;
15321529
CG && CG->getRequirements().size() != 0) {
15331530
NeedScheduledUpdate = true;
15341531

15351532
UpdateRequirements.insert(UpdateRequirements.end(),
1536-
Node->MCommandGroup->getRequirements().begin(),
1537-
Node->MCommandGroup->getRequirements().end());
1533+
Node.MCommandGroup->getRequirements().begin(),
1534+
Node.MCommandGroup->getRequirements().end());
15381535
}
15391536
}
15401537

@@ -1740,18 +1737,17 @@ exec_graph_impl::getURUpdatableNodes(nodes_range Nodes) const {
17401737
return PartitionedNodes;
17411738
}
17421739

1743-
void exec_graph_impl::updateHostTasksImpl(
1744-
const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
1745-
for (auto &Node : Nodes) {
1746-
if (Node->MNodeType != node_type::host_task) {
1740+
void exec_graph_impl::updateHostTasksImpl(nodes_range Nodes) const {
1741+
for (node_impl &Node : Nodes) {
1742+
if (Node.MNodeType != node_type::host_task) {
17471743
continue;
17481744
}
17491745
// Query the ID cache to find the equivalent exec node for the node passed
17501746
// to this function.
1751-
auto ExecNode = MIDCache.find(Node->MID);
1747+
auto ExecNode = MIDCache.find(Node.MID);
17521748
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
17531749

1754-
ExecNode->second->updateFromOtherNode(*Node);
1750+
ExecNode->second->updateFromOtherNode(Node);
17551751
}
17561752
}
17571753

@@ -1852,21 +1848,18 @@ node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
18521848
void modifiable_command_graph::addGraphLeafDependencies(node Node) {
18531849
// Find all exit nodes in the current graph and add them to the dependency
18541850
// vector
1855-
std::shared_ptr<detail::node_impl> DstImpl =
1856-
sycl::detail::getSyclObjImpl(Node);
1851+
detail::node_impl &DstImpl = *sycl::detail::getSyclObjImpl(Node);
18571852
graph_impl::WriteLock Lock(impl->MMutex);
18581853
for (auto &NodeImpl : impl->MNodeStorage) {
1859-
if ((NodeImpl->MSuccessors.size() == 0) && (NodeImpl != DstImpl)) {
1860-
impl->makeEdge(NodeImpl, DstImpl);
1854+
if ((NodeImpl->MSuccessors.size() == 0) && (NodeImpl.get() != &DstImpl)) {
1855+
impl->makeEdge(*NodeImpl, DstImpl);
18611856
}
18621857
}
18631858
}
18641859

18651860
void modifiable_command_graph::make_edge(node &Src, node &Dest) {
1866-
std::shared_ptr<detail::node_impl> SenderImpl =
1867-
sycl::detail::getSyclObjImpl(Src);
1868-
std::shared_ptr<detail::node_impl> ReceiverImpl =
1869-
sycl::detail::getSyclObjImpl(Dest);
1861+
detail::node_impl &SenderImpl = *sycl::detail::getSyclObjImpl(Src);
1862+
detail::node_impl &ReceiverImpl = *sycl::detail::getSyclObjImpl(Dest);
18701863

18711864
graph_impl::WriteLock Lock(impl->MMutex);
18721865
impl->makeEdge(SenderImpl, ReceiverImpl);
@@ -2030,17 +2023,11 @@ void executable_command_graph::update(
20302023
}
20312024

20322025
void executable_command_graph::update(const node &Node) {
2033-
impl->update(sycl::detail::getSyclObjImpl(Node));
2026+
impl->update(*sycl::detail::getSyclObjImpl(Node));
20342027
}
20352028

20362029
void executable_command_graph::update(const std::vector<node> &Nodes) {
2037-
std::vector<std::shared_ptr<node_impl>> NodeImpls{};
2038-
NodeImpls.reserve(Nodes.size());
2039-
for (auto &Node : Nodes) {
2040-
NodeImpls.push_back(sycl::detail::getSyclObjImpl(Node));
2041-
}
2042-
2043-
impl->update(NodeImpls);
2030+
impl->update(Nodes);
20442031
}
20452032

20462033
size_t executable_command_graph::get_required_mem_size() const {

0 commit comments

Comments
 (0)