//! OAuth service — domain logic for the Anthropic OAuth 2.0 PKCE flow. //! //! Extracts business logic from `http/oauth.rs` following the conventions in //! `docs/architecture/service-modules.md`: //! - `mod.rs` (this file) — public API, typed `Error`, `OAuthState`, orchestration //! - `io.rs` — the ONLY place that performs side effects (HTTP, filesystem, clock) //! - `pkce.rs` — pure PKCE helpers: generation, challenge, encoding //! - `flow.rs` — pure flow types and token-expiry decision logic pub mod flow; pub(super) mod io; pub mod pkce; pub use flow::AccountInfo; use flow::PendingFlow; use std::collections::HashMap; use std::sync::{Arc, Mutex}; // ── Error type ──────────────────────────────────────────────────────────────── /// Typed errors returned by `service::oauth` functions. /// /// HTTP handlers map these to status codes: /// - [`Error::InvalidGrant`] → 400 Bad Request /// - [`Error::Network`] → 500 Internal Server Error /// - [`Error::TokenExpired`] → 401 Unauthorized /// - [`Error::TokenStorage`] → 500 Internal Server Error /// - [`Error::InvalidState`] → 400 Bad Request /// - [`Error::MissingCode`] → 400 Bad Request /// - [`Error::MissingState`] → 400 Bad Request /// - [`Error::AuthorizationDenied`] → 400 Bad Request /// - [`Error::Parse`] → 500 Internal Server Error #[derive(Debug)] #[allow(dead_code)] pub enum Error { /// The OAuth provider rejected the authorization code (invalid-grant). InvalidGrant(String), /// A network error occurred communicating with the OAuth provider. Network(String), /// The access token has expired and cannot be refreshed. TokenExpired(String), /// Failed to read or write the credential storage file. TokenStorage(String), /// The CSRF state parameter does not match any pending flow. InvalidState(String), /// No authorization code was provided in the callback. MissingCode, /// No state parameter was provided in the callback. MissingState, /// The OAuth provider returned an explicit error (e.g. user denied access). AuthorizationDenied(String), /// The token response could not be parsed. Parse(String), } impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::InvalidGrant(msg) => write!(f, "Invalid grant: {msg}"), Self::Network(msg) => write!(f, "Network error: {msg}"), Self::TokenExpired(msg) => write!(f, "Token expired: {msg}"), Self::TokenStorage(msg) => write!(f, "Token storage error: {msg}"), Self::InvalidState(msg) => write!(f, "Invalid state: {msg}"), Self::MissingCode => write!(f, "Missing authorization code"), Self::MissingState => write!(f, "Missing state parameter"), Self::AuthorizationDenied(msg) => write!(f, "Authorization denied: {msg}"), Self::Parse(msg) => write!(f, "Parse error: {msg}"), } } } // ── OAuthState ──────────────────────────────────────────────────────────────── /// In-memory store for pending PKCE flows, keyed by CSRF state parameter. /// /// Injected into Poem route handlers via `Data>`. #[derive(Clone)] pub struct OAuthState { /// Maps CSRF state → pending PKCE flow data. pending: Arc>>, /// Server port, used to build the `redirect_uri`. port: u16, } impl OAuthState { /// Create a new `OAuthState` for the server listening on `port`. pub fn new(port: u16) -> Self { Self { pending: Arc::new(Mutex::new(HashMap::new())), port, } } /// Return the OAuth callback URL for this server instance. pub(crate) fn callback_url(&self) -> String { format!("http://localhost:{}/callback", self.port) } } // ── Public API ──────────────────────────────────────────────────────────────── /// Initiate a new OAuth PKCE flow. /// /// Generates a code verifier, CSRF state token, and PKCE challenge; stores /// the pending flow; and returns `(csrf_state, authorize_url)` for the caller /// to redirect the browser to. pub fn initiate_flow(state: &OAuthState) -> (String, String) { use pkce::{build_authorize_url, compute_code_challenge, random_string}; let code_verifier = random_string(128); let code_challenge = compute_code_challenge(&code_verifier); let csrf_state = random_string(32); let redirect_uri = state.callback_url(); crate::slog!("[oauth] Starting OAuth flow, state={}", csrf_state); state.pending.lock().unwrap().insert( csrf_state.clone(), PendingFlow { code_verifier, redirect_uri: redirect_uri.clone(), }, ); let url = build_authorize_url(&redirect_uri, &code_challenge, &csrf_state); (csrf_state, url) } /// Exchange an authorization code for tokens and persist the credentials. /// /// Looks up the pending PKCE flow for `csrf_state`, exchanges the code with /// Anthropic's token endpoint, and writes the result to /// `~/.claude/.credentials.json`. /// /// # Errors /// - [`Error::InvalidState`] if `csrf_state` is unknown or already consumed. /// - [`Error::Network`] if the token endpoint is unreachable. /// - [`Error::InvalidGrant`] if Anthropic rejects the code (non-2xx response). /// - [`Error::Parse`] if the token response cannot be parsed. /// - [`Error::TokenStorage`] if writing credentials to disk fails. pub async fn exchange_code(state: &OAuthState, code: &str, csrf_state: &str) -> Result<(), Error> { crate::slog!("[oauth] Received callback, exchanging code for tokens"); let pending = state.pending.lock().unwrap().remove(csrf_state); let flow = pending.ok_or_else(|| { crate::slog!("[oauth] Unknown state parameter: {}", csrf_state); Error::InvalidState( "Unknown or expired state parameter. Please try logging in again.".to_string(), ) })?; let token = io::exchange_code_for_tokens(code, &flow.redirect_uri, &flow.code_verifier, csrf_state) .await?; let now_ms = io::current_time_ms(); // Attempt to resolve the email for this account; silently fall back to an // empty string so that credential storage always succeeds. let email = io::fetch_user_email(&token.access_token) .await .unwrap_or_default(); io::save_credentials(&token, now_ms, &email)?; crate::slog!("[oauth] Successfully authenticated and saved credentials"); Ok(()) } /// Return status information for every account in the login pool. /// /// If no pool exists yet, falls back to the legacy single-account credentials /// file so that existing deployments continue to work. Returns an empty `Vec` /// when neither the pool nor the legacy file is present. pub fn check_all_accounts() -> Vec { io::load_all_accounts() } // ── Tests ───────────────────────────────────────────────────────────────────── #[cfg(test)] mod tests { use super::*; #[test] fn oauth_state_new_sets_port() { let s = OAuthState::new(3001); assert_eq!(s.callback_url(), "http://localhost:3001/callback"); } #[test] fn oauth_state_different_ports() { let s = OAuthState::new(9876); assert_eq!(s.callback_url(), "http://localhost:9876/callback"); } #[test] fn initiate_flow_stores_pending_entry() { let state = OAuthState::new(3001); let (csrf_state, url) = initiate_flow(&state); assert!(!csrf_state.is_empty()); assert!(url.contains(&csrf_state)); assert!(state.pending.lock().unwrap().contains_key(&csrf_state)); } #[test] fn initiate_flow_generates_unique_states() { let state = OAuthState::new(3001); let (s1, _) = initiate_flow(&state); let (s2, _) = initiate_flow(&state); assert_ne!(s1, s2); } #[test] fn error_display_invalid_grant() { let e = Error::InvalidGrant("bad code".to_string()); assert_eq!(e.to_string(), "Invalid grant: bad code"); } #[test] fn error_display_network_error() { let e = Error::Network("timeout".to_string()); assert_eq!(e.to_string(), "Network error: timeout"); } #[test] fn error_display_token_expired() { let e = Error::TokenExpired("expired".to_string()); assert_eq!(e.to_string(), "Token expired: expired"); } #[test] fn error_display_token_storage() { let e = Error::TokenStorage("disk full".to_string()); assert_eq!(e.to_string(), "Token storage error: disk full"); } #[test] fn error_display_invalid_state() { let e = Error::InvalidState("unknown".to_string()); assert_eq!(e.to_string(), "Invalid state: unknown"); } #[test] fn error_display_missing_code() { let e = Error::MissingCode; assert_eq!(e.to_string(), "Missing authorization code"); } #[test] fn error_display_missing_state() { let e = Error::MissingState; assert_eq!(e.to_string(), "Missing state parameter"); } #[test] fn error_display_authorization_denied() { let e = Error::AuthorizationDenied("access_denied".to_string()); assert_eq!(e.to_string(), "Authorization denied: access_denied"); } #[test] fn error_display_parse_error() { let e = Error::Parse("bad json".to_string()); assert_eq!(e.to_string(), "Parse error: bad json"); } #[test] fn exchange_code_returns_invalid_state_for_unknown_csrf() { // Can test the InvalidState path synchronously by driving the pending map directly let state = OAuthState::new(3001); // No pending flow inserted — exchange_code will find no match let rt = tokio::runtime::Runtime::new().unwrap(); let result = rt.block_on(exchange_code(&state, "somecode", "unknownstate")); assert!(matches!(result, Err(Error::InvalidState(_)))); } }