huskies: merge 1015
This commit is contained in:
@@ -16,6 +16,10 @@ pub mod session_store;
|
||||
/// Token-usage tracking and budget estimation.
|
||||
pub mod token_usage;
|
||||
|
||||
/// Typed agent model enum (Sonnet/Opus/Haiku).
|
||||
pub mod model;
|
||||
pub use model::AgentModel;
|
||||
|
||||
use crate::config::AgentConfig;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -205,13 +209,13 @@ impl TokenUsage {
|
||||
/// data in the agent log (since `total_cost_usd` is only available in the
|
||||
/// `result` event at session end). Uses conservative (high) pricing when
|
||||
/// the model is unknown so budget limits are hit sooner rather than later.
|
||||
pub fn estimate_cost_usd(&self, model: Option<&str>) -> f64 {
|
||||
pub fn estimate_cost_usd(&self, model: Option<&AgentModel>) -> f64 {
|
||||
// Per-million-token pricing (input, output, cache_read, cache_create).
|
||||
let (inp, out, cr, cc) = match model {
|
||||
Some(m) if m.contains("haiku") => (0.80, 4.0, 0.08, 1.00),
|
||||
Some(m) if m.contains("sonnet") => (3.0, 15.0, 0.30, 3.75),
|
||||
// Opus or unknown → most expensive = conservative.
|
||||
_ => (15.0, 75.0, 1.50, 18.75),
|
||||
Some(AgentModel::Haiku) => (0.80, 4.0, 0.08, 1.00),
|
||||
Some(AgentModel::Sonnet) => (3.0, 15.0, 0.30, 3.75),
|
||||
// Opus, Other, or unknown → most expensive = conservative.
|
||||
Some(AgentModel::Opus) | Some(AgentModel::Other(_)) | None => (15.0, 75.0, 1.50, 18.75),
|
||||
};
|
||||
(self.input_tokens as f64 * inp
|
||||
+ self.output_tokens as f64 * out
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
//! Typed agent model — replaces raw model strings throughout the agent subsystem.
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::fmt;
|
||||
|
||||
/// Supported agent model families.
|
||||
///
|
||||
/// Serialises to the canonical short name ("sonnet", "opus", "haiku") or, for
|
||||
/// `Other`, the original string verbatim. Deserialises from any string:
|
||||
/// Claude family names are matched by substring (e.g. "claude-sonnet-4-6"),
|
||||
/// everything else becomes `Other(string)` so non-Claude runtimes (Gemini,
|
||||
/// OpenAI, etc.) survive a config round-trip without error.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum AgentModel {
|
||||
/// Claude Sonnet family.
|
||||
Sonnet,
|
||||
/// Claude Opus family.
|
||||
Opus,
|
||||
/// Claude Haiku family.
|
||||
Haiku,
|
||||
/// Any model string not recognised as a Claude family name.
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl AgentModel {
|
||||
/// Canonical short name used for serialisation and CLI `--model` flags.
|
||||
pub fn as_str(&self) -> &str {
|
||||
match self {
|
||||
Self::Sonnet => "sonnet",
|
||||
Self::Opus => "opus",
|
||||
Self::Haiku => "haiku",
|
||||
Self::Other(s) => s.as_str(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse any model string into a variant — always succeeds.
|
||||
///
|
||||
/// Claude family names are matched by substring; everything else becomes
|
||||
/// `Other`.
|
||||
pub fn from_api_str(s: &str) -> Self {
|
||||
if s.contains("haiku") {
|
||||
Self::Haiku
|
||||
} else if s.contains("sonnet") {
|
||||
Self::Sonnet
|
||||
} else if s.contains("opus") {
|
||||
Self::Opus
|
||||
} else {
|
||||
Self::Other(s.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for AgentModel {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for AgentModel {
|
||||
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
|
||||
s.serialize_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for AgentModel {
|
||||
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
|
||||
let s = String::deserialize(d)?;
|
||||
Ok(Self::from_api_str(&s))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn short_names_deserialise() {
|
||||
let s: AgentModel = serde_json::from_str("\"sonnet\"").unwrap();
|
||||
assert_eq!(s, AgentModel::Sonnet);
|
||||
let o: AgentModel = serde_json::from_str("\"opus\"").unwrap();
|
||||
assert_eq!(o, AgentModel::Opus);
|
||||
let h: AgentModel = serde_json::from_str("\"haiku\"").unwrap();
|
||||
assert_eq!(h, AgentModel::Haiku);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn long_names_deserialise() {
|
||||
let s: AgentModel = serde_json::from_str("\"claude-sonnet-4-6\"").unwrap();
|
||||
assert_eq!(s, AgentModel::Sonnet);
|
||||
let h: AgentModel = serde_json::from_str("\"claude-haiku-4-5-20251001\"").unwrap();
|
||||
assert_eq!(h, AgentModel::Haiku);
|
||||
let o: AgentModel = serde_json::from_str("\"claude-opus-4-5\"").unwrap();
|
||||
assert_eq!(o, AgentModel::Opus);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialises_to_short_name() {
|
||||
assert_eq!(
|
||||
serde_json::to_string(&AgentModel::Sonnet).unwrap(),
|
||||
"\"sonnet\""
|
||||
);
|
||||
assert_eq!(
|
||||
serde_json::to_string(&AgentModel::Opus).unwrap(),
|
||||
"\"opus\""
|
||||
);
|
||||
assert_eq!(
|
||||
serde_json::to_string(&AgentModel::Haiku).unwrap(),
|
||||
"\"haiku\""
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_string_becomes_other() {
|
||||
let r: AgentModel = serde_json::from_str("\"gemini-2.5-pro\"").unwrap();
|
||||
assert_eq!(r, AgentModel::Other("gemini-2.5-pro".to_string()));
|
||||
assert_eq!(r.as_str(), "gemini-2.5-pro");
|
||||
// Round-trips verbatim
|
||||
assert_eq!(serde_json::to_string(&r).unwrap(), "\"gemini-2.5-pro\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn option_none_round_trips() {
|
||||
let v: Option<AgentModel> = serde_json::from_str("null").unwrap();
|
||||
assert!(v.is_none());
|
||||
}
|
||||
}
|
||||
@@ -120,11 +120,9 @@ pub(in crate::agents::pool) fn find_free_agent_for_stage<'a>(
|
||||
// model matches. This keeps opus agents reserved for explicit requests.
|
||||
if *stage == PipelineStage::Coder
|
||||
&& let Some(ref default_model) = config.default_coder_model
|
||||
&& agent_config.model.as_ref() != Some(default_model)
|
||||
{
|
||||
let agent_model = agent_config.model.as_deref().unwrap_or("");
|
||||
if agent_model != default_model {
|
||||
continue;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
let is_busy = agents.values().any(|a| {
|
||||
a.agent_name == agent_config.name
|
||||
|
||||
@@ -67,7 +67,10 @@ pub(crate) fn compute_budget_from_single_log(path: &Path) -> f64 {
|
||||
&& let Some(message) = data.get("message")
|
||||
&& let Some(usage) = message.get("usage")
|
||||
{
|
||||
let model = message.get("model").and_then(|v| v.as_str());
|
||||
let model = message
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(crate::agents::AgentModel::from_api_str);
|
||||
let token_usage = TokenUsage {
|
||||
input_tokens: usage
|
||||
.get("input_tokens")
|
||||
@@ -87,7 +90,7 @@ pub(crate) fn compute_budget_from_single_log(path: &Path) -> f64 {
|
||||
.unwrap_or(0),
|
||||
total_cost_usd: 0.0,
|
||||
};
|
||||
cost += token_usage.estimate_cost_usd(model);
|
||||
cost += token_usage.estimate_cost_usd(model.as_ref());
|
||||
}
|
||||
}
|
||||
cost
|
||||
|
||||
@@ -364,13 +364,13 @@ impl AgentPool {
|
||||
let effective_session_id = session_id_to_resume.or_else(|| {
|
||||
let model = config
|
||||
.find_agent(&resolved_name)
|
||||
.and_then(|a| a.model.clone())
|
||||
.unwrap_or_default();
|
||||
.and_then(|a| a.model.as_ref().map(|m| m.as_str()))
|
||||
.unwrap_or("");
|
||||
crate::agents::session_store::lookup_session(
|
||||
project_root,
|
||||
story_id,
|
||||
&resolved_name,
|
||||
&model,
|
||||
model,
|
||||
)
|
||||
});
|
||||
|
||||
|
||||
@@ -384,8 +384,7 @@ pub(super) async fn run_agent_spawn(
|
||||
// passed to RuntimeContext for eager session recording (bug 967).
|
||||
let agent_model = config_clone
|
||||
.find_agent(&aname)
|
||||
.and_then(|a| a.model.clone())
|
||||
.unwrap_or_default();
|
||||
.and_then(|a| a.model.clone());
|
||||
|
||||
let run_result = match runtime_name {
|
||||
"claude-code" => {
|
||||
@@ -463,11 +462,15 @@ pub(super) async fn run_agent_spawn(
|
||||
&& let Some(agent) = agents.get(&key_clone)
|
||||
&& let Some(ref pr) = agent.project_root
|
||||
{
|
||||
let model = config_clone
|
||||
let model_for_record = config_clone
|
||||
.find_agent(&aname)
|
||||
.and_then(|a| a.model.clone());
|
||||
let record =
|
||||
crate::agents::token_usage::build_record(&sid, &aname, model, usage.clone());
|
||||
let record = crate::agents::token_usage::build_record(
|
||||
&sid,
|
||||
&aname,
|
||||
model_for_record,
|
||||
usage.clone(),
|
||||
);
|
||||
if let Err(e) = crate::agents::token_usage::append_record(pr, &record) {
|
||||
slog_error!(
|
||||
"[agents] Failed to persist token usage for \
|
||||
@@ -480,13 +483,13 @@ pub(super) async fn run_agent_spawn(
|
||||
if let Some(ref sess_id) = result.session_id {
|
||||
let model = config_clone
|
||||
.find_agent(&aname)
|
||||
.and_then(|a| a.model.clone())
|
||||
.unwrap_or_default();
|
||||
.and_then(|a| a.model.as_ref().map(|m| m.as_str()))
|
||||
.unwrap_or("");
|
||||
crate::agents::session_store::record_session(
|
||||
&project_root_clone,
|
||||
&sid,
|
||||
&aname,
|
||||
&model,
|
||||
model,
|
||||
sess_id,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -42,11 +42,10 @@ impl AgentRuntime for ClaudeCodeRuntime {
|
||||
event_log: Arc<Mutex<Vec<AgentEvent>>>,
|
||||
log_writer: Option<Arc<Mutex<AgentLogWriter>>>,
|
||||
) -> Result<RuntimeResult, String> {
|
||||
let eager_record = if ctx.model.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some((ctx.project_root.clone(), ctx.model.clone()))
|
||||
};
|
||||
let eager_record = ctx
|
||||
.model
|
||||
.as_ref()
|
||||
.map(|m| (ctx.project_root.clone(), m.as_str().to_string()));
|
||||
let pty_result = super::super::pty::run_agent_pty_streaming(
|
||||
&ctx.story_id,
|
||||
&ctx.agent_name,
|
||||
|
||||
@@ -118,7 +118,7 @@ mod tests {
|
||||
session_id_to_resume: None,
|
||||
fresh_prompt: None,
|
||||
project_root: std::path::PathBuf::from("/tmp/project"),
|
||||
model: String::new(),
|
||||
model: None,
|
||||
};
|
||||
|
||||
let instruction = build_system_instruction(&ctx);
|
||||
@@ -139,7 +139,7 @@ mod tests {
|
||||
session_id_to_resume: None,
|
||||
fresh_prompt: None,
|
||||
project_root: std::path::PathBuf::from("/tmp/project"),
|
||||
model: String::new(),
|
||||
model: None,
|
||||
};
|
||||
|
||||
let instruction = build_system_instruction(&ctx);
|
||||
|
||||
@@ -385,7 +385,7 @@ mod tests {
|
||||
session_id_to_resume: None,
|
||||
fresh_prompt: None,
|
||||
project_root: std::path::PathBuf::from("/tmp/project"),
|
||||
model: String::new(),
|
||||
model: None,
|
||||
};
|
||||
|
||||
// The model extraction logic is inside start(), but we test the
|
||||
|
||||
@@ -46,9 +46,9 @@ pub struct RuntimeContext {
|
||||
/// Eager recording ensures the session survives a watchdog kill that aborts
|
||||
/// the tokio task before `run_agent_spawn`'s `record_session()` call runs.
|
||||
pub project_root: std::path::PathBuf,
|
||||
/// Agent model name — forms part of the session store key used for eager
|
||||
/// recording (bug 967). An empty string disables eager recording.
|
||||
pub model: String,
|
||||
/// Agent model — forms part of the session store key used for eager
|
||||
/// recording (bug 967). `None` disables eager recording.
|
||||
pub model: Option<crate::agents::AgentModel>,
|
||||
}
|
||||
|
||||
/// Result returned by a runtime after the agent session completes.
|
||||
@@ -134,7 +134,7 @@ mod tests {
|
||||
session_id_to_resume: None,
|
||||
fresh_prompt: None,
|
||||
project_root: std::path::PathBuf::from("/tmp/project"),
|
||||
model: "sonnet".to_string(),
|
||||
model: Some(crate::agents::AgentModel::Sonnet),
|
||||
};
|
||||
assert_eq!(ctx.story_id, "42_story_foo");
|
||||
assert_eq!(ctx.agent_name, "coder-1");
|
||||
|
||||
@@ -561,7 +561,7 @@ mod tests {
|
||||
session_id_to_resume: None,
|
||||
fresh_prompt: None,
|
||||
project_root: std::path::PathBuf::from("/tmp/project"),
|
||||
model: String::new(),
|
||||
model: None,
|
||||
};
|
||||
|
||||
assert_eq!(build_system_text(&ctx), "Custom system prompt");
|
||||
@@ -581,7 +581,7 @@ mod tests {
|
||||
session_id_to_resume: None,
|
||||
fresh_prompt: None,
|
||||
project_root: std::path::PathBuf::from("/tmp/project"),
|
||||
model: String::new(),
|
||||
model: None,
|
||||
};
|
||||
|
||||
let text = build_system_text(&ctx);
|
||||
@@ -634,7 +634,7 @@ mod tests {
|
||||
session_id_to_resume: None,
|
||||
fresh_prompt: None,
|
||||
project_root: std::path::PathBuf::from("/tmp/project"),
|
||||
model: String::new(),
|
||||
model: None,
|
||||
};
|
||||
assert!(ctx.command.starts_with("gpt"));
|
||||
}
|
||||
@@ -653,7 +653,7 @@ mod tests {
|
||||
session_id_to_resume: None,
|
||||
fresh_prompt: None,
|
||||
project_root: std::path::PathBuf::from("/tmp/project"),
|
||||
model: String::new(),
|
||||
model: None,
|
||||
};
|
||||
assert!(ctx.command.starts_with("o"));
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::path::Path;
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::TokenUsage;
|
||||
use super::{AgentModel, TokenUsage};
|
||||
|
||||
/// A single token usage record persisted to disk.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
@@ -14,7 +14,7 @@ pub struct TokenUsageRecord {
|
||||
pub agent_name: String,
|
||||
pub timestamp: String,
|
||||
#[serde(default)]
|
||||
pub model: Option<String>,
|
||||
pub model: Option<AgentModel>,
|
||||
pub usage: TokenUsage,
|
||||
}
|
||||
|
||||
@@ -75,7 +75,7 @@ pub fn read_all(project_root: &Path) -> Result<Vec<TokenUsageRecord>, String> {
|
||||
pub fn build_record(
|
||||
story_id: &str,
|
||||
agent_name: &str,
|
||||
model: Option<String>,
|
||||
model: Option<AgentModel>,
|
||||
usage: TokenUsage,
|
||||
) -> TokenUsageRecord {
|
||||
TokenUsageRecord {
|
||||
|
||||
Reference in New Issue
Block a user