Files
huskies/server/src/crdt_state/state.rs
T

536 lines
18 KiB
Rust
Raw Normal View History

//! Internal CRDT state struct, statics, initialisation, and central write primitive.
use std::collections::HashMap;
use std::path::Path;
use std::sync::{Mutex, OnceLock};
use bft_json_crdt::json_crdt::*;
use bft_json_crdt::keypair::make_keypair;
use fastcrypto::ed25519::Ed25519KeyPair;
use fastcrypto::traits::ToFromBytes;
use serde_json::json;
use sqlx::SqlitePool;
use sqlx::sqlite::SqliteConnectOptions;
use tokio::sync::{broadcast, mpsc};
use super::VectorClock;
use super::hex;
use super::types::{CrdtEvent, PipelineDoc};
use crate::slog;
// ── Sync broadcast channels ──────────────────────────────────────────
pub(super) static CRDT_EVENT_TX: OnceLock<broadcast::Sender<CrdtEvent>> = OnceLock::new();
pub(super) static SYNC_TX: OnceLock<broadcast::Sender<SignedOp>> = OnceLock::new();
/// All persisted ops as JSON strings, in causal (insertion) order.
///
/// Pub(crate) so that `crdt_snapshot` can access it for compaction.
pub(crate) static ALL_OPS: OnceLock<Mutex<Vec<String>>> = OnceLock::new();
/// Live vector clock tracking op counts per author.
///
/// Updated in lockstep with `ALL_OPS` — every time an op is appended to the
/// journal, the corresponding author's count is incremented here. This avoids
/// re-parsing all ops when a peer requests `our_vector_clock()`.
pub(crate) static VECTOR_CLOCK: OnceLock<Mutex<super::VectorClock>> = OnceLock::new();
/// Append an op's JSON to `ALL_OPS` and bump the author's count in `VECTOR_CLOCK`.
///
/// Centralises the bookkeeping that must stay in sync between the two statics.
pub(super) fn track_op(signed: &SignedOp, json: String) {
if let Some(all) = ALL_OPS.get()
&& let Ok(mut v) = all.lock()
{
v.push(json);
}
if let Some(vc) = VECTOR_CLOCK.get()
&& let Ok(mut clock) = vc.lock()
{
let author_hex = super::hex::encode(&signed.author());
*clock.entry(author_hex).or_insert(0) += 1;
}
}
pub(super) struct CrdtState {
pub(super) crdt: BaseCrdt<PipelineDoc>,
pub(super) keypair: Ed25519KeyPair,
/// Maps story_id → index in the items ListCrdt for O(1) lookup.
pub(super) index: HashMap<String, usize>,
/// Maps node_id (hex) → index in the nodes ListCrdt for O(1) lookup.
pub(super) node_index: HashMap<String, usize>,
/// Channel sender for fire-and-forget op persistence.
pub(super) persist_tx: mpsc::UnboundedSender<SignedOp>,
}
static CRDT_STATE: OnceLock<Mutex<CrdtState>> = OnceLock::new();
#[cfg(test)]
thread_local! {
static CRDT_STATE_TL: OnceLock<Mutex<CrdtState>> = const { OnceLock::new() };
}
#[cfg(not(test))]
pub(super) fn get_crdt() -> Option<&'static Mutex<CrdtState>> {
CRDT_STATE.get()
}
#[cfg(test)]
pub(super) fn get_crdt() -> Option<&'static Mutex<CrdtState>> {
let tl = CRDT_STATE_TL.with(|lock| {
if lock.get().is_some() {
Some(lock as *const OnceLock<Mutex<CrdtState>>)
} else {
None
}
});
if let Some(ptr) = tl {
// SAFETY: The thread-local lives as long as the thread, which outlives
// any test using it. We only need 'static for the return type.
let lock = unsafe { &*ptr };
lock.get()
} else {
CRDT_STATE.get()
}
}
/// Initialise the CRDT state layer.
///
/// Opens the SQLite database, loads or creates a node keypair, replays any
/// persisted ops to reconstruct state, and spawns a background persistence
/// task. Safe to call only once; subsequent calls are no-ops.
pub async fn init(db_path: &Path) -> Result<(), sqlx::Error> {
if CRDT_STATE.get().is_some() {
return Ok(());
}
let options = SqliteConnectOptions::new()
.filename(db_path)
.create_if_missing(true);
let pool = SqlitePool::connect_with(options).await?;
sqlx::migrate!("./migrations").run(&pool).await?;
// Load or create the node keypair.
let keypair = load_or_create_keypair(&pool).await?;
let mut crdt = BaseCrdt::<PipelineDoc>::new(&keypair);
// Replay persisted ops to reconstruct state.
let rows: Vec<(String,)> = sqlx::query_as("SELECT op_json FROM crdt_ops ORDER BY rowid ASC")
.fetch_all(&pool)
.await?;
let mut all_ops_vec = Vec::with_capacity(rows.len());
let mut vector_clock = VectorClock::new();
for (op_json,) in &rows {
if let Ok(signed_op) = serde_json::from_str::<SignedOp>(op_json) {
let author_hex = hex::encode(&signed_op.author());
*vector_clock.entry(author_hex).or_insert(0) += 1;
crdt.apply(signed_op);
all_ops_vec.push(op_json.clone());
} else {
slog!("[crdt] Warning: failed to deserialize stored op");
}
}
let _ = ALL_OPS.set(Mutex::new(all_ops_vec));
let _ = VECTOR_CLOCK.set(Mutex::new(vector_clock));
// Build the indices from the reconstructed state.
let index = rebuild_index(&crdt);
let node_index = rebuild_node_index(&crdt);
slog!(
"[crdt] Initialised: {} ops replayed, {} items indexed, {} nodes indexed",
rows.len(),
index.len(),
node_index.len()
);
// Spawn background persistence task.
let (persist_tx, mut persist_rx) = mpsc::unbounded_channel::<SignedOp>();
tokio::spawn(async move {
while let Some(op) = persist_rx.recv().await {
let op_json = match serde_json::to_string(&op) {
Ok(j) => j,
Err(e) => {
slog!("[crdt] Failed to serialize op: {e}");
continue;
}
};
let op_id = hex::encode(&op.id());
let seq = op.inner.seq as i64;
let now = chrono::Utc::now().to_rfc3339();
let result = sqlx::query(
"INSERT INTO crdt_ops (op_id, seq, op_json, created_at) \
VALUES (?1, ?2, ?3, ?4) \
ON CONFLICT(op_id) DO NOTHING",
)
.bind(&op_id)
.bind(seq)
.bind(&op_json)
.bind(&now)
.execute(&pool)
.await;
if let Err(e) = result {
slog!("[crdt] Failed to persist op {}: {e}", &op_id[..12]);
}
}
});
let state = CrdtState {
crdt,
keypair,
index,
node_index,
persist_tx,
};
let _ = CRDT_STATE.set(Mutex::new(state));
// Initialise the CRDT event broadcast channel.
let (event_tx, _) = broadcast::channel::<CrdtEvent>(256);
let _ = CRDT_EVENT_TX.set(event_tx);
// Initialise the sync broadcast channel for outgoing ops.
let (sync_tx, _) = broadcast::channel::<SignedOp>(1024);
let _ = SYNC_TX.set(sync_tx);
Ok(())
}
/// Initialise a minimal in-memory CRDT state for unit tests.
///
/// This avoids the async SQLite setup from `init()`. Ops are accepted via a
/// channel whose receiver is immediately dropped, so nothing is persisted.
/// Safe to call multiple times — subsequent calls are no-ops (OnceLock).
#[cfg(test)]
pub fn init_for_test() {
// Initialise thread-local CRDT for test isolation.
// Only creates a new CRDT if one isn't set yet on this thread;
// subsequent calls are no-ops (matching the old OnceLock semantics
// while keeping each thread isolated).
CRDT_STATE_TL.with(|lock| {
if lock.get().is_none() {
let keypair = make_keypair();
let crdt = BaseCrdt::<PipelineDoc>::new(&keypair);
let (persist_tx, _rx) = mpsc::unbounded_channel();
let state = CrdtState {
crdt,
keypair,
index: HashMap::new(),
node_index: HashMap::new(),
persist_tx,
};
let _ = lock.set(Mutex::new(state));
}
});
let _ = CRDT_EVENT_TX.get_or_init(|| broadcast::channel::<CrdtEvent>(256).0);
let _ = SYNC_TX.get_or_init(|| broadcast::channel::<SignedOp>(1024).0);
let _ = ALL_OPS.get_or_init(|| Mutex::new(Vec::new()));
let _ = VECTOR_CLOCK.get_or_init(|| Mutex::new(VectorClock::new()));
}
/// Load or create the Ed25519 keypair used by this node.
async fn load_or_create_keypair(pool: &SqlitePool) -> Result<Ed25519KeyPair, sqlx::Error> {
let row: Option<(Vec<u8>,)> =
sqlx::query_as("SELECT seed FROM crdt_node_identity WHERE id = 1")
.fetch_optional(pool)
.await?;
if let Some((seed,)) = row {
// Reconstruct from stored seed. The seed is the 32-byte private key.
if let Ok(kp) = Ed25519KeyPair::from_bytes(&seed) {
return Ok(kp);
}
slog!("[crdt] Stored keypair invalid, regenerating");
}
let kp = make_keypair();
let seed = kp.as_bytes().to_vec();
sqlx::query("INSERT INTO crdt_node_identity (id, seed) VALUES (1, ?1) ON CONFLICT(id) DO UPDATE SET seed = excluded.seed")
.bind(&seed)
.execute(pool)
.await?;
Ok(kp)
}
/// Rebuild the story_id → list index mapping from the current CRDT state.
pub(super) fn rebuild_index(crdt: &BaseCrdt<PipelineDoc>) -> HashMap<String, usize> {
let mut map = HashMap::new();
for (i, item) in crdt.doc.items.iter().enumerate() {
if let JsonValue::String(ref sid) = item.story_id.view() {
map.insert(sid.clone(), i);
}
}
map
}
/// Rebuild the node_id → nodes list index mapping from the current CRDT state.
pub(super) fn rebuild_node_index(crdt: &BaseCrdt<PipelineDoc>) -> HashMap<String, usize> {
let mut map = HashMap::new();
for (i, node) in crdt.doc.nodes.iter().enumerate() {
if let JsonValue::String(ref nid) = node.node_id.view() {
map.insert(nid.clone(), i);
}
}
map
}
// ── Write path ───────────────────────────────────────────────────────
/// Create a CRDT op via `op_fn`, sign it, apply it, and send it to the
/// persistence channel. The closure receives `&mut CrdtState` so it can
/// mutably access the CRDT document, while `sign` only needs `&keypair`.
pub(super) fn apply_and_persist<F>(state: &mut CrdtState, op_fn: F)
where
F: FnOnce(&mut CrdtState) -> bft_json_crdt::op::Op<JsonValue>,
{
let raw_op = op_fn(state);
let signed = raw_op.sign(&state.keypair);
state.crdt.apply(signed.clone());
if let Err(e) = state.persist_tx.send(signed.clone()) {
crate::slog_error!(
"[crdt] Failed to send op to persist task: {e}; persist task may be dead. \
In-memory state is now ahead of persisted state."
);
}
// Track in ALL_OPS + VECTOR_CLOCK, then broadcast to sync peers.
if let Ok(json) = serde_json::to_string(&signed) {
track_op(&signed, json);
}
if let Some(tx) = SYNC_TX.get() {
let _ = tx.send(signed);
}
}
/// Broadcast a CRDT event to all subscribers.
pub(super) fn emit_event(event: CrdtEvent) {
if let Some(tx) = CRDT_EVENT_TX.get() {
let _ = tx.send(event);
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::types::PipelineItemCrdt;
use super::super::write::write_item;
use super::super::read::{extract_item_view, read_item};
use bft_json_crdt::json_crdt::OpState;
use bft_json_crdt::keypair::make_keypair;
use bft_json_crdt::op::ROOT_ID;
use super::super::hex;
use serde_json::json;
#[test]
fn crdt_ops_replay_reconstructs_state() {
let kp = make_keypair();
let mut crdt1 = BaseCrdt::<PipelineDoc>::new(&kp);
// Build state with a series of ops.
let item_json: JsonValue = json!({
"story_id": "30_story_replay",
"stage": "1_backlog",
"name": "Replay Test",
"agent": "",
"retry_count": 0.0,
"blocked": false,
"depends_on": "",
"claimed_by": "",
"claimed_at": 0.0,
})
.into();
let op1 = crdt1.doc.items.insert(ROOT_ID, item_json).sign(&kp);
crdt1.apply(op1.clone());
let op2 = crdt1.doc.items[0]
.stage
.set("2_current".to_string())
.sign(&kp);
crdt1.apply(op2.clone());
let op3 = crdt1.doc.items[0]
.name
.set("Updated Name".to_string())
.sign(&kp);
crdt1.apply(op3.clone());
// Replay ops on a fresh CRDT.
let mut crdt2 = BaseCrdt::<PipelineDoc>::new(&kp);
crdt2.apply(op1);
crdt2.apply(op2);
crdt2.apply(op3);
assert_eq!(
crdt1.doc.items[0].stage.view(),
crdt2.doc.items[0].stage.view()
);
assert_eq!(
crdt1.doc.items[0].name.view(),
crdt2.doc.items[0].name.view()
);
}
#[test]
fn rebuild_index_maps_story_ids() {
let kp = make_keypair();
let mut crdt = BaseCrdt::<PipelineDoc>::new(&kp);
for (sid, stage) in &[("10_story_a", "1_backlog"), ("20_story_b", "2_current")] {
let item: JsonValue = json!({
"story_id": sid,
"stage": stage,
"name": "",
"agent": "",
"retry_count": 0.0,
"blocked": false,
"depends_on": "",
"claimed_by": "",
"claimed_at": 0.0,
})
.into();
let op = crdt.doc.items.insert(ROOT_ID, item).sign(&kp);
crdt.apply(op);
}
let index = rebuild_index(&crdt);
assert_eq!(index.len(), 2);
assert!(index.contains_key("10_story_a"));
assert!(index.contains_key("20_story_b"));
}
#[tokio::test]
async fn init_and_write_read_roundtrip() {
let tmp = tempfile::tempdir().unwrap();
let db_path = tmp.path().join("crdt_test.db");
// Init directly (not via the global singleton, for test isolation).
let options = SqliteConnectOptions::new()
.filename(&db_path)
.create_if_missing(true);
let pool = SqlitePool::connect_with(options).await.unwrap();
sqlx::migrate!("./migrations").run(&pool).await.unwrap();
let keypair = make_keypair();
let mut crdt = BaseCrdt::<PipelineDoc>::new(&keypair);
// Insert and update like write_item does.
let item_json: JsonValue = json!({
"story_id": "50_story_roundtrip",
"stage": "1_backlog",
"name": "Roundtrip",
"agent": "",
"retry_count": 0.0,
"blocked": false,
"depends_on": "",
"claimed_by": "",
"claimed_at": 0.0,
})
.into();
let insert_op = crdt.doc.items.insert(ROOT_ID, item_json).sign(&keypair);
crdt.apply(insert_op.clone());
// Persist the op.
let op_json = serde_json::to_string(&insert_op).unwrap();
let op_id = hex::encode(&insert_op.id());
let now = chrono::Utc::now().to_rfc3339();
sqlx::query(
"INSERT INTO crdt_ops (op_id, seq, op_json, created_at) VALUES (?1, ?2, ?3, ?4)",
)
.bind(&op_id)
.bind(insert_op.inner.seq as i64)
.bind(&op_json)
.bind(&now)
.execute(&pool)
.await
.unwrap();
// Reconstruct from DB.
let rows: Vec<(String,)> =
sqlx::query_as("SELECT op_json FROM crdt_ops ORDER BY rowid ASC")
.fetch_all(&pool)
.await
.unwrap();
let mut crdt2 = BaseCrdt::<PipelineDoc>::new(&keypair);
for (json_str,) in &rows {
let op: SignedOp = serde_json::from_str(json_str).unwrap();
crdt2.apply(op);
}
let view = extract_item_view(&crdt2.doc.items[0]).unwrap();
assert_eq!(view.story_id, "50_story_roundtrip");
assert_eq!(view.stage, "1_backlog");
assert_eq!(view.name.as_deref(), Some("Roundtrip"));
}
#[test]
fn persist_tx_send_failure_logs_error() {
let kp = make_keypair();
let crdt = BaseCrdt::<PipelineDoc>::new(&kp);
let (persist_tx, persist_rx) = mpsc::unbounded_channel::<SignedOp>();
let mut state = CrdtState {
crdt,
keypair: kp,
index: HashMap::new(),
node_index: HashMap::new(),
persist_tx,
};
// Drop the receiver so that the next send fails immediately.
drop(persist_rx);
let item_json: JsonValue = json!({
"story_id": "518_story_persist_fail",
"stage": "1_backlog",
"name": "Persist Fail Test",
"agent": "",
"retry_count": 0.0,
"blocked": false,
"depends_on": "",
"claimed_by": "",
"claimed_at": 0.0,
})
.into();
let before_errors = crate::log_buffer::global()
.get_recent_entries(1000, None, Some(&crate::log_buffer::LogLevel::Error))
.len();
apply_and_persist(&mut state, |s| s.crdt.doc.items.insert(ROOT_ID, item_json));
let error_entries = crate::log_buffer::global().get_recent_entries(
1000,
None,
Some(&crate::log_buffer::LogLevel::Error),
);
assert!(
error_entries.len() > before_errors,
"expected an ERROR log entry when persist_tx send fails, but none was added"
);
let last_error = &error_entries[error_entries.len() - 1];
assert!(
last_error.message.contains("persist"),
"error message should mention persist: {}",
last_error.message
);
assert!(
last_error.message.contains("ahead") || last_error.message.contains("diverged"),
"error message should note in-memory/persisted divergence: {}",
last_error.message
);
}
}