Skip to content

Commit 8079b99

Browse files
committed
feat: Add github workspace command
Supports resolving github code index workspace data and searching in it TODO: Currently this api do not accepts ghu_ github copilot token, and I need to use `gh cli` instead that creates hosts.yml https://github.blog/engineering/the-technology-behind-githubs-new-code-search/ Signed-off-by: Tomas Slusny <[email protected]>
1 parent 07bcd20 commit 8079b99

File tree

6 files changed

+152
-5
lines changed

6 files changed

+152
-5
lines changed

lua/CopilotChat/client.lua

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,24 @@ function Client:embed(inputs, model)
824824
return results
825825
end
826826

827+
--- Search for the given query
828+
---@param query string: The query to search for
829+
---@param repository string: The repository to search in
830+
---@param model string: The model to use for search
831+
---@return table<CopilotChat.context.embed>
832+
function Client:search(query, repository, model)
833+
local models = self:fetch_models()
834+
835+
local provider_name, search = resolve_provider_function('search', model, models, self.providers)
836+
local headers = self:authenticate(provider_name)
837+
local ok, response = pcall(search, query, repository, headers)
838+
if not ok then
839+
log.warn('Failed to search: ', response)
840+
return {}
841+
end
842+
return response
843+
end
844+
827845
--- Stop the running job
828846
---@return boolean
829847
function Client:stop()

lua/CopilotChat/config/contexts.lua

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ local utils = require('CopilotChat.utils')
55
---@class CopilotChat.config.context
66
---@field description string?
77
---@field input fun(callback: fun(input: string?), source: CopilotChat.source)?
8-
---@field resolve fun(input: string?, source: CopilotChat.source, prompt: string):table<CopilotChat.context.embed>
8+
---@field resolve fun(input: string?, source: CopilotChat.source, prompt: string, model: string):table<CopilotChat.context.embed>
99

1010
---@type table<string, CopilotChat.config.context>
1111
return {
@@ -160,4 +160,10 @@ return {
160160
return context.quickfix()
161161
end,
162162
},
163+
workspace = {
164+
description = 'Includes all non-hidden files in the current workspace in chat context.',
165+
resolve = function(_, _, prompt, model)
166+
return context.workspace(prompt, model)
167+
end,
168+
},
163169
}

lua/CopilotChat/config/mappings.lua

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,8 @@ return {
409409
async.run(function()
410410
local embeddings = {}
411411
if section and not section.answer then
412-
embeddings = copilot.resolve_embeddings(section.content, chat.config)
412+
local _, selected_model = pcall(copilot.resolve_model, section.content, chat.config)
413+
embeddings = copilot.resolve_embeddings(section.content, selected_model, chat.config)
413414
end
414415

415416
for _, embedding in ipairs(embeddings) do

lua/CopilotChat/config/providers.lua

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,41 @@ local function get_github_token()
101101
error('Failed to find GitHub token')
102102
end
103103

104+
local cached_gh_apps_token = nil
105+
106+
--- Get the github apps token (gho_ token)
107+
---@return string
108+
local function get_gh_apps_token()
109+
if cached_gh_apps_token then
110+
return cached_gh_apps_token
111+
end
112+
113+
async.util.scheduler()
114+
115+
local config_path = utils.config_path()
116+
if not config_path then
117+
error('Failed to find config path for GitHub token')
118+
end
119+
120+
local file_path = config_path .. '/gh/hosts.yml'
121+
if vim.fn.filereadable(file_path) == 1 then
122+
local content = table.concat(vim.fn.readfile(file_path), '\n')
123+
local token = content:match('oauth_token:%s*([%w_]+)')
124+
if token then
125+
cached_gh_apps_token = token
126+
return token
127+
end
128+
end
129+
130+
error('Failed to find GitHub token')
131+
end
132+
104133
---@type table<string, CopilotChat.Provider>
105134
local M = {}
106135

107136
M.copilot = {
108137
embed = 'copilot_embeddings',
138+
search = 'copilot_search',
109139

110140
get_headers = function(token)
111141
return {
@@ -279,6 +309,7 @@ M.copilot = {
279309

280310
M.github_models = {
281311
embed = 'copilot_embeddings',
312+
search = 'copilot_search',
282313

283314
get_headers = function(token)
284315
return {
@@ -360,4 +391,80 @@ M.copilot_embeddings = {
360391
end,
361392
}
362393

394+
M.copilot_search = {
395+
get_headers = M.copilot.get_headers,
396+
397+
get_token = function()
398+
return get_gh_apps_token(), nil
399+
end,
400+
401+
search = function(query, repository, headers)
402+
utils.curl_post(
403+
'https://api.github.com/repos/' .. repository .. '/copilot_internal/embeddings_index',
404+
{
405+
headers = headers,
406+
}
407+
)
408+
409+
local response, err = utils.curl_get(
410+
'https://api.github.com/repos/' .. repository .. '/copilot_internal/embeddings_index',
411+
{
412+
headers = headers,
413+
}
414+
)
415+
416+
if err then
417+
error(err)
418+
end
419+
420+
if response.status ~= 200 then
421+
error('Failed to check search: ' .. tostring(response.status))
422+
end
423+
424+
local body = vim.json.decode(response.body)
425+
426+
if
427+
body.can_index ~= 'ok'
428+
or not body.bm25_search_ok
429+
or not body.lexical_search_ok
430+
or not body.semantic_code_search_ok
431+
or not body.semantic_doc_search_ok
432+
or not body.semantic_indexing_enabled
433+
then
434+
error('Failed to search: ' .. vim.inspect(body))
435+
end
436+
437+
local body = vim.json.encode({
438+
query = query,
439+
scopingQuery = '(repo:' .. repository .. ')',
440+
similarity = 0.766,
441+
limit = 100,
442+
})
443+
444+
local response, err = utils.curl_post('https://api.individual.githubcopilot.com/search/code', {
445+
headers = headers,
446+
body = utils.temp_file(body),
447+
})
448+
449+
if err then
450+
error(err)
451+
end
452+
453+
if response.status ~= 200 then
454+
error('Failed to search: ' .. tostring(response.body))
455+
end
456+
457+
local out = {}
458+
for _, result in ipairs(vim.json.decode(response.body)) do
459+
table.insert(out, {
460+
filename = result.path,
461+
filetype = result.languageName:lower(),
462+
score = result.score,
463+
content = result.contents,
464+
})
465+
end
466+
return out
467+
end,
468+
}
469+
363470
return M

lua/CopilotChat/context.lua

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,20 @@ function M.quickfix()
639639
return out
640640
end
641641

642+
--- Get the content of the current workspace
643+
---@param prompt string
644+
---@param model string
645+
function M.workspace(prompt, model)
646+
local git_remote =
647+
vim.trim(utils.system({ 'git', 'config', '--get', 'remote.origin.url' }).stdout)
648+
local repo_path = git_remote:match('github.com[:/](.+).git$')
649+
if not repo_path then
650+
error('Could not determine GitHub repository from git remote: ' .. git_remote)
651+
end
652+
653+
return client:search(prompt, repo_path, model)
654+
end
655+
642656
--- Filter embeddings based on the query
643657
---@param prompt string
644658
---@param model string

lua/CopilotChat/init.lua

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,10 @@ end
220220

221221
--- Resolve the embeddings from the prompt.
222222
---@param prompt string
223+
---@param model string
223224
---@param config CopilotChat.config.shared
224225
---@return table<CopilotChat.context.embed>, string
225-
function M.resolve_embeddings(prompt, config)
226+
function M.resolve_embeddings(prompt, model, config)
226227
local contexts = {}
227228
local function parse_context(prompt_context)
228229
local split = vim.split(prompt_context, ':')
@@ -262,7 +263,7 @@ function M.resolve_embeddings(prompt, config)
262263
for _, context_data in ipairs(contexts) do
263264
local context_value = M.config.contexts[context_data.name]
264265
for _, embedding in
265-
ipairs(context_value.resolve(context_data.input, state.source or {}, prompt))
266+
ipairs(context_value.resolve(context_data.input, state.source or {}, prompt, model))
266267
do
267268
if embedding then
268269
embeddings:set(embedding.filename, embedding)
@@ -648,7 +649,7 @@ function M.ask(prompt, config)
648649
local ok, err = pcall(async.run, function()
649650
local selected_agent, prompt = M.resolve_agent(prompt, config)
650651
local selected_model, prompt = M.resolve_model(prompt, config)
651-
local embeddings, prompt = M.resolve_embeddings(prompt, config)
652+
local embeddings, prompt = M.resolve_embeddings(prompt, selected_model, config)
652653

653654
local has_output = false
654655
local query_ok, filtered_embeddings =

0 commit comments

Comments
 (0)