Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/memos/api/handlers/search_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_ENV_CONTEXT_RECALL = "MEMOS_DREAM_CONTEXT_RECALL"
_ENV_CONTEXT_RECALL_TOP_K = "MEMOS_DREAM_CONTEXT_RECALL_TOP_K"
_DEFAULT_CONTEXT_RECALL_TOP_K = 2
_MISSING_EMBEDDING_BATCH_SIZE = 10


def _env_enabled(name: str, default: str = "off") -> bool:
Expand Down Expand Up @@ -590,10 +591,13 @@ def _extract_embeddings(self, memories: list[dict[str, Any]]) -> list[list[float
missing_documents.append(mem.get("memory", ""))

if missing_indices:
computed = self.searcher.embedder.embed(missing_documents)
for idx, embedding in zip(missing_indices, computed, strict=False):
embeddings[idx] = embedding
memories[idx]["metadata"]["embedding"] = embedding
for start in range(0, len(missing_documents), _MISSING_EMBEDDING_BATCH_SIZE):
batch_documents = missing_documents[start : start + _MISSING_EMBEDDING_BATCH_SIZE]
batch_indices = missing_indices[start : start + _MISSING_EMBEDDING_BATCH_SIZE]
computed = self.searcher.embedder.embed(batch_documents)
for idx, embedding in zip(batch_indices, computed, strict=False):
embeddings[idx] = embedding
memories[idx]["metadata"]["embedding"] = embedding

return embeddings

Expand Down
41 changes: 41 additions & 0 deletions tests/api/test_search_handler_embedding_batches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from memos.api.handlers.base_handler import HandlerDependencies
from memos.api.handlers.search_handler import SearchHandler


class BatchLimitedEmbedder:
def __init__(self, *, limit: int):
self.limit = limit
self.calls: list[list[str]] = []

def embed(self, texts: list[str]) -> list[list[float]]:
self.calls.append(list(texts))
if len(texts) > self.limit:
raise AssertionError(f"batch too large: {len(texts)}")
return [[float(len(text)), 0.0] for text in texts]


def _handler(embedder: BatchLimitedEmbedder) -> SearchHandler:
searcher = type("FakeSearcher", (), {"embedder": embedder})()
return SearchHandler(
HandlerDependencies(
naive_mem_cube=object(),
mem_scheduler=object(),
searcher=searcher,
deepsearch_agent=object(),
)
)


def test_extract_embeddings_batches_missing_documents():
embedder = BatchLimitedEmbedder(limit=10)
handler = _handler(embedder)
memories = [
{"memory": f"memory {idx}", "metadata": {}}
for idx in range(25)
]

embeddings = handler._extract_embeddings(memories)

assert [len(call) for call in embedder.calls] == [10, 10, 5]
assert len(embeddings) == 25
assert all(mem["metadata"]["embedding"] for mem in memories)