infotheory/compression/
mod.rs

1//! Rate-coded compression helpers (AC/rANS) with optional framing.
2//!
3//! The functions in this module implement lossless byte compression by combining:
4//! - a predictive rate model (`RateBackend`) that emits per-symbol PDFs,
5//! - an entropy coder (`AC` or `rANS`),
6//! - optional framing metadata for robust decompression.
7
8use anyhow::{Result, bail};
9
10use crate::backends::calibration::CalibratorCore;
11use crate::backends::match_model::MatchModel;
12use crate::backends::ppmd::PpmdModel;
13use crate::backends::sparse_match::SparseMatchModel;
14use crate::backends::text_context::TextContextAnalyzer;
15use crate::coders::{
16    ANS_TOTAL, ArithmeticDecoder, ArithmeticEncoder, BlockedRansDecoder, BlockedRansEncoder,
17    CDF_TOTAL, Cdf, CoderType, crc32, quantize_pdf_to_rans_cdf_with_buffer,
18};
19use crate::ctw::FacContextTree;
20#[cfg(feature = "backend-mamba")]
21use crate::mambazip;
22use crate::mixture::DEFAULT_MIN_PROB;
23use crate::neural_mix::NeuralMixCore;
24use crate::rosaplus::RosaPlus;
25#[cfg(feature = "backend-rwkv")]
26use crate::rwkvzip;
27use crate::zpaq_rate::ZpaqRateModel;
28use crate::{CalibratedSpec, MixtureKind, MixtureSpec, RateBackend};
29
30const FRAMED_MAGIC: u32 = 0x4354_4946; // "FITC"
31const FRAMED_VERSION: u8 = 1;
32const PDF_MIN: f64 = DEFAULT_MIN_PROB;
33
34#[inline]
35fn build_calibrator(spec: &CalibratedSpec) -> CalibratorCore {
36    CalibratorCore::new(spec.context, spec.bins, spec.learning_rate, spec.bias_clip)
37}
38
39#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
40/// Wire format mode for rate-coded payloads.
41pub enum FramingMode {
42    /// Emit only coder payload bytes (no integrity/length header).
43    Raw,
44    /// Emit framed payload with magic/version/length/checksum header.
45    #[default]
46    Framed,
47}
48
49#[derive(Clone, Copy, Debug)]
50struct FramedHeader {
51    magic: u32,
52    version: u8,
53    coder: u8,
54    original_len: u64,
55    crc32: u32,
56}
57
58impl FramedHeader {
59    const SIZE: usize = 4 + 1 + 1 + 8 + 4;
60
61    fn new(coder: CoderType, original_len: u64, crc32: u32) -> Self {
62        Self {
63            magic: FRAMED_MAGIC,
64            version: FRAMED_VERSION,
65            coder: match coder {
66                CoderType::AC => 0,
67                CoderType::RANS => 1,
68            },
69            original_len,
70            crc32,
71        }
72    }
73
74    fn write(&self, out: &mut Vec<u8>) {
75        out.extend_from_slice(&self.magic.to_le_bytes());
76        out.push(self.version);
77        out.push(self.coder);
78        out.extend_from_slice(&self.original_len.to_le_bytes());
79        out.extend_from_slice(&self.crc32.to_le_bytes());
80    }
81
82    fn read(input: &[u8]) -> Result<Self> {
83        if input.len() < Self::SIZE {
84            bail!("framed payload too short");
85        }
86        let magic = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
87        if magic != FRAMED_MAGIC {
88            bail!("invalid framed magic: expected 0x{FRAMED_MAGIC:08X}, got 0x{magic:08X}");
89        }
90        let version = input[4];
91        if version != FRAMED_VERSION {
92            bail!("unsupported framed version: {version}");
93        }
94        let coder = input[5];
95        let original_len = u64::from_le_bytes([
96            input[6], input[7], input[8], input[9], input[10], input[11], input[12], input[13],
97        ]);
98        let crc32 = u32::from_le_bytes([input[14], input[15], input[16], input[17]]);
99        Ok(Self {
100            magic,
101            version,
102            coder,
103            original_len,
104            crc32,
105        })
106    }
107
108    fn coder_type(&self) -> CoderType {
109        match self.coder {
110            0 => CoderType::AC,
111            _ => CoderType::RANS,
112        }
113    }
114}
115
116#[derive(Clone)]
117struct CtwPredictor {
118    tree: FacContextTree,
119    bits_per_symbol: usize,
120    msb_first: bool,
121    pdf: Vec<f64>,
122    pattern_logps: Vec<f64>,
123    valid: bool,
124}
125
126impl CtwPredictor {
127    fn new_ctw(depth: usize) -> Self {
128        Self {
129            tree: FacContextTree::new(depth, 8),
130            bits_per_symbol: 8,
131            msb_first: true,
132            pdf: vec![0.0; 256],
133            pattern_logps: vec![f64::NEG_INFINITY; 256],
134            valid: false,
135        }
136    }
137
138    fn new_fac(base_depth: usize, bits_per_symbol: usize) -> Self {
139        Self {
140            tree: FacContextTree::new(base_depth, bits_per_symbol),
141            bits_per_symbol,
142            msb_first: false,
143            pdf: vec![0.0; 256],
144            pattern_logps: vec![f64::NEG_INFINITY; 256],
145            valid: false,
146        }
147    }
148
149    fn fill_pattern_log_probs(&mut self) -> usize {
150        fn rec(
151            tree: &mut FacContextTree,
152            bits: usize,
153            msb_first: bool,
154            depth: usize,
155            pattern: usize,
156            log_before: f64,
157            out: &mut [f64],
158        ) {
159            if depth == bits {
160                out[pattern] = tree.get_log_block_probability() - log_before;
161                return;
162            }
163            for bit in [false, true] {
164                tree.update(bit, depth);
165                let next_pattern = if msb_first {
166                    (pattern << 1) | (bit as usize)
167                } else {
168                    pattern | ((bit as usize) << depth)
169                };
170                rec(
171                    tree,
172                    bits,
173                    msb_first,
174                    depth + 1,
175                    next_pattern,
176                    log_before,
177                    out,
178                );
179                tree.revert(depth);
180            }
181        }
182
183        let bits = self.bits_per_symbol.clamp(1, 8);
184        let patterns = 1usize << bits;
185        let log_before = self.tree.get_log_block_probability();
186        self.pattern_logps[..patterns].fill(f64::NEG_INFINITY);
187        rec(
188            &mut self.tree,
189            bits,
190            self.msb_first,
191            0,
192            0,
193            log_before,
194            &mut self.pattern_logps[..patterns],
195        );
196        patterns
197    }
198
199    #[cfg(test)]
200    fn log_prob_symbol_bruteforce(&mut self, symbol: u8) -> f64 {
201        let bits = self.bits_per_symbol.clamp(1, 8);
202        let before = self.tree.get_log_block_probability();
203        if self.msb_first {
204            for bit_idx in 0..bits {
205                let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
206                self.tree.update(bit, bit_idx);
207            }
208            let after = self.tree.get_log_block_probability();
209            for bit_idx in (0..bits).rev() {
210                self.tree.revert(bit_idx);
211            }
212            after - before
213        } else {
214            for bit_idx in 0..bits {
215                let bit = ((symbol >> bit_idx) & 1) == 1;
216                self.tree.update(bit, bit_idx);
217            }
218            let after = self.tree.get_log_block_probability();
219            for bit_idx in (0..bits).rev() {
220                self.tree.revert(bit_idx);
221            }
222            after - before
223        }
224    }
225
226    fn normalize_pdf(pdf: &mut [f64]) {
227        let mut sum = 0.0f64;
228        for p in pdf.iter_mut() {
229            let v = if p.is_finite() { *p } else { 0.0 };
230            *p = v.max(PDF_MIN);
231            sum += *p;
232        }
233        if sum <= 0.0 || !sum.is_finite() {
234            let u = 1.0 / (pdf.len() as f64);
235            for p in pdf.iter_mut() {
236                *p = u;
237            }
238            return;
239        }
240        let inv = 1.0 / sum;
241        for p in pdf.iter_mut() {
242            *p *= inv;
243        }
244    }
245
246    fn pdf_next(&mut self) -> &[f64] {
247        if !self.valid {
248            let bits = self.bits_per_symbol.clamp(1, 8);
249            let patterns = self.fill_pattern_log_probs();
250            if bits == 8 {
251                for sym in 0..256usize {
252                    self.pdf[sym] = self.pattern_logps[sym].exp();
253                }
254            } else {
255                let aliases = 1usize << (8 - bits);
256                for byte in 0..256usize {
257                    let pat = if self.msb_first {
258                        byte >> (8 - bits)
259                    } else {
260                        byte & (patterns - 1)
261                    };
262                    self.pdf[byte] = self.pattern_logps[pat].exp() / (aliases as f64);
263                }
264            }
265            Self::normalize_pdf(&mut self.pdf);
266            self.valid = true;
267        }
268        &self.pdf
269    }
270
271    fn update(&mut self, symbol: u8) {
272        if self.msb_first {
273            for bit_idx in 0..self.bits_per_symbol {
274                let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
275                self.tree.update(bit, bit_idx);
276            }
277        } else {
278            for bit_idx in 0..self.bits_per_symbol {
279                let bit = ((symbol >> bit_idx) & 1) == 1;
280                self.tree.update(bit, bit_idx);
281            }
282        }
283        self.valid = false;
284    }
285
286    #[inline]
287    fn can_fast_ac_bitwise(&self) -> bool {
288        self.bits_per_symbol == 8 && self.msb_first
289    }
290
291    #[inline]
292    fn bit_prob_one_msb(&mut self, bit_idx: usize) -> f64 {
293        debug_assert!(self.can_fast_ac_bitwise());
294        self.tree.predict_one(bit_idx).clamp(PDF_MIN, 1.0 - PDF_MIN)
295    }
296
297    #[inline]
298    fn update_bit_msb(&mut self, bit_idx: usize, bit: bool) {
299        debug_assert!(self.can_fast_ac_bitwise());
300        self.tree.update_predicted(bit, bit_idx);
301        self.valid = false;
302    }
303}
304
305#[derive(Clone)]
306struct RosaPredictor {
307    model: RosaPlus,
308    pdf: Vec<f64>,
309    cdf: [f64; 257],
310    valid: bool,
311    cdf_valid: bool,
312}
313
314impl RosaPredictor {
315    fn new(max_order: i64) -> Self {
316        let mut model = RosaPlus::new(max_order, false, 0, 42);
317        model.build_lm_full_bytes_no_finalize_endpos();
318        Self {
319            model,
320            pdf: vec![0.0; 256],
321            cdf: uniform_cdf_row(),
322            valid: false,
323            cdf_valid: false,
324        }
325    }
326
327    fn pdf_next(&mut self) -> &[f64] {
328        self.ensure_pdf(false);
329        &self.pdf
330    }
331
332    fn cdf_next(&mut self) -> &[f64; 257] {
333        self.ensure_pdf(true);
334        &self.cdf
335    }
336
337    fn ensure_pdf(&mut self, want_cdf: bool) {
338        if self.valid {
339            if want_cdf && !self.cdf_valid {
340                build_cdf_row_from_pdf_slice(&self.pdf, &mut self.cdf);
341                self.cdf_valid = true;
342            }
343            return;
344        }
345        self.model.fill_probs_for_last_bytes(&mut self.pdf);
346        normalize_pdf_vec_and_maybe_build_cdf(
347            &mut self.pdf,
348            if want_cdf { Some(&mut self.cdf) } else { None },
349        );
350        self.valid = true;
351        self.cdf_valid = want_cdf;
352    }
353
354    fn update(&mut self, symbol: u8) {
355        self.model.train_byte(symbol);
356        self.valid = false;
357        self.cdf_valid = false;
358    }
359
360    fn begin_stream(&mut self, total_len: usize) {
361        self.model.reserve_for_stream(total_len);
362    }
363}
364
365#[derive(Clone)]
366#[cfg(feature = "backend-mamba")]
367struct MambaPredictor {
368    compressor: mambazip::Compressor,
369    primed: bool,
370    pdf: Vec<f64>,
371    cdf: [f64; 257],
372    valid: bool,
373    cdf_valid: bool,
374}
375
376#[derive(Clone)]
377#[cfg(feature = "backend-rwkv")]
378struct RwkvPredictor {
379    compressor: rwkvzip::Compressor,
380    primed: bool,
381    cdf: [f64; 257],
382    cdf_valid: bool,
383}
384
385#[derive(Clone)]
386struct ZpaqPredictor {
387    method: String,
388    history: Vec<u8>,
389    pdf: Vec<f64>,
390    valid: bool,
391}
392
393impl ZpaqPredictor {
394    fn new(method: String) -> Self {
395        Self {
396            method,
397            history: Vec::new(),
398            pdf: vec![0.0; 256],
399            valid: false,
400        }
401    }
402
403    fn pdf_next(&mut self) -> &[f64] {
404        if !self.valid {
405            for sym in 0..256usize {
406                let mut model = ZpaqRateModel::new(self.method.clone(), PDF_MIN);
407                if !self.history.is_empty() {
408                    let _ = model.update_and_score(&self.history);
409                }
410                let logp = model.log_prob(sym as u8);
411                self.pdf[sym] = logp.exp().max(PDF_MIN);
412            }
413            normalize_pdf(&mut self.pdf);
414            self.valid = true;
415        }
416        &self.pdf
417    }
418
419    fn update(&mut self, symbol: u8) {
420        self.history.push(symbol);
421        self.valid = false;
422    }
423}
424
425#[cfg(feature = "backend-mamba")]
426impl MambaPredictor {
427    fn from_model(model: std::sync::Arc<mambazip::Model>) -> Self {
428        let compressor = mambazip::Compressor::new_from_model(model);
429        let vocab = compressor.vocab_size();
430        Self {
431            compressor,
432            primed: false,
433            pdf: vec![0.0; vocab],
434            cdf: uniform_cdf_row(),
435            valid: false,
436            cdf_valid: false,
437        }
438    }
439
440    fn from_method(method: &str) -> Result<Self> {
441        let compressor = mambazip::Compressor::new_from_method(method)?;
442        let vocab = compressor.vocab_size();
443        Ok(Self {
444            compressor,
445            primed: false,
446            pdf: vec![0.0; vocab],
447            cdf: uniform_cdf_row(),
448            valid: false,
449            cdf_valid: false,
450        })
451    }
452
453    fn ensure_predicted(&mut self, want_cdf: bool) {
454        if self.valid {
455            if want_cdf && !self.cdf_valid {
456                debug_assert!(self.pdf.len() >= 256);
457                build_cdf_row_from_pdf_slice(&self.pdf[..256], &mut self.cdf);
458                self.cdf_valid = true;
459            }
460            return;
461        }
462        if !self.primed {
463            self.compressor.forward_to_pdf(0, &mut self.pdf);
464            self.primed = true;
465            self.valid = true;
466            self.cdf_valid = false;
467            if want_cdf {
468                debug_assert!(self.pdf.len() >= 256);
469                build_cdf_row_from_pdf_slice(&self.pdf[..256], &mut self.cdf);
470                self.cdf_valid = true;
471            }
472            return;
473        }
474        self.valid = true;
475        self.cdf_valid = false;
476        if want_cdf {
477            debug_assert!(self.pdf.len() >= 256);
478            build_cdf_row_from_pdf_slice(&self.pdf[..256], &mut self.cdf);
479            self.cdf_valid = true;
480        }
481    }
482
483    fn pdf_next(&mut self) -> &[f64] {
484        self.ensure_predicted(false);
485        &self.pdf
486    }
487
488    fn cdf_next(&mut self) -> &[f64; 257] {
489        self.ensure_predicted(true);
490        &self.cdf
491    }
492
493    fn update(&mut self, symbol: u8) -> Result<()> {
494        self.ensure_predicted(false);
495        self.compressor.online_update_from_pdf(symbol, &self.pdf)?;
496        self.compressor.forward_to_pdf(symbol as u32, &mut self.pdf);
497        self.valid = true;
498        self.cdf_valid = false;
499        Ok(())
500    }
501
502    fn begin_stream(&mut self, total_len: usize) -> Result<()> {
503        self.compressor
504            .begin_online_policy_stream(Some(total_len as u64))
505    }
506}
507
508#[cfg(feature = "backend-rwkv")]
509impl RwkvPredictor {
510    fn from_model(model: std::sync::Arc<rwkvzip::Model>) -> Self {
511        let compressor = rwkvzip::Compressor::new_from_model(model);
512        Self {
513            compressor,
514            primed: false,
515            cdf: uniform_cdf_row(),
516            cdf_valid: false,
517        }
518    }
519
520    fn from_method(method: &str) -> Result<Self> {
521        let compressor = rwkvzip::Compressor::new_from_method(method)?;
522        Ok(Self {
523            compressor,
524            primed: false,
525            cdf: uniform_cdf_row(),
526            cdf_valid: false,
527        })
528    }
529
530    fn ensure_predicted(&mut self, want_cdf: bool) {
531        if !self.primed {
532            self.compressor.reset_and_prime();
533            self.primed = true;
534            self.cdf_valid = false;
535        }
536        if want_cdf && !self.cdf_valid {
537            debug_assert!(self.compressor.pdf_buffer.len() >= 256);
538            build_cdf_row_from_pdf_slice(&self.compressor.pdf_buffer[..256], &mut self.cdf);
539            self.cdf_valid = true;
540        }
541    }
542
543    fn pdf_next(&mut self) -> &[f64] {
544        self.ensure_predicted(false);
545        &self.compressor.pdf_buffer
546    }
547
548    fn cdf_next(&mut self) -> &[f64; 257] {
549        self.ensure_predicted(true);
550        &self.cdf
551    }
552
553    fn update(&mut self, symbol: u8) -> Result<()> {
554        self.ensure_predicted(false);
555        self.compressor.observe_symbol_from_current_pdf(symbol)?;
556        self.cdf_valid = false;
557        Ok(())
558    }
559
560    fn begin_stream(&mut self, total_len: usize) -> Result<()> {
561        self.compressor
562            .begin_online_policy_stream(Some(total_len as u64))
563    }
564
565    fn finish_stream(&mut self) -> Result<()> {
566        self.compressor.finish_online_policy_stream()
567    }
568}
569
570#[derive(Clone)]
571struct MixExpert {
572    predictor: Box<RatePdfPredictor>,
573    log_weight: f64,
574    log_prior: f64,
575    cum_log_loss: f64,
576}
577
578#[derive(Clone)]
579struct MixturePredictor {
580    kind: MixtureKind,
581    alpha: f64,
582    decay: f64,
583    experts: Vec<MixExpert>,
584    neural: NeuralMixCore,
585    analyzer: TextContextAnalyzer,
586    neural_logps: Vec<f64>,
587    neural_bit_modes: Vec<u8>,
588    neural_lo: Vec<usize>,
589    neural_hi: Vec<usize>,
590    neural_pdf_cdf_rows: Vec<Vec<f64>>,
591    scratch: Vec<f64>,
592    scratch2: Vec<f64>,
593    pdf: Vec<f64>,
594    valid: bool,
595}
596
597impl MixturePredictor {
598    fn new(spec: &MixtureSpec) -> Result<Self> {
599        if spec.experts.is_empty() {
600            bail!("mixture spec must include at least one expert");
601        }
602        let mut experts = Vec::with_capacity(spec.experts.len());
603        for e in &spec.experts {
604            experts.push(MixExpert {
605                predictor: Box::new(RatePdfPredictor::from_rate_backend(
606                    e.backend.clone(),
607                    e.max_order,
608                )?),
609                log_weight: e.log_prior,
610                log_prior: e.log_prior,
611                cum_log_loss: 0.0,
612            });
613        }
614        let m = logsumexp_expert_weights(&experts);
615        for e in &mut experts {
616            e.log_weight -= m;
617        }
618
619        let mut prior_weights = vec![0.0; experts.len()];
620        for (i, e) in experts.iter().enumerate() {
621            let p = (e.log_weight).exp().clamp(PDF_MIN, 1.0 - PDF_MIN);
622            prior_weights[i] = p;
623        }
624
625        let base_lr = spec.alpha.abs().clamp(1e-6, 1.0);
626        let effective_lr = (base_lr * 25.0).clamp(1e-6, 1.0);
627        let analyzer = TextContextAnalyzer::new();
628        let mut neural = NeuralMixCore::new(
629            experts.len(),
630            &prior_weights,
631            effective_lr * 0.5,
632            effective_lr,
633            1e-5,
634        );
635        neural.set_context_state(analyzer.state());
636        Ok(Self {
637            kind: spec.kind,
638            alpha: spec.alpha.clamp(1e-12, 1.0 - 1e-12),
639            decay: spec.decay.unwrap_or(1.0).clamp(0.0, 1.0),
640            experts,
641            neural,
642            analyzer,
643            neural_logps: vec![0.0; spec.experts.len()],
644            neural_bit_modes: vec![0; spec.experts.len()],
645            neural_lo: vec![0; spec.experts.len()],
646            neural_hi: vec![256; spec.experts.len()],
647            neural_pdf_cdf_rows: vec![vec![0.0; 257]; spec.experts.len()],
648            scratch: Vec::new(),
649            scratch2: Vec::new(),
650            pdf: vec![0.0; 256],
651            valid: false,
652        })
653    }
654
655    fn ensure_pdf(&mut self) -> Result<&[f64]> {
656        if self.valid {
657            return Ok(&self.pdf);
658        }
659        match self.kind {
660            MixtureKind::Neural => {
661                if self.experts.len() == 1 {
662                    self.pdf.fill(0.0);
663                    let epdf = self.experts[0].predictor.pdf_next()?;
664                    self.pdf.copy_from_slice(epdf);
665                    normalize_pdf(&mut self.pdf);
666                    self.valid = true;
667                    return Ok(&self.pdf);
668                }
669                self.neural.set_context_state(self.analyzer.state());
670                self.neural.evaluate_expert_weights();
671                let n = self.experts.len();
672                self.scratch.resize(n, 0.0);
673                self.scratch.copy_from_slice(self.neural.expert_weights());
674                self.pdf.fill(0.0);
675                for i in 0..n {
676                    let epdf = self.experts[i].predictor.pdf_next()?;
677                    let w = self.scratch[i];
678                    for (pdf_slot, &p) in self.pdf.iter_mut().zip(epdf.iter()) {
679                        *pdf_slot += w * p;
680                    }
681                }
682                normalize_pdf(&mut self.pdf);
683                self.valid = true;
684                return Ok(&self.pdf);
685            }
686            _ => {
687                self.pdf.fill(0.0);
688
689                let lw_norm = logsumexp_expert_weights(&self.experts);
690                for e in &mut self.experts {
691                    let w = (e.log_weight - lw_norm).exp();
692                    let epdf = e.predictor.pdf_next()?;
693                    for (i, p) in epdf.iter().enumerate().take(256) {
694                        self.pdf[i] += w * *p;
695                    }
696                }
697            }
698        }
699
700        normalize_pdf(&mut self.pdf);
701        self.valid = true;
702        Ok(&self.pdf)
703    }
704
705    fn begin_stream(&mut self, total_len: usize) -> Result<()> {
706        for expert in &mut self.experts {
707            match &mut *expert.predictor {
708                // Direct CTW benefits from pre-reserving, but inside mixtures that extra
709                // headroom can dominate peak RSS without a proportional runtime gain.
710                RatePdfPredictor::Ctw(_) | RatePdfPredictor::FacCtw(_) => {}
711                _ => expert.predictor.begin_stream(total_len)?,
712            }
713        }
714        Ok(())
715    }
716
717    fn update(&mut self, symbol: u8) -> Result<()> {
718        let _ = self.ensure_pdf()?;
719
720        match self.kind {
721            MixtureKind::Bayes => {
722                let n = self.experts.len();
723                self.scratch.resize(n, 0.0);
724                self.scratch2.resize(n, 0.0);
725                for (i, e) in self.experts.iter_mut().enumerate() {
726                    let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
727                    let lp = p.ln();
728                    self.scratch[i] = lp;
729                    self.scratch2[i] = e.log_weight + lp;
730                }
731                let log_mix = logsumexp_slice(&self.scratch2[..n]);
732                for (i, e) in self.experts.iter_mut().enumerate() {
733                    e.log_weight = e.log_weight + self.scratch[i] - log_mix;
734                    e.cum_log_loss -= self.scratch[i];
735                    e.predictor.update(symbol)?;
736                }
737            }
738            MixtureKind::FadingBayes => {
739                let n = self.experts.len();
740                self.scratch.resize(n, 0.0);
741                self.scratch2.resize(n, 0.0);
742                for (i, e) in self.experts.iter_mut().enumerate() {
743                    let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
744                    let lp = p.ln();
745                    self.scratch[i] = lp;
746                    self.scratch2[i] = e.log_weight + lp;
747                }
748                for (i, e) in self.experts.iter_mut().enumerate() {
749                    self.scratch2[i] = self.decay * e.log_weight + self.scratch[i];
750                }
751                let log_mix = logsumexp_slice(&self.scratch2[..n]);
752                for (i, e) in self.experts.iter_mut().enumerate() {
753                    e.log_weight = self.decay * e.log_weight + self.scratch[i] - log_mix;
754                    e.cum_log_loss -= self.scratch[i];
755                    e.predictor.update(symbol)?;
756                }
757            }
758            MixtureKind::Switching => {
759                let n = self.experts.len();
760                self.scratch.resize(n, 0.0);
761                self.scratch2.resize(n, 0.0);
762                for (i, e) in self.experts.iter_mut().enumerate() {
763                    let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
764                    let lp = p.ln();
765                    self.scratch[i] = lp;
766                    self.scratch2[i] = e.log_weight + lp;
767                }
768                let log_alpha = self.alpha.ln();
769                let log_1m_alpha = (1.0 - self.alpha).ln();
770                for (i, e) in self.experts.iter_mut().enumerate() {
771                    let switched = logsumexp2(log_1m_alpha + e.log_weight, log_alpha + e.log_prior);
772                    self.scratch2[i] = switched + self.scratch[i];
773                }
774                let log_mix = logsumexp_slice(&self.scratch2[..n]);
775                for (i, e) in self.experts.iter_mut().enumerate() {
776                    e.log_weight = self.scratch2[i] - log_mix;
777                    e.cum_log_loss -= self.scratch[i];
778                    e.predictor.update(symbol)?;
779                }
780            }
781            MixtureKind::Mdl => {
782                let n = self.experts.len();
783                self.scratch.resize(n, 0.0);
784                for (i, e) in self.experts.iter_mut().enumerate() {
785                    let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
786                    let lp = p.ln();
787                    self.scratch[i] = lp;
788                }
789                for (i, e) in self.experts.iter_mut().enumerate() {
790                    e.cum_log_loss -= self.scratch[i];
791                    e.predictor.update(symbol)?;
792                }
793            }
794            MixtureKind::Neural => {
795                let y = symbol as usize;
796                if self.experts.len() == 1 {
797                    let lp = self.experts[0].predictor.pdf_next()?[y].max(PDF_MIN).ln();
798                    self.experts[0].cum_log_loss -= lp;
799                    self.experts[0].predictor.update(symbol)?;
800                    self.analyzer.update(symbol);
801                    self.neural.set_context_state(self.analyzer.state());
802                    self.valid = false;
803                    return Ok(());
804                }
805                let n = self.experts.len();
806                self.neural.set_context_state(self.analyzer.state());
807                self.neural_logps.resize(n, 0.0);
808                for i in 0..n {
809                    let p = self.experts[i].predictor.pdf_next()?[y].max(PDF_MIN);
810                    let lp = p.ln();
811                    self.neural_logps[i] = lp;
812                    self.experts[i].cum_log_loss -= lp;
813                }
814                self.neural.evaluate_symbol(&self.neural_logps, PDF_MIN);
815                self.neural
816                    .update_weights_symbol(&self.neural_logps, PDF_MIN);
817                for e in &mut self.experts {
818                    e.predictor.update(symbol)?;
819                }
820                self.analyzer.update(symbol);
821                self.neural.set_context_state(self.analyzer.state());
822            }
823        }
824
825        self.valid = false;
826        Ok(())
827    }
828
829    fn finish_stream(&mut self) -> Result<()> {
830        for expert in &mut self.experts {
831            expert.predictor.finish_stream()?;
832        }
833        Ok(())
834    }
835
836    #[inline]
837    fn can_fast_ac_bitwise(&self) -> bool {
838        self.experts.iter().any(|e| {
839            if let RatePdfPredictor::Ctw(ctw) = &*e.predictor {
840                ctw.can_fast_ac_bitwise()
841            } else {
842                false
843            }
844        })
845    }
846
847    fn ac_step_bitwise<F>(&mut self, mut choose_bit: F) -> Result<u8>
848    where
849        F: FnMut(usize, f64) -> Result<u8>,
850    {
851        let n = self.experts.len();
852        self.scratch.resize(n, 0.0);
853        match self.kind {
854            MixtureKind::Neural if n > 1 => {
855                self.neural.set_context_state(self.analyzer.state());
856                self.neural.evaluate_expert_weights();
857                self.scratch.copy_from_slice(self.neural.expert_weights());
858            }
859            _ => {
860                let lw_norm = logsumexp_expert_weights(&self.experts);
861                for (i, expert) in self.experts.iter().enumerate() {
862                    self.scratch[i] = (expert.log_weight - lw_norm).exp();
863                }
864            }
865        }
866        self.scratch2.resize(n, 1.0);
867        self.scratch2.fill(1.0);
868        self.neural_logps.resize(n, 0.0);
869        self.neural_bit_modes.resize(n, 0);
870        self.neural_lo.resize(n, 0);
871        self.neural_hi.resize(n, 256);
872        if self.neural_pdf_cdf_rows.len() < n {
873            self.neural_pdf_cdf_rows.resize_with(n, || vec![0.0; 257]);
874        }
875
876        for i in 0..n {
877            self.neural_bit_modes[i] = 1;
878            self.neural_lo[i] = 0;
879            self.neural_hi[i] = 256;
880
881            let mut handled_ctw = false;
882            if let RatePdfPredictor::Ctw(ctw) = &mut *self.experts[i].predictor
883                && ctw.can_fast_ac_bitwise()
884            {
885                self.neural_bit_modes[i] = 0;
886                handled_ctw = true;
887            }
888            if handled_ctw {
889                continue;
890            }
891
892            if self.experts[i]
893                .predictor
894                .prepare_cached_cdf_fast_bitwise()?
895            {
896                self.neural_bit_modes[i] = 2;
897                continue;
898            }
899
900            let pdf = self.experts[i].predictor.pdf_next()?;
901            let row = &mut self.neural_pdf_cdf_rows[i];
902            if row.len() != 257 {
903                row.resize(257, 0.0);
904            }
905            row[0] = 0.0;
906            for b in 0..256usize {
907                row[b + 1] = row[b] + pdf[b].max(PDF_MIN);
908            }
909            if !row[256].is_finite() || row[256] <= 0.0 {
910                for (j, v) in row.iter_mut().enumerate() {
911                    *v = (j as f64) / 256.0;
912                }
913            }
914        }
915
916        let mut symbol = 0u8;
917        for bit_idx in 0..8usize {
918            let mut denom = 0.0;
919            let mut numer1 = 0.0;
920
921            for i in 0..n {
922                let p1 = if self.neural_bit_modes[i] == 0 {
923                    match &mut *self.experts[i].predictor {
924                        RatePdfPredictor::Ctw(ctw) => ctw.bit_prob_one_msb(bit_idx),
925                        _ => 0.5,
926                    }
927                } else if self.neural_bit_modes[i] == 2 {
928                    self.experts[i]
929                        .predictor
930                        .cached_cdf_bit_prob_one_msb(self.neural_lo[i], self.neural_hi[i])
931                        .unwrap_or(0.5)
932                } else {
933                    let lo = self.neural_lo[i];
934                    let hi = self.neural_hi[i];
935                    let mid = (lo + hi) >> 1;
936                    let row = &self.neural_pdf_cdf_rows[i];
937                    let total = (row[hi] - row[lo]).max(PDF_MIN);
938                    let one = (row[hi] - row[mid]).max(0.0);
939                    (one / total).clamp(PDF_MIN, 1.0 - PDF_MIN)
940                };
941                self.neural_logps[i] = p1;
942                let wp = self.scratch[i] * self.scratch2[i];
943                denom += wp;
944                numer1 += wp * p1;
945            }
946
947            let p1_mix = if denom.is_finite() && denom > 0.0 {
948                (numer1 / denom).clamp(PDF_MIN, 1.0 - PDF_MIN)
949            } else {
950                0.5
951            };
952            let bit = choose_bit(bit_idx, p1_mix)? & 1;
953            symbol |= bit << (7 - bit_idx);
954
955            for i in 0..n {
956                let p1 = self.neural_logps[i];
957                let pb = if bit == 1 { p1 } else { 1.0 - p1 };
958                self.scratch2[i] = (self.scratch2[i] * pb).max(PDF_MIN);
959
960                if self.neural_bit_modes[i] == 0 {
961                    if let RatePdfPredictor::Ctw(ctw) = &mut *self.experts[i].predictor {
962                        ctw.update_bit_msb(bit_idx, bit == 1);
963                    }
964                } else {
965                    let lo = self.neural_lo[i];
966                    let hi = self.neural_hi[i];
967                    let mid = (lo + hi) >> 1;
968                    if bit == 1 {
969                        self.neural_lo[i] = mid;
970                        self.neural_hi[i] = hi;
971                    } else {
972                        self.neural_lo[i] = lo;
973                        self.neural_hi[i] = mid;
974                    }
975                }
976            }
977        }
978
979        for i in 0..n {
980            let lp = self.scratch2[i].max(PDF_MIN).ln();
981            self.neural_logps[i] = lp;
982            self.experts[i].cum_log_loss -= lp;
983            if self.neural_bit_modes[i] != 0 {
984                self.experts[i].predictor.update(symbol)?;
985            }
986        }
987
988        match self.kind {
989            MixtureKind::Bayes => {
990                for i in 0..n {
991                    self.scratch[i] = self.experts[i].log_weight + self.neural_logps[i];
992                }
993                let log_mix = logsumexp_slice(&self.scratch[..n]);
994                for i in 0..n {
995                    self.experts[i].log_weight += self.neural_logps[i] - log_mix;
996                }
997            }
998            MixtureKind::FadingBayes => {
999                for i in 0..n {
1000                    self.scratch[i] =
1001                        self.decay * self.experts[i].log_weight + self.neural_logps[i];
1002                }
1003                let log_mix = logsumexp_slice(&self.scratch[..n]);
1004                for i in 0..n {
1005                    self.experts[i].log_weight = self.scratch[i] - log_mix;
1006                }
1007            }
1008            MixtureKind::Switching => {
1009                let log_alpha = self.alpha.ln();
1010                let log_1m_alpha = (1.0 - self.alpha).ln();
1011                for i in 0..n {
1012                    let expert = &self.experts[i];
1013                    let switched = logsumexp2(
1014                        log_1m_alpha + expert.log_weight,
1015                        log_alpha + expert.log_prior,
1016                    );
1017                    self.scratch[i] = switched + self.neural_logps[i];
1018                }
1019                let log_mix = logsumexp_slice(&self.scratch[..n]);
1020                for i in 0..n {
1021                    self.experts[i].log_weight = self.scratch[i] - log_mix;
1022                }
1023            }
1024            MixtureKind::Mdl => {}
1025            MixtureKind::Neural => {
1026                if n > 1 {
1027                    self.neural.set_context_state(self.analyzer.state());
1028                    self.neural.evaluate_symbol(&self.neural_logps, PDF_MIN);
1029                    self.neural
1030                        .update_weights_symbol(&self.neural_logps, PDF_MIN);
1031                }
1032                self.analyzer.update(symbol);
1033                self.neural.set_context_state(self.analyzer.state());
1034            }
1035        }
1036        self.valid = false;
1037        Ok(symbol)
1038    }
1039}
1040
1041#[derive(Clone)]
1042#[allow(clippy::large_enum_variant)]
1043enum RatePdfPredictor {
1044    Rosa(RosaPredictor),
1045    Match {
1046        model: MatchModel,
1047    },
1048    SparseMatch {
1049        model: SparseMatchModel,
1050    },
1051    Ppmd {
1052        model: PpmdModel,
1053    },
1054    Ctw(CtwPredictor),
1055    FacCtw(CtwPredictor),
1056    #[cfg(feature = "backend-mamba")]
1057    Mamba(MambaPredictor),
1058    #[cfg(feature = "backend-rwkv")]
1059    Rwkv(RwkvPredictor),
1060    Zpaq(ZpaqPredictor),
1061    Mixture(MixturePredictor),
1062    Particle(crate::particle::ParticleRuntime),
1063    Calibrated {
1064        base: Box<RatePdfPredictor>,
1065        core: CalibratorCore,
1066        pdf: Vec<f64>,
1067        valid: bool,
1068    },
1069}
1070
1071impl RatePdfPredictor {
1072    fn from_rate_backend(backend: RateBackend, max_order: i64) -> Result<Self> {
1073        match backend {
1074            RateBackend::RosaPlus => Ok(Self::Rosa(RosaPredictor::new(max_order))),
1075            RateBackend::Match {
1076                hash_bits,
1077                min_len,
1078                max_len,
1079                base_mix,
1080                confidence_scale,
1081            } => Ok(Self::Match {
1082                model: MatchModel::new_contiguous(
1083                    hash_bits,
1084                    min_len,
1085                    max_len,
1086                    base_mix,
1087                    confidence_scale,
1088                ),
1089            }),
1090            RateBackend::SparseMatch {
1091                hash_bits,
1092                min_len,
1093                max_len,
1094                gap_min,
1095                gap_max,
1096                base_mix,
1097                confidence_scale,
1098            } => Ok(Self::SparseMatch {
1099                model: SparseMatchModel::new(
1100                    hash_bits,
1101                    min_len,
1102                    max_len,
1103                    gap_min,
1104                    gap_max,
1105                    base_mix,
1106                    confidence_scale,
1107                ),
1108            }),
1109            RateBackend::Ppmd { order, memory_mb } => Ok(Self::Ppmd {
1110                model: PpmdModel::new(order, memory_mb),
1111            }),
1112            RateBackend::Ctw { depth } => Ok(Self::Ctw(CtwPredictor::new_ctw(depth))),
1113            RateBackend::FacCtw {
1114                base_depth,
1115                num_percept_bits: _,
1116                encoding_bits,
1117            } => {
1118                let bits = encoding_bits.clamp(1, 8);
1119                Ok(Self::FacCtw(CtwPredictor::new_fac(base_depth, bits)))
1120            }
1121            #[cfg(feature = "backend-mamba")]
1122            RateBackend::Mamba { model } => Ok(Self::Mamba(MambaPredictor::from_model(model))),
1123            #[cfg(feature = "backend-mamba")]
1124            RateBackend::MambaMethod { method } => {
1125                Ok(Self::Mamba(MambaPredictor::from_method(&method)?))
1126            }
1127            #[cfg(feature = "backend-rwkv")]
1128            RateBackend::Rwkv7 { model } => Ok(Self::Rwkv(RwkvPredictor::from_model(model))),
1129            #[cfg(feature = "backend-rwkv")]
1130            RateBackend::Rwkv7Method { method } => {
1131                Ok(Self::Rwkv(RwkvPredictor::from_method(&method)?))
1132            }
1133            RateBackend::Zpaq { method } => Ok(Self::Zpaq(ZpaqPredictor::new(method))),
1134            RateBackend::Mixture { spec } => {
1135                Ok(Self::Mixture(MixturePredictor::new(spec.as_ref())?))
1136            }
1137            RateBackend::Particle { spec } => Ok(Self::Particle(
1138                crate::particle::ParticleRuntime::new(spec.as_ref()),
1139            )),
1140            RateBackend::Calibrated { spec } => Ok(Self::Calibrated {
1141                base: Box::new(Self::from_rate_backend(spec.base.clone(), max_order)?),
1142                core: build_calibrator(spec.as_ref()),
1143                pdf: vec![1.0 / 256.0; 256],
1144                valid: false,
1145            }),
1146        }
1147    }
1148
1149    fn begin_stream(&mut self, total_len: usize) -> Result<()> {
1150        self.finish_stream()?;
1151        match self {
1152            Self::Rosa(m) => {
1153                m.begin_stream(total_len);
1154                Ok(())
1155            }
1156            Self::Match { .. }
1157            | Self::SparseMatch { .. }
1158            | Self::Ppmd { .. }
1159            | Self::Zpaq(_)
1160            | Self::Particle(_) => Ok(()),
1161            Self::Ctw(m) | Self::FacCtw(m) => {
1162                m.tree.reserve_for_symbols(total_len);
1163                Ok(())
1164            }
1165            #[cfg(feature = "backend-mamba")]
1166            Self::Mamba(m) => m.begin_stream(total_len),
1167            #[cfg(feature = "backend-rwkv")]
1168            Self::Rwkv(m) => m.begin_stream(total_len),
1169            Self::Mixture(m) => m.begin_stream(total_len),
1170            Self::Calibrated { base, .. } => base.begin_stream(total_len),
1171        }
1172    }
1173
1174    fn finish_stream(&mut self) -> Result<()> {
1175        match self {
1176            Self::Rosa(_)
1177            | Self::Match { .. }
1178            | Self::SparseMatch { .. }
1179            | Self::Ppmd { .. }
1180            | Self::Ctw(_)
1181            | Self::FacCtw(_)
1182            | Self::Zpaq(_)
1183            | Self::Particle(_) => Ok(()),
1184            #[cfg(feature = "backend-mamba")]
1185            Self::Mamba(_) => Ok(()),
1186            #[cfg(feature = "backend-rwkv")]
1187            Self::Rwkv(m) => m.finish_stream(),
1188            Self::Mixture(m) => m.finish_stream(),
1189            Self::Calibrated { base, .. } => base.finish_stream(),
1190        }
1191    }
1192
1193    fn pdf_next(&mut self) -> Result<&[f64]> {
1194        match self {
1195            Self::Rosa(m) => Ok(m.pdf_next()),
1196            Self::Match { model } => Ok(model.pdf()),
1197            Self::Ctw(m) => Ok(m.pdf_next()),
1198            Self::FacCtw(m) => Ok(m.pdf_next()),
1199            #[cfg(feature = "backend-mamba")]
1200            Self::Mamba(m) => Ok(m.pdf_next()),
1201            #[cfg(feature = "backend-rwkv")]
1202            Self::Rwkv(m) => Ok(m.pdf_next()),
1203            Self::Zpaq(m) => Ok(m.pdf_next()),
1204            Self::Mixture(m) => m.ensure_pdf(),
1205            Self::Particle(m) => Ok(m.pdf_next()),
1206            Self::SparseMatch { model } => Ok(model.pdf()),
1207            Self::Ppmd { model } => Ok(model.pdf()),
1208            Self::Calibrated {
1209                base,
1210                core,
1211                pdf,
1212                valid,
1213            } => {
1214                if !*valid {
1215                    let base_pdf = base.pdf_next()?;
1216                    core.apply_pdf(base_pdf, pdf);
1217                    normalize_pdf(pdf);
1218                    *valid = true;
1219                }
1220                Ok(pdf)
1221            }
1222        }
1223    }
1224
1225    fn update(&mut self, symbol: u8) -> Result<()> {
1226        match self {
1227            Self::Rosa(m) => {
1228                m.update(symbol);
1229                Ok(())
1230            }
1231            Self::Match { model } => {
1232                model.update(symbol);
1233                Ok(())
1234            }
1235            Self::SparseMatch { model } => {
1236                model.update(symbol);
1237                Ok(())
1238            }
1239            Self::Ppmd { model } => {
1240                model.update(symbol);
1241                Ok(())
1242            }
1243            Self::Ctw(m) => {
1244                m.update(symbol);
1245                Ok(())
1246            }
1247            Self::FacCtw(m) => {
1248                m.update(symbol);
1249                Ok(())
1250            }
1251            #[cfg(feature = "backend-mamba")]
1252            Self::Mamba(m) => m.update(symbol),
1253            #[cfg(feature = "backend-rwkv")]
1254            Self::Rwkv(m) => m.update(symbol),
1255            Self::Zpaq(m) => {
1256                m.update(symbol);
1257                Ok(())
1258            }
1259            Self::Mixture(m) => m.update(symbol),
1260            Self::Particle(m) => {
1261                m.step(symbol);
1262                Ok(())
1263            }
1264            Self::Calibrated {
1265                base,
1266                core,
1267                pdf,
1268                valid,
1269            } => {
1270                if !*valid {
1271                    let base_pdf = base.pdf_next()?;
1272                    core.apply_pdf(base_pdf, pdf);
1273                    normalize_pdf(pdf);
1274                }
1275                core.update(symbol, pdf);
1276                base.update(symbol)?;
1277                *valid = false;
1278                Ok(())
1279            }
1280        }
1281    }
1282
1283    fn prepare_cached_cdf_fast_bitwise(&mut self) -> Result<bool> {
1284        match self {
1285            Self::Rosa(m) => {
1286                let _ = m.cdf_next();
1287                Ok(true)
1288            }
1289            Self::Match { model } => {
1290                let _ = model.cdf();
1291                Ok(true)
1292            }
1293            Self::SparseMatch { model } => {
1294                let _ = model.cdf();
1295                Ok(true)
1296            }
1297            Self::Ppmd { model } => {
1298                let _ = model.cdf();
1299                Ok(true)
1300            }
1301            #[cfg(feature = "backend-mamba")]
1302            Self::Mamba(m) => {
1303                let _ = m.cdf_next();
1304                Ok(true)
1305            }
1306            #[cfg(feature = "backend-rwkv")]
1307            Self::Rwkv(m) => {
1308                let _ = m.cdf_next();
1309                Ok(true)
1310            }
1311            _ => Ok(false),
1312        }
1313    }
1314
1315    fn cached_cdf_bit_prob_one_msb(&mut self, lo: usize, hi: usize) -> Option<f64> {
1316        match self {
1317            Self::Rosa(m) => Some(cdf_bit_prob_one_msb(&m.cdf, lo, hi)),
1318            Self::Match { model } => Some(cdf_bit_prob_one_msb(model.cdf(), lo, hi)),
1319            Self::SparseMatch { model } => Some(cdf_bit_prob_one_msb(model.cdf(), lo, hi)),
1320            Self::Ppmd { model } => Some(cdf_bit_prob_one_msb(model.cdf(), lo, hi)),
1321            #[cfg(feature = "backend-mamba")]
1322            Self::Mamba(m) => Some(cdf_bit_prob_one_msb(m.cdf_next(), lo, hi)),
1323            #[cfg(feature = "backend-rwkv")]
1324            Self::Rwkv(m) => Some(cdf_bit_prob_one_msb(m.cdf_next(), lo, hi)),
1325            _ => None,
1326        }
1327    }
1328
1329    #[inline]
1330    fn can_fast_ac_bitwise(&self) -> bool {
1331        match self {
1332            Self::Ctw(m) => m.can_fast_ac_bitwise(),
1333            Self::Mixture(m) => m.can_fast_ac_bitwise(),
1334            _ => false,
1335        }
1336    }
1337
1338    fn ac_step_fast_bitwise<F>(&mut self, choose_bit: F) -> Result<u8>
1339    where
1340        F: FnMut(usize, f64) -> Result<u8>,
1341    {
1342        match self {
1343            Self::Ctw(m) => ctw_ac_step_bitwise(m, choose_bit),
1344            Self::Mixture(m) => m.ac_step_bitwise(choose_bit),
1345            _ => unreachable!("fast bitwise path requested for unsupported predictor"),
1346        }
1347    }
1348}
1349
1350fn ctw_ac_step_bitwise<F>(ctw: &mut CtwPredictor, mut choose_bit: F) -> Result<u8>
1351where
1352    F: FnMut(usize, f64) -> Result<u8>,
1353{
1354    debug_assert!(ctw.can_fast_ac_bitwise());
1355    let mut symbol = 0u8;
1356    for bit_idx in 0..8usize {
1357        let p1 = ctw.bit_prob_one_msb(bit_idx);
1358        let bit = choose_bit(bit_idx, p1)? & 1;
1359        symbol |= bit << (7 - bit_idx);
1360        ctw.update_bit_msb(bit_idx, bit == 1);
1361    }
1362    Ok(symbol)
1363}
1364
1365#[inline]
1366fn binary_split_from_prob_one(p1: f64) -> u32 {
1367    let p1 = p1.clamp(PDF_MIN, 1.0 - PDF_MIN);
1368    let p0 = 1.0 - p1;
1369    let mut split = (p0 * (CDF_TOTAL as f64)) as u32;
1370    if split == 0 {
1371        split = 1;
1372    } else if split >= CDF_TOTAL {
1373        split = CDF_TOTAL - 1;
1374    }
1375    split
1376}
1377
1378fn encode_payload_ac(data: &[u8], predictor: &mut RatePdfPredictor) -> Result<Vec<u8>> {
1379    predictor.begin_stream(data.len())?;
1380    if predictor.can_fast_ac_bitwise() {
1381        let mut out = Vec::new();
1382        {
1383            let mut enc = ArithmeticEncoder::new(&mut out);
1384            for &symbol in data {
1385                predictor.ac_step_fast_bitwise(|bit_idx, p1_mix| {
1386                    let bit = (symbol >> (7 - bit_idx)) & 1;
1387                    let split = binary_split_from_prob_one(p1_mix);
1388                    if bit == 0 {
1389                        enc.encode_counts(0, split as u64, CDF_TOTAL as u64)?;
1390                    } else {
1391                        enc.encode_counts(split as u64, CDF_TOTAL as u64, CDF_TOTAL as u64)?;
1392                    }
1393                    Ok(bit)
1394                })?;
1395            }
1396            let _ = enc.finish()?;
1397        }
1398        predictor.finish_stream()?;
1399        return Ok(out);
1400    }
1401
1402    let mut out = Vec::new();
1403    {
1404        let mut enc = ArithmeticEncoder::new(&mut out);
1405        let mut cdf = vec![0u32; 257];
1406        for &b in data {
1407            let pdf = predictor.pdf_next()?;
1408            crate::coders::quantize_pdf_to_integer_cdf_dense_positive_with_buffer(
1409                pdf, CDF_TOTAL, &mut cdf,
1410            );
1411            let sym = b as usize;
1412            enc.encode_counts(cdf[sym] as u64, cdf[sym + 1] as u64, CDF_TOTAL as u64)?;
1413            predictor.update(b)?;
1414        }
1415        let _ = enc.finish()?;
1416    }
1417    predictor.finish_stream()?;
1418    Ok(out)
1419}
1420
1421fn decode_payload_ac(
1422    payload: &[u8],
1423    out_len: usize,
1424    predictor: &mut RatePdfPredictor,
1425) -> Result<Vec<u8>> {
1426    predictor.begin_stream(out_len)?;
1427    if predictor.can_fast_ac_bitwise() {
1428        let mut dec = ArithmeticDecoder::new(payload)?;
1429        let mut out = Vec::with_capacity(out_len);
1430        for _ in 0..out_len {
1431            let symbol = predictor.ac_step_fast_bitwise(|_, p1_mix| {
1432                let split = binary_split_from_prob_one(p1_mix);
1433                let cdf = [0u32, split, CDF_TOTAL];
1434                Ok(dec.decode_symbol_counts(&cdf, CDF_TOTAL)? as u8)
1435            })?;
1436            out.push(symbol);
1437        }
1438        predictor.finish_stream()?;
1439        return Ok(out);
1440    }
1441
1442    let mut dec = ArithmeticDecoder::new(payload)?;
1443    let mut out = Vec::with_capacity(out_len);
1444    let mut cdf = vec![0u32; 257];
1445    for _ in 0..out_len {
1446        let pdf = predictor.pdf_next()?;
1447        crate::coders::quantize_pdf_to_integer_cdf_dense_positive_with_buffer(
1448            pdf, CDF_TOTAL, &mut cdf,
1449        );
1450        let sym = dec.decode_symbol_counts(&cdf, CDF_TOTAL)? as u8;
1451        out.push(sym);
1452        predictor.update(sym)?;
1453    }
1454    predictor.finish_stream()?;
1455    Ok(out)
1456}
1457
1458fn encode_payload_rans(data: &[u8], predictor: &mut RatePdfPredictor) -> Result<Vec<u8>> {
1459    predictor.begin_stream(data.len())?;
1460    let mut encoder = BlockedRansEncoder::new();
1461    let mut cdf = vec![0u32; 257];
1462    let mut freq = vec![0i64; 256];
1463
1464    for &b in data {
1465        let pdf = predictor.pdf_next()?;
1466        quantize_pdf_to_rans_cdf_with_buffer(pdf, &mut cdf, &mut freq);
1467        let s = b as usize;
1468        encoder.encode(Cdf::new(cdf[s], cdf[s + 1], ANS_TOTAL));
1469        predictor.update(b)?;
1470    }
1471
1472    let blocks = encoder.finish();
1473    let mut out = Vec::new();
1474    out.extend_from_slice(&(blocks.len() as u32).to_le_bytes());
1475    for block in blocks {
1476        out.extend_from_slice(&(block.len() as u32).to_le_bytes());
1477        out.extend_from_slice(&block);
1478    }
1479    predictor.finish_stream()?;
1480    Ok(out)
1481}
1482
1483fn decode_payload_rans(
1484    payload: &[u8],
1485    out_len: usize,
1486    predictor: &mut RatePdfPredictor,
1487) -> Result<Vec<u8>> {
1488    predictor.begin_stream(out_len)?;
1489    if payload.len() < 4 {
1490        bail!("rANS payload too short");
1491    }
1492    let block_count = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]) as usize;
1493    let mut pos = 4usize;
1494    let mut blocks = Vec::with_capacity(block_count);
1495    for _ in 0..block_count {
1496        if pos + 4 > payload.len() {
1497            bail!("truncated rANS block header");
1498        }
1499        let len = u32::from_le_bytes([
1500            payload[pos],
1501            payload[pos + 1],
1502            payload[pos + 2],
1503            payload[pos + 3],
1504        ]) as usize;
1505        pos += 4;
1506        if pos + len > payload.len() {
1507            bail!("truncated rANS block data");
1508        }
1509        blocks.push(&payload[pos..pos + len]);
1510        pos += len;
1511    }
1512
1513    let mut dec = BlockedRansDecoder::new(blocks, out_len)?;
1514    let mut out = Vec::with_capacity(out_len);
1515    let mut cdf = vec![0u32; 257];
1516    let mut freq = vec![0i64; 256];
1517
1518    for _ in 0..out_len {
1519        let pdf = predictor.pdf_next()?;
1520        quantize_pdf_to_rans_cdf_with_buffer(pdf, &mut cdf, &mut freq);
1521        let sym = dec.decode(&cdf)? as u8;
1522        out.push(sym);
1523        predictor.update(sym)?;
1524    }
1525    predictor.finish_stream()?;
1526    Ok(out)
1527}
1528
1529/// Compress bytes using a predictive rate backend and entropy coder.
1530///
1531/// When `framing` is [`FramingMode::Framed`], output includes a compact header
1532/// with payload metadata and CRC for safer transport/storage.
1533pub fn compress_rate_bytes(
1534    data: &[u8],
1535    rate_backend: &RateBackend,
1536    max_order: i64,
1537    coder: CoderType,
1538    framing: FramingMode,
1539) -> Result<Vec<u8>> {
1540    let mut predictor = RatePdfPredictor::from_rate_backend(rate_backend.clone(), max_order)?;
1541    let payload = match coder {
1542        CoderType::AC => encode_payload_ac(data, &mut predictor)?,
1543        CoderType::RANS => encode_payload_rans(data, &mut predictor)?,
1544    };
1545
1546    if framing == FramingMode::Raw {
1547        return Ok(payload);
1548    }
1549
1550    let mut out = Vec::with_capacity(FramedHeader::SIZE + payload.len());
1551    let hdr = FramedHeader::new(coder, data.len() as u64, crc32(data));
1552    hdr.write(&mut out);
1553    out.extend_from_slice(&payload);
1554    Ok(out)
1555}
1556
1557/// Return compressed size (in bytes) for `data` using rate coding.
1558pub fn compress_rate_size(
1559    data: &[u8],
1560    rate_backend: &RateBackend,
1561    max_order: i64,
1562    coder: CoderType,
1563    framing: FramingMode,
1564) -> Result<u64> {
1565    let encoded = compress_rate_bytes(data, rate_backend, max_order, coder, framing)?;
1566    Ok(encoded.len() as u64)
1567}
1568
1569/// Return compressed size (in bytes) for concatenated slices under one stream.
1570pub fn compress_rate_size_chain(
1571    parts: &[&[u8]],
1572    rate_backend: &RateBackend,
1573    max_order: i64,
1574    coder: CoderType,
1575    framing: FramingMode,
1576) -> Result<u64> {
1577    let total = parts.iter().map(|p| p.len()).sum();
1578    let mut data = Vec::with_capacity(total);
1579    for p in parts {
1580        data.extend_from_slice(p);
1581    }
1582    compress_rate_size(&data, rate_backend, max_order, coder, framing)
1583}
1584
1585/// Decompress bytes produced by [`compress_rate_bytes`].
1586pub fn decompress_rate_bytes(
1587    input: &[u8],
1588    rate_backend: &RateBackend,
1589    max_order: i64,
1590    _coder: CoderType,
1591    framing: FramingMode,
1592) -> Result<Vec<u8>> {
1593    let (payload, coder, out_len, expected_crc) = if framing == FramingMode::Framed {
1594        let hdr = FramedHeader::read(input)?;
1595        (
1596            &input[FramedHeader::SIZE..],
1597            hdr.coder_type(),
1598            hdr.original_len as usize,
1599            Some(hdr.crc32),
1600        )
1601    } else {
1602        bail!("raw payload decompression requires explicit output length and is not supported");
1603    };
1604
1605    let _ = coder;
1606    let mut predictor = RatePdfPredictor::from_rate_backend(rate_backend.clone(), max_order)?;
1607    let decoded = match coder {
1608        CoderType::AC => decode_payload_ac(payload, out_len, &mut predictor)?,
1609        CoderType::RANS => decode_payload_rans(payload, out_len, &mut predictor)?,
1610    };
1611
1612    if let Some(crc) = expected_crc {
1613        let got = crc32(&decoded);
1614        if got != crc {
1615            bail!("CRC32 mismatch: expected 0x{crc:08X}, got 0x{got:08X}");
1616        }
1617    }
1618
1619    Ok(decoded)
1620}
1621
1622fn normalize_pdf(pdf: &mut [f64]) {
1623    let mut sum = 0.0;
1624    for p in pdf.iter_mut() {
1625        *p = if p.is_finite() {
1626            (*p).max(PDF_MIN)
1627        } else {
1628            PDF_MIN
1629        };
1630        sum += *p;
1631    }
1632    if !(sum.is_finite()) || sum <= 0.0 {
1633        let u = 1.0 / (pdf.len() as f64);
1634        for p in pdf.iter_mut() {
1635            *p = u;
1636        }
1637        return;
1638    }
1639    let inv = 1.0 / sum;
1640    for p in pdf.iter_mut() {
1641        *p *= inv;
1642    }
1643}
1644
1645#[inline]
1646fn uniform_cdf_row() -> [f64; 257] {
1647    let mut cdf = [0.0; 257];
1648    let inv = 1.0 / 256.0;
1649    for (i, slot) in cdf.iter_mut().enumerate() {
1650        *slot = (i as f64) * inv;
1651    }
1652    cdf
1653}
1654
1655#[inline]
1656fn build_cdf_row_from_pdf_slice(pdf: &[f64], cdf: &mut [f64; 257]) {
1657    cdf[0] = 0.0;
1658    let mut acc = 0.0;
1659    for i in 0..256 {
1660        acc += pdf[i];
1661        cdf[i + 1] = acc;
1662    }
1663}
1664
1665fn normalize_pdf_vec_and_maybe_build_cdf(pdf: &mut [f64], mut cdf: Option<&mut [f64; 257]>) {
1666    let mut sum = 0.0;
1667    for p in pdf.iter_mut() {
1668        *p = if p.is_finite() {
1669            (*p).max(PDF_MIN)
1670        } else {
1671            PDF_MIN
1672        };
1673        sum += *p;
1674    }
1675    if !(sum.is_finite()) || sum <= 0.0 {
1676        let u = 1.0 / (pdf.len() as f64);
1677        pdf.fill(u);
1678        if let Some(cdf) = cdf.as_deref_mut() {
1679            *cdf = uniform_cdf_row();
1680        }
1681        return;
1682    }
1683    let inv = 1.0 / sum;
1684    if let Some(cdf) = cdf.as_deref_mut() {
1685        cdf[0] = 0.0;
1686        let mut acc = 0.0;
1687        for i in 0..256 {
1688            pdf[i] *= inv;
1689            acc += pdf[i];
1690            cdf[i + 1] = acc;
1691        }
1692    } else {
1693        for p in pdf.iter_mut() {
1694            *p *= inv;
1695        }
1696    }
1697}
1698
1699#[inline]
1700fn cdf_bit_prob_one_msb(cdf: &[f64; 257], lo: usize, hi: usize) -> f64 {
1701    let mid = (lo + hi) >> 1;
1702    let total = (cdf[hi] - cdf[lo]).max(PDF_MIN);
1703    let one = (cdf[hi] - cdf[mid]).max(0.0);
1704    (one / total).clamp(PDF_MIN, 1.0 - PDF_MIN)
1705}
1706
1707#[inline]
1708fn logsumexp_slice(vals: &[f64]) -> f64 {
1709    let mut m = f64::NEG_INFINITY;
1710    for &v in vals {
1711        if v > m {
1712            m = v;
1713        }
1714    }
1715    if !m.is_finite() {
1716        return m;
1717    }
1718    let mut s = 0.0;
1719    for &v in vals {
1720        s += (v - m).exp();
1721    }
1722    m + s.ln()
1723}
1724
1725#[inline]
1726fn logsumexp_expert_weights(experts: &[MixExpert]) -> f64 {
1727    let mut m = f64::NEG_INFINITY;
1728    for e in experts {
1729        if e.log_weight > m {
1730            m = e.log_weight;
1731        }
1732    }
1733    if !m.is_finite() {
1734        return m;
1735    }
1736    let mut s = 0.0;
1737    for e in experts {
1738        s += (e.log_weight - m).exp();
1739    }
1740    m + s.ln()
1741}
1742
1743fn logsumexp2(a: f64, b: f64) -> f64 {
1744    let m = if a > b { a } else { b };
1745    if !m.is_finite() {
1746        return m;
1747    }
1748    m + ((a - m).exp() + (b - m).exp()).ln()
1749}
1750
1751#[allow(dead_code)]
1752fn _zpaq_marker(_: &ZpaqRateModel) {}
1753
1754#[cfg(test)]
1755mod tests {
1756    use super::*;
1757    use std::sync::Arc;
1758
1759    fn assert_pdf_close(lhs: &[f64], rhs: &[f64], tol: f64) {
1760        assert_eq!(lhs.len(), rhs.len());
1761        for (idx, (&a, &b)) in lhs.iter().zip(rhs.iter()).enumerate() {
1762            let delta = (a - b).abs();
1763            assert!(
1764                delta <= tol,
1765                "pdf mismatch at symbol {idx}: lhs={a} rhs={b} delta={delta}"
1766            );
1767        }
1768    }
1769
1770    fn brute_force_pdf(predictor: &mut CtwPredictor) -> Vec<f64> {
1771        let bits = predictor.bits_per_symbol.clamp(1, 8);
1772        let mut out = vec![0.0; 256];
1773
1774        if bits == 8 {
1775            for (sym, slot) in out.iter_mut().enumerate().take(256usize) {
1776                *slot = predictor.log_prob_symbol_bruteforce(sym as u8).exp();
1777            }
1778        } else {
1779            let patterns = 1usize << bits;
1780            let aliases = 1usize << (8 - bits);
1781            let mut pat_prob = vec![0.0; patterns];
1782            for (pat, value) in pat_prob.iter_mut().enumerate() {
1783                let symbol = if predictor.msb_first {
1784                    (pat as u8) << (8 - bits)
1785                } else {
1786                    pat as u8
1787                };
1788                *value = predictor.log_prob_symbol_bruteforce(symbol).exp();
1789            }
1790            for (byte, slot) in out.iter_mut().enumerate().take(256usize) {
1791                let pat = if predictor.msb_first {
1792                    byte >> (8 - bits)
1793                } else {
1794                    byte & (patterns - 1)
1795                };
1796                *slot = pat_prob[pat] / (aliases as f64);
1797            }
1798        }
1799
1800        CtwPredictor::normalize_pdf(&mut out);
1801        out
1802    }
1803
1804    #[test]
1805    fn ctw_pdf_fast_matches_bruteforce() {
1806        let mut predictor = CtwPredictor::new_ctw(6);
1807        for &b in b"ctw fast-path regression corpus 1234567890" {
1808            predictor.update(b);
1809        }
1810
1811        let fast = predictor.pdf_next().to_vec();
1812        predictor.valid = false;
1813        let brute = brute_force_pdf(&mut predictor);
1814
1815        for i in 0..256usize {
1816            let delta = (fast[i] - brute[i]).abs();
1817            assert!(
1818                delta < 1e-12,
1819                "symbol={i} fast={} brute={} delta={delta}",
1820                fast[i],
1821                brute[i]
1822            );
1823        }
1824    }
1825
1826    #[test]
1827    fn fac_pdf_fast_matches_bruteforce_subbyte() {
1828        let mut predictor = CtwPredictor::new_fac(5, 5);
1829        for &b in b"fac ctw subbyte regression corpus abcdefghijklmnopqrstuvwxyz" {
1830            predictor.update(b);
1831        }
1832
1833        let fast = predictor.pdf_next().to_vec();
1834        predictor.valid = false;
1835        let brute = brute_force_pdf(&mut predictor);
1836
1837        for i in 0..256usize {
1838            let delta = (fast[i] - brute[i]).abs();
1839            assert!(
1840                delta < 1e-12,
1841                "symbol={i} fast={} brute={} delta={delta}",
1842                fast[i],
1843                brute[i]
1844            );
1845        }
1846    }
1847
1848    fn assert_ctw_pdf_next_preserves_state(mut predictor: CtwPredictor) {
1849        for &b in b"ctw predictor state preservation payload" {
1850            predictor.update(b);
1851        }
1852        let mut before_p0 = [0.0f64; 8];
1853        let mut before_p1 = [0.0f64; 8];
1854        for bit_idx in 0..8usize {
1855            before_p0[bit_idx] = predictor.tree.predict(false, bit_idx);
1856            before_p1[bit_idx] = predictor.tree.predict(true, bit_idx);
1857        }
1858        let log_before = predictor.tree.get_log_block_probability();
1859        let _ = predictor.pdf_next();
1860        let log_after = predictor.tree.get_log_block_probability();
1861        assert!(
1862            (log_before - log_after).abs() < 1e-12,
1863            "log drift: before={log_before} after={log_after}"
1864        );
1865        for bit_idx in 0..8usize {
1866            let after_p0 = predictor.tree.predict(false, bit_idx);
1867            let after_p1 = predictor.tree.predict(true, bit_idx);
1868            assert!(
1869                (before_p0[bit_idx] - after_p0).abs() < 1e-12,
1870                "bit {bit_idx} p0 drift: {} vs {}",
1871                before_p0[bit_idx],
1872                after_p0
1873            );
1874            assert!(
1875                (before_p1[bit_idx] - after_p1).abs() < 1e-12,
1876                "bit {bit_idx} p1 drift: {} vs {}",
1877                before_p1[bit_idx],
1878                after_p1
1879            );
1880        }
1881    }
1882
1883    #[test]
1884    fn ctw_pdf_next_preserves_state() {
1885        assert_ctw_pdf_next_preserves_state(CtwPredictor::new_ctw(7));
1886    }
1887
1888    #[test]
1889    fn fac_pdf_next_preserves_state() {
1890        assert_ctw_pdf_next_preserves_state(CtwPredictor::new_fac(7, 8));
1891    }
1892
1893    fn assert_fill_pattern_preserves_symbol_log_probs(mut predictor: CtwPredictor) {
1894        for &b in b"fill-pattern preservation regression payload" {
1895            predictor.update(b);
1896        }
1897        let mut baseline = [0.0f64; 256];
1898        for (sym, slot) in baseline.iter_mut().enumerate() {
1899            *slot = predictor.log_prob_symbol_bruteforce(sym as u8);
1900        }
1901        let _ = predictor.fill_pattern_log_probs();
1902        for (sym, &expected) in baseline.iter().enumerate() {
1903            let got = predictor.log_prob_symbol_bruteforce(sym as u8);
1904            let diff = (expected - got).abs();
1905            assert!(
1906                diff < 1e-12,
1907                "symbol={sym} expected={expected} got={got} diff={diff}"
1908            );
1909        }
1910    }
1911
1912    #[test]
1913    fn ctw_fill_pattern_preserves_symbol_log_probs() {
1914        assert_fill_pattern_preserves_symbol_log_probs(CtwPredictor::new_ctw(7));
1915    }
1916
1917    #[test]
1918    fn fac_fill_pattern_preserves_symbol_log_probs() {
1919        assert_fill_pattern_preserves_symbol_log_probs(CtwPredictor::new_fac(7, 8));
1920    }
1921
1922    fn assert_pdf_then_update_matches_plain_update(mut base: CtwPredictor) {
1923        for &b in b"pdf then update parity payload" {
1924            base.update(b);
1925        }
1926        let observed = b'n';
1927        let mut with_pdf = base.clone();
1928        let mut plain = base;
1929
1930        let _ = with_pdf.pdf_next();
1931        with_pdf.update(observed);
1932        plain.update(observed);
1933
1934        for sym in 0u8..=255u8 {
1935            let lp_with_pdf = with_pdf.log_prob_symbol_bruteforce(sym);
1936            let lp_plain = plain.log_prob_symbol_bruteforce(sym);
1937            let diff = (lp_with_pdf - lp_plain).abs();
1938            assert!(
1939                diff < 1e-12,
1940                "symbol={sym} with_pdf={lp_with_pdf} plain={lp_plain} diff={diff}"
1941            );
1942        }
1943    }
1944
1945    #[test]
1946    fn ctw_pdf_then_update_matches_plain_update() {
1947        assert_pdf_then_update_matches_plain_update(CtwPredictor::new_ctw(7));
1948    }
1949
1950    #[test]
1951    fn fac_pdf_then_update_matches_plain_update() {
1952        assert_pdf_then_update_matches_plain_update(CtwPredictor::new_fac(7, 8));
1953    }
1954
1955    #[test]
1956    fn roundtrip_rate_ac_ctw() {
1957        let data = b"ctw backend roundtrip payload";
1958        let backend = RateBackend::Ctw { depth: 8 };
1959        let enc =
1960            compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
1961        let dec =
1962            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
1963        assert_eq!(dec, data);
1964    }
1965
1966    #[test]
1967    fn roundtrip_rate_ac_match_family_and_ppmd() {
1968        let data = b"repeat repeat repeat sparse sparse repeat payload";
1969        for backend in [
1970            RateBackend::Match {
1971                hash_bits: 20,
1972                min_len: 4,
1973                max_len: 255,
1974                base_mix: 0.02,
1975                confidence_scale: 1.0,
1976            },
1977            RateBackend::SparseMatch {
1978                hash_bits: 19,
1979                min_len: 3,
1980                max_len: 64,
1981                gap_min: 1,
1982                gap_max: 2,
1983                base_mix: 0.05,
1984                confidence_scale: 1.0,
1985            },
1986            RateBackend::Ppmd {
1987                order: 8,
1988                memory_mb: 8,
1989            },
1990        ] {
1991            let enc = compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed)
1992                .unwrap();
1993            let dec = decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed)
1994                .unwrap();
1995            assert_eq!(dec, data);
1996        }
1997    }
1998
1999    #[test]
2000    fn roundtrip_rate_ac_ppmd_high_order_text_payload() {
2001        let seed = include_bytes!("../../README.md");
2002        let mut data = Vec::with_capacity(4096);
2003        while data.len() < 4096 {
2004            data.extend_from_slice(seed);
2005        }
2006        data.truncate(4096);
2007
2008        let backend = RateBackend::Ppmd {
2009            order: 12,
2010            memory_mb: 256,
2011        };
2012        let enc = compress_rate_bytes(&data, &backend, -1, CoderType::AC, FramingMode::Framed)
2013            .expect("ppmd high-order compression");
2014        let dec = decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed)
2015            .expect("ppmd high-order decompression");
2016        assert_eq!(dec, data);
2017    }
2018
2019    #[test]
2020    fn roundtrip_rate_ac_calibrated_backend() {
2021        let data = b"calibration wrapper payload calibration wrapper payload";
2022        let backend = RateBackend::Calibrated {
2023            spec: Arc::new(crate::CalibratedSpec {
2024                base: RateBackend::Ctw { depth: 8 },
2025                context: crate::CalibrationContextKind::Text,
2026                bins: 33,
2027                learning_rate: 0.02,
2028                bias_clip: 4.0,
2029            }),
2030        };
2031        let enc =
2032            compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2033        let dec =
2034            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2035        assert_eq!(dec, data);
2036    }
2037
2038    #[test]
2039    fn roundtrip_rate_ac_single_expert_ctw_neural_mixture() {
2040        let data = b"single expert neural ctw fast path payload";
2041        let spec = MixtureSpec::new(
2042            MixtureKind::Neural,
2043            vec![crate::MixtureExpertSpec {
2044                name: Some("ctw".to_string()),
2045                log_prior: 0.0,
2046                max_order: -1,
2047                backend: RateBackend::Ctw { depth: 8 },
2048            }],
2049        )
2050        .with_alpha(0.03);
2051        let backend = RateBackend::Mixture {
2052            spec: Arc::new(spec),
2053        };
2054        let enc =
2055            compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2056        let dec =
2057            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2058        assert_eq!(dec, data);
2059    }
2060
2061    #[test]
2062    fn roundtrip_rate_ac_single_expert_ctw_bayes_mixture() {
2063        let data = b"single expert bayes ctw fast path payload";
2064        let spec = MixtureSpec::new(
2065            MixtureKind::Bayes,
2066            vec![crate::MixtureExpertSpec {
2067                name: Some("ctw".to_string()),
2068                log_prior: 0.0,
2069                max_order: -1,
2070                backend: RateBackend::Ctw { depth: 8 },
2071            }],
2072        )
2073        .with_alpha(0.03);
2074        let backend = RateBackend::Mixture {
2075            spec: Arc::new(spec),
2076        };
2077        let enc =
2078            compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2079        let dec =
2080            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2081        assert_eq!(dec, data);
2082    }
2083
2084    #[test]
2085    fn roundtrip_rate_rans_recursive_mixture() {
2086        let data = b"recursive mixture payload";
2087        let nested = MixtureSpec::new(
2088            MixtureKind::Bayes,
2089            vec![
2090                crate::MixtureExpertSpec {
2091                    name: Some("ctw".to_string()),
2092                    log_prior: 0.0,
2093                    max_order: -1,
2094                    backend: RateBackend::Ctw { depth: 6 },
2095                },
2096                crate::MixtureExpertSpec {
2097                    name: Some("fac".to_string()),
2098                    log_prior: 0.0,
2099                    max_order: -1,
2100                    backend: RateBackend::FacCtw {
2101                        base_depth: 6,
2102                        num_percept_bits: 8,
2103                        encoding_bits: 8,
2104                    },
2105                },
2106            ],
2107        );
2108        let root = MixtureSpec::new(
2109            MixtureKind::Switching,
2110            vec![
2111                crate::MixtureExpertSpec {
2112                    name: Some("nested".to_string()),
2113                    log_prior: 0.0,
2114                    max_order: -1,
2115                    backend: RateBackend::Mixture {
2116                        spec: Arc::new(nested),
2117                    },
2118                },
2119                crate::MixtureExpertSpec {
2120                    name: Some("zpaq".to_string()),
2121                    log_prior: 0.0,
2122                    max_order: -1,
2123                    backend: RateBackend::Zpaq {
2124                        method: "1".to_string(),
2125                    },
2126                },
2127            ],
2128        )
2129        .with_alpha(0.05);
2130
2131        let backend = RateBackend::Mixture {
2132            spec: Arc::new(root),
2133        };
2134        let enc =
2135            compress_rate_bytes(data, &backend, -1, CoderType::RANS, FramingMode::Framed).unwrap();
2136        let dec = decompress_rate_bytes(&enc, &backend, -1, CoderType::RANS, FramingMode::Framed)
2137            .unwrap();
2138        assert_eq!(dec, data);
2139    }
2140
2141    #[test]
2142    fn roundtrip_rate_ac_recursive_neural_mixture() {
2143        let data = b"neural recursive mixture payload for ac coder";
2144        let inner = MixtureSpec::new(
2145            MixtureKind::Bayes,
2146            vec![
2147                crate::MixtureExpertSpec {
2148                    name: Some("ctw".to_string()),
2149                    log_prior: 0.0,
2150                    max_order: -1,
2151                    backend: RateBackend::Ctw { depth: 6 },
2152                },
2153                crate::MixtureExpertSpec {
2154                    name: Some("fac".to_string()),
2155                    log_prior: 0.0,
2156                    max_order: -1,
2157                    backend: RateBackend::FacCtw {
2158                        base_depth: 6,
2159                        num_percept_bits: 8,
2160                        encoding_bits: 8,
2161                    },
2162                },
2163            ],
2164        );
2165        let root = MixtureSpec::new(
2166            MixtureKind::Neural,
2167            vec![
2168                crate::MixtureExpertSpec {
2169                    name: Some("nested".to_string()),
2170                    log_prior: 0.0,
2171                    max_order: -1,
2172                    backend: RateBackend::Mixture {
2173                        spec: Arc::new(inner),
2174                    },
2175                },
2176                crate::MixtureExpertSpec {
2177                    name: Some("zpaq".to_string()),
2178                    log_prior: 0.0,
2179                    max_order: -1,
2180                    backend: RateBackend::Zpaq {
2181                        method: "1".to_string(),
2182                    },
2183                },
2184            ],
2185        )
2186        .with_alpha(0.03);
2187
2188        let backend = RateBackend::Mixture {
2189            spec: Arc::new(root),
2190        };
2191        let enc =
2192            compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2193        let dec =
2194            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2195        assert_eq!(dec, data);
2196    }
2197
2198    #[test]
2199    fn neural_runtime_and_compression_predictor_align() {
2200        let spec = MixtureSpec::new(
2201            MixtureKind::Neural,
2202            vec![
2203                crate::MixtureExpertSpec {
2204                    name: Some("ctw".to_string()),
2205                    log_prior: 0.0,
2206                    max_order: -1,
2207                    backend: RateBackend::Ctw { depth: 7 },
2208                },
2209                crate::MixtureExpertSpec {
2210                    name: Some("fac".to_string()),
2211                    log_prior: 0.0,
2212                    max_order: -1,
2213                    backend: RateBackend::FacCtw {
2214                        base_depth: 7,
2215                        num_percept_bits: 8,
2216                        encoding_bits: 8,
2217                    },
2218                },
2219            ],
2220        )
2221        .with_alpha(0.03);
2222
2223        let backend = RateBackend::Mixture {
2224            spec: Arc::new(spec.clone()),
2225        };
2226        let mut predictor = RatePdfPredictor::from_rate_backend(backend, -1).unwrap();
2227        let experts = spec.build_experts();
2228        let mut runtime = crate::mixture::build_mixture_runtime(&spec, &experts).unwrap();
2229
2230        let data = b"neural alignment check sequence";
2231        for &b in data {
2232            let pdf = predictor.pdf_next().unwrap();
2233            let p_comp = pdf[b as usize];
2234            let p_runtime = runtime.peek_log_prob(b).exp();
2235            assert!(
2236                (p_comp - p_runtime).abs() < 1e-8,
2237                "p_comp={p_comp} p_runtime={p_runtime} symbol={b}"
2238            );
2239            predictor.update(b).unwrap();
2240            runtime.step(b);
2241        }
2242    }
2243
2244    fn assert_cached_cdf_fast_bitwise_matches_pdf_rows(mut predictor: RatePdfPredictor) {
2245        let data = b"cached cdf parity check payload";
2246        for &symbol in data {
2247            let pdf = predictor.pdf_next().unwrap().to_vec();
2248            assert!(predictor.prepare_cached_cdf_fast_bitwise().unwrap());
2249
2250            let mut row = [0.0; 257];
2251            row[0] = 0.0;
2252            for i in 0..256 {
2253                row[i + 1] = row[i] + pdf[i].max(PDF_MIN);
2254            }
2255
2256            let mut stack = vec![(0usize, 256usize)];
2257            while let Some((lo, hi)) = stack.pop() {
2258                if hi - lo <= 1 {
2259                    continue;
2260                }
2261                let expected = cdf_bit_prob_one_msb(&row, lo, hi);
2262                let got = predictor
2263                    .cached_cdf_bit_prob_one_msb(lo, hi)
2264                    .expect("cached cdf branch probability");
2265                let diff = (expected - got).abs();
2266                assert!(
2267                    diff <= 1e-12,
2268                    "lo={lo} hi={hi} expected={expected} got={got} diff={diff}"
2269                );
2270                let mid = (lo + hi) >> 1;
2271                stack.push((lo, mid));
2272                stack.push((mid, hi));
2273            }
2274
2275            predictor.update(symbol).unwrap();
2276        }
2277    }
2278
2279    #[test]
2280    fn cached_cdf_fast_bitwise_matches_pdf_rows_for_specialized_predictors() {
2281        assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2282            RatePdfPredictor::from_rate_backend(RateBackend::RosaPlus, -1).unwrap(),
2283        );
2284        assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2285            RatePdfPredictor::from_rate_backend(
2286                RateBackend::Ppmd {
2287                    order: 6,
2288                    memory_mb: 8,
2289                },
2290                -1,
2291            )
2292            .unwrap(),
2293        );
2294        assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2295            RatePdfPredictor::from_rate_backend(
2296                RateBackend::Match {
2297                    hash_bits: 20,
2298                    min_len: 4,
2299                    max_len: 255,
2300                    base_mix: 0.02,
2301                    confidence_scale: 1.0,
2302                },
2303                -1,
2304            )
2305            .unwrap(),
2306        );
2307        #[cfg(feature = "backend-rwkv")]
2308        assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2309            RatePdfPredictor::from_rate_backend(
2310                RateBackend::Rwkv7Method {
2311                    method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=11,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
2312                },
2313                -1,
2314            )
2315            .unwrap(),
2316        );
2317        #[cfg(feature = "backend-mamba")]
2318        assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2319            RatePdfPredictor::from_rate_backend(
2320                RateBackend::MambaMethod {
2321                    method: "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=7,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
2322                },
2323                -1,
2324            )
2325            .unwrap(),
2326        );
2327    }
2328
2329    #[test]
2330    fn raw_size_not_larger_than_framed_size() {
2331        let data = b"raw/framed size check payload";
2332        let backend = RateBackend::RosaPlus;
2333        let raw = compress_rate_size(data, &backend, 8, CoderType::AC, FramingMode::Raw).unwrap();
2334        let framed =
2335            compress_rate_size(data, &backend, 8, CoderType::AC, FramingMode::Framed).unwrap();
2336        assert!(framed >= raw);
2337    }
2338
2339    #[cfg(feature = "backend-rwkv")]
2340    #[test]
2341    fn roundtrip_rate_rwkv_method_cfg() {
2342        let data = b"rwkv cfg method backend";
2343        let backend = RateBackend::Rwkv7Method {
2344            method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=11,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
2345        };
2346        let enc =
2347            compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2348        let dec =
2349            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2350        assert_eq!(dec, data);
2351    }
2352
2353    #[cfg(feature = "backend-rwkv")]
2354    #[test]
2355    fn rwkv_rate_predictor_preserves_backend_pdf_exactly() {
2356        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=11,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer";
2357        let mut predictor = RwkvPredictor::from_method(method).expect("rwkv predictor");
2358        let mut backend = rwkvzip::Compressor::new_from_method(method).expect("rwkv backend");
2359        let mut direct = vec![0.0; backend.vocab_size()];
2360
2361        let predicted = predictor.pdf_next().to_vec();
2362        backend.forward_to_pdf(0, &mut direct);
2363        assert_pdf_close(&predicted, &direct, 1e-18);
2364
2365        predictor.update(b'x').expect("predictor update");
2366        backend
2367            .online_update_from_pdf(b'x', &direct)
2368            .expect("backend update");
2369        backend.forward_to_pdf(u32::from(b'x'), &mut direct);
2370        assert_pdf_close(predictor.pdf_next(), &direct, 1e-18);
2371    }
2372
2373    #[cfg(feature = "backend-rwkv")]
2374    #[test]
2375    fn rwkv_rate_predictor_matches_backend_after_partial_tbptt_stream() {
2376        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=29,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=8,clip=0,momentum=0.9)";
2377        let data = b"abcdefghij";
2378        let mut predictor = RwkvPredictor::from_method(method).expect("rwkv predictor");
2379        let mut backend = rwkvzip::Compressor::new_from_method(method).expect("rwkv backend");
2380        let mut direct = vec![0.0; backend.vocab_size()];
2381
2382        predictor
2383            .begin_stream(data.len())
2384            .expect("begin predictor stream");
2385        backend
2386            .begin_online_policy_stream(Some(data.len() as u64))
2387            .expect("begin backend stream");
2388        backend.reset_and_prime();
2389
2390        for &byte in data {
2391            let predicted = predictor.pdf_next().to_vec();
2392            backend.copy_current_pdf_to(&mut direct);
2393            assert_pdf_close(&predicted, &direct, 1e-18);
2394
2395            predictor.update(byte).expect("predictor update");
2396            backend
2397                .observe_symbol_from_current_pdf(byte)
2398                .expect("backend update");
2399        }
2400
2401        predictor.finish_stream().expect("finish predictor stream");
2402        backend
2403            .finish_online_policy_stream()
2404            .expect("finish backend stream");
2405        backend.copy_current_pdf_to(&mut direct);
2406        assert_pdf_close(predictor.pdf_next(), &direct, 1e-18);
2407    }
2408
2409    #[cfg(feature = "backend-rwkv")]
2410    #[test]
2411    fn roundtrip_rate_rwkv_two_json_method_2m() {
2412        let two_json: serde_json::Value =
2413            serde_json::from_str(include_str!("../../examples/two.json")).unwrap();
2414        let method = two_json["experts"]
2415            .as_array()
2416            .unwrap()
2417            .iter()
2418            .find(|expert| expert["name"].as_str() == Some("rwkv"))
2419            .and_then(|expert| expert["method"].as_str())
2420            .unwrap()
2421            .to_string();
2422
2423        let backend = RateBackend::Rwkv7Method { method };
2424        let seed = include_bytes!("../../README.md");
2425        let target_len = 2_097_152usize;
2426        let mut data = Vec::with_capacity(target_len);
2427        while data.len() < target_len {
2428            let remaining = target_len - data.len();
2429            data.extend_from_slice(&seed[..seed.len().min(remaining)]);
2430        }
2431
2432        let enc =
2433            compress_rate_bytes(&data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2434        let dec =
2435            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2436        assert_eq!(dec, data);
2437    }
2438
2439    #[cfg(feature = "backend-mamba")]
2440    #[test]
2441    fn mamba_rate_predictor_preserves_backend_pdf_exactly() {
2442        let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=7,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer";
2443        let mut predictor = MambaPredictor::from_method(method).expect("mamba predictor");
2444        let mut backend = mambazip::Compressor::new_from_method(method).expect("mamba backend");
2445        let mut direct = vec![0.0; backend.vocab_size()];
2446
2447        let predicted = predictor.pdf_next().to_vec();
2448        backend.forward_to_pdf(0, &mut direct);
2449        assert_pdf_close(&predicted, &direct, 1e-18);
2450
2451        predictor.update(b'x').expect("predictor update");
2452        backend
2453            .online_update_from_pdf(b'x', &direct)
2454            .expect("backend update");
2455        backend.forward_to_pdf(u32::from(b'x'), &mut direct);
2456        assert_pdf_close(predictor.pdf_next(), &direct, 1e-18);
2457    }
2458
2459    #[test]
2460    fn roundtrip_rate_ac_particle() {
2461        let spec = crate::ParticleSpec {
2462            num_particles: 4,
2463            num_cells: 4,
2464            cell_dim: 8,
2465            num_rules: 2,
2466            selector_hidden: 16,
2467            rule_hidden: 16,
2468            context_window: 8,
2469            unroll_steps: 1,
2470            ..crate::ParticleSpec::default()
2471        };
2472        let data = b"particle ac roundtrip payload";
2473        let backend = RateBackend::Particle {
2474            spec: Arc::new(spec),
2475        };
2476        let enc =
2477            compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2478        let dec =
2479            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2480        assert_eq!(dec, data);
2481    }
2482
2483    #[test]
2484    fn roundtrip_rate_rans_particle() {
2485        let spec = crate::ParticleSpec {
2486            num_particles: 4,
2487            num_cells: 4,
2488            cell_dim: 8,
2489            num_rules: 2,
2490            selector_hidden: 16,
2491            rule_hidden: 16,
2492            context_window: 8,
2493            unroll_steps: 1,
2494            ..crate::ParticleSpec::default()
2495        };
2496        let data = b"particle rans roundtrip payload";
2497        let backend = RateBackend::Particle {
2498            spec: Arc::new(spec),
2499        };
2500        let enc =
2501            compress_rate_bytes(data, &backend, -1, CoderType::RANS, FramingMode::Framed).unwrap();
2502        let dec = decompress_rate_bytes(&enc, &backend, -1, CoderType::RANS, FramingMode::Framed)
2503            .unwrap();
2504        assert_eq!(dec, data);
2505    }
2506
2507    #[test]
2508    fn mixture_with_particle_expert_roundtrip() {
2509        let particle_spec = crate::ParticleSpec {
2510            num_particles: 4,
2511            num_cells: 4,
2512            cell_dim: 8,
2513            num_rules: 2,
2514            selector_hidden: 16,
2515            rule_hidden: 16,
2516            context_window: 8,
2517            unroll_steps: 1,
2518            ..crate::ParticleSpec::default()
2519        };
2520        let spec = MixtureSpec::new(
2521            MixtureKind::Bayes,
2522            vec![
2523                crate::MixtureExpertSpec {
2524                    name: Some("particle".to_string()),
2525                    log_prior: 0.0,
2526                    max_order: -1,
2527                    backend: RateBackend::Particle {
2528                        spec: Arc::new(particle_spec),
2529                    },
2530                },
2531                crate::MixtureExpertSpec {
2532                    name: Some("ctw".to_string()),
2533                    log_prior: 0.0,
2534                    max_order: -1,
2535                    backend: RateBackend::Ctw { depth: 6 },
2536                },
2537            ],
2538        );
2539        let backend = RateBackend::Mixture {
2540            spec: Arc::new(spec),
2541        };
2542        let data = b"mixture with particle expert roundtrip";
2543        let enc =
2544            compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2545        let dec =
2546            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2547        assert_eq!(dec, data);
2548    }
2549}