Skip to content

Commit

Permalink
Add option to show a few model inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
OyvindTafjord committed Apr 19, 2023
1 parent f0d3b76 commit b017f8c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
20 changes: 15 additions & 5 deletions catwalk/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_parser.add_argument('--limit', type=int)
_parser.add_argument('--full_output_file', type=str, default=None, help="Filename for verbose output")
_parser.add_argument('--metrics_file', type=str, default=None, help="Filename for metrics output")
_parser.add_argument('--num_model_inputs', type=int, default=0, help="Number of sample model inputs in full output, for sanity checks")
_parser.add_argument('-d', '-w', type=str, default=None, metavar="workspace", dest="workspace", help="the Tango workspace with the cache")


Expand Down Expand Up @@ -71,14 +72,16 @@ def main(args: argparse.Namespace):
kwargs["num_shots"] = args.num_shots
if args.fewshot_seed is not None:
kwargs["fewshot_seed"] = args.fewshot_seed
if args.num_model_inputs:
kwargs["num_model_inputs"] = args.num_model_inputs
random_subsample_seed = None

metric_task_dict = {}
if save_output:
verbose_output = []
for task in tasks:
logger.info(f"Processing task: {task}")
predictions = PredictStep(
full_predictions = PredictStep(
model=args.model,
task=task,
split=args.split,
Expand All @@ -88,7 +91,7 @@ def main(args: argparse.Namespace):
metrics = CalculateMetricsStep(
model=args.model,
task=task,
predictions=predictions)
predictions=full_predictions)
metric_task_dict[task] = metrics
if save_output:
task_obj = task
Expand All @@ -98,13 +101,20 @@ def main(args: argparse.Namespace):
if split is None:
split = task_obj.default_split
instances = get_instances(task_obj, split, limit, random_subsample_seed)
predictions_explicit = list(predictions.result(workspace))
predictions_explicit = list(full_predictions.result(workspace))
metrics_explicit = metrics.result(workspace)
output = {"task": task, "model": args.model, "split": split, "limit": limit, "metrics": metrics_explicit,
"num_instances": len(instances)}
logger.info(f"Results from task {task}: {output}")
output["per_instance"] = [{"instance": guess_instance_id(inst), "prediction": prediction} for \
inst, prediction in zip(instances, predictions_explicit)]
per_instance = []
for inst, p in zip(instances, predictions_explicit):
res1 = {"instance": guess_instance_id(inst), "prediction": p.get('prediction', p)}
if 'model_input' in p:
res1['model_input'] = p['model_input']
per_instance.append(res1)
output["per_instance"] = per_instance
if per_instance:
logger.info(f"First instance details for task {task}: {per_instance[0]}")
verbose_output.append(output)
if args.full_output_file:
logger.info(f"Saving full output in {args.full_output_file}...")
Expand Down
4 changes: 2 additions & 2 deletions catwalk/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def run(
model = MODELS[model]
if isinstance(task, str):
task = TASKS[task]

return model.calculate_metrics(task, predictions)
predictions_raw = [p.get('prediction', p) for p in predictions]
return model.calculate_metrics(task, predictions_raw)


@Step.register("catwalk::finetune")
Expand Down
3 changes: 3 additions & 0 deletions catwalk/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any
import dataclasses
import torch

def sanitize(x: Any) -> Any:
Expand All @@ -22,6 +23,8 @@ def sanitize(x: Any) -> Any:
return "None"
elif hasattr(x, "to_json"):
return x.to_json()
elif dataclasses.is_dataclass(x):
return sanitize(dataclasses.asdict(x))
else:
raise ValueError(
f"Cannot sanitize {x} of type {type(x)}. "
Expand Down

0 comments on commit b017f8c

Please sign in to comment.