Skip to content

Commit

Permalink
Merge pull request #4 from OyvindTafjord/log-inputs
Browse files Browse the repository at this point in the history
Fix missing file for num_model_inputs
  • Loading branch information
OyvindTafjord authored Apr 19, 2023
2 parents 8ea4e88 + be53724 commit ff1dccb
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions catwalk/models/rank_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def predict( # type: ignore
batch_size: int = 32,
max_instances_in_memory: int = 32 * 1024,
num_shots: int = 0,
fewshot_seed: Optional[int] = None
fewshot_seed: Optional[int] = None,
num_model_inputs: Optional[int] = 0, # Number of instances to log in detail
) -> Iterator[Dict[str, Any]]:
model = self._make_model(
self.pretrained_model_name_or_path,
Expand All @@ -80,7 +81,8 @@ def predict( # type: ignore
tokenizer,
batch_size=batch_size,
num_shots=num_shots,
fewshot_seed=fewshot_seed
fewshot_seed=fewshot_seed,
num_model_inputs=num_model_inputs,
)

def predict_chunk(
Expand All @@ -91,7 +93,8 @@ def predict_chunk(
tokenizer: _Tokenizer,
batch_size: int = 32,
num_shots: int = 0,
fewshot_seed: Optional[int] = None
fewshot_seed: Optional[int] = None,
num_model_inputs: Optional[int] = 0, # Number of model inputs to log in detail
) -> Iterator[Dict[str, Any]]:
instance_index_to_tuple_indices: Mapping[int, List[int]] = collections.defaultdict(list)
tuples: List[Tuple[str, str]] = []
Expand Down Expand Up @@ -121,10 +124,12 @@ def predict_chunk(
results_for_instance = [results[i] for i in tuple_indices]
result_tensor = torch.tensor(results_for_instance)
metric_args = (result_tensor, instance.correct_choice)
yield {
metric_name: metric_args
for metric_name in task.metrics.keys()
}
prediction = {metric_name: metric_args for metric_name in task.metrics.keys()}
if instance_index >= num_model_inputs:
yield prediction
else:
model_input = [tuples[i] for i in tuple_indices]
yield {"model_input": model_input, "prediction": prediction}

def _run_loglikelihood(
self,
Expand Down Expand Up @@ -308,10 +313,11 @@ def _make_model(
pretrained_model_name_or_path: str,
*,
make_copy: bool = False,
model_class: Any = AutoModelForCausalLM,
**kwargs
) -> GPT2LMHeadModel:
return cached_transformers.get(
AutoModelForCausalLM,
model_class,
pretrained_model_name_or_path,
make_copy=make_copy,
**kwargs)
Expand Down

0 comments on commit ff1dccb

Please sign in to comment.