@@ -85,6 +85,12 @@ Embedding &Embedding::operator-=(const Embedding &RHS) {
85
85
return *this ;
86
86
}
87
87
88
+ Embedding &Embedding::operator *=(double Factor) {
89
+ std::transform (this ->begin (), this ->end (), this ->begin (),
90
+ [Factor](double Elem) { return Elem * Factor; });
91
+ return *this ;
92
+ }
93
+
88
94
Embedding &Embedding::scaleAndAdd (const Embedding &Src, float Factor) {
89
95
assert (this ->size () == Src.size () && " Vectors must have the same dimension" );
90
96
for (size_t Itr = 0 ; Itr < this ->size (); ++Itr)
@@ -101,6 +107,13 @@ bool Embedding::approximatelyEquals(const Embedding &RHS,
101
107
return true ;
102
108
}
103
109
110
+ void Embedding::print (raw_ostream &OS) const {
111
+ OS << " [" ;
112
+ for (const auto &Elem : Data)
113
+ OS << " " << format (" %.2f" , Elem) << " " ;
114
+ OS << " ]\n " ;
115
+ }
116
+
104
117
// ==----------------------------------------------------------------------===//
105
118
// Embedder and its subclasses
106
119
// ===----------------------------------------------------------------------===//
@@ -196,18 +209,12 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
196
209
for (const auto &I : BB.instructionsWithoutDebug ()) {
197
210
Embedding InstVector (Dimension, 0 );
198
211
199
- const auto OpcVec = lookupVocab (I.getOpcodeName ());
200
- InstVector.scaleAndAdd (OpcVec, OpcWeight);
201
-
202
212
// FIXME: Currently lookups are string based. Use numeric Keys
203
213
// for efficiency.
204
- const auto Type = I.getType ();
205
- const auto TypeVec = getTypeEmbedding (Type);
206
- InstVector.scaleAndAdd (TypeVec, TypeWeight);
207
-
214
+ InstVector += lookupVocab (I.getOpcodeName ());
215
+ InstVector += getTypeEmbedding (I.getType ());
208
216
for (const auto &Op : I.operands ()) {
209
- const auto OperandVec = getOperandEmbedding (Op.get ());
210
- InstVector.scaleAndAdd (OperandVec, ArgWeight);
217
+ InstVector += getOperandEmbedding (Op.get ());
211
218
}
212
219
InstVecMap[&I] = InstVector;
213
220
BBVector += InstVector;
@@ -251,6 +258,47 @@ bool IR2VecVocabResult::invalidate(
251
258
return !(PAC.preservedWhenStateless ());
252
259
}
253
260
261
+ Error IR2VecVocabAnalysis::parseVocabSection (const char *Key,
262
+ const json::Value ParsedVocabValue,
263
+ ir2vec::Vocab &TargetVocab,
264
+ unsigned &Dim) {
265
+ assert (Key && " Key cannot be null" );
266
+
267
+ json::Path::Root Path (" " );
268
+ const json::Object *RootObj = ParsedVocabValue.getAsObject ();
269
+ if (!RootObj) {
270
+ return createStringError (errc::invalid_argument,
271
+ " JSON root is not an object" );
272
+ }
273
+
274
+ const json::Value *SectionValue = RootObj->get (Key);
275
+ if (!SectionValue)
276
+ return createStringError (errc::invalid_argument,
277
+ " Missing '" + std::string (Key) +
278
+ " ' section in vocabulary file" );
279
+ if (!json::fromJSON (*SectionValue, TargetVocab, Path))
280
+ return createStringError (errc::illegal_byte_sequence,
281
+ " Unable to parse '" + std::string (Key) +
282
+ " ' section from vocabulary" );
283
+
284
+ Dim = TargetVocab.begin ()->second .size ();
285
+ if (Dim == 0 )
286
+ return createStringError (errc::illegal_byte_sequence,
287
+ " Dimension of '" + std::string (Key) +
288
+ " ' section of the vocabulary is zero" );
289
+
290
+ if (!std::all_of (TargetVocab.begin (), TargetVocab.end (),
291
+ [Dim](const std::pair<StringRef, Embedding> &Entry) {
292
+ return Entry.second .size () == Dim;
293
+ }))
294
+ return createStringError (
295
+ errc::illegal_byte_sequence,
296
+ " All vectors in the '" + std::string (Key) +
297
+ " ' section of the vocabulary are not of the same dimension" );
298
+
299
+ return Error::success ();
300
+ };
301
+
254
302
// FIXME: Make this optional. We can avoid file reads
255
303
// by auto-generating a default vocabulary during the build time.
256
304
Error IR2VecVocabAnalysis::readVocabulary () {
@@ -259,32 +307,40 @@ Error IR2VecVocabAnalysis::readVocabulary() {
259
307
return createFileError (VocabFile, BufOrError.getError ());
260
308
261
309
auto Content = BufOrError.get ()->getBuffer ();
262
- json::Path::Root Path ( " " );
310
+
263
311
Expected<json::Value> ParsedVocabValue = json::parse (Content);
264
312
if (!ParsedVocabValue)
265
313
return ParsedVocabValue.takeError ();
266
314
267
- bool Res = json::fromJSON (*ParsedVocabValue, Vocabulary, Path);
268
- if (!Res)
269
- return createStringError (errc::illegal_byte_sequence,
270
- " Unable to parse the vocabulary" );
315
+ ir2vec::Vocab OpcodeVocab, TypeVocab, ArgVocab;
316
+ unsigned OpcodeDim, TypeDim, ArgDim;
317
+ if (auto Err = parseVocabSection (" Opcodes" , *ParsedVocabValue, OpcodeVocab,
318
+ OpcodeDim))
319
+ return Err;
271
320
272
- if (Vocabulary. empty ())
273
- return createStringError (errc::illegal_byte_sequence,
274
- " Vocabulary is empty " ) ;
321
+ if (auto Err =
322
+ parseVocabSection ( " Types " , *ParsedVocabValue, TypeVocab, TypeDim))
323
+ return Err ;
275
324
276
- unsigned Dim = Vocabulary.begin ()->second .size ();
277
- if (Dim == 0 )
325
+ if (auto Err =
326
+ parseVocabSection (" Arguments" , *ParsedVocabValue, ArgVocab, ArgDim))
327
+ return Err;
328
+
329
+ if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
278
330
return createStringError (errc::illegal_byte_sequence,
279
- " Dimension of vocabulary is zero " );
331
+ " Vocabulary sections have different dimensions " );
280
332
281
- if (!std::all_of (Vocabulary.begin (), Vocabulary.end (),
282
- [Dim](const std::pair<StringRef, Embedding> &Entry) {
283
- return Entry.second .size () == Dim;
284
- }))
285
- return createStringError (
286
- errc::illegal_byte_sequence,
287
- " All vectors in the vocabulary are not of the same dimension" );
333
+ auto scaleVocabSection = [](ir2vec::Vocab &Vocab, double Weight) {
334
+ for (auto &Entry : Vocab)
335
+ Entry.second *= Weight;
336
+ };
337
+ scaleVocabSection (OpcodeVocab, OpcWeight);
338
+ scaleVocabSection (TypeVocab, TypeWeight);
339
+ scaleVocabSection (ArgVocab, ArgWeight);
340
+
341
+ Vocabulary.insert (OpcodeVocab.begin (), OpcodeVocab.end ());
342
+ Vocabulary.insert (TypeVocab.begin (), TypeVocab.end ());
343
+ Vocabulary.insert (ArgVocab.begin (), ArgVocab.end ());
288
344
289
345
return Error::success ();
290
346
}
@@ -304,7 +360,7 @@ void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
304
360
IR2VecVocabAnalysis::Result
305
361
IR2VecVocabAnalysis::run (Module &M, ModuleAnalysisManager &AM) {
306
362
auto Ctx = &M.getContext ();
307
- // FIXME: Scale the vocabulary once. This would avoid scaling per use later.
363
+
308
364
// If vocabulary is already populated by the constructor, use it.
309
365
if (!Vocabulary.empty ())
310
366
return IR2VecVocabResult (std::move (Vocabulary));
@@ -323,16 +379,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
323
379
}
324
380
325
381
// ==----------------------------------------------------------------------===//
326
- // IR2VecPrinterPass
382
+ // Printer Passes
327
383
// ===----------------------------------------------------------------------===//
328
384
329
- void IR2VecPrinterPass::printVector (const Embedding &Vec) const {
330
- OS << " [" ;
331
- for (const auto &Elem : Vec)
332
- OS << " " << format (" %.2f" , Elem) << " " ;
333
- OS << " ]\n " ;
334
- }
335
-
336
385
PreservedAnalyses IR2VecPrinterPass::run (Module &M,
337
386
ModuleAnalysisManager &MAM) {
338
387
auto IR2VecVocabResult = MAM.getResult <IR2VecVocabAnalysis>(M);
@@ -353,15 +402,15 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
353
402
354
403
OS << " IR2Vec embeddings for function " << F.getName () << " :\n " ;
355
404
OS << " Function vector: " ;
356
- printVector ( Emb->getFunctionVector ());
405
+ Emb->getFunctionVector (). print (OS );
357
406
358
407
OS << " Basic block vectors:\n " ;
359
408
const auto &BBMap = Emb->getBBVecMap ();
360
409
for (const BasicBlock &BB : F) {
361
410
auto It = BBMap.find (&BB);
362
411
if (It != BBMap.end ()) {
363
412
OS << " Basic block: " << BB.getName () << " :\n " ;
364
- printVector ( It->second );
413
+ It->second . print (OS );
365
414
}
366
415
}
367
416
@@ -373,10 +422,24 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
373
422
if (It != InstMap.end ()) {
374
423
OS << " Instruction: " ;
375
424
I.print (OS);
376
- printVector ( It->second );
425
+ It->second . print (OS );
377
426
}
378
427
}
379
428
}
380
429
}
381
430
return PreservedAnalyses::all ();
382
431
}
432
+
433
+ PreservedAnalyses IR2VecVocabPrinterPass::run (Module &M,
434
+ ModuleAnalysisManager &MAM) {
435
+ auto IR2VecVocabResult = MAM.getResult <IR2VecVocabAnalysis>(M);
436
+ assert (IR2VecVocabResult.isValid () && " IR2Vec Vocabulary is invalid" );
437
+
438
+ auto Vocab = IR2VecVocabResult.getVocabulary ();
439
+ for (const auto &Entry : Vocab) {
440
+ OS << " Key: " << Entry.first << " : " ;
441
+ Entry.second .print (OS);
442
+ }
443
+
444
+ return PreservedAnalyses::all ();
445
+ }
0 commit comments