88import multiprocessing as mp
99import multiprocessing .connection as mpc
1010from dataclasses import dataclass
11- from typing import Any
11+ from typing import Any , Union
1212
1313import sqlalchemy
1414
@@ -247,32 +247,20 @@ def run_benchmark(
247247 results .error_msg = "No framework"
248248 return (results , {})
249249
250- results_dict = dict ()
251- _exec (
250+ output = _exec (
252251 bench ,
253252 framework ,
254253 rc .implementation ,
255254 rc .preset ,
256255 rc .repeat ,
257- results_dict ,
256+ results ,
258257 # copy output if we want to validate results
259258 rc .validate ,
260259 )
261260
262- results .error_state = results_dict .get (
263- "error_state" , ErrorCodes .FAILED_EXECUTION
264- )
265- results .error_msg = results_dict .get ("error_msg" , "Unexpected crash" )
266- results .input_size = results_dict .get ("input_size" )
267261 if results .error_state != ErrorCodes .SUCCESS :
268262 return (results , {})
269263
270- output = results_dict .get ("outputs" , {})
271- results .setup_time = results_dict ["setup_time" ]
272- results .warmup_time = results_dict ["warmup_time" ]
273- results .exec_times = results_dict ["exec_times" ]
274- results .teardown_time = results_dict ["teardown_time" ]
275-
276264 if rc .validate and results .error_state == ErrorCodes .SUCCESS :
277265 ref_framework = build_framework (rc .ref_framework )
278266 ref_output = _exec_simple (
@@ -433,11 +421,8 @@ def _exec_simple(
433421
434422 try :
435423 retval = framework .execute (impl_fn , inputs )
436- results_dict = {}
437-
438- _exec_copy_output (bench , framework , retval , inputs , results_dict )
439424
440- return results_dict [ "outputs" ]
425+ return _exec_copy_output ( bench , framework , retval , inputs , None )
441426 except Exception :
442427 logging .exception ("Benchmark execution failed at the warmup step." )
443428 return None
@@ -449,9 +434,9 @@ def _exec(
449434 impl_postfix : str ,
450435 preset : str ,
451436 repeat : int ,
452- results_dict : dict ,
437+ results : BenchmarkResults ,
453438 copy_output : bool ,
454- ):
439+ ) -> Union [ dict , None ] :
455440 """Executes a benchmark for a given implementation.
456441
457442 A helper function to execute a benchmark. The function is called in a
@@ -473,20 +458,18 @@ def _exec(
473458 repeat : Number of repetitions of the benchmark execution.
474459 precision: The precision to use for benchmark input data.
475460 args : Input arguments to benchmark implementation function.
476- results_dict : A dictionary where timing and other results are stored.
461+ results : A benchmark results where timing and other results are stored.
477462 copy_output : A flag that controls copying output.
478463 """
479464 np_input_data = bench .get_input_data (preset = preset )
480465
481466 with timer () as t :
482467 inputs = _set_input_args (bench , framework , np_input_data )
483- results_dict [ " setup_time" ] = t .get_elapsed_time ()
468+ results . setup_time = t .get_elapsed_time ()
484469
485- input_size = 0
470+ results . input_size = 0
486471 for arg in bench .info .array_args :
487- input_size += _array_size (bench .bdata [preset ][arg ])
488-
489- results_dict ["input_size" ] = input_size
472+ results .input_size += _array_size (bench .bdata [preset ][arg ])
490473
491474 impl_fn = bench .get_implementation (impl_postfix )
492475
@@ -496,11 +479,11 @@ def _exec(
496479 framework .execute (impl_fn , inputs )
497480 except Exception :
498481 logging .exception ("Benchmark execution failed at the warmup step." )
499- results_dict [ " error_state" ] = ErrorCodes .FAILED_EXECUTION
500- results_dict [ " error_msg" ] = "Execution failed"
482+ results . error_state = ErrorCodes .FAILED_EXECUTION
483+ results . error_msg = "Execution failed"
501484 return
502485
503- results_dict [ " warmup_time" ] = t .get_elapsed_time ()
486+ results . warmup_time = t .get_elapsed_time ()
504487
505488 _reset_output_args (bench , framework , inputs , np_input_data )
506489
@@ -516,24 +499,26 @@ def _exec(
516499 if i < repeat - 1 :
517500 _reset_output_args (bench , framework , inputs , np_input_data )
518501
519- results_dict [ " exec_times" ] = exec_times
502+ results . exec_times = exec_times
520503
521504 # Get the output data
522- results_dict ["teardown_time" ] = 0.0
505+ results .teardown_time = 0.0
506+ results .error_state = ErrorCodes .SUCCESS
507+ results .error_msg = ""
508+
523509 if copy_output :
524- _exec_copy_output (bench , framework , retval , inputs , results_dict )
510+ return _exec_copy_output (bench , framework , retval , inputs , results )
525511
526- results_dict ["error_state" ] = ErrorCodes .SUCCESS
527- results_dict ["error_msg" ] = ""
512+ return None
528513
529514
530515def _exec_copy_output (
531516 bench : Benchmark ,
532517 fmwrk : Framework ,
533518 retval ,
534519 inputs : dict ,
535- results_dict : dict ,
536- ):
520+ results : BenchmarkResults ,
521+ ) -> dict :
537522 output_arrays = dict ()
538523 with timer () as t :
539524 for out_arg in bench .info .output_args :
@@ -545,8 +530,10 @@ def _exec_copy_output(
545530 if retval is not None :
546531 output_arrays ["return-value" ] = convert_to_numpy (retval , fmwrk )
547532
548- results_dict ["outputs" ] = output_arrays
549- results_dict ["teardown_time" ] = t .get_elapsed_time ()
533+ if results :
534+ results .teardown_time = t .get_elapsed_time ()
535+
536+ return output_arrays
550537
551538
552539def convert_to_numpy (value : any , fmwrk : Framework ) -> any :
0 commit comments