diff --git a/proximity-rs/Cargo.lock b/proximity-rs/Cargo.lock new file mode 100644 index 0000000000000000000000000000000000000000..cd391e5fcac2891f35f91b53047564b314ebcdb0 --- /dev/null +++ b/proximity-rs/Cargo.lock @@ -0,0 +1,531 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "cpufeatures" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +dependencies = [ + "libc", +] + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + +[[package]] +name = "env_logger" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "indoc" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" + +[[package]] +name = "libc" +version = "0.2.169" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" + +[[package]] +name = "log" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "npyz" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13f27ea175875c472b3df61ece89a6d6ef4e0627f43704e400c782f174681ebd" +dependencies = [ + "byteorder", + "num-bigint", + "py_literal", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" + +[[package]] +name = "pest" +version = "2.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b7cafe60d6cf8e62e1b9b2ea516a089c008945bb5a275416789e7db0bc199dc" +dependencies = [ + "memchr", + "thiserror", + "ucd-trie", +] + +[[package]] +name = "pest_derive" +version = "2.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "816518421cfc6887a0d62bf441b6ffb4536fcc926395a69e1a85852d4363f57e" +dependencies = [ + "pest", + "pest_generator", +] + +[[package]] +name = "pest_generator" +version = "2.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d1396fd3a870fc7838768d171b4616d5c91f6cc25e377b673d714567d99377b" +dependencies = [ + "pest", + "pest_meta", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "pest_meta" +version = "2.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1e58089ea25d717bfd31fb534e4f3afcc2cc569c70de3e239778991ea3b7dea" +dependencies = [ + "once_cell", + "pest", + "sha2", +] + +[[package]] +name = "portable-atomic" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" + +[[package]] +name = "ppv-lite86" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "proximity-cache" +version = "0.1.0" +dependencies = [ + "npyz", + "pyo3", + "quickcheck", + "rand", +] + +[[package]] +name = "py_literal" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "102df7a3d46db9d3891f178dcc826dc270a6746277a9ae6436f8d29fd490a8e1" +dependencies = [ + "num-bigint", + "num-complex", + "num-traits", + "pest", + "pest_derive", +] + +[[package]] +name = "pyo3" +version = "0.23.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e484fd2c8b4cb67ab05a318f1fd6fa8f199fcc30819f08f07d200809dba26c15" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.23.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc0e0469a84f208e20044b98965e1561028180219e35352a2afaf2b942beff3b" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.23.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb1547a7f9966f6f1a0f0227564a9945fe36b90da5a93b3933fc3dc03fae372d" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.23.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb6da8ec6fa5cedd1626c886fc8749bdcbb09424a86461eb8cdf096b7c33257" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.23.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38a385202ff5a92791168b1136afae5059d3ac118457bb7bc304c197c2d33e7d" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quickcheck" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "588f6378e4dd99458b60ec275b4477add41ce4fa9f64dcba6f15adccb19b50d6" +dependencies = [ + "env_logger", + "log", + "rand", +] + +[[package]] +name = "quote" +version = "1.0.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "regex" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "syn" +version = "2.0.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46f71c0377baf4ef1cc3e3402ded576dccc315800fbc62dfc7fe04b009773b4a" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + +[[package]] +name = "thiserror" +version = "2.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + +[[package]] +name = "ucd-trie" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" + +[[package]] +name = "unicode-ident" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" + +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/proximity-rs/Cargo.toml b/proximity-rs/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..061865a880705d8901b53952099aff0a13c8752e --- /dev/null +++ b/proximity-rs/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "proximity-cache" +version = "0.1.0" +edition = "2021" +authors = ["SaCS laboratory, EPFL. Correspond with mathis[d o t]randl[a t]epfl.ch"] +description = "Experiments on approximate vector search in high-dimensional spaces" +readme = "README.md" +license = "MIT" +repository = "https://gitlab.epfl.ch/randl/proximity" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "proximipy" +crate-type = ["cdylib"] +path = "src/lib.rs" + +[[bin]] +name = "proximitybin" +path = "src/main.rs" + +[dev-dependencies] +quickcheck = "1.0.3" + +[dependencies] +pyo3 = "0.23.3" +npyz = "0.8.3" +rand = "0.8.5" \ No newline at end of file diff --git a/proximity-rs/rust-toolchain.toml b/proximity-rs/rust-toolchain.toml new file mode 100644 index 0000000000000000000000000000000000000000..271800cb2f3791b3adc24328e71c9e2550b439db --- /dev/null +++ b/proximity-rs/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly" \ No newline at end of file diff --git a/proximity-rs/src/caching/approximate_cache.rs b/proximity-rs/src/caching/approximate_cache.rs new file mode 100644 index 0000000000000000000000000000000000000000..d574bd78529a1c2e74e6766a7137bdffbd64b585 --- /dev/null +++ b/proximity-rs/src/caching/approximate_cache.rs @@ -0,0 +1,17 @@ +use crate::numerics::comp::ApproxComparable; + +// size of caches in implementations where that should be known at comptime +pub const COMPTIME_CACHE_SIZE: usize = 1024; + +pub trait ApproximateCache<K, V> +where + K: ApproxComparable, + V: Clone, +{ + fn find(&mut self, key: &K) -> Option<V>; + fn insert(&mut self, key: K, value: V); + fn len(&self) -> usize; + fn is_empty(&self) -> bool { + self.len() == 0 + } +} diff --git a/proximity-rs/src/caching/bounded/bounded_linear_cache.rs b/proximity-rs/src/caching/bounded/bounded_linear_cache.rs new file mode 100644 index 0000000000000000000000000000000000000000..7e7e98870549be181855ac0578c389b9a2c34473 --- /dev/null +++ b/proximity-rs/src/caching/bounded/bounded_linear_cache.rs @@ -0,0 +1,202 @@ +use std::collections::HashMap; +use std::hash::Hash; + +use pyo3::{pyclass, pymethods}; + +use crate::numerics::comp::ApproxComparable; + +use crate::caching::approximate_cache::ApproximateCache; + +use super::linked_list::DoublyLinkedList; +use super::list_node::{Node, SharedNode}; + +/// `BoundedLinearCache` is a bounded cache with approximate key matching support. +/// +/// The cache enforces a maximum capacity, and when the capacity is exceeded, the least recently used (LRU) element is evicted. +/// +/// # Approximate Key Matching +/// Keys must implement the `ApproxComparable` trait, which allows approximate equality comparisons based on the provided `tolerance`. +/// This enables the cache to retrieve values even when the queried key is not an exact match but is "close enough." +/// +/// # Example Usage +/// ``` +/// use proximitylib::caching::bounded::bounded_linear_cache::BoundedLinearCache; +/// use proximitylib::caching::approximate_cache::ApproximateCache; +/// +/// let mut cache = BoundedLinearCache::new(3, 2.0); +/// +/// cache.insert(10 as i16, "Value 1"); +/// cache.insert(20, "Value 2"); +/// cache.insert(30, "Value 3"); +/// +/// assert_eq!(cache.find(&11), Some("Value 1")); +/// assert_eq!(cache.len(), 3); +/// +/// cache.insert(40, "Value 4"); // Evicts the least recently used (Key(20)) +/// assert!(cache.find(&20).is_none()); +/// ``` +/// +/// # Type Parameters +/// - `K`: The type of the keys, which must implement `ApproxComparable`, `Eq`, `Hash`, and `Clone`. +/// - `V`: The type of the values, which must implement `Clone`. +/// +/// # Methods +/// - `new(max_capacity: usize, tolerance: f32) -> Self`: Creates a new `BoundedLinearCache` with the specified maximum capacity and tolerance. +/// - `find(&mut self, key: &K) -> Option<V>`: Attempts to find a value matching the given key approximately. Promotes the found key to the head of the list. +/// - `insert(&mut self, key: K, value: V)`: Inserts a key-value pair into the cache. Evicts the least recently used item if the cache is full. +/// - `len(&self) -> usize`: Returns the current size of the cache. +pub struct BoundedLinearCache<K, V> { + max_capacity: usize, + map: HashMap<K, SharedNode<K, V>>, + list: DoublyLinkedList<K, V>, + tolerance: f32, +} + +impl<K, V> ApproximateCache<K, V> for BoundedLinearCache<K, V> +where + K: ApproxComparable + Eq + Hash + Clone, + V: Clone, +{ + fn find(&mut self, key: &K) -> Option<V> { + let matching = self + .map + .keys() + .find(|&k| key.roughly_matches(k, self.tolerance))?; + let node: SharedNode<K, V> = self.map.get(matching).cloned()?; + self.list.remove(node.clone()); + self.list.add_to_head(node.clone()); + return Some(node.borrow().value.clone()); + } + + fn insert(&mut self, key: K, value: V) { + if self.len() == self.max_capacity { + if let Some(tail) = self.list.remove_tail() { + self.map.remove(&tail.borrow().key); + } + } + let new_node = Node::new(key.clone(), value); + self.list.add_to_head(new_node.clone()); + self.map.insert(key, new_node); + } + + fn len(&self) -> usize { + self.map.len() + } +} + +impl<K, V> BoundedLinearCache<K, V> { + pub fn new(max_capacity: usize, tolerance: f32) -> Self { + assert!(max_capacity > 0); + assert!(tolerance > 0.0); + Self { + max_capacity, + map: HashMap::new(), + list: DoublyLinkedList::new(), + tolerance, + } + } +} + +macro_rules! create_pythonized_interface { + ($name: ident, $type: ident) => { + // unsendable == should hard-crash if Python tries to access it from + // two different Python threads. + // + // The implementation is very much thread-unsafe anyways (lots of mutations), + // so this is an OK behavior, we will detect it with a nice backtrace + // and without UB. + // + // Even in the case where we want the cache to be multithreaded, this would + // happen on the Rust side and will not be visible to the Python ML pipeline. + #[pyclass(unsendable)] + pub struct $name { + inner: BoundedLinearCache<$type, $type>, + } + + #[pymethods] + impl $name { + #[new] + pub fn new(max_capacity: usize, tolerance: f32) -> Self { + Self { + inner: BoundedLinearCache::new(max_capacity, tolerance), + } + } + + fn find(&mut self, k: $type) -> Option<$type> { + self.inner.find(&k) + } + + fn insert(&mut self, key: $type, value: $type) { + self.inner.insert(key, value) + } + + fn len(&self) -> usize { + self.inner.len() + } + } + }; +} + +create_pythonized_interface!(I16Cache, i16); + +#[cfg(test)] +mod tests { + use super::*; + + const TEST_TOLERANCE: f32 = 1e-8; + #[test] + fn test_lru_cache_basic_operations() { + let mut cache: BoundedLinearCache<i16, i16> = BoundedLinearCache::new(2, TEST_TOLERANCE); + cache.insert(1, 1); // Cache is {1=1} + cache.insert(2, 2); // Cache is {1=1, 2=2} + assert_eq!(cache.find(&1), Some(1)); // Returns 1, Cache is {2=2, 1=1} + cache.insert(3, 3); // Evicts key 2, Cache is {1=1, 3=3} + assert_eq!(cache.find(&2), None); // Key 2 not found + cache.insert(4, 4); // Evicts key 1, Cache is {3=3, 4=4} + assert_eq!(cache.find(&1), None); // Key 1 not found + assert_eq!(cache.find(&3), Some(3)); // Returns 3 + assert_eq!(cache.find(&4), Some(4)); // Returns 4 + } + + #[test] + fn test_lru_cache_eviction_order() { + let mut cache: BoundedLinearCache<i16, i16> = BoundedLinearCache::new(3, TEST_TOLERANCE); + cache.insert(1, 1); // Cache is {1=1} + cache.insert(2, 2); // Cache is {1=1, 2=2} + cache.insert(3, 3); // Cache is {1=1, 2=2, 3=3} + cache.find(&1); // Access key 1, Cache is {2=2, 3=3, 1=1} + cache.insert(4, 4); // Evicts key 2, Cache is {3=3, 1=1, 4=4} + assert_eq!(cache.find(&2), None); // Key 2 not found + assert_eq!(cache.find(&3), Some(3)); // Returns 3 + assert_eq!(cache.find(&4), Some(4)); // Returns 4 + assert_eq!(cache.find(&1), Some(1)); // Returns 1 + } + + #[test] + fn test_lru_cache_overwrite() { + let mut cache: BoundedLinearCache<i16, i16> = BoundedLinearCache::new(2, TEST_TOLERANCE); + cache.insert(1, 1); // Cache is {1=1} + cache.insert(2, 2); // Cache is {1=1, 2=2} + cache.insert(1, 10); // Overwrites key 1, Cache is {2=2, 1=10} + assert_eq!(cache.find(&1), Some(10)); // Returns 10 + cache.insert(3, 3); // Evicts key 2, Cache is {1=10, 3=3} + assert_eq!(cache.find(&2), None); // Key 2 not found + assert_eq!(cache.find(&3), Some(3)); // Returns 3 + } + + #[test] + fn test_lru_cache_capacity_one() { + let mut cache: BoundedLinearCache<i16, i16> = BoundedLinearCache::new(1, TEST_TOLERANCE); + cache.insert(1, 1); // Cache is {1=1} + assert_eq!(cache.find(&1), Some(1)); // Returns 1 + cache.insert(2, 2); // Evicts key 1, Cache is {2=2} + assert_eq!(cache.find(&1), None); // Key 1 not found + assert_eq!(cache.find(&2), Some(2)); // Returns 2 + } + + #[test] + #[should_panic] + fn test_lru_cache_empty() { + let _cache: BoundedLinearCache<i16, i16> = BoundedLinearCache::new(0, TEST_TOLERANCE); + } +} diff --git a/proximity-rs/src/caching/bounded/linked_list.rs b/proximity-rs/src/caching/bounded/linked_list.rs new file mode 100644 index 0000000000000000000000000000000000000000..3b560aec1cf06ef47d95e2d1471845383ee253b6 --- /dev/null +++ b/proximity-rs/src/caching/bounded/linked_list.rs @@ -0,0 +1,156 @@ +use std::rc::Rc; + +use crate::caching::bounded::list_node::SharedNode; + +pub struct DoublyLinkedList<K, V> { + head: Option<SharedNode<K, V>>, + tail: Option<SharedNode<K, V>>, +} + +impl<K, V> DoublyLinkedList<K, V> { + pub(crate) fn new() -> Self { + Self { + head: None, + tail: None, + } + } + + pub(crate) fn add_to_head(&mut self, node: SharedNode<K, V>) { + node.borrow_mut().next = self.head.clone(); + node.borrow_mut().prev = None; + + if let Some(head) = self.head.clone() { + head.borrow_mut().prev = Some(Rc::downgrade(&node)); + } + + self.head = Some(node.clone()); + + if self.tail.is_none() { + self.tail = Some(node); + } + } + + pub(crate) fn remove(&mut self, node: SharedNode<K, V>) { + let prev = node.borrow().prev.clone(); + let next = node.borrow().next.clone(); + + if let Some(prev_node) = prev.as_ref().and_then(|weak| weak.upgrade()) { + prev_node.borrow_mut().next = next.clone(); + } else { + self.head = next.clone(); + } + + if let Some(next_node) = next { + next_node.borrow_mut().prev = prev; + } else { + self.tail = prev.and_then(|weak| weak.upgrade()); + } + } + + pub(crate) fn remove_tail(&mut self) -> Option<SharedNode<K, V>> { + if let Some(tail) = self.tail.clone() { + self.remove(tail.clone()); + Some(tail) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::caching::bounded::list_node::Node; + + #[test] + fn test_add_to_head() { + let mut list = DoublyLinkedList::new(); + let node1 = Node::new(1, 10); + let node2 = Node::new(2, 20); + + list.add_to_head(node1.clone()); + assert_eq!(list.head.as_ref().unwrap().borrow().key, 1); + assert_eq!(list.tail.as_ref().unwrap().borrow().key, 1); + + list.add_to_head(node2.clone()); + assert_eq!(list.head.as_ref().unwrap().borrow().key, 2); + assert_eq!(list.tail.as_ref().unwrap().borrow().key, 1); + } + + #[test] + fn test_remove_node() { + let mut list = DoublyLinkedList::new(); + let node1 = Node::new(1, 10); + let node2 = Node::new(2, 20); + let node3 = Node::new(3, 30); + + list.add_to_head(node1.clone()); + list.add_to_head(node2.clone()); + list.add_to_head(node3.clone()); + + // List is now: {3, 2, 1} + list.remove(node2.clone()); + // List should now be: {3, 1} + assert_eq!(list.head.as_ref().unwrap().borrow().key, 3); + assert_eq!(list.tail.as_ref().unwrap().borrow().key, 1); + + list.remove(node3.clone()); + // List should now be: {1} + assert_eq!(list.head.as_ref().unwrap().borrow().key, 1); + assert_eq!(list.tail.as_ref().unwrap().borrow().key, 1); + + list.remove(node1.clone()); + // List should now be empty + assert!(list.head.is_none()); + assert!(list.tail.is_none()); + } + + #[test] + fn test_remove_tail() { + let mut list = DoublyLinkedList::new(); + let node1 = Node::new(1, 10); + let node2 = Node::new(2, 20); + + list.add_to_head(node1.clone()); + list.add_to_head(node2.clone()); + + // List is now: {2, 1} + let removed_tail = list.remove_tail().unwrap(); + assert_eq!(removed_tail.borrow().key, 1); + // List should now be: {2} + assert_eq!(list.head.as_ref().unwrap().borrow().key, 2); + assert_eq!(list.tail.as_ref().unwrap().borrow().key, 2); + + let removed_tail = list.remove_tail().unwrap(); + assert_eq!(removed_tail.borrow().key, 2); + // List should now be empty + assert!(list.head.is_none()); + assert!(list.tail.is_none()); + } + + #[test] + fn test_add_and_remove_combination() { + let mut list = DoublyLinkedList::new(); + let node1 = Node::new(1, 10); + let node2 = Node::new(2, 20); + let node3 = Node::new(3, 30); + + list.add_to_head(node1.clone()); + list.add_to_head(node2.clone()); + list.add_to_head(node3.clone()); + + // List is now: {3, 2, 1} + assert_eq!(list.head.as_ref().unwrap().borrow().key, 3); + assert_eq!(list.tail.as_ref().unwrap().borrow().key, 1); + + list.remove(node1.clone()); + // List should now be: {3, 2} + assert_eq!(list.head.as_ref().unwrap().borrow().key, 3); + assert_eq!(list.tail.as_ref().unwrap().borrow().key, 2); + + list.add_to_head(node1.clone()); + // List should now be: {1, 3, 2} + assert_eq!(list.head.as_ref().unwrap().borrow().key, 1); + assert_eq!(list.tail.as_ref().unwrap().borrow().key, 2); + } +} diff --git a/proximity-rs/src/caching/bounded/list_node.rs b/proximity-rs/src/caching/bounded/list_node.rs new file mode 100644 index 0000000000000000000000000000000000000000..dafe9f6394610c7e95aeba68b225da5d0a6c25e3 --- /dev/null +++ b/proximity-rs/src/caching/bounded/list_node.rs @@ -0,0 +1,26 @@ +use std::{ + cell::RefCell, + rc::{Rc, Weak}, +}; + +pub(crate) type SharedNode<K, V> = Rc<RefCell<Node<K, V>>>; +pub(crate) type WeakSharedNode<K, V> = Weak<RefCell<Node<K, V>>>; + +#[derive(Debug)] +pub struct Node<K, V> { + pub(crate) key: K, + pub(crate) value: V, + pub(crate) prev: Option<WeakSharedNode<K, V>>, + pub(crate) next: Option<SharedNode<K, V>>, +} + +impl<K, V> Node<K, V> { + pub fn new(key: K, value: V) -> Rc<RefCell<Self>> { + Rc::new(RefCell::new(Node { + key, + value, + prev: None, + next: None, + })) + } +} diff --git a/proximity-rs/src/caching/bounded/mod.rs b/proximity-rs/src/caching/bounded/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..b2b1881a91a5fe79a9634011518242ed279dc1e8 --- /dev/null +++ b/proximity-rs/src/caching/bounded/mod.rs @@ -0,0 +1,3 @@ +pub mod bounded_linear_cache; +mod linked_list; +mod list_node; diff --git a/proximity-rs/src/caching/mod.rs b/proximity-rs/src/caching/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..713a6cc41362e8e873222b1bd1a3dcda3c9839fc --- /dev/null +++ b/proximity-rs/src/caching/mod.rs @@ -0,0 +1,3 @@ +pub mod approximate_cache; +pub mod bounded; +pub mod unbounded_linear_cache; diff --git a/proximity-rs/src/caching/unbounded_linear_cache.rs b/proximity-rs/src/caching/unbounded_linear_cache.rs new file mode 100644 index 0000000000000000000000000000000000000000..2b20bd84977f95cce262005f43d914d007dafbd7 --- /dev/null +++ b/proximity-rs/src/caching/unbounded_linear_cache.rs @@ -0,0 +1,156 @@ +use crate::caching::approximate_cache::ApproximateCache; +use crate::numerics::comp::ApproxComparable; +/// A cache implementation that checks all entries one-by-one, without eviction +/// ## Generic Types +/// The types K and V are used for the cache keys and values respectively. +/// +/// K should be `ApproxComparable`, i.e. the compiler should know how to +/// decide that two K's are 'close enough' given a certain tolerance. +/// +/// V should be `Clone` so that the user can do whatever they want with a returned +/// value without messing with the actual cache line. +/// +/// ## Constructors +/// Use the ```from``` method to create a new cache. You will be asked to provide a +/// tolerance for the search and (optionally) an initial allocated capacity in memory. +/// ```tolerance``` indicates the searching sensitivity (see `ApproxComparable`), +/// which is a constant w.r.t. to the queried K (for now). +pub struct UnboundedLinearCache<K, V> +where + K: ApproxComparable, + V: Clone, +{ + keys: Vec<K>, + values: Vec<V>, + tolerance: f32, +} + +impl<K, V> UnboundedLinearCache<K, V> +where + K: ApproxComparable, + V: Clone, +{ + pub fn new(tolerance: f32) -> Self { + UnboundedLinearCache { + keys: Vec::new(), + values: Vec::new(), + tolerance, + } + } + + pub fn with_initial_capacity(tolerance: f32, capacity: usize) -> Self { + UnboundedLinearCache { + keys: Vec::with_capacity(capacity), + values: Vec::with_capacity(capacity), + tolerance, + } + } +} + +impl<K, V> ApproximateCache<K, V> for UnboundedLinearCache<K, V> +where + K: ApproxComparable, + V: Clone, +{ + // to find a match in an unbounded cache, iterate over all cache lines + // and return early if you have something + fn find(&mut self, to_find: &K) -> Option<V> { + let potential_match = self + .keys + .iter() + .position(|key| to_find.roughly_matches(key, self.tolerance)); + + potential_match.map(|i| self.values[i].clone()) + } + + // inserting a new value in a linear cache == pushing it at the end for future scans + fn insert(&mut self, key: K, value: V) { + self.keys.push(key); + self.values.push(value); + } + + fn len(&self) -> usize { + self.keys.len() + } +} + +#[cfg(test)] +mod tests { + use crate::caching::approximate_cache::COMPTIME_CACHE_SIZE; + + use super::*; + use quickcheck::{QuickCheck, TestResult}; + + const TEST_TOLERANCE: f32 = 1e-8; + const TEST_MAX_SIZE: usize = COMPTIME_CACHE_SIZE; + + #[test] + fn start_always_matches_exactly() { + fn qc_start_always_matches_exactly( + start_state: Vec<(f32, u8)>, + key: f32, + value: u8, + ) -> TestResult { + let mut ulc = UnboundedLinearCache::<f32, u8>::new(TEST_TOLERANCE); + if !key.is_finite() || start_state.len() > TEST_MAX_SIZE { + return TestResult::discard(); + } + + ulc.insert(key, value); + for &(k, v) in start_state.iter() { + ulc.insert(k, v); + } + + assert!(ulc.len() == start_state.len() + 1); + + if let Some(x) = ulc.find(&key) { + TestResult::from_bool(x == value) + } else { + TestResult::failed() + } + } + + QuickCheck::new() + .tests(10_000) + .min_tests_passed(1_000) + .quickcheck( + qc_start_always_matches_exactly as fn(Vec<(f32, u8)>, f32, u8) -> TestResult, + ); + } + + #[test] + fn middle_always_matches() { + fn qc_middle_always_matches( + start_state: Vec<(f32, u8)>, + key: f32, + value: u8, + end_state: Vec<(f32, u8)>, + ) -> TestResult { + let mut ulc = UnboundedLinearCache::<f32, u8>::new(TEST_TOLERANCE); + if !key.is_finite() || start_state.len() > TEST_MAX_SIZE { + return TestResult::discard(); + } + + for &(k, v) in start_state.iter() { + ulc.insert(k, v); + } + ulc.insert(key, value); + for &(k, v) in end_state.iter() { + ulc.insert(k, v); + } + + assert!(ulc.len() == start_state.len() + end_state.len() + 1); + + // we should match on something but we can't know on what + TestResult::from_bool(ulc.find(&key).is_some()) + } + + QuickCheck::new() + .tests(10_000) + .min_tests_passed(1_000) + .quickcheck( + qc_middle_always_matches + as fn(Vec<(f32, u8)>, f32, u8, Vec<(f32, u8)>) -> TestResult, + ); + } +} diff --git a/proximity-rs/src/fs/file_manager.rs b/proximity-rs/src/fs/file_manager.rs new file mode 100644 index 0000000000000000000000000000000000000000..785126b985d21301f36c0257f2bacde2e3940ffa --- /dev/null +++ b/proximity-rs/src/fs/file_manager.rs @@ -0,0 +1,20 @@ +use std::{fs, path::Path}; + +pub fn read_from_file_f32(path: &Path) -> Vec<f32> { + //todo plenty of unnecessary copying going on here + let file_u8 = fs::read(path).unwrap(); + let chunks = file_u8.array_chunks::<516>(); + chunks.flat_map(handle_f32_raw_vec).collect::<Vec<_>>() +} + +fn handle_f32_raw_vec(v: &[u8; 516]) -> Vec<f32> { + let chunks = v[4..].array_chunks::<4>(); + chunks.map(|&chk| f32::from_le_bytes(chk)).collect() +} + +pub fn read_from_npy(path: &Path) -> Vec<u8> { + let bytes = std::fs::read(path).unwrap(); + + let npy = npyz::NpyFile::new(&bytes[..]).unwrap(); + npy.into_vec().unwrap() +} diff --git a/proximity-rs/src/fs/mod.rs b/proximity-rs/src/fs/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..38751b5dbb915c352334ccd41c5a5fd9db359b8f --- /dev/null +++ b/proximity-rs/src/fs/mod.rs @@ -0,0 +1,2 @@ +pub mod file_manager; +pub mod vector_type; diff --git a/proximity-rs/src/fs/vector_type.rs b/proximity-rs/src/fs/vector_type.rs new file mode 100644 index 0000000000000000000000000000000000000000..1e06d2b5400a38991d9202809c90c4719beb5599 --- /dev/null +++ b/proximity-rs/src/fs/vector_type.rs @@ -0,0 +1,5 @@ +#[derive(Eq, PartialEq)] +pub enum VectorType { + F32, + I8, +} diff --git a/proximity-rs/src/lib.rs b/proximity-rs/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..9aa51bed92e048f47e55065763bc792c2965bb9e --- /dev/null +++ b/proximity-rs/src/lib.rs @@ -0,0 +1,24 @@ +#![feature(portable_simd, test, array_chunks)] + +use caching::bounded::bounded_linear_cache::I16Cache; +use pyo3::prelude::*; + +extern crate npyz; +extern crate rand; +extern crate test; + +pub mod caching; +pub mod fs; +pub mod numerics; + +#[pyfunction] +fn sum_as_string(a: usize, b: usize) -> PyResult<String> { + Ok((a + b).to_string()) +} + +#[pymodule] +fn proximipy(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; + m.add_class::<I16Cache>()?; + Ok(()) +} diff --git a/proximity-rs/src/main.rs b/proximity-rs/src/main.rs new file mode 100644 index 0000000000000000000000000000000000000000..f966bb10329ca05126a99c39fbde559b42596080 --- /dev/null +++ b/proximity-rs/src/main.rs @@ -0,0 +1,82 @@ +#![allow(dead_code)] +#![feature(portable_simd, test, array_chunks)] + +use std::path::Path; + +use caching::approximate_cache::ApproximateCache; +use caching::bounded::bounded_linear_cache::BoundedLinearCache; +use fs::file_manager; +use numerics::f32vector::F32Vector; +use std::fs::File; +use std::io::Write; + +extern crate npyz; +extern crate rand; +extern crate test; + +mod caching; +mod fs; +mod numerics; + +const PATH_TO_ROOT: &str = "/Users/matrix/Documents/proximity/"; +fn main() { + let vecs = + file_manager::read_from_npy(Path::new(&(PATH_TO_ROOT.to_owned() + "res/sift/p9.npy"))); + let mut file = File::create("res/sift/p9res.txt").unwrap(); + + let vecs_f: Vec<f32> = vecs.into_iter().map(f32::from).collect(); + println!("{:?}", vecs_f.chunks_exact(128).next().unwrap()); + + let mut ulc = BoundedLinearCache::<F32Vector, usize>::new(10000, 15_000.0); + let mut count: u32 = 0; + let mut scanned: usize = 0; + + for index in 0..50_000 { + let f32v = F32Vector::from(&vecs_f[index * 128..(index + 1) * 128]); + let find = ulc.find(&f32v); + let found = find.is_some(); + + if found { + count += 1; + } else { + scanned += ulc.len(); + ulc.insert(f32v, index); + } + + writeln!(file, "{} {}", index, find.unwrap_or(50001)).unwrap(); + } + + println!("count : {}, scanned lower bound : {}", count, scanned) +} + +#[cfg(test)] +mod tests { + use std::hint::black_box; + + use rand::Rng as _; + + use crate::numerics::f32vector::F32Vector; + + const VEC_SIZE: usize = (u32::MAX / 128) as usize; + + #[bench] + fn perftest(b: &mut test::Bencher) { + let mut rng = rand::thread_rng(); + let v1: Vec<_> = (0..128) + .map(|_| f32::from(rng.gen_range(-20 as i16..20))) + .collect(); + let v2s: Vec<f32> = (0 as u64..(128 * 10_000)) + .map(|_| f32::from(rng.gen_range(-20 as i16..20))) + .collect(); + + assert!(v1.len() == 128); + let v1_f32v = F32Vector::from(&v1[..]); + + b.iter(|| { + for v2_i in v2s.chunks_exact(128) { + let v2_f32v = F32Vector::from(v2_i); + black_box(v1_f32v.l2_dist_squared(&v2_f32v)); + } + }) + } +} diff --git a/proximity-rs/src/numerics/comp.rs b/proximity-rs/src/numerics/comp.rs new file mode 100644 index 0000000000000000000000000000000000000000..2a060d2e381ca7e7af1bd99af027f3d6d0c61d03 --- /dev/null +++ b/proximity-rs/src/numerics/comp.rs @@ -0,0 +1,26 @@ +use super::f32vector::F32Vector; + +// rust Ord trait has some issues +pub trait ApproxComparable { + fn roughly_matches(&self, instore: &Self, tolerance: f32) -> bool; +} + +impl ApproxComparable for f32 { + fn roughly_matches(&self, target: &f32, tolerance: f32) -> bool { + (self - target).abs() < tolerance + } +} + +impl<'a> ApproxComparable for F32Vector<'a> { + fn roughly_matches(&self, target: &F32Vector<'a>, square_tolerance: f32) -> bool { + self.l2_dist_squared(target) < square_tolerance + } +} + +impl ApproxComparable for i16 { + fn roughly_matches(&self, instore: &Self, tolerance: f32) -> bool { + let fself = f32::from(*self); + let foth = f32::from(*instore); + fself.roughly_matches(&foth, tolerance) + } +} diff --git a/proximity-rs/src/numerics/f32vector.rs b/proximity-rs/src/numerics/f32vector.rs new file mode 100644 index 0000000000000000000000000000000000000000..ed7a98502b918dbb1c07a9047aecc3c89cc9b3f0 --- /dev/null +++ b/proximity-rs/src/numerics/f32vector.rs @@ -0,0 +1,189 @@ +use std::simd::{num::SimdFloat, Simd}; + +const SIMD_LANECOUNT: usize = 8; +type SimdF32 = Simd<f32, SIMD_LANECOUNT>; + +#[derive(Debug, Clone)] +pub struct F32Vector<'a> { + array: &'a [f32], +} + +impl<'a> F32Vector<'a> { + pub fn len(&self) -> usize { + self.array.len() + } + + pub fn is_empty(&self) -> bool { + self.array.is_empty() + } + + /// # Usage + /// Computes the **SQUARED** L2 distance between two vectors. + /// This is cheaper to compute than the regular L2 distance. + /// This is typically useful when comparing two distances : + /// + /// dist(u,v) < dist(w, x) ⇔ dist(u,v) ** 2 < dist(w,x) ** 2 + /// + /// # Panics + /// + /// Panics in debug mode if the two vectors have different lengths. + /// In release mode, the longest vector will be silently truncated. + #[inline] + pub fn l2_dist_squared(&self, othr: &F32Vector<'a>) -> f32 { + debug_assert!(self.len() == othr.len()); + debug_assert!(self.len() % SIMD_LANECOUNT == 0); + + let mut intermediate_sum_x8 = Simd::<f32, SIMD_LANECOUNT>::splat(0.0); + + let self_chunks = self.array.chunks_exact(SIMD_LANECOUNT); + let othr_chunks = othr.array.chunks_exact(SIMD_LANECOUNT); + + for (slice_self, slice_othr) in self_chunks.zip(othr_chunks) { + let f32x8_slf = SimdF32::from_slice(slice_self); + let f32x8_oth = SimdF32::from_slice(slice_othr); + let diff = f32x8_slf - f32x8_oth; + intermediate_sum_x8 += diff * diff; + } + + intermediate_sum_x8.reduce_sum() // 8-to-1 sum + } + + /// # Usage + /// Computes the L2 distance between two vectors. + /// + /// # Panics + /// + /// Panics in debug mode if the two vectors have different lengths. + /// In release mode, the longest vector will be silently truncated. + #[inline] + pub fn l2_dist(&self, other: &F32Vector<'a>) -> f32 { + self.l2_dist_squared(other).sqrt() + } +} + +impl<'a> From<&'a [f32]> for F32Vector<'a> { + fn from(value: &'a [f32]) -> Self { + F32Vector { array: value } + } +} + +impl PartialEq for F32Vector<'_> { + fn eq(&self, other: &Self) -> bool { + self.array + .iter() + .zip(other.array.iter()) + .all(|(&a, &b)| a == b) + } +} + +impl Eq for F32Vector<'_> {} + +impl std::hash::Hash for F32Vector<'_> { + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { + // Iterate through each element of the slice and hash it + for &value in self.array { + value.to_bits().hash(state); // Convert `f32` to its bit representation for consistent hashing + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use quickcheck::{QuickCheck, TestResult}; + + const TOLERANCE: f32 = 1e-8; + + fn close(actual: f32, target: f32) -> bool { + (target - actual).abs() < TOLERANCE + } + + fn is_valid_l2(suspect: f32) -> bool { + suspect.is_finite() && suspect >= 0.0 + } + + fn l2_spec<'a>(v1: F32Vector<'a>, v2: F32Vector<'a>) -> f32 { + v1.array + .iter() + .zip(v2.array.iter()) + .map(|(&x, &y)| { + let diff = x - y; + diff * diff + }) + .sum() + } + + #[test] + fn self_sim_is_zero() { + fn qc_self_sim_is_zero(totest: Vec<f32>) -> TestResult { + let usable_length = totest.len() / 8 * 8; + if totest[0..usable_length].iter().any(|x| !x.is_finite()) { + return TestResult::discard(); + } + let testvec = F32Vector::from(&totest[0..usable_length]); + let selfsim = testvec.l2_dist(&testvec); + let to_check = is_valid_l2(selfsim) && close(selfsim, 0.0); + return TestResult::from_bool(to_check); + } + + QuickCheck::new() + .tests(10_000) + // force that less than 90% of tests are discarded due to precondition violations + // i.e. at least 10% of inputs should be valid so that we cover a good range + .min_tests_passed(1_000) + .quickcheck(qc_self_sim_is_zero as fn(Vec<f32>) -> TestResult); + } + + #[test] + // verifies the claim in the documentation of l2_dist_squared + // i.e. dist(u,v) < dist(w, x) ⇔ dist(u,v) ** 2 < dist(w,x) ** 2 + fn squared_invariant() { + fn qc_squared_invariant(u: Vec<f32>, v: Vec<f32>, w: Vec<f32>, x: Vec<f32>) -> TestResult { + let all_vecs = [u, v, w, x]; //no need to check for NaNs in this case + let min_length = all_vecs.iter().map(|x| x.len()).min().unwrap() / 8 * 8; + let all_vectors: Vec<F32Vector> = all_vecs + .iter() + .map(|vec| F32Vector::from(&vec[..min_length])) + .collect(); + + let d1_squared = all_vectors[0].l2_dist_squared(&all_vectors[1]); + let d2_squared = all_vectors[2].l2_dist_squared(&all_vectors[3]); + + let d1_root = all_vectors[0].l2_dist(&all_vectors[1]); + let d2_root = all_vectors[2].l2_dist(&all_vectors[3]); + + let sanity_check1 = (d1_squared < d2_squared) == (d1_root < d2_root); + let sanity_check2 = (d1_squared <= d2_squared) == (d1_root <= d2_root); + TestResult::from_bool(sanity_check1 && sanity_check2) + } + + QuickCheck::new().tests(10_000).quickcheck( + qc_squared_invariant as fn(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) -> TestResult, + ); + } + + #[test] + fn simd_matches_spec() { + fn qc_simd_matches_spec(u: Vec<f32>, v: Vec<f32>) -> TestResult { + let min_length = u.len().min(v.len()) / 8 * 8; + let (u_f32v, v_f32v) = ( + F32Vector::from(&u[0..min_length]), + F32Vector::from(&v[0..min_length]), + ); + let simd = u_f32v.l2_dist_squared(&v_f32v); + let spec = l2_spec(u_f32v, v_f32v); + + if simd.is_infinite() { + TestResult::from_bool(spec.is_infinite()) + } else if simd.is_nan() { + TestResult::from_bool(spec.is_nan()) + } else { + TestResult::from_bool(close(simd, spec)) + } + } + + QuickCheck::new() + .tests(10_000) + .quickcheck(qc_simd_matches_spec as fn(Vec<f32>, Vec<f32>) -> TestResult); + } +} diff --git a/proximity-rs/src/numerics/mod.rs b/proximity-rs/src/numerics/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..319fa33484ffd698b906feb99a640106453bf428 --- /dev/null +++ b/proximity-rs/src/numerics/mod.rs @@ -0,0 +1,2 @@ +pub mod comp; +pub mod f32vector;