278 lines
10 KiB
Rust
278 lines
10 KiB
Rust
//! OAuth service — domain logic for the Anthropic OAuth 2.0 PKCE flow.
|
|
//!
|
|
//! Extracts business logic from `http/oauth.rs` following the conventions in
|
|
//! `docs/architecture/service-modules.md`:
|
|
//! - `mod.rs` (this file) — public API, typed `Error`, `OAuthState`, orchestration
|
|
//! - `io.rs` — the ONLY place that performs side effects (HTTP, filesystem, clock)
|
|
//! - `pkce.rs` — pure PKCE helpers: generation, challenge, encoding
|
|
//! - `flow.rs` — pure flow types and token-expiry decision logic
|
|
|
|
pub mod flow;
|
|
pub(super) mod io;
|
|
pub mod pkce;
|
|
|
|
pub use flow::AccountInfo;
|
|
|
|
use flow::PendingFlow;
|
|
use std::collections::HashMap;
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
// ── Error type ────────────────────────────────────────────────────────────────
|
|
|
|
/// Typed errors returned by `service::oauth` functions.
|
|
///
|
|
/// HTTP handlers map these to status codes:
|
|
/// - [`Error::InvalidGrant`] → 400 Bad Request
|
|
/// - [`Error::Network`] → 500 Internal Server Error
|
|
/// - [`Error::TokenExpired`] → 401 Unauthorized
|
|
/// - [`Error::TokenStorage`] → 500 Internal Server Error
|
|
/// - [`Error::InvalidState`] → 400 Bad Request
|
|
/// - [`Error::MissingCode`] → 400 Bad Request
|
|
/// - [`Error::MissingState`] → 400 Bad Request
|
|
/// - [`Error::AuthorizationDenied`] → 400 Bad Request
|
|
/// - [`Error::Parse`] → 500 Internal Server Error
|
|
#[derive(Debug)]
|
|
#[allow(dead_code)]
|
|
pub enum Error {
|
|
/// The OAuth provider rejected the authorization code (invalid-grant).
|
|
InvalidGrant(String),
|
|
/// A network error occurred communicating with the OAuth provider.
|
|
Network(String),
|
|
/// The access token has expired and cannot be refreshed.
|
|
TokenExpired(String),
|
|
/// Failed to read or write the credential storage file.
|
|
TokenStorage(String),
|
|
/// The CSRF state parameter does not match any pending flow.
|
|
InvalidState(String),
|
|
/// No authorization code was provided in the callback.
|
|
MissingCode,
|
|
/// No state parameter was provided in the callback.
|
|
MissingState,
|
|
/// The OAuth provider returned an explicit error (e.g. user denied access).
|
|
AuthorizationDenied(String),
|
|
/// The token response could not be parsed.
|
|
Parse(String),
|
|
}
|
|
|
|
impl std::fmt::Display for Error {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
Self::InvalidGrant(msg) => write!(f, "Invalid grant: {msg}"),
|
|
Self::Network(msg) => write!(f, "Network error: {msg}"),
|
|
Self::TokenExpired(msg) => write!(f, "Token expired: {msg}"),
|
|
Self::TokenStorage(msg) => write!(f, "Token storage error: {msg}"),
|
|
Self::InvalidState(msg) => write!(f, "Invalid state: {msg}"),
|
|
Self::MissingCode => write!(f, "Missing authorization code"),
|
|
Self::MissingState => write!(f, "Missing state parameter"),
|
|
Self::AuthorizationDenied(msg) => write!(f, "Authorization denied: {msg}"),
|
|
Self::Parse(msg) => write!(f, "Parse error: {msg}"),
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── OAuthState ────────────────────────────────────────────────────────────────
|
|
|
|
/// In-memory store for pending PKCE flows, keyed by CSRF state parameter.
|
|
///
|
|
/// Injected into Poem route handlers via `Data<Arc<OAuthState>>`.
|
|
#[derive(Clone)]
|
|
pub struct OAuthState {
|
|
/// Maps CSRF state → pending PKCE flow data.
|
|
pending: Arc<Mutex<HashMap<String, PendingFlow>>>,
|
|
/// Server port, used to build the `redirect_uri`.
|
|
port: u16,
|
|
}
|
|
|
|
impl OAuthState {
|
|
/// Create a new `OAuthState` for the server listening on `port`.
|
|
pub fn new(port: u16) -> Self {
|
|
Self {
|
|
pending: Arc::new(Mutex::new(HashMap::new())),
|
|
port,
|
|
}
|
|
}
|
|
|
|
/// Return the OAuth callback URL for this server instance.
|
|
pub(crate) fn callback_url(&self) -> String {
|
|
format!("http://localhost:{}/callback", self.port)
|
|
}
|
|
}
|
|
|
|
// ── Public API ────────────────────────────────────────────────────────────────
|
|
|
|
/// Initiate a new OAuth PKCE flow.
|
|
///
|
|
/// Generates a code verifier, CSRF state token, and PKCE challenge; stores
|
|
/// the pending flow; and returns `(csrf_state, authorize_url)` for the caller
|
|
/// to redirect the browser to.
|
|
pub fn initiate_flow(state: &OAuthState) -> (String, String) {
|
|
use pkce::{build_authorize_url, compute_code_challenge, random_string};
|
|
|
|
let code_verifier = random_string(128);
|
|
let code_challenge = compute_code_challenge(&code_verifier);
|
|
let csrf_state = random_string(32);
|
|
let redirect_uri = state.callback_url();
|
|
|
|
crate::slog!("[oauth] Starting OAuth flow, state={}", csrf_state);
|
|
|
|
state.pending.lock().unwrap().insert(
|
|
csrf_state.clone(),
|
|
PendingFlow {
|
|
code_verifier,
|
|
redirect_uri: redirect_uri.clone(),
|
|
},
|
|
);
|
|
|
|
let url = build_authorize_url(&redirect_uri, &code_challenge, &csrf_state);
|
|
(csrf_state, url)
|
|
}
|
|
|
|
/// Exchange an authorization code for tokens and persist the credentials.
|
|
///
|
|
/// Looks up the pending PKCE flow for `csrf_state`, exchanges the code with
|
|
/// Anthropic's token endpoint, and writes the result to
|
|
/// `~/.claude/.credentials.json`.
|
|
///
|
|
/// # Errors
|
|
/// - [`Error::InvalidState`] if `csrf_state` is unknown or already consumed.
|
|
/// - [`Error::Network`] if the token endpoint is unreachable.
|
|
/// - [`Error::InvalidGrant`] if Anthropic rejects the code (non-2xx response).
|
|
/// - [`Error::Parse`] if the token response cannot be parsed.
|
|
/// - [`Error::TokenStorage`] if writing credentials to disk fails.
|
|
pub async fn exchange_code(state: &OAuthState, code: &str, csrf_state: &str) -> Result<(), Error> {
|
|
crate::slog!("[oauth] Received callback, exchanging code for tokens");
|
|
|
|
let pending = state.pending.lock().unwrap().remove(csrf_state);
|
|
let flow = pending.ok_or_else(|| {
|
|
crate::slog!("[oauth] Unknown state parameter: {}", csrf_state);
|
|
Error::InvalidState(
|
|
"Unknown or expired state parameter. Please try logging in again.".to_string(),
|
|
)
|
|
})?;
|
|
|
|
let token =
|
|
io::exchange_code_for_tokens(code, &flow.redirect_uri, &flow.code_verifier, csrf_state)
|
|
.await?;
|
|
let now_ms = io::current_time_ms();
|
|
|
|
// Attempt to resolve the email for this account; silently fall back to an
|
|
// empty string so that credential storage always succeeds.
|
|
let email = io::fetch_user_email(&token.access_token)
|
|
.await
|
|
.unwrap_or_default();
|
|
|
|
io::save_credentials(&token, now_ms, &email)?;
|
|
|
|
crate::slog!("[oauth] Successfully authenticated and saved credentials");
|
|
Ok(())
|
|
}
|
|
|
|
/// Return status information for every account in the login pool.
|
|
///
|
|
/// If no pool exists yet, falls back to the legacy single-account credentials
|
|
/// file so that existing deployments continue to work. Returns an empty `Vec`
|
|
/// when neither the pool nor the legacy file is present.
|
|
pub fn check_all_accounts() -> Vec<AccountInfo> {
|
|
io::load_all_accounts()
|
|
}
|
|
|
|
// ── Tests ─────────────────────────────────────────────────────────────────────
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn oauth_state_new_sets_port() {
|
|
let s = OAuthState::new(3001);
|
|
assert_eq!(s.callback_url(), "http://localhost:3001/callback");
|
|
}
|
|
|
|
#[test]
|
|
fn oauth_state_different_ports() {
|
|
let s = OAuthState::new(9876);
|
|
assert_eq!(s.callback_url(), "http://localhost:9876/callback");
|
|
}
|
|
|
|
#[test]
|
|
fn initiate_flow_stores_pending_entry() {
|
|
let state = OAuthState::new(3001);
|
|
let (csrf_state, url) = initiate_flow(&state);
|
|
assert!(!csrf_state.is_empty());
|
|
assert!(url.contains(&csrf_state));
|
|
assert!(state.pending.lock().unwrap().contains_key(&csrf_state));
|
|
}
|
|
|
|
#[test]
|
|
fn initiate_flow_generates_unique_states() {
|
|
let state = OAuthState::new(3001);
|
|
let (s1, _) = initiate_flow(&state);
|
|
let (s2, _) = initiate_flow(&state);
|
|
assert_ne!(s1, s2);
|
|
}
|
|
|
|
#[test]
|
|
fn error_display_invalid_grant() {
|
|
let e = Error::InvalidGrant("bad code".to_string());
|
|
assert_eq!(e.to_string(), "Invalid grant: bad code");
|
|
}
|
|
|
|
#[test]
|
|
fn error_display_network_error() {
|
|
let e = Error::Network("timeout".to_string());
|
|
assert_eq!(e.to_string(), "Network error: timeout");
|
|
}
|
|
|
|
#[test]
|
|
fn error_display_token_expired() {
|
|
let e = Error::TokenExpired("expired".to_string());
|
|
assert_eq!(e.to_string(), "Token expired: expired");
|
|
}
|
|
|
|
#[test]
|
|
fn error_display_token_storage() {
|
|
let e = Error::TokenStorage("disk full".to_string());
|
|
assert_eq!(e.to_string(), "Token storage error: disk full");
|
|
}
|
|
|
|
#[test]
|
|
fn error_display_invalid_state() {
|
|
let e = Error::InvalidState("unknown".to_string());
|
|
assert_eq!(e.to_string(), "Invalid state: unknown");
|
|
}
|
|
|
|
#[test]
|
|
fn error_display_missing_code() {
|
|
let e = Error::MissingCode;
|
|
assert_eq!(e.to_string(), "Missing authorization code");
|
|
}
|
|
|
|
#[test]
|
|
fn error_display_missing_state() {
|
|
let e = Error::MissingState;
|
|
assert_eq!(e.to_string(), "Missing state parameter");
|
|
}
|
|
|
|
#[test]
|
|
fn error_display_authorization_denied() {
|
|
let e = Error::AuthorizationDenied("access_denied".to_string());
|
|
assert_eq!(e.to_string(), "Authorization denied: access_denied");
|
|
}
|
|
|
|
#[test]
|
|
fn error_display_parse_error() {
|
|
let e = Error::Parse("bad json".to_string());
|
|
assert_eq!(e.to_string(), "Parse error: bad json");
|
|
}
|
|
|
|
#[test]
|
|
fn exchange_code_returns_invalid_state_for_unknown_csrf() {
|
|
// Can test the InvalidState path synchronously by driving the pending map directly
|
|
let state = OAuthState::new(3001);
|
|
// No pending flow inserted — exchange_code will find no match
|
|
let rt = tokio::runtime::Runtime::new().unwrap();
|
|
let result = rt.block_on(exchange_code(&state, "somecode", "unknownstate"));
|
|
assert!(matches!(result, Err(Error::InvalidState(_))));
|
|
}
|
|
}
|