Skip to content

Commit b0be8c4

Browse files
committed
Begin updating to SS:GB 6.2.0; static inline still WIP
1 parent 9585e0c commit b0be8c4

File tree

6 files changed

+853
-5
lines changed

6 files changed

+853
-5
lines changed

suitesparse_graphblas/build.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88

99
ffibuilder = FFI()
1010

11-
with open(os.path.join(thisdir, "source.c")) as f:
11+
if is_win:
12+
source_filename = "source_no_complex.c"
13+
else:
14+
source_filename = "source.c"
15+
16+
with open(os.path.join(thisdir, source_filename)) as f:
1217
source = f.read()
1318

1419
include_dirs = [os.path.join(sys.prefix, "include")]

suitesparse_graphblas/create_headers.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def groupby(index, seq):
315315
"CMPLX",
316316
"CMPLXF",
317317
"GB_PUBLIC",
318+
"GB_restrict",
318319
"GRAPHBLAS_H",
319320
"GrB_INVALID_HANDLE",
320321
"GrB_NULL",
@@ -426,6 +427,23 @@ def get_groups(ast):
426427
seen.update(vals)
427428
groups["GxB typedef funcs"] = sorted(vals, key=sort_key)
428429

430+
vals = []
431+
next_i = -1
432+
for i, line in enumerate(lines):
433+
if i < next_i or line in seen:
434+
continue
435+
if "inline static" in line and ("GB" in line or "GrB" in line or "GxB" in line):
436+
val = [line]
437+
i += 1
438+
while lines[i] != "}":
439+
val.append(lines[i])
440+
i += 1
441+
val.append(lines[i])
442+
next_i = i + 1
443+
seen.update(val)
444+
vals.append("\n".join(val))
445+
groups["static inline"] = vals
446+
429447
vals = {x for x in lines if "typedef" in x and "GrB" in x} - seen
430448
assert not vals, ", ".join(sorted(vals))
431449
groups["not seen"] = sorted(set(lines) - seen, key=sort_key)
@@ -579,6 +597,19 @@ def visit_Decl(self, node):
579597
if isinstance(node.type, c_ast.FuncDecl) and node.storage == ["extern"]:
580598
self.functions.append(node)
581599

600+
class FuncDefVisitorStaticInline(c_ast.NodeVisitor):
601+
def __init__(self):
602+
self.functions = []
603+
604+
def visit_FuncDef(self, node):
605+
decl = node.decl
606+
if (
607+
isinstance(decl.type, c_ast.FuncDecl)
608+
and decl.storage == ["static"]
609+
and decl.funcspec == ["inline"]
610+
):
611+
self.functions.append(node)
612+
582613
def handle_function_node(node):
583614
if generator.visit(node.type.type) != "GrB_Info":
584615
raise ValueError(generator.visit(node))
@@ -599,6 +630,7 @@ def handle_function_node(node):
599630
group = {
600631
# Apply our naming scheme
601632
"GrB_Matrix": "matrix",
633+
"Matrix": "matrix",
602634
"GrB_Vector": "vector",
603635
"GxB_Scalar": "scalar",
604636
"SelectOp": "selectop",
@@ -610,6 +642,7 @@ def handle_function_node(node):
610642
"Type": "type",
611643
"UnaryOp": "unary",
612644
"IndexUnaryOp": "indexunary",
645+
"Iterator": "iterator",
613646
# "everything else" is "core"
614647
"getVersion": "core",
615648
"Global": "core",
@@ -636,16 +669,42 @@ def handle_function_node(node):
636669
assert len(gxb_nodes) == len(groups["GxB methods"])
637670
assert len(gb_nodes) == len(groups["GB methods"])
638671

672+
visitor = FuncDefVisitorStaticInline()
673+
visitor.visit(ast)
674+
static_inline_nodes = visitor.functions
675+
assert len(static_inline_nodes) == len(groups["static inline"])
676+
for node in static_inline_nodes:
677+
# Sanity check
678+
text = generator.visit(node).strip()
679+
assert text in groups["static inline"]
680+
681+
def handle_static_inline(node):
682+
decl = node.decl
683+
if decl.name in DEPRECATED:
684+
return
685+
text = generator.visit(node).strip()
686+
if skip_complex and has_complex(text):
687+
return
688+
return {
689+
"name": decl.name,
690+
"group": "static inline",
691+
"node": node,
692+
"text": text + "\n",
693+
}
694+
639695
grb_funcs = (handle_function_node(node) for node in grb_nodes)
640696
gxb_funcs = (handle_function_node(node) for node in gxb_nodes)
641697
gb_funcs = (handle_function_node(node) for node in gb_nodes)
698+
si_funcs = (handle_static_inline(node) for node in static_inline_nodes)
642699
grb_funcs = [x for x in grb_funcs if x is not None]
643700
gxb_funcs = [x for x in gxb_funcs if x is not None]
644701
gb_funcs = [x for x in gb_funcs if x is not None]
702+
si_funcs = [x for x in si_funcs if x is not None]
645703

646704
rv["GrB methods"] = sorted(grb_funcs, key=lambda x: sort_key(x["text"]))
647705
rv["GxB methods"] = sorted(gxb_funcs, key=lambda x: sort_key(x["text"]))
648706
rv["GB methods"] = sorted(gb_funcs, key=lambda x: sort_key(x["text"]))
707+
rv["static inline"] = si_funcs # Should we sort these?
649708
for key in groups.keys() - rv.keys():
650709
rv[key] = groups[key]
651710
return rv
@@ -732,6 +791,10 @@ def handle_funcs(group):
732791
text.append("****************/")
733792
text.extend(handle_funcs(groups["GxB methods"]))
734793

794+
# Cython doesn't like compiling this; add to source.c instead (may work?)
795+
# text.append("")
796+
# text.extend(handle_funcs(groups["static inline"]))
797+
735798
text.append("")
736799
text.append("/* int DEFINES */")
737800
for item in sorted(defines, key=sort_key):
@@ -744,7 +807,7 @@ def handle_funcs(group):
744807
return text
745808

746809

747-
def create_source_text(*, char_defines=None):
810+
def create_source_text(groups, *, char_defines=None):
748811
if char_defines is None:
749812
char_defines = CHAR_DEFINES
750813
text = [
@@ -753,6 +816,9 @@ def create_source_text(*, char_defines=None):
753816
]
754817
for item in sorted(char_defines, key=sort_key):
755818
text.append(f"char *{item}_STR = {item};")
819+
text.append("")
820+
for node in groups["static inline"]:
821+
text.append(node["text"])
756822
return text
757823

758824

@@ -780,6 +846,7 @@ def main():
780846
final_h = os.path.join(thisdir, "suitesparse_graphblas.h")
781847
final_no_complex_h = os.path.join(thisdir, "suitesparse_graphblas_no_complex.h")
782848
source_c = os.path.join(thisdir, "source.c")
849+
source_no_complex_c = os.path.join(thisdir, "source_no_complex.c")
783850

784851
# Copy original file
785852
print(f"Step 1: copy {args.graphblas} to {graphblas_h}")
@@ -814,12 +881,18 @@ def main():
814881

815882
# Create source
816883
print(f"Step 5: create {source_c}")
817-
text = create_source_text()
884+
text = create_source_text(groups)
818885
with open(source_c, "w") as f:
819886
f.write("\n".join(text))
820887

888+
# Create source (no complex)
889+
print(f"Step 6: create {source_no_complex_c}")
890+
text = create_source_text(groups_no_complex)
891+
with open(source_no_complex_c, "w") as f:
892+
f.write("\n".join(text))
893+
821894
# Check defines
822-
print("Step 6: check #define definitions")
895+
print("Step 7: check #define definitions")
823896
with open(graphblas_h) as f:
824897
text = f.read()
825898
define_lines = re.compile(r".*?#define\s+\w+\s+")

0 commit comments

Comments
 (0)