Since it's generated code anyway, it doesn't need to be formatted nicely, and rustfmt was spewing out a bunch of errors because of too-long lines anyway.
187 lines
6.3 KiB
Python
187 lines
6.3 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright (c) 2012 Leonhard Gruenschloss (leonhard@gruenschloss.org)
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights to
|
|
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
|
# of the Software, and to permit persons to whom the Software is furnished to do
|
|
# so, subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be included in
|
|
# all copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
# Adapted to generate Rust instead of C by Nathan Vegdahl
|
|
# Generate Rust code for evaluating Halton points with Faure-permutations for different bases.
|
|
|
|
# How many components to generate.
|
|
num_dimensions = 128
|
|
|
|
# Check primality. Not optimized, since it's not performance-critical.
|
|
def is_prime(p):
|
|
for i in range(2, p):
|
|
if not p % i:
|
|
return False
|
|
return True
|
|
|
|
# Init prime number array.
|
|
primes = []
|
|
candidate = 1
|
|
for i in range(num_dimensions):
|
|
while (True):
|
|
candidate += 1
|
|
if (is_prime(candidate)):
|
|
break;
|
|
primes.append(candidate)
|
|
|
|
# Compute the Faure digit permutation for 0, ..., b - 1.
|
|
def get_faure_permutation(b):
|
|
if b < 2:
|
|
return (0,)
|
|
|
|
elif b == 2:
|
|
return (0, 1)
|
|
|
|
elif b & 1: # odd
|
|
c = (b - 1) / 2
|
|
|
|
def faure_odd(i):
|
|
if i == c:
|
|
return c
|
|
|
|
f = faure[b - 1][i - int(i > c)]
|
|
return f + int(f >= c)
|
|
|
|
return tuple((faure_odd(i) for i in range(b)))
|
|
|
|
else: # even
|
|
c = b / 2
|
|
|
|
def faure_even(i):
|
|
if i < c:
|
|
return 2 * faure[c][i]
|
|
else:
|
|
return 2 * faure[c][i - c] + 1
|
|
|
|
return tuple((faure_even(i) for i in range(b)))
|
|
|
|
# Init Faure permutations.
|
|
faure = []
|
|
for b in range(primes[-1] + 1):
|
|
faure.append(get_faure_permutation(b))
|
|
|
|
# Compute the radical inverse with Faure permutations.
|
|
def invert(base, index, digits):
|
|
result = 0
|
|
for i in range(digits):
|
|
index, remainder = divmod(index, base)
|
|
result = result * base + faure[base][remainder]
|
|
return result
|
|
|
|
# Print the beginning bits of the file
|
|
print '''#![allow(dead_code)]
|
|
#![cfg_attr(rustfmt, rustfmt_skip)]
|
|
// Copyright (c) 2012 Leonhard Gruenschloss (leonhard@gruenschloss.org)
|
|
//
|
|
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
// of this software and associated documentation files (the "Software"), to deal
|
|
// in the Software without restriction, including without limitation the rights to
|
|
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
|
// of the Software, and to permit persons to whom the Software is furnished to do
|
|
// so, subject to the following conditions:
|
|
//
|
|
// The above copyright notice and this permission notice shall be included in
|
|
// all copies or substantial portions of the Software.
|
|
//
|
|
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
// SOFTWARE.
|
|
|
|
// This file is automatically generated.
|
|
|
|
// Compute points of the Halton sequence with with Faure-permutations for different bases.
|
|
|
|
pub const MAX_DIMENSION: u32 = %d;''' % num_dimensions
|
|
|
|
# Print the sampling function
|
|
print '''
|
|
pub fn sample(dimension: u32, index: u32) -> f32 {
|
|
match dimension {'''
|
|
|
|
for i in range(num_dimensions):
|
|
print ' %d => halton%d(index),' % (i, primes[i])
|
|
|
|
print '''
|
|
_ => panic!("Exceeded max dimensions."),
|
|
}
|
|
}
|
|
'''
|
|
|
|
# Print the special-cased first dimension
|
|
print '''
|
|
// Special case: radical inverse in base 2, with direct bit reversal.
|
|
fn halton2(mut index: u32) -> f32 {
|
|
index = (index << 16) | (index >> 16);
|
|
index = ((index & 0x00ff00ff) << 8) | ((index & 0xff00ff00) >> 8);
|
|
index = ((index & 0x0f0f0f0f) << 4) | ((index & 0xf0f0f0f0) >> 4);
|
|
index = ((index & 0x33333333) << 2) | ((index & 0xcccccccc) >> 2);
|
|
index = ((index & 0x55555555) << 1) | ((index & 0xaaaaaaaa) >> 1);
|
|
return (index as f32) * (1.0 / ((1u64 << 32) as f32));
|
|
}
|
|
'''
|
|
|
|
for i in range(1, num_dimensions): # Skip base 2.
|
|
base = primes[i]
|
|
|
|
# Based on the permutation table size, we process multiple digits at once.
|
|
digits = 1
|
|
pow_base = base
|
|
while pow_base * base <= 500: # Maximum permutation table size.
|
|
pow_base *= base
|
|
digits += 1
|
|
|
|
max_power = pow_base
|
|
powers = []
|
|
while max_power * pow_base < (1 << 32): # 32-bit unsigned precision
|
|
powers.append(max_power)
|
|
max_power *= pow_base
|
|
|
|
# Build the permutation table.
|
|
perm = []
|
|
for j in range(pow_base):
|
|
perm.append(invert(base, j, digits))
|
|
|
|
power = max_power / pow_base
|
|
print '''
|
|
fn halton%d(index: u32) -> f32 {
|
|
const PERM%d: [u16; %d] = [%s];
|
|
''' % (base, base, len(perm), ', '.join(str(k) for k in perm))
|
|
|
|
print ''' return (unsafe{*PERM%d.get_unchecked((index %% %d) as usize)} as u32 * %d +''' % \
|
|
(base, pow_base, power)
|
|
|
|
# Advance to next set of digits.
|
|
div = 1
|
|
while power / pow_base > 1:
|
|
div *= pow_base
|
|
power /= pow_base
|
|
print ' unsafe{*PERM%d.get_unchecked(((index / %d) %% %d) as usize)} as u32 * %d +' % (base, div, pow_base, power)
|
|
|
|
print ''' unsafe{*PERM%d.get_unchecked(((index / %d) %% %d) as usize)} as u32) as f32 * (0.999999940395355224609375f32 / (%du32 as f32)); // Results in [0,1).
|
|
}
|
|
''' % (base, div * pow_base, pow_base, max_power)
|
|
|