Skip to content

Commit c4d46c6

Browse files
fix: Groups conversations based on the user's messages.
Closes: #419 Before we could use the `chat_id` at the output messages as means to group the messages into conversations. This logic is not working anymore. The new logic takes into account the user messages provided as input to the LLM to map the messages into conversations. Usually LLMs receive all last user messages. Example: ``` req1 = {messages:[{"role": "user", "content": "hello"}]} req2 = {messages:[{"role": "user", "content": "hello"}, {"role": "user", "content": "how are you?}]} ``` In this last example, `req1` and `req2` should be mapped together to form a conversation
1 parent ff4a3a7 commit c4d46c6

File tree

3 files changed

+390
-148
lines changed

3 files changed

+390
-148
lines changed

src/codegate/dashboard/post_processing.py

+147-80
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import asyncio
22
import json
33
import re
4-
from typing import List, Optional, Tuple, Union
4+
from collections import defaultdict
5+
from typing import List, Optional, Union
56

67
import structlog
78

89
from codegate.dashboard.request_models import (
910
AlertConversation,
1011
ChatMessage,
1112
Conversation,
12-
PartialConversation,
13+
PartialQuestionAnswer,
14+
PartialQuestions,
1315
QuestionAnswer,
1416
)
1517
from codegate.db.models import GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow
@@ -74,60 +76,57 @@ async def parse_request(request_str: str) -> Optional[str]:
7476
return None
7577

7678
# Only respond with the latest message
77-
return messages[-1]
79+
return messages
7880

7981

80-
async def parse_output(output_str: str) -> Tuple[Optional[str], Optional[str]]:
82+
async def parse_output(output_str: str) -> Optional[str]:
8183
"""
82-
Parse the output string from the pipeline and return the message and chat_id.
84+
Parse the output string from the pipeline and return the message.
8385
"""
8486
try:
8587
if output_str is None:
86-
return None, None
88+
return None
8789

8890
output = json.loads(output_str)
8991
except Exception as e:
9092
logger.warning(f"Error parsing output: {output_str}. {e}")
91-
return None, None
93+
return None
9294

9395
def _parse_single_output(single_output: dict) -> str:
94-
single_chat_id = single_output.get("id")
9596
single_output_message = ""
9697
for choice in single_output.get("choices", []):
9798
if not isinstance(choice, dict):
9899
continue
99100
content_dict = choice.get("delta", {}) or choice.get("message", {})
100101
single_output_message += content_dict.get("content", "")
101-
return single_output_message, single_chat_id
102+
return single_output_message
102103

103104
full_output_message = ""
104-
chat_id = None
105105
if isinstance(output, list):
106106
for output_chunk in output:
107-
output_message, output_chat_id = "", None
107+
output_message = ""
108108
if isinstance(output_chunk, dict):
109-
output_message, output_chat_id = _parse_single_output(output_chunk)
109+
output_message = _parse_single_output(output_chunk)
110110
elif isinstance(output_chunk, str):
111111
try:
112112
output_decoded = json.loads(output_chunk)
113-
output_message, output_chat_id = _parse_single_output(output_decoded)
113+
output_message = _parse_single_output(output_decoded)
114114
except Exception:
115115
logger.error(f"Error reading chunk: {output_chunk}")
116116
else:
117117
logger.warning(
118118
f"Could not handle output: {output_chunk}", out_type=type(output_chunk)
119119
)
120-
chat_id = chat_id or output_chat_id
121120
full_output_message += output_message
122121
elif isinstance(output, dict):
123-
full_output_message, chat_id = _parse_single_output(output)
122+
full_output_message = _parse_single_output(output)
124123

125-
return full_output_message, chat_id
124+
return full_output_message
126125

127126

128127
async def _get_question_answer(
129128
row: Union[GetPromptWithOutputsRow, GetAlertsWithPromptAndOutputRow]
130-
) -> Tuple[Optional[QuestionAnswer], Optional[str]]:
129+
) -> Optional[PartialQuestionAnswer]:
131130
"""
132131
Parse a row from the get_prompt_with_outputs query and return a PartialConversation
133132
@@ -137,17 +136,19 @@ async def _get_question_answer(
137136
request_task = tg.create_task(parse_request(row.request))
138137
output_task = tg.create_task(parse_output(row.output))
139138

140-
request_msg_str = request_task.result()
141-
output_msg_str, chat_id = output_task.result()
139+
request_user_msgs = request_task.result()
140+
output_msg_str = output_task.result()
142141

143-
# If we couldn't parse the request or output, return None
144-
if not request_msg_str:
145-
return None, None
142+
# If we couldn't parse the request, return None
143+
if not request_user_msgs:
144+
return None
146145

147-
request_message = ChatMessage(
148-
message=request_msg_str,
146+
request_message = PartialQuestions(
147+
messages=request_user_msgs,
149148
timestamp=row.timestamp,
150149
message_id=row.id,
150+
provider=row.provider,
151+
type=row.type,
151152
)
152153
if output_msg_str:
153154
output_message = ChatMessage(
@@ -157,28 +158,7 @@ async def _get_question_answer(
157158
)
158159
else:
159160
output_message = None
160-
chat_id = row.id
161-
return QuestionAnswer(question=request_message, answer=output_message), chat_id
162-
163-
164-
async def parse_get_prompt_with_output(
165-
row: GetPromptWithOutputsRow,
166-
) -> Optional[PartialConversation]:
167-
"""
168-
Parse a row from the get_prompt_with_outputs query and return a PartialConversation
169-
170-
The row contains the raw request and output strings from the pipeline.
171-
"""
172-
question_answer, chat_id = await _get_question_answer(row)
173-
if not question_answer or not chat_id:
174-
return None
175-
return PartialConversation(
176-
question_answer=question_answer,
177-
provider=row.provider,
178-
type=row.type,
179-
chat_id=chat_id,
180-
request_timestamp=row.timestamp,
181-
)
161+
return PartialQuestionAnswer(partial_questions=request_message, answer=output_message)
182162

183163

184164
def parse_question_answer(input_text: str) -> str:
@@ -195,50 +175,135 @@ def parse_question_answer(input_text: str) -> str:
195175
return input_text
196176

197177

178+
def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[PartialQuestions]]:
179+
"""
180+
A PartialQuestion is an object that contains several user messages provided from a
181+
chat conversation. Example:
182+
- PartialQuestion(messages=["Hello"], timestamp=2022-01-01T00:00:00Z)
183+
- PartialQuestion(messages=["Hello", "How are you?"], timestamp=2022-01-01T00:00:01Z)
184+
In the above example both PartialQuestions are part of the same conversation and should be
185+
matched together.
186+
Group PartialQuestions objects such that:
187+
- If one PartialQuestion (pq) is a subset of another pq's messages, group them together.
188+
- If multiple subsets exist for the same superset, choose only the one
189+
closest in timestamp to the superset.
190+
- Leave any unpaired pq by itself.
191+
- Finally, sort the resulting groups by the earliest timestamp in each group.
192+
"""
193+
# 1) Sort by length of messages descending (largest/most-complete first),
194+
# then by timestamp ascending for stable processing.
195+
pq_list_sorted = sorted(pq_list, key=lambda x: (-len(x.messages), x.timestamp))
196+
197+
used = set()
198+
groups = []
199+
200+
# 2) Iterate in order of "largest messages first"
201+
for sup in pq_list_sorted:
202+
if sup.message_id in used:
203+
continue # Already grouped
204+
205+
# Find all potential subsets of 'sup' that are not yet used
206+
# (If sup's messages == sub's messages, that also counts, because sub ⊆ sup)
207+
possible_subsets = []
208+
for sub in pq_list_sorted:
209+
if sub.message_id == sup.message_id:
210+
continue
211+
if sub.message_id in used:
212+
continue
213+
if (
214+
set(sub.messages).issubset(set(sup.messages))
215+
and sub.provider == sup.provider
216+
and set(sub.messages) != set(sup.messages)
217+
):
218+
possible_subsets.append(sub)
219+
220+
# 3) If there are no subsets, this sup stands alone
221+
if not possible_subsets:
222+
groups.append([sup])
223+
used.add(sup.message_id)
224+
else:
225+
# 4) Group subsets by messages to discard duplicates e.g.: 2 subsets with single 'hello'
226+
subs_group_by_messages = defaultdict(list)
227+
for q in possible_subsets:
228+
subs_group_by_messages[tuple(q.messages)].append(q)
229+
230+
new_group = [sup]
231+
used.add(sup.message_id)
232+
for subs_same_message in subs_group_by_messages.values():
233+
# If more than one pick the one subset closest in time to sup
234+
closest_subset = min(
235+
subs_same_message, key=lambda s: abs(s.timestamp - sup.timestamp)
236+
)
237+
new_group.append(closest_subset)
238+
used.add(closest_subset.message_id)
239+
groups.append(new_group)
240+
241+
# 5) Sort the groups by the earliest timestamp within each group
242+
groups.sort(key=lambda g: min(pq.timestamp for pq in g))
243+
return groups
244+
245+
246+
def _get_question_answer_from_partial(
247+
partial_question_answer: PartialQuestionAnswer,
248+
) -> QuestionAnswer:
249+
"""
250+
Get a QuestionAnswer object from a PartialQuestionAnswer object.
251+
"""
252+
# Get the last user message as the question
253+
question = ChatMessage(
254+
message=partial_question_answer.partial_questions.messages[-1],
255+
timestamp=partial_question_answer.partial_questions.timestamp,
256+
message_id=partial_question_answer.partial_questions.message_id,
257+
)
258+
259+
return QuestionAnswer(question=question, answer=partial_question_answer.answer)
260+
261+
198262
async def match_conversations(
199-
partial_conversations: List[Optional[PartialConversation]],
263+
partial_question_answers: List[Optional[PartialQuestionAnswer]],
200264
) -> List[Conversation]:
201265
"""
202266
Match partial conversations to form a complete conversation.
203267
"""
204-
convers = {}
205-
for partial_conversation in partial_conversations:
206-
if not partial_conversation:
207-
continue
208-
209-
# Group by chat_id
210-
if partial_conversation.chat_id not in convers:
211-
convers[partial_conversation.chat_id] = []
212-
convers[partial_conversation.chat_id].append(partial_conversation)
268+
valid_partial_qas = [
269+
partial_qas for partial_qas in partial_question_answers if partial_qas is not None
270+
]
271+
grouped_partial_questions = _group_partial_messages(
272+
[partial_qs_a.partial_questions for partial_qs_a in valid_partial_qas]
273+
)
213274

214-
# Sort by timestamp
215-
sorted_convers = {
216-
chat_id: sorted(conversations, key=lambda x: x.request_timestamp)
217-
for chat_id, conversations in convers.items()
218-
}
219275
# Create the conversation objects
220276
conversations = []
221-
for chat_id, sorted_convers in sorted_convers.items():
277+
for group in grouped_partial_questions:
222278
questions_answers = []
223-
first_partial_conversation = None
224-
for partial_conversation in sorted_convers:
279+
first_partial_qa = None
280+
for partial_question in sorted(group, key=lambda x: x.timestamp):
281+
# Partial questions don't contain the answer, so we need to find the corresponding
282+
selected_partial_qa = None
283+
for partial_qa in valid_partial_qas:
284+
if partial_question.message_id == partial_qa.partial_questions.message_id:
285+
selected_partial_qa = partial_qa
286+
break
287+
225288
# check if we have an answer, otherwise do not add it
226-
if partial_conversation.question_answer.answer is not None:
227-
first_partial_conversation = partial_conversation
228-
partial_conversation.question_answer.question.message = parse_question_answer(
229-
partial_conversation.question_answer.question.message
289+
if selected_partial_qa.answer is not None:
290+
# if we don't have a first question, set it
291+
first_partial_qa = first_partial_qa or selected_partial_qa
292+
question_answer = _get_question_answer_from_partial(selected_partial_qa)
293+
question_answer.question.message = parse_question_answer(
294+
question_answer.question.message
230295
)
231-
questions_answers.append(partial_conversation.question_answer)
296+
questions_answers.append(question_answer)
232297

233298
# only add conversation if we have some answers
234-
if len(questions_answers) > 0 and first_partial_conversation is not None:
299+
if len(questions_answers) > 0 and first_partial_qa is not None:
235300
conversations.append(
236301
Conversation(
237302
question_answers=questions_answers,
238-
provider=first_partial_conversation.provider,
239-
type=first_partial_conversation.type,
240-
chat_id=chat_id,
241-
conversation_timestamp=sorted_convers[0].request_timestamp,
303+
provider=first_partial_qa.partial_questions.provider,
304+
type=first_partial_qa.partial_questions.type,
305+
chat_id=first_partial_qa.partial_questions.message_id,
306+
conversation_timestamp=first_partial_qa.partial_questions.timestamp,
242307
)
243308
)
244309

@@ -254,10 +319,10 @@ async def parse_messages_in_conversations(
254319

255320
# Parse the prompts and outputs in parallel
256321
async with asyncio.TaskGroup() as tg:
257-
tasks = [tg.create_task(parse_get_prompt_with_output(row)) for row in prompts_outputs]
258-
partial_conversations = [task.result() for task in tasks]
322+
tasks = [tg.create_task(_get_question_answer(row)) for row in prompts_outputs]
323+
partial_question_answers = [task.result() for task in tasks]
259324

260-
conversations = await match_conversations(partial_conversations)
325+
conversations = await match_conversations(partial_question_answers)
261326
return conversations
262327

263328

@@ -269,15 +334,17 @@ async def parse_row_alert_conversation(
269334
270335
The row contains the raw request and output strings from the pipeline.
271336
"""
272-
question_answer, chat_id = await _get_question_answer(row)
273-
if not question_answer or not chat_id:
337+
partial_qa = await _get_question_answer(row)
338+
if not partial_qa:
274339
return None
275340

341+
question_answer = _get_question_answer_from_partial(partial_qa)
342+
276343
conversation = Conversation(
277344
question_answers=[question_answer],
278345
provider=row.provider,
279346
type=row.type,
280-
chat_id=chat_id or "chat-id-not-found",
347+
chat_id=row.id,
281348
conversation_timestamp=row.timestamp,
282349
)
283350
code_snippet = json.loads(row.code_snippet) if row.code_snippet else None

src/codegate/dashboard/request_models.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,25 @@ class QuestionAnswer(BaseModel):
2525
answer: Optional[ChatMessage]
2626

2727

28-
class PartialConversation(BaseModel):
28+
class PartialQuestions(BaseModel):
2929
"""
30-
Represents a partial conversation obtained from a DB row.
30+
Represents all user messages obtained from a DB row.
3131
"""
3232

33-
question_answer: QuestionAnswer
33+
messages: List[str]
34+
timestamp: datetime.datetime
35+
message_id: str
3436
provider: Optional[str]
3537
type: str
36-
chat_id: str
37-
request_timestamp: datetime.datetime
38+
39+
40+
class PartialQuestionAnswer(BaseModel):
41+
"""
42+
Represents a partial conversation.
43+
"""
44+
45+
partial_questions: PartialQuestions
46+
answer: Optional[ChatMessage]
3847

3948

4049
class Conversation(BaseModel):

0 commit comments

Comments
 (0)