From d34ab6a1483e558f61a3a64221125238440cd8a2 Mon Sep 17 00:00:00 2001 From: Mathis Randl <mathis.randl@epfl.ch> Date: Wed, 8 Jan 2025 16:33:05 +0100 Subject: [PATCH] doubly generic cache + vector handling --- bindings/src/api.rs | 54 ++++++++++++++++++++++++++++++++++++++++----- bindings/src/lib.rs | 4 ++-- 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/bindings/src/api.rs b/bindings/src/api.rs index e91cb73..e175834 100644 --- a/bindings/src/api.rs +++ b/bindings/src/api.rs @@ -1,9 +1,9 @@ use proximipy::caching::approximate_cache::ApproximateCache; use proximipy::caching::bounded::bounded_linear_cache::BoundedLinearCache; -use pyo3::{pyclass, pymethods}; +use pyo3::{pyclass, pymethods, types::{PyAnyMethods, PyList}, Bound, FromPyObject, IntoPyObject, PyErr}; macro_rules! create_pythonized_interface { - ($name: ident, $type: ident) => { + ($name: ident, $keytype: ident, $valuetype : ident) => { // unsendable == should hard-crash if Python tries to access it from // two different Python threads. // @@ -15,7 +15,7 @@ macro_rules! create_pythonized_interface { // happen on the Rust side and will not be visible to the Python ML pipeline. #[pyclass(unsendable)] pub struct $name { - inner: BoundedLinearCache<$type, $type>, + inner: BoundedLinearCache<$keytype, $valuetype>, } #[pymethods] @@ -27,19 +27,61 @@ macro_rules! create_pythonized_interface { } } - fn find(&mut self, k: $type) -> Option<$type> { + fn find(&mut self, k: $keytype) -> Option<$valuetype> { self.inner.find(&k) } - fn insert(&mut self, key: $type, value: $type) { + fn insert(&mut self, key: $keytype, value: $valuetype) { self.inner.insert(key, value) } fn len(&self) -> usize { self.inner.len() } + + fn __len__(&self) -> usize { + self.len() + } } }; } +struct F32VecPy { + inner : Vec<f32> +} + +/// Explain to Rust how to parse some random python object into an actual Rust vector +/// This involves new allocations because Python cannot be trusted to keep this +/// reference alive. +/// +/// This can fail if the random object in question is not a list of numbers, +/// in which case it is automatically reported by raising a TypeError exception +/// in the Python code +impl <'a> FromPyObject <'a> for F32VecPy { + fn extract_bound(ob: &pyo3::Bound<'a, pyo3::PyAny>) -> pyo3::PyResult<Self> { + let list : Vec<f32> = ob.downcast::<PyList>()?.extract()?; + Ok(F32VecPy {inner : list}) + } +} + +// Cast back the list of floats to a Python list +impl <'a> IntoPyObject<'a> for F32VecPy { + type Target = PyList; + type Output = Bound<'a, PyList>; + type Error = PyErr; + + fn into_pyobject(self, py: pyo3::Python<'a>) -> Result<Self::Output, Self::Error> { + let internal = self.inner; + PyList::new(py, internal) + } +} + +impl Clone for F32VecPy { + fn clone(&self) -> Self { + F32VecPy { + inner : self.inner.clone() + } + } +} + -create_pythonized_interface!(I16Cache, i16); +create_pythonized_interface!(I16ToVectorCache, i16, F32VecPy); \ No newline at end of file diff --git a/bindings/src/lib.rs b/bindings/src/lib.rs index 5edb994..13e37e3 100644 --- a/bindings/src/lib.rs +++ b/bindings/src/lib.rs @@ -1,4 +1,4 @@ -use api::I16Cache; +use api::I16ToVectorCache; use pyo3::prelude::*; mod api; @@ -6,6 +6,6 @@ mod api; /// A Python module implemented in Rust. #[pymodule] fn proximipy(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_class::<I16Cache>()?; + m.add_class::<I16ToVectorCache>()?; Ok(()) } -- GitLab