Skip to content

Commit c48969c

Browse files
authored
Merge pull request #546 from stacklok/issue-419
fix: Groups conversations based on the user's messages.
2 parents 05c2574 + c4d46c6 commit c48969c

File tree

3 files changed

+390
-148
lines changed

3 files changed

+390
-148
lines changed

src/codegate/dashboard/post_processing.py

+147-80
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import asyncio
22
import json
33
import re
4-
from typing import List, Optional, Tuple, Union
4+
from collections import defaultdict
5+
from typing import List, Optional, Union
56

67
import structlog
78

89
from codegate.dashboard.request_models import (
910
AlertConversation,
1011
ChatMessage,
1112
Conversation,
12-
PartialConversation,
13+
PartialQuestionAnswer,
14+
PartialQuestions,
1315
QuestionAnswer,
1416
)
1517
from codegate.db.models import GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow
@@ -74,60 +76,57 @@ async def parse_request(request_str: str) -> Optional[str]:
7476
return None
7577

7678
# Only respond with the latest message
77-
return messages[-1]
79+
return messages
7880

7981

80-
async def parse_output(output_str: str) -> Tuple[Optional[str], Optional[str]]:
82+
async def parse_output(output_str: str) -> Optional[str]:
8183
"""
82-
Parse the output string from the pipeline and return the message and chat_id.
84+
Parse the output string from the pipeline and return the message.
8385
"""
8486
try:
8587
if output_str is None:
86-
return None, None
88+
return None
8789

8890
output = json.loads(output_str)
8991
except Exception as e:
9092
logger.warning(f"Error parsing output: {output_str}. {e}")
91-
return None, None
93+
return None
9294

9395
def _parse_single_output(single_output: dict) -> str:
94-
single_chat_id = single_output.get("id")
9596
single_output_message = ""
9697
for choice in single_output.get("choices", []):
9798
if not isinstance(choice, dict):
9899
continue
99100
content_dict = choice.get("delta", {}) or choice.get("message", {})
100101
single_output_message += content_dict.get("content", "")
101-
return single_output_message, single_chat_id
102+
return single_output_message
102103

103104
full_output_message = ""
104-
chat_id = None
105105
if isinstance(output, list):
106106
for output_chunk in output:
107-
output_message, output_chat_id = "", None
107+
output_message = ""
108108
if isinstance(output_chunk, dict):
109-
output_message, output_chat_id = _parse_single_output(output_chunk)
109+
output_message = _parse_single_output(output_chunk)
110110
elif isinstance(output_chunk, str):
111111
try:
112112
output_decoded = json.loads(output_chunk)
113-
output_message, output_chat_id = _parse_single_output(output_decoded)
113+
output_message = _parse_single_output(output_decoded)
114114
except Exception:
115115
logger.error(f"Error reading chunk: {output_chunk}")
116116
else:
117117
logger.warning(
118118
f"Could not handle output: {output_chunk}", out_type=type(output_chunk)
119119
)
120-
chat_id = chat_id or output_chat_id
121120
full_output_message += output_message
122121
elif isinstance(output, dict):
123-
full_output_message, chat_id = _parse_single_output(output)
122+
full_output_message = _parse_single_output(output)
124123

125-
return full_output_message, chat_id
124+
return full_output_message
126125

127126

128127
async def _get_question_answer(
129128
row: Union[GetPromptWithOutputsRow, GetAlertsWithPromptAndOutputRow]
130-
) -> Tuple[Optional[QuestionAnswer], Optional[str]]:
129+
) -> Optional[PartialQuestionAnswer]:
131130
"""
132131
Parse a row from the get_prompt_with_outputs query and return a PartialConversation
133132
@@ -137,17 +136,19 @@ async def _get_question_answer(
137136
request_task = tg.create_task(parse_request(row.request))
138137
output_task = tg.create_task(parse_output(row.output))
139138

140-
request_msg_str = request_task.result()
141-
output_msg_str, chat_id = output_task.result()
139+
request_user_msgs = request_task.result()
140+
output_msg_str = output_task.result()
142141

143-
# If we couldn't parse the request or output, return None
144-
if not request_msg_str:
145-
return None, None
142+
# If we couldn't parse the request, return None
143+
if not request_user_msgs:
144+
return None
146145

147-
request_message = ChatMessage(
148-
message=request_msg_str,
146+
request_message = PartialQuestions(
147+
messages=request_user_msgs,
149148
timestamp=row.timestamp,
150149
message_id=row.id,
150+
provider=row.provider,
151+
type=row.type,
151152
)
152153
if output_msg_str:
153154
output_message = ChatMessage(
@@ -157,28 +158,7 @@ async def _get_question_answer(
157158
)
158159
else:
159160
output_message = None
160-
chat_id = row.id
161-
return QuestionAnswer(question=request_message, answer=output_message), chat_id
162-
163-
164-
async def parse_get_prompt_with_output(
165-
row: GetPromptWithOutputsRow,
166-
) -> Optional[PartialConversation]:
167-
"""
168-
Parse a row from the get_prompt_with_outputs query and return a PartialConversation
169-
170-
The row contains the raw request and output strings from the pipeline.
171-
"""
172-
question_answer, chat_id = await _get_question_answer(row)
173-
if not question_answer or not chat_id:
174-
return None
175-
return PartialConversation(
176-
question_answer=question_answer,
177-
provider=row.provider,
178-
type=row.type,
179-
chat_id=chat_id,
180-
request_timestamp=row.timestamp,
181-
)
161+
return PartialQuestionAnswer(partial_questions=request_message, answer=output_message)
182162

183163

184164
def parse_question_answer(input_text: str) -> str:
@@ -195,50 +175,135 @@ def parse_question_answer(input_text: str) -> str:
195175
return input_text
196176

197177

178+
def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[PartialQuestions]]:
179+
"""
180+
A PartialQuestion is an object that contains several user messages provided from a
181+
chat conversation. Example:
182+
- PartialQuestion(messages=["Hello"], timestamp=2022-01-01T00:00:00Z)
183+
- PartialQuestion(messages=["Hello", "How are you?"], timestamp=2022-01-01T00:00:01Z)
184+
In the above example both PartialQuestions are part of the same conversation and should be
185+
matched together.
186+
Group PartialQuestions objects such that:
187+
- If one PartialQuestion (pq) is a subset of another pq's messages, group them together.
188+
- If multiple subsets exist for the same superset, choose only the one
189+
closest in timestamp to the superset.
190+
- Leave any unpaired pq by itself.
191+
- Finally, sort the resulting groups by the earliest timestamp in each group.
192+
"""
193+
# 1) Sort by length of messages descending (largest/most-complete first),
194+
# then by timestamp ascending for stable processing.
195+
pq_list_sorted = sorted(pq_list, key=lambda x: (-len(x.messages), x.timestamp))
196+
197+
used = set()
198+
groups = []
199+
200+
# 2) Iterate in order of "largest messages first"
201+
for sup in pq_list_sorted:
202+
if sup.message_id in used:
203+
continue # Already grouped
204+
205+
# Find all potential subsets of 'sup' that are not yet used
206+
# (If sup's messages == sub's messages, that also counts, because sub ⊆ sup)
207+
possible_subsets = []
208+
for sub in pq_list_sorted:
209+
if sub.message_id == sup.message_id:
210+
continue
211+
if sub.message_id in used:
212+
continue
213+
if (
214+
set(sub.messages).issubset(set(sup.messages))
215+
and sub.provider == sup.provider
216+
and set(sub.messages) != set(sup.messages)
217+
):
218+
possible_subsets.append(sub)
219+
220+
# 3) If there are no subsets, this sup stands alone
221+
if not possible_subsets:
222+
groups.append([sup])
223+
used.add(sup.message_id)
224+
else:
225+
# 4) Group subsets by messages to discard duplicates e.g.: 2 subsets with single 'hello'
226+
subs_group_by_messages = defaultdict(list)
227+
for q in possible_subsets:
228+
subs_group_by_messages[tuple(q.messages)].append(q)
229+
230+
new_group = [sup]
231+
used.add(sup.message_id)
232+
for subs_same_message in subs_group_by_messages.values():
233+
# If more than one pick the one subset closest in time to sup
234+
closest_subset = min(
235+
subs_same_message, key=lambda s: abs(s.timestamp - sup.timestamp)
236+
)
237+
new_group.append(closest_subset)
238+
used.add(closest_subset.message_id)
239+
groups.append(new_group)
240+
241+
# 5) Sort the groups by the earliest timestamp within each group
242+
groups.sort(key=lambda g: min(pq.timestamp for pq in g))
243+
return groups
244+
245+
246+
def _get_question_answer_from_partial(
247+
partial_question_answer: PartialQuestionAnswer,
248+
) -> QuestionAnswer:
249+
"""
250+
Get a QuestionAnswer object from a PartialQuestionAnswer object.
251+
"""
252+
# Get the last user message as the question
253+
question = ChatMessage(
254+
message=partial_question_answer.partial_questions.messages[-1],
255+
timestamp=partial_question_answer.partial_questions.timestamp,
256+
message_id=partial_question_answer.partial_questions.message_id,
257+
)
258+
259+
return QuestionAnswer(question=question, answer=partial_question_answer.answer)
260+
261+
198262
async def match_conversations(
199-
partial_conversations: List[Optional[PartialConversation]],
263+
partial_question_answers: List[Optional[PartialQuestionAnswer]],
200264
) -> List[Conversation]:
201265
"""
202266
Match partial conversations to form a complete conversation.
203267
"""
204-
convers = {}
205-
for partial_conversation in partial_conversations:
206-
if not partial_conversation:
207-
continue
208-
209-
# Group by chat_id
210-
if partial_conversation.chat_id not in convers:
211-
convers[partial_conversation.chat_id] = []
212-
convers[partial_conversation.chat_id].append(partial_conversation)
268+
valid_partial_qas = [
269+
partial_qas for partial_qas in partial_question_answers if partial_qas is not None
270+
]
271+
grouped_partial_questions = _group_partial_messages(
272+
[partial_qs_a.partial_questions for partial_qs_a in valid_partial_qas]
273+
)
213274

214-
# Sort by timestamp
215-
sorted_convers = {
216-
chat_id: sorted(conversations, key=lambda x: x.request_timestamp)
217-
for chat_id, conversations in convers.items()
218-
}
219275
# Create the conversation objects
220276
conversations = []
221-
for chat_id, sorted_convers in sorted_convers.items():
277+
for group in grouped_partial_questions:
222278
questions_answers = []
223-
first_partial_conversation = None
224-
for partial_conversation in sorted_convers:
279+
first_partial_qa = None
280+
for partial_question in sorted(group, key=lambda x: x.timestamp):
281+
# Partial questions don't contain the answer, so we need to find the corresponding
282+
selected_partial_qa = None
283+
for partial_qa in valid_partial_qas:
284+
if partial_question.message_id == partial_qa.partial_questions.message_id:
285+
selected_partial_qa = partial_qa
286+
break
287+
225288
# check if we have an answer, otherwise do not add it
226-
if partial_conversation.question_answer.answer is not None:
227-
first_partial_conversation = partial_conversation
228-
partial_conversation.question_answer.question.message = parse_question_answer(
229-
partial_conversation.question_answer.question.message
289+
if selected_partial_qa.answer is not None:
290+
# if we don't have a first question, set it
291+
first_partial_qa = first_partial_qa or selected_partial_qa
292+
question_answer = _get_question_answer_from_partial(selected_partial_qa)
293+
question_answer.question.message = parse_question_answer(
294+
question_answer.question.message
230295
)
231-
questions_answers.append(partial_conversation.question_answer)
296+
questions_answers.append(question_answer)
232297

233298
# only add conversation if we have some answers
234-
if len(questions_answers) > 0 and first_partial_conversation is not None:
299+
if len(questions_answers) > 0 and first_partial_qa is not None:
235300
conversations.append(
236301
Conversation(
237302
question_answers=questions_answers,
238-
provider=first_partial_conversation.provider,
239-
type=first_partial_conversation.type,
240-
chat_id=chat_id,
241-
conversation_timestamp=sorted_convers[0].request_timestamp,
303+
provider=first_partial_qa.partial_questions.provider,
304+
type=first_partial_qa.partial_questions.type,
305+
chat_id=first_partial_qa.partial_questions.message_id,
306+
conversation_timestamp=first_partial_qa.partial_questions.timestamp,
242307
)
243308
)
244309

@@ -254,10 +319,10 @@ async def parse_messages_in_conversations(
254319

255320
# Parse the prompts and outputs in parallel
256321
async with asyncio.TaskGroup() as tg:
257-
tasks = [tg.create_task(parse_get_prompt_with_output(row)) for row in prompts_outputs]
258-
partial_conversations = [task.result() for task in tasks]
322+
tasks = [tg.create_task(_get_question_answer(row)) for row in prompts_outputs]
323+
partial_question_answers = [task.result() for task in tasks]
259324

260-
conversations = await match_conversations(partial_conversations)
325+
conversations = await match_conversations(partial_question_answers)
261326
return conversations
262327

263328

@@ -269,15 +334,17 @@ async def parse_row_alert_conversation(
269334
270335
The row contains the raw request and output strings from the pipeline.
271336
"""
272-
question_answer, chat_id = await _get_question_answer(row)
273-
if not question_answer or not chat_id:
337+
partial_qa = await _get_question_answer(row)
338+
if not partial_qa:
274339
return None
275340

341+
question_answer = _get_question_answer_from_partial(partial_qa)
342+
276343
conversation = Conversation(
277344
question_answers=[question_answer],
278345
provider=row.provider,
279346
type=row.type,
280-
chat_id=chat_id or "chat-id-not-found",
347+
chat_id=row.id,
281348
conversation_timestamp=row.timestamp,
282349
)
283350
code_snippet = json.loads(row.code_snippet) if row.code_snippet else None

src/codegate/dashboard/request_models.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,25 @@ class QuestionAnswer(BaseModel):
2525
answer: Optional[ChatMessage]
2626

2727

28-
class PartialConversation(BaseModel):
28+
class PartialQuestions(BaseModel):
2929
"""
30-
Represents a partial conversation obtained from a DB row.
30+
Represents all user messages obtained from a DB row.
3131
"""
3232

33-
question_answer: QuestionAnswer
33+
messages: List[str]
34+
timestamp: datetime.datetime
35+
message_id: str
3436
provider: Optional[str]
3537
type: str
36-
chat_id: str
37-
request_timestamp: datetime.datetime
38+
39+
40+
class PartialQuestionAnswer(BaseModel):
41+
"""
42+
Represents a partial conversation.
43+
"""
44+
45+
partial_questions: PartialQuestions
46+
answer: Optional[ChatMessage]
3847

3948

4049
class Conversation(BaseModel):

0 commit comments

Comments
 (0)