Skip to content

Commit f427dcf

Browse files
committed
llvm: Deduplicate LLVM-IR strings.
This changes the format of the LLVM-IR program graphs to store a list of unique strings, rather than LLVM-IR strings in each node. We use a graph-level "strings" feature to store a list of the original LLVM-IR string corresponding to each graph nodes. This allows to us to refer to the same string from multiple nodes without duplication. This breaks compatability with the inst2vec encoder on program graphs generated prior to this commit.
1 parent 6aa9b5b commit f427dcf

6 files changed

+80
-30
lines changed

programl/graph/program_graph_builder.h

+3
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ class ProgramGraphBuilder {
9999
inline Edge* AddEdge(const Edge::Flow& flow, int32_t position,
100100
const Node* source, const Node* target);
101101

102+
// Return a mutable pointer to the root node in the graph.
103+
Node* GetMutableRootNode() { return graph_.mutable_node(0); }
104+
102105
// Return a mutable pointer to the graph protocol buffer.
103106
ProgramGraph* GetMutableProgramGraph() { return &graph_; }
104107

programl/ir/llvm/inst2vec_encoder.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@
5050
)
5151

5252

53-
def NodeFullText(node: node_pb2.Node) -> str:
53+
def NodeFullText(
54+
graph: program_graph_pb2.ProgramGraph,
55+
node: node_pb2.Node
56+
) -> str:
5457
"""Get the full text of a node, or an empty string if not set."""
55-
if len(node.features.feature["full_text"].bytes_list.value):
56-
return (
57-
node.features.feature["full_text"].bytes_list.value[0].decode("utf-8")
58-
)
59-
return ""
58+
idx = node.features.feature["llvm_string"].int64_list.value[0]
59+
return graph.features.feature["strings"].bytes_list.value[idx].decode("utf-8")
6060

6161

6262
class Inst2vecEncoder(object):
@@ -94,7 +94,7 @@ def Encode(
9494
"""
9595
# Gather the instruction texts to pre-process.
9696
lines = [
97-
[NodeFullText(node)]
97+
[NodeFullText(proto, node)]
9898
for node in proto.node
9999
if node.type == node_pb2.Node.INSTRUCTION
100100
]

programl/ir/llvm/inst2vec_encoder_test.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,21 @@ def AddVariable(self, full_text: str):
5656

5757
def Build(self):
5858
proto = super(Inst2vecGraphBuilder, self).Build()
59+
60+
# Add the root node string feature.
61+
proto.node[0].features.feature["llvm_string"].int64_list.value[:] = [0]
62+
63+
# Build the strings list.
64+
strings_list = list(set(self.full_texts.values()))
65+
proto.features.feature["strings"].bytes_list.value[:] = [
66+
string.encode("utf-8") for string in strings_list
67+
]
68+
69+
# Add the string indices.
5970
for node, full_text in self.full_texts.items():
60-
proto.node[node].features.feature["full_text"].bytes_list.value.append(
61-
full_text.encode("utf-8")
62-
)
71+
idx = strings_list.index(full_text)
72+
node_feature = proto.node[node].features.feature["llvm_string"]
73+
node_feature.int64_list.value.append(idx)
6374
return proto
6475

6576

programl/ir/llvm/internal/program_graph_builder.cc

+28-5
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,18 @@ namespace ir {
3939
namespace llvm {
4040
namespace internal {
4141

42+
ProgramGraphBuilder::ProgramGraphBuilder(const ProgramGraphOptions& options)
43+
: programl::graph::ProgramGraphBuilder(),
44+
options_(options),
45+
blockCount_(0),
46+
stringsList_((*GetMutableProgramGraph()
47+
->mutable_features()
48+
->mutable_feature())["strings"]
49+
.mutable_bytes_list()) {
50+
// Add an empty string for the root node.
51+
graph::AddScalarFeature(GetMutableRootNode(), "llvm_string", AddString(""));
52+
}
53+
4254
labm8::StatusOr<BasicBlockEntryExit> ProgramGraphBuilder::VisitBasicBlock(
4355
const ::llvm::BasicBlock& block, const Function* functionMessage,
4456
InstructionMap* instructions, ArgumentConsumerMap* argumentConsumers,
@@ -194,7 +206,7 @@ labm8::StatusOr<FunctionEntryExits> ProgramGraphBuilder::VisitFunction(
194206

195207
if (function.isDeclaration()) {
196208
Node* node = AddInstruction("; undefined function", functionMessage);
197-
graph::AddScalarFeature(node, "full_text", "");
209+
graph::AddScalarFeature(node, "llvm_string", AddString(""));
198210
functionEntryExits.first = node;
199211
functionEntryExits.second.push_back(node);
200212
return functionEntryExits;
@@ -325,7 +337,7 @@ Node* ProgramGraphBuilder::AddLlvmInstruction(
325337
const LlvmTextComponents text = textEncoder_.Encode(instruction);
326338
Node* node = AddInstruction(text.opcode_name, function);
327339
node->set_block(blockCount_);
328-
graph::AddScalarFeature(node, "full_text", text.text);
340+
graph::AddScalarFeature(node, "llvm_string", AddString(text.text));
329341

330342
// Add profiling information features, if available.
331343
uint64_t profTotalWeight;
@@ -347,7 +359,7 @@ Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Instruction* operand,
347359
const LlvmTextComponents text = textEncoder_.Encode(operand);
348360
Node* node = AddVariable(text.lhs_type, function);
349361
node->set_block(blockCount_);
350-
graph::AddScalarFeature(node, "full_text", text.lhs);
362+
graph::AddScalarFeature(node, "llvm_string", AddString(text.lhs));
351363

352364
return node;
353365
}
@@ -357,7 +369,7 @@ Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Argument* argument,
357369
const LlvmTextComponents text = textEncoder_.Encode(argument);
358370
Node* node = AddVariable(text.lhs_type, function);
359371
node->set_block(blockCount_);
360-
graph::AddScalarFeature(node, "full_text", text.lhs);
372+
graph::AddScalarFeature(node, "llvm_string", AddString(text.lhs));
361373

362374
return node;
363375
}
@@ -366,7 +378,7 @@ Node* ProgramGraphBuilder::AddLlvmConstant(const ::llvm::Constant* constant) {
366378
const LlvmTextComponents text = textEncoder_.Encode(constant);
367379
Node* node = AddConstant(text.lhs_type);
368380
node->set_block(blockCount_);
369-
graph::AddScalarFeature(node, "full_text", text.text);
381+
graph::AddScalarFeature(node, "llvm_string", AddString(text.text));
370382

371383
return node;
372384
}
@@ -465,6 +477,17 @@ void ProgramGraphBuilder::Clear() {
465477
programl::graph::ProgramGraphBuilder::Clear();
466478
}
467479

480+
int32_t ProgramGraphBuilder::AddString(const string& text) {
481+
auto it = stringsListPositions_.find(text);
482+
if (it == stringsListPositions_.end()) {
483+
int32_t index = stringsListPositions_.size();
484+
stringsListPositions_[text] = index;
485+
stringsList_->add_value(text);
486+
return index;
487+
}
488+
return it->second;
489+
}
490+
468491
} // namespace internal
469492
} // namespace llvm
470493
} // namespace ir

programl/ir/llvm/internal/program_graph_builder.h

+16-6
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,10 @@ using ArgumentConsumerMap =
6464
// A specialized program graph builder for LLVM-IR.
6565
class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
6666
public:
67-
explicit ProgramGraphBuilder(const ProgramGraphOptions& options)
68-
: programl::graph::ProgramGraphBuilder(),
69-
options_(options),
70-
blockCount_(0){}
67+
explicit ProgramGraphBuilder(const ProgramGraphOptions& options);
7168

72-
[[nodiscard]] labm8::StatusOr<ProgramGraph> Build(
73-
const ::llvm::Module& module);
69+
[[nodiscard]] labm8::StatusOr<ProgramGraph> Build(
70+
const ::llvm::Module& module);
7471

7572
void Clear();
7673

@@ -94,6 +91,13 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
9491
const Function* function);
9592
Node* AddLlvmConstant(const ::llvm::Constant* constant);
9693

94+
// Add a string to the strings list and return its position.
95+
//
96+
// We use a graph-level "strings" feature to store a list of the original
97+
// LLVM-IR string corresponding to each graph nodes. This allows to us to
98+
// refer to the same string from multiple nodes without duplication.
99+
int32_t AddString(const string& text);
100+
97101
private:
98102
const ProgramGraphOptions options_;
99103

@@ -110,6 +114,12 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
110114
// visited.
111115
absl::flat_hash_map<const ::llvm::Constant*, std::vector<PositionalNode>>
112116
constants_;
117+
118+
// A mapping from string table value to its position in the "strings_table"
119+
// graph-level feature.
120+
absl::flat_hash_map<string, int32_t> stringsListPositions_;
121+
// The underlying storage for the strings table.
122+
BytesList* stringsList_;
113123
};
114124

115125
} // namespace internal

programl/ir/llvm/py/llvm_test.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,13 @@
3838
"""
3939

4040

41-
def GetStringScalar(proto, name):
42-
return proto.features.feature[name].bytes_list.value[0].decode("utf-8")
41+
def NodeFullText(
42+
graph: program_graph_pb2.ProgramGraph,
43+
node: node_pb2.Node
44+
) -> str:
45+
"""Get the full text of a node, or an empty string if not set."""
46+
idx = node.features.feature["llvm_string"].int64_list.value[0]
47+
return graph.features.feature["strings"].bytes_list.value[idx].decode("utf-8")
4348

4449

4550
def test_simple_ir():
@@ -57,27 +62,25 @@ def test_simple_ir():
5762

5863
assert proto.node[1].text == "add"
5964
assert proto.node[1].type == node_pb2.Node.INSTRUCTION
60-
assert (
61-
GetStringScalar(proto.node[1], "full_text") == "%3 = add nsw i32 %1, %0"
62-
)
65+
assert NodeFullText(proto, proto.node[1]) == "%3 = add nsw i32 %1, %0"
6366

6467
assert proto.node[2].text == "ret"
6568
assert proto.node[2].type == node_pb2.Node.INSTRUCTION
66-
assert GetStringScalar(proto.node[2], "full_text") == "ret i32 %3"
69+
assert NodeFullText(proto, proto.node[2]) == "ret i32 %3"
6770

6871
assert proto.node[3].text == "i32"
6972
assert proto.node[3].type == node_pb2.Node.VARIABLE
70-
assert GetStringScalar(proto.node[3], "full_text") == "i32 %3"
73+
assert NodeFullText(proto, proto.node[3]) == "i32 %3"
7174

7275
# Use startswith() to compare names for these last two variables as thier
7376
# order may differ.
7477
assert proto.node[4].text == "i32"
7578
assert proto.node[4].type == node_pb2.Node.VARIABLE
76-
assert GetStringScalar(proto.node[4], "full_text").startswith("i32 %")
79+
assert NodeFullText(proto, proto.node[4]).startswith("i32 %")
7780

7881
assert proto.node[5].text == "i32"
7982
assert proto.node[5].type == node_pb2.Node.VARIABLE
80-
assert GetStringScalar(proto.node[5], "full_text").startswith("i32 %")
83+
assert NodeFullText(proto, proto.node[5]).startswith("i32 %")
8184

8285

8386
def test_opt_level():

0 commit comments

Comments
 (0)