Skip to content

Commit a3932a1

Browse files
committed
refactor: move embedding cache to context module
Move embedding cache functionality from client to context module to improve separation of concerns and code organization. The cache is now maintained in the context module which is more appropriate since it deals with context management. This change simplifies the client module by removing the embedding cache responsibility while maintaining the same caching functionality in a more logical location. Signed-off-by: Tomas Slusny <[email protected]>
1 parent 36e6292 commit a3932a1

File tree

2 files changed

+44
-38
lines changed

2 files changed

+44
-38
lines changed

lua/CopilotChat/client.lua

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,6 @@ end
244244
---@field history table<CopilotChat.Provider.input>
245245
---@field providers table<string, CopilotChat.Provider>
246246
---@field provider_cache table<string, table>
247-
---@field embedding_cache table<string, CopilotChat.context.embed>
248247
---@field models table<string, CopilotChat.Client.model>?
249248
---@field agents table<string, CopilotChat.Client.agent>?
250249
---@field current_job string?
@@ -253,7 +252,6 @@ local Client = class(function(self)
253252
self.history = {}
254253
self.providers = {}
255254
self.provider_cache = {}
256-
self.embedding_cache = {}
257255
self.models = nil
258256
self.agents = nil
259257
self.current_job = nil
@@ -749,25 +747,10 @@ function Client:embed(inputs, model)
749747
notify.publish(notify.STATUS, 'Generating embeddings for ' .. #inputs .. ' inputs')
750748

751749
-- Initialize essentials
752-
local to_process = {}
750+
local to_process = inputs
753751
local results = {}
754752
local initial_chunk_size = 10
755753

756-
-- Process each input, using cache when possible
757-
for _, input in ipairs(inputs) do
758-
input.filename = input.filename or 'unknown'
759-
input.filetype = input.filetype or 'text'
760-
761-
if input.content then
762-
local cache_key = input.filename .. utils.quick_hash(input.content)
763-
if self.embedding_cache[cache_key] then
764-
table.insert(results, self.embedding_cache[cache_key])
765-
else
766-
table.insert(to_process, input)
767-
end
768-
end
769-
end
770-
771754
-- Process inputs in batches with adaptive chunk size
772755
while #to_process > 0 do
773756
local chunk_size = initial_chunk_size -- Reset chunk size for each new batch
@@ -814,9 +797,6 @@ function Client:embed(inputs, model)
814797
for _, embedding in ipairs(data) do
815798
local result = vim.tbl_extend('force', batch[embedding.index + 1], embedding)
816799
table.insert(results, result)
817-
818-
local cache_key = result.filename .. utils.quick_hash(result.content)
819-
self.embedding_cache[cache_key] = result
820800
end
821801
end
822802
end
@@ -845,7 +825,6 @@ end
845825
function Client:reset()
846826
local stopped = self:stop()
847827
self.history = {}
848-
self.embedding_cache = {}
849828
return stopped
850829
end
851830

lua/CopilotChat/context.lua

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ local notify = require('CopilotChat.notify')
2323
local utils = require('CopilotChat.utils')
2424
local file_cache = {}
2525
local url_cache = {}
26+
local embedding_cache = {}
2627

2728
local M = {}
2829

@@ -661,15 +662,21 @@ function M.filter_embeddings(prompt, model, headless, embeddings)
661662
notify.publish(notify.STATUS, 'Ranking embeddings')
662663

663664
-- Build query from history and prompt
664-
local query = ''
665+
local query = prompt
665666
if not headless then
666-
for _, message in ipairs(client.history) do
667-
if message.role == 'user' then
668-
query = query .. '\n' .. message.content
669-
end
670-
end
667+
query = table.concat(
668+
vim
669+
.iter(client.history)
670+
:filter(function(m)
671+
return m.role == 'user'
672+
end)
673+
:map(function(m)
674+
return vim.trim(m.content)
675+
end)
676+
:totable(),
677+
'\n'
678+
) .. '\n' .. prompt
671679
end
672-
query = query .. '\n' .. prompt
673680

674681
-- Rank embeddings by symbols
675682
embeddings = data_ranked_by_symbols(query, embeddings, MIN_SYMBOL_SIMILARITY)
@@ -678,26 +685,46 @@ function M.filter_embeddings(prompt, model, headless, embeddings)
678685
log.debug(string.format('%s: %s - %s', i, item.score, item.filename))
679686
end
680687

681-
-- Embed the query
682-
table.insert(embeddings, {
688+
-- Prepare embeddings for processing
689+
local to_process = {}
690+
local results = {}
691+
for _, input in ipairs(embeddings) do
692+
input.filename = input.filename or 'unknown'
693+
input.filetype = input.filetype or 'text'
694+
if input.content then
695+
local cache_key = input.filename .. utils.quick_hash(input.content)
696+
if embedding_cache[cache_key] then
697+
table.insert(results, embedding_cache[cache_key])
698+
else
699+
table.insert(to_process, input)
700+
end
701+
end
702+
end
703+
table.insert(to_process, {
683704
content = query,
684705
filename = 'query',
685706
filetype = 'raw',
686707
})
687708

688-
-- Get embeddings from all items
689-
embeddings = client:embed(embeddings, model)
709+
-- Embed the data and process the results
710+
for _, input in ipairs(client:embed(to_process, model)) do
711+
if input.filetype ~= 'raw' then
712+
local cache_key = input.filename .. utils.quick_hash(input.content)
713+
embedding_cache[cache_key] = input
714+
end
715+
table.insert(results, input)
716+
end
690717

691718
-- Rate embeddings by relatedness to the query
692-
local embedded_query = table.remove(embeddings, #embeddings)
719+
local embedded_query = table.remove(results, #results)
693720
log.debug('Embedded query:', embedded_query.content)
694-
embeddings = data_ranked_by_relatedness(embedded_query, embeddings, MIN_SEMANTIC_SIMILARITY)
695-
log.debug('Ranked embeddings:', #embeddings)
696-
for i, item in ipairs(embeddings) do
721+
results = data_ranked_by_relatedness(embedded_query, results, MIN_SEMANTIC_SIMILARITY)
722+
log.debug('Ranked embeddings:', #results)
723+
for i, item in ipairs(results) do
697724
log.debug(string.format('%s: %s - %s', i, item.score, item.filename))
698725
end
699726

700-
return embeddings
727+
return results
701728
end
702729

703730
return M

0 commit comments

Comments
 (0)