infotheory/backends/
zpaq_rate.rs1#[cfg(feature = "backend-zpaq")]
7use std::f64::consts::LN_2;
8
9#[cfg(feature = "backend-zpaq")]
10const DEFAULT_MIN_PROB: f64 = 5.960_464_477_539_063e-8;
11
12#[cfg(feature = "backend-zpaq")]
13mod imp {
14 use super::{DEFAULT_MIN_PROB, LN_2};
15 use zpaq_rs::StreamingCompressor;
16
17 struct ZpaqStreaming {
18 compressor: StreamingCompressor,
19 last_bits: f64,
20 }
21
22 pub struct ZpaqRateModel {
24 stream: ZpaqStreaming,
25 history: Vec<u8>,
26 pending_symbol: Option<u8>,
27 pending_bits: f64,
28 min_prob: f64,
29 method: String,
30 }
31
32 impl ZpaqRateModel {
33 pub fn new(method: impl Into<String>, min_prob: f64) -> Self {
37 let method = method.into();
38 let min_prob = if min_prob.is_finite() && min_prob > 0.0 {
39 min_prob
40 } else {
41 DEFAULT_MIN_PROB
42 };
43
44 let compressor = StreamingCompressor::new(method.as_str()).unwrap_or_else(|e| {
45 panic!("ZPAQ rate backend requires a streamable method; got '{method}': {e}")
46 });
47
48 Self {
49 stream: ZpaqStreaming {
50 compressor,
51 last_bits: 0.0,
52 },
53 history: Vec::new(),
54 pending_symbol: None,
55 pending_bits: 0.0,
56 min_prob,
57 method,
58 }
59 }
60
61 pub fn reset(&mut self) {
63 let method = self.method.clone();
64 let compressor = StreamingCompressor::new(method.as_str()).unwrap_or_else(|e| {
65 panic!("ZPAQ rate backend requires a streamable method; got '{method}': {e}")
66 });
67 self.stream = ZpaqStreaming {
68 compressor,
69 last_bits: 0.0,
70 };
71 self.history.clear();
72 self.pending_symbol = None;
73 self.pending_bits = 0.0;
74 }
75
76 fn rebuild_stream_from_history(&mut self) {
77 let method = self.method.clone();
78 let compressor = StreamingCompressor::new(method.as_str()).unwrap_or_else(|e| {
79 panic!("ZPAQ rate backend requires a streamable method; got '{method}': {e}")
80 });
81 self.stream = ZpaqStreaming {
82 compressor,
83 last_bits: 0.0,
84 };
85 let history = self.history.clone();
86 for b in history {
87 let _ = self.encode_bits(b);
88 }
89 self.pending_symbol = None;
90 self.pending_bits = 0.0;
91 }
92
93 fn log_prob_from_history(&self, symbol: u8) -> f64 {
94 let mut compressor =
95 StreamingCompressor::new(self.method.as_str()).expect("zpaq streaming new failed");
96 for &b in &self.history {
97 compressor
98 .push(b)
99 .expect("zpaq streaming compression failed");
100 }
101 let before = compressor.bits();
102 compressor
103 .push(symbol)
104 .expect("zpaq streaming compression failed");
105 let bits = (compressor.bits() - before).max(0.0);
106 let logp = -(bits * LN_2);
107 logp.max(self.min_prob.ln())
108 }
109
110 fn encode_bits(&mut self, symbol: u8) -> f64 {
111 let before = self.stream.last_bits;
112 self.stream
113 .compressor
114 .push(symbol)
115 .expect("zpaq streaming compression failed");
116 let after = self.stream.compressor.bits();
117 self.stream.last_bits = after;
118 (after - before).max(0.0)
119 }
120
121 pub fn log_prob(&mut self, symbol: u8) -> f64 {
125 if let Some(pending) = self.pending_symbol {
126 if pending == symbol {
127 let logp = -(self.pending_bits * LN_2);
128 return logp.max(self.min_prob.ln());
129 }
130 self.rebuild_stream_from_history();
132 }
133
134 let bits = self.encode_bits(symbol);
135 self.pending_symbol = Some(symbol);
136 self.pending_bits = bits;
137 let logp = -(bits * LN_2);
138 logp.max(self.min_prob.ln())
139 }
140
141 pub fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
143 self.rebuild_stream_from_history();
145 for (sym, slot) in out.iter_mut().enumerate() {
146 *slot = self.log_prob_from_history(sym as u8);
147 }
148 }
149
150 pub fn update(&mut self, symbol: u8) {
152 if let Some(pending) = self.pending_symbol
153 && pending == symbol
154 {
155 self.pending_symbol = None;
156 self.history.push(symbol);
157 return;
158 }
159 if self.pending_symbol.is_some() {
160 self.rebuild_stream_from_history();
161 }
162 let _ = self.encode_bits(symbol);
163 self.pending_symbol = None;
164 self.pending_bits = 0.0;
165 self.history.push(symbol);
166 }
167
168 pub fn update_and_score(&mut self, data: &[u8]) -> f64 {
170 if data.is_empty() {
171 return 0.0;
172 }
173 self.pending_symbol = None;
174 self.pending_bits = 0.0;
175 let mut bits = 0.0;
176 for &b in data {
177 bits += self.encode_bits(b);
178 self.history.push(b);
179 }
180 bits
181 }
182 }
183
184 impl Clone for ZpaqRateModel {
185 fn clone(&self) -> Self {
186 let mut cloned = Self::new(self.method.clone(), self.min_prob);
187 if !self.history.is_empty() {
188 let _ = cloned.update_and_score(&self.history);
189 }
190 if let Some(symbol) = self.pending_symbol {
193 let bits = cloned.encode_bits(symbol);
194 cloned.pending_symbol = Some(symbol);
195 cloned.pending_bits = bits;
196 } else {
197 cloned.pending_symbol = None;
198 cloned.pending_bits = 0.0;
199 }
200 cloned
201 }
202 }
203
204 pub fn validate_zpaq_rate_method(method: &str) -> Result<(), String> {
206 StreamingCompressor::new(method)
207 .map(|_| ())
208 .map_err(|e| e.to_string())
209 }
210
211 #[cfg(test)]
212 mod tests {
213 use super::*;
214
215 #[test]
216 fn zpaq_log_prob_update_matches_update_and_score() {
217 let data = b"the quick brown fox jumps over the lazy dog";
218 let mut model_a = ZpaqRateModel::new("1", 1e-9);
219 let mut bits_a = 0.0;
220 for &b in data {
221 let logp = model_a.log_prob(b);
222 bits_a += -logp / LN_2;
223 model_a.update(b);
224 }
225
226 let mut model_b = ZpaqRateModel::new("1", 1e-9);
227 let bits_b = model_b.update_and_score(data);
228
229 let diff = (bits_a - bits_b).abs();
230 assert!(diff < 1e-6, "bits mismatch: {bits_a} vs {bits_b}");
231 }
232
233 #[test]
234 fn zpaq_fill_log_probs_is_non_mutating() {
235 let history = b"zpaq fill non mutating";
236 let mut model_a = ZpaqRateModel::new("1", 1e-9);
237 let mut model_b = ZpaqRateModel::new("1", 1e-9);
238 for &b in history {
239 model_a.update(b);
240 model_b.update(b);
241 }
242
243 let mut row = [0.0f64; 256];
244 model_b.fill_log_probs(&mut row);
245
246 let sym = b'x';
247 let lp_a = model_a.log_prob(sym);
248 let lp_b = model_b.log_prob(sym);
249 assert!((lp_a - lp_b).abs() < 1e-9, "lp_a={lp_a} lp_b={lp_b}");
250 assert!((row[sym as usize] - lp_a).abs() < 1e-9);
251
252 model_a.update(sym);
253 model_b.update(sym);
254 let next_sym = b'y';
255 let lp_a2 = model_a.log_prob(next_sym);
256 let lp_b2 = model_b.log_prob(next_sym);
257 assert!((lp_a2 - lp_b2).abs() < 1e-9, "lp_a2={lp_a2} lp_b2={lp_b2}");
258 }
259
260 #[test]
261 fn zpaq_clone_preserves_pending_prediction_state() {
262 let mut model_a = ZpaqRateModel::new("1", 1e-9);
263 for &b in b"clone preserves pending state" {
264 model_a.update(b);
265 }
266
267 let probe = b'x';
268 let lp_a = model_a.log_prob(probe);
269 let mut model_b = model_a.clone();
270 let lp_b = model_b.log_prob(probe);
271 assert!((lp_a - lp_b).abs() < 1e-9, "lp_a={lp_a} lp_b={lp_b}");
272
273 model_a.update(probe);
274 model_b.update(probe);
275 let next = b'y';
276 let lp_a2 = model_a.log_prob(next);
277 let lp_b2 = model_b.log_prob(next);
278 assert!((lp_a2 - lp_b2).abs() < 1e-9, "lp_a2={lp_a2} lp_b2={lp_b2}");
279 }
280 }
281}
282
283#[cfg(not(feature = "backend-zpaq"))]
284mod imp {
285 #[derive(Clone)]
286 pub struct ZpaqRateModel {
287 min_log_prob: f64,
288 }
289
290 impl ZpaqRateModel {
291 pub fn new(_method: impl Into<String>, min_prob: f64) -> Self {
292 let min_prob = if min_prob.is_finite() && min_prob > 0.0 {
293 min_prob
294 } else {
295 1e-12
296 };
297 Self {
298 min_log_prob: min_prob.ln(),
299 }
300 }
301
302 pub fn reset(&mut self) {}
303
304 pub fn log_prob(&mut self, _symbol: u8) -> f64 {
305 self.min_log_prob
306 }
307
308 pub fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
309 out.fill(self.min_log_prob);
310 }
311
312 pub fn update(&mut self, _symbol: u8) {}
313
314 pub fn update_and_score(&mut self, data: &[u8]) -> f64 {
315 let bits_per_symbol = -self.min_log_prob / std::f64::consts::LN_2;
316 bits_per_symbol * (data.len() as f64)
317 }
318 }
319
320 pub fn validate_zpaq_rate_method(_method: &str) -> Result<(), String> {
321 Err("zpaq backend disabled at compile time".to_string())
322 }
323}
324
325pub use imp::ZpaqRateModel;
327pub use imp::validate_zpaq_rate_method;