infotheory/backends/
ppmd.rs1use ahash::AHashMap;
2use std::collections::VecDeque;
3
4const PDF_MIN: f64 = crate::mixture::DEFAULT_MIN_PROB;
5
6#[derive(Clone, Debug, Default)]
7struct ContextStats {
8 counts: Vec<(u8, u16)>,
9 total: u32,
10}
11
12impl ContextStats {
13 fn observe(&mut self, symbol: u8) {
14 if let Some((_, count)) = self.counts.iter_mut().find(|(s, _)| *s == symbol) {
15 *count = count.saturating_add(1);
16 } else {
17 self.counts.push((symbol, 1));
18 }
19 self.total = self.total.saturating_add(1);
20 if self.total > 4096 {
21 self.rescale();
22 }
23 }
24
25 fn rescale(&mut self) {
26 self.total = 0;
27 self.counts.retain_mut(|(_, count)| {
28 *count = (*count).div_ceil(2).max(1);
29 self.total += *count as u32;
30 true
31 });
32 }
33}
34
35#[derive(Clone, Debug)]
36pub struct PpmdModel {
38 order: usize,
39 max_contexts: usize,
40 contexts: Vec<AHashMap<u64, ContextStats>>,
41 queue: VecDeque<(usize, u64)>,
42 history: Vec<u8>,
43 pdf: [f64; 256],
44 cdf: [f64; 257],
45 valid: bool,
46 cdf_valid: bool,
47}
48
49impl PpmdModel {
50 pub fn new(order: usize, memory_mb: usize) -> Self {
52 let order = order.max(1);
53 let max_contexts = (memory_mb.max(1) * 1024 * 1024) / 96;
54 Self {
55 order,
56 max_contexts: max_contexts.max(1024),
57 contexts: (0..=order).map(|_| AHashMap::new()).collect(),
58 queue: VecDeque::new(),
59 history: Vec::new(),
60 pdf: [1.0 / 256.0; 256],
61 cdf: uniform_cdf(),
62 valid: false,
63 cdf_valid: false,
64 }
65 }
66
67 pub fn fill_pdf(&mut self, out: &mut [f64; 256]) {
69 self.ensure_pdf_inner(false);
70 out.copy_from_slice(&self.pdf);
71 }
72
73 pub fn pdf(&mut self) -> &[f64; 256] {
75 self.ensure_pdf_inner(false);
76 &self.pdf
77 }
78
79 pub fn cdf(&mut self) -> &[f64; 257] {
81 self.ensure_pdf_inner(true);
82 &self.cdf
83 }
84
85 pub fn log_prob(&mut self, symbol: u8, min_prob: f64) -> f64 {
87 self.ensure_pdf_inner(false);
88 self.pdf[symbol as usize].max(min_prob).ln()
89 }
90
91 pub fn update(&mut self, symbol: u8) {
93 let max_order = self.order.min(self.history.len());
94 for ord in 0..=max_order {
95 let key = self.context_key(ord);
96 let map = &mut self.contexts[ord];
97 if !map.contains_key(&key) {
98 map.insert(key, ContextStats::default());
99 self.queue.push_back((ord, key));
100 }
101 if let Some(ctx) = map.get_mut(&key) {
102 ctx.observe(symbol);
103 }
104 }
105 self.prune();
106 self.history.push(symbol);
107 self.valid = false;
108 self.cdf_valid = false;
109 }
110
111 pub fn reset_history(&mut self) {
113 self.history.clear();
114 self.valid = false;
115 self.cdf_valid = false;
116 self.pdf.fill(1.0 / 256.0);
117 self.cdf = uniform_cdf();
118 }
119
120 pub fn update_history_only(&mut self, symbol: u8) {
122 self.history.push(symbol);
123 self.valid = false;
124 self.cdf_valid = false;
125 }
126
127 fn ensure_pdf_inner(&mut self, want_cdf: bool) {
128 if self.valid {
129 if want_cdf && !self.cdf_valid {
130 build_cdf_from_pdf(&self.pdf, &mut self.cdf);
131 self.cdf_valid = true;
132 }
133 return;
134 }
135 let mut lower = [1.0 / 256.0; 256];
136 let max_order = self.order.min(self.history.len());
137 for ord in 0..=max_order {
138 let key = self.context_key(ord);
139 if let Some(ctx) = self.contexts[ord].get(&key) {
140 lower = interpolate_context(ctx, &lower);
141 }
142 }
143 self.pdf.copy_from_slice(&lower);
144 normalize_pdf_and_maybe_cdf(
145 &mut self.pdf,
146 if want_cdf { Some(&mut self.cdf) } else { None },
147 );
148 self.valid = true;
149 self.cdf_valid = want_cdf;
150 }
151
152 fn prune(&mut self) {
153 let mut total_contexts: usize = self.contexts.iter().map(|m| m.len()).sum();
154 while total_contexts > self.max_contexts {
155 let Some((ord, key)) = self.queue.pop_front() else {
156 break;
157 };
158 if self.contexts[ord].remove(&key).is_some() {
159 total_contexts -= 1;
160 }
161 }
162 }
163
164 fn context_key(&self, ord: usize) -> u64 {
165 if ord == 0 {
166 return 0;
167 }
168 let start = self.history.len() - ord;
169 hash_bytes(&self.history[start..])
170 }
171}
172
173fn interpolate_context(ctx: &ContextStats, lower: &[f64; 256]) -> [f64; 256] {
174 let distinct = ctx.counts.len() as f64;
175 let denom = (ctx.total as f64) + distinct + 1.0;
176 let escape = (distinct + 1.0) / denom;
177 let mut out = [0.0; 256];
178 for i in 0..256 {
179 out[i] = lower[i] * escape;
180 }
181 for &(symbol, count) in &ctx.counts {
182 out[symbol as usize] += (count as f64) / denom;
183 }
184 out
185}
186
187fn normalize_pdf_and_maybe_cdf(pdf: &mut [f64; 256], mut cdf: Option<&mut [f64; 257]>) {
188 let mut sum = 0.0;
189 for p in pdf.iter_mut() {
190 *p = if p.is_finite() {
191 (*p).max(PDF_MIN)
192 } else {
193 PDF_MIN
194 };
195 sum += *p;
196 }
197 if !(sum.is_finite()) || sum <= 0.0 {
198 let u = 1.0 / 256.0;
199 pdf.fill(u);
200 if let Some(cdf) = cdf.as_deref_mut() {
201 *cdf = uniform_cdf();
202 }
203 return;
204 }
205 let inv = 1.0 / sum;
206 if let Some(cdf) = cdf.as_deref_mut() {
207 cdf[0] = 0.0;
208 let mut acc = 0.0;
209 for i in 0..256 {
210 pdf[i] *= inv;
211 acc += pdf[i];
212 cdf[i + 1] = acc;
213 }
214 } else {
215 for p in pdf.iter_mut() {
216 *p *= inv;
217 }
218 }
219}
220
221#[inline]
222fn uniform_cdf() -> [f64; 257] {
223 let mut cdf = [0.0; 257];
224 let inv = 1.0 / 256.0;
225 for (i, slot) in cdf.iter_mut().enumerate() {
226 *slot = (i as f64) * inv;
227 }
228 cdf
229}
230
231#[inline]
232fn build_cdf_from_pdf(pdf: &[f64; 256], cdf: &mut [f64; 257]) {
233 cdf[0] = 0.0;
234 let mut acc = 0.0;
235 for i in 0..256 {
236 acc += pdf[i];
237 cdf[i + 1] = acc;
238 }
239}
240
241fn hash_bytes(bytes: &[u8]) -> u64 {
242 let mut h = 0xCBF2_9CE4_8422_2325u64;
243 for &b in bytes {
244 h ^= b as u64;
245 h = h.wrapping_mul(0x1000_0000_01B3);
246 }
247 h
248}