Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

198 changes: 185 additions & 13 deletions packages/sqlite-web-core/src/coordination.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,23 @@ use std::rc::Rc;
use uuid::Uuid;
use wasm_bindgen::prelude::*;
use wasm_bindgen::JsCast;
use wasm_bindgen_futures::spawn_local;
use web_sys::BroadcastChannel;
use wasm_bindgen_futures::{spawn_local, JsFuture};
use web_sys::{BroadcastChannel, DedicatedWorkerGlobalScope};

use crate::database::SQLiteDatabase;
use crate::messages::{ChannelMessage, PendingQuery};
use crate::messages::{ChannelMessage, PendingQuery, WORKER_ERROR_TYPE_INITIALIZATION_PENDING};
use crate::util::{js_value_to_string, sanitize_identifier, set_js_property};

// Worker state
pub struct WorkerState {
pub worker_id: String,
pub is_leader: Rc<RefCell<bool>>,
pub has_leader: Rc<RefCell<bool>>,
pub db: Rc<RefCell<Option<SQLiteDatabase>>>,
pub channel: BroadcastChannel,
pub db_name: String,
pub pending_queries: Rc<RefCell<HashMap<String, PendingQuery>>>,
pub follower_timeout_ms: f64,
}

fn reflect_get(target: &JsValue, key: &str) -> Result<JsValue, JsValue> {
Expand All @@ -40,6 +42,32 @@ fn send_channel_message(
})
}

fn post_worker_message(obj: &js_sys::Object) -> Result<(), String> {
let global = js_sys::global();
let scope: DedicatedWorkerGlobalScope = global
.dyn_into()
.map_err(|_| "Failed to access worker scope".to_string())?;
scope
.post_message(obj.as_ref())
.map_err(|err| js_value_to_string(&err))
}

fn send_worker_ready_message() -> Result<(), String> {
let message = js_sys::Object::new();
set_js_property(&message, "type", &JsValue::from_str("worker-ready"))
.map_err(|err| js_value_to_string(&err))?;
post_worker_message(&message)
}

fn send_worker_error_message(error: &str) -> Result<(), String> {
let message = js_sys::Object::new();
set_js_property(&message, "type", &JsValue::from_str("worker-error"))
.map_err(|err| js_value_to_string(&err))?;
set_js_property(&message, "error", &JsValue::from_str(error))
.map_err(|err| js_value_to_string(&err))?;
post_worker_message(&message)
}

impl WorkerState {
pub fn new() -> Result<Self, JsValue> {
fn get_db_name_from_global() -> Result<String, JsValue> {
Expand All @@ -62,31 +90,56 @@ impl WorkerState {
}
}

fn get_follower_timeout_from_global() -> f64 {
let global = js_sys::global();
let val = Reflect::get(&global, &JsValue::from_str("__SQLITE_FOLLOWER_TIMEOUT_MS"))
.unwrap_or(JsValue::UNDEFINED);
if let Some(n) = val.as_f64() {
if n.is_finite() && n >= 0.0 {
return n;
}
}
5000.0
}

let worker_id = Uuid::new_v4().to_string();
let db_name_raw = get_db_name_from_global()?;
let channel_name = format!("sqlite-queries-{}", sanitize_identifier(&db_name_raw));
let channel = BroadcastChannel::new(&channel_name)?;
let follower_timeout_ms = get_follower_timeout_from_global();

Ok(WorkerState {
worker_id,
is_leader: Rc::new(RefCell::new(false)),
has_leader: Rc::new(RefCell::new(false)),
db: Rc::new(RefCell::new(None)),
channel,
db_name: db_name_raw,
pending_queries: Rc::new(RefCell::new(HashMap::new())),
follower_timeout_ms,
})
}

pub fn setup_channel_listener(&self) -> Result<(), JsValue> {
let is_leader = Rc::clone(&self.is_leader);
let has_leader = Rc::clone(&self.has_leader);
let db = Rc::clone(&self.db);
let pending_queries = Rc::clone(&self.pending_queries);
let channel = self.channel.clone();
let worker_id = self.worker_id.clone();

let onmessage = Closure::wrap(Box::new(move |event: web_sys::MessageEvent| {
let data = event.data();
if let Ok(msg) = serde_wasm_bindgen::from_value::<ChannelMessage>(data) {
handle_channel_message(&is_leader, &db, &channel, &pending_queries, msg);
handle_channel_message(
&is_leader,
&has_leader,
&db,
&channel,
&pending_queries,
&worker_id,
msg,
);
}
}) as Box<dyn FnMut(web_sys::MessageEvent)>);

Expand All @@ -96,9 +149,48 @@ impl WorkerState {
Ok(())
}

pub fn start_leader_probe(self: &Rc<Self>) {
if *self.is_leader.borrow() {
return;
}
let has_leader = Rc::clone(&self.has_leader);
let channel = self.channel.clone();
let worker_id = self.worker_id.clone();
let follower_timeout_ms = self.follower_timeout_ms;
spawn_local(async move {
const POLL_INTERVAL_MS: f64 = 250.0;
let max_attempts = if follower_timeout_ms.is_finite() && follower_timeout_ms > 0.0 {
(follower_timeout_ms / POLL_INTERVAL_MS).ceil() as u32
} else {
1
};
let mut attempts = 0;
while attempts < max_attempts {
attempts += 1;
if *has_leader.borrow() {
break;
}
let ping = ChannelMessage::LeaderPing {
requester_id: worker_id.clone(),
};
if let Err(err_msg) = send_channel_message(&channel, &ping) {
let _ = send_worker_error_message(&err_msg);
break;
}
sleep_ms(POLL_INTERVAL_MS as i32).await;
}
if !*has_leader.borrow() {
let timeout = follower_timeout_ms.max(0.0);
let message = format!("Leader election timed out after {:.0}ms", timeout);
let _ = send_worker_error_message(&message);
}
});
}

pub async fn attempt_leadership(&self) -> Result<(), JsValue> {
let worker_id = self.worker_id.clone();
let is_leader = Rc::clone(&self.is_leader);
let has_leader = Rc::clone(&self.has_leader);
let db = Rc::clone(&self.db);
let channel = self.channel.clone();
let db_name_for_handler = self.db_name.clone();
Expand All @@ -113,16 +205,19 @@ impl WorkerState {

let handler = Closure::once(move |_lock: JsValue| -> Promise {
*is_leader.borrow_mut() = true;
*has_leader.borrow_mut() = true;

let db = Rc::clone(&db);
let channel = channel.clone();
let worker_id = worker_id.clone();
let db_name = db_name_for_handler.clone();
let has_leader_inner = Rc::clone(&has_leader);

spawn_local(async move {
match SQLiteDatabase::initialize_opfs(&db_name).await {
Ok(database) => {
*db.borrow_mut() = Some(database);
*has_leader_inner.borrow_mut() = true;

let msg = ChannelMessage::NewLeader {
leader_id: worker_id.clone(),
Expand All @@ -135,8 +230,15 @@ impl WorkerState {
};
let _ = send_channel_message(&channel, &fallback);
}
if let Err(err_msg) = send_worker_ready_message() {
let _ = send_worker_error_message(&err_msg);
}
}
Err(err) => {
let msg = js_value_to_string(&err);
*has_leader_inner.borrow_mut() = false;
let _ = send_worker_error_message(&msg);
}
Err(_e) => {}
}
});

Expand Down Expand Up @@ -169,6 +271,9 @@ impl WorkerState {
if *self.is_leader.borrow() {
exec_on_db(Rc::clone(&self.db), sql, params).await
} else {
if !*self.has_leader.borrow() {
return Err(WORKER_ERROR_TYPE_INITIALIZATION_PENDING.to_string());
}
let query_id = Uuid::new_v4().to_string();

let promise = Promise::new(&mut |resolve, reject| {
Expand All @@ -182,7 +287,7 @@ impl WorkerState {
let timeout_promise = schedule_timeout_promise(
Rc::clone(&self.pending_queries),
query_id.clone(),
5000.0,
self.follower_timeout_ms,
);

let result = wasm_bindgen_futures::JsFuture::from(js_sys::Promise::race(
Expand All @@ -194,17 +299,19 @@ impl WorkerState {
Ok(val) => val
.as_string()
.ok_or_else(|| "Invalid response".to_string()),
Err(e) => Err(format!("{e:?}")),
Err(e) => Err(js_value_to_string(&e)),
}
}
}
}

fn handle_channel_message(
is_leader: &Rc<RefCell<bool>>,
has_leader: &Rc<RefCell<bool>>,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it's best to avoid boolean blindness

https://medium.com/@itsme.mittal/boolean-blindness-60937910e40e

db: &Rc<RefCell<Option<SQLiteDatabase>>>,
channel: &BroadcastChannel,
pending_queries: &Rc<RefCell<HashMap<String, PendingQuery>>>,
worker_id: &str,
msg: ChannelMessage,
) {
match msg {
Expand Down Expand Up @@ -235,7 +342,28 @@ fn handle_channel_message(
result,
error,
} => handle_query_response(pending_queries, query_id, result, error),
ChannelMessage::NewLeader { leader_id: _ } => {}
ChannelMessage::NewLeader { leader_id: _ } => {
let mut has_leader_ref = has_leader.borrow_mut();
let already_had_leader = *has_leader_ref;
*has_leader_ref = true;
drop(has_leader_ref);

if !already_had_leader {
if let Err(err_msg) = send_worker_ready_message() {
let _ = send_worker_error_message(&err_msg);
}
}
}
ChannelMessage::LeaderPing { requester_id: _ } => {
if *is_leader.borrow() {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: would inlining the if not work?

Suggested change
ChannelMessage::LeaderPing { requester_id: _ } => {
if *is_leader.borrow() {
ChannelMessage::LeaderPing { requester_id: _ } if *is_leader.borrow() => {

let response = ChannelMessage::NewLeader {
leader_id: worker_id.to_string(),
};
if let Err(err_msg) = send_channel_message(channel, &response) {
let _ = send_worker_error_message(&err_msg);
}
}
}
}
}

Expand Down Expand Up @@ -273,6 +401,46 @@ fn handle_query_response(
}
}

async fn sleep_ms(ms: i32) {
let promise = js_sys::Promise::new(&mut |resolve, _| {
let resolve_for_timeout = resolve.clone();
let closure = Closure::once(move || {
let _ = resolve_for_timeout.call0(&JsValue::NULL);
});

let timeout_result = js_sys::global()
.dyn_into::<DedicatedWorkerGlobalScope>()
.map_err(|_| ())
.and_then(|scope| {
scope
.set_timeout_with_callback_and_timeout_and_arguments_0(
closure.as_ref().unchecked_ref(),
ms,
)
.map(|_| ())
.map_err(|_| ())
})
.or_else(|_| {
web_sys::window().ok_or(()).and_then(|win| {
win.set_timeout_with_callback_and_timeout_and_arguments_0(
closure.as_ref().unchecked_ref(),
ms,
)
.map(|_| ())
.map_err(|_| ())
})
});

if timeout_result.is_err() {
// As a best-effort fallback, resolve immediately.
let _ = resolve.call0(&JsValue::NULL);
}

closure.forget();
});
let _ = JsFuture::from(promise).await;
}

async fn exec_on_db(
db: Rc<RefCell<Option<SQLiteDatabase>>>,
sql: String,
Expand Down Expand Up @@ -501,12 +669,11 @@ mod tests {
.execute_query("SELECT 1".to_string(), None)
.await;
match result {
Err(msg) => assert!(
msg.contains("timeout") || msg.contains("Query timeout"),
"Follower should timeout, got: {}",
msg
Err(msg) => assert_eq!(
msg, WORKER_ERROR_TYPE_INITIALIZATION_PENDING,
"Follower should reject while leader is pending"
),
Ok(_) => panic!("Expected timeout error for follower"),
Ok(_) => panic!("Expected initialization error for follower"),
}
}
}
Expand Down Expand Up @@ -577,4 +744,9 @@ mod tests {
);
}
}

#[wasm_bindgen_test(async)]
async fn test_sleep_ms_completes() {
sleep_ms(0).await;
}
}
Loading