31
31
#include " programl/graph/features.h"
32
32
#include " programl/ir/llvm/internal/text_encoder.h"
33
33
#include " programl/proto/program_graph.pb.h"
34
+ #include " program_graph_builder.h"
34
35
35
36
using labm8::Status;
36
37
@@ -39,6 +40,11 @@ namespace ir {
39
40
namespace llvm {
40
41
namespace internal {
41
42
43
+ template <typename T>
44
+ void AddFullTextFeature (T* element, const std::string& fullText) {
45
+ graph::AddScalarFeature (element, " full_text" , fullText);
46
+ }
47
+
42
48
labm8::StatusOr<BasicBlockEntryExit> ProgramGraphBuilder::VisitBasicBlock (
43
49
const ::llvm::BasicBlock& block, const Function* functionMessage,
44
50
InstructionMap* instructions, ArgumentConsumerMap* argumentConsumers,
@@ -194,7 +200,7 @@ labm8::StatusOr<FunctionEntryExits> ProgramGraphBuilder::VisitFunction(
194
200
195
201
if (function.isDeclaration ()) {
196
202
Node* node = AddInstruction (" ; undefined function" , functionMessage);
197
- graph::AddScalarFeature (node, " full_text " , " " );
203
+ AddFullTextFeature (node, " " );
198
204
functionEntryExits.first = node;
199
205
functionEntryExits.second .push_back (node);
200
206
return functionEntryExits;
@@ -325,7 +331,7 @@ Node* ProgramGraphBuilder::AddLlvmInstruction(
325
331
const LlvmTextComponents text = textEncoder_.Encode (instruction);
326
332
Node* node = AddInstruction (text.opcode_name , function);
327
333
node->set_block (blockCount_);
328
- graph::AddScalarFeature (node, " full_text " , text.text );
334
+ AddFullTextFeature (node, text.text );
329
335
330
336
// Add profiling information features, if available.
331
337
uint64_t profTotalWeight;
@@ -345,29 +351,88 @@ Node* ProgramGraphBuilder::AddLlvmInstruction(
345
351
Node* ProgramGraphBuilder::AddLlvmVariable (const ::llvm::Instruction* operand,
346
352
const programl::Function* function) {
347
353
const LlvmTextComponents text = textEncoder_.Encode (operand);
348
- Node* node = AddVariable (text. lhs_type , function);
354
+ Node* node = AddVariable (" var " , function);
349
355
node->set_block (blockCount_);
350
- graph::AddScalarFeature (node, " full_text" , text.lhs );
351
-
356
+ AddFullTextFeature (node, text.lhs );
352
357
return node;
353
358
}
354
359
355
360
Node* ProgramGraphBuilder::AddLlvmVariable (const ::llvm::Argument* argument,
356
361
const programl::Function* function) {
357
362
const LlvmTextComponents text = textEncoder_.Encode (argument);
358
- Node* node = AddVariable (text. lhs_type , function);
363
+ Node* node = AddVariable (" var " , function);
359
364
node->set_block (blockCount_);
360
- graph::AddScalarFeature (node, " full_text" , text.lhs );
365
+ AddFullTextFeature (node, text.lhs );
366
+
367
+ Node* type = GetOrCreateType (operand->getType ());
368
+ AddTypeEdge (/* position=*/ 0 , type, node);
361
369
362
370
return node;
363
371
}
364
372
365
373
Node* ProgramGraphBuilder::AddLlvmConstant (const ::llvm::Constant* constant) {
366
374
const LlvmTextComponents text = textEncoder_.Encode (constant);
367
- Node* node = AddConstant (text. lhs_type );
375
+ Node* node = AddConstant (" const " );
368
376
node->set_block (blockCount_);
369
- graph::AddScalarFeature (node, " full_text " , text.text );
377
+ AddFullTextFeature (node, text.text );
370
378
379
+ Node* type = GetOrCreateType (operand->getType ());
380
+ AddTypeEdge (/* position=*/ 0 , type, node);
381
+ return node;
382
+ }
383
+
384
+ Node* ProgramGraphBuilder::AddLlvmType (const ::llvm::Type* type) {
385
+ const LlvmTextComponents text = textEncoder_.Encode (constant);
386
+ Node* node = AddType (text.ls_type );
387
+ AddFullTextFeature (node, text.lhs );
388
+ return node;
389
+ }
390
+
391
+ Node* ProgramGraphBuilder::AddLlvmType (const ::llvm::StructType* type) {
392
+ Node* node = AddType (" struct" );
393
+ AddFullTextFeature (node, type->hasName () ? type->getName () : " struct" );
394
+
395
+ // Add types for the struct elements, and type edges.
396
+ for (int i = 0 ; i < type->getNumElements (); ++i) {
397
+ const auto & member = type->elements ()[i];
398
+ // Re-use the type if it already exists to prevent duplication of member types.
399
+ auto memberNode = GetOrCreateType (member);
400
+ AddTypeEdge (/* position=*/ i, memberNode, node);
401
+ }
402
+
403
+ return node;
404
+ }
405
+
406
+ Node* ProgramGraphBuilder::AddLlvmType (const ::llvm::PointerType* type) {
407
+ Node* node = AddType (" *" );
408
+ AddFullTextFeature (node, textEncoder_.Encode (constant).lhs );
409
+ // Re-use the type if it already exists to prevent duplication of element types.
410
+ auto elementType = GetOrCreateType (type->getElementType ());
411
+ AddTypeEdge (/* position=*/ 0 ,, elementType, node);
412
+ return node;
413
+ }
414
+
415
+ Node* ProgramGraphBuilder::AddLlvmType (const ::llvm::FunctionType* type) {
416
+ Node* node = AddType (" fn" );
417
+ AddFullTextFeature (node, textEncoder_.Encode (constant).lhs );
418
+ return node;
419
+ }
420
+
421
+ Node* ProgramGraphBuilder::AddLlvmType (const ::llvm::ArrayType* type) {
422
+ Node* node = AddType (" []" );
423
+ AddFullTextFeature (node, textEncoder_.Encode (constant).lhs );
424
+ // Re-use the type if it already exists to prevent duplication of element types.
425
+ auto elementType = GetOrCreateType (type->getElementType ());
426
+ AddTypeEdge (/* position=*/ 0 , elementType, node);
427
+ return node;
428
+ }
429
+
430
+ Node* ProgramGraphBuilder::AddLlvmType (const ::llvm::VectorType* type) {
431
+ Node* node = AddType (" vector" );
432
+ AddFullTextFeature (node, textEncoder_.Encode (constant).lhs );
433
+ // Re-use the type if it already exists to prevent duplication of element types.
434
+ auto elementType = GetOrCreateType (type->getElementType ());
435
+ AddTypeEdge (/* position=*/ 0 , elementType, node);
371
436
return node;
372
437
}
373
438
@@ -465,6 +530,16 @@ void ProgramGraphBuilder::Clear() {
465
530
programl::graph::ProgramGraphBuilder::Clear ();
466
531
}
467
532
533
+ Node* ProgramGraphBuilder::GetOrCreateType (const ::llvm::Type* type) {
534
+ auto it = types_.find (type);
535
+ if (it == types_.end ()) {
536
+ Node* node = AddLlvmType (type);
537
+ types_[type] = node;
538
+ return node;
539
+ }
540
+ return it->second ;
541
+ }
542
+
468
543
} // namespace internal
469
544
} // namespace llvm
470
545
} // namespace ir
0 commit comments