storkit: create 365_story_surface_api_rate_limit_warnings_in_chat
This commit is contained in:
@@ -1,426 +0,0 @@
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
use std::time::Duration;
|
||||
use wait_timeout::ChildExt;
|
||||
|
||||
/// Maximum time any single test command is allowed to run before being killed.
|
||||
const TEST_TIMEOUT: Duration = Duration::from_secs(600); // 10 minutes
|
||||
|
||||
/// Detect whether the base branch in a worktree is `master` or `main`.
|
||||
/// Falls back to `"master"` if neither is found.
|
||||
pub(crate) fn detect_worktree_base_branch(wt_path: &Path) -> String {
|
||||
for branch in &["master", "main"] {
|
||||
let ok = Command::new("git")
|
||||
.args(["rev-parse", "--verify", branch])
|
||||
.current_dir(wt_path)
|
||||
.output()
|
||||
.map(|o| o.status.success())
|
||||
.unwrap_or(false);
|
||||
if ok {
|
||||
return branch.to_string();
|
||||
}
|
||||
}
|
||||
"master".to_string()
|
||||
}
|
||||
|
||||
/// Return `true` if the git worktree at `wt_path` has commits on its current
|
||||
/// branch that are not present on the base branch (`master` or `main`).
|
||||
///
|
||||
/// Used during server startup reconciliation to detect stories whose agent work
|
||||
/// was committed while the server was offline.
|
||||
pub(crate) fn worktree_has_committed_work(wt_path: &Path) -> bool {
|
||||
let base_branch = detect_worktree_base_branch(wt_path);
|
||||
let output = Command::new("git")
|
||||
.args(["log", &format!("{base_branch}..HEAD"), "--oneline"])
|
||||
.current_dir(wt_path)
|
||||
.output();
|
||||
match output {
|
||||
Ok(out) if out.status.success() => {
|
||||
!String::from_utf8_lossy(&out.stdout).trim().is_empty()
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check whether the given directory has any uncommitted git changes.
|
||||
/// Returns `Err` with a descriptive message if there are any.
|
||||
pub(crate) fn check_uncommitted_changes(path: &Path) -> Result<(), String> {
|
||||
let output = Command::new("git")
|
||||
.args(["status", "--porcelain"])
|
||||
.current_dir(path)
|
||||
.output()
|
||||
.map_err(|e| format!("Failed to run git status: {e}"))?;
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
if !stdout.trim().is_empty() {
|
||||
return Err(format!(
|
||||
"Worktree has uncommitted changes. Please commit all work before \
|
||||
the agent exits:\n{stdout}"
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run the project's test suite.
|
||||
///
|
||||
/// Uses `script/test` if present, treating it as the canonical single test entry point.
|
||||
/// Falls back to `cargo nextest run` / `cargo test` when `script/test` is absent.
|
||||
/// Returns `(tests_passed, output)`.
|
||||
pub(crate) fn run_project_tests(path: &Path) -> Result<(bool, String), String> {
|
||||
let script_test = path.join("script").join("test");
|
||||
if script_test.exists() {
|
||||
let mut output = String::from("=== script/test ===\n");
|
||||
let (success, out) = run_command_with_timeout(&script_test, &[], path)?;
|
||||
output.push_str(&out);
|
||||
output.push('\n');
|
||||
return Ok((success, output));
|
||||
}
|
||||
|
||||
// Fallback: cargo nextest run / cargo test
|
||||
let mut output = String::from("=== tests ===\n");
|
||||
let (success, test_out) = match run_command_with_timeout("cargo", &["nextest", "run"], path) {
|
||||
Ok(result) => result,
|
||||
Err(_) => {
|
||||
// nextest not available — fall back to cargo test
|
||||
run_command_with_timeout("cargo", &["test"], path)
|
||||
.map_err(|e| format!("Failed to run cargo test: {e}"))?
|
||||
}
|
||||
};
|
||||
output.push_str(&test_out);
|
||||
output.push('\n');
|
||||
Ok((success, output))
|
||||
}
|
||||
|
||||
/// Run a command with a timeout. Returns `(success, combined_output)`.
|
||||
/// Kills the child process if it exceeds `TEST_TIMEOUT`.
|
||||
///
|
||||
/// Stdout and stderr are drained in background threads to avoid a pipe-buffer
|
||||
/// deadlock: if the child fills the 64 KB OS pipe buffer while the parent
|
||||
/// blocks on `waitpid`, neither side can make progress.
|
||||
fn run_command_with_timeout(
|
||||
program: impl AsRef<std::ffi::OsStr>,
|
||||
args: &[&str],
|
||||
dir: &Path,
|
||||
) -> Result<(bool, String), String> {
|
||||
let mut child = Command::new(program)
|
||||
.args(args)
|
||||
.current_dir(dir)
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| format!("Failed to spawn command: {e}"))?;
|
||||
|
||||
// Drain stdout/stderr in background threads so the pipe buffers never fill.
|
||||
let stdout_handle = child.stdout.take().map(|r| {
|
||||
std::thread::spawn(move || {
|
||||
let mut s = String::new();
|
||||
let mut r = r;
|
||||
std::io::Read::read_to_string(&mut r, &mut s).ok();
|
||||
s
|
||||
})
|
||||
});
|
||||
let stderr_handle = child.stderr.take().map(|r| {
|
||||
std::thread::spawn(move || {
|
||||
let mut s = String::new();
|
||||
let mut r = r;
|
||||
std::io::Read::read_to_string(&mut r, &mut s).ok();
|
||||
s
|
||||
})
|
||||
});
|
||||
|
||||
match child.wait_timeout(TEST_TIMEOUT) {
|
||||
Ok(Some(status)) => {
|
||||
let stdout = stdout_handle
|
||||
.and_then(|h| h.join().ok())
|
||||
.unwrap_or_default();
|
||||
let stderr = stderr_handle
|
||||
.and_then(|h| h.join().ok())
|
||||
.unwrap_or_default();
|
||||
Ok((status.success(), format!("{stdout}{stderr}")))
|
||||
}
|
||||
Ok(None) => {
|
||||
// Timed out — kill the child.
|
||||
let _ = child.kill();
|
||||
let _ = child.wait();
|
||||
Err(format!(
|
||||
"Command timed out after {} seconds",
|
||||
TEST_TIMEOUT.as_secs()
|
||||
))
|
||||
}
|
||||
Err(e) => Err(format!("Failed to wait for command: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run `cargo clippy` and the project test suite (via `script/test` if present,
|
||||
/// otherwise `cargo nextest run` / `cargo test`) in the given directory.
|
||||
/// Returns `(gates_passed, combined_output)`.
|
||||
pub(crate) fn run_acceptance_gates(path: &Path) -> Result<(bool, String), String> {
|
||||
let mut all_output = String::new();
|
||||
let mut all_passed = true;
|
||||
|
||||
// ── cargo clippy ──────────────────────────────────────────────
|
||||
let clippy = Command::new("cargo")
|
||||
.args(["clippy", "--all-targets", "--all-features"])
|
||||
.current_dir(path)
|
||||
.output()
|
||||
.map_err(|e| format!("Failed to run cargo clippy: {e}"))?;
|
||||
|
||||
all_output.push_str("=== cargo clippy ===\n");
|
||||
let clippy_stdout = String::from_utf8_lossy(&clippy.stdout);
|
||||
let clippy_stderr = String::from_utf8_lossy(&clippy.stderr);
|
||||
if !clippy_stdout.is_empty() {
|
||||
all_output.push_str(&clippy_stdout);
|
||||
}
|
||||
if !clippy_stderr.is_empty() {
|
||||
all_output.push_str(&clippy_stderr);
|
||||
}
|
||||
all_output.push('\n');
|
||||
|
||||
if !clippy.status.success() {
|
||||
all_passed = false;
|
||||
}
|
||||
|
||||
// ── tests (script/test if available, else cargo nextest/test) ─
|
||||
let (test_success, test_out) = run_project_tests(path)?;
|
||||
all_output.push_str(&test_out);
|
||||
if !test_success {
|
||||
all_passed = false;
|
||||
}
|
||||
|
||||
Ok((all_passed, all_output))
|
||||
}
|
||||
|
||||
/// Run `script/test_coverage` in the given directory if the script exists.
|
||||
///
|
||||
/// Used as a QA gate before advancing a story from `3_qa/` to `4_merge/`.
|
||||
/// Returns `(passed, output)`. If the script does not exist, returns `(true, …)`.
|
||||
pub(crate) fn run_coverage_gate(path: &Path) -> Result<(bool, String), String> {
|
||||
let script = path.join("script").join("test_coverage");
|
||||
if !script.exists() {
|
||||
return Ok((
|
||||
true,
|
||||
"script/test_coverage not found; coverage gate skipped.\n".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut output = String::from("=== script/test_coverage ===\n");
|
||||
let result = Command::new(&script)
|
||||
.current_dir(path)
|
||||
.output()
|
||||
.map_err(|e| format!("Failed to run script/test_coverage: {e}"))?;
|
||||
|
||||
let combined = format!(
|
||||
"{}{}",
|
||||
String::from_utf8_lossy(&result.stdout),
|
||||
String::from_utf8_lossy(&result.stderr)
|
||||
);
|
||||
output.push_str(&combined);
|
||||
output.push('\n');
|
||||
|
||||
Ok((result.status.success(), output))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn init_git_repo(repo: &std::path::Path) {
|
||||
Command::new("git")
|
||||
.args(["init"])
|
||||
.current_dir(repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
Command::new("git")
|
||||
.args(["config", "user.email", "test@test.com"])
|
||||
.current_dir(repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
Command::new("git")
|
||||
.args(["config", "user.name", "Test"])
|
||||
.current_dir(repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
Command::new("git")
|
||||
.args(["commit", "--allow-empty", "-m", "init"])
|
||||
.current_dir(repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// ── run_project_tests tests ───────────────────────────────────
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn run_project_tests_uses_script_test_when_present_and_passes() {
|
||||
use std::fs;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let path = tmp.path();
|
||||
let script_dir = path.join("script");
|
||||
fs::create_dir_all(&script_dir).unwrap();
|
||||
let script_test = script_dir.join("test");
|
||||
fs::write(&script_test, "#!/usr/bin/env bash\necho 'all tests passed'\nexit 0\n").unwrap();
|
||||
let mut perms = fs::metadata(&script_test).unwrap().permissions();
|
||||
perms.set_mode(0o755);
|
||||
fs::set_permissions(&script_test, perms).unwrap();
|
||||
|
||||
let (passed, output) = run_project_tests(path).unwrap();
|
||||
assert!(passed, "script/test exiting 0 should pass");
|
||||
assert!(output.contains("script/test"), "output should mention script/test");
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn run_project_tests_reports_failure_when_script_test_exits_nonzero() {
|
||||
use std::fs;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let path = tmp.path();
|
||||
let script_dir = path.join("script");
|
||||
fs::create_dir_all(&script_dir).unwrap();
|
||||
let script_test = script_dir.join("test");
|
||||
fs::write(&script_test, "#!/usr/bin/env bash\nexit 1\n").unwrap();
|
||||
let mut perms = fs::metadata(&script_test).unwrap().permissions();
|
||||
perms.set_mode(0o755);
|
||||
fs::set_permissions(&script_test, perms).unwrap();
|
||||
|
||||
let (passed, output) = run_project_tests(path).unwrap();
|
||||
assert!(!passed, "script/test exiting 1 should fail");
|
||||
assert!(output.contains("script/test"), "output should mention script/test");
|
||||
}
|
||||
|
||||
// ── run_coverage_gate tests ───────────────────────────────────────────────
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn coverage_gate_passes_when_script_absent() {
|
||||
use tempfile::tempdir;
|
||||
let tmp = tempdir().unwrap();
|
||||
let (passed, output) = run_coverage_gate(tmp.path()).unwrap();
|
||||
assert!(passed, "coverage gate should pass when script is absent");
|
||||
assert!(
|
||||
output.contains("not found"),
|
||||
"output should mention script not found"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn coverage_gate_passes_when_script_exits_zero() {
|
||||
use std::fs;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let path = tmp.path();
|
||||
let script_dir = path.join("script");
|
||||
fs::create_dir_all(&script_dir).unwrap();
|
||||
let script = script_dir.join("test_coverage");
|
||||
fs::write(
|
||||
&script,
|
||||
"#!/usr/bin/env bash\necho 'Rust line coverage: 85%'\necho 'PASS: Coverage 85% meets threshold 0%'\nexit 0\n",
|
||||
)
|
||||
.unwrap();
|
||||
let mut perms = fs::metadata(&script).unwrap().permissions();
|
||||
perms.set_mode(0o755);
|
||||
fs::set_permissions(&script, perms).unwrap();
|
||||
|
||||
let (passed, output) = run_coverage_gate(path).unwrap();
|
||||
assert!(passed, "coverage gate should pass when script exits 0");
|
||||
assert!(
|
||||
output.contains("script/test_coverage"),
|
||||
"output should mention script/test_coverage"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn coverage_gate_fails_when_script_exits_nonzero() {
|
||||
use std::fs;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let path = tmp.path();
|
||||
let script_dir = path.join("script");
|
||||
fs::create_dir_all(&script_dir).unwrap();
|
||||
let script = script_dir.join("test_coverage");
|
||||
fs::write(
|
||||
&script,
|
||||
"#!/usr/bin/env bash\necho 'FAIL: Coverage 40% is below threshold 80%'\nexit 1\n",
|
||||
)
|
||||
.unwrap();
|
||||
let mut perms = fs::metadata(&script).unwrap().permissions();
|
||||
perms.set_mode(0o755);
|
||||
fs::set_permissions(&script, perms).unwrap();
|
||||
|
||||
let (passed, output) = run_coverage_gate(path).unwrap();
|
||||
assert!(!passed, "coverage gate should fail when script exits 1");
|
||||
assert!(
|
||||
output.contains("script/test_coverage"),
|
||||
"output should mention script/test_coverage"
|
||||
);
|
||||
}
|
||||
|
||||
// ── worktree_has_committed_work tests ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn worktree_has_committed_work_false_on_fresh_repo() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let repo = tmp.path();
|
||||
// init_git_repo creates the initial commit on the default branch.
|
||||
// HEAD IS the base branch — no commits ahead.
|
||||
init_git_repo(repo);
|
||||
assert!(!worktree_has_committed_work(repo));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn worktree_has_committed_work_true_after_commit_on_feature_branch() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let project_root = tmp.path().join("project");
|
||||
fs::create_dir_all(&project_root).unwrap();
|
||||
init_git_repo(&project_root);
|
||||
|
||||
// Create a git worktree on a feature branch.
|
||||
let wt_path = tmp.path().join("wt");
|
||||
Command::new("git")
|
||||
.args([
|
||||
"worktree",
|
||||
"add",
|
||||
&wt_path.to_string_lossy(),
|
||||
"-b",
|
||||
"feature/story-99_test",
|
||||
])
|
||||
.current_dir(&project_root)
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
// No commits on the feature branch yet — same as base branch.
|
||||
assert!(!worktree_has_committed_work(&wt_path));
|
||||
|
||||
// Add a commit to the feature branch in the worktree.
|
||||
fs::write(wt_path.join("work.txt"), "done").unwrap();
|
||||
Command::new("git")
|
||||
.args(["add", "."])
|
||||
.current_dir(&wt_path)
|
||||
.output()
|
||||
.unwrap();
|
||||
Command::new("git")
|
||||
.args([
|
||||
"-c",
|
||||
"user.email=test@test.com",
|
||||
"-c",
|
||||
"user.name=Test",
|
||||
"commit",
|
||||
"-m",
|
||||
"coder: implement story",
|
||||
])
|
||||
.current_dir(&wt_path)
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
// Now the feature branch is ahead of the base branch.
|
||||
assert!(worktree_has_committed_work(&wt_path));
|
||||
}
|
||||
}
|
||||
@@ -1,829 +0,0 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
|
||||
use crate::io::story_metadata::{clear_front_matter_field, write_rejection_notes};
|
||||
use crate::slog;
|
||||
|
||||
pub(super) fn item_type_from_id(item_id: &str) -> &'static str {
|
||||
// New format: {digits}_{type}_{slug}
|
||||
let after_num = item_id.trim_start_matches(|c: char| c.is_ascii_digit());
|
||||
if after_num.starts_with("_bug_") {
|
||||
"bug"
|
||||
} else if after_num.starts_with("_spike_") {
|
||||
"spike"
|
||||
} else {
|
||||
"story"
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the source directory path for a work item (always work/1_backlog/).
|
||||
fn item_source_dir(project_root: &Path, _item_id: &str) -> PathBuf {
|
||||
project_root.join(".storkit").join("work").join("1_backlog")
|
||||
}
|
||||
|
||||
/// Return the done directory path for a work item (always work/5_done/).
|
||||
fn item_archive_dir(project_root: &Path, _item_id: &str) -> PathBuf {
|
||||
project_root.join(".storkit").join("work").join("5_done")
|
||||
}
|
||||
|
||||
/// Move a work item (story, bug, or spike) from `work/1_backlog/` to `work/2_current/`.
|
||||
///
|
||||
/// Idempotent: if the item is already in `2_current/`, returns Ok without committing.
|
||||
/// If the item is not found in `1_backlog/`, logs a warning and returns Ok.
|
||||
pub fn move_story_to_current(project_root: &Path, story_id: &str) -> Result<(), String> {
|
||||
let sk = project_root.join(".storkit").join("work");
|
||||
let current_dir = sk.join("2_current");
|
||||
let current_path = current_dir.join(format!("{story_id}.md"));
|
||||
|
||||
if current_path.exists() {
|
||||
// Already in 2_current/ — idempotent, nothing to do.
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let source_dir = item_source_dir(project_root, story_id);
|
||||
let source_path = source_dir.join(format!("{story_id}.md"));
|
||||
|
||||
if !source_path.exists() {
|
||||
slog!(
|
||||
"[lifecycle] Work item '{story_id}' not found in {}; skipping move to 2_current/",
|
||||
source_dir.display()
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
std::fs::create_dir_all(¤t_dir)
|
||||
.map_err(|e| format!("Failed to create work/2_current/ directory: {e}"))?;
|
||||
|
||||
std::fs::rename(&source_path, ¤t_path)
|
||||
.map_err(|e| format!("Failed to move '{story_id}' to 2_current/: {e}"))?;
|
||||
|
||||
slog!(
|
||||
"[lifecycle] Moved '{story_id}' from {} to work/2_current/",
|
||||
source_dir.display()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check whether a feature branch `feature/story-{story_id}` exists and has
|
||||
/// commits that are not yet on master. Returns `true` when there is unmerged
|
||||
/// work, `false` when there is no branch or all its commits are already
|
||||
/// reachable from master.
|
||||
pub fn feature_branch_has_unmerged_changes(project_root: &Path, story_id: &str) -> bool {
|
||||
let branch = format!("feature/story-{story_id}");
|
||||
|
||||
// Check if the branch exists.
|
||||
let branch_check = Command::new("git")
|
||||
.args(["rev-parse", "--verify", &branch])
|
||||
.current_dir(project_root)
|
||||
.output();
|
||||
match branch_check {
|
||||
Ok(out) if out.status.success() => {}
|
||||
_ => return false, // No feature branch → nothing to merge.
|
||||
}
|
||||
|
||||
// Check if the branch has commits not reachable from master.
|
||||
let log = Command::new("git")
|
||||
.args(["log", &format!("master..{branch}"), "--oneline"])
|
||||
.current_dir(project_root)
|
||||
.output();
|
||||
match log {
|
||||
Ok(out) => {
|
||||
let stdout = String::from_utf8_lossy(&out.stdout);
|
||||
!stdout.trim().is_empty()
|
||||
}
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Move a story from `work/2_current/` to `work/5_done/` and auto-commit.
|
||||
///
|
||||
/// * If the story is in `2_current/`, it is moved to `5_done/` and committed.
|
||||
/// * If the story is in `4_merge/`, it is moved to `5_done/` and committed.
|
||||
/// * If the story is already in `5_done/` or `6_archived/`, this is a no-op (idempotent).
|
||||
/// * If the story is not found in `2_current/`, `4_merge/`, `5_done/`, or `6_archived/`, an error is returned.
|
||||
pub fn move_story_to_archived(project_root: &Path, story_id: &str) -> Result<(), String> {
|
||||
let sk = project_root.join(".storkit").join("work");
|
||||
let current_path = sk.join("2_current").join(format!("{story_id}.md"));
|
||||
let merge_path = sk.join("4_merge").join(format!("{story_id}.md"));
|
||||
let done_dir = sk.join("5_done");
|
||||
let done_path = done_dir.join(format!("{story_id}.md"));
|
||||
let archived_path = sk.join("6_archived").join(format!("{story_id}.md"));
|
||||
|
||||
if done_path.exists() || archived_path.exists() {
|
||||
// Already in done or archived — idempotent, nothing to do.
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Check 2_current/ first, then 4_merge/
|
||||
let source_path = if current_path.exists() {
|
||||
current_path.clone()
|
||||
} else if merge_path.exists() {
|
||||
merge_path.clone()
|
||||
} else {
|
||||
return Err(format!(
|
||||
"Story '{story_id}' not found in work/2_current/ or work/4_merge/. Cannot accept story."
|
||||
));
|
||||
};
|
||||
|
||||
std::fs::create_dir_all(&done_dir)
|
||||
.map_err(|e| format!("Failed to create work/5_done/ directory: {e}"))?;
|
||||
std::fs::rename(&source_path, &done_path)
|
||||
.map_err(|e| format!("Failed to move story '{story_id}' to 5_done/: {e}"))?;
|
||||
|
||||
// Strip stale pipeline fields from front matter now that the story is done.
|
||||
for field in &["merge_failure", "retry_count", "blocked"] {
|
||||
if let Err(e) = clear_front_matter_field(&done_path, field) {
|
||||
slog!("[lifecycle] Warning: could not clear {field} from '{story_id}': {e}");
|
||||
}
|
||||
}
|
||||
|
||||
let from_dir = if source_path == current_path {
|
||||
"work/2_current/"
|
||||
} else {
|
||||
"work/4_merge/"
|
||||
};
|
||||
slog!("[lifecycle] Moved story '{story_id}' from {from_dir} to work/5_done/");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Move a story/bug from `work/2_current/` or `work/3_qa/` to `work/4_merge/`.
|
||||
///
|
||||
/// This stages a work item as ready for the mergemaster to pick up and merge into master.
|
||||
/// Idempotent: if already in `4_merge/`, returns Ok without committing.
|
||||
pub fn move_story_to_merge(project_root: &Path, story_id: &str) -> Result<(), String> {
|
||||
let sk = project_root.join(".storkit").join("work");
|
||||
let current_path = sk.join("2_current").join(format!("{story_id}.md"));
|
||||
let qa_path = sk.join("3_qa").join(format!("{story_id}.md"));
|
||||
let merge_dir = sk.join("4_merge");
|
||||
let merge_path = merge_dir.join(format!("{story_id}.md"));
|
||||
|
||||
if merge_path.exists() {
|
||||
// Already in 4_merge/ — idempotent, nothing to do.
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Accept from 2_current/ (manual trigger) or 3_qa/ (pipeline advancement from QA stage).
|
||||
let source_path = if current_path.exists() {
|
||||
current_path.clone()
|
||||
} else if qa_path.exists() {
|
||||
qa_path.clone()
|
||||
} else {
|
||||
return Err(format!(
|
||||
"Work item '{story_id}' not found in work/2_current/ or work/3_qa/. Cannot move to 4_merge/."
|
||||
));
|
||||
};
|
||||
|
||||
std::fs::create_dir_all(&merge_dir)
|
||||
.map_err(|e| format!("Failed to create work/4_merge/ directory: {e}"))?;
|
||||
std::fs::rename(&source_path, &merge_path)
|
||||
.map_err(|e| format!("Failed to move '{story_id}' to 4_merge/: {e}"))?;
|
||||
|
||||
let from_dir = if source_path == current_path {
|
||||
"work/2_current/"
|
||||
} else {
|
||||
"work/3_qa/"
|
||||
};
|
||||
// Reset retry count and blocked for the new stage.
|
||||
if let Err(e) = clear_front_matter_field(&merge_path, "retry_count") {
|
||||
slog!("[lifecycle] Warning: could not clear retry_count for '{story_id}': {e}");
|
||||
}
|
||||
if let Err(e) = clear_front_matter_field(&merge_path, "blocked") {
|
||||
slog!("[lifecycle] Warning: could not clear blocked for '{story_id}': {e}");
|
||||
}
|
||||
|
||||
slog!("[lifecycle] Moved '{story_id}' from {from_dir} to work/4_merge/");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Move a story/bug from `work/2_current/` to `work/3_qa/` and auto-commit.
|
||||
///
|
||||
/// This stages a work item for QA review before merging to master.
|
||||
/// Idempotent: if already in `3_qa/`, returns Ok without committing.
|
||||
pub fn move_story_to_qa(project_root: &Path, story_id: &str) -> Result<(), String> {
|
||||
let sk = project_root.join(".storkit").join("work");
|
||||
let current_path = sk.join("2_current").join(format!("{story_id}.md"));
|
||||
let qa_dir = sk.join("3_qa");
|
||||
let qa_path = qa_dir.join(format!("{story_id}.md"));
|
||||
|
||||
if qa_path.exists() {
|
||||
// Already in 3_qa/ — idempotent, nothing to do.
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if !current_path.exists() {
|
||||
return Err(format!(
|
||||
"Work item '{story_id}' not found in work/2_current/. Cannot move to 3_qa/."
|
||||
));
|
||||
}
|
||||
|
||||
std::fs::create_dir_all(&qa_dir)
|
||||
.map_err(|e| format!("Failed to create work/3_qa/ directory: {e}"))?;
|
||||
std::fs::rename(¤t_path, &qa_path)
|
||||
.map_err(|e| format!("Failed to move '{story_id}' to 3_qa/: {e}"))?;
|
||||
|
||||
// Reset retry count for the new stage.
|
||||
if let Err(e) = clear_front_matter_field(&qa_path, "retry_count") {
|
||||
slog!("[lifecycle] Warning: could not clear retry_count for '{story_id}': {e}");
|
||||
}
|
||||
if let Err(e) = clear_front_matter_field(&qa_path, "blocked") {
|
||||
slog!("[lifecycle] Warning: could not clear blocked for '{story_id}': {e}");
|
||||
}
|
||||
|
||||
slog!("[lifecycle] Moved '{story_id}' from work/2_current/ to work/3_qa/");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Move a story from `work/3_qa/` back to `work/2_current/` and write rejection notes.
|
||||
///
|
||||
/// Used when a human reviewer rejects a story during manual QA.
|
||||
/// Clears the `review_hold` front matter field and appends rejection notes to the story file.
|
||||
pub fn reject_story_from_qa(
|
||||
project_root: &Path,
|
||||
story_id: &str,
|
||||
notes: &str,
|
||||
) -> Result<(), String> {
|
||||
let sk = project_root.join(".storkit").join("work");
|
||||
let qa_path = sk.join("3_qa").join(format!("{story_id}.md"));
|
||||
let current_dir = sk.join("2_current");
|
||||
let current_path = current_dir.join(format!("{story_id}.md"));
|
||||
|
||||
if current_path.exists() {
|
||||
return Ok(()); // Already in 2_current — idempotent.
|
||||
}
|
||||
|
||||
if !qa_path.exists() {
|
||||
return Err(format!(
|
||||
"Work item '{story_id}' not found in work/3_qa/. Cannot reject."
|
||||
));
|
||||
}
|
||||
|
||||
std::fs::create_dir_all(¤t_dir)
|
||||
.map_err(|e| format!("Failed to create work/2_current/ directory: {e}"))?;
|
||||
std::fs::rename(&qa_path, ¤t_path)
|
||||
.map_err(|e| format!("Failed to move '{story_id}' from 3_qa/ to 2_current/: {e}"))?;
|
||||
|
||||
// Clear review_hold since the story is going back for rework.
|
||||
if let Err(e) = clear_front_matter_field(¤t_path, "review_hold") {
|
||||
slog!("[lifecycle] Warning: could not clear review_hold from '{story_id}': {e}");
|
||||
}
|
||||
|
||||
// Write rejection notes into the story file so the coder can see what needs fixing.
|
||||
if !notes.is_empty()
|
||||
&& let Err(e) = write_rejection_notes(¤t_path, notes)
|
||||
{
|
||||
slog!("[lifecycle] Warning: could not write rejection notes to '{story_id}': {e}");
|
||||
}
|
||||
|
||||
slog!("[lifecycle] Rejected '{story_id}' from work/3_qa/ back to work/2_current/");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Move any work item to an arbitrary pipeline stage by searching all stages.
|
||||
///
|
||||
/// Accepts `target_stage` as one of: `backlog`, `current`, `qa`, `merge`, `done`.
|
||||
/// Idempotent: if the item is already in the target stage, returns Ok.
|
||||
/// Returns `(from_stage, to_stage)` on success.
|
||||
pub fn move_story_to_stage(
|
||||
project_root: &Path,
|
||||
story_id: &str,
|
||||
target_stage: &str,
|
||||
) -> Result<(String, String), String> {
|
||||
let stage_dirs: &[(&str, &str)] = &[
|
||||
("backlog", "1_backlog"),
|
||||
("current", "2_current"),
|
||||
("qa", "3_qa"),
|
||||
("merge", "4_merge"),
|
||||
("done", "5_done"),
|
||||
];
|
||||
|
||||
let target_dir_name = stage_dirs
|
||||
.iter()
|
||||
.find(|(name, _)| *name == target_stage)
|
||||
.map(|(_, dir)| *dir)
|
||||
.ok_or_else(|| {
|
||||
format!(
|
||||
"Invalid target_stage '{target_stage}'. Must be one of: backlog, current, qa, merge, done"
|
||||
)
|
||||
})?;
|
||||
|
||||
let sk = project_root.join(".storkit").join("work");
|
||||
let target_dir = sk.join(target_dir_name);
|
||||
let target_path = target_dir.join(format!("{story_id}.md"));
|
||||
|
||||
if target_path.exists() {
|
||||
return Ok((target_stage.to_string(), target_stage.to_string()));
|
||||
}
|
||||
|
||||
// Search all named stages plus the archive stage.
|
||||
let search_dirs: &[(&str, &str)] = &[
|
||||
("backlog", "1_backlog"),
|
||||
("current", "2_current"),
|
||||
("qa", "3_qa"),
|
||||
("merge", "4_merge"),
|
||||
("done", "5_done"),
|
||||
("archived", "6_archived"),
|
||||
];
|
||||
|
||||
let mut found_path: Option<std::path::PathBuf> = None;
|
||||
let mut from_stage = "";
|
||||
for (stage_name, dir_name) in search_dirs {
|
||||
let candidate = sk.join(dir_name).join(format!("{story_id}.md"));
|
||||
if candidate.exists() {
|
||||
found_path = Some(candidate);
|
||||
from_stage = stage_name;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let source_path =
|
||||
found_path.ok_or_else(|| format!("Work item '{story_id}' not found in any pipeline stage."))?;
|
||||
|
||||
std::fs::create_dir_all(&target_dir)
|
||||
.map_err(|e| format!("Failed to create work/{target_dir_name}/ directory: {e}"))?;
|
||||
std::fs::rename(&source_path, &target_path)
|
||||
.map_err(|e| format!("Failed to move '{story_id}' to work/{target_dir_name}/: {e}"))?;
|
||||
|
||||
slog!(
|
||||
"[lifecycle] Moved '{story_id}' from work/{from_stage}/ to work/{target_dir_name}/"
|
||||
);
|
||||
|
||||
Ok((from_stage.to_string(), target_stage.to_string()))
|
||||
}
|
||||
|
||||
/// Move a bug from `work/2_current/` or `work/1_backlog/` to `work/5_done/` and auto-commit.
|
||||
///
|
||||
/// * If the bug is in `2_current/`, it is moved to `5_done/` and committed.
|
||||
/// * If the bug is still in `1_backlog/` (never started), it is moved directly to `5_done/`.
|
||||
/// * If the bug is already in `5_done/`, this is a no-op (idempotent).
|
||||
/// * If the bug is not found anywhere, an error is returned.
|
||||
pub fn close_bug_to_archive(project_root: &Path, bug_id: &str) -> Result<(), String> {
|
||||
let sk = project_root.join(".storkit").join("work");
|
||||
let current_path = sk.join("2_current").join(format!("{bug_id}.md"));
|
||||
let backlog_path = sk.join("1_backlog").join(format!("{bug_id}.md"));
|
||||
let archive_dir = item_archive_dir(project_root, bug_id);
|
||||
let archive_path = archive_dir.join(format!("{bug_id}.md"));
|
||||
|
||||
if archive_path.exists() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let source_path = if current_path.exists() {
|
||||
current_path.clone()
|
||||
} else if backlog_path.exists() {
|
||||
backlog_path.clone()
|
||||
} else {
|
||||
return Err(format!(
|
||||
"Bug '{bug_id}' not found in work/2_current/ or work/1_backlog/. Cannot close bug."
|
||||
));
|
||||
};
|
||||
|
||||
std::fs::create_dir_all(&archive_dir)
|
||||
.map_err(|e| format!("Failed to create work/5_done/ directory: {e}"))?;
|
||||
std::fs::rename(&source_path, &archive_path)
|
||||
.map_err(|e| format!("Failed to move bug '{bug_id}' to 5_done/: {e}"))?;
|
||||
|
||||
slog!(
|
||||
"[lifecycle] Closed bug '{bug_id}' → work/5_done/"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── move_story_to_current tests ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn move_story_to_current_moves_file() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
let backlog = root.join(".storkit/work/1_backlog");
|
||||
let current = root.join(".storkit/work/2_current");
|
||||
fs::create_dir_all(&backlog).unwrap();
|
||||
fs::create_dir_all(¤t).unwrap();
|
||||
fs::write(backlog.join("10_story_foo.md"), "test").unwrap();
|
||||
|
||||
move_story_to_current(root, "10_story_foo").unwrap();
|
||||
|
||||
assert!(!backlog.join("10_story_foo.md").exists());
|
||||
assert!(current.join("10_story_foo.md").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_story_to_current_is_idempotent_when_already_current() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
let current = root.join(".storkit/work/2_current");
|
||||
fs::create_dir_all(¤t).unwrap();
|
||||
fs::write(current.join("11_story_foo.md"), "test").unwrap();
|
||||
|
||||
move_story_to_current(root, "11_story_foo").unwrap();
|
||||
assert!(current.join("11_story_foo.md").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_story_to_current_noop_when_not_in_backlog() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
assert!(move_story_to_current(tmp.path(), "99_missing").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_bug_to_current_moves_from_backlog() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
let backlog = root.join(".storkit/work/1_backlog");
|
||||
let current = root.join(".storkit/work/2_current");
|
||||
fs::create_dir_all(&backlog).unwrap();
|
||||
fs::create_dir_all(¤t).unwrap();
|
||||
fs::write(backlog.join("1_bug_test.md"), "# Bug 1\n").unwrap();
|
||||
|
||||
move_story_to_current(root, "1_bug_test").unwrap();
|
||||
|
||||
assert!(!backlog.join("1_bug_test.md").exists());
|
||||
assert!(current.join("1_bug_test.md").exists());
|
||||
}
|
||||
|
||||
// ── close_bug_to_archive tests ─────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn close_bug_moves_from_current_to_archive() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
let current = root.join(".storkit/work/2_current");
|
||||
fs::create_dir_all(¤t).unwrap();
|
||||
fs::write(current.join("2_bug_test.md"), "# Bug 2\n").unwrap();
|
||||
|
||||
close_bug_to_archive(root, "2_bug_test").unwrap();
|
||||
|
||||
assert!(!current.join("2_bug_test.md").exists());
|
||||
assert!(root.join(".storkit/work/5_done/2_bug_test.md").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn close_bug_moves_from_backlog_when_not_started() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
let backlog = root.join(".storkit/work/1_backlog");
|
||||
fs::create_dir_all(&backlog).unwrap();
|
||||
fs::write(backlog.join("3_bug_test.md"), "# Bug 3\n").unwrap();
|
||||
|
||||
close_bug_to_archive(root, "3_bug_test").unwrap();
|
||||
|
||||
assert!(!backlog.join("3_bug_test.md").exists());
|
||||
assert!(root.join(".storkit/work/5_done/3_bug_test.md").exists());
|
||||
}
|
||||
|
||||
// ── item_type_from_id tests ────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn item_type_from_id_detects_types() {
|
||||
assert_eq!(item_type_from_id("1_bug_test"), "bug");
|
||||
assert_eq!(item_type_from_id("1_spike_research"), "spike");
|
||||
assert_eq!(item_type_from_id("50_story_my_story"), "story");
|
||||
assert_eq!(item_type_from_id("1_story_simple"), "story");
|
||||
}
|
||||
|
||||
// ── move_story_to_merge tests ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn move_story_to_merge_moves_file() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
let current = root.join(".storkit/work/2_current");
|
||||
fs::create_dir_all(¤t).unwrap();
|
||||
fs::write(current.join("20_story_foo.md"), "test").unwrap();
|
||||
|
||||
move_story_to_merge(root, "20_story_foo").unwrap();
|
||||
|
||||
assert!(!current.join("20_story_foo.md").exists());
|
||||
assert!(root.join(".storkit/work/4_merge/20_story_foo.md").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_story_to_merge_from_qa_dir() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
let qa_dir = root.join(".storkit/work/3_qa");
|
||||
fs::create_dir_all(&qa_dir).unwrap();
|
||||
fs::write(qa_dir.join("40_story_test.md"), "test").unwrap();
|
||||
|
||||
move_story_to_merge(root, "40_story_test").unwrap();
|
||||
|
||||
assert!(!qa_dir.join("40_story_test.md").exists());
|
||||
assert!(root.join(".storkit/work/4_merge/40_story_test.md").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_story_to_merge_idempotent_when_already_in_merge() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
let merge_dir = root.join(".storkit/work/4_merge");
|
||||
fs::create_dir_all(&merge_dir).unwrap();
|
||||
fs::write(merge_dir.join("21_story_test.md"), "test").unwrap();
|
||||
|
||||
move_story_to_merge(root, "21_story_test").unwrap();
|
||||
assert!(merge_dir.join("21_story_test.md").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_story_to_merge_errors_when_not_in_current_or_qa() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let result = move_story_to_merge(tmp.path(), "99_nonexistent");
|
||||
assert!(result.unwrap_err().contains("not found in work/2_current/ or work/3_qa/"));
|
||||
}
|
||||
|
||||
// ── move_story_to_qa tests ────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn move_story_to_qa_moves_file() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
let current = root.join(".storkit/work/2_current");
|
||||
fs::create_dir_all(¤t).unwrap();
|
||||
fs::write(current.join("30_story_qa.md"), "test").unwrap();
|
||||
|
||||
move_story_to_qa(root, "30_story_qa").unwrap();
|
||||
|
||||
assert!(!current.join("30_story_qa.md").exists());
|
||||
assert!(root.join(".storkit/work/3_qa/30_story_qa.md").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_story_to_qa_idempotent_when_already_in_qa() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
let qa_dir = root.join(".storkit/work/3_qa");
|
||||
fs::create_dir_all(&qa_dir).unwrap();
|
||||
fs::write(qa_dir.join("31_story_test.md"), "test").unwrap();
|
||||
|
||||
move_story_to_qa(root, "31_story_test").unwrap();
|
||||
assert!(qa_dir.join("31_story_test.md").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_story_to_qa_errors_when_not_in_current() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let result = move_story_to_qa(tmp.path(), "99_nonexistent");
|
||||
assert!(result.unwrap_err().contains("not found in work/2_current/"));
|
||||
}
|
||||
|
||||
// ── move_story_to_archived tests ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn move_story_to_archived_finds_in_merge_dir() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
let merge_dir = root.join(".storkit/work/4_merge");
|
||||
fs::create_dir_all(&merge_dir).unwrap();
|
||||
fs::write(merge_dir.join("22_story_test.md"), "test").unwrap();
|
||||
|
||||
move_story_to_archived(root, "22_story_test").unwrap();
|
||||
|
||||
assert!(!merge_dir.join("22_story_test.md").exists());
|
||||
assert!(root.join(".storkit/work/5_done/22_story_test.md").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_story_to_archived_error_when_not_in_current_or_merge() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let result = move_story_to_archived(tmp.path(), "99_nonexistent");
|
||||
assert!(result.unwrap_err().contains("4_merge"));
|
||||
}
|
||||
|
||||
// ── feature_branch_has_unmerged_changes tests ────────────────────────────
|
||||
|
||||
fn init_git_repo(repo: &std::path::Path) {
|
||||
Command::new("git")
|
||||
.args(["init"])
|
||||
.current_dir(repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
Command::new("git")
|
||||
.args(["config", "user.email", "test@test.com"])
|
||||
.current_dir(repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
Command::new("git")
|
||||
.args(["config", "user.name", "Test"])
|
||||
.current_dir(repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
Command::new("git")
|
||||
.args(["commit", "--allow-empty", "-m", "init"])
|
||||
.current_dir(repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
/// Bug 226: feature_branch_has_unmerged_changes returns true when the
|
||||
/// feature branch has commits not on master.
|
||||
#[test]
|
||||
fn feature_branch_has_unmerged_changes_detects_unmerged_code() {
|
||||
use std::fs;
|
||||
use tempfile::tempdir;
|
||||
|
||||
let tmp = tempdir().unwrap();
|
||||
let repo = tmp.path();
|
||||
init_git_repo(repo);
|
||||
|
||||
// Create a feature branch with a code commit.
|
||||
Command::new("git")
|
||||
.args(["checkout", "-b", "feature/story-50_story_test"])
|
||||
.current_dir(repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
fs::write(repo.join("feature.rs"), "fn main() {}").unwrap();
|
||||
Command::new("git")
|
||||
.args(["add", "."])
|
||||
.current_dir(repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
Command::new("git")
|
||||
.args(["commit", "-m", "add feature"])
|
||||
.current_dir(repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
Command::new("git")
|
||||
.args(["checkout", "master"])
|
||||
.current_dir(repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
feature_branch_has_unmerged_changes(repo, "50_story_test"),
|
||||
"should detect unmerged changes on feature branch"
|
||||
);
|
||||
}
|
||||
|
||||
/// Bug 226: feature_branch_has_unmerged_changes returns false when no
|
||||
/// feature branch exists.
|
||||
#[test]
|
||||
fn feature_branch_has_unmerged_changes_false_when_no_branch() {
|
||||
use tempfile::tempdir;
|
||||
|
||||
let tmp = tempdir().unwrap();
|
||||
let repo = tmp.path();
|
||||
init_git_repo(repo);
|
||||
|
||||
assert!(
|
||||
!feature_branch_has_unmerged_changes(repo, "99_nonexistent"),
|
||||
"should return false when no feature branch"
|
||||
);
|
||||
}
|
||||
|
||||
// ── reject_story_from_qa tests ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn reject_story_from_qa_moves_to_current() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
let qa_dir = root.join(".storkit/work/3_qa");
|
||||
let current_dir = root.join(".storkit/work/2_current");
|
||||
fs::create_dir_all(&qa_dir).unwrap();
|
||||
fs::create_dir_all(¤t_dir).unwrap();
|
||||
fs::write(
|
||||
qa_dir.join("50_story_test.md"),
|
||||
"---\nname: Test\nreview_hold: true\n---\n# Story\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
reject_story_from_qa(root, "50_story_test", "Button color wrong").unwrap();
|
||||
|
||||
assert!(!qa_dir.join("50_story_test.md").exists());
|
||||
assert!(current_dir.join("50_story_test.md").exists());
|
||||
let contents = fs::read_to_string(current_dir.join("50_story_test.md")).unwrap();
|
||||
assert!(contents.contains("Button color wrong"));
|
||||
assert!(contents.contains("## QA Rejection Notes"));
|
||||
assert!(!contents.contains("review_hold"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_story_from_qa_errors_when_not_in_qa() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let result = reject_story_from_qa(tmp.path(), "99_nonexistent", "notes");
|
||||
assert!(result.unwrap_err().contains("not found in work/3_qa/"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_story_from_qa_idempotent_when_in_current() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
let current_dir = root.join(".storkit/work/2_current");
|
||||
fs::create_dir_all(¤t_dir).unwrap();
|
||||
fs::write(current_dir.join("51_story_test.md"), "---\nname: Test\n---\n# Story\n").unwrap();
|
||||
|
||||
reject_story_from_qa(root, "51_story_test", "notes").unwrap();
|
||||
assert!(current_dir.join("51_story_test.md").exists());
|
||||
}
|
||||
|
||||
// ── move_story_to_stage tests ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn move_story_to_stage_moves_from_backlog_to_current() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
let backlog = root.join(".storkit/work/1_backlog");
|
||||
let current = root.join(".storkit/work/2_current");
|
||||
fs::create_dir_all(&backlog).unwrap();
|
||||
fs::create_dir_all(¤t).unwrap();
|
||||
fs::write(backlog.join("60_story_move.md"), "test").unwrap();
|
||||
|
||||
let (from, to) = move_story_to_stage(root, "60_story_move", "current").unwrap();
|
||||
|
||||
assert_eq!(from, "backlog");
|
||||
assert_eq!(to, "current");
|
||||
assert!(!backlog.join("60_story_move.md").exists());
|
||||
assert!(current.join("60_story_move.md").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_story_to_stage_moves_from_current_to_backlog() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
let current = root.join(".storkit/work/2_current");
|
||||
let backlog = root.join(".storkit/work/1_backlog");
|
||||
fs::create_dir_all(¤t).unwrap();
|
||||
fs::create_dir_all(&backlog).unwrap();
|
||||
fs::write(current.join("61_story_back.md"), "test").unwrap();
|
||||
|
||||
let (from, to) = move_story_to_stage(root, "61_story_back", "backlog").unwrap();
|
||||
|
||||
assert_eq!(from, "current");
|
||||
assert_eq!(to, "backlog");
|
||||
assert!(!current.join("61_story_back.md").exists());
|
||||
assert!(backlog.join("61_story_back.md").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_story_to_stage_idempotent_when_already_in_target() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
let current = root.join(".storkit/work/2_current");
|
||||
fs::create_dir_all(¤t).unwrap();
|
||||
fs::write(current.join("62_story_idem.md"), "test").unwrap();
|
||||
|
||||
let (from, to) = move_story_to_stage(root, "62_story_idem", "current").unwrap();
|
||||
|
||||
assert_eq!(from, "current");
|
||||
assert_eq!(to, "current");
|
||||
assert!(current.join("62_story_idem.md").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_story_to_stage_invalid_target_returns_error() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let result = move_story_to_stage(tmp.path(), "1_story_test", "invalid");
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("Invalid target_stage"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_story_to_stage_not_found_returns_error() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let result = move_story_to_stage(tmp.path(), "99_story_ghost", "current");
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("not found in any pipeline stage"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_story_to_stage_finds_in_qa_dir() {
|
||||
use std::fs;
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
let qa_dir = root.join(".storkit/work/3_qa");
|
||||
let backlog = root.join(".storkit/work/1_backlog");
|
||||
fs::create_dir_all(&qa_dir).unwrap();
|
||||
fs::create_dir_all(&backlog).unwrap();
|
||||
fs::write(qa_dir.join("63_story_qa.md"), "test").unwrap();
|
||||
|
||||
let (from, to) = move_story_to_stage(root, "63_story_qa", "backlog").unwrap();
|
||||
|
||||
assert_eq!(from, "qa");
|
||||
assert_eq!(to, "backlog");
|
||||
assert!(!qa_dir.join("63_story_qa.md").exists());
|
||||
assert!(backlog.join("63_story_qa.md").exists());
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,222 +0,0 @@
|
||||
pub mod gates;
|
||||
pub mod lifecycle;
|
||||
pub mod merge;
|
||||
mod pool;
|
||||
pub(crate) mod pty;
|
||||
pub mod runtime;
|
||||
pub mod token_usage;
|
||||
|
||||
use crate::config::AgentConfig;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub use lifecycle::{
|
||||
close_bug_to_archive, feature_branch_has_unmerged_changes, move_story_to_archived,
|
||||
move_story_to_merge, move_story_to_qa, move_story_to_stage, reject_story_from_qa,
|
||||
};
|
||||
pub use pool::AgentPool;
|
||||
|
||||
/// Events emitted during server startup reconciliation to broadcast real-time
|
||||
/// progress to connected WebSocket clients.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ReconciliationEvent {
|
||||
/// The story being reconciled, or empty string for the overall "done" event.
|
||||
pub story_id: String,
|
||||
/// Coarse status: "checking", "gates_running", "advanced", "skipped", "failed", "done"
|
||||
pub status: String,
|
||||
/// Human-readable details.
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// Events streamed from a running agent to SSE clients.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum AgentEvent {
|
||||
/// Agent status changed.
|
||||
Status {
|
||||
story_id: String,
|
||||
agent_name: String,
|
||||
status: String,
|
||||
},
|
||||
/// Raw text output from the agent process.
|
||||
Output {
|
||||
story_id: String,
|
||||
agent_name: String,
|
||||
text: String,
|
||||
},
|
||||
/// Agent produced a JSON event from `--output-format stream-json`.
|
||||
AgentJson {
|
||||
story_id: String,
|
||||
agent_name: String,
|
||||
data: serde_json::Value,
|
||||
},
|
||||
/// Agent finished.
|
||||
Done {
|
||||
story_id: String,
|
||||
agent_name: String,
|
||||
session_id: Option<String>,
|
||||
},
|
||||
/// Agent errored.
|
||||
Error {
|
||||
story_id: String,
|
||||
agent_name: String,
|
||||
message: String,
|
||||
},
|
||||
/// Thinking tokens from an extended-thinking block.
|
||||
Thinking {
|
||||
story_id: String,
|
||||
agent_name: String,
|
||||
text: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AgentStatus {
|
||||
Pending,
|
||||
Running,
|
||||
Completed,
|
||||
Failed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AgentStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Pending => write!(f, "pending"),
|
||||
Self::Running => write!(f, "running"),
|
||||
Self::Completed => write!(f, "completed"),
|
||||
Self::Failed => write!(f, "failed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Pipeline stages for automatic story advancement.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum PipelineStage {
|
||||
/// Coding agents (coder-1, coder-2, etc.)
|
||||
Coder,
|
||||
/// QA review agent
|
||||
Qa,
|
||||
/// Mergemaster agent
|
||||
Mergemaster,
|
||||
/// Supervisors and unknown agents — no automatic advancement.
|
||||
Other,
|
||||
}
|
||||
|
||||
/// Determine the pipeline stage from an agent name.
|
||||
pub fn pipeline_stage(agent_name: &str) -> PipelineStage {
|
||||
match agent_name {
|
||||
"qa" => PipelineStage::Qa,
|
||||
"mergemaster" => PipelineStage::Mergemaster,
|
||||
name if name.starts_with("coder") => PipelineStage::Coder,
|
||||
_ => PipelineStage::Other,
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine the pipeline stage for a configured agent.
|
||||
///
|
||||
/// Prefers the explicit `stage` config field (added in Bug 150) over the
|
||||
/// legacy name-based heuristic so that agents with non-standard names
|
||||
/// (e.g. `qa-2`, `coder-opus`) are assigned to the correct stage.
|
||||
pub(crate) fn agent_config_stage(cfg: &AgentConfig) -> PipelineStage {
|
||||
match cfg.stage.as_deref() {
|
||||
Some("coder") => PipelineStage::Coder,
|
||||
Some("qa") => PipelineStage::Qa,
|
||||
Some("mergemaster") => PipelineStage::Mergemaster,
|
||||
Some(_) => PipelineStage::Other,
|
||||
None => pipeline_stage(&cfg.name),
|
||||
}
|
||||
}
|
||||
|
||||
/// Completion report produced when acceptance gates are run.
|
||||
///
|
||||
/// Created automatically by the server when an agent process exits normally,
|
||||
/// or via the internal `report_completion` method.
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub struct CompletionReport {
|
||||
pub summary: String,
|
||||
pub gates_passed: bool,
|
||||
pub gate_output: String,
|
||||
}
|
||||
|
||||
/// Token usage from a Claude Code session's `result` event.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct TokenUsage {
|
||||
pub input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
pub cache_creation_input_tokens: u64,
|
||||
pub cache_read_input_tokens: u64,
|
||||
pub total_cost_usd: f64,
|
||||
}
|
||||
|
||||
impl TokenUsage {
|
||||
/// Parse token usage from a Claude Code `result` JSON event.
|
||||
pub fn from_result_event(json: &serde_json::Value) -> Option<Self> {
|
||||
let usage = json.get("usage")?;
|
||||
Some(Self {
|
||||
input_tokens: usage
|
||||
.get("input_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
output_tokens: usage
|
||||
.get("output_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
cache_creation_input_tokens: usage
|
||||
.get("cache_creation_input_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
cache_read_input_tokens: usage
|
||||
.get("cache_read_input_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
total_cost_usd: json
|
||||
.get("total_cost_usd")
|
||||
.and_then(|v| v.as_f64())
|
||||
.unwrap_or(0.0),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub struct AgentInfo {
|
||||
pub story_id: String,
|
||||
pub agent_name: String,
|
||||
pub status: AgentStatus,
|
||||
pub session_id: Option<String>,
|
||||
pub worktree_path: Option<String>,
|
||||
pub base_branch: Option<String>,
|
||||
pub completion: Option<CompletionReport>,
|
||||
/// UUID identifying the persistent log file for this session.
|
||||
pub log_session_id: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── pipeline_stage tests ──────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn pipeline_stage_detects_coders() {
|
||||
assert_eq!(pipeline_stage("coder-1"), PipelineStage::Coder);
|
||||
assert_eq!(pipeline_stage("coder-2"), PipelineStage::Coder);
|
||||
assert_eq!(pipeline_stage("coder-3"), PipelineStage::Coder);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pipeline_stage_detects_qa() {
|
||||
assert_eq!(pipeline_stage("qa"), PipelineStage::Qa);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pipeline_stage_detects_mergemaster() {
|
||||
assert_eq!(pipeline_stage("mergemaster"), PipelineStage::Mergemaster);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pipeline_stage_supervisor_is_other() {
|
||||
assert_eq!(pipeline_stage("supervisor"), PipelineStage::Other);
|
||||
assert_eq!(pipeline_stage("default"), PipelineStage::Other);
|
||||
assert_eq!(pipeline_stage("unknown"), PipelineStage::Other);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,514 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use portable_pty::{ChildKiller, CommandBuilder, PtySize, native_pty_system};
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use super::{AgentEvent, TokenUsage};
|
||||
use crate::agent_log::AgentLogWriter;
|
||||
use crate::slog;
|
||||
use crate::slog_warn;
|
||||
|
||||
/// Result from a PTY agent session, containing the session ID and token usage.
|
||||
pub(in crate::agents) struct PtyResult {
|
||||
pub session_id: Option<String>,
|
||||
pub token_usage: Option<TokenUsage>,
|
||||
}
|
||||
|
||||
fn composite_key(story_id: &str, agent_name: &str) -> String {
|
||||
format!("{story_id}:{agent_name}")
|
||||
}
|
||||
|
||||
struct ChildKillerGuard {
|
||||
killers: Arc<Mutex<HashMap<String, Box<dyn ChildKiller + Send + Sync>>>>,
|
||||
key: String,
|
||||
}
|
||||
|
||||
impl Drop for ChildKillerGuard {
|
||||
fn drop(&mut self) {
|
||||
if let Ok(mut killers) = self.killers.lock() {
|
||||
killers.remove(&self.key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawn claude agent in a PTY and stream events through the broadcast channel.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(in crate::agents) async fn run_agent_pty_streaming(
|
||||
story_id: &str,
|
||||
agent_name: &str,
|
||||
command: &str,
|
||||
args: &[String],
|
||||
prompt: &str,
|
||||
cwd: &str,
|
||||
tx: &broadcast::Sender<AgentEvent>,
|
||||
event_log: &Arc<Mutex<Vec<AgentEvent>>>,
|
||||
log_writer: Option<Arc<Mutex<AgentLogWriter>>>,
|
||||
inactivity_timeout_secs: u64,
|
||||
child_killers: Arc<Mutex<HashMap<String, Box<dyn ChildKiller + Send + Sync>>>>,
|
||||
) -> Result<PtyResult, String> {
|
||||
let sid = story_id.to_string();
|
||||
let aname = agent_name.to_string();
|
||||
let cmd = command.to_string();
|
||||
let args = args.to_vec();
|
||||
let prompt = prompt.to_string();
|
||||
let cwd = cwd.to_string();
|
||||
let tx = tx.clone();
|
||||
let event_log = event_log.clone();
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
run_agent_pty_blocking(
|
||||
&sid,
|
||||
&aname,
|
||||
&cmd,
|
||||
&args,
|
||||
&prompt,
|
||||
&cwd,
|
||||
&tx,
|
||||
&event_log,
|
||||
log_writer.as_deref(),
|
||||
inactivity_timeout_secs,
|
||||
&child_killers,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| format!("Agent task panicked: {e}"))?
|
||||
}
|
||||
|
||||
/// Dispatch a `stream_event` from Claude Code's `--include-partial-messages` output.
|
||||
///
|
||||
/// Extracts `thinking_delta` and `text_delta` from `content_block_delta` events
|
||||
/// and routes them as `AgentEvent::Thinking` and `AgentEvent::Output` respectively.
|
||||
/// This ensures thinking traces flow through the dedicated `ThinkingBlock` UI
|
||||
/// component rather than appearing as unbounded regular output.
|
||||
fn handle_agent_stream_event(
|
||||
event: &serde_json::Value,
|
||||
story_id: &str,
|
||||
agent_name: &str,
|
||||
tx: &broadcast::Sender<AgentEvent>,
|
||||
event_log: &Mutex<Vec<AgentEvent>>,
|
||||
log_writer: Option<&Mutex<AgentLogWriter>>,
|
||||
) {
|
||||
let event_type = event.get("type").and_then(|t| t.as_str()).unwrap_or("");
|
||||
|
||||
if event_type == "content_block_delta"
|
||||
&& let Some(delta) = event.get("delta")
|
||||
{
|
||||
let delta_type = delta.get("type").and_then(|t| t.as_str()).unwrap_or("");
|
||||
match delta_type {
|
||||
"thinking_delta" => {
|
||||
if let Some(thinking) = delta.get("thinking").and_then(|t| t.as_str()) {
|
||||
emit_event(
|
||||
AgentEvent::Thinking {
|
||||
story_id: story_id.to_string(),
|
||||
agent_name: agent_name.to_string(),
|
||||
text: thinking.to_string(),
|
||||
},
|
||||
tx,
|
||||
event_log,
|
||||
log_writer,
|
||||
);
|
||||
}
|
||||
}
|
||||
"text_delta" => {
|
||||
if let Some(text) = delta.get("text").and_then(|t| t.as_str()) {
|
||||
emit_event(
|
||||
AgentEvent::Output {
|
||||
story_id: story_id.to_string(),
|
||||
agent_name: agent_name.to_string(),
|
||||
text: text.to_string(),
|
||||
},
|
||||
tx,
|
||||
event_log,
|
||||
log_writer,
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to send an event to broadcast, event log, and optional persistent log file.
|
||||
pub(super) fn emit_event(
|
||||
event: AgentEvent,
|
||||
tx: &broadcast::Sender<AgentEvent>,
|
||||
event_log: &Mutex<Vec<AgentEvent>>,
|
||||
log_writer: Option<&Mutex<AgentLogWriter>>,
|
||||
) {
|
||||
if let Ok(mut log) = event_log.lock() {
|
||||
log.push(event.clone());
|
||||
}
|
||||
if let Some(writer) = log_writer
|
||||
&& let Ok(mut w) = writer.lock()
|
||||
&& let Err(e) = w.write_event(&event)
|
||||
{
|
||||
eprintln!("[agent_log] Failed to write event to log file: {e}");
|
||||
}
|
||||
let _ = tx.send(event);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_agent_pty_blocking(
|
||||
story_id: &str,
|
||||
agent_name: &str,
|
||||
command: &str,
|
||||
args: &[String],
|
||||
prompt: &str,
|
||||
cwd: &str,
|
||||
tx: &broadcast::Sender<AgentEvent>,
|
||||
event_log: &Mutex<Vec<AgentEvent>>,
|
||||
log_writer: Option<&Mutex<AgentLogWriter>>,
|
||||
inactivity_timeout_secs: u64,
|
||||
child_killers: &Arc<Mutex<HashMap<String, Box<dyn ChildKiller + Send + Sync>>>>,
|
||||
) -> Result<PtyResult, String> {
|
||||
let pty_system = native_pty_system();
|
||||
|
||||
let pair = pty_system
|
||||
.openpty(PtySize {
|
||||
rows: 50,
|
||||
cols: 200,
|
||||
pixel_width: 0,
|
||||
pixel_height: 0,
|
||||
})
|
||||
.map_err(|e| format!("Failed to open PTY: {e}"))?;
|
||||
|
||||
let mut cmd = CommandBuilder::new(command);
|
||||
|
||||
// -p <prompt> must come first
|
||||
cmd.arg("-p");
|
||||
cmd.arg(prompt);
|
||||
|
||||
// Add configured args (e.g., --directory /path/to/worktree, --model, etc.)
|
||||
for arg in args {
|
||||
cmd.arg(arg);
|
||||
}
|
||||
|
||||
cmd.arg("--output-format");
|
||||
cmd.arg("stream-json");
|
||||
cmd.arg("--verbose");
|
||||
// Enable partial streaming so we receive thinking_delta and text_delta
|
||||
// events in real-time, rather than only complete assistant events.
|
||||
// Without this, thinking traces may not appear in the structured output
|
||||
// and instead leak as unstructured PTY text.
|
||||
cmd.arg("--include-partial-messages");
|
||||
|
||||
// Supervised agents don't need interactive permission prompts
|
||||
cmd.arg("--permission-mode");
|
||||
cmd.arg("bypassPermissions");
|
||||
|
||||
cmd.cwd(cwd);
|
||||
cmd.env("NO_COLOR", "1");
|
||||
|
||||
// Allow spawning Claude Code from within a Claude Code session
|
||||
cmd.env_remove("CLAUDECODE");
|
||||
cmd.env_remove("CLAUDE_CODE_ENTRYPOINT");
|
||||
|
||||
slog!("[agent:{story_id}:{agent_name}] Spawning {command} in {cwd} with args: {args:?}");
|
||||
|
||||
let mut child = pair
|
||||
.slave
|
||||
.spawn_command(cmd)
|
||||
.map_err(|e| format!("Failed to spawn agent for {story_id}:{agent_name}: {e}"))?;
|
||||
|
||||
// Register the child killer so that kill_all_children() / stop_agent() can
|
||||
// terminate this process on server shutdown, even if the blocking thread
|
||||
// cannot be interrupted. The ChildKillerGuard deregisters on function exit.
|
||||
let killer_key = composite_key(story_id, agent_name);
|
||||
{
|
||||
let killer = child.clone_killer();
|
||||
if let Ok(mut killers) = child_killers.lock() {
|
||||
killers.insert(killer_key.clone(), killer);
|
||||
}
|
||||
}
|
||||
let _killer_guard = ChildKillerGuard {
|
||||
killers: Arc::clone(child_killers),
|
||||
key: killer_key,
|
||||
};
|
||||
|
||||
drop(pair.slave);
|
||||
|
||||
let reader = pair
|
||||
.master
|
||||
.try_clone_reader()
|
||||
.map_err(|e| format!("Failed to clone PTY reader: {e}"))?;
|
||||
|
||||
drop(pair.master);
|
||||
|
||||
// Spawn a reader thread to collect PTY output lines.
|
||||
// We use a channel so the main thread can apply an inactivity deadline
|
||||
// via recv_timeout: if no output arrives within the configured window
|
||||
// the process is killed and the agent is marked Failed.
|
||||
let (line_tx, line_rx) = std::sync::mpsc::channel::<std::io::Result<String>>();
|
||||
std::thread::spawn(move || {
|
||||
let buf_reader = BufReader::new(reader);
|
||||
for line in buf_reader.lines() {
|
||||
if line_tx.send(line).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let timeout_dur = if inactivity_timeout_secs > 0 {
|
||||
Some(std::time::Duration::from_secs(inactivity_timeout_secs))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut session_id: Option<String> = None;
|
||||
let mut token_usage: Option<TokenUsage> = None;
|
||||
|
||||
loop {
|
||||
let recv_result = match timeout_dur {
|
||||
Some(dur) => line_rx.recv_timeout(dur),
|
||||
None => line_rx
|
||||
.recv()
|
||||
.map_err(|_| std::sync::mpsc::RecvTimeoutError::Disconnected),
|
||||
};
|
||||
|
||||
let line = match recv_result {
|
||||
Ok(Ok(l)) => l,
|
||||
Ok(Err(_)) => {
|
||||
// IO error reading from PTY — treat as EOF.
|
||||
break;
|
||||
}
|
||||
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
|
||||
// Reader thread exited (EOF from PTY).
|
||||
break;
|
||||
}
|
||||
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
|
||||
slog_warn!(
|
||||
"[agent:{story_id}:{agent_name}] Inactivity timeout after \
|
||||
{inactivity_timeout_secs}s with no output. Killing process."
|
||||
);
|
||||
let _ = child.kill();
|
||||
let _ = child.wait();
|
||||
return Err(format!(
|
||||
"Agent inactivity timeout: no output received for {inactivity_timeout_secs}s"
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Try to parse as JSON
|
||||
let json: serde_json::Value = match serde_json::from_str(trimmed) {
|
||||
Ok(j) => j,
|
||||
Err(_) => {
|
||||
// Non-JSON output (terminal escapes etc.) — send as raw output
|
||||
emit_event(
|
||||
AgentEvent::Output {
|
||||
story_id: story_id.to_string(),
|
||||
agent_name: agent_name.to_string(),
|
||||
text: trimmed.to_string(),
|
||||
},
|
||||
tx,
|
||||
event_log,
|
||||
log_writer,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let event_type = json.get("type").and_then(|t| t.as_str()).unwrap_or("");
|
||||
|
||||
match event_type {
|
||||
"system" => {
|
||||
session_id = json
|
||||
.get("session_id")
|
||||
.and_then(|s| s.as_str())
|
||||
.map(|s| s.to_string());
|
||||
}
|
||||
// With --include-partial-messages, thinking and text arrive
|
||||
// incrementally via stream_event → content_block_delta. Handle
|
||||
// them here for real-time streaming to the frontend.
|
||||
"stream_event" => {
|
||||
if let Some(event) = json.get("event") {
|
||||
handle_agent_stream_event(
|
||||
event,
|
||||
story_id,
|
||||
agent_name,
|
||||
tx,
|
||||
event_log,
|
||||
log_writer,
|
||||
);
|
||||
}
|
||||
}
|
||||
// Complete assistant events are skipped for content extraction
|
||||
// because thinking and text already arrived via stream_event.
|
||||
// The raw JSON is still forwarded as AgentJson below.
|
||||
"assistant" | "user" => {}
|
||||
"result" => {
|
||||
// Extract token usage from the result event.
|
||||
if let Some(usage) = TokenUsage::from_result_event(&json) {
|
||||
slog!(
|
||||
"[agent:{story_id}:{agent_name}] Token usage: in={} out={} cache_create={} cache_read={} cost=${:.4}",
|
||||
usage.input_tokens,
|
||||
usage.output_tokens,
|
||||
usage.cache_creation_input_tokens,
|
||||
usage.cache_read_input_tokens,
|
||||
usage.total_cost_usd,
|
||||
);
|
||||
token_usage = Some(usage);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Forward all JSON events
|
||||
emit_event(
|
||||
AgentEvent::AgentJson {
|
||||
story_id: story_id.to_string(),
|
||||
agent_name: agent_name.to_string(),
|
||||
data: json,
|
||||
},
|
||||
tx,
|
||||
event_log,
|
||||
log_writer,
|
||||
);
|
||||
}
|
||||
|
||||
let _ = child.kill();
|
||||
let _ = child.wait();
|
||||
|
||||
slog!(
|
||||
"[agent:{story_id}:{agent_name}] Done. Session: {:?}",
|
||||
session_id
|
||||
);
|
||||
|
||||
Ok(PtyResult {
|
||||
session_id,
|
||||
token_usage,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::agents::AgentEvent;
|
||||
|
||||
#[test]
|
||||
fn test_emit_event_writes_to_log_writer() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let root = tmp.path();
|
||||
|
||||
let log_writer =
|
||||
AgentLogWriter::new(root, "42_story_foo", "coder-1", "sess-emit").unwrap();
|
||||
let log_mutex = Mutex::new(log_writer);
|
||||
|
||||
let (tx, _rx) = broadcast::channel::<AgentEvent>(64);
|
||||
let event_log: Mutex<Vec<AgentEvent>> = Mutex::new(Vec::new());
|
||||
|
||||
let event = AgentEvent::Status {
|
||||
story_id: "42_story_foo".to_string(),
|
||||
agent_name: "coder-1".to_string(),
|
||||
status: "running".to_string(),
|
||||
};
|
||||
|
||||
emit_event(event, &tx, &event_log, Some(&log_mutex));
|
||||
|
||||
// Verify event was added to in-memory log
|
||||
let mem_events = event_log.lock().unwrap();
|
||||
assert_eq!(mem_events.len(), 1);
|
||||
drop(mem_events);
|
||||
|
||||
// Verify event was written to the log file
|
||||
let log_path =
|
||||
crate::agent_log::log_file_path(root, "42_story_foo", "coder-1", "sess-emit");
|
||||
let entries = crate::agent_log::read_log(&log_path).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert_eq!(entries[0].event["type"], "status");
|
||||
assert_eq!(entries[0].event["status"], "running");
|
||||
}
|
||||
|
||||
// ── bug 167: handle_agent_stream_event routes thinking/text correctly ───
|
||||
|
||||
#[test]
|
||||
fn stream_event_thinking_delta_emits_thinking_event() {
|
||||
let (tx, mut rx) = broadcast::channel::<AgentEvent>(64);
|
||||
let event_log: Mutex<Vec<AgentEvent>> = Mutex::new(Vec::new());
|
||||
|
||||
let event = serde_json::json!({
|
||||
"type": "content_block_delta",
|
||||
"delta": {"type": "thinking_delta", "thinking": "Let me analyze this..."}
|
||||
});
|
||||
|
||||
handle_agent_stream_event(&event, "s1", "coder-1", &tx, &event_log, None);
|
||||
|
||||
let received = rx.try_recv().unwrap();
|
||||
match received {
|
||||
AgentEvent::Thinking {
|
||||
story_id,
|
||||
agent_name,
|
||||
text,
|
||||
} => {
|
||||
assert_eq!(story_id, "s1");
|
||||
assert_eq!(agent_name, "coder-1");
|
||||
assert_eq!(text, "Let me analyze this...");
|
||||
}
|
||||
other => panic!("Expected Thinking event, got: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_event_text_delta_emits_output_event() {
|
||||
let (tx, mut rx) = broadcast::channel::<AgentEvent>(64);
|
||||
let event_log: Mutex<Vec<AgentEvent>> = Mutex::new(Vec::new());
|
||||
|
||||
let event = serde_json::json!({
|
||||
"type": "content_block_delta",
|
||||
"delta": {"type": "text_delta", "text": "Here is the result."}
|
||||
});
|
||||
|
||||
handle_agent_stream_event(&event, "s1", "coder-1", &tx, &event_log, None);
|
||||
|
||||
let received = rx.try_recv().unwrap();
|
||||
match received {
|
||||
AgentEvent::Output {
|
||||
story_id,
|
||||
agent_name,
|
||||
text,
|
||||
} => {
|
||||
assert_eq!(story_id, "s1");
|
||||
assert_eq!(agent_name, "coder-1");
|
||||
assert_eq!(text, "Here is the result.");
|
||||
}
|
||||
other => panic!("Expected Output event, got: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_event_input_json_delta_ignored() {
|
||||
let (tx, mut rx) = broadcast::channel::<AgentEvent>(64);
|
||||
let event_log: Mutex<Vec<AgentEvent>> = Mutex::new(Vec::new());
|
||||
|
||||
let event = serde_json::json!({
|
||||
"type": "content_block_delta",
|
||||
"delta": {"type": "input_json_delta", "partial_json": "{\"file\":"}
|
||||
});
|
||||
|
||||
handle_agent_stream_event(&event, "s1", "coder-1", &tx, &event_log, None);
|
||||
|
||||
// No event should be emitted for tool argument deltas
|
||||
assert!(rx.try_recv().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_event_non_delta_type_ignored() {
|
||||
let (tx, mut rx) = broadcast::channel::<AgentEvent>(64);
|
||||
let event_log: Mutex<Vec<AgentEvent>> = Mutex::new(Vec::new());
|
||||
|
||||
let event = serde_json::json!({
|
||||
"type": "message_start",
|
||||
"message": {"role": "assistant"}
|
||||
});
|
||||
|
||||
handle_agent_stream_event(&event, "s1", "coder-1", &tx, &event_log, None);
|
||||
|
||||
assert!(rx.try_recv().is_err());
|
||||
}
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use portable_pty::ChildKiller;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use crate::agent_log::AgentLogWriter;
|
||||
|
||||
use super::{AgentEvent, AgentRuntime, RuntimeContext, RuntimeResult, RuntimeStatus};
|
||||
|
||||
/// Agent runtime that spawns the `claude` CLI in a PTY and streams JSON events.
|
||||
///
|
||||
/// This is the default runtime (`runtime = "claude-code"` in project.toml).
|
||||
/// It wraps the existing PTY-based execution logic, preserving all streaming,
|
||||
/// token tracking, and inactivity timeout behaviour.
|
||||
pub struct ClaudeCodeRuntime {
|
||||
child_killers: Arc<Mutex<HashMap<String, Box<dyn ChildKiller + Send + Sync>>>>,
|
||||
}
|
||||
|
||||
impl ClaudeCodeRuntime {
|
||||
pub fn new(
|
||||
child_killers: Arc<Mutex<HashMap<String, Box<dyn ChildKiller + Send + Sync>>>>,
|
||||
) -> Self {
|
||||
Self { child_killers }
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentRuntime for ClaudeCodeRuntime {
|
||||
async fn start(
|
||||
&self,
|
||||
ctx: RuntimeContext,
|
||||
tx: broadcast::Sender<AgentEvent>,
|
||||
event_log: Arc<Mutex<Vec<AgentEvent>>>,
|
||||
log_writer: Option<Arc<Mutex<AgentLogWriter>>>,
|
||||
) -> Result<RuntimeResult, String> {
|
||||
let pty_result = super::super::pty::run_agent_pty_streaming(
|
||||
&ctx.story_id,
|
||||
&ctx.agent_name,
|
||||
&ctx.command,
|
||||
&ctx.args,
|
||||
&ctx.prompt,
|
||||
&ctx.cwd,
|
||||
&tx,
|
||||
&event_log,
|
||||
log_writer,
|
||||
ctx.inactivity_timeout_secs,
|
||||
Arc::clone(&self.child_killers),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(RuntimeResult {
|
||||
session_id: pty_result.session_id,
|
||||
token_usage: pty_result.token_usage,
|
||||
})
|
||||
}
|
||||
|
||||
fn stop(&self) {
|
||||
// Stopping is handled externally by the pool via kill_child_for_key().
|
||||
// The ChildKillerGuard in pty.rs deregisters automatically on process exit.
|
||||
}
|
||||
|
||||
fn get_status(&self) -> RuntimeStatus {
|
||||
// Lifecycle status is tracked by the pool; the runtime itself is stateless.
|
||||
RuntimeStatus::Idle
|
||||
}
|
||||
}
|
||||
@@ -1,809 +0,0 @@
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use crate::agent_log::AgentLogWriter;
|
||||
use crate::slog;
|
||||
|
||||
use super::super::{AgentEvent, TokenUsage};
|
||||
use super::{AgentRuntime, RuntimeContext, RuntimeResult, RuntimeStatus};
|
||||
|
||||
// ── Public runtime struct ────────────────────────────────────────────
|
||||
|
||||
/// Agent runtime that drives a Gemini model through the Google AI
|
||||
/// `generateContent` REST API.
|
||||
///
|
||||
/// The runtime:
|
||||
/// 1. Fetches MCP tool definitions from storkit's MCP server.
|
||||
/// 2. Converts them to Gemini function-calling format.
|
||||
/// 3. Sends the agent prompt + tools to the Gemini API.
|
||||
/// 4. Executes any requested function calls via MCP `tools/call`.
|
||||
/// 5. Loops until the model produces a text-only response or an error.
|
||||
/// 6. Tracks token usage from the API response metadata.
|
||||
pub struct GeminiRuntime {
|
||||
/// Whether a stop has been requested.
|
||||
cancelled: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl GeminiRuntime {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
cancelled: Arc::new(AtomicBool::new(false)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentRuntime for GeminiRuntime {
|
||||
async fn start(
|
||||
&self,
|
||||
ctx: RuntimeContext,
|
||||
tx: broadcast::Sender<AgentEvent>,
|
||||
event_log: Arc<Mutex<Vec<AgentEvent>>>,
|
||||
log_writer: Option<Arc<Mutex<AgentLogWriter>>>,
|
||||
) -> Result<RuntimeResult, String> {
|
||||
let api_key = std::env::var("GOOGLE_AI_API_KEY").map_err(|_| {
|
||||
"GOOGLE_AI_API_KEY environment variable is not set. \
|
||||
Set it to your Google AI API key to use the Gemini runtime."
|
||||
.to_string()
|
||||
})?;
|
||||
|
||||
let model = if ctx.command.starts_with("gemini") {
|
||||
// The pool puts the model into `command` for non-CLI runtimes,
|
||||
// but also check args for a --model flag.
|
||||
ctx.command.clone()
|
||||
} else {
|
||||
// Fall back to args: look for --model <value>
|
||||
ctx.args
|
||||
.iter()
|
||||
.position(|a| a == "--model")
|
||||
.and_then(|i| ctx.args.get(i + 1))
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "gemini-2.5-pro".to_string())
|
||||
};
|
||||
|
||||
let mcp_port = ctx.mcp_port;
|
||||
let mcp_base = format!("http://localhost:{mcp_port}/mcp");
|
||||
|
||||
let client = Client::new();
|
||||
let cancelled = Arc::clone(&self.cancelled);
|
||||
|
||||
// Step 1: Fetch MCP tool definitions and convert to Gemini format.
|
||||
let gemini_tools = fetch_and_convert_mcp_tools(&client, &mcp_base).await?;
|
||||
|
||||
// Step 2: Build the initial conversation contents.
|
||||
let system_instruction = build_system_instruction(&ctx);
|
||||
let mut contents: Vec<Value> = vec![json!({
|
||||
"role": "user",
|
||||
"parts": [{ "text": ctx.prompt }]
|
||||
})];
|
||||
|
||||
let mut total_usage = TokenUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
total_cost_usd: 0.0,
|
||||
};
|
||||
|
||||
let emit = |event: AgentEvent| {
|
||||
super::super::pty::emit_event(
|
||||
event,
|
||||
&tx,
|
||||
&event_log,
|
||||
log_writer.as_ref().map(|w| w.as_ref()),
|
||||
);
|
||||
};
|
||||
|
||||
emit(AgentEvent::Status {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
status: "running".to_string(),
|
||||
});
|
||||
|
||||
// Step 3: Conversation loop.
|
||||
let mut turn = 0u32;
|
||||
let max_turns = 200; // Safety limit
|
||||
|
||||
loop {
|
||||
if cancelled.load(Ordering::Relaxed) {
|
||||
emit(AgentEvent::Error {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
message: "Agent was stopped by user".to_string(),
|
||||
});
|
||||
return Ok(RuntimeResult {
|
||||
session_id: None,
|
||||
token_usage: Some(total_usage),
|
||||
});
|
||||
}
|
||||
|
||||
turn += 1;
|
||||
if turn > max_turns {
|
||||
emit(AgentEvent::Error {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
message: format!("Exceeded maximum turns ({max_turns})"),
|
||||
});
|
||||
return Ok(RuntimeResult {
|
||||
session_id: None,
|
||||
token_usage: Some(total_usage),
|
||||
});
|
||||
}
|
||||
|
||||
slog!("[gemini] Turn {turn} for {}:{}", ctx.story_id, ctx.agent_name);
|
||||
|
||||
let request_body = build_generate_content_request(
|
||||
&system_instruction,
|
||||
&contents,
|
||||
&gemini_tools,
|
||||
);
|
||||
|
||||
let url = format!(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}"
|
||||
);
|
||||
|
||||
let response = client
|
||||
.post(&url)
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Gemini API request failed: {e}"))?;
|
||||
|
||||
let status = response.status();
|
||||
let body: Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse Gemini API response: {e}"))?;
|
||||
|
||||
if !status.is_success() {
|
||||
let error_msg = body["error"]["message"]
|
||||
.as_str()
|
||||
.unwrap_or("Unknown API error");
|
||||
let err = format!("Gemini API error ({status}): {error_msg}");
|
||||
emit(AgentEvent::Error {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
message: err.clone(),
|
||||
});
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
// Accumulate token usage.
|
||||
if let Some(usage) = parse_usage_metadata(&body) {
|
||||
total_usage.input_tokens += usage.input_tokens;
|
||||
total_usage.output_tokens += usage.output_tokens;
|
||||
}
|
||||
|
||||
// Extract the candidate response.
|
||||
let candidate = body["candidates"]
|
||||
.as_array()
|
||||
.and_then(|c| c.first())
|
||||
.ok_or_else(|| "No candidates in Gemini response".to_string())?;
|
||||
|
||||
let parts = candidate["content"]["parts"]
|
||||
.as_array()
|
||||
.ok_or_else(|| "No parts in Gemini response candidate".to_string())?;
|
||||
|
||||
// Check finish reason.
|
||||
let finish_reason = candidate["finishReason"].as_str().unwrap_or("");
|
||||
|
||||
// Separate text parts and function call parts.
|
||||
let mut text_parts: Vec<String> = Vec::new();
|
||||
let mut function_calls: Vec<GeminiFunctionCall> = Vec::new();
|
||||
|
||||
for part in parts {
|
||||
if let Some(text) = part["text"].as_str() {
|
||||
text_parts.push(text.to_string());
|
||||
}
|
||||
if let Some(fc) = part.get("functionCall")
|
||||
&& let (Some(name), Some(args)) =
|
||||
(fc["name"].as_str(), fc.get("args"))
|
||||
{
|
||||
function_calls.push(GeminiFunctionCall {
|
||||
name: name.to_string(),
|
||||
args: args.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Emit any text output.
|
||||
for text in &text_parts {
|
||||
if !text.is_empty() {
|
||||
emit(AgentEvent::Output {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
text: text.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// If no function calls, the model is done.
|
||||
if function_calls.is_empty() {
|
||||
emit(AgentEvent::Done {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
session_id: None,
|
||||
});
|
||||
return Ok(RuntimeResult {
|
||||
session_id: None,
|
||||
token_usage: Some(total_usage),
|
||||
});
|
||||
}
|
||||
|
||||
// Add the model's response to the conversation.
|
||||
let model_parts: Vec<Value> = parts.to_vec();
|
||||
contents.push(json!({
|
||||
"role": "model",
|
||||
"parts": model_parts
|
||||
}));
|
||||
|
||||
// Execute function calls via MCP and build response parts.
|
||||
let mut response_parts: Vec<Value> = Vec::new();
|
||||
|
||||
for fc in &function_calls {
|
||||
if cancelled.load(Ordering::Relaxed) {
|
||||
break;
|
||||
}
|
||||
|
||||
slog!(
|
||||
"[gemini] Calling MCP tool '{}' for {}:{}",
|
||||
fc.name,
|
||||
ctx.story_id,
|
||||
ctx.agent_name
|
||||
);
|
||||
|
||||
emit(AgentEvent::Output {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
text: format!("\n[Tool call: {}]\n", fc.name),
|
||||
});
|
||||
|
||||
let tool_result =
|
||||
call_mcp_tool(&client, &mcp_base, &fc.name, &fc.args).await;
|
||||
|
||||
let response_value = match &tool_result {
|
||||
Ok(result) => {
|
||||
emit(AgentEvent::Output {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
text: format!(
|
||||
"[Tool result: {} chars]\n",
|
||||
result.len()
|
||||
),
|
||||
});
|
||||
json!({ "result": result })
|
||||
}
|
||||
Err(e) => {
|
||||
emit(AgentEvent::Output {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
text: format!("[Tool error: {e}]\n"),
|
||||
});
|
||||
json!({ "error": e })
|
||||
}
|
||||
};
|
||||
|
||||
response_parts.push(json!({
|
||||
"functionResponse": {
|
||||
"name": fc.name,
|
||||
"response": response_value
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
// Add function responses to the conversation.
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
"parts": response_parts
|
||||
}));
|
||||
|
||||
// If the model indicated it's done despite having function calls,
|
||||
// respect the finish reason.
|
||||
if finish_reason == "STOP" && function_calls.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
emit(AgentEvent::Done {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
session_id: None,
|
||||
});
|
||||
|
||||
Ok(RuntimeResult {
|
||||
session_id: None,
|
||||
token_usage: Some(total_usage),
|
||||
})
|
||||
}
|
||||
|
||||
fn stop(&self) {
|
||||
self.cancelled.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn get_status(&self) -> RuntimeStatus {
|
||||
if self.cancelled.load(Ordering::Relaxed) {
|
||||
RuntimeStatus::Failed
|
||||
} else {
|
||||
RuntimeStatus::Idle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Internal types ───────────────────────────────────────────────────
|
||||
|
||||
struct GeminiFunctionCall {
|
||||
name: String,
|
||||
args: Value,
|
||||
}
|
||||
|
||||
// ── Gemini API types (for serde) ─────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct GeminiFunctionDeclaration {
|
||||
name: String,
|
||||
description: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
parameters: Option<Value>,
|
||||
}
|
||||
|
||||
// ── Helper functions ─────────────────────────────────────────────────
|
||||
|
||||
/// Build the system instruction content from the RuntimeContext.
|
||||
fn build_system_instruction(ctx: &RuntimeContext) -> Value {
|
||||
// Use system_prompt from args if provided via --append-system-prompt,
|
||||
// otherwise use a sensible default.
|
||||
let system_text = ctx
|
||||
.args
|
||||
.iter()
|
||||
.position(|a| a == "--append-system-prompt")
|
||||
.and_then(|i| ctx.args.get(i + 1))
|
||||
.cloned()
|
||||
.unwrap_or_else(|| {
|
||||
format!(
|
||||
"You are an AI coding agent working on story {}. \
|
||||
You have access to tools via function calling. \
|
||||
Use them to complete the task. \
|
||||
Work in the directory: {}",
|
||||
ctx.story_id, ctx.cwd
|
||||
)
|
||||
});
|
||||
|
||||
json!({
|
||||
"parts": [{ "text": system_text }]
|
||||
})
|
||||
}
|
||||
|
||||
/// Build the full `generateContent` request body.
|
||||
fn build_generate_content_request(
|
||||
system_instruction: &Value,
|
||||
contents: &[Value],
|
||||
gemini_tools: &[GeminiFunctionDeclaration],
|
||||
) -> Value {
|
||||
let mut body = json!({
|
||||
"system_instruction": system_instruction,
|
||||
"contents": contents,
|
||||
"generationConfig": {
|
||||
"temperature": 0.2,
|
||||
"maxOutputTokens": 65536,
|
||||
}
|
||||
});
|
||||
|
||||
if !gemini_tools.is_empty() {
|
||||
body["tools"] = json!([{
|
||||
"functionDeclarations": gemini_tools
|
||||
}]);
|
||||
}
|
||||
|
||||
body
|
||||
}
|
||||
|
||||
/// Fetch MCP tool definitions from storkit's MCP server and convert
|
||||
/// them to Gemini function declaration format.
|
||||
async fn fetch_and_convert_mcp_tools(
|
||||
client: &Client,
|
||||
mcp_base: &str,
|
||||
) -> Result<Vec<GeminiFunctionDeclaration>, String> {
|
||||
let request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
});
|
||||
|
||||
let response = client
|
||||
.post(mcp_base)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to fetch MCP tools: {e}"))?;
|
||||
|
||||
let body: Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse MCP tools response: {e}"))?;
|
||||
|
||||
let tools = body["result"]["tools"]
|
||||
.as_array()
|
||||
.ok_or_else(|| "No tools array in MCP response".to_string())?;
|
||||
|
||||
let mut declarations = Vec::new();
|
||||
|
||||
for tool in tools {
|
||||
let name = tool["name"].as_str().unwrap_or("").to_string();
|
||||
let description = tool["description"].as_str().unwrap_or("").to_string();
|
||||
|
||||
if name.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Convert MCP inputSchema (JSON Schema) to Gemini parameters
|
||||
// (OpenAPI-subset schema). They are structurally compatible for
|
||||
// simple object schemas.
|
||||
let parameters = convert_mcp_schema_to_gemini(tool.get("inputSchema"));
|
||||
|
||||
declarations.push(GeminiFunctionDeclaration {
|
||||
name,
|
||||
description,
|
||||
parameters,
|
||||
});
|
||||
}
|
||||
|
||||
slog!("[gemini] Loaded {} MCP tools as function declarations", declarations.len());
|
||||
Ok(declarations)
|
||||
}
|
||||
|
||||
/// Convert an MCP inputSchema (JSON Schema) to a Gemini-compatible
|
||||
/// OpenAPI-subset parameter schema.
|
||||
///
|
||||
/// Gemini function calling expects parameters in OpenAPI format, which
|
||||
/// is structurally similar to JSON Schema for simple object types.
|
||||
/// We strip unsupported fields and ensure the type is "object".
|
||||
fn convert_mcp_schema_to_gemini(schema: Option<&Value>) -> Option<Value> {
|
||||
let schema = schema?;
|
||||
|
||||
// If the schema has no properties (empty tool), return None.
|
||||
let properties = schema.get("properties")?;
|
||||
if properties.as_object().is_some_and(|p| p.is_empty()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut result = json!({
|
||||
"type": "object",
|
||||
"properties": clean_schema_properties(properties),
|
||||
});
|
||||
|
||||
// Preserve required fields if present.
|
||||
if let Some(required) = schema.get("required") {
|
||||
result["required"] = required.clone();
|
||||
}
|
||||
|
||||
Some(result)
|
||||
}
|
||||
|
||||
/// Recursively clean schema properties to be Gemini-compatible.
|
||||
/// Removes unsupported JSON Schema keywords.
|
||||
fn clean_schema_properties(properties: &Value) -> Value {
|
||||
let Some(obj) = properties.as_object() else {
|
||||
return properties.clone();
|
||||
};
|
||||
|
||||
let mut cleaned = serde_json::Map::new();
|
||||
for (key, value) in obj {
|
||||
let mut prop = value.clone();
|
||||
// Remove JSON Schema keywords not supported by Gemini
|
||||
if let Some(p) = prop.as_object_mut() {
|
||||
p.remove("$schema");
|
||||
p.remove("additionalProperties");
|
||||
|
||||
// Recursively clean nested object properties
|
||||
if let Some(nested_props) = p.get("properties").cloned() {
|
||||
p.insert(
|
||||
"properties".to_string(),
|
||||
clean_schema_properties(&nested_props),
|
||||
);
|
||||
}
|
||||
|
||||
// Clean items schema for arrays
|
||||
if let Some(items) = p.get("items").cloned()
|
||||
&& let Some(items_obj) = items.as_object()
|
||||
{
|
||||
let mut cleaned_items = items_obj.clone();
|
||||
cleaned_items.remove("$schema");
|
||||
cleaned_items.remove("additionalProperties");
|
||||
p.insert("items".to_string(), Value::Object(cleaned_items));
|
||||
}
|
||||
}
|
||||
cleaned.insert(key.clone(), prop);
|
||||
}
|
||||
Value::Object(cleaned)
|
||||
}
|
||||
|
||||
/// Call an MCP tool via storkit's MCP server.
|
||||
async fn call_mcp_tool(
|
||||
client: &Client,
|
||||
mcp_base: &str,
|
||||
tool_name: &str,
|
||||
args: &Value,
|
||||
) -> Result<String, String> {
|
||||
let request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": args
|
||||
}
|
||||
});
|
||||
|
||||
let response = client
|
||||
.post(mcp_base)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("MCP tool call failed: {e}"))?;
|
||||
|
||||
let body: Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse MCP tool response: {e}"))?;
|
||||
|
||||
if let Some(error) = body.get("error") {
|
||||
let msg = error["message"].as_str().unwrap_or("Unknown MCP error");
|
||||
return Err(format!("MCP tool '{tool_name}' error: {msg}"));
|
||||
}
|
||||
|
||||
// MCP tools/call returns { result: { content: [{ type: "text", text: "..." }] } }
|
||||
let content = &body["result"]["content"];
|
||||
if let Some(arr) = content.as_array() {
|
||||
let texts: Vec<&str> = arr
|
||||
.iter()
|
||||
.filter_map(|c| c["text"].as_str())
|
||||
.collect();
|
||||
if !texts.is_empty() {
|
||||
return Ok(texts.join("\n"));
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to serializing the entire result.
|
||||
Ok(body["result"].to_string())
|
||||
}
|
||||
|
||||
/// Parse token usage metadata from a Gemini API response.
|
||||
fn parse_usage_metadata(response: &Value) -> Option<TokenUsage> {
|
||||
let metadata = response.get("usageMetadata")?;
|
||||
Some(TokenUsage {
|
||||
input_tokens: metadata
|
||||
.get("promptTokenCount")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
output_tokens: metadata
|
||||
.get("candidatesTokenCount")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
// Gemini doesn't have cache token fields, but we keep the struct uniform.
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
// Google AI API doesn't report cost; leave at 0.
|
||||
total_cost_usd: 0.0,
|
||||
})
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_simple_object() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"story_id": {
|
||||
"type": "string",
|
||||
"description": "Story identifier"
|
||||
}
|
||||
},
|
||||
"required": ["story_id"]
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap();
|
||||
assert_eq!(result["type"], "object");
|
||||
assert!(result["properties"]["story_id"].is_object());
|
||||
assert_eq!(result["required"][0], "story_id");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_empty_properties_returns_none() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
});
|
||||
|
||||
assert!(convert_mcp_schema_to_gemini(Some(&schema)).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_none_returns_none() {
|
||||
assert!(convert_mcp_schema_to_gemini(None).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_strips_additional_properties() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"additionalProperties": false,
|
||||
"$schema": "http://json-schema.org/draft-07/schema#"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap();
|
||||
let name_prop = &result["properties"]["name"];
|
||||
assert!(name_prop.get("additionalProperties").is_none());
|
||||
assert!(name_prop.get("$schema").is_none());
|
||||
assert_eq!(name_prop["type"], "string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_with_nested_objects() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": { "type": "string" }
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap();
|
||||
assert!(result["properties"]["config"]["properties"]["key"].is_object());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_with_array_items() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" }
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap();
|
||||
let items_schema = &result["properties"]["items"]["items"];
|
||||
assert!(items_schema.get("additionalProperties").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_system_instruction_uses_args() {
|
||||
let ctx = RuntimeContext {
|
||||
story_id: "42_story_test".to_string(),
|
||||
agent_name: "coder-1".to_string(),
|
||||
command: "gemini-2.5-pro".to_string(),
|
||||
args: vec![
|
||||
"--append-system-prompt".to_string(),
|
||||
"Custom system prompt".to_string(),
|
||||
],
|
||||
prompt: "Do the thing".to_string(),
|
||||
cwd: "/tmp/wt".to_string(),
|
||||
inactivity_timeout_secs: 300,
|
||||
mcp_port: 3001,
|
||||
};
|
||||
|
||||
let instruction = build_system_instruction(&ctx);
|
||||
assert_eq!(instruction["parts"][0]["text"], "Custom system prompt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_system_instruction_default() {
|
||||
let ctx = RuntimeContext {
|
||||
story_id: "42_story_test".to_string(),
|
||||
agent_name: "coder-1".to_string(),
|
||||
command: "gemini-2.5-pro".to_string(),
|
||||
args: vec![],
|
||||
prompt: "Do the thing".to_string(),
|
||||
cwd: "/tmp/wt".to_string(),
|
||||
inactivity_timeout_secs: 300,
|
||||
mcp_port: 3001,
|
||||
};
|
||||
|
||||
let instruction = build_system_instruction(&ctx);
|
||||
let text = instruction["parts"][0]["text"].as_str().unwrap();
|
||||
assert!(text.contains("42_story_test"));
|
||||
assert!(text.contains("/tmp/wt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_generate_content_request_includes_tools() {
|
||||
let system = json!({"parts": [{"text": "system"}]});
|
||||
let contents = vec![json!({"role": "user", "parts": [{"text": "hello"}]})];
|
||||
let tools = vec![GeminiFunctionDeclaration {
|
||||
name: "my_tool".to_string(),
|
||||
description: "A tool".to_string(),
|
||||
parameters: Some(json!({"type": "object", "properties": {"x": {"type": "string"}}})),
|
||||
}];
|
||||
|
||||
let body = build_generate_content_request(&system, &contents, &tools);
|
||||
assert!(body["tools"][0]["functionDeclarations"].is_array());
|
||||
assert_eq!(body["tools"][0]["functionDeclarations"][0]["name"], "my_tool");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_generate_content_request_no_tools() {
|
||||
let system = json!({"parts": [{"text": "system"}]});
|
||||
let contents = vec![json!({"role": "user", "parts": [{"text": "hello"}]})];
|
||||
let tools: Vec<GeminiFunctionDeclaration> = vec![];
|
||||
|
||||
let body = build_generate_content_request(&system, &contents, &tools);
|
||||
assert!(body.get("tools").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_usage_metadata_valid() {
|
||||
let response = json!({
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 100,
|
||||
"candidatesTokenCount": 50,
|
||||
"totalTokenCount": 150
|
||||
}
|
||||
});
|
||||
|
||||
let usage = parse_usage_metadata(&response).unwrap();
|
||||
assert_eq!(usage.input_tokens, 100);
|
||||
assert_eq!(usage.output_tokens, 50);
|
||||
assert_eq!(usage.cache_creation_input_tokens, 0);
|
||||
assert_eq!(usage.total_cost_usd, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_usage_metadata_missing() {
|
||||
let response = json!({"candidates": []});
|
||||
assert!(parse_usage_metadata(&response).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gemini_runtime_stop_sets_cancelled() {
|
||||
let runtime = GeminiRuntime::new();
|
||||
assert_eq!(runtime.get_status(), RuntimeStatus::Idle);
|
||||
runtime.stop();
|
||||
assert_eq!(runtime.get_status(), RuntimeStatus::Failed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_extraction_from_command() {
|
||||
// When command starts with "gemini", use it as model name
|
||||
let ctx = RuntimeContext {
|
||||
story_id: "1".to_string(),
|
||||
agent_name: "coder".to_string(),
|
||||
command: "gemini-2.5-pro".to_string(),
|
||||
args: vec![],
|
||||
prompt: "test".to_string(),
|
||||
cwd: "/tmp".to_string(),
|
||||
inactivity_timeout_secs: 300,
|
||||
mcp_port: 3001,
|
||||
};
|
||||
|
||||
// The model extraction logic is inside start(), but we test the
|
||||
// condition here.
|
||||
assert!(ctx.command.starts_with("gemini"));
|
||||
}
|
||||
}
|
||||
@@ -1,159 +0,0 @@
|
||||
mod claude_code;
|
||||
mod gemini;
|
||||
mod openai;
|
||||
|
||||
pub use claude_code::ClaudeCodeRuntime;
|
||||
pub use gemini::GeminiRuntime;
|
||||
pub use openai::OpenAiRuntime;
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use crate::agent_log::AgentLogWriter;
|
||||
|
||||
use super::{AgentEvent, TokenUsage};
|
||||
|
||||
/// Context passed to a runtime when launching an agent session.
|
||||
pub struct RuntimeContext {
|
||||
pub story_id: String,
|
||||
pub agent_name: String,
|
||||
pub command: String,
|
||||
pub args: Vec<String>,
|
||||
pub prompt: String,
|
||||
pub cwd: String,
|
||||
pub inactivity_timeout_secs: u64,
|
||||
/// Port of the storkit MCP server, used by API-based runtimes (Gemini, OpenAI)
|
||||
/// to call back for tool execution.
|
||||
pub mcp_port: u16,
|
||||
}
|
||||
|
||||
/// Result returned by a runtime after the agent session completes.
|
||||
pub struct RuntimeResult {
|
||||
pub session_id: Option<String>,
|
||||
pub token_usage: Option<TokenUsage>,
|
||||
}
|
||||
|
||||
/// Runtime status reported by the backend.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[allow(dead_code)]
|
||||
pub enum RuntimeStatus {
|
||||
Idle,
|
||||
Running,
|
||||
Completed,
|
||||
Failed,
|
||||
}
|
||||
|
||||
/// Abstraction over different agent execution backends.
|
||||
///
|
||||
/// Implementations:
|
||||
/// - [`ClaudeCodeRuntime`]: spawns the `claude` CLI via a PTY (default, `runtime = "claude-code"`)
|
||||
///
|
||||
/// Future implementations could include OpenAI and Gemini API runtimes.
|
||||
#[allow(dead_code)]
|
||||
pub trait AgentRuntime: Send + Sync {
|
||||
/// Start the agent and drive it to completion, streaming events through
|
||||
/// the provided broadcast sender and event log.
|
||||
///
|
||||
/// Returns when the agent session finishes (success or error).
|
||||
async fn start(
|
||||
&self,
|
||||
ctx: RuntimeContext,
|
||||
tx: broadcast::Sender<AgentEvent>,
|
||||
event_log: Arc<Mutex<Vec<AgentEvent>>>,
|
||||
log_writer: Option<Arc<Mutex<AgentLogWriter>>>,
|
||||
) -> Result<RuntimeResult, String>;
|
||||
|
||||
/// Stop the running agent.
|
||||
fn stop(&self);
|
||||
|
||||
/// Get the current runtime status.
|
||||
fn get_status(&self) -> RuntimeStatus;
|
||||
|
||||
/// Return any events buffered outside the broadcast channel.
|
||||
///
|
||||
/// PTY-based runtimes stream directly to the broadcast channel; this
|
||||
/// returns empty by default. API-based runtimes may buffer events here.
|
||||
fn stream_events(&self) -> Vec<AgentEvent> {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn runtime_context_fields() {
|
||||
let ctx = RuntimeContext {
|
||||
story_id: "42_story_foo".to_string(),
|
||||
agent_name: "coder-1".to_string(),
|
||||
command: "claude".to_string(),
|
||||
args: vec!["--model".to_string(), "sonnet".to_string()],
|
||||
prompt: "Do the thing".to_string(),
|
||||
cwd: "/tmp/wt".to_string(),
|
||||
inactivity_timeout_secs: 300,
|
||||
mcp_port: 3001,
|
||||
};
|
||||
assert_eq!(ctx.story_id, "42_story_foo");
|
||||
assert_eq!(ctx.agent_name, "coder-1");
|
||||
assert_eq!(ctx.command, "claude");
|
||||
assert_eq!(ctx.args.len(), 2);
|
||||
assert_eq!(ctx.prompt, "Do the thing");
|
||||
assert_eq!(ctx.cwd, "/tmp/wt");
|
||||
assert_eq!(ctx.inactivity_timeout_secs, 300);
|
||||
assert_eq!(ctx.mcp_port, 3001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runtime_result_fields() {
|
||||
let result = RuntimeResult {
|
||||
session_id: Some("sess-123".to_string()),
|
||||
token_usage: Some(TokenUsage {
|
||||
input_tokens: 100,
|
||||
output_tokens: 50,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
total_cost_usd: 0.01,
|
||||
}),
|
||||
};
|
||||
assert_eq!(result.session_id, Some("sess-123".to_string()));
|
||||
assert!(result.token_usage.is_some());
|
||||
let usage = result.token_usage.unwrap();
|
||||
assert_eq!(usage.input_tokens, 100);
|
||||
assert_eq!(usage.output_tokens, 50);
|
||||
assert_eq!(usage.total_cost_usd, 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runtime_result_no_usage() {
|
||||
let result = RuntimeResult {
|
||||
session_id: None,
|
||||
token_usage: None,
|
||||
};
|
||||
assert!(result.session_id.is_none());
|
||||
assert!(result.token_usage.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runtime_status_variants() {
|
||||
assert_eq!(RuntimeStatus::Idle, RuntimeStatus::Idle);
|
||||
assert_ne!(RuntimeStatus::Running, RuntimeStatus::Completed);
|
||||
assert_ne!(RuntimeStatus::Failed, RuntimeStatus::Idle);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn claude_code_runtime_get_status_returns_idle() {
|
||||
use std::collections::HashMap;
|
||||
let killers = Arc::new(Mutex::new(HashMap::new()));
|
||||
let runtime = ClaudeCodeRuntime::new(killers);
|
||||
assert_eq!(runtime.get_status(), RuntimeStatus::Idle);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn claude_code_runtime_stream_events_empty() {
|
||||
use std::collections::HashMap;
|
||||
let killers = Arc::new(Mutex::new(HashMap::new()));
|
||||
let runtime = ClaudeCodeRuntime::new(killers);
|
||||
assert!(runtime.stream_events().is_empty());
|
||||
}
|
||||
}
|
||||
@@ -1,704 +0,0 @@
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use reqwest::Client;
|
||||
use serde_json::{json, Value};
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use crate::agent_log::AgentLogWriter;
|
||||
use crate::slog;
|
||||
|
||||
use super::super::{AgentEvent, TokenUsage};
|
||||
use super::{AgentRuntime, RuntimeContext, RuntimeResult, RuntimeStatus};
|
||||
|
||||
// ── Public runtime struct ────────────────────────────────────────────
|
||||
|
||||
/// Agent runtime that drives an OpenAI model (GPT-4o, o3, etc.) through
|
||||
/// the OpenAI Chat Completions API.
|
||||
///
|
||||
/// The runtime:
|
||||
/// 1. Fetches MCP tool definitions from storkit's MCP server.
|
||||
/// 2. Converts them to OpenAI function-calling format.
|
||||
/// 3. Sends the agent prompt + tools to the Chat Completions API.
|
||||
/// 4. Executes any requested tool calls via MCP `tools/call`.
|
||||
/// 5. Loops until the model produces a response with no tool calls.
|
||||
/// 6. Tracks token usage from the API response.
|
||||
pub struct OpenAiRuntime {
|
||||
/// Whether a stop has been requested.
|
||||
cancelled: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl OpenAiRuntime {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
cancelled: Arc::new(AtomicBool::new(false)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentRuntime for OpenAiRuntime {
|
||||
async fn start(
|
||||
&self,
|
||||
ctx: RuntimeContext,
|
||||
tx: broadcast::Sender<AgentEvent>,
|
||||
event_log: Arc<Mutex<Vec<AgentEvent>>>,
|
||||
log_writer: Option<Arc<Mutex<AgentLogWriter>>>,
|
||||
) -> Result<RuntimeResult, String> {
|
||||
let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
|
||||
"OPENAI_API_KEY environment variable is not set. \
|
||||
Set it to your OpenAI API key to use the OpenAI runtime."
|
||||
.to_string()
|
||||
})?;
|
||||
|
||||
let model = if ctx.command.starts_with("gpt") || ctx.command.starts_with("o") {
|
||||
// The pool puts the model into `command` for non-CLI runtimes.
|
||||
ctx.command.clone()
|
||||
} else {
|
||||
// Fall back to args: look for --model <value>
|
||||
ctx.args
|
||||
.iter()
|
||||
.position(|a| a == "--model")
|
||||
.and_then(|i| ctx.args.get(i + 1))
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "gpt-4o".to_string())
|
||||
};
|
||||
|
||||
let mcp_port = ctx.mcp_port;
|
||||
let mcp_base = format!("http://localhost:{mcp_port}/mcp");
|
||||
|
||||
let client = Client::new();
|
||||
let cancelled = Arc::clone(&self.cancelled);
|
||||
|
||||
// Step 1: Fetch MCP tool definitions and convert to OpenAI format.
|
||||
let openai_tools = fetch_and_convert_mcp_tools(&client, &mcp_base).await?;
|
||||
|
||||
// Step 2: Build the initial conversation messages.
|
||||
let system_text = build_system_text(&ctx);
|
||||
let mut messages: Vec<Value> = vec![
|
||||
json!({ "role": "system", "content": system_text }),
|
||||
json!({ "role": "user", "content": ctx.prompt }),
|
||||
];
|
||||
|
||||
let mut total_usage = TokenUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
total_cost_usd: 0.0,
|
||||
};
|
||||
|
||||
let emit = |event: AgentEvent| {
|
||||
super::super::pty::emit_event(
|
||||
event,
|
||||
&tx,
|
||||
&event_log,
|
||||
log_writer.as_ref().map(|w| w.as_ref()),
|
||||
);
|
||||
};
|
||||
|
||||
emit(AgentEvent::Status {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
status: "running".to_string(),
|
||||
});
|
||||
|
||||
// Step 3: Conversation loop.
|
||||
let mut turn = 0u32;
|
||||
let max_turns = 200; // Safety limit
|
||||
|
||||
loop {
|
||||
if cancelled.load(Ordering::Relaxed) {
|
||||
emit(AgentEvent::Error {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
message: "Agent was stopped by user".to_string(),
|
||||
});
|
||||
return Ok(RuntimeResult {
|
||||
session_id: None,
|
||||
token_usage: Some(total_usage),
|
||||
});
|
||||
}
|
||||
|
||||
turn += 1;
|
||||
if turn > max_turns {
|
||||
emit(AgentEvent::Error {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
message: format!("Exceeded maximum turns ({max_turns})"),
|
||||
});
|
||||
return Ok(RuntimeResult {
|
||||
session_id: None,
|
||||
token_usage: Some(total_usage),
|
||||
});
|
||||
}
|
||||
|
||||
slog!(
|
||||
"[openai] Turn {turn} for {}:{}",
|
||||
ctx.story_id,
|
||||
ctx.agent_name
|
||||
);
|
||||
|
||||
let mut request_body = json!({
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": 0.2,
|
||||
});
|
||||
|
||||
if !openai_tools.is_empty() {
|
||||
request_body["tools"] = json!(openai_tools);
|
||||
}
|
||||
|
||||
let response = client
|
||||
.post("https://api.openai.com/v1/chat/completions")
|
||||
.bearer_auth(&api_key)
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("OpenAI API request failed: {e}"))?;
|
||||
|
||||
let status = response.status();
|
||||
let body: Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse OpenAI API response: {e}"))?;
|
||||
|
||||
if !status.is_success() {
|
||||
let error_msg = body["error"]["message"]
|
||||
.as_str()
|
||||
.unwrap_or("Unknown API error");
|
||||
let err = format!("OpenAI API error ({status}): {error_msg}");
|
||||
emit(AgentEvent::Error {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
message: err.clone(),
|
||||
});
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
// Accumulate token usage.
|
||||
if let Some(usage) = parse_usage(&body) {
|
||||
total_usage.input_tokens += usage.input_tokens;
|
||||
total_usage.output_tokens += usage.output_tokens;
|
||||
}
|
||||
|
||||
// Extract the first choice.
|
||||
let choice = body["choices"]
|
||||
.as_array()
|
||||
.and_then(|c| c.first())
|
||||
.ok_or_else(|| "No choices in OpenAI response".to_string())?;
|
||||
|
||||
let message = &choice["message"];
|
||||
let content = message["content"].as_str().unwrap_or("");
|
||||
|
||||
// Emit any text content.
|
||||
if !content.is_empty() {
|
||||
emit(AgentEvent::Output {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
text: content.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check for tool calls.
|
||||
let tool_calls = message["tool_calls"].as_array();
|
||||
|
||||
if tool_calls.is_none() || tool_calls.is_some_and(|tc| tc.is_empty()) {
|
||||
// No tool calls — model is done.
|
||||
emit(AgentEvent::Done {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
session_id: None,
|
||||
});
|
||||
return Ok(RuntimeResult {
|
||||
session_id: None,
|
||||
token_usage: Some(total_usage),
|
||||
});
|
||||
}
|
||||
|
||||
let tool_calls = tool_calls.unwrap();
|
||||
|
||||
// Add the assistant message (with tool_calls) to the conversation.
|
||||
messages.push(message.clone());
|
||||
|
||||
// Execute each tool call via MCP and add results.
|
||||
for tc in tool_calls {
|
||||
if cancelled.load(Ordering::Relaxed) {
|
||||
break;
|
||||
}
|
||||
|
||||
let call_id = tc["id"].as_str().unwrap_or("");
|
||||
let function = &tc["function"];
|
||||
let tool_name = function["name"].as_str().unwrap_or("");
|
||||
let arguments_str = function["arguments"].as_str().unwrap_or("{}");
|
||||
|
||||
let args: Value = serde_json::from_str(arguments_str).unwrap_or(json!({}));
|
||||
|
||||
slog!(
|
||||
"[openai] Calling MCP tool '{}' for {}:{}",
|
||||
tool_name,
|
||||
ctx.story_id,
|
||||
ctx.agent_name
|
||||
);
|
||||
|
||||
emit(AgentEvent::Output {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
text: format!("\n[Tool call: {tool_name}]\n"),
|
||||
});
|
||||
|
||||
let tool_result = call_mcp_tool(&client, &mcp_base, tool_name, &args).await;
|
||||
|
||||
let result_content = match &tool_result {
|
||||
Ok(result) => {
|
||||
emit(AgentEvent::Output {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
text: format!("[Tool result: {} chars]\n", result.len()),
|
||||
});
|
||||
result.clone()
|
||||
}
|
||||
Err(e) => {
|
||||
emit(AgentEvent::Output {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
text: format!("[Tool error: {e}]\n"),
|
||||
});
|
||||
format!("Error: {e}")
|
||||
}
|
||||
};
|
||||
|
||||
// OpenAI expects tool results as role=tool messages with
|
||||
// the matching tool_call_id.
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": result_content,
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn stop(&self) {
|
||||
self.cancelled.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn get_status(&self) -> RuntimeStatus {
|
||||
if self.cancelled.load(Ordering::Relaxed) {
|
||||
RuntimeStatus::Failed
|
||||
} else {
|
||||
RuntimeStatus::Idle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helper functions ─────────────────────────────────────────────────
|
||||
|
||||
/// Build the system message text from the RuntimeContext.
|
||||
fn build_system_text(ctx: &RuntimeContext) -> String {
|
||||
ctx.args
|
||||
.iter()
|
||||
.position(|a| a == "--append-system-prompt")
|
||||
.and_then(|i| ctx.args.get(i + 1))
|
||||
.cloned()
|
||||
.unwrap_or_else(|| {
|
||||
format!(
|
||||
"You are an AI coding agent working on story {}. \
|
||||
You have access to tools via function calling. \
|
||||
Use them to complete the task. \
|
||||
Work in the directory: {}",
|
||||
ctx.story_id, ctx.cwd
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Fetch MCP tool definitions from storkit's MCP server and convert
|
||||
/// them to OpenAI function-calling format.
|
||||
async fn fetch_and_convert_mcp_tools(
|
||||
client: &Client,
|
||||
mcp_base: &str,
|
||||
) -> Result<Vec<Value>, String> {
|
||||
let request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
});
|
||||
|
||||
let response = client
|
||||
.post(mcp_base)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to fetch MCP tools: {e}"))?;
|
||||
|
||||
let body: Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse MCP tools response: {e}"))?;
|
||||
|
||||
let tools = body["result"]["tools"]
|
||||
.as_array()
|
||||
.ok_or_else(|| "No tools array in MCP response".to_string())?;
|
||||
|
||||
let mut openai_tools = Vec::new();
|
||||
|
||||
for tool in tools {
|
||||
let name = tool["name"].as_str().unwrap_or("").to_string();
|
||||
let description = tool["description"].as_str().unwrap_or("").to_string();
|
||||
|
||||
if name.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// OpenAI function calling uses JSON Schema natively for parameters,
|
||||
// so the MCP inputSchema can be used with minimal cleanup.
|
||||
let parameters = convert_mcp_schema_to_openai(tool.get("inputSchema"));
|
||||
|
||||
openai_tools.push(json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": parameters.unwrap_or_else(|| json!({"type": "object", "properties": {}})),
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
slog!(
|
||||
"[openai] Loaded {} MCP tools as function definitions",
|
||||
openai_tools.len()
|
||||
);
|
||||
Ok(openai_tools)
|
||||
}
|
||||
|
||||
/// Convert an MCP inputSchema (JSON Schema) to OpenAI-compatible
|
||||
/// function parameters.
|
||||
///
|
||||
/// OpenAI uses JSON Schema natively, so less transformation is needed
|
||||
/// compared to Gemini. We still strip `$schema` to keep payloads clean.
|
||||
fn convert_mcp_schema_to_openai(schema: Option<&Value>) -> Option<Value> {
|
||||
let schema = schema?;
|
||||
|
||||
let mut result = json!({
|
||||
"type": "object",
|
||||
});
|
||||
|
||||
if let Some(properties) = schema.get("properties") {
|
||||
result["properties"] = clean_schema_properties(properties);
|
||||
} else {
|
||||
result["properties"] = json!({});
|
||||
}
|
||||
|
||||
if let Some(required) = schema.get("required") {
|
||||
result["required"] = required.clone();
|
||||
}
|
||||
|
||||
// OpenAI recommends additionalProperties: false for strict mode.
|
||||
result["additionalProperties"] = json!(false);
|
||||
|
||||
Some(result)
|
||||
}
|
||||
|
||||
/// Recursively clean schema properties, removing unsupported keywords.
|
||||
fn clean_schema_properties(properties: &Value) -> Value {
|
||||
let Some(obj) = properties.as_object() else {
|
||||
return properties.clone();
|
||||
};
|
||||
|
||||
let mut cleaned = serde_json::Map::new();
|
||||
for (key, value) in obj {
|
||||
let mut prop = value.clone();
|
||||
if let Some(p) = prop.as_object_mut() {
|
||||
p.remove("$schema");
|
||||
|
||||
// Recursively clean nested object properties.
|
||||
if let Some(nested_props) = p.get("properties").cloned() {
|
||||
p.insert(
|
||||
"properties".to_string(),
|
||||
clean_schema_properties(&nested_props),
|
||||
);
|
||||
}
|
||||
|
||||
// Clean items schema for arrays.
|
||||
if let Some(items) = p.get("items").cloned()
|
||||
&& let Some(items_obj) = items.as_object()
|
||||
{
|
||||
let mut cleaned_items = items_obj.clone();
|
||||
cleaned_items.remove("$schema");
|
||||
p.insert("items".to_string(), Value::Object(cleaned_items));
|
||||
}
|
||||
}
|
||||
cleaned.insert(key.clone(), prop);
|
||||
}
|
||||
Value::Object(cleaned)
|
||||
}
|
||||
|
||||
/// Call an MCP tool via storkit's MCP server.
|
||||
async fn call_mcp_tool(
|
||||
client: &Client,
|
||||
mcp_base: &str,
|
||||
tool_name: &str,
|
||||
args: &Value,
|
||||
) -> Result<String, String> {
|
||||
let request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": args
|
||||
}
|
||||
});
|
||||
|
||||
let response = client
|
||||
.post(mcp_base)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("MCP tool call failed: {e}"))?;
|
||||
|
||||
let body: Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse MCP tool response: {e}"))?;
|
||||
|
||||
if let Some(error) = body.get("error") {
|
||||
let msg = error["message"].as_str().unwrap_or("Unknown MCP error");
|
||||
return Err(format!("MCP tool '{tool_name}' error: {msg}"));
|
||||
}
|
||||
|
||||
// MCP tools/call returns { result: { content: [{ type: "text", text: "..." }] } }
|
||||
let content = &body["result"]["content"];
|
||||
if let Some(arr) = content.as_array() {
|
||||
let texts: Vec<&str> = arr
|
||||
.iter()
|
||||
.filter_map(|c| c["text"].as_str())
|
||||
.collect();
|
||||
if !texts.is_empty() {
|
||||
return Ok(texts.join("\n"));
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to serializing the entire result.
|
||||
Ok(body["result"].to_string())
|
||||
}
|
||||
|
||||
/// Parse token usage from an OpenAI API response.
|
||||
fn parse_usage(response: &Value) -> Option<TokenUsage> {
|
||||
let usage = response.get("usage")?;
|
||||
Some(TokenUsage {
|
||||
input_tokens: usage
|
||||
.get("prompt_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
output_tokens: usage
|
||||
.get("completion_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
// OpenAI API doesn't report cost directly; leave at 0.
|
||||
total_cost_usd: 0.0,
|
||||
})
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_simple_object() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"story_id": {
|
||||
"type": "string",
|
||||
"description": "Story identifier"
|
||||
}
|
||||
},
|
||||
"required": ["story_id"]
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
|
||||
assert_eq!(result["type"], "object");
|
||||
assert!(result["properties"]["story_id"].is_object());
|
||||
assert_eq!(result["required"][0], "story_id");
|
||||
assert_eq!(result["additionalProperties"], false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_empty_properties() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
|
||||
assert_eq!(result["type"], "object");
|
||||
assert!(result["properties"].as_object().unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_none_returns_none() {
|
||||
assert!(convert_mcp_schema_to_openai(None).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_strips_dollar_schema() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"$schema": "http://json-schema.org/draft-07/schema#"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
|
||||
let name_prop = &result["properties"]["name"];
|
||||
assert!(name_prop.get("$schema").is_none());
|
||||
assert_eq!(name_prop["type"], "string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_with_nested_objects() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": { "type": "string" }
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
|
||||
assert!(result["properties"]["config"]["properties"]["key"].is_object());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_with_array_items() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" }
|
||||
},
|
||||
"$schema": "http://json-schema.org/draft-07/schema#"
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
|
||||
let items_schema = &result["properties"]["items"]["items"];
|
||||
assert!(items_schema.get("$schema").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_system_text_uses_args() {
|
||||
let ctx = RuntimeContext {
|
||||
story_id: "42_story_test".to_string(),
|
||||
agent_name: "coder-1".to_string(),
|
||||
command: "gpt-4o".to_string(),
|
||||
args: vec![
|
||||
"--append-system-prompt".to_string(),
|
||||
"Custom system prompt".to_string(),
|
||||
],
|
||||
prompt: "Do the thing".to_string(),
|
||||
cwd: "/tmp/wt".to_string(),
|
||||
inactivity_timeout_secs: 300,
|
||||
mcp_port: 3001,
|
||||
};
|
||||
|
||||
assert_eq!(build_system_text(&ctx), "Custom system prompt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_system_text_default() {
|
||||
let ctx = RuntimeContext {
|
||||
story_id: "42_story_test".to_string(),
|
||||
agent_name: "coder-1".to_string(),
|
||||
command: "gpt-4o".to_string(),
|
||||
args: vec![],
|
||||
prompt: "Do the thing".to_string(),
|
||||
cwd: "/tmp/wt".to_string(),
|
||||
inactivity_timeout_secs: 300,
|
||||
mcp_port: 3001,
|
||||
};
|
||||
|
||||
let text = build_system_text(&ctx);
|
||||
assert!(text.contains("42_story_test"));
|
||||
assert!(text.contains("/tmp/wt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_usage_valid() {
|
||||
let response = json!({
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150
|
||||
}
|
||||
});
|
||||
|
||||
let usage = parse_usage(&response).unwrap();
|
||||
assert_eq!(usage.input_tokens, 100);
|
||||
assert_eq!(usage.output_tokens, 50);
|
||||
assert_eq!(usage.cache_creation_input_tokens, 0);
|
||||
assert_eq!(usage.total_cost_usd, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_usage_missing() {
|
||||
let response = json!({"choices": []});
|
||||
assert!(parse_usage(&response).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn openai_runtime_stop_sets_cancelled() {
|
||||
let runtime = OpenAiRuntime::new();
|
||||
assert_eq!(runtime.get_status(), RuntimeStatus::Idle);
|
||||
runtime.stop();
|
||||
assert_eq!(runtime.get_status(), RuntimeStatus::Failed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_extraction_from_command_gpt() {
|
||||
let ctx = RuntimeContext {
|
||||
story_id: "1".to_string(),
|
||||
agent_name: "coder".to_string(),
|
||||
command: "gpt-4o".to_string(),
|
||||
args: vec![],
|
||||
prompt: "test".to_string(),
|
||||
cwd: "/tmp".to_string(),
|
||||
inactivity_timeout_secs: 300,
|
||||
mcp_port: 3001,
|
||||
};
|
||||
assert!(ctx.command.starts_with("gpt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_extraction_from_command_o3() {
|
||||
let ctx = RuntimeContext {
|
||||
story_id: "1".to_string(),
|
||||
agent_name: "coder".to_string(),
|
||||
command: "o3".to_string(),
|
||||
args: vec![],
|
||||
prompt: "test".to_string(),
|
||||
cwd: "/tmp".to_string(),
|
||||
inactivity_timeout_secs: 300,
|
||||
mcp_port: 3001,
|
||||
};
|
||||
assert!(ctx.command.starts_with("o"));
|
||||
}
|
||||
}
|
||||
@@ -1,202 +0,0 @@
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::TokenUsage;
|
||||
|
||||
/// A single token usage record persisted to disk.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct TokenUsageRecord {
|
||||
pub story_id: String,
|
||||
pub agent_name: String,
|
||||
pub timestamp: String,
|
||||
#[serde(default)]
|
||||
pub model: Option<String>,
|
||||
pub usage: TokenUsage,
|
||||
}
|
||||
|
||||
/// Append a token usage record to the persistent JSONL file.
|
||||
///
|
||||
/// Each line is a self-contained JSON object, making appends atomic and
|
||||
/// reads simple. The file lives at `.storkit/token_usage.jsonl`.
|
||||
pub fn append_record(project_root: &Path, record: &TokenUsageRecord) -> Result<(), String> {
|
||||
let path = token_usage_path(project_root);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.map_err(|e| format!("Failed to create token_usage directory: {e}"))?;
|
||||
}
|
||||
let mut line =
|
||||
serde_json::to_string(record).map_err(|e| format!("Failed to serialize record: {e}"))?;
|
||||
line.push('\n');
|
||||
use std::io::Write;
|
||||
let file = fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&path)
|
||||
.map_err(|e| format!("Failed to open token_usage file: {e}"))?;
|
||||
let mut writer = std::io::BufWriter::new(file);
|
||||
writer
|
||||
.write_all(line.as_bytes())
|
||||
.map_err(|e| format!("Failed to write token_usage record: {e}"))?;
|
||||
writer
|
||||
.flush()
|
||||
.map_err(|e| format!("Failed to flush token_usage file: {e}"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read all token usage records from the persistent file.
|
||||
pub fn read_all(project_root: &Path) -> Result<Vec<TokenUsageRecord>, String> {
|
||||
let path = token_usage_path(project_root);
|
||||
if !path.exists() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let content =
|
||||
fs::read_to_string(&path).map_err(|e| format!("Failed to read token_usage file: {e}"))?;
|
||||
let mut records = Vec::new();
|
||||
for line in content.lines() {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
match serde_json::from_str::<TokenUsageRecord>(trimmed) {
|
||||
Ok(record) => records.push(record),
|
||||
Err(e) => {
|
||||
crate::slog_warn!("[token_usage] Skipping malformed line: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
/// Build a `TokenUsageRecord` from the parts available at completion time.
|
||||
pub fn build_record(
|
||||
story_id: &str,
|
||||
agent_name: &str,
|
||||
model: Option<String>,
|
||||
usage: TokenUsage,
|
||||
) -> TokenUsageRecord {
|
||||
TokenUsageRecord {
|
||||
story_id: story_id.to_string(),
|
||||
agent_name: agent_name.to_string(),
|
||||
timestamp: Utc::now().to_rfc3339(),
|
||||
model,
|
||||
usage,
|
||||
}
|
||||
}
|
||||
|
||||
fn token_usage_path(project_root: &Path) -> std::path::PathBuf {
|
||||
project_root.join(".storkit").join("token_usage.jsonl")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn sample_usage() -> TokenUsage {
|
||||
TokenUsage {
|
||||
input_tokens: 100,
|
||||
output_tokens: 200,
|
||||
cache_creation_input_tokens: 5000,
|
||||
cache_read_input_tokens: 10000,
|
||||
total_cost_usd: 1.57,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn append_and_read_roundtrip() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let root = dir.path();
|
||||
|
||||
let record = build_record("42_story_foo", "coder-1", None, sample_usage());
|
||||
append_record(root, &record).unwrap();
|
||||
|
||||
let records = read_all(root).unwrap();
|
||||
assert_eq!(records.len(), 1);
|
||||
assert_eq!(records[0].story_id, "42_story_foo");
|
||||
assert_eq!(records[0].agent_name, "coder-1");
|
||||
assert_eq!(records[0].usage, sample_usage());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_appends_accumulate() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let root = dir.path();
|
||||
|
||||
let r1 = build_record("s1", "coder-1", None, sample_usage());
|
||||
let r2 = build_record("s2", "coder-2", None, sample_usage());
|
||||
append_record(root, &r1).unwrap();
|
||||
append_record(root, &r2).unwrap();
|
||||
|
||||
let records = read_all(root).unwrap();
|
||||
assert_eq!(records.len(), 2);
|
||||
assert_eq!(records[0].story_id, "s1");
|
||||
assert_eq!(records[1].story_id, "s2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_empty_returns_empty() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let records = read_all(dir.path()).unwrap();
|
||||
assert!(records.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn malformed_lines_are_skipped() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let root = dir.path();
|
||||
let path = root.join(".storkit").join("token_usage.jsonl");
|
||||
fs::create_dir_all(path.parent().unwrap()).unwrap();
|
||||
fs::write(&path, "not json\n{\"bad\":true}\n").unwrap();
|
||||
|
||||
let records = read_all(root).unwrap();
|
||||
assert!(records.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_usage_from_result_event() {
|
||||
let json = serde_json::json!({
|
||||
"type": "result",
|
||||
"total_cost_usd": 1.57,
|
||||
"usage": {
|
||||
"input_tokens": 7,
|
||||
"output_tokens": 475,
|
||||
"cache_creation_input_tokens": 185020,
|
||||
"cache_read_input_tokens": 810585
|
||||
}
|
||||
});
|
||||
|
||||
let usage = TokenUsage::from_result_event(&json).unwrap();
|
||||
assert_eq!(usage.input_tokens, 7);
|
||||
assert_eq!(usage.output_tokens, 475);
|
||||
assert_eq!(usage.cache_creation_input_tokens, 185020);
|
||||
assert_eq!(usage.cache_read_input_tokens, 810585);
|
||||
assert!((usage.total_cost_usd - 1.57).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_usage_from_result_event_missing_usage() {
|
||||
let json = serde_json::json!({"type": "result"});
|
||||
assert!(TokenUsage::from_result_event(&json).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_usage_from_result_event_partial_fields() {
|
||||
let json = serde_json::json!({
|
||||
"type": "result",
|
||||
"total_cost_usd": 0.5,
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20
|
||||
}
|
||||
});
|
||||
|
||||
let usage = TokenUsage::from_result_event(&json).unwrap();
|
||||
assert_eq!(usage.input_tokens, 10);
|
||||
assert_eq!(usage.output_tokens, 20);
|
||||
assert_eq!(usage.cache_creation_input_tokens, 0);
|
||||
assert_eq!(usage.cache_read_input_tokens, 0);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user