Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 199 additions & 33 deletions src/commands/checkpoint_agent/presets/agent_v1.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use super::{AgentPreset, ParsedHookEvent, PostFileEdit, PreFileEdit, PresetContext};
use super::{
AgentPreset, ParsedHookEvent, PostBashCall, PostFileEdit, PreBashCall, PreFileEdit,
PresetContext,
};
use crate::authorship::working_log::AgentId;
use crate::error::GitAiError;
use serde::Deserialize;
Expand All @@ -25,6 +28,69 @@ enum AgentV1Payload {
model: String,
conversation_id: String,
},
PreShellCommand {
repo_working_dir: String,
agent_name: String,
model: String,
conversation_id: String,
tool_use_id: Option<String>,
#[serde(default)]
command: Option<String>,
},
PostShellCommand {
repo_working_dir: String,
agent_name: String,
model: String,
conversation_id: String,
tool_use_id: Option<String>,
#[serde(default)]
command: Option<String>,
},
}

fn resolve_paths(paths: Option<Vec<String>>, repo_working_dir: &str) -> Vec<PathBuf> {
paths
.unwrap_or_default()
.into_iter()
.map(|p| super::parse::resolve_absolute(&p, repo_working_dir))
.collect()
}

fn resolve_dirty_files(
dirty_files: Option<HashMap<String, String>>,
repo_working_dir: &str,
) -> Option<HashMap<PathBuf, String>> {
dirty_files.map(|df| {
df.into_iter()
.map(|(k, v)| (super::parse::resolve_absolute(&k, repo_working_dir), v))
.collect()
})
}

fn agent_context(
repo_working_dir: &str,
agent_name: String,
model: String,
conversation_id: String,
trace_id: &str,
command: Option<String>,
) -> PresetContext {
let mut metadata = HashMap::new();
if let Some(command) = command {
metadata.insert("command".to_string(), command);
}

PresetContext {
agent_id: AgentId {
tool: agent_name,
id: conversation_id.clone(),
model,
},
external_session_id: conversation_id,
trace_id: trace_id.to_string(),
cwd: PathBuf::from(repo_working_dir),
metadata,
}
}

impl AgentPreset for AgentV1Preset {
Expand All @@ -43,16 +109,8 @@ impl AgentPreset for AgentV1Preset {
dirty_files,
} => {
let cwd = PathBuf::from(&repo_working_dir);
let file_paths = will_edit_filepaths
.unwrap_or_default()
.into_iter()
.map(|p| super::parse::resolve_absolute(&p, &repo_working_dir))
.collect();
let dirty = dirty_files.map(|df| {
df.into_iter()
.map(|(k, v)| (super::parse::resolve_absolute(&k, &repo_working_dir), v))
.collect()
});
let file_paths = resolve_paths(will_edit_filepaths, &repo_working_dir);
let dirty = resolve_dirty_files(dirty_files, &repo_working_dir);
ParsedHookEvent::PreFileEdit(PreFileEdit {
context: PresetContext {
agent_id: AgentId {
Expand All @@ -78,35 +136,60 @@ impl AgentPreset for AgentV1Preset {
model,
conversation_id,
} => {
let cwd = PathBuf::from(&repo_working_dir);
let file_paths = edited_filepaths
.unwrap_or_default()
.into_iter()
.map(|p| super::parse::resolve_absolute(&p, &repo_working_dir))
.collect();
let dirty = dirty_files.map(|df| {
df.into_iter()
.map(|(k, v)| (super::parse::resolve_absolute(&k, &repo_working_dir), v))
.collect()
});
let file_paths = resolve_paths(edited_filepaths, &repo_working_dir);
let dirty = resolve_dirty_files(dirty_files, &repo_working_dir);
ParsedHookEvent::PostFileEdit(PostFileEdit {
context: PresetContext {
agent_id: AgentId {
tool: agent_name,
id: conversation_id.clone(),
model,
},
external_session_id: conversation_id,
trace_id: trace_id.to_string(),
cwd,
metadata: HashMap::new(),
},
context: agent_context(
&repo_working_dir,
agent_name,
model,
conversation_id,
trace_id,
None,
),
file_paths,
dirty_files: dirty,
stream_source: None,
tool_use_id: None,
})
}
AgentV1Payload::PreShellCommand {
repo_working_dir,
agent_name,
model,
conversation_id,
tool_use_id,
command,
} => ParsedHookEvent::PreBashCall(PreBashCall {
context: agent_context(
&repo_working_dir,
agent_name,
model,
conversation_id,
trace_id,
command,
),
tool_use_id: tool_use_id.unwrap_or_else(|| "shell".to_string()),
}),
AgentV1Payload::PostShellCommand {
repo_working_dir,
agent_name,
model,
conversation_id,
tool_use_id,
command,
} => ParsedHookEvent::PostBashCall(PostBashCall {
context: agent_context(
&repo_working_dir,
agent_name,
model,
conversation_id,
trace_id,
command,
),
tool_use_id: tool_use_id.unwrap_or_else(|| "shell".to_string()),
stream_source: None,
}),
};

Ok(vec![event])
Expand Down Expand Up @@ -199,6 +282,89 @@ mod tests {
assert!(result.is_err());
}

#[test]
fn test_agent_v1_pre_shell_command_type() {
let input = json!({
"type": "pre_shell_command",
"repo_working_dir": "/home/user/project",
"agent_name": "my-agent",
"model": "gpt-4",
"conversation_id": "conv-123",
"tool_use_id": "shell-1",
"command": "printf 'generated\\n' > output.txt"
})
.to_string();
let events = AgentV1Preset.parse(&input, "t_test").unwrap();
assert_eq!(events.len(), 1);
match &events[0] {
ParsedHookEvent::PreBashCall(e) => {
assert_eq!(e.context.agent_id.tool, "my-agent");
assert_eq!(e.context.agent_id.id, "conv-123");
assert_eq!(e.context.agent_id.model, "gpt-4");
assert_eq!(e.context.external_session_id, "conv-123");
assert_eq!(e.context.cwd, PathBuf::from("/home/user/project"));
assert_eq!(e.tool_use_id, "shell-1");
assert_eq!(
e.context.metadata.get("command").map(String::as_str),
Some("printf 'generated\\n' > output.txt")
);
}
_ => panic!("Expected PreBashCall"),
}
}

#[test]
fn test_agent_v1_post_shell_command_type() {
let input = json!({
"type": "post_shell_command",
"repo_working_dir": "/home/user/project",
"agent_name": "my-agent",
"model": "gpt-4",
"conversation_id": "conv-123",
"tool_use_id": "shell-1",
"command": "printf 'generated\\n' > output.txt"
})
.to_string();
let events = AgentV1Preset.parse(&input, "t_test").unwrap();
assert_eq!(events.len(), 1);
match &events[0] {
ParsedHookEvent::PostBashCall(e) => {
assert_eq!(e.context.agent_id.tool, "my-agent");
assert_eq!(e.context.agent_id.id, "conv-123");
assert_eq!(e.context.agent_id.model, "gpt-4");
assert_eq!(e.context.external_session_id, "conv-123");
assert_eq!(e.context.cwd, PathBuf::from("/home/user/project"));
assert_eq!(e.tool_use_id, "shell-1");
assert_eq!(
e.context.metadata.get("command").map(String::as_str),
Some("printf 'generated\\n' > output.txt")
);
assert!(e.stream_source.is_none());
}
_ => panic!("Expected PostBashCall"),
}
}

#[test]
fn test_agent_v1_shell_command_defaults_tool_use_id() {
let input = json!({
"type": "pre_shell_command",
"repo_working_dir": "/home/user/project",
"agent_name": "my-agent",
"model": "gpt-4",
"conversation_id": "conv-123"
})
.to_string();
let events = AgentV1Preset.parse(&input, "t_test").unwrap();
match &events[0] {
ParsedHookEvent::PreBashCall(e) => {
assert_eq!(e.tool_use_id, "shell");
assert!(e.context.metadata.is_empty());
}
_ => panic!("Expected PreBashCall"),
}
}

#[test]
fn test_agent_v1_unknown_type() {
let input = json!({
Expand Down
47 changes: 47 additions & 0 deletions tests/integration/agent_v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,50 @@ fn test_agent_v1_relative_dirty_files_e2e_attribution() {
"AI added line".ai(),
]);
}

#[test]
fn test_agent_v1_shell_command_e2e_attribution() {
let repo = TestRepo::new();
let file_path = repo.path().join("script-output.txt");

fs::write(&file_path, "base line\n").unwrap();
repo.stage_all_and_commit("Initial commit").unwrap();
let mut file = repo.filename("script-output.txt");
file.assert_committed_lines(crate::lines!["base line".unattributed_human(),]);

let repo_dir = repo.canonical_path().to_string_lossy().to_string();

let pre_payload = json!({
"type": "pre_shell_command",
"repo_working_dir": repo_dir,
"agent_name": "agent-v1-test",
"model": "test-model",
"conversation_id": "shell-session-123",
"tool_use_id": "shell-tool-1",
"command": "printf 'created by shell\\n' >> script-output.txt"
})
.to_string();
repo.git_ai(&["checkpoint", "agent-v1", "--hook-input", &pre_payload])
.unwrap();

fs::write(&file_path, "base line\ncreated by shell\n").unwrap();

let post_payload = json!({
"type": "post_shell_command",
"repo_working_dir": repo_dir,
"agent_name": "agent-v1-test",
"model": "test-model",
"conversation_id": "shell-session-123",
"tool_use_id": "shell-tool-1",
"command": "printf 'created by shell\\n' >> script-output.txt"
})
.to_string();
repo.git_ai(&["checkpoint", "agent-v1", "--hook-input", &post_payload])
.unwrap();

repo.stage_all_and_commit("Agent v1 shell edit").unwrap();
file.assert_committed_lines(crate::lines![
"base line".unattributed_human(),
"created by shell".ai(),
]);
}
1 change: 1 addition & 0 deletions tests/integration/repos/test_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const AI_AUTHOR_NAMES: &[&str] = &[
"cloud-agent",
"codex-cloud",
"git-ai-cloud-agent",
"agent-v1",
];

#[derive(Debug, Clone, PartialEq)]
Expand Down
Loading