11from concurrent .futures import ThreadPoolExecutor
22from copy import deepcopy
3+ import itertools
34import sys
45import traceback
56from typing import TYPE_CHECKING
@@ -297,6 +298,12 @@ def as_dataframe(self) -> None:
297298 return pd .DataFrame (data = records_list , columns = pd .MultiIndex .from_tuples (column_tuples ))
298299
299300
301+ class _ExperimentRunInfo :
302+ def __init__ (self , run_interation : int ):
303+ self ._id = uuid .uuid4 ()
304+ self ._run_iteration = run_interation
305+
306+
300307class Experiment :
301308 def __init__ (
302309 self ,
@@ -316,6 +323,7 @@ def __init__(
316323 ]
317324 ]
318325 ] = None ,
326+ runs : Optional [int ] = None ,
319327 ) -> None :
320328 self .name = name
321329 self ._task = task
@@ -326,6 +334,7 @@ def __init__(
326334 self ._tags : Dict [str , str ] = tags or {}
327335 self ._tags ["ddtrace.version" ] = str (ddtrace .__version__ )
328336 self ._config : Dict [str , JSONType ] = config or {}
337+ self ._runs : int = runs or 1
329338 self ._llmobs_instance = _llmobs_instance
330339
331340 if not project_name :
@@ -358,18 +367,23 @@ def run(self, jobs: int = 1, raise_errors: bool = False, sample_size: Optional[i
358367 self ._config ,
359368 convert_tags_dict_to_list (self ._tags ),
360369 self ._description ,
370+ self ._runs ,
361371 )
362372 self ._id = experiment_id
363373 self ._tags ["experiment_id" ] = str (experiment_id )
364374 self ._run_name = experiment_run_name
365- task_results = self ._run_task (jobs , raise_errors , sample_size )
366- evaluations = self ._run_evaluators (task_results , raise_errors = raise_errors )
367- summary_evals = self ._run_summary_evaluators (task_results , evaluations , raise_errors )
368- experiment_results = self ._merge_results (task_results , evaluations , summary_evals )
369- experiment_evals = self ._generate_metrics_from_exp_results (experiment_results )
370- self ._llmobs_instance ._dne_client .experiment_eval_post (
371- self ._id , experiment_evals , convert_tags_dict_to_list (self ._tags )
372- )
375+ for run_iteration in range (self ._runs ):
376+ run = _ExperimentRunInfo (run_iteration )
377+ self ._tags ["run_id" ] = str (run ._id )
378+ self ._tags ["run_iteration" ] = str (run ._run_iteration )
379+ task_results = self ._run_task (jobs , run , raise_errors , sample_size )
380+ evaluations = self ._run_evaluators (task_results , raise_errors = raise_errors )
381+ summary_evals = self ._run_summary_evaluators (task_results , evaluations , raise_errors )
382+ experiment_results = self ._merge_results (task_results , evaluations , summary_evals )
383+ experiment_evals = self ._generate_metrics_from_exp_results (experiment_results )
384+ self ._llmobs_instance ._dne_client .experiment_eval_post (
385+ self ._id , experiment_evals , convert_tags_dict_to_list (self ._tags )
386+ )
373387
374388 return experiment_results
375389
@@ -378,11 +392,13 @@ def url(self) -> str:
378392 # FIXME: will not work for subdomain orgs
379393 return f"{ _get_base_url ()} /llm/experiments/{ self ._id } "
380394
381- def _process_record (self , idx_record : Tuple [int , DatasetRecord ]) -> Optional [TaskResult ]:
395+ def _process_record (self , idx_record : Tuple [int , DatasetRecord ], run : _ExperimentRunInfo ) -> Optional [TaskResult ]:
382396 if not self ._llmobs_instance or not self ._llmobs_instance .enabled :
383397 return None
384398 idx , record = idx_record
385- with self ._llmobs_instance ._experiment (name = self ._task .__name__ , experiment_id = self ._id ) as span :
399+ with self ._llmobs_instance ._experiment (
400+ name = self ._task .__name__ , experiment_id = self ._id , run_id = str (run ._id ), run_iteration = run ._run_iteration
401+ ) as span :
386402 span_context = self ._llmobs_instance .export_span (span = span )
387403 if span_context :
388404 span_id = span_context .get ("span_id" , "" )
@@ -422,7 +438,9 @@ def _process_record(self, idx_record: Tuple[int, DatasetRecord]) -> Optional[Tas
422438 },
423439 }
424440
425- def _run_task (self , jobs : int , raise_errors : bool = False , sample_size : Optional [int ] = None ) -> List [TaskResult ]:
441+ def _run_task (
442+ self , jobs : int , run : _ExperimentRunInfo , raise_errors : bool = False , sample_size : Optional [int ] = None
443+ ) -> List [TaskResult ]:
426444 if not self ._llmobs_instance or not self ._llmobs_instance .enabled :
427445 return []
428446 if sample_size is not None and sample_size < len (self ._dataset ):
@@ -441,7 +459,9 @@ def _run_task(self, jobs: int, raise_errors: bool = False, sample_size: Optional
441459 subset_dataset = self ._dataset
442460 task_results = []
443461 with ThreadPoolExecutor (max_workers = jobs ) as executor :
444- for result in executor .map (self ._process_record , enumerate (subset_dataset )):
462+ for result in executor .map (
463+ self ._process_record , enumerate (subset_dataset ), itertools .repeat (run , len (subset_dataset ))
464+ ):
445465 if not result :
446466 continue
447467 task_results .append (result )
0 commit comments