Skip to content

Commit e7549db

Browse files
committed
TypeGraph: Introduce NodeTracker for efficient cycle detection
Added to Flattener and TypeIdentifier passes for now as a proof-of-concept. Other passes can come later.
1 parent 28b813c commit e7549db

11 files changed

+259
-44
lines changed

Diff for: oi/type_graph/Flattener.cpp

+5-12
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,19 @@ namespace type_graph {
2222

2323
Pass Flattener::createPass() {
2424
auto fn = [](TypeGraph& typeGraph) {
25-
Flattener flattener;
26-
flattener.flatten(typeGraph.rootTypes());
27-
// TODO should flatten just operate on a single type and we do the looping
28-
// here?
25+
Flattener flattener{typeGraph.resetTracker()};
26+
for (auto& type : typeGraph.rootTypes()) {
27+
flattener.accept(type);
28+
}
2929
};
3030

3131
return Pass("Flattener", fn);
3232
}
3333

34-
void Flattener::flatten(std::vector<std::reference_wrapper<Type>>& types) {
35-
for (auto& type : types) {
36-
accept(type);
37-
}
38-
}
39-
4034
void Flattener::accept(Type& type) {
41-
if (visited_.count(&type) != 0)
35+
if (tracker_.visit(type))
4236
return;
4337

44-
visited_.insert(&type);
4538
type.accept(*this);
4639
}
4740

Diff for: oi/type_graph/Flattener.h

+6-5
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
#pragma once
1717

1818
#include <string>
19-
#include <unordered_set>
2019
#include <vector>
2120

21+
#include "NodeTracker.h"
2222
#include "PassManager.h"
2323
#include "Types.h"
2424
#include "Visitor.h"
@@ -28,14 +28,15 @@ namespace type_graph {
2828
/*
2929
* Flattener
3030
*
31-
* Flattens classes by removing parents and adding their members directly into
32-
* derived classes.
31+
* Flattens classes by removing parents and adding their attributes directly
32+
* into derived classes.
3333
*/
3434
class Flattener : public RecursiveVisitor {
3535
public:
3636
static Pass createPass();
3737

38-
void flatten(std::vector<std::reference_wrapper<Type>>& types);
38+
Flattener(NodeTracker& tracker) : tracker_(tracker) {
39+
}
3940

4041
using RecursiveVisitor::accept;
4142

@@ -46,7 +47,7 @@ class Flattener : public RecursiveVisitor {
4647
static const inline std::string ParentPrefix = "__oi_parent";
4748

4849
private:
49-
std::unordered_set<Type*> visited_;
50+
NodeTracker& tracker_;
5051
std::vector<Member> flattened_members_;
5152
std::vector<uint64_t> offset_stack_;
5253
};

Diff for: oi/type_graph/NodeTracker.h

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <vector>
19+
20+
#include "Types.h"
21+
22+
namespace type_graph {
23+
24+
/*
25+
* NodeTracker
26+
*
27+
* Helper class for visitors. Efficiently tracks whether or not a graph node has
28+
* been seen before, to avoid infinite looping on cycles.
29+
*/
30+
class NodeTracker {
31+
public:
32+
NodeTracker() = default;
33+
NodeTracker(size_t size) : visited_(size) {
34+
}
35+
36+
/*
37+
* visit
38+
*
39+
* Marks a given node as visited.
40+
* Returns true if this node has already been visited, false otherwise.
41+
*/
42+
bool visit(const Type& type) {
43+
auto id = type.id();
44+
if (id < 0)
45+
return false;
46+
if (visited_.size() <= static_cast<size_t>(id))
47+
visited_.resize(id + 1);
48+
bool result = visited_[id];
49+
visited_[id] = true;
50+
return result;
51+
}
52+
53+
/*
54+
* reset
55+
*
56+
* Clears the contents of this NodeTracker and marks every node as unvisited.
57+
*/
58+
void reset() {
59+
std::fill(visited_.begin(), visited_.end(), false);
60+
}
61+
62+
void resize(size_t size) {
63+
visited_.resize(size);
64+
}
65+
66+
private:
67+
std::vector<bool> visited_;
68+
};
69+
70+
} // namespace type_graph

Diff for: oi/type_graph/TypeGraph.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717

1818
namespace type_graph {
1919

20+
NodeTracker& TypeGraph::resetTracker() noexcept {
21+
tracker_.reset();
22+
tracker_.resize(size());
23+
return tracker_;
24+
}
25+
2026
template <>
2127
Primitive& TypeGraph::makeType<Primitive>(Primitive::Kind kind) {
2228
switch (kind) {

Diff for: oi/type_graph/TypeGraph.h

+4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <memory>
2020
#include <vector>
2121

22+
#include "NodeTracker.h"
2223
#include "Types.h"
2324

2425
namespace type_graph {
@@ -47,6 +48,8 @@ class TypeGraph {
4748
rootTypes_.push_back(type);
4849
}
4950

51+
NodeTracker& resetTracker() noexcept;
52+
5053
// Override of the generic makeType function that returns singleton Primitive
5154
// objects
5255
template <typename T>
@@ -83,6 +86,7 @@ class TypeGraph {
8386
std::vector<std::reference_wrapper<Type>> rootTypes_;
8487
// Store all type objects in vectors for ownership. Order is not significant.
8588
std::vector<std::unique_ptr<Type>> types_;
89+
NodeTracker tracker_;
8690
NodeId next_id_ = 0;
8791
};
8892

Diff for: oi/type_graph/TypeIdentifier.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ namespace type_graph {
2323
Pass TypeIdentifier::createPass(
2424
const std::vector<ContainerInfo>& passThroughTypes) {
2525
auto fn = [&passThroughTypes](TypeGraph& typeGraph) {
26-
TypeIdentifier typeId{typeGraph, passThroughTypes};
26+
TypeIdentifier typeId{typeGraph.resetTracker(), typeGraph,
27+
passThroughTypes};
2728
for (auto& type : typeGraph.rootTypes()) {
2829
typeId.accept(type);
2930
}
@@ -48,10 +49,9 @@ bool TypeIdentifier::isAllocator(Type& t) {
4849
}
4950

5051
void TypeIdentifier::accept(Type& type) {
51-
if (visited_.count(&type) != 0)
52+
if (tracker_.visit(type))
5253
return;
5354

54-
visited_.insert(&type);
5555
type.accept(*this);
5656
}
5757

Diff for: oi/type_graph/TypeIdentifier.h

+7-5
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
*/
1616
#pragma once
1717

18-
#include <array>
19-
#include <unordered_set>
2018
#include <vector>
2119

20+
#include "NodeTracker.h"
2221
#include "PassManager.h"
2322
#include "Types.h"
2423
#include "Visitor.h"
@@ -38,9 +37,12 @@ class TypeIdentifier : public RecursiveVisitor {
3837
static Pass createPass(const std::vector<ContainerInfo>& passThroughTypes);
3938
static bool isAllocator(Type& t);
4039

41-
TypeIdentifier(TypeGraph& typeGraph,
40+
TypeIdentifier(NodeTracker& tracker,
41+
TypeGraph& typeGraph,
4242
const std::vector<ContainerInfo>& passThroughTypes)
43-
: typeGraph_(typeGraph), passThroughTypes_(passThroughTypes) {
43+
: tracker_(tracker),
44+
typeGraph_(typeGraph),
45+
passThroughTypes_(passThroughTypes) {
4446
}
4547

4648
using RecursiveVisitor::accept;
@@ -49,7 +51,7 @@ class TypeIdentifier : public RecursiveVisitor {
4951
void visit(Container& c) override;
5052

5153
private:
52-
std::unordered_set<Type*> visited_;
54+
NodeTracker& tracker_;
5355
TypeGraph& typeGraph_;
5456
const std::vector<ContainerInfo>& passThroughTypes_;
5557
};

0 commit comments

Comments
 (0)