diff --git a/sub_crates/sobol/src/wide.rs b/sub_crates/sobol/src/wide.rs index f90940f..e112cc9 100644 --- a/sub_crates/sobol/src/wide.rs +++ b/sub_crates/sobol/src/wide.rs @@ -22,6 +22,11 @@ pub(crate) mod sse { } } + pub fn get(self, i: usize) -> u32 { + let n: [u32; 4] = unsafe { std::mem::transmute(self) }; + n[i] + } + /// Converts the full range of a 32 bit integer to a float in [0, 1). #[inline(always)] pub fn to_norm_floats(self) -> [f32; 4] { @@ -203,7 +208,7 @@ pub(crate) mod sse { #[inline(always)] fn shl(self, other: i32) -> Int4 { Int4 { - v: unsafe { _mm_sll_epi32(self.v, _mm_set1_epi32(other)) }, + v: unsafe { _mm_sll_epi32(self.v, _mm_set_epi32(0, 0, 0, other)) }, } } } @@ -214,7 +219,7 @@ pub(crate) mod sse { #[inline(always)] fn shr(self, other: i32) -> Int4 { Int4 { - v: unsafe { _mm_srl_epi32(self.v, _mm_set1_epi32(other)) }, + v: unsafe { _mm_srl_epi32(self.v, _mm_set_epi32(0, 0, 0, other)) }, } } } @@ -227,6 +232,38 @@ pub(crate) mod sse { } } } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn from_array() { + let a = Int4::from([1, 2, 3, 4]); + assert_eq!(a.get(0), 1); + assert_eq!(a.get(1), 2); + assert_eq!(a.get(2), 3); + assert_eq!(a.get(3), 4); + } + + #[test] + fn shr() { + let a = Int4::from([0xffffffff; 4]) >> 16; + assert_eq!(a.get(0), 0x0000ffff); + assert_eq!(a.get(1), 0x0000ffff); + assert_eq!(a.get(2), 0x0000ffff); + assert_eq!(a.get(3), 0x0000ffff); + } + + #[test] + fn shl() { + let a = Int4::from([0xffffffff; 4]) << 16; + assert_eq!(a.get(0), 0xffff0000); + assert_eq!(a.get(1), 0xffff0000); + assert_eq!(a.get(2), 0xffff0000); + assert_eq!(a.get(3), 0xffff0000); + } + } } #[cfg(target_arch = "x86_64")] pub(crate) use sse::Int4;