huskies: merge 609_story_extract_oauth_service
This commit is contained in:
+68
-274
@@ -1,102 +1,23 @@
|
|||||||
//! OAuth endpoints — Anthropic OAuth callback and token exchange flow.
|
//! OAuth endpoints — thin HTTP adapters over `service::oauth`.
|
||||||
use crate::llm::oauth;
|
//!
|
||||||
|
//! Business logic lives in `service::oauth`. These handlers only:
|
||||||
|
//! 1. Extract parameters from the HTTP request.
|
||||||
|
//! 2. Call the service layer.
|
||||||
|
//! 3. Map service errors to HTTP responses.
|
||||||
|
use crate::service::oauth as svc;
|
||||||
use crate::slog;
|
use crate::slog;
|
||||||
use poem::handler;
|
use poem::handler;
|
||||||
use poem::http::StatusCode;
|
use poem::http::StatusCode;
|
||||||
use poem::web::{Data, Query, Redirect};
|
use poem::web::{Data, Query, Redirect};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::collections::HashMap;
|
use std::sync::Arc;
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
|
|
||||||
/// Anthropic OAuth configuration.
|
// Re-export service types so that existing tests in this file continue to
|
||||||
const CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
|
// compile unchanged (they use `use super::*` and call these by name).
|
||||||
/// Claude.ai authorize URL (for Max/Pro subscriptions).
|
pub(crate) use svc::OAuthState;
|
||||||
const AUTHORIZE_URL: &str = "https://claude.com/cai/oauth/authorize";
|
// Re-exported for tests only (tests use `use super::*` to call these by name).
|
||||||
const TOKEN_ENDPOINT: &str = "https://platform.claude.com/v1/oauth/token";
|
#[cfg(test)]
|
||||||
const SCOPES: &str =
|
pub(crate) use svc::pkce::{base64url_encode, compute_code_challenge, random_string};
|
||||||
"user:inference user:profile user:mcp_servers user:sessions:claude_code user:file_upload";
|
|
||||||
|
|
||||||
/// In-memory store for pending PKCE flows, keyed by state parameter.
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct OAuthState {
|
|
||||||
/// Maps state → (code_verifier, redirect_uri)
|
|
||||||
pending: Arc<Mutex<HashMap<String, PendingFlow>>>,
|
|
||||||
/// The port the server is listening on (for building redirect_uri).
|
|
||||||
port: u16,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct PendingFlow {
|
|
||||||
code_verifier: String,
|
|
||||||
redirect_uri: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OAuthState {
|
|
||||||
pub fn new(port: u16) -> Self {
|
|
||||||
Self {
|
|
||||||
pending: Arc::new(Mutex::new(HashMap::new())),
|
|
||||||
port,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn callback_url(&self) -> String {
|
|
||||||
format!("http://localhost:{}/callback", self.port)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generate a random alphanumeric string of the given length.
|
|
||||||
fn random_string(len: usize) -> String {
|
|
||||||
use std::collections::hash_map::RandomState;
|
|
||||||
use std::hash::{BuildHasher, Hasher};
|
|
||||||
let mut s = String::with_capacity(len);
|
|
||||||
let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
|
|
||||||
for _ in 0..len {
|
|
||||||
let hasher = RandomState::new().build_hasher();
|
|
||||||
let idx = hasher.finish() as usize % chars.len();
|
|
||||||
s.push(chars[idx] as char);
|
|
||||||
}
|
|
||||||
s
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Compute the S256 PKCE code challenge from a code verifier.
|
|
||||||
fn compute_code_challenge(verifier: &str) -> String {
|
|
||||||
use sha2::{Digest, Sha256};
|
|
||||||
let hash = Sha256::digest(verifier.as_bytes());
|
|
||||||
base64url_encode(&hash)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Base64url-encode without padding (RFC 7636).
|
|
||||||
fn base64url_encode(data: &[u8]) -> String {
|
|
||||||
// Standard base64 then convert to base64url
|
|
||||||
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
|
||||||
let mut result = String::new();
|
|
||||||
let mut i = 0;
|
|
||||||
while i < data.len() {
|
|
||||||
let b0 = data[i] as u32;
|
|
||||||
let b1 = if i + 1 < data.len() {
|
|
||||||
data[i + 1] as u32
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
let b2 = if i + 2 < data.len() {
|
|
||||||
data[i + 2] as u32
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
let triple = (b0 << 16) | (b1 << 8) | b2;
|
|
||||||
|
|
||||||
result.push(CHARS[((triple >> 18) & 0x3F) as usize] as char);
|
|
||||||
result.push(CHARS[((triple >> 12) & 0x3F) as usize] as char);
|
|
||||||
if i + 1 < data.len() {
|
|
||||||
result.push(CHARS[((triple >> 6) & 0x3F) as usize] as char);
|
|
||||||
}
|
|
||||||
if i + 2 < data.len() {
|
|
||||||
result.push(CHARS[(triple & 0x3F) as usize] as char);
|
|
||||||
}
|
|
||||||
i += 3;
|
|
||||||
}
|
|
||||||
// Convert to base64url: replace + with -, / with _
|
|
||||||
result.replace('+', "-").replace('/', "_")
|
|
||||||
}
|
|
||||||
|
|
||||||
/// `GET /oauth/authorize` — Initiates the OAuth flow.
|
/// `GET /oauth/authorize` — Initiates the OAuth flow.
|
||||||
///
|
///
|
||||||
@@ -104,35 +25,11 @@ fn base64url_encode(data: &[u8]) -> String {
|
|||||||
/// Anthropic's authorization page.
|
/// Anthropic's authorization page.
|
||||||
#[handler]
|
#[handler]
|
||||||
pub async fn oauth_authorize(state: Data<&Arc<OAuthState>>) -> Redirect {
|
pub async fn oauth_authorize(state: Data<&Arc<OAuthState>>) -> Redirect {
|
||||||
let code_verifier = random_string(128);
|
let (_, url) = svc::initiate_flow(&state);
|
||||||
let code_challenge = compute_code_challenge(&code_verifier);
|
Redirect::temporary(url)
|
||||||
let csrf_state = random_string(32);
|
|
||||||
let redirect_uri = state.callback_url();
|
|
||||||
|
|
||||||
slog!("[oauth] Starting OAuth flow, state={}", csrf_state);
|
|
||||||
|
|
||||||
// Store the pending flow
|
|
||||||
state.pending.lock().unwrap().insert(
|
|
||||||
csrf_state.clone(),
|
|
||||||
PendingFlow {
|
|
||||||
code_verifier,
|
|
||||||
redirect_uri: redirect_uri.clone(),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
let authorize_url = format!(
|
|
||||||
"{}?code=true&client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}",
|
|
||||||
AUTHORIZE_URL,
|
|
||||||
CLIENT_ID,
|
|
||||||
percent_encode(&redirect_uri),
|
|
||||||
percent_encode(SCOPES),
|
|
||||||
percent_encode(&code_challenge),
|
|
||||||
percent_encode(&csrf_state),
|
|
||||||
);
|
|
||||||
|
|
||||||
Redirect::temporary(authorize_url)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Query parameters received on the OAuth callback URL.
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub struct CallbackParams {
|
pub struct CallbackParams {
|
||||||
code: Option<String>,
|
code: Option<String>,
|
||||||
@@ -141,18 +38,6 @@ pub struct CallbackParams {
|
|||||||
error_description: Option<String>,
|
error_description: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Response from the Anthropic OAuth token endpoint.
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct TokenResponse {
|
|
||||||
access_token: String,
|
|
||||||
refresh_token: Option<String>,
|
|
||||||
expires_in: u64,
|
|
||||||
#[allow(dead_code)]
|
|
||||||
token_type: Option<String>,
|
|
||||||
#[allow(dead_code)]
|
|
||||||
scope: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// `GET /oauth/callback` — Handles the OAuth redirect from Anthropic.
|
/// `GET /oauth/callback` — Handles the OAuth redirect from Anthropic.
|
||||||
///
|
///
|
||||||
/// Exchanges the authorization code for tokens and writes them to
|
/// Exchanges the authorization code for tokens and writes them to
|
||||||
@@ -162,7 +47,7 @@ pub async fn oauth_callback(
|
|||||||
state: Data<&Arc<OAuthState>>,
|
state: Data<&Arc<OAuthState>>,
|
||||||
Query(params): Query<CallbackParams>,
|
Query(params): Query<CallbackParams>,
|
||||||
) -> poem::Response {
|
) -> poem::Response {
|
||||||
// Handle errors from Anthropic
|
// Handle provider-side errors (e.g. user denied access).
|
||||||
if let Some(err) = ¶ms.error {
|
if let Some(err) = ¶ms.error {
|
||||||
let desc = params
|
let desc = params
|
||||||
.error_description
|
.error_description
|
||||||
@@ -177,7 +62,7 @@ pub async fn oauth_callback(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let code = match ¶ms.code {
|
let code = match ¶ms.code {
|
||||||
Some(c) => c,
|
Some(c) => c.clone(),
|
||||||
None => {
|
None => {
|
||||||
return html_response(
|
return html_response(
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
@@ -188,7 +73,7 @@ pub async fn oauth_callback(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let csrf_state = match ¶ms.state {
|
let csrf_state = match ¶ms.state {
|
||||||
Some(s) => s,
|
Some(s) => s.clone(),
|
||||||
None => {
|
None => {
|
||||||
return html_response(
|
return html_response(
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
@@ -198,163 +83,72 @@ pub async fn oauth_callback(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Look up and remove the pending flow
|
match svc::exchange_code(&state, &code, &csrf_state).await {
|
||||||
let pending = state.pending.lock().unwrap().remove(csrf_state);
|
Ok(()) => html_response(
|
||||||
let flow = match pending {
|
|
||||||
Some(f) => f,
|
|
||||||
None => {
|
|
||||||
slog!("[oauth] Unknown state parameter: {}", csrf_state);
|
|
||||||
return html_response(
|
|
||||||
StatusCode::BAD_REQUEST,
|
|
||||||
"Invalid State",
|
|
||||||
"Unknown or expired state parameter. Please try logging in again.",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
slog!("[oauth] Received callback, exchanging code for tokens");
|
|
||||||
|
|
||||||
// Exchange the authorization code for tokens
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
let resp = client
|
|
||||||
.post(TOKEN_ENDPOINT)
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.json(&serde_json::json!({
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
"code": code,
|
|
||||||
"client_id": CLIENT_ID,
|
|
||||||
"redirect_uri": &flow.redirect_uri,
|
|
||||||
"code_verifier": &flow.code_verifier,
|
|
||||||
"state": csrf_state,
|
|
||||||
}))
|
|
||||||
.send()
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let resp = match resp {
|
|
||||||
Ok(r) => r,
|
|
||||||
Err(e) => {
|
|
||||||
slog!("[oauth] Token exchange request failed: {}", e);
|
|
||||||
return html_response(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
"Token Exchange Failed",
|
|
||||||
&format!("Failed to contact Anthropic: {e}"),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let status = resp.status();
|
|
||||||
let body = resp.text().await.unwrap_or_default();
|
|
||||||
|
|
||||||
slog!(
|
|
||||||
"[oauth] Token exchange response (HTTP {}): {}",
|
|
||||||
status,
|
|
||||||
body
|
|
||||||
);
|
|
||||||
|
|
||||||
if !status.is_success() {
|
|
||||||
return html_response(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
"Token Exchange Failed",
|
|
||||||
&format!("Anthropic returned HTTP {status}. Please try again."),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let token_resp: TokenResponse = match serde_json::from_str(&body) {
|
|
||||||
Ok(t) => t,
|
|
||||||
Err(e) => {
|
|
||||||
slog!("[oauth] Failed to parse token response: {}", e);
|
|
||||||
return html_response(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
"Token Parse Failed",
|
|
||||||
"Received an unexpected response from Anthropic.",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let now_ms = std::time::SystemTime::now()
|
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
|
||||||
.map(|d| d.as_millis() as u64)
|
|
||||||
.unwrap_or(0);
|
|
||||||
|
|
||||||
let creds = oauth::CredentialsFile {
|
|
||||||
claude_ai_oauth: oauth::OAuthCredentials {
|
|
||||||
access_token: token_resp.access_token,
|
|
||||||
refresh_token: token_resp.refresh_token.unwrap_or_default(),
|
|
||||||
expires_at: now_ms + (token_resp.expires_in * 1000),
|
|
||||||
scopes: SCOPES.split(' ').map(|s| s.to_string()).collect(),
|
|
||||||
subscription_type: None,
|
|
||||||
rate_limit_tier: None,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Err(e) = oauth::write_credentials(&creds) {
|
|
||||||
slog!("[oauth] Failed to write credentials: {}", e);
|
|
||||||
return html_response(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
"Credential Write Failed",
|
|
||||||
&format!("Tokens received but failed to save: {e}"),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
slog!("[oauth] Successfully authenticated and saved credentials");
|
|
||||||
|
|
||||||
html_response(
|
|
||||||
StatusCode::OK,
|
StatusCode::OK,
|
||||||
"Authenticated!",
|
"Authenticated!",
|
||||||
"Claude OAuth login successful. You can close this tab and return to Huskies.",
|
"Claude OAuth login successful. You can close this tab and return to Huskies.",
|
||||||
)
|
),
|
||||||
|
Err(e) => map_service_error(e),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check whether valid (non-expired) OAuth credentials exist.
|
/// `GET /oauth/status` — Check whether valid (non-expired) OAuth credentials exist.
|
||||||
#[handler]
|
#[handler]
|
||||||
pub async fn oauth_status() -> poem::Response {
|
pub async fn oauth_status() -> poem::Response {
|
||||||
match oauth::read_credentials() {
|
let status = svc::check_status();
|
||||||
Ok(creds) => {
|
|
||||||
let now_ms = std::time::SystemTime::now()
|
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
|
||||||
.map(|d| d.as_millis() as u64)
|
|
||||||
.unwrap_or(0);
|
|
||||||
let expired = now_ms > creds.claude_ai_oauth.expires_at;
|
|
||||||
let body = serde_json::json!({
|
let body = serde_json::json!({
|
||||||
"authenticated": true,
|
"authenticated": status.authenticated,
|
||||||
"expired": expired,
|
"expired": status.expired,
|
||||||
"expires_at": creds.claude_ai_oauth.expires_at,
|
"expires_at": status.expires_at,
|
||||||
"has_refresh_token": !creds.claude_ai_oauth.refresh_token.is_empty(),
|
"has_refresh_token": status.has_refresh_token,
|
||||||
});
|
});
|
||||||
poem::Response::builder()
|
poem::Response::builder()
|
||||||
.status(StatusCode::OK)
|
.status(StatusCode::OK)
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.body(body.to_string())
|
.body(body.to_string())
|
||||||
}
|
|
||||||
Err(_) => {
|
|
||||||
let body = serde_json::json!({
|
|
||||||
"authenticated": false,
|
|
||||||
"expired": false,
|
|
||||||
"expires_at": 0,
|
|
||||||
"has_refresh_token": false,
|
|
||||||
});
|
|
||||||
poem::Response::builder()
|
|
||||||
.status(StatusCode::OK)
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.body(body.to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Percent-encode a string for use in URL query parameters.
|
// ── Private helpers ───────────────────────────────────────────────────────────
|
||||||
fn percent_encode(input: &str) -> String {
|
|
||||||
let mut encoded = String::with_capacity(input.len() * 3);
|
/// Map a service-layer `Error` to an HTML HTTP response.
|
||||||
for byte in input.bytes() {
|
fn map_service_error(e: svc::Error) -> poem::Response {
|
||||||
match byte {
|
use svc::Error;
|
||||||
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
|
match e {
|
||||||
encoded.push(byte as char);
|
Error::MissingCode => html_response(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"Missing Code",
|
||||||
|
"No authorization code received.",
|
||||||
|
),
|
||||||
|
Error::MissingState => html_response(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"Missing State",
|
||||||
|
"No state parameter received.",
|
||||||
|
),
|
||||||
|
Error::InvalidState(msg) => html_response(StatusCode::BAD_REQUEST, "Invalid State", &msg),
|
||||||
|
Error::AuthorizationDenied(msg) => {
|
||||||
|
html_response(StatusCode::BAD_REQUEST, "Authentication Failed", &msg)
|
||||||
}
|
}
|
||||||
_ => {
|
Error::InvalidGrant(msg) => {
|
||||||
encoded.push_str(&format!("%{byte:02X}"));
|
html_response(StatusCode::BAD_REQUEST, "Token Exchange Failed", &msg)
|
||||||
}
|
}
|
||||||
|
Error::Network(msg) => html_response(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
"Token Exchange Failed",
|
||||||
|
&msg,
|
||||||
|
),
|
||||||
|
Error::TokenExpired(msg) => html_response(StatusCode::UNAUTHORIZED, "Token Expired", &msg),
|
||||||
|
Error::TokenStorage(msg) => html_response(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
"Credential Write Failed",
|
||||||
|
&msg,
|
||||||
|
),
|
||||||
|
Error::Parse(msg) => html_response(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
"Token Parse Failed",
|
||||||
|
&msg,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
|
||||||
encoded
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn html_response(status: StatusCode, title: &str, message: &str) -> poem::Response {
|
fn html_response(status: StatusCode, title: &str, message: &str) -> poem::Response {
|
||||||
|
|||||||
@@ -11,5 +11,6 @@ pub mod bot_command;
|
|||||||
pub mod events;
|
pub mod events;
|
||||||
pub mod file_io;
|
pub mod file_io;
|
||||||
pub mod health;
|
pub mod health;
|
||||||
|
pub mod oauth;
|
||||||
pub mod project;
|
pub mod project;
|
||||||
pub mod ws;
|
pub mod ws;
|
||||||
|
|||||||
@@ -0,0 +1,124 @@
|
|||||||
|
//! OAuth flow state types and pure decision logic.
|
||||||
|
//!
|
||||||
|
//! All functions here are pure — no I/O, no network, no clocks.
|
||||||
|
//! Side-effectful operations live exclusively in `io.rs`.
|
||||||
|
|
||||||
|
use crate::llm::oauth::CredentialsFile;
|
||||||
|
|
||||||
|
/// A pending PKCE flow waiting for an OAuth callback.
|
||||||
|
pub struct PendingFlow {
|
||||||
|
/// The PKCE code verifier generated at flow initiation.
|
||||||
|
pub code_verifier: String,
|
||||||
|
/// The redirect URI sent to the authorization endpoint.
|
||||||
|
pub redirect_uri: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Current OAuth credential status, computed without I/O from already-loaded credentials.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct FlowStatus {
|
||||||
|
/// Whether valid credentials were found on disk.
|
||||||
|
pub authenticated: bool,
|
||||||
|
/// Whether the access token is past its expiry timestamp.
|
||||||
|
pub expired: bool,
|
||||||
|
/// The Unix-epoch millisecond expiry timestamp (0 when unauthenticated).
|
||||||
|
pub expires_at: u64,
|
||||||
|
/// Whether a non-empty refresh token is present.
|
||||||
|
pub has_refresh_token: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Determine whether `expires_at` (Unix epoch ms) has passed, given `now_ms`.
|
||||||
|
///
|
||||||
|
/// Returns `true` when `now_ms > expires_at`.
|
||||||
|
pub fn is_token_expired(expires_at: u64, now_ms: u64) -> bool {
|
||||||
|
now_ms > expires_at
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a `FlowStatus` from loaded credentials and the current time.
|
||||||
|
pub fn build_flow_status(creds: &CredentialsFile, now_ms: u64) -> FlowStatus {
|
||||||
|
let expires_at = creds.claude_ai_oauth.expires_at;
|
||||||
|
FlowStatus {
|
||||||
|
authenticated: true,
|
||||||
|
expired: is_token_expired(expires_at, now_ms),
|
||||||
|
expires_at,
|
||||||
|
has_refresh_token: !creds.claude_ai_oauth.refresh_token.is_empty(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the unauthenticated `FlowStatus` (no credentials on disk).
|
||||||
|
pub fn unauthenticated_status() -> FlowStatus {
|
||||||
|
FlowStatus {
|
||||||
|
authenticated: false,
|
||||||
|
expired: false,
|
||||||
|
expires_at: 0,
|
||||||
|
has_refresh_token: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn is_token_expired_when_past_expiry() {
|
||||||
|
assert!(is_token_expired(1000, 2000));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn is_token_not_expired_when_before_expiry() {
|
||||||
|
assert!(!is_token_expired(2000, 1000));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn is_token_not_expired_at_exact_boundary() {
|
||||||
|
// expires_at == now_ms → not expired
|
||||||
|
assert!(!is_token_expired(1000, 1000));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unauthenticated_status_is_not_authenticated() {
|
||||||
|
let s = unauthenticated_status();
|
||||||
|
assert!(!s.authenticated);
|
||||||
|
assert!(!s.expired);
|
||||||
|
assert_eq!(s.expires_at, 0);
|
||||||
|
assert!(!s.has_refresh_token);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_flow_status_authenticated_not_expired() {
|
||||||
|
use crate::llm::oauth::{CredentialsFile, OAuthCredentials};
|
||||||
|
let creds = CredentialsFile {
|
||||||
|
claude_ai_oauth: OAuthCredentials {
|
||||||
|
access_token: "tok".to_string(),
|
||||||
|
refresh_token: "ref".to_string(),
|
||||||
|
expires_at: 5000,
|
||||||
|
scopes: vec![],
|
||||||
|
subscription_type: None,
|
||||||
|
rate_limit_tier: None,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let status = build_flow_status(&creds, 1000);
|
||||||
|
assert!(status.authenticated);
|
||||||
|
assert!(!status.expired);
|
||||||
|
assert_eq!(status.expires_at, 5000);
|
||||||
|
assert!(status.has_refresh_token);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_flow_status_authenticated_expired() {
|
||||||
|
use crate::llm::oauth::{CredentialsFile, OAuthCredentials};
|
||||||
|
let creds = CredentialsFile {
|
||||||
|
claude_ai_oauth: OAuthCredentials {
|
||||||
|
access_token: "tok".to_string(),
|
||||||
|
refresh_token: String::new(),
|
||||||
|
expires_at: 1000,
|
||||||
|
scopes: vec![],
|
||||||
|
subscription_type: None,
|
||||||
|
rate_limit_tier: None,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let status = build_flow_status(&creds, 9999);
|
||||||
|
assert!(status.authenticated);
|
||||||
|
assert!(status.expired);
|
||||||
|
assert!(!status.has_refresh_token);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,112 @@
|
|||||||
|
//! OAuth I/O — the ONLY place in `service/oauth/` that may perform side effects.
|
||||||
|
//!
|
||||||
|
//! Side effects here include: reading the system clock, making HTTP requests to
|
||||||
|
//! the Anthropic token endpoint, and reading/writing `~/.claude/.credentials.json`.
|
||||||
|
//! All business logic and branching belong in `mod.rs`, `pkce.rs`, or `flow.rs`.
|
||||||
|
|
||||||
|
use super::Error;
|
||||||
|
use super::flow::FlowStatus;
|
||||||
|
use super::pkce::SCOPES;
|
||||||
|
use crate::llm::oauth::{self, CredentialsFile};
|
||||||
|
use crate::slog;
|
||||||
|
|
||||||
|
/// Raw token exchange result returned by the Anthropic OAuth endpoint.
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
pub(super) struct TokenExchangeResult {
|
||||||
|
pub access_token: String,
|
||||||
|
pub refresh_token: Option<String>,
|
||||||
|
pub expires_in: u64,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub token_type: Option<String>,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub scope: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the current Unix-epoch time in milliseconds.
|
||||||
|
pub(super) fn current_time_ms() -> u64 {
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.map(|d| d.as_millis() as u64)
|
||||||
|
.unwrap_or(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Exchange an authorization code for tokens via the Anthropic token endpoint.
|
||||||
|
///
|
||||||
|
/// Returns the raw token response on success. Network or HTTP errors are
|
||||||
|
/// mapped to typed [`Error`] variants.
|
||||||
|
pub(super) async fn exchange_code_for_tokens(
|
||||||
|
code: &str,
|
||||||
|
redirect_uri: &str,
|
||||||
|
code_verifier: &str,
|
||||||
|
csrf_state: &str,
|
||||||
|
) -> Result<TokenExchangeResult, Error> {
|
||||||
|
use super::pkce::CLIENT_ID;
|
||||||
|
const TOKEN_ENDPOINT: &str = "https://platform.claude.com/v1/oauth/token";
|
||||||
|
|
||||||
|
slog!("[oauth] Exchanging authorization code for tokens");
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let resp = client
|
||||||
|
.post(TOKEN_ENDPOINT)
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.json(&serde_json::json!({
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": code,
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"code_verifier": code_verifier,
|
||||||
|
"state": csrf_state,
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::Network(format!("Failed to contact Anthropic: {e}")))?;
|
||||||
|
|
||||||
|
let status = resp.status();
|
||||||
|
let body = resp.text().await.unwrap_or_default();
|
||||||
|
|
||||||
|
slog!(
|
||||||
|
"[oauth] Token exchange response (HTTP {}): {}",
|
||||||
|
status,
|
||||||
|
body
|
||||||
|
);
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
return Err(Error::InvalidGrant(format!(
|
||||||
|
"Anthropic returned HTTP {status}. Please try again."
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
serde_json::from_str(&body)
|
||||||
|
.map_err(|e| Error::Parse(format!("Unexpected response from Anthropic: {e}")))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Persist a token exchange result to `~/.claude/.credentials.json`.
|
||||||
|
///
|
||||||
|
/// Builds a [`CredentialsFile`] from the token response and `now_ms`, then
|
||||||
|
/// delegates to [`oauth::write_credentials`].
|
||||||
|
pub(super) fn save_credentials(token: &TokenExchangeResult, now_ms: u64) -> Result<(), Error> {
|
||||||
|
let creds = CredentialsFile {
|
||||||
|
claude_ai_oauth: oauth::OAuthCredentials {
|
||||||
|
access_token: token.access_token.clone(),
|
||||||
|
refresh_token: token.refresh_token.clone().unwrap_or_default(),
|
||||||
|
expires_at: now_ms + (token.expires_in * 1000),
|
||||||
|
scopes: SCOPES.split(' ').map(|s| s.to_string()).collect(),
|
||||||
|
subscription_type: None,
|
||||||
|
rate_limit_tier: None,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
oauth::write_credentials(&creds).map_err(Error::TokenStorage)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load OAuth credentials from disk and compute a [`FlowStatus`].
|
||||||
|
///
|
||||||
|
/// Returns `Ok(None)` when no credentials file exists yet (user not logged in).
|
||||||
|
pub(super) fn load_status() -> FlowStatus {
|
||||||
|
match oauth::read_credentials() {
|
||||||
|
Ok(creds) => {
|
||||||
|
let now_ms = current_time_ms();
|
||||||
|
super::flow::build_flow_status(&creds, now_ms)
|
||||||
|
}
|
||||||
|
Err(_) => super::flow::unauthenticated_status(),
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,269 @@
|
|||||||
|
//! OAuth service — domain logic for the Anthropic OAuth 2.0 PKCE flow.
|
||||||
|
//!
|
||||||
|
//! Extracts business logic from `http/oauth.rs` following the conventions in
|
||||||
|
//! `docs/architecture/service-modules.md`:
|
||||||
|
//! - `mod.rs` (this file) — public API, typed `Error`, `OAuthState`, orchestration
|
||||||
|
//! - `io.rs` — the ONLY place that performs side effects (HTTP, filesystem, clock)
|
||||||
|
//! - `pkce.rs` — pure PKCE helpers: generation, challenge, encoding
|
||||||
|
//! - `flow.rs` — pure flow types and token-expiry decision logic
|
||||||
|
|
||||||
|
pub mod flow;
|
||||||
|
pub(super) mod io;
|
||||||
|
pub mod pkce;
|
||||||
|
|
||||||
|
pub use flow::FlowStatus;
|
||||||
|
|
||||||
|
use flow::PendingFlow;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
// ── Error type ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Typed errors returned by `service::oauth` functions.
|
||||||
|
///
|
||||||
|
/// HTTP handlers map these to status codes:
|
||||||
|
/// - [`Error::InvalidGrant`] → 400 Bad Request
|
||||||
|
/// - [`Error::Network`] → 500 Internal Server Error
|
||||||
|
/// - [`Error::TokenExpired`] → 401 Unauthorized
|
||||||
|
/// - [`Error::TokenStorage`] → 500 Internal Server Error
|
||||||
|
/// - [`Error::InvalidState`] → 400 Bad Request
|
||||||
|
/// - [`Error::MissingCode`] → 400 Bad Request
|
||||||
|
/// - [`Error::MissingState`] → 400 Bad Request
|
||||||
|
/// - [`Error::AuthorizationDenied`] → 400 Bad Request
|
||||||
|
/// - [`Error::Parse`] → 500 Internal Server Error
|
||||||
|
#[derive(Debug)]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub enum Error {
|
||||||
|
/// The OAuth provider rejected the authorization code (invalid-grant).
|
||||||
|
InvalidGrant(String),
|
||||||
|
/// A network error occurred communicating with the OAuth provider.
|
||||||
|
Network(String),
|
||||||
|
/// The access token has expired and cannot be refreshed.
|
||||||
|
TokenExpired(String),
|
||||||
|
/// Failed to read or write the credential storage file.
|
||||||
|
TokenStorage(String),
|
||||||
|
/// The CSRF state parameter does not match any pending flow.
|
||||||
|
InvalidState(String),
|
||||||
|
/// No authorization code was provided in the callback.
|
||||||
|
MissingCode,
|
||||||
|
/// No state parameter was provided in the callback.
|
||||||
|
MissingState,
|
||||||
|
/// The OAuth provider returned an explicit error (e.g. user denied access).
|
||||||
|
AuthorizationDenied(String),
|
||||||
|
/// The token response could not be parsed.
|
||||||
|
Parse(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for Error {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::InvalidGrant(msg) => write!(f, "Invalid grant: {msg}"),
|
||||||
|
Self::Network(msg) => write!(f, "Network error: {msg}"),
|
||||||
|
Self::TokenExpired(msg) => write!(f, "Token expired: {msg}"),
|
||||||
|
Self::TokenStorage(msg) => write!(f, "Token storage error: {msg}"),
|
||||||
|
Self::InvalidState(msg) => write!(f, "Invalid state: {msg}"),
|
||||||
|
Self::MissingCode => write!(f, "Missing authorization code"),
|
||||||
|
Self::MissingState => write!(f, "Missing state parameter"),
|
||||||
|
Self::AuthorizationDenied(msg) => write!(f, "Authorization denied: {msg}"),
|
||||||
|
Self::Parse(msg) => write!(f, "Parse error: {msg}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── OAuthState ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// In-memory store for pending PKCE flows, keyed by CSRF state parameter.
|
||||||
|
///
|
||||||
|
/// Injected into Poem route handlers via `Data<Arc<OAuthState>>`.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OAuthState {
|
||||||
|
/// Maps CSRF state → pending PKCE flow data.
|
||||||
|
pending: Arc<Mutex<HashMap<String, PendingFlow>>>,
|
||||||
|
/// Server port, used to build the `redirect_uri`.
|
||||||
|
port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OAuthState {
|
||||||
|
/// Create a new `OAuthState` for the server listening on `port`.
|
||||||
|
pub fn new(port: u16) -> Self {
|
||||||
|
Self {
|
||||||
|
pending: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
port,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the OAuth callback URL for this server instance.
|
||||||
|
pub(crate) fn callback_url(&self) -> String {
|
||||||
|
format!("http://localhost:{}/callback", self.port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Public API ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Initiate a new OAuth PKCE flow.
|
||||||
|
///
|
||||||
|
/// Generates a code verifier, CSRF state token, and PKCE challenge; stores
|
||||||
|
/// the pending flow; and returns `(csrf_state, authorize_url)` for the caller
|
||||||
|
/// to redirect the browser to.
|
||||||
|
pub fn initiate_flow(state: &OAuthState) -> (String, String) {
|
||||||
|
use pkce::{build_authorize_url, compute_code_challenge, random_string};
|
||||||
|
|
||||||
|
let code_verifier = random_string(128);
|
||||||
|
let code_challenge = compute_code_challenge(&code_verifier);
|
||||||
|
let csrf_state = random_string(32);
|
||||||
|
let redirect_uri = state.callback_url();
|
||||||
|
|
||||||
|
crate::slog!("[oauth] Starting OAuth flow, state={}", csrf_state);
|
||||||
|
|
||||||
|
state.pending.lock().unwrap().insert(
|
||||||
|
csrf_state.clone(),
|
||||||
|
PendingFlow {
|
||||||
|
code_verifier,
|
||||||
|
redirect_uri: redirect_uri.clone(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
let url = build_authorize_url(&redirect_uri, &code_challenge, &csrf_state);
|
||||||
|
(csrf_state, url)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Exchange an authorization code for tokens and persist the credentials.
|
||||||
|
///
|
||||||
|
/// Looks up the pending PKCE flow for `csrf_state`, exchanges the code with
|
||||||
|
/// Anthropic's token endpoint, and writes the result to
|
||||||
|
/// `~/.claude/.credentials.json`.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
/// - [`Error::InvalidState`] if `csrf_state` is unknown or already consumed.
|
||||||
|
/// - [`Error::Network`] if the token endpoint is unreachable.
|
||||||
|
/// - [`Error::InvalidGrant`] if Anthropic rejects the code (non-2xx response).
|
||||||
|
/// - [`Error::Parse`] if the token response cannot be parsed.
|
||||||
|
/// - [`Error::TokenStorage`] if writing credentials to disk fails.
|
||||||
|
pub async fn exchange_code(state: &OAuthState, code: &str, csrf_state: &str) -> Result<(), Error> {
|
||||||
|
crate::slog!("[oauth] Received callback, exchanging code for tokens");
|
||||||
|
|
||||||
|
let pending = state.pending.lock().unwrap().remove(csrf_state);
|
||||||
|
let flow = pending.ok_or_else(|| {
|
||||||
|
crate::slog!("[oauth] Unknown state parameter: {}", csrf_state);
|
||||||
|
Error::InvalidState(
|
||||||
|
"Unknown or expired state parameter. Please try logging in again.".to_string(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let token =
|
||||||
|
io::exchange_code_for_tokens(code, &flow.redirect_uri, &flow.code_verifier, csrf_state)
|
||||||
|
.await?;
|
||||||
|
let now_ms = io::current_time_ms();
|
||||||
|
io::save_credentials(&token, now_ms)?;
|
||||||
|
|
||||||
|
crate::slog!("[oauth] Successfully authenticated and saved credentials");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the current OAuth credential status without performing any I/O beyond
|
||||||
|
/// reading the credentials file.
|
||||||
|
///
|
||||||
|
/// Returns an unauthenticated [`FlowStatus`] when no credentials file exists.
|
||||||
|
pub fn check_status() -> FlowStatus {
|
||||||
|
io::load_status()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oauth_state_new_sets_port() {
|
||||||
|
let s = OAuthState::new(3001);
|
||||||
|
assert_eq!(s.callback_url(), "http://localhost:3001/callback");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oauth_state_different_ports() {
|
||||||
|
let s = OAuthState::new(9876);
|
||||||
|
assert_eq!(s.callback_url(), "http://localhost:9876/callback");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn initiate_flow_stores_pending_entry() {
|
||||||
|
let state = OAuthState::new(3001);
|
||||||
|
let (csrf_state, url) = initiate_flow(&state);
|
||||||
|
assert!(!csrf_state.is_empty());
|
||||||
|
assert!(url.contains(&csrf_state));
|
||||||
|
assert!(state.pending.lock().unwrap().contains_key(&csrf_state));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn initiate_flow_generates_unique_states() {
|
||||||
|
let state = OAuthState::new(3001);
|
||||||
|
let (s1, _) = initiate_flow(&state);
|
||||||
|
let (s2, _) = initiate_flow(&state);
|
||||||
|
assert_ne!(s1, s2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn error_display_invalid_grant() {
|
||||||
|
let e = Error::InvalidGrant("bad code".to_string());
|
||||||
|
assert_eq!(e.to_string(), "Invalid grant: bad code");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn error_display_network_error() {
|
||||||
|
let e = Error::Network("timeout".to_string());
|
||||||
|
assert_eq!(e.to_string(), "Network error: timeout");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn error_display_token_expired() {
|
||||||
|
let e = Error::TokenExpired("expired".to_string());
|
||||||
|
assert_eq!(e.to_string(), "Token expired: expired");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn error_display_token_storage() {
|
||||||
|
let e = Error::TokenStorage("disk full".to_string());
|
||||||
|
assert_eq!(e.to_string(), "Token storage error: disk full");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn error_display_invalid_state() {
|
||||||
|
let e = Error::InvalidState("unknown".to_string());
|
||||||
|
assert_eq!(e.to_string(), "Invalid state: unknown");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn error_display_missing_code() {
|
||||||
|
let e = Error::MissingCode;
|
||||||
|
assert_eq!(e.to_string(), "Missing authorization code");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn error_display_missing_state() {
|
||||||
|
let e = Error::MissingState;
|
||||||
|
assert_eq!(e.to_string(), "Missing state parameter");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn error_display_authorization_denied() {
|
||||||
|
let e = Error::AuthorizationDenied("access_denied".to_string());
|
||||||
|
assert_eq!(e.to_string(), "Authorization denied: access_denied");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn error_display_parse_error() {
|
||||||
|
let e = Error::Parse("bad json".to_string());
|
||||||
|
assert_eq!(e.to_string(), "Parse error: bad json");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn exchange_code_returns_invalid_state_for_unknown_csrf() {
|
||||||
|
// Can test the InvalidState path synchronously by driving the pending map directly
|
||||||
|
let state = OAuthState::new(3001);
|
||||||
|
// No pending flow inserted — exchange_code will find no match
|
||||||
|
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||||
|
let result = rt.block_on(exchange_code(&state, "somecode", "unknownstate"));
|
||||||
|
assert!(matches!(result, Err(Error::InvalidState(_))));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,175 @@
|
|||||||
|
//! PKCE (Proof Key for Code Exchange) helpers — pure functions with no side effects.
|
||||||
|
//!
|
||||||
|
//! Covers code verifier/challenge generation, base64url encoding,
|
||||||
|
//! URL percent-encoding, and authorization URL construction.
|
||||||
|
|
||||||
|
use sha2::{Digest, Sha256};
|
||||||
|
|
||||||
|
/// The Anthropic authorize endpoint.
|
||||||
|
const AUTHORIZE_URL: &str = "https://claude.com/cai/oauth/authorize";
|
||||||
|
/// The OAuth client ID used by Claude Code.
|
||||||
|
pub(crate) const CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
|
||||||
|
/// The OAuth scopes requested.
|
||||||
|
pub const SCOPES: &str =
|
||||||
|
"user:inference user:profile user:mcp_servers user:sessions:claude_code user:file_upload";
|
||||||
|
|
||||||
|
/// Generate a random alphanumeric string of the given length.
|
||||||
|
///
|
||||||
|
/// Used to produce PKCE code verifiers (128 chars) and CSRF state tokens (32 chars).
|
||||||
|
pub fn random_string(len: usize) -> String {
|
||||||
|
use std::collections::hash_map::RandomState;
|
||||||
|
use std::hash::{BuildHasher, Hasher};
|
||||||
|
let mut s = String::with_capacity(len);
|
||||||
|
let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
|
||||||
|
for _ in 0..len {
|
||||||
|
let hasher = RandomState::new().build_hasher();
|
||||||
|
let idx = hasher.finish() as usize % chars.len();
|
||||||
|
s.push(chars[idx] as char);
|
||||||
|
}
|
||||||
|
s
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute the S256 PKCE code challenge from a code verifier.
|
||||||
|
///
|
||||||
|
/// Returns the base64url-encoded SHA-256 hash of `verifier` (no padding).
|
||||||
|
pub fn compute_code_challenge(verifier: &str) -> String {
|
||||||
|
let hash = Sha256::digest(verifier.as_bytes());
|
||||||
|
base64url_encode(&hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Base64url-encode `data` without padding (RFC 7636).
|
||||||
|
pub fn base64url_encode(data: &[u8]) -> String {
|
||||||
|
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||||
|
let mut result = String::new();
|
||||||
|
let mut i = 0;
|
||||||
|
while i < data.len() {
|
||||||
|
let b0 = data[i] as u32;
|
||||||
|
let b1 = if i + 1 < data.len() {
|
||||||
|
data[i + 1] as u32
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
let b2 = if i + 2 < data.len() {
|
||||||
|
data[i + 2] as u32
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
let triple = (b0 << 16) | (b1 << 8) | b2;
|
||||||
|
|
||||||
|
result.push(CHARS[((triple >> 18) & 0x3F) as usize] as char);
|
||||||
|
result.push(CHARS[((triple >> 12) & 0x3F) as usize] as char);
|
||||||
|
if i + 1 < data.len() {
|
||||||
|
result.push(CHARS[((triple >> 6) & 0x3F) as usize] as char);
|
||||||
|
}
|
||||||
|
if i + 2 < data.len() {
|
||||||
|
result.push(CHARS[(triple & 0x3F) as usize] as char);
|
||||||
|
}
|
||||||
|
i += 3;
|
||||||
|
}
|
||||||
|
result.replace('+', "-").replace('/', "_")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Percent-encode `input` for use in URL query parameters (RFC 3986 unreserved chars).
|
||||||
|
pub fn percent_encode(input: &str) -> String {
|
||||||
|
let mut encoded = String::with_capacity(input.len() * 3);
|
||||||
|
for byte in input.bytes() {
|
||||||
|
match byte {
|
||||||
|
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
|
||||||
|
encoded.push(byte as char);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
encoded.push_str(&format!("%{byte:02X}"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
encoded
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build the full authorization URL to redirect the browser to.
|
||||||
|
///
|
||||||
|
/// `redirect_uri` — the callback URL (`http://localhost:<port>/callback`)
|
||||||
|
/// `code_challenge` — the S256 code challenge
|
||||||
|
/// `csrf_state` — the random CSRF state token
|
||||||
|
pub fn build_authorize_url(redirect_uri: &str, code_challenge: &str, csrf_state: &str) -> String {
|
||||||
|
format!(
|
||||||
|
"{}?code=true&client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}",
|
||||||
|
AUTHORIZE_URL,
|
||||||
|
CLIENT_ID,
|
||||||
|
percent_encode(redirect_uri),
|
||||||
|
percent_encode(SCOPES),
|
||||||
|
percent_encode(code_challenge),
|
||||||
|
percent_encode(csrf_state),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn base64url_encode_basic() {
|
||||||
|
assert_eq!(base64url_encode(b"Hello"), "SGVsbG8");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn base64url_encode_no_padding() {
|
||||||
|
assert!(!base64url_encode(b"a").contains('='));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn base64url_encode_no_plus_or_slash() {
|
||||||
|
let data: Vec<u8> = (0..=255).collect();
|
||||||
|
let encoded = base64url_encode(&data);
|
||||||
|
assert!(!encoded.contains('+'));
|
||||||
|
assert!(!encoded.contains('/'));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn compute_code_challenge_returns_nonempty() {
|
||||||
|
assert!(!compute_code_challenge("test_verifier").is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn compute_code_challenge_is_deterministic() {
|
||||||
|
assert_eq!(
|
||||||
|
compute_code_challenge("same"),
|
||||||
|
compute_code_challenge("same")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn random_string_length() {
|
||||||
|
assert_eq!(random_string(64).len(), 64);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn random_string_is_alphanumeric() {
|
||||||
|
assert!(
|
||||||
|
random_string(100)
|
||||||
|
.chars()
|
||||||
|
.all(|c| c.is_ascii_alphanumeric())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn percent_encode_unreserved_chars_unchanged() {
|
||||||
|
assert_eq!(percent_encode("abc-_.~"), "abc-_.~");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn percent_encode_space_becomes_percent_20() {
|
||||||
|
assert_eq!(percent_encode("hello world"), "hello%20world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_authorize_url_contains_client_id() {
|
||||||
|
let url = build_authorize_url("http://localhost:3001/callback", "challenge", "state");
|
||||||
|
assert!(url.contains(CLIENT_ID));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_authorize_url_contains_state() {
|
||||||
|
let url = build_authorize_url("http://localhost:3001/callback", "challenge", "mystate");
|
||||||
|
assert!(url.contains("mystate"));
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user