Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions lua/CopilotChat/constants.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,13 @@ return {
SYSTEM = 'system',
TOOL = 'tool',
},

SYMBOLS = {
STICKY = '> ',
MODEL = '$',
TOOL = '@',
RESOURCE = '#',
URI = '##',
PROMPT = '/',
},
}
139 changes: 54 additions & 85 deletions lua/CopilotChat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@ local client = require('CopilotChat.client')
local constants = require('CopilotChat.constants')
local notify = require('CopilotChat.notify')
local utils = require('CopilotChat.utils')
local prompts = require('CopilotChat.prompts')

local WORD = '([^%s:]+)'
local WORD_NO_INPUT = '([^%s]+)'
local WORD_WITH_INPUT_QUOTED = WORD .. ':`([^`]+)`'
local WORD_WITH_INPUT_UNQUOTED = WORD .. ':?([^%s`]*)'
local BLOCK_OUTPUT_FORMAT = '```%s\n%s\n```'

---@class CopilotChat
Expand Down Expand Up @@ -315,10 +312,11 @@ function M.resolve_functions(prompt, config)
tools[tool.name] = tool
end

local refs = prompts.parse(prompt)
local found_tools = utils.to_table(config.tools)
local enabled_tools = {}
local resolved_resources = {}
local resolved_tools = {}
local matches = utils.to_table(config.tools)
local tool_calls = {}
for _, message in ipairs(M.chat.messages) do
if message.tool_calls then
Expand All @@ -328,54 +326,33 @@ function M.resolve_functions(prompt, config)
end
end

-- Check for @tool pattern to find enabled tools
prompt = prompt:gsub('@' .. WORD, function(match)
for name, tool in pairs(M.config.functions) do
if name == match or tool.group == match then
table.insert(matches, match)
return ''
-- Find enabled tools from @ references
for _, ref in ipairs(refs) do
if ref.type == 'function_reference' then
for name, tool in pairs(M.config.functions) do
if name == ref.value or tool.group == ref.value then
table.insert(found_tools, ref.value)
end
end
end
return '@' .. match
end)
for _, match in ipairs(matches) do
end

-- Convert tool names to tool objects
for _, match in ipairs(found_tools) do
for name, tool in pairs(M.config.functions) do
if name == match or tool.group == match then
table.insert(enabled_tools, tools[name])
end
end
end

local matches = utils.ordered_map()

-- Check for #word:`input` pattern
for word, input in prompt:gmatch('#' .. WORD_WITH_INPUT_QUOTED) do
local pattern = string.format('#%s:`%s`', word, input)
matches:set(pattern, {
word = word,
input = input,
})
end

-- Check for #word:input pattern
for word, input in prompt:gmatch('#' .. WORD_WITH_INPUT_UNQUOTED) do
local pattern = utils.empty(input) and string.format('#%s', word) or string.format('#%s:%s', word, input)
matches:set(pattern, {
word = word,
input = input,
})
end

-- Check for ##word:input pattern
for word in prompt:gmatch('##' .. WORD_NO_INPUT) do
local pattern = string.format('##%s', word)
matches:set(pattern, {
word = word,
})
end
prompt = prompts.replace(prompt, refs, function(ref)
if ref.type ~= 'function_call' then
return
end

-- Resolve each function reference
local function expand_function(name, input)
local name = ref.value
local input = ref.input
notify.publish(notify.STATUS, 'Running function: ' .. name)

local tool_id = nil
Expand Down Expand Up @@ -448,17 +425,7 @@ function M.resolve_functions(prompt, config)
end

return result
end

-- Resolve and process all tools
for _, pattern in ipairs(matches:keys()) do
if not utils.empty(pattern) then
local match = matches:get(pattern)
local out = expand_function(match.word, match.input) or pattern
out = out:gsub('%%', '%%%%') -- Escape percent signs for gsub
prompt = prompt:gsub(vim.pesc(pattern), out, 1)
end
end
end)

return enabled_tools, resolved_resources, resolved_tools, prompt
end
Expand All @@ -479,45 +446,44 @@ function M.resolve_prompt(prompt, config)
local depth = 0
local MAX_DEPTH = 10

local function resolve(inner_config, inner_prompt)
if depth >= MAX_DEPTH then
local function resolve_prompt_template(inner_config, inner_prompt)
if depth >= MAX_DEPTH or not inner_prompt then
return inner_config, inner_prompt
end

depth = depth + 1

inner_prompt = string.gsub(inner_prompt, '/' .. WORD, function(match)
local p = prompts_to_use[match]
if p then
local resolved_config, resolved_prompt = resolve(p, p.prompt or '')
inner_config = vim.tbl_deep_extend('force', inner_config, resolved_config)
return resolved_prompt
for _, ref in ipairs(prompts.parse(inner_prompt)) do
if ref.type == 'prompt' then
local template = prompts_to_use[ref.value]
if template then
local resolved_config, resolved_prompt = resolve_prompt_template(template, template.prompt)
inner_config = vim.tbl_deep_extend('force', inner_config, resolved_config)
if resolved_prompt then
inner_prompt = inner_prompt:sub(1, ref.start_pos - 1)
.. resolved_prompt
.. inner_prompt:sub(ref.end_pos + 1)
end
end
end

return '/' .. match
end)
end

depth = depth - 1
return inner_config, inner_prompt
end

local function resolve_system_prompt(system_prompt)
if type(system_prompt) == 'function' then
local ok, result = pcall(system_prompt)
if not ok then
log.warn('Failed to resolve system prompt function: ' .. result)
return nil
end
return result
end

return system_prompt
end

config = vim.tbl_deep_extend('force', M.config, config or {})
config, prompt = resolve(config, prompt or '')
config, prompt = resolve_prompt_template(config, prompt)
prompt = prompt or ''

if config.system_prompt then
config.system_prompt = resolve_system_prompt(config.system_prompt)
if type(config.system_prompt) == 'function' then
---@diagnostic disable-next-line: param-type-mismatch
local ok, result = pcall(config.system_prompt)
if ok then
config.system_prompt = result
end
end

if M.config.prompts[config.system_prompt] then
-- Name references are good for making system prompt auto sticky
Expand Down Expand Up @@ -547,13 +513,16 @@ function M.resolve_model(prompt, config)
return model.id
end, list_models())

local refs = prompts.parse(prompt)
local selected_model = config.model or ''
prompt = prompt:gsub('%$' .. WORD, function(match)
if vim.tbl_contains(models, match) then
selected_model = match
return ''

prompt = prompts.replace(prompt, refs, function(ref)
if ref.type == 'model' then
if vim.tbl_contains(models, ref.value) then
selected_model = ref.value
return ''
end
end
return '$' .. match
end)

return selected_model, prompt
Expand Down
152 changes: 152 additions & 0 deletions lua/CopilotChat/prompts.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
local M = {}

local WORD = '([^%s:]+)'
local WORD_NO_INPUT = '([^%s]+)'
local WORD_WITH_INPUT_QUOTED = WORD .. ':`([^`]+)`'
local WORD_WITH_INPUT_UNQUOTED = WORD .. ':?([^%s`]*)'

---@class CopilotChat.prompts.Reference
---@field type 'model'|'function'|'function_call'|'resource'|'sticky'|'prompt'
---@field value string
---@field input? string
---@field start_pos integer
---@field end_pos integer

--- Parse all references from a prompt string, tracking positions.
---@param prompt string
---@return CopilotChat.prompts.Reference[] refs
function M.parse(prompt)
local refs = {}

-- $model
for s, value, e in prompt:gmatch('()%$' .. WORD .. '()') do
table.insert(refs, {
type = 'model',
value = value,
start_pos = s,
end_pos = e - 1,
})
end

-- @function
for s, value, e in prompt:gmatch('()@' .. WORD .. '()') do
table.insert(refs, {
type = 'function_reference',
value = value,
start_pos = s,
end_pos = e - 1,
})
end

-- #function_call
local function function_call_matches(str)
local matches = {}
-- #function_call:`input` (quoted)
for s, value, input, e in str:gmatch('()#' .. WORD_WITH_INPUT_QUOTED .. '()') do
table.insert(matches, { s = s, e = e - 1, value = value, input = input })
end
-- #function_call:input (unquoted)
for s, value, input, e in str:gmatch('()#' .. WORD_WITH_INPUT_UNQUOTED .. '()') do
table.insert(matches, { s = s, e = e - 1, value = value, input = input })
end
-- #function_call (no input)
for s, value, e in str:gmatch('()#' .. WORD_NO_INPUT .. '()') do
table.insert(matches, { s = s, e = e - 1, value = value, input = nil })
end
return matches
end
for _, m in ipairs(function_call_matches(prompt)) do
table.insert(refs, {
type = 'function_call',
value = m.value,
input = m.input or nil,
start_pos = m.s,
end_pos = m.e,
})
end

-- ##resource
for s, value, e in prompt:gmatch('()##' .. WORD_NO_INPUT .. '()') do
table.insert(refs, {
type = 'resource',
value = value,
start_pos = s,
end_pos = e - 1,
})
end

-- > sticky
local function sticky_matches(str)
local matches = {}
-- > sticky (newline)
for s, value, e in str:gmatch('()\n> ([^\n]+)()') do
table.insert(matches, { s = s + 1, e = e - 1, value = value })
end
-- > sticky (start of string)
for s, value, e in str:gmatch('()^> ([^\n]+)()') do
table.insert(matches, { s = s, e = e - 1, value = value })
end
return matches
end
for _, m in ipairs(sticky_matches(prompt)) do
table.insert(refs, {
type = 'sticky',
value = m.value,
start_pos = m.s,
end_pos = m.e,
})
end

-- /prompt
for s, value, e in prompt:gmatch('()/' .. WORD_NO_INPUT .. '()') do
table.insert(refs, {
type = 'prompt',
value = value,
start_pos = s,
end_pos = e - 1,
})
end

local keep = {}
for i, ref in ipairs(refs) do
local contained = false
for j, other in ipairs(refs) do
if i ~= j then
-- Strictly contained
if other.type ~= 'sticky' and ref.start_pos > other.start_pos and ref.end_pos < other.end_pos then
contained = true
break
end
-- Exact match, only keep the first occurrence
if ref.start_pos == other.start_pos and ref.end_pos == other.end_pos and j < i then
contained = true
break
end
end
end
if not contained then
table.insert(keep, ref)
end
end

return keep
end

--- Replace references in the prompt using positions (descending order).
---@param prompt string
---@param refs CopilotChat.prompts.Reference[]
---@param resolver fun(ref: CopilotChat.prompts.Reference): string?
function M.replace(prompt, refs, resolver)
table.sort(refs, function(a, b)
return a.start_pos > b.start_pos
end)
for _, ref in ipairs(refs) do
local output = resolver(ref)
if output then
prompt = prompt:sub(1, ref.start_pos - 1) .. output .. prompt:sub(ref.end_pos + 1)
end
end
return prompt
end

return M
Loading