Skip to content

Commit b3eb8ff

Browse files
Experimental - Kerngraph (loop interchange) Initial Commit
Adds a new python script called "kerngraph" which will be used for kernel graph representation and optimization of kernel output. Currently, kerngraph supports loop interchange for part/rns for non-composable kernels such as add, sub, mul, or muli. Future PR will include ntt/intt, mod, relin, etc. Kerngraph is currently experimental and is subject to change in the near future.
1 parent e84776b commit b3eb8ff

File tree

9 files changed

+622
-1
lines changed

9 files changed

+622
-1
lines changed

p-isa_tools/kerngen/const/options.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Module for defining constants and enums used in the kernel generator"""
5+
from enum import Enum
6+
7+
8+
class LoopKey(Enum):
9+
"""Sort keys for PIsaOp instructions"""
10+
11+
RNS = "rns"
12+
PART = "part"
13+
UNIT = "unit"
14+
15+
@classmethod
16+
def from_str(cls, value: str) -> "LoopKey":
17+
"""Convert a string to a LoopKey enum"""
18+
if value is None:
19+
raise ValueError("LoopKey cannot be None")
20+
try:
21+
return cls[value.upper()]
22+
except KeyError:
23+
raise ValueError(f"Invalid LoopKey: {value}") from None

p-isa_tools/kerngen/high_parser/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __call__(self, *args) -> str:
4646
return self.expand(*args)
4747

4848
def __repr__(self) -> str:
49-
return self.name
49+
return f"Polys(name={self.name}, parts={self.parts}, rns={self.rns})"
5050

5151
@classmethod
5252
def from_polys(cls, poly: "Polys", *, mode: str | None = None) -> "Polys":
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Module for loop interchange optimization in P-ISA operations"""
5+
6+
import re
7+
from const.options import LoopKey
8+
from high_parser.pisa_operations import PIsaOp, Comment
9+
10+
11+
def loop_interchange(
12+
pisa_list: list[PIsaOp],
13+
primary_key: LoopKey | None = LoopKey.PART,
14+
secondary_key: LoopKey | None = LoopKey.RNS,
15+
) -> list[PIsaOp]:
16+
"""Batch pisa_list into groups and sort them by primary and optional secondary keys.
17+
18+
Args:
19+
pisa_list: List of PIsaOp instructions
20+
primary_key: Primary sort criterion from SortKey enum
21+
secondary_key: Optional secondary sort criterion from SortKey enum
22+
23+
Returns:
24+
List of processed PIsaOp instructions
25+
26+
Raises:
27+
ValueError: If invalid sort key values provided
28+
"""
29+
if primary_key is None and secondary_key is None:
30+
return pisa_list
31+
32+
def get_sort_value(pisa: PIsaOp, key: LoopKey) -> int:
33+
match key:
34+
case LoopKey.RNS:
35+
return pisa.q
36+
case LoopKey.PART:
37+
match = re.search(r"_(\d+)_", str(pisa))
38+
return int(match[1]) if match else 0
39+
case LoopKey.UNIT:
40+
match = re.search(r"_(\d+),", str(pisa))
41+
return int(match[1]) if match else 0
42+
case _:
43+
raise ValueError(f"Invalid sort key value: {key}")
44+
45+
def get_sort_key(pisa: PIsaOp) -> tuple:
46+
primary_value = get_sort_value(pisa, primary_key)
47+
if secondary_key:
48+
secondary_value = get_sort_value(pisa, secondary_key)
49+
return (primary_value, secondary_value)
50+
return (primary_value,)
51+
52+
# Filter out comments
53+
pisa_list_wo_comments = [p for p in pisa_list if not isinstance(p, Comment)]
54+
# Sort based on primary and optional secondary keys
55+
pisa_list_wo_comments.sort(key=get_sort_key)
56+
return pisa_list_wo_comments
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Module for parsing kernel commands from Kerngen"""
5+
6+
import re
7+
from high_parser.types import Immediate, KernelContext, Polys, Context
8+
from pisa_generators.basic import Copy, HighOp, Add, Sub, Mul, Muli
9+
from pisa_generators.ntt import NTT, INTT
10+
from pisa_generators.square import Square
11+
from pisa_generators.relin import Relin
12+
from pisa_generators.rotate import Rotate
13+
from pisa_generators.mod import Mod, ModUp
14+
from pisa_generators.rescale import Rescale
15+
16+
17+
class KernelParser:
18+
"""Parser for kernel operations."""
19+
20+
high_op_map = {
21+
"Add": Add,
22+
"Mul": Mul,
23+
"Muli": Muli,
24+
"Copy": Copy,
25+
"Sub": Sub,
26+
"Square": Square,
27+
"NTT": NTT,
28+
"INTT": INTT,
29+
"Mod": Mod,
30+
"ModUp": ModUp,
31+
"Relin": Relin,
32+
"Rotate": Rotate,
33+
"Rescale": Rescale,
34+
}
35+
36+
@staticmethod
37+
def parse_context(context_str: str) -> KernelContext:
38+
"""Parse the context string and return a KernelContext object."""
39+
context_match = re.search(
40+
r"KernelContext\(scheme='(?P<scheme>\w+)', "
41+
+ r"poly_order=(?P<poly_order>\w+), key_rns=(?P<key_rns>\w+), "
42+
r"current_rns=(?P<current_rns>\w+), .*? label='(?P<label>\w+)'\)",
43+
context_str,
44+
)
45+
if not context_match:
46+
raise ValueError("Invalid context string format.")
47+
return KernelContext.from_context(
48+
Context(
49+
scheme=context_match.group("scheme"),
50+
poly_order=int(context_match.group("poly_order")),
51+
key_rns=int(context_match.group("key_rns")),
52+
current_rns=int(context_match.group("current_rns")),
53+
max_rns=int(context_match.group("key_rns")) - 1,
54+
),
55+
label=context_match.group("label"),
56+
)
57+
58+
@staticmethod
59+
def parse_polys(polys_str: str) -> Polys:
60+
"""Parse the Polys string and return a Polys object."""
61+
polys_match = re.search(
62+
r"Polys\(name=(.*?), parts=(\d+), rns=(\d+)\)", polys_str
63+
)
64+
if not polys_match:
65+
raise ValueError("Invalid Polys string format.")
66+
name, parts, rns = polys_match.groups()
67+
return Polys(name=name, parts=int(parts), rns=int(rns))
68+
69+
@staticmethod
70+
def parse_immediate(immediate_str: str) -> Immediate:
71+
"""Parse the Immediate string and return an Immediate object."""
72+
immediate_match = re.search(
73+
r"Immediate\(name='(?P<name>\w+)', rns=(?P<rns>\w+)\)", immediate_str
74+
)
75+
if not immediate_match:
76+
raise ValueError("Invalid Immediate string format.")
77+
name, rns = immediate_match.group("name"), immediate_match.group("rns")
78+
rns = None if rns == "None" else int(rns)
79+
return Immediate(name=name, rns=rns)
80+
81+
@staticmethod
82+
def parse_high_op(kernel_str: str) -> HighOp:
83+
"""Parse a HighOp kernel string and return the corresponding object."""
84+
pattern = (
85+
r"### Kernel \(\d+\): (?P<op_type>\w+)\(context=(KernelContext\(.*?\)), "
86+
r"output=(Polys\(.*?\)), input0=(Polys\(.*?\))"
87+
)
88+
has_second_input = False
89+
# Check if the kernel string contains "input1" or not
90+
if "input1" not in kernel_str:
91+
# Match the operation type and its arguments
92+
high_op_match = re.search(pattern, kernel_str)
93+
else:
94+
# Adjust the pattern to include input1
95+
pattern += r", input1=(Polys\(.*?\)\)|Immediate\(.*?\)\))"
96+
# Match the operation type and its arguments
97+
high_op_match = re.search(pattern, kernel_str)
98+
has_second_input = True
99+
100+
if not high_op_match:
101+
raise ValueError(f"Invalid kernel string format: {kernel_str}.")
102+
103+
op_type = high_op_match.group("op_type")
104+
context_str, output_str, input0_str = high_op_match.groups()[1:4]
105+
106+
if has_second_input:
107+
input1_str = high_op_match.group(5)
108+
109+
# Parse the components
110+
context = KernelParser.parse_context(context_str)
111+
output = KernelParser.parse_polys(output_str)
112+
input0 = KernelParser.parse_polys(input0_str)
113+
if has_second_input:
114+
if op_type == "Muli":
115+
input1 = KernelParser.parse_immediate(input1_str)
116+
else:
117+
# For other operations, parse as Polys
118+
input1 = KernelParser.parse_polys(input1_str)
119+
120+
if op_type not in KernelParser.high_op_map:
121+
raise ValueError(f"Unsupported HighOp type: {op_type}")
122+
123+
# Instantiate the HighOp object
124+
if has_second_input:
125+
return KernelParser.high_op_map[op_type](
126+
context=context, output=output, input0=input0, input1=input1
127+
)
128+
# For operations without a second input, we can ignore the input1 parameter
129+
return KernelParser.high_op_map[op_type](
130+
context=context, output=output, input0=input0
131+
)
132+
133+
@staticmethod
134+
def parse_kernel(kernel_str: str) -> HighOp:
135+
"""Parse a kernel string and return the corresponding HighOp object."""
136+
return KernelParser.parse_high_op(kernel_str)

p-isa_tools/kerngen/kerngraph.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#! /usr/bin/env python3
2+
# Copyright (C) 2024 Intel Corporation
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
kerngraph.py
7+
8+
This script provides a command-line tool for parsing kernel strings from standard input using the KernelParser class.
9+
Future improvements may include graph representation of the parsed kernels and optimization.
10+
11+
Functions:
12+
parse_args():
13+
Parses command-line arguments.
14+
Returns:
15+
argparse.Namespace: Parsed arguments including debug flag.
16+
17+
main(args):
18+
Reads lines from standard input, parses each line as a kernel string using KernelParser,
19+
and prints the successfully parsed kernel objects. If parsing fails for a line, an error
20+
message is printed if debug mode is enabled.
21+
22+
Usage:
23+
Run the script and provide kernel strings via standard input. Use the '-d' or '--debug' flag
24+
to enable debug output for parsing errors.
25+
26+
Example:
27+
$ cat bgv.add.high | ./kerngen.py | ./kerngraph.py
28+
"""
29+
30+
31+
import argparse
32+
import sys
33+
from kernel_parser.parser import KernelParser
34+
from kernel_optimization.loops import loop_interchange
35+
from const.options import LoopKey
36+
from pisa_generators.basic import mixed_to_pisa_ops
37+
38+
39+
def parse_args():
40+
"""Parse arguments from the commandline"""
41+
parser = argparse.ArgumentParser(description="Kernel Graph Parser")
42+
parser.add_argument("-d", "--debug", action="store_true", help="Enable Debug Print")
43+
parser.add_argument(
44+
"-t",
45+
"--target",
46+
nargs="*",
47+
default=[],
48+
# Composition high ops such are ntt, mod, and relin are not currently supported
49+
choices=["add", "sub", "mul", "muli", "copy"], # currently supports single ops
50+
help="List of high_op names",
51+
)
52+
parser.add_argument(
53+
"-p",
54+
"--primary",
55+
type=LoopKey,
56+
default=LoopKey.PART,
57+
choices=list(LoopKey),
58+
help="Primary key for loop interchange (default: PART, options: RNS, PART))",
59+
)
60+
parser.add_argument(
61+
"-s",
62+
"--secondary",
63+
type=LoopKey,
64+
default=None,
65+
choices=list(LoopKey) + list([None]),
66+
help="Secondary key for loop interchange (default: None, Options: RNS, PART)",
67+
)
68+
parsed_args = parser.parse_args()
69+
# verify that primary and secondary keys are not the same
70+
if parsed_args.primary == parsed_args.secondary:
71+
raise ValueError("Primary and secondary keys cannot be the same.")
72+
return parser.parse_args()
73+
74+
75+
def main(args):
76+
"""Main function to read input and parse each line with KernelParser."""
77+
input_lines = sys.stdin.read().strip().splitlines()
78+
valid_kernels = []
79+
80+
for line in input_lines:
81+
try:
82+
kernel = KernelParser.parse_kernel(line)
83+
valid_kernels.append(kernel)
84+
except ValueError as e:
85+
if args.debug:
86+
print(f"Error parsing line: {line}\nReason: {e}")
87+
continue # Skip invalid lines
88+
89+
if not valid_kernels:
90+
print("No valid kernel strings were parsed.")
91+
else:
92+
if args.debug:
93+
print(
94+
f"# Reordered targets {args.target} with primary key {args.primary} and secondary key {args.secondary}"
95+
)
96+
for kernel in valid_kernels:
97+
if args.target and any(
98+
target.capitalize() in str(kernel) for target in args.target
99+
):
100+
kernel = loop_interchange(
101+
kernel.to_pisa(),
102+
primary_key=args.primary,
103+
secondary_key=args.secondary,
104+
)
105+
for pisa in mixed_to_pisa_ops(kernel):
106+
print(pisa)
107+
else:
108+
for pisa in kernel.to_pisa():
109+
print(pisa)
110+
111+
112+
if __name__ == "__main__":
113+
cmdline_args = parse_args()
114+
main(cmdline_args)

0 commit comments

Comments
 (0)