Skip to content

Commit 7ae2f3a

Browse files
committed
feat(runnable): complete rewrite of RunnableRails with full LangChain Runnable protocol support
- Implement comprehensive async/sync invoke, batch, and streaming support - Add robust input/output transformation for all LangChain formats (ChatPromptValue, BaseMessage, dict, string) - Enhance chaining behavior with intelligent __or__ method handling RunnableBinding and complex chains - Add concurrency controls, error handling, and configurable blocking messages - Implement proper tool calling support with tool call passthrough - Add extensive test suite (14 test files, 2800+ lines) covering all major functionality including batching, streaming, composition, piping, and tool calling - Reorganize and expand test structure for better maintainability
1 parent 0294474 commit 7ae2f3a

15 files changed

+3021
-112
lines changed

nemoguardrails/actions/llm/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,10 @@ def _convert_messages_to_langchain_format(prompt: List[dict]) -> List:
167167
return messages
168168

169169

170-
except Exception as e:
171-
raise LLMCallException(e)
170+
def _store_tool_calls(response) -> None:
171+
"""Extract and store tool calls from response in context."""
172+
tool_calls = getattr(response, "tool_calls", None)
173+
tool_calls_var.set(tool_calls)
172174

173175

174176
def _extract_content(response) -> str:

nemoguardrails/integrations/langchain/runnable_rails.py

Lines changed: 817 additions & 110 deletions
Large diffs are not rendered by default.
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
Tests for basic RunnableRails operations (invoke, async, batch, stream).
18+
"""
19+
20+
import pytest
21+
from langchain_core.messages import AIMessage, HumanMessage
22+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
23+
from langchain_core.runnables import RunnablePassthrough
24+
25+
from nemoguardrails import RailsConfig
26+
from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails
27+
from tests.utils import FakeLLM
28+
29+
30+
def test_updated_runnable_rails_basic():
31+
"""Test basic functionality of updated RunnableRails."""
32+
llm = FakeLLM(
33+
responses=[
34+
"Hello there! How can I help you today?",
35+
]
36+
)
37+
config = RailsConfig.from_content(config={"models": []})
38+
model_with_rails = RunnableRails(config, llm=llm)
39+
40+
result = model_with_rails.invoke("Hi there")
41+
42+
assert isinstance(result, str)
43+
assert "Hello there" in result
44+
45+
46+
async def test_updated_runnable_rails_async():
47+
"""Test async functionality of updated RunnableRails."""
48+
llm = FakeLLM(
49+
responses=[
50+
"Hello there! How can I help you today?",
51+
]
52+
)
53+
config = RailsConfig.from_content(config={"models": []})
54+
model_with_rails = RunnableRails(config, llm=llm)
55+
56+
result = await model_with_rails.ainvoke("Hi there")
57+
58+
assert isinstance(result, str)
59+
assert "Hello there" in result
60+
61+
62+
def test_updated_runnable_rails_batch():
63+
"""Test batch functionality of updated RunnableRails."""
64+
llm = FakeLLM(
65+
responses=[
66+
"Response 1",
67+
"Response 2",
68+
]
69+
)
70+
config = RailsConfig.from_content(config={"models": []})
71+
model_with_rails = RunnableRails(config, llm=llm)
72+
73+
results = model_with_rails.batch(["Question 1", "Question 2"])
74+
75+
assert len(results) == 2
76+
assert results[0] == "Response 1"
77+
assert results[1] == "Response 2"
78+
79+
80+
def test_updated_runnable_rails_stream():
81+
"""Test streaming functionality of updated RunnableRails."""
82+
llm = FakeLLM(
83+
responses=[
84+
"Hello there!",
85+
]
86+
)
87+
config = RailsConfig.from_content(config={"models": []})
88+
model_with_rails = RunnableRails(config, llm=llm)
89+
90+
chunks = []
91+
for chunk in model_with_rails.stream("Hi there"):
92+
chunks.append(chunk)
93+
94+
assert len(chunks) == 2
95+
assert chunks[0].content == "Hello "
96+
assert chunks[1].content == "there!"
97+
98+
99+
def test_runnable_rails_with_message_history():
100+
"""Test handling of message history with updated RunnableRails."""
101+
llm = FakeLLM(
102+
responses=[
103+
"Yes, Paris is the capital of France.",
104+
]
105+
)
106+
config = RailsConfig.from_content(config={"models": []})
107+
model_with_rails = RunnableRails(config, llm=llm)
108+
109+
history = [
110+
HumanMessage(content="Hello"),
111+
AIMessage(content="Hi there!"),
112+
HumanMessage(content="What's the capital of France?"),
113+
]
114+
115+
result = model_with_rails.invoke(history)
116+
117+
assert isinstance(result, AIMessage)
118+
assert "Paris" in result.content
119+
120+
121+
def test_runnable_rails_with_chat_template():
122+
"""Test updated RunnableRails with chat templates."""
123+
llm = FakeLLM(
124+
responses=[
125+
"Yes, Paris is the capital of France.",
126+
]
127+
)
128+
config = RailsConfig.from_content(config={"models": []})
129+
model_with_rails = RunnableRails(config, llm=llm)
130+
131+
prompt = ChatPromptTemplate.from_messages(
132+
[
133+
MessagesPlaceholder(variable_name="history"),
134+
("human", "{question}"),
135+
]
136+
)
137+
138+
chain = prompt | model_with_rails
139+
140+
result = chain.invoke(
141+
{
142+
"history": [
143+
HumanMessage(content="Hello"),
144+
AIMessage(content="Hi there!"),
145+
],
146+
"question": "What's the capital of France?",
147+
}
148+
)
149+
150+
assert isinstance(result, AIMessage)
151+
assert "Paris" in result.content
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for batch_as_completed methods."""
17+
18+
import pytest
19+
20+
from nemoguardrails import RailsConfig
21+
from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails
22+
from tests.utils import FakeLLM
23+
24+
25+
@pytest.fixture
26+
def rails():
27+
"""Create a RunnableRails instance for testing."""
28+
config = RailsConfig.from_content(config={"models": []})
29+
llm = FakeLLM(responses=["response 1", "response 2", "response 3"])
30+
return RunnableRails(config, llm=llm)
31+
32+
33+
def test_batch_as_completed_exists(rails):
34+
"""Test that batch_as_completed method exists."""
35+
# Check if method exists - this should pass if it's inherited from base class
36+
assert hasattr(rails, "batch_as_completed")
37+
38+
39+
@pytest.mark.asyncio
40+
async def test_abatch_as_completed_exists(rails):
41+
"""Test that abatch_as_completed method exists."""
42+
# Check if async version exists
43+
assert hasattr(rails, "abatch_as_completed")

tests/runnable_rails/test_batching.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import pytest
17+
from langchain_core.messages import AIMessage, HumanMessage
18+
19+
from nemoguardrails import RailsConfig
20+
from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails
21+
from tests.utils import FakeLLM
22+
23+
24+
def test_batch_processing():
25+
"""Test batch processing of multiple inputs."""
26+
llm = FakeLLM(
27+
responses=[
28+
"Paris.",
29+
"Rome.",
30+
"Berlin.",
31+
]
32+
)
33+
config = RailsConfig.from_content(config={"models": []})
34+
model_with_rails = RunnableRails(config, llm=llm)
35+
36+
inputs = [
37+
"What's the capital of France?",
38+
"What's the capital of Italy?",
39+
"What's the capital of Germany?",
40+
]
41+
42+
results = model_with_rails.batch(inputs)
43+
44+
assert len(results) == 3
45+
assert results[0] == "Paris."
46+
assert results[1] == "Rome."
47+
assert results[2] == "Berlin."
48+
49+
50+
@pytest.mark.asyncio
51+
async def test_abatch_processing():
52+
"""Test async batch processing of multiple inputs."""
53+
llm = FakeLLM(
54+
responses=[
55+
"Paris.",
56+
"Rome.",
57+
"Berlin.",
58+
]
59+
)
60+
config = RailsConfig.from_content(config={"models": []})
61+
model_with_rails = RunnableRails(config, llm=llm)
62+
63+
inputs = [
64+
"What's the capital of France?",
65+
"What's the capital of Italy?",
66+
"What's the capital of Germany?",
67+
]
68+
69+
results = await model_with_rails.abatch(inputs)
70+
71+
assert len(results) == 3
72+
assert results[0] == "Paris."
73+
assert results[1] == "Rome."
74+
assert results[2] == "Berlin."
75+
76+
77+
def test_batch_with_different_input_types():
78+
"""Test batch processing with different input types."""
79+
llm = FakeLLM(
80+
responses=[
81+
"Paris.",
82+
"Rome.",
83+
"Berlin.",
84+
]
85+
)
86+
config = RailsConfig.from_content(config={"models": []})
87+
model_with_rails = RunnableRails(config, llm=llm)
88+
89+
inputs = [
90+
"What's the capital of France?",
91+
HumanMessage(content="What's the capital of Italy?"),
92+
{"input": "What's the capital of Germany?"},
93+
]
94+
95+
results = model_with_rails.batch(inputs)
96+
97+
assert len(results) == 3
98+
assert results[0] == "Paris."
99+
assert isinstance(results[1], AIMessage)
100+
assert results[1].content == "Rome."
101+
assert isinstance(results[2], dict)
102+
assert results[2]["output"] == "Berlin."
103+
104+
105+
def test_stream_output():
106+
"""Test streaming output (simplified for now)."""
107+
llm = FakeLLM(
108+
responses=[
109+
"Paris.",
110+
]
111+
)
112+
config = RailsConfig.from_content(config={"models": []})
113+
model_with_rails = RunnableRails(config, llm=llm)
114+
115+
# Collect all chunks from the stream
116+
chunks = []
117+
for chunk in model_with_rails.stream("What's the capital of France?"):
118+
chunks.append(chunk)
119+
120+
# Currently, stream just yields the full response as a single chunk
121+
assert len(chunks) == 1
122+
assert chunks[0].content == "Paris."
123+
124+
125+
@pytest.mark.asyncio
126+
async def test_astream_output():
127+
"""Test async streaming output (simplified for now)."""
128+
llm = FakeLLM(
129+
responses=[
130+
"hello what can you do?",
131+
],
132+
streaming=True,
133+
)
134+
config = RailsConfig.from_content(config={"models": [], "streaming": True})
135+
model_with_rails = RunnableRails(config, llm=llm)
136+
137+
# Collect all chunks from the stream
138+
chunks = []
139+
async for chunk in model_with_rails.astream("What's the capital of France?"):
140+
chunks.append(chunk)
141+
142+
# Stream should yield individual word chunks
143+
assert len(chunks) == 5
144+
assert chunks[0].content == "hello "
145+
assert chunks[1].content == "what "
146+
assert chunks[2].content == "can "
147+
assert chunks[3].content == "you "
148+
assert chunks[4].content == "do?"

0 commit comments

Comments
 (0)