diff --git a/sub_crates/trifloat/src/unsigned32.rs b/sub_crates/trifloat/src/unsigned32.rs index 56baf1f..ff462fc 100644 --- a/sub_crates/trifloat/src/unsigned32.rs +++ b/sub_crates/trifloat/src/unsigned32.rs @@ -2,7 +2,7 @@ //! //! The encoding uses 9 bits of mantissa per number, and 5 bits for the shared //! exponent. The bit layout is [mantissa 1, mantissa 2, mantissa 3, exponent]. -//! The exponent is stored as an unsigned integer with a bias of 10. +//! The exponent is stored as an unsigned integer with a bias of 11. //! //! The largest representable number is `2^21 - 4096`, and the smallest //! representable non-zero number is `2^-19`. @@ -14,22 +14,20 @@ use crate::{fiddle_exp2, fiddle_log2}; /// Largest representable number. -pub const MAX: f32 = 2_093_056.0; +pub const MAX: f32 = ((1u64 << (32 - EXP_BIAS)) - (1 << (32 - EXP_BIAS - 9))) as f32; /// Smallest representable non-zero number. -pub const MIN: f32 = 0.000_001_907_348_6; +pub const MIN: f32 = 1.0 / (1 << (EXP_BIAS + 8)) as f32; /// Difference between 1.0 and the next largest representable number. pub const EPSILON: f32 = 1.0 / 256.0; -const EXP_BIAS: i32 = 10; -const MIN_EXP: i32 = 0 - EXP_BIAS; -const MAX_EXP: i32 = 31 - EXP_BIAS; +const EXP_BIAS: i32 = 11; /// Encodes three floating point values into a signed 32-bit trifloat. /// /// Input floats larger than `MAX` will saturate to `MAX`, including infinity. -/// Values are converted to trifloat precision by rounding. +/// Values are converted to trifloat precision by rounding down. /// /// Warning: negative values and NaN's are _not_ supported by the trifloat /// format. There are debug-only assertions in place to catch such @@ -51,31 +49,19 @@ pub fn encode(floats: (f32, f32, f32)) -> u32 { floats.2 ); - // Find the largest of the three values. - let largest_value = floats.0.max(floats.1.max(floats.2)); - if largest_value <= 0.0 { + let largest = floats.0.max(floats.1.max(floats.2)); + + if largest < MIN { return 0; + } else { + let e = fiddle_log2(largest).max(-EXP_BIAS).min(31 - EXP_BIAS); + let inv_multiplier = fiddle_exp2(-e + 8); + let x = (floats.0 * inv_multiplier).min(511.0) as u32; + let y = (floats.1 * inv_multiplier).min(511.0) as u32; + let z = (floats.2 * inv_multiplier).min(511.0) as u32; + + (x << (9 + 9 + 5)) | (y << (9 + 5)) | (z << 5) | (e + EXP_BIAS) as u32 } - - // Calculate the exponent and 1.0/multiplier for encoding the values. - let mut exponent = (fiddle_log2(largest_value) + 1).max(MIN_EXP).min(MAX_EXP); - let mut inv_multiplier = fiddle_exp2(-exponent + 9); - - // Edge-case: make sure rounding pushes the largest value up - // appropriately if needed. - if (largest_value * inv_multiplier) + 0.5 >= 512.0 { - exponent = (exponent + 1).min(MAX_EXP); - inv_multiplier = fiddle_exp2(-exponent + 9); - } - - // Quantize and encode values. - let x = (floats.0 * inv_multiplier + 0.5).min(511.0) as u32 & 0b1_1111_1111; - let y = (floats.1 * inv_multiplier + 0.5).min(511.0) as u32 & 0b1_1111_1111; - let z = (floats.2 * inv_multiplier + 0.5).min(511.0) as u32 & 0b1_1111_1111; - let e = (exponent + EXP_BIAS) as u32 & 0b1_1111; - - // Pack values into a u32. - (x << (5 + 9 + 9)) | (y << (5 + 9)) | (z << 5) | e } /// Decodes an unsigned 32-bit trifloat into three full floating point numbers. @@ -84,12 +70,12 @@ pub fn encode(floats: (f32, f32, f32)) -> u32 { #[inline] pub fn decode(trifloat: u32) -> (f32, f32, f32) { // Unpack values. - let x = trifloat >> (5 + 9 + 9); - let y = (trifloat >> (5 + 9)) & 0b1_1111_1111; + let x = trifloat >> (9 + 9 + 5); + let y = (trifloat >> (9 + 5)) & 0b1_1111_1111; let z = (trifloat >> 5) & 0b1_1111_1111; let e = trifloat & 0b1_1111; - let multiplier = fiddle_exp2(e as i32 - EXP_BIAS - 9); + let multiplier = fiddle_exp2(e as i32 - EXP_BIAS - 8); ( x as f32 * multiplier, @@ -120,11 +106,11 @@ mod tests { #[test] fn powers_of_two() { let fs = (8.0f32, 128.0f32, 0.5f32); - assert_eq!(round_trip(fs), fs); + assert_eq!(fs, round_trip(fs)); } #[test] - fn accuracy() { + fn accuracy_01() { let mut n = 1.0; for _ in 0..256 { let (x, _, _) = round_trip((n, 0.0, 0.0)); @@ -133,6 +119,17 @@ mod tests { } } + #[test] + #[should_panic] + fn accuracy_02() { + let mut n = 1.0; + for _ in 0..512 { + let (x, _, _) = round_trip((n, 0.0, 0.0)); + assert_eq!(n, x); + n += 1.0 / 512.0; + } + } + #[test] fn integers() { for n in 0..=512 { @@ -142,24 +139,17 @@ mod tests { } #[test] - fn rounding() { + fn precision_floor() { let fs = (7.0f32, 513.0f32, 1.0f32); - assert_eq!(round_trip(fs), (8.0, 514.0, 2.0)); - } - - #[test] - fn rounding_edge_case() { - let fs = (1023.0f32, 0.0f32, 0.0f32); - - assert_eq!(round_trip(fs), (1024.0, 0.0, 0.0),); + assert_eq!((6.0, 512.0, 0.0), round_trip(fs)); } #[test] fn saturate() { let fs = (9999999999.0, 9999999999.0, 9999999999.0); - assert_eq!(round_trip(fs), (MAX, MAX, MAX)); - assert_eq!(decode(0xFFFFFFFF), (MAX, MAX, MAX),); + assert_eq!((MAX, MAX, MAX), round_trip(fs)); + assert_eq!((MAX, MAX, MAX), decode(0xFFFFFFFF)); } #[test] @@ -167,29 +157,29 @@ mod tests { use std::f32::INFINITY; let fs = (INFINITY, 0.0, 0.0); - assert_eq!(round_trip(fs), (MAX, 0.0, 0.0)); - assert_eq!(encode(fs), 0xFF80001F,); + assert_eq!((MAX, 0.0, 0.0), round_trip(fs)); + assert_eq!(0xFF80001F, encode(fs)); } #[test] fn partial_saturate() { let fs = (9999999999.0, 4096.0, 262144.0); - assert_eq!(round_trip(fs), (MAX, 4096.0, 262144.0)); + assert_eq!((MAX, 4096.0, 262144.0), round_trip(fs)); } #[test] fn smallest_value() { - let fs = (MIN, MIN * 0.5, MIN * 0.49); - assert_eq!(round_trip(fs), (MIN, MIN, 0.0)); - assert_eq!(decode(0x00_80_40_00), (MIN, MIN, 0.0)); + let fs = (MIN * 1.5, MIN, MIN * 0.5); + assert_eq!((MIN, MIN, 0.0), round_trip(fs)); + assert_eq!((MIN, MIN, 0.0), decode(0x00_80_40_00)); } #[test] fn underflow() { - let fs = (MIN * 0.49, 0.0, 0.0); - assert_eq!(encode(fs), 0); - assert_eq!(round_trip(fs), (0.0, 0.0, 0.0)); + let fs = (MIN * 0.99, 0.0, 0.0); + assert_eq!(0, encode(fs)); + assert_eq!((0.0, 0.0, 0.0), round_trip(fs)); } #[test]