From 4344081b5430fb89698c2dc3d54152c7dad3dc92 Mon Sep 17 00:00:00 2001 From: Dave Date: Fri, 20 Mar 2026 22:05:25 +0000 Subject: [PATCH] storkit: merge 343_refactor_abstract_agent_runtime_to_support_non_claude_code_backends --- server/src/agents/mod.rs | 3 +- server/src/agents/pool/mod.rs | 52 +++++--- server/src/agents/pty.rs | 4 +- server/src/agents/runtime/claude_code.rs | 66 ++++++++++ server/src/agents/runtime/mod.rs | 150 +++++++++++++++++++++++ server/src/config.rs | 54 ++++++++ 6 files changed, 307 insertions(+), 22 deletions(-) create mode 100644 server/src/agents/runtime/claude_code.rs create mode 100644 server/src/agents/runtime/mod.rs diff --git a/server/src/agents/mod.rs b/server/src/agents/mod.rs index 268e99b..c5661a5 100644 --- a/server/src/agents/mod.rs +++ b/server/src/agents/mod.rs @@ -2,7 +2,8 @@ pub mod gates; pub mod lifecycle; pub mod merge; mod pool; -mod pty; +pub(crate) mod pty; +pub mod runtime; pub mod token_usage; use crate::config::AgentConfig; diff --git a/server/src/agents/pool/mod.rs b/server/src/agents/pool/mod.rs index 0b63661..85057e7 100644 --- a/server/src/agents/pool/mod.rs +++ b/server/src/agents/pool/mod.rs @@ -17,6 +17,7 @@ use super::{ AgentEvent, AgentInfo, AgentStatus, CompletionReport, PipelineStage, agent_config_stage, pipeline_stage, }; +use super::runtime::{AgentRuntime, ClaudeCodeRuntime, RuntimeContext}; /// Build the composite key used to track agents in the pool. fn composite_key(story_id: &str, agent_name: &str) -> String { @@ -513,25 +514,38 @@ impl AgentPool { }); Self::notify_agent_state_changed(&watcher_tx_clone); - // Step 4: launch the agent process. - match super::pty::run_agent_pty_streaming( - &sid, - &aname, - &command, - &args, - &prompt, - &wt_path_str, - &tx_clone, - &log_clone, - log_writer_clone, - inactivity_timeout_secs, - child_killers_clone, - ) - .await - { - Ok(pty_result) => { + // Step 4: launch the agent process via the configured runtime. + let runtime_name = config_clone + .find_agent(&aname) + .and_then(|a| a.runtime.as_deref()) + .unwrap_or("claude-code"); + + let run_result = match runtime_name { + "claude-code" => { + let runtime = ClaudeCodeRuntime::new(child_killers_clone.clone()); + let ctx = RuntimeContext { + story_id: sid.clone(), + agent_name: aname.clone(), + command, + args, + prompt, + cwd: wt_path_str, + inactivity_timeout_secs, + }; + runtime + .start(ctx, tx_clone.clone(), log_clone.clone(), log_writer_clone) + .await + } + other => Err(format!( + "Unknown agent runtime '{other}'; check the 'runtime' field in project.toml. \ + Supported: 'claude-code'" + )), + }; + + match run_result { + Ok(result) => { // Persist token usage if the agent reported it. - if let Some(ref usage) = pty_result.token_usage + if let Some(ref usage) = result.token_usage && let Ok(agents) = agents_ref.lock() && let Some(agent) = agents.get(&key_clone) && let Some(ref pr) = agent.project_root @@ -557,7 +571,7 @@ impl AgentPool { port_for_task, &sid, &aname, - pty_result.session_id, + result.session_id, watcher_tx_clone.clone(), ) .await; diff --git a/server/src/agents/pty.rs b/server/src/agents/pty.rs index 72c76e0..8697f0b 100644 --- a/server/src/agents/pty.rs +++ b/server/src/agents/pty.rs @@ -11,7 +11,7 @@ use crate::slog; use crate::slog_warn; /// Result from a PTY agent session, containing the session ID and token usage. -pub(super) struct PtyResult { +pub(in crate::agents) struct PtyResult { pub session_id: Option, pub token_usage: Option, } @@ -35,7 +35,7 @@ impl Drop for ChildKillerGuard { /// Spawn claude agent in a PTY and stream events through the broadcast channel. #[allow(clippy::too_many_arguments)] -pub(super) async fn run_agent_pty_streaming( +pub(in crate::agents) async fn run_agent_pty_streaming( story_id: &str, agent_name: &str, command: &str, diff --git a/server/src/agents/runtime/claude_code.rs b/server/src/agents/runtime/claude_code.rs new file mode 100644 index 0000000..5993c37 --- /dev/null +++ b/server/src/agents/runtime/claude_code.rs @@ -0,0 +1,66 @@ +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +use portable_pty::ChildKiller; +use tokio::sync::broadcast; + +use crate::agent_log::AgentLogWriter; + +use super::{AgentEvent, AgentRuntime, RuntimeContext, RuntimeResult, RuntimeStatus}; + +/// Agent runtime that spawns the `claude` CLI in a PTY and streams JSON events. +/// +/// This is the default runtime (`runtime = "claude-code"` in project.toml). +/// It wraps the existing PTY-based execution logic, preserving all streaming, +/// token tracking, and inactivity timeout behaviour. +pub struct ClaudeCodeRuntime { + child_killers: Arc>>>, +} + +impl ClaudeCodeRuntime { + pub fn new( + child_killers: Arc>>>, + ) -> Self { + Self { child_killers } + } +} + +impl AgentRuntime for ClaudeCodeRuntime { + async fn start( + &self, + ctx: RuntimeContext, + tx: broadcast::Sender, + event_log: Arc>>, + log_writer: Option>>, + ) -> Result { + let pty_result = super::super::pty::run_agent_pty_streaming( + &ctx.story_id, + &ctx.agent_name, + &ctx.command, + &ctx.args, + &ctx.prompt, + &ctx.cwd, + &tx, + &event_log, + log_writer, + ctx.inactivity_timeout_secs, + Arc::clone(&self.child_killers), + ) + .await?; + + Ok(RuntimeResult { + session_id: pty_result.session_id, + token_usage: pty_result.token_usage, + }) + } + + fn stop(&self) { + // Stopping is handled externally by the pool via kill_child_for_key(). + // The ChildKillerGuard in pty.rs deregisters automatically on process exit. + } + + fn get_status(&self) -> RuntimeStatus { + // Lifecycle status is tracked by the pool; the runtime itself is stateless. + RuntimeStatus::Idle + } +} diff --git a/server/src/agents/runtime/mod.rs b/server/src/agents/runtime/mod.rs new file mode 100644 index 0000000..31afc29 --- /dev/null +++ b/server/src/agents/runtime/mod.rs @@ -0,0 +1,150 @@ +mod claude_code; + +pub use claude_code::ClaudeCodeRuntime; + +use std::sync::{Arc, Mutex}; +use tokio::sync::broadcast; + +use crate::agent_log::AgentLogWriter; + +use super::{AgentEvent, TokenUsage}; + +/// Context passed to a runtime when launching an agent session. +pub struct RuntimeContext { + pub story_id: String, + pub agent_name: String, + pub command: String, + pub args: Vec, + pub prompt: String, + pub cwd: String, + pub inactivity_timeout_secs: u64, +} + +/// Result returned by a runtime after the agent session completes. +pub struct RuntimeResult { + pub session_id: Option, + pub token_usage: Option, +} + +/// Runtime status reported by the backend. +#[derive(Debug, Clone, PartialEq)] +#[allow(dead_code)] +pub enum RuntimeStatus { + Idle, + Running, + Completed, + Failed, +} + +/// Abstraction over different agent execution backends. +/// +/// Implementations: +/// - [`ClaudeCodeRuntime`]: spawns the `claude` CLI via a PTY (default, `runtime = "claude-code"`) +/// +/// Future implementations could include OpenAI and Gemini API runtimes. +#[allow(dead_code)] +pub trait AgentRuntime: Send + Sync { + /// Start the agent and drive it to completion, streaming events through + /// the provided broadcast sender and event log. + /// + /// Returns when the agent session finishes (success or error). + async fn start( + &self, + ctx: RuntimeContext, + tx: broadcast::Sender, + event_log: Arc>>, + log_writer: Option>>, + ) -> Result; + + /// Stop the running agent. + fn stop(&self); + + /// Get the current runtime status. + fn get_status(&self) -> RuntimeStatus; + + /// Return any events buffered outside the broadcast channel. + /// + /// PTY-based runtimes stream directly to the broadcast channel; this + /// returns empty by default. API-based runtimes may buffer events here. + fn stream_events(&self) -> Vec { + vec![] + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn runtime_context_fields() { + let ctx = RuntimeContext { + story_id: "42_story_foo".to_string(), + agent_name: "coder-1".to_string(), + command: "claude".to_string(), + args: vec!["--model".to_string(), "sonnet".to_string()], + prompt: "Do the thing".to_string(), + cwd: "/tmp/wt".to_string(), + inactivity_timeout_secs: 300, + }; + assert_eq!(ctx.story_id, "42_story_foo"); + assert_eq!(ctx.agent_name, "coder-1"); + assert_eq!(ctx.command, "claude"); + assert_eq!(ctx.args.len(), 2); + assert_eq!(ctx.prompt, "Do the thing"); + assert_eq!(ctx.cwd, "/tmp/wt"); + assert_eq!(ctx.inactivity_timeout_secs, 300); + } + + #[test] + fn runtime_result_fields() { + let result = RuntimeResult { + session_id: Some("sess-123".to_string()), + token_usage: Some(TokenUsage { + input_tokens: 100, + output_tokens: 50, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + total_cost_usd: 0.01, + }), + }; + assert_eq!(result.session_id, Some("sess-123".to_string())); + assert!(result.token_usage.is_some()); + let usage = result.token_usage.unwrap(); + assert_eq!(usage.input_tokens, 100); + assert_eq!(usage.output_tokens, 50); + assert_eq!(usage.total_cost_usd, 0.01); + } + + #[test] + fn runtime_result_no_usage() { + let result = RuntimeResult { + session_id: None, + token_usage: None, + }; + assert!(result.session_id.is_none()); + assert!(result.token_usage.is_none()); + } + + #[test] + fn runtime_status_variants() { + assert_eq!(RuntimeStatus::Idle, RuntimeStatus::Idle); + assert_ne!(RuntimeStatus::Running, RuntimeStatus::Completed); + assert_ne!(RuntimeStatus::Failed, RuntimeStatus::Idle); + } + + #[test] + fn claude_code_runtime_get_status_returns_idle() { + use std::collections::HashMap; + let killers = Arc::new(Mutex::new(HashMap::new())); + let runtime = ClaudeCodeRuntime::new(killers); + assert_eq!(runtime.get_status(), RuntimeStatus::Idle); + } + + #[test] + fn claude_code_runtime_stream_events_empty() { + use std::collections::HashMap; + let killers = Arc::new(Mutex::new(HashMap::new())); + let runtime = ClaudeCodeRuntime::new(killers); + assert!(runtime.stream_events().is_empty()); + } +} diff --git a/server/src/config.rs b/server/src/config.rs index 838c0e3..798e698 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -117,6 +117,11 @@ pub struct AgentConfig { /// and marked as Failed. Default: 300 (5 minutes). Set to 0 to disable. #[serde(default = "default_inactivity_timeout_secs")] pub inactivity_timeout_secs: u64, + /// Agent runtime backend. Controls how the agent process is spawned and + /// how events are streamed. Default: `"claude-code"` (spawns the `claude` + /// CLI in a PTY). Future values: `"openai"`, `"gemini"`. + #[serde(default)] + pub runtime: Option, } fn default_path() -> String { @@ -178,6 +183,7 @@ impl Default for ProjectConfig { system_prompt: None, stage: None, inactivity_timeout_secs: default_inactivity_timeout_secs(), + runtime: None, }], watcher: WatcherConfig::default(), default_qa: default_qa(), @@ -370,6 +376,17 @@ fn validate_agents(agents: &[AgentConfig]) -> Result<(), String> { agent.name )); } + if let Some(ref runtime) = agent.runtime { + match runtime.as_str() { + "claude-code" => {} + other => { + return Err(format!( + "Agent '{}': unknown runtime '{other}'. Supported: 'claude-code'", + agent.name + )); + } + } + } } Ok(()) } @@ -792,6 +809,43 @@ name = "coder-1" assert_eq!(config.max_coders, Some(3)); } + // ── runtime config ──────────────────────────────────────────────── + + #[test] + fn runtime_defaults_to_none() { + let toml_str = r#" +[[agent]] +name = "coder" +"#; + let config = ProjectConfig::parse(toml_str).unwrap(); + assert_eq!(config.agent[0].runtime, None); + } + + #[test] + fn runtime_claude_code_accepted() { + let toml_str = r#" +[[agent]] +name = "coder" +runtime = "claude-code" +"#; + let config = ProjectConfig::parse(toml_str).unwrap(); + assert_eq!( + config.agent[0].runtime, + Some("claude-code".to_string()) + ); + } + + #[test] + fn runtime_unknown_rejected() { + let toml_str = r#" +[[agent]] +name = "coder" +runtime = "openai" +"#; + let err = ProjectConfig::parse(toml_str).unwrap_err(); + assert!(err.contains("unknown runtime 'openai'")); + } + #[test] fn project_toml_has_three_sonnet_coders() { let manifest_dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR"));