diff --git a/bindings/src/api.rs b/bindings/src/api.rs index e6b154d29f3351aa3f454c82c2e8471d9c2e1890..4224dc20d4eda675f265ee55f750acbd09a8afb8 100644 --- a/bindings/src/api.rs +++ b/bindings/src/api.rs @@ -10,7 +10,7 @@ use pyo3::{ }; macro_rules! create_pythonized_interface { - ($name: ident, $keytype: ident, $valuetype : ident) => { + ($name: ident, $keytype: ident, $valuetype : ident, $best_match : expr) => { // unsendable == should hard-crash if Python tries to access it from // two different Python threads. // @@ -22,7 +22,7 @@ macro_rules! create_pythonized_interface { // happen on the Rust side and will not be visible to the Python ML pipeline. #[pyclass(unsendable)] pub struct $name { - inner: BoundedLinearCache<$keytype, $valuetype>, + inner: BoundedLinearCache<$keytype, $valuetype, $best_match>, } #[pymethods] @@ -126,12 +126,13 @@ impl ApproxComparable for VecPy<f32> { F32Vector::from(&self.inner as &[f32]) .roughly_matches(&F32Vector::from(&instore.inner as &[f32]), tolerance) } + fn fuzziness(&self, instore: &Self) -> f32 { + F32Vector::from(&self.inner as &[f32]).fuzziness(&F32Vector::from(&instore.inner as &[f32])) + } } type F32VecPy = VecPy<f32>; -type U32VecPy = VecPy<u32>; type UsizeVecPy = VecPy<usize>; -create_pythonized_interface!(I16ToF32VectorCache, i16, F32VecPy); -create_pythonized_interface!(FVecToU32VectorCache, F32VecPy, U32VecPy); -create_pythonized_interface!(FVecToUsizeVectorCache, F32VecPy, UsizeVecPy); +create_pythonized_interface!(FVecToUsizeVectorAny, F32VecPy, UsizeVecPy, false); +create_pythonized_interface!(FVecToUsizeVectorBest, F32VecPy, UsizeVecPy, true); diff --git a/bindings/src/lib.rs b/bindings/src/lib.rs index 8a64c312602613cd000cd79fdd8d34f59c7132c2..caec528f2ded7f77b9b851af6130caa8d4ede0bb 100644 --- a/bindings/src/lib.rs +++ b/bindings/src/lib.rs @@ -1,4 +1,4 @@ -use api::{FVecToU32VectorCache, FVecToUsizeVectorCache, I16ToF32VectorCache}; +use api::{FVecToUsizeVectorAny, FVecToUsizeVectorBest}; use pyo3::prelude::*; mod api; @@ -6,8 +6,7 @@ mod api; /// A Python module implemented in Rust. #[pymodule] fn proximipy(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_class::<I16ToF32VectorCache>()?; - m.add_class::<FVecToU32VectorCache>()?; - m.add_class::<FVecToUsizeVectorCache>()?; + m.add_class::<FVecToUsizeVectorBest>()?; + m.add_class::<FVecToUsizeVectorAny>()?; Ok(()) } diff --git a/core/src/caching/bounded/bounded_linear_cache.rs b/core/src/caching/bounded/bounded_linear_cache.rs index 3fed6b9729656f98e60e4d08e1bff340b7a88d75..7bf92be9b460fdeddcb9531929a2824240c9e175 100644 --- a/core/src/caching/bounded/bounded_linear_cache.rs +++ b/core/src/caching/bounded/bounded_linear_cache.rs @@ -21,7 +21,7 @@ use super::list_node::{Node, SharedNode}; /// use proximipy::caching::bounded::bounded_linear_cache::BoundedLinearCache; /// use proximipy::caching::approximate_cache::ApproximateCache; /// -/// let mut cache = BoundedLinearCache::new(3, 2.0); +/// let mut cache = BoundedLinearCache::<_,_,true>::new(3, 2.0); /// /// cache.insert(10 as i16, "Value 1"); /// cache.insert(20, "Value 2"); @@ -43,23 +43,35 @@ use super::list_node::{Node, SharedNode}; /// - `find(&mut self, key: &K) -> Option<V>`: Attempts to find a value matching the given key approximately. Promotes the found key to the head of the list. /// - `insert(&mut self, key: K, value: V)`: Inserts a key-value pair into the cache. Evicts the least recently used item if the cache is full. /// - `len(&self) -> usize`: Returns the current size of the cache. -pub struct BoundedLinearCache<K, V> { +pub struct BoundedLinearCache<K, V, const BEST_MATCH: bool> { max_capacity: usize, map: HashMap<K, SharedNode<K, V>>, list: DoublyLinkedList<K, V>, tolerance: f32, } -impl<K, V> ApproximateCache<K, V> for BoundedLinearCache<K, V> +impl<K, V, const BEST_MATCH: bool> ApproximateCache<K, V> for BoundedLinearCache<K, V, BEST_MATCH> where K: ApproxComparable + Eq + Hash + Clone, V: Clone, { fn find(&mut self, key: &K) -> Option<V> { - let matching = self - .map - .keys() - .find(|&k| key.roughly_matches(k, self.tolerance))?; + let matching = if BEST_MATCH { + self.map + .keys() + .find(|&k| key.roughly_matches(k, self.tolerance))? + } else { + let potential_candi = self + .map + .keys() + .min_by(|&x, &y| key.fuzziness(x).partial_cmp(&key.fuzziness(y)).unwrap())?; + + if potential_candi.roughly_matches(key, self.tolerance) { + potential_candi + } else { + return None; + } + }; let node: SharedNode<K, V> = self.map.get(matching).cloned()?; self.list.remove(node.clone()); self.list.add_to_head(node.clone()); @@ -82,7 +94,7 @@ where } } -impl<K, V> BoundedLinearCache<K, V> { +impl<K, V, const BESTMATCH: bool> BoundedLinearCache<K, V, BESTMATCH> { pub fn new(max_capacity: usize, tolerance: f32) -> Self { assert!(max_capacity > 0); assert!(tolerance > 0.0); @@ -102,57 +114,78 @@ mod tests { const TEST_TOLERANCE: f32 = 1e-8; #[test] fn test_lru_cache_basic_operations() { - let mut cache: BoundedLinearCache<i16, i16> = BoundedLinearCache::new(2, TEST_TOLERANCE); - cache.insert(1, 1); // Cache is {1=1} - cache.insert(2, 2); // Cache is {1=1, 2=2} - assert_eq!(cache.find(&1), Some(1)); // Returns 1, Cache is {2=2, 1=1} - cache.insert(3, 3); // Evicts key 2, Cache is {1=1, 3=3} - assert_eq!(cache.find(&2), None); // Key 2 not found - cache.insert(4, 4); // Evicts key 1, Cache is {3=3, 4=4} - assert_eq!(cache.find(&1), None); // Key 1 not found - assert_eq!(cache.find(&3), Some(3)); // Returns 3 - assert_eq!(cache.find(&4), Some(4)); // Returns 4 + fn test_lru_cache_basic_operations_best_match<const BEST_MATCH: bool>() { + let mut cache: BoundedLinearCache<i16, i16, BEST_MATCH> = + BoundedLinearCache::new(2, TEST_TOLERANCE); + cache.insert(1, 1); // Cache is {1=1} + cache.insert(2, 2); // Cache is {1=1, 2=2} + assert_eq!(cache.find(&1), Some(1)); // Returns 1, Cache is {2=2, 1=1} + cache.insert(3, 3); // Evicts key 2, Cache is {1=1, 3=3} + assert_eq!(cache.find(&2), None); // Key 2 not found + cache.insert(4, 4); // Evicts key 1, Cache is {3=3, 4=4} + assert_eq!(cache.find(&1), None); // Key 1 not found + assert_eq!(cache.find(&3), Some(3)); // Returns 3 + assert_eq!(cache.find(&4), Some(4)); // Returns 4 + } + test_lru_cache_basic_operations_best_match::<true>(); + test_lru_cache_basic_operations_best_match::<false>(); } #[test] fn test_lru_cache_eviction_order() { - let mut cache: BoundedLinearCache<i16, i16> = BoundedLinearCache::new(3, TEST_TOLERANCE); - cache.insert(1, 1); // Cache is {1=1} - cache.insert(2, 2); // Cache is {1=1, 2=2} - cache.insert(3, 3); // Cache is {1=1, 2=2, 3=3} - cache.find(&1); // Access key 1, Cache is {2=2, 3=3, 1=1} - cache.insert(4, 4); // Evicts key 2, Cache is {3=3, 1=1, 4=4} - assert_eq!(cache.find(&2), None); // Key 2 not found - assert_eq!(cache.find(&3), Some(3)); // Returns 3 - assert_eq!(cache.find(&4), Some(4)); // Returns 4 - assert_eq!(cache.find(&1), Some(1)); // Returns 1 + fn test_lru_cache_eviction_order_best_match<const BEST_MATCH: bool>() { + let mut cache: BoundedLinearCache<i16, i16, BEST_MATCH> = + BoundedLinearCache::new(3, TEST_TOLERANCE); + cache.insert(1, 1); // Cache is {1=1} + cache.insert(2, 2); // Cache is {1=1, 2=2} + cache.insert(3, 3); // Cache is {1=1, 2=2, 3=3} + cache.find(&1); // Access key 1, Cache is {2=2, 3=3, 1=1} + cache.insert(4, 4); // Evicts key 2, Cache is {3=3, 1=1, 4=4} + assert_eq!(cache.find(&2), None); // Key 2 not found + assert_eq!(cache.find(&3), Some(3)); // Returns 3 + assert_eq!(cache.find(&4), Some(4)); // Returns 4 + assert_eq!(cache.find(&1), Some(1)); // Returns 1 + } + + test_lru_cache_eviction_order_best_match::<true>(); + test_lru_cache_eviction_order_best_match::<false>(); } #[test] fn test_lru_cache_overwrite() { - let mut cache: BoundedLinearCache<i16, i16> = BoundedLinearCache::new(2, TEST_TOLERANCE); - cache.insert(1, 1); // Cache is {1=1} - cache.insert(2, 2); // Cache is {1=1, 2=2} - cache.insert(1, 10); // Overwrites key 1, Cache is {2=2, 1=10} - assert_eq!(cache.find(&1), Some(10)); // Returns 10 - cache.insert(3, 3); // Evicts key 2, Cache is {1=10, 3=3} - assert_eq!(cache.find(&2), None); // Key 2 not found - assert_eq!(cache.find(&3), Some(3)); // Returns 3 + fn test_lru_cache_overwrite_best_match<const BEST_MATCH: bool>() { + let mut cache: BoundedLinearCache<i16, i16, BEST_MATCH> = + BoundedLinearCache::new(2, TEST_TOLERANCE); + cache.insert(1, 1); // Cache is {1=1} + cache.insert(2, 2); // Cache is {1=1, 2=2} + cache.insert(1, 10); // Overwrites key 1, Cache is {2=2, 1=10} + assert_eq!(cache.find(&1), Some(10)); // Returns 10 + cache.insert(3, 3); // Evicts key 2, Cache is {1=10, 3=3} + assert_eq!(cache.find(&2), None); // Key 2 not found + assert_eq!(cache.find(&3), Some(3)); // Returns 3 + } + test_lru_cache_overwrite_best_match::<true>(); + test_lru_cache_overwrite_best_match::<false>(); } #[test] fn test_lru_cache_capacity_one() { - let mut cache: BoundedLinearCache<i16, i16> = BoundedLinearCache::new(1, TEST_TOLERANCE); - cache.insert(1, 1); // Cache is {1=1} - assert_eq!(cache.find(&1), Some(1)); // Returns 1 - cache.insert(2, 2); // Evicts key 1, Cache is {2=2} - assert_eq!(cache.find(&1), None); // Key 1 not found - assert_eq!(cache.find(&2), Some(2)); // Returns 2 + fn test_lru_cache_capacity_one_best_match<const BEST_MATCH: bool>() { + let mut cache: BoundedLinearCache<i16, i16, BEST_MATCH> = + BoundedLinearCache::new(1, TEST_TOLERANCE); + cache.insert(1, 1); // Cache is {1=1} + assert_eq!(cache.find(&1), Some(1)); // Returns 1 + cache.insert(2, 2); // Evicts key 1, Cache is {2=2} + assert_eq!(cache.find(&1), None); // Key 1 not found + assert_eq!(cache.find(&2), Some(2)); // Returns 2 + } + test_lru_cache_capacity_one_best_match::<true>(); + test_lru_cache_capacity_one_best_match::<false>(); } #[test] #[should_panic] fn test_lru_cache_empty() { - let _cache: BoundedLinearCache<i16, i16> = BoundedLinearCache::new(0, TEST_TOLERANCE); + let _cache: BoundedLinearCache<i16, i16, true> = BoundedLinearCache::new(0, TEST_TOLERANCE); } } diff --git a/core/src/main.rs b/core/src/main.rs index f966bb10329ca05126a99c39fbde559b42596080..28b62b45434cc523994c5b07d381e3ff7815ed8c 100644 --- a/core/src/main.rs +++ b/core/src/main.rs @@ -27,7 +27,7 @@ fn main() { let vecs_f: Vec<f32> = vecs.into_iter().map(f32::from).collect(); println!("{:?}", vecs_f.chunks_exact(128).next().unwrap()); - let mut ulc = BoundedLinearCache::<F32Vector, usize>::new(10000, 15_000.0); + let mut ulc = BoundedLinearCache::<F32Vector, usize, true>::new(10000, 15_000.0); let mut count: u32 = 0; let mut scanned: usize = 0; diff --git a/core/src/numerics/comp.rs b/core/src/numerics/comp.rs index df0d7b354f0722a96e4f8fb29788dbffb3763f6d..e5f1845631308272894f95a3007234fb00bd59f4 100644 --- a/core/src/numerics/comp.rs +++ b/core/src/numerics/comp.rs @@ -1,26 +1,33 @@ use super::f32vector::F32Vector; pub trait ApproxComparable { - fn roughly_matches(&self, instore: &Self, tolerance: f32) -> bool; + #[inline] + fn roughly_matches(&self, instore: &Self, tolerance: f32) -> bool { + self.fuzziness(instore) < tolerance + } + fn fuzziness(&self, instore: &Self) -> f32; } impl ApproxComparable for f32 { - fn roughly_matches(&self, target: &f32, tolerance: f32) -> bool { - (self - target).abs() < tolerance + fn fuzziness(&self, instore: &Self) -> f32 { + (self - instore).abs() } } impl<'a> ApproxComparable for F32Vector<'a> { - #[inline] fn roughly_matches(&self, target: &F32Vector<'a>, tolerance: f32) -> bool { self.l2_dist_squared(target) < tolerance * tolerance } + + fn fuzziness(&self, instore: &Self) -> f32 { + self.l2_dist_squared(instore).sqrt() + } } impl ApproxComparable for i16 { - fn roughly_matches(&self, instore: &Self, tolerance: f32) -> bool { + fn fuzziness(&self, instore: &Self) -> f32 { let fself = f32::from(*self); let foth = f32::from(*instore); - fself.roughly_matches(&foth, tolerance) + fself.fuzziness(&foth) } }