Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import os
import subprocess
from pathlib import Path
from typing import ClassVar
Expand All @@ -14,6 +15,9 @@
from nemo_safe_synthesizer_plugin.config import config
from nemo_safe_synthesizer_plugin.runtime import runtime_info, runtime_task_command, setup_runtime

NEMO_DEPLOYMENT_TYPE_ENVVAR = "NEMO_DEPLOYMENT_TYPE"
NMP_DEPLOYMENT_TYPE = "nmp"


class SafeSynthesizerCLI(NemoCLI):
"""CLI extensions for host-local Safe Synthesizer development."""
Expand Down Expand Up @@ -100,7 +104,9 @@ def run_local_command(
typer.echo(str(e), err=True)
raise typer.Exit(1) from e

result = subprocess.run(command, check=False)
runtime_env = os.environ.copy()
runtime_env[NEMO_DEPLOYMENT_TYPE_ENVVAR] = NMP_DEPLOYMENT_TYPE
result = subprocess.run(command, check=False, env=runtime_env)
if result.returncode != 0:
raise typer.Exit(result.returncode)
typer.echo(f"Wrote Safe Synthesizer results to {output_dir}")
Expand Down
48 changes: 48 additions & 0 deletions plugins/nemo-safe-synthesizer/tests/unit/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from subprocess import CompletedProcess

from nemo_safe_synthesizer_plugin import cli
from nemo_safe_synthesizer_plugin.cli import NEMO_DEPLOYMENT_TYPE_ENVVAR, NMP_DEPLOYMENT_TYPE, SafeSynthesizerCLI
from typer.testing import CliRunner


def test_run_local_sets_nmp_deployment_type_for_runtime_subprocess(tmp_path, monkeypatch):
spec_file = tmp_path / "nss-job.json"
spec_file.write_text("{}", encoding="utf-8")
data_file = tmp_path / "input.csv"
data_file.write_text("name\nAda\n", encoding="utf-8")
output_dir = tmp_path / "nss-output"
captured = {}

def fake_runtime_task_command(_config, args):
return ["runtime-python", *args]

def fake_run(command, *, check=False, env=None):
captured["command"] = command
captured["check"] = check
captured["env"] = env
return CompletedProcess(command, 0)

monkeypatch.setattr(cli, "runtime_task_command", fake_runtime_task_command)
monkeypatch.setattr(cli.subprocess, "run", fake_run)

result = CliRunner().invoke(
SafeSynthesizerCLI().get_cli(),
[
"run-local",
"--workspace",
"default",
"--spec-file",
str(spec_file),
"--data-source",
str(data_file),
"--output-dir",
str(output_dir),
],
)

assert result.exit_code == 0, result.output
assert captured["check"] is False
assert captured["env"][NEMO_DEPLOYMENT_TYPE_ENVVAR] == NMP_DEPLOYMENT_TYPE