Files
huskies/server/src/crdt_sync/server.rs
T

1681 lines
65 KiB
Rust
Raw Normal View History

//! Server-side `/crdt-sync` WebSocket handler.
use bft_json_crdt::json_crdt::SignedOp;
use futures::{SinkExt, StreamExt};
use poem::handler;
use poem::http::StatusCode;
use poem::web::Data;
use poem::web::Query;
use poem::web::websocket::{Message as WsMessage, WebSocket};
use serde::Deserialize;
use std::sync::Arc;
use crate::crdt_snapshot;
use crate::crdt_state;
use crate::crdt_wire;
use crate::http::context::AppContext;
use crate::node_identity;
use crate::slog;
use crate::slog_warn;
use super::auth::{REQUIRE_TOKEN, trusted_keys, validate_join_token};
use super::dispatch::{handle_incoming_binary, handle_incoming_text};
use super::wire::{AuthMessage, ChallengeMessage, SyncMessage};
use super::{AUTH_TIMEOUT_SECS, PING_INTERVAL_SECS, PONG_TIMEOUT_SECS};
// ── Server-side WebSocket handler ───────────────────────────────────
/// Query parameters accepted on the `/crdt-sync` WebSocket upgrade request.
#[derive(Deserialize)]
struct SyncQueryParams {
/// Optional bearer token. Required when the server is in token-required mode.
token: Option<String>,
}
/// WebSocket handler for CRDT peer synchronisation.
///
/// Accepts an optional `?token=<bearer-token>` query parameter. When the
/// server is configured with `crdt_require_token = true`, a valid token must
/// be supplied or the upgrade is rejected with HTTP 401. When the server is
/// in open-access mode (the default), a token is optional but still validated
/// if present.
#[handler]
pub async fn crdt_sync_handler(
ws: WebSocket,
_ctx: Data<&Arc<AppContext>>,
remote_addr: &poem::web::RemoteAddr,
Query(params): Query<SyncQueryParams>,
) -> poem::Response {
// ── Bearer-token check (pre-upgrade) ────────────────────────────
let require_token = REQUIRE_TOKEN.get().copied().unwrap_or(false);
match &params.token {
Some(t) => {
if !validate_join_token(t) {
slog!("[crdt-sync] Rejected connection: invalid or expired token");
return poem::Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body("invalid or expired token");
}
}
None if require_token => {
slog!("[crdt-sync] Rejected connection: token required but not provided");
return poem::Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body("token required");
}
None => {}
}
// ── WebSocket upgrade ────────────────────────────────────────────
use poem::IntoResponse as _;
let peer_addr = remote_addr.to_string();
ws.on_upgrade(move |socket| async move {
let (mut sink, mut stream) = socket.split();
slog!("[crdt-sync] Peer connected, starting auth handshake");
// ── Step 1: Send challenge to the connecting peer ─────────
let challenge = node_identity::generate_challenge();
let challenge_msg = ChallengeMessage {
r#type: "challenge".to_string(),
nonce: challenge.clone(),
};
let challenge_json = match serde_json::to_string(&challenge_msg) {
Ok(j) => j,
Err(_) => return,
};
if sink.send(WsMessage::Text(challenge_json)).await.is_err() {
return;
}
// ── Step 2: Await auth reply within timeout ───────────────
let auth_result = tokio::time::timeout(
std::time::Duration::from_secs(AUTH_TIMEOUT_SECS),
stream.next(),
)
.await;
let auth_text = match auth_result {
Ok(Some(Ok(WsMessage::Text(text)))) => text,
Ok(_) | Err(_) => {
// Timeout or connection closed before auth reply.
slog!("[crdt-sync] Auth timeout or connection lost during handshake");
let _ = sink
.send(WsMessage::Close(Some((
poem::web::websocket::CloseCode::from(4001),
"auth_timeout".to_string(),
))))
.await;
let _ = sink.close().await;
return;
}
};
// ── Step 3: Verify auth reply ─────────────────────────────
let auth_msg: AuthMessage = match serde_json::from_str(&auth_text) {
Ok(m) => m,
Err(_) => {
slog!("[crdt-sync] Invalid auth message from peer");
close_with_auth_failed(&mut sink).await;
return;
}
};
// Verify signature AND check allow-list.
let sig_valid =
node_identity::verify_challenge(&auth_msg.pubkey_hex, &challenge, &auth_msg.signature_hex);
let key_trusted = trusted_keys().iter().any(|k| k == &auth_msg.pubkey_hex);
if !sig_valid || !key_trusted {
slog!("[crdt-sync] Auth failed for peer (sig_valid={sig_valid}, key_trusted={key_trusted})");
close_with_auth_failed(&mut sink).await;
return;
}
slog!(
"[crdt-sync] Peer authenticated: {:.12}…",
&auth_msg.pubkey_hex
);
// ── Auth passed — proceed with CRDT sync ──────────────────
// v2 protocol: send our vector clock so the peer can compute the delta.
let our_clock = crdt_state::our_vector_clock().unwrap_or_default();
let clock_msg = SyncMessage::Clock { clock: our_clock };
if let Ok(json) = serde_json::to_string(&clock_msg)
&& sink.send(WsMessage::Text(json)).await.is_err()
{
return;
}
// Wait for the peer's first sync message to determine protocol version.
let first_msg = tokio::time::timeout(
std::time::Duration::from_secs(AUTH_TIMEOUT_SECS),
wait_for_sync_text(&mut stream, &mut sink),
)
.await;
match first_msg {
Ok(Some(SyncMessage::Clock { clock: peer_clock })) => {
// v2 peer — if we have a snapshot and the peer has an empty
// clock (new node), send the snapshot first for onboarding.
if peer_clock.is_empty()
&& let Some(snapshot) = crdt_snapshot::latest_snapshot()
{
let snap_msg = crdt_snapshot::SnapshotMessage::Snapshot(snapshot);
if let Ok(json) = serde_json::to_string(&snap_msg) {
if sink.send(WsMessage::Text(json)).await.is_err() {
return;
}
slog!("[crdt-sync] Sent snapshot to new node for onboarding");
}
}
// Send only the ops the peer is missing.
let delta = crdt_state::ops_since(&peer_clock).unwrap_or_default();
slog!(
"[crdt-sync] v2 delta sync: sending {} ops (peer missing)",
delta.len()
);
let msg = SyncMessage::Bulk { ops: delta };
if let Ok(json) = serde_json::to_string(&msg)
&& sink.send(WsMessage::Text(json)).await.is_err()
{
return;
}
}
Ok(Some(SyncMessage::Bulk { ops })) => {
// v1 peer — apply their bulk and send our full bulk.
let mut applied = 0u64;
for op_json in &ops {
if let Ok(signed_op) = serde_json::from_str::<SignedOp>(op_json)
&& crdt_state::apply_remote_op(signed_op)
{
applied += 1;
}
}
slog!(
"[crdt-sync] v1 bulk sync: received {} ops, applied {applied}",
ops.len()
);
if let Some(all) = crdt_state::all_ops_json() {
let msg = SyncMessage::Bulk { ops: all };
if let Ok(json) = serde_json::to_string(&msg)
&& sink.send(WsMessage::Text(json)).await.is_err()
{
return;
}
}
}
Ok(Some(SyncMessage::Op { op })) => {
// Single op before negotiation — treat as v1.
if let Ok(signed_op) = serde_json::from_str::<SignedOp>(&op) {
crdt_state::apply_remote_op(signed_op);
}
if let Some(all) = crdt_state::all_ops_json() {
let msg = SyncMessage::Bulk { ops: all };
if let Ok(json) = serde_json::to_string(&msg)
&& sink.send(WsMessage::Text(json)).await.is_err()
{
return;
}
}
}
_ => {
// Timeout or error — send full bulk as fallback.
slog!("[crdt-sync] No sync message from peer; sending full bulk as fallback");
if let Some(all) = crdt_state::all_ops_json() {
let msg = SyncMessage::Bulk { ops: all };
if let Ok(json) = serde_json::to_string(&msg)
&& sink.send(WsMessage::Text(json)).await.is_err()
{
return;
}
}
}
}
// Bulk-delta phase complete — signal the peer that we are ready for
// real-time op streaming.
if let Ok(json) = serde_json::to_string(&SyncMessage::Ready)
&& sink.send(WsMessage::Text(json)).await.is_err()
{
return;
}
// Subscribe to new local ops.
let Some(mut op_rx) = crdt_state::subscribe_ops() else {
return;
};
// Buffer for locally-generated ops produced before the peer's `ready`
// arrives. Flushed in-order once the peer signals catch-up.
let mut peer_ready = false;
let mut op_buffer: Vec<bft_json_crdt::json_crdt::SignedOp> = Vec::new();
// ── Keepalive state ───────────────────────────────────────────
let mut pong_deadline = tokio::time::Instant::now()
+ std::time::Duration::from_secs(PONG_TIMEOUT_SECS);
let mut ping_ticker = tokio::time::interval_at(
tokio::time::Instant::now()
+ std::time::Duration::from_secs(PING_INTERVAL_SECS),
std::time::Duration::from_secs(PING_INTERVAL_SECS),
);
loop {
tokio::select! {
// Send periodic Ping and enforce Pong timeout.
_ = ping_ticker.tick() => {
if tokio::time::Instant::now() >= pong_deadline {
slog_warn!(
"[crdt-sync] No pong from peer {} in {}s; disconnecting",
peer_addr,
PONG_TIMEOUT_SECS
);
break;
}
if sink.send(WsMessage::Ping(vec![])).await.is_err() {
break;
}
}
// Forward new local ops to the peer encoded via the wire codec.
result = op_rx.recv() => {
match result {
Ok(signed_op) => {
if peer_ready {
let bytes = crdt_wire::encode(&signed_op);
if sink.send(WsMessage::Binary(bytes)).await.is_err() {
break;
}
} else {
// Buffer until the peer signals ready.
op_buffer.push(signed_op);
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
// The peer cannot keep up; disconnect so it can
// reconnect and receive a fresh bulk state dump.
slog!("[crdt-sync] Slow peer lagged {n} ops; disconnecting");
break;
}
Err(_) => break,
}
}
// Receive ops from the peer.
frame = stream.next() => {
match frame {
Some(Ok(WsMessage::Pong(_))) => {
// Reset the pong deadline on every Pong received.
pong_deadline = tokio::time::Instant::now()
+ std::time::Duration::from_secs(PONG_TIMEOUT_SECS);
}
Some(Ok(WsMessage::Ping(data))) => {
// Respond to peer's Ping so the peer's keepalive passes.
let _ = sink.send(WsMessage::Pong(data)).await;
}
Some(Ok(WsMessage::Text(text))) => {
// Check for the ready signal before other text frames.
if let Ok(SyncMessage::Ready) = serde_json::from_str(&text) {
peer_ready = true;
slog!("[crdt-sync] Peer ready; flushing {} buffered ops", op_buffer.len());
let mut flush_ok = true;
for op in op_buffer.drain(..) {
let bytes = crdt_wire::encode(&op);
if sink.send(WsMessage::Binary(bytes)).await.is_err() {
flush_ok = false;
break;
}
}
if !flush_ok {
break;
}
} else {
// Bulk state dump, legacy op frame, or clock frame.
handle_incoming_text(&text);
}
}
Some(Ok(WsMessage::Binary(bytes))) => {
// Real-time op encoded via wire codec — applied immediately
// regardless of our own ready state.
handle_incoming_binary(&bytes);
}
Some(Ok(WsMessage::Close(_))) | None => break,
_ => {}
}
}
}
}
slog!("[crdt-sync] Peer disconnected");
})
.into_response()
}
/// Wait for the next text-frame sync message from the peer, handling Ping/Pong
/// transparently.
///
/// Returns `None` on connection close or read error.
async fn wait_for_sync_text(
stream: &mut futures::stream::SplitStream<poem::web::websocket::WebSocketStream>,
sink: &mut futures::stream::SplitSink<poem::web::websocket::WebSocketStream, WsMessage>,
) -> Option<SyncMessage> {
loop {
match stream.next().await {
Some(Ok(WsMessage::Text(text))) => {
return serde_json::from_str(&text).ok();
}
Some(Ok(WsMessage::Ping(data))) => {
let _ = sink.send(WsMessage::Pong(data)).await;
}
Some(Ok(WsMessage::Pong(_))) => continue,
_ => return None,
}
}
}
/// Close the WebSocket with a generic `auth_failed` reason.
///
/// The close reason is intentionally the same for all auth failures
/// (bad signature, untrusted key, malformed message) to avoid leaking
/// which check failed.
async fn close_with_auth_failed(
sink: &mut futures::stream::SplitSink<poem::web::websocket::WebSocketStream, WsMessage>,
) {
let _ = sink
.send(WsMessage::Close(Some((
poem::web::websocket::CloseCode::from(4002),
"auth_failed".to_string(),
))))
.await;
let _ = sink.close().await;
}
/// Process an incoming text-frame sync message from a peer.
///
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
#[allow(dead_code)]
enum AuthListenerResult {
Authenticated(String),
AuthFailed(String),
AuthTimeout,
ConnectionLost,
PeerClosedEarly(Option<String>),
}
#[test]
fn peer_receives_op_encoded_via_wire_codec() {
use bft_json_crdt::json_crdt::BaseCrdt;
use bft_json_crdt::keypair::make_keypair;
use bft_json_crdt::op::ROOT_ID;
use serde_json::json;
use crate::crdt_state::PipelineDoc;
use crate::crdt_wire;
let kp = make_keypair();
let mut crdt = BaseCrdt::<PipelineDoc>::new(&kp);
let item: bft_json_crdt::json_crdt::JsonValue = json!({
"story_id": "506_story_lifecycle_test",
"stage": "1_backlog",
"name": "Lifecycle Test",
"agent": "",
"retry_count": 0.0,
"blocked": false,
"depends_on": "",
"claimed_by": "",
"claimed_at": 0.0,
})
.into();
let op = crdt.doc.items.insert(ROOT_ID, item).sign(&kp);
// Simulate what the broadcast handler does: encode via wire codec.
let bytes = crdt_wire::encode(&op);
// The bytes must be a versioned JSON envelope, not a SyncMessage wrapper.
let text = std::str::from_utf8(&bytes).expect("wire output is valid UTF-8");
assert!(
text.contains("\"v\":1"),
"wire codec version tag must be present: {text}"
);
assert!(
!text.contains("\"type\":\"op\""),
"must not be wrapped in SyncMessage: {text}"
);
// The receiving peer can decode and apply the op.
let decoded = crdt_wire::decode(&bytes).expect("decode must succeed");
assert_eq!(op, decoded);
}
#[tokio::test]
async fn multiple_peers_all_receive_broadcast_op() {
use bft_json_crdt::json_crdt::BaseCrdt;
use bft_json_crdt::keypair::make_keypair;
use bft_json_crdt::op::ROOT_ID;
use serde_json::json;
use tokio::sync::broadcast;
use crate::crdt_state::PipelineDoc;
use crate::crdt_wire;
// Create a broadcast channel (analogous to SYNC_TX).
let (tx, _) = broadcast::channel::<SignedOp>(16);
let mut rx_peer1 = tx.subscribe();
let mut rx_peer2 = tx.subscribe();
let kp = make_keypair();
let mut crdt = BaseCrdt::<PipelineDoc>::new(&kp);
let item: bft_json_crdt::json_crdt::JsonValue = json!({
"story_id": "506_story_multi_peer_test",
"stage": "1_backlog",
"name": "Multi-Peer Test",
"agent": "",
"retry_count": 0.0,
"blocked": false,
"depends_on": "",
"claimed_by": "",
"claimed_at": 0.0,
})
.into();
let op = crdt.doc.items.insert(ROOT_ID, item).sign(&kp);
// Broadcast one op.
tx.send(op.clone()).expect("send must succeed");
// Both peers receive the same op.
let received1 = rx_peer1.recv().await.expect("peer 1 must receive");
let received2 = rx_peer2.recv().await.expect("peer 2 must receive");
assert_eq!(received1, op);
assert_eq!(received2, op);
// Both encode identically via wire codec.
let bytes1 = crdt_wire::encode(&received1);
let bytes2 = crdt_wire::encode(&received2);
assert_eq!(bytes1, bytes2, "wire-encoded bytes must be identical");
}
#[test]
fn disconnected_peer_does_not_panic() {
use bft_json_crdt::json_crdt::BaseCrdt;
use bft_json_crdt::keypair::make_keypair;
use bft_json_crdt::op::ROOT_ID;
use serde_json::json;
use tokio::sync::broadcast;
use crate::crdt_state::PipelineDoc;
let (tx, rx) = broadcast::channel::<SignedOp>(16);
// Drop the receiver to simulate a peer that disconnected.
drop(rx);
let kp = make_keypair();
let mut crdt = BaseCrdt::<PipelineDoc>::new(&kp);
let item: bft_json_crdt::json_crdt::JsonValue = json!({
"story_id": "506_story_disconnect_test",
"stage": "1_backlog",
"name": "Disconnect Test",
"agent": "",
"retry_count": 0.0,
"blocked": false,
"depends_on": "",
"claimed_by": "",
"claimed_at": 0.0,
})
.into();
let op = crdt.doc.items.insert(ROOT_ID, item).sign(&kp);
// Sending to a channel with no receivers returns an error; must not panic.
let _ = tx.send(op);
}
#[tokio::test]
async fn lagged_peer_gets_lagged_error() {
use bft_json_crdt::json_crdt::BaseCrdt;
use bft_json_crdt::keypair::make_keypair;
use bft_json_crdt::op::ROOT_ID;
use serde_json::json;
use tokio::sync::broadcast;
use crate::crdt_state::PipelineDoc;
// Tiny capacity so we can trigger Lagged easily.
let (tx, mut rx) = broadcast::channel::<SignedOp>(2);
let kp = make_keypair();
let mut crdt = BaseCrdt::<PipelineDoc>::new(&kp);
let item: bft_json_crdt::json_crdt::JsonValue = json!({
"story_id": "506_story_lag_test",
"stage": "1_backlog",
"name": "Lag Test",
"agent": "",
"retry_count": 0.0,
"blocked": false,
"depends_on": "",
"claimed_by": "",
"claimed_at": 0.0,
})
.into();
let op1 = crdt.doc.items.insert(ROOT_ID, item).sign(&kp);
crdt.apply(op1.clone());
// Overflow the tiny buffer by sending more ops than the capacity.
let op2 = crdt.doc.items[0]
.stage
.set("2_current".to_string())
.sign(&kp);
crdt.apply(op2.clone());
let op3 = crdt.doc.items[0].stage.set("3_qa".to_string()).sign(&kp);
crdt.apply(op3.clone());
let op4 = crdt.doc.items[0].stage.set("4_merge".to_string()).sign(&kp);
crdt.apply(op4.clone());
// Send more ops than the channel capacity without consuming.
let _ = tx.send(op1);
let _ = tx.send(op2);
let _ = tx.send(op3);
let _ = tx.send(op4);
// The slow peer should now see a Lagged error on next recv.
// Consume until we hit Lagged or run out.
let mut got_lagged = false;
for _ in 0..10 {
match rx.recv().await {
Err(broadcast::error::RecvError::Lagged(_)) => {
got_lagged = true;
break;
}
Ok(_) => continue,
Err(broadcast::error::RecvError::Closed) => break,
}
}
assert!(
got_lagged,
"slow peer must receive a Lagged error when channel overflows"
);
}
#[tokio::test]
async fn e2e_convergence_two_websocket_nodes() {
use bft_json_crdt::json_crdt::{BaseCrdt, CrdtNode, JsonValue as JV, OpState};
use bft_json_crdt::keypair::make_keypair;
use bft_json_crdt::op::ROOT_ID;
use futures::{SinkExt, StreamExt};
use serde_json::json;
use std::sync::{Arc, Mutex};
use tokio::net::TcpListener;
use tokio_tungstenite::tungstenite::Message as TMsg;
use tokio_tungstenite::{accept_async, connect_async};
use crate::crdt_state::PipelineDoc;
// ── Node A: build local state ──────────────────────────────────────
let kp_a = make_keypair();
let mut crdt_a = BaseCrdt::<PipelineDoc>::new(&kp_a);
let item: bft_json_crdt::json_crdt::JsonValue = json!({
"story_id": "508_e2e_convergence",
"stage": "2_current",
"name": "E2E Convergence Test",
"agent": "",
"retry_count": 0.0,
"blocked": false,
"depends_on": "",
"claimed_by": "",
"claimed_at": 0.0,
})
.into();
let op1 = crdt_a.doc.items.insert(ROOT_ID, item).sign(&kp_a);
crdt_a.apply(op1.clone());
// Serialise A's full state as a bulk message.
let op1_json = serde_json::to_string(&op1).unwrap();
let bulk_msg = SyncMessage::Bulk {
ops: vec![op1_json],
};
let bulk_wire = serde_json::to_string(&bulk_msg).unwrap();
// ── Start Node A's WebSocket server on a random port ───────────────
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let bulk_to_send = bulk_wire.clone();
let received_by_a: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(vec![]));
let received_by_a_clone = received_by_a.clone();
tokio::spawn(async move {
let (tcp_stream, _) = listener.accept().await.unwrap();
let ws_stream = accept_async(tcp_stream).await.unwrap();
let (mut sink, mut stream) = ws_stream.split();
// Send bulk state to the connecting peer.
sink.send(TMsg::Text(bulk_to_send.into())).await.unwrap();
// Also listen for ops sent by the peer.
if let Some(Ok(TMsg::Text(txt))) = stream.next().await {
received_by_a_clone.lock().unwrap().push(txt.to_string());
}
});
// ── Node B: connect to Node A and exchange state ───────────────────
let kp_b = make_keypair();
let mut crdt_b = BaseCrdt::<PipelineDoc>::new(&kp_b);
let url = format!("ws://{addr}");
let (ws_b, _) = connect_async(&url).await.unwrap();
let (mut sink_b, mut stream_b) = ws_b.split();
// Node B receives bulk from A.
if let Some(Ok(TMsg::Text(txt))) = stream_b.next().await {
let msg: SyncMessage = serde_json::from_str(txt.as_str()).unwrap();
match msg {
SyncMessage::Bulk { ops } => {
for op_str in &ops {
let signed: bft_json_crdt::json_crdt::SignedOp =
serde_json::from_str(op_str).unwrap();
let r = crdt_b.apply(signed);
assert!(r == OpState::Ok || r == OpState::AlreadySeen);
}
}
_ => panic!("Expected Bulk from Node A"),
}
}
// Node B also creates a new op and sends it to A.
let item_b: bft_json_crdt::json_crdt::JsonValue = json!({
"story_id": "508_e2e_convergence_b",
"stage": "1_backlog",
"name": "E2E Convergence B",
"agent": "",
"retry_count": 0.0,
"blocked": false,
"depends_on": "",
"claimed_by": "",
"claimed_at": 0.0,
})
.into();
let op_b1 = crdt_b.doc.items.insert(ROOT_ID, item_b).sign(&kp_b);
crdt_b.apply(op_b1.clone());
let op_b1_json = serde_json::to_string(&op_b1).unwrap();
let msg_to_a = SyncMessage::Op { op: op_b1_json };
sink_b
.send(TMsg::Text(serde_json::to_string(&msg_to_a).unwrap().into()))
.await
.unwrap();
// Wait a moment for Node A to process.
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
// ── Assert convergence ─────────────────────────────────────────────
// Node B received Node A's item.
assert_eq!(
crdt_b.doc.items.view().len(),
2,
"Node B must see both items after sync"
);
let has_a_item = crdt_b
.doc
.items
.view()
.iter()
.any(|item| item.story_id.view() == JV::String("508_e2e_convergence".to_string()));
assert!(has_a_item, "Node B must have Node A's item");
// Node A received Node B's op via the WebSocket.
let a_received = received_by_a.lock().unwrap();
assert!(
!a_received.is_empty(),
"Node A must have received an op from Node B"
);
}
#[tokio::test]
async fn e2e_partition_healing_websocket() {
use bft_json_crdt::json_crdt::{BaseCrdt, CrdtNode, JsonValue as JV};
use bft_json_crdt::keypair::make_keypair;
use bft_json_crdt::op::ROOT_ID;
use futures::{SinkExt, StreamExt};
use serde_json::json;
use tokio::net::TcpListener;
use tokio_tungstenite::tungstenite::Message as TMsg;
use tokio_tungstenite::{accept_async, connect_async};
use crate::crdt_state::PipelineDoc;
// ── Phase 1: Both nodes start with op_a1 (before partition) ───────
let kp_a = make_keypair();
let kp_b = make_keypair();
let mut crdt_a = BaseCrdt::<PipelineDoc>::new(&kp_a);
let mut crdt_b = BaseCrdt::<PipelineDoc>::new(&kp_b);
let item_a: bft_json_crdt::json_crdt::JsonValue = json!({
"story_id": "508_heal_a",
"stage": "1_backlog",
"name": "Heal A",
"agent": "",
"retry_count": 0.0,
"blocked": false,
"depends_on": "",
"claimed_by": "",
"claimed_at": 0.0,
})
.into();
let op_a1 = crdt_a.doc.items.insert(ROOT_ID, item_a).sign(&kp_a);
crdt_a.apply(op_a1.clone());
// B also starts with op_a1 (shared state before partition).
crdt_b.apply(op_a1.clone());
// ── Phase 2: Partition — each side mutates independently ──────────
// A advances its story stage.
let op_a2 = crdt_a.doc.items[0]
.stage
.set("2_current".to_string())
.sign(&kp_a);
crdt_a.apply(op_a2.clone());
// B inserts a new story that A doesn't know about yet.
let item_b: bft_json_crdt::json_crdt::JsonValue = json!({
"story_id": "508_heal_b",
"stage": "1_backlog",
"name": "Heal B",
"agent": "",
"retry_count": 0.0,
"blocked": false,
"depends_on": "",
"claimed_by": "",
"claimed_at": 0.0,
})
.into();
let op_b1 = crdt_b.doc.items.insert(ROOT_ID, item_b).sign(&kp_b);
crdt_b.apply(op_b1.clone());
// Collect B's full state as bulk (what it will send on reconnect).
let b_ops: Vec<String> = [&op_a1, &op_b1]
.iter()
.map(|op| serde_json::to_string(op).unwrap())
.collect();
let b_bulk_wire = serde_json::to_string(&SyncMessage::Bulk { ops: b_ops }).unwrap();
// Collect A's full state as bulk (what it will send on reconnect).
let a_ops: Vec<String> = [&op_a1, &op_a2]
.iter()
.map(|op| serde_json::to_string(op).unwrap())
.collect();
let a_bulk_wire = serde_json::to_string(&SyncMessage::Bulk { ops: a_ops }).unwrap();
// ── Phase 3: Reconnect — use a real WebSocket to exchange bulk ────
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let a_bulk_to_send = a_bulk_wire.clone();
let a_received_bulk: std::sync::Arc<std::sync::Mutex<Option<String>>> =
std::sync::Arc::new(std::sync::Mutex::new(None));
let a_received_clone = a_received_bulk.clone();
tokio::spawn(async move {
let (tcp, _) = listener.accept().await.unwrap();
let ws = accept_async(tcp).await.unwrap();
let (mut sink, mut stream) = ws.split();
// A sends its bulk state.
sink.send(TMsg::Text(a_bulk_to_send.into())).await.unwrap();
// A receives B's bulk state.
if let Some(Ok(TMsg::Text(txt))) = stream.next().await {
*a_received_clone.lock().unwrap() = Some(txt.to_string());
}
});
// B connects, exchanges bulk state.
let (ws_b, _) = connect_async(format!("ws://{addr}")).await.unwrap();
let (mut sink_b, mut stream_b) = ws_b.split();
// B receives A's bulk and applies it.
if let Some(Ok(TMsg::Text(txt))) = stream_b.next().await {
let msg: SyncMessage = serde_json::from_str(txt.as_str()).unwrap();
if let SyncMessage::Bulk { ops } = msg {
for op_str in &ops {
let signed: bft_json_crdt::json_crdt::SignedOp =
serde_json::from_str(op_str).unwrap();
let _ = crdt_b.apply(signed);
}
}
}
// B sends its bulk state to A.
sink_b.send(TMsg::Text(b_bulk_wire.into())).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
// Apply A's received ops into crdt_a.
if let Some(bulk_str) = a_received_bulk.lock().unwrap().take() {
let msg: SyncMessage = serde_json::from_str(&bulk_str).unwrap();
if let SyncMessage::Bulk { ops } = msg {
for op_str in &ops {
let signed: bft_json_crdt::json_crdt::SignedOp =
serde_json::from_str(op_str).unwrap();
let _ = crdt_a.apply(signed);
}
}
}
// ── Assert convergence ─────────────────────────────────────────────
// Both nodes must have 2 items.
assert_eq!(
crdt_a.doc.items.view().len(),
2,
"A must have 2 items after healing"
);
assert_eq!(
crdt_b.doc.items.view().len(),
2,
"B must have 2 items after healing"
);
// A must see B's story.
let b_story_on_a = crdt_a
.doc
.items
.view()
.iter()
.any(|item| item.story_id.view() == JV::String("508_heal_b".to_string()));
assert!(b_story_on_a, "A must have B's story after healing");
// B must see A's stage advance.
let a_story_on_b = crdt_b
.doc
.items
.view()
.iter()
.any(|item| item.story_id.view() == JV::String("508_heal_a".to_string()));
assert!(a_story_on_b, "B must have A's story after healing");
// CRDT views must be byte-identical (convergence).
let view_a = serde_json::to_string(&CrdtNode::view(&crdt_a.doc.items)).unwrap();
let view_b = serde_json::to_string(&CrdtNode::view(&crdt_b.doc.items)).unwrap();
assert_eq!(
view_a, view_b,
"Both nodes must converge to identical state"
);
}
async fn start_auth_listener(
trusted_keys: Vec<String>,
) -> (
std::net::SocketAddr,
tokio::sync::oneshot::Receiver<AuthListenerResult>,
) {
use tokio::net::TcpListener;
use tokio_tungstenite::accept_async;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (result_tx, result_rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
let (tcp_stream, _) = listener.accept().await.unwrap();
let ws_stream = accept_async(tcp_stream).await.unwrap();
let (mut sink, mut stream) = ws_stream.split();
use tokio_tungstenite::tungstenite::Message as TMsg;
// Step 1: Send challenge.
let challenge = crate::node_identity::generate_challenge();
let challenge_msg = super::ChallengeMessage {
r#type: "challenge".to_string(),
nonce: challenge.clone(),
};
let challenge_json = serde_json::to_string(&challenge_msg).unwrap();
if sink.send(TMsg::Text(challenge_json.into())).await.is_err() {
let _ = result_tx.send(AuthListenerResult::ConnectionLost);
return;
}
// Step 2: Await auth reply (10s timeout).
let auth_frame =
tokio::time::timeout(std::time::Duration::from_secs(10), stream.next()).await;
let auth_text = match auth_frame {
Ok(Some(Ok(TMsg::Text(t)))) => t.to_string(),
Ok(Some(Ok(TMsg::Close(reason)))) => {
let _ = result_tx.send(AuthListenerResult::PeerClosedEarly(
reason.map(|r| r.reason.to_string()),
));
return;
}
_ => {
let _ = sink
.send(TMsg::Close(Some(tokio_tungstenite::tungstenite::protocol::CloseFrame {
code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::from(4001),
reason: "auth_timeout".into(),
})))
.await;
let _ = result_tx.send(AuthListenerResult::AuthTimeout);
return;
}
};
// Step 3: Verify.
let auth_msg: super::AuthMessage = match serde_json::from_str(&auth_text) {
Ok(m) => m,
Err(_) => {
let _ = close_listener_auth_failed(&mut sink).await;
let _ = result_tx.send(AuthListenerResult::AuthFailed("bad_json".into()));
return;
}
};
let sig_valid = crate::node_identity::verify_challenge(
&auth_msg.pubkey_hex,
&challenge,
&auth_msg.signature_hex,
);
let key_trusted = trusted_keys.iter().any(|k| k == &auth_msg.pubkey_hex);
if !sig_valid || !key_trusted {
let _ = close_listener_auth_failed(&mut sink).await;
let _ = result_tx.send(AuthListenerResult::AuthFailed(format!(
"sig_valid={sig_valid}, key_trusted={key_trusted}"
)));
return;
}
// Auth passed! Send a bulk state with one op to prove sync works.
let kp = bft_json_crdt::keypair::make_keypair();
let mut crdt =
bft_json_crdt::json_crdt::BaseCrdt::<crate::crdt_state::PipelineDoc>::new(&kp);
let item: bft_json_crdt::json_crdt::JsonValue = serde_json::json!({
"story_id": "628_auth_test_item",
"stage": "1_backlog",
"name": "Auth Test",
"agent": "",
"retry_count": 0.0,
"blocked": false,
"depends_on": "",
"claimed_by": "",
"claimed_at": 0.0,
})
.into();
let op = crdt
.doc
.items
.insert(bft_json_crdt::op::ROOT_ID, item)
.sign(&kp);
let op_json = serde_json::to_string(&op).unwrap();
let bulk = super::SyncMessage::Bulk { ops: vec![op_json] };
let bulk_json = serde_json::to_string(&bulk).unwrap();
let _ = sink.send(TMsg::Text(bulk_json.into())).await;
let _ = result_tx.send(AuthListenerResult::Authenticated(auth_msg.pubkey_hex));
});
(addr, result_rx)
}
async fn close_listener_auth_failed(
sink: &mut futures::stream::SplitSink<
tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
tokio_tungstenite::tungstenite::Message,
>,
) {
use tokio_tungstenite::tungstenite::protocol::CloseFrame;
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
let _ = sink
.send(tokio_tungstenite::tungstenite::Message::Close(Some(
CloseFrame {
code: CloseCode::from(4002),
reason: "auth_failed".into(),
},
)))
.await;
}
#[tokio::test]
async fn auth_happy_path_handshake_and_sync() {
use bft_json_crdt::keypair::make_keypair;
use futures::{SinkExt, StreamExt};
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::Message as TMsg;
let connector_kp = make_keypair();
let connector_pubkey = crate::node_identity::public_key_hex(&connector_kp);
// Start listener that trusts the connector's pubkey.
let (addr, result_rx) = start_auth_listener(vec![connector_pubkey.clone()]).await;
// Connect and do the handshake.
let url = format!("ws://{addr}");
let (ws, _) = connect_async(&url).await.unwrap();
let (mut sink, mut stream) = ws.split();
// Receive challenge.
let challenge_frame = stream.next().await.unwrap().unwrap();
let challenge_text = match challenge_frame {
TMsg::Text(t) => t.to_string(),
other => panic!("Expected text frame, got {other:?}"),
};
let challenge_msg: super::ChallengeMessage = serde_json::from_str(&challenge_text).unwrap();
assert_eq!(challenge_msg.r#type, "challenge");
assert_eq!(
challenge_msg.nonce.len(),
64,
"Challenge must be 64 hex chars"
);
// Sign and reply.
let sig = crate::node_identity::sign_challenge(&connector_kp, &challenge_msg.nonce);
let auth_msg = super::AuthMessage {
r#type: "auth".to_string(),
pubkey_hex: connector_pubkey.clone(),
signature_hex: sig,
};
let auth_json = serde_json::to_string(&auth_msg).unwrap();
sink.send(TMsg::Text(auth_json.into())).await.unwrap();
// After auth, we should receive a bulk sync message with at least one op.
let bulk_frame = tokio::time::timeout(std::time::Duration::from_secs(5), stream.next())
.await
.expect("should receive bulk within 5s")
.unwrap()
.unwrap();
let bulk_text = match bulk_frame {
TMsg::Text(t) => t.to_string(),
other => panic!("Expected bulk text frame, got {other:?}"),
};
let bulk_msg: super::SyncMessage = serde_json::from_str(&bulk_text).unwrap();
match bulk_msg {
super::SyncMessage::Bulk { ops } => {
assert!(
!ops.is_empty(),
"Bulk sync must contain at least one op after successful auth"
);
// Verify we can deserialize the op.
let _signed: bft_json_crdt::json_crdt::SignedOp =
serde_json::from_str(&ops[0]).unwrap();
}
_ => panic!("Expected Bulk message after auth"),
}
// Verify listener also reports success.
let listener_result = result_rx.await.unwrap();
match listener_result {
AuthListenerResult::Authenticated(pubkey) => {
assert_eq!(pubkey, connector_pubkey);
}
other => panic!("Expected Authenticated, got {other:?}"),
}
}
#[tokio::test]
async fn auth_untrusted_pubkey_rejected() {
use bft_json_crdt::keypair::make_keypair;
use futures::{SinkExt, StreamExt};
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::Message as TMsg;
let connector_kp = make_keypair();
let connector_pubkey = crate::node_identity::public_key_hex(&connector_kp);
// Listener trusts a DIFFERENT key, not the connector's.
let other_kp = make_keypair();
let other_pubkey = crate::node_identity::public_key_hex(&other_kp);
let (addr, result_rx) = start_auth_listener(vec![other_pubkey]).await;
let url = format!("ws://{addr}");
let (ws, _) = connect_async(&url).await.unwrap();
let (mut sink, mut stream) = ws.split();
// Receive challenge and sign with our (untrusted) key.
let challenge_frame = stream.next().await.unwrap().unwrap();
let challenge_text = match challenge_frame {
TMsg::Text(t) => t.to_string(),
_ => panic!("Expected text frame"),
};
let challenge_msg: super::ChallengeMessage = serde_json::from_str(&challenge_text).unwrap();
let sig = crate::node_identity::sign_challenge(&connector_kp, &challenge_msg.nonce);
let auth_msg = super::AuthMessage {
r#type: "auth".to_string(),
pubkey_hex: connector_pubkey,
signature_hex: sig,
};
sink.send(TMsg::Text(serde_json::to_string(&auth_msg).unwrap().into()))
.await
.unwrap();
// Should receive a close frame with auth_failed.
let close_frame = tokio::time::timeout(std::time::Duration::from_secs(5), stream.next())
.await
.expect("should receive close within 5s");
match close_frame {
Some(Ok(TMsg::Close(Some(frame)))) => {
assert_eq!(
&*frame.reason, "auth_failed",
"Close reason must be 'auth_failed'"
);
}
Some(Ok(TMsg::Close(None))) => {
// Some implementations omit the close frame payload — that's acceptable
// as long as no sync data was sent.
}
other => {
// Connection dropped without close frame is also acceptable.
// The key assertion is below: no ops were exchanged.
let _ = other;
}
}
// Verify listener reports auth failure.
let listener_result = result_rx.await.unwrap();
match listener_result {
AuthListenerResult::AuthFailed(reason) => {
assert!(reason.contains("key_trusted=false"), "Reason: {reason}");
}
other => panic!("Expected AuthFailed, got {other:?}"),
}
}
#[tokio::test]
async fn auth_bad_signature_rejected() {
use bft_json_crdt::keypair::make_keypair;
use futures::{SinkExt, StreamExt};
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::Message as TMsg;
let legitimate_kp = make_keypair();
let legitimate_pubkey = crate::node_identity::public_key_hex(&legitimate_kp);
// A different keypair that will sign the challenge (wrong key).
let impersonator_kp = make_keypair();
// Listener trusts the legitimate pubkey.
let (addr, result_rx) = start_auth_listener(vec![legitimate_pubkey.clone()]).await;
let url = format!("ws://{addr}");
let (ws, _) = connect_async(&url).await.unwrap();
let (mut sink, mut stream) = ws.split();
// Receive challenge.
let challenge_frame = stream.next().await.unwrap().unwrap();
let challenge_text = match challenge_frame {
TMsg::Text(t) => t.to_string(),
_ => panic!("Expected text frame"),
};
let challenge_msg: super::ChallengeMessage = serde_json::from_str(&challenge_text).unwrap();
// Sign with the WRONG keypair but claim to be the legitimate pubkey.
let bad_sig = crate::node_identity::sign_challenge(&impersonator_kp, &challenge_msg.nonce);
let auth_msg = super::AuthMessage {
r#type: "auth".to_string(),
pubkey_hex: legitimate_pubkey,
signature_hex: bad_sig,
};
sink.send(TMsg::Text(serde_json::to_string(&auth_msg).unwrap().into()))
.await
.unwrap();
// Should be rejected.
let close_frame = tokio::time::timeout(std::time::Duration::from_secs(5), stream.next())
.await
.expect("should receive close within 5s");
match close_frame {
Some(Ok(TMsg::Close(Some(frame)))) => {
assert_eq!(
&*frame.reason, "auth_failed",
"Close reason must be 'auth_failed' — same as untrusted key"
);
}
_ => {
// Connection closed is acceptable.
}
}
// Verify listener reports auth failure with sig_valid=false.
let listener_result = result_rx.await.unwrap();
match listener_result {
AuthListenerResult::AuthFailed(reason) => {
assert!(reason.contains("sig_valid=false"), "Reason: {reason}");
}
other => panic!("Expected AuthFailed, got {other:?}"),
}
}
#[tokio::test]
async fn auth_replay_protection_fresh_nonces() {
use futures::StreamExt;
use tokio::net::TcpListener;
use tokio_tungstenite::tungstenite::Message as TMsg;
use tokio_tungstenite::{accept_async, connect_async};
// Start a listener that sends challenges but doesn't complete auth.
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (nonce_tx, mut nonce_rx) = tokio::sync::mpsc::channel::<String>(2);
tokio::spawn(async move {
for _ in 0..2 {
let (tcp, _) = listener.accept().await.unwrap();
let ws = accept_async(tcp).await.unwrap();
let (mut sink, _stream) = ws.split();
let challenge = crate::node_identity::generate_challenge();
let msg = super::ChallengeMessage {
r#type: "challenge".to_string(),
nonce: challenge.clone(),
};
let json = serde_json::to_string(&msg).unwrap();
let _ = sink.send(TMsg::Text(json.into())).await;
let _ = nonce_tx.send(challenge).await;
}
});
// Connect twice and collect the nonces.
let mut nonces = Vec::new();
for _ in 0..2 {
let url = format!("ws://{addr}");
let (ws, _) = connect_async(&url).await.unwrap();
let (_sink, mut stream) = ws.split();
let frame = stream.next().await.unwrap().unwrap();
let text = match frame {
TMsg::Text(t) => t.to_string(),
_ => panic!("Expected text"),
};
let msg: super::ChallengeMessage = serde_json::from_str(&text).unwrap();
nonces.push(msg.nonce);
// Drop connection so listener accepts the next one.
drop(stream);
}
// Also collect nonces from the listener side.
let server_nonce_1 = nonce_rx.recv().await.unwrap();
let server_nonce_2 = nonce_rx.recv().await.unwrap();
assert_ne!(
nonces[0], nonces[1],
"Consecutive challenges must be different"
);
assert_ne!(
server_nonce_1, server_nonce_2,
"Server must generate fresh nonce per accept"
);
assert_eq!(nonces[0], server_nonce_1, "Client/server nonces must match");
assert_eq!(nonces[1], server_nonce_2, "Client/server nonces must match");
}
#[test]
fn keepalive_constants_are_correct() {
assert_eq!(
super::super::PING_INTERVAL_SECS,
30,
"Ping interval must be 30 seconds"
);
assert_eq!(
super::super::PONG_TIMEOUT_SECS,
60,
"Pong timeout must be 60 seconds"
);
}
#[test]
fn agent_mode_heartbeat_interval_unchanged() {
assert_eq!(
crate::agent_mode::SCAN_INTERVAL_SECS,
15,
"Agent-mode heartbeat interval must remain 15s"
);
}
#[test]
fn reconnect_backoff_constants_unchanged() {
assert_eq!(
super::super::RENDEZVOUS_ERROR_THRESHOLD,
10,
"Backoff threshold must still be 10"
);
}
#[tokio::test]
async fn server_sends_ping_to_peer_at_interval() {
use futures::{SinkExt, StreamExt};
use std::time::Duration;
use tokio::net::TcpListener;
use tokio_tungstenite::tungstenite::Message as TMsg;
use tokio_tungstenite::{accept_async, connect_async};
let ping_ms = 100u64;
let timeout_ms = 400u64;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
// Server task: keepalive sender with short intervals.
tokio::spawn(async move {
let (tcp, _) = listener.accept().await.unwrap();
let ws = accept_async(tcp).await.unwrap();
let (mut sink, mut stream) = ws.split();
let mut pong_deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms);
let mut ticker = tokio::time::interval_at(
tokio::time::Instant::now() + Duration::from_millis(ping_ms),
Duration::from_millis(ping_ms),
);
loop {
tokio::select! {
_ = ticker.tick() => {
if tokio::time::Instant::now() >= pong_deadline { break; }
if sink.send(TMsg::Ping(bytes::Bytes::new())).await.is_err() { break; }
}
frame = stream.next() => {
match frame {
Some(Ok(TMsg::Pong(_))) => {
pong_deadline = tokio::time::Instant::now()
+ Duration::from_millis(timeout_ms);
}
None | Some(Err(_)) => break,
_ => {}
}
}
}
}
});
let (ws_client, _) = connect_async(format!("ws://{addr}")).await.unwrap();
let (_sink_c, mut stream_c) = ws_client.split();
// Wait for more than one ping interval.
tokio::time::sleep(Duration::from_millis(ping_ms * 2)).await;
// Client should receive a Ping from the server.
let frame = tokio::time::timeout(Duration::from_millis(200), stream_c.next()).await;
let got_ping = matches!(frame, Ok(Some(Ok(TMsg::Ping(_)))));
assert!(
got_ping,
"Client must receive a Ping frame from the server after the ping interval"
);
}
#[tokio::test]
async fn client_sends_ping_to_server_at_interval() {
use futures::{SinkExt, StreamExt};
use std::time::Duration;
use tokio::net::TcpListener;
use tokio_tungstenite::tungstenite::Message as TMsg;
use tokio_tungstenite::{accept_async, connect_async};
let ping_ms = 100u64;
let timeout_ms = 400u64;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (ping_tx, ping_rx) = tokio::sync::oneshot::channel::<()>();
// Server task: wait for the first Ping the client sends.
tokio::spawn(async move {
let (tcp, _) = listener.accept().await.unwrap();
let ws = accept_async(tcp).await.unwrap();
let (_sink, mut stream) = ws.split();
loop {
match stream.next().await {
Some(Ok(TMsg::Ping(_))) => {
let _ = ping_tx.send(());
break;
}
Some(Ok(_)) => continue,
_ => break,
}
}
});
let (ws_client, _) = connect_async(format!("ws://{addr}")).await.unwrap();
let (mut sink_c, mut stream_c) = ws_client.split();
// Client keepalive task.
tokio::spawn(async move {
let mut pong_deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms);
let mut ticker = tokio::time::interval_at(
tokio::time::Instant::now() + Duration::from_millis(ping_ms),
Duration::from_millis(ping_ms),
);
loop {
tokio::select! {
_ = ticker.tick() => {
if tokio::time::Instant::now() >= pong_deadline { break; }
if sink_c.send(TMsg::Ping(bytes::Bytes::new())).await.is_err() { break; }
}
frame = stream_c.next() => {
match frame {
Some(Ok(TMsg::Pong(_))) => {
pong_deadline = tokio::time::Instant::now()
+ Duration::from_millis(timeout_ms);
}
None | Some(Err(_)) => break,
_ => {}
}
}
}
}
});
let result = tokio::time::timeout(Duration::from_millis(ping_ms * 3), ping_rx).await;
assert!(
result.is_ok(),
"Server must receive a Ping from the client after the ping interval"
);
}
#[tokio::test]
async fn keepalive_disconnects_when_pong_withheld() {
use futures::{SinkExt, StreamExt};
use std::time::Duration;
use tokio::net::TcpListener;
use tokio_tungstenite::tungstenite::Message as TMsg;
use tokio_tungstenite::{accept_async, connect_async};
let ping_ms = 100u64;
let timeout_ms = 250u64;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (done_tx, done_rx) = tokio::sync::oneshot::channel::<bool>();
// Server: sends Pings, never receives Pong (client swallows all).
tokio::spawn(async move {
let (tcp, _) = listener.accept().await.unwrap();
let ws = accept_async(tcp).await.unwrap();
let (mut sink, mut stream) = ws.split();
let pong_deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms);
let mut ticker = tokio::time::interval_at(
tokio::time::Instant::now() + Duration::from_millis(ping_ms),
Duration::from_millis(ping_ms),
);
let timed_out = loop {
tokio::select! {
_ = ticker.tick() => {
if tokio::time::Instant::now() >= pong_deadline { break true; }
if sink.send(TMsg::Ping(bytes::Bytes::new())).await.is_err() {
break false;
}
}
frame = stream.next() => {
match frame {
Some(Ok(_)) => {} // swallow — no Pong sent
_ => break false,
}
}
}
};
let _ = done_tx.send(timed_out);
});
// Client: connect but never respond to Pings.
let (_ws_client, _) = connect_async(format!("ws://{addr}")).await.unwrap();
let result =
tokio::time::timeout(Duration::from_millis(timeout_ms + ping_ms * 3), done_rx).await;
let timed_out = result
.expect("Server must report within expected wall-clock time")
.expect("oneshot intact");
assert!(
timed_out,
"Server must disconnect on keepalive timeout when Pong is withheld"
);
}
#[tokio::test]
async fn keepalive_connection_survives_with_pong_responses() {
use futures::{SinkExt, StreamExt};
use std::time::Duration;
use tokio::net::TcpListener;
use tokio_tungstenite::tungstenite::Message as TMsg;
use tokio_tungstenite::{accept_async, connect_async};
let ping_ms = 100u64;
let timeout_ms = 250u64;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (result_tx, result_rx) = tokio::sync::oneshot::channel::<bool>();
// Server: sends Pings, resets deadline on Pong.
tokio::spawn(async move {
let (tcp, _) = listener.accept().await.unwrap();
let ws = accept_async(tcp).await.unwrap();
let (mut sink, mut stream) = ws.split();
let mut pong_deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms);
let mut ticker = tokio::time::interval_at(
tokio::time::Instant::now() + Duration::from_millis(ping_ms),
Duration::from_millis(ping_ms),
);
let timed_out = loop {
tokio::select! {
_ = ticker.tick() => {
if tokio::time::Instant::now() >= pong_deadline { break true; }
if sink.send(TMsg::Ping(bytes::Bytes::new())).await.is_err() {
break false;
}
}
frame = stream.next() => {
match frame {
Some(Ok(TMsg::Pong(_))) => {
pong_deadline = tokio::time::Instant::now()
+ Duration::from_millis(timeout_ms);
}
None | Some(Err(_)) => break false, // clean close
_ => {}
}
}
}
};
let _ = result_tx.send(timed_out);
});
let (ws_client, _) = connect_async(format!("ws://{addr}")).await.unwrap();
let (mut sink_c, mut stream_c) = ws_client.split();
// Client: respond to every Ping with Pong for several intervals.
let respond_task = tokio::spawn(async move {
while let Some(Ok(msg)) = stream_c.next().await {
if let TMsg::Ping(data) = msg
&& sink_c.send(TMsg::Pong(data)).await.is_err()
{
break;
}
}
});
// Run for a few intervals, then drop the client.
tokio::time::sleep(Duration::from_millis(ping_ms * 3)).await;
respond_task.abort();
let result = tokio::time::timeout(Duration::from_millis(200), result_rx).await;
let timed_out = result.unwrap_or(Ok(false)).unwrap_or(false);
assert!(
!timed_out,
"Server must NOT timeout when the client responds to Pings with Pongs"
);
}
#[tokio::test]
async fn two_node_pong_swallow_causes_disconnect_within_timeout() {
use futures::{SinkExt, StreamExt};
use std::time::Duration;
use tokio::net::TcpListener;
use tokio_tungstenite::tungstenite::Message as TMsg;
use tokio_tungstenite::{accept_async, connect_async};
let ping_ms = 100u64;
let timeout_ms = 250u64;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
// Node A (listener): sends Pings, never receives Pong.
let (a_done_tx, a_done_rx) = tokio::sync::oneshot::channel::<bool>();
tokio::spawn(async move {
let (tcp, _) = listener.accept().await.unwrap();
let ws = accept_async(tcp).await.unwrap();
let (mut sink, mut stream) = ws.split();
let pong_deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms);
let mut ticker = tokio::time::interval_at(
tokio::time::Instant::now() + Duration::from_millis(ping_ms),
Duration::from_millis(ping_ms),
);
let timed_out = loop {
tokio::select! {
_ = ticker.tick() => {
if tokio::time::Instant::now() >= pong_deadline { break true; }
if sink.send(TMsg::Ping(bytes::Bytes::new())).await.is_err() {
break false;
}
}
frame = stream.next() => {
match frame {
Some(Ok(_)) => {} // swallow all frames
_ => break false,
}
}
}
};
let _ = a_done_tx.send(timed_out);
});
// Node B: connects, drains frames silently (swallows Pings, never pongs).
let (ws_b, _) = connect_async(format!("ws://{addr}")).await.unwrap();
let (_sink_b, mut stream_b) = ws_b.split();
tokio::spawn(async move { while let Some(Ok(_)) = stream_b.next().await {} });
let result =
tokio::time::timeout(Duration::from_millis(timeout_ms + ping_ms * 3), a_done_rx).await;
let timed_out = result
.expect("Node A must report within expected wall-clock time")
.expect("channel intact");
assert!(
timed_out,
"Node A must disconnect due to keepalive timeout when Node B swallows Pongs"
);
}
}