infotheory/backends/
match_model.rs

1use ahash::AHashMap;
2
3#[derive(Clone, Debug)]
4/// Local match predictor with configurable contiguous or gapped matching.
5pub struct MatchModel {
6    hash_bits: usize,
7    min_len: usize,
8    max_len: usize,
9    stride_min: usize,
10    stride_max: usize,
11    base_mix: f64,
12    confidence_scale: f64,
13    history: Vec<u8>,
14    frozen_anchor: usize,
15    tables: Vec<AHashMap<u64, (usize, usize)>>,
16    pdf: [f64; 256],
17    cdf: [f64; 257],
18    valid: bool,
19    cdf_valid: bool,
20    predicted: Option<u8>,
21    match_len: usize,
22}
23
24impl MatchModel {
25    /// Create a match model with an inclusive stride range `[gap_min+1, gap_max+1]`.
26    pub fn new(
27        hash_bits: usize,
28        min_len: usize,
29        max_len: usize,
30        gap_min: usize,
31        gap_max: usize,
32        base_mix: f64,
33        confidence_scale: f64,
34    ) -> Self {
35        let stride_min = gap_min.saturating_add(1);
36        let stride_max = gap_max.saturating_add(1).max(stride_min);
37        let mut tables = Vec::new();
38        for _ in stride_min..=stride_max {
39            tables.push(AHashMap::new());
40        }
41        Self {
42            hash_bits,
43            min_len: min_len.max(1),
44            max_len: max_len.max(min_len.max(1)),
45            stride_min,
46            stride_max,
47            base_mix: base_mix.clamp(1e-6, 0.99),
48            confidence_scale: confidence_scale.max(0.0),
49            history: Vec::new(),
50            frozen_anchor: 0,
51            tables,
52            pdf: [1.0 / 256.0; 256],
53            cdf: uniform_cdf(),
54            valid: false,
55            cdf_valid: false,
56            predicted: None,
57            match_len: 0,
58        }
59    }
60
61    /// Convenience constructor for contiguous matching (`gap_min = gap_max = 0`).
62    pub fn new_contiguous(
63        hash_bits: usize,
64        min_len: usize,
65        max_len: usize,
66        base_mix: f64,
67        confidence_scale: f64,
68    ) -> Self {
69        Self::new(
70            hash_bits,
71            min_len,
72            max_len,
73            0,
74            0,
75            base_mix,
76            confidence_scale,
77        )
78    }
79
80    /// Fill `out` with the current normalized byte PDF.
81    pub fn fill_pdf(&mut self, out: &mut [f64; 256]) {
82        self.ensure_pdf_inner(false);
83        out.copy_from_slice(&self.pdf);
84    }
85
86    /// Borrow the current normalized byte PDF.
87    pub fn pdf(&mut self) -> &[f64; 256] {
88        self.ensure_pdf_inner(false);
89        &self.pdf
90    }
91
92    /// Borrow the cumulative distribution derived from the current PDF.
93    pub fn cdf(&mut self) -> &[f64; 257] {
94        self.ensure_pdf_inner(true);
95        &self.cdf
96    }
97
98    /// Return `ln(max(P(symbol), min_prob))`.
99    pub fn log_prob(&mut self, symbol: u8, min_prob: f64) -> f64 {
100        self.ensure_pdf_inner(false);
101        self.pdf[symbol as usize].max(min_prob).ln()
102    }
103
104    /// Observe one symbol and update match tables/history.
105    pub fn update(&mut self, symbol: u8) {
106        if self.frozen_anchor > 0 {
107            self.frozen_anchor = 0;
108        }
109        self.history.push(symbol);
110        for stride in self.stride_min..=self.stride_max {
111            if let Some(key) = self.suffix_key(stride) {
112                let end = self.history.len() - 1;
113                self.tables[stride - self.stride_min]
114                    .entry(key)
115                    .and_modify(|entry| {
116                        entry.1 = entry.0;
117                        entry.0 = end;
118                    })
119                    .or_insert((end, usize::MAX));
120            }
121        }
122        self.valid = false;
123        self.cdf_valid = false;
124    }
125
126    /// Reset conditioning while preserving learned match tables and fitted corpus bytes.
127    pub fn reset_history(&mut self) {
128        if self.frozen_anchor > 0 {
129            self.history.truncate(self.frozen_anchor);
130        } else {
131            self.frozen_anchor = self.history.len();
132        }
133        self.valid = false;
134        self.cdf_valid = false;
135        self.predicted = None;
136        self.match_len = 0;
137        self.pdf.fill(1.0 / 256.0);
138        self.cdf = uniform_cdf();
139    }
140
141    /// Advance conditioning history without updating learned match tables.
142    pub fn update_history_only(&mut self, symbol: u8) {
143        if self.frozen_anchor == 0 {
144            self.frozen_anchor = self.history.len();
145        }
146        self.history.push(symbol);
147        self.valid = false;
148        self.cdf_valid = false;
149    }
150
151    /// Length of the best match used for the last computed distribution.
152    pub fn match_len(&mut self) -> usize {
153        self.ensure_pdf_inner(false);
154        self.match_len
155    }
156
157    /// Predicted next byte from the best match, if any.
158    pub fn predicted_byte(&mut self) -> Option<u8> {
159        self.ensure_pdf_inner(false);
160        self.predicted
161    }
162
163    fn ensure_pdf_inner(&mut self, want_cdf: bool) {
164        if self.valid {
165            if want_cdf && !self.cdf_valid {
166                build_cdf_from_pdf(&self.pdf, &mut self.cdf);
167                self.cdf_valid = true;
168            }
169            return;
170        }
171        self.predicted = None;
172        self.match_len = 0;
173        self.pdf.fill(1.0 / 256.0);
174        let active_len = self.history.len().saturating_sub(self.frozen_anchor);
175        if active_len < self.min_len {
176            self.valid = true;
177            if want_cdf {
178                self.cdf = uniform_cdf();
179                self.cdf_valid = true;
180            } else {
181                self.cdf_valid = false;
182            }
183            return;
184        }
185
186        let mut best = None;
187        let history_limit = if self.frozen_anchor > 0 {
188            self.frozen_anchor
189        } else {
190            self.history.len()
191        };
192        for stride in self.stride_min..=self.stride_max {
193            let Some(key) = self.suffix_key(stride) else {
194                continue;
195            };
196            let Some(&(latest, previous)) = self.tables[stride - self.stride_min].get(&key) else {
197                continue;
198            };
199            let current_end = self.history.len() - 1;
200            let candidate_end = if latest == current_end {
201                previous
202            } else {
203                latest
204            };
205            if self.frozen_anchor > 0 && candidate_end >= self.frozen_anchor {
206                continue;
207            }
208            if candidate_end == usize::MAX || candidate_end + stride >= history_limit {
209                continue;
210            }
211            let matched = self.extend_match(candidate_end, stride);
212            if matched < self.min_len {
213                continue;
214            }
215            let predicted = self.history[candidate_end + stride];
216            match best {
217                Some((best_len, _, _)) if matched <= best_len => {}
218                _ => best = Some((matched, predicted, stride)),
219            }
220        }
221
222        if let Some((match_len, predicted, _stride)) = best {
223            self.predicted = Some(predicted);
224            self.match_len = match_len;
225            let base = 1.0 / 256.0;
226            let span = self.max_len.saturating_sub(self.min_len).max(1);
227            let covered = match_len.saturating_sub(self.min_len).min(span);
228            let confidence = ((covered as f64) / (span as f64)).sqrt() * self.confidence_scale;
229            let p_copy = (base + confidence.clamp(0.0, 1.0) * ((1.0 - self.base_mix) - base))
230                .clamp(base, 1.0 - self.base_mix);
231            let rest = ((1.0 - p_copy) / 255.0).max(0.0);
232            self.pdf.fill(rest);
233            self.pdf[predicted as usize] = p_copy;
234        }
235        if want_cdf {
236            build_cdf_from_pdf(&self.pdf, &mut self.cdf);
237        }
238        self.valid = true;
239        self.cdf_valid = want_cdf;
240    }
241
242    fn suffix_key(&self, stride: usize) -> Option<u64> {
243        let need = self
244            .min_len
245            .checked_sub(1)?
246            .saturating_mul(stride)
247            .saturating_add(1);
248        if self.history.len().saturating_sub(self.frozen_anchor) < need {
249            return None;
250        }
251        let mut h = 0x517C_C1B7_2722_0A95u64;
252        let mut idx = self.history.len() - 1;
253        for step in 0..self.min_len {
254            h ^= self.history[idx] as u64;
255            h = h.rotate_left(7).wrapping_mul(0x9E37_79B1);
256            if step + 1 == self.min_len {
257                break;
258            }
259            let boundary = self.frozen_anchor.saturating_add(stride);
260            if idx < boundary {
261                return None;
262            }
263            idx -= stride;
264        }
265        let bits = self.hash_bits.clamp(4, 63);
266        Some(h & ((1u64 << bits) - 1))
267    }
268
269    fn extend_match(&self, candidate_end: usize, stride: usize) -> usize {
270        let current_end = self.history.len() - 1;
271        let mut matched = self.min_len;
272        while matched < self.max_len {
273            let step = matched.saturating_mul(stride);
274            let Some(current_idx) = current_end.checked_sub(step) else {
275                break;
276            };
277            if current_idx < self.frozen_anchor {
278                break;
279            }
280            let Some(candidate_idx) = candidate_end.checked_sub(step) else {
281                break;
282            };
283            if self.history[current_idx] != self.history[candidate_idx] {
284                break;
285            }
286            matched += 1;
287        }
288        matched
289    }
290}
291
292#[inline]
293fn uniform_cdf() -> [f64; 257] {
294    let mut cdf = [0.0; 257];
295    let inv = 1.0 / 256.0;
296    for (i, slot) in cdf.iter_mut().enumerate() {
297        *slot = (i as f64) * inv;
298    }
299    cdf
300}
301
302#[inline]
303fn build_cdf_from_pdf(pdf: &[f64; 256], cdf: &mut [f64; 257]) {
304    cdf[0] = 0.0;
305    for i in 0..256 {
306        cdf[i + 1] = cdf[i] + pdf[i];
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::MatchModel;
313
314    #[test]
315    fn reset_history_preserves_fit_corpus_for_frozen_conditioning() {
316        let mut model = MatchModel::new_contiguous(32, 3, 16, 0.02, 1.0);
317        for &b in b"abcabcX" {
318            model.update(b);
319        }
320
321        model.reset_history();
322        for &b in b"abcabc" {
323            model.update_history_only(b);
324        }
325
326        assert_eq!(model.predicted_byte(), Some(b'X'));
327        assert!(model.match_len() >= 3);
328    }
329
330    #[test]
331    fn reset_history_drops_previous_conditioning() {
332        let mut model = MatchModel::new_contiguous(32, 3, 16, 0.02, 1.0);
333        for &b in b"abcabcX" {
334            model.update(b);
335        }
336
337        model.reset_history();
338        for &b in b"abcabc" {
339            model.update_history_only(b);
340        }
341        assert_eq!(model.predicted_byte(), Some(b'X'));
342
343        model.reset_history();
344        assert_eq!(model.predicted_byte(), None);
345    }
346}