diff --git a/aw-datastore/Cargo.toml b/aw-datastore/Cargo.toml index 73d768b5..51f0a003 100644 --- a/aw-datastore/Cargo.toml +++ b/aw-datastore/Cargo.toml @@ -7,6 +7,8 @@ edition = "2021" [features] default = [] # no features by default legacy_import_tests = [] +# Test-only feature to enable panic triggering for testing catch_unwind behavior +test-panic = [] [dependencies] dirs = "6" diff --git a/aw-datastore/src/worker.rs b/aw-datastore/src/worker.rs index 18eaf665..81eac165 100644 --- a/aw-datastore/src/worker.rs +++ b/aw-datastore/src/worker.rs @@ -78,6 +78,10 @@ pub enum Command { SetKeyValue(String, String), DeleteKeyValue(String), Close(), + /// Test-only command to trigger a panic in the worker thread. + /// Only available with the `test-panic` feature. + #[cfg(feature = "test-panic")] + TriggerPanic(String), } fn _unwrap_response( @@ -294,10 +298,26 @@ impl DatastoreWorker { self.quit = true; Ok(Response::Empty()) } + #[cfg(feature = "test-panic")] + Command::TriggerPanic(msg) => { + panic!("{}", msg); + } } } } +/// Extracts a human-readable message from a panic payload. +/// Handles the two common payload types (&str and String), with a fallback for other types. +fn extract_panic_message(panic_info: &Box) -> String { + if let Some(s) = panic_info.downcast_ref::<&str>() { + s.to_string() + } else if let Some(s) = panic_info.downcast_ref::() { + s.clone() + } else { + "Unknown panic".to_string() + } +} + impl Datastore { pub fn new(dbpath: String, legacy_import: bool) -> Self { let method = DatastoreMethod::File(dbpath); @@ -314,7 +334,22 @@ impl Datastore { mpsc_requests::channel::>(); let _thread = thread::spawn(move || { let mut di = DatastoreWorker::new(responder, legacy_import); - di.work_loop(method); + // Wrap work_loop in catch_unwind to handle any unexpected panics gracefully. + // This prevents panics from poisoning locks and leaving the server in an + // unusable state. Instead, the worker exits cleanly and the channel closes, + // allowing clients to receive proper errors instead of "poisoned lock" errors. + // See: https://github.com/ActivityWatch/aw-server-rust/issues/405 + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + di.work_loop(method); + })); + if let Err(panic_info) = result { + let panic_msg = extract_panic_message(&panic_info); + error!( + "Datastore worker panicked: {}. Worker shutting down gracefully.", + panic_msg + ); + } + // Worker exits, channel closes, future requests get clean errors }); Datastore { requester } } @@ -527,4 +562,56 @@ impl Datastore { Err(e) => panic!("Error closing database: {:?}", e), } } + + /// Test-only method to trigger a panic in the worker thread. + /// Used to verify that catch_unwind properly handles panics. + /// Only available with the `test-panic` feature. + #[cfg(feature = "test-panic")] + pub fn trigger_panic(&self, msg: &str) -> Result<(), DatastoreError> { + let cmd = Command::TriggerPanic(msg.to_string()); + match self.requester.request(cmd) { + Ok(receiver) => match receiver.collect() { + Ok(result) => match result { + Ok(_) => Ok(()), + Err(e) => Err(e), + }, + Err(_) => Err(DatastoreError::InternalError( + "Channel closed (worker panicked)".to_string(), + )), + }, + Err(_) => Err(DatastoreError::InternalError( + "Failed to send command (channel closed)".to_string(), + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::extract_panic_message; + + /// &str panic payload (produced by panic!("literal")) + #[test] + fn test_extract_panic_message_str_literal() { + let panic_info = std::panic::catch_unwind(|| panic!("static str panic")).unwrap_err(); + assert_eq!(extract_panic_message(&panic_info), "static str panic"); + } + + /// String panic payload (produced by panic!("{}", expr)) + #[test] + fn test_extract_panic_message_string() { + let msg = format!("formatted panic {}", 42); + let panic_info = + std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| panic!("{}", msg))) + .unwrap_err(); + assert_eq!(extract_panic_message(&panic_info), "formatted panic 42"); + } + + /// Unknown payload type falls back to "Unknown panic" + #[test] + fn test_extract_panic_message_unknown() { + let panic_info = + std::panic::catch_unwind(|| std::panic::panic_any(42u32)).unwrap_err(); + assert_eq!(extract_panic_message(&panic_info), "Unknown panic"); + } } diff --git a/aw-datastore/tests/datastore.rs b/aw-datastore/tests/datastore.rs index 739e8ade..9560daa2 100644 --- a/aw-datastore/tests/datastore.rs +++ b/aw-datastore/tests/datastore.rs @@ -531,4 +531,115 @@ mod datastore_tests { ); } } + + /// Test that when the worker thread panics, it shuts down gracefully + /// and subsequent requests receive clean errors instead of "poisoned lock" errors. + /// This tests the catch_unwind wrapper around work_loop(). + #[cfg(feature = "test-panic")] + #[test] + fn test_worker_panic_graceful_shutdown() { + // Setup datastore + let ds = Datastore::new_in_memory(false); + let bucket = create_test_bucket(&ds); + + // Verify datastore is working normally + let fetched_bucket = ds.get_bucket(&bucket.id).unwrap(); + assert_eq!(fetched_bucket.id, bucket.id); + + // Trigger a panic in the worker thread + let panic_result = ds.trigger_panic("Test panic for graceful shutdown"); + + // The panic should cause the worker to exit and close the channel. + // We may or may not get an error from the trigger_panic call itself + // depending on timing (whether the response was sent before the panic). + info!("trigger_panic result: {:?}", panic_result); + + // Give the worker thread a moment to fully shut down + std::thread::sleep(std::time::Duration::from_millis(100)); + + // Subsequent requests will fail. The current API uses .unwrap() which panics + // on closed channels, so we need to catch the panic and verify it's NOT + // a "poisoned lock" error (which would indicate catch_unwind didn't work). + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + ds.get_buckets() + })); + + // We expect either: + // 1. A panic with "SendError" (channel closed) - this is expected + // 2. An Ok(Err(...)) with a clean error - this is also fine + // What we must NOT see is a "poisoned" error which would mean catch_unwind failed + match result { + Ok(inner_result) => { + // If we got through without panic, check the error doesn't mention poison + if let Err(e) = inner_result { + let err_msg = format!("{:?}", e); + assert!( + !err_msg.to_lowercase().contains("poison"), + "Error should not be a poisoned lock error, got: {}", + err_msg + ); + info!("Got expected clean error: {}", err_msg); + } + } + Err(panic_info) => { + // We got a panic - extract the message and verify it's not about poisoned locks + let panic_msg = if let Some(s) = panic_info.downcast_ref::<&str>() { + s.to_string() + } else if let Some(s) = panic_info.downcast_ref::() { + s.clone() + } else { + format!("{:?}", panic_info) + }; + assert!( + !panic_msg.to_lowercase().contains("poison"), + "Panic should not be about poisoned lock, got: {}", + panic_msg + ); + info!( + "Got expected panic (channel closed, not poisoned lock): {}", + panic_msg + ); + } + } + } + + /// Test that panic handling correctly extracts both &str and String panic messages. + #[cfg(feature = "test-panic")] + #[test] + fn test_worker_panic_with_string_message() { + let ds = Datastore::new_in_memory(false); + + // The trigger_panic method accepts a &str which gets converted to String, + // so this tests the String panic message extraction path. + let _ = ds.trigger_panic("String panic message test"); + + std::thread::sleep(std::time::Duration::from_millis(100)); + + // Verify the datastore is no longer usable (channel closed) + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + ds.get_buckets() + })); + + // Should fail in some way (panic or error), just not with "poisoned" + match result { + Ok(inner_result) => { + assert!(inner_result.is_err(), "Expected error after worker panic"); + } + Err(panic_info) => { + // Got a panic (expected behavior with current API that uses unwrap) + let panic_msg = if let Some(s) = panic_info.downcast_ref::<&str>() { + s.to_string() + } else if let Some(s) = panic_info.downcast_ref::() { + s.clone() + } else { + format!("{:?}", panic_info) + }; + assert!( + !panic_msg.to_lowercase().contains("poison"), + "Panic should not be about poisoned lock, got: {}", + panic_msg + ); + } + } + } }