Skip to content

Support printing reasons in the console output for pydantic-evals #2163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions pydantic_evals/pydantic_evals/reporting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Mapping
from dataclasses import dataclass
from io import StringIO
from typing import Any, Callable, Generic, Literal, Protocol
from typing import Any, Callable, Generic, Literal, Protocol, cast

from pydantic import BaseModel, TypeAdapter
from rich.console import Console
Expand Down Expand Up @@ -168,6 +168,7 @@ def print(
self,
width: int | None = None,
baseline: EvaluationReport[InputsT, OutputT, MetadataT] | None = None,
*,
include_input: bool = False,
include_metadata: bool = False,
include_expected_output: bool = False,
Expand All @@ -183,6 +184,7 @@ def print(
label_configs: dict[str, RenderValueConfig] | None = None,
metric_configs: dict[str, RenderNumberConfig] | None = None,
duration_config: RenderNumberConfig | None = None,
include_reasons: bool = False,
): # pragma: no cover
"""Print this report to the console, optionally comparing it to a baseline report.

Expand All @@ -205,12 +207,14 @@ def print(
label_configs=label_configs,
metric_configs=metric_configs,
duration_config=duration_config,
include_reasons=include_reasons,
)
Console(width=width).print(table)

def console_table(
self,
baseline: EvaluationReport[InputsT, OutputT, MetadataT] | None = None,
*,
include_input: bool = False,
include_metadata: bool = False,
include_expected_output: bool = False,
Expand All @@ -226,6 +230,7 @@ def console_table(
label_configs: dict[str, RenderValueConfig] | None = None,
metric_configs: dict[str, RenderNumberConfig] | None = None,
duration_config: RenderNumberConfig | None = None,
include_reasons: bool = False,
) -> Table:
"""Return a table containing the data from this report, or the diff between this report and a baseline report.

Expand All @@ -247,6 +252,7 @@ def console_table(
label_configs=label_configs or {},
metric_configs=metric_configs or {},
duration_config=duration_config or _DEFAULT_DURATION_CONFIG,
include_reasons=include_reasons,
)
if baseline is None:
return renderer.build_table(self)
Expand Down Expand Up @@ -529,15 +535,16 @@ class ReportCaseRenderer:
include_labels: bool
include_metrics: bool
include_assertions: bool
include_reasons: bool
include_durations: bool
include_total_duration: bool

input_renderer: _ValueRenderer
metadata_renderer: _ValueRenderer
output_renderer: _ValueRenderer
score_renderers: dict[str, _NumberRenderer]
label_renderers: dict[str, _ValueRenderer]
metric_renderers: dict[str, _NumberRenderer]
score_renderers: Mapping[str, _NumberRenderer]
label_renderers: Mapping[str, _ValueRenderer]
metric_renderers: Mapping[str, _NumberRenderer]
duration_renderer: _NumberRenderer

def build_base_table(self, title: str) -> Table:
Expand Down Expand Up @@ -581,10 +588,10 @@ def build_row(self, case: ReportCase) -> list[str]:
row.append(self.output_renderer.render_value(None, case.output) or EMPTY_CELL_STR)

if self.include_scores:
row.append(self._render_dict({k: v.value for k, v in case.scores.items()}, self.score_renderers))
row.append(self._render_dict({k: v for k, v in case.scores.items()}, self.score_renderers))

if self.include_labels:
row.append(self._render_dict({k: v.value for k, v in case.labels.items()}, self.label_renderers))
row.append(self._render_dict({k: v for k, v in case.labels.items()}, self.label_renderers))

if self.include_metrics:
row.append(self._render_dict(case.metrics, self.metric_renderers))
Expand Down Expand Up @@ -779,26 +786,36 @@ def _render_dicts_diff(
diff_lines.append(rendered)
return '\n'.join(diff_lines) if diff_lines else EMPTY_CELL_STR

@staticmethod
def _render_dict(
case_dict: dict[str, T],
self,
case_dict: Mapping[str, EvaluationResult[T] | T],
renderers: Mapping[str, _AbstractRenderer[T]],
*,
include_names: bool = True,
) -> str:
diff_lines: list[str] = []
for key, val in case_dict.items():
rendered = renderers[key].render_value(key if include_names else None, val)
value = cast(EvaluationResult[T], val).value if isinstance(val, EvaluationResult) else val
rendered = renderers[key].render_value(key if include_names else None, value)
if self.include_reasons and isinstance(val, EvaluationResult) and (reason := val.reason):
rendered += f'\n Reason: {reason}\n'
diff_lines.append(rendered)
return '\n'.join(diff_lines) if diff_lines else EMPTY_CELL_STR

@staticmethod
def _render_assertions(
self,
assertions: list[EvaluationResult[bool]],
) -> str:
if not assertions:
return EMPTY_CELL_STR
return ''.join(['[green]✔[/]' if a.value else '[red]✗[/]' for a in assertions])
lines: list[str] = []
for a in assertions:
line = '[green]✔[/]' if a.value else '[red]✗[/]'
if self.include_reasons:
line = f'{a.name}: {line}\n'
line = f'{line} Reason: {a.reason}\n\n' if a.reason else line
lines.append(line)
return ''.join(lines)

@staticmethod
def _render_aggregate_assertions(
Expand Down Expand Up @@ -859,6 +876,10 @@ class EvaluationRenderer:
metric_configs: dict[str, RenderNumberConfig]
duration_config: RenderNumberConfig

# TODO: Make this class kw-only so we can reorder the kwargs
# Data to include
include_reasons: bool # only applies to reports, not to diffs

def include_scores(self, report: EvaluationReport, baseline: EvaluationReport | None = None):
return any(case.scores for case in self._all_cases(report, baseline))

Expand Down Expand Up @@ -905,6 +926,7 @@ def _get_case_renderer(
include_labels=self.include_labels(report, baseline),
include_metrics=self.include_metrics(report, baseline),
include_assertions=self.include_assertions(report, baseline),
include_reasons=self.include_reasons,
include_durations=self.include_durations,
include_total_duration=self.include_total_duration,
input_renderer=input_renderer,
Expand Down
4 changes: 4 additions & 0 deletions tests/evals/test_reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ async def test_evaluation_renderer_basic(sample_report: EvaluationReport):
label_configs={},
metric_configs={},
duration_config={},
include_reasons=False,
)

table = renderer.build_table(sample_report)
Expand Down Expand Up @@ -191,6 +192,7 @@ async def test_evaluation_renderer_with_baseline(sample_report: EvaluationReport
label_configs={},
metric_configs={},
duration_config={},
include_reasons=False,
)

table = renderer.build_diff_table(sample_report, baseline_report)
Expand Down Expand Up @@ -248,6 +250,7 @@ async def test_evaluation_renderer_with_removed_cases(sample_report: EvaluationR
label_configs={},
metric_configs={},
duration_config={},
include_reasons=False,
)

table = renderer.build_diff_table(sample_report, baseline_report)
Expand Down Expand Up @@ -311,6 +314,7 @@ async def test_evaluation_renderer_with_custom_configs(sample_report: Evaluation
'diff_increase_style': 'bold red',
'diff_decrease_style': 'bold green',
},
include_reasons=False,
)

table = renderer.build_table(sample_report)
Expand Down
Loading