Moved matrix transpose and inverse code into Float4 crate.

This allows for more optimized implementations, taking advantage
of SIMD intrinsics.
This commit is contained in:
Nathan Vegdahl 2018-06-24 21:06:32 -07:00
parent 8e791259b3
commit df27f7b829
2 changed files with 311 additions and 98 deletions

View File

@ -1,6 +1,6 @@
#![allow(dead_code)]
/// Implementation of Float4 for x86_64 platforms with sse support
/// Implementation of Float4 for x86_64 platforms with SSE support.
#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
mod x86_64_sse {
use std::arch::x86_64::__m128;
@ -343,6 +343,8 @@ mod x86_64_sse {
}
}
// Free functions for Float4
#[inline(always)]
pub fn v_min(a: Float4, b: Float4) -> Float4 {
a.v_min(b)
@ -353,6 +355,196 @@ mod x86_64_sse {
a.v_max(b)
}
/// Transposes a 4x4 matrix in-place.
#[inline(always)]
pub fn transpose(matrix: &mut [Float4; 4]) {
use std::arch::x86_64::_MM_TRANSPOSE4_PS;
// The weird &mut/*mut gymnastics below are to get around
// the borrow-checker. We know statically that these references
// are non-overlapping, so it's safe.
unsafe {
_MM_TRANSPOSE4_PS(
&mut *(&mut matrix[0].data as *mut __m128),
&mut *(&mut matrix[1].data as *mut __m128),
&mut *(&mut matrix[2].data as *mut __m128),
&mut *(&mut matrix[3].data as *mut __m128),
)
};
}
/// Inverts a 4x4 matrix and returns the determinate.
#[inline(always)]
pub fn invert(matrix: &mut [Float4; 4]) -> f32 {
// Code pulled from "Streaming SIMD Extensions - Inverse of 4x4 Matrix"
// by Intel.
// ftp://download.intel.com/design/PentiumIII/sml/24504301.pdf
// Ported to Rust.
// TODO: once __m64 and accompanying intrinsics are stabilized, switch
// to using those, commented out in the code below.
use std::arch::x86_64::{
_mm_add_ps,
_mm_add_ss,
_mm_cvtss_f32,
_mm_mul_ps,
_mm_mul_ss,
_mm_rcp_ss,
// _mm_loadh_pi,
// _mm_loadl_pi,
// _mm_storeh_pi,
// _mm_storel_pi,
_mm_set_ps,
_mm_shuffle_ps,
_mm_sub_ps,
_mm_sub_ss,
};
use std::mem::transmute;
let mut minor0: __m128;
let mut minor1: __m128;
let mut minor2: __m128;
let mut minor3: __m128;
let row0: __m128;
let mut row1: __m128;
let mut row2: __m128;
let mut row3: __m128;
let mut det: __m128;
let mut tmp1: __m128;
unsafe {
// tmp1 = _mm_loadh_pi(_mm_loadl_pi(tmp1, (__m64*)(src)), (__m64*)(src+ 4));
tmp1 = _mm_set_ps(
matrix[1].get_1(),
matrix[1].get_0(),
matrix[0].get_1(),
matrix[0].get_0(),
);
// row1 = _mm_loadh_pi(_mm_loadl_pi(row1, (__m64*)(src+8)), (__m64*)(src+12));
row1 = _mm_set_ps(
matrix[3].get_1(),
matrix[3].get_0(),
matrix[2].get_1(),
matrix[2].get_0(),
);
row0 = _mm_shuffle_ps(tmp1, row1, 0x88);
row1 = _mm_shuffle_ps(row1, tmp1, 0xDD);
// tmp1 = _mm_loadh_pi(_mm_loadl_pi(tmp1, (__m64*)(src+ 2)), (__m64*)(src+ 6));
tmp1 = _mm_set_ps(
matrix[1].get_3(),
matrix[1].get_2(),
matrix[0].get_3(),
matrix[0].get_2(),
);
// row3 = _mm_loadh_pi(_mm_loadl_pi(row3, (__m64*)(src+10)), (__m64*)(src+14));
row3 = _mm_set_ps(
matrix[3].get_3(),
matrix[3].get_2(),
matrix[2].get_3(),
matrix[2].get_2(),
);
row2 = _mm_shuffle_ps(tmp1, row3, 0x88);
row3 = _mm_shuffle_ps(row3, tmp1, 0xDD);
// -----------------------------------------------
tmp1 = _mm_mul_ps(row2, row3);
tmp1 = _mm_shuffle_ps(tmp1, tmp1, 0xB1);
minor0 = _mm_mul_ps(row1, tmp1);
minor1 = _mm_mul_ps(row0, tmp1);
tmp1 = _mm_shuffle_ps(tmp1, tmp1, 0x4E);
minor0 = _mm_sub_ps(_mm_mul_ps(row1, tmp1), minor0);
minor1 = _mm_sub_ps(_mm_mul_ps(row0, tmp1), minor1);
minor1 = _mm_shuffle_ps(minor1, minor1, 0x4E);
// -----------------------------------------------
tmp1 = _mm_mul_ps(row1, row2);
tmp1 = _mm_shuffle_ps(tmp1, tmp1, 0xB1);
minor0 = _mm_add_ps(_mm_mul_ps(row3, tmp1), minor0);
minor3 = _mm_mul_ps(row0, tmp1);
tmp1 = _mm_shuffle_ps(tmp1, tmp1, 0x4E);
minor0 = _mm_sub_ps(minor0, _mm_mul_ps(row3, tmp1));
minor3 = _mm_sub_ps(_mm_mul_ps(row0, tmp1), minor3);
minor3 = _mm_shuffle_ps(minor3, minor3, 0x4E);
// -----------------------------------------------
tmp1 = _mm_mul_ps(_mm_shuffle_ps(row1, row1, 0x4E), row3);
tmp1 = _mm_shuffle_ps(tmp1, tmp1, 0xB1);
row2 = _mm_shuffle_ps(row2, row2, 0x4E);
minor0 = _mm_add_ps(_mm_mul_ps(row2, tmp1), minor0);
minor2 = _mm_mul_ps(row0, tmp1);
tmp1 = _mm_shuffle_ps(tmp1, tmp1, 0x4E);
minor0 = _mm_sub_ps(minor0, _mm_mul_ps(row2, tmp1));
minor2 = _mm_sub_ps(_mm_mul_ps(row0, tmp1), minor2);
minor2 = _mm_shuffle_ps(minor2, minor2, 0x4E);
// -----------------------------------------------
tmp1 = _mm_mul_ps(row0, row1);
tmp1 = _mm_shuffle_ps(tmp1, tmp1, 0xB1);
minor2 = _mm_add_ps(_mm_mul_ps(row3, tmp1), minor2);
minor3 = _mm_sub_ps(_mm_mul_ps(row2, tmp1), minor3);
tmp1 = _mm_shuffle_ps(tmp1, tmp1, 0x4E);
minor2 = _mm_sub_ps(_mm_mul_ps(row3, tmp1), minor2);
minor3 = _mm_sub_ps(minor3, _mm_mul_ps(row2, tmp1));
// -----------------------------------------------
tmp1 = _mm_mul_ps(row0, row3);
tmp1 = _mm_shuffle_ps(tmp1, tmp1, 0xB1);
minor1 = _mm_sub_ps(minor1, _mm_mul_ps(row2, tmp1));
minor2 = _mm_add_ps(_mm_mul_ps(row1, tmp1), minor2);
tmp1 = _mm_shuffle_ps(tmp1, tmp1, 0x4E);
minor1 = _mm_add_ps(_mm_mul_ps(row2, tmp1), minor1);
minor2 = _mm_sub_ps(minor2, _mm_mul_ps(row1, tmp1));
// -----------------------------------------------
tmp1 = _mm_mul_ps(row0, row2);
tmp1 = _mm_shuffle_ps(tmp1, tmp1, 0xB1);
minor1 = _mm_add_ps(_mm_mul_ps(row3, tmp1), minor1);
minor3 = _mm_sub_ps(minor3, _mm_mul_ps(row1, tmp1));
tmp1 = _mm_shuffle_ps(tmp1, tmp1, 0x4E);
minor1 = _mm_sub_ps(minor1, _mm_mul_ps(row3, tmp1));
minor3 = _mm_add_ps(_mm_mul_ps(row1, tmp1), minor3);
// -----------------------------------------------
det = _mm_mul_ps(row0, minor0);
det = _mm_add_ps(_mm_shuffle_ps(det, det, 0x4E), det);
det = _mm_add_ss(_mm_shuffle_ps(det, det, 0xB1), det);
tmp1 = _mm_rcp_ss(det);
det = _mm_sub_ss(
_mm_add_ss(tmp1, tmp1),
_mm_mul_ss(det, _mm_mul_ss(tmp1, tmp1)),
);
det = _mm_shuffle_ps(det, det, 0x00);
minor0 = _mm_mul_ps(det, minor0);
// _mm_storel_pi((__m64*)(src), minor0);
// _mm_storeh_pi((__m64*)(src+2), minor0);
let minor0 = transmute::<__m128, [f32; 4]>(minor0);
matrix[0].data = _mm_set_ps(minor0[3], minor0[2], minor0[1], minor0[0]);
minor1 = _mm_mul_ps(det, minor1);
// _mm_storel_pi((__m64*)(src+4), minor1);
// _mm_storeh_pi((__m64*)(src+6), minor1);
let minor1 = transmute::<__m128, [f32; 4]>(minor1);
matrix[1].data = _mm_set_ps(minor1[3], minor1[2], minor1[1], minor1[0]);
minor2 = _mm_mul_ps(det, minor2);
// _mm_storel_pi((__m64*)(src+ 8), minor2);
// _mm_storeh_pi((__m64*)(src+10), minor2);
let minor2 = transmute::<__m128, [f32; 4]>(minor2);
matrix[2].data = _mm_set_ps(minor2[3], minor2[2], minor2[1], minor2[0]);
minor3 = _mm_mul_ps(det, minor3);
// _mm_storel_pi((__m64*)(src+12), minor3);
// _mm_storeh_pi((__m64*)(src+14), minor3);
let minor3 = transmute::<__m128, [f32; 4]>(minor3);
matrix[3].data = _mm_set_ps(minor3[3], minor3[2], minor3[1], minor3[0]);
_mm_cvtss_f32(det)
}
}
/// Essentially a tuple of four bools, which will use SIMD operations
/// where possible on a platform.
#[derive(Debug, Copy, Clone)]
@ -824,6 +1016,7 @@ mod fallback {
}
}
// Free functions for Float4
#[inline(always)]
pub fn v_min(a: Float4, b: Float4) -> Float4 {
a.v_min(b)
@ -834,6 +1027,93 @@ mod fallback {
a.v_max(b)
}
/// Transposes a 4x4 matrix in-place
#[inline(always)]
pub fn transpose(matrix: &mut [Float4; 4]) {
let m = [
Float4::new(
matrix[0].get_0(),
matrix[1].get_0(),
matrix[2].get_0(),
matrix[3].get_0(),
),
Float4::new(
matrix[0].get_1(),
matrix[1].get_1(),
matrix[2].get_1(),
matrix[3].get_1(),
),
Float4::new(
matrix[0].get_2(),
matrix[1].get_2(),
matrix[2].get_2(),
matrix[3].get_2(),
),
Float4::new(
matrix[0].get_3(),
matrix[1].get_3(),
matrix[2].get_3(),
matrix[3].get_3(),
),
];
*matrix = m;
}
/// Inverts a 4x4 matrix and returns the determinate.
#[inline(always)]
pub fn invert(matrix: &mut [Float4; 4]) -> f32 {
let m = *matrix;
let s0 = (m[0].get_0() * m[1].get_1()) - (m[1].get_0() * m[0].get_1());
let s1 = (m[0].get_0() * m[1].get_2()) - (m[1].get_0() * m[0].get_2());
let s2 = (m[0].get_0() * m[1].get_3()) - (m[1].get_0() * m[0].get_3());
let s3 = (m[0].get_1() * m[1].get_2()) - (m[1].get_1() * m[0].get_2());
let s4 = (m[0].get_1() * m[1].get_3()) - (m[1].get_1() * m[0].get_3());
let s5 = (m[0].get_2() * m[1].get_3()) - (m[1].get_2() * m[0].get_3());
let c5 = (m[2].get_2() * m[3].get_3()) - (m[3].get_2() * m[2].get_3());
let c4 = (m[2].get_1() * m[3].get_3()) - (m[3].get_1() * m[2].get_3());
let c3 = (m[2].get_1() * m[3].get_2()) - (m[3].get_1() * m[2].get_2());
let c2 = (m[2].get_0() * m[3].get_3()) - (m[3].get_0() * m[2].get_3());
let c1 = (m[2].get_0() * m[3].get_2()) - (m[3].get_0() * m[2].get_2());
let c0 = (m[2].get_0() * m[3].get_1()) - (m[3].get_0() * m[2].get_1());
// We don't check for 0.0 determinant, as that is expected to be handled
// by the calling code.
let det = (s0 * c5) - (s1 * c4) + (s2 * c3) + (s3 * c2) - (s4 * c1) + (s5 * c0);
let invdet = 1.0 / det;
*matrix = [
Float4::new(
((m[1].get_1() * c5) - (m[1].get_2() * c4) + (m[1].get_3() * c3)) * invdet,
((-m[0].get_1() * c5) + (m[0].get_2() * c4) - (m[0].get_3() * c3)) * invdet,
((m[3].get_1() * s5) - (m[3].get_2() * s4) + (m[3].get_3() * s3)) * invdet,
((-m[2].get_1() * s5) + (m[2].get_2() * s4) - (m[2].get_3() * s3)) * invdet,
),
Float4::new(
((-m[1].get_0() * c5) + (m[1].get_2() * c2) - (m[1].get_3() * c1)) * invdet,
((m[0].get_0() * c5) - (m[0].get_2() * c2) + (m[0].get_3() * c1)) * invdet,
((-m[3].get_0() * s5) + (m[3].get_2() * s2) - (m[3].get_3() * s1)) * invdet,
((m[2].get_0() * s5) - (m[2].get_2() * s2) + (m[2].get_3() * s1)) * invdet,
),
Float4::new(
((m[1].get_0() * c4) - (m[1].get_1() * c2) + (m[1].get_3() * c0)) * invdet,
((-m[0].get_0() * c4) + (m[0].get_1() * c2) - (m[0].get_3() * c0)) * invdet,
((m[3].get_0() * s4) - (m[3].get_1() * s2) + (m[3].get_3() * s0)) * invdet,
((-m[2].get_0() * s4) + (m[2].get_1() * s2) - (m[2].get_3() * s0)) * invdet,
),
Float4::new(
((-m[1].get_0() * c3) + (m[1].get_1() * c1) - (m[1].get_2() * c0)) * invdet,
((m[0].get_0() * c3) - (m[0].get_1() * c1) + (m[0].get_2() * c0)) * invdet,
((-m[3].get_0() * s3) + (m[3].get_1() * s1) - (m[3].get_2() * s0)) * invdet,
((m[2].get_0() * s3) - (m[2].get_1() * s1) + (m[2].get_2() * s0)) * invdet,
),
];
det
}
/// Essentially a tuple of four bools, which will use SIMD operations
/// where possible on a platform.
#[cfg(feature = "simd_perf")]
@ -928,10 +1208,10 @@ mod fallback {
//===========================================================================
#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
pub use x86_64_sse::{v_max, v_min, Bool4, Float4};
pub use x86_64_sse::{invert, transpose, v_max, v_min, Bool4, Float4};
#[cfg(not(all(target_arch = "x86_64", target_feature = "sse")))]
pub use fallback::{v_max, v_min, Bool4, Float4};
pub use fallback::{invert, transpose, v_max, v_min, Bool4, Float4};
//===========================================================================
@ -1107,6 +1387,26 @@ mod tests {
assert_eq!(r.get_3(), true);
}
#[test]
fn matrix_transpose() {
let mut m1 = [
Float4::new(1.0, 2.0, 3.0, 4.0),
Float4::new(5.0, 6.0, 7.0, 8.0),
Float4::new(9.0, 10.0, 11.0, 12.0),
Float4::new(13.0, 14.0, 15.0, 16.0),
];
let m2 = [
Float4::new(1.0, 5.0, 9.0, 13.0),
Float4::new(2.0, 6.0, 10.0, 14.0),
Float4::new(3.0, 7.0, 11.0, 15.0),
Float4::new(4.0, 8.0, 12.0, 16.0),
];
transpose(&mut m1);
assert_eq!(m1, m2);
}
#[test]
fn bool4_bitmask_01() {
let f1 = Float4::new(0.0, 0.0, 0.0, 0.0);

View File

@ -3,7 +3,7 @@
use std;
use std::ops::{Index, IndexMut, Mul};
use float4::Float4;
use float4::{Float4, transpose, invert};
use super::Point;
@ -109,105 +109,18 @@ impl Matrix4x4 {
/// Returns the transpose of the matrix
#[inline]
pub fn transposed(&self) -> Matrix4x4 {
Matrix4x4 {
values: {
[
Float4::new(
self[0].get_0(),
self[1].get_0(),
self[2].get_0(),
self[3].get_0(),
),
Float4::new(
self[0].get_1(),
self[1].get_1(),
self[2].get_1(),
self[3].get_1(),
),
Float4::new(
self[0].get_2(),
self[1].get_2(),
self[2].get_2(),
self[3].get_2(),
),
Float4::new(
self[0].get_3(),
self[1].get_3(),
self[2].get_3(),
self[3].get_3(),
),
]
},
}
let mut m = *self;
transpose(&mut m.values);
m
}
/// Returns the inverse of the Matrix
#[inline]
pub fn inverse(&self) -> Matrix4x4 {
let s0 = (self[0].get_0() * self[1].get_1()) - (self[1].get_0() * self[0].get_1());
let s1 = (self[0].get_0() * self[1].get_2()) - (self[1].get_0() * self[0].get_2());
let s2 = (self[0].get_0() * self[1].get_3()) - (self[1].get_0() * self[0].get_3());
let s3 = (self[0].get_1() * self[1].get_2()) - (self[1].get_1() * self[0].get_2());
let s4 = (self[0].get_1() * self[1].get_3()) - (self[1].get_1() * self[0].get_3());
let s5 = (self[0].get_2() * self[1].get_3()) - (self[1].get_2() * self[0].get_3());
let c5 = (self[2].get_2() * self[3].get_3()) - (self[3].get_2() * self[2].get_3());
let c4 = (self[2].get_1() * self[3].get_3()) - (self[3].get_1() * self[2].get_3());
let c3 = (self[2].get_1() * self[3].get_2()) - (self[3].get_1() * self[2].get_2());
let c2 = (self[2].get_0() * self[3].get_3()) - (self[3].get_0() * self[2].get_3());
let c1 = (self[2].get_0() * self[3].get_2()) - (self[3].get_0() * self[2].get_2());
let c0 = (self[2].get_0() * self[3].get_1()) - (self[3].get_0() * self[2].get_1());
// TODO: handle 0.0 determinant
let det = (s0 * c5) - (s1 * c4) + (s2 * c3) + (s3 * c2) - (s4 * c1) + (s5 * c0);
let invdet = 1.0 / det;
Matrix4x4 {
values: {
[
Float4::new(
((self[1].get_1() * c5) - (self[1].get_2() * c4) + (self[1].get_3() * c3))
* invdet,
((-self[0].get_1() * c5) + (self[0].get_2() * c4) - (self[0].get_3() * c3))
* invdet,
((self[3].get_1() * s5) - (self[3].get_2() * s4) + (self[3].get_3() * s3))
* invdet,
((-self[2].get_1() * s5) + (self[2].get_2() * s4) - (self[2].get_3() * s3))
* invdet,
),
Float4::new(
((-self[1].get_0() * c5) + (self[1].get_2() * c2) - (self[1].get_3() * c1))
* invdet,
((self[0].get_0() * c5) - (self[0].get_2() * c2) + (self[0].get_3() * c1))
* invdet,
((-self[3].get_0() * s5) + (self[3].get_2() * s2) - (self[3].get_3() * s1))
* invdet,
((self[2].get_0() * s5) - (self[2].get_2() * s2) + (self[2].get_3() * s1))
* invdet,
),
Float4::new(
((self[1].get_0() * c4) - (self[1].get_1() * c2) + (self[1].get_3() * c0))
* invdet,
((-self[0].get_0() * c4) + (self[0].get_1() * c2) - (self[0].get_3() * c0))
* invdet,
((self[3].get_0() * s4) - (self[3].get_1() * s2) + (self[3].get_3() * s0))
* invdet,
((-self[2].get_0() * s4) + (self[2].get_1() * s2) - (self[2].get_3() * s0))
* invdet,
),
Float4::new(
((-self[1].get_0() * c3) + (self[1].get_1() * c1) - (self[1].get_2() * c0))
* invdet,
((self[0].get_0() * c3) - (self[0].get_1() * c1) + (self[0].get_2() * c0))
* invdet,
((-self[3].get_0() * s3) + (self[3].get_1() * s1) - (self[3].get_2() * s0))
* invdet,
((self[2].get_0() * s3) - (self[2].get_1() * s1) + (self[2].get_2() * s0))
* invdet,
),
]
},
}
let mut m = *self;
let det = invert(&mut m.values);
debug_assert_ne!(det, 0.0);
m
}
}