diff --git a/sub_crates/rmath/src/utils.rs b/sub_crates/rmath/src/utils.rs index bb5bd23..da746f8 100644 --- a/sub_crates/rmath/src/utils.rs +++ b/sub_crates/rmath/src/utils.rs @@ -1,5 +1,3 @@ -const TOP_BIT: u32 = 1 << 31; - /// Compute how different two floats are in ulps. /// /// Notes: @@ -7,9 +5,12 @@ const TOP_BIT: u32 = 1 << 31; /// implications of that to the rest of the numbers. E.g. the numbers /// just before and after 0.0/-0.0 are only two ulps apart, not three. /// - Infinity is one ulp past float max, and converse for -infinity. -/// - This doesn't handle NaNs in any really useful way. +/// - If either number is NaN, returns `u32::MAX`. #[inline(always)] pub fn ulp_diff(a: f32, b: f32) -> u32 { + const TOP_BIT: u32 = 1 << 31; + const NAN_THRESHOLD: u32 = 0x7f800000; + let a = a.to_bits(); let b = b.to_bits(); @@ -18,7 +19,10 @@ pub fn ulp_diff(a: f32, b: f32) -> u32 { let a_abs = a & !TOP_BIT; let b_abs = b & !TOP_BIT; - if a_sign == b_sign { + if a_abs > NAN_THRESHOLD || b_abs > NAN_THRESHOLD { + // NaNs always return maximum ulps apart. + u32::MAX + } else if a_sign == b_sign { a_abs.max(b_abs) - a_abs.min(b_abs) } else { a_abs + b_abs @@ -28,7 +32,8 @@ pub fn ulp_diff(a: f32, b: f32) -> u32 { /// Checks if two floats are approximately equal, within `max_ulps`. #[inline(always)] pub fn ulps_eq(a: f32, b: f32, max_ulps: u32) -> bool { - !a.is_nan() && !b.is_nan() && (ulp_diff(a, b) <= max_ulps) + // The minimum ensures that NaNs never return true. + ulp_diff(a, b) <= max_ulps.min(u32::MAX - 1) } #[cfg(test)] @@ -49,10 +54,12 @@ mod tests { assert_eq!(ulp_diff(-0.0, 1.0), 0x3f800000); assert_eq!(ulp_diff(0.0, -1.0), 0x3f800000); assert_eq!(ulp_diff(-0.0, -1.0), 0x3f800000); - assert_eq!( - ulp_diff(std::f32::INFINITY, -std::f32::INFINITY), - 0xff000000 - ); + assert_eq!(ulp_diff(f32::INFINITY, -f32::INFINITY), 0xff000000); + assert_eq!(ulp_diff(f32::NAN, f32::NAN), 0xffffffff); + assert_eq!(ulp_diff(f32::NAN, 1.0), 0xffffffff); + assert_eq!(ulp_diff(1.0, f32::NAN), 0xffffffff); + assert_eq!(ulp_diff(-f32::NAN, 1.0), 0xffffffff); + assert_eq!(ulp_diff(1.0, -f32::NAN), 0xffffffff); assert_eq!(ulp_diff(0.0, f32::from_bits(0.0f32.to_bits() + 1)), 1); assert_eq!(ulp_diff(-0.0, f32::from_bits(0.0f32.to_bits() + 1)), 1); } @@ -82,6 +89,7 @@ mod tests { assert!(!ulps_eq(std::f32::NAN, std::f32::NAN, 0)); assert!(!ulps_eq(-std::f32::NAN, -std::f32::NAN, 0)); + assert!(!ulps_eq(std::f32::NAN, std::f32::NAN, u32::MAX)); assert!(!ulps_eq(std::f32::NAN, std::f32::INFINITY, 1 << 31)); assert!(!ulps_eq(std::f32::INFINITY, std::f32::NAN, 1 << 31)); }