Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/ucode/agents/pi.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
PI_UCODE_HOME = APP_DIR / "pi-home"
PI_CONFIG_DIR = PI_UCODE_HOME / ".pi" / "agent"
PI_CONFIG_PATH = PI_CONFIG_DIR / "models.json"
PI_SETTINGS_PATH = PI_CONFIG_DIR / "settings.json"
PI_BACKUP_PATH = APP_DIR / "pi-models.backup.json"

SPEC: ToolSpec = {
Expand Down Expand Up @@ -182,11 +183,24 @@ def write_tool_config(
providers.pop(stale, None)
merged = deep_merge_dict(existing, overlay)
write_json_file(PI_CONFIG_PATH, merged)
_write_settings(overlay["model"])
state = mark_tool_managed(state, "pi", managed_keys)
Comment on lines 184 to 187
save_state(state)
return state, token


def _write_settings(model_selector: str) -> None:
# Pin defaultProvider/defaultModel in settings.json so Pi doesn't fall
# through to an env-key-backed provider (e.g. HF_TOKEN exposing
# huggingface) in `findInitialModel` when no --model is passed.
provider, _, model_id = model_selector.partition("/")
if not model_id:
return
existing = read_json_safe(PI_SETTINGS_PATH)
merged = deep_merge_dict(existing, {"defaultProvider": provider, "defaultModel": model_id})
write_json_file(PI_SETTINGS_PATH, merged)


def default_model(state: dict) -> str | None:
"""Prefer Claude opus → sonnet → haiku; fall back to codex, gemini."""
claude_models = state.get("claude_models") or {}
Expand Down
26 changes: 22 additions & 4 deletions tests/test_agent_pi.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,11 @@ def _setup(self, tmp_path, monkeypatch):
monkeypatch.setattr(config_io_mod, "APP_DIR", tmp_path)
config_file = tmp_path / "models.json"
backup_file = tmp_path / "pi-backup.json"
settings_file = tmp_path / "settings.json"
monkeypatch.setattr(pi_mod, "PI_CONFIG_PATH", config_file)
monkeypatch.setattr(pi_mod, "PI_SETTINGS_PATH", settings_file)
monkeypatch.setattr(pi_mod, "PI_BACKUP_PATH", backup_file)
return pi_mod, config_file
return pi_mod, config_file, settings_file

def _state(self, **overrides) -> dict:
state = {
Expand All @@ -284,7 +286,7 @@ def _state(self, **overrides) -> dict:
return state

def test_stale_managed_providers_removed_before_merge(self, tmp_path, monkeypatch):
pi_mod, config_file = self._setup(tmp_path, monkeypatch)
pi_mod, config_file, _ = self._setup(tmp_path, monkeypatch)

stale = {
"providers": {
Expand Down Expand Up @@ -312,7 +314,7 @@ def test_legacy_providers_removed_on_upgrade(self, tmp_path, monkeypatch):
"""Earlier ucode versions wrote `databricks-anthropic`, `databricks-codex`,
and `databricks-oss` providers. They must be stripped on the next write
so users don't end up with stale entries pointing at routes that 400."""
pi_mod, config_file = self._setup(tmp_path, monkeypatch)
pi_mod, config_file, _ = self._setup(tmp_path, monkeypatch)

config_file.write_text(
json.dumps(
Expand All @@ -339,7 +341,7 @@ def test_legacy_providers_removed_on_upgrade(self, tmp_path, monkeypatch):
assert "databricks-claude" in written_providers

def test_config_written_with_correct_model_and_token(self, tmp_path, monkeypatch):
pi_mod, config_file = self._setup(tmp_path, monkeypatch)
pi_mod, config_file, _ = self._setup(tmp_path, monkeypatch)

with (
patch("ucode.agents.pi.get_databricks_token", return_value="tok"),
Expand All @@ -350,3 +352,19 @@ def test_config_written_with_correct_model_and_token(self, tmp_path, monkeypatch
written = json.loads(config_file.read_text())
assert written["model"] == "databricks-claude/claude-sonnet"
assert written["providers"]["databricks-claude"]["apiKey"] == "tok"

def test_settings_pins_default_provider_and_model(self, tmp_path, monkeypatch):
# Without this, Pi's `findInitialModel` can fall through to a built-in
# provider when an unrelated env var (e.g. HF_TOKEN) makes one look
# auth-configured. Pinning the default keeps Pi on our provider.
pi_mod, _, settings_file = self._setup(tmp_path, monkeypatch)

with (
patch("ucode.agents.pi.get_databricks_token", return_value="tok"),
patch("ucode.agents.pi.save_state"),
):
pi_mod.write_tool_config(self._state(), "claude-sonnet", token="tok")

settings = json.loads(settings_file.read_text())
assert settings["defaultProvider"] == "databricks-claude"
assert settings["defaultModel"] == "claude-sonnet"
Loading