diff --git a/sub_crates/rmath/src/wide4/fallback.rs b/sub_crates/rmath/src/wide4/fallback.rs new file mode 100644 index 0000000..076bcf5 --- /dev/null +++ b/sub_crates/rmath/src/wide4/fallback.rs @@ -0,0 +1,439 @@ +use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Index, Mul, Neg, Not, Sub}; + +use crate::FMulAdd; + +//============================================================= +// Float4 + +#[derive(Debug, Copy, Clone)] +#[repr(C, align(16))] +pub struct Float4([f32; 4]); + +impl Float4 { + /// Create a new `Float4` with the given components. + #[inline(always)] + pub fn new(a: f32, b: f32, c: f32, d: f32) -> Self { + Self([a, b, c, d]) + } + + /// Create a new `Float4` with all elements set to `n`. + #[inline(always)] + pub fn splat(n: f32) -> Self { + Self([n, n, n, n]) + } + + /// Component-wise fused multiply-add. + /// + /// `(self * a) + b` with only one rounding error. + #[inline(always)] + pub fn mul_add(self, a: Self, b: Self) -> Self { + Self([ + self.0[0].mul_add(a.0[0], b.0[0]), + self.0[1].mul_add(a.0[1], b.0[1]), + self.0[2].mul_add(a.0[2], b.0[2]), + self.0[3].mul_add(a.0[3], b.0[3]), + ]) + } + + /// Vertical minimum. + #[inline(always)] + pub fn min(self, a: Self) -> Self { + Self([ + self.0[0].min(a.0[0]), + self.0[1].min(a.0[1]), + self.0[2].min(a.0[2]), + self.0[3].min(a.0[3]), + ]) + } + + /// Vertical maximum. + #[inline(always)] + pub fn max(self, a: Self) -> Self { + Self([ + self.0[0].max(a.0[0]), + self.0[1].max(a.0[1]), + self.0[2].max(a.0[2]), + self.0[3].max(a.0[3]), + ]) + } + + /// Horizontal minimum. + #[inline(always)] + pub fn min_element(self) -> f32 { + let a = self.0[0].min(self.0[1]); + let b = self.0[2].min(self.0[3]); + a.min(b) + } + + /// Horizontal maximum. + #[inline(always)] + pub fn max_element(self) -> f32 { + let a = self.0[0].max(self.0[1]); + let b = self.0[2].max(self.0[3]); + a.max(b) + } + + /// 1.0 / self + #[inline(always)] + pub fn recip(self) -> Self { + Float4::splat(1.0) / self + } + + #[inline(always)] + pub fn abs(self) -> Self { + Float4::new( + self.a().abs(), + self.b().abs(), + self.c().abs(), + self.d().abs(), + ) + } + + //----------------------------------------------------- + // Comparisons. + + /// Less than. + #[inline(always)] + pub fn cmplt(self, rhs: Self) -> Bool4 { + Bool4([ + self.0[0] < rhs.0[0], + self.0[1] < rhs.0[1], + self.0[2] < rhs.0[2], + self.0[3] < rhs.0[3], + ]) + } + + /// Less than or equal. + #[inline(always)] + pub fn cmplte(self, rhs: Self) -> Bool4 { + Bool4([ + self.0[0] <= rhs.0[0], + self.0[1] <= rhs.0[1], + self.0[2] <= rhs.0[2], + self.0[3] <= rhs.0[3], + ]) + } + + /// Greater than. + #[inline(always)] + pub fn cmpgt(self, rhs: Self) -> Bool4 { + Bool4([ + self.0[0] > rhs.0[0], + self.0[1] > rhs.0[1], + self.0[2] > rhs.0[2], + self.0[3] > rhs.0[3], + ]) + } + + /// Greater than or equal. + #[inline(always)] + pub fn cmpgte(self, rhs: Self) -> Bool4 { + Bool4([ + self.0[0] >= rhs.0[0], + self.0[1] >= rhs.0[1], + self.0[2] >= rhs.0[2], + self.0[3] >= rhs.0[3], + ]) + } + + /// Equal. + #[inline(always)] + pub fn cmpeq(self, rhs: Self) -> Bool4 { + Bool4([ + self.0[0] == rhs.0[0], + self.0[1] == rhs.0[1], + self.0[2] == rhs.0[2], + self.0[3] == rhs.0[3], + ]) + } + + //----------------------------------------------------- + // Individual components. + + #[inline(always)] + pub fn a(self) -> f32 { + self.0[0] + } + + #[inline(always)] + pub fn b(self) -> f32 { + self.0[1] + } + + #[inline(always)] + pub fn c(self) -> f32 { + self.0[2] + } + + #[inline(always)] + pub fn d(self) -> f32 { + self.0[3] + } + + #[inline(always)] + #[must_use] + pub fn set_a(self, n: f32) -> Self { + Self([n, self.0[1], self.0[2], self.0[3]]) + } + + #[inline(always)] + #[must_use] + pub fn set_b(self, n: f32) -> Self { + Self([self.0[0], n, self.0[2], self.0[3]]) + } + + #[inline(always)] + #[must_use] + pub fn set_c(self, n: f32) -> Self { + Self([self.0[0], self.0[1], n, self.0[3]]) + } + + #[inline(always)] + #[must_use] + pub fn set_d(self, n: f32) -> Self { + Self([self.0[0], self.0[1], self.0[2], n]) + } + + //----------------------------------------------------- + // Shuffles. + + #[inline(always)] + pub fn aaaa(self) -> Self { + let a = self.0[0]; + Self([a, a, a, a]) + } + + #[inline(always)] + pub fn bbbb(self) -> Self { + let b = self.0[1]; + Self([b, b, b, b]) + } + + #[inline(always)] + pub fn cccc(self) -> Self { + let c = self.0[2]; + Self([c, c, c, c]) + } + + #[inline(always)] + pub fn dddd(self) -> Self { + let d = self.0[3]; + Self([d, d, d, d]) + } + + #[inline(always)] + pub fn bcad(self) -> Self { + let a = self.0[0]; + let b = self.0[1]; + let c = self.0[2]; + let d = self.0[3]; + Self([b, c, a, d]) + } + + #[inline(always)] + pub fn cabd(self) -> Self { + let a = self.0[0]; + let b = self.0[1]; + let c = self.0[2]; + let d = self.0[3]; + Self([c, a, b, d]) + } +} + +impl Index for Float4 { + type Output = f32; + + #[inline(always)] + fn index(&self, idx: usize) -> &f32 { + &self.0[idx] + } +} + +impl Add for Float4 { + type Output = Self; + + #[inline(always)] + fn add(self, rhs: Self) -> Self { + Self([ + self.0[0] + rhs.0[0], + self.0[1] + rhs.0[1], + self.0[2] + rhs.0[2], + self.0[3] + rhs.0[3], + ]) + } +} + +impl Sub for Float4 { + type Output = Self; + + #[inline(always)] + fn sub(self, rhs: Self) -> Self { + Self([ + self.0[0] - rhs.0[0], + self.0[1] - rhs.0[1], + self.0[2] - rhs.0[2], + self.0[3] - rhs.0[3], + ]) + } +} + +impl Mul for Float4 { + type Output = Self; + + #[inline(always)] + fn mul(self, rhs: Self) -> Self { + Self([ + self.0[0] * rhs.0[0], + self.0[1] * rhs.0[1], + self.0[2] * rhs.0[2], + self.0[3] * rhs.0[3], + ]) + } +} + +impl Mul for Float4 { + type Output = Self; + + #[inline(always)] + fn mul(self, rhs: f32) -> Self { + Self([ + self.0[0] * rhs, + self.0[1] * rhs, + self.0[2] * rhs, + self.0[3] * rhs, + ]) + } +} + +impl Div for Float4 { + type Output = Self; + + #[inline(always)] + fn div(self, rhs: Self) -> Self { + Self([ + self.0[0] / rhs.0[0], + self.0[1] / rhs.0[1], + self.0[2] / rhs.0[2], + self.0[3] / rhs.0[3], + ]) + } +} + +impl Div for Float4 { + type Output = Self; + + #[inline(always)] + fn div(self, rhs: f32) -> Self { + Self([ + self.0[0] / rhs, + self.0[1] / rhs, + self.0[2] / rhs, + self.0[3] / rhs, + ]) + } +} + +impl Neg for Float4 { + type Output = Self; + + #[inline(always)] + fn neg(self) -> Self { + Self([-self.0[0], -self.0[1], -self.0[2], -self.0[3]]) + } +} + +impl FMulAdd for Float4 { + #[inline(always)] + fn fma(self, b: Self, c: Self) -> Self { + self.mul_add(b, c) + } +} + +//============================================================= +// Bool4 + +#[derive(Debug, Copy, Clone)] +#[repr(transparent)] +pub struct Bool4([bool; 4]); + +impl Bool4 { + #[inline(always)] + pub fn new_false() -> Self { + Self([false, false, false, false]) + } + + #[inline(always)] + pub fn to_bools(self) -> [bool; 4] { + self.0 + } + + /// Note: `a` goes to the least significant bit. + #[inline(always)] + pub fn bitmask(self) -> u8 { + self.0[0] as u8 + | ((self.0[1] as u8) << 1) + | ((self.0[2] as u8) << 2) + | ((self.0[3] as u8) << 3) + } + + #[inline(always)] + pub fn any(self) -> bool { + self.0[0] | &self.0[1] | self.0[2] | self.0[3] + } + + #[inline(always)] + pub fn all(self) -> bool { + self.0[0] & &self.0[1] & self.0[2] & self.0[3] + } +} + +impl BitAnd for Bool4 { + type Output = Self; + + #[inline(always)] + fn bitand(self, rhs: Self) -> Self { + Self([ + self.0[0] & rhs.0[0], + self.0[1] & rhs.0[1], + self.0[2] & rhs.0[2], + self.0[3] & rhs.0[3], + ]) + } +} + +impl BitOr for Bool4 { + type Output = Self; + + #[inline(always)] + fn bitor(self, rhs: Self) -> Self { + Self([ + self.0[0] | rhs.0[0], + self.0[1] | rhs.0[1], + self.0[2] | rhs.0[2], + self.0[3] | rhs.0[3], + ]) + } +} + +impl BitXor for Bool4 { + type Output = Self; + + #[inline(always)] + fn bitxor(self, rhs: Self) -> Self { + Self([ + self.0[0] ^ rhs.0[0], + self.0[1] ^ rhs.0[1], + self.0[2] ^ rhs.0[2], + self.0[3] ^ rhs.0[3], + ]) + } +} + +impl Not for Bool4 { + type Output = Self; + + #[inline(always)] + fn not(self) -> Self { + Self([!self.0[0], !self.0[1], !self.0[2], !self.0[3]]) + } +} diff --git a/sub_crates/rmath/src/wide4.rs b/sub_crates/rmath/src/wide4/mod.rs similarity index 50% rename from sub_crates/rmath/src/wide4.rs rename to sub_crates/rmath/src/wide4/mod.rs index 4e449e2..a0ebe35 100644 --- a/sub_crates/rmath/src/wide4.rs +++ b/sub_crates/rmath/src/wide4/mod.rs @@ -1,453 +1,25 @@ -use std::ops::{ - AddAssign, BitAndAssign, BitOrAssign, BitXorAssign, DivAssign, MulAssign, SubAssign, +use std::{ + cmp::PartialEq, + ops::{AddAssign, BitAndAssign, BitOrAssign, BitXorAssign, DivAssign, MulAssign, SubAssign}, }; -use std::cmp::PartialEq; - use crate::utils::ulps_eq; use crate::{difference_of_products, two_prod, two_sum}; +//------------------------------------------------------------- +// Which implementation to use. + +#[cfg(not(any(target_arch = "x86_64")))] +mod fallback; +#[cfg(not(any(target_arch = "x86_64")))] pub use fallback::{Bool4, Float4}; -mod fallback { - use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Index, Mul, Neg, Not, Sub}; - use crate::FMulAdd; - - #[derive(Debug, Copy, Clone)] - #[repr(C, align(16))] - pub struct Float4([f32; 4]); - - impl Float4 { - /// Create a new `Float4` with the given components. - #[inline(always)] - pub fn new(a: f32, b: f32, c: f32, d: f32) -> Self { - Self([a, b, c, d]) - } - - /// Create a new `Float4` with all elements set to `n`. - #[inline(always)] - pub fn splat(n: f32) -> Self { - Self([n, n, n, n]) - } - - /// Component-wise fused multiply-add. - /// - /// `(self * a) + b` with only one rounding error. - #[inline(always)] - pub fn mul_add(self, a: Self, b: Self) -> Self { - Self([ - self.0[0].mul_add(a.0[0], b.0[0]), - self.0[1].mul_add(a.0[1], b.0[1]), - self.0[2].mul_add(a.0[2], b.0[2]), - self.0[3].mul_add(a.0[3], b.0[3]), - ]) - } - - /// Vertical minimum. - #[inline(always)] - pub fn min(self, a: Self) -> Self { - Self([ - self.0[0].min(a.0[0]), - self.0[1].min(a.0[1]), - self.0[2].min(a.0[2]), - self.0[3].min(a.0[3]), - ]) - } - - /// Vertical maximum. - #[inline(always)] - pub fn max(self, a: Self) -> Self { - Self([ - self.0[0].max(a.0[0]), - self.0[1].max(a.0[1]), - self.0[2].max(a.0[2]), - self.0[3].max(a.0[3]), - ]) - } - - /// Horizontal minimum. - #[inline(always)] - pub fn min_element(self) -> f32 { - let a = self.0[0].min(self.0[1]); - let b = self.0[2].min(self.0[3]); - a.min(b) - } - - /// Horizontal maximum. - #[inline(always)] - pub fn max_element(self) -> f32 { - let a = self.0[0].max(self.0[1]); - let b = self.0[2].max(self.0[3]); - a.max(b) - } - - /// 1.0 / self - #[inline(always)] - pub fn recip(self) -> Self { - Float4::splat(1.0) / self - } - - #[inline(always)] - pub fn abs(self) -> Self { - Float4::new( - self.a().abs(), - self.b().abs(), - self.c().abs(), - self.d().abs(), - ) - } - - //----------------------------------------------------- - // Comparisons. - - /// Less than. - #[inline(always)] - pub fn cmplt(self, rhs: Self) -> Bool4 { - Bool4([ - self.0[0] < rhs.0[0], - self.0[1] < rhs.0[1], - self.0[2] < rhs.0[2], - self.0[3] < rhs.0[3], - ]) - } - - /// Less than or equal. - #[inline(always)] - pub fn cmplte(self, rhs: Self) -> Bool4 { - Bool4([ - self.0[0] <= rhs.0[0], - self.0[1] <= rhs.0[1], - self.0[2] <= rhs.0[2], - self.0[3] <= rhs.0[3], - ]) - } - - /// Greater than. - #[inline(always)] - pub fn cmpgt(self, rhs: Self) -> Bool4 { - Bool4([ - self.0[0] > rhs.0[0], - self.0[1] > rhs.0[1], - self.0[2] > rhs.0[2], - self.0[3] > rhs.0[3], - ]) - } - - /// Greater than or equal. - #[inline(always)] - pub fn cmpgte(self, rhs: Self) -> Bool4 { - Bool4([ - self.0[0] >= rhs.0[0], - self.0[1] >= rhs.0[1], - self.0[2] >= rhs.0[2], - self.0[3] >= rhs.0[3], - ]) - } - - /// Equal. - #[inline(always)] - pub fn cmpeq(self, rhs: Self) -> Bool4 { - Bool4([ - self.0[0] == rhs.0[0], - self.0[1] == rhs.0[1], - self.0[2] == rhs.0[2], - self.0[3] == rhs.0[3], - ]) - } - - //----------------------------------------------------- - // Individual components. - - #[inline(always)] - pub fn a(self) -> f32 { - self.0[0] - } - - #[inline(always)] - pub fn b(self) -> f32 { - self.0[1] - } - - #[inline(always)] - pub fn c(self) -> f32 { - self.0[2] - } - - #[inline(always)] - pub fn d(self) -> f32 { - self.0[3] - } - - #[inline(always)] - #[must_use] - pub fn set_a(self, n: f32) -> Self { - Self([n, self.0[1], self.0[2], self.0[3]]) - } - - #[inline(always)] - #[must_use] - pub fn set_b(self, n: f32) -> Self { - Self([self.0[0], n, self.0[2], self.0[3]]) - } - - #[inline(always)] - #[must_use] - pub fn set_c(self, n: f32) -> Self { - Self([self.0[0], self.0[1], n, self.0[3]]) - } - - #[inline(always)] - #[must_use] - pub fn set_d(self, n: f32) -> Self { - Self([self.0[0], self.0[1], self.0[2], n]) - } - - //----------------------------------------------------- - // Shuffles. - - #[inline(always)] - pub fn aaaa(self) -> Self { - let a = self.0[0]; - Self([a, a, a, a]) - } - - #[inline(always)] - pub fn bbbb(self) -> Self { - let b = self.0[1]; - Self([b, b, b, b]) - } - - #[inline(always)] - pub fn cccc(self) -> Self { - let c = self.0[2]; - Self([c, c, c, c]) - } - - #[inline(always)] - pub fn dddd(self) -> Self { - let d = self.0[3]; - Self([d, d, d, d]) - } - - #[inline(always)] - pub fn bcad(self) -> Self { - let a = self.0[0]; - let b = self.0[1]; - let c = self.0[2]; - let d = self.0[3]; - Self([b, c, a, d]) - } - - #[inline(always)] - pub fn cabd(self) -> Self { - let a = self.0[0]; - let b = self.0[1]; - let c = self.0[2]; - let d = self.0[3]; - Self([c, a, b, d]) - } - } - - impl Index for Float4 { - type Output = f32; - - #[inline(always)] - fn index(&self, idx: usize) -> &f32 { - &self.0[idx] - } - } - - impl Add for Float4 { - type Output = Self; - - #[inline(always)] - fn add(self, rhs: Self) -> Self { - Self([ - self.0[0] + rhs.0[0], - self.0[1] + rhs.0[1], - self.0[2] + rhs.0[2], - self.0[3] + rhs.0[3], - ]) - } - } - - impl Sub for Float4 { - type Output = Self; - - #[inline(always)] - fn sub(self, rhs: Self) -> Self { - Self([ - self.0[0] - rhs.0[0], - self.0[1] - rhs.0[1], - self.0[2] - rhs.0[2], - self.0[3] - rhs.0[3], - ]) - } - } - - impl Mul for Float4 { - type Output = Self; - - #[inline(always)] - fn mul(self, rhs: Self) -> Self { - Self([ - self.0[0] * rhs.0[0], - self.0[1] * rhs.0[1], - self.0[2] * rhs.0[2], - self.0[3] * rhs.0[3], - ]) - } - } - - impl Mul for Float4 { - type Output = Self; - - #[inline(always)] - fn mul(self, rhs: f32) -> Self { - Self([ - self.0[0] * rhs, - self.0[1] * rhs, - self.0[2] * rhs, - self.0[3] * rhs, - ]) - } - } - - impl Div for Float4 { - type Output = Self; - - #[inline(always)] - fn div(self, rhs: Self) -> Self { - Self([ - self.0[0] / rhs.0[0], - self.0[1] / rhs.0[1], - self.0[2] / rhs.0[2], - self.0[3] / rhs.0[3], - ]) - } - } - - impl Div for Float4 { - type Output = Self; - - #[inline(always)] - fn div(self, rhs: f32) -> Self { - Self([ - self.0[0] / rhs, - self.0[1] / rhs, - self.0[2] / rhs, - self.0[3] / rhs, - ]) - } - } - - impl Neg for Float4 { - type Output = Self; - - #[inline(always)] - fn neg(self) -> Self { - Self([-self.0[0], -self.0[1], -self.0[2], -self.0[3]]) - } - } - - impl FMulAdd for Float4 { - #[inline(always)] - fn fma(self, b: Self, c: Self) -> Self { - self.mul_add(b, c) - } - } - - //--------------------------------------------------------- - - #[derive(Debug, Copy, Clone)] - #[repr(transparent)] - pub struct Bool4([bool; 4]); - - impl Bool4 { - #[inline(always)] - pub fn new_false() -> Self { - Self([false, false, false, false]) - } - - #[inline(always)] - pub fn to_bools(self) -> [bool; 4] { - self.0 - } - - /// Note: `a` goes to the least significant bit. - #[inline(always)] - pub fn bitmask(self) -> u8 { - self.0[0] as u8 - | ((self.0[1] as u8) << 1) - | ((self.0[2] as u8) << 2) - | ((self.0[3] as u8) << 3) - } - - #[inline(always)] - pub fn any(self) -> bool { - self.0[0] | &self.0[1] | self.0[2] | self.0[3] - } - - #[inline(always)] - pub fn all(self) -> bool { - self.0[0] & &self.0[1] & self.0[2] & self.0[3] - } - } - - impl BitAnd for Bool4 { - type Output = Self; - - #[inline(always)] - fn bitand(self, rhs: Self) -> Self { - Self([ - self.0[0] & rhs.0[0], - self.0[1] & rhs.0[1], - self.0[2] & rhs.0[2], - self.0[3] & rhs.0[3], - ]) - } - } - - impl BitOr for Bool4 { - type Output = Self; - - #[inline(always)] - fn bitor(self, rhs: Self) -> Self { - Self([ - self.0[0] | rhs.0[0], - self.0[1] | rhs.0[1], - self.0[2] | rhs.0[2], - self.0[3] | rhs.0[3], - ]) - } - } - - impl BitXor for Bool4 { - type Output = Self; - - #[inline(always)] - fn bitxor(self, rhs: Self) -> Self { - Self([ - self.0[0] ^ rhs.0[0], - self.0[1] ^ rhs.0[1], - self.0[2] ^ rhs.0[2], - self.0[3] ^ rhs.0[3], - ]) - } - } - - impl Not for Bool4 { - type Output = Self; - - #[inline(always)] - fn not(self) -> Self { - Self([!self.0[0], !self.0[1], !self.0[2], !self.0[3]]) - } - } -} +#[cfg(target_arch = "x86_64")] +mod sse; +#[cfg(target_arch = "x86_64")] +pub use sse::{Bool4, Float4}; //------------------------------------------------------------- -// Impls that don't depend on inner representation. impl Float4 { /// 3D dot product (only uses the first 3 components). @@ -838,4 +410,34 @@ mod tests { assert!(Float4::aprx_eq(c, d, 575)); assert!(!Float4::aprx_eq(c, d, 565)); } + + #[test] + fn index() { + let v = Float4::new(0.0, 1.0, 2.0, 3.0); + assert_eq!(v[0], 0.0); + assert_eq!(v[1], 1.0); + assert_eq!(v[2], 2.0); + assert_eq!(v[3], 3.0); + } + + #[test] + fn shuffle() { + let v = Float4::new(0.0, 1.0, 2.0, 3.0); + + assert_eq!(v.aaaa(), Float4::splat(0.0)); + assert_eq!(v.bbbb(), Float4::splat(1.0)); + assert_eq!(v.cccc(), Float4::splat(2.0)); + assert_eq!(v.dddd(), Float4::splat(3.0)); + + assert_eq!(v.bcad(), Float4::new(1.0, 2.0, 0.0, 3.0)); + assert_eq!(v.cabd(), Float4::new(2.0, 0.0, 1.0, 3.0)); + } + + #[test] + fn bitmask() { + let v1 = Float4::new(0.0, 1.0, 2.0, 3.0); + let v2 = Float4::new(9.0, 1.0, 9.0, 3.0); + + assert_eq!(v1.cmpeq(v2).bitmask(), 0b1010); + } } diff --git a/sub_crates/rmath/src/wide4/sse.rs b/sub_crates/rmath/src/wide4/sse.rs new file mode 100644 index 0000000..9835ce5 --- /dev/null +++ b/sub_crates/rmath/src/wide4/sse.rs @@ -0,0 +1,370 @@ +use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Index, Mul, Neg, Not, Sub}; + +use std::arch::x86_64::{ + __m128, _mm_add_ps, _mm_and_ps, _mm_castsi128_ps, _mm_cmpeq_ps, _mm_cmpge_ps, _mm_cmpgt_ps, + _mm_cmple_ps, _mm_cmplt_ps, _mm_div_ps, _mm_fmadd_ps, _mm_max_ps, _mm_min_ps, _mm_movemask_ps, + _mm_mul_ps, _mm_or_ps, _mm_rcp_ps, _mm_set1_epi32, _mm_set1_ps, _mm_set_ps, _mm_setzero_ps, + _mm_shuffle_ps, _mm_storeu_ps, _mm_sub_ps, _mm_xor_ps, +}; + +use crate::FMulAdd; + +//============================================================= +// Float4 + +#[derive(Debug, Copy, Clone)] +#[repr(transparent)] +pub struct Float4(__m128); + +impl Float4 { + /// Create a new `Float4` with the given components. + #[inline(always)] + pub fn new(a: f32, b: f32, c: f32, d: f32) -> Self { + Self(unsafe { _mm_set_ps(d, c, b, a) }) + } + + /// Create a new `Float4` with all elements set to `n`. + #[inline(always)] + pub fn splat(n: f32) -> Self { + Self(unsafe { _mm_set1_ps(n) }) + } + + /// Component-wise fused multiply-add. + /// + /// `(self * a) + b` with only one rounding error. + #[inline(always)] + pub fn mul_add(self, a: Self, b: Self) -> Self { + if cfg!(feature = "fma") { + Self(unsafe { _mm_fmadd_ps(self.0, a.0, b.0) }) + } else { + Self::new( + self.a().mul_add(a.a(), b.a()), + self.b().mul_add(a.b(), b.b()), + self.c().mul_add(a.c(), b.c()), + self.d().mul_add(a.d(), b.d()), + ) + } + } + + /// Vertical minimum. + #[inline(always)] + pub fn min(self, rhs: Self) -> Self { + Self(unsafe { _mm_min_ps(self.0, rhs.0) }) + } + + /// Vertical maximum. + #[inline(always)] + pub fn max(self, rhs: Self) -> Self { + Self(unsafe { _mm_max_ps(self.0, rhs.0) }) + } + + /// Horizontal minimum. + #[inline(always)] + pub fn min_element(self) -> f32 { + let a = self.a().min(self.b()); + let b = self.c().min(self.d()); + a.min(b) + } + + /// Horizontal maximum. + #[inline(always)] + pub fn max_element(self) -> f32 { + let a = self.a().max(self.b()); + let b = self.c().max(self.d()); + a.max(b) + } + + /// 1.0 / self + #[inline(always)] + pub fn recip(self) -> Self { + Self(unsafe { _mm_rcp_ps(self.0) }) + } + + #[inline(always)] + pub fn abs(self) -> Self { + Self(unsafe { + let abs_mask = _mm_castsi128_ps(_mm_set1_epi32(!(1 << 31))); + _mm_and_ps(self.0, abs_mask) + }) + } + + //----------------------------------------------------- + // Comparisons. + + /// Less than. + #[inline(always)] + pub fn cmplt(self, rhs: Self) -> Bool4 { + Bool4(unsafe { _mm_cmplt_ps(self.0, rhs.0) }) + } + + /// Less than or equal. + #[inline(always)] + pub fn cmplte(self, rhs: Self) -> Bool4 { + Bool4(unsafe { _mm_cmple_ps(self.0, rhs.0) }) + } + + /// Greater than. + #[inline(always)] + pub fn cmpgt(self, rhs: Self) -> Bool4 { + Bool4(unsafe { _mm_cmpgt_ps(self.0, rhs.0) }) + } + + /// Greater than or equal. + #[inline(always)] + pub fn cmpgte(self, rhs: Self) -> Bool4 { + Bool4(unsafe { _mm_cmpge_ps(self.0, rhs.0) }) + } + + /// Equal. + #[inline(always)] + pub fn cmpeq(self, rhs: Self) -> Bool4 { + Bool4(unsafe { _mm_cmpeq_ps(self.0, rhs.0) }) + } + + //----------------------------------------------------- + // Individual components. + + #[inline(always)] + pub fn a(self) -> f32 { + self[0] + } + + #[inline(always)] + pub fn b(self) -> f32 { + self[1] + } + + #[inline(always)] + pub fn c(self) -> f32 { + self[2] + } + + #[inline(always)] + pub fn d(self) -> f32 { + self[3] + } + + #[inline(always)] + #[must_use] + pub fn set_a(self, n: f32) -> Self { + Self::new(n, self.b(), self.c(), self.d()) + } + + #[inline(always)] + #[must_use] + pub fn set_b(self, n: f32) -> Self { + Self::new(self.a(), n, self.c(), self.d()) + } + + #[inline(always)] + #[must_use] + pub fn set_c(self, n: f32) -> Self { + Self::new(self.a(), self.b(), n, self.d()) + } + + #[inline(always)] + #[must_use] + pub fn set_d(self, n: f32) -> Self { + Self::new(self.a(), self.b(), self.c(), n) + } + + //----------------------------------------------------- + // Shuffles. + + #[inline(always)] + pub fn aaaa(self) -> Self { + Self(unsafe { _mm_shuffle_ps(self.0, self.0, 0b00_00_00_00) }) + } + + #[inline(always)] + pub fn bbbb(self) -> Self { + Self(unsafe { _mm_shuffle_ps(self.0, self.0, 0b01_01_01_01) }) + } + + #[inline(always)] + pub fn cccc(self) -> Self { + Self(unsafe { _mm_shuffle_ps(self.0, self.0, 0b10_10_10_10) }) + } + + #[inline(always)] + pub fn dddd(self) -> Self { + Self(unsafe { _mm_shuffle_ps(self.0, self.0, 0b11_11_11_11) }) + } + + #[inline(always)] + pub fn bcad(self) -> Self { + Self(unsafe { _mm_shuffle_ps(self.0, self.0, 0b11_00_10_01) }) + } + + #[inline(always)] + pub fn cabd(self) -> Self { + Self(unsafe { _mm_shuffle_ps(self.0, self.0, 0b11_01_00_10) }) + } +} + +impl Index for Float4 { + type Output = f32; + + #[inline(always)] + fn index(&self, idx: usize) -> &f32 { + let elements: &[f32; 4] = unsafe { std::mem::transmute(&self.0) }; + match idx { + 0 => &elements[0], + 1 => &elements[1], + 2 => &elements[2], + 3 => &elements[3], + _ => panic!("Out of bounds access of Float4 elements."), + } + } +} + +impl Add for Float4 { + type Output = Self; + + #[inline(always)] + fn add(self, rhs: Self) -> Self { + Self(unsafe { _mm_add_ps(self.0, rhs.0) }) + } +} + +impl Sub for Float4 { + type Output = Self; + + #[inline(always)] + fn sub(self, rhs: Self) -> Self { + Self(unsafe { _mm_sub_ps(self.0, rhs.0) }) + } +} + +impl Mul for Float4 { + type Output = Self; + + #[inline(always)] + fn mul(self, rhs: Self) -> Self { + Self(unsafe { _mm_mul_ps(self.0, rhs.0) }) + } +} + +impl Mul for Float4 { + type Output = Self; + + #[inline(always)] + fn mul(self, rhs: f32) -> Self { + Self(unsafe { _mm_mul_ps(self.0, _mm_set1_ps(rhs)) }) + } +} + +impl Div for Float4 { + type Output = Self; + + #[inline(always)] + fn div(self, rhs: Self) -> Self { + Self(unsafe { _mm_div_ps(self.0, rhs.0) }) + } +} + +impl Div for Float4 { + type Output = Self; + + #[inline(always)] + fn div(self, rhs: f32) -> Self { + Self(unsafe { _mm_div_ps(self.0, _mm_set1_ps(rhs)) }) + } +} + +impl Neg for Float4 { + type Output = Self; + + #[inline(always)] + fn neg(self) -> Self { + Self(unsafe { _mm_mul_ps(self.0, _mm_set1_ps(-1.0)) }) + } +} + +impl FMulAdd for Float4 { + #[inline(always)] + fn fma(self, b: Self, c: Self) -> Self { + self.mul_add(b, c) + } +} + +//============================================================= +// Bool4 + +#[derive(Debug, Copy, Clone)] +#[repr(transparent)] +pub struct Bool4(__m128); + +impl Bool4 { + #[inline(always)] + pub fn new_false() -> Self { + Self(unsafe { _mm_setzero_ps() }) + } + + #[inline(always)] + pub fn to_bools(self) -> [bool; 4] { + let mut v = [0.0f32; 4]; + unsafe { _mm_storeu_ps((&mut v[..]).as_mut_ptr(), self.0) } + [ + v[0].to_bits() != 0, + v[1].to_bits() != 0, + v[2].to_bits() != 0, + v[3].to_bits() != 0, + ] + } + + /// Note: `a` goes to the least significant bit. + #[inline(always)] + pub fn bitmask(self) -> u8 { + unsafe { _mm_movemask_ps(self.0) as u8 } + } + + #[inline(always)] + pub fn any(self) -> bool { + self.bitmask() != 0 + } + + #[inline(always)] + pub fn all(self) -> bool { + self.bitmask() == 0b1111 + } +} + +impl BitAnd for Bool4 { + type Output = Self; + + #[inline(always)] + fn bitand(self, rhs: Self) -> Self { + Self(unsafe { _mm_and_ps(self.0, rhs.0) }) + } +} + +impl BitOr for Bool4 { + type Output = Self; + + #[inline(always)] + fn bitor(self, rhs: Self) -> Self { + Self(unsafe { _mm_or_ps(self.0, rhs.0) }) + } +} + +impl BitXor for Bool4 { + type Output = Self; + + #[inline(always)] + fn bitxor(self, rhs: Self) -> Self { + Self(unsafe { _mm_xor_ps(self.0, rhs.0) }) + } +} + +impl Not for Bool4 { + type Output = Self; + + #[inline(always)] + fn not(self) -> Self { + Self(unsafe { + let ones = _mm_castsi128_ps(_mm_set1_epi32(!0)); + _mm_xor_ps(self.0, ones) + }) + } +}