feat: auto-refresh expired OAuth token for Claude Code PTY (story 405)

Detect authentication_failed errors from the Claude Code PTY stream
and automatically refresh the OAuth access token using the stored
refresh token in ~/.claude/.credentials.json.

- New module server/src/llm/oauth.rs: reads credentials, calls
  platform.claude.com/v1/oauth/token with JSON body, writes back
- PTY provider detects "error":"authentication_failed" via AtomicBool
- chat_stream retries once after successful refresh
- Clear error message if refresh also fails

On success the retry is transparent. On failure the user sees:
"OAuth session expired. Please run claude login to re-authenticate."

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Timmy
2026-03-26 19:58:04 +00:00
parent ab4ce2db92
commit 710b604b7c
4 changed files with 421 additions and 75 deletions
+175 -75
View File
@@ -53,10 +53,6 @@ impl ClaudeCodeProvider {
T: FnMut(&str) + Send,
A: FnMut(&str) + Send,
{
let message = user_message.to_string();
let cwd = project_root.to_string();
let resume_id = session_id.map(|s| s.to_string());
let sys_prompt = system_prompt.map(|s| s.to_string());
let cancelled = Arc::new(AtomicBool::new(false));
let cancelled_clone = cancelled.clone();
@@ -70,66 +66,98 @@ impl ClaudeCodeProvider {
}
});
let (token_tx, mut token_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
let (thinking_tx, mut thinking_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
let (activity_tx, mut activity_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
let (msg_tx, msg_rx) = std::sync::mpsc::channel::<Message>();
let (sid_tx, sid_rx) = tokio::sync::oneshot::channel::<String>();
// Attempt up to 2 times: first try, then retry after OAuth refresh.
for attempt in 0..2 {
let message = user_message.to_string();
let cwd = project_root.to_string();
let resume_id = session_id.map(|s| s.to_string());
let sys_prompt = system_prompt.map(|s| s.to_string());
let cancelled_inner = cancelled.clone();
let auth_failed = Arc::new(AtomicBool::new(false));
let auth_failed_clone = auth_failed.clone();
let pty_handle = tokio::task::spawn_blocking(move || {
run_pty_session(
&message,
&cwd,
resume_id.as_deref(),
sys_prompt.as_deref(),
cancelled,
token_tx,
thinking_tx,
activity_tx,
msg_tx,
sid_tx,
)
});
let (token_tx, mut token_rx) =
tokio::sync::mpsc::unbounded_channel::<String>();
let (thinking_tx, mut thinking_rx) =
tokio::sync::mpsc::unbounded_channel::<String>();
let (activity_tx, mut activity_rx) =
tokio::sync::mpsc::unbounded_channel::<String>();
let (msg_tx, msg_rx) = std::sync::mpsc::channel::<Message>();
let (sid_tx, sid_rx) = tokio::sync::oneshot::channel::<String>();
loop {
tokio::select! {
msg = token_rx.recv() => match msg {
Some(t) => on_token(&t),
None => break,
},
msg = thinking_rx.recv() => if let Some(t) = msg {
on_thinking(&t);
},
msg = activity_rx.recv() => if let Some(name) = msg {
on_activity(&name);
},
let pty_handle = tokio::task::spawn_blocking(move || {
run_pty_session(
&message,
&cwd,
resume_id.as_deref(),
sys_prompt.as_deref(),
cancelled_inner,
auth_failed_clone,
token_tx,
thinking_tx,
activity_tx,
msg_tx,
sid_tx,
)
});
loop {
tokio::select! {
msg = token_rx.recv() => match msg {
Some(t) => on_token(&t),
None => break,
},
msg = thinking_rx.recv() => if let Some(t) = msg {
on_thinking(&t);
},
msg = activity_rx.recv() => if let Some(name) = msg {
on_activity(&name);
},
}
}
// Drain any remaining activity/thinking messages that were buffered
// when the token channel closed.
while let Ok(t) = thinking_rx.try_recv() {
on_thinking(&t);
}
while let Ok(name) = activity_rx.try_recv() {
on_activity(&name);
}
pty_handle
.await
.map_err(|e| format!("PTY task panicked: {e}"))??;
// Check if the PTY session failed due to expired OAuth token.
if auth_failed.load(Ordering::Relaxed) && attempt == 0 {
slog!("[oauth] Authentication failed, attempting token refresh");
match crate::llm::oauth::refresh_access_token().await {
Ok(()) => {
slog!("[oauth] Token refreshed, retrying request");
on_token("\n*Refreshing authentication token...*\n");
continue;
}
Err(e) => {
return Err(format!(
"OAuth session expired. Please run `claude login` to re-authenticate. ({e})"
));
}
}
}
let captured_session_id = sid_rx.await.ok();
slog!("[pty-debug] RECEIVED session_id: {:?}", captured_session_id);
let structured_messages: Vec<Message> = msg_rx.try_iter().collect();
return Ok(ClaudeCodeResult {
messages: structured_messages,
session_id: captured_session_id,
});
}
// Drain any remaining activity/thinking messages that were buffered when
// the token channel closed.
while let Ok(t) = thinking_rx.try_recv() {
on_thinking(&t);
}
// Drain any remaining activity messages that were buffered when the
// token channel closed. The select! loop breaks on token_rx → None,
// but activity_rx may still hold signals sent in the same instant.
while let Ok(name) = activity_rx.try_recv() {
on_activity(&name);
}
pty_handle
.await
.map_err(|e| format!("PTY task panicked: {e}"))??;
let captured_session_id = sid_rx.await.ok();
slog!("[pty-debug] RECEIVED session_id: {:?}", captured_session_id);
let structured_messages: Vec<Message> = msg_rx.try_iter().collect();
Ok(ClaudeCodeResult {
messages: structured_messages,
session_id: captured_session_id,
})
// Should never reach here, but just in case.
Err("Authentication failed after retry".to_string())
}
}
@@ -152,6 +180,7 @@ fn run_pty_session(
resume_session_id: Option<&str>,
_system_prompt: Option<&str>,
cancelled: Arc<AtomicBool>,
auth_failed: Arc<AtomicBool>,
token_tx: tokio::sync::mpsc::UnboundedSender<String>,
thinking_tx: tokio::sync::mpsc::UnboundedSender<String>,
activity_tx: tokio::sync::mpsc::UnboundedSender<String>,
@@ -278,6 +307,7 @@ fn run_pty_session(
&activity_tx,
&msg_tx,
&mut sid_tx,
&auth_failed,
)
{
got_result = true;
@@ -304,6 +334,7 @@ fn run_pty_session(
&activity_tx,
&msg_tx,
&mut sid_tx,
&auth_failed,
);
}
}
@@ -341,6 +372,7 @@ fn run_pty_session(
///
/// Returns `true` if a `result` event was received (signals session completion).
/// Captures the session ID from the first event that carries it.
/// Sets `auth_failed` to `true` if an `authentication_failed` error is detected.
fn process_json_event(
json: &serde_json::Value,
token_tx: &tokio::sync::mpsc::UnboundedSender<String>,
@@ -348,6 +380,7 @@ fn process_json_event(
activity_tx: &tokio::sync::mpsc::UnboundedSender<String>,
msg_tx: &std::sync::mpsc::Sender<Message>,
sid_tx: &mut Option<tokio::sync::oneshot::Sender<String>>,
auth_failed: &AtomicBool,
) -> bool {
let event_type = match json.get("type").and_then(|t| t.as_str()) {
Some(t) => t,
@@ -364,6 +397,12 @@ fn process_json_event(
}
}
// Detect authentication_failed at the top level of any event.
if json.get("error").and_then(|e| e.as_str()) == Some("authentication_failed") {
slog!("[pty-debug] Detected authentication_failed error");
auth_failed.store(true, Ordering::Relaxed);
}
match event_type {
"stream_event" => {
if let Some(event) = json.get("event") {
@@ -916,7 +955,8 @@ mod tests {
&thi_tx,
&act_tx,
&msg_tx,
&mut sid_tx_opt
&mut sid_tx_opt,
&AtomicBool::new(false),
));
}
@@ -931,7 +971,8 @@ mod tests {
&thi_tx,
&act_tx,
&msg_tx,
&mut sid_tx
&mut sid_tx,
&AtomicBool::new(false),
));
}
@@ -946,7 +987,8 @@ mod tests {
&thi_tx,
&act_tx,
&msg_tx,
&mut sid_tx
&mut sid_tx,
&AtomicBool::new(false),
));
}
@@ -961,7 +1003,8 @@ mod tests {
&thi_tx,
&act_tx,
&msg_tx,
&mut sid_tx
&mut sid_tx,
&AtomicBool::new(false),
));
}
@@ -976,7 +1019,8 @@ mod tests {
&thi_tx,
&act_tx,
&msg_tx,
&mut sid_tx
&mut sid_tx,
&AtomicBool::new(false),
));
}
@@ -986,7 +1030,7 @@ mod tests {
let (sid_tx, mut sid_rx) = tokio::sync::oneshot::channel::<String>();
let mut sid_tx_opt = Some(sid_tx);
let json = json!({"type": "system", "session_id": "sess-abc-123"});
process_json_event(&json, &tok_tx, &thi_tx, &act_tx, &msg_tx, &mut sid_tx_opt);
process_json_event(&json, &tok_tx, &thi_tx, &act_tx, &msg_tx, &mut sid_tx_opt, &AtomicBool::new(false));
// sid_tx should have been consumed
assert!(sid_tx_opt.is_none());
let received = sid_rx.try_recv().unwrap();
@@ -999,7 +1043,7 @@ mod tests {
let (sid_tx, _sid_rx) = tokio::sync::oneshot::channel::<String>();
let mut sid_tx_opt = Some(sid_tx);
let json = json!({"type": "system"});
process_json_event(&json, &tok_tx, &thi_tx, &act_tx, &msg_tx, &mut sid_tx_opt);
process_json_event(&json, &tok_tx, &thi_tx, &act_tx, &msg_tx, &mut sid_tx_opt, &AtomicBool::new(false));
// sid_tx should still be present since no session_id in event
assert!(sid_tx_opt.is_some());
}
@@ -1023,7 +1067,8 @@ mod tests {
&thi_tx,
&act_tx,
&msg_tx,
&mut sid_tx
&mut sid_tx,
&AtomicBool::new(false),
));
drop(tok_tx);
let tokens: Vec<String> = {
@@ -1058,7 +1103,8 @@ mod tests {
&thi_tx,
&act_tx,
&msg_tx,
&mut sid_tx
&mut sid_tx,
&AtomicBool::new(false),
));
drop(act_tx);
let activities: Vec<String> = {
@@ -1091,7 +1137,8 @@ mod tests {
&thi_tx,
&act_tx,
&msg_tx,
&mut sid_tx
&mut sid_tx,
&AtomicBool::new(false),
));
drop(act_tx);
let activities: Vec<String> = {
@@ -1124,7 +1171,8 @@ mod tests {
&thi_tx,
&act_tx,
&msg_tx,
&mut sid_tx
&mut sid_tx,
&AtomicBool::new(false),
));
drop(act_tx);
let activities: Vec<String> = {
@@ -1154,7 +1202,8 @@ mod tests {
&thi_tx,
&act_tx,
&msg_tx,
&mut sid_tx
&mut sid_tx,
&AtomicBool::new(false),
));
drop(act_tx);
let activities: Vec<String> = {
@@ -1183,7 +1232,8 @@ mod tests {
&thi_tx,
&act_tx,
&msg_tx,
&mut sid_tx
&mut sid_tx,
&AtomicBool::new(false),
));
drop(msg_tx);
let msgs: Vec<Message> = msg_rx.try_iter().collect();
@@ -1207,7 +1257,8 @@ mod tests {
&thi_tx,
&act_tx,
&msg_tx,
&mut sid_tx
&mut sid_tx,
&AtomicBool::new(false),
));
drop(msg_tx);
let msgs: Vec<Message> = msg_rx.try_iter().collect();
@@ -1230,7 +1281,8 @@ mod tests {
&thi_tx,
&act_tx,
&msg_tx,
&mut sid_tx
&mut sid_tx,
&AtomicBool::new(false),
));
drop(msg_tx);
let msgs: Vec<Message> = msg_rx.try_iter().collect();
@@ -1248,13 +1300,61 @@ mod tests {
&thi_tx,
&act_tx,
&msg_tx,
&mut sid_tx
&mut sid_tx,
&AtomicBool::new(false),
));
drop(msg_tx);
let msgs: Vec<Message> = msg_rx.try_iter().collect();
assert!(msgs.is_empty());
}
#[test]
fn process_json_event_detects_authentication_failed() {
let (tok_tx, _tok_rx, thi_tx, _thi_rx, act_tx, _act_rx, msg_tx, _msg_rx) = make_channels();
let mut sid_tx = None::<tokio::sync::oneshot::Sender<String>>;
let auth_failed = AtomicBool::new(false);
let json = json!({
"type": "assistant",
"error": "authentication_failed",
"message": {
"content": [{"type": "text", "text": "Failed to authenticate."}]
}
});
assert!(!process_json_event(
&json,
&tok_tx,
&thi_tx,
&act_tx,
&msg_tx,
&mut sid_tx,
&auth_failed,
));
assert!(auth_failed.load(Ordering::Relaxed));
}
#[test]
fn process_json_event_no_auth_failed_for_normal_events() {
let (tok_tx, _tok_rx, thi_tx, _thi_rx, act_tx, _act_rx, msg_tx, _msg_rx) = make_channels();
let mut sid_tx = None::<tokio::sync::oneshot::Sender<String>>;
let auth_failed = AtomicBool::new(false);
let json = json!({
"type": "assistant",
"message": {
"content": [{"type": "text", "text": "Hello!"}]
}
});
assert!(!process_json_event(
&json,
&tok_tx,
&thi_tx,
&act_tx,
&msg_tx,
&mut sid_tx,
&auth_failed,
));
assert!(!auth_failed.load(Ordering::Relaxed));
}
#[test]
fn claude_code_provider_new() {
let _provider = ClaudeCodeProvider::new();