Restore codebase deleted by bad auto-commit e4227cf

Commit e4227cf (a story creation auto-commit) erroneously deleted 175
files from master's tree, likely due to a race condition between
concurrent git operations. This commit re-adds all files from the
working directory.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
dave
2026-03-22 19:07:07 +00:00
parent 89f776b978
commit f610ef6046
174 changed files with 84280 additions and 0 deletions
+426
View File
@@ -0,0 +1,426 @@
use std::path::Path;
use std::process::Command;
use std::time::Duration;
use wait_timeout::ChildExt;
/// Maximum time any single test command is allowed to run before being killed.
const TEST_TIMEOUT: Duration = Duration::from_secs(600); // 10 minutes
/// Detect whether the base branch in a worktree is `master` or `main`.
/// Falls back to `"master"` if neither is found.
pub(crate) fn detect_worktree_base_branch(wt_path: &Path) -> String {
for branch in &["master", "main"] {
let ok = Command::new("git")
.args(["rev-parse", "--verify", branch])
.current_dir(wt_path)
.output()
.map(|o| o.status.success())
.unwrap_or(false);
if ok {
return branch.to_string();
}
}
"master".to_string()
}
/// Return `true` if the git worktree at `wt_path` has commits on its current
/// branch that are not present on the base branch (`master` or `main`).
///
/// Used during server startup reconciliation to detect stories whose agent work
/// was committed while the server was offline.
pub(crate) fn worktree_has_committed_work(wt_path: &Path) -> bool {
let base_branch = detect_worktree_base_branch(wt_path);
let output = Command::new("git")
.args(["log", &format!("{base_branch}..HEAD"), "--oneline"])
.current_dir(wt_path)
.output();
match output {
Ok(out) if out.status.success() => {
!String::from_utf8_lossy(&out.stdout).trim().is_empty()
}
_ => false,
}
}
/// Check whether the given directory has any uncommitted git changes.
/// Returns `Err` with a descriptive message if there are any.
pub(crate) fn check_uncommitted_changes(path: &Path) -> Result<(), String> {
let output = Command::new("git")
.args(["status", "--porcelain"])
.current_dir(path)
.output()
.map_err(|e| format!("Failed to run git status: {e}"))?;
let stdout = String::from_utf8_lossy(&output.stdout);
if !stdout.trim().is_empty() {
return Err(format!(
"Worktree has uncommitted changes. Please commit all work before \
the agent exits:\n{stdout}"
));
}
Ok(())
}
/// Run the project's test suite.
///
/// Uses `script/test` if present, treating it as the canonical single test entry point.
/// Falls back to `cargo nextest run` / `cargo test` when `script/test` is absent.
/// Returns `(tests_passed, output)`.
pub(crate) fn run_project_tests(path: &Path) -> Result<(bool, String), String> {
let script_test = path.join("script").join("test");
if script_test.exists() {
let mut output = String::from("=== script/test ===\n");
let (success, out) = run_command_with_timeout(&script_test, &[], path)?;
output.push_str(&out);
output.push('\n');
return Ok((success, output));
}
// Fallback: cargo nextest run / cargo test
let mut output = String::from("=== tests ===\n");
let (success, test_out) = match run_command_with_timeout("cargo", &["nextest", "run"], path) {
Ok(result) => result,
Err(_) => {
// nextest not available — fall back to cargo test
run_command_with_timeout("cargo", &["test"], path)
.map_err(|e| format!("Failed to run cargo test: {e}"))?
}
};
output.push_str(&test_out);
output.push('\n');
Ok((success, output))
}
/// Run a command with a timeout. Returns `(success, combined_output)`.
/// Kills the child process if it exceeds `TEST_TIMEOUT`.
///
/// Stdout and stderr are drained in background threads to avoid a pipe-buffer
/// deadlock: if the child fills the 64 KB OS pipe buffer while the parent
/// blocks on `waitpid`, neither side can make progress.
fn run_command_with_timeout(
program: impl AsRef<std::ffi::OsStr>,
args: &[&str],
dir: &Path,
) -> Result<(bool, String), String> {
let mut child = Command::new(program)
.args(args)
.current_dir(dir)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| format!("Failed to spawn command: {e}"))?;
// Drain stdout/stderr in background threads so the pipe buffers never fill.
let stdout_handle = child.stdout.take().map(|r| {
std::thread::spawn(move || {
let mut s = String::new();
let mut r = r;
std::io::Read::read_to_string(&mut r, &mut s).ok();
s
})
});
let stderr_handle = child.stderr.take().map(|r| {
std::thread::spawn(move || {
let mut s = String::new();
let mut r = r;
std::io::Read::read_to_string(&mut r, &mut s).ok();
s
})
});
match child.wait_timeout(TEST_TIMEOUT) {
Ok(Some(status)) => {
let stdout = stdout_handle
.and_then(|h| h.join().ok())
.unwrap_or_default();
let stderr = stderr_handle
.and_then(|h| h.join().ok())
.unwrap_or_default();
Ok((status.success(), format!("{stdout}{stderr}")))
}
Ok(None) => {
// Timed out — kill the child.
let _ = child.kill();
let _ = child.wait();
Err(format!(
"Command timed out after {} seconds",
TEST_TIMEOUT.as_secs()
))
}
Err(e) => Err(format!("Failed to wait for command: {e}")),
}
}
/// Run `cargo clippy` and the project test suite (via `script/test` if present,
/// otherwise `cargo nextest run` / `cargo test`) in the given directory.
/// Returns `(gates_passed, combined_output)`.
pub(crate) fn run_acceptance_gates(path: &Path) -> Result<(bool, String), String> {
let mut all_output = String::new();
let mut all_passed = true;
// ── cargo clippy ──────────────────────────────────────────────
let clippy = Command::new("cargo")
.args(["clippy", "--all-targets", "--all-features"])
.current_dir(path)
.output()
.map_err(|e| format!("Failed to run cargo clippy: {e}"))?;
all_output.push_str("=== cargo clippy ===\n");
let clippy_stdout = String::from_utf8_lossy(&clippy.stdout);
let clippy_stderr = String::from_utf8_lossy(&clippy.stderr);
if !clippy_stdout.is_empty() {
all_output.push_str(&clippy_stdout);
}
if !clippy_stderr.is_empty() {
all_output.push_str(&clippy_stderr);
}
all_output.push('\n');
if !clippy.status.success() {
all_passed = false;
}
// ── tests (script/test if available, else cargo nextest/test) ─
let (test_success, test_out) = run_project_tests(path)?;
all_output.push_str(&test_out);
if !test_success {
all_passed = false;
}
Ok((all_passed, all_output))
}
/// Run `script/test_coverage` in the given directory if the script exists.
///
/// Used as a QA gate before advancing a story from `3_qa/` to `4_merge/`.
/// Returns `(passed, output)`. If the script does not exist, returns `(true, …)`.
pub(crate) fn run_coverage_gate(path: &Path) -> Result<(bool, String), String> {
let script = path.join("script").join("test_coverage");
if !script.exists() {
return Ok((
true,
"script/test_coverage not found; coverage gate skipped.\n".to_string(),
));
}
let mut output = String::from("=== script/test_coverage ===\n");
let result = Command::new(&script)
.current_dir(path)
.output()
.map_err(|e| format!("Failed to run script/test_coverage: {e}"))?;
let combined = format!(
"{}{}",
String::from_utf8_lossy(&result.stdout),
String::from_utf8_lossy(&result.stderr)
);
output.push_str(&combined);
output.push('\n');
Ok((result.status.success(), output))
}
#[cfg(test)]
mod tests {
use super::*;
fn init_git_repo(repo: &std::path::Path) {
Command::new("git")
.args(["init"])
.current_dir(repo)
.output()
.unwrap();
Command::new("git")
.args(["config", "user.email", "test@test.com"])
.current_dir(repo)
.output()
.unwrap();
Command::new("git")
.args(["config", "user.name", "Test"])
.current_dir(repo)
.output()
.unwrap();
Command::new("git")
.args(["commit", "--allow-empty", "-m", "init"])
.current_dir(repo)
.output()
.unwrap();
}
// ── run_project_tests tests ───────────────────────────────────
#[cfg(unix)]
#[test]
fn run_project_tests_uses_script_test_when_present_and_passes() {
use std::fs;
use std::os::unix::fs::PermissionsExt;
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path();
let script_dir = path.join("script");
fs::create_dir_all(&script_dir).unwrap();
let script_test = script_dir.join("test");
fs::write(&script_test, "#!/usr/bin/env bash\necho 'all tests passed'\nexit 0\n").unwrap();
let mut perms = fs::metadata(&script_test).unwrap().permissions();
perms.set_mode(0o755);
fs::set_permissions(&script_test, perms).unwrap();
let (passed, output) = run_project_tests(path).unwrap();
assert!(passed, "script/test exiting 0 should pass");
assert!(output.contains("script/test"), "output should mention script/test");
}
#[cfg(unix)]
#[test]
fn run_project_tests_reports_failure_when_script_test_exits_nonzero() {
use std::fs;
use std::os::unix::fs::PermissionsExt;
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path();
let script_dir = path.join("script");
fs::create_dir_all(&script_dir).unwrap();
let script_test = script_dir.join("test");
fs::write(&script_test, "#!/usr/bin/env bash\nexit 1\n").unwrap();
let mut perms = fs::metadata(&script_test).unwrap().permissions();
perms.set_mode(0o755);
fs::set_permissions(&script_test, perms).unwrap();
let (passed, output) = run_project_tests(path).unwrap();
assert!(!passed, "script/test exiting 1 should fail");
assert!(output.contains("script/test"), "output should mention script/test");
}
// ── run_coverage_gate tests ───────────────────────────────────────────────
#[cfg(unix)]
#[test]
fn coverage_gate_passes_when_script_absent() {
use tempfile::tempdir;
let tmp = tempdir().unwrap();
let (passed, output) = run_coverage_gate(tmp.path()).unwrap();
assert!(passed, "coverage gate should pass when script is absent");
assert!(
output.contains("not found"),
"output should mention script not found"
);
}
#[cfg(unix)]
#[test]
fn coverage_gate_passes_when_script_exits_zero() {
use std::fs;
use std::os::unix::fs::PermissionsExt;
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path();
let script_dir = path.join("script");
fs::create_dir_all(&script_dir).unwrap();
let script = script_dir.join("test_coverage");
fs::write(
&script,
"#!/usr/bin/env bash\necho 'Rust line coverage: 85%'\necho 'PASS: Coverage 85% meets threshold 0%'\nexit 0\n",
)
.unwrap();
let mut perms = fs::metadata(&script).unwrap().permissions();
perms.set_mode(0o755);
fs::set_permissions(&script, perms).unwrap();
let (passed, output) = run_coverage_gate(path).unwrap();
assert!(passed, "coverage gate should pass when script exits 0");
assert!(
output.contains("script/test_coverage"),
"output should mention script/test_coverage"
);
}
#[cfg(unix)]
#[test]
fn coverage_gate_fails_when_script_exits_nonzero() {
use std::fs;
use std::os::unix::fs::PermissionsExt;
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path();
let script_dir = path.join("script");
fs::create_dir_all(&script_dir).unwrap();
let script = script_dir.join("test_coverage");
fs::write(
&script,
"#!/usr/bin/env bash\necho 'FAIL: Coverage 40% is below threshold 80%'\nexit 1\n",
)
.unwrap();
let mut perms = fs::metadata(&script).unwrap().permissions();
perms.set_mode(0o755);
fs::set_permissions(&script, perms).unwrap();
let (passed, output) = run_coverage_gate(path).unwrap();
assert!(!passed, "coverage gate should fail when script exits 1");
assert!(
output.contains("script/test_coverage"),
"output should mention script/test_coverage"
);
}
// ── worktree_has_committed_work tests ─────────────────────────────────────
#[test]
fn worktree_has_committed_work_false_on_fresh_repo() {
let tmp = tempfile::tempdir().unwrap();
let repo = tmp.path();
// init_git_repo creates the initial commit on the default branch.
// HEAD IS the base branch — no commits ahead.
init_git_repo(repo);
assert!(!worktree_has_committed_work(repo));
}
#[test]
fn worktree_has_committed_work_true_after_commit_on_feature_branch() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let project_root = tmp.path().join("project");
fs::create_dir_all(&project_root).unwrap();
init_git_repo(&project_root);
// Create a git worktree on a feature branch.
let wt_path = tmp.path().join("wt");
Command::new("git")
.args([
"worktree",
"add",
&wt_path.to_string_lossy(),
"-b",
"feature/story-99_test",
])
.current_dir(&project_root)
.output()
.unwrap();
// No commits on the feature branch yet — same as base branch.
assert!(!worktree_has_committed_work(&wt_path));
// Add a commit to the feature branch in the worktree.
fs::write(wt_path.join("work.txt"), "done").unwrap();
Command::new("git")
.args(["add", "."])
.current_dir(&wt_path)
.output()
.unwrap();
Command::new("git")
.args([
"-c",
"user.email=test@test.com",
"-c",
"user.name=Test",
"commit",
"-m",
"coder: implement story",
])
.current_dir(&wt_path)
.output()
.unwrap();
// Now the feature branch is ahead of the base branch.
assert!(worktree_has_committed_work(&wt_path));
}
}
+829
View File
@@ -0,0 +1,829 @@
use std::path::{Path, PathBuf};
use std::process::Command;
use crate::io::story_metadata::{clear_front_matter_field, write_rejection_notes};
use crate::slog;
pub(super) fn item_type_from_id(item_id: &str) -> &'static str {
// New format: {digits}_{type}_{slug}
let after_num = item_id.trim_start_matches(|c: char| c.is_ascii_digit());
if after_num.starts_with("_bug_") {
"bug"
} else if after_num.starts_with("_spike_") {
"spike"
} else {
"story"
}
}
/// Return the source directory path for a work item (always work/1_backlog/).
fn item_source_dir(project_root: &Path, _item_id: &str) -> PathBuf {
project_root.join(".storkit").join("work").join("1_backlog")
}
/// Return the done directory path for a work item (always work/5_done/).
fn item_archive_dir(project_root: &Path, _item_id: &str) -> PathBuf {
project_root.join(".storkit").join("work").join("5_done")
}
/// Move a work item (story, bug, or spike) from `work/1_backlog/` to `work/2_current/`.
///
/// Idempotent: if the item is already in `2_current/`, returns Ok without committing.
/// If the item is not found in `1_backlog/`, logs a warning and returns Ok.
pub fn move_story_to_current(project_root: &Path, story_id: &str) -> Result<(), String> {
let sk = project_root.join(".storkit").join("work");
let current_dir = sk.join("2_current");
let current_path = current_dir.join(format!("{story_id}.md"));
if current_path.exists() {
// Already in 2_current/ — idempotent, nothing to do.
return Ok(());
}
let source_dir = item_source_dir(project_root, story_id);
let source_path = source_dir.join(format!("{story_id}.md"));
if !source_path.exists() {
slog!(
"[lifecycle] Work item '{story_id}' not found in {}; skipping move to 2_current/",
source_dir.display()
);
return Ok(());
}
std::fs::create_dir_all(&current_dir)
.map_err(|e| format!("Failed to create work/2_current/ directory: {e}"))?;
std::fs::rename(&source_path, &current_path)
.map_err(|e| format!("Failed to move '{story_id}' to 2_current/: {e}"))?;
slog!(
"[lifecycle] Moved '{story_id}' from {} to work/2_current/",
source_dir.display()
);
Ok(())
}
/// Check whether a feature branch `feature/story-{story_id}` exists and has
/// commits that are not yet on master. Returns `true` when there is unmerged
/// work, `false` when there is no branch or all its commits are already
/// reachable from master.
pub fn feature_branch_has_unmerged_changes(project_root: &Path, story_id: &str) -> bool {
let branch = format!("feature/story-{story_id}");
// Check if the branch exists.
let branch_check = Command::new("git")
.args(["rev-parse", "--verify", &branch])
.current_dir(project_root)
.output();
match branch_check {
Ok(out) if out.status.success() => {}
_ => return false, // No feature branch → nothing to merge.
}
// Check if the branch has commits not reachable from master.
let log = Command::new("git")
.args(["log", &format!("master..{branch}"), "--oneline"])
.current_dir(project_root)
.output();
match log {
Ok(out) => {
let stdout = String::from_utf8_lossy(&out.stdout);
!stdout.trim().is_empty()
}
Err(_) => false,
}
}
/// Move a story from `work/2_current/` to `work/5_done/` and auto-commit.
///
/// * If the story is in `2_current/`, it is moved to `5_done/` and committed.
/// * If the story is in `4_merge/`, it is moved to `5_done/` and committed.
/// * If the story is already in `5_done/` or `6_archived/`, this is a no-op (idempotent).
/// * If the story is not found in `2_current/`, `4_merge/`, `5_done/`, or `6_archived/`, an error is returned.
pub fn move_story_to_archived(project_root: &Path, story_id: &str) -> Result<(), String> {
let sk = project_root.join(".storkit").join("work");
let current_path = sk.join("2_current").join(format!("{story_id}.md"));
let merge_path = sk.join("4_merge").join(format!("{story_id}.md"));
let done_dir = sk.join("5_done");
let done_path = done_dir.join(format!("{story_id}.md"));
let archived_path = sk.join("6_archived").join(format!("{story_id}.md"));
if done_path.exists() || archived_path.exists() {
// Already in done or archived — idempotent, nothing to do.
return Ok(());
}
// Check 2_current/ first, then 4_merge/
let source_path = if current_path.exists() {
current_path.clone()
} else if merge_path.exists() {
merge_path.clone()
} else {
return Err(format!(
"Story '{story_id}' not found in work/2_current/ or work/4_merge/. Cannot accept story."
));
};
std::fs::create_dir_all(&done_dir)
.map_err(|e| format!("Failed to create work/5_done/ directory: {e}"))?;
std::fs::rename(&source_path, &done_path)
.map_err(|e| format!("Failed to move story '{story_id}' to 5_done/: {e}"))?;
// Strip stale pipeline fields from front matter now that the story is done.
for field in &["merge_failure", "retry_count", "blocked"] {
if let Err(e) = clear_front_matter_field(&done_path, field) {
slog!("[lifecycle] Warning: could not clear {field} from '{story_id}': {e}");
}
}
let from_dir = if source_path == current_path {
"work/2_current/"
} else {
"work/4_merge/"
};
slog!("[lifecycle] Moved story '{story_id}' from {from_dir} to work/5_done/");
Ok(())
}
/// Move a story/bug from `work/2_current/` or `work/3_qa/` to `work/4_merge/`.
///
/// This stages a work item as ready for the mergemaster to pick up and merge into master.
/// Idempotent: if already in `4_merge/`, returns Ok without committing.
pub fn move_story_to_merge(project_root: &Path, story_id: &str) -> Result<(), String> {
let sk = project_root.join(".storkit").join("work");
let current_path = sk.join("2_current").join(format!("{story_id}.md"));
let qa_path = sk.join("3_qa").join(format!("{story_id}.md"));
let merge_dir = sk.join("4_merge");
let merge_path = merge_dir.join(format!("{story_id}.md"));
if merge_path.exists() {
// Already in 4_merge/ — idempotent, nothing to do.
return Ok(());
}
// Accept from 2_current/ (manual trigger) or 3_qa/ (pipeline advancement from QA stage).
let source_path = if current_path.exists() {
current_path.clone()
} else if qa_path.exists() {
qa_path.clone()
} else {
return Err(format!(
"Work item '{story_id}' not found in work/2_current/ or work/3_qa/. Cannot move to 4_merge/."
));
};
std::fs::create_dir_all(&merge_dir)
.map_err(|e| format!("Failed to create work/4_merge/ directory: {e}"))?;
std::fs::rename(&source_path, &merge_path)
.map_err(|e| format!("Failed to move '{story_id}' to 4_merge/: {e}"))?;
let from_dir = if source_path == current_path {
"work/2_current/"
} else {
"work/3_qa/"
};
// Reset retry count and blocked for the new stage.
if let Err(e) = clear_front_matter_field(&merge_path, "retry_count") {
slog!("[lifecycle] Warning: could not clear retry_count for '{story_id}': {e}");
}
if let Err(e) = clear_front_matter_field(&merge_path, "blocked") {
slog!("[lifecycle] Warning: could not clear blocked for '{story_id}': {e}");
}
slog!("[lifecycle] Moved '{story_id}' from {from_dir} to work/4_merge/");
Ok(())
}
/// Move a story/bug from `work/2_current/` to `work/3_qa/` and auto-commit.
///
/// This stages a work item for QA review before merging to master.
/// Idempotent: if already in `3_qa/`, returns Ok without committing.
pub fn move_story_to_qa(project_root: &Path, story_id: &str) -> Result<(), String> {
let sk = project_root.join(".storkit").join("work");
let current_path = sk.join("2_current").join(format!("{story_id}.md"));
let qa_dir = sk.join("3_qa");
let qa_path = qa_dir.join(format!("{story_id}.md"));
if qa_path.exists() {
// Already in 3_qa/ — idempotent, nothing to do.
return Ok(());
}
if !current_path.exists() {
return Err(format!(
"Work item '{story_id}' not found in work/2_current/. Cannot move to 3_qa/."
));
}
std::fs::create_dir_all(&qa_dir)
.map_err(|e| format!("Failed to create work/3_qa/ directory: {e}"))?;
std::fs::rename(&current_path, &qa_path)
.map_err(|e| format!("Failed to move '{story_id}' to 3_qa/: {e}"))?;
// Reset retry count for the new stage.
if let Err(e) = clear_front_matter_field(&qa_path, "retry_count") {
slog!("[lifecycle] Warning: could not clear retry_count for '{story_id}': {e}");
}
if let Err(e) = clear_front_matter_field(&qa_path, "blocked") {
slog!("[lifecycle] Warning: could not clear blocked for '{story_id}': {e}");
}
slog!("[lifecycle] Moved '{story_id}' from work/2_current/ to work/3_qa/");
Ok(())
}
/// Move a story from `work/3_qa/` back to `work/2_current/` and write rejection notes.
///
/// Used when a human reviewer rejects a story during manual QA.
/// Clears the `review_hold` front matter field and appends rejection notes to the story file.
pub fn reject_story_from_qa(
project_root: &Path,
story_id: &str,
notes: &str,
) -> Result<(), String> {
let sk = project_root.join(".storkit").join("work");
let qa_path = sk.join("3_qa").join(format!("{story_id}.md"));
let current_dir = sk.join("2_current");
let current_path = current_dir.join(format!("{story_id}.md"));
if current_path.exists() {
return Ok(()); // Already in 2_current — idempotent.
}
if !qa_path.exists() {
return Err(format!(
"Work item '{story_id}' not found in work/3_qa/. Cannot reject."
));
}
std::fs::create_dir_all(&current_dir)
.map_err(|e| format!("Failed to create work/2_current/ directory: {e}"))?;
std::fs::rename(&qa_path, &current_path)
.map_err(|e| format!("Failed to move '{story_id}' from 3_qa/ to 2_current/: {e}"))?;
// Clear review_hold since the story is going back for rework.
if let Err(e) = clear_front_matter_field(&current_path, "review_hold") {
slog!("[lifecycle] Warning: could not clear review_hold from '{story_id}': {e}");
}
// Write rejection notes into the story file so the coder can see what needs fixing.
if !notes.is_empty()
&& let Err(e) = write_rejection_notes(&current_path, notes)
{
slog!("[lifecycle] Warning: could not write rejection notes to '{story_id}': {e}");
}
slog!("[lifecycle] Rejected '{story_id}' from work/3_qa/ back to work/2_current/");
Ok(())
}
/// Move any work item to an arbitrary pipeline stage by searching all stages.
///
/// Accepts `target_stage` as one of: `backlog`, `current`, `qa`, `merge`, `done`.
/// Idempotent: if the item is already in the target stage, returns Ok.
/// Returns `(from_stage, to_stage)` on success.
pub fn move_story_to_stage(
project_root: &Path,
story_id: &str,
target_stage: &str,
) -> Result<(String, String), String> {
let stage_dirs: &[(&str, &str)] = &[
("backlog", "1_backlog"),
("current", "2_current"),
("qa", "3_qa"),
("merge", "4_merge"),
("done", "5_done"),
];
let target_dir_name = stage_dirs
.iter()
.find(|(name, _)| *name == target_stage)
.map(|(_, dir)| *dir)
.ok_or_else(|| {
format!(
"Invalid target_stage '{target_stage}'. Must be one of: backlog, current, qa, merge, done"
)
})?;
let sk = project_root.join(".storkit").join("work");
let target_dir = sk.join(target_dir_name);
let target_path = target_dir.join(format!("{story_id}.md"));
if target_path.exists() {
return Ok((target_stage.to_string(), target_stage.to_string()));
}
// Search all named stages plus the archive stage.
let search_dirs: &[(&str, &str)] = &[
("backlog", "1_backlog"),
("current", "2_current"),
("qa", "3_qa"),
("merge", "4_merge"),
("done", "5_done"),
("archived", "6_archived"),
];
let mut found_path: Option<std::path::PathBuf> = None;
let mut from_stage = "";
for (stage_name, dir_name) in search_dirs {
let candidate = sk.join(dir_name).join(format!("{story_id}.md"));
if candidate.exists() {
found_path = Some(candidate);
from_stage = stage_name;
break;
}
}
let source_path =
found_path.ok_or_else(|| format!("Work item '{story_id}' not found in any pipeline stage."))?;
std::fs::create_dir_all(&target_dir)
.map_err(|e| format!("Failed to create work/{target_dir_name}/ directory: {e}"))?;
std::fs::rename(&source_path, &target_path)
.map_err(|e| format!("Failed to move '{story_id}' to work/{target_dir_name}/: {e}"))?;
slog!(
"[lifecycle] Moved '{story_id}' from work/{from_stage}/ to work/{target_dir_name}/"
);
Ok((from_stage.to_string(), target_stage.to_string()))
}
/// Move a bug from `work/2_current/` or `work/1_backlog/` to `work/5_done/` and auto-commit.
///
/// * If the bug is in `2_current/`, it is moved to `5_done/` and committed.
/// * If the bug is still in `1_backlog/` (never started), it is moved directly to `5_done/`.
/// * If the bug is already in `5_done/`, this is a no-op (idempotent).
/// * If the bug is not found anywhere, an error is returned.
pub fn close_bug_to_archive(project_root: &Path, bug_id: &str) -> Result<(), String> {
let sk = project_root.join(".storkit").join("work");
let current_path = sk.join("2_current").join(format!("{bug_id}.md"));
let backlog_path = sk.join("1_backlog").join(format!("{bug_id}.md"));
let archive_dir = item_archive_dir(project_root, bug_id);
let archive_path = archive_dir.join(format!("{bug_id}.md"));
if archive_path.exists() {
return Ok(());
}
let source_path = if current_path.exists() {
current_path.clone()
} else if backlog_path.exists() {
backlog_path.clone()
} else {
return Err(format!(
"Bug '{bug_id}' not found in work/2_current/ or work/1_backlog/. Cannot close bug."
));
};
std::fs::create_dir_all(&archive_dir)
.map_err(|e| format!("Failed to create work/5_done/ directory: {e}"))?;
std::fs::rename(&source_path, &archive_path)
.map_err(|e| format!("Failed to move bug '{bug_id}' to 5_done/: {e}"))?;
slog!(
"[lifecycle] Closed bug '{bug_id}' → work/5_done/"
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
// ── move_story_to_current tests ────────────────────────────────────────────
#[test]
fn move_story_to_current_moves_file() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let backlog = root.join(".storkit/work/1_backlog");
let current = root.join(".storkit/work/2_current");
fs::create_dir_all(&backlog).unwrap();
fs::create_dir_all(&current).unwrap();
fs::write(backlog.join("10_story_foo.md"), "test").unwrap();
move_story_to_current(root, "10_story_foo").unwrap();
assert!(!backlog.join("10_story_foo.md").exists());
assert!(current.join("10_story_foo.md").exists());
}
#[test]
fn move_story_to_current_is_idempotent_when_already_current() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let current = root.join(".storkit/work/2_current");
fs::create_dir_all(&current).unwrap();
fs::write(current.join("11_story_foo.md"), "test").unwrap();
move_story_to_current(root, "11_story_foo").unwrap();
assert!(current.join("11_story_foo.md").exists());
}
#[test]
fn move_story_to_current_noop_when_not_in_backlog() {
let tmp = tempfile::tempdir().unwrap();
assert!(move_story_to_current(tmp.path(), "99_missing").is_ok());
}
#[test]
fn move_bug_to_current_moves_from_backlog() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let backlog = root.join(".storkit/work/1_backlog");
let current = root.join(".storkit/work/2_current");
fs::create_dir_all(&backlog).unwrap();
fs::create_dir_all(&current).unwrap();
fs::write(backlog.join("1_bug_test.md"), "# Bug 1\n").unwrap();
move_story_to_current(root, "1_bug_test").unwrap();
assert!(!backlog.join("1_bug_test.md").exists());
assert!(current.join("1_bug_test.md").exists());
}
// ── close_bug_to_archive tests ─────────────────────────────────────────────
#[test]
fn close_bug_moves_from_current_to_archive() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let current = root.join(".storkit/work/2_current");
fs::create_dir_all(&current).unwrap();
fs::write(current.join("2_bug_test.md"), "# Bug 2\n").unwrap();
close_bug_to_archive(root, "2_bug_test").unwrap();
assert!(!current.join("2_bug_test.md").exists());
assert!(root.join(".storkit/work/5_done/2_bug_test.md").exists());
}
#[test]
fn close_bug_moves_from_backlog_when_not_started() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let backlog = root.join(".storkit/work/1_backlog");
fs::create_dir_all(&backlog).unwrap();
fs::write(backlog.join("3_bug_test.md"), "# Bug 3\n").unwrap();
close_bug_to_archive(root, "3_bug_test").unwrap();
assert!(!backlog.join("3_bug_test.md").exists());
assert!(root.join(".storkit/work/5_done/3_bug_test.md").exists());
}
// ── item_type_from_id tests ────────────────────────────────────────────────
#[test]
fn item_type_from_id_detects_types() {
assert_eq!(item_type_from_id("1_bug_test"), "bug");
assert_eq!(item_type_from_id("1_spike_research"), "spike");
assert_eq!(item_type_from_id("50_story_my_story"), "story");
assert_eq!(item_type_from_id("1_story_simple"), "story");
}
// ── move_story_to_merge tests ──────────────────────────────────────────────
#[test]
fn move_story_to_merge_moves_file() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let current = root.join(".storkit/work/2_current");
fs::create_dir_all(&current).unwrap();
fs::write(current.join("20_story_foo.md"), "test").unwrap();
move_story_to_merge(root, "20_story_foo").unwrap();
assert!(!current.join("20_story_foo.md").exists());
assert!(root.join(".storkit/work/4_merge/20_story_foo.md").exists());
}
#[test]
fn move_story_to_merge_from_qa_dir() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let qa_dir = root.join(".storkit/work/3_qa");
fs::create_dir_all(&qa_dir).unwrap();
fs::write(qa_dir.join("40_story_test.md"), "test").unwrap();
move_story_to_merge(root, "40_story_test").unwrap();
assert!(!qa_dir.join("40_story_test.md").exists());
assert!(root.join(".storkit/work/4_merge/40_story_test.md").exists());
}
#[test]
fn move_story_to_merge_idempotent_when_already_in_merge() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let merge_dir = root.join(".storkit/work/4_merge");
fs::create_dir_all(&merge_dir).unwrap();
fs::write(merge_dir.join("21_story_test.md"), "test").unwrap();
move_story_to_merge(root, "21_story_test").unwrap();
assert!(merge_dir.join("21_story_test.md").exists());
}
#[test]
fn move_story_to_merge_errors_when_not_in_current_or_qa() {
let tmp = tempfile::tempdir().unwrap();
let result = move_story_to_merge(tmp.path(), "99_nonexistent");
assert!(result.unwrap_err().contains("not found in work/2_current/ or work/3_qa/"));
}
// ── move_story_to_qa tests ────────────────────────────────────────────────
#[test]
fn move_story_to_qa_moves_file() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let current = root.join(".storkit/work/2_current");
fs::create_dir_all(&current).unwrap();
fs::write(current.join("30_story_qa.md"), "test").unwrap();
move_story_to_qa(root, "30_story_qa").unwrap();
assert!(!current.join("30_story_qa.md").exists());
assert!(root.join(".storkit/work/3_qa/30_story_qa.md").exists());
}
#[test]
fn move_story_to_qa_idempotent_when_already_in_qa() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let qa_dir = root.join(".storkit/work/3_qa");
fs::create_dir_all(&qa_dir).unwrap();
fs::write(qa_dir.join("31_story_test.md"), "test").unwrap();
move_story_to_qa(root, "31_story_test").unwrap();
assert!(qa_dir.join("31_story_test.md").exists());
}
#[test]
fn move_story_to_qa_errors_when_not_in_current() {
let tmp = tempfile::tempdir().unwrap();
let result = move_story_to_qa(tmp.path(), "99_nonexistent");
assert!(result.unwrap_err().contains("not found in work/2_current/"));
}
// ── move_story_to_archived tests ──────────────────────────────────────────
#[test]
fn move_story_to_archived_finds_in_merge_dir() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let merge_dir = root.join(".storkit/work/4_merge");
fs::create_dir_all(&merge_dir).unwrap();
fs::write(merge_dir.join("22_story_test.md"), "test").unwrap();
move_story_to_archived(root, "22_story_test").unwrap();
assert!(!merge_dir.join("22_story_test.md").exists());
assert!(root.join(".storkit/work/5_done/22_story_test.md").exists());
}
#[test]
fn move_story_to_archived_error_when_not_in_current_or_merge() {
let tmp = tempfile::tempdir().unwrap();
let result = move_story_to_archived(tmp.path(), "99_nonexistent");
assert!(result.unwrap_err().contains("4_merge"));
}
// ── feature_branch_has_unmerged_changes tests ────────────────────────────
fn init_git_repo(repo: &std::path::Path) {
Command::new("git")
.args(["init"])
.current_dir(repo)
.output()
.unwrap();
Command::new("git")
.args(["config", "user.email", "test@test.com"])
.current_dir(repo)
.output()
.unwrap();
Command::new("git")
.args(["config", "user.name", "Test"])
.current_dir(repo)
.output()
.unwrap();
Command::new("git")
.args(["commit", "--allow-empty", "-m", "init"])
.current_dir(repo)
.output()
.unwrap();
}
/// Bug 226: feature_branch_has_unmerged_changes returns true when the
/// feature branch has commits not on master.
#[test]
fn feature_branch_has_unmerged_changes_detects_unmerged_code() {
use std::fs;
use tempfile::tempdir;
let tmp = tempdir().unwrap();
let repo = tmp.path();
init_git_repo(repo);
// Create a feature branch with a code commit.
Command::new("git")
.args(["checkout", "-b", "feature/story-50_story_test"])
.current_dir(repo)
.output()
.unwrap();
fs::write(repo.join("feature.rs"), "fn main() {}").unwrap();
Command::new("git")
.args(["add", "."])
.current_dir(repo)
.output()
.unwrap();
Command::new("git")
.args(["commit", "-m", "add feature"])
.current_dir(repo)
.output()
.unwrap();
Command::new("git")
.args(["checkout", "master"])
.current_dir(repo)
.output()
.unwrap();
assert!(
feature_branch_has_unmerged_changes(repo, "50_story_test"),
"should detect unmerged changes on feature branch"
);
}
/// Bug 226: feature_branch_has_unmerged_changes returns false when no
/// feature branch exists.
#[test]
fn feature_branch_has_unmerged_changes_false_when_no_branch() {
use tempfile::tempdir;
let tmp = tempdir().unwrap();
let repo = tmp.path();
init_git_repo(repo);
assert!(
!feature_branch_has_unmerged_changes(repo, "99_nonexistent"),
"should return false when no feature branch"
);
}
// ── reject_story_from_qa tests ────────────────────────────────────────────
#[test]
fn reject_story_from_qa_moves_to_current() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let qa_dir = root.join(".storkit/work/3_qa");
let current_dir = root.join(".storkit/work/2_current");
fs::create_dir_all(&qa_dir).unwrap();
fs::create_dir_all(&current_dir).unwrap();
fs::write(
qa_dir.join("50_story_test.md"),
"---\nname: Test\nreview_hold: true\n---\n# Story\n",
)
.unwrap();
reject_story_from_qa(root, "50_story_test", "Button color wrong").unwrap();
assert!(!qa_dir.join("50_story_test.md").exists());
assert!(current_dir.join("50_story_test.md").exists());
let contents = fs::read_to_string(current_dir.join("50_story_test.md")).unwrap();
assert!(contents.contains("Button color wrong"));
assert!(contents.contains("## QA Rejection Notes"));
assert!(!contents.contains("review_hold"));
}
#[test]
fn reject_story_from_qa_errors_when_not_in_qa() {
let tmp = tempfile::tempdir().unwrap();
let result = reject_story_from_qa(tmp.path(), "99_nonexistent", "notes");
assert!(result.unwrap_err().contains("not found in work/3_qa/"));
}
#[test]
fn reject_story_from_qa_idempotent_when_in_current() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let current_dir = root.join(".storkit/work/2_current");
fs::create_dir_all(&current_dir).unwrap();
fs::write(current_dir.join("51_story_test.md"), "---\nname: Test\n---\n# Story\n").unwrap();
reject_story_from_qa(root, "51_story_test", "notes").unwrap();
assert!(current_dir.join("51_story_test.md").exists());
}
// ── move_story_to_stage tests ─────────────────────────────────
#[test]
fn move_story_to_stage_moves_from_backlog_to_current() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let backlog = root.join(".storkit/work/1_backlog");
let current = root.join(".storkit/work/2_current");
fs::create_dir_all(&backlog).unwrap();
fs::create_dir_all(&current).unwrap();
fs::write(backlog.join("60_story_move.md"), "test").unwrap();
let (from, to) = move_story_to_stage(root, "60_story_move", "current").unwrap();
assert_eq!(from, "backlog");
assert_eq!(to, "current");
assert!(!backlog.join("60_story_move.md").exists());
assert!(current.join("60_story_move.md").exists());
}
#[test]
fn move_story_to_stage_moves_from_current_to_backlog() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let current = root.join(".storkit/work/2_current");
let backlog = root.join(".storkit/work/1_backlog");
fs::create_dir_all(&current).unwrap();
fs::create_dir_all(&backlog).unwrap();
fs::write(current.join("61_story_back.md"), "test").unwrap();
let (from, to) = move_story_to_stage(root, "61_story_back", "backlog").unwrap();
assert_eq!(from, "current");
assert_eq!(to, "backlog");
assert!(!current.join("61_story_back.md").exists());
assert!(backlog.join("61_story_back.md").exists());
}
#[test]
fn move_story_to_stage_idempotent_when_already_in_target() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let current = root.join(".storkit/work/2_current");
fs::create_dir_all(&current).unwrap();
fs::write(current.join("62_story_idem.md"), "test").unwrap();
let (from, to) = move_story_to_stage(root, "62_story_idem", "current").unwrap();
assert_eq!(from, "current");
assert_eq!(to, "current");
assert!(current.join("62_story_idem.md").exists());
}
#[test]
fn move_story_to_stage_invalid_target_returns_error() {
let tmp = tempfile::tempdir().unwrap();
let result = move_story_to_stage(tmp.path(), "1_story_test", "invalid");
assert!(result.is_err());
assert!(result.unwrap_err().contains("Invalid target_stage"));
}
#[test]
fn move_story_to_stage_not_found_returns_error() {
let tmp = tempfile::tempdir().unwrap();
let result = move_story_to_stage(tmp.path(), "99_story_ghost", "current");
assert!(result.is_err());
assert!(result.unwrap_err().contains("not found in any pipeline stage"));
}
#[test]
fn move_story_to_stage_finds_in_qa_dir() {
use std::fs;
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let qa_dir = root.join(".storkit/work/3_qa");
let backlog = root.join(".storkit/work/1_backlog");
fs::create_dir_all(&qa_dir).unwrap();
fs::create_dir_all(&backlog).unwrap();
fs::write(qa_dir.join("63_story_qa.md"), "test").unwrap();
let (from, to) = move_story_to_stage(root, "63_story_qa", "backlog").unwrap();
assert_eq!(from, "qa");
assert_eq!(to, "backlog");
assert!(!qa_dir.join("63_story_qa.md").exists());
assert!(backlog.join("63_story_qa.md").exists());
}
}
File diff suppressed because it is too large Load Diff
+222
View File
@@ -0,0 +1,222 @@
pub mod gates;
pub mod lifecycle;
pub mod merge;
mod pool;
pub(crate) mod pty;
pub mod runtime;
pub mod token_usage;
use crate::config::AgentConfig;
use serde::{Deserialize, Serialize};
pub use lifecycle::{
close_bug_to_archive, feature_branch_has_unmerged_changes, move_story_to_archived,
move_story_to_merge, move_story_to_qa, move_story_to_stage, reject_story_from_qa,
};
pub use pool::AgentPool;
/// Events emitted during server startup reconciliation to broadcast real-time
/// progress to connected WebSocket clients.
#[derive(Debug, Clone, Serialize)]
pub struct ReconciliationEvent {
/// The story being reconciled, or empty string for the overall "done" event.
pub story_id: String,
/// Coarse status: "checking", "gates_running", "advanced", "skipped", "failed", "done"
pub status: String,
/// Human-readable details.
pub message: String,
}
/// Events streamed from a running agent to SSE clients.
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AgentEvent {
/// Agent status changed.
Status {
story_id: String,
agent_name: String,
status: String,
},
/// Raw text output from the agent process.
Output {
story_id: String,
agent_name: String,
text: String,
},
/// Agent produced a JSON event from `--output-format stream-json`.
AgentJson {
story_id: String,
agent_name: String,
data: serde_json::Value,
},
/// Agent finished.
Done {
story_id: String,
agent_name: String,
session_id: Option<String>,
},
/// Agent errored.
Error {
story_id: String,
agent_name: String,
message: String,
},
/// Thinking tokens from an extended-thinking block.
Thinking {
story_id: String,
agent_name: String,
text: String,
},
}
#[derive(Debug, Clone, Serialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum AgentStatus {
Pending,
Running,
Completed,
Failed,
}
impl std::fmt::Display for AgentStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Pending => write!(f, "pending"),
Self::Running => write!(f, "running"),
Self::Completed => write!(f, "completed"),
Self::Failed => write!(f, "failed"),
}
}
}
/// Pipeline stages for automatic story advancement.
#[derive(Debug, Clone, PartialEq)]
pub enum PipelineStage {
/// Coding agents (coder-1, coder-2, etc.)
Coder,
/// QA review agent
Qa,
/// Mergemaster agent
Mergemaster,
/// Supervisors and unknown agents — no automatic advancement.
Other,
}
/// Determine the pipeline stage from an agent name.
pub fn pipeline_stage(agent_name: &str) -> PipelineStage {
match agent_name {
"qa" => PipelineStage::Qa,
"mergemaster" => PipelineStage::Mergemaster,
name if name.starts_with("coder") => PipelineStage::Coder,
_ => PipelineStage::Other,
}
}
/// Determine the pipeline stage for a configured agent.
///
/// Prefers the explicit `stage` config field (added in Bug 150) over the
/// legacy name-based heuristic so that agents with non-standard names
/// (e.g. `qa-2`, `coder-opus`) are assigned to the correct stage.
pub(crate) fn agent_config_stage(cfg: &AgentConfig) -> PipelineStage {
match cfg.stage.as_deref() {
Some("coder") => PipelineStage::Coder,
Some("qa") => PipelineStage::Qa,
Some("mergemaster") => PipelineStage::Mergemaster,
Some(_) => PipelineStage::Other,
None => pipeline_stage(&cfg.name),
}
}
/// Completion report produced when acceptance gates are run.
///
/// Created automatically by the server when an agent process exits normally,
/// or via the internal `report_completion` method.
#[derive(Debug, Serialize, Clone)]
pub struct CompletionReport {
pub summary: String,
pub gates_passed: bool,
pub gate_output: String,
}
/// Token usage from a Claude Code session's `result` event.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TokenUsage {
pub input_tokens: u64,
pub output_tokens: u64,
pub cache_creation_input_tokens: u64,
pub cache_read_input_tokens: u64,
pub total_cost_usd: f64,
}
impl TokenUsage {
/// Parse token usage from a Claude Code `result` JSON event.
pub fn from_result_event(json: &serde_json::Value) -> Option<Self> {
let usage = json.get("usage")?;
Some(Self {
input_tokens: usage
.get("input_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0),
output_tokens: usage
.get("output_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0),
cache_creation_input_tokens: usage
.get("cache_creation_input_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0),
cache_read_input_tokens: usage
.get("cache_read_input_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0),
total_cost_usd: json
.get("total_cost_usd")
.and_then(|v| v.as_f64())
.unwrap_or(0.0),
})
}
}
#[derive(Debug, Serialize, Clone)]
pub struct AgentInfo {
pub story_id: String,
pub agent_name: String,
pub status: AgentStatus,
pub session_id: Option<String>,
pub worktree_path: Option<String>,
pub base_branch: Option<String>,
pub completion: Option<CompletionReport>,
/// UUID identifying the persistent log file for this session.
pub log_session_id: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
// ── pipeline_stage tests ──────────────────────────────────────────────────
#[test]
fn pipeline_stage_detects_coders() {
assert_eq!(pipeline_stage("coder-1"), PipelineStage::Coder);
assert_eq!(pipeline_stage("coder-2"), PipelineStage::Coder);
assert_eq!(pipeline_stage("coder-3"), PipelineStage::Coder);
}
#[test]
fn pipeline_stage_detects_qa() {
assert_eq!(pipeline_stage("qa"), PipelineStage::Qa);
}
#[test]
fn pipeline_stage_detects_mergemaster() {
assert_eq!(pipeline_stage("mergemaster"), PipelineStage::Mergemaster);
}
#[test]
fn pipeline_stage_supervisor_is_other() {
assert_eq!(pipeline_stage("supervisor"), PipelineStage::Other);
assert_eq!(pipeline_stage("default"), PipelineStage::Other);
assert_eq!(pipeline_stage("unknown"), PipelineStage::Other);
}
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+591
View File
@@ -0,0 +1,591 @@
use std::collections::HashMap;
use std::io::{BufRead, BufReader};
use std::sync::{Arc, Mutex};
use portable_pty::{ChildKiller, CommandBuilder, PtySize, native_pty_system};
use tokio::sync::broadcast;
use super::{AgentEvent, TokenUsage};
use crate::agent_log::AgentLogWriter;
use crate::io::watcher::WatcherEvent;
use crate::slog;
use crate::slog_warn;
/// Result from a PTY agent session, containing the session ID and token usage.
pub(in crate::agents) struct PtyResult {
pub session_id: Option<String>,
pub token_usage: Option<TokenUsage>,
}
fn composite_key(story_id: &str, agent_name: &str) -> String {
format!("{story_id}:{agent_name}")
}
struct ChildKillerGuard {
killers: Arc<Mutex<HashMap<String, Box<dyn ChildKiller + Send + Sync>>>>,
key: String,
}
impl Drop for ChildKillerGuard {
fn drop(&mut self) {
if let Ok(mut killers) = self.killers.lock() {
killers.remove(&self.key);
}
}
}
/// Spawn claude agent in a PTY and stream events through the broadcast channel.
#[allow(clippy::too_many_arguments)]
pub(in crate::agents) async fn run_agent_pty_streaming(
story_id: &str,
agent_name: &str,
command: &str,
args: &[String],
prompt: &str,
cwd: &str,
tx: &broadcast::Sender<AgentEvent>,
event_log: &Arc<Mutex<Vec<AgentEvent>>>,
log_writer: Option<Arc<Mutex<AgentLogWriter>>>,
inactivity_timeout_secs: u64,
child_killers: Arc<Mutex<HashMap<String, Box<dyn ChildKiller + Send + Sync>>>>,
watcher_tx: broadcast::Sender<WatcherEvent>,
) -> Result<PtyResult, String> {
let sid = story_id.to_string();
let aname = agent_name.to_string();
let cmd = command.to_string();
let args = args.to_vec();
let prompt = prompt.to_string();
let cwd = cwd.to_string();
let tx = tx.clone();
let event_log = event_log.clone();
tokio::task::spawn_blocking(move || {
run_agent_pty_blocking(
&sid,
&aname,
&cmd,
&args,
&prompt,
&cwd,
&tx,
&event_log,
log_writer.as_deref(),
inactivity_timeout_secs,
&child_killers,
&watcher_tx,
)
})
.await
.map_err(|e| format!("Agent task panicked: {e}"))?
}
/// Dispatch a `stream_event` from Claude Code's `--include-partial-messages` output.
///
/// Extracts `thinking_delta` and `text_delta` from `content_block_delta` events
/// and routes them as `AgentEvent::Thinking` and `AgentEvent::Output` respectively.
/// This ensures thinking traces flow through the dedicated `ThinkingBlock` UI
/// component rather than appearing as unbounded regular output.
fn handle_agent_stream_event(
event: &serde_json::Value,
story_id: &str,
agent_name: &str,
tx: &broadcast::Sender<AgentEvent>,
event_log: &Mutex<Vec<AgentEvent>>,
log_writer: Option<&Mutex<AgentLogWriter>>,
) {
let event_type = event.get("type").and_then(|t| t.as_str()).unwrap_or("");
if event_type == "content_block_delta"
&& let Some(delta) = event.get("delta")
{
let delta_type = delta.get("type").and_then(|t| t.as_str()).unwrap_or("");
match delta_type {
"thinking_delta" => {
if let Some(thinking) = delta.get("thinking").and_then(|t| t.as_str()) {
emit_event(
AgentEvent::Thinking {
story_id: story_id.to_string(),
agent_name: agent_name.to_string(),
text: thinking.to_string(),
},
tx,
event_log,
log_writer,
);
}
}
"text_delta" => {
if let Some(text) = delta.get("text").and_then(|t| t.as_str()) {
emit_event(
AgentEvent::Output {
story_id: story_id.to_string(),
agent_name: agent_name.to_string(),
text: text.to_string(),
},
tx,
event_log,
log_writer,
);
}
}
_ => {}
}
}
}
/// Helper to send an event to broadcast, event log, and optional persistent log file.
pub(super) fn emit_event(
event: AgentEvent,
tx: &broadcast::Sender<AgentEvent>,
event_log: &Mutex<Vec<AgentEvent>>,
log_writer: Option<&Mutex<AgentLogWriter>>,
) {
if let Ok(mut log) = event_log.lock() {
log.push(event.clone());
}
if let Some(writer) = log_writer
&& let Ok(mut w) = writer.lock()
&& let Err(e) = w.write_event(&event)
{
eprintln!("[agent_log] Failed to write event to log file: {e}");
}
let _ = tx.send(event);
}
#[allow(clippy::too_many_arguments)]
fn run_agent_pty_blocking(
story_id: &str,
agent_name: &str,
command: &str,
args: &[String],
prompt: &str,
cwd: &str,
tx: &broadcast::Sender<AgentEvent>,
event_log: &Mutex<Vec<AgentEvent>>,
log_writer: Option<&Mutex<AgentLogWriter>>,
inactivity_timeout_secs: u64,
child_killers: &Arc<Mutex<HashMap<String, Box<dyn ChildKiller + Send + Sync>>>>,
watcher_tx: &broadcast::Sender<WatcherEvent>,
) -> Result<PtyResult, String> {
let pty_system = native_pty_system();
let pair = pty_system
.openpty(PtySize {
rows: 50,
cols: 200,
pixel_width: 0,
pixel_height: 0,
})
.map_err(|e| format!("Failed to open PTY: {e}"))?;
let mut cmd = CommandBuilder::new(command);
// -p <prompt> must come first
cmd.arg("-p");
cmd.arg(prompt);
// Add configured args (e.g., --directory /path/to/worktree, --model, etc.)
for arg in args {
cmd.arg(arg);
}
cmd.arg("--output-format");
cmd.arg("stream-json");
cmd.arg("--verbose");
// Enable partial streaming so we receive thinking_delta and text_delta
// events in real-time, rather than only complete assistant events.
// Without this, thinking traces may not appear in the structured output
// and instead leak as unstructured PTY text.
cmd.arg("--include-partial-messages");
// Supervised agents don't need interactive permission prompts
cmd.arg("--permission-mode");
cmd.arg("bypassPermissions");
cmd.cwd(cwd);
cmd.env("NO_COLOR", "1");
// Allow spawning Claude Code from within a Claude Code session
cmd.env_remove("CLAUDECODE");
cmd.env_remove("CLAUDE_CODE_ENTRYPOINT");
slog!("[agent:{story_id}:{agent_name}] Spawning {command} in {cwd} with args: {args:?}");
let mut child = pair
.slave
.spawn_command(cmd)
.map_err(|e| format!("Failed to spawn agent for {story_id}:{agent_name}: {e}"))?;
// Register the child killer so that kill_all_children() / stop_agent() can
// terminate this process on server shutdown, even if the blocking thread
// cannot be interrupted. The ChildKillerGuard deregisters on function exit.
let killer_key = composite_key(story_id, agent_name);
{
let killer = child.clone_killer();
if let Ok(mut killers) = child_killers.lock() {
killers.insert(killer_key.clone(), killer);
}
}
let _killer_guard = ChildKillerGuard {
killers: Arc::clone(child_killers),
key: killer_key,
};
drop(pair.slave);
let reader = pair
.master
.try_clone_reader()
.map_err(|e| format!("Failed to clone PTY reader: {e}"))?;
drop(pair.master);
// Spawn a reader thread to collect PTY output lines.
// We use a channel so the main thread can apply an inactivity deadline
// via recv_timeout: if no output arrives within the configured window
// the process is killed and the agent is marked Failed.
let (line_tx, line_rx) = std::sync::mpsc::channel::<std::io::Result<String>>();
std::thread::spawn(move || {
let buf_reader = BufReader::new(reader);
for line in buf_reader.lines() {
if line_tx.send(line).is_err() {
break;
}
}
});
let timeout_dur = if inactivity_timeout_secs > 0 {
Some(std::time::Duration::from_secs(inactivity_timeout_secs))
} else {
None
};
let mut session_id: Option<String> = None;
let mut token_usage: Option<TokenUsage> = None;
loop {
let recv_result = match timeout_dur {
Some(dur) => line_rx.recv_timeout(dur),
None => line_rx
.recv()
.map_err(|_| std::sync::mpsc::RecvTimeoutError::Disconnected),
};
let line = match recv_result {
Ok(Ok(l)) => l,
Ok(Err(_)) => {
// IO error reading from PTY — treat as EOF.
break;
}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
// Reader thread exited (EOF from PTY).
break;
}
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
slog_warn!(
"[agent:{story_id}:{agent_name}] Inactivity timeout after \
{inactivity_timeout_secs}s with no output. Killing process."
);
let _ = child.kill();
let _ = child.wait();
return Err(format!(
"Agent inactivity timeout: no output received for {inactivity_timeout_secs}s"
));
}
};
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
// Try to parse as JSON
let json: serde_json::Value = match serde_json::from_str(trimmed) {
Ok(j) => j,
Err(_) => {
// Non-JSON output (terminal escapes etc.) — send as raw output
emit_event(
AgentEvent::Output {
story_id: story_id.to_string(),
agent_name: agent_name.to_string(),
text: trimmed.to_string(),
},
tx,
event_log,
log_writer,
);
continue;
}
};
let event_type = json.get("type").and_then(|t| t.as_str()).unwrap_or("");
match event_type {
"system" => {
session_id = json
.get("session_id")
.and_then(|s| s.as_str())
.map(|s| s.to_string());
}
// With --include-partial-messages, thinking and text arrive
// incrementally via stream_event → content_block_delta. Handle
// them here for real-time streaming to the frontend.
"stream_event" => {
if let Some(event) = json.get("event") {
handle_agent_stream_event(
event,
story_id,
agent_name,
tx,
event_log,
log_writer,
);
}
}
// Complete assistant events are skipped for content extraction
// because thinking and text already arrived via stream_event.
// The raw JSON is still forwarded as AgentJson below.
"assistant" | "user" => {}
"rate_limit_event" => {
slog!(
"[agent:{story_id}:{agent_name}] API rate limit warning received"
);
let _ = watcher_tx.send(WatcherEvent::RateLimitWarning {
story_id: story_id.to_string(),
agent_name: agent_name.to_string(),
});
}
"result" => {
// Extract token usage from the result event.
if let Some(usage) = TokenUsage::from_result_event(&json) {
slog!(
"[agent:{story_id}:{agent_name}] Token usage: in={} out={} cache_create={} cache_read={} cost=${:.4}",
usage.input_tokens,
usage.output_tokens,
usage.cache_creation_input_tokens,
usage.cache_read_input_tokens,
usage.total_cost_usd,
);
token_usage = Some(usage);
}
}
_ => {}
}
// Forward all JSON events
emit_event(
AgentEvent::AgentJson {
story_id: story_id.to_string(),
agent_name: agent_name.to_string(),
data: json,
},
tx,
event_log,
log_writer,
);
}
let _ = child.kill();
let _ = child.wait();
slog!(
"[agent:{story_id}:{agent_name}] Done. Session: {:?}",
session_id
);
Ok(PtyResult {
session_id,
token_usage,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agents::AgentEvent;
use crate::io::watcher::WatcherEvent;
use std::collections::HashMap;
use std::sync::Arc;
// ── AC1: pty detects rate_limit_event and emits RateLimitWarning ─────────
/// Verify that when a `rate_limit_event` JSON line appears in PTY output,
/// `run_agent_pty_streaming` sends a `WatcherEvent::RateLimitWarning` with
/// the correct story_id and agent_name.
///
/// The command invoked is: `sh -p -- <script>` where `--` terminates
/// option parsing so the script path is treated as the operand.
#[tokio::test]
async fn rate_limit_event_json_sends_watcher_warning() {
use std::os::unix::fs::PermissionsExt;
let tmp = tempfile::tempdir().unwrap();
let script = tmp.path().join("emit_rate_limit.sh");
std::fs::write(
&script,
"#!/bin/sh\nprintf '%s\\n' '{\"type\":\"rate_limit_event\",\"rate_limit_info\":{\"status\":\"allowed_warning\"}}'\n",
)
.unwrap();
std::fs::set_permissions(&script, std::fs::Permissions::from_mode(0o755)).unwrap();
let (tx, _rx) = broadcast::channel::<AgentEvent>(64);
let (watcher_tx, mut watcher_rx) = broadcast::channel::<WatcherEvent>(16);
let event_log = Arc::new(Mutex::new(Vec::new()));
let child_killers = Arc::new(Mutex::new(HashMap::new()));
// sh -p "--" <script>: -p = privileged mode, "--" = end options,
// then the script path is the file operand.
let result = run_agent_pty_streaming(
"365_story_test",
"coder-1",
"sh",
&[script.to_string_lossy().to_string()],
"--",
"/tmp",
&tx,
&event_log,
None,
0,
child_killers,
watcher_tx,
)
.await;
assert!(result.is_ok(), "PTY run should succeed: {:?}", result.err());
let evt = watcher_rx
.try_recv()
.expect("Expected a RateLimitWarning to be sent on watcher_tx");
match evt {
WatcherEvent::RateLimitWarning {
story_id,
agent_name,
} => {
assert_eq!(story_id, "365_story_test");
assert_eq!(agent_name, "coder-1");
}
other => panic!("Expected RateLimitWarning, got: {other:?}"),
}
}
#[test]
fn test_emit_event_writes_to_log_writer() {
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path();
let log_writer =
AgentLogWriter::new(root, "42_story_foo", "coder-1", "sess-emit").unwrap();
let log_mutex = Mutex::new(log_writer);
let (tx, _rx) = broadcast::channel::<AgentEvent>(64);
let event_log: Mutex<Vec<AgentEvent>> = Mutex::new(Vec::new());
let event = AgentEvent::Status {
story_id: "42_story_foo".to_string(),
agent_name: "coder-1".to_string(),
status: "running".to_string(),
};
emit_event(event, &tx, &event_log, Some(&log_mutex));
// Verify event was added to in-memory log
let mem_events = event_log.lock().unwrap();
assert_eq!(mem_events.len(), 1);
drop(mem_events);
// Verify event was written to the log file
let log_path =
crate::agent_log::log_file_path(root, "42_story_foo", "coder-1", "sess-emit");
let entries = crate::agent_log::read_log(&log_path).unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].event["type"], "status");
assert_eq!(entries[0].event["status"], "running");
}
// ── bug 167: handle_agent_stream_event routes thinking/text correctly ───
#[test]
fn stream_event_thinking_delta_emits_thinking_event() {
let (tx, mut rx) = broadcast::channel::<AgentEvent>(64);
let event_log: Mutex<Vec<AgentEvent>> = Mutex::new(Vec::new());
let event = serde_json::json!({
"type": "content_block_delta",
"delta": {"type": "thinking_delta", "thinking": "Let me analyze this..."}
});
handle_agent_stream_event(&event, "s1", "coder-1", &tx, &event_log, None);
let received = rx.try_recv().unwrap();
match received {
AgentEvent::Thinking {
story_id,
agent_name,
text,
} => {
assert_eq!(story_id, "s1");
assert_eq!(agent_name, "coder-1");
assert_eq!(text, "Let me analyze this...");
}
other => panic!("Expected Thinking event, got: {other:?}"),
}
}
#[test]
fn stream_event_text_delta_emits_output_event() {
let (tx, mut rx) = broadcast::channel::<AgentEvent>(64);
let event_log: Mutex<Vec<AgentEvent>> = Mutex::new(Vec::new());
let event = serde_json::json!({
"type": "content_block_delta",
"delta": {"type": "text_delta", "text": "Here is the result."}
});
handle_agent_stream_event(&event, "s1", "coder-1", &tx, &event_log, None);
let received = rx.try_recv().unwrap();
match received {
AgentEvent::Output {
story_id,
agent_name,
text,
} => {
assert_eq!(story_id, "s1");
assert_eq!(agent_name, "coder-1");
assert_eq!(text, "Here is the result.");
}
other => panic!("Expected Output event, got: {other:?}"),
}
}
#[test]
fn stream_event_input_json_delta_ignored() {
let (tx, mut rx) = broadcast::channel::<AgentEvent>(64);
let event_log: Mutex<Vec<AgentEvent>> = Mutex::new(Vec::new());
let event = serde_json::json!({
"type": "content_block_delta",
"delta": {"type": "input_json_delta", "partial_json": "{\"file\":"}
});
handle_agent_stream_event(&event, "s1", "coder-1", &tx, &event_log, None);
// No event should be emitted for tool argument deltas
assert!(rx.try_recv().is_err());
}
#[test]
fn stream_event_non_delta_type_ignored() {
let (tx, mut rx) = broadcast::channel::<AgentEvent>(64);
let event_log: Mutex<Vec<AgentEvent>> = Mutex::new(Vec::new());
let event = serde_json::json!({
"type": "message_start",
"message": {"role": "assistant"}
});
handle_agent_stream_event(&event, "s1", "coder-1", &tx, &event_log, None);
assert!(rx.try_recv().is_err());
}
}
+73
View File
@@ -0,0 +1,73 @@
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use portable_pty::ChildKiller;
use tokio::sync::broadcast;
use crate::agent_log::AgentLogWriter;
use crate::io::watcher::WatcherEvent;
use super::{AgentEvent, AgentRuntime, RuntimeContext, RuntimeResult, RuntimeStatus};
/// Agent runtime that spawns the `claude` CLI in a PTY and streams JSON events.
///
/// This is the default runtime (`runtime = "claude-code"` in project.toml).
/// It wraps the existing PTY-based execution logic, preserving all streaming,
/// token tracking, and inactivity timeout behaviour.
pub struct ClaudeCodeRuntime {
child_killers: Arc<Mutex<HashMap<String, Box<dyn ChildKiller + Send + Sync>>>>,
watcher_tx: broadcast::Sender<WatcherEvent>,
}
impl ClaudeCodeRuntime {
pub fn new(
child_killers: Arc<Mutex<HashMap<String, Box<dyn ChildKiller + Send + Sync>>>>,
watcher_tx: broadcast::Sender<WatcherEvent>,
) -> Self {
Self {
child_killers,
watcher_tx,
}
}
}
impl AgentRuntime for ClaudeCodeRuntime {
async fn start(
&self,
ctx: RuntimeContext,
tx: broadcast::Sender<AgentEvent>,
event_log: Arc<Mutex<Vec<AgentEvent>>>,
log_writer: Option<Arc<Mutex<AgentLogWriter>>>,
) -> Result<RuntimeResult, String> {
let pty_result = super::super::pty::run_agent_pty_streaming(
&ctx.story_id,
&ctx.agent_name,
&ctx.command,
&ctx.args,
&ctx.prompt,
&ctx.cwd,
&tx,
&event_log,
log_writer,
ctx.inactivity_timeout_secs,
Arc::clone(&self.child_killers),
self.watcher_tx.clone(),
)
.await?;
Ok(RuntimeResult {
session_id: pty_result.session_id,
token_usage: pty_result.token_usage,
})
}
fn stop(&self) {
// Stopping is handled externally by the pool via kill_child_for_key().
// The ChildKillerGuard in pty.rs deregisters automatically on process exit.
}
fn get_status(&self) -> RuntimeStatus {
// Lifecycle status is tracked by the pool; the runtime itself is stateless.
RuntimeStatus::Idle
}
}
+809
View File
@@ -0,0 +1,809 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use tokio::sync::broadcast;
use crate::agent_log::AgentLogWriter;
use crate::slog;
use super::super::{AgentEvent, TokenUsage};
use super::{AgentRuntime, RuntimeContext, RuntimeResult, RuntimeStatus};
// ── Public runtime struct ────────────────────────────────────────────
/// Agent runtime that drives a Gemini model through the Google AI
/// `generateContent` REST API.
///
/// The runtime:
/// 1. Fetches MCP tool definitions from storkit's MCP server.
/// 2. Converts them to Gemini function-calling format.
/// 3. Sends the agent prompt + tools to the Gemini API.
/// 4. Executes any requested function calls via MCP `tools/call`.
/// 5. Loops until the model produces a text-only response or an error.
/// 6. Tracks token usage from the API response metadata.
pub struct GeminiRuntime {
/// Whether a stop has been requested.
cancelled: Arc<AtomicBool>,
}
impl GeminiRuntime {
pub fn new() -> Self {
Self {
cancelled: Arc::new(AtomicBool::new(false)),
}
}
}
impl AgentRuntime for GeminiRuntime {
async fn start(
&self,
ctx: RuntimeContext,
tx: broadcast::Sender<AgentEvent>,
event_log: Arc<Mutex<Vec<AgentEvent>>>,
log_writer: Option<Arc<Mutex<AgentLogWriter>>>,
) -> Result<RuntimeResult, String> {
let api_key = std::env::var("GOOGLE_AI_API_KEY").map_err(|_| {
"GOOGLE_AI_API_KEY environment variable is not set. \
Set it to your Google AI API key to use the Gemini runtime."
.to_string()
})?;
let model = if ctx.command.starts_with("gemini") {
// The pool puts the model into `command` for non-CLI runtimes,
// but also check args for a --model flag.
ctx.command.clone()
} else {
// Fall back to args: look for --model <value>
ctx.args
.iter()
.position(|a| a == "--model")
.and_then(|i| ctx.args.get(i + 1))
.cloned()
.unwrap_or_else(|| "gemini-2.5-pro".to_string())
};
let mcp_port = ctx.mcp_port;
let mcp_base = format!("http://localhost:{mcp_port}/mcp");
let client = Client::new();
let cancelled = Arc::clone(&self.cancelled);
// Step 1: Fetch MCP tool definitions and convert to Gemini format.
let gemini_tools = fetch_and_convert_mcp_tools(&client, &mcp_base).await?;
// Step 2: Build the initial conversation contents.
let system_instruction = build_system_instruction(&ctx);
let mut contents: Vec<Value> = vec![json!({
"role": "user",
"parts": [{ "text": ctx.prompt }]
})];
let mut total_usage = TokenUsage {
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
total_cost_usd: 0.0,
};
let emit = |event: AgentEvent| {
super::super::pty::emit_event(
event,
&tx,
&event_log,
log_writer.as_ref().map(|w| w.as_ref()),
);
};
emit(AgentEvent::Status {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
status: "running".to_string(),
});
// Step 3: Conversation loop.
let mut turn = 0u32;
let max_turns = 200; // Safety limit
loop {
if cancelled.load(Ordering::Relaxed) {
emit(AgentEvent::Error {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
message: "Agent was stopped by user".to_string(),
});
return Ok(RuntimeResult {
session_id: None,
token_usage: Some(total_usage),
});
}
turn += 1;
if turn > max_turns {
emit(AgentEvent::Error {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
message: format!("Exceeded maximum turns ({max_turns})"),
});
return Ok(RuntimeResult {
session_id: None,
token_usage: Some(total_usage),
});
}
slog!("[gemini] Turn {turn} for {}:{}", ctx.story_id, ctx.agent_name);
let request_body = build_generate_content_request(
&system_instruction,
&contents,
&gemini_tools,
);
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}"
);
let response = client
.post(&url)
.json(&request_body)
.send()
.await
.map_err(|e| format!("Gemini API request failed: {e}"))?;
let status = response.status();
let body: Value = response
.json()
.await
.map_err(|e| format!("Failed to parse Gemini API response: {e}"))?;
if !status.is_success() {
let error_msg = body["error"]["message"]
.as_str()
.unwrap_or("Unknown API error");
let err = format!("Gemini API error ({status}): {error_msg}");
emit(AgentEvent::Error {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
message: err.clone(),
});
return Err(err);
}
// Accumulate token usage.
if let Some(usage) = parse_usage_metadata(&body) {
total_usage.input_tokens += usage.input_tokens;
total_usage.output_tokens += usage.output_tokens;
}
// Extract the candidate response.
let candidate = body["candidates"]
.as_array()
.and_then(|c| c.first())
.ok_or_else(|| "No candidates in Gemini response".to_string())?;
let parts = candidate["content"]["parts"]
.as_array()
.ok_or_else(|| "No parts in Gemini response candidate".to_string())?;
// Check finish reason.
let finish_reason = candidate["finishReason"].as_str().unwrap_or("");
// Separate text parts and function call parts.
let mut text_parts: Vec<String> = Vec::new();
let mut function_calls: Vec<GeminiFunctionCall> = Vec::new();
for part in parts {
if let Some(text) = part["text"].as_str() {
text_parts.push(text.to_string());
}
if let Some(fc) = part.get("functionCall")
&& let (Some(name), Some(args)) =
(fc["name"].as_str(), fc.get("args"))
{
function_calls.push(GeminiFunctionCall {
name: name.to_string(),
args: args.clone(),
});
}
}
// Emit any text output.
for text in &text_parts {
if !text.is_empty() {
emit(AgentEvent::Output {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
text: text.clone(),
});
}
}
// If no function calls, the model is done.
if function_calls.is_empty() {
emit(AgentEvent::Done {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
session_id: None,
});
return Ok(RuntimeResult {
session_id: None,
token_usage: Some(total_usage),
});
}
// Add the model's response to the conversation.
let model_parts: Vec<Value> = parts.to_vec();
contents.push(json!({
"role": "model",
"parts": model_parts
}));
// Execute function calls via MCP and build response parts.
let mut response_parts: Vec<Value> = Vec::new();
for fc in &function_calls {
if cancelled.load(Ordering::Relaxed) {
break;
}
slog!(
"[gemini] Calling MCP tool '{}' for {}:{}",
fc.name,
ctx.story_id,
ctx.agent_name
);
emit(AgentEvent::Output {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
text: format!("\n[Tool call: {}]\n", fc.name),
});
let tool_result =
call_mcp_tool(&client, &mcp_base, &fc.name, &fc.args).await;
let response_value = match &tool_result {
Ok(result) => {
emit(AgentEvent::Output {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
text: format!(
"[Tool result: {} chars]\n",
result.len()
),
});
json!({ "result": result })
}
Err(e) => {
emit(AgentEvent::Output {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
text: format!("[Tool error: {e}]\n"),
});
json!({ "error": e })
}
};
response_parts.push(json!({
"functionResponse": {
"name": fc.name,
"response": response_value
}
}));
}
// Add function responses to the conversation.
contents.push(json!({
"role": "user",
"parts": response_parts
}));
// If the model indicated it's done despite having function calls,
// respect the finish reason.
if finish_reason == "STOP" && function_calls.is_empty() {
break;
}
}
emit(AgentEvent::Done {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
session_id: None,
});
Ok(RuntimeResult {
session_id: None,
token_usage: Some(total_usage),
})
}
fn stop(&self) {
self.cancelled.store(true, Ordering::Relaxed);
}
fn get_status(&self) -> RuntimeStatus {
if self.cancelled.load(Ordering::Relaxed) {
RuntimeStatus::Failed
} else {
RuntimeStatus::Idle
}
}
}
// ── Internal types ───────────────────────────────────────────────────
struct GeminiFunctionCall {
name: String,
args: Value,
}
// ── Gemini API types (for serde) ─────────────────────────────────────
#[derive(Debug, Serialize, Deserialize)]
struct GeminiFunctionDeclaration {
name: String,
description: String,
#[serde(skip_serializing_if = "Option::is_none")]
parameters: Option<Value>,
}
// ── Helper functions ─────────────────────────────────────────────────
/// Build the system instruction content from the RuntimeContext.
fn build_system_instruction(ctx: &RuntimeContext) -> Value {
// Use system_prompt from args if provided via --append-system-prompt,
// otherwise use a sensible default.
let system_text = ctx
.args
.iter()
.position(|a| a == "--append-system-prompt")
.and_then(|i| ctx.args.get(i + 1))
.cloned()
.unwrap_or_else(|| {
format!(
"You are an AI coding agent working on story {}. \
You have access to tools via function calling. \
Use them to complete the task. \
Work in the directory: {}",
ctx.story_id, ctx.cwd
)
});
json!({
"parts": [{ "text": system_text }]
})
}
/// Build the full `generateContent` request body.
fn build_generate_content_request(
system_instruction: &Value,
contents: &[Value],
gemini_tools: &[GeminiFunctionDeclaration],
) -> Value {
let mut body = json!({
"system_instruction": system_instruction,
"contents": contents,
"generationConfig": {
"temperature": 0.2,
"maxOutputTokens": 65536,
}
});
if !gemini_tools.is_empty() {
body["tools"] = json!([{
"functionDeclarations": gemini_tools
}]);
}
body
}
/// Fetch MCP tool definitions from storkit's MCP server and convert
/// them to Gemini function declaration format.
async fn fetch_and_convert_mcp_tools(
client: &Client,
mcp_base: &str,
) -> Result<Vec<GeminiFunctionDeclaration>, String> {
let request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list",
"params": {}
});
let response = client
.post(mcp_base)
.json(&request)
.send()
.await
.map_err(|e| format!("Failed to fetch MCP tools: {e}"))?;
let body: Value = response
.json()
.await
.map_err(|e| format!("Failed to parse MCP tools response: {e}"))?;
let tools = body["result"]["tools"]
.as_array()
.ok_or_else(|| "No tools array in MCP response".to_string())?;
let mut declarations = Vec::new();
for tool in tools {
let name = tool["name"].as_str().unwrap_or("").to_string();
let description = tool["description"].as_str().unwrap_or("").to_string();
if name.is_empty() {
continue;
}
// Convert MCP inputSchema (JSON Schema) to Gemini parameters
// (OpenAPI-subset schema). They are structurally compatible for
// simple object schemas.
let parameters = convert_mcp_schema_to_gemini(tool.get("inputSchema"));
declarations.push(GeminiFunctionDeclaration {
name,
description,
parameters,
});
}
slog!("[gemini] Loaded {} MCP tools as function declarations", declarations.len());
Ok(declarations)
}
/// Convert an MCP inputSchema (JSON Schema) to a Gemini-compatible
/// OpenAPI-subset parameter schema.
///
/// Gemini function calling expects parameters in OpenAPI format, which
/// is structurally similar to JSON Schema for simple object types.
/// We strip unsupported fields and ensure the type is "object".
fn convert_mcp_schema_to_gemini(schema: Option<&Value>) -> Option<Value> {
let schema = schema?;
// If the schema has no properties (empty tool), return None.
let properties = schema.get("properties")?;
if properties.as_object().is_some_and(|p| p.is_empty()) {
return None;
}
let mut result = json!({
"type": "object",
"properties": clean_schema_properties(properties),
});
// Preserve required fields if present.
if let Some(required) = schema.get("required") {
result["required"] = required.clone();
}
Some(result)
}
/// Recursively clean schema properties to be Gemini-compatible.
/// Removes unsupported JSON Schema keywords.
fn clean_schema_properties(properties: &Value) -> Value {
let Some(obj) = properties.as_object() else {
return properties.clone();
};
let mut cleaned = serde_json::Map::new();
for (key, value) in obj {
let mut prop = value.clone();
// Remove JSON Schema keywords not supported by Gemini
if let Some(p) = prop.as_object_mut() {
p.remove("$schema");
p.remove("additionalProperties");
// Recursively clean nested object properties
if let Some(nested_props) = p.get("properties").cloned() {
p.insert(
"properties".to_string(),
clean_schema_properties(&nested_props),
);
}
// Clean items schema for arrays
if let Some(items) = p.get("items").cloned()
&& let Some(items_obj) = items.as_object()
{
let mut cleaned_items = items_obj.clone();
cleaned_items.remove("$schema");
cleaned_items.remove("additionalProperties");
p.insert("items".to_string(), Value::Object(cleaned_items));
}
}
cleaned.insert(key.clone(), prop);
}
Value::Object(cleaned)
}
/// Call an MCP tool via storkit's MCP server.
async fn call_mcp_tool(
client: &Client,
mcp_base: &str,
tool_name: &str,
args: &Value,
) -> Result<String, String> {
let request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": args
}
});
let response = client
.post(mcp_base)
.json(&request)
.send()
.await
.map_err(|e| format!("MCP tool call failed: {e}"))?;
let body: Value = response
.json()
.await
.map_err(|e| format!("Failed to parse MCP tool response: {e}"))?;
if let Some(error) = body.get("error") {
let msg = error["message"].as_str().unwrap_or("Unknown MCP error");
return Err(format!("MCP tool '{tool_name}' error: {msg}"));
}
// MCP tools/call returns { result: { content: [{ type: "text", text: "..." }] } }
let content = &body["result"]["content"];
if let Some(arr) = content.as_array() {
let texts: Vec<&str> = arr
.iter()
.filter_map(|c| c["text"].as_str())
.collect();
if !texts.is_empty() {
return Ok(texts.join("\n"));
}
}
// Fall back to serializing the entire result.
Ok(body["result"].to_string())
}
/// Parse token usage metadata from a Gemini API response.
fn parse_usage_metadata(response: &Value) -> Option<TokenUsage> {
let metadata = response.get("usageMetadata")?;
Some(TokenUsage {
input_tokens: metadata
.get("promptTokenCount")
.and_then(|v| v.as_u64())
.unwrap_or(0),
output_tokens: metadata
.get("candidatesTokenCount")
.and_then(|v| v.as_u64())
.unwrap_or(0),
// Gemini doesn't have cache token fields, but we keep the struct uniform.
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
// Google AI API doesn't report cost; leave at 0.
total_cost_usd: 0.0,
})
}
// ── Tests ────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn convert_mcp_schema_simple_object() {
let schema = json!({
"type": "object",
"properties": {
"story_id": {
"type": "string",
"description": "Story identifier"
}
},
"required": ["story_id"]
});
let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap();
assert_eq!(result["type"], "object");
assert!(result["properties"]["story_id"].is_object());
assert_eq!(result["required"][0], "story_id");
}
#[test]
fn convert_mcp_schema_empty_properties_returns_none() {
let schema = json!({
"type": "object",
"properties": {}
});
assert!(convert_mcp_schema_to_gemini(Some(&schema)).is_none());
}
#[test]
fn convert_mcp_schema_none_returns_none() {
assert!(convert_mcp_schema_to_gemini(None).is_none());
}
#[test]
fn convert_mcp_schema_strips_additional_properties() {
let schema = json!({
"type": "object",
"properties": {
"name": {
"type": "string",
"additionalProperties": false,
"$schema": "http://json-schema.org/draft-07/schema#"
}
}
});
let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap();
let name_prop = &result["properties"]["name"];
assert!(name_prop.get("additionalProperties").is_none());
assert!(name_prop.get("$schema").is_none());
assert_eq!(name_prop["type"], "string");
}
#[test]
fn convert_mcp_schema_with_nested_objects() {
let schema = json!({
"type": "object",
"properties": {
"config": {
"type": "object",
"properties": {
"key": { "type": "string" }
}
}
}
});
let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap();
assert!(result["properties"]["config"]["properties"]["key"].is_object());
}
#[test]
fn convert_mcp_schema_with_array_items() {
let schema = json!({
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" }
},
"additionalProperties": false
}
}
}
});
let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap();
let items_schema = &result["properties"]["items"]["items"];
assert!(items_schema.get("additionalProperties").is_none());
}
#[test]
fn build_system_instruction_uses_args() {
let ctx = RuntimeContext {
story_id: "42_story_test".to_string(),
agent_name: "coder-1".to_string(),
command: "gemini-2.5-pro".to_string(),
args: vec![
"--append-system-prompt".to_string(),
"Custom system prompt".to_string(),
],
prompt: "Do the thing".to_string(),
cwd: "/tmp/wt".to_string(),
inactivity_timeout_secs: 300,
mcp_port: 3001,
};
let instruction = build_system_instruction(&ctx);
assert_eq!(instruction["parts"][0]["text"], "Custom system prompt");
}
#[test]
fn build_system_instruction_default() {
let ctx = RuntimeContext {
story_id: "42_story_test".to_string(),
agent_name: "coder-1".to_string(),
command: "gemini-2.5-pro".to_string(),
args: vec![],
prompt: "Do the thing".to_string(),
cwd: "/tmp/wt".to_string(),
inactivity_timeout_secs: 300,
mcp_port: 3001,
};
let instruction = build_system_instruction(&ctx);
let text = instruction["parts"][0]["text"].as_str().unwrap();
assert!(text.contains("42_story_test"));
assert!(text.contains("/tmp/wt"));
}
#[test]
fn build_generate_content_request_includes_tools() {
let system = json!({"parts": [{"text": "system"}]});
let contents = vec![json!({"role": "user", "parts": [{"text": "hello"}]})];
let tools = vec![GeminiFunctionDeclaration {
name: "my_tool".to_string(),
description: "A tool".to_string(),
parameters: Some(json!({"type": "object", "properties": {"x": {"type": "string"}}})),
}];
let body = build_generate_content_request(&system, &contents, &tools);
assert!(body["tools"][0]["functionDeclarations"].is_array());
assert_eq!(body["tools"][0]["functionDeclarations"][0]["name"], "my_tool");
}
#[test]
fn build_generate_content_request_no_tools() {
let system = json!({"parts": [{"text": "system"}]});
let contents = vec![json!({"role": "user", "parts": [{"text": "hello"}]})];
let tools: Vec<GeminiFunctionDeclaration> = vec![];
let body = build_generate_content_request(&system, &contents, &tools);
assert!(body.get("tools").is_none());
}
#[test]
fn parse_usage_metadata_valid() {
let response = json!({
"usageMetadata": {
"promptTokenCount": 100,
"candidatesTokenCount": 50,
"totalTokenCount": 150
}
});
let usage = parse_usage_metadata(&response).unwrap();
assert_eq!(usage.input_tokens, 100);
assert_eq!(usage.output_tokens, 50);
assert_eq!(usage.cache_creation_input_tokens, 0);
assert_eq!(usage.total_cost_usd, 0.0);
}
#[test]
fn parse_usage_metadata_missing() {
let response = json!({"candidates": []});
assert!(parse_usage_metadata(&response).is_none());
}
#[test]
fn gemini_runtime_stop_sets_cancelled() {
let runtime = GeminiRuntime::new();
assert_eq!(runtime.get_status(), RuntimeStatus::Idle);
runtime.stop();
assert_eq!(runtime.get_status(), RuntimeStatus::Failed);
}
#[test]
fn model_extraction_from_command() {
// When command starts with "gemini", use it as model name
let ctx = RuntimeContext {
story_id: "1".to_string(),
agent_name: "coder".to_string(),
command: "gemini-2.5-pro".to_string(),
args: vec![],
prompt: "test".to_string(),
cwd: "/tmp".to_string(),
inactivity_timeout_secs: 300,
mcp_port: 3001,
};
// The model extraction logic is inside start(), but we test the
// condition here.
assert!(ctx.command.starts_with("gemini"));
}
}
+163
View File
@@ -0,0 +1,163 @@
mod claude_code;
mod gemini;
mod openai;
pub use claude_code::ClaudeCodeRuntime;
pub use gemini::GeminiRuntime;
pub use openai::OpenAiRuntime;
use std::sync::{Arc, Mutex};
use tokio::sync::broadcast;
use crate::agent_log::AgentLogWriter;
use super::{AgentEvent, TokenUsage};
/// Context passed to a runtime when launching an agent session.
pub struct RuntimeContext {
pub story_id: String,
pub agent_name: String,
pub command: String,
pub args: Vec<String>,
pub prompt: String,
pub cwd: String,
pub inactivity_timeout_secs: u64,
/// Port of the storkit MCP server, used by API-based runtimes (Gemini, OpenAI)
/// to call back for tool execution.
pub mcp_port: u16,
}
/// Result returned by a runtime after the agent session completes.
pub struct RuntimeResult {
pub session_id: Option<String>,
pub token_usage: Option<TokenUsage>,
}
/// Runtime status reported by the backend.
#[derive(Debug, Clone, PartialEq)]
#[allow(dead_code)]
pub enum RuntimeStatus {
Idle,
Running,
Completed,
Failed,
}
/// Abstraction over different agent execution backends.
///
/// Implementations:
/// - [`ClaudeCodeRuntime`]: spawns the `claude` CLI via a PTY (default, `runtime = "claude-code"`)
///
/// Future implementations could include OpenAI and Gemini API runtimes.
#[allow(dead_code)]
pub trait AgentRuntime: Send + Sync {
/// Start the agent and drive it to completion, streaming events through
/// the provided broadcast sender and event log.
///
/// Returns when the agent session finishes (success or error).
async fn start(
&self,
ctx: RuntimeContext,
tx: broadcast::Sender<AgentEvent>,
event_log: Arc<Mutex<Vec<AgentEvent>>>,
log_writer: Option<Arc<Mutex<AgentLogWriter>>>,
) -> Result<RuntimeResult, String>;
/// Stop the running agent.
fn stop(&self);
/// Get the current runtime status.
fn get_status(&self) -> RuntimeStatus;
/// Return any events buffered outside the broadcast channel.
///
/// PTY-based runtimes stream directly to the broadcast channel; this
/// returns empty by default. API-based runtimes may buffer events here.
fn stream_events(&self) -> Vec<AgentEvent> {
vec![]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn runtime_context_fields() {
let ctx = RuntimeContext {
story_id: "42_story_foo".to_string(),
agent_name: "coder-1".to_string(),
command: "claude".to_string(),
args: vec!["--model".to_string(), "sonnet".to_string()],
prompt: "Do the thing".to_string(),
cwd: "/tmp/wt".to_string(),
inactivity_timeout_secs: 300,
mcp_port: 3001,
};
assert_eq!(ctx.story_id, "42_story_foo");
assert_eq!(ctx.agent_name, "coder-1");
assert_eq!(ctx.command, "claude");
assert_eq!(ctx.args.len(), 2);
assert_eq!(ctx.prompt, "Do the thing");
assert_eq!(ctx.cwd, "/tmp/wt");
assert_eq!(ctx.inactivity_timeout_secs, 300);
assert_eq!(ctx.mcp_port, 3001);
}
#[test]
fn runtime_result_fields() {
let result = RuntimeResult {
session_id: Some("sess-123".to_string()),
token_usage: Some(TokenUsage {
input_tokens: 100,
output_tokens: 50,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
total_cost_usd: 0.01,
}),
};
assert_eq!(result.session_id, Some("sess-123".to_string()));
assert!(result.token_usage.is_some());
let usage = result.token_usage.unwrap();
assert_eq!(usage.input_tokens, 100);
assert_eq!(usage.output_tokens, 50);
assert_eq!(usage.total_cost_usd, 0.01);
}
#[test]
fn runtime_result_no_usage() {
let result = RuntimeResult {
session_id: None,
token_usage: None,
};
assert!(result.session_id.is_none());
assert!(result.token_usage.is_none());
}
#[test]
fn runtime_status_variants() {
assert_eq!(RuntimeStatus::Idle, RuntimeStatus::Idle);
assert_ne!(RuntimeStatus::Running, RuntimeStatus::Completed);
assert_ne!(RuntimeStatus::Failed, RuntimeStatus::Idle);
}
#[test]
fn claude_code_runtime_get_status_returns_idle() {
use std::collections::HashMap;
use crate::io::watcher::WatcherEvent;
let killers = Arc::new(Mutex::new(HashMap::new()));
let (watcher_tx, _) = broadcast::channel::<WatcherEvent>(16);
let runtime = ClaudeCodeRuntime::new(killers, watcher_tx);
assert_eq!(runtime.get_status(), RuntimeStatus::Idle);
}
#[test]
fn claude_code_runtime_stream_events_empty() {
use std::collections::HashMap;
use crate::io::watcher::WatcherEvent;
let killers = Arc::new(Mutex::new(HashMap::new()));
let (watcher_tx, _) = broadcast::channel::<WatcherEvent>(16);
let runtime = ClaudeCodeRuntime::new(killers, watcher_tx);
assert!(runtime.stream_events().is_empty());
}
}
+704
View File
@@ -0,0 +1,704 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use reqwest::Client;
use serde_json::{json, Value};
use tokio::sync::broadcast;
use crate::agent_log::AgentLogWriter;
use crate::slog;
use super::super::{AgentEvent, TokenUsage};
use super::{AgentRuntime, RuntimeContext, RuntimeResult, RuntimeStatus};
// ── Public runtime struct ────────────────────────────────────────────
/// Agent runtime that drives an OpenAI model (GPT-4o, o3, etc.) through
/// the OpenAI Chat Completions API.
///
/// The runtime:
/// 1. Fetches MCP tool definitions from storkit's MCP server.
/// 2. Converts them to OpenAI function-calling format.
/// 3. Sends the agent prompt + tools to the Chat Completions API.
/// 4. Executes any requested tool calls via MCP `tools/call`.
/// 5. Loops until the model produces a response with no tool calls.
/// 6. Tracks token usage from the API response.
pub struct OpenAiRuntime {
/// Whether a stop has been requested.
cancelled: Arc<AtomicBool>,
}
impl OpenAiRuntime {
pub fn new() -> Self {
Self {
cancelled: Arc::new(AtomicBool::new(false)),
}
}
}
impl AgentRuntime for OpenAiRuntime {
async fn start(
&self,
ctx: RuntimeContext,
tx: broadcast::Sender<AgentEvent>,
event_log: Arc<Mutex<Vec<AgentEvent>>>,
log_writer: Option<Arc<Mutex<AgentLogWriter>>>,
) -> Result<RuntimeResult, String> {
let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
"OPENAI_API_KEY environment variable is not set. \
Set it to your OpenAI API key to use the OpenAI runtime."
.to_string()
})?;
let model = if ctx.command.starts_with("gpt") || ctx.command.starts_with("o") {
// The pool puts the model into `command` for non-CLI runtimes.
ctx.command.clone()
} else {
// Fall back to args: look for --model <value>
ctx.args
.iter()
.position(|a| a == "--model")
.and_then(|i| ctx.args.get(i + 1))
.cloned()
.unwrap_or_else(|| "gpt-4o".to_string())
};
let mcp_port = ctx.mcp_port;
let mcp_base = format!("http://localhost:{mcp_port}/mcp");
let client = Client::new();
let cancelled = Arc::clone(&self.cancelled);
// Step 1: Fetch MCP tool definitions and convert to OpenAI format.
let openai_tools = fetch_and_convert_mcp_tools(&client, &mcp_base).await?;
// Step 2: Build the initial conversation messages.
let system_text = build_system_text(&ctx);
let mut messages: Vec<Value> = vec![
json!({ "role": "system", "content": system_text }),
json!({ "role": "user", "content": ctx.prompt }),
];
let mut total_usage = TokenUsage {
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
total_cost_usd: 0.0,
};
let emit = |event: AgentEvent| {
super::super::pty::emit_event(
event,
&tx,
&event_log,
log_writer.as_ref().map(|w| w.as_ref()),
);
};
emit(AgentEvent::Status {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
status: "running".to_string(),
});
// Step 3: Conversation loop.
let mut turn = 0u32;
let max_turns = 200; // Safety limit
loop {
if cancelled.load(Ordering::Relaxed) {
emit(AgentEvent::Error {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
message: "Agent was stopped by user".to_string(),
});
return Ok(RuntimeResult {
session_id: None,
token_usage: Some(total_usage),
});
}
turn += 1;
if turn > max_turns {
emit(AgentEvent::Error {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
message: format!("Exceeded maximum turns ({max_turns})"),
});
return Ok(RuntimeResult {
session_id: None,
token_usage: Some(total_usage),
});
}
slog!(
"[openai] Turn {turn} for {}:{}",
ctx.story_id,
ctx.agent_name
);
let mut request_body = json!({
"model": model,
"messages": messages,
"temperature": 0.2,
});
if !openai_tools.is_empty() {
request_body["tools"] = json!(openai_tools);
}
let response = client
.post("https://api.openai.com/v1/chat/completions")
.bearer_auth(&api_key)
.json(&request_body)
.send()
.await
.map_err(|e| format!("OpenAI API request failed: {e}"))?;
let status = response.status();
let body: Value = response
.json()
.await
.map_err(|e| format!("Failed to parse OpenAI API response: {e}"))?;
if !status.is_success() {
let error_msg = body["error"]["message"]
.as_str()
.unwrap_or("Unknown API error");
let err = format!("OpenAI API error ({status}): {error_msg}");
emit(AgentEvent::Error {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
message: err.clone(),
});
return Err(err);
}
// Accumulate token usage.
if let Some(usage) = parse_usage(&body) {
total_usage.input_tokens += usage.input_tokens;
total_usage.output_tokens += usage.output_tokens;
}
// Extract the first choice.
let choice = body["choices"]
.as_array()
.and_then(|c| c.first())
.ok_or_else(|| "No choices in OpenAI response".to_string())?;
let message = &choice["message"];
let content = message["content"].as_str().unwrap_or("");
// Emit any text content.
if !content.is_empty() {
emit(AgentEvent::Output {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
text: content.to_string(),
});
}
// Check for tool calls.
let tool_calls = message["tool_calls"].as_array();
if tool_calls.is_none() || tool_calls.is_some_and(|tc| tc.is_empty()) {
// No tool calls — model is done.
emit(AgentEvent::Done {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
session_id: None,
});
return Ok(RuntimeResult {
session_id: None,
token_usage: Some(total_usage),
});
}
let tool_calls = tool_calls.unwrap();
// Add the assistant message (with tool_calls) to the conversation.
messages.push(message.clone());
// Execute each tool call via MCP and add results.
for tc in tool_calls {
if cancelled.load(Ordering::Relaxed) {
break;
}
let call_id = tc["id"].as_str().unwrap_or("");
let function = &tc["function"];
let tool_name = function["name"].as_str().unwrap_or("");
let arguments_str = function["arguments"].as_str().unwrap_or("{}");
let args: Value = serde_json::from_str(arguments_str).unwrap_or(json!({}));
slog!(
"[openai] Calling MCP tool '{}' for {}:{}",
tool_name,
ctx.story_id,
ctx.agent_name
);
emit(AgentEvent::Output {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
text: format!("\n[Tool call: {tool_name}]\n"),
});
let tool_result = call_mcp_tool(&client, &mcp_base, tool_name, &args).await;
let result_content = match &tool_result {
Ok(result) => {
emit(AgentEvent::Output {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
text: format!("[Tool result: {} chars]\n", result.len()),
});
result.clone()
}
Err(e) => {
emit(AgentEvent::Output {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
text: format!("[Tool error: {e}]\n"),
});
format!("Error: {e}")
}
};
// OpenAI expects tool results as role=tool messages with
// the matching tool_call_id.
messages.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": result_content,
}));
}
}
}
fn stop(&self) {
self.cancelled.store(true, Ordering::Relaxed);
}
fn get_status(&self) -> RuntimeStatus {
if self.cancelled.load(Ordering::Relaxed) {
RuntimeStatus::Failed
} else {
RuntimeStatus::Idle
}
}
}
// ── Helper functions ─────────────────────────────────────────────────
/// Build the system message text from the RuntimeContext.
fn build_system_text(ctx: &RuntimeContext) -> String {
ctx.args
.iter()
.position(|a| a == "--append-system-prompt")
.and_then(|i| ctx.args.get(i + 1))
.cloned()
.unwrap_or_else(|| {
format!(
"You are an AI coding agent working on story {}. \
You have access to tools via function calling. \
Use them to complete the task. \
Work in the directory: {}",
ctx.story_id, ctx.cwd
)
})
}
/// Fetch MCP tool definitions from storkit's MCP server and convert
/// them to OpenAI function-calling format.
async fn fetch_and_convert_mcp_tools(
client: &Client,
mcp_base: &str,
) -> Result<Vec<Value>, String> {
let request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list",
"params": {}
});
let response = client
.post(mcp_base)
.json(&request)
.send()
.await
.map_err(|e| format!("Failed to fetch MCP tools: {e}"))?;
let body: Value = response
.json()
.await
.map_err(|e| format!("Failed to parse MCP tools response: {e}"))?;
let tools = body["result"]["tools"]
.as_array()
.ok_or_else(|| "No tools array in MCP response".to_string())?;
let mut openai_tools = Vec::new();
for tool in tools {
let name = tool["name"].as_str().unwrap_or("").to_string();
let description = tool["description"].as_str().unwrap_or("").to_string();
if name.is_empty() {
continue;
}
// OpenAI function calling uses JSON Schema natively for parameters,
// so the MCP inputSchema can be used with minimal cleanup.
let parameters = convert_mcp_schema_to_openai(tool.get("inputSchema"));
openai_tools.push(json!({
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": parameters.unwrap_or_else(|| json!({"type": "object", "properties": {}})),
}
}));
}
slog!(
"[openai] Loaded {} MCP tools as function definitions",
openai_tools.len()
);
Ok(openai_tools)
}
/// Convert an MCP inputSchema (JSON Schema) to OpenAI-compatible
/// function parameters.
///
/// OpenAI uses JSON Schema natively, so less transformation is needed
/// compared to Gemini. We still strip `$schema` to keep payloads clean.
fn convert_mcp_schema_to_openai(schema: Option<&Value>) -> Option<Value> {
let schema = schema?;
let mut result = json!({
"type": "object",
});
if let Some(properties) = schema.get("properties") {
result["properties"] = clean_schema_properties(properties);
} else {
result["properties"] = json!({});
}
if let Some(required) = schema.get("required") {
result["required"] = required.clone();
}
// OpenAI recommends additionalProperties: false for strict mode.
result["additionalProperties"] = json!(false);
Some(result)
}
/// Recursively clean schema properties, removing unsupported keywords.
fn clean_schema_properties(properties: &Value) -> Value {
let Some(obj) = properties.as_object() else {
return properties.clone();
};
let mut cleaned = serde_json::Map::new();
for (key, value) in obj {
let mut prop = value.clone();
if let Some(p) = prop.as_object_mut() {
p.remove("$schema");
// Recursively clean nested object properties.
if let Some(nested_props) = p.get("properties").cloned() {
p.insert(
"properties".to_string(),
clean_schema_properties(&nested_props),
);
}
// Clean items schema for arrays.
if let Some(items) = p.get("items").cloned()
&& let Some(items_obj) = items.as_object()
{
let mut cleaned_items = items_obj.clone();
cleaned_items.remove("$schema");
p.insert("items".to_string(), Value::Object(cleaned_items));
}
}
cleaned.insert(key.clone(), prop);
}
Value::Object(cleaned)
}
/// Call an MCP tool via storkit's MCP server.
async fn call_mcp_tool(
client: &Client,
mcp_base: &str,
tool_name: &str,
args: &Value,
) -> Result<String, String> {
let request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": args
}
});
let response = client
.post(mcp_base)
.json(&request)
.send()
.await
.map_err(|e| format!("MCP tool call failed: {e}"))?;
let body: Value = response
.json()
.await
.map_err(|e| format!("Failed to parse MCP tool response: {e}"))?;
if let Some(error) = body.get("error") {
let msg = error["message"].as_str().unwrap_or("Unknown MCP error");
return Err(format!("MCP tool '{tool_name}' error: {msg}"));
}
// MCP tools/call returns { result: { content: [{ type: "text", text: "..." }] } }
let content = &body["result"]["content"];
if let Some(arr) = content.as_array() {
let texts: Vec<&str> = arr
.iter()
.filter_map(|c| c["text"].as_str())
.collect();
if !texts.is_empty() {
return Ok(texts.join("\n"));
}
}
// Fall back to serializing the entire result.
Ok(body["result"].to_string())
}
/// Parse token usage from an OpenAI API response.
fn parse_usage(response: &Value) -> Option<TokenUsage> {
let usage = response.get("usage")?;
Some(TokenUsage {
input_tokens: usage
.get("prompt_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0),
output_tokens: usage
.get("completion_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0),
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
// OpenAI API doesn't report cost directly; leave at 0.
total_cost_usd: 0.0,
})
}
// ── Tests ────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn convert_mcp_schema_simple_object() {
let schema = json!({
"type": "object",
"properties": {
"story_id": {
"type": "string",
"description": "Story identifier"
}
},
"required": ["story_id"]
});
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
assert_eq!(result["type"], "object");
assert!(result["properties"]["story_id"].is_object());
assert_eq!(result["required"][0], "story_id");
assert_eq!(result["additionalProperties"], false);
}
#[test]
fn convert_mcp_schema_empty_properties() {
let schema = json!({
"type": "object",
"properties": {}
});
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
assert_eq!(result["type"], "object");
assert!(result["properties"].as_object().unwrap().is_empty());
}
#[test]
fn convert_mcp_schema_none_returns_none() {
assert!(convert_mcp_schema_to_openai(None).is_none());
}
#[test]
fn convert_mcp_schema_strips_dollar_schema() {
let schema = json!({
"type": "object",
"properties": {
"name": {
"type": "string",
"$schema": "http://json-schema.org/draft-07/schema#"
}
}
});
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
let name_prop = &result["properties"]["name"];
assert!(name_prop.get("$schema").is_none());
assert_eq!(name_prop["type"], "string");
}
#[test]
fn convert_mcp_schema_with_nested_objects() {
let schema = json!({
"type": "object",
"properties": {
"config": {
"type": "object",
"properties": {
"key": { "type": "string" }
}
}
}
});
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
assert!(result["properties"]["config"]["properties"]["key"].is_object());
}
#[test]
fn convert_mcp_schema_with_array_items() {
let schema = json!({
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" }
},
"$schema": "http://json-schema.org/draft-07/schema#"
}
}
}
});
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
let items_schema = &result["properties"]["items"]["items"];
assert!(items_schema.get("$schema").is_none());
}
#[test]
fn build_system_text_uses_args() {
let ctx = RuntimeContext {
story_id: "42_story_test".to_string(),
agent_name: "coder-1".to_string(),
command: "gpt-4o".to_string(),
args: vec![
"--append-system-prompt".to_string(),
"Custom system prompt".to_string(),
],
prompt: "Do the thing".to_string(),
cwd: "/tmp/wt".to_string(),
inactivity_timeout_secs: 300,
mcp_port: 3001,
};
assert_eq!(build_system_text(&ctx), "Custom system prompt");
}
#[test]
fn build_system_text_default() {
let ctx = RuntimeContext {
story_id: "42_story_test".to_string(),
agent_name: "coder-1".to_string(),
command: "gpt-4o".to_string(),
args: vec![],
prompt: "Do the thing".to_string(),
cwd: "/tmp/wt".to_string(),
inactivity_timeout_secs: 300,
mcp_port: 3001,
};
let text = build_system_text(&ctx);
assert!(text.contains("42_story_test"));
assert!(text.contains("/tmp/wt"));
}
#[test]
fn parse_usage_valid() {
let response = json!({
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150
}
});
let usage = parse_usage(&response).unwrap();
assert_eq!(usage.input_tokens, 100);
assert_eq!(usage.output_tokens, 50);
assert_eq!(usage.cache_creation_input_tokens, 0);
assert_eq!(usage.total_cost_usd, 0.0);
}
#[test]
fn parse_usage_missing() {
let response = json!({"choices": []});
assert!(parse_usage(&response).is_none());
}
#[test]
fn openai_runtime_stop_sets_cancelled() {
let runtime = OpenAiRuntime::new();
assert_eq!(runtime.get_status(), RuntimeStatus::Idle);
runtime.stop();
assert_eq!(runtime.get_status(), RuntimeStatus::Failed);
}
#[test]
fn model_extraction_from_command_gpt() {
let ctx = RuntimeContext {
story_id: "1".to_string(),
agent_name: "coder".to_string(),
command: "gpt-4o".to_string(),
args: vec![],
prompt: "test".to_string(),
cwd: "/tmp".to_string(),
inactivity_timeout_secs: 300,
mcp_port: 3001,
};
assert!(ctx.command.starts_with("gpt"));
}
#[test]
fn model_extraction_from_command_o3() {
let ctx = RuntimeContext {
story_id: "1".to_string(),
agent_name: "coder".to_string(),
command: "o3".to_string(),
args: vec![],
prompt: "test".to_string(),
cwd: "/tmp".to_string(),
inactivity_timeout_secs: 300,
mcp_port: 3001,
};
assert!(ctx.command.starts_with("o"));
}
}
+202
View File
@@ -0,0 +1,202 @@
use std::fs;
use std::path::Path;
use chrono::Utc;
use serde::{Deserialize, Serialize};
use super::TokenUsage;
/// A single token usage record persisted to disk.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TokenUsageRecord {
pub story_id: String,
pub agent_name: String,
pub timestamp: String,
#[serde(default)]
pub model: Option<String>,
pub usage: TokenUsage,
}
/// Append a token usage record to the persistent JSONL file.
///
/// Each line is a self-contained JSON object, making appends atomic and
/// reads simple. The file lives at `.storkit/token_usage.jsonl`.
pub fn append_record(project_root: &Path, record: &TokenUsageRecord) -> Result<(), String> {
let path = token_usage_path(project_root);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.map_err(|e| format!("Failed to create token_usage directory: {e}"))?;
}
let mut line =
serde_json::to_string(record).map_err(|e| format!("Failed to serialize record: {e}"))?;
line.push('\n');
use std::io::Write;
let file = fs::OpenOptions::new()
.create(true)
.append(true)
.open(&path)
.map_err(|e| format!("Failed to open token_usage file: {e}"))?;
let mut writer = std::io::BufWriter::new(file);
writer
.write_all(line.as_bytes())
.map_err(|e| format!("Failed to write token_usage record: {e}"))?;
writer
.flush()
.map_err(|e| format!("Failed to flush token_usage file: {e}"))?;
Ok(())
}
/// Read all token usage records from the persistent file.
pub fn read_all(project_root: &Path) -> Result<Vec<TokenUsageRecord>, String> {
let path = token_usage_path(project_root);
if !path.exists() {
return Ok(Vec::new());
}
let content =
fs::read_to_string(&path).map_err(|e| format!("Failed to read token_usage file: {e}"))?;
let mut records = Vec::new();
for line in content.lines() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
match serde_json::from_str::<TokenUsageRecord>(trimmed) {
Ok(record) => records.push(record),
Err(e) => {
crate::slog_warn!("[token_usage] Skipping malformed line: {e}");
}
}
}
Ok(records)
}
/// Build a `TokenUsageRecord` from the parts available at completion time.
pub fn build_record(
story_id: &str,
agent_name: &str,
model: Option<String>,
usage: TokenUsage,
) -> TokenUsageRecord {
TokenUsageRecord {
story_id: story_id.to_string(),
agent_name: agent_name.to_string(),
timestamp: Utc::now().to_rfc3339(),
model,
usage,
}
}
fn token_usage_path(project_root: &Path) -> std::path::PathBuf {
project_root.join(".storkit").join("token_usage.jsonl")
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn sample_usage() -> TokenUsage {
TokenUsage {
input_tokens: 100,
output_tokens: 200,
cache_creation_input_tokens: 5000,
cache_read_input_tokens: 10000,
total_cost_usd: 1.57,
}
}
#[test]
fn append_and_read_roundtrip() {
let dir = TempDir::new().unwrap();
let root = dir.path();
let record = build_record("42_story_foo", "coder-1", None, sample_usage());
append_record(root, &record).unwrap();
let records = read_all(root).unwrap();
assert_eq!(records.len(), 1);
assert_eq!(records[0].story_id, "42_story_foo");
assert_eq!(records[0].agent_name, "coder-1");
assert_eq!(records[0].usage, sample_usage());
}
#[test]
fn multiple_appends_accumulate() {
let dir = TempDir::new().unwrap();
let root = dir.path();
let r1 = build_record("s1", "coder-1", None, sample_usage());
let r2 = build_record("s2", "coder-2", None, sample_usage());
append_record(root, &r1).unwrap();
append_record(root, &r2).unwrap();
let records = read_all(root).unwrap();
assert_eq!(records.len(), 2);
assert_eq!(records[0].story_id, "s1");
assert_eq!(records[1].story_id, "s2");
}
#[test]
fn read_empty_returns_empty() {
let dir = TempDir::new().unwrap();
let records = read_all(dir.path()).unwrap();
assert!(records.is_empty());
}
#[test]
fn malformed_lines_are_skipped() {
let dir = TempDir::new().unwrap();
let root = dir.path();
let path = root.join(".storkit").join("token_usage.jsonl");
fs::create_dir_all(path.parent().unwrap()).unwrap();
fs::write(&path, "not json\n{\"bad\":true}\n").unwrap();
let records = read_all(root).unwrap();
assert!(records.is_empty());
}
#[test]
fn token_usage_from_result_event() {
let json = serde_json::json!({
"type": "result",
"total_cost_usd": 1.57,
"usage": {
"input_tokens": 7,
"output_tokens": 475,
"cache_creation_input_tokens": 185020,
"cache_read_input_tokens": 810585
}
});
let usage = TokenUsage::from_result_event(&json).unwrap();
assert_eq!(usage.input_tokens, 7);
assert_eq!(usage.output_tokens, 475);
assert_eq!(usage.cache_creation_input_tokens, 185020);
assert_eq!(usage.cache_read_input_tokens, 810585);
assert!((usage.total_cost_usd - 1.57).abs() < f64::EPSILON);
}
#[test]
fn token_usage_from_result_event_missing_usage() {
let json = serde_json::json!({"type": "result"});
assert!(TokenUsage::from_result_event(&json).is_none());
}
#[test]
fn token_usage_from_result_event_partial_fields() {
let json = serde_json::json!({
"type": "result",
"total_cost_usd": 0.5,
"usage": {
"input_tokens": 10,
"output_tokens": 20
}
});
let usage = TokenUsage::from_result_event(&json).unwrap();
assert_eq!(usage.input_tokens, 10);
assert_eq!(usage.output_tokens, 20);
assert_eq!(usage.cache_creation_input_tokens, 0);
assert_eq!(usage.cache_read_input_tokens, 0);
}
}