Skip to content

Commit baebccb

Browse files
committed
test: _reduce_messages in LangChainAgent
1 parent 89ec76c commit baebccb

File tree

2 files changed

+61
-24
lines changed

2 files changed

+61
-24
lines changed

src/rai_core/rai/agents/langchain/agent.py

+34-24
Original file line numberDiff line numberDiff line change
@@ -33,25 +33,28 @@ class BaseState(TypedDict):
3333
messages: List[BaseMessage]
3434

3535

36+
newMessageBehaviorType = Literal[
37+
"take_all",
38+
"keep_last",
39+
"queue",
40+
"interuppt_take_all",
41+
"interuppt_keep_last",
42+
]
43+
44+
3645
class LangChainAgent(BaseAgent):
3746
def __init__(
3847
self,
3948
target_connectors: Dict[str, HRIConnector[HRIMessage]],
4049
runnable: Runnable,
4150
state: BaseState | None = None,
42-
new_message_behavior: Literal[
43-
"take_all",
44-
"keep_last",
45-
"queue",
46-
"interuppt_take_all",
47-
"interuppt_keep_last",
48-
] = "interuppt_keep_last",
51+
new_message_behavior: newMessageBehaviorType = "interuppt_keep_last",
4952
max_size: int = 100,
5053
):
5154
super().__init__()
5255
self.logger = logging.getLogger(__name__)
5356
self.agent = runnable
54-
self.new_message_behavior = new_message_behavior
57+
self.new_message_behavior: newMessageBehaviorType = new_message_behavior
5558
self.tracing_callbacks = get_tracing_callbacks()
5659
self.state = state or ReActAgentState(messages=[])
5760
self._langchain_callback = HRICallbackHandler(
@@ -141,26 +144,33 @@ def stop(self):
141144
self.thread = None
142145
self.logger.info("Agent stopped")
143146

144-
def _reduce_messages(self) -> HRIMessage:
145-
text = ""
146-
images = []
147-
audios = []
148-
source_messages = list()
149-
if "take_all" in self.new_message_behavior:
147+
@staticmethod
148+
def _apply_reduction_behavior(
149+
method: newMessageBehaviorType, buffer: Deque
150+
) -> List:
151+
output = list()
152+
if "take_all" in method:
150153
# Take all starting from the oldest
151-
while len(self._received_messages) > 0:
152-
source_messages.append(self._received_messages.popleft())
153-
elif "keep_last" in self.new_message_behavior:
154+
while len(buffer) > 0:
155+
output.append(buffer.popleft())
156+
elif "keep_last" in method:
154157
# Take the recently added message
155-
source_messages.append(self._received_messages.pop())
156-
self._received_messages.clear()
157-
elif self.new_message_behavior == "queue":
158+
output.append(buffer.pop())
159+
buffer.clear()
160+
elif method == "queue":
158161
# Take the first message from the queue. Let other messages wait.
159-
source_messages.append(self._received_messages.popleft())
162+
output.append(buffer.popleft())
160163
else:
161-
raise ValueError(
162-
f"Invalid new_message_behavior: {self.new_message_behavior}"
163-
)
164+
raise ValueError(f"Invalid new_message_behavior: {method}")
165+
return output
166+
167+
def _reduce_messages(self) -> HRIMessage:
168+
text = ""
169+
images = []
170+
audios = []
171+
source_messages = self._apply_reduction_behavior(
172+
self.new_message_behavior, self._received_messages
173+
)
164174
for source_message in source_messages:
165175
text += f"{source_message.text}\n"
166176
images.extend(source_message.images)

tests/agents/test_langchain_agent.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from collections import deque
2+
from typing import List
3+
4+
import pytest
5+
from rai.agents.langchain.agent import LangChainAgent, newMessageBehaviorType
6+
7+
8+
@pytest.mark.parametrize(
9+
"new_message_behavior,in_buffer,out_buffer,output",
10+
[
11+
("take_all", [1, 2, 3], [], [1, 2, 3]),
12+
("keep_last", [1, 2, 3], [], [3]),
13+
("queue", [1, 2, 3], [2, 3], [1]),
14+
("interuppt_take_all", [1, 2, 3], [], [1, 2, 3]),
15+
("interuppt_keep_last", [1, 2, 3], [], [3]),
16+
],
17+
)
18+
def test_reduce_messages(
19+
new_message_behavior: newMessageBehaviorType,
20+
in_buffer: List,
21+
out_buffer: List,
22+
output: List,
23+
):
24+
buffer = deque(in_buffer)
25+
output = LangChainAgent._apply_reduction_behavior(new_message_behavior, buffer)
26+
assert output == output
27+
assert buffer == deque(out_buffer)

0 commit comments

Comments
 (0)