From 4f4187a80a27952ac9e77547b68570705dc710d7 Mon Sep 17 00:00:00 2001 From: Mathis Randl <mathis.randl@epfl.ch> Date: Wed, 20 Nov 2024 13:20:43 +0100 Subject: [PATCH] simd cleanup + specified the behavior wrt non-simd impl --- src/numerics/f32vector.rs | 64 +++++++++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 12 deletions(-) diff --git a/src/numerics/f32vector.rs b/src/numerics/f32vector.rs index bae4b75..a1fd06b 100644 --- a/src/numerics/f32vector.rs +++ b/src/numerics/f32vector.rs @@ -22,19 +22,23 @@ impl<'a> F32Vector<'a> { /// Panics in debug mode if the two vectors have different lengths. /// In release mode, the longest vector will be silently truncated. #[inline] - pub fn l2_dist_squared(&self, other: &F32Vector<'a>) -> f32 { - let mut sum = f32x8::splat(0.0); - for (chka, chkb) in self - .array - .chunks_exact(f32x8::LEN) - .zip(other.array.chunks_exact(f32x8::LEN)) - { - let simd_a = f32x8::from_slice(chka); - let simd_b = f32x8::from_slice(chkb); - let diff = simd_a - simd_b; - sum += diff * diff; + pub fn l2_dist_squared(&self, othr: &F32Vector<'a>) -> f32 { + debug_assert!(self.len() == othr.len()); + debug_assert!(self.len() % f32x8::LEN == 0); + + let mut intermediate_sum_x8: f32x8 = f32x8::splat(0.0); + + let self_chunks = self.array.chunks_exact(f32x8::LEN); + let othr_chunks = othr.array.chunks_exact(f32x8::LEN); + + for (slice_self, slice_othr) in self_chunks.zip(othr_chunks) { + let f32x8_slf = f32x8::from_slice(slice_self); + let f32x8_oth = f32x8::from_slice(slice_othr); + let diff = f32x8_slf - f32x8_oth; + intermediate_sum_x8 += diff * diff; } - sum.reduce_sum() + + intermediate_sum_x8.reduce_sum() // 8-to-1 sum } /// # Usage @@ -71,6 +75,17 @@ mod tests { suspect.is_finite() && suspect >= 0.0 } + fn l2_spec<'a>(v1: F32Vector<'a>, v2: F32Vector<'a>) -> f32 { + v1.array + .iter() + .zip(v2.array.iter()) + .map(|(&x, &y)| { + let diff = x - y; + diff * diff + }) + .sum() + } + #[test] fn self_sim_is_zero() { fn qc_self_sim_is_zero(totest: Vec<f32>) -> TestResult { @@ -118,4 +133,29 @@ mod tests { qc_squared_invariant as fn(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) -> TestResult, ); } + + #[test] + fn simd_matches_spec() { + fn qc_simd_matches_spec(u: Vec<f32>, v: Vec<f32>) -> TestResult { + let min_length = u.len().min(v.len()) / 8 * 8; + let (u_f32v, v_f32v) = ( + F32Vector::from(&u[0..min_length]), + F32Vector::from(&v[0..min_length]), + ); + let simd = u_f32v.l2_dist_squared(&v_f32v); + let spec = l2_spec(u_f32v, v_f32v); + + if simd.is_infinite() { + TestResult::from_bool(spec.is_infinite()) + } else if simd.is_nan() { + TestResult::from_bool(spec.is_nan()) + } else { + TestResult::from_bool(close(simd, spec)) + } + } + + QuickCheck::new() + .tests(10_000) + .quickcheck(qc_simd_matches_spec as fn(Vec<f32>, Vec<f32>) -> TestResult); + } } -- GitLab