infotheory/backends/
ppmd.rs

1use ahash::AHashMap;
2use std::collections::VecDeque;
3
4const PDF_MIN: f64 = crate::mixture::DEFAULT_MIN_PROB;
5
6#[derive(Clone, Debug, Default)]
7struct ContextStats {
8    counts: Vec<(u8, u16)>,
9    total: u32,
10}
11
12impl ContextStats {
13    fn observe(&mut self, symbol: u8) {
14        if let Some((_, count)) = self.counts.iter_mut().find(|(s, _)| *s == symbol) {
15            *count = count.saturating_add(1);
16        } else {
17            self.counts.push((symbol, 1));
18        }
19        self.total = self.total.saturating_add(1);
20        if self.total > 4096 {
21            self.rescale();
22        }
23    }
24
25    fn rescale(&mut self) {
26        self.total = 0;
27        self.counts.retain_mut(|(_, count)| {
28            *count = (*count).div_ceil(2).max(1);
29            self.total += *count as u32;
30            true
31        });
32    }
33}
34
35#[derive(Clone, Debug)]
36/// Bounded-memory PPMD-inspired byte model with interpolation across orders.
37pub struct PpmdModel {
38    order: usize,
39    max_contexts: usize,
40    contexts: Vec<AHashMap<u64, ContextStats>>,
41    queue: VecDeque<(usize, u64)>,
42    history: Vec<u8>,
43    pdf: [f64; 256],
44    cdf: [f64; 257],
45    valid: bool,
46    cdf_valid: bool,
47}
48
49impl PpmdModel {
50    /// Create a model with maximum `order` and approximate memory budget in MiB.
51    pub fn new(order: usize, memory_mb: usize) -> Self {
52        let order = order.max(1);
53        let max_contexts = (memory_mb.max(1) * 1024 * 1024) / 96;
54        Self {
55            order,
56            max_contexts: max_contexts.max(1024),
57            contexts: (0..=order).map(|_| AHashMap::new()).collect(),
58            queue: VecDeque::new(),
59            history: Vec::new(),
60            pdf: [1.0 / 256.0; 256],
61            cdf: uniform_cdf(),
62            valid: false,
63            cdf_valid: false,
64        }
65    }
66
67    /// Fill `out` with the current normalized byte PDF.
68    pub fn fill_pdf(&mut self, out: &mut [f64; 256]) {
69        self.ensure_pdf_inner(false);
70        out.copy_from_slice(&self.pdf);
71    }
72
73    /// Borrow the current normalized byte PDF.
74    pub fn pdf(&mut self) -> &[f64; 256] {
75        self.ensure_pdf_inner(false);
76        &self.pdf
77    }
78
79    /// Borrow the cumulative distribution derived from the current PDF.
80    pub fn cdf(&mut self) -> &[f64; 257] {
81        self.ensure_pdf_inner(true);
82        &self.cdf
83    }
84
85    /// Return `ln(max(P(symbol), min_prob))`.
86    pub fn log_prob(&mut self, symbol: u8, min_prob: f64) -> f64 {
87        self.ensure_pdf_inner(false);
88        self.pdf[symbol as usize].max(min_prob).ln()
89    }
90
91    /// Observe one symbol and update all active contexts up to model order.
92    pub fn update(&mut self, symbol: u8) {
93        let max_order = self.order.min(self.history.len());
94        for ord in 0..=max_order {
95            let key = self.context_key(ord);
96            let map = &mut self.contexts[ord];
97            if !map.contains_key(&key) {
98                map.insert(key, ContextStats::default());
99                self.queue.push_back((ord, key));
100            }
101            if let Some(ctx) = map.get_mut(&key) {
102                ctx.observe(symbol);
103            }
104        }
105        self.prune();
106        self.history.push(symbol);
107        self.valid = false;
108        self.cdf_valid = false;
109    }
110
111    /// Reset only the conditioning history while preserving fitted contexts.
112    pub fn reset_history(&mut self) {
113        self.history.clear();
114        self.valid = false;
115        self.cdf_valid = false;
116        self.pdf.fill(1.0 / 256.0);
117        self.cdf = uniform_cdf();
118    }
119
120    /// Advance conditioning history without updating fitted context counts.
121    pub fn update_history_only(&mut self, symbol: u8) {
122        self.history.push(symbol);
123        self.valid = false;
124        self.cdf_valid = false;
125    }
126
127    fn ensure_pdf_inner(&mut self, want_cdf: bool) {
128        if self.valid {
129            if want_cdf && !self.cdf_valid {
130                build_cdf_from_pdf(&self.pdf, &mut self.cdf);
131                self.cdf_valid = true;
132            }
133            return;
134        }
135        let mut lower = [1.0 / 256.0; 256];
136        let max_order = self.order.min(self.history.len());
137        for ord in 0..=max_order {
138            let key = self.context_key(ord);
139            if let Some(ctx) = self.contexts[ord].get(&key) {
140                lower = interpolate_context(ctx, &lower);
141            }
142        }
143        self.pdf.copy_from_slice(&lower);
144        normalize_pdf_and_maybe_cdf(
145            &mut self.pdf,
146            if want_cdf { Some(&mut self.cdf) } else { None },
147        );
148        self.valid = true;
149        self.cdf_valid = want_cdf;
150    }
151
152    fn prune(&mut self) {
153        let mut total_contexts: usize = self.contexts.iter().map(|m| m.len()).sum();
154        while total_contexts > self.max_contexts {
155            let Some((ord, key)) = self.queue.pop_front() else {
156                break;
157            };
158            if self.contexts[ord].remove(&key).is_some() {
159                total_contexts -= 1;
160            }
161        }
162    }
163
164    fn context_key(&self, ord: usize) -> u64 {
165        if ord == 0 {
166            return 0;
167        }
168        let start = self.history.len() - ord;
169        hash_bytes(&self.history[start..])
170    }
171}
172
173fn interpolate_context(ctx: &ContextStats, lower: &[f64; 256]) -> [f64; 256] {
174    let distinct = ctx.counts.len() as f64;
175    let denom = (ctx.total as f64) + distinct + 1.0;
176    let escape = (distinct + 1.0) / denom;
177    let mut out = [0.0; 256];
178    for i in 0..256 {
179        out[i] = lower[i] * escape;
180    }
181    for &(symbol, count) in &ctx.counts {
182        out[symbol as usize] += (count as f64) / denom;
183    }
184    out
185}
186
187fn normalize_pdf_and_maybe_cdf(pdf: &mut [f64; 256], mut cdf: Option<&mut [f64; 257]>) {
188    let mut sum = 0.0;
189    for p in pdf.iter_mut() {
190        *p = if p.is_finite() {
191            (*p).max(PDF_MIN)
192        } else {
193            PDF_MIN
194        };
195        sum += *p;
196    }
197    if !(sum.is_finite()) || sum <= 0.0 {
198        let u = 1.0 / 256.0;
199        pdf.fill(u);
200        if let Some(cdf) = cdf.as_deref_mut() {
201            *cdf = uniform_cdf();
202        }
203        return;
204    }
205    let inv = 1.0 / sum;
206    if let Some(cdf) = cdf.as_deref_mut() {
207        cdf[0] = 0.0;
208        let mut acc = 0.0;
209        for i in 0..256 {
210            pdf[i] *= inv;
211            acc += pdf[i];
212            cdf[i + 1] = acc;
213        }
214    } else {
215        for p in pdf.iter_mut() {
216            *p *= inv;
217        }
218    }
219}
220
221#[inline]
222fn uniform_cdf() -> [f64; 257] {
223    let mut cdf = [0.0; 257];
224    let inv = 1.0 / 256.0;
225    for (i, slot) in cdf.iter_mut().enumerate() {
226        *slot = (i as f64) * inv;
227    }
228    cdf
229}
230
231#[inline]
232fn build_cdf_from_pdf(pdf: &[f64; 256], cdf: &mut [f64; 257]) {
233    cdf[0] = 0.0;
234    let mut acc = 0.0;
235    for i in 0..256 {
236        acc += pdf[i];
237        cdf[i + 1] = acc;
238    }
239}
240
241fn hash_bytes(bytes: &[u8]) -> u64 {
242    let mut h = 0xCBF2_9CE4_8422_2325u64;
243    for &b in bytes {
244        h ^= b as u64;
245        h = h.wrapping_mul(0x1000_0000_01B3);
246    }
247    h
248}