Skip to content

Commit

Permalink
use rotation to remove branching in inner loop (#63)
Browse files Browse the repository at this point in the history
* use rotation to remove branching

* fix comments

* fix comments

* add more comments

* added comments as suggested and remove some leftover special case code that no longer applies

* fix comments to remove references to removed special cases
  • Loading branch information
mcroomp authored Apr 16, 2024
1 parent 12508d3 commit 983873c
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 47 deletions.
146 changes: 103 additions & 43 deletions src/structs/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,20 @@
*--------------------------------------------------------------------------------------------*/

/*
The logic here is different here than the C++ version, resulting in
a 2x speed increase. Nothing magic, the main change is to not
store the probability, since it is deterministically determined
based on the true/false counts. Instead of doing the calculation,
we just lookup the 16-bit value in a lookup table to get the
corresponding probabiity.
The only corner case is that in the case of 255 true and 1
false, the C++ version decides to set the probability to 0 for the next
true value, which is different than the formula ((a << 8) / ( a + b )).
To handle this, we use 0 as a special value to indicate this corner case,
which has a value of 0 in the lookup table. On subsequent calls,
we make sure that we immediately transition back to (255,1) before
executing any further logic.
*/

pub struct Branch {
/// The top byte is the number of false bits seen so far
/// and the bottom byte is the number of true bits seen.
/// On overflow both values are normalized by dividing by 2 (rounding up).
///
/// Both counts are never less than 1, so we start off with 0x0101.
counts: u16,
}

Expand All @@ -33,7 +28,7 @@ impl Default for Branch {
}
}

// used to precalculate the probabilities
/// used to precalculate the probabilities and store them as a const array
const fn problookup() -> [u8; 65536] {
let mut retval = [0; 65536];
let mut i = 1i32;
Expand All @@ -48,58 +43,123 @@ const fn problookup() -> [u8; 65536] {
return retval;
}

/// precalculated probabilities for the next bit being false
static PROB_LOOKUP: [u8; 65536] = problookup();

impl Branch {
pub fn new() -> Self {
Branch { counts: 0x0101 }
}

// used for debugging
/// used for debugging to keep the state for hashing
#[allow(dead_code)]
pub fn get_u64(&self) -> u64 {
let mut c = self.counts;
if c == 0 {
c = 0x01ff;
}

let c = self.counts;
return ((PROB_LOOKUP[self.counts as usize] as u64) << 16) + c as u64;
}

/// Returns the probability of the next bit being a false as a value between 1 and 255
///
/// Calculated by looking up the probability in a precalculated table
/// where 'f' is the number of false bits and 't' is the number of true bits seen.
///
/// (f * 256) / (f + t)
#[inline(always)]
pub fn get_probability(&self) -> u8 {
// 0x00ff is a special corner case which should return probability 0
// since 0x00ff is impossible to happen since the counts always start at 1
PROB_LOOKUP[self.counts as usize]
}

/// Updates the counters when we encounter a 1 or 0. If we hit 255 values, then
/// we normalize both counts (divide by 2), except in the case where the remaining value is 1,
/// in which case we don't touch. This biases the probability to get better results
/// when there are long runs of 1 or 0.
///
/// This function merges updating either the true or false counter
/// by swapping the top and bottom byte of the 16-bit value.
///
/// The update algorithm looks like this (with top and bottom swapped depending on 'bit'):
///
/// if top_byte < 0xff {
/// top_byte += 1;
/// } else if bottom_byte != 1 {
/// top_byte = 0x81;
/// bottom_byte = (bottom_byte + 1) >> 1;
/// }
#[inline(always)]
pub fn record_and_update_true_obs(&mut self) {
if (self.counts & 0xff) != 0xff {
// non-overflow case is easy
self.counts += 1;
} else {
// normalize, except special case where it is all trues
if self.counts != 0x01ff {
self.counts = (((self.counts as u32 + 0x100) >> 1) & 0xff00) as u16 | 129;
}
pub fn record_and_update_bit(&mut self, bit: bool) {
// rotation is used to update either the true or false counter
// this allows the same code to be used without branching,
// which makes the CPU about 20% happier.
//
// Since the bits are randomly 1/0, the CPU branch predictor does
// a terrible job and ends up wasting a lot of time. Normally
// branches are a better idea if the branch very predictable vs
// this case where it is better to always pay the price of the
// extra rotation to avoid the branch.
let orig = self.counts.rotate_left(bit as u32 * 8);
let (mut sum, o) = orig.overflowing_add(0x100);
if o {
// normalize, except in special case where we have 0xff or more same bits in a row
// in which case we want to bias the probability to get better compression
//
// CPU branch prediction soon realizes that this section is not often executed
// and will optimize for the common case where the counts are not 0xff.
let mask = if orig == 0xff01 { 0xff00 } else { 0x8100 };

// upper byte is 0 since we incremented 0xffxx so we don't have to mask it
sum = ((1 + sum) >> 1) | mask;
}
}

#[inline(always)]
pub fn record_and_update_false_obs(&mut self) {
let (result, overflow) = self.counts.overflowing_add(0x100);
if !overflow {
self.counts = result;
} else {
// normalize, except in special case where it is all falses
if self.counts != 0xff01 {
self.counts = ((1 + (self.counts & 0xff) as u32) >> 1) as u16 | 0x8100;
}
}
self.counts = sum.rotate_left(bit as u32 * 8);
}
}

#[test]
fn test_branch_update_false() {
let mut b = Branch { counts: 0x0101 };
b.record_and_update_bit(false);
assert_eq!(b.counts, 0x0201);

b.counts = 0x80ff;
b.record_and_update_bit(false);
assert_eq!(b.counts, 0x81ff);

b.counts = 0xff01;
b.record_and_update_bit(false);
assert_eq!(b.counts, 0xff01);

b.counts = 0xff02;
b.record_and_update_bit(false);
assert_eq!(b.counts, 0x8101);

b.counts = 0xffff;
b.record_and_update_bit(false);
assert_eq!(b.counts, 0x8180);
}

#[test]
fn test_branch_update_true() {
let mut b = Branch { counts: 0x0101 };
b.record_and_update_bit(true);
assert_eq!(b.counts, 0x0102);

b.counts = 0xff80;
b.record_and_update_bit(true);
assert_eq!(b.counts, 0xff81);

b.counts = 0x01ff;
b.record_and_update_bit(true);
assert_eq!(b.counts, 0x01ff);

b.counts = 0x02ff;
b.record_and_update_bit(true);
assert_eq!(b.counts, 0x0181);

b.counts = 0xffff;
b.record_and_update_bit(true);
assert_eq!(b.counts, 0x8081);
}

/// run through all the possible combinations of counts and ensure that the probability is the same
#[test]
fn test_all_probabilities() {
Expand Down Expand Up @@ -163,7 +223,7 @@ fn test_all_probabilities() {

for _k in 0..10 {
old_f.record_obs_and_update(false);
new_f.record_and_update_false_obs();
new_f.record_and_update_bit(false);
assert_eq!(old_f.probability, new_f.get_probability());
}

Expand All @@ -175,7 +235,7 @@ fn test_all_probabilities() {

for _k in 0..10 {
old_t.record_obs_and_update(true);
new_t.record_and_update_true_obs();
new_t.record_and_update_bit(true);

if old_t.probability == 0 {
// there is a change of behavior here compared to the C++ version,
Expand Down
5 changes: 3 additions & 2 deletions src/structs/vpx_bool_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,10 @@ impl<R: Read> VPXBoolReader<R> {
let bit = tmp_value >= big_split;

let shift;

branch.record_and_update_bit(bit);

if bit {
branch.record_and_update_true_obs();
tmp_range -= split;
tmp_value -= big_split;

Expand All @@ -166,7 +168,6 @@ impl<R: Read> VPXBoolReader<R> {

shift = tmp_range.leading_zeros() as i32 - 24;
} else {
branch.record_and_update_false_obs();
tmp_range = split;

// optimizer understands that split > 0
Expand Down
4 changes: 2 additions & 2 deletions src/structs/vpx_bool_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,14 @@ impl<W: Write> VPXBoolWriter<W> {
let mut tmp_low_value = self.low_value;

let mut shift;
branch.record_and_update_bit(value);

if value {
branch.record_and_update_true_obs();
tmp_low_value += split;
tmp_range -= split;

shift = (tmp_range as u8).leading_zeros() as i32;
} else {
branch.record_and_update_false_obs();
tmp_range = split;

// optimizer understands that split > 0, so it can optimize this
Expand Down

0 comments on commit 983873c

Please sign in to comment.