diff --git a/src/iterative.rs b/src/iterative.rs index ca068a9..7d35233 100644 --- a/src/iterative.rs +++ b/src/iterative.rs @@ -181,6 +181,39 @@ impl BinarySearchTree for IterativeBST { self.size += 1; } } + + ///Inserts the given value as a node or updates an existing node with the same value. + /// + /// If a node with the given value already exists, the provided function `f` is called + /// with a mutable reference to the existing value to update it. + /// + /// # Example + /// + /// ```rust + /// use bst_rs::{BinarySearchTree, RecursiveBST}; + /// + /// let mut bst = RecursiveBST::new(); + /// + /// bst.insert((1, 10)); + /// assert_eq!(bst.retrieve(&(1, 10)), Some(&(1, 10))); + /// + /// bst.insert_or_update((1, 10), |tuple| tuple.1 = 30); + /// assert_eq!(bst.retrieve(&(1, 10)), None); + /// assert_eq!(bst.retrieve(&(1, 30)), Some(&(1, 30))); + /// + /// bst.insert_or_update((2, 20), |_| {}); + /// assert_eq!(bst.size(), 2); + /// ``` + fn insert_or_update(&mut self, value: T, f: F) + where + F: Fn(&mut T) { + if let Ok(inserted) = Node::iterative_insert_or_update(&mut self.root, value, f) { + if inserted { + self.size += 1; + } + } + } + /// Returns `true` if the binary search tree contains an element with the given value. /// @@ -1588,4 +1621,30 @@ mod tests { assert_eq!(actual_bst, expected_bst); } + + #[test] + fn successfully_insert_or_update_elements() { + let mut bst = IterativeBST::new(); + + bst.insert((1, 10)); + bst.insert((2, 20)); + bst.insert((3, 30)); + assert_eq!(bst.size(), 3); + + bst.insert_or_update((1, 50), |pair| pair.1 = 15); + assert_eq!(bst.retrieve(&(1, 50)), Some(&(1, 50))); + bst.insert_or_update((1, 50), |pair| pair.1 = 15); + assert_eq!(bst.retrieve(&(1, 50)), None); + println!("{}", bst); + assert_eq!(bst.size(), 4); + + bst.insert_or_update((4, 40), |_| {}); + assert_eq!(bst.size(), 5); + + bst.insert_or_update((4, 40), |pair| { + pair.1 *= 2; + }); + assert_eq!(bst.retrieve(&(4, 40)), None); + assert_eq!(bst.retrieve(&(4, 80)), Some(&(4, 80))); + } } diff --git a/src/lib.rs b/src/lib.rs index e31dad4..30af7bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -244,6 +244,13 @@ pub trait BinarySearchTree { /// **Duplicate values are _not allowed_**. fn insert(&mut self, value: T); + /// Inserts given value as a node or updates T with some function + /// + /// **Duplicate values are allowed**. + fn insert_or_update(&mut self, value: T, f: F) + where + F: Fn(&mut T); + /// Returns `true` if the binary search tree contains an element with the given value. fn contains(&self, value: &T) -> bool; diff --git a/src/node.rs b/src/node.rs index 0241cb7..bfc2f78 100644 --- a/src/node.rs +++ b/src/node.rs @@ -32,6 +32,24 @@ impl Node { Ok(()) } + pub(crate) fn iterative_insert_or_update(mut root: &mut HeapNode, value: T, f: F) -> Result + where + F: Fn(&mut T), + { + while let Some(ref mut node) = root { + match value.cmp(&node.value) { + Ordering::Equal => { + f(&mut node.value); + return Ok(false); + } + Ordering::Less => root = &mut node.left, + Ordering::Greater => root = &mut node.right, + } + } + *root = Some(Box::new(Node::new(value))); + Ok(true) + } + pub(crate) fn recursive_insert(&mut self, value: T) -> Result<(), ()> { match value.cmp(&self.value) { Ordering::Equal => Err(()), @@ -52,6 +70,32 @@ impl Node { } } + pub(crate) fn recursive_insert_or_update(&mut self, value: T, f: F) -> Result + where + F: Fn(&mut T), + { + match value.cmp(&self.value) { + Ordering::Equal => { + f(&mut self.value); + Ok(false) + } + Ordering::Less => match self.left { + None => { + self.left = Some(Box::from(Node::new(value))); + Ok(true) + } + Some(ref mut node) => node.recursive_insert_or_update(value, f), + }, + Ordering::Greater => match self.right { + None => { + self.right = Some(Box::from(Node::new(value))); + Ok(true) + } + Some(ref mut node) => node.recursive_insert_or_update(value, f), + }, + } + } + pub(crate) fn iterative_contains(mut root: &HeapNode, value: &T) -> bool { while let Some(current) = root { match value.cmp(¤t.value) { diff --git a/src/recursive.rs b/src/recursive.rs index c92fb52..68d1088 100644 --- a/src/recursive.rs +++ b/src/recursive.rs @@ -2,8 +2,8 @@ use std::fmt::{Debug, Display, Formatter}; use std::vec::IntoIter; use crate::BinarySearchTree; -use crate::Node; use crate::HeapNode; +use crate::Node; /// Recursive Binary Search Tree implementation. /// /// # Important @@ -193,6 +193,47 @@ impl BinarySearchTree for RecursiveBST { } } + /// Inserts the given value as a node or updates an existing node with the same value. + /// + /// If a node with the given value already exists, the provided function `f` is called + /// with a mutable reference to the existing value to update it. + /// + /// # Example + /// + /// ```rust + /// use bst_rs::{BinarySearchTree, RecursiveBST}; + /// + /// let mut bst = RecursiveBST::new(); + /// + /// bst.insert((1, 10)); + /// assert_eq!(bst.retrieve(&(1, 10)), Some(&(1, 10))); + /// + /// bst.insert_or_update((1, 10), |tuple| tuple.1 = 30); + /// assert_eq!(bst.retrieve(&(1, 10)), None); + /// assert_eq!(bst.retrieve(&(1, 30)), Some(&(1, 30))); + /// + /// bst.insert_or_update((2, 20), |_| {}); + /// assert_eq!(bst.size(), 2); + /// ``` + fn insert_or_update(&mut self, value: T, f: F) + where + F: Fn(&mut T), + { + match self.root { + None => { + self.root = Some(Box::from(Node::new(value))); + self.size += 1; + } + Some(ref mut node) => { + if let Ok(inserted) = node.recursive_insert_or_update(value, f) { + if inserted { + self.size += 1; + } + } + } + } + } + /// Returns `true` if the binary search tree contains an element with the given value. /// /// # Example @@ -1652,4 +1693,30 @@ mod tests { assert_eq!(actual_bst, expected_bst); } + + #[test] + fn successfully_insert_or_update_elements() { + let mut bst = RecursiveBST::new(); + + bst.insert((1, 10)); + bst.insert((2, 20)); + bst.insert((3, 30)); + assert_eq!(bst.size(), 3); + + bst.insert_or_update((1, 50), |pair| pair.1 = 15); + assert_eq!(bst.retrieve(&(1, 50)), Some(&(1, 50))); + bst.insert_or_update((1, 50), |pair| pair.1 = 15); + assert_eq!(bst.retrieve(&(1, 50)), None); + println!("{}", bst); + assert_eq!(bst.size(), 4); + + bst.insert_or_update((4, 40), |_| {}); + assert_eq!(bst.size(), 5); + + bst.insert_or_update((4, 40), |pair| { + pair.1 *= 2; + }); + assert_eq!(bst.retrieve(&(4, 40)), None); + assert_eq!(bst.retrieve(&(4, 80)), Some(&(4, 80))); + } }