Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 64 additions & 18 deletions src/together/lib/cli/api/beta/jig/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import json
from typing import TYPE_CHECKING, Any, Optional
from pathlib import Path
from dataclasses import field, asdict, dataclass
from dataclasses import field, asdict

import click
from pydantic import ValidationError
from pydantic.dataclasses import dataclass

if TYPE_CHECKING:
import tomli as tomllib
Expand All @@ -36,6 +38,25 @@
# --- Configuration Dataclasses ---


def _format_validation_errors(
exc: ValidationError,
prefix: str,
section: str,
path: Path,
) -> str:
"""Format a pydantic ValidationError with file context."""
header = f"Configuration error in {path}"
if prefix:
header += f" [{prefix}.{section}]" if section else f" [{prefix}]"
elif section:
header += f" [{section}]"
lines = [header + ":"]
for e in exc.errors():
loc = " -> ".join(str(part) for part in e["loc"])
lines.append(f" - {loc}: {e['msg']}")
return "\n".join(lines)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More idiomatic way of doing it with pydantic might be just using the @field_validator on every field: https://docs.pydantic.dev/latest/errors/errors/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels too extreme to have validators for each field. Also didn't want to fully commit to pydantic yet assuming we might move to msgspec



@dataclass
class ImageConfig:
"""Container image configuration from pyproject.toml"""
Expand All @@ -45,12 +66,15 @@ class ImageConfig:
environment: dict[str, str] = field(default_factory=dict[str, str])
run: list[str] = field(default_factory=list[str])
cmd: str = "python app.py"
copy: list[str] = field(default_factory=list[str])
copy_files: list[str] = field(default_factory=list[str])
auto_include_git: bool = False

@classmethod
def from_dict(cls, data: dict[str, Any]) -> ImageConfig:
return cls(**{k: v for k, v in data.items() if k in cls.__annotations__})
mapped = {k: v for k, v in data.items() if k in cls.__annotations__}
if "copy" in data:
mapped["copy_files"] = data["copy"]
return cls(**mapped)


@dataclass
Expand Down Expand Up @@ -101,8 +125,8 @@ class Config:
dockerfile: str = "Dockerfile"
image: ImageConfig = field(default_factory=ImageConfig)
deploy: DeployConfig = field(default_factory=DeployConfig)
_path: Path = field(default_factory=lambda: Path("pyproject.toml"))
_unique_name_tip: str = "Update project.name in pyproject.toml"
config_path: Path = field(default_factory=lambda: Path("pyproject.toml"))
unique_name_tip: str = "Update project.name in pyproject.toml"

@classmethod
def find(cls, config_path: Optional[str] = None, init: bool = False) -> Config:
Expand Down Expand Up @@ -161,13 +185,33 @@ def load(cls, data: dict[str, Any], path: Path) -> Config:
# Support volume_mounts at jig level (merge into deploy config)
jig_config["deploy"]["volume_mounts"] = jig_config.get("volume_mounts", [])

prefix = "tool.jig" if is_pyproject else ""
errors: list[str] = []

try:
image = ImageConfig.from_dict(jig_config.get("image", {}))
except ValidationError as exc:
errors.append(_format_validation_errors(exc, prefix, "image", path))
image = None

try:
deploy = DeployConfig.from_dict(jig_config.get("deploy", {}))
except ValidationError as exc:
errors.append(_format_validation_errors(exc, prefix, "deploy", path))
deploy = None

if errors:
click.echo("\n\n".join(errors), err=True)
sys.exit(1)

assert image is not None and deploy is not None
return cls(
image=ImageConfig.from_dict(jig_config.get("image", {})),
deploy=DeployConfig.from_dict(jig_config.get("deploy", {})),
image=image,
deploy=deploy,
dockerfile=jig_config.get("dockerfile", "Dockerfile"),
model_name=name,
_path=path,
_unique_name_tip=tip,
config_path=path,
unique_name_tip=tip,
)


Expand All @@ -178,16 +222,18 @@ def load(cls, data: dict[str, Any], path: Path) -> Config:
class State:
"""Persistent state stored in .jig.json"""

_config_dir: Path
_project_name: str
config_dir: Path
project_name: str
registry_base_path: str = ""
secrets: dict[str, str] = field(default_factory=dict[str, str])
volumes: dict[str, str] = field(default_factory=dict[str, str])

@classmethod
def from_dict(cls, config_dir: Path, project_name: str, **data: Any) -> State:
filtered = {k: v for k, v in data.items() if k in cls.__annotations__ and not k.startswith("_")}
return cls(_config_dir=config_dir, _project_name=project_name, **filtered)
filtered = {
k: v for k, v in data.items() if k in cls.__annotations__ and k not in ("config_dir", "project_name")
}
return cls(config_dir=config_dir, project_name=project_name, **filtered)

@classmethod
def load(cls, config_dir: Path, project_name: str) -> State:
Expand Down Expand Up @@ -218,16 +264,16 @@ def load(cls, config_dir: Path, project_name: str) -> State:
if "secrets" in all_data or "volumes" in all_data:
return cls.from_dict(config_dir, project_name, **all_data)
# File exists but this project isn't in it yet
return cls(_config_dir=config_dir, _project_name=project_name)
return cls(config_dir=config_dir, project_name=project_name)
except FileNotFoundError:
return cls(_config_dir=config_dir, _project_name=project_name)
return cls(config_dir=config_dir, project_name=project_name)

def save(self) -> None:
"""Save state for this project to .jig.json.

Preserves other projects' state in the same file.
"""
path = self._config_dir / ".jig.json"
path = self.config_dir / ".jig.json"

# Load existing file to preserve other projects
try:
Expand All @@ -237,8 +283,8 @@ def save(self) -> None:
all_data = {}

# Update this project's state
project_data = {k: v for k, v in asdict(self).items() if not k.startswith("_")}
all_data[self._project_name] = project_data
project_data = {k: v for k, v in asdict(self).items() if k not in ("config_dir", "project_name")}
all_data[self.project_name] = project_data

# Save back to file
with open(path, "w") as f:
Expand Down
16 changes: 10 additions & 6 deletions src/together/lib/cli/api/beta/jig/jig.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _generate_dockerfile(config: Config) -> str:

def _get_files_to_copy(config: Config) -> list[str]:
"""Get list of files to copy"""
files = set(config.image.copy)
files = set(config.image.copy_files)
if config.image.auto_include_git:
try:
if _run(["git", "status", "--porcelain"]).stdout.strip():
Expand Down Expand Up @@ -147,7 +147,11 @@ def _dockerfile(config: Config) -> bool:
return False

# Skip regeneration if config hasn't changed
if config._path and config._path.exists() and dockerfile_path.stat().st_mtime >= config._path.stat().st_mtime:
if (
config.config_path
and config.config_path.exists()
and dockerfile_path.stat().st_mtime >= config.config_path.stat().st_mtime
):
return True

with open(dockerfile_path, "w") as f:
Expand Down Expand Up @@ -373,7 +377,7 @@ def build(

client: Together = ctx.obj
config = Config.find(config_path)
state = State.load(config._path.parent, config.model_name)
state = State.load(config.config_path.parent, config.model_name)
_ensure_registry_base_path(client, state)

image = _get_image(state, config, tag)
Expand Down Expand Up @@ -407,7 +411,7 @@ def push(ctx: click.Context, tag: str, config_path: str | None) -> None:
"""Push image to registry"""
client: Together = ctx.obj
config = Config.find(config_path)
state = State.load(config._path.parent, config.model_name)
state = State.load(config.config_path.parent, config.model_name)
_ensure_registry_base_path(client, state)

image = _get_image(state, config, tag)
Expand Down Expand Up @@ -441,7 +445,7 @@ def deploy(
"""Deploy model"""
client: Together = ctx.obj
config = Config.find(config_path)
state = State.load(config._path.parent, config.model_name)
state = State.load(config.config_path.parent, config.model_name)
_ensure_registry_base_path(client, state)

if existing_image:
Expand Down Expand Up @@ -528,7 +532,7 @@ def handle_create() -> dict[str, Any]:
error_body: Any = getattr(e, "body", None)
error_message = error_body.get("error", "") if isinstance(error_body, dict) else "" # pyright: ignore
if "already exists" in error_message or "must be unique" in error_message:
raise RuntimeError(f"Deployment name must be unique. Tip: {config._unique_name_tip}") from None
raise RuntimeError(f"Deployment name must be unique. Tip: {config.unique_name_tip}") from None
# TODO: helpful tips for more error cases
raise

Expand Down
6 changes: 3 additions & 3 deletions src/together/lib/cli/api/beta/jig/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def secrets_set(
"""Set a secret (create or update)"""
client: Together = ctx.obj
config = Config.find(config_path)
state = State.load(config._path.parent, config.model_name)
state = State.load(config.config_path.parent, config.model_name)

deployment_secret_name = f"{config.model_name}-{name}"

Expand Down Expand Up @@ -76,7 +76,7 @@ def secrets_unset(
) -> None:
"""Remove a secret from both remote and local state"""
config = Config.find(config_path)
state = State.load(config._path.parent, config.model_name)
state = State.load(config.config_path.parent, config.model_name)

if state.secrets.pop(name, ""):
state.save()
Expand All @@ -96,7 +96,7 @@ def secrets_list(
"""List all secrets with sync status"""
client: Together = ctx.obj
config = Config.find(config_path)
state = State.load(config._path.parent, config.model_name)
state = State.load(config.config_path.parent, config.model_name)

prefix = f"{config.model_name}-"

Expand Down