308 lines
8.8 KiB
Rust
308 lines
8.8 KiB
Rust
|
|
use portable_pty::{CommandBuilder, PtySize, native_pty_system};
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
use std::collections::HashMap;
|
||
|
|
use std::io::{BufRead, BufReader};
|
||
|
|
use std::sync::Mutex;
|
||
|
|
|
||
|
|
/// Manages multiple concurrent Claude Code agent sessions.
|
||
|
|
///
|
||
|
|
/// Each agent is identified by a string name (e.g., "coder-1", "coder-2").
|
||
|
|
/// Agents run `claude -p` in a PTY for Max subscription billing.
|
||
|
|
/// Sessions can be resumed for multi-turn conversations.
|
||
|
|
pub struct AgentPool {
|
||
|
|
agents: Mutex<HashMap<String, AgentState>>,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Clone, Serialize)]
|
||
|
|
pub struct AgentInfo {
|
||
|
|
pub name: String,
|
||
|
|
pub role: String,
|
||
|
|
pub cwd: String,
|
||
|
|
pub session_id: Option<String>,
|
||
|
|
pub status: AgentStatus,
|
||
|
|
pub message_count: usize,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Clone, Serialize)]
|
||
|
|
#[serde(rename_all = "snake_case")]
|
||
|
|
pub enum AgentStatus {
|
||
|
|
Idle,
|
||
|
|
Running,
|
||
|
|
}
|
||
|
|
|
||
|
|
struct AgentState {
|
||
|
|
role: String,
|
||
|
|
cwd: String,
|
||
|
|
session_id: Option<String>,
|
||
|
|
message_count: usize,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Deserialize)]
|
||
|
|
pub struct CreateAgentRequest {
|
||
|
|
pub name: String,
|
||
|
|
pub role: String,
|
||
|
|
pub cwd: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Deserialize)]
|
||
|
|
pub struct SendMessageRequest {
|
||
|
|
pub message: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Serialize)]
|
||
|
|
pub struct AgentResponse {
|
||
|
|
pub agent: String,
|
||
|
|
pub text: String,
|
||
|
|
pub session_id: Option<String>,
|
||
|
|
pub model: Option<String>,
|
||
|
|
pub api_key_source: Option<String>,
|
||
|
|
pub rate_limit_type: Option<String>,
|
||
|
|
pub cost_usd: Option<f64>,
|
||
|
|
pub input_tokens: Option<u64>,
|
||
|
|
pub output_tokens: Option<u64>,
|
||
|
|
pub duration_ms: Option<u64>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl AgentPool {
|
||
|
|
pub fn new() -> Self {
|
||
|
|
Self {
|
||
|
|
agents: Mutex::new(HashMap::new()),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn create_agent(&self, req: CreateAgentRequest) -> Result<AgentInfo, String> {
|
||
|
|
let mut agents = self.agents.lock().map_err(|e| e.to_string())?;
|
||
|
|
|
||
|
|
if agents.contains_key(&req.name) {
|
||
|
|
return Err(format!("Agent '{}' already exists", req.name));
|
||
|
|
}
|
||
|
|
|
||
|
|
let state = AgentState {
|
||
|
|
role: req.role.clone(),
|
||
|
|
cwd: req.cwd.clone(),
|
||
|
|
session_id: None,
|
||
|
|
message_count: 0,
|
||
|
|
};
|
||
|
|
|
||
|
|
let info = AgentInfo {
|
||
|
|
name: req.name.clone(),
|
||
|
|
role: req.role,
|
||
|
|
cwd: req.cwd,
|
||
|
|
session_id: None,
|
||
|
|
status: AgentStatus::Idle,
|
||
|
|
message_count: 0,
|
||
|
|
};
|
||
|
|
|
||
|
|
agents.insert(req.name, state);
|
||
|
|
Ok(info)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn list_agents(&self) -> Result<Vec<AgentInfo>, String> {
|
||
|
|
let agents = self.agents.lock().map_err(|e| e.to_string())?;
|
||
|
|
Ok(agents
|
||
|
|
.iter()
|
||
|
|
.map(|(name, state)| AgentInfo {
|
||
|
|
name: name.clone(),
|
||
|
|
role: state.role.clone(),
|
||
|
|
cwd: state.cwd.clone(),
|
||
|
|
session_id: state.session_id.clone(),
|
||
|
|
status: AgentStatus::Idle,
|
||
|
|
message_count: state.message_count,
|
||
|
|
})
|
||
|
|
.collect())
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Send a message to an agent and wait for the complete response.
|
||
|
|
/// This spawns a `claude -p` process in a PTY, optionally resuming
|
||
|
|
/// a previous session for multi-turn conversations.
|
||
|
|
pub async fn send_message(
|
||
|
|
&self,
|
||
|
|
agent_name: &str,
|
||
|
|
message: &str,
|
||
|
|
) -> Result<AgentResponse, String> {
|
||
|
|
let (cwd, role, session_id) = {
|
||
|
|
let agents = self.agents.lock().map_err(|e| e.to_string())?;
|
||
|
|
let state = agents
|
||
|
|
.get(agent_name)
|
||
|
|
.ok_or_else(|| format!("Agent '{}' not found", agent_name))?;
|
||
|
|
(
|
||
|
|
state.cwd.clone(),
|
||
|
|
state.role.clone(),
|
||
|
|
state.session_id.clone(),
|
||
|
|
)
|
||
|
|
};
|
||
|
|
|
||
|
|
let agent = agent_name.to_string();
|
||
|
|
let msg = message.to_string();
|
||
|
|
let role_clone = role.clone();
|
||
|
|
|
||
|
|
let result = tokio::task::spawn_blocking(move || {
|
||
|
|
run_agent_pty(&agent, &msg, &cwd, &role_clone, session_id.as_deref())
|
||
|
|
})
|
||
|
|
.await
|
||
|
|
.map_err(|e| format!("Agent task panicked: {e}"))??;
|
||
|
|
|
||
|
|
// Update session_id for next message
|
||
|
|
if let Some(ref sid) = result.session_id {
|
||
|
|
let mut agents = self.agents.lock().map_err(|e| e.to_string())?;
|
||
|
|
if let Some(state) = agents.get_mut(agent_name) {
|
||
|
|
state.session_id = Some(sid.clone());
|
||
|
|
state.message_count += 1;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(result)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fn run_agent_pty(
|
||
|
|
agent_name: &str,
|
||
|
|
message: &str,
|
||
|
|
cwd: &str,
|
||
|
|
role: &str,
|
||
|
|
resume_session: Option<&str>,
|
||
|
|
) -> Result<AgentResponse, String> {
|
||
|
|
let pty_system = native_pty_system();
|
||
|
|
|
||
|
|
let pair = pty_system
|
||
|
|
.openpty(PtySize {
|
||
|
|
rows: 50,
|
||
|
|
cols: 200,
|
||
|
|
pixel_width: 0,
|
||
|
|
pixel_height: 0,
|
||
|
|
})
|
||
|
|
.map_err(|e| format!("Failed to open PTY: {e}"))?;
|
||
|
|
|
||
|
|
let mut cmd = CommandBuilder::new("claude");
|
||
|
|
cmd.arg("-p");
|
||
|
|
cmd.arg(message);
|
||
|
|
cmd.arg("--output-format");
|
||
|
|
cmd.arg("stream-json");
|
||
|
|
cmd.arg("--verbose");
|
||
|
|
|
||
|
|
// Append role as system prompt context
|
||
|
|
cmd.arg("--append-system-prompt");
|
||
|
|
cmd.arg(format!(
|
||
|
|
"You are agent '{}' with role: {}. Work autonomously on the task given.",
|
||
|
|
agent_name, role
|
||
|
|
));
|
||
|
|
|
||
|
|
// Resume previous session if available
|
||
|
|
if let Some(session_id) = resume_session {
|
||
|
|
cmd.arg("--resume");
|
||
|
|
cmd.arg(session_id);
|
||
|
|
}
|
||
|
|
|
||
|
|
cmd.cwd(cwd);
|
||
|
|
cmd.env("NO_COLOR", "1");
|
||
|
|
|
||
|
|
eprintln!(
|
||
|
|
"[agent:{}] Spawning claude -p (session: {:?})",
|
||
|
|
agent_name,
|
||
|
|
resume_session.unwrap_or("new")
|
||
|
|
);
|
||
|
|
|
||
|
|
let mut child = pair
|
||
|
|
.slave
|
||
|
|
.spawn_command(cmd)
|
||
|
|
.map_err(|e| format!("Failed to spawn claude for agent {agent_name}: {e}"))?;
|
||
|
|
|
||
|
|
drop(pair.slave);
|
||
|
|
|
||
|
|
let reader = pair
|
||
|
|
.master
|
||
|
|
.try_clone_reader()
|
||
|
|
.map_err(|e| format!("Failed to clone PTY reader: {e}"))?;
|
||
|
|
|
||
|
|
drop(pair.master);
|
||
|
|
|
||
|
|
let buf_reader = BufReader::new(reader);
|
||
|
|
let mut response = AgentResponse {
|
||
|
|
agent: agent_name.to_string(),
|
||
|
|
text: String::new(),
|
||
|
|
session_id: None,
|
||
|
|
model: None,
|
||
|
|
api_key_source: None,
|
||
|
|
rate_limit_type: None,
|
||
|
|
cost_usd: None,
|
||
|
|
input_tokens: None,
|
||
|
|
output_tokens: None,
|
||
|
|
duration_ms: None,
|
||
|
|
};
|
||
|
|
|
||
|
|
for line in buf_reader.lines() {
|
||
|
|
let line = match line {
|
||
|
|
Ok(l) => l,
|
||
|
|
Err(_) => break,
|
||
|
|
};
|
||
|
|
|
||
|
|
let trimmed = line.trim();
|
||
|
|
if trimmed.is_empty() {
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
|
||
|
|
let json: serde_json::Value = match serde_json::from_str(trimmed) {
|
||
|
|
Ok(j) => j,
|
||
|
|
Err(_) => continue, // skip non-JSON (terminal escapes)
|
||
|
|
};
|
||
|
|
|
||
|
|
let event_type = json.get("type").and_then(|t| t.as_str()).unwrap_or("");
|
||
|
|
|
||
|
|
match event_type {
|
||
|
|
"system" => {
|
||
|
|
response.session_id = json
|
||
|
|
.get("session_id")
|
||
|
|
.and_then(|s| s.as_str())
|
||
|
|
.map(|s| s.to_string());
|
||
|
|
response.model = json
|
||
|
|
.get("model")
|
||
|
|
.and_then(|s| s.as_str())
|
||
|
|
.map(|s| s.to_string());
|
||
|
|
response.api_key_source = json
|
||
|
|
.get("apiKeySource")
|
||
|
|
.and_then(|s| s.as_str())
|
||
|
|
.map(|s| s.to_string());
|
||
|
|
}
|
||
|
|
"rate_limit_event" => {
|
||
|
|
if let Some(info) = json.get("rate_limit_info") {
|
||
|
|
response.rate_limit_type = info
|
||
|
|
.get("rateLimitType")
|
||
|
|
.and_then(|s| s.as_str())
|
||
|
|
.map(|s| s.to_string());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
"assistant" => {
|
||
|
|
if let Some(message) = json.get("message") {
|
||
|
|
if let Some(content) = message.get("content").and_then(|c| c.as_array()) {
|
||
|
|
for block in content {
|
||
|
|
if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
|
||
|
|
response.text.push_str(text);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
"result" => {
|
||
|
|
response.cost_usd = json.get("total_cost_usd").and_then(|c| c.as_f64());
|
||
|
|
response.duration_ms = json.get("duration_ms").and_then(|d| d.as_u64());
|
||
|
|
if let Some(usage) = json.get("usage") {
|
||
|
|
response.input_tokens =
|
||
|
|
usage.get("input_tokens").and_then(|t| t.as_u64());
|
||
|
|
response.output_tokens =
|
||
|
|
usage.get("output_tokens").and_then(|t| t.as_u64());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
_ => {}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
let _ = child.kill();
|
||
|
|
|
||
|
|
eprintln!(
|
||
|
|
"[agent:{}] Done. Session: {:?}, tokens: {:?}/{:?}",
|
||
|
|
agent_name, response.session_id, response.input_tokens, response.output_tokens
|
||
|
|
);
|
||
|
|
|
||
|
|
Ok(response)
|
||
|
|
}
|