Files
huskies/server/src/agents/model.rs
T

126 lines
3.9 KiB
Rust
Raw Normal View History

2026-05-13 23:33:30 +00:00
//! Typed agent model — replaces raw model strings throughout the agent subsystem.
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;
/// Supported agent model families.
///
/// Serialises to the canonical short name ("sonnet", "opus", "haiku") or, for
/// `Other`, the original string verbatim. Deserialises from any string:
/// Claude family names are matched by substring (e.g. "claude-sonnet-4-6"),
/// everything else becomes `Other(string)` so non-Claude runtimes (Gemini,
/// OpenAI, etc.) survive a config round-trip without error.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AgentModel {
/// Claude Sonnet family.
Sonnet,
/// Claude Opus family.
Opus,
/// Claude Haiku family.
Haiku,
/// Any model string not recognised as a Claude family name.
Other(String),
}
impl AgentModel {
/// Canonical short name used for serialisation and CLI `--model` flags.
pub fn as_str(&self) -> &str {
match self {
Self::Sonnet => "sonnet",
Self::Opus => "opus",
Self::Haiku => "haiku",
Self::Other(s) => s.as_str(),
}
}
/// Parse any model string into a variant — always succeeds.
///
/// Claude family names are matched by substring; everything else becomes
/// `Other`.
pub fn from_api_str(s: &str) -> Self {
if s.contains("haiku") {
Self::Haiku
} else if s.contains("sonnet") {
Self::Sonnet
} else if s.contains("opus") {
Self::Opus
} else {
Self::Other(s.to_string())
}
}
}
impl fmt::Display for AgentModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl Serialize for AgentModel {
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
s.serialize_str(self.as_str())
}
}
impl<'de> Deserialize<'de> for AgentModel {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let s = String::deserialize(d)?;
Ok(Self::from_api_str(&s))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn short_names_deserialise() {
let s: AgentModel = serde_json::from_str("\"sonnet\"").unwrap();
assert_eq!(s, AgentModel::Sonnet);
let o: AgentModel = serde_json::from_str("\"opus\"").unwrap();
assert_eq!(o, AgentModel::Opus);
let h: AgentModel = serde_json::from_str("\"haiku\"").unwrap();
assert_eq!(h, AgentModel::Haiku);
}
#[test]
fn long_names_deserialise() {
let s: AgentModel = serde_json::from_str("\"claude-sonnet-4-6\"").unwrap();
assert_eq!(s, AgentModel::Sonnet);
let h: AgentModel = serde_json::from_str("\"claude-haiku-4-5-20251001\"").unwrap();
assert_eq!(h, AgentModel::Haiku);
let o: AgentModel = serde_json::from_str("\"claude-opus-4-5\"").unwrap();
assert_eq!(o, AgentModel::Opus);
}
#[test]
fn serialises_to_short_name() {
assert_eq!(
serde_json::to_string(&AgentModel::Sonnet).unwrap(),
"\"sonnet\""
);
assert_eq!(
serde_json::to_string(&AgentModel::Opus).unwrap(),
"\"opus\""
);
assert_eq!(
serde_json::to_string(&AgentModel::Haiku).unwrap(),
"\"haiku\""
);
}
#[test]
fn unknown_string_becomes_other() {
let r: AgentModel = serde_json::from_str("\"gemini-2.5-pro\"").unwrap();
assert_eq!(r, AgentModel::Other("gemini-2.5-pro".to_string()));
assert_eq!(r.as_str(), "gemini-2.5-pro");
// Round-trips verbatim
assert_eq!(serde_json::to_string(&r).unwrap(), "\"gemini-2.5-pro\"");
}
#[test]
fn option_none_round_trips() {
let v: Option<AgentModel> = serde_json::from_str("null").unwrap();
assert!(v.is_none());
}
}