diff --git a/src/accel/bvh4_simd.rs b/src/accel/bvh4_simd.rs index 95b042f..2ad0848 100644 --- a/src/accel/bvh4_simd.rs +++ b/src/accel/bvh4_simd.rs @@ -25,6 +25,7 @@ use super::{ }; use bvh_order::{calc_traversal_code, SplitAxes, TRAVERSAL_TABLE}; +use float4::Bool4; pub fn ray_code(dir: Vector) -> usize { let ray_sign_is_neg = [dir.x() < 0.0, dir.y() < 0.0, dir.z() < 0.0]; @@ -128,27 +129,25 @@ impl<'a> BVH4<'a> { children, traversal_code, } => { - let mut all_hits = 0; + let mut all_hits = Bool4::new(); // Ray testing ray_stack.pop_do_next_task(children.len(), |ray_idx| { if rays.is_done(ray_idx) { ([0; 4], 0) } else { - let hits = lerp_slice(bounds, rays.time(ray_idx)) - .intersect_ray( - rays.orig_local(ray_idx), - rays.dir_inv_local(ray_idx), - rays.max_t(ray_idx), - ) - .to_bitmask(); + let hits = lerp_slice(bounds, rays.time(ray_idx)).intersect_ray( + rays.orig_local(ray_idx), + rays.dir_inv_local(ray_idx), + rays.max_t(ray_idx), + ); - if hits != 0 { - all_hits |= hits; + if !hits.all_false() { + all_hits = all_hits | hits; let mut lanes = [0u8; 4]; let mut lane_count = 0; for i in 0..children.len() { - if (hits >> i) & 1 != 0 { + if hits.get_n(i) { lanes[lane_count] = i as u8; lane_count += 1; } @@ -161,14 +160,14 @@ impl<'a> BVH4<'a> { }); // If there were any intersections, create tasks. - if all_hits > 0 { + if !all_hits.all_false() { let order_code = traversal_table[traversal_code as usize]; let mut lanes = [0usize; 4]; let mut lane_count = 0; for i in 0..children.len() { let inv_i = (children.len() - 1) - i; let child_i = ((order_code >> (inv_i * 2)) & 3) as usize; - if ((all_hits >> child_i) & 1) != 0 { + if all_hits.get_n(child_i) { node_stack[stack_ptr + lane_count] = &children[child_i]; lanes[lane_count] = child_i; lane_count += 1; diff --git a/sub_crates/float4/src/lib.rs b/sub_crates/float4/src/lib.rs index 4006301..99c0417 100644 --- a/sub_crates/float4/src/lib.rs +++ b/sub_crates/float4/src/lib.rs @@ -620,6 +620,14 @@ mod x86_64_sse { } impl Bool4 { + #[inline(always)] + pub fn new() -> Bool4 { + use std::arch::x86_64::_mm_set1_ps; + Bool4 { + data: unsafe { _mm_set1_ps(0.0) }, + } + } + /// Returns the value of the nth element. #[inline(always)] pub fn get_n(&self, n: usize) -> bool { @@ -637,24 +645,33 @@ mod x86_64_sse { self.get_n(0) } - /// Returns the value of the 1th element. + /// Returns the value of the 1st element. #[inline(always)] pub fn get_1(&self) -> bool { self.get_n(1) } - /// Returns the value of the 2th element. + /// Returns the value of the 2nd element. #[inline(always)] pub fn get_2(&self) -> bool { self.get_n(2) } - /// Returns the value of the 3th element. + /// Returns the value of the 3rd element. #[inline(always)] pub fn get_3(&self) -> bool { self.get_n(3) } + /// Returns whether all four bools are false. + /// + /// This is the `OR` operation on all the contained bools. If even + /// one bool is true, this returns true. + pub fn all_false(&self) -> bool { + let a = unsafe { *(&self.data as *const __m128 as *const u128) }; + a == 0 + } + #[inline] pub fn to_bitmask(&self) -> u8 { let a = unsafe { *(&self.data as *const __m128 as *const u8).offset(0) };