Skip to content

Commit 702adeb

Browse files
committed
triplet-ext-script
1 parent 42f9479 commit 702adeb

File tree

2 files changed

+242
-0
lines changed

2 files changed

+242
-0
lines changed

llvm/docs/CommandGuide/llvm-ir2vec.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ embedding training (see
5050
<https://github.com/thunlp/OpenKE/tree/OpenKE-PyTorch?tab=readme-ov-file#data-format>
5151
for details).
5252

53+
See `llvm/utils/mlgo-utils/IR2Vec/generateTriplets.py` for more details on how
54+
these two modes are used to generate the triplets and entity mappings.
55+
5356
Triplet Generation Mode
5457
~~~~~~~~~~~~~~~~~~~~~~~
5558

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
"""IR2Vec Triplet Generator
5+
6+
Generates IR2Vec triplets by applying random optimization levels to LLVM IR files
7+
and extracting triplets using llvm-ir2vec. Automatically generates preprocessed
8+
files: entity2id.txt, relation2id.txt, and train2id.txt.
9+
10+
Usage:
11+
python generateTriplets.py <llvm_build_dir> <num_optimizations> <ll_file_list> <output_dir>
12+
"""
13+
14+
import argparse
15+
import logging
16+
import os
17+
import random
18+
import subprocess
19+
import sys
20+
from concurrent.futures import ThreadPoolExecutor, as_completed
21+
from pathlib import Path
22+
from typing import List, Set, Tuple
23+
24+
# Configuration
25+
OPT_LEVELS = ["O0", "O1", "O2", "O3", "Os", "Oz"]
26+
DEFAULT_MAX_WORKERS = 100
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
class TripletResult:
32+
"""Result from processing a single LLVM IR file"""
33+
__slots__ = ['triplets', 'max_relation']
34+
35+
def __init__(self, triplets: Set[str], max_relation: int):
36+
self.triplets = triplets
37+
self.max_relation = max_relation
38+
39+
40+
class IR2VecTripletGenerator:
41+
"""Main class for generating IR2Vec triplets"""
42+
43+
def __init__(self, llvm_build_dir: Path, num_optimizations: int,
44+
output_dir: Path, max_workers: int = DEFAULT_MAX_WORKERS):
45+
self.llvm_build_dir = llvm_build_dir
46+
self.num_optimizations = num_optimizations
47+
self.output_dir = output_dir
48+
self.max_workers = max_workers
49+
50+
# Tool paths
51+
self.opt_binary = os.path.join(llvm_build_dir, "bin", "opt")
52+
self.ir2vec_binary = os.path.join(llvm_build_dir, "bin", "llvm-ir2vec")
53+
54+
self._validate_setup()
55+
56+
def _validate_setup(self):
57+
"""Validate that all required tools and paths exist"""
58+
if not self.llvm_build_dir.exists():
59+
raise FileNotFoundError(f"LLVM build directory not found: {self.llvm_build_dir}")
60+
61+
if not os.path.isfile(self.opt_binary) or not os.access(self.opt_binary, os.X_OK):
62+
raise FileNotFoundError(f"opt binary not found or not executable: {self.opt_binary}")
63+
64+
if not os.path.isfile(self.ir2vec_binary) or not os.access(self.ir2vec_binary, os.X_OK):
65+
raise FileNotFoundError(f"llvm-ir2vec binary not found or not executable: {self.ir2vec_binary}")
66+
67+
if not (1 <= self.num_optimizations <= len(OPT_LEVELS)):
68+
raise ValueError(f"Number of optimizations must be between 1-{len(OPT_LEVELS)}")
69+
70+
self.output_dir.mkdir(parents=True, exist_ok=True)
71+
72+
def _select_optimization_levels(self) -> List[str]:
73+
"""Select unique random optimization levels"""
74+
return random.sample(OPT_LEVELS, self.num_optimizations)
75+
76+
def _process_single_file(self, input_file: Path) -> TripletResult:
77+
"""Process a single LLVM IR file with multiple optimization levels"""
78+
all_triplets = set()
79+
max_relation = 1
80+
opt_levels = self._select_optimization_levels()
81+
82+
for opt_level in opt_levels:
83+
try:
84+
triplets, file_max_relation = self._run_pipeline(input_file, opt_level)
85+
if triplets:
86+
all_triplets.update(triplets)
87+
max_relation = max(max_relation, file_max_relation)
88+
logger.debug(f"Generated {len(triplets)} triplets for {input_file} with {opt_level}")
89+
except Exception as e:
90+
logger.warning(f"Error processing {input_file} with {opt_level}: {e}")
91+
92+
return TripletResult(all_triplets, max_relation)
93+
94+
def _run_pipeline(self, input_file: Path, opt_level: str) -> Tuple[Set[str], int]:
95+
"""Run opt | llvm-ir2vec pipeline elegantly."""
96+
pipeline_cmd = (
97+
f'"{self.opt_binary}" -{opt_level} "{input_file}" -o - | '
98+
f'"{self.ir2vec_binary}" --mode=triplets - -o -'
99+
)
100+
101+
try:
102+
result = subprocess.run(
103+
pipeline_cmd, shell=True, capture_output=True, text=True, check=True
104+
)
105+
return self._parse_triplet_output(result.stdout)
106+
except subprocess.CalledProcessError:
107+
return set(), 1
108+
109+
def _parse_triplet_output(self, output: str) -> Tuple[Set[str], int]:
110+
"""Parse triplet output and extract max relation"""
111+
if not output.strip():
112+
return set(), 1
113+
114+
lines = output.strip().split('\n')
115+
max_relation = 1
116+
117+
# Extract max relation from metadata line
118+
if lines and lines[0].startswith("MAX_RELATION="):
119+
max_relation = int(lines[0].split('=')[1])
120+
lines = lines[1:]
121+
122+
# Remove duplicate triplets by converting to a set
123+
return set(lines), max_relation
124+
125+
def generate_triplets(self, file_list: Path) -> None:
126+
"""Main method to generate triplets from a list of LLVM IR files"""
127+
input_files = self._read_file_list(file_list)
128+
logger.info(f"Processing {len(input_files)} files with {self.num_optimizations} "
129+
f"optimization levels using {self.max_workers} workers")
130+
131+
all_triplets = set()
132+
global_max_relation = 1
133+
134+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
135+
future_to_file = {executor.submit(self._process_single_file, file): file
136+
for file in input_files}
137+
138+
for future in as_completed(future_to_file):
139+
try:
140+
result = future.result()
141+
all_triplets.update(result.triplets)
142+
global_max_relation = max(global_max_relation, result.max_relation)
143+
except Exception as e:
144+
file_path = future_to_file[future]
145+
logger.error(f"Error processing {file_path}: {e}")
146+
147+
self._generate_output_files(all_triplets, global_max_relation)
148+
logger.info("Processing completed successfully")
149+
150+
def _read_file_list(self, file_list: Path) -> List[Path]:
151+
"""Read and validate the list of input files"""
152+
input_files = []
153+
with open(file_list, 'r') as f:
154+
for line_num, line in enumerate(f, 1):
155+
if line := line.strip():
156+
file_path = Path(line)
157+
if file_path.exists():
158+
input_files.append(file_path)
159+
else:
160+
logger.warning(f"File not found (line {line_num}): {file_path}")
161+
162+
if not input_files:
163+
raise ValueError("No valid input files found")
164+
return input_files
165+
166+
def _generate_output_files(self, all_triplets: Set[str], max_relation: int) -> None:
167+
"""Generate the final output files"""
168+
logger.info(f"Generating output files with {len(all_triplets)} unique triplets")
169+
170+
# Write all output files -- train2id.txt, entity2id.txt, relation2id.txt
171+
train2id_file = os.path.join(self.output_dir, "train2id.txt")
172+
entity2id_file = os.path.join(self.output_dir, "entity2id.txt")
173+
relation2id_file = os.path.join(self.output_dir, "relation2id.txt")
174+
175+
with open(train2id_file, 'w') as f:
176+
f.write(f"{len(all_triplets)}\n")
177+
f.writelines(f"{triplet}\n" for triplet in all_triplets)
178+
179+
self._generate_entity2id(entity2id_file)
180+
self._generate_relation2id(relation2id_file, max_relation)
181+
182+
def _generate_entity2id(self, output_file: Path) -> None:
183+
"""Generate entity2id.txt using llvm-ir2vec"""
184+
subprocess.run([str(self.ir2vec_binary), "--mode=entities", "-o", str(output_file)],
185+
check=True, capture_output=True)
186+
187+
def _generate_relation2id(self, output_file: Path, max_relation: int) -> None:
188+
"""Generate relation2id.txt from max relation"""
189+
max_relation = max(max_relation, 1) # At least Type and Next relations
190+
num_relations = max_relation + 1
191+
192+
with open(output_file, 'w') as f:
193+
f.write(f"{num_relations}\n")
194+
f.write("Type\t0\n")
195+
f.write("Next\t1\n")
196+
f.writelines(f"Arg{i-2}\t{i}\n" for i in range(2, num_relations))
197+
198+
def main():
199+
"""Main entry point"""
200+
parser = argparse.ArgumentParser(
201+
description="Generate IR2Vec triplets from LLVM IR files",
202+
formatter_class=argparse.RawDescriptionHelpFormatter
203+
)
204+
205+
parser.add_argument("llvm_build_dir", type=Path,
206+
help="Path to LLVM build directory")
207+
parser.add_argument("num_optimizations", type=int,
208+
help="Number of optimization levels to apply (1-6)")
209+
parser.add_argument("ll_file_list", type=Path,
210+
help="File containing list of LLVM IR files to process")
211+
parser.add_argument("output_dir", type=Path,
212+
help="Output directory for generated files")
213+
parser.add_argument("-j", "--max-workers", type=int, default=DEFAULT_MAX_WORKERS,
214+
help=f"Maximum number of parallel workers (default: {DEFAULT_MAX_WORKERS})")
215+
parser.add_argument("-v", "--verbose", action="store_true",
216+
help="Enable debug logging")
217+
parser.add_argument("-q", "--quiet", action="store_true",
218+
help="Suppress all output except errors")
219+
220+
args = parser.parse_args()
221+
222+
# Configure logging
223+
level = logging.ERROR if args.quiet else (logging.DEBUG if args.verbose else logging.INFO)
224+
logging.basicConfig(level=level, format='[%(asctime)s] %(levelname)s: %(message)s',
225+
datefmt='%H:%M:%S')
226+
227+
try:
228+
generator = IR2VecTripletGenerator(
229+
args.llvm_build_dir, args.num_optimizations,
230+
args.output_dir, args.max_workers
231+
)
232+
generator.generate_triplets(args.ll_file_list)
233+
except Exception as e:
234+
logger.error(f"Error: {e}")
235+
sys.exit(1)
236+
237+
238+
if __name__ == "__main__":
239+
main()

0 commit comments

Comments
 (0)