Skip to content

Commit

Permalink
optimize branch logic
Browse files Browse the repository at this point in the history
  • Loading branch information
mcroomp committed Feb 18, 2024
1 parent f116e52 commit 5f4bc49
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions src/structs/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,32 +79,20 @@ impl Branch {
// non-overflow case is easy
self.counts += 1;
} else {
// special case where it is all trues
if self.counts <= 0x01ff {
// corner case since the original implementation
// insists on setting the probabily to zero,
// although the probability calculation would
// return 1.
self.counts = 0x00ff;
} else {
// normalize, except special case where it is all trues
if self.counts != 0x01ff {
self.counts = (((self.counts as u32 + 0x100) >> 1) & 0xff00) as u16 | 129;
}
}
}

#[inline(always)]
pub fn record_and_update_false_obs(&mut self) {
if self.counts == 0x00ff {
// handle corner case where prob was set to zero (purely for compatibility, remove this if there is a breaking change in the format)
self.counts = 0x02ff;
return;
}

let (result, overflow) = self.counts.overflowing_add(0x100);
if !overflow {
self.counts = result;
} else {
// special case where it is all falses
// normalize, except in special case where it is all falses
if self.counts != 0xff01 {
self.counts = ((1 + (self.counts & 0xff) as u32) >> 1) as u16 | 0x8100;
}
Expand Down Expand Up @@ -176,7 +164,15 @@ fn test_all_probabilities() {
for _k in 0..10 {
old_f.record_obs_and_update(false);
new_f.record_and_update_false_obs();
assert_eq!(old_f.probability, new_f.get_probability());
if old_f.probability == 0 {
// there is a change of behavior here compared to the C++ version,
// but because of the way split is calculated it doesn't result in an
// overall change in the way that encoding is done, but it does simplify
// one of the corner cases.
assert_eq!(old_f.probability, 1);
} else {
assert_eq!(old_f.probability, new_f.get_probability());
}
}

let mut old_t = OriginalImplForTest {
Expand All @@ -189,7 +185,15 @@ fn test_all_probabilities() {
old_t.record_obs_and_update(true);
new_t.record_and_update_true_obs();

assert_eq!(old_t.probability, new_t.get_probability());
if old_f.probability == 0 {
// there is a change of behavior here compared to the C++ version,
// but because of the way split is calculated it doesn't result in an
// overall change in the way that encoding is done, but it does simplify
// one of the corner cases.
assert_eq!(old_f.probability, 1);
} else {
assert_eq!(old_f.probability, new_f.get_probability());
}
}
}
}

0 comments on commit 5f4bc49

Please sign in to comment.