RMath: first pass at an SSE implementation.

This commit is contained in:
Nathan Vegdahl 2022-07-16 00:03:09 -07:00
parent 08e2e6eb06
commit 8dcf093dbb
3 changed files with 852 additions and 441 deletions

View File

@ -0,0 +1,439 @@
use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Index, Mul, Neg, Not, Sub};
use crate::FMulAdd;
//=============================================================
// Float4
#[derive(Debug, Copy, Clone)]
#[repr(C, align(16))]
pub struct Float4([f32; 4]);
impl Float4 {
/// Create a new `Float4` with the given components.
#[inline(always)]
pub fn new(a: f32, b: f32, c: f32, d: f32) -> Self {
Self([a, b, c, d])
}
/// Create a new `Float4` with all elements set to `n`.
#[inline(always)]
pub fn splat(n: f32) -> Self {
Self([n, n, n, n])
}
/// Component-wise fused multiply-add.
///
/// `(self * a) + b` with only one rounding error.
#[inline(always)]
pub fn mul_add(self, a: Self, b: Self) -> Self {
Self([
self.0[0].mul_add(a.0[0], b.0[0]),
self.0[1].mul_add(a.0[1], b.0[1]),
self.0[2].mul_add(a.0[2], b.0[2]),
self.0[3].mul_add(a.0[3], b.0[3]),
])
}
/// Vertical minimum.
#[inline(always)]
pub fn min(self, a: Self) -> Self {
Self([
self.0[0].min(a.0[0]),
self.0[1].min(a.0[1]),
self.0[2].min(a.0[2]),
self.0[3].min(a.0[3]),
])
}
/// Vertical maximum.
#[inline(always)]
pub fn max(self, a: Self) -> Self {
Self([
self.0[0].max(a.0[0]),
self.0[1].max(a.0[1]),
self.0[2].max(a.0[2]),
self.0[3].max(a.0[3]),
])
}
/// Horizontal minimum.
#[inline(always)]
pub fn min_element(self) -> f32 {
let a = self.0[0].min(self.0[1]);
let b = self.0[2].min(self.0[3]);
a.min(b)
}
/// Horizontal maximum.
#[inline(always)]
pub fn max_element(self) -> f32 {
let a = self.0[0].max(self.0[1]);
let b = self.0[2].max(self.0[3]);
a.max(b)
}
/// 1.0 / self
#[inline(always)]
pub fn recip(self) -> Self {
Float4::splat(1.0) / self
}
#[inline(always)]
pub fn abs(self) -> Self {
Float4::new(
self.a().abs(),
self.b().abs(),
self.c().abs(),
self.d().abs(),
)
}
//-----------------------------------------------------
// Comparisons.
/// Less than.
#[inline(always)]
pub fn cmplt(self, rhs: Self) -> Bool4 {
Bool4([
self.0[0] < rhs.0[0],
self.0[1] < rhs.0[1],
self.0[2] < rhs.0[2],
self.0[3] < rhs.0[3],
])
}
/// Less than or equal.
#[inline(always)]
pub fn cmplte(self, rhs: Self) -> Bool4 {
Bool4([
self.0[0] <= rhs.0[0],
self.0[1] <= rhs.0[1],
self.0[2] <= rhs.0[2],
self.0[3] <= rhs.0[3],
])
}
/// Greater than.
#[inline(always)]
pub fn cmpgt(self, rhs: Self) -> Bool4 {
Bool4([
self.0[0] > rhs.0[0],
self.0[1] > rhs.0[1],
self.0[2] > rhs.0[2],
self.0[3] > rhs.0[3],
])
}
/// Greater than or equal.
#[inline(always)]
pub fn cmpgte(self, rhs: Self) -> Bool4 {
Bool4([
self.0[0] >= rhs.0[0],
self.0[1] >= rhs.0[1],
self.0[2] >= rhs.0[2],
self.0[3] >= rhs.0[3],
])
}
/// Equal.
#[inline(always)]
pub fn cmpeq(self, rhs: Self) -> Bool4 {
Bool4([
self.0[0] == rhs.0[0],
self.0[1] == rhs.0[1],
self.0[2] == rhs.0[2],
self.0[3] == rhs.0[3],
])
}
//-----------------------------------------------------
// Individual components.
#[inline(always)]
pub fn a(self) -> f32 {
self.0[0]
}
#[inline(always)]
pub fn b(self) -> f32 {
self.0[1]
}
#[inline(always)]
pub fn c(self) -> f32 {
self.0[2]
}
#[inline(always)]
pub fn d(self) -> f32 {
self.0[3]
}
#[inline(always)]
#[must_use]
pub fn set_a(self, n: f32) -> Self {
Self([n, self.0[1], self.0[2], self.0[3]])
}
#[inline(always)]
#[must_use]
pub fn set_b(self, n: f32) -> Self {
Self([self.0[0], n, self.0[2], self.0[3]])
}
#[inline(always)]
#[must_use]
pub fn set_c(self, n: f32) -> Self {
Self([self.0[0], self.0[1], n, self.0[3]])
}
#[inline(always)]
#[must_use]
pub fn set_d(self, n: f32) -> Self {
Self([self.0[0], self.0[1], self.0[2], n])
}
//-----------------------------------------------------
// Shuffles.
#[inline(always)]
pub fn aaaa(self) -> Self {
let a = self.0[0];
Self([a, a, a, a])
}
#[inline(always)]
pub fn bbbb(self) -> Self {
let b = self.0[1];
Self([b, b, b, b])
}
#[inline(always)]
pub fn cccc(self) -> Self {
let c = self.0[2];
Self([c, c, c, c])
}
#[inline(always)]
pub fn dddd(self) -> Self {
let d = self.0[3];
Self([d, d, d, d])
}
#[inline(always)]
pub fn bcad(self) -> Self {
let a = self.0[0];
let b = self.0[1];
let c = self.0[2];
let d = self.0[3];
Self([b, c, a, d])
}
#[inline(always)]
pub fn cabd(self) -> Self {
let a = self.0[0];
let b = self.0[1];
let c = self.0[2];
let d = self.0[3];
Self([c, a, b, d])
}
}
impl Index<usize> for Float4 {
type Output = f32;
#[inline(always)]
fn index(&self, idx: usize) -> &f32 {
&self.0[idx]
}
}
impl Add for Float4 {
type Output = Self;
#[inline(always)]
fn add(self, rhs: Self) -> Self {
Self([
self.0[0] + rhs.0[0],
self.0[1] + rhs.0[1],
self.0[2] + rhs.0[2],
self.0[3] + rhs.0[3],
])
}
}
impl Sub for Float4 {
type Output = Self;
#[inline(always)]
fn sub(self, rhs: Self) -> Self {
Self([
self.0[0] - rhs.0[0],
self.0[1] - rhs.0[1],
self.0[2] - rhs.0[2],
self.0[3] - rhs.0[3],
])
}
}
impl Mul for Float4 {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: Self) -> Self {
Self([
self.0[0] * rhs.0[0],
self.0[1] * rhs.0[1],
self.0[2] * rhs.0[2],
self.0[3] * rhs.0[3],
])
}
}
impl Mul<f32> for Float4 {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: f32) -> Self {
Self([
self.0[0] * rhs,
self.0[1] * rhs,
self.0[2] * rhs,
self.0[3] * rhs,
])
}
}
impl Div for Float4 {
type Output = Self;
#[inline(always)]
fn div(self, rhs: Self) -> Self {
Self([
self.0[0] / rhs.0[0],
self.0[1] / rhs.0[1],
self.0[2] / rhs.0[2],
self.0[3] / rhs.0[3],
])
}
}
impl Div<f32> for Float4 {
type Output = Self;
#[inline(always)]
fn div(self, rhs: f32) -> Self {
Self([
self.0[0] / rhs,
self.0[1] / rhs,
self.0[2] / rhs,
self.0[3] / rhs,
])
}
}
impl Neg for Float4 {
type Output = Self;
#[inline(always)]
fn neg(self) -> Self {
Self([-self.0[0], -self.0[1], -self.0[2], -self.0[3]])
}
}
impl FMulAdd for Float4 {
#[inline(always)]
fn fma(self, b: Self, c: Self) -> Self {
self.mul_add(b, c)
}
}
//=============================================================
// Bool4
#[derive(Debug, Copy, Clone)]
#[repr(transparent)]
pub struct Bool4([bool; 4]);
impl Bool4 {
#[inline(always)]
pub fn new_false() -> Self {
Self([false, false, false, false])
}
#[inline(always)]
pub fn to_bools(self) -> [bool; 4] {
self.0
}
/// Note: `a` goes to the least significant bit.
#[inline(always)]
pub fn bitmask(self) -> u8 {
self.0[0] as u8
| ((self.0[1] as u8) << 1)
| ((self.0[2] as u8) << 2)
| ((self.0[3] as u8) << 3)
}
#[inline(always)]
pub fn any(self) -> bool {
self.0[0] | &self.0[1] | self.0[2] | self.0[3]
}
#[inline(always)]
pub fn all(self) -> bool {
self.0[0] & &self.0[1] & self.0[2] & self.0[3]
}
}
impl BitAnd for Bool4 {
type Output = Self;
#[inline(always)]
fn bitand(self, rhs: Self) -> Self {
Self([
self.0[0] & rhs.0[0],
self.0[1] & rhs.0[1],
self.0[2] & rhs.0[2],
self.0[3] & rhs.0[3],
])
}
}
impl BitOr for Bool4 {
type Output = Self;
#[inline(always)]
fn bitor(self, rhs: Self) -> Self {
Self([
self.0[0] | rhs.0[0],
self.0[1] | rhs.0[1],
self.0[2] | rhs.0[2],
self.0[3] | rhs.0[3],
])
}
}
impl BitXor for Bool4 {
type Output = Self;
#[inline(always)]
fn bitxor(self, rhs: Self) -> Self {
Self([
self.0[0] ^ rhs.0[0],
self.0[1] ^ rhs.0[1],
self.0[2] ^ rhs.0[2],
self.0[3] ^ rhs.0[3],
])
}
}
impl Not for Bool4 {
type Output = Self;
#[inline(always)]
fn not(self) -> Self {
Self([!self.0[0], !self.0[1], !self.0[2], !self.0[3]])
}
}

View File

@ -1,453 +1,25 @@
use std::ops::{
AddAssign, BitAndAssign, BitOrAssign, BitXorAssign, DivAssign, MulAssign, SubAssign,
use std::{
cmp::PartialEq,
ops::{AddAssign, BitAndAssign, BitOrAssign, BitXorAssign, DivAssign, MulAssign, SubAssign},
};
use std::cmp::PartialEq;
use crate::utils::ulps_eq;
use crate::{difference_of_products, two_prod, two_sum};
//-------------------------------------------------------------
// Which implementation to use.
#[cfg(not(any(target_arch = "x86_64")))]
mod fallback;
#[cfg(not(any(target_arch = "x86_64")))]
pub use fallback::{Bool4, Float4};
mod fallback {
use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Index, Mul, Neg, Not, Sub};
use crate::FMulAdd;
#[derive(Debug, Copy, Clone)]
#[repr(C, align(16))]
pub struct Float4([f32; 4]);
impl Float4 {
/// Create a new `Float4` with the given components.
#[inline(always)]
pub fn new(a: f32, b: f32, c: f32, d: f32) -> Self {
Self([a, b, c, d])
}
/// Create a new `Float4` with all elements set to `n`.
#[inline(always)]
pub fn splat(n: f32) -> Self {
Self([n, n, n, n])
}
/// Component-wise fused multiply-add.
///
/// `(self * a) + b` with only one rounding error.
#[inline(always)]
pub fn mul_add(self, a: Self, b: Self) -> Self {
Self([
self.0[0].mul_add(a.0[0], b.0[0]),
self.0[1].mul_add(a.0[1], b.0[1]),
self.0[2].mul_add(a.0[2], b.0[2]),
self.0[3].mul_add(a.0[3], b.0[3]),
])
}
/// Vertical minimum.
#[inline(always)]
pub fn min(self, a: Self) -> Self {
Self([
self.0[0].min(a.0[0]),
self.0[1].min(a.0[1]),
self.0[2].min(a.0[2]),
self.0[3].min(a.0[3]),
])
}
/// Vertical maximum.
#[inline(always)]
pub fn max(self, a: Self) -> Self {
Self([
self.0[0].max(a.0[0]),
self.0[1].max(a.0[1]),
self.0[2].max(a.0[2]),
self.0[3].max(a.0[3]),
])
}
/// Horizontal minimum.
#[inline(always)]
pub fn min_element(self) -> f32 {
let a = self.0[0].min(self.0[1]);
let b = self.0[2].min(self.0[3]);
a.min(b)
}
/// Horizontal maximum.
#[inline(always)]
pub fn max_element(self) -> f32 {
let a = self.0[0].max(self.0[1]);
let b = self.0[2].max(self.0[3]);
a.max(b)
}
/// 1.0 / self
#[inline(always)]
pub fn recip(self) -> Self {
Float4::splat(1.0) / self
}
#[inline(always)]
pub fn abs(self) -> Self {
Float4::new(
self.a().abs(),
self.b().abs(),
self.c().abs(),
self.d().abs(),
)
}
//-----------------------------------------------------
// Comparisons.
/// Less than.
#[inline(always)]
pub fn cmplt(self, rhs: Self) -> Bool4 {
Bool4([
self.0[0] < rhs.0[0],
self.0[1] < rhs.0[1],
self.0[2] < rhs.0[2],
self.0[3] < rhs.0[3],
])
}
/// Less than or equal.
#[inline(always)]
pub fn cmplte(self, rhs: Self) -> Bool4 {
Bool4([
self.0[0] <= rhs.0[0],
self.0[1] <= rhs.0[1],
self.0[2] <= rhs.0[2],
self.0[3] <= rhs.0[3],
])
}
/// Greater than.
#[inline(always)]
pub fn cmpgt(self, rhs: Self) -> Bool4 {
Bool4([
self.0[0] > rhs.0[0],
self.0[1] > rhs.0[1],
self.0[2] > rhs.0[2],
self.0[3] > rhs.0[3],
])
}
/// Greater than or equal.
#[inline(always)]
pub fn cmpgte(self, rhs: Self) -> Bool4 {
Bool4([
self.0[0] >= rhs.0[0],
self.0[1] >= rhs.0[1],
self.0[2] >= rhs.0[2],
self.0[3] >= rhs.0[3],
])
}
/// Equal.
#[inline(always)]
pub fn cmpeq(self, rhs: Self) -> Bool4 {
Bool4([
self.0[0] == rhs.0[0],
self.0[1] == rhs.0[1],
self.0[2] == rhs.0[2],
self.0[3] == rhs.0[3],
])
}
//-----------------------------------------------------
// Individual components.
#[inline(always)]
pub fn a(self) -> f32 {
self.0[0]
}
#[inline(always)]
pub fn b(self) -> f32 {
self.0[1]
}
#[inline(always)]
pub fn c(self) -> f32 {
self.0[2]
}
#[inline(always)]
pub fn d(self) -> f32 {
self.0[3]
}
#[inline(always)]
#[must_use]
pub fn set_a(self, n: f32) -> Self {
Self([n, self.0[1], self.0[2], self.0[3]])
}
#[inline(always)]
#[must_use]
pub fn set_b(self, n: f32) -> Self {
Self([self.0[0], n, self.0[2], self.0[3]])
}
#[inline(always)]
#[must_use]
pub fn set_c(self, n: f32) -> Self {
Self([self.0[0], self.0[1], n, self.0[3]])
}
#[inline(always)]
#[must_use]
pub fn set_d(self, n: f32) -> Self {
Self([self.0[0], self.0[1], self.0[2], n])
}
//-----------------------------------------------------
// Shuffles.
#[inline(always)]
pub fn aaaa(self) -> Self {
let a = self.0[0];
Self([a, a, a, a])
}
#[inline(always)]
pub fn bbbb(self) -> Self {
let b = self.0[1];
Self([b, b, b, b])
}
#[inline(always)]
pub fn cccc(self) -> Self {
let c = self.0[2];
Self([c, c, c, c])
}
#[inline(always)]
pub fn dddd(self) -> Self {
let d = self.0[3];
Self([d, d, d, d])
}
#[inline(always)]
pub fn bcad(self) -> Self {
let a = self.0[0];
let b = self.0[1];
let c = self.0[2];
let d = self.0[3];
Self([b, c, a, d])
}
#[inline(always)]
pub fn cabd(self) -> Self {
let a = self.0[0];
let b = self.0[1];
let c = self.0[2];
let d = self.0[3];
Self([c, a, b, d])
}
}
impl Index<usize> for Float4 {
type Output = f32;
#[inline(always)]
fn index(&self, idx: usize) -> &f32 {
&self.0[idx]
}
}
impl Add for Float4 {
type Output = Self;
#[inline(always)]
fn add(self, rhs: Self) -> Self {
Self([
self.0[0] + rhs.0[0],
self.0[1] + rhs.0[1],
self.0[2] + rhs.0[2],
self.0[3] + rhs.0[3],
])
}
}
impl Sub for Float4 {
type Output = Self;
#[inline(always)]
fn sub(self, rhs: Self) -> Self {
Self([
self.0[0] - rhs.0[0],
self.0[1] - rhs.0[1],
self.0[2] - rhs.0[2],
self.0[3] - rhs.0[3],
])
}
}
impl Mul for Float4 {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: Self) -> Self {
Self([
self.0[0] * rhs.0[0],
self.0[1] * rhs.0[1],
self.0[2] * rhs.0[2],
self.0[3] * rhs.0[3],
])
}
}
impl Mul<f32> for Float4 {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: f32) -> Self {
Self([
self.0[0] * rhs,
self.0[1] * rhs,
self.0[2] * rhs,
self.0[3] * rhs,
])
}
}
impl Div for Float4 {
type Output = Self;
#[inline(always)]
fn div(self, rhs: Self) -> Self {
Self([
self.0[0] / rhs.0[0],
self.0[1] / rhs.0[1],
self.0[2] / rhs.0[2],
self.0[3] / rhs.0[3],
])
}
}
impl Div<f32> for Float4 {
type Output = Self;
#[inline(always)]
fn div(self, rhs: f32) -> Self {
Self([
self.0[0] / rhs,
self.0[1] / rhs,
self.0[2] / rhs,
self.0[3] / rhs,
])
}
}
impl Neg for Float4 {
type Output = Self;
#[inline(always)]
fn neg(self) -> Self {
Self([-self.0[0], -self.0[1], -self.0[2], -self.0[3]])
}
}
impl FMulAdd for Float4 {
#[inline(always)]
fn fma(self, b: Self, c: Self) -> Self {
self.mul_add(b, c)
}
}
//---------------------------------------------------------
#[derive(Debug, Copy, Clone)]
#[repr(transparent)]
pub struct Bool4([bool; 4]);
impl Bool4 {
#[inline(always)]
pub fn new_false() -> Self {
Self([false, false, false, false])
}
#[inline(always)]
pub fn to_bools(self) -> [bool; 4] {
self.0
}
/// Note: `a` goes to the least significant bit.
#[inline(always)]
pub fn bitmask(self) -> u8 {
self.0[0] as u8
| ((self.0[1] as u8) << 1)
| ((self.0[2] as u8) << 2)
| ((self.0[3] as u8) << 3)
}
#[inline(always)]
pub fn any(self) -> bool {
self.0[0] | &self.0[1] | self.0[2] | self.0[3]
}
#[inline(always)]
pub fn all(self) -> bool {
self.0[0] & &self.0[1] & self.0[2] & self.0[3]
}
}
impl BitAnd for Bool4 {
type Output = Self;
#[inline(always)]
fn bitand(self, rhs: Self) -> Self {
Self([
self.0[0] & rhs.0[0],
self.0[1] & rhs.0[1],
self.0[2] & rhs.0[2],
self.0[3] & rhs.0[3],
])
}
}
impl BitOr for Bool4 {
type Output = Self;
#[inline(always)]
fn bitor(self, rhs: Self) -> Self {
Self([
self.0[0] | rhs.0[0],
self.0[1] | rhs.0[1],
self.0[2] | rhs.0[2],
self.0[3] | rhs.0[3],
])
}
}
impl BitXor for Bool4 {
type Output = Self;
#[inline(always)]
fn bitxor(self, rhs: Self) -> Self {
Self([
self.0[0] ^ rhs.0[0],
self.0[1] ^ rhs.0[1],
self.0[2] ^ rhs.0[2],
self.0[3] ^ rhs.0[3],
])
}
}
impl Not for Bool4 {
type Output = Self;
#[inline(always)]
fn not(self) -> Self {
Self([!self.0[0], !self.0[1], !self.0[2], !self.0[3]])
}
}
}
#[cfg(target_arch = "x86_64")]
mod sse;
#[cfg(target_arch = "x86_64")]
pub use sse::{Bool4, Float4};
//-------------------------------------------------------------
// Impls that don't depend on inner representation.
impl Float4 {
/// 3D dot product (only uses the first 3 components).
@ -838,4 +410,34 @@ mod tests {
assert!(Float4::aprx_eq(c, d, 575));
assert!(!Float4::aprx_eq(c, d, 565));
}
#[test]
fn index() {
let v = Float4::new(0.0, 1.0, 2.0, 3.0);
assert_eq!(v[0], 0.0);
assert_eq!(v[1], 1.0);
assert_eq!(v[2], 2.0);
assert_eq!(v[3], 3.0);
}
#[test]
fn shuffle() {
let v = Float4::new(0.0, 1.0, 2.0, 3.0);
assert_eq!(v.aaaa(), Float4::splat(0.0));
assert_eq!(v.bbbb(), Float4::splat(1.0));
assert_eq!(v.cccc(), Float4::splat(2.0));
assert_eq!(v.dddd(), Float4::splat(3.0));
assert_eq!(v.bcad(), Float4::new(1.0, 2.0, 0.0, 3.0));
assert_eq!(v.cabd(), Float4::new(2.0, 0.0, 1.0, 3.0));
}
#[test]
fn bitmask() {
let v1 = Float4::new(0.0, 1.0, 2.0, 3.0);
let v2 = Float4::new(9.0, 1.0, 9.0, 3.0);
assert_eq!(v1.cmpeq(v2).bitmask(), 0b1010);
}
}

View File

@ -0,0 +1,370 @@
use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Index, Mul, Neg, Not, Sub};
use std::arch::x86_64::{
__m128, _mm_add_ps, _mm_and_ps, _mm_castsi128_ps, _mm_cmpeq_ps, _mm_cmpge_ps, _mm_cmpgt_ps,
_mm_cmple_ps, _mm_cmplt_ps, _mm_div_ps, _mm_fmadd_ps, _mm_max_ps, _mm_min_ps, _mm_movemask_ps,
_mm_mul_ps, _mm_or_ps, _mm_rcp_ps, _mm_set1_epi32, _mm_set1_ps, _mm_set_ps, _mm_setzero_ps,
_mm_shuffle_ps, _mm_storeu_ps, _mm_sub_ps, _mm_xor_ps,
};
use crate::FMulAdd;
//=============================================================
// Float4
#[derive(Debug, Copy, Clone)]
#[repr(transparent)]
pub struct Float4(__m128);
impl Float4 {
/// Create a new `Float4` with the given components.
#[inline(always)]
pub fn new(a: f32, b: f32, c: f32, d: f32) -> Self {
Self(unsafe { _mm_set_ps(d, c, b, a) })
}
/// Create a new `Float4` with all elements set to `n`.
#[inline(always)]
pub fn splat(n: f32) -> Self {
Self(unsafe { _mm_set1_ps(n) })
}
/// Component-wise fused multiply-add.
///
/// `(self * a) + b` with only one rounding error.
#[inline(always)]
pub fn mul_add(self, a: Self, b: Self) -> Self {
if cfg!(feature = "fma") {
Self(unsafe { _mm_fmadd_ps(self.0, a.0, b.0) })
} else {
Self::new(
self.a().mul_add(a.a(), b.a()),
self.b().mul_add(a.b(), b.b()),
self.c().mul_add(a.c(), b.c()),
self.d().mul_add(a.d(), b.d()),
)
}
}
/// Vertical minimum.
#[inline(always)]
pub fn min(self, rhs: Self) -> Self {
Self(unsafe { _mm_min_ps(self.0, rhs.0) })
}
/// Vertical maximum.
#[inline(always)]
pub fn max(self, rhs: Self) -> Self {
Self(unsafe { _mm_max_ps(self.0, rhs.0) })
}
/// Horizontal minimum.
#[inline(always)]
pub fn min_element(self) -> f32 {
let a = self.a().min(self.b());
let b = self.c().min(self.d());
a.min(b)
}
/// Horizontal maximum.
#[inline(always)]
pub fn max_element(self) -> f32 {
let a = self.a().max(self.b());
let b = self.c().max(self.d());
a.max(b)
}
/// 1.0 / self
#[inline(always)]
pub fn recip(self) -> Self {
Self(unsafe { _mm_rcp_ps(self.0) })
}
#[inline(always)]
pub fn abs(self) -> Self {
Self(unsafe {
let abs_mask = _mm_castsi128_ps(_mm_set1_epi32(!(1 << 31)));
_mm_and_ps(self.0, abs_mask)
})
}
//-----------------------------------------------------
// Comparisons.
/// Less than.
#[inline(always)]
pub fn cmplt(self, rhs: Self) -> Bool4 {
Bool4(unsafe { _mm_cmplt_ps(self.0, rhs.0) })
}
/// Less than or equal.
#[inline(always)]
pub fn cmplte(self, rhs: Self) -> Bool4 {
Bool4(unsafe { _mm_cmple_ps(self.0, rhs.0) })
}
/// Greater than.
#[inline(always)]
pub fn cmpgt(self, rhs: Self) -> Bool4 {
Bool4(unsafe { _mm_cmpgt_ps(self.0, rhs.0) })
}
/// Greater than or equal.
#[inline(always)]
pub fn cmpgte(self, rhs: Self) -> Bool4 {
Bool4(unsafe { _mm_cmpge_ps(self.0, rhs.0) })
}
/// Equal.
#[inline(always)]
pub fn cmpeq(self, rhs: Self) -> Bool4 {
Bool4(unsafe { _mm_cmpeq_ps(self.0, rhs.0) })
}
//-----------------------------------------------------
// Individual components.
#[inline(always)]
pub fn a(self) -> f32 {
self[0]
}
#[inline(always)]
pub fn b(self) -> f32 {
self[1]
}
#[inline(always)]
pub fn c(self) -> f32 {
self[2]
}
#[inline(always)]
pub fn d(self) -> f32 {
self[3]
}
#[inline(always)]
#[must_use]
pub fn set_a(self, n: f32) -> Self {
Self::new(n, self.b(), self.c(), self.d())
}
#[inline(always)]
#[must_use]
pub fn set_b(self, n: f32) -> Self {
Self::new(self.a(), n, self.c(), self.d())
}
#[inline(always)]
#[must_use]
pub fn set_c(self, n: f32) -> Self {
Self::new(self.a(), self.b(), n, self.d())
}
#[inline(always)]
#[must_use]
pub fn set_d(self, n: f32) -> Self {
Self::new(self.a(), self.b(), self.c(), n)
}
//-----------------------------------------------------
// Shuffles.
#[inline(always)]
pub fn aaaa(self) -> Self {
Self(unsafe { _mm_shuffle_ps(self.0, self.0, 0b00_00_00_00) })
}
#[inline(always)]
pub fn bbbb(self) -> Self {
Self(unsafe { _mm_shuffle_ps(self.0, self.0, 0b01_01_01_01) })
}
#[inline(always)]
pub fn cccc(self) -> Self {
Self(unsafe { _mm_shuffle_ps(self.0, self.0, 0b10_10_10_10) })
}
#[inline(always)]
pub fn dddd(self) -> Self {
Self(unsafe { _mm_shuffle_ps(self.0, self.0, 0b11_11_11_11) })
}
#[inline(always)]
pub fn bcad(self) -> Self {
Self(unsafe { _mm_shuffle_ps(self.0, self.0, 0b11_00_10_01) })
}
#[inline(always)]
pub fn cabd(self) -> Self {
Self(unsafe { _mm_shuffle_ps(self.0, self.0, 0b11_01_00_10) })
}
}
impl Index<usize> for Float4 {
type Output = f32;
#[inline(always)]
fn index(&self, idx: usize) -> &f32 {
let elements: &[f32; 4] = unsafe { std::mem::transmute(&self.0) };
match idx {
0 => &elements[0],
1 => &elements[1],
2 => &elements[2],
3 => &elements[3],
_ => panic!("Out of bounds access of Float4 elements."),
}
}
}
impl Add for Float4 {
type Output = Self;
#[inline(always)]
fn add(self, rhs: Self) -> Self {
Self(unsafe { _mm_add_ps(self.0, rhs.0) })
}
}
impl Sub for Float4 {
type Output = Self;
#[inline(always)]
fn sub(self, rhs: Self) -> Self {
Self(unsafe { _mm_sub_ps(self.0, rhs.0) })
}
}
impl Mul for Float4 {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: Self) -> Self {
Self(unsafe { _mm_mul_ps(self.0, rhs.0) })
}
}
impl Mul<f32> for Float4 {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: f32) -> Self {
Self(unsafe { _mm_mul_ps(self.0, _mm_set1_ps(rhs)) })
}
}
impl Div for Float4 {
type Output = Self;
#[inline(always)]
fn div(self, rhs: Self) -> Self {
Self(unsafe { _mm_div_ps(self.0, rhs.0) })
}
}
impl Div<f32> for Float4 {
type Output = Self;
#[inline(always)]
fn div(self, rhs: f32) -> Self {
Self(unsafe { _mm_div_ps(self.0, _mm_set1_ps(rhs)) })
}
}
impl Neg for Float4 {
type Output = Self;
#[inline(always)]
fn neg(self) -> Self {
Self(unsafe { _mm_mul_ps(self.0, _mm_set1_ps(-1.0)) })
}
}
impl FMulAdd for Float4 {
#[inline(always)]
fn fma(self, b: Self, c: Self) -> Self {
self.mul_add(b, c)
}
}
//=============================================================
// Bool4
#[derive(Debug, Copy, Clone)]
#[repr(transparent)]
pub struct Bool4(__m128);
impl Bool4 {
#[inline(always)]
pub fn new_false() -> Self {
Self(unsafe { _mm_setzero_ps() })
}
#[inline(always)]
pub fn to_bools(self) -> [bool; 4] {
let mut v = [0.0f32; 4];
unsafe { _mm_storeu_ps((&mut v[..]).as_mut_ptr(), self.0) }
[
v[0].to_bits() != 0,
v[1].to_bits() != 0,
v[2].to_bits() != 0,
v[3].to_bits() != 0,
]
}
/// Note: `a` goes to the least significant bit.
#[inline(always)]
pub fn bitmask(self) -> u8 {
unsafe { _mm_movemask_ps(self.0) as u8 }
}
#[inline(always)]
pub fn any(self) -> bool {
self.bitmask() != 0
}
#[inline(always)]
pub fn all(self) -> bool {
self.bitmask() == 0b1111
}
}
impl BitAnd for Bool4 {
type Output = Self;
#[inline(always)]
fn bitand(self, rhs: Self) -> Self {
Self(unsafe { _mm_and_ps(self.0, rhs.0) })
}
}
impl BitOr for Bool4 {
type Output = Self;
#[inline(always)]
fn bitor(self, rhs: Self) -> Self {
Self(unsafe { _mm_or_ps(self.0, rhs.0) })
}
}
impl BitXor for Bool4 {
type Output = Self;
#[inline(always)]
fn bitxor(self, rhs: Self) -> Self {
Self(unsafe { _mm_xor_ps(self.0, rhs.0) })
}
}
impl Not for Bool4 {
type Output = Self;
#[inline(always)]
fn not(self) -> Self {
Self(unsafe {
let ones = _mm_castsi128_ps(_mm_set1_epi32(!0));
_mm_xor_ps(self.0, ones)
})
}
}