diff --git a/README.md b/README.md index 18f3a5decd45dbbbf2ee9eae57c03dbc825b85a8..61ed21f9f0244b867b94893a3f4653d482d1595c 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# proximity-cache +# Proximity Proximity is a research project exploring the optimization and speed-recall tradeoffs of approximate vector search in high-dimensional spaces. We provide an approximate cache for vector databases that is written in Rust and exposes Python bindings. @@ -44,4 +44,3 @@ todo This project is licensed under the MIT License. See LICENSE for details. -This code is meant as a beta/development playground. It should not be used for production systems. diff --git a/bindings/Cargo.toml b/bindings/Cargo.toml index ac7aeae78774522d89c9e33d5653a4147d4965c1..51bfe176682108cf5147b09a0dd2418fae887be3 100644 --- a/bindings/Cargo.toml +++ b/bindings/Cargo.toml @@ -9,5 +9,5 @@ name = "proximipy" crate-type = ["cdylib"] [dependencies] -pyo3 = "0.23.3" +pyo3 = {version = "0.23.3", features = ["py-clone"]} proximity-cache = { path = "../core"} diff --git a/bindings/src/api.rs b/bindings/src/api.rs index 15a33fb879255da2af6423b55627642c5a668802..5386e0eb43681da86766bd9aa9cb6d69c8707e12 100644 --- a/bindings/src/api.rs +++ b/bindings/src/api.rs @@ -1,10 +1,12 @@ use std::hash::{Hash, Hasher}; +use proximipy::caching::approximate_cache::ApproximateCache; use proximipy::caching::bounded::fifo::fifo_cache::FifoCache as FifoInternal; use proximipy::caching::bounded::lru::lru_cache::LRUCache as LruInternal; - +use proximipy::numerics::comp::ApproxComparable; use proximipy::numerics::f32vector::F32Vector; -use proximipy::{caching::approximate_cache::ApproximateCache, numerics::comp::ApproxComparable}; + +use pyo3::PyObject; use pyo3::{ pyclass, pymethods, types::{PyAnyMethods, PyList}, @@ -12,7 +14,7 @@ use pyo3::{ }; macro_rules! create_pythonized_interface { - ($internal : ident, $name: ident, $keytype: ident, $valuetype : ident) => { + ($internal : ident, $name: ident, $keytype: ident) => { // unsendable == should hard-crash if Python tries to access it from // two different Python threads. // @@ -24,36 +26,33 @@ macro_rules! create_pythonized_interface { // happen on the Rust side and will not be visible to the Python ML pipeline. #[pyclass(unsendable)] pub struct $name { - inner: $internal<$keytype, $valuetype>, + inner: $internal<$keytype, PyObject>, } #[pymethods] impl $name { #[new] - pub fn new(max_capacity: usize, tolerance: f32) -> Self { + pub fn new(max_capacity: usize) -> Self { Self { - inner: $internal::new(max_capacity, tolerance), + inner: $internal::new(max_capacity), } } - fn find(&mut self, k: $keytype) -> Option<$valuetype> { + fn find(&mut self, k: $keytype) -> Option<PyObject> { self.inner.find(&k) } - fn batch_find(&mut self, ks: Vec<$keytype>) -> Vec<Option<$valuetype>> { + fn batch_find(&mut self, ks: Vec<$keytype>) -> Vec<Option<PyObject>> { + // more efficient than a python for loop ks.into_iter().map(|k| self.find(k)).collect() } - fn insert(&mut self, key: $keytype, value: $valuetype) { - self.inner.insert(key, value) - } - - fn len(&self) -> usize { - self.inner.len() + fn insert(&mut self, key: $keytype, value: PyObject, tolerance: f32) { + self.inner.insert(key, value, tolerance) } fn __len__(&self) -> usize { - self.len() + self.inner.len() } } }; @@ -138,8 +137,6 @@ impl ApproxComparable for VecPy<f32> { } type F32VecPy = VecPy<f32>; -type UsizeVecPy = VecPy<usize>; -type UsizeWithRankingVecPy = (UsizeVecPy, F32VecPy); -create_pythonized_interface!(LruInternal, LRUCache, F32VecPy, UsizeWithRankingVecPy); -create_pythonized_interface!(FifoInternal, FifoCache, F32VecPy, UsizeWithRankingVecPy); +create_pythonized_interface!(LruInternal, LRUCache, F32VecPy); +create_pythonized_interface!(FifoInternal, FifoCache, F32VecPy); diff --git a/core/src/caching/approximate_cache.rs b/core/src/caching/approximate_cache.rs index d574bd78529a1c2e74e6766a7137bdffbd64b585..c8983ce1992fee2b98959a868a3b8404be113894 100644 --- a/core/src/caching/approximate_cache.rs +++ b/core/src/caching/approximate_cache.rs @@ -1,15 +1,13 @@ use crate::numerics::comp::ApproxComparable; -// size of caches in implementations where that should be known at comptime -pub const COMPTIME_CACHE_SIZE: usize = 1024; +pub type Tolerance = f32; pub trait ApproximateCache<K, V> where K: ApproxComparable, - V: Clone, { - fn find(&mut self, key: &K) -> Option<V>; - fn insert(&mut self, key: K, value: V); + fn find(&mut self, target: &K) -> Option<V>; + fn insert(&mut self, key: K, value: V, tolerance: f32); fn len(&self) -> usize; fn is_empty(&self) -> bool { self.len() == 0 diff --git a/core/src/caching/bounded/fifo/fifo_cache.rs b/core/src/caching/bounded/fifo/fifo_cache.rs index 60cb5b00d9949fa97e189e0cba55186f3e9bd5f8..6dc388f1c9553ffd3e07d98a4fbb2e211d5ea0a2 100644 --- a/core/src/caching/bounded/fifo/fifo_cache.rs +++ b/core/src/caching/bounded/fifo/fifo_cache.rs @@ -1,11 +1,19 @@ use std::collections::VecDeque; -use crate::{caching::approximate_cache::ApproximateCache, numerics::comp::ApproxComparable}; +use crate::caching::approximate_cache::ApproximateCache; +use crate::caching::approximate_cache::Tolerance; +use crate::numerics::comp::ApproxComparable; + +#[derive(Clone)] +struct CacheLine<K, V> { + key: K, + tol: Tolerance, + value: V, +} pub struct FifoCache<K, V> { max_capacity: usize, - items: VecDeque<(K, V)>, - tolerance: f32, + items: VecDeque<CacheLine<K, V>>, } impl<K, V> ApproximateCache<K, V> for FifoCache<K, V> @@ -13,21 +21,27 @@ where K: ApproxComparable, V: Clone, { - fn find(&mut self, key: &K) -> Option<V> { + fn find(&mut self, target: &K) -> Option<V> { let candidate = self .items .iter() - .min_by(|&(x, _), &(y, _)| key.fuzziness(x).partial_cmp(&key.fuzziness(y)).unwrap())?; - let (candidate_key, candidate_value) = candidate; - if candidate_key.roughly_matches(key, self.tolerance) { - Some(candidate_value.clone()) - } else { - None - } + .filter(|&entry| entry.key.roughly_matches(target, entry.tol)) + .min_by(|&x, &y| { + target + .fuzziness(&x.key) + .partial_cmp(&target.fuzziness(&y.key)) + .unwrap() + })?; + Some(candidate.value.clone()) } - fn insert(&mut self, key: K, value: V) { - self.items.push_back((key, value)); + fn insert(&mut self, key: K, value: V, tolerance: f32) { + let new_entry = CacheLine { + key, + tol: tolerance, + value, + }; + self.items.push_back(new_entry); if self.items.len() > self.max_capacity { self.items.pop_front(); } @@ -39,13 +53,11 @@ where } impl<K, V> FifoCache<K, V> { - pub fn new(max_capacity: usize, tolerance: f32) -> Self { + pub fn new(max_capacity: usize) -> Self { assert!(max_capacity > 0); - assert!(tolerance > 0.0); Self { max_capacity, items: VecDeque::with_capacity(max_capacity), - tolerance, } } } @@ -57,11 +69,11 @@ mod tests { const TEST_TOLERANCE: f32 = 1e-8; #[test] fn test_fifo_cache_basic_operations() { - let mut cache = FifoCache::new(2, TEST_TOLERANCE); - cache.insert(1, 1); // Cache is {1=1} - cache.insert(2, 2); // Cache is {1=1, 2=2} + let mut cache = FifoCache::new(2); + cache.insert(1, 1, TEST_TOLERANCE); // Cache is {1=1} + cache.insert(2, 2, TEST_TOLERANCE); // Cache is {1=1, 2=2} assert_eq!(cache.find(&1), Some(1)); // Returns 1, Cache is {1=1, 2=2} - cache.insert(3, 3); // Evicts key 1, Cache is {2=2, 3=3} + cache.insert(3, 3, TEST_TOLERANCE); // Evicts key 1, Cache is {2=2, 3=3} assert_eq!(cache.find(&1), None); // Key 1 not found assert_eq!(cache.find(&2), Some(2)); // Returns 2 assert_eq!(cache.find(&3), Some(3)); // Returns 3 @@ -69,11 +81,11 @@ mod tests { #[test] fn test_fifo_cache_eviction_order() { - let mut cache = FifoCache::new(3, TEST_TOLERANCE); - cache.insert(1, 1); // Cache is {1=1} - cache.insert(2, 2); // Cache is {1=1, 2=2} - cache.insert(3, 3); // Cache is {1=1, 2=2, 3=3} - cache.insert(4, 4); // Evicts key 1, Cache is {2=2, 3=3, 4=4} + let mut cache = FifoCache::new(3); + cache.insert(1, 1, TEST_TOLERANCE); // Cache is {1=1} + cache.insert(2, 2, TEST_TOLERANCE); // Cache is {1=1, 2=2} + cache.insert(3, 3, TEST_TOLERANCE); // Cache is {1=1, 2=2, 3=3} + cache.insert(4, 4, TEST_TOLERANCE); // Evicts key 1, Cache is {2=2, 3=3, 4=4} assert_eq!(cache.find(&1), None); // Key 1 not found assert_eq!(cache.find(&2), Some(2)); // Returns 2 assert_eq!(cache.find(&3), Some(3)); // Returns 3 @@ -82,22 +94,22 @@ mod tests { #[test] fn test_fifo_cache_overwrite() { - let mut cache = FifoCache::new(2, TEST_TOLERANCE); - cache.insert(1, 1); // Cache is {1=1} - cache.insert(2, 2); // Cache is {1=1, 2=2} - cache.insert(1, 10); // Overwrites key 1, Cache is {2=2, 1=10} + let mut cache = FifoCache::new(2); + cache.insert(1, 1, TEST_TOLERANCE); // Cache is {1=1} + cache.insert(2, 2, TEST_TOLERANCE); // Cache is {1=1, 2=2} + cache.insert(1, 10, TEST_TOLERANCE); // Overwrites key 1, Cache is {2=2, 1=10} assert_eq!(cache.find(&1), Some(10)); // Returns 10 - cache.insert(3, 3); // Evicts key 2, Cache is {1=10, 3=3} + cache.insert(3, 3, TEST_TOLERANCE); // Evicts key 2, Cache is {1=10, 3=3} assert_eq!(cache.find(&2), None); // Key 2 not found assert_eq!(cache.find(&3), Some(3)); // Returns 3 } #[test] fn test_fifo_cache_capacity_one() { - let mut cache = FifoCache::new(1, TEST_TOLERANCE); - cache.insert(1, 1); // Cache is {1=1} + let mut cache = FifoCache::new(1); + cache.insert(1, 1, TEST_TOLERANCE); // Cache is {1=1} assert_eq!(cache.find(&1), Some(1)); // Returns 1 - cache.insert(2, 2); // Evicts key 1, Cache is {2=2} + cache.insert(2, 2, TEST_TOLERANCE); // Evicts key 1, Cache is {2=2} assert_eq!(cache.find(&1), None); // Key 1 not found assert_eq!(cache.find(&2), Some(2)); // Returns 2 } @@ -105,6 +117,6 @@ mod tests { #[test] #[should_panic] fn test_fifo_cache_empty() { - let _cache: FifoCache<i16, i16> = FifoCache::new(0, TEST_TOLERANCE); + let _cache: FifoCache<i16, i16> = FifoCache::new(0); } } diff --git a/core/src/caching/bounded/lru/lru_cache.rs b/core/src/caching/bounded/lru/lru_cache.rs index b60b5af9ef41967b3de06fe7b10878fd9c4f6eda..cd8245136319da27bb60cb693b897be3cddb6641 100644 --- a/core/src/caching/bounded/lru/lru_cache.rs +++ b/core/src/caching/bounded/lru/lru_cache.rs @@ -7,6 +7,7 @@ use crate::caching::approximate_cache::ApproximateCache; use super::linked_list::DoublyLinkedList; use super::list_node::{Node, SharedNode}; +use super::map_entry::MapEntry; /// `LRUCache` is a bounded cache with approximate key matching support and LRU eviction. /// @@ -19,16 +20,17 @@ use super::list_node::{Node, SharedNode}; /// use proximipy::caching::bounded::lru::lru_cache::LRUCache; /// use proximipy::caching::approximate_cache::ApproximateCache; /// -/// let mut cache = LRUCache::new(3, 2.0); +/// let mut cache = LRUCache::new(3); +/// const TEST_TOL: f32 = 2.0; /// -/// cache.insert(10 as i16, "Value 1"); -/// cache.insert(20, "Value 2"); -/// cache.insert(30, "Value 3"); +/// cache.insert(10 as i16, "Value 1", TEST_TOL); +/// cache.insert(20, "Value 2", TEST_TOL); +/// cache.insert(30, "Value 3", TEST_TOL); /// /// assert_eq!(cache.find(&11), Some("Value 1")); /// assert_eq!(cache.len(), 3); /// -/// cache.insert(40, "Value 4"); // Evicts the least recently used (Key(20)) +/// cache.insert(40, "Value 4", TEST_TOL); // Evicts the least recently used (Key(20)) /// assert!(cache.find(&20).is_none()); /// ``` /// @@ -43,9 +45,8 @@ use super::list_node::{Node, SharedNode}; /// - `len(&self) -> usize`: Returns the current size of the cache. pub struct LRUCache<K, V> { max_capacity: usize, - map: HashMap<K, SharedNode<K, V>>, - list: DoublyLinkedList<K, V>, - tolerance: f32, + map: HashMap<MapEntry<K>, SharedNode<MapEntry<K>, V>>, + list: DoublyLinkedList<MapEntry<K>, V>, } impl<K, V> ApproximateCache<K, V> for LRUCache<K, V> @@ -53,33 +54,37 @@ where K: ApproxComparable + Eq + Hash + Clone, V: Clone, { - fn find(&mut self, key: &K) -> Option<V> { - let potential_candi = self + fn find(&mut self, target: &K) -> Option<V> { + let candidate = self .map .keys() - .min_by(|&x, &y| key.fuzziness(x).partial_cmp(&key.fuzziness(y)).unwrap())?; - - let matching = if potential_candi.roughly_matches(key, self.tolerance) { - Some(potential_candi) - } else { - None - }?; - - let node: SharedNode<K, V> = self.map.get(matching).cloned()?; + .filter(|&entry| entry.key.roughly_matches(target, entry.tolerance)) + .min_by(|&xentry, ¥try| { + target + .fuzziness(&xentry.key) + .partial_cmp(&target.fuzziness(¥try.key)) + .unwrap() + })?; + + let node: SharedNode<MapEntry<K>, V> = self.map.get(candidate).cloned()?; self.list.remove(node.clone()); self.list.add_to_head(node.clone()); return Some(node.borrow().value.clone()); } - fn insert(&mut self, key: K, value: V) { + fn insert(&mut self, key: K, value: V, tolerance: f32) { if self.len() >= self.max_capacity { if let Some(tail) = self.list.remove_tail() { self.map.remove(&tail.borrow().key); } } - let new_node = Node::new(key.clone(), value); + let map_entry = MapEntry { + key: key.clone(), + tolerance, + }; + let new_node = Node::new(map_entry.clone(), value); self.list.add_to_head(new_node.clone()); - self.map.insert(key, new_node); + self.map.insert(map_entry, new_node); } fn len(&self) -> usize { @@ -88,14 +93,12 @@ where } impl<K, V> LRUCache<K, V> { - pub fn new(max_capacity: usize, tolerance: f32) -> Self { + pub fn new(max_capacity: usize) -> Self { assert!(max_capacity > 0); - assert!(tolerance > 0.0); Self { max_capacity, map: HashMap::with_capacity(max_capacity), list: DoublyLinkedList::new(), - tolerance, } } } @@ -107,13 +110,13 @@ mod tests { const TEST_TOLERANCE: f32 = 1e-8; #[test] fn test_lru_cache_basic_operations() { - let mut cache = LRUCache::new(2, TEST_TOLERANCE); - cache.insert(1, 1); // Cache is {1=1} - cache.insert(2, 2); // Cache is {1=1, 2=2} + let mut cache = LRUCache::new(2); + cache.insert(1, 1, TEST_TOLERANCE); // Cache is {1=1} + cache.insert(2, 2, TEST_TOLERANCE); // Cache is {1=1, 2=2} assert_eq!(cache.find(&1), Some(1)); // Returns 1, Cache is {2=2, 1=1} - cache.insert(3, 3); // Evicts key 2, Cache is {1=1, 3=3} + cache.insert(3, 3, TEST_TOLERANCE); // Evicts key 2, Cache is {1=1, 3=3} assert_eq!(cache.find(&2), None); // Key 2 not found - cache.insert(4, 4); // Evicts key 1, Cache is {3=3, 4=4} + cache.insert(4, 4, TEST_TOLERANCE); // Evicts key 1, Cache is {3=3, 4=4} assert_eq!(cache.find(&1), None); // Key 1 not found assert_eq!(cache.find(&3), Some(3)); // Returns 3 assert_eq!(cache.find(&4), Some(4)); // Returns 4 @@ -121,12 +124,12 @@ mod tests { #[test] fn test_lru_cache_eviction_order() { - let mut cache = LRUCache::new(3, TEST_TOLERANCE); - cache.insert(1, 1); // Cache is {1=1} - cache.insert(2, 2); // Cache is {1=1, 2=2} - cache.insert(3, 3); // Cache is {1=1, 2=2, 3=3} + let mut cache = LRUCache::new(3); + cache.insert(1, 1, TEST_TOLERANCE); // Cache is {1=1} + cache.insert(2, 2, TEST_TOLERANCE); // Cache is {1=1, 2=2} + cache.insert(3, 3, TEST_TOLERANCE); // Cache is {1=1, 2=2, 3=3} cache.find(&1); // Access key 1, Cache is {2=2, 3=3, 1=1} - cache.insert(4, 4); // Evicts key 2, Cache is {3=3, 1=1, 4=4} + cache.insert(4, 4, TEST_TOLERANCE); // Evicts key 2, Cache is {3=3, 1=1, 4=4} assert_eq!(cache.find(&2), None); // Key 2 not found assert_eq!(cache.find(&3), Some(3)); // Returns 3 assert_eq!(cache.find(&4), Some(4)); // Returns 4 @@ -135,22 +138,22 @@ mod tests { #[test] fn test_lru_cache_overwrite() { - let mut cache = LRUCache::new(2, TEST_TOLERANCE); - cache.insert(1, 1); // Cache is {1=1} - cache.insert(2, 2); // Cache is {1=1, 2=2} - cache.insert(1, 10); // Overwrites key 1, Cache is {2=2, 1=10} + let mut cache = LRUCache::new(2); + cache.insert(1, 1, TEST_TOLERANCE); // Cache is {1=1} + cache.insert(2, 2, TEST_TOLERANCE); // Cache is {1=1, 2=2} + cache.insert(1, 10, TEST_TOLERANCE); // Overwrites key 1, Cache is {2=2, 1=10} assert_eq!(cache.find(&1), Some(10)); // Returns 10 - cache.insert(3, 3); // Evicts key 2, Cache is {1=10, 3=3} + cache.insert(3, 3, TEST_TOLERANCE); // Evicts key 2, Cache is {1=10, 3=3} assert_eq!(cache.find(&2), None); // Key 2 not found assert_eq!(cache.find(&3), Some(3)); // Returns 3 } #[test] fn test_lru_cache_capacity_one() { - let mut cache = LRUCache::new(1, TEST_TOLERANCE); - cache.insert(1, 1); // Cache is {1=1} + let mut cache = LRUCache::new(1); + cache.insert(1, 1, TEST_TOLERANCE); // Cache is {1=1} assert_eq!(cache.find(&1), Some(1)); // Returns 1 - cache.insert(2, 2); // Evicts key 1, Cache is {2=2} + cache.insert(2, 2, TEST_TOLERANCE); // Evicts key 1, Cache is {2=2} assert_eq!(cache.find(&1), None); // Key 1 not found assert_eq!(cache.find(&2), Some(2)); // Returns 2 } @@ -158,6 +161,6 @@ mod tests { #[test] #[should_panic] fn test_lru_cache_empty() { - let _cache: LRUCache<i16, i16> = LRUCache::new(0, TEST_TOLERANCE); + let _cache: LRUCache<i16, i16> = LRUCache::new(0); } } diff --git a/core/src/caching/bounded/lru/map_entry.rs b/core/src/caching/bounded/lru/map_entry.rs new file mode 100644 index 0000000000000000000000000000000000000000..ea723f5cccbc92122702f29058fc2cf08a98f48c --- /dev/null +++ b/core/src/caching/bounded/lru/map_entry.rs @@ -0,0 +1,22 @@ +use std::hash::Hash; + +#[derive(Clone)] +pub struct MapEntry<K> { + pub key: K, + pub tolerance: f32, +} + +impl<K: Eq> PartialEq for MapEntry<K> { + fn eq(&self, other: &Self) -> bool { + self.key == other.key && self.tolerance == other.tolerance + } +} + +impl<K: Eq> Eq for MapEntry<K> {} + +impl<K: Hash> Hash for MapEntry<K> { + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { + self.key.hash(state); + self.tolerance.to_bits().hash(state); + } +} diff --git a/core/src/caching/bounded/lru/mod.rs b/core/src/caching/bounded/lru/mod.rs index 272f6adf4c26ef228f4a1d0e7f72d391588f13d1..445862905260747920ac8d0f663ba7f70c9909a3 100644 --- a/core/src/caching/bounded/lru/mod.rs +++ b/core/src/caching/bounded/lru/mod.rs @@ -1,3 +1,4 @@ mod linked_list; mod list_node; pub mod lru_cache; +mod map_entry; diff --git a/core/src/caching/mod.rs b/core/src/caching/mod.rs index 713a6cc41362e8e873222b1bd1a3dcda3c9839fc..85010f40f97ef55438f185684a35909817ae6805 100644 --- a/core/src/caching/mod.rs +++ b/core/src/caching/mod.rs @@ -1,3 +1,2 @@ pub mod approximate_cache; pub mod bounded; -pub mod unbounded_linear_cache; diff --git a/core/src/caching/unbounded_linear_cache.rs b/core/src/caching/unbounded_linear_cache.rs deleted file mode 100644 index 2b20bd84977f95cce262005f43d914d007dafbd7..0000000000000000000000000000000000000000 --- a/core/src/caching/unbounded_linear_cache.rs +++ /dev/null @@ -1,156 +0,0 @@ -use crate::caching::approximate_cache::ApproximateCache; -use crate::numerics::comp::ApproxComparable; -/// A cache implementation that checks all entries one-by-one, without eviction -/// ## Generic Types -/// The types K and V are used for the cache keys and values respectively. -/// -/// K should be `ApproxComparable`, i.e. the compiler should know how to -/// decide that two K's are 'close enough' given a certain tolerance. -/// -/// V should be `Clone` so that the user can do whatever they want with a returned -/// value without messing with the actual cache line. -/// -/// ## Constructors -/// Use the ```from``` method to create a new cache. You will be asked to provide a -/// tolerance for the search and (optionally) an initial allocated capacity in memory. -/// ```tolerance``` indicates the searching sensitivity (see `ApproxComparable`), -/// which is a constant w.r.t. to the queried K (for now). -pub struct UnboundedLinearCache<K, V> -where - K: ApproxComparable, - V: Clone, -{ - keys: Vec<K>, - values: Vec<V>, - tolerance: f32, -} - -impl<K, V> UnboundedLinearCache<K, V> -where - K: ApproxComparable, - V: Clone, -{ - pub fn new(tolerance: f32) -> Self { - UnboundedLinearCache { - keys: Vec::new(), - values: Vec::new(), - tolerance, - } - } - - pub fn with_initial_capacity(tolerance: f32, capacity: usize) -> Self { - UnboundedLinearCache { - keys: Vec::with_capacity(capacity), - values: Vec::with_capacity(capacity), - tolerance, - } - } -} - -impl<K, V> ApproximateCache<K, V> for UnboundedLinearCache<K, V> -where - K: ApproxComparable, - V: Clone, -{ - // to find a match in an unbounded cache, iterate over all cache lines - // and return early if you have something - fn find(&mut self, to_find: &K) -> Option<V> { - let potential_match = self - .keys - .iter() - .position(|key| to_find.roughly_matches(key, self.tolerance)); - - potential_match.map(|i| self.values[i].clone()) - } - - // inserting a new value in a linear cache == pushing it at the end for future scans - fn insert(&mut self, key: K, value: V) { - self.keys.push(key); - self.values.push(value); - } - - fn len(&self) -> usize { - self.keys.len() - } -} - -#[cfg(test)] -mod tests { - use crate::caching::approximate_cache::COMPTIME_CACHE_SIZE; - - use super::*; - use quickcheck::{QuickCheck, TestResult}; - - const TEST_TOLERANCE: f32 = 1e-8; - const TEST_MAX_SIZE: usize = COMPTIME_CACHE_SIZE; - - #[test] - fn start_always_matches_exactly() { - fn qc_start_always_matches_exactly( - start_state: Vec<(f32, u8)>, - key: f32, - value: u8, - ) -> TestResult { - let mut ulc = UnboundedLinearCache::<f32, u8>::new(TEST_TOLERANCE); - if !key.is_finite() || start_state.len() > TEST_MAX_SIZE { - return TestResult::discard(); - } - - ulc.insert(key, value); - for &(k, v) in start_state.iter() { - ulc.insert(k, v); - } - - assert!(ulc.len() == start_state.len() + 1); - - if let Some(x) = ulc.find(&key) { - TestResult::from_bool(x == value) - } else { - TestResult::failed() - } - } - - QuickCheck::new() - .tests(10_000) - .min_tests_passed(1_000) - .quickcheck( - qc_start_always_matches_exactly as fn(Vec<(f32, u8)>, f32, u8) -> TestResult, - ); - } - - #[test] - fn middle_always_matches() { - fn qc_middle_always_matches( - start_state: Vec<(f32, u8)>, - key: f32, - value: u8, - end_state: Vec<(f32, u8)>, - ) -> TestResult { - let mut ulc = UnboundedLinearCache::<f32, u8>::new(TEST_TOLERANCE); - if !key.is_finite() || start_state.len() > TEST_MAX_SIZE { - return TestResult::discard(); - } - - for &(k, v) in start_state.iter() { - ulc.insert(k, v); - } - ulc.insert(key, value); - for &(k, v) in end_state.iter() { - ulc.insert(k, v); - } - - assert!(ulc.len() == start_state.len() + end_state.len() + 1); - - // we should match on something but we can't know on what - TestResult::from_bool(ulc.find(&key).is_some()) - } - - QuickCheck::new() - .tests(10_000) - .min_tests_passed(1_000) - .quickcheck( - qc_middle_always_matches - as fn(Vec<(f32, u8)>, f32, u8, Vec<(f32, u8)>) -> TestResult, - ); - } -} diff --git a/core/src/main.rs b/core/src/main.rs index 37eab5232a5d83ba20d5bedb8ec5c5d5650c7ffb..03776d879ab1026b7c1e8ec30db7cc7f46e5afbd 100644 --- a/core/src/main.rs +++ b/core/src/main.rs @@ -27,7 +27,7 @@ fn main() { let vecs_f: Vec<f32> = vecs.into_iter().map(f32::from).collect(); println!("{:?}", vecs_f.chunks_exact(128).next().unwrap()); - let mut ulc = LRUCache::<F32Vector, usize>::new(10000, 15_000.0); + let mut ulc = LRUCache::<F32Vector, usize>::new(10000); let mut count: u32 = 0; let mut scanned: usize = 0; @@ -40,7 +40,7 @@ fn main() { count += 1; } else { scanned += ulc.len(); - ulc.insert(f32v, index); + ulc.insert(f32v, index, 15_000.0); } writeln!(file, "{} {}", index, find.unwrap_or(50001)).unwrap();