Skip to content

Commit

Permalink
feat: Support extending the system prompt (#1167)
Browse files Browse the repository at this point in the history
  • Loading branch information
baxen authored Feb 12, 2025
1 parent a5e2419 commit 6220ef0
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 1 deletion.
16 changes: 16 additions & 0 deletions crates/goose-cli/src/cli_prompt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/// Returns a system prompt extension that explains CLI-specific functionality
pub fn get_cli_prompt() -> String {
String::from(
"You are being accessed through a command-line interface. The following slash commands are available
- you can let the user know about them if they need help:
- /exit or /quit - Exit the session
- /t - Toggle between Light/Dark/Ansi themes
- /? or /help - Display help message
Additional keyboard shortcuts:
- Ctrl+C - Interrupt the current interaction (resets to before the interrupted request)
- Ctrl+J - Add a newline
- Up/Down arrows - Navigate command history"
)
}
5 changes: 5 additions & 0 deletions crates/goose-cli/src/commands/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ pub async fn build_session(

let prompt = Box::new(RustylinePrompt::new());

// Add CLI-specific system prompt extension
agent
.extend_system_prompt(crate::cli_prompt::get_cli_prompt())
.await;

display_session_info(resume, &provider_name, &model, &session_file);
Session::new(agent, prompt, session_file)
}
Expand Down
1 change: 1 addition & 0 deletions crates/goose-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub static APP_STRATEGY: Lazy<AppStrategyArgs> = Lazy::new(|| AppStrategyArgs {
app_name: "goose".to_string(),
});

mod cli_prompt;
mod commands;
mod log_usage;
mod logging;
Expand Down
35 changes: 35 additions & 0 deletions crates/goose-server/src/routes/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ struct VersionsResponse {
default_version: String,
}

#[derive(Deserialize)]
struct ExtendPromptRequest {
extension: String,
}

#[derive(Serialize)]
struct ExtendPromptResponse {
success: bool,
}

#[derive(Deserialize)]
struct CreateAgentRequest {
version: Option<String>,
Expand Down Expand Up @@ -61,6 +71,30 @@ async fn get_versions() -> Json<VersionsResponse> {
})
}

async fn extend_prompt(
State(state): State<AppState>,
headers: HeaderMap,
Json(payload): Json<ExtendPromptRequest>,
) -> Result<Json<ExtendPromptResponse>, StatusCode> {
// Verify secret key
let secret_key = headers
.get("X-Secret-Key")
.and_then(|value| value.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;

if secret_key != state.secret_key {
return Err(StatusCode::UNAUTHORIZED);
}

let mut agent = state.agent.lock().await;
if let Some(ref mut agent) = *agent {
agent.extend_system_prompt(payload.extension).await;
Ok(Json(ExtendPromptResponse { success: true }))
} else {
Err(StatusCode::NOT_FOUND)
}
}

async fn create_agent(
State(state): State<AppState>,
headers: HeaderMap,
Expand Down Expand Up @@ -132,6 +166,7 @@ pub fn routes(state: AppState) -> Router {
Router::new()
.route("/agent/versions", get(get_versions))
.route("/agent/providers", get(list_providers))
.route("/agent/prompt", post(extend_prompt))
.route("/agent", post(create_agent))
.with_state(state)
}
3 changes: 3 additions & 0 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,7 @@ pub trait Agent: Send + Sync {

/// Get the total usage of the agent
async fn usage(&self) -> Vec<ProviderUsage>;

/// Add custom text to be included in the system prompt
async fn extend_system_prompt(&mut self, extension: String);
}
19 changes: 18 additions & 1 deletion crates/goose/src/agents/capabilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub struct Capabilities {
resource_capable_extensions: HashSet<String>,
provider: Box<dyn Provider>,
provider_usage: Mutex<Vec<ProviderUsage>>,
system_prompt_extensions: Vec<String>,
}

/// A flattened representation of a resource used by the agent to prepare inference
Expand Down Expand Up @@ -88,6 +89,7 @@ impl Capabilities {
resource_capable_extensions: HashSet::new(),
provider,
provider_usage: Mutex::new(Vec::new()),
system_prompt_extensions: Vec::new(),
}
}

Expand Down Expand Up @@ -164,6 +166,11 @@ impl Capabilities {
Ok(())
}

/// Add a system prompt extension
pub fn add_system_prompt_extension(&mut self, extension: String) {
self.system_prompt_extensions.push(extension);
}

/// Get a reference to the provider
pub fn provider(&self) -> &dyn Provider {
&*self.provider
Expand Down Expand Up @@ -303,7 +310,17 @@ impl Capabilities {
context.insert("extensions", serde_json::to_value(extensions_info).unwrap());
context.insert("current_date_time", Value::String(current_date_time));

load_prompt_file("system.md", &context).expect("Prompt should render")
let base_prompt = load_prompt_file("system.md", &context).expect("Prompt should render");

if self.system_prompt_extensions.is_empty() {
base_prompt
} else {
format!(
"{}\n\n# Additional Instructions:\n\n{}",
base_prompt,
self.system_prompt_extensions.join("\n\n")
)
}
}

/// Find and return a reference to the appropriate client for a tool call
Expand Down
5 changes: 5 additions & 0 deletions crates/goose/src/agents/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ impl Agent for ReferenceAgent {
let capabilities = self.capabilities.lock().await;
capabilities.get_usage().await
}

async fn extend_system_prompt(&mut self, extension: String) {
let mut capabilities = self.capabilities.lock().await;
capabilities.add_system_prompt_extension(extension);
}
}

register_agent!("reference", ReferenceAgent);
5 changes: 5 additions & 0 deletions crates/goose/src/agents/truncate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,11 @@ impl Agent for TruncateAgent {
let capabilities = self.capabilities.lock().await;
capabilities.get_usage().await
}

async fn extend_system_prompt(&mut self, extension: String) {
let mut capabilities = self.capabilities.lock().await;
capabilities.add_system_prompt_extension(extension);
}
}

register_agent!("truncate", TruncateAgent);

0 comments on commit 6220ef0

Please sign in to comment.