From be53724fa9e5d153d8b3355dc2ffc6014fbc34a8 Mon Sep 17 00:00:00 2001 From: Oyvind Tafjord Date: Tue, 18 Apr 2023 21:33:31 -0700 Subject: [PATCH] Fix missing file for num_model_inputs --- catwalk/models/rank_classification.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/catwalk/models/rank_classification.py b/catwalk/models/rank_classification.py index 75fd10a..abd0324 100644 --- a/catwalk/models/rank_classification.py +++ b/catwalk/models/rank_classification.py @@ -64,7 +64,8 @@ def predict( # type: ignore batch_size: int = 32, max_instances_in_memory: int = 32 * 1024, num_shots: int = 0, - fewshot_seed: Optional[int] = None + fewshot_seed: Optional[int] = None, + num_model_inputs: Optional[int] = 0, # Number of instances to log in detail ) -> Iterator[Dict[str, Any]]: model = self._make_model( self.pretrained_model_name_or_path, @@ -80,7 +81,8 @@ def predict( # type: ignore tokenizer, batch_size=batch_size, num_shots=num_shots, - fewshot_seed=fewshot_seed + fewshot_seed=fewshot_seed, + num_model_inputs=num_model_inputs, ) def predict_chunk( @@ -91,7 +93,8 @@ def predict_chunk( tokenizer: _Tokenizer, batch_size: int = 32, num_shots: int = 0, - fewshot_seed: Optional[int] = None + fewshot_seed: Optional[int] = None, + num_model_inputs: Optional[int] = 0, # Number of model inputs to log in detail ) -> Iterator[Dict[str, Any]]: instance_index_to_tuple_indices: Mapping[int, List[int]] = collections.defaultdict(list) tuples: List[Tuple[str, str]] = [] @@ -121,10 +124,12 @@ def predict_chunk( results_for_instance = [results[i] for i in tuple_indices] result_tensor = torch.tensor(results_for_instance) metric_args = (result_tensor, instance.correct_choice) - yield { - metric_name: metric_args - for metric_name in task.metrics.keys() - } + prediction = {metric_name: metric_args for metric_name in task.metrics.keys()} + if instance_index >= num_model_inputs: + yield prediction + else: + model_input = [tuples[i] for i in tuple_indices] + yield {"model_input": model_input, "prediction": prediction} def _run_loglikelihood( self, @@ -308,10 +313,11 @@ def _make_model( pretrained_model_name_or_path: str, *, make_copy: bool = False, + model_class: Any = AutoModelForCausalLM, **kwargs ) -> GPT2LMHeadModel: return cached_transformers.get( - AutoModelForCausalLM, + model_class, pretrained_model_name_or_path, make_copy=make_copy, **kwargs)