Skip to content

Commit 2657262

Browse files
committed
[IR2Vec] Scale vocab
1 parent b7ec652 commit 2657262

File tree

16 files changed

+397
-168
lines changed

16 files changed

+397
-168
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ struct Embedding {
108108
/// Arithmetic operators
109109
Embedding &operator+=(const Embedding &RHS);
110110
Embedding &operator-=(const Embedding &RHS);
111+
Embedding &operator*=(double Factor);
111112

112113
/// Adds Src Embedding scaled by Factor with the called Embedding.
113114
/// Called_Embedding += Src * Factor
@@ -116,6 +117,8 @@ struct Embedding {
116117
/// Returns true if the embedding is approximately equal to the RHS embedding
117118
/// within the specified tolerance.
118119
bool approximatelyEquals(const Embedding &RHS, double Tolerance = 1e-6) const;
120+
121+
void print(raw_ostream &OS) const;
119122
};
120123

121124
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
@@ -234,6 +237,8 @@ class IR2VecVocabResult {
234237
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
235238
ir2vec::Vocab Vocabulary;
236239
Error readVocabulary();
240+
Error parseVocabSection(const char *Key, const json::Value ParsedVocabValue,
241+
ir2vec::Vocab &TargetVocab, unsigned &Dim);
237242
void emitError(Error Err, LLVMContext &Ctx);
238243

239244
public:
@@ -249,14 +254,23 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
249254
/// functions.
250255
class IR2VecPrinterPass : public PassInfoMixin<IR2VecPrinterPass> {
251256
raw_ostream &OS;
252-
void printVector(const ir2vec::Embedding &Vec) const;
253257

254258
public:
255259
explicit IR2VecPrinterPass(raw_ostream &OS) : OS(OS) {}
256260
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
257261
static bool isRequired() { return true; }
258262
};
259263

264+
/// This pass prints the embeddings in the vocabulary
265+
class IR2VecVocabPrinterPass : public PassInfoMixin<IR2VecVocabPrinterPass> {
266+
raw_ostream &OS;
267+
268+
public:
269+
explicit IR2VecVocabPrinterPass(raw_ostream &OS) : OS(OS) {}
270+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
271+
static bool isRequired() { return true; }
272+
};
273+
260274
} // namespace llvm
261275

262276
#endif // LLVM_ANALYSIS_IR2VEC_H

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 101 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ Embedding &Embedding::operator-=(const Embedding &RHS) {
8585
return *this;
8686
}
8787

88+
Embedding &Embedding::operator*=(double Factor) {
89+
std::transform(this->begin(), this->end(), this->begin(),
90+
[Factor](double Elem) { return Elem * Factor; });
91+
return *this;
92+
}
93+
8894
Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
8995
assert(this->size() == Src.size() && "Vectors must have the same dimension");
9096
for (size_t Itr = 0; Itr < this->size(); ++Itr)
@@ -101,6 +107,13 @@ bool Embedding::approximatelyEquals(const Embedding &RHS,
101107
return true;
102108
}
103109

110+
void Embedding::print(raw_ostream &OS) const {
111+
OS << " [";
112+
for (const auto &Elem : Data)
113+
OS << " " << format("%.2f", Elem) << " ";
114+
OS << "]\n";
115+
}
116+
104117
// ==----------------------------------------------------------------------===//
105118
// Embedder and its subclasses
106119
//===----------------------------------------------------------------------===//
@@ -196,18 +209,12 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
196209
for (const auto &I : BB.instructionsWithoutDebug()) {
197210
Embedding InstVector(Dimension, 0);
198211

199-
const auto OpcVec = lookupVocab(I.getOpcodeName());
200-
InstVector.scaleAndAdd(OpcVec, OpcWeight);
201-
202212
// FIXME: Currently lookups are string based. Use numeric Keys
203213
// for efficiency.
204-
const auto Type = I.getType();
205-
const auto TypeVec = getTypeEmbedding(Type);
206-
InstVector.scaleAndAdd(TypeVec, TypeWeight);
207-
214+
InstVector += lookupVocab(I.getOpcodeName());
215+
InstVector += getTypeEmbedding(I.getType());
208216
for (const auto &Op : I.operands()) {
209-
const auto OperandVec = getOperandEmbedding(Op.get());
210-
InstVector.scaleAndAdd(OperandVec, ArgWeight);
217+
InstVector += getOperandEmbedding(Op.get());
211218
}
212219
InstVecMap[&I] = InstVector;
213220
BBVector += InstVector;
@@ -251,6 +258,46 @@ bool IR2VecVocabResult::invalidate(
251258
return !(PAC.preservedWhenStateless());
252259
}
253260

261+
Error IR2VecVocabAnalysis::parseVocabSection(const char *Key,
262+
const json::Value ParsedVocabValue,
263+
ir2vec::Vocab &TargetVocab,
264+
unsigned &Dim) {
265+
assert(Key && "Key cannot be null");
266+
267+
json::Path::Root Path("");
268+
const json::Object *RootObj = ParsedVocabValue.getAsObject();
269+
if (!RootObj)
270+
return createStringError(errc::invalid_argument,
271+
"JSON root is not an object");
272+
273+
const json::Value *SectionValue = RootObj->get(Key);
274+
if (!SectionValue)
275+
return createStringError(errc::invalid_argument,
276+
"Missing '" + std::string(Key) +
277+
"' section in vocabulary file");
278+
if (!json::fromJSON(*SectionValue, TargetVocab, Path))
279+
return createStringError(errc::illegal_byte_sequence,
280+
"Unable to parse '" + std::string(Key) +
281+
"' section from vocabulary");
282+
283+
Dim = TargetVocab.begin()->second.size();
284+
if (Dim == 0)
285+
return createStringError(errc::illegal_byte_sequence,
286+
"Dimension of '" + std::string(Key) +
287+
"' section of the vocabulary is zero");
288+
289+
if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
290+
[Dim](const std::pair<StringRef, Embedding> &Entry) {
291+
return Entry.second.size() == Dim;
292+
}))
293+
return createStringError(
294+
errc::illegal_byte_sequence,
295+
"All vectors in the '" + std::string(Key) +
296+
"' section of the vocabulary are not of the same dimension");
297+
298+
return Error::success();
299+
};
300+
254301
// FIXME: Make this optional. We can avoid file reads
255302
// by auto-generating a default vocabulary during the build time.
256303
Error IR2VecVocabAnalysis::readVocabulary() {
@@ -259,32 +306,40 @@ Error IR2VecVocabAnalysis::readVocabulary() {
259306
return createFileError(VocabFile, BufOrError.getError());
260307

261308
auto Content = BufOrError.get()->getBuffer();
262-
json::Path::Root Path("");
309+
263310
Expected<json::Value> ParsedVocabValue = json::parse(Content);
264311
if (!ParsedVocabValue)
265312
return ParsedVocabValue.takeError();
266313

267-
bool Res = json::fromJSON(*ParsedVocabValue, Vocabulary, Path);
268-
if (!Res)
269-
return createStringError(errc::illegal_byte_sequence,
270-
"Unable to parse the vocabulary");
314+
ir2vec::Vocab OpcodeVocab, TypeVocab, ArgVocab;
315+
unsigned OpcodeDim, TypeDim, ArgDim;
316+
if (auto Err = parseVocabSection("Opcodes", *ParsedVocabValue, OpcodeVocab,
317+
OpcodeDim))
318+
return Err;
271319

272-
if (Vocabulary.empty())
273-
return createStringError(errc::illegal_byte_sequence,
274-
"Vocabulary is empty");
320+
if (auto Err =
321+
parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim))
322+
return Err;
275323

276-
unsigned Dim = Vocabulary.begin()->second.size();
277-
if (Dim == 0)
324+
if (auto Err =
325+
parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim))
326+
return Err;
327+
328+
if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
278329
return createStringError(errc::illegal_byte_sequence,
279-
"Dimension of vocabulary is zero");
330+
"Vocabulary sections have different dimensions");
280331

281-
if (!std::all_of(Vocabulary.begin(), Vocabulary.end(),
282-
[Dim](const std::pair<StringRef, Embedding> &Entry) {
283-
return Entry.second.size() == Dim;
284-
}))
285-
return createStringError(
286-
errc::illegal_byte_sequence,
287-
"All vectors in the vocabulary are not of the same dimension");
332+
auto scaleVocabSection = [](ir2vec::Vocab &Vocab, double Weight) {
333+
for (auto &Entry : Vocab)
334+
Entry.second *= Weight;
335+
};
336+
scaleVocabSection(OpcodeVocab, OpcWeight);
337+
scaleVocabSection(TypeVocab, TypeWeight);
338+
scaleVocabSection(ArgVocab, ArgWeight);
339+
340+
Vocabulary.insert(OpcodeVocab.begin(), OpcodeVocab.end());
341+
Vocabulary.insert(TypeVocab.begin(), TypeVocab.end());
342+
Vocabulary.insert(ArgVocab.begin(), ArgVocab.end());
288343

289344
return Error::success();
290345
}
@@ -304,7 +359,7 @@ void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
304359
IR2VecVocabAnalysis::Result
305360
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
306361
auto Ctx = &M.getContext();
307-
// FIXME: Scale the vocabulary once. This would avoid scaling per use later.
362+
308363
// If vocabulary is already populated by the constructor, use it.
309364
if (!Vocabulary.empty())
310365
return IR2VecVocabResult(std::move(Vocabulary));
@@ -323,16 +378,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
323378
}
324379

325380
// ==----------------------------------------------------------------------===//
326-
// IR2VecPrinterPass
381+
// Printer Passes
327382
//===----------------------------------------------------------------------===//
328383

329-
void IR2VecPrinterPass::printVector(const Embedding &Vec) const {
330-
OS << " [";
331-
for (const auto &Elem : Vec)
332-
OS << " " << format("%.2f", Elem) << " ";
333-
OS << "]\n";
334-
}
335-
336384
PreservedAnalyses IR2VecPrinterPass::run(Module &M,
337385
ModuleAnalysisManager &MAM) {
338386
auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
@@ -353,15 +401,15 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
353401

354402
OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
355403
OS << "Function vector: ";
356-
printVector(Emb->getFunctionVector());
404+
Emb->getFunctionVector().print(OS);
357405

358406
OS << "Basic block vectors:\n";
359407
const auto &BBMap = Emb->getBBVecMap();
360408
for (const BasicBlock &BB : F) {
361409
auto It = BBMap.find(&BB);
362410
if (It != BBMap.end()) {
363411
OS << "Basic block: " << BB.getName() << ":\n";
364-
printVector(It->second);
412+
It->second.print(OS);
365413
}
366414
}
367415

@@ -373,10 +421,24 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
373421
if (It != InstMap.end()) {
374422
OS << "Instruction: ";
375423
I.print(OS);
376-
printVector(It->second);
424+
It->second.print(OS);
377425
}
378426
}
379427
}
380428
}
381429
return PreservedAnalyses::all();
382430
}
431+
432+
PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
433+
ModuleAnalysisManager &MAM) {
434+
auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
435+
assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid");
436+
437+
auto Vocab = IR2VecVocabResult.getVocabulary();
438+
for (const auto &Entry : Vocab) {
439+
OS << "Key: " << Entry.first << ": ";
440+
Entry.second.print(OS);
441+
}
442+
443+
return PreservedAnalyses::all();
444+
}

0 commit comments

Comments
 (0)