Skip to content

Commit de62e23

Browse files
Working loop interchange, initial checking
1 parent 9d493c3 commit de62e23

File tree

4 files changed

+82
-1
lines changed

4 files changed

+82
-1
lines changed

kerngen/const/options.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Module for defining constants and enums used in the kernel generator"""
5+
from enum import Enum
6+
7+
8+
class LoopKey(Enum):
9+
"""Sort keys for PIsaOp instructions"""
10+
11+
RNS = "rns"
12+
PART = "part"
13+
UNIT = "unit"
14+
15+
@classmethod
16+
def from_str(cls, value: str) -> "LoopKey":
17+
"""Convert a string to a LoopKey enum"""
18+
if value is None:
19+
raise ValueError("LoopKey cannot be None")
20+
try:
21+
return cls[value.upper()]
22+
except KeyError:
23+
raise ValueError(f"Invalid LoopKey: {value}") from None
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0

kerngen/kernel_optimization/loops.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Module for loop interchange optimization in P-ISA operations"""
5+
6+
import re
7+
from const.options import LoopKey
8+
from high_parser.pisa_operations import PIsaOp, Comment
9+
10+
11+
def loop_interchange(
12+
pisa_list: list[PIsaOp],
13+
primary_key: LoopKey | None = LoopKey.PART,
14+
secondary_key: LoopKey | None = LoopKey.RNS,
15+
) -> list[PIsaOp]:
16+
"""Batch pisa_list into groups and sort them by primary and optional secondary keys.
17+
18+
Args:
19+
pisa_list: List of PIsaOp instructions
20+
primary_key: Primary sort criterion from SortKey enum
21+
secondary_key: Optional secondary sort criterion from SortKey enum
22+
23+
Returns:
24+
List of processed PIsaOp instructions
25+
26+
Raises:
27+
ValueError: If invalid sort key values provided
28+
"""
29+
if primary_key is None and secondary_key is None:
30+
return pisa_list
31+
32+
def get_sort_value(pisa: PIsaOp, key: LoopKey) -> int:
33+
match key:
34+
case LoopKey.RNS:
35+
return pisa.q
36+
case LoopKey.PART:
37+
match = re.search(r"_(\d+)_", str(pisa))
38+
return int(match[1]) if match else 0
39+
case LoopKey.UNIT:
40+
match = re.search(r"_(\d+),", str(pisa))
41+
return int(match[1]) if match else 0
42+
case _:
43+
raise ValueError(f"Invalid sort key value: {key}")
44+
45+
def get_sort_key(pisa: PIsaOp) -> tuple:
46+
primary_value = get_sort_value(pisa, primary_key)
47+
if secondary_key:
48+
secondary_value = get_sort_value(pisa, secondary_key)
49+
return (primary_value, secondary_value)
50+
return (primary_value,)
51+
52+
# Filter out comments
53+
pisa_list_wo_comments = [p for p in pisa_list if not isinstance(p, Comment)]
54+
# Sort based on primary and optional secondary keys
55+
pisa_list_wo_comments.sort(key=get_sort_key)
56+
return pisa_list_wo_comments
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# Copyright (C) 2024 Intel Corporation
2-
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-License-Identifier: Apache-2.0

0 commit comments

Comments
 (0)