Skip to content

Commit 25a353b

Browse files
committed
feat(model): validate model name configuration
Implement proper model name validation in the Model class to ensure that a model name is specified exactly once, either directly via the 'model' field or through parameters. This also normalizes the configuration by moving model names from parameters to the dedicated model field. Add tests to verify all validation scenarios fix: ensure custom_model config has model
1 parent 34a94c6 commit 25a353b

File tree

3 files changed

+168
-1
lines changed

3 files changed

+168
-1
lines changed

nemoguardrails/rails/llm/config.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@
2323
from typing import Any, Dict, List, Optional, Set, Tuple, Union
2424

2525
import yaml
26-
from pydantic import BaseModel, ConfigDict, ValidationError, root_validator
26+
from pydantic import (
27+
BaseModel,
28+
ConfigDict,
29+
ValidationError,
30+
model_validator,
31+
root_validator,
32+
)
2733
from pydantic.fields import Field
2834

2935
from nemoguardrails import utils
@@ -103,6 +109,40 @@ class Model(BaseModel):
103109
)
104110
parameters: Dict[str, Any] = Field(default_factory=dict)
105111

112+
@model_validator(mode="before")
113+
@classmethod
114+
def set_and_validate_model(cls, data: Any) -> Any:
115+
if isinstance(data, dict):
116+
parameters = data.get("parameters")
117+
if parameters is None:
118+
return data
119+
model_field = data.get("model")
120+
model_from_params = parameters.get("model_name") or parameters.get("model")
121+
122+
if model_field and model_from_params:
123+
raise ValueError(
124+
"Model name must be specified in exactly one place: either in the 'model' field or in parameters, not both."
125+
)
126+
if not model_field and model_from_params:
127+
data["model"] = model_from_params
128+
if (
129+
"model_name" in parameters
130+
and parameters["model_name"] == model_from_params
131+
):
132+
parameters.pop("model_name")
133+
elif "model" in parameters and parameters["model"] == model_from_params:
134+
parameters.pop("model")
135+
return data
136+
137+
@model_validator(mode="after")
138+
def model_must_be_non_empty(self) -> "Model":
139+
"""Validate that a model name is present either directly or in parameters."""
140+
if not self.model or not self.model.strip():
141+
raise ValueError(
142+
"Model name must be specified either directly in the 'model' field or through 'model_name'/'model' in parameters"
143+
)
144+
return self
145+
106146

107147
class Instruction(BaseModel):
108148
"""Configuration for instructions in natural language that should be passed to the LLM."""
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
models:
22
- type: main
33
engine: custom_llm
4+
parameters:
5+
model: custom_model

tests/test_rails_llm_config.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import pytest
17+
from pydantic import ValidationError
18+
19+
from nemoguardrails.rails.llm.config import Model
20+
21+
22+
def test_explicit_model_param():
23+
"""Test model specified directly via the model parameter."""
24+
model = Model(type="main", engine="test_engine", model="test_model")
25+
assert model.model == "test_model"
26+
assert "model" not in model.parameters
27+
28+
29+
def test_model_in_parameters():
30+
"""Test model specified via parameters dictionary."""
31+
model = Model(type="main", engine="test_engine", parameters={"model": "test_model"})
32+
assert model.model == "test_model"
33+
assert "model" not in model.parameters
34+
35+
36+
def test_model_name_in_parameters():
37+
"""Test model specified via model_name in parameters dictionary."""
38+
model = Model(
39+
type="main", engine="test_engine", parameters={"model_name": "test_model"}
40+
)
41+
assert model.model == "test_model"
42+
assert "model_name" not in model.parameters
43+
44+
45+
def test_model_equivalence():
46+
"""Test that models defined in different ways are considered equivalent."""
47+
model1 = Model(type="main", engine="test_engine", model="test_model")
48+
model2 = Model(
49+
type="main", engine="test_engine", parameters={"model": "test_model"}
50+
)
51+
assert model1 == model2
52+
53+
54+
def test_empty_model_and_parameters():
55+
"""Test that an empty model and parameters dict fails validation."""
56+
with pytest.raises(ValueError, match="Model name must be specified"):
57+
Model(type="main", engine="openai", parameters={})
58+
59+
60+
def test_none_model_and_parameters():
61+
"""Test that None model and empty parameters dict fails validation."""
62+
with pytest.raises(ValueError, match="Model name must be specified"):
63+
Model(type="main", engine="openai", model=None, parameters={})
64+
65+
66+
def test_none_model_and_none_parameters():
67+
"""Test that None model and None parameters fails validation."""
68+
with pytest.raises(ValueError):
69+
Model(type="main", engine="openai", model=None, parameters=None)
70+
71+
72+
def test_model_and_model_name_in_parameters():
73+
"""Test that having both model and model_name in parameters raises an error."""
74+
with pytest.raises(
75+
ValueError, match="Model name must be specified in exactly one place"
76+
):
77+
Model(
78+
type="main",
79+
engine="openai",
80+
model="gpt-4",
81+
parameters={"model_name": "gpt-3.5-turbo"},
82+
)
83+
84+
85+
def test_model_and_model_in_parameters():
86+
"""Test that having both model field and model in parameters raises an error."""
87+
with pytest.raises(
88+
ValueError, match="Model name must be specified in exactly one place"
89+
):
90+
Model(
91+
type="main",
92+
engine="openai",
93+
model="gpt-4",
94+
parameters={"model": "gpt-3.5-turbo"},
95+
)
96+
97+
98+
def test_empty_string_model():
99+
"""Test that an empty string model fails validation."""
100+
with pytest.raises(ValueError, match="Model name must be specified"):
101+
Model(type="main", engine="openai", model="", parameters={})
102+
103+
104+
def test_whitespace_only_model():
105+
"""Test that a whitespace-only model fails validation."""
106+
with pytest.raises(ValueError, match="Model name must be specified"):
107+
Model(type="main", engine="openai", model=" ", parameters={})
108+
109+
110+
def test_empty_string_model_name_in_parameters():
111+
"""Test that an empty string model_name in parameters fails validation."""
112+
with pytest.raises(ValueError, match="Model name must be specified"):
113+
Model(type="main", engine="openai", parameters={"model_name": ""})
114+
115+
116+
def test_whitespace_only_model_in_parameters():
117+
"""Test that a whitespace-only model in parameters fails validation."""
118+
with pytest.raises(ValueError, match="Model name must be specified"):
119+
Model(type="main", engine="openai", parameters={"model": " "})
120+
121+
122+
def test_model_name_none_in_parameters():
123+
"""Test that None model_name in parameters fails validation."""
124+
with pytest.raises(ValueError, match="Model name must be specified"):
125+
Model(type="main", engine="openai", parameters={"model_name": None})

0 commit comments

Comments
 (0)