diff --git a/sub_crates/rmath/src/lib.rs b/sub_crates/rmath/src/lib.rs index 25e343c..fee0aca 100644 --- a/sub_crates/rmath/src/lib.rs +++ b/sub_crates/rmath/src/lib.rs @@ -8,17 +8,26 @@ mod vector; pub mod wide4; mod xform; +use std::ops::{Add, Mul, Neg, Sub}; + pub use self::{normal::Normal, point::Point, vector::Vector, xform::Xform, xform::XformFull}; -// /// Trait for calculating dot products. -// pub trait DotProduct { -// fn dot(self, other: Self) -> f32; -// } +/// Trait for calculating dot products. +pub trait DotProduct { + fn dot(self, other: Self) -> f32; -// #[inline] -// pub fn dot(a: T, b: T) -> f32 { -// a.dot(b) -// } + fn dot_fast(self, other: Self) -> f32; +} + +#[inline(always)] +pub fn dot(a: T, b: T) -> f32 { + a.dot(b) +} + +#[inline(always)] +pub fn dot_fast(a: T, b: T) -> f32 { + a.dot_fast(b) +} // /// Trait for calculating cross products. // pub trait CrossProduct { @@ -29,3 +38,68 @@ pub use self::{normal::Normal, point::Point, vector::Vector, xform::Xform, xform // pub fn cross(a: T, b: T) -> T { // a.cross(b) // } + +//------------------------------------------------------------- + +/// Trait representing types that can do fused multiply-add. +trait FMulAdd { + /// `(self * b) + c` with only one floating point rounding error. + fn fma(self, b: Self, c: Self) -> Self; +} + +impl FMulAdd for f32 { + fn fma(self, b: Self, c: Self) -> Self { + self.mul_add(b, c) + } +} + +/// `(a * b) - (c * d)` but done with high precision via floating point tricks. +/// +/// See https://pharr.org/matt/blog/2019/11/03/difference-of-floats +#[inline(always)] +fn difference_of_products(a: T, b: T, c: T, d: T) -> T +where + T: Copy + FMulAdd + Add + Mul + Neg, +{ + let cd = c * d; + let dop = a.fma(b, -cd); + let err = (-c).fma(d, cd); + dop + err +} + +/// `(a * b) + (c * d)` but done with high precision via floating point tricks. +#[inline(always)] +fn sum_of_products(a: T, b: T, c: T, d: T) -> T +where + T: Copy + FMulAdd + Add + Mul + Neg, +{ + let cd = c * d; + let sop = a.fma(b, cd); + let err = c.fma(d, -cd); + sop + err +} + +/// `a * b` but also returns a rounding error for precise composition +/// with other operations. +#[inline(always)] +fn two_prod(a: T, b: T) -> (T, T) +// (product, rounding_err) +where + T: Copy + FMulAdd + Mul + Neg, +{ + let ab = a * b; + (ab, a.fma(b, -ab)) +} + +/// `a + b` but also returns a rounding error for precise composition +/// with other operations. +#[inline(always)] +fn two_sum(a: T, b: T) -> (T, T) +// (sum, rounding_err) +where + T: Copy + Add + Sub, +{ + let sum = a + b; + let delta = sum - a; + (sum, (a - (sum - delta)) + (b - delta)) +} diff --git a/sub_crates/rmath/src/normal.rs b/sub_crates/rmath/src/normal.rs index 64c1bcd..8ae05fe 100644 --- a/sub_crates/rmath/src/normal.rs +++ b/sub_crates/rmath/src/normal.rs @@ -3,7 +3,7 @@ use std::ops::{Add, Div, Mul, Neg, Sub}; use crate::wide4::Float4; - +use crate::DotProduct; use crate::Vector; /// A surface normal in 3D space. @@ -100,17 +100,6 @@ impl Mul for Normal { } } -// impl Mul for Normal { -// type Output = Normal; - -// #[inline] -// fn mul(self, other: Transform) -> Normal { -// Normal { -// co: other.0.matrix3.inverse().transpose().mul_vec3a(self.co), -// } -// } -// } - impl Div for Normal { type Output = Self; @@ -129,12 +118,17 @@ impl Neg for Normal { } } -// impl DotProduct for Normal { -// #[inline(always)] -// fn dot(self, other: Normal) -> f32 { -// self.co.dot(other.co) -// } -// } +impl DotProduct for Normal { + #[inline(always)] + fn dot(self, other: Self) -> f32 { + Float4::dot_3(self.0, other.0) + } + + #[inline(always)] + fn dot_fast(self, other: Self) -> f32 { + Float4::dot_3_fast(self.0, other.0) + } +} // impl CrossProduct for Normal { // #[inline] diff --git a/sub_crates/rmath/src/vector.rs b/sub_crates/rmath/src/vector.rs index 111cd87..0e37c6e 100644 --- a/sub_crates/rmath/src/vector.rs +++ b/sub_crates/rmath/src/vector.rs @@ -5,6 +5,7 @@ use std::ops::{Add, Div, Mul, Neg, Sub}; use crate::normal::Normal; use crate::point::Point; use crate::wide4::Float4; +use crate::DotProduct; /// A direction vector in 3D space. #[derive(Debug, Copy, Clone)] @@ -105,15 +106,6 @@ impl Mul for Vector { } } -// impl Mul for Vector { -// type Output = Self; - -// #[inline] -// fn mul(self, other: Transform) -> Self { -// Self(other.0.transform_vector3a(self.0)) -// } -// } - impl Div for Vector { type Output = Self; @@ -132,12 +124,17 @@ impl Neg for Vector { } } -// impl DotProduct for Vector { -// #[inline(always)] -// fn dot(self, other: Self) -> f32 { -// self.co.dot(other.co) -// } -// } +impl DotProduct for Vector { + #[inline(always)] + fn dot(self, other: Self) -> f32 { + Float4::dot_3(self.0, other.0) + } + + #[inline(always)] + fn dot_fast(self, other: Self) -> f32 { + Float4::dot_3_fast(self.0, other.0) + } +} // impl CrossProduct for Vector { // #[inline] diff --git a/sub_crates/rmath/src/wide4.rs b/sub_crates/rmath/src/wide4.rs index cdbd9ec..0bb3aa0 100644 --- a/sub_crates/rmath/src/wide4.rs +++ b/sub_crates/rmath/src/wide4.rs @@ -1,9 +1,13 @@ use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign}; +use crate::{difference_of_products, two_prod, two_sum}; + pub use fallback::Float4; mod fallback { use std::ops::{Add, Div, Mul, Neg, Sub}; + use crate::FMulAdd; + #[allow(non_camel_case_types)] #[derive(Debug, Copy, Clone)] #[repr(C, align(16))] @@ -81,31 +85,6 @@ mod fallback { // a.max(b) // } - //----------------------------------------------------- - // For matrix stuff. - - #[inline(always)] - pub fn transpose_3x3(m: [Self; 3]) -> [Self; 3] { - [ - // The fourth component in each row below is arbitrary, - // but in this case chosen so that it matches the - // behavior of the SSE version of transpose_3x3. - Self::new(m[0].a(), m[1].a(), m[2].a(), m[2].d()), - Self::new(m[0].b(), m[1].b(), m[2].b(), m[2].d()), - Self::new(m[0].c(), m[1].c(), m[2].c(), m[2].d()), - ] - } - - #[inline] - pub fn invert_3x3(_m: [Self; 3]) -> [Self; 3] { - todo!() - } - - #[inline] - pub fn invert_3x3_precise(_m: [Self; 3]) -> [Self; 3] { - todo!() - } - //----------------------------------------------------- // Individual components. @@ -187,6 +166,24 @@ mod fallback { let d = self.n[3]; Self { n: [d, d, d, d] } } + + #[inline(always)] + pub fn bcad(self) -> Self { + let a = self.n[0]; + let b = self.n[1]; + let c = self.n[2]; + let d = self.n[3]; + Self { n: [b, c, a, d] } + } + + #[inline(always)] + pub fn cabd(self) -> Self { + let a = self.n[0]; + let b = self.n[1]; + let c = self.n[2]; + let d = self.n[3]; + Self { n: [c, a, b, d] } + } } impl Add for Float4 { @@ -295,9 +292,168 @@ mod fallback { } } } + + impl FMulAdd for Float4 { + fn fma(self, b: Self, c: Self) -> Self { + self.mul_add(b, c) + } + } } //------------------------------------------------------------- +// Float4 impls that don't depend on its inner representation. + +impl Float4 { + /// 3D dot product (only uses the first 3 components). + #[inline(always)] + pub fn dot_3(a: Self, b: Self) -> f32 { + let (p, p_err) = two_prod(a, b); + + // Products. + let (x, x_err) = (p.a(), p_err.a()); + let (y, y_err) = (p.b(), p_err.b()); + let (z, z_err) = (p.c(), p_err.c()); + + // Sums. + let (s1, s1_err) = two_sum(x, y); + let err1 = x_err + (y_err + s1_err); + + let (s2, s2_err) = two_sum(s1, z); + let err2 = z_err + (err1 + s2_err); + + // Final result with rounding error compensation. + s2 + err2 + } + + /// 3D dot product (only uses the first 3 components). + /// + /// Faster but less precise version. + #[inline(always)] + pub fn dot_3_fast(a: Self, b: Self) -> f32 { + let c = a * b; + c.a() + c.b() + c.c() + } + + #[inline(always)] + pub fn transpose_3x3(m: [Self; 3]) -> [Self; 3] { + [ + // The fourth component in each row below is arbitrary, + // but in this case chosen so that it matches the + // behavior of the SSE version of transpose_3x3. + Self::new(m[0].a(), m[1].a(), m[2].a(), m[2].d()), + Self::new(m[0].b(), m[1].b(), m[2].b(), m[2].d()), + Self::new(m[0].c(), m[1].c(), m[2].c(), m[2].d()), + ] + } + + /// Invert a 3x3 matrix. + /// + /// Returns `None` if not invertible. + #[inline] + pub fn invert_3x3(m: [Self; 3]) -> Option<[Self; 3]> { + // let a = difference_of_products(m[1].b(), m[2].c(), m[1].c(), m[2].b()); + // let b = difference_of_products(m[1].c(), m[2].a(), m[1].a(), m[2].c()); + // let c = difference_of_products(m[1].a(), m[2].b(), m[1].b(), m[2].a()); + // let d = difference_of_products(m[2].b(), m[0].c(), m[2].c(), m[0].b()); + // let e = difference_of_products(m[2].c(), m[0].a(), m[2].a(), m[0].c()); + // let f = difference_of_products(m[2].a(), m[0].b(), m[2].b(), m[0].a()); + // let g = difference_of_products(m[0].b(), m[1].c(), m[0].c(), m[1].b()); + // let h = difference_of_products(m[0].c(), m[1].a(), m[0].a(), m[1].c()); + // let i = difference_of_products(m[0].a(), m[1].b(), m[0].b(), m[1].a()); + + let m0_bca = m[0].bcad(); + let m1_bca = m[1].bcad(); + let m2_bca = m[2].bcad(); + let m0_cab = m[0].cabd(); + let m1_cab = m[1].cabd(); + let m2_cab = m[2].cabd(); + let abc = difference_of_products(m1_bca, m2_cab, m1_cab, m2_bca); + let def = difference_of_products(m2_bca, m0_cab, m2_cab, m0_bca); + let ghi = difference_of_products(m0_bca, m1_cab, m0_cab, m1_bca); + + // TODO: use precise inner product. + let det = Self::dot_3( + Self::new(abc.a(), def.a(), ghi.a(), 0.0), + Self::new(m[0].a(), m[1].a(), m[2].a(), 0.0), + ); + + if det == 0.0 { + None + } else { + Some(Self::transpose_3x3([abc / det, def / det, ghi / det])) + } + } + + /// Invert a 3x3 matrix. Faster but less precise version. + /// + /// Returns `None` if not invertible. + #[inline] + pub fn invert_3x3_fast(m: [Self; 3]) -> Option<[Self; 3]> { + // let a = (m[1].b() * m[2].c()) - (m[1].c() * m[2].b()); + // let b = (m[1].c() * m[2].a()) - (m[1].a() * m[2].c()); + // let c = (m[1].a() * m[2].b()) - (m[1].b() * m[2].a()); + // let e = (m[2].c() * m[0].a()) - (m[2].a() * m[0].c()); + // let f = (m[2].a() * m[0].b()) - (m[2].b() * m[0].a()); + // let g = (m[0].b() * m[1].c()) - (m[0].c() * m[1].b()); + // let h = (m[0].c() * m[1].a()) - (m[0].a() * m[1].c()); + // let i = (m[0].a() * m[1].b()) - (m[0].b() * m[1].a()); + + let m0_bca = m[0].bcad(); + let m1_bca = m[1].bcad(); + let m2_bca = m[2].bcad(); + let m0_cab = m[0].cabd(); + let m1_cab = m[1].cabd(); + let m2_cab = m[2].cabd(); + let abc = (m1_bca * m2_cab) - (m1_cab * m2_bca); + let def = (m2_bca * m0_cab) - (m2_cab * m0_bca); + let ghi = (m0_bca * m1_cab) - (m0_cab * m1_bca); + + let det = Self::dot_3( + Self::new(abc.a(), def.a(), ghi.a(), 0.0), + Self::new(m[0].a(), m[1].a(), m[2].a(), 0.0), + ); + + if det == 0.0 { + None + } else { + Some(Self::transpose_3x3([abc / det, def / det, ghi / det])) + } + } + + /// Multiplies a 3D vector with a 3x3 matrix. + #[inline] + pub fn vec3_mul_3x3(_v: Self, _m: &[Self; 3]) -> Self { + todo!() + } + + /// Multiplies a 3D vector with a 3x3 matrix. + /// + /// Faster but less precise version. + #[inline] + pub fn vec3_mul_3x3_fast(_v: Self, _m: &[Self; 3]) -> Self { + todo!() + } + + /// Transforms a 3d point by an affine transform. + /// + /// `m` is the 3x3 part of the affine transform, `t` is the translation part. + #[inline] + pub fn pnt3_mul_affine(_p: Self, _m: &[Self; 3], _t: Self) -> Self { + todo!() + } + + /// Transforms a 3d point by an affine transform, except it applies + /// the translation part before the 3x3 part. + /// + /// This is primarily useful for performing efficient inverse transforms by + /// passing an inverted 3x3 part and a negated translation part. + /// + /// `m` is the 3x3 part of the affine transform, `t` is the translation part. + #[inline] + pub fn pnt3_mul_affine_rev(_p: Self, _m: &[Self; 3], _t: Self) -> Self { + todo!() + } +} impl AddAssign for Float4 { #[inline(always)] diff --git a/sub_crates/rmath/src/xform.rs b/sub_crates/rmath/src/xform.rs index 79f7cdb..b0a3593 100644 --- a/sub_crates/rmath/src/xform.rs +++ b/sub_crates/rmath/src/xform.rs @@ -96,22 +96,31 @@ impl Xform { eq } - /// Returns the dual transform, which can do inverse transforms. + /// Computes a "full" version of the transform, which can do both + /// forward and inverse transforms. #[inline] - pub fn compute_dual(self) -> XformFull { + pub fn compute_full(self) -> XformFull { XformFull { m: self.m, - m_inv: Float4::invert_3x3(self.m), + m_inv: Float4::invert_3x3(self.m).unwrap_or([ + Float4::new(1.0, 0.0, 0.0, 0.0), + Float4::new(0.0, 1.0, 0.0, 0.0), + Float4::new(0.0, 0.0, 1.0, 0.0), + ]), t: self.t, } } - /// Slower but precise version of `compute_dual()`. + /// Faster but less precise version of `compute_full()`. #[inline] - pub fn compute_dual_precise(self) -> XformFull { + pub fn compute_full_fast(self) -> XformFull { XformFull { m: self.m, - m_inv: Float4::invert_3x3_precise(self.m), + m_inv: Float4::invert_3x3_fast(self.m).unwrap_or([ + Float4::new(1.0, 0.0, 0.0, 0.0), + Float4::new(0.0, 1.0, 0.0, 0.0), + Float4::new(0.0, 0.0, 1.0, 0.0), + ]), t: self.t, } }