From 5ed1438ab935e8c5ea72717e08cd2445a9cfa25c Mon Sep 17 00:00:00 2001 From: dave Date: Wed, 13 May 2026 23:33:30 +0000 Subject: [PATCH] huskies: merge 1015 --- server/src/agents/mod.rs | 14 +- server/src/agents/model.rs | 125 ++++++++++++++++++ server/src/agents/pool/auto_assign/scan.rs | 6 +- .../pool/auto_assign/watchdog/budget.rs | 7 +- server/src/agents/pool/start/mod.rs | 6 +- server/src/agents/pool/start/spawn.rs | 19 +-- server/src/agents/runtime/claude_code.rs | 9 +- server/src/agents/runtime/gemini/api.rs | 4 +- server/src/agents/runtime/gemini/mod.rs | 2 +- server/src/agents/runtime/mod.rs | 8 +- server/src/agents/runtime/openai.rs | 8 +- server/src/agents/token_usage.rs | 6 +- server/src/chat/commands/status/render.rs | 2 +- server/src/config/mod.rs | 11 +- server/src/config/tests.rs | 29 +++- server/src/crdt_sync/rpc.rs | 6 +- server/src/http/workflow/pipeline.rs | 2 +- server/src/service/agents/mod.rs | 4 +- server/src/service/agents/token.rs | 9 +- server/src/service/settings/project.rs | 17 +-- server/src/service/settings/validate.rs | 2 +- server/src/service/ws/message/convert.rs | 4 +- 22 files changed, 227 insertions(+), 73 deletions(-) create mode 100644 server/src/agents/model.rs diff --git a/server/src/agents/mod.rs b/server/src/agents/mod.rs index 6e5cde03..e3326277 100644 --- a/server/src/agents/mod.rs +++ b/server/src/agents/mod.rs @@ -16,6 +16,10 @@ pub mod session_store; /// Token-usage tracking and budget estimation. pub mod token_usage; +/// Typed agent model enum (Sonnet/Opus/Haiku). +pub mod model; +pub use model::AgentModel; + use crate::config::AgentConfig; use serde::{Deserialize, Serialize}; @@ -205,13 +209,13 @@ impl TokenUsage { /// data in the agent log (since `total_cost_usd` is only available in the /// `result` event at session end). Uses conservative (high) pricing when /// the model is unknown so budget limits are hit sooner rather than later. - pub fn estimate_cost_usd(&self, model: Option<&str>) -> f64 { + pub fn estimate_cost_usd(&self, model: Option<&AgentModel>) -> f64 { // Per-million-token pricing (input, output, cache_read, cache_create). let (inp, out, cr, cc) = match model { - Some(m) if m.contains("haiku") => (0.80, 4.0, 0.08, 1.00), - Some(m) if m.contains("sonnet") => (3.0, 15.0, 0.30, 3.75), - // Opus or unknown → most expensive = conservative. - _ => (15.0, 75.0, 1.50, 18.75), + Some(AgentModel::Haiku) => (0.80, 4.0, 0.08, 1.00), + Some(AgentModel::Sonnet) => (3.0, 15.0, 0.30, 3.75), + // Opus, Other, or unknown → most expensive = conservative. + Some(AgentModel::Opus) | Some(AgentModel::Other(_)) | None => (15.0, 75.0, 1.50, 18.75), }; (self.input_tokens as f64 * inp + self.output_tokens as f64 * out diff --git a/server/src/agents/model.rs b/server/src/agents/model.rs new file mode 100644 index 00000000..8bc137a0 --- /dev/null +++ b/server/src/agents/model.rs @@ -0,0 +1,125 @@ +//! Typed agent model — replaces raw model strings throughout the agent subsystem. +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::fmt; + +/// Supported agent model families. +/// +/// Serialises to the canonical short name ("sonnet", "opus", "haiku") or, for +/// `Other`, the original string verbatim. Deserialises from any string: +/// Claude family names are matched by substring (e.g. "claude-sonnet-4-6"), +/// everything else becomes `Other(string)` so non-Claude runtimes (Gemini, +/// OpenAI, etc.) survive a config round-trip without error. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AgentModel { + /// Claude Sonnet family. + Sonnet, + /// Claude Opus family. + Opus, + /// Claude Haiku family. + Haiku, + /// Any model string not recognised as a Claude family name. + Other(String), +} + +impl AgentModel { + /// Canonical short name used for serialisation and CLI `--model` flags. + pub fn as_str(&self) -> &str { + match self { + Self::Sonnet => "sonnet", + Self::Opus => "opus", + Self::Haiku => "haiku", + Self::Other(s) => s.as_str(), + } + } + + /// Parse any model string into a variant — always succeeds. + /// + /// Claude family names are matched by substring; everything else becomes + /// `Other`. + pub fn from_api_str(s: &str) -> Self { + if s.contains("haiku") { + Self::Haiku + } else if s.contains("sonnet") { + Self::Sonnet + } else if s.contains("opus") { + Self::Opus + } else { + Self::Other(s.to_string()) + } + } +} + +impl fmt::Display for AgentModel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl Serialize for AgentModel { + fn serialize(&self, s: S) -> Result { + s.serialize_str(self.as_str()) + } +} + +impl<'de> Deserialize<'de> for AgentModel { + fn deserialize>(d: D) -> Result { + let s = String::deserialize(d)?; + Ok(Self::from_api_str(&s)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn short_names_deserialise() { + let s: AgentModel = serde_json::from_str("\"sonnet\"").unwrap(); + assert_eq!(s, AgentModel::Sonnet); + let o: AgentModel = serde_json::from_str("\"opus\"").unwrap(); + assert_eq!(o, AgentModel::Opus); + let h: AgentModel = serde_json::from_str("\"haiku\"").unwrap(); + assert_eq!(h, AgentModel::Haiku); + } + + #[test] + fn long_names_deserialise() { + let s: AgentModel = serde_json::from_str("\"claude-sonnet-4-6\"").unwrap(); + assert_eq!(s, AgentModel::Sonnet); + let h: AgentModel = serde_json::from_str("\"claude-haiku-4-5-20251001\"").unwrap(); + assert_eq!(h, AgentModel::Haiku); + let o: AgentModel = serde_json::from_str("\"claude-opus-4-5\"").unwrap(); + assert_eq!(o, AgentModel::Opus); + } + + #[test] + fn serialises_to_short_name() { + assert_eq!( + serde_json::to_string(&AgentModel::Sonnet).unwrap(), + "\"sonnet\"" + ); + assert_eq!( + serde_json::to_string(&AgentModel::Opus).unwrap(), + "\"opus\"" + ); + assert_eq!( + serde_json::to_string(&AgentModel::Haiku).unwrap(), + "\"haiku\"" + ); + } + + #[test] + fn unknown_string_becomes_other() { + let r: AgentModel = serde_json::from_str("\"gemini-2.5-pro\"").unwrap(); + assert_eq!(r, AgentModel::Other("gemini-2.5-pro".to_string())); + assert_eq!(r.as_str(), "gemini-2.5-pro"); + // Round-trips verbatim + assert_eq!(serde_json::to_string(&r).unwrap(), "\"gemini-2.5-pro\""); + } + + #[test] + fn option_none_round_trips() { + let v: Option = serde_json::from_str("null").unwrap(); + assert!(v.is_none()); + } +} diff --git a/server/src/agents/pool/auto_assign/scan.rs b/server/src/agents/pool/auto_assign/scan.rs index 0faa2273..b5138eae 100644 --- a/server/src/agents/pool/auto_assign/scan.rs +++ b/server/src/agents/pool/auto_assign/scan.rs @@ -120,11 +120,9 @@ pub(in crate::agents::pool) fn find_free_agent_for_stage<'a>( // model matches. This keeps opus agents reserved for explicit requests. if *stage == PipelineStage::Coder && let Some(ref default_model) = config.default_coder_model + && agent_config.model.as_ref() != Some(default_model) { - let agent_model = agent_config.model.as_deref().unwrap_or(""); - if agent_model != default_model { - continue; - } + continue; } let is_busy = agents.values().any(|a| { a.agent_name == agent_config.name diff --git a/server/src/agents/pool/auto_assign/watchdog/budget.rs b/server/src/agents/pool/auto_assign/watchdog/budget.rs index 5c11dc23..7aea3f02 100644 --- a/server/src/agents/pool/auto_assign/watchdog/budget.rs +++ b/server/src/agents/pool/auto_assign/watchdog/budget.rs @@ -67,7 +67,10 @@ pub(crate) fn compute_budget_from_single_log(path: &Path) -> f64 { && let Some(message) = data.get("message") && let Some(usage) = message.get("usage") { - let model = message.get("model").and_then(|v| v.as_str()); + let model = message + .get("model") + .and_then(|v| v.as_str()) + .map(crate::agents::AgentModel::from_api_str); let token_usage = TokenUsage { input_tokens: usage .get("input_tokens") @@ -87,7 +90,7 @@ pub(crate) fn compute_budget_from_single_log(path: &Path) -> f64 { .unwrap_or(0), total_cost_usd: 0.0, }; - cost += token_usage.estimate_cost_usd(model); + cost += token_usage.estimate_cost_usd(model.as_ref()); } } cost diff --git a/server/src/agents/pool/start/mod.rs b/server/src/agents/pool/start/mod.rs index a66c02ac..6c46de91 100644 --- a/server/src/agents/pool/start/mod.rs +++ b/server/src/agents/pool/start/mod.rs @@ -364,13 +364,13 @@ impl AgentPool { let effective_session_id = session_id_to_resume.or_else(|| { let model = config .find_agent(&resolved_name) - .and_then(|a| a.model.clone()) - .unwrap_or_default(); + .and_then(|a| a.model.as_ref().map(|m| m.as_str())) + .unwrap_or(""); crate::agents::session_store::lookup_session( project_root, story_id, &resolved_name, - &model, + model, ) }); diff --git a/server/src/agents/pool/start/spawn.rs b/server/src/agents/pool/start/spawn.rs index e81d7690..d5e8e423 100644 --- a/server/src/agents/pool/start/spawn.rs +++ b/server/src/agents/pool/start/spawn.rs @@ -384,8 +384,7 @@ pub(super) async fn run_agent_spawn( // passed to RuntimeContext for eager session recording (bug 967). let agent_model = config_clone .find_agent(&aname) - .and_then(|a| a.model.clone()) - .unwrap_or_default(); + .and_then(|a| a.model.clone()); let run_result = match runtime_name { "claude-code" => { @@ -463,11 +462,15 @@ pub(super) async fn run_agent_spawn( && let Some(agent) = agents.get(&key_clone) && let Some(ref pr) = agent.project_root { - let model = config_clone + let model_for_record = config_clone .find_agent(&aname) .and_then(|a| a.model.clone()); - let record = - crate::agents::token_usage::build_record(&sid, &aname, model, usage.clone()); + let record = crate::agents::token_usage::build_record( + &sid, + &aname, + model_for_record, + usage.clone(), + ); if let Err(e) = crate::agents::token_usage::append_record(pr, &record) { slog_error!( "[agents] Failed to persist token usage for \ @@ -480,13 +483,13 @@ pub(super) async fn run_agent_spawn( if let Some(ref sess_id) = result.session_id { let model = config_clone .find_agent(&aname) - .and_then(|a| a.model.clone()) - .unwrap_or_default(); + .and_then(|a| a.model.as_ref().map(|m| m.as_str())) + .unwrap_or(""); crate::agents::session_store::record_session( &project_root_clone, &sid, &aname, - &model, + model, sess_id, ); } diff --git a/server/src/agents/runtime/claude_code.rs b/server/src/agents/runtime/claude_code.rs index 08e038b2..ea619d30 100644 --- a/server/src/agents/runtime/claude_code.rs +++ b/server/src/agents/runtime/claude_code.rs @@ -42,11 +42,10 @@ impl AgentRuntime for ClaudeCodeRuntime { event_log: Arc>>, log_writer: Option>>, ) -> Result { - let eager_record = if ctx.model.is_empty() { - None - } else { - Some((ctx.project_root.clone(), ctx.model.clone())) - }; + let eager_record = ctx + .model + .as_ref() + .map(|m| (ctx.project_root.clone(), m.as_str().to_string())); let pty_result = super::super::pty::run_agent_pty_streaming( &ctx.story_id, &ctx.agent_name, diff --git a/server/src/agents/runtime/gemini/api.rs b/server/src/agents/runtime/gemini/api.rs index b17ec4d7..ac9ef3f0 100644 --- a/server/src/agents/runtime/gemini/api.rs +++ b/server/src/agents/runtime/gemini/api.rs @@ -118,7 +118,7 @@ mod tests { session_id_to_resume: None, fresh_prompt: None, project_root: std::path::PathBuf::from("/tmp/project"), - model: String::new(), + model: None, }; let instruction = build_system_instruction(&ctx); @@ -139,7 +139,7 @@ mod tests { session_id_to_resume: None, fresh_prompt: None, project_root: std::path::PathBuf::from("/tmp/project"), - model: String::new(), + model: None, }; let instruction = build_system_instruction(&ctx); diff --git a/server/src/agents/runtime/gemini/mod.rs b/server/src/agents/runtime/gemini/mod.rs index c1487685..b1b1f587 100644 --- a/server/src/agents/runtime/gemini/mod.rs +++ b/server/src/agents/runtime/gemini/mod.rs @@ -385,7 +385,7 @@ mod tests { session_id_to_resume: None, fresh_prompt: None, project_root: std::path::PathBuf::from("/tmp/project"), - model: String::new(), + model: None, }; // The model extraction logic is inside start(), but we test the diff --git a/server/src/agents/runtime/mod.rs b/server/src/agents/runtime/mod.rs index f782bce5..ebbda696 100644 --- a/server/src/agents/runtime/mod.rs +++ b/server/src/agents/runtime/mod.rs @@ -46,9 +46,9 @@ pub struct RuntimeContext { /// Eager recording ensures the session survives a watchdog kill that aborts /// the tokio task before `run_agent_spawn`'s `record_session()` call runs. pub project_root: std::path::PathBuf, - /// Agent model name — forms part of the session store key used for eager - /// recording (bug 967). An empty string disables eager recording. - pub model: String, + /// Agent model — forms part of the session store key used for eager + /// recording (bug 967). `None` disables eager recording. + pub model: Option, } /// Result returned by a runtime after the agent session completes. @@ -134,7 +134,7 @@ mod tests { session_id_to_resume: None, fresh_prompt: None, project_root: std::path::PathBuf::from("/tmp/project"), - model: "sonnet".to_string(), + model: Some(crate::agents::AgentModel::Sonnet), }; assert_eq!(ctx.story_id, "42_story_foo"); assert_eq!(ctx.agent_name, "coder-1"); diff --git a/server/src/agents/runtime/openai.rs b/server/src/agents/runtime/openai.rs index cd192b2a..800fe584 100644 --- a/server/src/agents/runtime/openai.rs +++ b/server/src/agents/runtime/openai.rs @@ -561,7 +561,7 @@ mod tests { session_id_to_resume: None, fresh_prompt: None, project_root: std::path::PathBuf::from("/tmp/project"), - model: String::new(), + model: None, }; assert_eq!(build_system_text(&ctx), "Custom system prompt"); @@ -581,7 +581,7 @@ mod tests { session_id_to_resume: None, fresh_prompt: None, project_root: std::path::PathBuf::from("/tmp/project"), - model: String::new(), + model: None, }; let text = build_system_text(&ctx); @@ -634,7 +634,7 @@ mod tests { session_id_to_resume: None, fresh_prompt: None, project_root: std::path::PathBuf::from("/tmp/project"), - model: String::new(), + model: None, }; assert!(ctx.command.starts_with("gpt")); } @@ -653,7 +653,7 @@ mod tests { session_id_to_resume: None, fresh_prompt: None, project_root: std::path::PathBuf::from("/tmp/project"), - model: String::new(), + model: None, }; assert!(ctx.command.starts_with("o")); } diff --git a/server/src/agents/token_usage.rs b/server/src/agents/token_usage.rs index 67ea1a24..048a2a69 100644 --- a/server/src/agents/token_usage.rs +++ b/server/src/agents/token_usage.rs @@ -5,7 +5,7 @@ use std::path::Path; use chrono::Utc; use serde::{Deserialize, Serialize}; -use super::TokenUsage; +use super::{AgentModel, TokenUsage}; /// A single token usage record persisted to disk. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] @@ -14,7 +14,7 @@ pub struct TokenUsageRecord { pub agent_name: String, pub timestamp: String, #[serde(default)] - pub model: Option, + pub model: Option, pub usage: TokenUsage, } @@ -75,7 +75,7 @@ pub fn read_all(project_root: &Path) -> Result, String> { pub fn build_record( story_id: &str, agent_name: &str, - model: Option, + model: Option, usage: TokenUsage, ) -> TokenUsageRecord { TokenUsageRecord { diff --git a/server/src/chat/commands/status/render.rs b/server/src/chat/commands/status/render.rs index 05cb6ec2..6d2ea558 100644 --- a/server/src/chat/commands/status/render.rs +++ b/server/src/chat/commands/status/render.rs @@ -340,7 +340,7 @@ fn render_item_line( let model_str = config .as_ref() .and_then(|cfg| cfg.find_agent(&agent.agent_name)) - .and_then(|ac| ac.model.as_deref()) + .and_then(|ac| ac.model.as_ref().map(|m| m.as_str())) .unwrap_or("?"); format!( " {dot}{display}{cost_suffix}{dep_suffix} — {} ({model_str})\n", diff --git a/server/src/config/mod.rs b/server/src/config/mod.rs index 24a5ae2e..397b8866 100644 --- a/server/src/config/mod.rs +++ b/server/src/config/mod.rs @@ -4,6 +4,7 @@ pub mod agent_name; pub use agent_name::AgentName; +use crate::agents::AgentModel; use crate::slog; use serde::Deserialize; use std::collections::HashSet; @@ -22,12 +23,12 @@ pub struct ProjectConfig { /// Per-story `qa` front matter overrides this. Default: "server". #[serde(default = "default_qa")] pub default_qa: String, - /// Default model for coder-stage agents (e.g. "sonnet"). + /// Default model for coder-stage agents (e.g. `Sonnet`). /// When set, `find_free_agent_for_stage` only considers coder agents whose /// model matches this value, so opus agents are only used when explicitly /// requested via story front matter `agent:` field. #[serde(default)] - pub default_coder_model: Option, + pub default_coder_model: Option, /// Maximum number of concurrent coder-stage agents. /// When set, `auto_assign_available_work` will not start more than this many /// coder agents at once. Stories wait in `2_current/` until a slot frees up. @@ -240,7 +241,7 @@ pub struct AgentConfig { #[serde(default = "default_agent_prompt")] pub prompt: String, #[serde(default)] - pub model: Option, + pub model: Option, #[serde(default)] pub allowed_tools: Option>, #[serde(default)] @@ -312,7 +313,7 @@ struct LegacyProjectConfig { #[serde(default = "default_qa")] default_qa: String, #[serde(default)] - default_coder_model: Option, + default_coder_model: Option, #[serde(default)] max_coders: Option, #[serde(default = "default_max_retries")] @@ -583,7 +584,7 @@ impl ProjectConfig { // Append structured CLI flags if let Some(ref model) = agent.model { args.push("--model".to_string()); - args.push(model.clone()); + args.push(model.as_str().to_string()); } if let Some(ref tools) = agent.allowed_tools && !tools.is_empty() diff --git a/server/src/config/tests.rs b/server/src/config/tests.rs index dcd256b8..a614fee6 100644 --- a/server/src/config/tests.rs +++ b/server/src/config/tests.rs @@ -39,7 +39,7 @@ max_budget_usd = 5.00 assert_eq!(config.agent.len(), 2); assert_eq!(config.agent[0].name, "supervisor"); assert_eq!(config.agent[0].role, "Coordinates work"); - assert_eq!(config.agent[0].model, Some("opus".to_string())); + assert_eq!(config.agent[0].model, Some(crate::agents::AgentModel::Opus)); assert_eq!(config.agent[0].max_turns, Some(50)); assert_eq!(config.agent[0].max_budget_usd, Some(10.0)); assert_eq!( @@ -47,7 +47,10 @@ max_budget_usd = 5.00 Some("You are a senior engineer".to_string()) ); assert_eq!(config.agent[1].name, "coder-1"); - assert_eq!(config.agent[1].model, Some("sonnet".to_string())); + assert_eq!( + config.agent[1].model, + Some(crate::agents::AgentModel::Sonnet) + ); assert_eq!(config.component.len(), 1); } @@ -237,7 +240,10 @@ model = "sonnet" assert_eq!(config.component[1].setup, vec!["pnpm install"]); assert_eq!(config.agent.len(), 1); assert_eq!(config.agent[0].name, "main"); - assert_eq!(config.agent[0].model, Some("sonnet".to_string())); + assert_eq!( + config.agent[0].model, + Some(crate::agents::AgentModel::Sonnet) + ); } #[test] @@ -269,7 +275,7 @@ model = "opus" let config = ProjectConfig::load(tmp.path()).unwrap(); assert_eq!(config.agent.len(), 1); assert_eq!(config.agent[0].name, "from-agents-toml"); - assert_eq!(config.agent[0].model, Some("opus".to_string())); + assert_eq!(config.agent[0].model, Some(crate::agents::AgentModel::Opus)); } #[test] @@ -438,7 +444,10 @@ stage = "coder" model = "opus" "#; let config = ProjectConfig::parse(toml_str).unwrap(); - assert_eq!(config.default_coder_model, Some("sonnet".to_string())); + assert_eq!( + config.default_coder_model, + Some(crate::agents::AgentModel::Sonnet) + ); assert_eq!(config.max_coders, Some(3)); } @@ -459,7 +468,10 @@ fn project_toml_has_default_coder_model_and_max_coders() { let manifest_dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")); let project_root = manifest_dir.parent().unwrap(); let config = ProjectConfig::load(project_root).unwrap(); - assert_eq!(config.default_coder_model, Some("sonnet".to_string())); + assert_eq!( + config.default_coder_model, + Some(crate::agents::AgentModel::Sonnet) + ); assert_eq!(config.max_coders, Some(3)); } @@ -617,7 +629,10 @@ fn project_toml_has_three_sonnet_coders() { let sonnet_coders: Vec<_> = config .agent .iter() - .filter(|a| a.stage.as_deref() == Some("coder") && a.model.as_deref() == Some("sonnet")) + .filter(|a| { + a.stage.as_deref() == Some("coder") + && a.model.as_ref() == Some(&crate::agents::AgentModel::Sonnet) + }) .collect(); assert_eq!( diff --git a/server/src/crdt_sync/rpc.rs b/server/src/crdt_sync/rpc.rs index d7c2f2a1..d38e4f28 100644 --- a/server/src/crdt_sync/rpc.rs +++ b/server/src/crdt_sync/rpc.rs @@ -694,7 +694,9 @@ async fn handle_settings_put_project(params: Value) -> Value { }; let domain = crate::service::settings::ProjectSettings { default_qa: typed.default_qa, - default_coder_model: typed.default_coder_model, + default_coder_model: typed + .default_coder_model + .map(|s| crate::agents::AgentModel::from_api_str(&s)), max_coders: typed.max_coders, max_retries: typed.max_retries, base_branch: typed.base_branch, @@ -713,7 +715,7 @@ async fn handle_settings_put_project(params: Value) -> Value { match crate::service::settings::load_project_settings(&root) { Ok(s) => serde_json::to_value(ProjectSettingsPayload { default_qa: s.default_qa, - default_coder_model: s.default_coder_model, + default_coder_model: s.default_coder_model.map(|m| m.as_str().to_string()), max_coders: s.max_coders, max_retries: s.max_retries, base_branch: s.base_branch, diff --git a/server/src/http/workflow/pipeline.rs b/server/src/http/workflow/pipeline.rs index 49e59380..6f52028b 100644 --- a/server/src/http/workflow/pipeline.rs +++ b/server/src/http/workflow/pipeline.rs @@ -10,7 +10,7 @@ use std::path::Path; #[derive(Clone, Debug, Serialize)] pub struct AgentAssignment { pub agent_name: crate::config::AgentName, - pub model: Option, + pub model: Option, pub status: crate::agents::AgentStatus, } diff --git a/server/src/service/agents/mod.rs b/server/src/service/agents/mod.rs index a44b17a9..f4124799 100644 --- a/server/src/service/agents/mod.rs +++ b/server/src/service/agents/mod.rs @@ -68,7 +68,7 @@ pub struct AgentConfigEntry { pub name: String, pub role: String, pub stage: Option, - pub model: Option, + pub model: Option, pub allowed_tools: Option>, pub max_turns: Option, pub max_budget_usd: Option, @@ -283,7 +283,7 @@ max_budget_usd = 5.0 let entries = get_agent_config(tmp.path()).unwrap(); assert_eq!(entries.len(), 1); assert_eq!(entries[0].name, "coder-1"); - assert_eq!(entries[0].model, Some("sonnet".to_string())); + assert_eq!(entries[0].model, Some(crate::agents::AgentModel::Sonnet)); assert_eq!(entries[0].max_turns, Some(30)); } diff --git a/server/src/service/agents/token.rs b/server/src/service/agents/token.rs index 8e3b2328..25f79fce 100644 --- a/server/src/service/agents/token.rs +++ b/server/src/service/agents/token.rs @@ -9,7 +9,7 @@ use std::collections::HashMap; #[derive(Debug, Clone, PartialEq)] pub struct AgentTokenCost { pub agent_name: String, - pub model: Option, + pub model: Option, pub input_tokens: u64, pub output_tokens: u64, pub cache_creation_input_tokens: u64, @@ -153,8 +153,11 @@ mod tests { #[test] fn aggregate_preserves_model_from_first_record() { let mut r = make_record("42_story_foo", "coder-1", 1.0); - r.model = Some("claude-sonnet".to_string()); + r.model = Some(crate::agents::AgentModel::Sonnet); let summary = aggregate_for_story(&[r], "42_story_foo"); - assert_eq!(summary.agents[0].model, Some("claude-sonnet".to_string())); + assert_eq!( + summary.agents[0].model, + Some(crate::agents::AgentModel::Sonnet) + ); } } diff --git a/server/src/service/settings/project.rs b/server/src/service/settings/project.rs index 6d04dc5a..b32374cf 100644 --- a/server/src/service/settings/project.rs +++ b/server/src/service/settings/project.rs @@ -5,6 +5,7 @@ //! [`merge_settings_into_toml`] (the pure TOML key-updating logic used by the //! write path in `mod.rs` + `io.rs`). +use crate::agents::AgentModel; use crate::config::ProjectConfig; use serde::{Deserialize, Serialize}; @@ -17,8 +18,8 @@ use serde::{Deserialize, Serialize}; pub struct ProjectSettings { /// Project-wide default QA mode: "server", "agent", or "human". Default: "server". pub default_qa: String, - /// Default model for coder-stage agents (e.g. "sonnet"). - pub default_coder_model: Option, + /// Default model for coder-stage agents (e.g. `Sonnet`). + pub default_coder_model: Option, /// Maximum number of concurrent coder-stage agents. pub max_coders: Option, /// Maximum retries per story per pipeline stage before marking as blocked. Default: 2. @@ -89,7 +90,7 @@ pub fn merge_settings_into_toml( Some(v) => { table.insert( "default_coder_model".to_string(), - toml::Value::String(v.clone()), + toml::Value::String(v.as_str().to_string()), ); } None => { @@ -180,7 +181,7 @@ mod tests { fn settings_from_config_copies_all_scalar_fields() { let cfg = ProjectConfig { default_qa: "human".to_string(), - default_coder_model: Some("opus".to_string()), + default_coder_model: Some(AgentModel::Opus), max_coders: Some(4), max_retries: 5, base_branch: Some("main".to_string()), @@ -196,7 +197,7 @@ mod tests { let s = settings_from_config(&cfg); assert_eq!(s.default_qa, "human"); - assert_eq!(s.default_coder_model, Some("opus".to_string())); + assert_eq!(s.default_coder_model, Some(AgentModel::Opus)); assert_eq!(s.max_coders, Some(4)); assert_eq!(s.max_retries, 5); assert_eq!(s.base_branch, Some("main".to_string())); @@ -247,7 +248,7 @@ mod tests { let mut val = empty_toml(); let s = ProjectSettings { default_qa: "server".to_string(), - default_coder_model: Some("sonnet".to_string()), + default_coder_model: Some(AgentModel::Sonnet), max_coders: Some(2), max_retries: 2, base_branch: Some("main".to_string()), @@ -273,7 +274,7 @@ mod tests { // First set them let s_with = ProjectSettings { default_qa: "server".to_string(), - default_coder_model: Some("sonnet".to_string()), + default_coder_model: Some(AgentModel::Sonnet), max_coders: Some(3), max_retries: 2, base_branch: Some("master".to_string()), @@ -372,7 +373,7 @@ path = "." let mut val = empty_toml(); let s = ProjectSettings { default_qa: "human".to_string(), - default_coder_model: Some("opus".to_string()), + default_coder_model: Some(AgentModel::Opus), max_coders: Some(2), max_retries: 4, base_branch: Some("develop".to_string()), diff --git a/server/src/service/settings/validate.rs b/server/src/service/settings/validate.rs index 4620c7c9..77bfb3bd 100644 --- a/server/src/service/settings/validate.rs +++ b/server/src/service/settings/validate.rs @@ -117,7 +117,7 @@ mod tests { fn valid_settings_with_all_optional_fields_set() { let s = ProjectSettings { default_qa: "agent".to_string(), - default_coder_model: Some("opus".to_string()), + default_coder_model: Some(crate::agents::AgentModel::Opus), max_coders: Some(4), max_retries: 5, base_branch: Some("main".to_string()), diff --git a/server/src/service/ws/message/convert.rs b/server/src/service/ws/message/convert.rs index 070ba58a..6e0e5b8e 100644 --- a/server/src/service/ws/message/convert.rs +++ b/server/src/service/ws/message/convert.rs @@ -296,7 +296,7 @@ mod tests { merge_failure: None, agent: Some(crate::http::workflow::pipeline::AgentAssignment { agent_name: crate::config::AgentName::Coder1, - model: Some("claude-3-5-sonnet".to_string()), + model: Some(crate::agents::AgentModel::Sonnet), status: crate::agents::AgentStatus::Running, }), review_hold: None, @@ -316,7 +316,7 @@ mod tests { let resp: WsResponse = state.into(); let json = serde_json::to_value(&resp).unwrap(); assert_eq!(json["current"][0]["agent"]["agent_name"], "coder-1"); - assert_eq!(json["current"][0]["agent"]["model"], "claude-3-5-sonnet"); + assert_eq!(json["current"][0]["agent"]["model"], "sonnet"); assert_eq!(json["current"][0]["agent"]["status"], "running"); }