Files
storkit/server/src/http/model.rs

130 lines
3.9 KiB
Rust
Raw Normal View History

2026-02-16 16:24:21 +00:00
use crate::http::context::{AppContext, OpenApiResult, bad_request};
use crate::io::fs;
use crate::llm::chat;
2026-02-16 16:50:50 +00:00
use poem_openapi::{Object, OpenApi, Tags, param::Query, payload::Json};
2026-02-16 16:35:25 +00:00
use serde::Deserialize;
use std::sync::Arc;
2026-02-16 16:24:21 +00:00
2026-02-16 16:50:50 +00:00
#[derive(Tags)]
enum ModelTags {
Model,
}
2026-02-16 16:35:25 +00:00
#[derive(Deserialize, Object)]
struct ModelPayload {
model: String,
2026-02-16 16:24:21 +00:00
}
2026-02-16 16:35:25 +00:00
pub struct ModelApi {
pub ctx: Arc<AppContext>,
2026-02-16 16:24:21 +00:00
}
2026-02-16 16:50:50 +00:00
#[OpenApi(tag = "ModelTags::Model")]
2026-02-16 16:35:25 +00:00
impl ModelApi {
2026-02-16 16:50:50 +00:00
/// Get the currently selected model preference, if any.
2026-02-16 16:35:25 +00:00
#[oai(path = "/model", method = "get")]
async fn get_model_preference(&self) -> OpenApiResult<Json<Option<String>>> {
let result = fs::get_model_preference(self.ctx.store.as_ref()).map_err(bad_request)?;
Ok(Json(result))
}
2026-02-16 16:50:50 +00:00
/// Persist the selected model preference.
2026-02-16 16:35:25 +00:00
#[oai(path = "/model", method = "post")]
async fn set_model_preference(&self, payload: Json<ModelPayload>) -> OpenApiResult<Json<bool>> {
fs::set_model_preference(payload.0.model, self.ctx.store.as_ref()).map_err(bad_request)?;
Ok(Json(true))
}
2026-02-16 16:50:50 +00:00
/// Fetch available model names from an Ollama server.
/// Optionally override the base URL via query string.
/// Returns an empty list when Ollama is unreachable so the UI stays functional.
2026-02-16 16:35:25 +00:00
#[oai(path = "/ollama/models", method = "get")]
async fn get_ollama_models(
&self,
base_url: Query<Option<String>>,
) -> OpenApiResult<Json<Vec<String>>> {
let models = chat::get_ollama_models(base_url.0)
.await
.unwrap_or_default();
2026-02-16 16:35:25 +00:00
Ok(Json(models))
}
2026-02-16 16:24:21 +00:00
}
#[cfg(test)]
mod tests {
use super::*;
use crate::http::context::AppContext;
use tempfile::TempDir;
fn make_api(dir: &TempDir) -> ModelApi {
ModelApi {
ctx: Arc::new(AppContext::new_test(dir.path().to_path_buf())),
}
}
#[tokio::test]
async fn get_model_preference_returns_none_when_unset() {
let dir = TempDir::new().unwrap();
let api = make_api(&dir);
let result = api.get_model_preference().await.unwrap();
assert!(result.0.is_none());
}
#[tokio::test]
async fn set_model_preference_returns_true() {
let dir = TempDir::new().unwrap();
let api = make_api(&dir);
let payload = Json(ModelPayload {
model: "claude-3-sonnet".to_string(),
});
let result = api.set_model_preference(payload).await.unwrap();
assert!(result.0);
}
#[tokio::test]
async fn get_model_preference_returns_value_after_set() {
let dir = TempDir::new().unwrap();
let api = make_api(&dir);
let payload = Json(ModelPayload {
model: "claude-3-sonnet".to_string(),
});
api.set_model_preference(payload).await.unwrap();
let result = api.get_model_preference().await.unwrap();
assert_eq!(result.0, Some("claude-3-sonnet".to_string()));
}
#[tokio::test]
async fn set_model_preference_overwrites_previous_value() {
let dir = TempDir::new().unwrap();
let api = make_api(&dir);
api.set_model_preference(Json(ModelPayload {
model: "model-a".to_string(),
}))
.await
.unwrap();
api.set_model_preference(Json(ModelPayload {
model: "model-b".to_string(),
}))
.await
.unwrap();
let result = api.get_model_preference().await.unwrap();
assert_eq!(result.0, Some("model-b".to_string()));
}
#[tokio::test]
async fn get_ollama_models_returns_empty_list_for_unreachable_url() {
let dir = TempDir::new().unwrap();
let api = make_api(&dir);
// Port 1 is reserved and should immediately refuse the connection.
let base_url = Query(Some("http://127.0.0.1:1".to_string()));
let result = api.get_ollama_models(base_url).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().0, Vec::<String>::new());
}
}