Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion nemoguardrails/rails/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import yaml
from pydantic import BaseModel, ConfigDict, ValidationError, root_validator
from pydantic import (
BaseModel,
ConfigDict,
ValidationError,
model_validator,
root_validator,
)
from pydantic.fields import Field

from nemoguardrails import utils
Expand Down Expand Up @@ -103,6 +109,40 @@ class Model(BaseModel):
)
parameters: Dict[str, Any] = Field(default_factory=dict)

@model_validator(mode="before")
@classmethod
def set_and_validate_model(cls, data: Any) -> Any:
if isinstance(data, dict):
parameters = data.get("parameters")
if parameters is None:
return data
model_field = data.get("model")
model_from_params = parameters.get("model_name") or parameters.get("model")

if model_field and model_from_params:
raise ValueError(
"Model name must be specified in exactly one place: either in the 'model' field or in parameters, not both."
)
if not model_field and model_from_params:
data["model"] = model_from_params
if (
"model_name" in parameters
and parameters["model_name"] == model_from_params
):
parameters.pop("model_name")
elif "model" in parameters and parameters["model"] == model_from_params:
parameters.pop("model")
return data

@model_validator(mode="after")
def model_must_be_non_empty(self) -> "Model":
"""Validate that a model name is present either directly or in parameters."""
if not self.model or not self.model.strip():
raise ValueError(
"Model name must be specified either directly in the 'model' field or through 'model_name'/'model' in parameters"
)
return self


class Instruction(BaseModel):
"""Configuration for instructions in natural language that should be passed to the LLM."""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_configs/fact_checking/config.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
models:
- type: main
engine: nim
model_name: meta/llama-3.1-70b-instruct
model: meta/llama-3.1-70b-instruct

rails:
config:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_configs/with_custom_llm/config.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
models:
- type: main
engine: custom_llm
parameters:
model: custom_model
125 changes: 125 additions & 0 deletions tests/test_rails_llm_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from pydantic import ValidationError

from nemoguardrails.rails.llm.config import Model


def test_explicit_model_param():
"""Test model specified directly via the model parameter."""
model = Model(type="main", engine="test_engine", model="test_model")
assert model.model == "test_model"
assert "model" not in model.parameters


def test_model_in_parameters():
"""Test model specified via parameters dictionary."""
model = Model(type="main", engine="test_engine", parameters={"model": "test_model"})
assert model.model == "test_model"
assert "model" not in model.parameters


def test_model_name_in_parameters():
"""Test model specified via model_name in parameters dictionary."""
model = Model(
type="main", engine="test_engine", parameters={"model_name": "test_model"}
)
assert model.model == "test_model"
assert "model_name" not in model.parameters


def test_model_equivalence():
"""Test that models defined in different ways are considered equivalent."""
model1 = Model(type="main", engine="test_engine", model="test_model")
model2 = Model(
type="main", engine="test_engine", parameters={"model": "test_model"}
)
assert model1 == model2


def test_empty_model_and_parameters():
"""Test that an empty model and parameters dict fails validation."""
with pytest.raises(ValueError, match="Model name must be specified"):
Model(type="main", engine="openai", parameters={})


def test_none_model_and_parameters():
"""Test that None model and empty parameters dict fails validation."""
with pytest.raises(ValueError, match="Model name must be specified"):
Model(type="main", engine="openai", model=None, parameters={})


def test_none_model_and_none_parameters():
"""Test that None model and None parameters fails validation."""
with pytest.raises(ValueError):
Model(type="main", engine="openai", model=None, parameters=None)


def test_model_and_model_name_in_parameters():
"""Test that having both model and model_name in parameters raises an error."""
with pytest.raises(
ValueError, match="Model name must be specified in exactly one place"
):
Model(
type="main",
engine="openai",
model="gpt-4",
parameters={"model_name": "gpt-3.5-turbo"},
)


def test_model_and_model_in_parameters():
"""Test that having both model field and model in parameters raises an error."""
with pytest.raises(
ValueError, match="Model name must be specified in exactly one place"
):
Model(
type="main",
engine="openai",
model="gpt-4",
parameters={"model": "gpt-3.5-turbo"},
)


def test_empty_string_model():
"""Test that an empty string model fails validation."""
with pytest.raises(ValueError, match="Model name must be specified"):
Model(type="main", engine="openai", model="", parameters={})


def test_whitespace_only_model():
"""Test that a whitespace-only model fails validation."""
with pytest.raises(ValueError, match="Model name must be specified"):
Model(type="main", engine="openai", model=" ", parameters={})


def test_empty_string_model_name_in_parameters():
"""Test that an empty string model_name in parameters fails validation."""
with pytest.raises(ValueError, match="Model name must be specified"):
Model(type="main", engine="openai", parameters={"model_name": ""})


def test_whitespace_only_model_in_parameters():
"""Test that a whitespace-only model in parameters fails validation."""
with pytest.raises(ValueError, match="Model name must be specified"):
Model(type="main", engine="openai", parameters={"model": " "})


def test_model_name_none_in_parameters():
"""Test that None model_name in parameters fails validation."""
with pytest.raises(ValueError, match="Model name must be specified"):
Model(type="main", engine="openai", parameters={"model_name": None})