//! OAuth endpoints — Anthropic OAuth callback and token exchange flow. use crate::llm::oauth; use crate::slog; use poem::handler; use poem::http::StatusCode; use poem::web::{Data, Query, Redirect}; use serde::Deserialize; use std::collections::HashMap; use std::sync::{Arc, Mutex}; /// Anthropic OAuth configuration. const CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"; /// Claude.ai authorize URL (for Max/Pro subscriptions). const AUTHORIZE_URL: &str = "https://claude.com/cai/oauth/authorize"; const TOKEN_ENDPOINT: &str = "https://platform.claude.com/v1/oauth/token"; const SCOPES: &str = "user:inference user:profile user:mcp_servers user:sessions:claude_code user:file_upload"; /// In-memory store for pending PKCE flows, keyed by state parameter. #[derive(Clone)] pub struct OAuthState { /// Maps state → (code_verifier, redirect_uri) pending: Arc>>, /// The port the server is listening on (for building redirect_uri). port: u16, } struct PendingFlow { code_verifier: String, redirect_uri: String, } impl OAuthState { pub fn new(port: u16) -> Self { Self { pending: Arc::new(Mutex::new(HashMap::new())), port, } } fn callback_url(&self) -> String { format!("http://localhost:{}/callback", self.port) } } /// Generate a random alphanumeric string of the given length. fn random_string(len: usize) -> String { use std::collections::hash_map::RandomState; use std::hash::{BuildHasher, Hasher}; let mut s = String::with_capacity(len); let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; for _ in 0..len { let hasher = RandomState::new().build_hasher(); let idx = hasher.finish() as usize % chars.len(); s.push(chars[idx] as char); } s } /// Compute the S256 PKCE code challenge from a code verifier. fn compute_code_challenge(verifier: &str) -> String { use sha2::{Digest, Sha256}; let hash = Sha256::digest(verifier.as_bytes()); base64url_encode(&hash) } /// Base64url-encode without padding (RFC 7636). fn base64url_encode(data: &[u8]) -> String { // Standard base64 then convert to base64url const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; let mut result = String::new(); let mut i = 0; while i < data.len() { let b0 = data[i] as u32; let b1 = if i + 1 < data.len() { data[i + 1] as u32 } else { 0 }; let b2 = if i + 2 < data.len() { data[i + 2] as u32 } else { 0 }; let triple = (b0 << 16) | (b1 << 8) | b2; result.push(CHARS[((triple >> 18) & 0x3F) as usize] as char); result.push(CHARS[((triple >> 12) & 0x3F) as usize] as char); if i + 1 < data.len() { result.push(CHARS[((triple >> 6) & 0x3F) as usize] as char); } if i + 2 < data.len() { result.push(CHARS[(triple & 0x3F) as usize] as char); } i += 3; } // Convert to base64url: replace + with -, / with _ result.replace('+', "-").replace('/', "_") } /// `GET /oauth/authorize` — Initiates the OAuth flow. /// /// Generates PKCE parameters, stores them, and redirects the browser to /// Anthropic's authorization page. #[handler] pub async fn oauth_authorize(state: Data<&Arc>) -> Redirect { let code_verifier = random_string(128); let code_challenge = compute_code_challenge(&code_verifier); let csrf_state = random_string(32); let redirect_uri = state.callback_url(); slog!("[oauth] Starting OAuth flow, state={}", csrf_state); // Store the pending flow state.pending.lock().unwrap().insert( csrf_state.clone(), PendingFlow { code_verifier, redirect_uri: redirect_uri.clone(), }, ); let authorize_url = format!( "{}?code=true&client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}", AUTHORIZE_URL, CLIENT_ID, percent_encode(&redirect_uri), percent_encode(SCOPES), percent_encode(&code_challenge), percent_encode(&csrf_state), ); Redirect::temporary(authorize_url) } #[derive(Deserialize)] pub struct CallbackParams { code: Option, state: Option, error: Option, error_description: Option, } /// Response from the Anthropic OAuth token endpoint. #[derive(Deserialize)] struct TokenResponse { access_token: String, refresh_token: Option, expires_in: u64, #[allow(dead_code)] token_type: Option, #[allow(dead_code)] scope: Option, } /// `GET /oauth/callback` — Handles the OAuth redirect from Anthropic. /// /// Exchanges the authorization code for tokens and writes them to /// `~/.claude/.credentials.json`. #[handler] pub async fn oauth_callback( state: Data<&Arc>, Query(params): Query, ) -> poem::Response { // Handle errors from Anthropic if let Some(err) = ¶ms.error { let desc = params .error_description .as_deref() .unwrap_or("Unknown error"); slog!("[oauth] Authorization denied: {} - {}", err, desc); return html_response( StatusCode::BAD_REQUEST, "Authentication Failed", &format!("Anthropic denied the request: {desc}"), ); } let code = match ¶ms.code { Some(c) => c, None => { return html_response( StatusCode::BAD_REQUEST, "Missing Code", "No authorization code received from Anthropic.", ); } }; let csrf_state = match ¶ms.state { Some(s) => s, None => { return html_response( StatusCode::BAD_REQUEST, "Missing State", "No state parameter received. Possible CSRF attack.", ); } }; // Look up and remove the pending flow let pending = state.pending.lock().unwrap().remove(csrf_state); let flow = match pending { Some(f) => f, None => { slog!("[oauth] Unknown state parameter: {}", csrf_state); return html_response( StatusCode::BAD_REQUEST, "Invalid State", "Unknown or expired state parameter. Please try logging in again.", ); } }; slog!("[oauth] Received callback, exchanging code for tokens"); // Exchange the authorization code for tokens let client = reqwest::Client::new(); let resp = client .post(TOKEN_ENDPOINT) .header("Content-Type", "application/json") .json(&serde_json::json!({ "grant_type": "authorization_code", "code": code, "client_id": CLIENT_ID, "redirect_uri": &flow.redirect_uri, "code_verifier": &flow.code_verifier, "state": csrf_state, })) .send() .await; let resp = match resp { Ok(r) => r, Err(e) => { slog!("[oauth] Token exchange request failed: {}", e); return html_response( StatusCode::INTERNAL_SERVER_ERROR, "Token Exchange Failed", &format!("Failed to contact Anthropic: {e}"), ); } }; let status = resp.status(); let body = resp.text().await.unwrap_or_default(); slog!( "[oauth] Token exchange response (HTTP {}): {}", status, body ); if !status.is_success() { return html_response( StatusCode::INTERNAL_SERVER_ERROR, "Token Exchange Failed", &format!("Anthropic returned HTTP {status}. Please try again."), ); } let token_resp: TokenResponse = match serde_json::from_str(&body) { Ok(t) => t, Err(e) => { slog!("[oauth] Failed to parse token response: {}", e); return html_response( StatusCode::INTERNAL_SERVER_ERROR, "Token Parse Failed", "Received an unexpected response from Anthropic.", ); } }; let now_ms = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .map(|d| d.as_millis() as u64) .unwrap_or(0); let creds = oauth::CredentialsFile { claude_ai_oauth: oauth::OAuthCredentials { access_token: token_resp.access_token, refresh_token: token_resp.refresh_token.unwrap_or_default(), expires_at: now_ms + (token_resp.expires_in * 1000), scopes: SCOPES.split(' ').map(|s| s.to_string()).collect(), subscription_type: None, rate_limit_tier: None, }, }; if let Err(e) = oauth::write_credentials(&creds) { slog!("[oauth] Failed to write credentials: {}", e); return html_response( StatusCode::INTERNAL_SERVER_ERROR, "Credential Write Failed", &format!("Tokens received but failed to save: {e}"), ); } slog!("[oauth] Successfully authenticated and saved credentials"); html_response( StatusCode::OK, "Authenticated!", "Claude OAuth login successful. You can close this tab and return to Huskies.", ) } /// Check whether valid (non-expired) OAuth credentials exist. #[handler] pub async fn oauth_status() -> poem::Response { match oauth::read_credentials() { Ok(creds) => { let now_ms = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .map(|d| d.as_millis() as u64) .unwrap_or(0); let expired = now_ms > creds.claude_ai_oauth.expires_at; let body = serde_json::json!({ "authenticated": true, "expired": expired, "expires_at": creds.claude_ai_oauth.expires_at, "has_refresh_token": !creds.claude_ai_oauth.refresh_token.is_empty(), }); poem::Response::builder() .status(StatusCode::OK) .header("Content-Type", "application/json") .body(body.to_string()) } Err(_) => { let body = serde_json::json!({ "authenticated": false, "expired": false, "expires_at": 0, "has_refresh_token": false, }); poem::Response::builder() .status(StatusCode::OK) .header("Content-Type", "application/json") .body(body.to_string()) } } } /// Percent-encode a string for use in URL query parameters. fn percent_encode(input: &str) -> String { let mut encoded = String::with_capacity(input.len() * 3); for byte in input.bytes() { match byte { b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { encoded.push(byte as char); } _ => { encoded.push_str(&format!("%{byte:02X}")); } } } encoded } fn html_response(status: StatusCode, title: &str, message: &str) -> poem::Response { let html = format!( r#" {title}

{title}

{message}

"# ); poem::Response::builder() .status(status) .header("Content-Type", "text/html; charset=utf-8") .body(html) } #[cfg(test)] mod tests { use super::*; #[test] fn base64url_encode_basic() { // Test vector: "Hello" → base64 "SGVsbG8=" → base64url "SGVsbG8" let encoded = base64url_encode(b"Hello"); assert_eq!(encoded, "SGVsbG8"); } #[test] fn base64url_encode_no_padding() { // Ensure no '=' padding characters let encoded = base64url_encode(b"a"); assert!(!encoded.contains('=')); } #[test] fn base64url_encode_no_plus_or_slash() { // Encode bytes that would produce + and / in standard base64 let data: Vec = (0..=255).collect(); let encoded = base64url_encode(&data); assert!(!encoded.contains('+')); assert!(!encoded.contains('/')); } #[test] fn compute_code_challenge_returns_nonempty() { let challenge = compute_code_challenge("test_verifier_string"); assert!(!challenge.is_empty()); } #[test] fn compute_code_challenge_is_deterministic() { let a = compute_code_challenge("same_input"); let b = compute_code_challenge("same_input"); assert_eq!(a, b); } #[test] fn random_string_length() { let s = random_string(64); assert_eq!(s.len(), 64); } #[test] fn random_string_is_alphanumeric() { let s = random_string(100); assert!(s.chars().all(|c| c.is_ascii_alphanumeric())); } #[test] fn oauth_state_callback_url() { let state = OAuthState::new(3001); assert_eq!(state.callback_url(), "http://localhost:3001/callback"); } #[test] fn oauth_state_callback_url_uses_given_port() { // Ensure OAuthState::new uses the port passed to it, not a hardcoded value. let state = OAuthState::new(9876); assert_eq!(state.callback_url(), "http://localhost:9876/callback"); } #[tokio::test] async fn html_response_contains_title_and_message() { let resp = html_response(StatusCode::OK, "Test Title", "Test message"); let body = resp.into_body().into_string().await.unwrap(); assert!(body.contains("Test Title")); assert!(body.contains("Test message")); } }