diff --git a/src/structs/branch.rs b/src/structs/branch.rs index ebae81bb..070e5832 100644 --- a/src/structs/branch.rs +++ b/src/structs/branch.rs @@ -79,14 +79,8 @@ impl Branch { // non-overflow case is easy self.counts += 1; } else { - // special case where it is all trues - if self.counts <= 0x01ff { - // corner case since the original implementation - // insists on setting the probabily to zero, - // although the probability calculation would - // return 1. - self.counts = 0x00ff; - } else { + // normalize, except special case where it is all trues + if self.counts != 0x01ff { self.counts = (((self.counts as u32 + 0x100) >> 1) & 0xff00) as u16 | 129; } } @@ -94,17 +88,11 @@ impl Branch { #[inline(always)] pub fn record_and_update_false_obs(&mut self) { - if self.counts == 0x00ff { - // handle corner case where prob was set to zero (purely for compatibility, remove this if there is a breaking change in the format) - self.counts = 0x02ff; - return; - } - let (result, overflow) = self.counts.overflowing_add(0x100); if !overflow { self.counts = result; } else { - // special case where it is all falses + // normalize, except in special case where it is all falses if self.counts != 0xff01 { self.counts = ((1 + (self.counts & 0xff) as u32) >> 1) as u16 | 0x8100; } @@ -189,7 +177,15 @@ fn test_all_probabilities() { old_t.record_obs_and_update(true); new_t.record_and_update_true_obs(); - assert_eq!(old_t.probability, new_t.get_probability()); + if old_t.probability == 0 { + // there is a change of behavior here compared to the C++ version, + // but because of the way split is calculated it doesn't result in an + // overall change in the way that encoding is done, but it does simplify + // one of the corner cases. + assert_eq!(new_t.get_probability(), 1); + } else { + assert_eq!(old_t.probability, new_t.get_probability()); + } } } }