Skip to content

feat: Add github workspace command #804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions lua/CopilotChat/client.lua
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,24 @@ function Client:embed(inputs, model)
return results
end

--- Search for the given query
---@param query string: The query to search for
---@param repository string: The repository to search in
---@param model string: The model to use for search
---@return table<CopilotChat.context.embed>
function Client:search(query, repository, model)
local models = self:fetch_models()

local provider_name, search = resolve_provider_function('search', model, models, self.providers)
local headers = self:authenticate(provider_name)
local ok, response = pcall(search, query, repository, headers)
if not ok then
log.warn('Failed to search: ', response)
return {}
end
return response
end

--- Stop the running job
---@return boolean
function Client:stop()
Expand Down
8 changes: 7 additions & 1 deletion lua/CopilotChat/config/contexts.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ local utils = require('CopilotChat.utils')
---@class CopilotChat.config.context
---@field description string?
---@field input fun(callback: fun(input: string?), source: CopilotChat.source)?
---@field resolve fun(input: string?, source: CopilotChat.source, prompt: string):table<CopilotChat.context.embed>
---@field resolve fun(input: string?, source: CopilotChat.source, prompt: string, model: string):table<CopilotChat.context.embed>

---@type table<string, CopilotChat.config.context>
return {
Expand Down Expand Up @@ -173,4 +173,10 @@ return {
return context.quickfix()
end,
},
workspace = {
description = 'Includes all non-hidden files in the current workspace in chat context.',
resolve = function(_, _, prompt, model)
return context.workspace(prompt, model)
end,
},
}
3 changes: 2 additions & 1 deletion lua/CopilotChat/config/mappings.lua
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ return {
async.run(function()
local embeddings = {}
if section and not section.answer then
embeddings = copilot.resolve_embeddings(section.content, chat.config)
local _, selected_model = pcall(copilot.resolve_model, section.content, chat.config)
embeddings = copilot.resolve_embeddings(section.content, selected_model, chat.config)
end

for _, embedding in ipairs(embeddings) do
Expand Down
108 changes: 108 additions & 0 deletions lua/CopilotChat/config/providers.lua
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ local utils = require('CopilotChat.utils')
---@field get_agents nil|fun(headers:table):table<CopilotChat.Provider.agent>
---@field get_models nil|fun(headers:table):table<CopilotChat.Provider.model>
---@field embed nil|string|fun(inputs:table<string>, headers:table):table<CopilotChat.Provider.embed>
---@field search nil|string|fun(query:string, repository:string, headers:table):table<CopilotChat.Provider.output>
---@field prepare_input nil|fun(inputs:table<CopilotChat.Provider.input>, opts:CopilotChat.Provider.options):table
---@field prepare_output nil|fun(output:table, opts:CopilotChat.Provider.options):CopilotChat.Provider.output
---@field get_url nil|fun(opts:CopilotChat.Provider.options):string
Expand Down Expand Up @@ -100,11 +101,41 @@ local function get_github_token()
error('Failed to find GitHub token')
end

local cached_gh_apps_token = nil

--- Get the github apps token (gho_ token)
---@return string
local function get_gh_apps_token()
if cached_gh_apps_token then
return cached_gh_apps_token
end

async.util.scheduler()

local config_path = utils.config_path()
if not config_path then
error('Failed to find config path for GitHub token')
end

local file_path = config_path .. '/gh/hosts.yml'
if vim.fn.filereadable(file_path) == 1 then
local content = table.concat(vim.fn.readfile(file_path), '\n')
local token = content:match('oauth_token:%s*([%w_]+)')
if token then
cached_gh_apps_token = token
return token
end
end

error('Failed to find GitHub token')
end

---@type table<string, CopilotChat.Provider>
local M = {}

M.copilot = {
embed = 'copilot_embeddings',
search = 'copilot_search',

get_headers = function()
local response, err = utils.curl_get('https://api.github.com/copilot_internal/v2/token', {
Expand Down Expand Up @@ -271,6 +302,7 @@ M.copilot = {

M.github_models = {
embed = 'copilot_embeddings',
search = 'copilot_search',

get_headers = function()
return {
Expand Down Expand Up @@ -350,4 +382,80 @@ M.copilot_embeddings = {
end,
}

M.copilot_search = {
get_headers = M.copilot.get_headers,

get_token = function()
return get_gh_apps_token(), nil
end,

search = function(query, repository, headers)
utils.curl_post(
'https://api.github.com/repos/' .. repository .. '/copilot_internal/embeddings_index',
{
headers = headers,
}
)

local response, err = utils.curl_get(
'https://api.github.com/repos/' .. repository .. '/copilot_internal/embeddings_index',
{
headers = headers,
}
)

if err then
error(err)
end

if response.status ~= 200 then
error('Failed to check search: ' .. tostring(response.status))
end

local body = vim.json.decode(response.body)

if
body.can_index ~= 'ok'
or not body.bm25_search_ok
or not body.lexical_search_ok
or not body.semantic_code_search_ok
or not body.semantic_doc_search_ok
or not body.semantic_indexing_enabled
then
error('Failed to search: ' .. vim.inspect(body))
end

local body = vim.json.encode({
query = query,
scopingQuery = '(repo:' .. repository .. ')',
similarity = 0.766,
limit = 100,
})

local response, err = utils.curl_post('https://api.individual.githubcopilot.com/search/code', {
headers = headers,
body = utils.temp_file(body),
})

if err then
error(err)
end

if response.status ~= 200 then
error('Failed to search: ' .. tostring(response.body))
end

local out = {}
for _, result in ipairs(vim.json.decode(response.body)) do
table.insert(out, {
filename = result.path,
filetype = result.languageName:lower(),
score = result.score,
content = result.contents,
})
end
return out
end,
}

return M
14 changes: 14 additions & 0 deletions lua/CopilotChat/context.lua
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,20 @@ function M.quickfix()
return out
end

--- Get the content of the current workspace
---@param prompt string
---@param model string
function M.workspace(prompt, model)
local git_remote =
vim.trim(utils.system({ 'git', 'config', '--get', 'remote.origin.url' }).stdout)
local repo_path = git_remote:match('github.com[:/](.+).git$')
if not repo_path then
error('Could not determine GitHub repository from git remote: ' .. git_remote)
end

return client:search(prompt, repo_path, model)
end

--- Filter embeddings based on the query
---@param prompt string
---@param model string
Expand Down
7 changes: 4 additions & 3 deletions lua/CopilotChat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,10 @@ end

--- Resolve the embeddings from the prompt.
---@param prompt string
---@param model string
---@param config CopilotChat.config.shared
---@return table<CopilotChat.context.embed>, string
function M.resolve_embeddings(prompt, config)
function M.resolve_embeddings(prompt, model, config)
local contexts = {}
local function parse_context(prompt_context)
local split = vim.split(prompt_context, ':')
Expand Down Expand Up @@ -289,7 +290,7 @@ function M.resolve_embeddings(prompt, config)
for _, context_data in ipairs(contexts) do
local context_value = M.config.contexts[context_data.name]
for _, embedding in
ipairs(context_value.resolve(context_data.input, state.source or {}, prompt))
ipairs(context_value.resolve(context_data.input, state.source or {}, prompt, model))
do
if embedding then
embeddings:set(embedding.filename, embedding)
Expand Down Expand Up @@ -672,7 +673,7 @@ function M.ask(prompt, config)
local ok, err = pcall(async.run, function()
local selected_agent, prompt = M.resolve_agent(prompt, config)
local selected_model, prompt = M.resolve_model(prompt, config)
local embeddings, prompt = M.resolve_embeddings(prompt, config)
local embeddings, prompt = M.resolve_embeddings(prompt, selected_model, config)

local has_output = false
local query_ok, filtered_embeddings =
Expand Down
Loading