Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from rich.theme import Theme

DEFAULT_NUM_RECORDS = 10
DEFAULT_DISPLAY_WIDTH = 110

EPSILON = 1e-8
REPORTING_PRECISION = 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

import html
import json
import logging
import os
import re
from collections import OrderedDict
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Any
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

from rich.console import Console, Group
from rich.padding import Padding
Expand All @@ -25,7 +28,11 @@
from data_designer.config.models import ModelConfig, ModelProvider
from data_designer.config.sampler_params import SamplerType
from data_designer.config.utils.code_lang import code_lang_to_syntax_lexer
from data_designer.config.utils.constants import NVIDIA_API_KEY_ENV_VAR_NAME, OPENAI_API_KEY_ENV_VAR_NAME
from data_designer.config.utils.constants import (
DEFAULT_DISPLAY_WIDTH,
NVIDIA_API_KEY_ENV_VAR_NAME,
OPENAI_API_KEY_ENV_VAR_NAME,
)
from data_designer.config.utils.errors import DatasetSampleDisplayError
from data_designer.config.utils.image_helpers import (
extract_base64_from_data_uri,
Expand All @@ -45,6 +52,7 @@


console = Console()
logger = logging.getLogger(__name__)


def _display_image_if_in_notebook(image_data: str, col_name: str) -> bool:
Expand Down Expand Up @@ -158,6 +166,9 @@ def display_sample_record(
background_color: str | None = None,
processors_to_display: list[str] | None = None,
hide_seed_columns: bool = False,
save_path: str | Path | None = None,
theme: Literal["dark", "light"] = "dark",
display_width: int = DEFAULT_DISPLAY_WIDTH,
) -> None:
"""Display a sample record from the Data Designer dataset preview.

Expand All @@ -170,12 +181,15 @@ def display_sample_record(
documentation from `rich` for information about available background colors.
processors_to_display: List of processors to display the artifacts for. If None, all processors will be displayed.
hide_seed_columns: If True, seed columns will not be displayed separately.
save_path: Optional path to save the rendered output as an HTML or SVG file.
theme: Color theme for saved HTML files (dark or light).
display_width: Width of the rendered output in characters.
"""
i = self._display_cycle_index if index is None else index

num_records = len(self._record_sampler_dataset)
try:
record = self._record_sampler_dataset.iloc[i]
num_records = len(self._record_sampler_dataset)
except IndexError:
raise DatasetSampleDisplayError(f"Index {i} is out of bounds for dataset of length {num_records}.")

Expand Down Expand Up @@ -207,6 +221,9 @@ def display_sample_record(
syntax_highlighting_theme=syntax_highlighting_theme,
record_index=i,
seed_column_names=seed_column_names,
save_path=save_path,
theme=theme,
display_width=display_width,
)
if index is None:
self._display_cycle_index = (self._display_cycle_index + 1) % num_records
Expand Down Expand Up @@ -240,7 +257,10 @@ def display_sample_record(
syntax_highlighting_theme: str = "dracula",
record_index: int | None = None,
seed_column_names: list[str] | None = None,
):
save_path: str | Path | None = None,
theme: Literal["dark", "light"] = "dark",
display_width: int = DEFAULT_DISPLAY_WIDTH,
) -> None:
if isinstance(record, (dict, pd.Series)):
record = pd.DataFrame([record]).iloc[0]
elif isinstance(record, pd.DataFrame):
Expand Down Expand Up @@ -416,7 +436,12 @@ def display_sample_record(
index_label = Text(f"[index: {record_index}]", justify="center")
render_list.append(index_label)

console.print(Group(*render_list), markup=False)
if save_path is not None:
recording_console = Console(record=True, width=display_width)
recording_console.print(Group(*render_list), markup=False)
_save_console_output(recording_console, save_path, theme=theme)
else:
console.print(Group(*render_list), markup=False)

# Display images at the bottom with captions (only in notebook)
if len(images_to_display_later) > 0:
Expand Down Expand Up @@ -540,7 +565,7 @@ def mask_api_key(api_key: str | None) -> str:
return "***" + api_key[-4:] if len(api_key) > 4 else "***"


def convert_to_row_element(elem):
def convert_to_row_element(elem: Any) -> Any:
try:
elem = Pretty(json.loads(elem))
except (TypeError, json.JSONDecodeError):
Expand All @@ -552,7 +577,7 @@ def convert_to_row_element(elem):
return elem


def pad_console_element(elem, padding=(1, 0, 1, 0)):
def pad_console_element(elem: Any, padding: tuple[int, int, int, int] = (1, 0, 1, 0)) -> Padding:
return Padding(elem, padding)


Expand Down Expand Up @@ -622,3 +647,53 @@ def _get_field_constraints(field: dict, schema: dict) -> str:
constraints.append(f"allowed: {', '.join(enum_values.keys())}")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this flips the background to dark blue but Rich's default markup colors (greens, reds, dim text) were designed for white. they're not broken, but some of them feel a bit washed out on the dark bg. the dracula code blocks look great since that theme was made for dark β€” just something to keep an eye on if you want to polish the dark mode later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good callout! For sure would be a good idea to iterate once more folks have seen it on their screen

return ", ".join(constraints)


_SAMPLE_RECORD_DARK_CSS = """
:root { color-scheme: dark; }
html, body { background: #020a1d !important; color: #dbe8ff !important; }
pre, code { color: inherit !important; }
table, th, td { border-color: rgba(184, 210, 255, 0.5) !important; }
"""


def apply_html_post_processing(html_path: str | Path, *, theme: Literal["dark", "light"] = "dark") -> None:
"""Inject viewport meta tag and optional dark-mode CSS into a Rich-exported HTML file."""
path = Path(html_path)
try:
content = path.read_text(encoding="utf-8")
except (FileNotFoundError, UnicodeDecodeError) as exc:
logger.warning("Could not post-process HTML at %s: %s", path, exc)
return

if 'name="viewport"' in content:
return

viewport_tag = '<meta name="viewport" content="width=device-width, initial-scale=1">\n'
injection = viewport_tag

if theme == "dark":
dark_css = _SAMPLE_RECORD_DARK_CSS.strip()
injection += f'<style id="data-designer-styles">\n{dark_css}\n</style>\n'

if re.search(r"</head>", content, flags=re.I):
content = re.sub(r"</head>", lambda m: injection + m.group(), content, count=1, flags=re.I)
else:
content = injection + content
path.write_text(content, encoding="utf-8")


def _save_console_output(
recorded_console: Console, save_path: str | Path, *, theme: Literal["dark", "light"] = "dark"
) -> None:
save_path = str(save_path)
suffix = Path(save_path).suffix.lower()
if suffix == ".html":
recorded_console.save_html(save_path)
apply_html_post_processing(save_path, theme=theme)
elif suffix == ".svg":
recorded_console.save_svg(save_path, title="")
else:
raise DatasetSampleDisplayError(
f"The extension of the save path must be either .html or .svg. You provided {save_path}."
)
Loading