Cleaned up the signed48 trifloat code.

This commit is contained in:
Nathan Vegdahl 2020-09-19 08:53:53 +09:00
parent fd98b33333
commit 49c97bf0fe

View File

@ -3,9 +3,9 @@
//! The encoding uses 13 bits of mantissa and 1 sign bit per number, and 6 //! The encoding uses 13 bits of mantissa and 1 sign bit per number, and 6
//! bits for the shared exponent. The bit layout is: [sign 1, mantissa 1, //! bits for the shared exponent. The bit layout is: [sign 1, mantissa 1,
//! sign 2, mantissa 2, sign 3, mantissa 3, exponent]. The exponent is stored //! sign 2, mantissa 2, sign 3, mantissa 3, exponent]. The exponent is stored
//! as an unsigned integer with a bias of 25. //! as an unsigned integer with a bias of 26.
//! //!
//! The largest representable number is `2^38 - 2^25`, and the smallest //! The largest representable number is just under `2^38`, and the smallest
//! representable positive number is `2^-38`. //! representable positive number is `2^-38`.
//! //!
//! Since the exponent is shared between all three values, the precision //! Since the exponent is shared between all three values, the precision
@ -18,32 +18,29 @@
use crate::{fiddle_exp2, fiddle_log2}; use crate::{fiddle_exp2, fiddle_log2};
/// Largest representable number. /// Largest representable number.
pub const MAX: f32 = 274_844_352_512.0; pub const MAX: f32 = ((1u128 << (64 - EXP_BIAS)) - (1 << (64 - EXP_BIAS - 13))) as f32;
/// Smallest representable number. /// Smallest representable number.
/// ///
/// Note this is not the smallest _magnitude_ number. This is a negative /// Note this is not the smallest _magnitude_ number. This is a negative
/// number of large magnitude. /// number of large magnitude.
pub const MIN: f32 = -274_844_352_512.0; pub const MIN: f32 = -MAX;
/// Smallest representable positive number. /// Smallest representable positive number.
/// ///
/// This is the number with the smallest possible magnitude (aside from zero). /// This is the number with the smallest possible magnitude (aside from zero).
#[allow(clippy::excessive_precision)] pub const MIN_POSITIVE: f32 = 1.0 / (1u128 << (EXP_BIAS + 12)) as f32;
pub const MIN_POSITIVE: f32 = 0.000_000_000_003_637_978_807_091_713;
/// Difference between 1.0 and the next largest representable number. /// Difference between 1.0 and the next largest representable number.
pub const EPSILON: f32 = 1.0 / 4096.0; pub const EPSILON: f32 = 1.0 / 4096.0;
const EXP_BIAS: i32 = 25; const EXP_BIAS: i32 = 26;
const MIN_EXP: i32 = 0 - EXP_BIAS;
const MAX_EXP: i32 = 63 - EXP_BIAS;
/// Encodes three floating point values into a signed 48-bit trifloat. /// Encodes three floating point values into a signed 48-bit trifloat.
/// ///
/// Input floats that are larger than `MAX` or smaller than `MIN` will saturate /// Input floats that are larger than `MAX` or smaller than `MIN` will saturate
/// to `MAX` and `MIN` respectively, including +/- infinity. Values are /// to `MAX` and `MIN` respectively, including +/- infinity. Values are
/// converted to trifloat precision by rounding. /// converted to trifloat precision by rounding towards zero.
/// ///
/// Only the lower 48 bits of the return value are used. The highest 16 bits /// Only the lower 48 bits of the return value are used. The highest 16 bits
/// will all be zero and can be safely discarded. /// will all be zero and can be safely discarded.
@ -62,46 +59,31 @@ pub fn encode(floats: (f32, f32, f32)) -> u64 {
floats.2 floats.2
); );
// Find the largest (in magnitude) of the three values. let floats_abs = (floats.0.abs(), floats.1.abs(), floats.2.abs());
let largest_value = {
let mut largest_value: f32 = 0.0;
if floats.0.abs() > largest_value.abs() {
largest_value = floats.0;
}
if floats.1.abs() > largest_value.abs() {
largest_value = floats.1;
}
if floats.2.abs() > largest_value.abs() {
largest_value = floats.2;
}
largest_value
};
// Calculate the exponent and 1.0/multiplier for encoding the values. let largest_abs = floats_abs.0.max(floats_abs.1.max(floats_abs.2));
let (exponent, inv_multiplier) = {
let mut exponent = (fiddle_log2(largest_value) + 1).max(MIN_EXP).min(MAX_EXP);
let mut inv_multiplier = fiddle_exp2(-exponent + 13);
// Edge-case: make sure rounding pushes the largest value up if largest_abs < MIN_POSITIVE {
// appropriately if needed. 0
if (largest_value * inv_multiplier).abs() + 0.5 >= 8192.0 { } else {
exponent = (exponent + 1).min(MAX_EXP); let e = fiddle_log2(largest_abs).max(-EXP_BIAS).min(63 - EXP_BIAS);
inv_multiplier = fiddle_exp2(-exponent + 13); let inv_multiplier = fiddle_exp2(-e + 12);
}
(exponent, inv_multiplier)
};
// Quantize and encode values. let x_sign = (floats.0.to_bits() >> 31) as u64;
let x = (floats.0.abs() * inv_multiplier + 0.5).min(8191.0) as u64 & 0b111_11111_11111; let x = (floats_abs.0 * inv_multiplier).min(8191.0) as u64;
let x_sign = (floats.0.to_bits() >> 31) as u64; let y_sign = (floats.1.to_bits() >> 31) as u64;
let y = (floats.1.abs() * inv_multiplier + 0.5).min(8191.0) as u64 & 0b111_11111_11111; let y = (floats_abs.1 * inv_multiplier).min(8191.0) as u64;
let y_sign = (floats.1.to_bits() >> 31) as u64; let z_sign = (floats.2.to_bits() >> 31) as u64;
let z = (floats.2.abs() * inv_multiplier + 0.5).min(8191.0) as u64 & 0b111_11111_11111; let z = (floats_abs.2 * inv_multiplier).min(8191.0) as u64;
let z_sign = (floats.2.to_bits() >> 31) as u64;
let e = (exponent + EXP_BIAS) as u64 & 0b111_111;
// Pack values into a single u64 and return. (x_sign << 47)
(x_sign << 47) | (x << 34) | (y_sign << 33) | (y << 20) | (z_sign << 19) | (z << 6) | e | (x << 34)
| (y_sign << 33)
| (y << 20)
| (z_sign << 19)
| (z << 6)
| (e + EXP_BIAS) as u64
}
} }
/// Decodes a signed 48-bit trifloat into three full floating point numbers. /// Decodes a signed 48-bit trifloat into three full floating point numbers.
@ -122,7 +104,7 @@ pub fn decode(trifloat: u64) -> (f32, f32, f32) {
let e = trifloat & 0b111_111; let e = trifloat & 0b111_111;
let multiplier = fiddle_exp2(e as i32 - EXP_BIAS - 13); let multiplier = fiddle_exp2(e as i32 - EXP_BIAS - 12);
( (
f32::from_bits((x as f32 * multiplier).to_bits() | x_sign), f32::from_bits((x as f32 * multiplier).to_bits() | x_sign),
@ -153,7 +135,7 @@ mod tests {
#[test] #[test]
fn powers_of_two() { fn powers_of_two() {
let fs = (8.0f32, 128.0f32, 0.5f32); let fs = (8.0f32, 128.0f32, 0.5f32);
assert_eq!(round_trip(fs), fs); assert_eq!(fs, round_trip(fs));
} }
#[test] #[test]
@ -196,18 +178,11 @@ mod tests {
} }
#[test] #[test]
fn rounding() { fn precision_floor() {
let fs = (7.0f32, 8193.0f32, -1.0f32); let fs = (7.0f32, 8193.0f32, -1.0f32);
let fsn = (-7.0f32, -8193.0f32, 1.0f32); let fsn = (-7.0f32, -8193.0f32, 1.0f32);
assert_eq!(round_trip(fs), (8.0, 8194.0, -2.0)); assert_eq!((6.0, 8192.0, -0.0), round_trip(fs));
assert_eq!(round_trip(fsn), (-8.0, -8194.0, 2.0)); assert_eq!((-6.0, -8192.0, 0.0), round_trip(fsn));
}
#[test]
fn rounding_edge_case() {
let fs = (16383.0f32, 0.0f32, 0.0f32);
assert_eq!(round_trip(fs), (16384.0, 0.0, 0.0),);
} }
#[test] #[test]
@ -223,10 +198,10 @@ mod tests {
-99_999_999_999_999.0, -99_999_999_999_999.0,
); );
assert_eq!(round_trip(fs), (MAX, MAX, MAX)); assert_eq!((MAX, MAX, MAX), round_trip(fs));
assert_eq!(round_trip(fsn), (MIN, MIN, MIN)); assert_eq!((MIN, MIN, MIN), round_trip(fsn));
assert_eq!(decode(0x7FFD_FFF7_FFFF), (MAX, MAX, MAX)); assert_eq!((MAX, MAX, MAX), decode(0x7FFD_FFF7_FFFF));
assert_eq!(decode(0xFFFF_FFFF_FFFF), (MIN, MIN, MIN)); assert_eq!((MIN, MIN, MIN), decode(0xFFFF_FFFF_FFFF));
} }
#[test] #[test]
@ -235,10 +210,10 @@ mod tests {
let fs = (INFINITY, 0.0, 0.0); let fs = (INFINITY, 0.0, 0.0);
let fsn = (-INFINITY, 0.0, 0.0); let fsn = (-INFINITY, 0.0, 0.0);
assert_eq!(round_trip(fs), (MAX, 0.0, 0.0)); assert_eq!((MAX, 0.0, 0.0), round_trip(fs));
assert_eq!(round_trip(fsn), (MIN, 0.0, 0.0)); assert_eq!((MIN, 0.0, 0.0), round_trip(fsn));
assert_eq!(encode(fs), 0x7FFC0000003F); assert_eq!(0x7FFC0000003F, encode(fs));
assert_eq!(encode(fsn), 0xFFFC0000003F); assert_eq!(0xFFFC0000003F, encode(fsn));
} }
#[test] #[test]
@ -246,25 +221,25 @@ mod tests {
let fs = (99_999_999_999_999.0, 4294967296.0, -17179869184.0); let fs = (99_999_999_999_999.0, 4294967296.0, -17179869184.0);
let fsn = (-99_999_999_999_999.0, 4294967296.0, -17179869184.0); let fsn = (-99_999_999_999_999.0, 4294967296.0, -17179869184.0);
assert_eq!(round_trip(fs), (MAX, 4294967296.0, -17179869184.0)); assert_eq!((MAX, 4294967296.0, -17179869184.0), round_trip(fs));
assert_eq!(round_trip(fsn), (MIN, 4294967296.0, -17179869184.0)); assert_eq!((MIN, 4294967296.0, -17179869184.0), round_trip(fsn));
} }
#[test] #[test]
fn smallest_value() { fn smallest_value() {
let fs = (MIN_POSITIVE, MIN_POSITIVE * 0.5, MIN_POSITIVE * 0.49); let fs = (MIN_POSITIVE * 1.5, MIN_POSITIVE, MIN_POSITIVE * 0.50);
let fsn = (-MIN_POSITIVE, -MIN_POSITIVE * 0.5, -MIN_POSITIVE * 0.49); let fsn = (-MIN_POSITIVE * 1.5, -MIN_POSITIVE, -MIN_POSITIVE * 0.50);
assert_eq!(decode(0x600100000), (MIN_POSITIVE, -MIN_POSITIVE, 0.0)); assert_eq!((MIN_POSITIVE, -MIN_POSITIVE, 0.0), decode(0x600100000));
assert_eq!(round_trip(fs), (MIN_POSITIVE, MIN_POSITIVE, 0.0)); assert_eq!((MIN_POSITIVE, MIN_POSITIVE, 0.0), round_trip(fs));
assert_eq!(round_trip(fsn), (-MIN_POSITIVE, -MIN_POSITIVE, -0.0)); assert_eq!((-MIN_POSITIVE, -MIN_POSITIVE, -0.0), round_trip(fsn));
} }
#[test] #[test]
fn underflow() { fn underflow() {
let fs = (MIN_POSITIVE * 0.49, -MIN_POSITIVE * 0.49, 0.0); let fs = (MIN_POSITIVE * 0.5, -MIN_POSITIVE * 0.5, MIN_POSITIVE);
assert_eq!(encode(fs), 0x200000000); assert_eq!(0x200000040, encode(fs));
assert_eq!(round_trip(fs), (0.0, -0.0, 0.0)); assert_eq!((0.0, -0.0, MIN_POSITIVE), round_trip(fs));
} }
#[test] #[test]