use crate::http::context::{AppContext, OpenApiResult, bad_request}; use crate::io::fs; use crate::llm::chat; use poem_openapi::{Object, OpenApi, Tags, param::Query, payload::Json}; use serde::Deserialize; use std::sync::Arc; #[derive(Tags)] enum ModelTags { Model, } #[derive(Deserialize, Object)] struct ModelPayload { model: String, } pub struct ModelApi { pub ctx: Arc, } #[OpenApi(tag = "ModelTags::Model")] impl ModelApi { /// Get the currently selected model preference, if any. #[oai(path = "/model", method = "get")] async fn get_model_preference(&self) -> OpenApiResult>> { let result = fs::get_model_preference(self.ctx.store.as_ref()).map_err(bad_request)?; Ok(Json(result)) } /// Persist the selected model preference. #[oai(path = "/model", method = "post")] async fn set_model_preference(&self, payload: Json) -> OpenApiResult> { fs::set_model_preference(payload.0.model, self.ctx.store.as_ref()).map_err(bad_request)?; Ok(Json(true)) } /// Fetch available model names from an Ollama server. /// Optionally override the base URL via query string. /// Returns an empty list when Ollama is unreachable so the UI stays functional. #[oai(path = "/ollama/models", method = "get")] async fn get_ollama_models( &self, base_url: Query>, ) -> OpenApiResult>> { let models = chat::get_ollama_models(base_url.0) .await .unwrap_or_default(); Ok(Json(models)) } } #[cfg(test)] mod tests { use super::*; use crate::http::context::AppContext; use tempfile::TempDir; fn make_api(dir: &TempDir) -> ModelApi { ModelApi { ctx: Arc::new(AppContext::new_test(dir.path().to_path_buf())), } } #[tokio::test] async fn get_model_preference_returns_none_when_unset() { let dir = TempDir::new().unwrap(); let api = make_api(&dir); let result = api.get_model_preference().await.unwrap(); assert!(result.0.is_none()); } #[tokio::test] async fn set_model_preference_returns_true() { let dir = TempDir::new().unwrap(); let api = make_api(&dir); let payload = Json(ModelPayload { model: "claude-3-sonnet".to_string(), }); let result = api.set_model_preference(payload).await.unwrap(); assert!(result.0); } #[tokio::test] async fn get_model_preference_returns_value_after_set() { let dir = TempDir::new().unwrap(); let api = make_api(&dir); let payload = Json(ModelPayload { model: "claude-3-sonnet".to_string(), }); api.set_model_preference(payload).await.unwrap(); let result = api.get_model_preference().await.unwrap(); assert_eq!(result.0, Some("claude-3-sonnet".to_string())); } #[tokio::test] async fn set_model_preference_overwrites_previous_value() { let dir = TempDir::new().unwrap(); let api = make_api(&dir); api.set_model_preference(Json(ModelPayload { model: "model-a".to_string(), })) .await .unwrap(); api.set_model_preference(Json(ModelPayload { model: "model-b".to_string(), })) .await .unwrap(); let result = api.get_model_preference().await.unwrap(); assert_eq!(result.0, Some("model-b".to_string())); } #[tokio::test] async fn get_ollama_models_returns_empty_list_for_unreachable_url() { let dir = TempDir::new().unwrap(); let api = make_api(&dir); // Port 1 is reserved and should immediately refuse the connection. let base_url = Query(Some("http://127.0.0.1:1".to_string())); let result = api.get_ollama_models(base_url).await; assert!(result.is_ok()); assert_eq!(result.unwrap().0, Vec::::new()); } }