huskies: merge 1015

This commit is contained in:
dave
2026-05-13 23:33:30 +00:00
parent 69b207872a
commit 5ed1438ab9
22 changed files with 227 additions and 73 deletions
+9 -5
View File
@@ -16,6 +16,10 @@ pub mod session_store;
/// Token-usage tracking and budget estimation. /// Token-usage tracking and budget estimation.
pub mod token_usage; pub mod token_usage;
/// Typed agent model enum (Sonnet/Opus/Haiku).
pub mod model;
pub use model::AgentModel;
use crate::config::AgentConfig; use crate::config::AgentConfig;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -205,13 +209,13 @@ impl TokenUsage {
/// data in the agent log (since `total_cost_usd` is only available in the /// data in the agent log (since `total_cost_usd` is only available in the
/// `result` event at session end). Uses conservative (high) pricing when /// `result` event at session end). Uses conservative (high) pricing when
/// the model is unknown so budget limits are hit sooner rather than later. /// the model is unknown so budget limits are hit sooner rather than later.
pub fn estimate_cost_usd(&self, model: Option<&str>) -> f64 { pub fn estimate_cost_usd(&self, model: Option<&AgentModel>) -> f64 {
// Per-million-token pricing (input, output, cache_read, cache_create). // Per-million-token pricing (input, output, cache_read, cache_create).
let (inp, out, cr, cc) = match model { let (inp, out, cr, cc) = match model {
Some(m) if m.contains("haiku") => (0.80, 4.0, 0.08, 1.00), Some(AgentModel::Haiku) => (0.80, 4.0, 0.08, 1.00),
Some(m) if m.contains("sonnet") => (3.0, 15.0, 0.30, 3.75), Some(AgentModel::Sonnet) => (3.0, 15.0, 0.30, 3.75),
// Opus or unknown → most expensive = conservative. // Opus, Other, or unknown → most expensive = conservative.
_ => (15.0, 75.0, 1.50, 18.75), Some(AgentModel::Opus) | Some(AgentModel::Other(_)) | None => (15.0, 75.0, 1.50, 18.75),
}; };
(self.input_tokens as f64 * inp (self.input_tokens as f64 * inp
+ self.output_tokens as f64 * out + self.output_tokens as f64 * out
+125
View File
@@ -0,0 +1,125 @@
//! Typed agent model — replaces raw model strings throughout the agent subsystem.
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;
/// Supported agent model families.
///
/// Serialises to the canonical short name ("sonnet", "opus", "haiku") or, for
/// `Other`, the original string verbatim. Deserialises from any string:
/// Claude family names are matched by substring (e.g. "claude-sonnet-4-6"),
/// everything else becomes `Other(string)` so non-Claude runtimes (Gemini,
/// OpenAI, etc.) survive a config round-trip without error.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AgentModel {
/// Claude Sonnet family.
Sonnet,
/// Claude Opus family.
Opus,
/// Claude Haiku family.
Haiku,
/// Any model string not recognised as a Claude family name.
Other(String),
}
impl AgentModel {
/// Canonical short name used for serialisation and CLI `--model` flags.
pub fn as_str(&self) -> &str {
match self {
Self::Sonnet => "sonnet",
Self::Opus => "opus",
Self::Haiku => "haiku",
Self::Other(s) => s.as_str(),
}
}
/// Parse any model string into a variant — always succeeds.
///
/// Claude family names are matched by substring; everything else becomes
/// `Other`.
pub fn from_api_str(s: &str) -> Self {
if s.contains("haiku") {
Self::Haiku
} else if s.contains("sonnet") {
Self::Sonnet
} else if s.contains("opus") {
Self::Opus
} else {
Self::Other(s.to_string())
}
}
}
impl fmt::Display for AgentModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl Serialize for AgentModel {
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
s.serialize_str(self.as_str())
}
}
impl<'de> Deserialize<'de> for AgentModel {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let s = String::deserialize(d)?;
Ok(Self::from_api_str(&s))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn short_names_deserialise() {
let s: AgentModel = serde_json::from_str("\"sonnet\"").unwrap();
assert_eq!(s, AgentModel::Sonnet);
let o: AgentModel = serde_json::from_str("\"opus\"").unwrap();
assert_eq!(o, AgentModel::Opus);
let h: AgentModel = serde_json::from_str("\"haiku\"").unwrap();
assert_eq!(h, AgentModel::Haiku);
}
#[test]
fn long_names_deserialise() {
let s: AgentModel = serde_json::from_str("\"claude-sonnet-4-6\"").unwrap();
assert_eq!(s, AgentModel::Sonnet);
let h: AgentModel = serde_json::from_str("\"claude-haiku-4-5-20251001\"").unwrap();
assert_eq!(h, AgentModel::Haiku);
let o: AgentModel = serde_json::from_str("\"claude-opus-4-5\"").unwrap();
assert_eq!(o, AgentModel::Opus);
}
#[test]
fn serialises_to_short_name() {
assert_eq!(
serde_json::to_string(&AgentModel::Sonnet).unwrap(),
"\"sonnet\""
);
assert_eq!(
serde_json::to_string(&AgentModel::Opus).unwrap(),
"\"opus\""
);
assert_eq!(
serde_json::to_string(&AgentModel::Haiku).unwrap(),
"\"haiku\""
);
}
#[test]
fn unknown_string_becomes_other() {
let r: AgentModel = serde_json::from_str("\"gemini-2.5-pro\"").unwrap();
assert_eq!(r, AgentModel::Other("gemini-2.5-pro".to_string()));
assert_eq!(r.as_str(), "gemini-2.5-pro");
// Round-trips verbatim
assert_eq!(serde_json::to_string(&r).unwrap(), "\"gemini-2.5-pro\"");
}
#[test]
fn option_none_round_trips() {
let v: Option<AgentModel> = serde_json::from_str("null").unwrap();
assert!(v.is_none());
}
}
+2 -4
View File
@@ -120,11 +120,9 @@ pub(in crate::agents::pool) fn find_free_agent_for_stage<'a>(
// model matches. This keeps opus agents reserved for explicit requests. // model matches. This keeps opus agents reserved for explicit requests.
if *stage == PipelineStage::Coder if *stage == PipelineStage::Coder
&& let Some(ref default_model) = config.default_coder_model && let Some(ref default_model) = config.default_coder_model
&& agent_config.model.as_ref() != Some(default_model)
{ {
let agent_model = agent_config.model.as_deref().unwrap_or(""); continue;
if agent_model != default_model {
continue;
}
} }
let is_busy = agents.values().any(|a| { let is_busy = agents.values().any(|a| {
a.agent_name == agent_config.name a.agent_name == agent_config.name
@@ -67,7 +67,10 @@ pub(crate) fn compute_budget_from_single_log(path: &Path) -> f64 {
&& let Some(message) = data.get("message") && let Some(message) = data.get("message")
&& let Some(usage) = message.get("usage") && let Some(usage) = message.get("usage")
{ {
let model = message.get("model").and_then(|v| v.as_str()); let model = message
.get("model")
.and_then(|v| v.as_str())
.map(crate::agents::AgentModel::from_api_str);
let token_usage = TokenUsage { let token_usage = TokenUsage {
input_tokens: usage input_tokens: usage
.get("input_tokens") .get("input_tokens")
@@ -87,7 +90,7 @@ pub(crate) fn compute_budget_from_single_log(path: &Path) -> f64 {
.unwrap_or(0), .unwrap_or(0),
total_cost_usd: 0.0, total_cost_usd: 0.0,
}; };
cost += token_usage.estimate_cost_usd(model); cost += token_usage.estimate_cost_usd(model.as_ref());
} }
} }
cost cost
+3 -3
View File
@@ -364,13 +364,13 @@ impl AgentPool {
let effective_session_id = session_id_to_resume.or_else(|| { let effective_session_id = session_id_to_resume.or_else(|| {
let model = config let model = config
.find_agent(&resolved_name) .find_agent(&resolved_name)
.and_then(|a| a.model.clone()) .and_then(|a| a.model.as_ref().map(|m| m.as_str()))
.unwrap_or_default(); .unwrap_or("");
crate::agents::session_store::lookup_session( crate::agents::session_store::lookup_session(
project_root, project_root,
story_id, story_id,
&resolved_name, &resolved_name,
&model, model,
) )
}); });
+11 -8
View File
@@ -384,8 +384,7 @@ pub(super) async fn run_agent_spawn(
// passed to RuntimeContext for eager session recording (bug 967). // passed to RuntimeContext for eager session recording (bug 967).
let agent_model = config_clone let agent_model = config_clone
.find_agent(&aname) .find_agent(&aname)
.and_then(|a| a.model.clone()) .and_then(|a| a.model.clone());
.unwrap_or_default();
let run_result = match runtime_name { let run_result = match runtime_name {
"claude-code" => { "claude-code" => {
@@ -463,11 +462,15 @@ pub(super) async fn run_agent_spawn(
&& let Some(agent) = agents.get(&key_clone) && let Some(agent) = agents.get(&key_clone)
&& let Some(ref pr) = agent.project_root && let Some(ref pr) = agent.project_root
{ {
let model = config_clone let model_for_record = config_clone
.find_agent(&aname) .find_agent(&aname)
.and_then(|a| a.model.clone()); .and_then(|a| a.model.clone());
let record = let record = crate::agents::token_usage::build_record(
crate::agents::token_usage::build_record(&sid, &aname, model, usage.clone()); &sid,
&aname,
model_for_record,
usage.clone(),
);
if let Err(e) = crate::agents::token_usage::append_record(pr, &record) { if let Err(e) = crate::agents::token_usage::append_record(pr, &record) {
slog_error!( slog_error!(
"[agents] Failed to persist token usage for \ "[agents] Failed to persist token usage for \
@@ -480,13 +483,13 @@ pub(super) async fn run_agent_spawn(
if let Some(ref sess_id) = result.session_id { if let Some(ref sess_id) = result.session_id {
let model = config_clone let model = config_clone
.find_agent(&aname) .find_agent(&aname)
.and_then(|a| a.model.clone()) .and_then(|a| a.model.as_ref().map(|m| m.as_str()))
.unwrap_or_default(); .unwrap_or("");
crate::agents::session_store::record_session( crate::agents::session_store::record_session(
&project_root_clone, &project_root_clone,
&sid, &sid,
&aname, &aname,
&model, model,
sess_id, sess_id,
); );
} }
+4 -5
View File
@@ -42,11 +42,10 @@ impl AgentRuntime for ClaudeCodeRuntime {
event_log: Arc<Mutex<Vec<AgentEvent>>>, event_log: Arc<Mutex<Vec<AgentEvent>>>,
log_writer: Option<Arc<Mutex<AgentLogWriter>>>, log_writer: Option<Arc<Mutex<AgentLogWriter>>>,
) -> Result<RuntimeResult, String> { ) -> Result<RuntimeResult, String> {
let eager_record = if ctx.model.is_empty() { let eager_record = ctx
None .model
} else { .as_ref()
Some((ctx.project_root.clone(), ctx.model.clone())) .map(|m| (ctx.project_root.clone(), m.as_str().to_string()));
};
let pty_result = super::super::pty::run_agent_pty_streaming( let pty_result = super::super::pty::run_agent_pty_streaming(
&ctx.story_id, &ctx.story_id,
&ctx.agent_name, &ctx.agent_name,
+2 -2
View File
@@ -118,7 +118,7 @@ mod tests {
session_id_to_resume: None, session_id_to_resume: None,
fresh_prompt: None, fresh_prompt: None,
project_root: std::path::PathBuf::from("/tmp/project"), project_root: std::path::PathBuf::from("/tmp/project"),
model: String::new(), model: None,
}; };
let instruction = build_system_instruction(&ctx); let instruction = build_system_instruction(&ctx);
@@ -139,7 +139,7 @@ mod tests {
session_id_to_resume: None, session_id_to_resume: None,
fresh_prompt: None, fresh_prompt: None,
project_root: std::path::PathBuf::from("/tmp/project"), project_root: std::path::PathBuf::from("/tmp/project"),
model: String::new(), model: None,
}; };
let instruction = build_system_instruction(&ctx); let instruction = build_system_instruction(&ctx);
+1 -1
View File
@@ -385,7 +385,7 @@ mod tests {
session_id_to_resume: None, session_id_to_resume: None,
fresh_prompt: None, fresh_prompt: None,
project_root: std::path::PathBuf::from("/tmp/project"), project_root: std::path::PathBuf::from("/tmp/project"),
model: String::new(), model: None,
}; };
// The model extraction logic is inside start(), but we test the // The model extraction logic is inside start(), but we test the
+4 -4
View File
@@ -46,9 +46,9 @@ pub struct RuntimeContext {
/// Eager recording ensures the session survives a watchdog kill that aborts /// Eager recording ensures the session survives a watchdog kill that aborts
/// the tokio task before `run_agent_spawn`'s `record_session()` call runs. /// the tokio task before `run_agent_spawn`'s `record_session()` call runs.
pub project_root: std::path::PathBuf, pub project_root: std::path::PathBuf,
/// Agent model name — forms part of the session store key used for eager /// Agent model — forms part of the session store key used for eager
/// recording (bug 967). An empty string disables eager recording. /// recording (bug 967). `None` disables eager recording.
pub model: String, pub model: Option<crate::agents::AgentModel>,
} }
/// Result returned by a runtime after the agent session completes. /// Result returned by a runtime after the agent session completes.
@@ -134,7 +134,7 @@ mod tests {
session_id_to_resume: None, session_id_to_resume: None,
fresh_prompt: None, fresh_prompt: None,
project_root: std::path::PathBuf::from("/tmp/project"), project_root: std::path::PathBuf::from("/tmp/project"),
model: "sonnet".to_string(), model: Some(crate::agents::AgentModel::Sonnet),
}; };
assert_eq!(ctx.story_id, "42_story_foo"); assert_eq!(ctx.story_id, "42_story_foo");
assert_eq!(ctx.agent_name, "coder-1"); assert_eq!(ctx.agent_name, "coder-1");
+4 -4
View File
@@ -561,7 +561,7 @@ mod tests {
session_id_to_resume: None, session_id_to_resume: None,
fresh_prompt: None, fresh_prompt: None,
project_root: std::path::PathBuf::from("/tmp/project"), project_root: std::path::PathBuf::from("/tmp/project"),
model: String::new(), model: None,
}; };
assert_eq!(build_system_text(&ctx), "Custom system prompt"); assert_eq!(build_system_text(&ctx), "Custom system prompt");
@@ -581,7 +581,7 @@ mod tests {
session_id_to_resume: None, session_id_to_resume: None,
fresh_prompt: None, fresh_prompt: None,
project_root: std::path::PathBuf::from("/tmp/project"), project_root: std::path::PathBuf::from("/tmp/project"),
model: String::new(), model: None,
}; };
let text = build_system_text(&ctx); let text = build_system_text(&ctx);
@@ -634,7 +634,7 @@ mod tests {
session_id_to_resume: None, session_id_to_resume: None,
fresh_prompt: None, fresh_prompt: None,
project_root: std::path::PathBuf::from("/tmp/project"), project_root: std::path::PathBuf::from("/tmp/project"),
model: String::new(), model: None,
}; };
assert!(ctx.command.starts_with("gpt")); assert!(ctx.command.starts_with("gpt"));
} }
@@ -653,7 +653,7 @@ mod tests {
session_id_to_resume: None, session_id_to_resume: None,
fresh_prompt: None, fresh_prompt: None,
project_root: std::path::PathBuf::from("/tmp/project"), project_root: std::path::PathBuf::from("/tmp/project"),
model: String::new(), model: None,
}; };
assert!(ctx.command.starts_with("o")); assert!(ctx.command.starts_with("o"));
} }
+3 -3
View File
@@ -5,7 +5,7 @@ use std::path::Path;
use chrono::Utc; use chrono::Utc;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::TokenUsage; use super::{AgentModel, TokenUsage};
/// A single token usage record persisted to disk. /// A single token usage record persisted to disk.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
@@ -14,7 +14,7 @@ pub struct TokenUsageRecord {
pub agent_name: String, pub agent_name: String,
pub timestamp: String, pub timestamp: String,
#[serde(default)] #[serde(default)]
pub model: Option<String>, pub model: Option<AgentModel>,
pub usage: TokenUsage, pub usage: TokenUsage,
} }
@@ -75,7 +75,7 @@ pub fn read_all(project_root: &Path) -> Result<Vec<TokenUsageRecord>, String> {
pub fn build_record( pub fn build_record(
story_id: &str, story_id: &str,
agent_name: &str, agent_name: &str,
model: Option<String>, model: Option<AgentModel>,
usage: TokenUsage, usage: TokenUsage,
) -> TokenUsageRecord { ) -> TokenUsageRecord {
TokenUsageRecord { TokenUsageRecord {
+1 -1
View File
@@ -340,7 +340,7 @@ fn render_item_line(
let model_str = config let model_str = config
.as_ref() .as_ref()
.and_then(|cfg| cfg.find_agent(&agent.agent_name)) .and_then(|cfg| cfg.find_agent(&agent.agent_name))
.and_then(|ac| ac.model.as_deref()) .and_then(|ac| ac.model.as_ref().map(|m| m.as_str()))
.unwrap_or("?"); .unwrap_or("?");
format!( format!(
" {dot}{display}{cost_suffix}{dep_suffix} — {} ({model_str})\n", " {dot}{display}{cost_suffix}{dep_suffix} — {} ({model_str})\n",
+6 -5
View File
@@ -4,6 +4,7 @@
pub mod agent_name; pub mod agent_name;
pub use agent_name::AgentName; pub use agent_name::AgentName;
use crate::agents::AgentModel;
use crate::slog; use crate::slog;
use serde::Deserialize; use serde::Deserialize;
use std::collections::HashSet; use std::collections::HashSet;
@@ -22,12 +23,12 @@ pub struct ProjectConfig {
/// Per-story `qa` front matter overrides this. Default: "server". /// Per-story `qa` front matter overrides this. Default: "server".
#[serde(default = "default_qa")] #[serde(default = "default_qa")]
pub default_qa: String, pub default_qa: String,
/// Default model for coder-stage agents (e.g. "sonnet"). /// Default model for coder-stage agents (e.g. `Sonnet`).
/// When set, `find_free_agent_for_stage` only considers coder agents whose /// When set, `find_free_agent_for_stage` only considers coder agents whose
/// model matches this value, so opus agents are only used when explicitly /// model matches this value, so opus agents are only used when explicitly
/// requested via story front matter `agent:` field. /// requested via story front matter `agent:` field.
#[serde(default)] #[serde(default)]
pub default_coder_model: Option<String>, pub default_coder_model: Option<AgentModel>,
/// Maximum number of concurrent coder-stage agents. /// Maximum number of concurrent coder-stage agents.
/// When set, `auto_assign_available_work` will not start more than this many /// When set, `auto_assign_available_work` will not start more than this many
/// coder agents at once. Stories wait in `2_current/` until a slot frees up. /// coder agents at once. Stories wait in `2_current/` until a slot frees up.
@@ -240,7 +241,7 @@ pub struct AgentConfig {
#[serde(default = "default_agent_prompt")] #[serde(default = "default_agent_prompt")]
pub prompt: String, pub prompt: String,
#[serde(default)] #[serde(default)]
pub model: Option<String>, pub model: Option<AgentModel>,
#[serde(default)] #[serde(default)]
pub allowed_tools: Option<Vec<String>>, pub allowed_tools: Option<Vec<String>>,
#[serde(default)] #[serde(default)]
@@ -312,7 +313,7 @@ struct LegacyProjectConfig {
#[serde(default = "default_qa")] #[serde(default = "default_qa")]
default_qa: String, default_qa: String,
#[serde(default)] #[serde(default)]
default_coder_model: Option<String>, default_coder_model: Option<AgentModel>,
#[serde(default)] #[serde(default)]
max_coders: Option<usize>, max_coders: Option<usize>,
#[serde(default = "default_max_retries")] #[serde(default = "default_max_retries")]
@@ -583,7 +584,7 @@ impl ProjectConfig {
// Append structured CLI flags // Append structured CLI flags
if let Some(ref model) = agent.model { if let Some(ref model) = agent.model {
args.push("--model".to_string()); args.push("--model".to_string());
args.push(model.clone()); args.push(model.as_str().to_string());
} }
if let Some(ref tools) = agent.allowed_tools if let Some(ref tools) = agent.allowed_tools
&& !tools.is_empty() && !tools.is_empty()
+22 -7
View File
@@ -39,7 +39,7 @@ max_budget_usd = 5.00
assert_eq!(config.agent.len(), 2); assert_eq!(config.agent.len(), 2);
assert_eq!(config.agent[0].name, "supervisor"); assert_eq!(config.agent[0].name, "supervisor");
assert_eq!(config.agent[0].role, "Coordinates work"); assert_eq!(config.agent[0].role, "Coordinates work");
assert_eq!(config.agent[0].model, Some("opus".to_string())); assert_eq!(config.agent[0].model, Some(crate::agents::AgentModel::Opus));
assert_eq!(config.agent[0].max_turns, Some(50)); assert_eq!(config.agent[0].max_turns, Some(50));
assert_eq!(config.agent[0].max_budget_usd, Some(10.0)); assert_eq!(config.agent[0].max_budget_usd, Some(10.0));
assert_eq!( assert_eq!(
@@ -47,7 +47,10 @@ max_budget_usd = 5.00
Some("You are a senior engineer".to_string()) Some("You are a senior engineer".to_string())
); );
assert_eq!(config.agent[1].name, "coder-1"); assert_eq!(config.agent[1].name, "coder-1");
assert_eq!(config.agent[1].model, Some("sonnet".to_string())); assert_eq!(
config.agent[1].model,
Some(crate::agents::AgentModel::Sonnet)
);
assert_eq!(config.component.len(), 1); assert_eq!(config.component.len(), 1);
} }
@@ -237,7 +240,10 @@ model = "sonnet"
assert_eq!(config.component[1].setup, vec!["pnpm install"]); assert_eq!(config.component[1].setup, vec!["pnpm install"]);
assert_eq!(config.agent.len(), 1); assert_eq!(config.agent.len(), 1);
assert_eq!(config.agent[0].name, "main"); assert_eq!(config.agent[0].name, "main");
assert_eq!(config.agent[0].model, Some("sonnet".to_string())); assert_eq!(
config.agent[0].model,
Some(crate::agents::AgentModel::Sonnet)
);
} }
#[test] #[test]
@@ -269,7 +275,7 @@ model = "opus"
let config = ProjectConfig::load(tmp.path()).unwrap(); let config = ProjectConfig::load(tmp.path()).unwrap();
assert_eq!(config.agent.len(), 1); assert_eq!(config.agent.len(), 1);
assert_eq!(config.agent[0].name, "from-agents-toml"); assert_eq!(config.agent[0].name, "from-agents-toml");
assert_eq!(config.agent[0].model, Some("opus".to_string())); assert_eq!(config.agent[0].model, Some(crate::agents::AgentModel::Opus));
} }
#[test] #[test]
@@ -438,7 +444,10 @@ stage = "coder"
model = "opus" model = "opus"
"#; "#;
let config = ProjectConfig::parse(toml_str).unwrap(); let config = ProjectConfig::parse(toml_str).unwrap();
assert_eq!(config.default_coder_model, Some("sonnet".to_string())); assert_eq!(
config.default_coder_model,
Some(crate::agents::AgentModel::Sonnet)
);
assert_eq!(config.max_coders, Some(3)); assert_eq!(config.max_coders, Some(3));
} }
@@ -459,7 +468,10 @@ fn project_toml_has_default_coder_model_and_max_coders() {
let manifest_dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")); let manifest_dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR"));
let project_root = manifest_dir.parent().unwrap(); let project_root = manifest_dir.parent().unwrap();
let config = ProjectConfig::load(project_root).unwrap(); let config = ProjectConfig::load(project_root).unwrap();
assert_eq!(config.default_coder_model, Some("sonnet".to_string())); assert_eq!(
config.default_coder_model,
Some(crate::agents::AgentModel::Sonnet)
);
assert_eq!(config.max_coders, Some(3)); assert_eq!(config.max_coders, Some(3));
} }
@@ -617,7 +629,10 @@ fn project_toml_has_three_sonnet_coders() {
let sonnet_coders: Vec<_> = config let sonnet_coders: Vec<_> = config
.agent .agent
.iter() .iter()
.filter(|a| a.stage.as_deref() == Some("coder") && a.model.as_deref() == Some("sonnet")) .filter(|a| {
a.stage.as_deref() == Some("coder")
&& a.model.as_ref() == Some(&crate::agents::AgentModel::Sonnet)
})
.collect(); .collect();
assert_eq!( assert_eq!(
+4 -2
View File
@@ -694,7 +694,9 @@ async fn handle_settings_put_project(params: Value) -> Value {
}; };
let domain = crate::service::settings::ProjectSettings { let domain = crate::service::settings::ProjectSettings {
default_qa: typed.default_qa, default_qa: typed.default_qa,
default_coder_model: typed.default_coder_model, default_coder_model: typed
.default_coder_model
.map(|s| crate::agents::AgentModel::from_api_str(&s)),
max_coders: typed.max_coders, max_coders: typed.max_coders,
max_retries: typed.max_retries, max_retries: typed.max_retries,
base_branch: typed.base_branch, base_branch: typed.base_branch,
@@ -713,7 +715,7 @@ async fn handle_settings_put_project(params: Value) -> Value {
match crate::service::settings::load_project_settings(&root) { match crate::service::settings::load_project_settings(&root) {
Ok(s) => serde_json::to_value(ProjectSettingsPayload { Ok(s) => serde_json::to_value(ProjectSettingsPayload {
default_qa: s.default_qa, default_qa: s.default_qa,
default_coder_model: s.default_coder_model, default_coder_model: s.default_coder_model.map(|m| m.as_str().to_string()),
max_coders: s.max_coders, max_coders: s.max_coders,
max_retries: s.max_retries, max_retries: s.max_retries,
base_branch: s.base_branch, base_branch: s.base_branch,
+1 -1
View File
@@ -10,7 +10,7 @@ use std::path::Path;
#[derive(Clone, Debug, Serialize)] #[derive(Clone, Debug, Serialize)]
pub struct AgentAssignment { pub struct AgentAssignment {
pub agent_name: crate::config::AgentName, pub agent_name: crate::config::AgentName,
pub model: Option<String>, pub model: Option<crate::agents::AgentModel>,
pub status: crate::agents::AgentStatus, pub status: crate::agents::AgentStatus,
} }
+2 -2
View File
@@ -68,7 +68,7 @@ pub struct AgentConfigEntry {
pub name: String, pub name: String,
pub role: String, pub role: String,
pub stage: Option<String>, pub stage: Option<String>,
pub model: Option<String>, pub model: Option<crate::agents::AgentModel>,
pub allowed_tools: Option<Vec<String>>, pub allowed_tools: Option<Vec<String>>,
pub max_turns: Option<u32>, pub max_turns: Option<u32>,
pub max_budget_usd: Option<f64>, pub max_budget_usd: Option<f64>,
@@ -283,7 +283,7 @@ max_budget_usd = 5.0
let entries = get_agent_config(tmp.path()).unwrap(); let entries = get_agent_config(tmp.path()).unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
assert_eq!(entries[0].name, "coder-1"); assert_eq!(entries[0].name, "coder-1");
assert_eq!(entries[0].model, Some("sonnet".to_string())); assert_eq!(entries[0].model, Some(crate::agents::AgentModel::Sonnet));
assert_eq!(entries[0].max_turns, Some(30)); assert_eq!(entries[0].max_turns, Some(30));
} }
+6 -3
View File
@@ -9,7 +9,7 @@ use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct AgentTokenCost { pub struct AgentTokenCost {
pub agent_name: String, pub agent_name: String,
pub model: Option<String>, pub model: Option<crate::agents::AgentModel>,
pub input_tokens: u64, pub input_tokens: u64,
pub output_tokens: u64, pub output_tokens: u64,
pub cache_creation_input_tokens: u64, pub cache_creation_input_tokens: u64,
@@ -153,8 +153,11 @@ mod tests {
#[test] #[test]
fn aggregate_preserves_model_from_first_record() { fn aggregate_preserves_model_from_first_record() {
let mut r = make_record("42_story_foo", "coder-1", 1.0); let mut r = make_record("42_story_foo", "coder-1", 1.0);
r.model = Some("claude-sonnet".to_string()); r.model = Some(crate::agents::AgentModel::Sonnet);
let summary = aggregate_for_story(&[r], "42_story_foo"); let summary = aggregate_for_story(&[r], "42_story_foo");
assert_eq!(summary.agents[0].model, Some("claude-sonnet".to_string())); assert_eq!(
summary.agents[0].model,
Some(crate::agents::AgentModel::Sonnet)
);
} }
} }
+9 -8
View File
@@ -5,6 +5,7 @@
//! [`merge_settings_into_toml`] (the pure TOML key-updating logic used by the //! [`merge_settings_into_toml`] (the pure TOML key-updating logic used by the
//! write path in `mod.rs` + `io.rs`). //! write path in `mod.rs` + `io.rs`).
use crate::agents::AgentModel;
use crate::config::ProjectConfig; use crate::config::ProjectConfig;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -17,8 +18,8 @@ use serde::{Deserialize, Serialize};
pub struct ProjectSettings { pub struct ProjectSettings {
/// Project-wide default QA mode: "server", "agent", or "human". Default: "server". /// Project-wide default QA mode: "server", "agent", or "human". Default: "server".
pub default_qa: String, pub default_qa: String,
/// Default model for coder-stage agents (e.g. "sonnet"). /// Default model for coder-stage agents (e.g. `Sonnet`).
pub default_coder_model: Option<String>, pub default_coder_model: Option<AgentModel>,
/// Maximum number of concurrent coder-stage agents. /// Maximum number of concurrent coder-stage agents.
pub max_coders: Option<u32>, pub max_coders: Option<u32>,
/// Maximum retries per story per pipeline stage before marking as blocked. Default: 2. /// Maximum retries per story per pipeline stage before marking as blocked. Default: 2.
@@ -89,7 +90,7 @@ pub fn merge_settings_into_toml(
Some(v) => { Some(v) => {
table.insert( table.insert(
"default_coder_model".to_string(), "default_coder_model".to_string(),
toml::Value::String(v.clone()), toml::Value::String(v.as_str().to_string()),
); );
} }
None => { None => {
@@ -180,7 +181,7 @@ mod tests {
fn settings_from_config_copies_all_scalar_fields() { fn settings_from_config_copies_all_scalar_fields() {
let cfg = ProjectConfig { let cfg = ProjectConfig {
default_qa: "human".to_string(), default_qa: "human".to_string(),
default_coder_model: Some("opus".to_string()), default_coder_model: Some(AgentModel::Opus),
max_coders: Some(4), max_coders: Some(4),
max_retries: 5, max_retries: 5,
base_branch: Some("main".to_string()), base_branch: Some("main".to_string()),
@@ -196,7 +197,7 @@ mod tests {
let s = settings_from_config(&cfg); let s = settings_from_config(&cfg);
assert_eq!(s.default_qa, "human"); assert_eq!(s.default_qa, "human");
assert_eq!(s.default_coder_model, Some("opus".to_string())); assert_eq!(s.default_coder_model, Some(AgentModel::Opus));
assert_eq!(s.max_coders, Some(4)); assert_eq!(s.max_coders, Some(4));
assert_eq!(s.max_retries, 5); assert_eq!(s.max_retries, 5);
assert_eq!(s.base_branch, Some("main".to_string())); assert_eq!(s.base_branch, Some("main".to_string()));
@@ -247,7 +248,7 @@ mod tests {
let mut val = empty_toml(); let mut val = empty_toml();
let s = ProjectSettings { let s = ProjectSettings {
default_qa: "server".to_string(), default_qa: "server".to_string(),
default_coder_model: Some("sonnet".to_string()), default_coder_model: Some(AgentModel::Sonnet),
max_coders: Some(2), max_coders: Some(2),
max_retries: 2, max_retries: 2,
base_branch: Some("main".to_string()), base_branch: Some("main".to_string()),
@@ -273,7 +274,7 @@ mod tests {
// First set them // First set them
let s_with = ProjectSettings { let s_with = ProjectSettings {
default_qa: "server".to_string(), default_qa: "server".to_string(),
default_coder_model: Some("sonnet".to_string()), default_coder_model: Some(AgentModel::Sonnet),
max_coders: Some(3), max_coders: Some(3),
max_retries: 2, max_retries: 2,
base_branch: Some("master".to_string()), base_branch: Some("master".to_string()),
@@ -372,7 +373,7 @@ path = "."
let mut val = empty_toml(); let mut val = empty_toml();
let s = ProjectSettings { let s = ProjectSettings {
default_qa: "human".to_string(), default_qa: "human".to_string(),
default_coder_model: Some("opus".to_string()), default_coder_model: Some(AgentModel::Opus),
max_coders: Some(2), max_coders: Some(2),
max_retries: 4, max_retries: 4,
base_branch: Some("develop".to_string()), base_branch: Some("develop".to_string()),
+1 -1
View File
@@ -117,7 +117,7 @@ mod tests {
fn valid_settings_with_all_optional_fields_set() { fn valid_settings_with_all_optional_fields_set() {
let s = ProjectSettings { let s = ProjectSettings {
default_qa: "agent".to_string(), default_qa: "agent".to_string(),
default_coder_model: Some("opus".to_string()), default_coder_model: Some(crate::agents::AgentModel::Opus),
max_coders: Some(4), max_coders: Some(4),
max_retries: 5, max_retries: 5,
base_branch: Some("main".to_string()), base_branch: Some("main".to_string()),
+2 -2
View File
@@ -296,7 +296,7 @@ mod tests {
merge_failure: None, merge_failure: None,
agent: Some(crate::http::workflow::pipeline::AgentAssignment { agent: Some(crate::http::workflow::pipeline::AgentAssignment {
agent_name: crate::config::AgentName::Coder1, agent_name: crate::config::AgentName::Coder1,
model: Some("claude-3-5-sonnet".to_string()), model: Some(crate::agents::AgentModel::Sonnet),
status: crate::agents::AgentStatus::Running, status: crate::agents::AgentStatus::Running,
}), }),
review_hold: None, review_hold: None,
@@ -316,7 +316,7 @@ mod tests {
let resp: WsResponse = state.into(); let resp: WsResponse = state.into();
let json = serde_json::to_value(&resp).unwrap(); let json = serde_json::to_value(&resp).unwrap();
assert_eq!(json["current"][0]["agent"]["agent_name"], "coder-1"); assert_eq!(json["current"][0]["agent"]["agent_name"], "coder-1");
assert_eq!(json["current"][0]["agent"]["model"], "claude-3-5-sonnet"); assert_eq!(json["current"][0]["agent"]["model"], "sonnet");
assert_eq!(json["current"][0]["agent"]["status"], "running"); assert_eq!(json["current"][0]["agent"]["status"], "running");
} }