Files
huskies/server/src/http/oauth.rs
T

247 lines
8.0 KiB
Rust

//! OAuth endpoints — thin HTTP adapters over `service::oauth`.
//!
//! Business logic lives in `service::oauth`. These handlers only:
//! 1. Extract parameters from the HTTP request.
//! 2. Call the service layer.
//! 3. Map service errors to HTTP responses.
use crate::service::oauth as svc;
use crate::slog;
use poem::handler;
use poem::http::StatusCode;
use poem::web::{Data, Query, Redirect};
use serde::Deserialize;
use std::sync::Arc;
// Re-export service types so that existing tests in this file continue to
// compile unchanged (they use `use super::*` and call these by name).
pub(crate) use svc::OAuthState;
// Re-exported for tests only (tests use `use super::*` to call these by name).
#[cfg(test)]
pub(crate) use svc::pkce::{base64url_encode, compute_code_challenge, random_string};
/// `GET /oauth/authorize` — Initiates the OAuth flow.
///
/// Generates PKCE parameters, stores them, and redirects the browser to
/// Anthropic's authorization page.
#[handler]
pub async fn oauth_authorize(state: Data<&Arc<OAuthState>>) -> Redirect {
let (_, url) = svc::initiate_flow(&state);
Redirect::temporary(url)
}
/// Query parameters received on the OAuth callback URL.
#[derive(Deserialize)]
pub struct CallbackParams {
code: Option<String>,
state: Option<String>,
error: Option<String>,
error_description: Option<String>,
}
/// `GET /oauth/callback` — Handles the OAuth redirect from Anthropic.
///
/// Exchanges the authorization code for tokens and writes them to
/// `~/.claude/.credentials.json`.
#[handler]
pub async fn oauth_callback(
state: Data<&Arc<OAuthState>>,
Query(params): Query<CallbackParams>,
) -> poem::Response {
// Handle provider-side errors (e.g. user denied access).
if let Some(err) = &params.error {
let desc = params
.error_description
.as_deref()
.unwrap_or("Unknown error");
slog!("[oauth] Authorization denied: {} - {}", err, desc);
return html_response(
StatusCode::BAD_REQUEST,
"Authentication Failed",
&format!("Anthropic denied the request: {desc}"),
);
}
let code = match &params.code {
Some(c) => c.clone(),
None => {
return html_response(
StatusCode::BAD_REQUEST,
"Missing Code",
"No authorization code received from Anthropic.",
);
}
};
let csrf_state = match &params.state {
Some(s) => s.clone(),
None => {
return html_response(
StatusCode::BAD_REQUEST,
"Missing State",
"No state parameter received. Possible CSRF attack.",
);
}
};
match svc::exchange_code(&state, &code, &csrf_state).await {
Ok(()) => html_response(
StatusCode::OK,
"Authenticated!",
"Claude OAuth login successful. You can close this tab and return to Huskies.",
),
Err(e) => map_service_error(e),
}
}
/// `GET /oauth/status` — Check whether valid (non-expired) OAuth credentials exist.
#[handler]
pub async fn oauth_status() -> poem::Response {
let status = svc::check_status();
let body = serde_json::json!({
"authenticated": status.authenticated,
"expired": status.expired,
"expires_at": status.expires_at,
"has_refresh_token": status.has_refresh_token,
});
poem::Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(body.to_string())
}
// ── Private helpers ───────────────────────────────────────────────────────────
/// Map a service-layer `Error` to an HTML HTTP response.
fn map_service_error(e: svc::Error) -> poem::Response {
use svc::Error;
match e {
Error::MissingCode => html_response(
StatusCode::BAD_REQUEST,
"Missing Code",
"No authorization code received.",
),
Error::MissingState => html_response(
StatusCode::BAD_REQUEST,
"Missing State",
"No state parameter received.",
),
Error::InvalidState(msg) => html_response(StatusCode::BAD_REQUEST, "Invalid State", &msg),
Error::AuthorizationDenied(msg) => {
html_response(StatusCode::BAD_REQUEST, "Authentication Failed", &msg)
}
Error::InvalidGrant(msg) => {
html_response(StatusCode::BAD_REQUEST, "Token Exchange Failed", &msg)
}
Error::Network(msg) => html_response(
StatusCode::INTERNAL_SERVER_ERROR,
"Token Exchange Failed",
&msg,
),
Error::TokenExpired(msg) => html_response(StatusCode::UNAUTHORIZED, "Token Expired", &msg),
Error::TokenStorage(msg) => html_response(
StatusCode::INTERNAL_SERVER_ERROR,
"Credential Write Failed",
&msg,
),
Error::Parse(msg) => html_response(
StatusCode::INTERNAL_SERVER_ERROR,
"Token Parse Failed",
&msg,
),
}
}
fn html_response(status: StatusCode, title: &str, message: &str) -> poem::Response {
let html = format!(
r#"<!DOCTYPE html>
<html>
<head><title>{title}</title>
<style>
body {{ font-family: system-ui, sans-serif; display: flex; justify-content: center; align-items: center; min-height: 100vh; margin: 0; background: #1a1a2e; color: #e0e0e0; }}
.card {{ background: #16213e; padding: 2rem; border-radius: 12px; text-align: center; max-width: 400px; box-shadow: 0 4px 24px rgba(0,0,0,0.3); }}
h1 {{ margin-top: 0; }}
</style>
</head>
<body><div class="card"><h1>{title}</h1><p>{message}</p></div></body>
</html>"#
);
poem::Response::builder()
.status(status)
.header("Content-Type", "text/html; charset=utf-8")
.body(html)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn base64url_encode_basic() {
// Test vector: "Hello" → base64 "SGVsbG8=" → base64url "SGVsbG8"
let encoded = base64url_encode(b"Hello");
assert_eq!(encoded, "SGVsbG8");
}
#[test]
fn base64url_encode_no_padding() {
// Ensure no '=' padding characters
let encoded = base64url_encode(b"a");
assert!(!encoded.contains('='));
}
#[test]
fn base64url_encode_no_plus_or_slash() {
// Encode bytes that would produce + and / in standard base64
let data: Vec<u8> = (0..=255).collect();
let encoded = base64url_encode(&data);
assert!(!encoded.contains('+'));
assert!(!encoded.contains('/'));
}
#[test]
fn compute_code_challenge_returns_nonempty() {
let challenge = compute_code_challenge("test_verifier_string");
assert!(!challenge.is_empty());
}
#[test]
fn compute_code_challenge_is_deterministic() {
let a = compute_code_challenge("same_input");
let b = compute_code_challenge("same_input");
assert_eq!(a, b);
}
#[test]
fn random_string_length() {
let s = random_string(64);
assert_eq!(s.len(), 64);
}
#[test]
fn random_string_is_alphanumeric() {
let s = random_string(100);
assert!(s.chars().all(|c| c.is_ascii_alphanumeric()));
}
#[test]
fn oauth_state_callback_url() {
let state = OAuthState::new(3001);
assert_eq!(state.callback_url(), "http://localhost:3001/callback");
}
#[test]
fn oauth_state_callback_url_uses_given_port() {
// Ensure OAuthState::new uses the port passed to it, not a hardcoded value.
let state = OAuthState::new(9876);
assert_eq!(state.callback_url(), "http://localhost:9876/callback");
}
#[tokio::test]
async fn html_response_contains_title_and_message() {
let resp = html_response(StatusCode::OK, "Test Title", "Test message");
let body = resp.into_body().into_string().await.unwrap();
assert!(body.contains("Test Title"));
assert!(body.contains("Test message"));
}
}