Cleaned up the u32 trifloat implementation.

This also makes encoding faster.  However, it no longer does
rounding to the nearest precision when encoding, and insead does
flooring.  This seems like a reasonable tradeoff: if you want more
precision... you should use a format with more precision.
This commit is contained in:
Nathan Vegdahl 2020-09-18 21:04:16 +09:00
parent f13ffac7bd
commit 96b8dd84b9

View File

@ -2,7 +2,7 @@
//! //!
//! The encoding uses 9 bits of mantissa per number, and 5 bits for the shared //! The encoding uses 9 bits of mantissa per number, and 5 bits for the shared
//! exponent. The bit layout is [mantissa 1, mantissa 2, mantissa 3, exponent]. //! exponent. The bit layout is [mantissa 1, mantissa 2, mantissa 3, exponent].
//! The exponent is stored as an unsigned integer with a bias of 10. //! The exponent is stored as an unsigned integer with a bias of 11.
//! //!
//! The largest representable number is `2^21 - 4096`, and the smallest //! The largest representable number is `2^21 - 4096`, and the smallest
//! representable non-zero number is `2^-19`. //! representable non-zero number is `2^-19`.
@ -14,22 +14,20 @@
use crate::{fiddle_exp2, fiddle_log2}; use crate::{fiddle_exp2, fiddle_log2};
/// Largest representable number. /// Largest representable number.
pub const MAX: f32 = 2_093_056.0; pub const MAX: f32 = ((1u64 << (32 - EXP_BIAS)) - (1 << (32 - EXP_BIAS - 9))) as f32;
/// Smallest representable non-zero number. /// Smallest representable non-zero number.
pub const MIN: f32 = 0.000_001_907_348_6; pub const MIN: f32 = 1.0 / (1 << (EXP_BIAS + 8)) as f32;
/// Difference between 1.0 and the next largest representable number. /// Difference between 1.0 and the next largest representable number.
pub const EPSILON: f32 = 1.0 / 256.0; pub const EPSILON: f32 = 1.0 / 256.0;
const EXP_BIAS: i32 = 10; const EXP_BIAS: i32 = 11;
const MIN_EXP: i32 = 0 - EXP_BIAS;
const MAX_EXP: i32 = 31 - EXP_BIAS;
/// Encodes three floating point values into a signed 32-bit trifloat. /// Encodes three floating point values into a signed 32-bit trifloat.
/// ///
/// Input floats larger than `MAX` will saturate to `MAX`, including infinity. /// Input floats larger than `MAX` will saturate to `MAX`, including infinity.
/// Values are converted to trifloat precision by rounding. /// Values are converted to trifloat precision by rounding down.
/// ///
/// Warning: negative values and NaN's are _not_ supported by the trifloat /// Warning: negative values and NaN's are _not_ supported by the trifloat
/// format. There are debug-only assertions in place to catch such /// format. There are debug-only assertions in place to catch such
@ -51,31 +49,19 @@ pub fn encode(floats: (f32, f32, f32)) -> u32 {
floats.2 floats.2
); );
// Find the largest of the three values. let largest = floats.0.max(floats.1.max(floats.2));
let largest_value = floats.0.max(floats.1.max(floats.2));
if largest_value <= 0.0 { if largest < MIN {
return 0; return 0;
} else {
let e = fiddle_log2(largest).max(-EXP_BIAS).min(31 - EXP_BIAS);
let inv_multiplier = fiddle_exp2(-e + 8);
let x = (floats.0 * inv_multiplier).min(511.0) as u32;
let y = (floats.1 * inv_multiplier).min(511.0) as u32;
let z = (floats.2 * inv_multiplier).min(511.0) as u32;
(x << (9 + 9 + 5)) | (y << (9 + 5)) | (z << 5) | (e + EXP_BIAS) as u32
} }
// Calculate the exponent and 1.0/multiplier for encoding the values.
let mut exponent = (fiddle_log2(largest_value) + 1).max(MIN_EXP).min(MAX_EXP);
let mut inv_multiplier = fiddle_exp2(-exponent + 9);
// Edge-case: make sure rounding pushes the largest value up
// appropriately if needed.
if (largest_value * inv_multiplier) + 0.5 >= 512.0 {
exponent = (exponent + 1).min(MAX_EXP);
inv_multiplier = fiddle_exp2(-exponent + 9);
}
// Quantize and encode values.
let x = (floats.0 * inv_multiplier + 0.5).min(511.0) as u32 & 0b1_1111_1111;
let y = (floats.1 * inv_multiplier + 0.5).min(511.0) as u32 & 0b1_1111_1111;
let z = (floats.2 * inv_multiplier + 0.5).min(511.0) as u32 & 0b1_1111_1111;
let e = (exponent + EXP_BIAS) as u32 & 0b1_1111;
// Pack values into a u32.
(x << (5 + 9 + 9)) | (y << (5 + 9)) | (z << 5) | e
} }
/// Decodes an unsigned 32-bit trifloat into three full floating point numbers. /// Decodes an unsigned 32-bit trifloat into three full floating point numbers.
@ -84,12 +70,12 @@ pub fn encode(floats: (f32, f32, f32)) -> u32 {
#[inline] #[inline]
pub fn decode(trifloat: u32) -> (f32, f32, f32) { pub fn decode(trifloat: u32) -> (f32, f32, f32) {
// Unpack values. // Unpack values.
let x = trifloat >> (5 + 9 + 9); let x = trifloat >> (9 + 9 + 5);
let y = (trifloat >> (5 + 9)) & 0b1_1111_1111; let y = (trifloat >> (9 + 5)) & 0b1_1111_1111;
let z = (trifloat >> 5) & 0b1_1111_1111; let z = (trifloat >> 5) & 0b1_1111_1111;
let e = trifloat & 0b1_1111; let e = trifloat & 0b1_1111;
let multiplier = fiddle_exp2(e as i32 - EXP_BIAS - 9); let multiplier = fiddle_exp2(e as i32 - EXP_BIAS - 8);
( (
x as f32 * multiplier, x as f32 * multiplier,
@ -120,11 +106,11 @@ mod tests {
#[test] #[test]
fn powers_of_two() { fn powers_of_two() {
let fs = (8.0f32, 128.0f32, 0.5f32); let fs = (8.0f32, 128.0f32, 0.5f32);
assert_eq!(round_trip(fs), fs); assert_eq!(fs, round_trip(fs));
} }
#[test] #[test]
fn accuracy() { fn accuracy_01() {
let mut n = 1.0; let mut n = 1.0;
for _ in 0..256 { for _ in 0..256 {
let (x, _, _) = round_trip((n, 0.0, 0.0)); let (x, _, _) = round_trip((n, 0.0, 0.0));
@ -133,6 +119,17 @@ mod tests {
} }
} }
#[test]
#[should_panic]
fn accuracy_02() {
let mut n = 1.0;
for _ in 0..512 {
let (x, _, _) = round_trip((n, 0.0, 0.0));
assert_eq!(n, x);
n += 1.0 / 512.0;
}
}
#[test] #[test]
fn integers() { fn integers() {
for n in 0..=512 { for n in 0..=512 {
@ -142,24 +139,17 @@ mod tests {
} }
#[test] #[test]
fn rounding() { fn precision_floor() {
let fs = (7.0f32, 513.0f32, 1.0f32); let fs = (7.0f32, 513.0f32, 1.0f32);
assert_eq!(round_trip(fs), (8.0, 514.0, 2.0)); assert_eq!((6.0, 512.0, 0.0), round_trip(fs));
}
#[test]
fn rounding_edge_case() {
let fs = (1023.0f32, 0.0f32, 0.0f32);
assert_eq!(round_trip(fs), (1024.0, 0.0, 0.0),);
} }
#[test] #[test]
fn saturate() { fn saturate() {
let fs = (9999999999.0, 9999999999.0, 9999999999.0); let fs = (9999999999.0, 9999999999.0, 9999999999.0);
assert_eq!(round_trip(fs), (MAX, MAX, MAX)); assert_eq!((MAX, MAX, MAX), round_trip(fs));
assert_eq!(decode(0xFFFFFFFF), (MAX, MAX, MAX),); assert_eq!((MAX, MAX, MAX), decode(0xFFFFFFFF));
} }
#[test] #[test]
@ -167,29 +157,29 @@ mod tests {
use std::f32::INFINITY; use std::f32::INFINITY;
let fs = (INFINITY, 0.0, 0.0); let fs = (INFINITY, 0.0, 0.0);
assert_eq!(round_trip(fs), (MAX, 0.0, 0.0)); assert_eq!((MAX, 0.0, 0.0), round_trip(fs));
assert_eq!(encode(fs), 0xFF80001F,); assert_eq!(0xFF80001F, encode(fs));
} }
#[test] #[test]
fn partial_saturate() { fn partial_saturate() {
let fs = (9999999999.0, 4096.0, 262144.0); let fs = (9999999999.0, 4096.0, 262144.0);
assert_eq!(round_trip(fs), (MAX, 4096.0, 262144.0)); assert_eq!((MAX, 4096.0, 262144.0), round_trip(fs));
} }
#[test] #[test]
fn smallest_value() { fn smallest_value() {
let fs = (MIN, MIN * 0.5, MIN * 0.49); let fs = (MIN * 1.5, MIN, MIN * 0.5);
assert_eq!(round_trip(fs), (MIN, MIN, 0.0)); assert_eq!((MIN, MIN, 0.0), round_trip(fs));
assert_eq!(decode(0x00_80_40_00), (MIN, MIN, 0.0)); assert_eq!((MIN, MIN, 0.0), decode(0x00_80_40_00));
} }
#[test] #[test]
fn underflow() { fn underflow() {
let fs = (MIN * 0.49, 0.0, 0.0); let fs = (MIN * 0.99, 0.0, 0.0);
assert_eq!(encode(fs), 0); assert_eq!(0, encode(fs));
assert_eq!(round_trip(fs), (0.0, 0.0, 0.0)); assert_eq!((0.0, 0.0, 0.0), round_trip(fs));
} }
#[test] #[test]