generated from lambda-feedback/Evaluation-Function-Boilerplate
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinformational_agent.py
205 lines (165 loc) · 9.81 KB
/
informational_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
try:
from ..llm_factory import OpenAILLMs, GoogleAILLMs
from .informational_prompts import \
informational_role_prompt, conv_pref_prompt, update_conv_pref_prompt, summary_prompt, update_summary_prompt, summary_system_prompt
from ..utils.types import InvokeAgentResponseType
except ImportError:
from src.agents.llm_factory import OpenAILLMs, GoogleAILLMs
from src.agents.informational_agent.informational_prompts import \
informational_role_prompt, conv_pref_prompt, update_conv_pref_prompt, summary_prompt, update_summary_prompt, summary_system_prompt
from src.agents.utils.types import InvokeAgentResponseType
from langgraph.graph import StateGraph, START, END
from langchain_core.messages import SystemMessage, RemoveMessage, HumanMessage, AIMessage
from langchain_core.runnables.config import RunnableConfig
from langgraph.graph.message import add_messages
from typing import Annotated, TypeAlias
from typing_extensions import TypedDict
"""
Based on the base_agent [LLM workflow with a summarisation, profiling, and chat agent that receives an external conversation history].
This agent is designed to:
- [summarise_prompt] summarise the conversation after 'max_messages_to_summarize' number of messages is reached in the conversation
- [conv_pref_prompt] analyse the conversation style of the student
- [informational_role_prompt] role of a tutor to answer student's ALL questions on any topic
"""
# TODO: return/uncomment improved conversational style use & analysis
ValidMessageTypes: TypeAlias = SystemMessage | HumanMessage | AIMessage
AllMessageTypes: TypeAlias = ValidMessageTypes | RemoveMessage
class State(TypedDict):
messages: Annotated[list[AllMessageTypes], add_messages]
summary: str
conversationalStyle: str
class InformationalAgent:
def __init__(self, informational_role_prompt: str = informational_role_prompt, conv_pref_prompt: str = conv_pref_prompt, update_conv_pref_prompt: str = update_conv_pref_prompt, summary_prompt: str = summary_prompt, update_summary_prompt: str = update_summary_prompt):
llm = GoogleAILLMs()
self.llm = llm.get_llm()
summarisation_llm = OpenAILLMs()
self.summarisation_llm = summarisation_llm.get_llm()
self.summary = ""
self.conversationalStyle = ""
# Define Agent's specific Parameters
self.max_messages_to_summarize = 11
self.role_prompt = informational_role_prompt
self.summary_prompt = summary_prompt
self.update_summary_prompt = update_summary_prompt
self.conversation_preference_prompt = conv_pref_prompt
self.update_conversation_preference_prompt = update_conv_pref_prompt
# Define a new graph for the conversation & compile it
self.workflow = StateGraph(State)
self.workflow_definition()
self.app = self.workflow.compile()
def call_model(self, state: State, config: RunnableConfig) -> str:
"""Call the LLM model knowing the role system prompt, the summary and the conversational style."""
# Default AI tutor role prompt
system_message = self.role_prompt
# Adding external student progress and question context details from data queries
question_response_details = config["configurable"].get("question_response_details", "")
if question_response_details:
system_message += f"\n\n ## Known Learning Materials: {question_response_details} \n\n"
# Adding summary and conversational style to the system message
summary = state.get("summary", "")
conversationalStyle = state.get("conversationalStyle", "")
if summary:
system_message += summary_system_prompt.format(summary=summary)
# if conversationalStyle:
# system_message += f"\n\n ## Known conversational style and preferences of the student for this conversation: {conversationalStyle}. \n\nYour answer must be in line with this conversational style."
messages = [SystemMessage(content=system_message)] + state['messages']
valid_messages = self.check_for_valid_messages(messages)
print("Informational agent valid messages, ready for LLM call...")
response = self.llm.invoke(valid_messages)
print("Informational agent response successfully received.")
# Save summary for fetching outside the class
self.summary = summary
self.conversationalStyle = conversationalStyle
return {"summary": summary, "conversationalStyle": conversationalStyle, "messages": [response]}
def check_for_valid_messages(self, messages: list[AllMessageTypes]) -> list[ValidMessageTypes]:
""" Removing the RemoveMessage() from the list of messages """
valid_messages: list[ValidMessageTypes] = []
for message in messages:
if message.type != 'remove':
valid_messages.append(message)
return valid_messages
def summarize_conversation(self, state: State, config: RunnableConfig) -> dict:
"""Summarize the conversation."""
summary = state.get("summary", "")
previous_summary = config["configurable"].get("summary", "")
previous_conversationalStyle = config["configurable"].get("conversational_style", "")
if previous_summary:
summary = previous_summary
if summary:
summary_message = (
f"This is summary of the conversation to date: {summary}\n\n" +
self.update_summary_prompt
)
else:
summary_message = self.summary_prompt
if previous_conversationalStyle:
conversationalStyle_message = (
f"This is the previous conversational style of the student for this conversation: {previous_conversationalStyle}\n\n" +
self.update_conversation_preference_prompt
)
else:
conversationalStyle_message = self.conversation_preference_prompt
# STEP 1: Summarize the conversation
messages = state["messages"][:-1] + [HumanMessage(content=summary_message)]
valid_messages = self.check_for_valid_messages(messages)
summary_response = self.summarisation_llm.invoke(valid_messages)
# STEP 2: Analyze the conversational style
messages = state["messages"][:-1] + [HumanMessage(content=conversationalStyle_message)]
valid_messages = self.check_for_valid_messages(messages)
conversationalStyle_response = self.summarisation_llm.invoke(valid_messages)
print("Informational agent summary and conversational style responses successfully received.")
# Delete messages that are no longer wanted, except the last ones
delete_messages: list[AllMessageTypes] = [RemoveMessage(id=m.id) for m in state["messages"][:-3]]
return {"summary": summary_response.content, "conversationalStyle": conversationalStyle_response.content, "messages": delete_messages}
def should_summarize(self, state: State) -> str:
"""
Return the next node to execute.
If there are more than X messages, then we summarize the conversation.
Otherwise, we call the LLM.
"""
messages = state["messages"]
valid_messages = self.check_for_valid_messages(messages)
nr_messages = len(valid_messages)
if len(valid_messages) == 0:
raise Exception("Internal Error: No valid messages found in the conversation history. Conversation history might be empty.")
if "system" in valid_messages[-1].type:
nr_messages -= 1
# always pairs of (sent, response) + 1 latest message
if nr_messages > self.max_messages_to_summarize:
print("Informational agent: summarizing conversation needed...")
return "summarize_conversation"
return "call_llm"
def workflow_definition(self) -> None:
self.workflow.add_node("call_llm", self.call_model)
self.workflow.add_node("summarize_conversation", self.summarize_conversation)
self.workflow.add_conditional_edges(source=START, path=self.should_summarize)
self.workflow.add_edge("summarize_conversation", "call_llm")
self.workflow.add_edge("call_llm", END)
def get_summary(self) -> str:
return self.summary
def get_conversational_style(self) -> str:
return self.conversationalStyle
def print_update(self, update: dict) -> None:
for k, v in update.items():
for m in v["messages"]:
m.pretty_print()
if "summary" in v:
print(v["summary"])
def pretty_response_value(self, event: dict) -> str:
return event["messages"][-1].content
agent = InformationalAgent()
def invoke_informational_agent(query: str, conversation_history: list, summary: str, conversationalStyle: str, question_response_details: str, session_id: str, agent: InformationalAgent = agent) -> InvokeAgentResponseType:
print(f'in invoke_informational_agent(), thread_id = {session_id}')
config = {"configurable": {"thread_id": session_id, "summary": summary, "conversational_style": conversationalStyle, "question_response_details": question_response_details}}
print("Informational agent invoking...")
response_events = agent.app.invoke({"messages": conversation_history, "summary": summary, "conversational_style": conversationalStyle}, config=config, stream_mode="values") #updates
print("Informational agent response received.")
pretty_printed_response = agent.pretty_response_value(response_events) # get last event/ai answer in the response
# Gather Metadata from the agent
summary = agent.get_summary()
conversationalStyle = agent.get_conversational_style()
return {
"input": query,
"output": pretty_printed_response,
"intermediate_steps": [str(summary), conversationalStyle, conversation_history]
}