Files
storkit/server/src/agents.rs

562 lines
17 KiB
Rust
Raw Normal View History

use crate::config::ProjectConfig;
use crate::worktree::{self, WorktreeInfo};
use portable_pty::{CommandBuilder, PtySize, native_pty_system};
use serde::Serialize;
use std::collections::HashMap;
use std::io::{BufRead, BufReader};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use tokio::sync::broadcast;
/// Build the composite key used to track agents in the pool.
fn composite_key(story_id: &str, agent_name: &str) -> String {
format!("{story_id}:{agent_name}")
}
/// Events streamed from a running agent to SSE clients.
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AgentEvent {
/// Agent status changed.
Status {
story_id: String,
agent_name: String,
status: String,
},
/// Raw text output from the agent process.
Output {
story_id: String,
agent_name: String,
text: String,
},
/// Agent produced a JSON event from `--output-format stream-json`.
AgentJson {
story_id: String,
agent_name: String,
data: serde_json::Value,
},
/// Agent finished.
Done {
story_id: String,
agent_name: String,
session_id: Option<String>,
},
/// Agent errored.
Error {
story_id: String,
agent_name: String,
message: String,
},
}
#[derive(Debug, Clone, Serialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum AgentStatus {
Pending,
Running,
Completed,
Failed,
}
impl std::fmt::Display for AgentStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Pending => write!(f, "pending"),
Self::Running => write!(f, "running"),
Self::Completed => write!(f, "completed"),
Self::Failed => write!(f, "failed"),
}
}
}
#[derive(Serialize, Clone)]
pub struct AgentInfo {
pub story_id: String,
pub agent_name: String,
pub status: AgentStatus,
pub session_id: Option<String>,
pub worktree_path: Option<String>,
pub base_branch: Option<String>,
}
struct StoryAgent {
agent_name: String,
status: AgentStatus,
worktree_info: Option<WorktreeInfo>,
session_id: Option<String>,
tx: broadcast::Sender<AgentEvent>,
task_handle: Option<tokio::task::JoinHandle<()>>,
/// Accumulated events for polling via get_agent_output.
event_log: Arc<Mutex<Vec<AgentEvent>>>,
}
/// Manages concurrent story agents, each in its own worktree.
pub struct AgentPool {
agents: Arc<Mutex<HashMap<String, StoryAgent>>>,
}
impl AgentPool {
pub fn new() -> Self {
Self {
agents: Arc::new(Mutex::new(HashMap::new())),
}
}
/// Start an agent for a story: load config, create worktree, spawn agent.
/// If `agent_name` is None, defaults to the first configured agent.
pub async fn start_agent(
&self,
project_root: &Path,
story_id: &str,
agent_name: Option<&str>,
) -> Result<AgentInfo, String> {
let config = ProjectConfig::load(project_root)?;
// Resolve agent name from config
let resolved_name = match agent_name {
Some(name) => {
config
.find_agent(name)
.ok_or_else(|| format!("No agent named '{name}' in config"))?;
name.to_string()
}
None => config
.default_agent()
.ok_or_else(|| "No agents configured".to_string())?
.name
.clone(),
};
let key = composite_key(story_id, &resolved_name);
// Check not already running
{
let agents = self.agents.lock().map_err(|e| e.to_string())?;
if let Some(agent) = agents.get(&key)
&& (agent.status == AgentStatus::Running || agent.status == AgentStatus::Pending)
{
return Err(format!(
"Agent '{resolved_name}' for story '{story_id}' is already {}",
agent.status
));
}
}
let (tx, _) = broadcast::channel::<AgentEvent>(1024);
let event_log: Arc<Mutex<Vec<AgentEvent>>> = Arc::new(Mutex::new(Vec::new()));
// Register as pending
{
let mut agents = self.agents.lock().map_err(|e| e.to_string())?;
agents.insert(
key.clone(),
StoryAgent {
agent_name: resolved_name.clone(),
status: AgentStatus::Pending,
worktree_info: None,
session_id: None,
tx: tx.clone(),
task_handle: None,
event_log: event_log.clone(),
},
);
}
let _ = tx.send(AgentEvent::Status {
story_id: story_id.to_string(),
agent_name: resolved_name.clone(),
status: "pending".to_string(),
});
// Create worktree
let wt_info = worktree::create_worktree(project_root, story_id, &config).await?;
// Update with worktree info
{
let mut agents = self.agents.lock().map_err(|e| e.to_string())?;
if let Some(agent) = agents.get_mut(&key) {
agent.worktree_info = Some(wt_info.clone());
}
}
// Spawn the agent process
let wt_path_str = wt_info.path.to_string_lossy().to_string();
let (command, args, prompt) =
config.render_agent_args(&wt_path_str, story_id, Some(&resolved_name), Some(&wt_info.base_branch))?;
let sid = story_id.to_string();
let aname = resolved_name.clone();
let tx_clone = tx.clone();
let agents_ref = self.agents.clone();
let cwd = wt_path_str.clone();
let key_clone = key.clone();
let log_clone = event_log.clone();
let handle = tokio::spawn(async move {
let _ = tx_clone.send(AgentEvent::Status {
story_id: sid.clone(),
agent_name: aname.clone(),
status: "running".to_string(),
});
match run_agent_pty_streaming(
&sid, &aname, &command, &args, &prompt, &cwd, &tx_clone, &log_clone,
)
.await
{
Ok(session_id) => {
if let Ok(mut agents) = agents_ref.lock()
&& let Some(agent) = agents.get_mut(&key_clone)
{
agent.status = AgentStatus::Completed;
agent.session_id = session_id.clone();
}
let _ = tx_clone.send(AgentEvent::Done {
story_id: sid.clone(),
agent_name: aname.clone(),
session_id,
});
}
Err(e) => {
if let Ok(mut agents) = agents_ref.lock()
&& let Some(agent) = agents.get_mut(&key_clone)
{
agent.status = AgentStatus::Failed;
}
let _ = tx_clone.send(AgentEvent::Error {
story_id: sid.clone(),
agent_name: aname.clone(),
message: e,
});
}
}
});
// Update status to running with task handle
{
let mut agents = self.agents.lock().map_err(|e| e.to_string())?;
if let Some(agent) = agents.get_mut(&key) {
agent.status = AgentStatus::Running;
agent.task_handle = Some(handle);
}
}
Ok(AgentInfo {
story_id: story_id.to_string(),
agent_name: resolved_name,
status: AgentStatus::Running,
session_id: None,
worktree_path: Some(wt_path_str),
base_branch: Some(wt_info.base_branch.clone()),
})
}
/// Stop a running agent. Worktree is preserved for inspection.
pub async fn stop_agent(
&self,
_project_root: &Path,
story_id: &str,
agent_name: &str,
) -> Result<(), String> {
let key = composite_key(story_id, agent_name);
let (worktree_info, task_handle, tx) = {
let mut agents = self.agents.lock().map_err(|e| e.to_string())?;
let agent = agents
.get_mut(&key)
.ok_or_else(|| format!("No agent '{agent_name}' for story '{story_id}'"))?;
let wt = agent.worktree_info.clone();
let handle = agent.task_handle.take();
let tx = agent.tx.clone();
agent.status = AgentStatus::Failed;
(wt, handle, tx)
};
// Abort the task
if let Some(handle) = task_handle {
handle.abort();
let _ = handle.await;
}
// Preserve worktree for inspection — don't destroy agent's work on stop.
if let Some(ref wt) = worktree_info {
eprintln!(
"[agents] Worktree preserved for {story_id}:{agent_name}: {}",
wt.path.display()
);
}
let _ = tx.send(AgentEvent::Status {
story_id: story_id.to_string(),
agent_name: agent_name.to_string(),
status: "stopped".to_string(),
});
// Remove from map
{
let mut agents = self.agents.lock().map_err(|e| e.to_string())?;
agents.remove(&key);
}
Ok(())
}
/// List all agents with their status.
pub fn list_agents(&self) -> Result<Vec<AgentInfo>, String> {
let agents = self.agents.lock().map_err(|e| e.to_string())?;
Ok(agents
.iter()
.map(|(key, agent)| {
// Extract story_id from composite key "story_id:agent_name"
let story_id = key
.rsplit_once(':')
.map(|(sid, _)| sid.to_string())
.unwrap_or_else(|| key.clone());
AgentInfo {
story_id,
agent_name: agent.agent_name.clone(),
status: agent.status.clone(),
session_id: agent.session_id.clone(),
worktree_path: agent
.worktree_info
.as_ref()
.map(|wt| wt.path.to_string_lossy().to_string()),
base_branch: agent
.worktree_info
.as_ref()
.map(|wt| wt.base_branch.clone()),
}
})
.collect())
}
/// Subscribe to events for a story agent.
pub fn subscribe(
&self,
story_id: &str,
agent_name: &str,
) -> Result<broadcast::Receiver<AgentEvent>, String> {
let key = composite_key(story_id, agent_name);
let agents = self.agents.lock().map_err(|e| e.to_string())?;
let agent = agents
.get(&key)
.ok_or_else(|| format!("No agent '{agent_name}' for story '{story_id}'"))?;
Ok(agent.tx.subscribe())
}
/// Drain accumulated events for polling. Returns all events since the last drain.
pub fn drain_events(
&self,
story_id: &str,
agent_name: &str,
) -> Result<Vec<AgentEvent>, String> {
let key = composite_key(story_id, agent_name);
let agents = self.agents.lock().map_err(|e| e.to_string())?;
let agent = agents
.get(&key)
.ok_or_else(|| format!("No agent '{agent_name}' for story '{story_id}'"))?;
let mut log = agent.event_log.lock().map_err(|e| e.to_string())?;
Ok(log.drain(..).collect())
}
/// Get project root helper.
pub fn get_project_root(
&self,
state: &crate::state::SessionState,
) -> Result<PathBuf, String> {
state.get_project_root()
}
}
/// Spawn claude agent in a PTY and stream events through the broadcast channel.
#[allow(clippy::too_many_arguments)]
async fn run_agent_pty_streaming(
story_id: &str,
agent_name: &str,
command: &str,
args: &[String],
prompt: &str,
cwd: &str,
tx: &broadcast::Sender<AgentEvent>,
event_log: &Arc<Mutex<Vec<AgentEvent>>>,
) -> Result<Option<String>, String> {
let sid = story_id.to_string();
let aname = agent_name.to_string();
let cmd = command.to_string();
let args = args.to_vec();
let prompt = prompt.to_string();
let cwd = cwd.to_string();
let tx = tx.clone();
let event_log = event_log.clone();
tokio::task::spawn_blocking(move || {
run_agent_pty_blocking(&sid, &aname, &cmd, &args, &prompt, &cwd, &tx, &event_log)
})
.await
.map_err(|e| format!("Agent task panicked: {e}"))?
}
/// Helper to send an event to both broadcast and event log.
fn emit_event(
event: AgentEvent,
tx: &broadcast::Sender<AgentEvent>,
event_log: &Mutex<Vec<AgentEvent>>,
) {
if let Ok(mut log) = event_log.lock() {
log.push(event.clone());
}
let _ = tx.send(event);
}
#[allow(clippy::too_many_arguments)]
fn run_agent_pty_blocking(
story_id: &str,
agent_name: &str,
command: &str,
args: &[String],
prompt: &str,
cwd: &str,
tx: &broadcast::Sender<AgentEvent>,
event_log: &Mutex<Vec<AgentEvent>>,
) -> Result<Option<String>, String> {
let pty_system = native_pty_system();
let pair = pty_system
.openpty(PtySize {
rows: 50,
cols: 200,
pixel_width: 0,
pixel_height: 0,
})
.map_err(|e| format!("Failed to open PTY: {e}"))?;
let mut cmd = CommandBuilder::new(command);
// -p <prompt> must come first
cmd.arg("-p");
cmd.arg(prompt);
// Add configured args (e.g., --directory /path/to/worktree, --model, etc.)
for arg in args {
cmd.arg(arg);
}
cmd.arg("--output-format");
cmd.arg("stream-json");
cmd.arg("--verbose");
// Supervised agents don't need interactive permission prompts
cmd.arg("--permission-mode");
cmd.arg("bypassPermissions");
cmd.cwd(cwd);
cmd.env("NO_COLOR", "1");
// Allow spawning Claude Code from within a Claude Code session
cmd.env_remove("CLAUDECODE");
cmd.env_remove("CLAUDE_CODE_ENTRYPOINT");
eprintln!("[agent:{story_id}:{agent_name}] Spawning {command} in {cwd} with args: {args:?}");
let mut child = pair
.slave
.spawn_command(cmd)
.map_err(|e| format!("Failed to spawn agent for {story_id}:{agent_name}: {e}"))?;
drop(pair.slave);
let reader = pair
.master
.try_clone_reader()
.map_err(|e| format!("Failed to clone PTY reader: {e}"))?;
drop(pair.master);
let buf_reader = BufReader::new(reader);
let mut session_id: Option<String> = None;
for line in buf_reader.lines() {
let line = match line {
Ok(l) => l,
Err(_) => break,
};
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
// Try to parse as JSON
let json: serde_json::Value = match serde_json::from_str(trimmed) {
Ok(j) => j,
Err(_) => {
// Non-JSON output (terminal escapes etc.) — send as raw output
emit_event(
AgentEvent::Output {
story_id: story_id.to_string(),
agent_name: agent_name.to_string(),
text: trimmed.to_string(),
},
tx,
event_log,
);
continue;
}
};
let event_type = json.get("type").and_then(|t| t.as_str()).unwrap_or("");
match event_type {
"system" => {
session_id = json
.get("session_id")
.and_then(|s| s.as_str())
.map(|s| s.to_string());
}
"assistant" => {
if let Some(message) = json.get("message")
&& let Some(content) = message.get("content").and_then(|c| c.as_array())
{
for block in content {
if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
emit_event(
AgentEvent::Output {
story_id: story_id.to_string(),
agent_name: agent_name.to_string(),
text: text.to_string(),
},
tx,
event_log,
);
}
}
}
}
_ => {}
}
// Forward all JSON events
emit_event(
AgentEvent::AgentJson {
story_id: story_id.to_string(),
agent_name: agent_name.to_string(),
data: json,
},
tx,
event_log,
);
}
let _ = child.kill();
let _ = child.wait();
eprintln!(
"[agent:{story_id}:{agent_name}] Done. Session: {:?}",
session_id
);
Ok(session_id)
}