Skip to content

feat: rewrite query with context consideration #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions examples/core_memories/tree_textual_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,27 +193,45 @@ def embed_memory_item(memory: str) -> list[float]:

time.sleep(60)

init_time = time.time()
results = my_tree_textual_memory.search(
"Talk about the user's childhood story?",
top_k=10,
info={"query": "Talk about the user's childhood story?", "user_id": "111", "session": "2234"},
info={
"query": "Talk about the user's childhood story?",
"user_id": "111",
"session_id": "2234",
"chat_history": [{"role": "user", "content": "xxxxx"}],
},
)
for i, r in enumerate(results):
r = r.to_dict()
print(f"{i}'th similar result is: " + str(r["memory"]))
print(f"Successfully search {len(results)} memories")
print(f"Successfully search {len(results)} memories in {round(time.time() - init_time)}s")

# try this when use 'fine' mode (Note that you should pass the internet Config, refer to examples/core_memories/textual_internet_memoy.py)
init_time = time.time()
results_fine_search = my_tree_textual_memory.search(
"Recent news in NewYork",
"Recent news in the first city you've mentioned.",
top_k=10,
mode="fine",
info={"query": "Recent news in NewYork", "user_id": "111", "session": "2234"},
info={
"query": "Recent news in NewYork",
"user_id": "111",
"session_id": "2234",
"chat_history": [
{"role": "user", "content": "I want to know three beautiful cities"},
{"role": "assistant", "content": "New York, London, and Shanghai"},
],
},
)

for i, r in enumerate(results_fine_search):
r = r.to_dict()
print(f"{i}'th similar result is: " + str(r["memory"]))
print(f"Successfully search {len(results_fine_search)} memories")
print(
f"Successfully search {len(results_fine_search)} memories in {round(time.time() - init_time)}s"
)

# find related nodes
related_nodes = my_tree_textual_memory.get_relevant_subgraph("Painting")
Expand Down
3 changes: 2 additions & 1 deletion src/memos/graph_dbs/nebular.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,8 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]:

def _escape_value(value):
if isinstance(value, str):
return f'"{value}"'
escaped = value.replace('"', '\\"')
return f'"{escaped}"'
elif isinstance(value, list):
return "[" + ", ".join(_escape_value(v) for v in value) + "]"
else:
Expand Down
51 changes: 45 additions & 6 deletions src/memos/mem_os/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import os
import uuid

from datetime import datetime
from pathlib import Path
Expand All @@ -24,6 +23,7 @@
from memos.memories.activation.item import ActivationMemoryItem
from memos.memories.parametric.item import ParametricMemoryItem
from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata
from memos.templates.mos_prompts import QUERY_REWRITING_PROMPT
from memos.types import ChatHistory, MessageList, MOSSearchResult


Expand Down Expand Up @@ -282,7 +282,15 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None =
)
self.mem_scheduler.submit_messages(messages=[message_item])

memories = mem_cube.text_mem.search(query, top_k=self.config.top_k)
memories = mem_cube.text_mem.search(
query,
top_k=self.config.top_k,
info={
"user_id": target_user_id,
"session_id": self.session_id,
"chat_history": chat_history,
},
)
memories_all.extend(memories)
logger.info(f"🧠 [Memory] Searched memories:\n{self._str_memories(memories_all)}\n")
system_prompt = self._build_system_prompt(memories_all, base_prompt=base_prompt)
Expand Down Expand Up @@ -555,6 +563,9 @@ def search(
logger.info(
f"User {target_user_id} has access to {len(user_cube_ids)} cubes: {user_cube_ids}"
)

chat_history = self.chat_history_manager[target_user_id]

result: MOSSearchResult = {
"text_mem": [],
"act_mem": [],
Expand All @@ -573,7 +584,11 @@ def search(
top_k=top_k if top_k else self.config.top_k,
mode=mode,
manual_close_internet=not internet_search,
info={"user_id": target_user_id, "session_id": str(uuid.uuid4())},
info={
"user_id": target_user_id,
"session_id": self.session_id,
"chat_history": chat_history,
},
)
result["text_mem"].append({"cube_id": mem_cube_id, "memories": memories})
logger.info(
Expand Down Expand Up @@ -639,7 +654,7 @@ def add(
memories = self.mem_reader.get_memory(
messages_list,
type="chat",
info={"user_id": target_user_id, "session_id": str(uuid.uuid4())},
info={"user_id": target_user_id, "session_id": self.session_id},
)

mem_ids = []
Expand Down Expand Up @@ -683,7 +698,7 @@ def add(
memories = self.mem_reader.get_memory(
messages_list,
type="chat",
info={"user_id": target_user_id, "session_id": str(uuid.uuid4())},
info={"user_id": target_user_id, "session_id": self.session_id},
)

mem_ids = []
Expand Down Expand Up @@ -717,7 +732,7 @@ def add(
doc_memories = self.mem_reader.get_memory(
documents,
type="doc",
info={"user_id": target_user_id, "session_id": str(uuid.uuid4())},
info={"user_id": target_user_id, "session_id": self.session_id},
)

mem_ids = []
Expand Down Expand Up @@ -971,3 +986,27 @@ def share_cube_with_user(self, cube_id: str, target_user_id: str) -> bool:
raise ValueError(f"Target user '{target_user_id}' does not exist or is inactive.")

return self.user_manager.add_user_to_cube(target_user_id, cube_id)

def get_query_rewrite(self, query: str, user_id: str | None = None):
"""
Rewrite user's query according the context.
Args:
query (str): The search query that needs rewriting.
user_id(str, optional): The identifier of the user that the query belongs to.
If None, the default user is used.

Returns:
str: query after rewriting process.
"""
target_user_id = user_id if user_id is not None else self.user_id
chat_history = self.chat_history_manager[target_user_id]

dialogue = "————{}".format("\n————".join(chat_history.chat_history))
user_prompt = QUERY_REWRITING_PROMPT.format(dialogue=dialogue, query=query)
messages = {"role": "user", "content": user_prompt}
rewritten_result = self.chat_llm.generate(messages=messages)
rewritten_result = json.loads(rewritten_result)
if rewritten_result.get("former_dialogue_related", False):
rewritten_query = rewritten_result.get("rewritten_question")
return rewritten_query if len(rewritten_query) > 0 else query
return query
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def retrieve_from_internet(
Returns:
List of TextualMemoryItem
"""
if not info:
info = {"user_id": "", "session_id": ""}
# Get search results
search_results = self.google_api.get_all_results(query, max_results=top_k)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ class ParsedTaskGoal:
memories: list[str] = field(default_factory=list)
keys: list[str] = field(default_factory=list)
tags: list[str] = field(default_factory=list)
rephrased_query: str | None = None
internet_search: bool = False
goal_type: str | None = None # e.g., 'default', 'explanation', etc.
55 changes: 42 additions & 13 deletions src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from memos.embedders.factory import OllamaEmbedder
from memos.graph_dbs.factory import Neo4jGraphDB
from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM
from memos.log import get_logger
from memos.memories.textual.item import SearchedTreeNodeTextualMemoryMetadata, TextualMemoryItem

from .internet_retriever_factory import InternetRetrieverFactory
Expand All @@ -15,6 +16,9 @@
from .task_goal_parser import TaskGoalParser


logger = get_logger(__name__)


class Searcher:
def __init__(
self,
Expand Down Expand Up @@ -53,7 +57,12 @@ def search(
Returns:
list[TextualMemoryItem]: List of matching memories.
"""

if not info:
logger.warning(
"Please input 'info' when use tree.search so that "
"the database would store the consume history."
)
info = {"user_id": "", "session_id": ""}
# Step 1: Parse task structure into topic, concept, and fact levels
context = []
if mode == "fine":
Expand All @@ -67,7 +76,18 @@ def search(
context = list(set(context))

# Step 1a: Parse task structure into topic, concept, and fact levels
parsed_goal = self.task_goal_parser.parse(query, "\n".join(context))
parsed_goal = self.task_goal_parser.parse(
task_description=query,
context="\n".join(context),
conversation=info.get("chat_history", []),
mode=mode,
)

query = (
parsed_goal.rephrased_query
if parsed_goal.rephrased_query and len(parsed_goal.rephrased_query) > 0
else query
)

if parsed_goal.memories:
query_embedding = self.embedder.embed(list({query, *parsed_goal.memories}))
Expand Down Expand Up @@ -136,7 +156,7 @@ def retrieve_from_internet():
"""
Retrieve information from the internet using Google Custom Search API.
"""
if not self.internet_retriever or mode == "fast":
if not self.internet_retriever or mode == "fast" or not parsed_goal.internet_search:
return []
if memory_type not in ["All"]:
return []
Expand All @@ -154,16 +174,25 @@ def retrieve_from_internet():
)
return ranked_memories

# Step 3: Parallel execution of all paths
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
future_working = executor.submit(retrieve_from_working_memory)
future_hybrid = executor.submit(retrieve_ranked_long_term_and_user)
future_internet = executor.submit(retrieve_from_internet)

working_results = future_working.result()
hybrid_results = future_hybrid.result()
internet_results = future_internet.result()
searched_res = working_results + hybrid_results + internet_results
# Step 3: Parallel execution of all paths (enable internet search accoeding to parameter in the parsed goal)
if parsed_goal.internet_search:
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
future_working = executor.submit(retrieve_from_working_memory)
future_hybrid = executor.submit(retrieve_ranked_long_term_and_user)
future_internet = executor.submit(retrieve_from_internet)

working_results = future_working.result()
hybrid_results = future_hybrid.result()
internet_results = future_internet.result()
searched_res = working_results + hybrid_results + internet_results
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
future_working = executor.submit(retrieve_from_working_memory)
future_hybrid = executor.submit(retrieve_ranked_long_term_and_user)

working_results = future_working.result()
hybrid_results = future_hybrid.result()
searched_res = working_results + hybrid_results

# Deduplicate by item.memory, keep higher score
deduped_result = {}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging
import traceback

from string import Template

Expand All @@ -14,54 +15,80 @@ class TaskGoalParser:
- mode == 'fine': use LLM to parse structured topic/keys/tags
"""

def __init__(self, llm=BaseLLM, mode: str = "fast"):
def __init__(self, llm=BaseLLM):
self.llm = llm
self.mode = mode

def parse(self, task_description: str, context: str = "") -> ParsedTaskGoal:
def parse(
self,
task_description: str,
context: str = "",
conversation: list[dict] | None = None,
mode: str = "fast",
) -> ParsedTaskGoal:
"""
Parse user input into structured semantic layers.
Returns:
ParsedTaskGoal: object containing topic/concept/fact levels and optional metadata
- mode == 'fast': use jieba to split words only
- mode == 'fine': use LLM to parse structured topic/keys/tags
"""
if self.mode == "fast":
if mode == "fast":
return self._parse_fast(task_description)
elif self.mode == "fine":
elif mode == "fine":
if not self.llm:
raise ValueError("LLM not provided for slow mode.")
return self._parse_fine(task_description, context)
return self._parse_fine(task_description, context, conversation)
else:
raise ValueError(f"Unknown mode: {self.mode}")
raise ValueError(f"Unknown mode: {mode}")

def _parse_fast(self, task_description: str, limit_num: int = 5) -> ParsedTaskGoal:
"""
Fast mode: simple jieba word split.
"""
return ParsedTaskGoal(
memories=[task_description], keys=[task_description], tags=[], goal_type="default"
memories=[task_description],
keys=[task_description],
tags=[],
goal_type="default",
rephrased_query=task_description,
internet_search=False,
)

def _parse_fine(self, query: str, context: str = "") -> ParsedTaskGoal:
def _parse_fine(
self, query: str, context: str = "", conversation: list[dict] | None = None
) -> ParsedTaskGoal:
"""
Slow mode: LLM structured parse.
"""
prompt = Template(TASK_PARSE_PROMPT).substitute(task=query.strip(), context=context)
response = self.llm.generate(messages=[{"role": "user", "content": prompt}])
return self._parse_response(response)
try:
if conversation:
conversation_prompt = "\n".join(
[f"{each['role']}: {each['content']}" for each in conversation]
)
else:
conversation_prompt = ""
prompt = Template(TASK_PARSE_PROMPT).substitute(
task=query.strip(), context=context, conversation=conversation_prompt
)
response = self.llm.generate(messages=[{"role": "user", "content": prompt}])
return self._parse_response(response)
except Exception:
logging.warning(f"Fail to fine-parse query {query}: {traceback.format_exc()}")
return self._parse_fast(query)

def _parse_response(self, response: str) -> ParsedTaskGoal:
"""
Parse LLM JSON output safely.
"""
try:
response = response.replace("```", "").replace("json", "")
response_json = json.loads(response.strip())
response = response.replace("```", "").replace("json", "").strip()
response_json = eval(response)
return ParsedTaskGoal(
memories=response_json.get("memories", []),
keys=response_json.get("keys", []),
tags=response_json.get("tags", []),
rephrased_query=response_json.get("rephrased_instruction", None),
internet_search=response_json.get("internet_search", False),
goal_type=response_json.get("goal_type", "default"),
)
except Exception as e:
Expand Down
Loading
Loading