infotheory/backends/
zpaq_rate.rs

1//! ZPAQ-backed sequential rate model.
2//!
3//! This backend estimates `log p(x_t | x_{<t})` by measuring incremental
4//! streaming compression growth under a streamable ZPAQ method.
5
6#[cfg(feature = "backend-zpaq")]
7use std::f64::consts::LN_2;
8
9#[cfg(feature = "backend-zpaq")]
10const DEFAULT_MIN_PROB: f64 = 5.960_464_477_539_063e-8;
11
12#[cfg(feature = "backend-zpaq")]
13mod imp {
14    use super::{DEFAULT_MIN_PROB, LN_2};
15    use zpaq_rs::StreamingCompressor;
16
17    struct ZpaqStreaming {
18        compressor: StreamingCompressor,
19        last_bits: f64,
20    }
21
22    /// Stateful ZPAQ-backed estimator of sequential symbol log-probabilities.
23    pub struct ZpaqRateModel {
24        stream: ZpaqStreaming,
25        history: Vec<u8>,
26        pending_symbol: Option<u8>,
27        pending_bits: f64,
28        min_prob: f64,
29        method: String,
30    }
31
32    impl ZpaqRateModel {
33        /// Create a new model with the provided streamable ZPAQ `method`.
34        ///
35        /// `min_prob` clamps very small probabilities for numerical stability.
36        pub fn new(method: impl Into<String>, min_prob: f64) -> Self {
37            let method = method.into();
38            let min_prob = if min_prob.is_finite() && min_prob > 0.0 {
39                min_prob
40            } else {
41                DEFAULT_MIN_PROB
42            };
43
44            let compressor = StreamingCompressor::new(method.as_str()).unwrap_or_else(|e| {
45                panic!("ZPAQ rate backend requires a streamable method; got '{method}': {e}")
46            });
47
48            Self {
49                stream: ZpaqStreaming {
50                    compressor,
51                    last_bits: 0.0,
52                },
53                history: Vec::new(),
54                pending_symbol: None,
55                pending_bits: 0.0,
56                min_prob,
57                method,
58            }
59        }
60
61        /// Reset model state and clear any pending prediction cache.
62        pub fn reset(&mut self) {
63            let method = self.method.clone();
64            let compressor = StreamingCompressor::new(method.as_str()).unwrap_or_else(|e| {
65                panic!("ZPAQ rate backend requires a streamable method; got '{method}': {e}")
66            });
67            self.stream = ZpaqStreaming {
68                compressor,
69                last_bits: 0.0,
70            };
71            self.history.clear();
72            self.pending_symbol = None;
73            self.pending_bits = 0.0;
74        }
75
76        fn rebuild_stream_from_history(&mut self) {
77            let method = self.method.clone();
78            let compressor = StreamingCompressor::new(method.as_str()).unwrap_or_else(|e| {
79                panic!("ZPAQ rate backend requires a streamable method; got '{method}': {e}")
80            });
81            self.stream = ZpaqStreaming {
82                compressor,
83                last_bits: 0.0,
84            };
85            let history = self.history.clone();
86            for b in history {
87                let _ = self.encode_bits(b);
88            }
89            self.pending_symbol = None;
90            self.pending_bits = 0.0;
91        }
92
93        fn log_prob_from_history(&self, symbol: u8) -> f64 {
94            let mut compressor =
95                StreamingCompressor::new(self.method.as_str()).expect("zpaq streaming new failed");
96            for &b in &self.history {
97                compressor
98                    .push(b)
99                    .expect("zpaq streaming compression failed");
100            }
101            let before = compressor.bits();
102            compressor
103                .push(symbol)
104                .expect("zpaq streaming compression failed");
105            let bits = (compressor.bits() - before).max(0.0);
106            let logp = -(bits * LN_2);
107            logp.max(self.min_prob.ln())
108        }
109
110        fn encode_bits(&mut self, symbol: u8) -> f64 {
111            let before = self.stream.last_bits;
112            self.stream
113                .compressor
114                .push(symbol)
115                .expect("zpaq streaming compression failed");
116            let after = self.stream.compressor.bits();
117            self.stream.last_bits = after;
118            (after - before).max(0.0)
119        }
120
121        /// Return `ln p(symbol | history)` under the current model state.
122        ///
123        /// This may cache the encoded-bit result for a matching immediate `update`.
124        pub fn log_prob(&mut self, symbol: u8) -> f64 {
125            if let Some(pending) = self.pending_symbol {
126                if pending == symbol {
127                    let logp = -(self.pending_bits * LN_2);
128                    return logp.max(self.min_prob.ln());
129                }
130                // We cannot rollback `StreamingCompressor`; rebuild to committed history.
131                self.rebuild_stream_from_history();
132            }
133
134            let bits = self.encode_bits(symbol);
135            self.pending_symbol = Some(symbol);
136            self.pending_bits = bits;
137            let logp = -(bits * LN_2);
138            logp.max(self.min_prob.ln())
139        }
140
141        /// Fill 256-way log-probabilities for the current committed history without mutation.
142        pub fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
143            // Treat fill as a read-only query of committed history.
144            self.rebuild_stream_from_history();
145            for (sym, slot) in out.iter_mut().enumerate() {
146                *slot = self.log_prob_from_history(sym as u8);
147            }
148        }
149
150        /// Advance model state with one observed symbol.
151        pub fn update(&mut self, symbol: u8) {
152            if let Some(pending) = self.pending_symbol
153                && pending == symbol
154            {
155                self.pending_symbol = None;
156                self.history.push(symbol);
157                return;
158            }
159            if self.pending_symbol.is_some() {
160                self.rebuild_stream_from_history();
161            }
162            let _ = self.encode_bits(symbol);
163            self.pending_symbol = None;
164            self.pending_bits = 0.0;
165            self.history.push(symbol);
166        }
167
168        /// Score and consume an entire byte slice, returning total code length in bits.
169        pub fn update_and_score(&mut self, data: &[u8]) -> f64 {
170            if data.is_empty() {
171                return 0.0;
172            }
173            self.pending_symbol = None;
174            self.pending_bits = 0.0;
175            let mut bits = 0.0;
176            for &b in data {
177                bits += self.encode_bits(b);
178                self.history.push(b);
179            }
180            bits
181        }
182    }
183
184    impl Clone for ZpaqRateModel {
185        fn clone(&self) -> Self {
186            let mut cloned = Self::new(self.method.clone(), self.min_prob);
187            if !self.history.is_empty() {
188                let _ = cloned.update_and_score(&self.history);
189            }
190            // Preserve speculative pending state so clone() is state-equivalent
191            // even when called between log_prob() and update().
192            if let Some(symbol) = self.pending_symbol {
193                let bits = cloned.encode_bits(symbol);
194                cloned.pending_symbol = Some(symbol);
195                cloned.pending_bits = bits;
196            } else {
197                cloned.pending_symbol = None;
198                cloned.pending_bits = 0.0;
199            }
200            cloned
201        }
202    }
203
204    /// Validate that `method` is streamable and accepted by the ZPAQ backend.
205    pub fn validate_zpaq_rate_method(method: &str) -> Result<(), String> {
206        StreamingCompressor::new(method)
207            .map(|_| ())
208            .map_err(|e| e.to_string())
209    }
210
211    #[cfg(test)]
212    mod tests {
213        use super::*;
214
215        #[test]
216        fn zpaq_log_prob_update_matches_update_and_score() {
217            let data = b"the quick brown fox jumps over the lazy dog";
218            let mut model_a = ZpaqRateModel::new("1", 1e-9);
219            let mut bits_a = 0.0;
220            for &b in data {
221                let logp = model_a.log_prob(b);
222                bits_a += -logp / LN_2;
223                model_a.update(b);
224            }
225
226            let mut model_b = ZpaqRateModel::new("1", 1e-9);
227            let bits_b = model_b.update_and_score(data);
228
229            let diff = (bits_a - bits_b).abs();
230            assert!(diff < 1e-6, "bits mismatch: {bits_a} vs {bits_b}");
231        }
232
233        #[test]
234        fn zpaq_fill_log_probs_is_non_mutating() {
235            let history = b"zpaq fill non mutating";
236            let mut model_a = ZpaqRateModel::new("1", 1e-9);
237            let mut model_b = ZpaqRateModel::new("1", 1e-9);
238            for &b in history {
239                model_a.update(b);
240                model_b.update(b);
241            }
242
243            let mut row = [0.0f64; 256];
244            model_b.fill_log_probs(&mut row);
245
246            let sym = b'x';
247            let lp_a = model_a.log_prob(sym);
248            let lp_b = model_b.log_prob(sym);
249            assert!((lp_a - lp_b).abs() < 1e-9, "lp_a={lp_a} lp_b={lp_b}");
250            assert!((row[sym as usize] - lp_a).abs() < 1e-9);
251
252            model_a.update(sym);
253            model_b.update(sym);
254            let next_sym = b'y';
255            let lp_a2 = model_a.log_prob(next_sym);
256            let lp_b2 = model_b.log_prob(next_sym);
257            assert!((lp_a2 - lp_b2).abs() < 1e-9, "lp_a2={lp_a2} lp_b2={lp_b2}");
258        }
259
260        #[test]
261        fn zpaq_clone_preserves_pending_prediction_state() {
262            let mut model_a = ZpaqRateModel::new("1", 1e-9);
263            for &b in b"clone preserves pending state" {
264                model_a.update(b);
265            }
266
267            let probe = b'x';
268            let lp_a = model_a.log_prob(probe);
269            let mut model_b = model_a.clone();
270            let lp_b = model_b.log_prob(probe);
271            assert!((lp_a - lp_b).abs() < 1e-9, "lp_a={lp_a} lp_b={lp_b}");
272
273            model_a.update(probe);
274            model_b.update(probe);
275            let next = b'y';
276            let lp_a2 = model_a.log_prob(next);
277            let lp_b2 = model_b.log_prob(next);
278            assert!((lp_a2 - lp_b2).abs() < 1e-9, "lp_a2={lp_a2} lp_b2={lp_b2}");
279        }
280    }
281}
282
283#[cfg(not(feature = "backend-zpaq"))]
284mod imp {
285    #[derive(Clone)]
286    pub struct ZpaqRateModel {
287        min_log_prob: f64,
288    }
289
290    impl ZpaqRateModel {
291        pub fn new(_method: impl Into<String>, min_prob: f64) -> Self {
292            let min_prob = if min_prob.is_finite() && min_prob > 0.0 {
293                min_prob
294            } else {
295                1e-12
296            };
297            Self {
298                min_log_prob: min_prob.ln(),
299            }
300        }
301
302        pub fn reset(&mut self) {}
303
304        pub fn log_prob(&mut self, _symbol: u8) -> f64 {
305            self.min_log_prob
306        }
307
308        pub fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
309            out.fill(self.min_log_prob);
310        }
311
312        pub fn update(&mut self, _symbol: u8) {}
313
314        pub fn update_and_score(&mut self, data: &[u8]) -> f64 {
315            let bits_per_symbol = -self.min_log_prob / std::f64::consts::LN_2;
316            bits_per_symbol * (data.len() as f64)
317        }
318    }
319
320    pub fn validate_zpaq_rate_method(_method: &str) -> Result<(), String> {
321        Err("zpaq backend disabled at compile time".to_string())
322    }
323}
324
325/// Stateful ZPAQ-based rate estimator.
326pub use imp::ZpaqRateModel;
327/// Validate that a ZPAQ method string is streamable and usable for rate modeling.
328pub use imp::validate_zpaq_rate_method;