//! Server-side `/crdt-sync` WebSocket handler. use bft_json_crdt::json_crdt::SignedOp; use futures::{SinkExt, StreamExt}; use poem::handler; use poem::http::StatusCode; use poem::web::Data; use poem::web::Query; use poem::web::websocket::{Message as WsMessage, WebSocket}; use serde::Deserialize; use std::sync::Arc; use crate::crdt_snapshot; use crate::crdt_state; use crate::crdt_wire; use crate::http::context::AppContext; use crate::node_identity; use crate::slog; use crate::slog_warn; use super::auth::{REQUIRE_TOKEN, trusted_keys, validate_join_token}; use super::dispatch::{handle_incoming_binary, handle_incoming_text}; use super::wire::{AuthMessage, ChallengeMessage, SyncMessage}; use super::{AUTH_TIMEOUT_SECS, PING_INTERVAL_SECS, PONG_TIMEOUT_SECS}; // ── Server-side WebSocket handler ─────────────────────────────────── /// Query parameters accepted on the `/crdt-sync` WebSocket upgrade request. #[derive(Deserialize)] struct SyncQueryParams { /// Optional bearer token. Required when the server is in token-required mode. token: Option, } /// WebSocket handler for CRDT peer synchronisation. /// /// Accepts an optional `?token=` query parameter. When the /// server is configured with `crdt_require_token = true`, a valid token must /// be supplied or the upgrade is rejected with HTTP 401. When the server is /// in open-access mode (the default), a token is optional but still validated /// if present. #[handler] pub async fn crdt_sync_handler( ws: WebSocket, _ctx: Data<&Arc>, remote_addr: &poem::web::RemoteAddr, Query(params): Query, ) -> poem::Response { // ── Bearer-token check (pre-upgrade) ──────────────────────────── let require_token = REQUIRE_TOKEN.get().copied().unwrap_or(false); match ¶ms.token { Some(t) => { if !validate_join_token(t) { slog!("[crdt-sync] Rejected connection: invalid or expired token"); return poem::Response::builder() .status(StatusCode::UNAUTHORIZED) .body("invalid or expired token"); } } None if require_token => { slog!("[crdt-sync] Rejected connection: token required but not provided"); return poem::Response::builder() .status(StatusCode::UNAUTHORIZED) .body("token required"); } None => {} } // ── WebSocket upgrade ──────────────────────────────────────────── use poem::IntoResponse as _; let peer_addr = remote_addr.to_string(); ws.on_upgrade(move |socket| async move { let (mut sink, mut stream) = socket.split(); slog!("[crdt-sync] Peer connected, starting auth handshake"); // ── Step 1: Send challenge to the connecting peer ───────── let challenge = node_identity::generate_challenge(); let challenge_msg = ChallengeMessage { r#type: "challenge".to_string(), nonce: challenge.clone(), }; let challenge_json = match serde_json::to_string(&challenge_msg) { Ok(j) => j, Err(_) => return, }; if sink.send(WsMessage::Text(challenge_json)).await.is_err() { return; } // ── Step 2: Await auth reply within timeout ─────────────── let auth_result = tokio::time::timeout( std::time::Duration::from_secs(AUTH_TIMEOUT_SECS), stream.next(), ) .await; let auth_text = match auth_result { Ok(Some(Ok(WsMessage::Text(text)))) => text, Ok(_) | Err(_) => { // Timeout or connection closed before auth reply. slog!("[crdt-sync] Auth timeout or connection lost during handshake"); let _ = sink .send(WsMessage::Close(Some(( poem::web::websocket::CloseCode::from(4001), "auth_timeout".to_string(), )))) .await; let _ = sink.close().await; return; } }; // ── Step 3: Verify auth reply ───────────────────────────── let auth_msg: AuthMessage = match serde_json::from_str(&auth_text) { Ok(m) => m, Err(_) => { slog!("[crdt-sync] Invalid auth message from peer"); close_with_auth_failed(&mut sink).await; return; } }; // Verify signature AND check allow-list. let sig_valid = node_identity::verify_challenge(&auth_msg.pubkey_hex, &challenge, &auth_msg.signature_hex); let key_trusted = trusted_keys().iter().any(|k| k == &auth_msg.pubkey_hex); if !sig_valid || !key_trusted { slog!("[crdt-sync] Auth failed for peer (sig_valid={sig_valid}, key_trusted={key_trusted})"); close_with_auth_failed(&mut sink).await; return; } slog!( "[crdt-sync] Peer authenticated: {:.12}…", &auth_msg.pubkey_hex ); // ── Auth passed — proceed with CRDT sync ────────────────── // v2 protocol: send our vector clock so the peer can compute the delta. let our_clock = crdt_state::our_vector_clock().unwrap_or_default(); let clock_msg = SyncMessage::Clock { clock: our_clock }; if let Ok(json) = serde_json::to_string(&clock_msg) && sink.send(WsMessage::Text(json)).await.is_err() { return; } // Wait for the peer's first sync message to determine protocol version. let first_msg = tokio::time::timeout( std::time::Duration::from_secs(AUTH_TIMEOUT_SECS), wait_for_sync_text(&mut stream, &mut sink), ) .await; match first_msg { Ok(Some(SyncMessage::Clock { clock: peer_clock })) => { // v2 peer — if we have a snapshot and the peer has an empty // clock (new node), send the snapshot first for onboarding. if peer_clock.is_empty() && let Some(snapshot) = crdt_snapshot::latest_snapshot() { let snap_msg = crdt_snapshot::SnapshotMessage::Snapshot(snapshot); if let Ok(json) = serde_json::to_string(&snap_msg) { if sink.send(WsMessage::Text(json)).await.is_err() { return; } slog!("[crdt-sync] Sent snapshot to new node for onboarding"); } } // Send only the ops the peer is missing. let delta = crdt_state::ops_since(&peer_clock).unwrap_or_default(); slog!( "[crdt-sync] v2 delta sync: sending {} ops (peer missing)", delta.len() ); let msg = SyncMessage::Bulk { ops: delta }; if let Ok(json) = serde_json::to_string(&msg) && sink.send(WsMessage::Text(json)).await.is_err() { return; } } Ok(Some(SyncMessage::Bulk { ops })) => { // v1 peer — apply their bulk and send our full bulk. let mut applied = 0u64; for op_json in &ops { if let Ok(signed_op) = serde_json::from_str::(op_json) && crdt_state::apply_remote_op(signed_op) { applied += 1; } } slog!( "[crdt-sync] v1 bulk sync: received {} ops, applied {applied}", ops.len() ); if let Some(all) = crdt_state::all_ops_json() { let msg = SyncMessage::Bulk { ops: all }; if let Ok(json) = serde_json::to_string(&msg) && sink.send(WsMessage::Text(json)).await.is_err() { return; } } } Ok(Some(SyncMessage::Op { op })) => { // Single op before negotiation — treat as v1. if let Ok(signed_op) = serde_json::from_str::(&op) { crdt_state::apply_remote_op(signed_op); } if let Some(all) = crdt_state::all_ops_json() { let msg = SyncMessage::Bulk { ops: all }; if let Ok(json) = serde_json::to_string(&msg) && sink.send(WsMessage::Text(json)).await.is_err() { return; } } } _ => { // Timeout or error — send full bulk as fallback. slog!("[crdt-sync] No sync message from peer; sending full bulk as fallback"); if let Some(all) = crdt_state::all_ops_json() { let msg = SyncMessage::Bulk { ops: all }; if let Ok(json) = serde_json::to_string(&msg) && sink.send(WsMessage::Text(json)).await.is_err() { return; } } } } // Bulk-delta phase complete — signal the peer that we are ready for // real-time op streaming. if let Ok(json) = serde_json::to_string(&SyncMessage::Ready) && sink.send(WsMessage::Text(json)).await.is_err() { return; } // Subscribe to new local ops. let Some(mut op_rx) = crdt_state::subscribe_ops() else { return; }; // Buffer for locally-generated ops produced before the peer's `ready` // arrives. Flushed in-order once the peer signals catch-up. let mut peer_ready = false; let mut op_buffer: Vec = Vec::new(); // ── Keepalive state ─────────────────────────────────────────── let mut pong_deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(PONG_TIMEOUT_SECS); let mut ping_ticker = tokio::time::interval_at( tokio::time::Instant::now() + std::time::Duration::from_secs(PING_INTERVAL_SECS), std::time::Duration::from_secs(PING_INTERVAL_SECS), ); loop { tokio::select! { // Send periodic Ping and enforce Pong timeout. _ = ping_ticker.tick() => { if tokio::time::Instant::now() >= pong_deadline { slog_warn!( "[crdt-sync] No pong from peer {} in {}s; disconnecting", peer_addr, PONG_TIMEOUT_SECS ); break; } if sink.send(WsMessage::Ping(vec![])).await.is_err() { break; } } // Forward new local ops to the peer encoded via the wire codec. result = op_rx.recv() => { match result { Ok(signed_op) => { if peer_ready { let bytes = crdt_wire::encode(&signed_op); if sink.send(WsMessage::Binary(bytes)).await.is_err() { break; } } else { // Buffer until the peer signals ready. op_buffer.push(signed_op); } } Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { // The peer cannot keep up; disconnect so it can // reconnect and receive a fresh bulk state dump. slog!("[crdt-sync] Slow peer lagged {n} ops; disconnecting"); break; } Err(_) => break, } } // Receive ops from the peer. frame = stream.next() => { match frame { Some(Ok(WsMessage::Pong(_))) => { // Reset the pong deadline on every Pong received. pong_deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(PONG_TIMEOUT_SECS); } Some(Ok(WsMessage::Ping(data))) => { // Respond to peer's Ping so the peer's keepalive passes. let _ = sink.send(WsMessage::Pong(data)).await; } Some(Ok(WsMessage::Text(text))) => { // Check for the ready signal before other text frames. if let Ok(SyncMessage::Ready) = serde_json::from_str(&text) { peer_ready = true; slog!("[crdt-sync] Peer ready; flushing {} buffered ops", op_buffer.len()); let mut flush_ok = true; for op in op_buffer.drain(..) { let bytes = crdt_wire::encode(&op); if sink.send(WsMessage::Binary(bytes)).await.is_err() { flush_ok = false; break; } } if !flush_ok { break; } } else { // Bulk state dump, legacy op frame, or clock frame. handle_incoming_text(&text); } } Some(Ok(WsMessage::Binary(bytes))) => { // Real-time op encoded via wire codec — applied immediately // regardless of our own ready state. handle_incoming_binary(&bytes); } Some(Ok(WsMessage::Close(_))) | None => break, _ => {} } } } } slog!("[crdt-sync] Peer disconnected"); }) .into_response() } /// Wait for the next text-frame sync message from the peer, handling Ping/Pong /// transparently. /// /// Returns `None` on connection close or read error. async fn wait_for_sync_text( stream: &mut futures::stream::SplitStream, sink: &mut futures::stream::SplitSink, ) -> Option { loop { match stream.next().await { Some(Ok(WsMessage::Text(text))) => { return serde_json::from_str(&text).ok(); } Some(Ok(WsMessage::Ping(data))) => { let _ = sink.send(WsMessage::Pong(data)).await; } Some(Ok(WsMessage::Pong(_))) => continue, _ => return None, } } } /// Close the WebSocket with a generic `auth_failed` reason. /// /// The close reason is intentionally the same for all auth failures /// (bad signature, untrusted key, malformed message) to avoid leaking /// which check failed. async fn close_with_auth_failed( sink: &mut futures::stream::SplitSink, ) { let _ = sink .send(WsMessage::Close(Some(( poem::web::websocket::CloseCode::from(4002), "auth_failed".to_string(), )))) .await; let _ = sink.close().await; } /// Process an incoming text-frame sync message from a peer. /// #[cfg(test)] mod tests { use super::*; #[derive(Debug)] #[allow(dead_code)] enum AuthListenerResult { Authenticated(String), AuthFailed(String), AuthTimeout, ConnectionLost, PeerClosedEarly(Option), } #[test] fn peer_receives_op_encoded_via_wire_codec() { use bft_json_crdt::json_crdt::BaseCrdt; use bft_json_crdt::keypair::make_keypair; use bft_json_crdt::op::ROOT_ID; use serde_json::json; use crate::crdt_state::PipelineDoc; use crate::crdt_wire; let kp = make_keypair(); let mut crdt = BaseCrdt::::new(&kp); let item: bft_json_crdt::json_crdt::JsonValue = json!({ "story_id": "506_story_lifecycle_test", "stage": "1_backlog", "name": "Lifecycle Test", "agent": "", "retry_count": 0.0, "blocked": false, "depends_on": "", "claimed_by": "", "claimed_at": 0.0, }) .into(); let op = crdt.doc.items.insert(ROOT_ID, item).sign(&kp); // Simulate what the broadcast handler does: encode via wire codec. let bytes = crdt_wire::encode(&op); // The bytes must be a versioned JSON envelope, not a SyncMessage wrapper. let text = std::str::from_utf8(&bytes).expect("wire output is valid UTF-8"); assert!( text.contains("\"v\":1"), "wire codec version tag must be present: {text}" ); assert!( !text.contains("\"type\":\"op\""), "must not be wrapped in SyncMessage: {text}" ); // The receiving peer can decode and apply the op. let decoded = crdt_wire::decode(&bytes).expect("decode must succeed"); assert_eq!(op, decoded); } #[tokio::test] async fn multiple_peers_all_receive_broadcast_op() { use bft_json_crdt::json_crdt::BaseCrdt; use bft_json_crdt::keypair::make_keypair; use bft_json_crdt::op::ROOT_ID; use serde_json::json; use tokio::sync::broadcast; use crate::crdt_state::PipelineDoc; use crate::crdt_wire; // Create a broadcast channel (analogous to SYNC_TX). let (tx, _) = broadcast::channel::(16); let mut rx_peer1 = tx.subscribe(); let mut rx_peer2 = tx.subscribe(); let kp = make_keypair(); let mut crdt = BaseCrdt::::new(&kp); let item: bft_json_crdt::json_crdt::JsonValue = json!({ "story_id": "506_story_multi_peer_test", "stage": "1_backlog", "name": "Multi-Peer Test", "agent": "", "retry_count": 0.0, "blocked": false, "depends_on": "", "claimed_by": "", "claimed_at": 0.0, }) .into(); let op = crdt.doc.items.insert(ROOT_ID, item).sign(&kp); // Broadcast one op. tx.send(op.clone()).expect("send must succeed"); // Both peers receive the same op. let received1 = rx_peer1.recv().await.expect("peer 1 must receive"); let received2 = rx_peer2.recv().await.expect("peer 2 must receive"); assert_eq!(received1, op); assert_eq!(received2, op); // Both encode identically via wire codec. let bytes1 = crdt_wire::encode(&received1); let bytes2 = crdt_wire::encode(&received2); assert_eq!(bytes1, bytes2, "wire-encoded bytes must be identical"); } #[test] fn disconnected_peer_does_not_panic() { use bft_json_crdt::json_crdt::BaseCrdt; use bft_json_crdt::keypair::make_keypair; use bft_json_crdt::op::ROOT_ID; use serde_json::json; use tokio::sync::broadcast; use crate::crdt_state::PipelineDoc; let (tx, rx) = broadcast::channel::(16); // Drop the receiver to simulate a peer that disconnected. drop(rx); let kp = make_keypair(); let mut crdt = BaseCrdt::::new(&kp); let item: bft_json_crdt::json_crdt::JsonValue = json!({ "story_id": "506_story_disconnect_test", "stage": "1_backlog", "name": "Disconnect Test", "agent": "", "retry_count": 0.0, "blocked": false, "depends_on": "", "claimed_by": "", "claimed_at": 0.0, }) .into(); let op = crdt.doc.items.insert(ROOT_ID, item).sign(&kp); // Sending to a channel with no receivers returns an error; must not panic. let _ = tx.send(op); } #[tokio::test] async fn lagged_peer_gets_lagged_error() { use bft_json_crdt::json_crdt::BaseCrdt; use bft_json_crdt::keypair::make_keypair; use bft_json_crdt::op::ROOT_ID; use serde_json::json; use tokio::sync::broadcast; use crate::crdt_state::PipelineDoc; // Tiny capacity so we can trigger Lagged easily. let (tx, mut rx) = broadcast::channel::(2); let kp = make_keypair(); let mut crdt = BaseCrdt::::new(&kp); let item: bft_json_crdt::json_crdt::JsonValue = json!({ "story_id": "506_story_lag_test", "stage": "1_backlog", "name": "Lag Test", "agent": "", "retry_count": 0.0, "blocked": false, "depends_on": "", "claimed_by": "", "claimed_at": 0.0, }) .into(); let op1 = crdt.doc.items.insert(ROOT_ID, item).sign(&kp); crdt.apply(op1.clone()); // Overflow the tiny buffer by sending more ops than the capacity. let op2 = crdt.doc.items[0] .stage .set("2_current".to_string()) .sign(&kp); crdt.apply(op2.clone()); let op3 = crdt.doc.items[0].stage.set("3_qa".to_string()).sign(&kp); crdt.apply(op3.clone()); let op4 = crdt.doc.items[0].stage.set("4_merge".to_string()).sign(&kp); crdt.apply(op4.clone()); // Send more ops than the channel capacity without consuming. let _ = tx.send(op1); let _ = tx.send(op2); let _ = tx.send(op3); let _ = tx.send(op4); // The slow peer should now see a Lagged error on next recv. // Consume until we hit Lagged or run out. let mut got_lagged = false; for _ in 0..10 { match rx.recv().await { Err(broadcast::error::RecvError::Lagged(_)) => { got_lagged = true; break; } Ok(_) => continue, Err(broadcast::error::RecvError::Closed) => break, } } assert!( got_lagged, "slow peer must receive a Lagged error when channel overflows" ); } #[tokio::test] async fn e2e_convergence_two_websocket_nodes() { use bft_json_crdt::json_crdt::{BaseCrdt, CrdtNode, JsonValue as JV, OpState}; use bft_json_crdt::keypair::make_keypair; use bft_json_crdt::op::ROOT_ID; use futures::{SinkExt, StreamExt}; use serde_json::json; use std::sync::{Arc, Mutex}; use tokio::net::TcpListener; use tokio_tungstenite::tungstenite::Message as TMsg; use tokio_tungstenite::{accept_async, connect_async}; use crate::crdt_state::PipelineDoc; // ── Node A: build local state ────────────────────────────────────── let kp_a = make_keypair(); let mut crdt_a = BaseCrdt::::new(&kp_a); let item: bft_json_crdt::json_crdt::JsonValue = json!({ "story_id": "508_e2e_convergence", "stage": "2_current", "name": "E2E Convergence Test", "agent": "", "retry_count": 0.0, "blocked": false, "depends_on": "", "claimed_by": "", "claimed_at": 0.0, }) .into(); let op1 = crdt_a.doc.items.insert(ROOT_ID, item).sign(&kp_a); crdt_a.apply(op1.clone()); // Serialise A's full state as a bulk message. let op1_json = serde_json::to_string(&op1).unwrap(); let bulk_msg = SyncMessage::Bulk { ops: vec![op1_json], }; let bulk_wire = serde_json::to_string(&bulk_msg).unwrap(); // ── Start Node A's WebSocket server on a random port ─────────────── let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let bulk_to_send = bulk_wire.clone(); let received_by_a: Arc>> = Arc::new(Mutex::new(vec![])); let received_by_a_clone = received_by_a.clone(); tokio::spawn(async move { let (tcp_stream, _) = listener.accept().await.unwrap(); let ws_stream = accept_async(tcp_stream).await.unwrap(); let (mut sink, mut stream) = ws_stream.split(); // Send bulk state to the connecting peer. sink.send(TMsg::Text(bulk_to_send.into())).await.unwrap(); // Also listen for ops sent by the peer. if let Some(Ok(TMsg::Text(txt))) = stream.next().await { received_by_a_clone.lock().unwrap().push(txt.to_string()); } }); // ── Node B: connect to Node A and exchange state ─────────────────── let kp_b = make_keypair(); let mut crdt_b = BaseCrdt::::new(&kp_b); let url = format!("ws://{addr}"); let (ws_b, _) = connect_async(&url).await.unwrap(); let (mut sink_b, mut stream_b) = ws_b.split(); // Node B receives bulk from A. if let Some(Ok(TMsg::Text(txt))) = stream_b.next().await { let msg: SyncMessage = serde_json::from_str(txt.as_str()).unwrap(); match msg { SyncMessage::Bulk { ops } => { for op_str in &ops { let signed: bft_json_crdt::json_crdt::SignedOp = serde_json::from_str(op_str).unwrap(); let r = crdt_b.apply(signed); assert!(r == OpState::Ok || r == OpState::AlreadySeen); } } _ => panic!("Expected Bulk from Node A"), } } // Node B also creates a new op and sends it to A. let item_b: bft_json_crdt::json_crdt::JsonValue = json!({ "story_id": "508_e2e_convergence_b", "stage": "1_backlog", "name": "E2E Convergence B", "agent": "", "retry_count": 0.0, "blocked": false, "depends_on": "", "claimed_by": "", "claimed_at": 0.0, }) .into(); let op_b1 = crdt_b.doc.items.insert(ROOT_ID, item_b).sign(&kp_b); crdt_b.apply(op_b1.clone()); let op_b1_json = serde_json::to_string(&op_b1).unwrap(); let msg_to_a = SyncMessage::Op { op: op_b1_json }; sink_b .send(TMsg::Text(serde_json::to_string(&msg_to_a).unwrap().into())) .await .unwrap(); // Wait a moment for Node A to process. tokio::time::sleep(std::time::Duration::from_millis(50)).await; // ── Assert convergence ───────────────────────────────────────────── // Node B received Node A's item. assert_eq!( crdt_b.doc.items.view().len(), 2, "Node B must see both items after sync" ); let has_a_item = crdt_b .doc .items .view() .iter() .any(|item| item.story_id.view() == JV::String("508_e2e_convergence".to_string())); assert!(has_a_item, "Node B must have Node A's item"); // Node A received Node B's op via the WebSocket. let a_received = received_by_a.lock().unwrap(); assert!( !a_received.is_empty(), "Node A must have received an op from Node B" ); } #[tokio::test] async fn e2e_partition_healing_websocket() { use bft_json_crdt::json_crdt::{BaseCrdt, CrdtNode, JsonValue as JV}; use bft_json_crdt::keypair::make_keypair; use bft_json_crdt::op::ROOT_ID; use futures::{SinkExt, StreamExt}; use serde_json::json; use tokio::net::TcpListener; use tokio_tungstenite::tungstenite::Message as TMsg; use tokio_tungstenite::{accept_async, connect_async}; use crate::crdt_state::PipelineDoc; // ── Phase 1: Both nodes start with op_a1 (before partition) ─────── let kp_a = make_keypair(); let kp_b = make_keypair(); let mut crdt_a = BaseCrdt::::new(&kp_a); let mut crdt_b = BaseCrdt::::new(&kp_b); let item_a: bft_json_crdt::json_crdt::JsonValue = json!({ "story_id": "508_heal_a", "stage": "1_backlog", "name": "Heal A", "agent": "", "retry_count": 0.0, "blocked": false, "depends_on": "", "claimed_by": "", "claimed_at": 0.0, }) .into(); let op_a1 = crdt_a.doc.items.insert(ROOT_ID, item_a).sign(&kp_a); crdt_a.apply(op_a1.clone()); // B also starts with op_a1 (shared state before partition). crdt_b.apply(op_a1.clone()); // ── Phase 2: Partition — each side mutates independently ────────── // A advances its story stage. let op_a2 = crdt_a.doc.items[0] .stage .set("2_current".to_string()) .sign(&kp_a); crdt_a.apply(op_a2.clone()); // B inserts a new story that A doesn't know about yet. let item_b: bft_json_crdt::json_crdt::JsonValue = json!({ "story_id": "508_heal_b", "stage": "1_backlog", "name": "Heal B", "agent": "", "retry_count": 0.0, "blocked": false, "depends_on": "", "claimed_by": "", "claimed_at": 0.0, }) .into(); let op_b1 = crdt_b.doc.items.insert(ROOT_ID, item_b).sign(&kp_b); crdt_b.apply(op_b1.clone()); // Collect B's full state as bulk (what it will send on reconnect). let b_ops: Vec = [&op_a1, &op_b1] .iter() .map(|op| serde_json::to_string(op).unwrap()) .collect(); let b_bulk_wire = serde_json::to_string(&SyncMessage::Bulk { ops: b_ops }).unwrap(); // Collect A's full state as bulk (what it will send on reconnect). let a_ops: Vec = [&op_a1, &op_a2] .iter() .map(|op| serde_json::to_string(op).unwrap()) .collect(); let a_bulk_wire = serde_json::to_string(&SyncMessage::Bulk { ops: a_ops }).unwrap(); // ── Phase 3: Reconnect — use a real WebSocket to exchange bulk ──── let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let a_bulk_to_send = a_bulk_wire.clone(); let a_received_bulk: std::sync::Arc>> = std::sync::Arc::new(std::sync::Mutex::new(None)); let a_received_clone = a_received_bulk.clone(); tokio::spawn(async move { let (tcp, _) = listener.accept().await.unwrap(); let ws = accept_async(tcp).await.unwrap(); let (mut sink, mut stream) = ws.split(); // A sends its bulk state. sink.send(TMsg::Text(a_bulk_to_send.into())).await.unwrap(); // A receives B's bulk state. if let Some(Ok(TMsg::Text(txt))) = stream.next().await { *a_received_clone.lock().unwrap() = Some(txt.to_string()); } }); // B connects, exchanges bulk state. let (ws_b, _) = connect_async(format!("ws://{addr}")).await.unwrap(); let (mut sink_b, mut stream_b) = ws_b.split(); // B receives A's bulk and applies it. if let Some(Ok(TMsg::Text(txt))) = stream_b.next().await { let msg: SyncMessage = serde_json::from_str(txt.as_str()).unwrap(); if let SyncMessage::Bulk { ops } = msg { for op_str in &ops { let signed: bft_json_crdt::json_crdt::SignedOp = serde_json::from_str(op_str).unwrap(); let _ = crdt_b.apply(signed); } } } // B sends its bulk state to A. sink_b.send(TMsg::Text(b_bulk_wire.into())).await.unwrap(); tokio::time::sleep(std::time::Duration::from_millis(50)).await; // Apply A's received ops into crdt_a. if let Some(bulk_str) = a_received_bulk.lock().unwrap().take() { let msg: SyncMessage = serde_json::from_str(&bulk_str).unwrap(); if let SyncMessage::Bulk { ops } = msg { for op_str in &ops { let signed: bft_json_crdt::json_crdt::SignedOp = serde_json::from_str(op_str).unwrap(); let _ = crdt_a.apply(signed); } } } // ── Assert convergence ───────────────────────────────────────────── // Both nodes must have 2 items. assert_eq!( crdt_a.doc.items.view().len(), 2, "A must have 2 items after healing" ); assert_eq!( crdt_b.doc.items.view().len(), 2, "B must have 2 items after healing" ); // A must see B's story. let b_story_on_a = crdt_a .doc .items .view() .iter() .any(|item| item.story_id.view() == JV::String("508_heal_b".to_string())); assert!(b_story_on_a, "A must have B's story after healing"); // B must see A's stage advance. let a_story_on_b = crdt_b .doc .items .view() .iter() .any(|item| item.story_id.view() == JV::String("508_heal_a".to_string())); assert!(a_story_on_b, "B must have A's story after healing"); // CRDT views must be byte-identical (convergence). let view_a = serde_json::to_string(&CrdtNode::view(&crdt_a.doc.items)).unwrap(); let view_b = serde_json::to_string(&CrdtNode::view(&crdt_b.doc.items)).unwrap(); assert_eq!( view_a, view_b, "Both nodes must converge to identical state" ); } async fn start_auth_listener( trusted_keys: Vec, ) -> ( std::net::SocketAddr, tokio::sync::oneshot::Receiver, ) { use tokio::net::TcpListener; use tokio_tungstenite::accept_async; let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let (result_tx, result_rx) = tokio::sync::oneshot::channel(); tokio::spawn(async move { let (tcp_stream, _) = listener.accept().await.unwrap(); let ws_stream = accept_async(tcp_stream).await.unwrap(); let (mut sink, mut stream) = ws_stream.split(); use tokio_tungstenite::tungstenite::Message as TMsg; // Step 1: Send challenge. let challenge = crate::node_identity::generate_challenge(); let challenge_msg = super::ChallengeMessage { r#type: "challenge".to_string(), nonce: challenge.clone(), }; let challenge_json = serde_json::to_string(&challenge_msg).unwrap(); if sink.send(TMsg::Text(challenge_json.into())).await.is_err() { let _ = result_tx.send(AuthListenerResult::ConnectionLost); return; } // Step 2: Await auth reply (10s timeout). let auth_frame = tokio::time::timeout(std::time::Duration::from_secs(10), stream.next()).await; let auth_text = match auth_frame { Ok(Some(Ok(TMsg::Text(t)))) => t.to_string(), Ok(Some(Ok(TMsg::Close(reason)))) => { let _ = result_tx.send(AuthListenerResult::PeerClosedEarly( reason.map(|r| r.reason.to_string()), )); return; } _ => { let _ = sink .send(TMsg::Close(Some(tokio_tungstenite::tungstenite::protocol::CloseFrame { code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::from(4001), reason: "auth_timeout".into(), }))) .await; let _ = result_tx.send(AuthListenerResult::AuthTimeout); return; } }; // Step 3: Verify. let auth_msg: super::AuthMessage = match serde_json::from_str(&auth_text) { Ok(m) => m, Err(_) => { let _ = close_listener_auth_failed(&mut sink).await; let _ = result_tx.send(AuthListenerResult::AuthFailed("bad_json".into())); return; } }; let sig_valid = crate::node_identity::verify_challenge( &auth_msg.pubkey_hex, &challenge, &auth_msg.signature_hex, ); let key_trusted = trusted_keys.iter().any(|k| k == &auth_msg.pubkey_hex); if !sig_valid || !key_trusted { let _ = close_listener_auth_failed(&mut sink).await; let _ = result_tx.send(AuthListenerResult::AuthFailed(format!( "sig_valid={sig_valid}, key_trusted={key_trusted}" ))); return; } // Auth passed! Send a bulk state with one op to prove sync works. let kp = bft_json_crdt::keypair::make_keypair(); let mut crdt = bft_json_crdt::json_crdt::BaseCrdt::::new(&kp); let item: bft_json_crdt::json_crdt::JsonValue = serde_json::json!({ "story_id": "628_auth_test_item", "stage": "1_backlog", "name": "Auth Test", "agent": "", "retry_count": 0.0, "blocked": false, "depends_on": "", "claimed_by": "", "claimed_at": 0.0, }) .into(); let op = crdt .doc .items .insert(bft_json_crdt::op::ROOT_ID, item) .sign(&kp); let op_json = serde_json::to_string(&op).unwrap(); let bulk = super::SyncMessage::Bulk { ops: vec![op_json] }; let bulk_json = serde_json::to_string(&bulk).unwrap(); let _ = sink.send(TMsg::Text(bulk_json.into())).await; let _ = result_tx.send(AuthListenerResult::Authenticated(auth_msg.pubkey_hex)); }); (addr, result_rx) } async fn close_listener_auth_failed( sink: &mut futures::stream::SplitSink< tokio_tungstenite::WebSocketStream, tokio_tungstenite::tungstenite::Message, >, ) { use tokio_tungstenite::tungstenite::protocol::CloseFrame; use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; let _ = sink .send(tokio_tungstenite::tungstenite::Message::Close(Some( CloseFrame { code: CloseCode::from(4002), reason: "auth_failed".into(), }, ))) .await; } #[tokio::test] async fn auth_happy_path_handshake_and_sync() { use bft_json_crdt::keypair::make_keypair; use futures::{SinkExt, StreamExt}; use tokio_tungstenite::connect_async; use tokio_tungstenite::tungstenite::Message as TMsg; let connector_kp = make_keypair(); let connector_pubkey = crate::node_identity::public_key_hex(&connector_kp); // Start listener that trusts the connector's pubkey. let (addr, result_rx) = start_auth_listener(vec![connector_pubkey.clone()]).await; // Connect and do the handshake. let url = format!("ws://{addr}"); let (ws, _) = connect_async(&url).await.unwrap(); let (mut sink, mut stream) = ws.split(); // Receive challenge. let challenge_frame = stream.next().await.unwrap().unwrap(); let challenge_text = match challenge_frame { TMsg::Text(t) => t.to_string(), other => panic!("Expected text frame, got {other:?}"), }; let challenge_msg: super::ChallengeMessage = serde_json::from_str(&challenge_text).unwrap(); assert_eq!(challenge_msg.r#type, "challenge"); assert_eq!( challenge_msg.nonce.len(), 64, "Challenge must be 64 hex chars" ); // Sign and reply. let sig = crate::node_identity::sign_challenge(&connector_kp, &challenge_msg.nonce); let auth_msg = super::AuthMessage { r#type: "auth".to_string(), pubkey_hex: connector_pubkey.clone(), signature_hex: sig, }; let auth_json = serde_json::to_string(&auth_msg).unwrap(); sink.send(TMsg::Text(auth_json.into())).await.unwrap(); // After auth, we should receive a bulk sync message with at least one op. let bulk_frame = tokio::time::timeout(std::time::Duration::from_secs(5), stream.next()) .await .expect("should receive bulk within 5s") .unwrap() .unwrap(); let bulk_text = match bulk_frame { TMsg::Text(t) => t.to_string(), other => panic!("Expected bulk text frame, got {other:?}"), }; let bulk_msg: super::SyncMessage = serde_json::from_str(&bulk_text).unwrap(); match bulk_msg { super::SyncMessage::Bulk { ops } => { assert!( !ops.is_empty(), "Bulk sync must contain at least one op after successful auth" ); // Verify we can deserialize the op. let _signed: bft_json_crdt::json_crdt::SignedOp = serde_json::from_str(&ops[0]).unwrap(); } _ => panic!("Expected Bulk message after auth"), } // Verify listener also reports success. let listener_result = result_rx.await.unwrap(); match listener_result { AuthListenerResult::Authenticated(pubkey) => { assert_eq!(pubkey, connector_pubkey); } other => panic!("Expected Authenticated, got {other:?}"), } } #[tokio::test] async fn auth_untrusted_pubkey_rejected() { use bft_json_crdt::keypair::make_keypair; use futures::{SinkExt, StreamExt}; use tokio_tungstenite::connect_async; use tokio_tungstenite::tungstenite::Message as TMsg; let connector_kp = make_keypair(); let connector_pubkey = crate::node_identity::public_key_hex(&connector_kp); // Listener trusts a DIFFERENT key, not the connector's. let other_kp = make_keypair(); let other_pubkey = crate::node_identity::public_key_hex(&other_kp); let (addr, result_rx) = start_auth_listener(vec![other_pubkey]).await; let url = format!("ws://{addr}"); let (ws, _) = connect_async(&url).await.unwrap(); let (mut sink, mut stream) = ws.split(); // Receive challenge and sign with our (untrusted) key. let challenge_frame = stream.next().await.unwrap().unwrap(); let challenge_text = match challenge_frame { TMsg::Text(t) => t.to_string(), _ => panic!("Expected text frame"), }; let challenge_msg: super::ChallengeMessage = serde_json::from_str(&challenge_text).unwrap(); let sig = crate::node_identity::sign_challenge(&connector_kp, &challenge_msg.nonce); let auth_msg = super::AuthMessage { r#type: "auth".to_string(), pubkey_hex: connector_pubkey, signature_hex: sig, }; sink.send(TMsg::Text(serde_json::to_string(&auth_msg).unwrap().into())) .await .unwrap(); // Should receive a close frame with auth_failed. let close_frame = tokio::time::timeout(std::time::Duration::from_secs(5), stream.next()) .await .expect("should receive close within 5s"); match close_frame { Some(Ok(TMsg::Close(Some(frame)))) => { assert_eq!( &*frame.reason, "auth_failed", "Close reason must be 'auth_failed'" ); } Some(Ok(TMsg::Close(None))) => { // Some implementations omit the close frame payload — that's acceptable // as long as no sync data was sent. } other => { // Connection dropped without close frame is also acceptable. // The key assertion is below: no ops were exchanged. let _ = other; } } // Verify listener reports auth failure. let listener_result = result_rx.await.unwrap(); match listener_result { AuthListenerResult::AuthFailed(reason) => { assert!(reason.contains("key_trusted=false"), "Reason: {reason}"); } other => panic!("Expected AuthFailed, got {other:?}"), } } #[tokio::test] async fn auth_bad_signature_rejected() { use bft_json_crdt::keypair::make_keypair; use futures::{SinkExt, StreamExt}; use tokio_tungstenite::connect_async; use tokio_tungstenite::tungstenite::Message as TMsg; let legitimate_kp = make_keypair(); let legitimate_pubkey = crate::node_identity::public_key_hex(&legitimate_kp); // A different keypair that will sign the challenge (wrong key). let impersonator_kp = make_keypair(); // Listener trusts the legitimate pubkey. let (addr, result_rx) = start_auth_listener(vec![legitimate_pubkey.clone()]).await; let url = format!("ws://{addr}"); let (ws, _) = connect_async(&url).await.unwrap(); let (mut sink, mut stream) = ws.split(); // Receive challenge. let challenge_frame = stream.next().await.unwrap().unwrap(); let challenge_text = match challenge_frame { TMsg::Text(t) => t.to_string(), _ => panic!("Expected text frame"), }; let challenge_msg: super::ChallengeMessage = serde_json::from_str(&challenge_text).unwrap(); // Sign with the WRONG keypair but claim to be the legitimate pubkey. let bad_sig = crate::node_identity::sign_challenge(&impersonator_kp, &challenge_msg.nonce); let auth_msg = super::AuthMessage { r#type: "auth".to_string(), pubkey_hex: legitimate_pubkey, signature_hex: bad_sig, }; sink.send(TMsg::Text(serde_json::to_string(&auth_msg).unwrap().into())) .await .unwrap(); // Should be rejected. let close_frame = tokio::time::timeout(std::time::Duration::from_secs(5), stream.next()) .await .expect("should receive close within 5s"); match close_frame { Some(Ok(TMsg::Close(Some(frame)))) => { assert_eq!( &*frame.reason, "auth_failed", "Close reason must be 'auth_failed' — same as untrusted key" ); } _ => { // Connection closed is acceptable. } } // Verify listener reports auth failure with sig_valid=false. let listener_result = result_rx.await.unwrap(); match listener_result { AuthListenerResult::AuthFailed(reason) => { assert!(reason.contains("sig_valid=false"), "Reason: {reason}"); } other => panic!("Expected AuthFailed, got {other:?}"), } } #[tokio::test] async fn auth_replay_protection_fresh_nonces() { use futures::StreamExt; use tokio::net::TcpListener; use tokio_tungstenite::tungstenite::Message as TMsg; use tokio_tungstenite::{accept_async, connect_async}; // Start a listener that sends challenges but doesn't complete auth. let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let (nonce_tx, mut nonce_rx) = tokio::sync::mpsc::channel::(2); tokio::spawn(async move { for _ in 0..2 { let (tcp, _) = listener.accept().await.unwrap(); let ws = accept_async(tcp).await.unwrap(); let (mut sink, _stream) = ws.split(); let challenge = crate::node_identity::generate_challenge(); let msg = super::ChallengeMessage { r#type: "challenge".to_string(), nonce: challenge.clone(), }; let json = serde_json::to_string(&msg).unwrap(); let _ = sink.send(TMsg::Text(json.into())).await; let _ = nonce_tx.send(challenge).await; } }); // Connect twice and collect the nonces. let mut nonces = Vec::new(); for _ in 0..2 { let url = format!("ws://{addr}"); let (ws, _) = connect_async(&url).await.unwrap(); let (_sink, mut stream) = ws.split(); let frame = stream.next().await.unwrap().unwrap(); let text = match frame { TMsg::Text(t) => t.to_string(), _ => panic!("Expected text"), }; let msg: super::ChallengeMessage = serde_json::from_str(&text).unwrap(); nonces.push(msg.nonce); // Drop connection so listener accepts the next one. drop(stream); } // Also collect nonces from the listener side. let server_nonce_1 = nonce_rx.recv().await.unwrap(); let server_nonce_2 = nonce_rx.recv().await.unwrap(); assert_ne!( nonces[0], nonces[1], "Consecutive challenges must be different" ); assert_ne!( server_nonce_1, server_nonce_2, "Server must generate fresh nonce per accept" ); assert_eq!(nonces[0], server_nonce_1, "Client/server nonces must match"); assert_eq!(nonces[1], server_nonce_2, "Client/server nonces must match"); } #[test] fn keepalive_constants_are_correct() { assert_eq!( super::super::PING_INTERVAL_SECS, 30, "Ping interval must be 30 seconds" ); assert_eq!( super::super::PONG_TIMEOUT_SECS, 60, "Pong timeout must be 60 seconds" ); } #[test] fn agent_mode_heartbeat_interval_unchanged() { assert_eq!( crate::agent_mode::SCAN_INTERVAL_SECS, 15, "Agent-mode heartbeat interval must remain 15s" ); } #[test] fn reconnect_backoff_constants_unchanged() { assert_eq!( super::super::RENDEZVOUS_ERROR_THRESHOLD, 10, "Backoff threshold must still be 10" ); } #[tokio::test] async fn server_sends_ping_to_peer_at_interval() { use futures::{SinkExt, StreamExt}; use std::time::Duration; use tokio::net::TcpListener; use tokio_tungstenite::tungstenite::Message as TMsg; use tokio_tungstenite::{accept_async, connect_async}; let ping_ms = 100u64; let timeout_ms = 400u64; let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); // Server task: keepalive sender with short intervals. tokio::spawn(async move { let (tcp, _) = listener.accept().await.unwrap(); let ws = accept_async(tcp).await.unwrap(); let (mut sink, mut stream) = ws.split(); let mut pong_deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms); let mut ticker = tokio::time::interval_at( tokio::time::Instant::now() + Duration::from_millis(ping_ms), Duration::from_millis(ping_ms), ); loop { tokio::select! { _ = ticker.tick() => { if tokio::time::Instant::now() >= pong_deadline { break; } if sink.send(TMsg::Ping(bytes::Bytes::new())).await.is_err() { break; } } frame = stream.next() => { match frame { Some(Ok(TMsg::Pong(_))) => { pong_deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms); } None | Some(Err(_)) => break, _ => {} } } } } }); let (ws_client, _) = connect_async(format!("ws://{addr}")).await.unwrap(); let (_sink_c, mut stream_c) = ws_client.split(); // Wait for more than one ping interval. tokio::time::sleep(Duration::from_millis(ping_ms * 2)).await; // Client should receive a Ping from the server. let frame = tokio::time::timeout(Duration::from_millis(200), stream_c.next()).await; let got_ping = matches!(frame, Ok(Some(Ok(TMsg::Ping(_))))); assert!( got_ping, "Client must receive a Ping frame from the server after the ping interval" ); } #[tokio::test] async fn client_sends_ping_to_server_at_interval() { use futures::{SinkExt, StreamExt}; use std::time::Duration; use tokio::net::TcpListener; use tokio_tungstenite::tungstenite::Message as TMsg; use tokio_tungstenite::{accept_async, connect_async}; let ping_ms = 100u64; let timeout_ms = 400u64; let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let (ping_tx, ping_rx) = tokio::sync::oneshot::channel::<()>(); // Server task: wait for the first Ping the client sends. tokio::spawn(async move { let (tcp, _) = listener.accept().await.unwrap(); let ws = accept_async(tcp).await.unwrap(); let (_sink, mut stream) = ws.split(); loop { match stream.next().await { Some(Ok(TMsg::Ping(_))) => { let _ = ping_tx.send(()); break; } Some(Ok(_)) => continue, _ => break, } } }); let (ws_client, _) = connect_async(format!("ws://{addr}")).await.unwrap(); let (mut sink_c, mut stream_c) = ws_client.split(); // Client keepalive task. tokio::spawn(async move { let mut pong_deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms); let mut ticker = tokio::time::interval_at( tokio::time::Instant::now() + Duration::from_millis(ping_ms), Duration::from_millis(ping_ms), ); loop { tokio::select! { _ = ticker.tick() => { if tokio::time::Instant::now() >= pong_deadline { break; } if sink_c.send(TMsg::Ping(bytes::Bytes::new())).await.is_err() { break; } } frame = stream_c.next() => { match frame { Some(Ok(TMsg::Pong(_))) => { pong_deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms); } None | Some(Err(_)) => break, _ => {} } } } } }); let result = tokio::time::timeout(Duration::from_millis(ping_ms * 3), ping_rx).await; assert!( result.is_ok(), "Server must receive a Ping from the client after the ping interval" ); } #[tokio::test] async fn keepalive_disconnects_when_pong_withheld() { use futures::{SinkExt, StreamExt}; use std::time::Duration; use tokio::net::TcpListener; use tokio_tungstenite::tungstenite::Message as TMsg; use tokio_tungstenite::{accept_async, connect_async}; let ping_ms = 100u64; let timeout_ms = 250u64; let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let (done_tx, done_rx) = tokio::sync::oneshot::channel::(); // Server: sends Pings, never receives Pong (client swallows all). tokio::spawn(async move { let (tcp, _) = listener.accept().await.unwrap(); let ws = accept_async(tcp).await.unwrap(); let (mut sink, mut stream) = ws.split(); let pong_deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms); let mut ticker = tokio::time::interval_at( tokio::time::Instant::now() + Duration::from_millis(ping_ms), Duration::from_millis(ping_ms), ); let timed_out = loop { tokio::select! { _ = ticker.tick() => { if tokio::time::Instant::now() >= pong_deadline { break true; } if sink.send(TMsg::Ping(bytes::Bytes::new())).await.is_err() { break false; } } frame = stream.next() => { match frame { Some(Ok(_)) => {} // swallow — no Pong sent _ => break false, } } } }; let _ = done_tx.send(timed_out); }); // Client: connect but never respond to Pings. let (_ws_client, _) = connect_async(format!("ws://{addr}")).await.unwrap(); let result = tokio::time::timeout(Duration::from_millis(timeout_ms + ping_ms * 3), done_rx).await; let timed_out = result .expect("Server must report within expected wall-clock time") .expect("oneshot intact"); assert!( timed_out, "Server must disconnect on keepalive timeout when Pong is withheld" ); } #[tokio::test] async fn keepalive_connection_survives_with_pong_responses() { use futures::{SinkExt, StreamExt}; use std::time::Duration; use tokio::net::TcpListener; use tokio_tungstenite::tungstenite::Message as TMsg; use tokio_tungstenite::{accept_async, connect_async}; let ping_ms = 100u64; let timeout_ms = 250u64; let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let (result_tx, result_rx) = tokio::sync::oneshot::channel::(); // Server: sends Pings, resets deadline on Pong. tokio::spawn(async move { let (tcp, _) = listener.accept().await.unwrap(); let ws = accept_async(tcp).await.unwrap(); let (mut sink, mut stream) = ws.split(); let mut pong_deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms); let mut ticker = tokio::time::interval_at( tokio::time::Instant::now() + Duration::from_millis(ping_ms), Duration::from_millis(ping_ms), ); let timed_out = loop { tokio::select! { _ = ticker.tick() => { if tokio::time::Instant::now() >= pong_deadline { break true; } if sink.send(TMsg::Ping(bytes::Bytes::new())).await.is_err() { break false; } } frame = stream.next() => { match frame { Some(Ok(TMsg::Pong(_))) => { pong_deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms); } None | Some(Err(_)) => break false, // clean close _ => {} } } } }; let _ = result_tx.send(timed_out); }); let (ws_client, _) = connect_async(format!("ws://{addr}")).await.unwrap(); let (mut sink_c, mut stream_c) = ws_client.split(); // Client: respond to every Ping with Pong for several intervals. let respond_task = tokio::spawn(async move { while let Some(Ok(msg)) = stream_c.next().await { if let TMsg::Ping(data) = msg && sink_c.send(TMsg::Pong(data)).await.is_err() { break; } } }); // Run for a few intervals, then drop the client. tokio::time::sleep(Duration::from_millis(ping_ms * 3)).await; respond_task.abort(); let result = tokio::time::timeout(Duration::from_millis(200), result_rx).await; let timed_out = result.unwrap_or(Ok(false)).unwrap_or(false); assert!( !timed_out, "Server must NOT timeout when the client responds to Pings with Pongs" ); } #[tokio::test] async fn two_node_pong_swallow_causes_disconnect_within_timeout() { use futures::{SinkExt, StreamExt}; use std::time::Duration; use tokio::net::TcpListener; use tokio_tungstenite::tungstenite::Message as TMsg; use tokio_tungstenite::{accept_async, connect_async}; let ping_ms = 100u64; let timeout_ms = 250u64; let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); // Node A (listener): sends Pings, never receives Pong. let (a_done_tx, a_done_rx) = tokio::sync::oneshot::channel::(); tokio::spawn(async move { let (tcp, _) = listener.accept().await.unwrap(); let ws = accept_async(tcp).await.unwrap(); let (mut sink, mut stream) = ws.split(); let pong_deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms); let mut ticker = tokio::time::interval_at( tokio::time::Instant::now() + Duration::from_millis(ping_ms), Duration::from_millis(ping_ms), ); let timed_out = loop { tokio::select! { _ = ticker.tick() => { if tokio::time::Instant::now() >= pong_deadline { break true; } if sink.send(TMsg::Ping(bytes::Bytes::new())).await.is_err() { break false; } } frame = stream.next() => { match frame { Some(Ok(_)) => {} // swallow all frames _ => break false, } } } }; let _ = a_done_tx.send(timed_out); }); // Node B: connects, drains frames silently (swallows Pings, never pongs). let (ws_b, _) = connect_async(format!("ws://{addr}")).await.unwrap(); let (_sink_b, mut stream_b) = ws_b.split(); tokio::spawn(async move { while let Some(Ok(_)) = stream_b.next().await {} }); let result = tokio::time::timeout(Duration::from_millis(timeout_ms + ping_ms * 3), a_done_rx).await; let timed_out = result .expect("Node A must report within expected wall-clock time") .expect("channel intact"); assert!( timed_out, "Node A must disconnect due to keepalive timeout when Node B swallows Pongs" ); } }