Skip to content

Commit d994487

Browse files
authored
[IR2Vec] Add llvm-ir2vec tool for generating triplet embeddings (#147842)
Add a new LLVM tool `llvm-ir2vec`. This tool is primarily intended to generate triplets for training the vocabulary (#141834) and to potentially generate the embeddings in a stand alone manner. This PR introduces the tool with triplet generation functionality. In the upcoming PRs I'll add scripts under `utils/mlgo` to complete the vocabulary tooling. #147844 adds embedding generation logic to the tool. (Tracking issue - #141817)
1 parent fd5fc76 commit d994487

File tree

5 files changed

+204
-0
lines changed

5 files changed

+204
-0
lines changed

llvm/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ set(LLVM_TEST_DEPENDS
9797
llvm-exegesis
9898
llvm-extract
9999
llvm-gsymutil
100+
llvm-ir2vec
100101
llvm-isel-fuzzer
101102
llvm-ifs
102103
llvm-install-name-tool

llvm/test/lit.cfg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def get_asan_rtlib():
197197
"llvm-dlltool",
198198
"llvm-exegesis",
199199
"llvm-extract",
200+
"llvm-ir2vec",
200201
"llvm-isel-fuzzer",
201202
"llvm-ifs",
202203
"llvm-install-name-tool",
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
; RUN: llvm-ir2vec %s | FileCheck %s -check-prefix=TRIPLETS
2+
3+
define i32 @simple_add(i32 %a, i32 %b) {
4+
entry:
5+
%add = add i32 %a, %b
6+
ret i32 %add
7+
}
8+
9+
define i32 @simple_mul(i32 %x, i32 %y) {
10+
entry:
11+
%mul = mul i32 %x, %y
12+
ret i32 %mul
13+
}
14+
15+
define i32 @test_function(i32 %arg1, i32 %arg2) {
16+
entry:
17+
%local1 = alloca i32, align 4
18+
%local2 = alloca i32, align 4
19+
store i32 %arg1, ptr %local1, align 4
20+
store i32 %arg2, ptr %local2, align 4
21+
%load1 = load i32, ptr %local1, align 4
22+
%load2 = load i32, ptr %local2, align 4
23+
%result = add i32 %load1, %load2
24+
ret i32 %result
25+
}
26+
27+
; TRIPLETS: Add IntegerTy Variable Variable
28+
; TRIPLETS-NEXT: Ret VoidTy Variable
29+
; TRIPLETS-NEXT: Mul IntegerTy Variable Variable
30+
; TRIPLETS-NEXT: Ret VoidTy Variable
31+
; TRIPLETS-NEXT: Alloca PointerTy Constant
32+
; TRIPLETS-NEXT: Alloca PointerTy Constant
33+
; TRIPLETS-NEXT: Store VoidTy Variable Pointer
34+
; TRIPLETS-NEXT: Store VoidTy Variable Pointer
35+
; TRIPLETS-NEXT: Load IntegerTy Pointer
36+
; TRIPLETS-NEXT: Load IntegerTy Pointer
37+
; TRIPLETS-NEXT: Add IntegerTy Variable Variable
38+
; TRIPLETS-NEXT: Ret VoidTy Variable

llvm/tools/llvm-ir2vec/CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
set(LLVM_LINK_COMPONENTS
2+
Analysis
3+
Core
4+
IRReader
5+
Support
6+
)
7+
8+
add_llvm_tool(llvm-ir2vec
9+
llvm-ir2vec.cpp
10+
)
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
//===- llvm-ir2vec.cpp - IR2Vec Embedding Generation Tool -----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
///
9+
/// \file
10+
/// This file implements the IR2Vec embedding generation tool.
11+
///
12+
/// Currently supports triplet generation for vocabulary training.
13+
/// Future updates will support embedding generation using trained vocabulary.
14+
///
15+
/// Usage: llvm-ir2vec input.bc -o triplets.txt
16+
///
17+
/// TODO: Add embedding generation mode with vocabulary support
18+
///
19+
//===----------------------------------------------------------------------===//
20+
21+
#include "llvm/Analysis/IR2Vec.h"
22+
#include "llvm/IR/BasicBlock.h"
23+
#include "llvm/IR/Function.h"
24+
#include "llvm/IR/Instructions.h"
25+
#include "llvm/IR/LLVMContext.h"
26+
#include "llvm/IR/Module.h"
27+
#include "llvm/IR/Type.h"
28+
#include "llvm/IRReader/IRReader.h"
29+
#include "llvm/Support/CommandLine.h"
30+
#include "llvm/Support/Debug.h"
31+
#include "llvm/Support/Errc.h"
32+
#include "llvm/Support/InitLLVM.h"
33+
#include "llvm/Support/SourceMgr.h"
34+
#include "llvm/Support/raw_ostream.h"
35+
36+
using namespace llvm;
37+
using namespace ir2vec;
38+
39+
#define DEBUG_TYPE "ir2vec"
40+
41+
static cl::OptionCategory IR2VecToolCategory("IR2Vec Tool Options");
42+
43+
static cl::opt<std::string> InputFilename(cl::Positional,
44+
cl::desc("<input bitcode file>"),
45+
cl::Required,
46+
cl::cat(IR2VecToolCategory));
47+
48+
static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),
49+
cl::value_desc("filename"),
50+
cl::init("-"),
51+
cl::cat(IR2VecToolCategory));
52+
53+
namespace {
54+
55+
/// Helper class for collecting IR information and generating triplets
56+
class IR2VecTool {
57+
private:
58+
Module &M;
59+
60+
public:
61+
explicit IR2VecTool(Module &M) : M(M) {}
62+
63+
/// Generate triplets for the entire module
64+
void generateTriplets(raw_ostream &OS) const {
65+
for (const Function &F : M)
66+
generateTriplets(F, OS);
67+
}
68+
69+
/// Generate triplets for a single function
70+
void generateTriplets(const Function &F, raw_ostream &OS) const {
71+
if (F.isDeclaration())
72+
return;
73+
74+
std::string LocalOutput;
75+
raw_string_ostream LocalOS(LocalOutput);
76+
77+
for (const BasicBlock &BB : F)
78+
traverseBasicBlock(BB, LocalOS);
79+
80+
LocalOS.flush();
81+
OS << LocalOutput;
82+
}
83+
84+
private:
85+
/// Process a single basic block for triplet generation
86+
void traverseBasicBlock(const BasicBlock &BB, raw_string_ostream &OS) const {
87+
// Consider only non-debug and non-pseudo instructions
88+
for (const auto &I : BB.instructionsWithoutDebug()) {
89+
StringRef OpcStr = Vocabulary::getVocabKeyForOpcode(I.getOpcode());
90+
StringRef TypeStr =
91+
Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID());
92+
93+
OS << '\n' << OpcStr << ' ' << TypeStr << ' ';
94+
95+
LLVM_DEBUG({
96+
I.print(dbgs());
97+
dbgs() << "\n";
98+
I.getType()->print(dbgs());
99+
dbgs() << " Type\n";
100+
});
101+
102+
for (const Use &U : I.operands())
103+
OS << Vocabulary::getVocabKeyForOperandKind(
104+
Vocabulary::getOperandKind(U.get()))
105+
<< ' ';
106+
}
107+
}
108+
};
109+
110+
Error processModule(Module &M, raw_ostream &OS) {
111+
IR2VecTool Tool(M);
112+
Tool.generateTriplets(OS);
113+
114+
return Error::success();
115+
}
116+
117+
} // anonymous namespace
118+
119+
int main(int argc, char **argv) {
120+
InitLLVM X(argc, argv);
121+
cl::HideUnrelatedOptions(IR2VecToolCategory);
122+
cl::ParseCommandLineOptions(
123+
argc, argv,
124+
"IR2Vec - Triplet Generation Tool\n"
125+
"Generates triplets for vocabulary training from LLVM IR.\n"
126+
"Future updates will support embedding generation.\n\n"
127+
"Usage:\n"
128+
" llvm-ir2vec input.bc -o triplets.txt\n");
129+
130+
// Parse the input LLVM IR file
131+
SMDiagnostic Err;
132+
LLVMContext Context;
133+
std::unique_ptr<Module> M = parseIRFile(InputFilename, Err, Context);
134+
if (!M) {
135+
Err.print(argv[0], errs());
136+
return 1;
137+
}
138+
139+
std::error_code EC;
140+
raw_fd_ostream OS(OutputFilename, EC);
141+
if (EC) {
142+
errs() << "Error opening output file: " << EC.message() << "\n";
143+
return 1;
144+
}
145+
146+
if (Error Err = processModule(*M, OS)) {
147+
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) {
148+
errs() << "Error: " << EIB.message() << "\n";
149+
});
150+
return 1;
151+
}
152+
153+
return 0;
154+
}

0 commit comments

Comments
 (0)