diff --git a/CHANGELOG.md b/CHANGELOG.md index d86ca7022f..d63d8b2a5b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ ### Logins - add checkpoint API: `set_checkpoint(checkpoint)` and `get_checkpoint()` for desktop's rolling migration +- add `count()` method to return the number of logins [Full Changelog](In progress) diff --git a/components/logins/src/db.rs b/components/logins/src/db.rs index 2319c1e4ad..d3b86f7382 100644 --- a/components/logins/src/db.rs +++ b/components/logins/src/db.rs @@ -138,6 +138,12 @@ impl LoginDb { rows.collect::>() } + pub fn count_all(&self) -> Result { + let mut stmt = self.db.prepare_cached(&COUNT_ALL_SQL)?; + let count: i64 = stmt.query_row([], |row| row.get(0))?; + Ok(count) + } + pub fn get_by_base_domain(&self, base_domain: &str) -> Result> { // We first parse the input string as a host so it is normalized. let base_host = match Host::parse(base_domain) { @@ -817,6 +823,13 @@ lazy_static! { SELECT {common_cols} FROM loginsM WHERE is_overridden = 0", common_cols = schema::COMMON_COLS, ); + static ref COUNT_ALL_SQL: String = format!( + "SELECT COUNT (*) FROM ( + SELECT guid FROM loginsL WHERE is_deleted = 0 + UNION ALL + SELECT guid FROM loginsM WHERE is_overridden = 0 + )" + ); static ref GET_BY_GUID_SQL: String = format!( "SELECT {common_cols} FROM loginsL @@ -1058,6 +1071,36 @@ mod tests { assert_eq!(db.get_all().unwrap().len(), 2); } + #[test] + fn test_count_all() { + ensure_initialized(); + + let login_a = LoginEntry { + origin: "https://a.example.com".into(), + http_realm: Some("https://www.example.com".into()), + username: "test".into(), + password: "sekret".into(), + ..LoginEntry::default() + }; + + let login_b = LoginEntry { + origin: "https://b.example.com".into(), + http_realm: Some("https://www.example.com".into()), + username: "test".into(), + password: "sekret".into(), + ..LoginEntry::default() + }; + + let db = LoginDb::open_in_memory().unwrap(); + + db.add_many(vec![login_a.clone(), login_b.clone()], &*TEST_ENCDEC) + .expect("should be able to add logins"); + + let count = db.count_all().expect("should work"); + + assert_eq!(count, 2); + } + #[test] fn test_add_many() { ensure_initialized(); diff --git a/components/logins/src/logins.udl b/components/logins/src/logins.udl index 6b7e9287c5..0b59d9acad 100644 --- a/components/logins/src/logins.udl +++ b/components/logins/src/logins.udl @@ -223,6 +223,9 @@ interface LoginStore { [Throws=LoginsApiError] sequence list(); + [Throws=LoginsApiError] + i64 count(); + [Throws=LoginsApiError] sequence get_by_base_domain([ByRef] string base_domain); diff --git a/components/logins/src/store.rs b/components/logins/src/store.rs index 1f0e937fa6..474f23329f 100644 --- a/components/logins/src/store.rs +++ b/components/logins/src/store.rs @@ -116,6 +116,11 @@ impl LoginStore { }) } + #[handle_error(Error)] + pub fn count(&self) -> ApiResult { + self.db.lock().count_all() + } + #[handle_error(Error)] pub fn get(&self, id: &str) -> ApiResult> { match self.db.lock().get_by_id(id) {