126 lines
3.9 KiB
Rust
126 lines
3.9 KiB
Rust
//! Typed agent model — replaces raw model strings throughout the agent subsystem.
|
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
|
use std::fmt;
|
|
|
|
/// Supported agent model families.
|
|
///
|
|
/// Serialises to the canonical short name ("sonnet", "opus", "haiku") or, for
|
|
/// `Other`, the original string verbatim. Deserialises from any string:
|
|
/// Claude family names are matched by substring (e.g. "claude-sonnet-4-6"),
|
|
/// everything else becomes `Other(string)` so non-Claude runtimes (Gemini,
|
|
/// OpenAI, etc.) survive a config round-trip without error.
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum AgentModel {
|
|
/// Claude Sonnet family.
|
|
Sonnet,
|
|
/// Claude Opus family.
|
|
Opus,
|
|
/// Claude Haiku family.
|
|
Haiku,
|
|
/// Any model string not recognised as a Claude family name.
|
|
Other(String),
|
|
}
|
|
|
|
impl AgentModel {
|
|
/// Canonical short name used for serialisation and CLI `--model` flags.
|
|
pub fn as_str(&self) -> &str {
|
|
match self {
|
|
Self::Sonnet => "sonnet",
|
|
Self::Opus => "opus",
|
|
Self::Haiku => "haiku",
|
|
Self::Other(s) => s.as_str(),
|
|
}
|
|
}
|
|
|
|
/// Parse any model string into a variant — always succeeds.
|
|
///
|
|
/// Claude family names are matched by substring; everything else becomes
|
|
/// `Other`.
|
|
pub fn from_api_str(s: &str) -> Self {
|
|
if s.contains("haiku") {
|
|
Self::Haiku
|
|
} else if s.contains("sonnet") {
|
|
Self::Sonnet
|
|
} else if s.contains("opus") {
|
|
Self::Opus
|
|
} else {
|
|
Self::Other(s.to_string())
|
|
}
|
|
}
|
|
}
|
|
|
|
impl fmt::Display for AgentModel {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
f.write_str(self.as_str())
|
|
}
|
|
}
|
|
|
|
impl Serialize for AgentModel {
|
|
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
|
|
s.serialize_str(self.as_str())
|
|
}
|
|
}
|
|
|
|
impl<'de> Deserialize<'de> for AgentModel {
|
|
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
|
|
let s = String::deserialize(d)?;
|
|
Ok(Self::from_api_str(&s))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn short_names_deserialise() {
|
|
let s: AgentModel = serde_json::from_str("\"sonnet\"").unwrap();
|
|
assert_eq!(s, AgentModel::Sonnet);
|
|
let o: AgentModel = serde_json::from_str("\"opus\"").unwrap();
|
|
assert_eq!(o, AgentModel::Opus);
|
|
let h: AgentModel = serde_json::from_str("\"haiku\"").unwrap();
|
|
assert_eq!(h, AgentModel::Haiku);
|
|
}
|
|
|
|
#[test]
|
|
fn long_names_deserialise() {
|
|
let s: AgentModel = serde_json::from_str("\"claude-sonnet-4-6\"").unwrap();
|
|
assert_eq!(s, AgentModel::Sonnet);
|
|
let h: AgentModel = serde_json::from_str("\"claude-haiku-4-5-20251001\"").unwrap();
|
|
assert_eq!(h, AgentModel::Haiku);
|
|
let o: AgentModel = serde_json::from_str("\"claude-opus-4-5\"").unwrap();
|
|
assert_eq!(o, AgentModel::Opus);
|
|
}
|
|
|
|
#[test]
|
|
fn serialises_to_short_name() {
|
|
assert_eq!(
|
|
serde_json::to_string(&AgentModel::Sonnet).unwrap(),
|
|
"\"sonnet\""
|
|
);
|
|
assert_eq!(
|
|
serde_json::to_string(&AgentModel::Opus).unwrap(),
|
|
"\"opus\""
|
|
);
|
|
assert_eq!(
|
|
serde_json::to_string(&AgentModel::Haiku).unwrap(),
|
|
"\"haiku\""
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn unknown_string_becomes_other() {
|
|
let r: AgentModel = serde_json::from_str("\"gemini-2.5-pro\"").unwrap();
|
|
assert_eq!(r, AgentModel::Other("gemini-2.5-pro".to_string()));
|
|
assert_eq!(r.as_str(), "gemini-2.5-pro");
|
|
// Round-trips verbatim
|
|
assert_eq!(serde_json::to_string(&r).unwrap(), "\"gemini-2.5-pro\"");
|
|
}
|
|
|
|
#[test]
|
|
fn option_none_round_trips() {
|
|
let v: Option<AgentModel> = serde_json::from_str("null").unwrap();
|
|
assert!(v.is_none());
|
|
}
|
|
}
|