Skip to content

Commit 76d1086

Browse files
authored
Merge pull request #547 from NVIDIA/fix/issue-158-llm-params
Fix LLMParams bug and add unit tests (fixes #158)
2 parents 7aa13e1 + 8f77004 commit 76d1086

File tree

2 files changed

+243
-9
lines changed

2 files changed

+243
-9
lines changed

nemoguardrails/llm/params.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
1919
Also allows registration of custom parameter managers for different language model types.
2020
"""
21+
2122
import logging
2223
from typing import Dict, Type
2324

@@ -41,11 +42,21 @@ def __enter__(self):
4142
if hasattr(self.llm, param):
4243
self.original_params[param] = getattr(self.llm, param)
4344
setattr(self.llm, param, value)
44-
# TODO: Fix the cases where self.llm.model_kwargs is not iterable
45-
# https://github.com/NVIDIA/NeMo-Guardrails/issues/92.
46-
# elif param in getattr(self.llm, "model_kwargs", {}):
47-
# self.original_params[param] = self.llm.model_kwargs[param]
48-
# self.llm.model_kwargs[param] = value
45+
46+
elif hasattr(self.llm, "model_kwargs"):
47+
if param not in self.llm.model_kwargs:
48+
log.warning(
49+
"Parameter %s does not exist for %s. Passing to model_kwargs",
50+
param,
51+
self.llm.__class__.__name__,
52+
)
53+
54+
self.original_params[param] = None
55+
else:
56+
self.original_params[param] = self.llm.model_kwargs[param]
57+
58+
self.llm.model_kwargs[param] = value
59+
4960
else:
5061
log.warning(
5162
"Parameter %s does not exist for %s",
@@ -58,10 +69,11 @@ def __exit__(self, type, value, traceback):
5869
for param, value in self.original_params.items():
5970
if hasattr(self.llm, param):
6071
setattr(self.llm, param, value)
61-
elif hasattr(self.llm, "model_kwargs") and param in getattr(
62-
self.llm, "model_kwargs", {}
63-
):
64-
self.llm.model_kwargs[param] = value
72+
elif hasattr(self.llm, "model_kwargs"):
73+
model_kwargs = getattr(self.llm, "model_kwargs", {})
74+
if param in model_kwargs:
75+
model_kwargs[param] = value
76+
setattr(self.llm, "model_kwargs", model_kwargs)
6577

6678

6779
# The list of registered param managers. This will allow us to override the param manager
@@ -76,6 +88,7 @@ def register_param_manager(llm_type: Type[BaseLanguageModel], manager: Type[LLMP
7688

7789
def llm_params(llm: BaseLanguageModel, **kwargs):
7890
"""Returns a parameter manager for the given language model."""
91+
7992
_llm_params = _param_managers.get(llm.__class__, LLMParams)
8093

8194
return _llm_params(llm, **kwargs)

tests/test_llm_params.py

+221
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
from typing import Any, Dict
18+
19+
from pydantic import BaseModel
20+
21+
from nemoguardrails.llm.params import LLMParams, llm_params, register_param_manager
22+
23+
24+
class FakeLLM(BaseModel):
25+
"""Fake LLM wrapper for testing purposes."""
26+
27+
model_kwargs: Dict[str, Any] = {}
28+
param3: str = ""
29+
30+
31+
class FakeLLM2(BaseModel):
32+
param3: str = ""
33+
34+
35+
class TestLLMParams(unittest.TestCase):
36+
def setUp(self):
37+
self.llm = FakeLLM(
38+
param3="value3", model_kwargs={"param1": "value1", "param2": "value2"}
39+
)
40+
self.llm_params = LLMParams(
41+
self.llm, param1="new_value1", param2="new_value2", param3="new_value3"
42+
)
43+
44+
def test_init(self):
45+
self.assertEqual(self.llm_params.llm, self.llm)
46+
self.assertEqual(
47+
self.llm_params.altered_params,
48+
{"param1": "new_value1", "param2": "new_value2", "param3": "new_value3"},
49+
)
50+
self.assertEqual(self.llm_params.original_params, {})
51+
52+
def test_enter(self):
53+
llm = self.llm
54+
with llm_params(
55+
llm, param1="new_value1", param2="new_value2", param3="new_value3"
56+
):
57+
self.assertEqual(self.llm.param3, "new_value3")
58+
self.assertEqual(self.llm.model_kwargs["param1"], "new_value1")
59+
60+
def test_exit(self):
61+
with self.llm_params:
62+
pass
63+
self.assertEqual(self.llm.model_kwargs["param1"], "value1")
64+
self.assertEqual(self.llm.param3, "value3")
65+
66+
def test_enter_with_nonexistent_param(self):
67+
"""Test that entering the context manager with a nonexistent parameter logs a warning."""
68+
69+
with self.assertLogs(level="WARNING") as cm:
70+
with llm_params(self.llm, nonexistent_param="value"):
71+
pass
72+
self.assertIn(
73+
"Parameter nonexistent_param does not exist for FakeLLM", cm.output[0]
74+
)
75+
76+
def test_exit_with_nonexistent_param(self):
77+
"""Test that exiting the context manager with a nonexistent parameter does not raise an error."""
78+
79+
llm_params = LLMParams(self.llm, nonexistent_param="value")
80+
llm_params.original_params = {"nonexistent_param": "original_value"}
81+
try:
82+
with llm_params:
83+
pass
84+
except Exception as e:
85+
self.fail(f"Exiting the context manager raised an exception: {e}")
86+
87+
88+
class TestLLMParamsWithEmptyModelKwargs(unittest.TestCase):
89+
def setUp(self):
90+
self.llm = FakeLLM(param3="value3", model_kwargs={})
91+
self.llm_params = LLMParams(
92+
self.llm, param1="new_value1", param2="new_value2", param3="new_value3"
93+
)
94+
95+
def test_init(self):
96+
self.assertEqual(self.llm_params.llm, self.llm)
97+
self.assertEqual(
98+
self.llm_params.altered_params,
99+
{"param1": "new_value1", "param2": "new_value2", "param3": "new_value3"},
100+
)
101+
self.assertEqual(self.llm_params.original_params, {})
102+
103+
def test_enter(self):
104+
llm = self.llm
105+
with llm_params(
106+
llm, param1="new_value1", param2="new_value2", param3="new_value3"
107+
):
108+
self.assertEqual(self.llm.param3, "new_value3")
109+
self.assertEqual(self.llm.model_kwargs["param1"], "new_value1")
110+
self.assertEqual(self.llm.model_kwargs["param2"], "new_value2")
111+
112+
def test_exit(self):
113+
with self.llm_params:
114+
pass
115+
self.assertEqual(self.llm.model_kwargs["param1"], None)
116+
self.assertEqual(self.llm.param3, "value3")
117+
118+
def test_enter_with_empty_model_kwargs(self):
119+
"""Test that entering the context manager with empty model_kwargs logs a warning."""
120+
warning_message = f"Parameter param1 does not exist for {self.llm.__class__.__name__}. Passing to model_kwargs"
121+
122+
with self.assertLogs(level="WARNING") as cm:
123+
with llm_params(self.llm, param1="new_value1"):
124+
pass
125+
self.assertIn(
126+
warning_message,
127+
cm.output[0],
128+
)
129+
130+
def test_exit_with_empty_model_kwargs(self):
131+
"""Test that exiting the context manager with empty model_kwargs does not raise an error."""
132+
133+
llm_params = LLMParams(self.llm, param1="new_value1")
134+
llm_params.original_params = {"param1": "original_value"}
135+
try:
136+
with llm_params:
137+
pass
138+
except Exception as e:
139+
self.fail(f"Exiting the context manager raised an exception: {e}")
140+
141+
142+
class TestLLMParamsWithoutModelKwargs(unittest.TestCase):
143+
def setUp(self):
144+
self.llm = FakeLLM2(param3="value3")
145+
self.llm_params = LLMParams(
146+
self.llm, param1="new_value1", param2="new_value2", param3="new_value3"
147+
)
148+
149+
def test_init(self):
150+
self.assertEqual(self.llm_params.llm, self.llm)
151+
self.assertEqual(
152+
self.llm_params.altered_params,
153+
{"param1": "new_value1", "param2": "new_value2", "param3": "new_value3"},
154+
)
155+
self.assertEqual(self.llm_params.original_params, {})
156+
157+
def test_enter(self):
158+
llm = self.llm
159+
with llm_params(
160+
llm, param1="new_value1", param2="new_value2", param3="new_value3"
161+
):
162+
self.assertEqual(self.llm.param3, "new_value3")
163+
164+
def test_exit(self):
165+
with self.llm_params:
166+
pass
167+
self.assertEqual(self.llm.param3, "value3")
168+
169+
def test_enter_with_empty_model_kwargs(self):
170+
"""Test that entering the context manager with empty model_kwargs logs a warning."""
171+
warning_message = (
172+
f"Parameter param1 does not exist for {self.llm.__class__.__name__}"
173+
)
174+
with self.assertLogs(level="WARNING") as cm:
175+
with llm_params(self.llm, param1="new_value1"):
176+
pass
177+
self.assertIn(
178+
warning_message,
179+
cm.output[0],
180+
)
181+
182+
def test_exit_with_empty_model_kwargs(self):
183+
"""Test that exiting the context manager with empty model_kwargs does not raise an error."""
184+
185+
llm_params = LLMParams(self.llm, param1="new_value1")
186+
llm_params.original_params = {"param1": "original_value"}
187+
try:
188+
with llm_params:
189+
pass
190+
except Exception as e:
191+
self.fail(f"Exiting the context manager raised an exception: {e}")
192+
193+
194+
class TestRegisterParamManager(unittest.TestCase):
195+
def test_register_param_manager(self):
196+
"""Test that a custom parameter manager can be registered and retrieved."""
197+
198+
class CustomLLMParams(LLMParams):
199+
pass
200+
201+
register_param_manager(FakeLLM, CustomLLMParams)
202+
self.assertEqual(llm_params(FakeLLM()).__class__, CustomLLMParams)
203+
204+
205+
class TestLLMParamsFunction(unittest.TestCase):
206+
def test_llm_params_with_registered_manager(self):
207+
"""Test that llm_params returns the registered manager for a given LLM type."""
208+
209+
class CustomLLMParams(LLMParams):
210+
pass
211+
212+
register_param_manager(FakeLLM, CustomLLMParams)
213+
self.assertIsInstance(llm_params(FakeLLM()), CustomLLMParams)
214+
215+
def test_llm_params_with_unregistered_manager(self):
216+
"""Test that llm_params returns the default manager for an unregistered LLM type."""
217+
218+
class UnregisteredLLM(BaseModel):
219+
pass
220+
221+
self.assertIsInstance(llm_params(UnregisteredLLM()), LLMParams)

0 commit comments

Comments
 (0)