66"""
77
88import json
9+ import os
910import subprocess as sp
1011import sys
1112from dataclasses import dataclass
1213from glob import glob , iglob
1314from inspect import cleandoc
1415from os import getenv
1516from pathlib import Path
16- from typing import TypedDict
17+ from typing import TypedDict , Self
1718
1819USAGE = cleandoc (
1920 """
5152ARTIFACT_GLOB = "baseline-icount*"
5253# Place this in a PR body to skip regression checks (must be at the start of a line).
5354REGRESSION_DIRECTIVE = "ci: allow-regressions"
55+ # Place this in a PR body to skip extensive tests
56+ SKIP_EXTENSIVE_DIRECTIVE = "ci: skip-extensive"
5457
5558# Don't run exhaustive tests if these files change, even if they contaiin a function
5659# definition.
@@ -68,6 +71,39 @@ def eprint(*args, **kwargs):
6871 print (* args , file = sys .stderr , ** kwargs )
6972
7073
74+ @dataclass
75+ class PrInfo :
76+ """GitHub response for PR query"""
77+
78+ body : str
79+ commits : list [str ]
80+ created_at : str
81+ number : int
82+
83+ @classmethod
84+ def load (cls , pr_number : int | str ) -> Self :
85+ """For a given PR number, query the body and commit list"""
86+ pr_info = sp .check_output (
87+ [
88+ "gh" ,
89+ "pr" ,
90+ "view" ,
91+ str (pr_number ),
92+ "--json=number,commits,body,createdAt" ,
93+ # Flatten the commit list to only hashes, change a key to snake naming
94+ "--jq=.commits |= map(.oid) | .created_at = .createdAt | del(.createdAt)" ,
95+ ],
96+ text = True ,
97+ )
98+ eprint ("PR info:" , json .dumps (pr_info , indent = 4 ))
99+ return cls (** json .loads (pr_info ))
100+
101+ def contains_directive (self , directive : str ) -> bool :
102+ """Return true if the provided directive is on a line in the PR body"""
103+ lines = self .body .splitlines ()
104+ return any (line .startswith (directive ) for line in lines )
105+
106+
71107class FunctionDef (TypedDict ):
72108 """Type for an entry in `function-definitions.json`"""
73109
@@ -149,7 +185,7 @@ def changed_routines(self) -> dict[str, list[str]]:
149185 eprint (f"changed files for { name } : { changed } " )
150186 routines .add (name )
151187
152- ret = {}
188+ ret : dict [ str , list [ str ]] = {}
153189 for r in sorted (routines ):
154190 ret .setdefault (self .defs [r ]["type" ], []).append (r )
155191
@@ -159,13 +195,27 @@ def make_workflow_output(self) -> str:
159195 """Create a JSON object a list items for each type's changed files, if any
160196 did change, and the routines that were affected by the change.
161197 """
198+
199+ pr_number = os .environ .get ("PR_NUMBER" )
200+ skip_tests = False
201+
202+ if pr_number is not None :
203+ pr = PrInfo .load (pr_number )
204+ skip_tests = pr .contains_directive (SKIP_EXTENSIVE_DIRECTIVE )
205+
206+ if skip_tests :
207+ eprint ("Skipping all extensive tests" )
208+
162209 changed = self .changed_routines ()
163210 ret = []
164211 for ty in TYPES :
165212 ty_changed = changed .get (ty , [])
213+ changed_str = "," .join (ty_changed )
214+
166215 item = {
167216 "ty" : ty ,
168- "changed" : "," .join (ty_changed ),
217+ "changed" : changed_str ,
218+ "to_test" : "" if skip_tests else changed_str ,
169219 }
170220 ret .append (item )
171221 output = json .dumps ({"matrix" : ret }, separators = ("," , ":" ))
@@ -266,13 +316,13 @@ def check_iai_regressions(args: list[str]):
266316 found.
267317 """
268318
269- iai_home = "iai-home"
270- pr_number = False
319+ iai_home_str = "iai-home"
320+ pr_number = None
271321
272322 while len (args ) > 0 :
273323 match args :
274324 case ["--home" , home , * rest ]:
275- iai_home = home
325+ iai_home_str = home
276326 args = rest
277327 case ["--allow-pr-override" , pr_num , * rest ]:
278328 pr_number = pr_num
@@ -281,18 +331,20 @@ def check_iai_regressions(args: list[str]):
281331 eprint (USAGE )
282332 exit (1 )
283333
284- iai_home = Path (iai_home )
334+ iai_home = Path (iai_home_str )
285335
286336 found_summaries = False
287- regressions = []
337+ regressions : list [ dict ] = []
288338 for summary_path in iglob ("**/summary.json" , root_dir = iai_home , recursive = True ):
289339 found_summaries = True
290340 with open (iai_home / summary_path , "r" ) as f :
291341 summary = json .load (f )
292342
293343 summary_regs = []
294344 run = summary ["callgrind_summary" ]["callgrind_run" ]
295- name_entry = {"name" : f"{ summary ["function_name" ]} .{ summary ["id" ]} " }
345+ fname = summary ["function_name" ]
346+ id = summary ["id" ]
347+ name_entry = {"name" : f"{ fname } .{ id } " }
296348
297349 for segment in run ["segments" ]:
298350 summary_regs .extend (segment ["regressions" ])
@@ -312,22 +364,8 @@ def check_iai_regressions(args: list[str]):
312364 eprint ("Found regressions:" , json .dumps (regressions , indent = 4 ))
313365
314366 if pr_number is not None :
315- pr_info = sp .check_output (
316- [
317- "gh" ,
318- "pr" ,
319- "view" ,
320- str (pr_number ),
321- "--json=number,commits,body,createdAt" ,
322- "--jq=.commits |= map(.oid)" ,
323- ],
324- text = True ,
325- )
326- pr = json .loads (pr_info )
327- eprint ("PR info:" , json .dumps (pr , indent = 4 ))
328-
329- lines = pr ["body" ].splitlines ()
330- if any (line .startswith (REGRESSION_DIRECTIVE ) for line in lines ):
367+ pr = PrInfo .load (pr_number )
368+ if pr .contains_directive (REGRESSION_DIRECTIVE ):
331369 eprint ("PR allows regressions, returning" )
332370 return
333371
0 commit comments