huskies: merge 494_story_mcp_tool_to_run_project_test_suite

This commit is contained in:
dave
2026-04-07 14:39:47 +00:00
parent 1b8c391836
commit 19768c23d5
5 changed files with 503 additions and 1 deletions
+233
View File
@@ -7,6 +7,8 @@ use std::path::PathBuf;
const DEFAULT_TIMEOUT_SECS: u64 = 120;
const MAX_TIMEOUT_SECS: u64 = 600;
const TEST_TIMEOUT_SECS: u64 = 600;
const MAX_OUTPUT_LINES: usize = 100;
/// Patterns that are unconditionally blocked regardless of context.
static BLOCKED_PATTERNS: &[&str] = &[
@@ -328,6 +330,117 @@ pub(super) fn handle_run_command_sse(
})))
}
/// Truncate output to at most `max_lines` lines, keeping the tail.
fn truncate_output(output: &str, max_lines: usize) -> String {
let lines: Vec<&str> = output.lines().collect();
if lines.len() <= max_lines {
return output.to_string();
}
let omitted = lines.len() - max_lines;
let tail = lines[lines.len() - max_lines..].join("\n");
format!("[... {omitted} lines omitted ...]\n{tail}")
}
/// Parse cumulative passed/failed counts from `cargo test` output lines like:
/// `"test result: ok. 5 passed; 0 failed; ..."`
fn parse_test_counts(output: &str) -> (u64, u64) {
let mut total_passed = 0u64;
let mut total_failed = 0u64;
for line in output.lines() {
if line.contains("test result:") {
if let Some(p) = extract_count(line, "passed") {
total_passed += p;
}
if let Some(f) = extract_count(line, "failed") {
total_failed += f;
}
}
}
(total_passed, total_failed)
}
/// Extract a count immediately before `label` in `line` (e.g. `"5 passed"` → 5).
fn extract_count(line: &str, label: &str) -> Option<u64> {
let pos = line.find(label)?;
let before = line[..pos].trim_end();
let num_str: String = before.chars().rev().take_while(|c| c.is_ascii_digit()).collect();
if num_str.is_empty() {
return None;
}
let num_str: String = num_str.chars().rev().collect();
num_str.parse().ok()
}
/// Run the project's `script/test` and return a structured result.
///
/// If `worktree_path` is provided the script is run from that worktree
/// (must be inside `.huskies/worktrees/`). Otherwise the project root is used.
pub(super) async fn tool_run_tests(args: &Value, ctx: &AppContext) -> Result<String, String> {
let project_root = ctx.agents.get_project_root(&ctx.state)?;
let working_dir = match args.get("worktree_path").and_then(|v| v.as_str()) {
Some(wt) => validate_working_dir(wt, ctx)?,
None => project_root
.canonicalize()
.map_err(|e| format!("Cannot canonicalize project root: {e}"))?,
};
let script_path = working_dir.join("script").join("test");
if !script_path.exists() {
return Err(format!(
"Test script not found: {}",
script_path.display()
));
}
let result = tokio::time::timeout(
std::time::Duration::from_secs(TEST_TIMEOUT_SECS),
tokio::task::spawn_blocking({
let dir = working_dir.clone();
let script = script_path.clone();
move || {
std::process::Command::new("bash")
.arg(&script)
.current_dir(&dir)
.output()
}
}),
)
.await;
match result {
Err(_) => serde_json::to_string_pretty(&json!({
"passed": false,
"exit_code": -1,
"timed_out": true,
"tests_passed": 0,
"tests_failed": 0,
"output": format!("Test suite timed out after {TEST_TIMEOUT_SECS}s"),
}))
.map_err(|e| format!("Serialization error: {e}")),
Ok(Err(e)) => Err(format!("Task join error: {e}")),
Ok(Ok(Err(e))) => Err(format!("Failed to execute test script: {e}")),
Ok(Ok(Ok(output))) => {
let passed = output.status.success();
let exit_code = output.status.code().unwrap_or(-1);
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
let combined = format!("{stdout}{stderr}");
let (tests_passed, tests_failed) = parse_test_counts(&combined);
let truncated = truncate_output(&combined, MAX_OUTPUT_LINES);
serde_json::to_string_pretty(&json!({
"passed": passed,
"exit_code": exit_code,
"timed_out": false,
"tests_passed": tests_passed,
"tests_failed": tests_failed,
"output": truncated,
}))
.map_err(|e| format!("Serialization error: {e}"))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -619,4 +732,124 @@ mod tests {
// Just ensure it doesn't panic and returns an Err about sandbox (not timeout)
assert!(result.is_err());
}
// ── tool_run_tests ────────────────────────────────────────────────
#[tokio::test]
async fn tool_run_tests_missing_script_returns_error() {
let tmp = tempfile::tempdir().unwrap();
let ctx = test_ctx(tmp.path());
// No script/test in tmp — should return Err
let result = tool_run_tests(&json!({}), &ctx).await;
assert!(result.is_err(), "expected error for missing script: {result:?}");
assert!(
result.unwrap_err().contains("not found"),
"error should mention 'not found'"
);
}
#[tokio::test]
async fn tool_run_tests_passes_when_script_exits_zero() {
let tmp = tempfile::tempdir().unwrap();
let script_dir = tmp.path().join("script");
std::fs::create_dir_all(&script_dir).unwrap();
let script_path = script_dir.join("test");
std::fs::write(&script_path, "#!/usr/bin/env bash\necho 'test result: ok. 3 passed; 0 failed'\nexit 0\n").unwrap();
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(&script_path, std::fs::Permissions::from_mode(0o755)).unwrap();
}
let ctx = test_ctx(tmp.path());
let result = tool_run_tests(&json!({}), &ctx).await.unwrap();
let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
assert_eq!(parsed["passed"], true);
assert_eq!(parsed["exit_code"], 0);
assert_eq!(parsed["timed_out"], false);
assert_eq!(parsed["tests_passed"], 3);
assert_eq!(parsed["tests_failed"], 0);
}
#[tokio::test]
async fn tool_run_tests_fails_when_script_exits_nonzero() {
let tmp = tempfile::tempdir().unwrap();
let script_dir = tmp.path().join("script");
std::fs::create_dir_all(&script_dir).unwrap();
let script_path = script_dir.join("test");
std::fs::write(&script_path, "#!/usr/bin/env bash\necho 'test result: FAILED. 1 passed; 2 failed'\nexit 1\n").unwrap();
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(&script_path, std::fs::Permissions::from_mode(0o755)).unwrap();
}
let ctx = test_ctx(tmp.path());
let result = tool_run_tests(&json!({}), &ctx).await.unwrap();
let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
assert_eq!(parsed["passed"], false);
assert_eq!(parsed["exit_code"], 1);
assert_eq!(parsed["tests_passed"], 1);
assert_eq!(parsed["tests_failed"], 2);
}
#[tokio::test]
async fn tool_run_tests_worktree_path_must_be_inside_worktrees() {
let tmp = tempfile::tempdir().unwrap();
let wt_dir = tmp.path().join(".huskies").join("worktrees");
std::fs::create_dir_all(&wt_dir).unwrap();
let ctx = test_ctx(tmp.path());
// tmp.path() itself is outside worktrees → should fail validation
let result =
tool_run_tests(&json!({"worktree_path": tmp.path().to_str().unwrap()}), &ctx).await;
assert!(result.is_err());
assert!(
result.unwrap_err().contains("worktrees"),
"expected sandbox error"
);
}
// ── truncate_output ───────────────────────────────────────────────
#[test]
fn truncate_output_short_text_unchanged() {
let text = "line1\nline2\nline3";
assert_eq!(truncate_output(text, 10), text);
}
#[test]
fn truncate_output_long_text_keeps_tail() {
let lines: Vec<String> = (1..=200).map(|i| format!("line {i}")).collect();
let text = lines.join("\n");
let result = truncate_output(&text, 50);
assert!(result.contains("line 200"), "should keep last line: {result}");
assert!(result.contains("omitted"), "should note omitted lines: {result}");
assert!(!result.contains("line 1\n"), "should not keep first line: {result}");
}
// ── parse_test_counts ─────────────────────────────────────────────
#[test]
fn parse_test_counts_extracts_passed_and_failed() {
let output = "test result: ok. 5 passed; 0 failed; 0 ignored\ntest result: FAILED. 2 passed; 3 failed;";
let (passed, failed) = parse_test_counts(output);
assert_eq!(passed, 7);
assert_eq!(failed, 3);
}
#[test]
fn parse_test_counts_no_results_returns_zeros() {
let (passed, failed) = parse_test_counts("no test output here");
assert_eq!(passed, 0);
assert_eq!(failed, 0);
}
#[test]
fn extract_count_finds_number_before_label() {
assert_eq!(extract_count("5 passed; 0 failed", "passed"), Some(5));
assert_eq!(extract_count("0 failed", "failed"), Some(0));
assert_eq!(extract_count("no number here passed", "passed"), None);
}
}