diff --git a/src/caching/approximate_cache.rs b/src/caching/approximate_cache.rs index 82804c25d558720afc9826f936345b775ac86c29..d574bd78529a1c2e74e6766a7137bdffbd64b585 100644 --- a/src/caching/approximate_cache.rs +++ b/src/caching/approximate_cache.rs @@ -11,4 +11,7 @@ where fn find(&mut self, key: &K) -> Option<V>; fn insert(&mut self, key: K, value: V); fn len(&self) -> usize; + fn is_empty(&self) -> bool { + self.len() == 0 + } } diff --git a/src/caching/bounded/bounded_linear_cache.rs b/src/caching/bounded/bounded_linear_cache.rs index ce8ded8f95e92ee86b5e11b90de131fd514a9b74..7e7e98870549be181855ac0578c389b9a2c34473 100644 --- a/src/caching/bounded/bounded_linear_cache.rs +++ b/src/caching/bounded/bounded_linear_cache.rs @@ -45,7 +45,6 @@ use super::list_node::{Node, SharedNode}; /// - `find(&mut self, key: &K) -> Option<V>`: Attempts to find a value matching the given key approximately. Promotes the found key to the head of the list. /// - `insert(&mut self, key: K, value: V)`: Inserts a key-value pair into the cache. Evicts the least recently used item if the cache is full. /// - `len(&self) -> usize`: Returns the current size of the cache. - pub struct BoundedLinearCache<K, V> { max_capacity: usize, map: HashMap<K, SharedNode<K, V>>, @@ -66,13 +65,13 @@ where let node: SharedNode<K, V> = self.map.get(matching).cloned()?; self.list.remove(node.clone()); self.list.add_to_head(node.clone()); - return Some(node.try_lock().unwrap().value.clone()); + return Some(node.borrow().value.clone()); } fn insert(&mut self, key: K, value: V) { if self.len() == self.max_capacity { if let Some(tail) = self.list.remove_tail() { - self.map.remove(&tail.try_lock().unwrap().key); + self.map.remove(&tail.borrow().key); } } let new_node = Node::new(key.clone(), value); @@ -100,10 +99,20 @@ impl<K, V> BoundedLinearCache<K, V> { macro_rules! create_pythonized_interface { ($name: ident, $type: ident) => { - #[pyclass] + // unsendable == should hard-crash if Python tries to access it from + // two different Python threads. + // + // The implementation is very much thread-unsafe anyways (lots of mutations), + // so this is an OK behavior, we will detect it with a nice backtrace + // and without UB. + // + // Even in the case where we want the cache to be multithreaded, this would + // happen on the Rust side and will not be visible to the Python ML pipeline. + #[pyclass(unsendable)] pub struct $name { - inner: BoundedLinearCache::<$type, $type>, + inner: BoundedLinearCache<$type, $type>, } + #[pymethods] impl $name { #[new] @@ -113,7 +122,7 @@ macro_rules! create_pythonized_interface { } } - fn find(&mut self, k : $type) -> Option<$type> { + fn find(&mut self, k: $type) -> Option<$type> { self.inner.find(&k) } diff --git a/src/caching/bounded/linked_list.rs b/src/caching/bounded/linked_list.rs index d72423968d4e76de9a87396a1e919e916ba18396..3b560aec1cf06ef47d95e2d1471845383ee253b6 100644 --- a/src/caching/bounded/linked_list.rs +++ b/src/caching/bounded/linked_list.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::rc::Rc; use crate::caching::bounded::list_node::SharedNode; @@ -16,11 +16,11 @@ impl<K, V> DoublyLinkedList<K, V> { } pub(crate) fn add_to_head(&mut self, node: SharedNode<K, V>) { - node.try_lock().unwrap().next = self.head.clone(); - node.try_lock().unwrap().prev = None; + node.borrow_mut().next = self.head.clone(); + node.borrow_mut().prev = None; if let Some(head) = self.head.clone() { - head.try_lock().unwrap().prev = Some(Arc::downgrade(&node)); + head.borrow_mut().prev = Some(Rc::downgrade(&node)); } self.head = Some(node.clone()); @@ -31,17 +31,17 @@ impl<K, V> DoublyLinkedList<K, V> { } pub(crate) fn remove(&mut self, node: SharedNode<K, V>) { - let prev = node.try_lock().unwrap().prev.clone(); - let next = node.try_lock().unwrap().next.clone(); + let prev = node.borrow().prev.clone(); + let next = node.borrow().next.clone(); if let Some(prev_node) = prev.as_ref().and_then(|weak| weak.upgrade()) { - prev_node.try_lock().unwrap().next = next.clone(); + prev_node.borrow_mut().next = next.clone(); } else { self.head = next.clone(); } if let Some(next_node) = next { - next_node.try_lock().unwrap().prev = prev; + next_node.borrow_mut().prev = prev; } else { self.tail = prev.and_then(|weak| weak.upgrade()); } @@ -50,9 +50,10 @@ impl<K, V> DoublyLinkedList<K, V> { pub(crate) fn remove_tail(&mut self) -> Option<SharedNode<K, V>> { if let Some(tail) = self.tail.clone() { self.remove(tail.clone()); - return Some(tail); + Some(tail) + } else { + None } - None } } @@ -68,12 +69,12 @@ mod tests { let node2 = Node::new(2, 20); list.add_to_head(node1.clone()); - assert_eq!(list.head.as_ref().unwrap().try_lock().unwrap().key, 1); - assert_eq!(list.tail.as_ref().unwrap().try_lock().unwrap().key, 1); + assert_eq!(list.head.as_ref().unwrap().borrow().key, 1); + assert_eq!(list.tail.as_ref().unwrap().borrow().key, 1); list.add_to_head(node2.clone()); - assert_eq!(list.head.as_ref().unwrap().try_lock().unwrap().key, 2); - assert_eq!(list.tail.as_ref().unwrap().try_lock().unwrap().key, 1); + assert_eq!(list.head.as_ref().unwrap().borrow().key, 2); + assert_eq!(list.tail.as_ref().unwrap().borrow().key, 1); } #[test] @@ -90,13 +91,13 @@ mod tests { // List is now: {3, 2, 1} list.remove(node2.clone()); // List should now be: {3, 1} - assert_eq!(list.head.as_ref().unwrap().try_lock().unwrap().key, 3); - assert_eq!(list.tail.as_ref().unwrap().try_lock().unwrap().key, 1); + assert_eq!(list.head.as_ref().unwrap().borrow().key, 3); + assert_eq!(list.tail.as_ref().unwrap().borrow().key, 1); list.remove(node3.clone()); // List should now be: {1} - assert_eq!(list.head.as_ref().unwrap().try_lock().unwrap().key, 1); - assert_eq!(list.tail.as_ref().unwrap().try_lock().unwrap().key, 1); + assert_eq!(list.head.as_ref().unwrap().borrow().key, 1); + assert_eq!(list.tail.as_ref().unwrap().borrow().key, 1); list.remove(node1.clone()); // List should now be empty @@ -115,13 +116,13 @@ mod tests { // List is now: {2, 1} let removed_tail = list.remove_tail().unwrap(); - assert_eq!(removed_tail.try_lock().unwrap().key, 1); + assert_eq!(removed_tail.borrow().key, 1); // List should now be: {2} - assert_eq!(list.head.as_ref().unwrap().try_lock().unwrap().key, 2); - assert_eq!(list.tail.as_ref().unwrap().try_lock().unwrap().key, 2); + assert_eq!(list.head.as_ref().unwrap().borrow().key, 2); + assert_eq!(list.tail.as_ref().unwrap().borrow().key, 2); let removed_tail = list.remove_tail().unwrap(); - assert_eq!(removed_tail.try_lock().unwrap().key, 2); + assert_eq!(removed_tail.borrow().key, 2); // List should now be empty assert!(list.head.is_none()); assert!(list.tail.is_none()); @@ -139,17 +140,17 @@ mod tests { list.add_to_head(node3.clone()); // List is now: {3, 2, 1} - assert_eq!(list.head.as_ref().unwrap().try_lock().unwrap().key, 3); - assert_eq!(list.tail.as_ref().unwrap().try_lock().unwrap().key, 1); + assert_eq!(list.head.as_ref().unwrap().borrow().key, 3); + assert_eq!(list.tail.as_ref().unwrap().borrow().key, 1); list.remove(node1.clone()); // List should now be: {3, 2} - assert_eq!(list.head.as_ref().unwrap().try_lock().unwrap().key, 3); - assert_eq!(list.tail.as_ref().unwrap().try_lock().unwrap().key, 2); + assert_eq!(list.head.as_ref().unwrap().borrow().key, 3); + assert_eq!(list.tail.as_ref().unwrap().borrow().key, 2); list.add_to_head(node1.clone()); // List should now be: {1, 3, 2} - assert_eq!(list.head.as_ref().unwrap().try_lock().unwrap().key, 1); - assert_eq!(list.tail.as_ref().unwrap().try_lock().unwrap().key, 2); + assert_eq!(list.head.as_ref().unwrap().borrow().key, 1); + assert_eq!(list.tail.as_ref().unwrap().borrow().key, 2); } } diff --git a/src/caching/bounded/list_node.rs b/src/caching/bounded/list_node.rs index 2ca9264d8bb8ad376f4d84f22a1a04514b8cdede..dafe9f6394610c7e95aeba68b225da5d0a6c25e3 100644 --- a/src/caching/bounded/list_node.rs +++ b/src/caching/bounded/list_node.rs @@ -1,7 +1,10 @@ -use std::sync::{Arc, Weak, Mutex}; +use std::{ + cell::RefCell, + rc::{Rc, Weak}, +}; -pub(crate) type SharedNode<K, V> = Arc<Mutex<Node<K, V>>>; -pub(crate) type WeakSharedNode<K, V> = Weak<Mutex<Node<K, V>>>; +pub(crate) type SharedNode<K, V> = Rc<RefCell<Node<K, V>>>; +pub(crate) type WeakSharedNode<K, V> = Weak<RefCell<Node<K, V>>>; #[derive(Debug)] pub struct Node<K, V> { @@ -12,8 +15,8 @@ pub struct Node<K, V> { } impl<K, V> Node<K, V> { - pub fn new(key: K, value: V) -> Arc<Mutex<Self>> { - Arc::new(Mutex::new(Node { + pub fn new(key: K, value: V) -> Rc<RefCell<Self>> { + Rc::new(RefCell::new(Node { key, value, prev: None, diff --git a/src/numerics/f32vector.rs b/src/numerics/f32vector.rs index 0a889149bdad28ccc713ae60c39aa6512f0621a2..ed7a98502b918dbb1c07a9047aecc3c89cc9b3f0 100644 --- a/src/numerics/f32vector.rs +++ b/src/numerics/f32vector.rs @@ -13,6 +13,10 @@ impl<'a> F32Vector<'a> { self.array.len() } + pub fn is_empty(&self) -> bool { + self.array.is_empty() + } + /// # Usage /// Computes the **SQUARED** L2 distance between two vectors. /// This is cheaper to compute than the regular L2 distance.