Skip to content

Commit 1d4af81

Browse files
committed
fix pytests
1 parent b98f9ea commit 1d4af81

File tree

6 files changed

+39
-9
lines changed

6 files changed

+39
-9
lines changed

index_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_missing_argument(self):
3131
for arg in arguments:
3232
event = {
3333
"message": "Hello, World",
34-
"params": {"conversation_id": "1234Test"}
34+
"params": {"conversation_id": "1234Test", "conversation_history": [{"type": "user", "content": "Hello, World"}]}
3535
}
3636
event.pop(arg)
3737

@@ -42,7 +42,7 @@ def test_missing_argument(self):
4242
def test_correct_arguments(self):
4343
event = {
4444
"message": "Hello, World",
45-
"params": {"conversation_id": "1234Test"}
45+
"params": {"conversation_id": "1234Test", "conversation_history": [{"type": "user", "content": "Hello, World"}]}
4646
}
4747

4848
result = handler(event, None)
@@ -52,7 +52,7 @@ def test_correct_arguments(self):
5252
def test_correct_response(self):
5353
event = {
5454
"message": "Hello, World",
55-
"params": {"conversation_id": "1234Test"}
55+
"params": {"conversation_id": "1234Test", "conversation_history": [{"type": "user", "content": "Hello, World"}]}
5656
}
5757

5858
result = handler(event, None)

src/agents/base_agent/base_agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ def should_summarize(self, state: State) -> str:
144144
messages = state["messages"]
145145
valid_messages = self.check_for_valid_messages(messages)
146146
nr_messages = len(valid_messages)
147+
if len(valid_messages) == 0:
148+
raise Exception("Internal Error: No valid messages found in the conversation history. Conversation history might be empty.")
147149
if "system" in valid_messages[-1].type:
148150
nr_messages -= 1
149151

src/agents/google_learnLM_agent/google_learnLM_agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ def should_summarize(self, state: State) -> str:
148148
messages = state["messages"]
149149
valid_messages = self.check_for_valid_messages(messages)
150150
nr_messages = len(valid_messages)
151+
if len(valid_messages) == 0:
152+
raise Exception("Internal Error: No valid messages found in the conversation history. Conversation history might be empty.")
151153
if "system" in valid_messages[-1].type:
152154
nr_messages -= 1
153155

src/agents/informational_agent/informational_agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def should_summarize(self, state: State) -> str:
147147
messages = state["messages"]
148148
valid_messages = self.check_for_valid_messages(messages)
149149
nr_messages = len(valid_messages)
150+
if len(valid_messages) == 0:
151+
raise Exception("Internal Error: No valid messages found in the conversation history. Conversation history might be empty.")
150152
if "system" in valid_messages[-1].type:
151153
nr_messages -= 1
152154

src/agents/socratic_agent/socratic_agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def should_summarize(self, state: State) -> str:
147147
messages = state["messages"]
148148
valid_messages = self.check_for_valid_messages(messages)
149149
nr_messages = len(valid_messages)
150+
if len(valid_messages) == 0:
151+
raise Exception("Internal Error: No valid messages found in the conversation history. Conversation history might be empty.")
150152
if "system" in valid_messages[-1].type:
151153
nr_messages -= 1
152154

src/module_test.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ class TestChatModuleFunction(unittest.TestCase):
2727

2828
def test_missing_parameters(self):
2929
# Checking state for missing parameters on default agent
30-
response, params = "Hello, World", Params()
31-
expected_params = Params(include_test_data=True, conversation_history=[], \
30+
response = "Hello, World"
31+
expected_params = Params(include_test_data=True, conversation_history=[{ "type": "user", "content": response }], \
3232
summary="", conversational_style="", \
3333
question_response_details={}, conversation_id="1234Test")
3434

3535
for p in expected_params:
3636
params = expected_params.copy()
3737
# except for the special parameters
38-
if p not in ["include_test_data", "conversation_id"]:
38+
if p not in ["include_test_data", "conversation_id", "conversation_history"]:
3939
params.pop(p)
4040

4141
result = chat_module(response, params)
@@ -58,20 +58,42 @@ def test_missing_parameters(self):
5858

5959
self.assertTrue("Internal Error" in str(cm.exception))
6060
self.assertTrue("conversation id" in str(cm.exception))
61+
elif p == "conversation_history":
62+
params.pop(p)
63+
64+
with self.assertRaises(Exception) as cm:
65+
chat_module(response, params)
66+
67+
self.assertTrue("Internal Error" in str(cm.exception))
68+
self.assertTrue("conversation history" in str(cm.exception))
6169

6270
def test_all_agents_output(self):
6371
# Checking the output of the agents
6472
agents = ["informational", "socratic"]
6573
for agent in agents:
66-
response, params = "Hello, World", Params(conversation_id="1234Test", agent_type=agent)
74+
response = "Hello, World"
75+
params = Params(conversation_id="1234Test", agent_type=agent, conversation_history=[{ "type": "user", "content": response }])
6776

6877
result = chat_module(response, params)
6978

7079
self.assertIsNotNone(result.get("chatbot_response"))
71-
80+
81+
def test_unknown_agent_type(self):
82+
agents = ["unknown"]
83+
for agent in agents:
84+
response = "Hello, World"
85+
params = Params(conversation_id="1234Test", agent_type=agent, conversation_history=[{ "type": "user", "content": response }])
86+
87+
with self.assertRaises(Exception) as cm:
88+
chat_module(response, params)
89+
90+
self.assertTrue("Input Parameter Error:" in str(cm.exception))
91+
self.assertTrue("Agent Type" in str(cm.exception))
92+
7293
def test_processing_time_calc(self):
7394
# Checking the processing time calculation
74-
response, params = "Hello, World", Params(include_test_data=True, conversation_id="1234Test")
95+
response = "Hello, World"
96+
params = Params(include_test_data=True, conversation_id="1234Test", conversation_history=[{ "type": "user", "content": response }])
7597

7698
result = chat_module(response, params)
7799

0 commit comments

Comments
 (0)