From 60a9c87794cc1ae04502268eba3bab7a07342c20 Mon Sep 17 00:00:00 2001 From: dave Date: Fri, 24 Apr 2026 16:15:10 +0000 Subject: [PATCH] huskies: merge 609_story_extract_oauth_service --- server/src/http/oauth.rs | 360 +++++++------------------------ server/src/service/mod.rs | 1 + server/src/service/oauth/flow.rs | 124 +++++++++++ server/src/service/oauth/io.rs | 112 ++++++++++ server/src/service/oauth/mod.rs | 269 +++++++++++++++++++++++ server/src/service/oauth/pkce.rs | 175 +++++++++++++++ 6 files changed, 758 insertions(+), 283 deletions(-) create mode 100644 server/src/service/oauth/flow.rs create mode 100644 server/src/service/oauth/io.rs create mode 100644 server/src/service/oauth/mod.rs create mode 100644 server/src/service/oauth/pkce.rs diff --git a/server/src/http/oauth.rs b/server/src/http/oauth.rs index bb5fe52b..90c45ad7 100644 --- a/server/src/http/oauth.rs +++ b/server/src/http/oauth.rs @@ -1,102 +1,23 @@ -//! OAuth endpoints — Anthropic OAuth callback and token exchange flow. -use crate::llm::oauth; +//! OAuth endpoints — thin HTTP adapters over `service::oauth`. +//! +//! Business logic lives in `service::oauth`. These handlers only: +//! 1. Extract parameters from the HTTP request. +//! 2. Call the service layer. +//! 3. Map service errors to HTTP responses. +use crate::service::oauth as svc; use crate::slog; use poem::handler; use poem::http::StatusCode; use poem::web::{Data, Query, Redirect}; use serde::Deserialize; -use std::collections::HashMap; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; -/// Anthropic OAuth configuration. -const CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"; -/// Claude.ai authorize URL (for Max/Pro subscriptions). -const AUTHORIZE_URL: &str = "https://claude.com/cai/oauth/authorize"; -const TOKEN_ENDPOINT: &str = "https://platform.claude.com/v1/oauth/token"; -const SCOPES: &str = - "user:inference user:profile user:mcp_servers user:sessions:claude_code user:file_upload"; - -/// In-memory store for pending PKCE flows, keyed by state parameter. -#[derive(Clone)] -pub struct OAuthState { - /// Maps state → (code_verifier, redirect_uri) - pending: Arc>>, - /// The port the server is listening on (for building redirect_uri). - port: u16, -} - -struct PendingFlow { - code_verifier: String, - redirect_uri: String, -} - -impl OAuthState { - pub fn new(port: u16) -> Self { - Self { - pending: Arc::new(Mutex::new(HashMap::new())), - port, - } - } - - fn callback_url(&self) -> String { - format!("http://localhost:{}/callback", self.port) - } -} - -/// Generate a random alphanumeric string of the given length. -fn random_string(len: usize) -> String { - use std::collections::hash_map::RandomState; - use std::hash::{BuildHasher, Hasher}; - let mut s = String::with_capacity(len); - let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; - for _ in 0..len { - let hasher = RandomState::new().build_hasher(); - let idx = hasher.finish() as usize % chars.len(); - s.push(chars[idx] as char); - } - s -} - -/// Compute the S256 PKCE code challenge from a code verifier. -fn compute_code_challenge(verifier: &str) -> String { - use sha2::{Digest, Sha256}; - let hash = Sha256::digest(verifier.as_bytes()); - base64url_encode(&hash) -} - -/// Base64url-encode without padding (RFC 7636). -fn base64url_encode(data: &[u8]) -> String { - // Standard base64 then convert to base64url - const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - let mut result = String::new(); - let mut i = 0; - while i < data.len() { - let b0 = data[i] as u32; - let b1 = if i + 1 < data.len() { - data[i + 1] as u32 - } else { - 0 - }; - let b2 = if i + 2 < data.len() { - data[i + 2] as u32 - } else { - 0 - }; - let triple = (b0 << 16) | (b1 << 8) | b2; - - result.push(CHARS[((triple >> 18) & 0x3F) as usize] as char); - result.push(CHARS[((triple >> 12) & 0x3F) as usize] as char); - if i + 1 < data.len() { - result.push(CHARS[((triple >> 6) & 0x3F) as usize] as char); - } - if i + 2 < data.len() { - result.push(CHARS[(triple & 0x3F) as usize] as char); - } - i += 3; - } - // Convert to base64url: replace + with -, / with _ - result.replace('+', "-").replace('/', "_") -} +// Re-export service types so that existing tests in this file continue to +// compile unchanged (they use `use super::*` and call these by name). +pub(crate) use svc::OAuthState; +// Re-exported for tests only (tests use `use super::*` to call these by name). +#[cfg(test)] +pub(crate) use svc::pkce::{base64url_encode, compute_code_challenge, random_string}; /// `GET /oauth/authorize` — Initiates the OAuth flow. /// @@ -104,35 +25,11 @@ fn base64url_encode(data: &[u8]) -> String { /// Anthropic's authorization page. #[handler] pub async fn oauth_authorize(state: Data<&Arc>) -> Redirect { - let code_verifier = random_string(128); - let code_challenge = compute_code_challenge(&code_verifier); - let csrf_state = random_string(32); - let redirect_uri = state.callback_url(); - - slog!("[oauth] Starting OAuth flow, state={}", csrf_state); - - // Store the pending flow - state.pending.lock().unwrap().insert( - csrf_state.clone(), - PendingFlow { - code_verifier, - redirect_uri: redirect_uri.clone(), - }, - ); - - let authorize_url = format!( - "{}?code=true&client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}", - AUTHORIZE_URL, - CLIENT_ID, - percent_encode(&redirect_uri), - percent_encode(SCOPES), - percent_encode(&code_challenge), - percent_encode(&csrf_state), - ); - - Redirect::temporary(authorize_url) + let (_, url) = svc::initiate_flow(&state); + Redirect::temporary(url) } +/// Query parameters received on the OAuth callback URL. #[derive(Deserialize)] pub struct CallbackParams { code: Option, @@ -141,18 +38,6 @@ pub struct CallbackParams { error_description: Option, } -/// Response from the Anthropic OAuth token endpoint. -#[derive(Deserialize)] -struct TokenResponse { - access_token: String, - refresh_token: Option, - expires_in: u64, - #[allow(dead_code)] - token_type: Option, - #[allow(dead_code)] - scope: Option, -} - /// `GET /oauth/callback` — Handles the OAuth redirect from Anthropic. /// /// Exchanges the authorization code for tokens and writes them to @@ -162,7 +47,7 @@ pub async fn oauth_callback( state: Data<&Arc>, Query(params): Query, ) -> poem::Response { - // Handle errors from Anthropic + // Handle provider-side errors (e.g. user denied access). if let Some(err) = ¶ms.error { let desc = params .error_description @@ -177,7 +62,7 @@ pub async fn oauth_callback( } let code = match ¶ms.code { - Some(c) => c, + Some(c) => c.clone(), None => { return html_response( StatusCode::BAD_REQUEST, @@ -188,7 +73,7 @@ pub async fn oauth_callback( }; let csrf_state = match ¶ms.state { - Some(s) => s, + Some(s) => s.clone(), None => { return html_response( StatusCode::BAD_REQUEST, @@ -198,163 +83,72 @@ pub async fn oauth_callback( } }; - // Look up and remove the pending flow - let pending = state.pending.lock().unwrap().remove(csrf_state); - let flow = match pending { - Some(f) => f, - None => { - slog!("[oauth] Unknown state parameter: {}", csrf_state); - return html_response( - StatusCode::BAD_REQUEST, - "Invalid State", - "Unknown or expired state parameter. Please try logging in again.", - ); - } - }; - - slog!("[oauth] Received callback, exchanging code for tokens"); - - // Exchange the authorization code for tokens - let client = reqwest::Client::new(); - let resp = client - .post(TOKEN_ENDPOINT) - .header("Content-Type", "application/json") - .json(&serde_json::json!({ - "grant_type": "authorization_code", - "code": code, - "client_id": CLIENT_ID, - "redirect_uri": &flow.redirect_uri, - "code_verifier": &flow.code_verifier, - "state": csrf_state, - })) - .send() - .await; - - let resp = match resp { - Ok(r) => r, - Err(e) => { - slog!("[oauth] Token exchange request failed: {}", e); - return html_response( - StatusCode::INTERNAL_SERVER_ERROR, - "Token Exchange Failed", - &format!("Failed to contact Anthropic: {e}"), - ); - } - }; - - let status = resp.status(); - let body = resp.text().await.unwrap_or_default(); - - slog!( - "[oauth] Token exchange response (HTTP {}): {}", - status, - body - ); - - if !status.is_success() { - return html_response( - StatusCode::INTERNAL_SERVER_ERROR, - "Token Exchange Failed", - &format!("Anthropic returned HTTP {status}. Please try again."), - ); + match svc::exchange_code(&state, &code, &csrf_state).await { + Ok(()) => html_response( + StatusCode::OK, + "Authenticated!", + "Claude OAuth login successful. You can close this tab and return to Huskies.", + ), + Err(e) => map_service_error(e), } - - let token_resp: TokenResponse = match serde_json::from_str(&body) { - Ok(t) => t, - Err(e) => { - slog!("[oauth] Failed to parse token response: {}", e); - return html_response( - StatusCode::INTERNAL_SERVER_ERROR, - "Token Parse Failed", - "Received an unexpected response from Anthropic.", - ); - } - }; - - let now_ms = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_millis() as u64) - .unwrap_or(0); - - let creds = oauth::CredentialsFile { - claude_ai_oauth: oauth::OAuthCredentials { - access_token: token_resp.access_token, - refresh_token: token_resp.refresh_token.unwrap_or_default(), - expires_at: now_ms + (token_resp.expires_in * 1000), - scopes: SCOPES.split(' ').map(|s| s.to_string()).collect(), - subscription_type: None, - rate_limit_tier: None, - }, - }; - - if let Err(e) = oauth::write_credentials(&creds) { - slog!("[oauth] Failed to write credentials: {}", e); - return html_response( - StatusCode::INTERNAL_SERVER_ERROR, - "Credential Write Failed", - &format!("Tokens received but failed to save: {e}"), - ); - } - - slog!("[oauth] Successfully authenticated and saved credentials"); - - html_response( - StatusCode::OK, - "Authenticated!", - "Claude OAuth login successful. You can close this tab and return to Huskies.", - ) } -/// Check whether valid (non-expired) OAuth credentials exist. +/// `GET /oauth/status` — Check whether valid (non-expired) OAuth credentials exist. #[handler] pub async fn oauth_status() -> poem::Response { - match oauth::read_credentials() { - Ok(creds) => { - let now_ms = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_millis() as u64) - .unwrap_or(0); - let expired = now_ms > creds.claude_ai_oauth.expires_at; - let body = serde_json::json!({ - "authenticated": true, - "expired": expired, - "expires_at": creds.claude_ai_oauth.expires_at, - "has_refresh_token": !creds.claude_ai_oauth.refresh_token.is_empty(), - }); - poem::Response::builder() - .status(StatusCode::OK) - .header("Content-Type", "application/json") - .body(body.to_string()) - } - Err(_) => { - let body = serde_json::json!({ - "authenticated": false, - "expired": false, - "expires_at": 0, - "has_refresh_token": false, - }); - poem::Response::builder() - .status(StatusCode::OK) - .header("Content-Type", "application/json") - .body(body.to_string()) - } - } + let status = svc::check_status(); + let body = serde_json::json!({ + "authenticated": status.authenticated, + "expired": status.expired, + "expires_at": status.expires_at, + "has_refresh_token": status.has_refresh_token, + }); + poem::Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "application/json") + .body(body.to_string()) } -/// Percent-encode a string for use in URL query parameters. -fn percent_encode(input: &str) -> String { - let mut encoded = String::with_capacity(input.len() * 3); - for byte in input.bytes() { - match byte { - b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { - encoded.push(byte as char); - } - _ => { - encoded.push_str(&format!("%{byte:02X}")); - } +// ── Private helpers ─────────────────────────────────────────────────────────── + +/// Map a service-layer `Error` to an HTML HTTP response. +fn map_service_error(e: svc::Error) -> poem::Response { + use svc::Error; + match e { + Error::MissingCode => html_response( + StatusCode::BAD_REQUEST, + "Missing Code", + "No authorization code received.", + ), + Error::MissingState => html_response( + StatusCode::BAD_REQUEST, + "Missing State", + "No state parameter received.", + ), + Error::InvalidState(msg) => html_response(StatusCode::BAD_REQUEST, "Invalid State", &msg), + Error::AuthorizationDenied(msg) => { + html_response(StatusCode::BAD_REQUEST, "Authentication Failed", &msg) } + Error::InvalidGrant(msg) => { + html_response(StatusCode::BAD_REQUEST, "Token Exchange Failed", &msg) + } + Error::Network(msg) => html_response( + StatusCode::INTERNAL_SERVER_ERROR, + "Token Exchange Failed", + &msg, + ), + Error::TokenExpired(msg) => html_response(StatusCode::UNAUTHORIZED, "Token Expired", &msg), + Error::TokenStorage(msg) => html_response( + StatusCode::INTERNAL_SERVER_ERROR, + "Credential Write Failed", + &msg, + ), + Error::Parse(msg) => html_response( + StatusCode::INTERNAL_SERVER_ERROR, + "Token Parse Failed", + &msg, + ), } - encoded } fn html_response(status: StatusCode, title: &str, message: &str) -> poem::Response { diff --git a/server/src/service/mod.rs b/server/src/service/mod.rs index b7504598..606c3838 100644 --- a/server/src/service/mod.rs +++ b/server/src/service/mod.rs @@ -11,5 +11,6 @@ pub mod bot_command; pub mod events; pub mod file_io; pub mod health; +pub mod oauth; pub mod project; pub mod ws; diff --git a/server/src/service/oauth/flow.rs b/server/src/service/oauth/flow.rs new file mode 100644 index 00000000..be127570 --- /dev/null +++ b/server/src/service/oauth/flow.rs @@ -0,0 +1,124 @@ +//! OAuth flow state types and pure decision logic. +//! +//! All functions here are pure — no I/O, no network, no clocks. +//! Side-effectful operations live exclusively in `io.rs`. + +use crate::llm::oauth::CredentialsFile; + +/// A pending PKCE flow waiting for an OAuth callback. +pub struct PendingFlow { + /// The PKCE code verifier generated at flow initiation. + pub code_verifier: String, + /// The redirect URI sent to the authorization endpoint. + pub redirect_uri: String, +} + +/// Current OAuth credential status, computed without I/O from already-loaded credentials. +#[derive(Debug, Clone)] +pub struct FlowStatus { + /// Whether valid credentials were found on disk. + pub authenticated: bool, + /// Whether the access token is past its expiry timestamp. + pub expired: bool, + /// The Unix-epoch millisecond expiry timestamp (0 when unauthenticated). + pub expires_at: u64, + /// Whether a non-empty refresh token is present. + pub has_refresh_token: bool, +} + +/// Determine whether `expires_at` (Unix epoch ms) has passed, given `now_ms`. +/// +/// Returns `true` when `now_ms > expires_at`. +pub fn is_token_expired(expires_at: u64, now_ms: u64) -> bool { + now_ms > expires_at +} + +/// Build a `FlowStatus` from loaded credentials and the current time. +pub fn build_flow_status(creds: &CredentialsFile, now_ms: u64) -> FlowStatus { + let expires_at = creds.claude_ai_oauth.expires_at; + FlowStatus { + authenticated: true, + expired: is_token_expired(expires_at, now_ms), + expires_at, + has_refresh_token: !creds.claude_ai_oauth.refresh_token.is_empty(), + } +} + +/// Return the unauthenticated `FlowStatus` (no credentials on disk). +pub fn unauthenticated_status() -> FlowStatus { + FlowStatus { + authenticated: false, + expired: false, + expires_at: 0, + has_refresh_token: false, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn is_token_expired_when_past_expiry() { + assert!(is_token_expired(1000, 2000)); + } + + #[test] + fn is_token_not_expired_when_before_expiry() { + assert!(!is_token_expired(2000, 1000)); + } + + #[test] + fn is_token_not_expired_at_exact_boundary() { + // expires_at == now_ms → not expired + assert!(!is_token_expired(1000, 1000)); + } + + #[test] + fn unauthenticated_status_is_not_authenticated() { + let s = unauthenticated_status(); + assert!(!s.authenticated); + assert!(!s.expired); + assert_eq!(s.expires_at, 0); + assert!(!s.has_refresh_token); + } + + #[test] + fn build_flow_status_authenticated_not_expired() { + use crate::llm::oauth::{CredentialsFile, OAuthCredentials}; + let creds = CredentialsFile { + claude_ai_oauth: OAuthCredentials { + access_token: "tok".to_string(), + refresh_token: "ref".to_string(), + expires_at: 5000, + scopes: vec![], + subscription_type: None, + rate_limit_tier: None, + }, + }; + let status = build_flow_status(&creds, 1000); + assert!(status.authenticated); + assert!(!status.expired); + assert_eq!(status.expires_at, 5000); + assert!(status.has_refresh_token); + } + + #[test] + fn build_flow_status_authenticated_expired() { + use crate::llm::oauth::{CredentialsFile, OAuthCredentials}; + let creds = CredentialsFile { + claude_ai_oauth: OAuthCredentials { + access_token: "tok".to_string(), + refresh_token: String::new(), + expires_at: 1000, + scopes: vec![], + subscription_type: None, + rate_limit_tier: None, + }, + }; + let status = build_flow_status(&creds, 9999); + assert!(status.authenticated); + assert!(status.expired); + assert!(!status.has_refresh_token); + } +} diff --git a/server/src/service/oauth/io.rs b/server/src/service/oauth/io.rs new file mode 100644 index 00000000..2e9bb085 --- /dev/null +++ b/server/src/service/oauth/io.rs @@ -0,0 +1,112 @@ +//! OAuth I/O — the ONLY place in `service/oauth/` that may perform side effects. +//! +//! Side effects here include: reading the system clock, making HTTP requests to +//! the Anthropic token endpoint, and reading/writing `~/.claude/.credentials.json`. +//! All business logic and branching belong in `mod.rs`, `pkce.rs`, or `flow.rs`. + +use super::Error; +use super::flow::FlowStatus; +use super::pkce::SCOPES; +use crate::llm::oauth::{self, CredentialsFile}; +use crate::slog; + +/// Raw token exchange result returned by the Anthropic OAuth endpoint. +#[derive(serde::Deserialize)] +pub(super) struct TokenExchangeResult { + pub access_token: String, + pub refresh_token: Option, + pub expires_in: u64, + #[allow(dead_code)] + pub token_type: Option, + #[allow(dead_code)] + pub scope: Option, +} + +/// Return the current Unix-epoch time in milliseconds. +pub(super) fn current_time_ms() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis() as u64) + .unwrap_or(0) +} + +/// Exchange an authorization code for tokens via the Anthropic token endpoint. +/// +/// Returns the raw token response on success. Network or HTTP errors are +/// mapped to typed [`Error`] variants. +pub(super) async fn exchange_code_for_tokens( + code: &str, + redirect_uri: &str, + code_verifier: &str, + csrf_state: &str, +) -> Result { + use super::pkce::CLIENT_ID; + const TOKEN_ENDPOINT: &str = "https://platform.claude.com/v1/oauth/token"; + + slog!("[oauth] Exchanging authorization code for tokens"); + + let client = reqwest::Client::new(); + let resp = client + .post(TOKEN_ENDPOINT) + .header("Content-Type", "application/json") + .json(&serde_json::json!({ + "grant_type": "authorization_code", + "code": code, + "client_id": CLIENT_ID, + "redirect_uri": redirect_uri, + "code_verifier": code_verifier, + "state": csrf_state, + })) + .send() + .await + .map_err(|e| Error::Network(format!("Failed to contact Anthropic: {e}")))?; + + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + + slog!( + "[oauth] Token exchange response (HTTP {}): {}", + status, + body + ); + + if !status.is_success() { + return Err(Error::InvalidGrant(format!( + "Anthropic returned HTTP {status}. Please try again." + ))); + } + + serde_json::from_str(&body) + .map_err(|e| Error::Parse(format!("Unexpected response from Anthropic: {e}"))) +} + +/// Persist a token exchange result to `~/.claude/.credentials.json`. +/// +/// Builds a [`CredentialsFile`] from the token response and `now_ms`, then +/// delegates to [`oauth::write_credentials`]. +pub(super) fn save_credentials(token: &TokenExchangeResult, now_ms: u64) -> Result<(), Error> { + let creds = CredentialsFile { + claude_ai_oauth: oauth::OAuthCredentials { + access_token: token.access_token.clone(), + refresh_token: token.refresh_token.clone().unwrap_or_default(), + expires_at: now_ms + (token.expires_in * 1000), + scopes: SCOPES.split(' ').map(|s| s.to_string()).collect(), + subscription_type: None, + rate_limit_tier: None, + }, + }; + oauth::write_credentials(&creds).map_err(Error::TokenStorage) +} + +/// Load OAuth credentials from disk and compute a [`FlowStatus`]. +/// +/// Returns `Ok(None)` when no credentials file exists yet (user not logged in). +pub(super) fn load_status() -> FlowStatus { + match oauth::read_credentials() { + Ok(creds) => { + let now_ms = current_time_ms(); + super::flow::build_flow_status(&creds, now_ms) + } + Err(_) => super::flow::unauthenticated_status(), + } +} diff --git a/server/src/service/oauth/mod.rs b/server/src/service/oauth/mod.rs new file mode 100644 index 00000000..0ec1b2c4 --- /dev/null +++ b/server/src/service/oauth/mod.rs @@ -0,0 +1,269 @@ +//! OAuth service — domain logic for the Anthropic OAuth 2.0 PKCE flow. +//! +//! Extracts business logic from `http/oauth.rs` following the conventions in +//! `docs/architecture/service-modules.md`: +//! - `mod.rs` (this file) — public API, typed `Error`, `OAuthState`, orchestration +//! - `io.rs` — the ONLY place that performs side effects (HTTP, filesystem, clock) +//! - `pkce.rs` — pure PKCE helpers: generation, challenge, encoding +//! - `flow.rs` — pure flow types and token-expiry decision logic + +pub mod flow; +pub(super) mod io; +pub mod pkce; + +pub use flow::FlowStatus; + +use flow::PendingFlow; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +// ── Error type ──────────────────────────────────────────────────────────────── + +/// Typed errors returned by `service::oauth` functions. +/// +/// HTTP handlers map these to status codes: +/// - [`Error::InvalidGrant`] → 400 Bad Request +/// - [`Error::Network`] → 500 Internal Server Error +/// - [`Error::TokenExpired`] → 401 Unauthorized +/// - [`Error::TokenStorage`] → 500 Internal Server Error +/// - [`Error::InvalidState`] → 400 Bad Request +/// - [`Error::MissingCode`] → 400 Bad Request +/// - [`Error::MissingState`] → 400 Bad Request +/// - [`Error::AuthorizationDenied`] → 400 Bad Request +/// - [`Error::Parse`] → 500 Internal Server Error +#[derive(Debug)] +#[allow(dead_code)] +pub enum Error { + /// The OAuth provider rejected the authorization code (invalid-grant). + InvalidGrant(String), + /// A network error occurred communicating with the OAuth provider. + Network(String), + /// The access token has expired and cannot be refreshed. + TokenExpired(String), + /// Failed to read or write the credential storage file. + TokenStorage(String), + /// The CSRF state parameter does not match any pending flow. + InvalidState(String), + /// No authorization code was provided in the callback. + MissingCode, + /// No state parameter was provided in the callback. + MissingState, + /// The OAuth provider returned an explicit error (e.g. user denied access). + AuthorizationDenied(String), + /// The token response could not be parsed. + Parse(String), +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidGrant(msg) => write!(f, "Invalid grant: {msg}"), + Self::Network(msg) => write!(f, "Network error: {msg}"), + Self::TokenExpired(msg) => write!(f, "Token expired: {msg}"), + Self::TokenStorage(msg) => write!(f, "Token storage error: {msg}"), + Self::InvalidState(msg) => write!(f, "Invalid state: {msg}"), + Self::MissingCode => write!(f, "Missing authorization code"), + Self::MissingState => write!(f, "Missing state parameter"), + Self::AuthorizationDenied(msg) => write!(f, "Authorization denied: {msg}"), + Self::Parse(msg) => write!(f, "Parse error: {msg}"), + } + } +} + +// ── OAuthState ──────────────────────────────────────────────────────────────── + +/// In-memory store for pending PKCE flows, keyed by CSRF state parameter. +/// +/// Injected into Poem route handlers via `Data>`. +#[derive(Clone)] +pub struct OAuthState { + /// Maps CSRF state → pending PKCE flow data. + pending: Arc>>, + /// Server port, used to build the `redirect_uri`. + port: u16, +} + +impl OAuthState { + /// Create a new `OAuthState` for the server listening on `port`. + pub fn new(port: u16) -> Self { + Self { + pending: Arc::new(Mutex::new(HashMap::new())), + port, + } + } + + /// Return the OAuth callback URL for this server instance. + pub(crate) fn callback_url(&self) -> String { + format!("http://localhost:{}/callback", self.port) + } +} + +// ── Public API ──────────────────────────────────────────────────────────────── + +/// Initiate a new OAuth PKCE flow. +/// +/// Generates a code verifier, CSRF state token, and PKCE challenge; stores +/// the pending flow; and returns `(csrf_state, authorize_url)` for the caller +/// to redirect the browser to. +pub fn initiate_flow(state: &OAuthState) -> (String, String) { + use pkce::{build_authorize_url, compute_code_challenge, random_string}; + + let code_verifier = random_string(128); + let code_challenge = compute_code_challenge(&code_verifier); + let csrf_state = random_string(32); + let redirect_uri = state.callback_url(); + + crate::slog!("[oauth] Starting OAuth flow, state={}", csrf_state); + + state.pending.lock().unwrap().insert( + csrf_state.clone(), + PendingFlow { + code_verifier, + redirect_uri: redirect_uri.clone(), + }, + ); + + let url = build_authorize_url(&redirect_uri, &code_challenge, &csrf_state); + (csrf_state, url) +} + +/// Exchange an authorization code for tokens and persist the credentials. +/// +/// Looks up the pending PKCE flow for `csrf_state`, exchanges the code with +/// Anthropic's token endpoint, and writes the result to +/// `~/.claude/.credentials.json`. +/// +/// # Errors +/// - [`Error::InvalidState`] if `csrf_state` is unknown or already consumed. +/// - [`Error::Network`] if the token endpoint is unreachable. +/// - [`Error::InvalidGrant`] if Anthropic rejects the code (non-2xx response). +/// - [`Error::Parse`] if the token response cannot be parsed. +/// - [`Error::TokenStorage`] if writing credentials to disk fails. +pub async fn exchange_code(state: &OAuthState, code: &str, csrf_state: &str) -> Result<(), Error> { + crate::slog!("[oauth] Received callback, exchanging code for tokens"); + + let pending = state.pending.lock().unwrap().remove(csrf_state); + let flow = pending.ok_or_else(|| { + crate::slog!("[oauth] Unknown state parameter: {}", csrf_state); + Error::InvalidState( + "Unknown or expired state parameter. Please try logging in again.".to_string(), + ) + })?; + + let token = + io::exchange_code_for_tokens(code, &flow.redirect_uri, &flow.code_verifier, csrf_state) + .await?; + let now_ms = io::current_time_ms(); + io::save_credentials(&token, now_ms)?; + + crate::slog!("[oauth] Successfully authenticated and saved credentials"); + Ok(()) +} + +/// Return the current OAuth credential status without performing any I/O beyond +/// reading the credentials file. +/// +/// Returns an unauthenticated [`FlowStatus`] when no credentials file exists. +pub fn check_status() -> FlowStatus { + io::load_status() +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn oauth_state_new_sets_port() { + let s = OAuthState::new(3001); + assert_eq!(s.callback_url(), "http://localhost:3001/callback"); + } + + #[test] + fn oauth_state_different_ports() { + let s = OAuthState::new(9876); + assert_eq!(s.callback_url(), "http://localhost:9876/callback"); + } + + #[test] + fn initiate_flow_stores_pending_entry() { + let state = OAuthState::new(3001); + let (csrf_state, url) = initiate_flow(&state); + assert!(!csrf_state.is_empty()); + assert!(url.contains(&csrf_state)); + assert!(state.pending.lock().unwrap().contains_key(&csrf_state)); + } + + #[test] + fn initiate_flow_generates_unique_states() { + let state = OAuthState::new(3001); + let (s1, _) = initiate_flow(&state); + let (s2, _) = initiate_flow(&state); + assert_ne!(s1, s2); + } + + #[test] + fn error_display_invalid_grant() { + let e = Error::InvalidGrant("bad code".to_string()); + assert_eq!(e.to_string(), "Invalid grant: bad code"); + } + + #[test] + fn error_display_network_error() { + let e = Error::Network("timeout".to_string()); + assert_eq!(e.to_string(), "Network error: timeout"); + } + + #[test] + fn error_display_token_expired() { + let e = Error::TokenExpired("expired".to_string()); + assert_eq!(e.to_string(), "Token expired: expired"); + } + + #[test] + fn error_display_token_storage() { + let e = Error::TokenStorage("disk full".to_string()); + assert_eq!(e.to_string(), "Token storage error: disk full"); + } + + #[test] + fn error_display_invalid_state() { + let e = Error::InvalidState("unknown".to_string()); + assert_eq!(e.to_string(), "Invalid state: unknown"); + } + + #[test] + fn error_display_missing_code() { + let e = Error::MissingCode; + assert_eq!(e.to_string(), "Missing authorization code"); + } + + #[test] + fn error_display_missing_state() { + let e = Error::MissingState; + assert_eq!(e.to_string(), "Missing state parameter"); + } + + #[test] + fn error_display_authorization_denied() { + let e = Error::AuthorizationDenied("access_denied".to_string()); + assert_eq!(e.to_string(), "Authorization denied: access_denied"); + } + + #[test] + fn error_display_parse_error() { + let e = Error::Parse("bad json".to_string()); + assert_eq!(e.to_string(), "Parse error: bad json"); + } + + #[test] + fn exchange_code_returns_invalid_state_for_unknown_csrf() { + // Can test the InvalidState path synchronously by driving the pending map directly + let state = OAuthState::new(3001); + // No pending flow inserted — exchange_code will find no match + let rt = tokio::runtime::Runtime::new().unwrap(); + let result = rt.block_on(exchange_code(&state, "somecode", "unknownstate")); + assert!(matches!(result, Err(Error::InvalidState(_)))); + } +} diff --git a/server/src/service/oauth/pkce.rs b/server/src/service/oauth/pkce.rs new file mode 100644 index 00000000..e8ccc42c --- /dev/null +++ b/server/src/service/oauth/pkce.rs @@ -0,0 +1,175 @@ +//! PKCE (Proof Key for Code Exchange) helpers — pure functions with no side effects. +//! +//! Covers code verifier/challenge generation, base64url encoding, +//! URL percent-encoding, and authorization URL construction. + +use sha2::{Digest, Sha256}; + +/// The Anthropic authorize endpoint. +const AUTHORIZE_URL: &str = "https://claude.com/cai/oauth/authorize"; +/// The OAuth client ID used by Claude Code. +pub(crate) const CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"; +/// The OAuth scopes requested. +pub const SCOPES: &str = + "user:inference user:profile user:mcp_servers user:sessions:claude_code user:file_upload"; + +/// Generate a random alphanumeric string of the given length. +/// +/// Used to produce PKCE code verifiers (128 chars) and CSRF state tokens (32 chars). +pub fn random_string(len: usize) -> String { + use std::collections::hash_map::RandomState; + use std::hash::{BuildHasher, Hasher}; + let mut s = String::with_capacity(len); + let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + for _ in 0..len { + let hasher = RandomState::new().build_hasher(); + let idx = hasher.finish() as usize % chars.len(); + s.push(chars[idx] as char); + } + s +} + +/// Compute the S256 PKCE code challenge from a code verifier. +/// +/// Returns the base64url-encoded SHA-256 hash of `verifier` (no padding). +pub fn compute_code_challenge(verifier: &str) -> String { + let hash = Sha256::digest(verifier.as_bytes()); + base64url_encode(&hash) +} + +/// Base64url-encode `data` without padding (RFC 7636). +pub fn base64url_encode(data: &[u8]) -> String { + const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + let mut result = String::new(); + let mut i = 0; + while i < data.len() { + let b0 = data[i] as u32; + let b1 = if i + 1 < data.len() { + data[i + 1] as u32 + } else { + 0 + }; + let b2 = if i + 2 < data.len() { + data[i + 2] as u32 + } else { + 0 + }; + let triple = (b0 << 16) | (b1 << 8) | b2; + + result.push(CHARS[((triple >> 18) & 0x3F) as usize] as char); + result.push(CHARS[((triple >> 12) & 0x3F) as usize] as char); + if i + 1 < data.len() { + result.push(CHARS[((triple >> 6) & 0x3F) as usize] as char); + } + if i + 2 < data.len() { + result.push(CHARS[(triple & 0x3F) as usize] as char); + } + i += 3; + } + result.replace('+', "-").replace('/', "_") +} + +/// Percent-encode `input` for use in URL query parameters (RFC 3986 unreserved chars). +pub fn percent_encode(input: &str) -> String { + let mut encoded = String::with_capacity(input.len() * 3); + for byte in input.bytes() { + match byte { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { + encoded.push(byte as char); + } + _ => { + encoded.push_str(&format!("%{byte:02X}")); + } + } + } + encoded +} + +/// Build the full authorization URL to redirect the browser to. +/// +/// `redirect_uri` — the callback URL (`http://localhost:/callback`) +/// `code_challenge` — the S256 code challenge +/// `csrf_state` — the random CSRF state token +pub fn build_authorize_url(redirect_uri: &str, code_challenge: &str, csrf_state: &str) -> String { + format!( + "{}?code=true&client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}", + AUTHORIZE_URL, + CLIENT_ID, + percent_encode(redirect_uri), + percent_encode(SCOPES), + percent_encode(code_challenge), + percent_encode(csrf_state), + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn base64url_encode_basic() { + assert_eq!(base64url_encode(b"Hello"), "SGVsbG8"); + } + + #[test] + fn base64url_encode_no_padding() { + assert!(!base64url_encode(b"a").contains('=')); + } + + #[test] + fn base64url_encode_no_plus_or_slash() { + let data: Vec = (0..=255).collect(); + let encoded = base64url_encode(&data); + assert!(!encoded.contains('+')); + assert!(!encoded.contains('/')); + } + + #[test] + fn compute_code_challenge_returns_nonempty() { + assert!(!compute_code_challenge("test_verifier").is_empty()); + } + + #[test] + fn compute_code_challenge_is_deterministic() { + assert_eq!( + compute_code_challenge("same"), + compute_code_challenge("same") + ); + } + + #[test] + fn random_string_length() { + assert_eq!(random_string(64).len(), 64); + } + + #[test] + fn random_string_is_alphanumeric() { + assert!( + random_string(100) + .chars() + .all(|c| c.is_ascii_alphanumeric()) + ); + } + + #[test] + fn percent_encode_unreserved_chars_unchanged() { + assert_eq!(percent_encode("abc-_.~"), "abc-_.~"); + } + + #[test] + fn percent_encode_space_becomes_percent_20() { + assert_eq!(percent_encode("hello world"), "hello%20world"); + } + + #[test] + fn build_authorize_url_contains_client_id() { + let url = build_authorize_url("http://localhost:3001/callback", "challenge", "state"); + assert!(url.contains(CLIENT_ID)); + } + + #[test] + fn build_authorize_url_contains_state() { + let url = build_authorize_url("http://localhost:3001/callback", "challenge", "mystate"); + assert!(url.contains("mystate")); + } +}