@@ -142,12 +142,11 @@ void propagatePartitionUp(node_impl &Node, int PartitionNum) {
142
142
// / @param PartitionNum Number to propagate.
143
143
// / @param HostTaskList List of host tasks that have already been processed and
144
144
// / are encountered as successors to the node Node.
145
- void propagatePartitionDown (
146
- node_impl &Node, int PartitionNum,
147
- std::list<std::shared_ptr<node_impl>> &HostTaskList) {
145
+ void propagatePartitionDown (node_impl &Node, int PartitionNum,
146
+ std::list<node_impl *> &HostTaskList) {
148
147
if (Node.MCGType == sycl::detail::CGType::CodeplayHostTask) {
149
148
if (Node.MPartitionNum != -1 ) {
150
- HostTaskList.push_front (Node. shared_from_this () );
149
+ HostTaskList.push_front (& Node);
151
150
}
152
151
return ;
153
152
}
@@ -181,11 +180,11 @@ void partition::updateSchedule() {
181
180
182
181
void exec_graph_impl::makePartitions () {
183
182
int CurrentPartition = -1 ;
184
- std::list<std::shared_ptr< node_impl> > HostTaskList;
183
+ std::list<node_impl * > HostTaskList;
185
184
// find all the host-tasks in the graph
186
- for (auto &Node : MNodeStorage ) {
187
- if (Node-> MCGType == sycl::detail::CGType::CodeplayHostTask) {
188
- HostTaskList.push_back (Node);
185
+ for (node_impl &Node : nodes () ) {
186
+ if (Node. MCGType == sycl::detail::CGType::CodeplayHostTask) {
187
+ HostTaskList.push_back (& Node);
189
188
}
190
189
}
191
190
@@ -215,29 +214,29 @@ void exec_graph_impl::makePartitions() {
215
214
// group that includes the predecessor of `B` can be merged with the group of
216
215
// the predecessors of the node `A`.
217
216
while (HostTaskList.size () > 0 ) {
218
- auto Node = HostTaskList.front ();
217
+ node_impl & Node = * HostTaskList.front ();
219
218
HostTaskList.pop_front ();
220
219
CurrentPartition++;
221
- for (node_impl &Predecessor : Node-> predecessors ()) {
220
+ for (node_impl &Predecessor : Node. predecessors ()) {
222
221
propagatePartitionUp (Predecessor, CurrentPartition);
223
222
}
224
223
CurrentPartition++;
225
- Node-> MPartitionNum = CurrentPartition;
224
+ Node. MPartitionNum = CurrentPartition;
226
225
CurrentPartition++;
227
226
auto TmpSize = HostTaskList.size ();
228
- for (node_impl &Successor : Node-> successors ()) {
227
+ for (node_impl &Successor : Node. successors ()) {
229
228
propagatePartitionDown (Successor, CurrentPartition, HostTaskList);
230
229
}
231
230
if (HostTaskList.size () > TmpSize) {
232
231
// At least one HostTask has been re-numbered so group merge opportunities
233
- for (const auto & HT : HostTaskList) {
232
+ for (node_impl * HT : HostTaskList) {
234
233
auto HTPartitionNum = HT->MPartitionNum ;
235
234
if (HTPartitionNum != -1 ) {
236
235
// can merge predecessors of node `Node` with predecessors of node
237
236
// `HT` (HTPartitionNum-1) since HT must be reprocessed
238
- for (const auto &NodeImpl : MNodeStorage ) {
239
- if (NodeImpl-> MPartitionNum == Node-> MPartitionNum - 1 ) {
240
- NodeImpl-> MPartitionNum = HTPartitionNum - 1 ;
237
+ for (node_impl &NodeImpl : nodes () ) {
238
+ if (NodeImpl. MPartitionNum == Node. MPartitionNum - 1 ) {
239
+ NodeImpl. MPartitionNum = HTPartitionNum - 1 ;
241
240
}
242
241
}
243
242
} else {
@@ -251,12 +250,12 @@ void exec_graph_impl::makePartitions() {
251
250
int PartitionFinalNum = 0 ;
252
251
for (int i = -1 ; i <= CurrentPartition; i++) {
253
252
const std::shared_ptr<partition> &Partition = std::make_shared<partition>();
254
- for (auto &Node : MNodeStorage ) {
255
- if (Node-> MPartitionNum == i) {
256
- MPartitionNodes[Node. get () ] = PartitionFinalNum;
257
- if (isPartitionRoot (* Node)) {
258
- Partition->MRoots .insert (Node. get () );
259
- if (Node-> MCGType == CGType::CodeplayHostTask) {
253
+ for (node_impl &Node : nodes () ) {
254
+ if (Node. MPartitionNum == i) {
255
+ MPartitionNodes[& Node] = PartitionFinalNum;
256
+ if (isPartitionRoot (Node)) {
257
+ Partition->MRoots .insert (& Node);
258
+ if (Node. MCGType == CGType::CodeplayHostTask) {
260
259
Partition->MIsHostTask = true ;
261
260
}
262
261
}
@@ -295,8 +294,8 @@ void exec_graph_impl::makePartitions() {
295
294
}
296
295
297
296
// Reset node groups (if node have to be re-processed - e.g. subgraph)
298
- for (auto &Node : MNodeStorage ) {
299
- Node-> MPartitionNum = -1 ;
297
+ for (node_impl &Node : nodes () ) {
298
+ Node. MPartitionNum = -1 ;
300
299
}
301
300
}
302
301
@@ -376,19 +375,19 @@ std::set<node_impl *> graph_impl::getCGEdges(
376
375
// A unique set of dependencies obtained by checking requirements and events
377
376
for (auto &Req : Requirements) {
378
377
// Look through the graph for nodes which share this requirement
379
- for (auto &Node : MNodeStorage ) {
380
- if (Node-> hasRequirementDependency (Req)) {
378
+ for (node_impl &Node : nodes () ) {
379
+ if (Node. hasRequirementDependency (Req)) {
381
380
bool ShouldAddDep = true ;
382
381
// If any of this node's successors have this requirement then we skip
383
382
// adding the current node as a dependency.
384
- for (node_impl &Succ : Node-> successors ()) {
383
+ for (node_impl &Succ : Node. successors ()) {
385
384
if (Succ.hasRequirementDependency (Req)) {
386
385
ShouldAddDep = false ;
387
386
break ;
388
387
}
389
388
}
390
389
if (ShouldAddDep) {
391
- UniqueDeps.insert (Node. get () );
390
+ UniqueDeps.insert (& Node);
392
391
}
393
392
}
394
393
}
@@ -487,7 +486,7 @@ node_impl &graph_impl::add(std::function<void(handler &)> CGF,
487
486
}
488
487
489
488
for (auto &[DynamicParam, ArgIndex] : DynamicParams) {
490
- DynamicParam->registerNode (NodeImpl. shared_from_this () , ArgIndex);
489
+ DynamicParam->registerNode (NodeImpl, ArgIndex);
491
490
}
492
491
493
492
return NodeImpl;
@@ -611,10 +610,9 @@ void graph_impl::setLastInorderNode(sycl::detail::queue_impl &Queue,
611
610
MInorderQueueMap[Queue.weak_from_this ()] = &Node;
612
611
}
613
612
614
- void graph_impl::makeEdge (std::shared_ptr<node_impl> Src,
615
- std::shared_ptr<node_impl> Dest) {
613
+ void graph_impl::makeEdge (node_impl &Src, node_impl &Dest) {
616
614
throwIfGraphRecordingQueue (" make_edge()" );
617
- if (Src == Dest) {
615
+ if (& Src == & Dest) {
618
616
throw sycl::exception (
619
617
make_error_code (sycl::errc::invalid),
620
618
" make_edge() cannot be called when Src and Dest are the same." );
@@ -624,8 +622,8 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
624
622
bool DestFound = false ;
625
623
for (const auto &Node : MNodeStorage) {
626
624
627
- SrcFound |= Node == Src;
628
- DestFound |= Node == Dest;
625
+ SrcFound |= Node. get () == & Src;
626
+ DestFound |= Node. get () == & Dest;
629
627
630
628
if (SrcFound && DestFound) {
631
629
break ;
@@ -641,49 +639,49 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
641
639
" Dest must be a node inside the graph." );
642
640
}
643
641
644
- bool DestWasGraphRoot = Dest-> MPredecessors .size () == 0 ;
642
+ bool DestWasGraphRoot = Dest. MPredecessors .size () == 0 ;
645
643
646
644
// We need to add the edges first before checking for cycles
647
- Src-> registerSuccessor (* Dest);
645
+ Src. registerSuccessor (Dest);
648
646
649
- bool DestLostRootStatus = DestWasGraphRoot && Dest-> MPredecessors .size () == 1 ;
647
+ bool DestLostRootStatus = DestWasGraphRoot && Dest. MPredecessors .size () == 1 ;
650
648
if (DestLostRootStatus) {
651
649
// Dest is no longer a Root node, so we need to remove it from MRoots.
652
- MRoots.erase (Dest. get () );
650
+ MRoots.erase (& Dest);
653
651
}
654
652
655
653
// We can skip cycle checks if either Dest has no successors (cycle not
656
654
// possible) or cycle checks have been disabled with the no_cycle_check
657
655
// property;
658
- if (Dest-> MSuccessors .empty () || !MSkipCycleChecks) {
656
+ if (Dest. MSuccessors .empty () || !MSkipCycleChecks) {
659
657
bool CycleFound = checkForCycles ();
660
658
661
659
if (CycleFound) {
662
660
// Remove the added successor and predecessor.
663
- Src-> MSuccessors .pop_back ();
664
- Dest-> MPredecessors .pop_back ();
661
+ Src. MSuccessors .pop_back ();
662
+ Dest. MPredecessors .pop_back ();
665
663
if (DestLostRootStatus) {
666
664
// Add Dest back into MRoots.
667
- MRoots.insert (Dest. get () );
665
+ MRoots.insert (& Dest);
668
666
}
669
667
670
668
throw sycl::exception (make_error_code (sycl::errc::invalid),
671
669
" Command graphs cannot contain cycles." );
672
670
}
673
671
}
674
- removeRoot (* Dest); // remove receiver from root node list
672
+ removeRoot (Dest); // remove receiver from root node list
675
673
}
676
674
677
675
std::vector<sycl::detail::EventImplPtr> graph_impl::getExitNodesEvents (
678
676
std::weak_ptr<sycl::detail::queue_impl> RecordedQueue) {
679
677
std::vector<sycl::detail::EventImplPtr> Events;
680
678
681
679
auto RecordedQueueSP = RecordedQueue.lock ();
682
- for (auto &Node : MNodeStorage ) {
683
- if (Node-> MSuccessors .empty ()) {
684
- auto EventForNode = getEventForNode (* Node);
680
+ for (node_impl &Node : nodes () ) {
681
+ if (Node. MSuccessors .empty ()) {
682
+ auto EventForNode = getEventForNode (Node);
685
683
if (EventForNode->getSubmittedQueue () == RecordedQueueSP) {
686
- Events.push_back (getEventForNode (* Node));
684
+ Events.push_back (getEventForNode (Node));
687
685
}
688
686
}
689
687
}
@@ -1433,15 +1431,14 @@ void exec_graph_impl::update(std::shared_ptr<graph_impl> GraphImpl) {
1433
1431
std::make_pair (GraphImpl->MNodeStorage [i]->MID , MNodeStorage[i].get ()));
1434
1432
}
1435
1433
1436
- update (GraphImpl->MNodeStorage );
1434
+ update (GraphImpl->nodes () );
1437
1435
}
1438
1436
1439
- void exec_graph_impl::update (std::shared_ptr< node_impl> Node) {
1440
- this ->update (std::vector<std::shared_ptr< node_impl>>{ Node});
1437
+ void exec_graph_impl::update (node_impl & Node) {
1438
+ this ->update (std::vector<node_impl *>{& Node});
1441
1439
}
1442
1440
1443
- void exec_graph_impl::update (
1444
- const std::vector<std::shared_ptr<node_impl>> &Nodes) {
1441
+ void exec_graph_impl::update (nodes_range Nodes) {
1445
1442
if (!MIsUpdatable) {
1446
1443
throw sycl::exception (sycl::make_error_code (errc::invalid),
1447
1444
" update() cannot be called on a executable graph "
@@ -1502,7 +1499,7 @@ void exec_graph_impl::update(
1502
1499
}
1503
1500
1504
1501
bool exec_graph_impl::needsScheduledUpdate (
1505
- const std::vector<std::shared_ptr<node_impl>> & Nodes,
1502
+ nodes_range Nodes,
1506
1503
std::vector<sycl::detail::AccessorImplHost *> &UpdateRequirements) {
1507
1504
// If there are any accessor requirements, we have to update through the
1508
1505
// scheduler to ensure that any allocations have taken place before trying to
@@ -1511,30 +1508,30 @@ bool exec_graph_impl::needsScheduledUpdate(
1511
1508
// At worst we may have as many requirements as there are for the entire graph
1512
1509
// for updating.
1513
1510
UpdateRequirements.reserve (MRequirements.size ());
1514
- for (auto &Node : Nodes) {
1511
+ for (node_impl &Node : Nodes) {
1515
1512
// Check if node(s) derived from this modifiable node exists in this graph
1516
- if (MIDCache.count (Node-> getID ()) == 0 ) {
1513
+ if (MIDCache.count (Node. getID ()) == 0 ) {
1517
1514
throw sycl::exception (
1518
1515
sycl::make_error_code (errc::invalid),
1519
1516
" Node passed to update() is not part of the graph." );
1520
1517
}
1521
1518
1522
- if (!Node-> isUpdatable ()) {
1519
+ if (!Node. isUpdatable ()) {
1523
1520
std::string ErrorString = " node_type::" ;
1524
- ErrorString += nodeTypeToString (Node-> MNodeType );
1521
+ ErrorString += nodeTypeToString (Node. MNodeType );
1525
1522
ErrorString +=
1526
1523
" nodes are not supported for update. Only kernel, host_task, "
1527
1524
" barrier and empty nodes are supported." ;
1528
1525
throw sycl::exception (errc::invalid, ErrorString);
1529
1526
}
1530
1527
1531
- if (const auto &CG = Node-> MCommandGroup ;
1528
+ if (const auto &CG = Node. MCommandGroup ;
1532
1529
CG && CG->getRequirements ().size () != 0 ) {
1533
1530
NeedScheduledUpdate = true ;
1534
1531
1535
1532
UpdateRequirements.insert (UpdateRequirements.end (),
1536
- Node-> MCommandGroup ->getRequirements ().begin (),
1537
- Node-> MCommandGroup ->getRequirements ().end ());
1533
+ Node. MCommandGroup ->getRequirements ().begin (),
1534
+ Node. MCommandGroup ->getRequirements ().end ());
1538
1535
}
1539
1536
}
1540
1537
@@ -1740,18 +1737,17 @@ exec_graph_impl::getURUpdatableNodes(nodes_range Nodes) const {
1740
1737
return PartitionedNodes;
1741
1738
}
1742
1739
1743
- void exec_graph_impl::updateHostTasksImpl (
1744
- const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
1745
- for (auto &Node : Nodes) {
1746
- if (Node->MNodeType != node_type::host_task) {
1740
+ void exec_graph_impl::updateHostTasksImpl (nodes_range Nodes) const {
1741
+ for (node_impl &Node : Nodes) {
1742
+ if (Node.MNodeType != node_type::host_task) {
1747
1743
continue ;
1748
1744
}
1749
1745
// Query the ID cache to find the equivalent exec node for the node passed
1750
1746
// to this function.
1751
- auto ExecNode = MIDCache.find (Node-> MID );
1747
+ auto ExecNode = MIDCache.find (Node. MID );
1752
1748
assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
1753
1749
1754
- ExecNode->second ->updateFromOtherNode (* Node);
1750
+ ExecNode->second ->updateFromOtherNode (Node);
1755
1751
}
1756
1752
}
1757
1753
@@ -1852,21 +1848,18 @@ node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
1852
1848
void modifiable_command_graph::addGraphLeafDependencies (node Node) {
1853
1849
// Find all exit nodes in the current graph and add them to the dependency
1854
1850
// vector
1855
- std::shared_ptr<detail::node_impl> DstImpl =
1856
- sycl::detail::getSyclObjImpl (Node);
1851
+ detail::node_impl &DstImpl = *sycl::detail::getSyclObjImpl (Node);
1857
1852
graph_impl::WriteLock Lock (impl->MMutex );
1858
1853
for (auto &NodeImpl : impl->MNodeStorage ) {
1859
- if ((NodeImpl->MSuccessors .size () == 0 ) && (NodeImpl != DstImpl)) {
1860
- impl->makeEdge (NodeImpl, DstImpl);
1854
+ if ((NodeImpl->MSuccessors .size () == 0 ) && (NodeImpl. get () != & DstImpl)) {
1855
+ impl->makeEdge (* NodeImpl, DstImpl);
1861
1856
}
1862
1857
}
1863
1858
}
1864
1859
1865
1860
void modifiable_command_graph::make_edge (node &Src, node &Dest) {
1866
- std::shared_ptr<detail::node_impl> SenderImpl =
1867
- sycl::detail::getSyclObjImpl (Src);
1868
- std::shared_ptr<detail::node_impl> ReceiverImpl =
1869
- sycl::detail::getSyclObjImpl (Dest);
1861
+ detail::node_impl &SenderImpl = *sycl::detail::getSyclObjImpl (Src);
1862
+ detail::node_impl &ReceiverImpl = *sycl::detail::getSyclObjImpl (Dest);
1870
1863
1871
1864
graph_impl::WriteLock Lock (impl->MMutex );
1872
1865
impl->makeEdge (SenderImpl, ReceiverImpl);
@@ -2030,17 +2023,11 @@ void executable_command_graph::update(
2030
2023
}
2031
2024
2032
2025
void executable_command_graph::update (const node &Node) {
2033
- impl->update (sycl::detail::getSyclObjImpl (Node));
2026
+ impl->update (* sycl::detail::getSyclObjImpl (Node));
2034
2027
}
2035
2028
2036
2029
void executable_command_graph::update (const std::vector<node> &Nodes) {
2037
- std::vector<std::shared_ptr<node_impl>> NodeImpls{};
2038
- NodeImpls.reserve (Nodes.size ());
2039
- for (auto &Node : Nodes) {
2040
- NodeImpls.push_back (sycl::detail::getSyclObjImpl (Node));
2041
- }
2042
-
2043
- impl->update (NodeImpls);
2030
+ impl->update (Nodes);
2044
2031
}
2045
2032
2046
2033
size_t executable_command_graph::get_required_mem_size () const {
0 commit comments