Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deduplicate LLVM-IR strings #101

Open
wants to merge 1 commit into
base: development
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions programl/graph/program_graph_builder.h
Original file line number Diff line number Diff line change
@@ -103,6 +103,9 @@ class ProgramGraphBuilder {
inline Edge* AddEdge(const Edge::Flow& flow, int32_t position, const Node* source,
const Node* target);

// Return a mutable pointer to the root node in the graph.
Node* GetMutableRootNode() { return graph_.mutable_node(0); }

// Return a mutable pointer to the graph protocol buffer.
ProgramGraph* GetMutableProgramGraph() { return &graph_; }

9 changes: 4 additions & 5 deletions programl/ir/llvm/inst2vec_encoder.py
Original file line number Diff line number Diff line change
@@ -43,11 +43,10 @@
)


def NodeFullText(node: node_pb2.Node) -> str:
def NodeFullText(graph: program_graph_pb2.ProgramGraph, node: node_pb2.Node) -> str:
"""Get the full text of a node, or an empty string if not set."""
if len(node.features.feature["full_text"].bytes_list.value):
return node.features.feature["full_text"].bytes_list.value[0].decode("utf-8")
return ""
idx = node.features.feature["llvm_string"].int64_list.value[0]
return graph.features.feature["strings"].bytes_list.value[idx].decode("utf-8")


class Inst2vecEncoder(object):
@@ -85,7 +84,7 @@ def Encode(
"""
# Gather the instruction texts to pre-process.
lines = [
[NodeFullText(node)]
[NodeFullText(proto, node)]
for node in proto.node
if node.type == node_pb2.Node.INSTRUCTION
]
17 changes: 14 additions & 3 deletions programl/ir/llvm/inst2vec_encoder_test.py
Original file line number Diff line number Diff line change
@@ -55,10 +55,21 @@ def AddVariable(self, full_text: str):

def Build(self):
proto = super(Inst2vecGraphBuilder, self).Build()

# Add the root node string feature.
proto.node[0].features.feature["llvm_string"].int64_list.value[:] = [0]

# Build the strings list.
strings_list = list(set(self.full_texts.values()))
proto.features.feature["strings"].bytes_list.value[:] = [
string.encode("utf-8") for string in strings_list
]

# Add the string indices.
for node, full_text in self.full_texts.items():
proto.node[node].features.feature["full_text"].bytes_list.value.append(
full_text.encode("utf-8")
)
idx = strings_list.index(full_text)
node_feature = proto.node[node].features.feature["llvm_string"]
node_feature.int64_list.value.append(idx)
return proto


41 changes: 36 additions & 5 deletions programl/ir/llvm/internal/program_graph_builder.cc
Original file line number Diff line number Diff line change
@@ -39,6 +39,16 @@ namespace ir {
namespace llvm {
namespace internal {

ProgramGraphBuilder::ProgramGraphBuilder(const ProgramGraphOptions& options)
: programl::graph::ProgramGraphBuilder(),
options_(options),
blockCount_(0),
stringsList_((*GetMutableProgramGraph()->mutable_features()->mutable_feature())["strings"]
.mutable_bytes_list()) {
// Add an empty
graph::AddScalarFeature(GetMutableRootNode(), "llvm_string", AddString(""));
}

labm8::StatusOr<BasicBlockEntryExit> ProgramGraphBuilder::VisitBasicBlock(
const ::llvm::BasicBlock& block, const Function* functionMessage, InstructionMap* instructions,
ArgumentConsumerMap* argumentConsumers, std::vector<DataEdge>* dataEdgesToAdd) {
@@ -184,7 +194,7 @@ labm8::StatusOr<FunctionEntryExits> ProgramGraphBuilder::VisitFunction(

if (function.isDeclaration()) {
Node* node = AddInstruction("; undefined function", functionMessage);
graph::AddScalarFeature(node, "full_text", "");
graph::AddScalarFeature(node, "llvm_string", AddString(""));
functionEntryExits.first = node;
functionEntryExits.second.push_back(node);
return functionEntryExits;
@@ -305,7 +315,7 @@ Node* ProgramGraphBuilder::AddLlvmInstruction(const ::llvm::Instruction* instruc
const LlvmTextComponents text = textEncoder_.Encode(instruction);
Node* node = AddInstruction(text.opcode_name, function);
node->set_block(blockCount_);
graph::AddScalarFeature(node, "full_text", text.text);
graph::AddScalarFeature(node, "llvm_string", AddString(text.text));

// Add profiling information features, if available.
uint64_t profTotalWeight;
@@ -327,7 +337,7 @@ Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Instruction* operand,
const LlvmTextComponents text = textEncoder_.Encode(operand);
Node* node = AddVariable(text.lhs_type, function);
node->set_block(blockCount_);
graph::AddScalarFeature(node, "full_text", text.lhs);
graph::AddScalarFeature(node, "llvm_string", AddString(text.lhs));

return node;
}
@@ -337,7 +347,7 @@ Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Argument* argument,
const LlvmTextComponents text = textEncoder_.Encode(argument);
Node* node = AddVariable(text.lhs_type, function);
node->set_block(blockCount_);
graph::AddScalarFeature(node, "full_text", text.lhs);
graph::AddScalarFeature(node, "llvm_string", AddString(text.lhs));

return node;
}
@@ -346,7 +356,7 @@ Node* ProgramGraphBuilder::AddLlvmConstant(const ::llvm::Constant* constant) {
const LlvmTextComponents text = textEncoder_.Encode(constant);
Node* node = AddConstant(text.lhs_type);
node->set_block(blockCount_);
graph::AddScalarFeature(node, "full_text", text.text);
graph::AddScalarFeature(node, "llvm_string", AddString(text.text));

return node;
}
@@ -436,6 +446,27 @@ void ProgramGraphBuilder::Clear() {
programl::graph::ProgramGraphBuilder::Clear();
}

Node* ProgramGraphBuilder::GetOrCreateType(const ::llvm::Type* type) {
auto it = types_.find(type);
if (it == types_.end()) {
Node* node = AddLlvmType(type);
types_[type] = node;
return node;
}
return it->second;
}

int32_t ProgramGraphBuilder::AddString(const string& text) {
auto it = stringsListPositions_.find(text);
if (it == stringsListPositions_.end()) {
int32_t index = stringsListPositions_.size();
stringsListPositions_[text] = index;
stringsList_->add_value(text);
return index;
}
return it->second;
}

} // namespace internal
} // namespace llvm
} // namespace ir
16 changes: 14 additions & 2 deletions programl/ir/llvm/internal/program_graph_builder.h
Original file line number Diff line number Diff line change
@@ -64,8 +64,7 @@ using ArgumentConsumerMap =
// A specialized program graph builder for LLVM-IR.
class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
public:
explicit ProgramGraphBuilder(const ProgramGraphOptions& options)
: programl::graph::ProgramGraphBuilder(options), blockCount_(0) {}
explicit ProgramGraphBuilder(const ProgramGraphOptions& options);

[[nodiscard]] labm8::StatusOr<ProgramGraph> Build(const ::llvm::Module& module);

@@ -87,6 +86,13 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
Node* AddLlvmVariable(const ::llvm::Argument* argument, const Function* function);
Node* AddLlvmConstant(const ::llvm::Constant* constant);

// Add a string to the strings list and return its position.
//
// We use a graph-level "strings" feature to store a list of the original
// LLVM-IR string corresponding to each graph nodes. This allows to us to
// refer to the same string from multiple nodes without duplication.
int32_t AddString(const string& text);

private:
TextEncoder textEncoder_;

@@ -100,6 +106,12 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
// populated by VisitBasicBlock() and consumed once all functions have been
// visited.
absl::flat_hash_map<const ::llvm::Constant*, std::vector<PositionalNode>> constants_;

// A mapping from string table value to its position in the "strings_table"
// graph-level feature.
absl::flat_hash_map<string, int32_t> stringsListPositions_;
// The underlying storage for the strings table.
BytesList* stringsList_;
};

} // namespace internal
16 changes: 9 additions & 7 deletions programl/ir/llvm/py/llvm_test.py
Original file line number Diff line number Diff line change
@@ -37,8 +37,10 @@
"""


def GetStringScalar(proto, name):
return proto.features.feature[name].bytes_list.value[0].decode("utf-8")
def NodeFullText(graph: program_graph_pb2.ProgramGraph, node: node_pb2.Node) -> str:
"""Get the full text of a node, or an empty string if not set."""
idx = node.features.feature["llvm_string"].int64_list.value[0]
return graph.features.feature["strings"].bytes_list.value[idx].decode("utf-8")


def test_simple_ir():
@@ -56,25 +58,25 @@ def test_simple_ir():

assert proto.node[1].text == "add"
assert proto.node[1].type == node_pb2.Node.INSTRUCTION
assert GetStringScalar(proto.node[1], "full_text") == "%3 = add nsw i32 %1, %0"
assert NodeFullText(proto, proto.node[1]) == "%3 = add nsw i32 %1, %0"

assert proto.node[2].text == "ret"
assert proto.node[2].type == node_pb2.Node.INSTRUCTION
assert GetStringScalar(proto.node[2], "full_text") == "ret i32 %3"
assert NodeFullText(proto, proto.node[2]) == "ret i32 %3"

assert proto.node[3].text == "i32"
assert proto.node[3].type == node_pb2.Node.VARIABLE
assert GetStringScalar(proto.node[3], "full_text") == "i32 %3"
assert NodeFullText(proto, proto.node[3]) == "i32 %3"

# Use startswith() to compare names for these last two variables as thier
# order may differ.
assert proto.node[4].text == "i32"
assert proto.node[4].type == node_pb2.Node.VARIABLE
assert GetStringScalar(proto.node[4], "full_text").startswith("i32 %")
assert NodeFullText(proto, proto.node[4]).startswith("i32 %")

assert proto.node[5].text == "i32"
assert proto.node[5].type == node_pb2.Node.VARIABLE
assert GetStringScalar(proto.node[5], "full_text").startswith("i32 %")
assert NodeFullText(proto, proto.node[5]).startswith("i32 %")


def test_opt_level():