Skip to main content

infotheory/compression/
mod.rs

1//! Rate-coded compression helpers (AC/rANS) with optional framing.
2//!
3//! The functions in this module implement lossless byte compression by combining:
4//! - a predictive rate model (`RateBackend`) that emits per-symbol PDFs,
5//! - an entropy coder (`AC` or `rANS`),
6//! - optional framing metadata for robust decompression.
7
8use anyhow::{Result, bail};
9
10use crate::backends::calibration::CalibratorCore;
11use crate::backends::match_model::MatchModel;
12use crate::backends::ppmd::PpmdModel;
13use crate::backends::sequitur::SequiturModel;
14use crate::backends::sparse_match::SparseMatchModel;
15use crate::backends::text_context::TextContextAnalyzer;
16use crate::coders::{
17    ANS_TOTAL, ArithmeticDecoder, ArithmeticEncoder, BlockedRansDecoder, BlockedRansEncoder,
18    CDF_TOTAL, Cdf, CoderType, crc32, quantize_pdf_to_rans_cdf_with_buffer,
19};
20use crate::ctw::FacContextTree;
21#[cfg(feature = "backend-mamba")]
22use crate::mambazip;
23use crate::mixture::{
24    DEFAULT_MIN_PROB, convex_step_size_for_update, project_simplex_with_scratch,
25    switching_alpha_for_update,
26};
27use crate::neural_mix::NeuralMixCore;
28use crate::rosaplus::RosaPlus;
29#[cfg(feature = "backend-rwkv")]
30use crate::rwkvzip;
31use crate::zpaq_rate::ZpaqRateModel;
32use crate::{CalibratedSpec, MixtureKind, MixtureScheduleMode, MixtureSpec, RateBackend};
33use rayon::{ThreadPool, prelude::*};
34
35const FRAMED_MAGIC: u32 = 0x4354_4946; // "FITC"
36const FRAMED_VERSION: u8 = 1;
37const PDF_MIN: f64 = DEFAULT_MIN_PROB;
38const DIAGNOSTIC_PARALLEL_THRESHOLD: usize = 4;
39
40#[inline]
41fn build_calibrator(spec: &CalibratedSpec) -> CalibratorCore {
42    CalibratorCore::new(spec.context, spec.bins, spec.learning_rate, spec.bias_clip)
43}
44
45#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
46/// Wire format mode for rate-coded payloads.
47pub enum FramingMode {
48    /// Emit only coder payload bytes (no integrity/length header).
49    Raw,
50    /// Emit framed payload with magic/version/length/checksum header.
51    #[default]
52    Framed,
53}
54
55#[derive(Clone, Copy, Debug)]
56struct FramedHeader {
57    magic: u32,
58    version: u8,
59    coder: u8,
60    original_len: u64,
61    crc32: u32,
62}
63
64impl FramedHeader {
65    const SIZE: usize = 4 + 1 + 1 + 8 + 4;
66
67    fn new(coder: CoderType, original_len: u64, crc32: u32) -> Self {
68        Self {
69            magic: FRAMED_MAGIC,
70            version: FRAMED_VERSION,
71            coder: match coder {
72                CoderType::AC => 0,
73                CoderType::RANS => 1,
74            },
75            original_len,
76            crc32,
77        }
78    }
79
80    fn write(&self, out: &mut Vec<u8>) {
81        out.extend_from_slice(&self.magic.to_le_bytes());
82        out.push(self.version);
83        out.push(self.coder);
84        out.extend_from_slice(&self.original_len.to_le_bytes());
85        out.extend_from_slice(&self.crc32.to_le_bytes());
86    }
87
88    fn read(input: &[u8]) -> Result<Self> {
89        if input.len() < Self::SIZE {
90            bail!("framed payload too short");
91        }
92        let magic = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
93        if magic != FRAMED_MAGIC {
94            bail!("invalid framed magic: expected 0x{FRAMED_MAGIC:08X}, got 0x{magic:08X}");
95        }
96        let version = input[4];
97        if version != FRAMED_VERSION {
98            bail!("unsupported framed version: {version}");
99        }
100        let coder = input[5];
101        let original_len = u64::from_le_bytes([
102            input[6], input[7], input[8], input[9], input[10], input[11], input[12], input[13],
103        ]);
104        let crc32 = u32::from_le_bytes([input[14], input[15], input[16], input[17]]);
105        Ok(Self {
106            magic,
107            version,
108            coder,
109            original_len,
110            crc32,
111        })
112    }
113
114    fn coder_type(&self) -> CoderType {
115        match self.coder {
116            0 => CoderType::AC,
117            _ => CoderType::RANS,
118        }
119    }
120}
121
122#[derive(Clone)]
123struct CtwPredictor {
124    tree: FacContextTree,
125    bits_per_symbol: usize,
126    msb_first: bool,
127    pdf: Vec<f64>,
128    pattern_logps: Vec<f64>,
129    valid: bool,
130}
131
132impl CtwPredictor {
133    fn new_ctw(depth: usize) -> Self {
134        Self {
135            tree: FacContextTree::new(depth, 8),
136            bits_per_symbol: 8,
137            msb_first: true,
138            pdf: vec![0.0; 256],
139            pattern_logps: vec![f64::NEG_INFINITY; 256],
140            valid: false,
141        }
142    }
143
144    fn new_fac(base_depth: usize, bits_per_symbol: usize) -> Self {
145        Self {
146            tree: FacContextTree::new(base_depth, bits_per_symbol),
147            bits_per_symbol,
148            msb_first: false,
149            pdf: vec![0.0; 256],
150            pattern_logps: vec![f64::NEG_INFINITY; 256],
151            valid: false,
152        }
153    }
154
155    fn fill_pattern_log_probs(&mut self) -> usize {
156        fn rec(
157            tree: &mut FacContextTree,
158            bits: usize,
159            msb_first: bool,
160            depth: usize,
161            pattern: usize,
162            log_before: f64,
163            out: &mut [f64],
164        ) {
165            if depth == bits {
166                out[pattern] = tree.get_log_block_probability() - log_before;
167                return;
168            }
169            for bit in [false, true] {
170                tree.update(bit, depth);
171                let next_pattern = if msb_first {
172                    (pattern << 1) | (bit as usize)
173                } else {
174                    pattern | ((bit as usize) << depth)
175                };
176                rec(
177                    tree,
178                    bits,
179                    msb_first,
180                    depth + 1,
181                    next_pattern,
182                    log_before,
183                    out,
184                );
185                tree.revert(depth);
186            }
187        }
188
189        let bits = self.bits_per_symbol.clamp(1, 8);
190        let patterns = 1usize << bits;
191        let log_before = self.tree.get_log_block_probability();
192        self.pattern_logps[..patterns].fill(f64::NEG_INFINITY);
193        rec(
194            &mut self.tree,
195            bits,
196            self.msb_first,
197            0,
198            0,
199            log_before,
200            &mut self.pattern_logps[..patterns],
201        );
202        patterns
203    }
204
205    #[cfg(test)]
206    fn log_prob_symbol_bruteforce(&mut self, symbol: u8) -> f64 {
207        let bits = self.bits_per_symbol.clamp(1, 8);
208        let before = self.tree.get_log_block_probability();
209        if self.msb_first {
210            for bit_idx in 0..bits {
211                let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
212                self.tree.update(bit, bit_idx);
213            }
214            let after = self.tree.get_log_block_probability();
215            for bit_idx in (0..bits).rev() {
216                self.tree.revert(bit_idx);
217            }
218            after - before
219        } else {
220            for bit_idx in 0..bits {
221                let bit = ((symbol >> bit_idx) & 1) == 1;
222                self.tree.update(bit, bit_idx);
223            }
224            let after = self.tree.get_log_block_probability();
225            for bit_idx in (0..bits).rev() {
226                self.tree.revert(bit_idx);
227            }
228            after - before
229        }
230    }
231
232    fn normalize_pdf(pdf: &mut [f64]) {
233        let mut sum = 0.0f64;
234        for p in pdf.iter_mut() {
235            let v = if p.is_finite() { *p } else { 0.0 };
236            *p = v.max(PDF_MIN);
237            sum += *p;
238        }
239        if sum <= 0.0 || !sum.is_finite() {
240            let u = 1.0 / (pdf.len() as f64);
241            for p in pdf.iter_mut() {
242                *p = u;
243            }
244            return;
245        }
246        let inv = 1.0 / sum;
247        for p in pdf.iter_mut() {
248            *p *= inv;
249        }
250    }
251
252    fn pdf_next(&mut self) -> &[f64] {
253        if !self.valid {
254            let bits = self.bits_per_symbol.clamp(1, 8);
255            let patterns = self.fill_pattern_log_probs();
256            if bits == 8 {
257                for sym in 0..256usize {
258                    self.pdf[sym] = self.pattern_logps[sym].exp();
259                }
260            } else {
261                let aliases = 1usize << (8 - bits);
262                for byte in 0..256usize {
263                    let pat = if self.msb_first {
264                        byte >> (8 - bits)
265                    } else {
266                        byte & (patterns - 1)
267                    };
268                    self.pdf[byte] = self.pattern_logps[pat].exp() / (aliases as f64);
269                }
270            }
271            Self::normalize_pdf(&mut self.pdf);
272            self.valid = true;
273        }
274        &self.pdf
275    }
276
277    fn update(&mut self, symbol: u8) {
278        if self.msb_first {
279            for bit_idx in 0..self.bits_per_symbol {
280                let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
281                self.tree.update(bit, bit_idx);
282            }
283        } else {
284            for bit_idx in 0..self.bits_per_symbol {
285                let bit = ((symbol >> bit_idx) & 1) == 1;
286                self.tree.update(bit, bit_idx);
287            }
288        }
289        self.valid = false;
290    }
291
292    #[inline]
293    fn can_fast_ac_bitwise(&self) -> bool {
294        self.bits_per_symbol == 8 && self.msb_first
295    }
296
297    #[inline]
298    fn bit_prob_one_msb(&mut self, bit_idx: usize) -> f64 {
299        debug_assert!(self.can_fast_ac_bitwise());
300        self.tree.predict_one(bit_idx).clamp(PDF_MIN, 1.0 - PDF_MIN)
301    }
302
303    #[inline]
304    fn update_bit_msb(&mut self, bit_idx: usize, bit: bool) {
305        debug_assert!(self.can_fast_ac_bitwise());
306        self.tree.update_predicted(bit, bit_idx);
307        self.valid = false;
308    }
309}
310
311#[derive(Clone)]
312struct RosaPredictor {
313    model: RosaPlus,
314    pdf: Vec<f64>,
315    cdf: [f64; 257],
316    valid: bool,
317    cdf_valid: bool,
318}
319
320impl RosaPredictor {
321    fn new(max_order: i64) -> Self {
322        let mut model = RosaPlus::new(max_order, false, 0, 42);
323        model.build_lm_full_bytes_no_finalize_endpos();
324        Self {
325            model,
326            pdf: vec![0.0; 256],
327            cdf: uniform_cdf_row(),
328            valid: false,
329            cdf_valid: false,
330        }
331    }
332
333    fn pdf_next(&mut self) -> &[f64] {
334        self.ensure_pdf(false);
335        &self.pdf
336    }
337
338    fn cdf_next(&mut self) -> &[f64; 257] {
339        self.ensure_pdf(true);
340        &self.cdf
341    }
342
343    fn ensure_pdf(&mut self, want_cdf: bool) {
344        if self.valid {
345            if want_cdf && !self.cdf_valid {
346                build_cdf_row_from_pdf_slice(&self.pdf, &mut self.cdf);
347                self.cdf_valid = true;
348            }
349            return;
350        }
351        self.model.fill_probs_for_last_bytes(&mut self.pdf);
352        normalize_pdf_vec_and_maybe_build_cdf(
353            &mut self.pdf,
354            if want_cdf { Some(&mut self.cdf) } else { None },
355        );
356        self.valid = true;
357        self.cdf_valid = want_cdf;
358    }
359
360    fn update(&mut self, symbol: u8) {
361        self.model.train_byte(symbol);
362        self.valid = false;
363        self.cdf_valid = false;
364    }
365
366    fn begin_stream(&mut self, total_len: usize) {
367        self.model.reserve_for_stream(total_len);
368    }
369}
370
371#[derive(Clone)]
372#[cfg(feature = "backend-mamba")]
373struct MambaPredictor {
374    compressor: mambazip::Compressor,
375    primed: bool,
376    pdf: Vec<f64>,
377    cdf: [f64; 257],
378    valid: bool,
379    cdf_valid: bool,
380}
381
382#[derive(Clone)]
383#[cfg(feature = "backend-rwkv")]
384struct RwkvPredictor {
385    compressor: rwkvzip::Compressor,
386    primed: bool,
387    cdf: [f64; 257],
388    cdf_valid: bool,
389}
390
391#[derive(Clone)]
392struct ZpaqPredictor {
393    method: String,
394    history: Vec<u8>,
395    pdf: Vec<f64>,
396    valid: bool,
397}
398
399impl ZpaqPredictor {
400    fn new(method: String) -> Self {
401        Self {
402            method,
403            history: Vec::new(),
404            pdf: vec![0.0; 256],
405            valid: false,
406        }
407    }
408
409    fn pdf_next(&mut self) -> &[f64] {
410        if !self.valid {
411            for sym in 0..256usize {
412                let mut model = ZpaqRateModel::new(self.method.clone(), PDF_MIN);
413                if !self.history.is_empty() {
414                    let _ = model.update_and_score(&self.history);
415                }
416                let logp = model.log_prob(sym as u8);
417                self.pdf[sym] = logp.exp().max(PDF_MIN);
418            }
419            normalize_pdf(&mut self.pdf);
420            self.valid = true;
421        }
422        &self.pdf
423    }
424
425    fn update(&mut self, symbol: u8) {
426        self.history.push(symbol);
427        self.valid = false;
428    }
429}
430
431#[cfg(feature = "backend-mamba")]
432impl MambaPredictor {
433    fn from_model(model: std::sync::Arc<mambazip::Model>) -> Self {
434        let compressor = mambazip::Compressor::new_from_model(model);
435        let vocab = compressor.vocab_size();
436        Self {
437            compressor,
438            primed: false,
439            pdf: vec![0.0; vocab],
440            cdf: uniform_cdf_row(),
441            valid: false,
442            cdf_valid: false,
443        }
444    }
445
446    fn from_method(method: &str) -> Result<Self> {
447        let compressor = mambazip::Compressor::new_from_method(method)?;
448        let vocab = compressor.vocab_size();
449        Ok(Self {
450            compressor,
451            primed: false,
452            pdf: vec![0.0; vocab],
453            cdf: uniform_cdf_row(),
454            valid: false,
455            cdf_valid: false,
456        })
457    }
458
459    fn ensure_predicted(&mut self, want_cdf: bool) {
460        if self.valid {
461            if want_cdf && !self.cdf_valid {
462                debug_assert!(self.pdf.len() >= 256);
463                build_cdf_row_from_pdf_slice(&self.pdf[..256], &mut self.cdf);
464                self.cdf_valid = true;
465            }
466            return;
467        }
468        if !self.primed {
469            self.compressor.forward_to_pdf(0, &mut self.pdf);
470            self.primed = true;
471            self.valid = true;
472            self.cdf_valid = false;
473            if want_cdf {
474                debug_assert!(self.pdf.len() >= 256);
475                build_cdf_row_from_pdf_slice(&self.pdf[..256], &mut self.cdf);
476                self.cdf_valid = true;
477            }
478            return;
479        }
480        self.valid = true;
481        self.cdf_valid = false;
482        if want_cdf {
483            debug_assert!(self.pdf.len() >= 256);
484            build_cdf_row_from_pdf_slice(&self.pdf[..256], &mut self.cdf);
485            self.cdf_valid = true;
486        }
487    }
488
489    fn pdf_next(&mut self) -> &[f64] {
490        self.ensure_predicted(false);
491        &self.pdf
492    }
493
494    fn cdf_next(&mut self) -> &[f64; 257] {
495        self.ensure_predicted(true);
496        &self.cdf
497    }
498
499    fn update(&mut self, symbol: u8) -> Result<()> {
500        self.ensure_predicted(false);
501        self.compressor.online_update_from_pdf(symbol, &self.pdf)?;
502        self.compressor.forward_to_pdf(symbol as u32, &mut self.pdf);
503        self.valid = true;
504        self.cdf_valid = false;
505        Ok(())
506    }
507
508    fn begin_stream(&mut self, total_len: usize) -> Result<()> {
509        self.compressor
510            .begin_online_policy_stream(Some(total_len as u64))
511    }
512}
513
514#[cfg(feature = "backend-rwkv")]
515impl RwkvPredictor {
516    fn from_model(model: std::sync::Arc<rwkvzip::Model>) -> Self {
517        let compressor = rwkvzip::Compressor::new_from_model(model);
518        Self {
519            compressor,
520            primed: false,
521            cdf: uniform_cdf_row(),
522            cdf_valid: false,
523        }
524    }
525
526    fn from_method(method: &str) -> Result<Self> {
527        let compressor = rwkvzip::Compressor::new_from_method(method)?;
528        Ok(Self {
529            compressor,
530            primed: false,
531            cdf: uniform_cdf_row(),
532            cdf_valid: false,
533        })
534    }
535
536    fn ensure_predicted(&mut self, want_cdf: bool) {
537        if !self.primed {
538            self.compressor.reset_and_prime();
539            self.primed = true;
540            self.cdf_valid = false;
541        }
542        if want_cdf && !self.cdf_valid {
543            debug_assert!(self.compressor.pdf_buffer.len() >= 256);
544            build_cdf_row_from_pdf_slice(&self.compressor.pdf_buffer[..256], &mut self.cdf);
545            self.cdf_valid = true;
546        }
547    }
548
549    fn pdf_next(&mut self) -> &[f64] {
550        self.ensure_predicted(false);
551        &self.compressor.pdf_buffer
552    }
553
554    fn cdf_next(&mut self) -> &[f64; 257] {
555        self.ensure_predicted(true);
556        &self.cdf
557    }
558
559    fn update(&mut self, symbol: u8) -> Result<()> {
560        self.ensure_predicted(false);
561        self.compressor.observe_symbol_from_current_pdf(symbol)?;
562        self.cdf_valid = false;
563        Ok(())
564    }
565
566    fn begin_stream(&mut self, total_len: usize) -> Result<()> {
567        self.compressor
568            .begin_online_policy_stream(Some(total_len as u64))
569    }
570
571    fn finish_stream(&mut self) -> Result<()> {
572        self.compressor.finish_online_policy_stream()
573    }
574}
575
576#[derive(Clone)]
577struct MixExpert {
578    predictor: Box<RatePdfPredictor>,
579    log_weight: f64,
580    log_prior: f64,
581    cum_log_loss: f64,
582}
583
584#[derive(Clone, Copy, Debug, Default)]
585pub(crate) struct AcLogLossNodeValue {
586    pub(crate) prob: f64,
587    pub(crate) local_weight: f64,
588    pub(crate) effective_weight: f64,
589}
590
591#[derive(Clone, Debug, Default)]
592pub(crate) struct AcLogLossSubtreeSnapshot {
593    pub(crate) prob: f64,
594    pub(crate) rows: Vec<AcLogLossNodeValue>,
595}
596
597#[derive(Clone, Copy, Debug, Default)]
598pub(crate) struct AcLogLossRootSnapshot {
599    pub(crate) mix_prob: f64,
600    pub(crate) root_weight_entropy_bits: f64,
601    pub(crate) root_top1_child_index: Option<usize>,
602    pub(crate) root_top1_weight: f64,
603    pub(crate) root_top2_child_index: Option<usize>,
604    pub(crate) root_top2_weight: f64,
605}
606
607#[derive(Clone)]
608struct MixturePredictor {
609    kind: MixtureKind,
610    schedule: MixtureScheduleMode,
611    alpha: f64,
612    decay: f64,
613    experts: Vec<MixExpert>,
614    prior_weights: Vec<f64>,
615    neural: NeuralMixCore,
616    analyzer: TextContextAnalyzer,
617    neural_logps: Vec<f64>,
618    neural_bit_modes: Vec<u8>,
619    neural_lo: Vec<usize>,
620    neural_hi: Vec<usize>,
621    neural_pdf_cdf_rows: Vec<Vec<f64>>,
622    scratch: Vec<f64>,
623    scratch2: Vec<f64>,
624    projection_scratch: Vec<f64>,
625    pdf: Vec<f64>,
626    valid: bool,
627    switch_updates: u64,
628    convex_updates: u64,
629}
630
631impl MixturePredictor {
632    fn new(spec: &MixtureSpec) -> Result<Self> {
633        spec.validate().map_err(anyhow::Error::msg)?;
634        let mut experts = Vec::with_capacity(spec.experts.len());
635        for e in &spec.experts {
636            experts.push(MixExpert {
637                predictor: Box::new(RatePdfPredictor::from_rate_backend(
638                    e.backend.clone(),
639                    e.max_order,
640                )?),
641                log_weight: e.log_prior,
642                log_prior: e.log_prior,
643                cum_log_loss: 0.0,
644            });
645        }
646        let m = logsumexp_expert_weights(&experts);
647        for e in &mut experts {
648            e.log_weight -= m;
649        }
650
651        let mut prior_weights = vec![0.0; experts.len()];
652        normalized_mix_expert_prior_weights(&experts, &mut prior_weights);
653        let mut neural_prior_weights = prior_weights.clone();
654        for weight in &mut neural_prior_weights {
655            *weight = weight.clamp(PDF_MIN, 1.0 - PDF_MIN);
656        }
657
658        let base_lr = spec.alpha.abs().clamp(1e-6, 1.0);
659        let effective_lr = (base_lr * 25.0).clamp(1e-6, 1.0);
660        let analyzer = TextContextAnalyzer::new();
661        let mut neural = NeuralMixCore::new(
662            experts.len(),
663            &neural_prior_weights,
664            effective_lr * 0.5,
665            effective_lr,
666            1e-5,
667        );
668        neural.set_context_state(analyzer.state());
669        Ok(Self {
670            kind: spec.kind,
671            schedule: spec.schedule,
672            alpha: spec.alpha,
673            decay: spec.decay.unwrap_or(1.0).clamp(0.0, 1.0),
674            experts,
675            prior_weights,
676            neural,
677            analyzer,
678            neural_logps: vec![0.0; spec.experts.len()],
679            neural_bit_modes: vec![0; spec.experts.len()],
680            neural_lo: vec![0; spec.experts.len()],
681            neural_hi: vec![256; spec.experts.len()],
682            neural_pdf_cdf_rows: vec![vec![0.0; 257]; spec.experts.len()],
683            scratch: Vec::new(),
684            scratch2: Vec::new(),
685            projection_scratch: Vec::new(),
686            pdf: vec![0.0; 256],
687            valid: false,
688            switch_updates: 0,
689            convex_updates: 0,
690        })
691    }
692
693    fn best_expert_index(&self) -> Option<usize> {
694        let mut best_idx = None;
695        let mut best_loss = f64::INFINITY;
696        for (index, expert) in self.experts.iter().enumerate() {
697            if expert.cum_log_loss < best_loss {
698                best_loss = expert.cum_log_loss;
699                best_idx = Some(index);
700            }
701        }
702        best_idx
703    }
704
705    fn predictive_weights(&mut self) -> Vec<f64> {
706        if self.experts.is_empty() {
707            return Vec::new();
708        }
709
710        match self.kind {
711            MixtureKind::Neural => {
712                if self.experts.len() == 1 {
713                    return vec![1.0];
714                }
715                self.neural.set_context_state(self.analyzer.state());
716                self.neural.evaluate_expert_weights();
717                let mut weights = self.neural.expert_weights().to_vec();
718                normalize_simplex_weights(&mut weights);
719                weights
720            }
721            MixtureKind::Mdl => {
722                let mut weights = vec![0.0; self.experts.len()];
723                if let Some(best_idx) = self.best_expert_index() {
724                    weights[best_idx] = 1.0;
725                }
726                weights
727            }
728            MixtureKind::FadingBayes => {
729                let max_log = self
730                    .experts
731                    .iter()
732                    .map(|expert| self.decay * expert.log_weight)
733                    .fold(f64::NEG_INFINITY, f64::max);
734                let mut weights = self
735                    .experts
736                    .iter()
737                    .map(|expert| {
738                        if max_log.is_finite() {
739                            (self.decay * expert.log_weight - max_log).exp()
740                        } else {
741                            0.0
742                        }
743                    })
744                    .collect::<Vec<_>>();
745                normalize_simplex_weights(&mut weights);
746                weights
747            }
748            MixtureKind::Convex => {
749                let mut weights = self
750                    .experts
751                    .iter()
752                    .map(|expert| expert.log_weight.exp())
753                    .collect::<Vec<_>>();
754                normalize_simplex_weights(&mut weights);
755                weights
756            }
757            MixtureKind::Bayes | MixtureKind::Switching => {
758                let max_log = self
759                    .experts
760                    .iter()
761                    .map(|expert| expert.log_weight)
762                    .fold(f64::NEG_INFINITY, f64::max);
763                let mut weights = self
764                    .experts
765                    .iter()
766                    .map(|expert| {
767                        if max_log.is_finite() {
768                            (expert.log_weight - max_log).exp()
769                        } else {
770                            0.0
771                        }
772                    })
773                    .collect::<Vec<_>>();
774                normalize_simplex_weights(&mut weights);
775                weights
776            }
777        }
778    }
779
780    fn ensure_pdf(&mut self) -> Result<&[f64]> {
781        if self.valid {
782            return Ok(&self.pdf);
783        }
784        let weights = self.predictive_weights();
785        if weights.len() == 1 && matches!(self.kind, MixtureKind::Mdl | MixtureKind::Neural) {
786            self.pdf.fill(0.0);
787        } else {
788            self.pdf.fill(0.0);
789        }
790        for (index, expert) in self.experts.iter_mut().enumerate() {
791            let weight = weights.get(index).copied().unwrap_or(0.0);
792            if weight <= 0.0 {
793                continue;
794            }
795            let epdf = expert.predictor.pdf_next()?;
796            for (slot, &p) in self.pdf.iter_mut().zip(epdf.iter()) {
797                *slot += weight * p;
798            }
799        }
800
801        normalize_pdf(&mut self.pdf);
802        self.valid = true;
803        Ok(&self.pdf)
804    }
805
806    fn begin_stream(&mut self, total_len: usize) -> Result<()> {
807        for expert in &mut self.experts {
808            match &mut *expert.predictor {
809                // Direct CTW benefits from pre-reserving, but inside mixtures that extra
810                // headroom can dominate peak RSS without a proportional runtime gain.
811                RatePdfPredictor::Ctw(_) | RatePdfPredictor::FacCtw(_) => {}
812                _ => expert.predictor.begin_stream(total_len)?,
813            }
814        }
815        Ok(())
816    }
817
818    fn diagnostic_collect_children(
819        &mut self,
820        symbol: u8,
821        weights: &[f64],
822        effective_prefix: f64,
823        pool: Option<&ThreadPool>,
824    ) -> Result<Vec<AcLogLossSubtreeSnapshot>> {
825        let use_parallel = pool.is_some() && self.experts.len() >= DIAGNOSTIC_PARALLEL_THRESHOLD;
826        if use_parallel {
827            let pool = pool.expect("checked is_some");
828            pool.install(|| {
829                self.experts
830                    .par_iter_mut()
831                    .enumerate()
832                    .map(|(index, expert)| {
833                        let local_weight = weights.get(index).copied().unwrap_or(0.0);
834                        let effective_weight = effective_prefix * local_weight;
835                        expert.predictor.diagnostic_snapshot_subtree(
836                            symbol,
837                            local_weight,
838                            effective_weight,
839                            None,
840                        )
841                    })
842                    .collect()
843            })
844        } else {
845            let mut children = Vec::with_capacity(self.experts.len());
846            for (index, expert) in self.experts.iter_mut().enumerate() {
847                let local_weight = weights.get(index).copied().unwrap_or(0.0);
848                let effective_weight = effective_prefix * local_weight;
849                children.push(expert.predictor.diagnostic_snapshot_subtree(
850                    symbol,
851                    local_weight,
852                    effective_weight,
853                    pool,
854                )?);
855            }
856            Ok(children)
857        }
858    }
859
860    fn diagnostic_subtree_snapshot(
861        &mut self,
862        symbol: u8,
863        local_weight: f64,
864        effective_weight: f64,
865        pool: Option<&ThreadPool>,
866    ) -> Result<AcLogLossSubtreeSnapshot> {
867        let weights = self.predictive_weights();
868        let children =
869            self.diagnostic_collect_children(symbol, &weights, effective_weight, pool)?;
870        let mix_prob = children
871            .iter()
872            .enumerate()
873            .map(|(index, child)| weights.get(index).copied().unwrap_or(0.0) * child.prob)
874            .sum::<f64>()
875            .max(PDF_MIN);
876        let total_rows = 1 + children.iter().map(|child| child.rows.len()).sum::<usize>();
877        let mut rows = Vec::with_capacity(total_rows);
878        rows.push(AcLogLossNodeValue {
879            prob: mix_prob,
880            local_weight,
881            effective_weight,
882        });
883        for child in children {
884            rows.extend(child.rows);
885        }
886        Ok(AcLogLossSubtreeSnapshot {
887            prob: mix_prob,
888            rows,
889        })
890    }
891
892    fn diagnostic_root_snapshot(
893        &mut self,
894        symbol: u8,
895        pool: Option<&ThreadPool>,
896        out: &mut Vec<AcLogLossNodeValue>,
897    ) -> Result<AcLogLossRootSnapshot> {
898        let weights = self.predictive_weights();
899        let children = self.diagnostic_collect_children(symbol, &weights, 1.0, pool)?;
900        out.clear();
901        out.reserve(children.iter().map(|child| child.rows.len()).sum::<usize>());
902        for child in &children {
903            out.extend_from_slice(&child.rows);
904        }
905
906        let mix_prob = children
907            .iter()
908            .enumerate()
909            .map(|(index, child)| weights.get(index).copied().unwrap_or(0.0) * child.prob)
910            .sum::<f64>()
911            .max(PDF_MIN);
912
913        let mut top1 = None;
914        let mut top2 = None;
915        for (index, &weight) in weights.iter().enumerate() {
916            match top1 {
917                None => top1 = Some((index, weight)),
918                Some((best_idx, best_weight)) if weight > best_weight => {
919                    top2 = Some((best_idx, best_weight));
920                    top1 = Some((index, weight));
921                }
922                _ => match top2 {
923                    None => top2 = Some((index, weight)),
924                    Some((_, second_weight)) if weight > second_weight => {
925                        top2 = Some((index, weight));
926                    }
927                    _ => {}
928                },
929            }
930        }
931
932        let root_weight_entropy_bits = weights
933            .iter()
934            .copied()
935            .filter(|weight| *weight > 0.0)
936            .map(|weight| -weight * weight.log2())
937            .sum::<f64>();
938
939        Ok(AcLogLossRootSnapshot {
940            mix_prob,
941            root_weight_entropy_bits,
942            root_top1_child_index: top1.map(|(index, _)| index),
943            root_top1_weight: top1.map(|(_, weight)| weight).unwrap_or(0.0),
944            root_top2_child_index: top2.map(|(index, _)| index),
945            root_top2_weight: top2.map(|(_, weight)| weight).unwrap_or(0.0),
946        })
947    }
948
949    fn update(&mut self, symbol: u8) -> Result<()> {
950        let _ = self.ensure_pdf()?;
951
952        match self.kind {
953            MixtureKind::Bayes => {
954                let n = self.experts.len();
955                self.scratch.resize(n, 0.0);
956                self.scratch2.resize(n, 0.0);
957                for (i, e) in self.experts.iter_mut().enumerate() {
958                    let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
959                    let lp = p.ln();
960                    self.scratch[i] = lp;
961                    self.scratch2[i] = e.log_weight + lp;
962                }
963                let log_mix = logsumexp_slice(&self.scratch2[..n]);
964                for (i, e) in self.experts.iter_mut().enumerate() {
965                    e.log_weight = e.log_weight + self.scratch[i] - log_mix;
966                    e.cum_log_loss -= self.scratch[i];
967                    e.predictor.update(symbol)?;
968                }
969            }
970            MixtureKind::FadingBayes => {
971                let n = self.experts.len();
972                self.scratch.resize(n, 0.0);
973                self.scratch2.resize(n, 0.0);
974                for (i, e) in self.experts.iter_mut().enumerate() {
975                    let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
976                    let lp = p.ln();
977                    self.scratch[i] = lp;
978                    self.scratch2[i] = e.log_weight + lp;
979                }
980                for (i, e) in self.experts.iter_mut().enumerate() {
981                    self.scratch2[i] = self.decay * e.log_weight + self.scratch[i];
982                }
983                let log_mix = logsumexp_slice(&self.scratch2[..n]);
984                for (i, e) in self.experts.iter_mut().enumerate() {
985                    e.log_weight = self.decay * e.log_weight + self.scratch[i] - log_mix;
986                    e.cum_log_loss -= self.scratch[i];
987                    e.predictor.update(symbol)?;
988                }
989            }
990            MixtureKind::Switching => {
991                let n = self.experts.len();
992                self.scratch.resize(n, 0.0);
993                self.scratch2.resize(n, 0.0);
994                for (i, e) in self.experts.iter_mut().enumerate() {
995                    let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
996                    let lp = p.ln();
997                    self.scratch[i] = lp;
998                    self.scratch2[i] = e.log_weight + lp;
999                }
1000                let log_mix = logsumexp_slice(&self.scratch2[..n]);
1001                for (i, e) in self.experts.iter_mut().enumerate() {
1002                    self.scratch2[i] = (self.scratch2[i] - log_mix).exp();
1003                    e.cum_log_loss -= self.scratch[i];
1004                    e.predictor.update(symbol)?;
1005                }
1006                let alpha =
1007                    switching_alpha_for_update(self.schedule, self.alpha, self.switch_updates);
1008                self.switch_updates = self.switch_updates.saturating_add(1);
1009                apply_switching_weights(
1010                    &mut self.experts,
1011                    &self.prior_weights[..n],
1012                    alpha,
1013                    &mut self.scratch2[..n],
1014                    &mut self.scratch[..n],
1015                );
1016            }
1017            MixtureKind::Convex => {
1018                let n = self.experts.len();
1019                self.scratch.resize(n, 0.0);
1020                self.scratch2.resize(n, 0.0);
1021                for (i, e) in self.experts.iter_mut().enumerate() {
1022                    let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
1023                    let lp = p.ln();
1024                    self.scratch[i] = lp;
1025                    self.scratch2[i] = e.log_weight.exp();
1026                    e.cum_log_loss -= lp;
1027                    e.predictor.update(symbol)?;
1028                }
1029                let mix_prob = self
1030                    .scratch
1031                    .iter()
1032                    .zip(self.scratch2.iter())
1033                    .map(|(&lp, &w)| w * lp.exp())
1034                    .sum::<f64>()
1035                    .max(PDF_MIN);
1036                let log_mix = mix_prob.ln();
1037                self.convex_updates = self.convex_updates.saturating_add(1);
1038                let eta =
1039                    convex_step_size_for_update(self.schedule, self.alpha, self.convex_updates);
1040                for i in 0..n {
1041                    let grad = -(self.scratch[i] - log_mix).exp();
1042                    self.scratch2[i] -= eta * grad;
1043                }
1044                project_simplex_with_scratch(&mut self.scratch2[..n], &mut self.projection_scratch);
1045                for i in 0..n {
1046                    self.experts[i].log_weight = self.scratch2[i].max(PDF_MIN).ln();
1047                }
1048            }
1049            MixtureKind::Mdl => {
1050                let n = self.experts.len();
1051                self.scratch.resize(n, 0.0);
1052                for (i, e) in self.experts.iter_mut().enumerate() {
1053                    let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
1054                    let lp = p.ln();
1055                    self.scratch[i] = lp;
1056                }
1057                for (i, e) in self.experts.iter_mut().enumerate() {
1058                    e.cum_log_loss -= self.scratch[i];
1059                    e.predictor.update(symbol)?;
1060                }
1061            }
1062            MixtureKind::Neural => {
1063                let y = symbol as usize;
1064                if self.experts.len() == 1 {
1065                    let lp = self.experts[0].predictor.pdf_next()?[y].max(PDF_MIN).ln();
1066                    self.experts[0].cum_log_loss -= lp;
1067                    self.experts[0].predictor.update(symbol)?;
1068                    self.analyzer.update(symbol);
1069                    self.neural.set_context_state(self.analyzer.state());
1070                    self.valid = false;
1071                    return Ok(());
1072                }
1073                let n = self.experts.len();
1074                self.neural.set_context_state(self.analyzer.state());
1075                self.neural_logps.resize(n, 0.0);
1076                for i in 0..n {
1077                    let p = self.experts[i].predictor.pdf_next()?[y].max(PDF_MIN);
1078                    let lp = p.ln();
1079                    self.neural_logps[i] = lp;
1080                    self.experts[i].cum_log_loss -= lp;
1081                }
1082                self.neural.evaluate_symbol(&self.neural_logps, PDF_MIN);
1083                self.neural
1084                    .update_weights_symbol(&self.neural_logps, PDF_MIN);
1085                for e in &mut self.experts {
1086                    e.predictor.update(symbol)?;
1087                }
1088                self.analyzer.update(symbol);
1089                self.neural.set_context_state(self.analyzer.state());
1090            }
1091        }
1092
1093        self.valid = false;
1094        Ok(())
1095    }
1096
1097    fn finish_stream(&mut self) -> Result<()> {
1098        for expert in &mut self.experts {
1099            expert.predictor.finish_stream()?;
1100        }
1101        Ok(())
1102    }
1103
1104    #[inline]
1105    fn can_fast_ac_bitwise(&self) -> bool {
1106        self.experts.iter().any(|e| {
1107            if let RatePdfPredictor::Ctw(ctw) = &*e.predictor {
1108                ctw.can_fast_ac_bitwise()
1109            } else {
1110                false
1111            }
1112        })
1113    }
1114
1115    fn ac_step_bitwise<F>(&mut self, mut choose_bit: F) -> Result<u8>
1116    where
1117        F: FnMut(usize, f64) -> Result<u8>,
1118    {
1119        let n = self.experts.len();
1120        self.scratch.resize(n, 0.0);
1121        match self.kind {
1122            MixtureKind::Neural if n > 1 => {
1123                self.neural.set_context_state(self.analyzer.state());
1124                self.neural.evaluate_expert_weights();
1125                self.scratch.copy_from_slice(self.neural.expert_weights());
1126            }
1127            MixtureKind::FadingBayes => {
1128                let weights = self.predictive_weights();
1129                self.scratch.copy_from_slice(&weights);
1130            }
1131            MixtureKind::Mdl => {
1132                let weights = self.predictive_weights();
1133                self.scratch.copy_from_slice(&weights);
1134            }
1135            _ => {
1136                let weights = self.predictive_weights();
1137                self.scratch.copy_from_slice(&weights);
1138            }
1139        }
1140        self.scratch2.resize(n, 1.0);
1141        self.scratch2.fill(1.0);
1142        self.neural_logps.resize(n, 0.0);
1143        self.neural_bit_modes.resize(n, 0);
1144        self.neural_lo.resize(n, 0);
1145        self.neural_hi.resize(n, 256);
1146        if self.neural_pdf_cdf_rows.len() < n {
1147            self.neural_pdf_cdf_rows.resize_with(n, || vec![0.0; 257]);
1148        }
1149
1150        for i in 0..n {
1151            self.neural_bit_modes[i] = 1;
1152            self.neural_lo[i] = 0;
1153            self.neural_hi[i] = 256;
1154
1155            let mut handled_ctw = false;
1156            if let RatePdfPredictor::Ctw(ctw) = &mut *self.experts[i].predictor
1157                && ctw.can_fast_ac_bitwise()
1158            {
1159                self.neural_bit_modes[i] = 0;
1160                handled_ctw = true;
1161            }
1162            if handled_ctw {
1163                continue;
1164            }
1165
1166            if self.experts[i]
1167                .predictor
1168                .prepare_cached_cdf_fast_bitwise()?
1169            {
1170                self.neural_bit_modes[i] = 2;
1171                continue;
1172            }
1173
1174            let pdf = self.experts[i].predictor.pdf_next()?;
1175            let row = &mut self.neural_pdf_cdf_rows[i];
1176            if row.len() != 257 {
1177                row.resize(257, 0.0);
1178            }
1179            row[0] = 0.0;
1180            for b in 0..256usize {
1181                row[b + 1] = row[b] + pdf[b].max(PDF_MIN);
1182            }
1183            if !row[256].is_finite() || row[256] <= 0.0 {
1184                for (j, v) in row.iter_mut().enumerate() {
1185                    *v = (j as f64) / 256.0;
1186                }
1187            }
1188        }
1189
1190        let mut symbol = 0u8;
1191        for bit_idx in 0..8usize {
1192            let mut denom = 0.0;
1193            let mut numer1 = 0.0;
1194
1195            for i in 0..n {
1196                let p1 = if self.neural_bit_modes[i] == 0 {
1197                    match &mut *self.experts[i].predictor {
1198                        RatePdfPredictor::Ctw(ctw) => ctw.bit_prob_one_msb(bit_idx),
1199                        _ => 0.5,
1200                    }
1201                } else if self.neural_bit_modes[i] == 2 {
1202                    self.experts[i]
1203                        .predictor
1204                        .cached_cdf_bit_prob_one_msb(self.neural_lo[i], self.neural_hi[i])
1205                        .unwrap_or(0.5)
1206                } else {
1207                    let lo = self.neural_lo[i];
1208                    let hi = self.neural_hi[i];
1209                    let mid = (lo + hi) >> 1;
1210                    let row = &self.neural_pdf_cdf_rows[i];
1211                    let total = (row[hi] - row[lo]).max(PDF_MIN);
1212                    let one = (row[hi] - row[mid]).max(0.0);
1213                    (one / total).clamp(PDF_MIN, 1.0 - PDF_MIN)
1214                };
1215                self.neural_logps[i] = p1;
1216                let wp = self.scratch[i] * self.scratch2[i];
1217                denom += wp;
1218                numer1 += wp * p1;
1219            }
1220
1221            let p1_mix = if denom.is_finite() && denom > 0.0 {
1222                (numer1 / denom).clamp(PDF_MIN, 1.0 - PDF_MIN)
1223            } else {
1224                0.5
1225            };
1226            let bit = choose_bit(bit_idx, p1_mix)? & 1;
1227            symbol |= bit << (7 - bit_idx);
1228
1229            for i in 0..n {
1230                let p1 = self.neural_logps[i];
1231                let pb = if bit == 1 { p1 } else { 1.0 - p1 };
1232                self.scratch2[i] = (self.scratch2[i] * pb).max(PDF_MIN);
1233
1234                if self.neural_bit_modes[i] == 0 {
1235                    if let RatePdfPredictor::Ctw(ctw) = &mut *self.experts[i].predictor {
1236                        ctw.update_bit_msb(bit_idx, bit == 1);
1237                    }
1238                } else {
1239                    let lo = self.neural_lo[i];
1240                    let hi = self.neural_hi[i];
1241                    let mid = (lo + hi) >> 1;
1242                    if bit == 1 {
1243                        self.neural_lo[i] = mid;
1244                        self.neural_hi[i] = hi;
1245                    } else {
1246                        self.neural_lo[i] = lo;
1247                        self.neural_hi[i] = mid;
1248                    }
1249                }
1250            }
1251        }
1252
1253        for i in 0..n {
1254            let lp = self.scratch2[i].max(PDF_MIN).ln();
1255            self.neural_logps[i] = lp;
1256            self.experts[i].cum_log_loss -= lp;
1257            if self.neural_bit_modes[i] != 0 {
1258                self.experts[i].predictor.update(symbol)?;
1259            }
1260        }
1261
1262        match self.kind {
1263            MixtureKind::Bayes => {
1264                for i in 0..n {
1265                    self.scratch[i] = self.experts[i].log_weight + self.neural_logps[i];
1266                }
1267                let log_mix = logsumexp_slice(&self.scratch[..n]);
1268                for i in 0..n {
1269                    self.experts[i].log_weight += self.neural_logps[i] - log_mix;
1270                }
1271            }
1272            MixtureKind::FadingBayes => {
1273                for i in 0..n {
1274                    self.scratch[i] =
1275                        self.decay * self.experts[i].log_weight + self.neural_logps[i];
1276                }
1277                let log_mix = logsumexp_slice(&self.scratch[..n]);
1278                for i in 0..n {
1279                    self.experts[i].log_weight = self.scratch[i] - log_mix;
1280                }
1281            }
1282            MixtureKind::Switching => {
1283                for i in 0..n {
1284                    self.scratch[i] = self.experts[i].log_weight + self.neural_logps[i];
1285                }
1286                let log_mix = logsumexp_slice(&self.scratch[..n]);
1287                for weight in &mut self.scratch[..n] {
1288                    *weight = (*weight - log_mix).exp();
1289                }
1290                let alpha =
1291                    switching_alpha_for_update(self.schedule, self.alpha, self.switch_updates);
1292                self.switch_updates = self.switch_updates.saturating_add(1);
1293                apply_switching_weights(
1294                    &mut self.experts,
1295                    &self.prior_weights[..n],
1296                    alpha,
1297                    &mut self.scratch[..n],
1298                    &mut self.scratch2[..n],
1299                );
1300            }
1301            MixtureKind::Convex => {
1302                self.scratch.resize(n, 0.0);
1303                self.scratch2.resize(n, 0.0);
1304                for i in 0..n {
1305                    self.scratch2[i] = self.experts[i].log_weight.exp();
1306                }
1307                let mix_prob = self
1308                    .neural_logps
1309                    .iter()
1310                    .zip(self.scratch2.iter())
1311                    .map(|(&lp, &w)| w * lp.exp())
1312                    .sum::<f64>()
1313                    .max(PDF_MIN);
1314                let log_mix = mix_prob.ln();
1315                self.convex_updates = self.convex_updates.saturating_add(1);
1316                let eta =
1317                    convex_step_size_for_update(self.schedule, self.alpha, self.convex_updates);
1318                for i in 0..n {
1319                    let grad = -(self.neural_logps[i] - log_mix).exp();
1320                    self.scratch2[i] -= eta * grad;
1321                }
1322                project_simplex_with_scratch(&mut self.scratch2[..n], &mut self.projection_scratch);
1323                for i in 0..n {
1324                    self.experts[i].log_weight = self.scratch2[i].max(PDF_MIN).ln();
1325                }
1326            }
1327            MixtureKind::Mdl => {}
1328            MixtureKind::Neural => {
1329                if n > 1 {
1330                    self.neural.set_context_state(self.analyzer.state());
1331                    self.neural.evaluate_symbol(&self.neural_logps, PDF_MIN);
1332                    self.neural
1333                        .update_weights_symbol(&self.neural_logps, PDF_MIN);
1334                }
1335                self.analyzer.update(symbol);
1336                self.neural.set_context_state(self.analyzer.state());
1337            }
1338        }
1339        self.valid = false;
1340        Ok(symbol)
1341    }
1342}
1343
1344pub(crate) struct DiagnosticRatePredictor {
1345    inner: RatePdfPredictor,
1346}
1347
1348impl DiagnosticRatePredictor {
1349    pub(crate) fn from_rate_backend(backend: RateBackend, max_order: i64) -> Result<Self> {
1350        Ok(Self {
1351            inner: RatePdfPredictor::from_rate_backend(backend, max_order)?,
1352        })
1353    }
1354
1355    pub(crate) fn begin_stream(&mut self, total_len: usize) -> Result<()> {
1356        self.inner.begin_stream(total_len)
1357    }
1358
1359    pub(crate) fn finish_stream(&mut self) -> Result<()> {
1360        self.inner.finish_stream()
1361    }
1362
1363    #[cfg(test)]
1364    pub(crate) fn pdf_next(&mut self) -> Result<&[f64]> {
1365        self.inner.pdf_next()
1366    }
1367
1368    #[cfg(test)]
1369    pub(crate) fn update(&mut self, symbol: u8) -> Result<()> {
1370        self.inner.update(symbol)
1371    }
1372
1373    pub(crate) fn diagnostic_root_snapshot(
1374        &mut self,
1375        symbol: u8,
1376        pool: Option<&ThreadPool>,
1377        out: &mut Vec<AcLogLossNodeValue>,
1378    ) -> Result<AcLogLossRootSnapshot> {
1379        self.inner.diagnostic_root_snapshot(symbol, pool, out)
1380    }
1381
1382    pub(crate) fn encode_symbol_ac_step<W: std::io::Write>(
1383        &mut self,
1384        symbol: u8,
1385        encoder: &mut ArithmeticEncoder<W>,
1386    ) -> Result<()> {
1387        self.inner.encode_symbol_ac_step(symbol, encoder)
1388    }
1389}
1390
1391#[derive(Clone)]
1392#[allow(clippy::large_enum_variant)]
1393enum RatePdfPredictor {
1394    Rosa(RosaPredictor),
1395    Match {
1396        model: MatchModel,
1397    },
1398    SparseMatch {
1399        model: SparseMatchModel,
1400    },
1401    Ppmd {
1402        model: PpmdModel,
1403    },
1404    Sequitur {
1405        model: SequiturModel,
1406    },
1407    Ctw(CtwPredictor),
1408    FacCtw(CtwPredictor),
1409    #[cfg(feature = "backend-mamba")]
1410    Mamba(MambaPredictor),
1411    #[cfg(feature = "backend-rwkv")]
1412    Rwkv(RwkvPredictor),
1413    Zpaq(ZpaqPredictor),
1414    Mixture(MixturePredictor),
1415    Particle(crate::particle::ParticleRuntime),
1416    Calibrated {
1417        base: Box<RatePdfPredictor>,
1418        core: CalibratorCore,
1419        pdf: Vec<f64>,
1420        valid: bool,
1421    },
1422}
1423
1424impl RatePdfPredictor {
1425    fn from_rate_backend(backend: RateBackend, max_order: i64) -> Result<Self> {
1426        match backend {
1427            RateBackend::RosaPlus => Ok(Self::Rosa(RosaPredictor::new(max_order))),
1428            RateBackend::Match {
1429                hash_bits,
1430                min_len,
1431                max_len,
1432                base_mix,
1433                confidence_scale,
1434            } => Ok(Self::Match {
1435                model: MatchModel::new_contiguous(
1436                    hash_bits,
1437                    min_len,
1438                    max_len,
1439                    base_mix,
1440                    confidence_scale,
1441                ),
1442            }),
1443            RateBackend::SparseMatch {
1444                hash_bits,
1445                min_len,
1446                max_len,
1447                gap_min,
1448                gap_max,
1449                base_mix,
1450                confidence_scale,
1451            } => Ok(Self::SparseMatch {
1452                model: SparseMatchModel::new(
1453                    hash_bits,
1454                    min_len,
1455                    max_len,
1456                    gap_min,
1457                    gap_max,
1458                    base_mix,
1459                    confidence_scale,
1460                ),
1461            }),
1462            RateBackend::Ppmd { order, memory_mb } => Ok(Self::Ppmd {
1463                model: PpmdModel::new(order, memory_mb),
1464            }),
1465            RateBackend::Sequitur { context_bytes } => Ok(Self::Sequitur {
1466                model: SequiturModel::new(context_bytes),
1467            }),
1468            RateBackend::Ctw { depth } => Ok(Self::Ctw(CtwPredictor::new_ctw(depth))),
1469            RateBackend::FacCtw {
1470                base_depth,
1471                num_percept_bits: _,
1472                encoding_bits,
1473            } => {
1474                let bits = encoding_bits.clamp(1, 8);
1475                Ok(Self::FacCtw(CtwPredictor::new_fac(base_depth, bits)))
1476            }
1477            #[cfg(feature = "backend-mamba")]
1478            RateBackend::Mamba { model } => Ok(Self::Mamba(MambaPredictor::from_model(model))),
1479            #[cfg(feature = "backend-mamba")]
1480            RateBackend::MambaMethod { method } => {
1481                Ok(Self::Mamba(MambaPredictor::from_method(&method)?))
1482            }
1483            #[cfg(feature = "backend-rwkv")]
1484            RateBackend::Rwkv7 { model } => Ok(Self::Rwkv(RwkvPredictor::from_model(model))),
1485            #[cfg(feature = "backend-rwkv")]
1486            RateBackend::Rwkv7Method { method } => {
1487                Ok(Self::Rwkv(RwkvPredictor::from_method(&method)?))
1488            }
1489            RateBackend::Zpaq { method } => Ok(Self::Zpaq(ZpaqPredictor::new(method))),
1490            RateBackend::Mixture { spec } => {
1491                Ok(Self::Mixture(MixturePredictor::new(spec.as_ref())?))
1492            }
1493            RateBackend::Particle { spec } => Ok(Self::Particle(
1494                crate::particle::ParticleRuntime::new(spec.as_ref()),
1495            )),
1496            RateBackend::Calibrated { spec } => Ok(Self::Calibrated {
1497                base: Box::new(Self::from_rate_backend(spec.base.clone(), max_order)?),
1498                core: build_calibrator(spec.as_ref()),
1499                pdf: vec![1.0 / 256.0; 256],
1500                valid: false,
1501            }),
1502        }
1503    }
1504
1505    fn begin_stream(&mut self, total_len: usize) -> Result<()> {
1506        self.finish_stream()?;
1507        match self {
1508            Self::Rosa(m) => {
1509                m.begin_stream(total_len);
1510                Ok(())
1511            }
1512            Self::Match { .. }
1513            | Self::SparseMatch { .. }
1514            | Self::Ppmd { .. }
1515            | Self::Zpaq(_)
1516            | Self::Particle(_) => Ok(()),
1517            Self::Sequitur { model } => {
1518                model.begin_stream(Some(total_len as u64));
1519                Ok(())
1520            }
1521            Self::Ctw(m) | Self::FacCtw(m) => {
1522                m.tree.reserve_for_symbols(total_len);
1523                Ok(())
1524            }
1525            #[cfg(feature = "backend-mamba")]
1526            Self::Mamba(m) => m.begin_stream(total_len),
1527            #[cfg(feature = "backend-rwkv")]
1528            Self::Rwkv(m) => m.begin_stream(total_len),
1529            Self::Mixture(m) => m.begin_stream(total_len),
1530            Self::Calibrated { base, .. } => base.begin_stream(total_len),
1531        }
1532    }
1533
1534    fn finish_stream(&mut self) -> Result<()> {
1535        match self {
1536            Self::Rosa(_)
1537            | Self::Match { .. }
1538            | Self::SparseMatch { .. }
1539            | Self::Ppmd { .. }
1540            | Self::Sequitur { .. }
1541            | Self::Ctw(_)
1542            | Self::FacCtw(_)
1543            | Self::Zpaq(_)
1544            | Self::Particle(_) => Ok(()),
1545            #[cfg(feature = "backend-mamba")]
1546            Self::Mamba(_) => Ok(()),
1547            #[cfg(feature = "backend-rwkv")]
1548            Self::Rwkv(m) => m.finish_stream(),
1549            Self::Mixture(m) => m.finish_stream(),
1550            Self::Calibrated { base, .. } => base.finish_stream(),
1551        }
1552    }
1553
1554    fn pdf_next(&mut self) -> Result<&[f64]> {
1555        match self {
1556            Self::Rosa(m) => Ok(m.pdf_next()),
1557            Self::Match { model } => Ok(model.pdf()),
1558            Self::Ctw(m) => Ok(m.pdf_next()),
1559            Self::FacCtw(m) => Ok(m.pdf_next()),
1560            #[cfg(feature = "backend-mamba")]
1561            Self::Mamba(m) => Ok(m.pdf_next()),
1562            #[cfg(feature = "backend-rwkv")]
1563            Self::Rwkv(m) => Ok(m.pdf_next()),
1564            Self::Zpaq(m) => Ok(m.pdf_next()),
1565            Self::Mixture(m) => m.ensure_pdf(),
1566            Self::Particle(m) => Ok(m.pdf_next()),
1567            Self::SparseMatch { model } => Ok(model.pdf()),
1568            Self::Ppmd { model } => Ok(model.pdf()),
1569            Self::Sequitur { model } => Ok(model.pdf()),
1570            Self::Calibrated {
1571                base,
1572                core,
1573                pdf,
1574                valid,
1575            } => {
1576                if !*valid {
1577                    let base_pdf = base.pdf_next()?;
1578                    core.apply_pdf(base_pdf, pdf);
1579                    normalize_pdf(pdf);
1580                    *valid = true;
1581                }
1582                Ok(pdf)
1583            }
1584        }
1585    }
1586
1587    fn update(&mut self, symbol: u8) -> Result<()> {
1588        match self {
1589            Self::Rosa(m) => {
1590                m.update(symbol);
1591                Ok(())
1592            }
1593            Self::Match { model } => {
1594                model.update(symbol);
1595                Ok(())
1596            }
1597            Self::SparseMatch { model } => {
1598                model.update(symbol);
1599                Ok(())
1600            }
1601            Self::Ppmd { model } => {
1602                model.update(symbol);
1603                Ok(())
1604            }
1605            Self::Sequitur { model } => {
1606                model.update(symbol);
1607                Ok(())
1608            }
1609            Self::Ctw(m) => {
1610                m.update(symbol);
1611                Ok(())
1612            }
1613            Self::FacCtw(m) => {
1614                m.update(symbol);
1615                Ok(())
1616            }
1617            #[cfg(feature = "backend-mamba")]
1618            Self::Mamba(m) => m.update(symbol),
1619            #[cfg(feature = "backend-rwkv")]
1620            Self::Rwkv(m) => m.update(symbol),
1621            Self::Zpaq(m) => {
1622                m.update(symbol);
1623                Ok(())
1624            }
1625            Self::Mixture(m) => m.update(symbol),
1626            Self::Particle(m) => {
1627                m.step(symbol);
1628                Ok(())
1629            }
1630            Self::Calibrated {
1631                base,
1632                core,
1633                pdf,
1634                valid,
1635            } => {
1636                if !*valid {
1637                    let base_pdf = base.pdf_next()?;
1638                    core.apply_pdf(base_pdf, pdf);
1639                    normalize_pdf(pdf);
1640                }
1641                core.update(symbol, pdf);
1642                base.update(symbol)?;
1643                *valid = false;
1644                Ok(())
1645            }
1646        }
1647    }
1648
1649    fn prepare_cached_cdf_fast_bitwise(&mut self) -> Result<bool> {
1650        match self {
1651            Self::Rosa(m) => {
1652                let _ = m.cdf_next();
1653                Ok(true)
1654            }
1655            Self::Match { model } => {
1656                let _ = model.cdf();
1657                Ok(true)
1658            }
1659            Self::SparseMatch { model } => {
1660                let _ = model.cdf();
1661                Ok(true)
1662            }
1663            Self::Ppmd { model } => {
1664                let _ = model.cdf();
1665                Ok(true)
1666            }
1667            #[cfg(feature = "backend-mamba")]
1668            Self::Mamba(m) => {
1669                let _ = m.cdf_next();
1670                Ok(true)
1671            }
1672            #[cfg(feature = "backend-rwkv")]
1673            Self::Rwkv(m) => {
1674                let _ = m.cdf_next();
1675                Ok(true)
1676            }
1677            _ => Ok(false),
1678        }
1679    }
1680
1681    fn cached_cdf_bit_prob_one_msb(&mut self, lo: usize, hi: usize) -> Option<f64> {
1682        match self {
1683            Self::Rosa(m) => Some(cdf_bit_prob_one_msb(&m.cdf, lo, hi)),
1684            Self::Match { model } => Some(cdf_bit_prob_one_msb(model.cdf(), lo, hi)),
1685            Self::SparseMatch { model } => Some(cdf_bit_prob_one_msb(model.cdf(), lo, hi)),
1686            Self::Ppmd { model } => Some(cdf_bit_prob_one_msb(model.cdf(), lo, hi)),
1687            #[cfg(feature = "backend-mamba")]
1688            Self::Mamba(m) => Some(cdf_bit_prob_one_msb(m.cdf_next(), lo, hi)),
1689            #[cfg(feature = "backend-rwkv")]
1690            Self::Rwkv(m) => Some(cdf_bit_prob_one_msb(m.cdf_next(), lo, hi)),
1691            _ => None,
1692        }
1693    }
1694
1695    #[inline]
1696    fn can_fast_ac_bitwise(&self) -> bool {
1697        match self {
1698            Self::Ctw(m) => m.can_fast_ac_bitwise(),
1699            Self::Mixture(m) => m.can_fast_ac_bitwise(),
1700            _ => false,
1701        }
1702    }
1703
1704    fn ac_step_fast_bitwise<F>(&mut self, choose_bit: F) -> Result<u8>
1705    where
1706        F: FnMut(usize, f64) -> Result<u8>,
1707    {
1708        match self {
1709            Self::Ctw(m) => ctw_ac_step_bitwise(m, choose_bit),
1710            Self::Mixture(m) => m.ac_step_bitwise(choose_bit),
1711            _ => unreachable!("fast bitwise path requested for unsupported predictor"),
1712        }
1713    }
1714
1715    fn diagnostic_snapshot_subtree(
1716        &mut self,
1717        symbol: u8,
1718        local_weight: f64,
1719        effective_weight: f64,
1720        pool: Option<&ThreadPool>,
1721    ) -> Result<AcLogLossSubtreeSnapshot> {
1722        match self {
1723            Self::Mixture(m) => {
1724                m.diagnostic_subtree_snapshot(symbol, local_weight, effective_weight, pool)
1725            }
1726            _ => {
1727                let prob = self.pdf_next()?[symbol as usize].max(PDF_MIN);
1728                Ok(AcLogLossSubtreeSnapshot {
1729                    prob,
1730                    rows: vec![AcLogLossNodeValue {
1731                        prob,
1732                        local_weight,
1733                        effective_weight,
1734                    }],
1735                })
1736            }
1737        }
1738    }
1739
1740    fn diagnostic_root_snapshot(
1741        &mut self,
1742        symbol: u8,
1743        pool: Option<&ThreadPool>,
1744        out: &mut Vec<AcLogLossNodeValue>,
1745    ) -> Result<AcLogLossRootSnapshot> {
1746        match self {
1747            Self::Mixture(m) => m.diagnostic_root_snapshot(symbol, pool, out),
1748            _ => anyhow::bail!("AC log-loss diagnostics require a top-level mixture backend"),
1749        }
1750    }
1751
1752    fn encode_symbol_ac_step<W: std::io::Write>(
1753        &mut self,
1754        symbol: u8,
1755        encoder: &mut ArithmeticEncoder<W>,
1756    ) -> Result<()> {
1757        if self.can_fast_ac_bitwise() {
1758            self.ac_step_fast_bitwise(|bit_idx, p1_mix| {
1759                let bit = (symbol >> (7 - bit_idx)) & 1;
1760                let split = binary_split_from_prob_one(p1_mix);
1761                if bit == 0 {
1762                    encoder.encode_counts(0, split as u64, CDF_TOTAL as u64)?;
1763                } else {
1764                    encoder.encode_counts(split as u64, CDF_TOTAL as u64, CDF_TOTAL as u64)?;
1765                }
1766                Ok(bit)
1767            })?;
1768            return Ok(());
1769        }
1770
1771        let pdf = self.pdf_next()?;
1772        let mut cdf = vec![0u32; 257];
1773        crate::coders::quantize_pdf_to_integer_cdf_dense_positive_with_buffer(
1774            pdf, CDF_TOTAL, &mut cdf,
1775        );
1776        let sym = symbol as usize;
1777        encoder.encode_counts(cdf[sym] as u64, cdf[sym + 1] as u64, CDF_TOTAL as u64)?;
1778        self.update(symbol)
1779    }
1780}
1781
1782fn ctw_ac_step_bitwise<F>(ctw: &mut CtwPredictor, mut choose_bit: F) -> Result<u8>
1783where
1784    F: FnMut(usize, f64) -> Result<u8>,
1785{
1786    debug_assert!(ctw.can_fast_ac_bitwise());
1787    let mut symbol = 0u8;
1788    for bit_idx in 0..8usize {
1789        let p1 = ctw.bit_prob_one_msb(bit_idx);
1790        let bit = choose_bit(bit_idx, p1)? & 1;
1791        symbol |= bit << (7 - bit_idx);
1792        ctw.update_bit_msb(bit_idx, bit == 1);
1793    }
1794    Ok(symbol)
1795}
1796
1797#[inline]
1798fn binary_split_from_prob_one(p1: f64) -> u32 {
1799    let p1 = p1.clamp(PDF_MIN, 1.0 - PDF_MIN);
1800    let p0 = 1.0 - p1;
1801    let mut split = (p0 * (CDF_TOTAL as f64)) as u32;
1802    if split == 0 {
1803        split = 1;
1804    } else if split >= CDF_TOTAL {
1805        split = CDF_TOTAL - 1;
1806    }
1807    split
1808}
1809
1810fn encode_payload_ac(data: &[u8], predictor: &mut RatePdfPredictor) -> Result<Vec<u8>> {
1811    predictor.begin_stream(data.len())?;
1812    let mut out = Vec::new();
1813    {
1814        let mut enc = ArithmeticEncoder::new(&mut out);
1815        for &symbol in data {
1816            predictor.encode_symbol_ac_step(symbol, &mut enc)?;
1817        }
1818        let _ = enc.finish()?;
1819    }
1820    predictor.finish_stream()?;
1821    Ok(out)
1822}
1823
1824fn decode_payload_ac(
1825    payload: &[u8],
1826    out_len: usize,
1827    predictor: &mut RatePdfPredictor,
1828) -> Result<Vec<u8>> {
1829    predictor.begin_stream(out_len)?;
1830    if predictor.can_fast_ac_bitwise() {
1831        let mut dec = ArithmeticDecoder::new(payload)?;
1832        let mut out = Vec::with_capacity(out_len);
1833        for _ in 0..out_len {
1834            let symbol = predictor.ac_step_fast_bitwise(|_, p1_mix| {
1835                let split = binary_split_from_prob_one(p1_mix);
1836                let cdf = [0u32, split, CDF_TOTAL];
1837                Ok(dec.decode_symbol_counts(&cdf, CDF_TOTAL)? as u8)
1838            })?;
1839            out.push(symbol);
1840        }
1841        predictor.finish_stream()?;
1842        return Ok(out);
1843    }
1844
1845    let mut dec = ArithmeticDecoder::new(payload)?;
1846    let mut out = Vec::with_capacity(out_len);
1847    let mut cdf = vec![0u32; 257];
1848    for _ in 0..out_len {
1849        let pdf = predictor.pdf_next()?;
1850        crate::coders::quantize_pdf_to_integer_cdf_dense_positive_with_buffer(
1851            pdf, CDF_TOTAL, &mut cdf,
1852        );
1853        let sym = dec.decode_symbol_counts(&cdf, CDF_TOTAL)? as u8;
1854        out.push(sym);
1855        predictor.update(sym)?;
1856    }
1857    predictor.finish_stream()?;
1858    Ok(out)
1859}
1860
1861fn encode_payload_rans(data: &[u8], predictor: &mut RatePdfPredictor) -> Result<Vec<u8>> {
1862    predictor.begin_stream(data.len())?;
1863    let mut encoder = BlockedRansEncoder::new();
1864    let mut cdf = vec![0u32; 257];
1865    let mut freq = vec![0i64; 256];
1866
1867    for &b in data {
1868        let pdf = predictor.pdf_next()?;
1869        quantize_pdf_to_rans_cdf_with_buffer(pdf, &mut cdf, &mut freq);
1870        let s = b as usize;
1871        encoder.encode(Cdf::new(cdf[s], cdf[s + 1], ANS_TOTAL));
1872        predictor.update(b)?;
1873    }
1874
1875    let blocks = encoder.finish();
1876    let mut out = Vec::new();
1877    out.extend_from_slice(&(blocks.len() as u32).to_le_bytes());
1878    for block in blocks {
1879        out.extend_from_slice(&(block.len() as u32).to_le_bytes());
1880        out.extend_from_slice(&block);
1881    }
1882    predictor.finish_stream()?;
1883    Ok(out)
1884}
1885
1886fn decode_payload_rans(
1887    payload: &[u8],
1888    out_len: usize,
1889    predictor: &mut RatePdfPredictor,
1890) -> Result<Vec<u8>> {
1891    predictor.begin_stream(out_len)?;
1892    if payload.len() < 4 {
1893        bail!("rANS payload too short");
1894    }
1895    let block_count = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]) as usize;
1896    let mut pos = 4usize;
1897    let mut blocks = Vec::with_capacity(block_count);
1898    for _ in 0..block_count {
1899        if pos + 4 > payload.len() {
1900            bail!("truncated rANS block header");
1901        }
1902        let len = u32::from_le_bytes([
1903            payload[pos],
1904            payload[pos + 1],
1905            payload[pos + 2],
1906            payload[pos + 3],
1907        ]) as usize;
1908        pos += 4;
1909        if pos + len > payload.len() {
1910            bail!("truncated rANS block data");
1911        }
1912        blocks.push(&payload[pos..pos + len]);
1913        pos += len;
1914    }
1915
1916    let mut dec = BlockedRansDecoder::new(blocks, out_len)?;
1917    let mut out = Vec::with_capacity(out_len);
1918    let mut cdf = vec![0u32; 257];
1919    let mut freq = vec![0i64; 256];
1920
1921    for _ in 0..out_len {
1922        let pdf = predictor.pdf_next()?;
1923        quantize_pdf_to_rans_cdf_with_buffer(pdf, &mut cdf, &mut freq);
1924        let sym = dec.decode(&cdf)? as u8;
1925        out.push(sym);
1926        predictor.update(sym)?;
1927    }
1928    predictor.finish_stream()?;
1929    Ok(out)
1930}
1931
1932/// Compress bytes using a predictive rate backend and entropy coder.
1933///
1934/// When `framing` is [`FramingMode::Framed`], output includes a compact header
1935/// with payload metadata and CRC for safer transport/storage.
1936pub fn compress_rate_bytes(
1937    data: &[u8],
1938    rate_backend: &RateBackend,
1939    max_order: i64,
1940    coder: CoderType,
1941    framing: FramingMode,
1942) -> Result<Vec<u8>> {
1943    let mut predictor = RatePdfPredictor::from_rate_backend(rate_backend.clone(), max_order)?;
1944    let payload = match coder {
1945        CoderType::AC => encode_payload_ac(data, &mut predictor)?,
1946        CoderType::RANS => encode_payload_rans(data, &mut predictor)?,
1947    };
1948
1949    if framing == FramingMode::Raw {
1950        return Ok(payload);
1951    }
1952
1953    let mut out = Vec::with_capacity(FramedHeader::SIZE + payload.len());
1954    let hdr = FramedHeader::new(coder, data.len() as u64, crc32(data));
1955    hdr.write(&mut out);
1956    out.extend_from_slice(&payload);
1957    Ok(out)
1958}
1959
1960/// Return compressed size (in bytes) for `data` using rate coding.
1961pub fn compress_rate_size(
1962    data: &[u8],
1963    rate_backend: &RateBackend,
1964    max_order: i64,
1965    coder: CoderType,
1966    framing: FramingMode,
1967) -> Result<u64> {
1968    let encoded = compress_rate_bytes(data, rate_backend, max_order, coder, framing)?;
1969    Ok(encoded.len() as u64)
1970}
1971
1972/// Return compressed size (in bytes) for concatenated slices under one stream.
1973pub fn compress_rate_size_chain(
1974    parts: &[&[u8]],
1975    rate_backend: &RateBackend,
1976    max_order: i64,
1977    coder: CoderType,
1978    framing: FramingMode,
1979) -> Result<u64> {
1980    let total = parts.iter().map(|p| p.len()).sum();
1981    let mut data = Vec::with_capacity(total);
1982    for p in parts {
1983        data.extend_from_slice(p);
1984    }
1985    compress_rate_size(&data, rate_backend, max_order, coder, framing)
1986}
1987
1988/// Decompress bytes produced by [`compress_rate_bytes`].
1989pub fn decompress_rate_bytes(
1990    input: &[u8],
1991    rate_backend: &RateBackend,
1992    max_order: i64,
1993    _coder: CoderType,
1994    framing: FramingMode,
1995) -> Result<Vec<u8>> {
1996    let (payload, coder, out_len, expected_crc) = if framing == FramingMode::Framed {
1997        let hdr = FramedHeader::read(input)?;
1998        (
1999            &input[FramedHeader::SIZE..],
2000            hdr.coder_type(),
2001            hdr.original_len as usize,
2002            Some(hdr.crc32),
2003        )
2004    } else {
2005        bail!("raw payload decompression requires explicit output length and is not supported");
2006    };
2007
2008    let _ = coder;
2009    let mut predictor = RatePdfPredictor::from_rate_backend(rate_backend.clone(), max_order)?;
2010    let decoded = match coder {
2011        CoderType::AC => decode_payload_ac(payload, out_len, &mut predictor)?,
2012        CoderType::RANS => decode_payload_rans(payload, out_len, &mut predictor)?,
2013    };
2014
2015    if let Some(crc) = expected_crc {
2016        let got = crc32(&decoded);
2017        if got != crc {
2018            bail!("CRC32 mismatch: expected 0x{crc:08X}, got 0x{got:08X}");
2019        }
2020    }
2021
2022    Ok(decoded)
2023}
2024
2025fn normalize_pdf(pdf: &mut [f64]) {
2026    let mut sum = 0.0;
2027    for p in pdf.iter_mut() {
2028        *p = if p.is_finite() {
2029            (*p).max(PDF_MIN)
2030        } else {
2031            PDF_MIN
2032        };
2033        sum += *p;
2034    }
2035    if !(sum.is_finite()) || sum <= 0.0 {
2036        let u = 1.0 / (pdf.len() as f64);
2037        for p in pdf.iter_mut() {
2038            *p = u;
2039        }
2040        return;
2041    }
2042    let inv = 1.0 / sum;
2043    for p in pdf.iter_mut() {
2044        *p *= inv;
2045    }
2046}
2047
2048#[inline]
2049fn uniform_cdf_row() -> [f64; 257] {
2050    let mut cdf = [0.0; 257];
2051    let inv = 1.0 / 256.0;
2052    for (i, slot) in cdf.iter_mut().enumerate() {
2053        *slot = (i as f64) * inv;
2054    }
2055    cdf
2056}
2057
2058#[inline]
2059fn build_cdf_row_from_pdf_slice(pdf: &[f64], cdf: &mut [f64; 257]) {
2060    cdf[0] = 0.0;
2061    let mut acc = 0.0;
2062    for i in 0..256 {
2063        acc += pdf[i];
2064        cdf[i + 1] = acc;
2065    }
2066}
2067
2068fn normalize_pdf_vec_and_maybe_build_cdf(pdf: &mut [f64], mut cdf: Option<&mut [f64; 257]>) {
2069    let mut sum = 0.0;
2070    for p in pdf.iter_mut() {
2071        *p = if p.is_finite() {
2072            (*p).max(PDF_MIN)
2073        } else {
2074            PDF_MIN
2075        };
2076        sum += *p;
2077    }
2078    if !(sum.is_finite()) || sum <= 0.0 {
2079        let u = 1.0 / (pdf.len() as f64);
2080        pdf.fill(u);
2081        if let Some(cdf) = cdf.as_deref_mut() {
2082            *cdf = uniform_cdf_row();
2083        }
2084        return;
2085    }
2086    let inv = 1.0 / sum;
2087    if let Some(cdf) = cdf.as_deref_mut() {
2088        cdf[0] = 0.0;
2089        let mut acc = 0.0;
2090        for i in 0..256 {
2091            pdf[i] *= inv;
2092            acc += pdf[i];
2093            cdf[i + 1] = acc;
2094        }
2095    } else {
2096        for p in pdf.iter_mut() {
2097            *p *= inv;
2098        }
2099    }
2100}
2101
2102#[inline]
2103fn cdf_bit_prob_one_msb(cdf: &[f64; 257], lo: usize, hi: usize) -> f64 {
2104    let mid = (lo + hi) >> 1;
2105    let total = (cdf[hi] - cdf[lo]).max(PDF_MIN);
2106    let one = (cdf[hi] - cdf[mid]).max(0.0);
2107    (one / total).clamp(PDF_MIN, 1.0 - PDF_MIN)
2108}
2109
2110#[inline]
2111fn logsumexp_slice(vals: &[f64]) -> f64 {
2112    let mut m = f64::NEG_INFINITY;
2113    for &v in vals {
2114        if v > m {
2115            m = v;
2116        }
2117    }
2118    if !m.is_finite() {
2119        return m;
2120    }
2121    let mut s = 0.0;
2122    for &v in vals {
2123        s += (v - m).exp();
2124    }
2125    m + s.ln()
2126}
2127
2128#[inline]
2129fn logsumexp_expert_weights(experts: &[MixExpert]) -> f64 {
2130    let mut m = f64::NEG_INFINITY;
2131    for e in experts {
2132        if e.log_weight > m {
2133            m = e.log_weight;
2134        }
2135    }
2136    if !m.is_finite() {
2137        return m;
2138    }
2139    let mut s = 0.0;
2140    for e in experts {
2141        s += (e.log_weight - m).exp();
2142    }
2143    m + s.ln()
2144}
2145
2146fn normalize_simplex_weights(weights: &mut [f64]) {
2147    if weights.is_empty() {
2148        return;
2149    }
2150    let mut sum = 0.0;
2151    for weight in weights.iter_mut() {
2152        if !weight.is_finite() || *weight < 0.0 {
2153            *weight = 0.0;
2154        }
2155        sum += *weight;
2156    }
2157    if !sum.is_finite() || sum <= 0.0 {
2158        let uniform = 1.0 / (weights.len() as f64);
2159        weights.fill(uniform);
2160        return;
2161    }
2162    for weight in weights.iter_mut() {
2163        *weight /= sum;
2164    }
2165}
2166
2167fn normalized_mix_expert_prior_weights(experts: &[MixExpert], out: &mut [f64]) {
2168    debug_assert_eq!(experts.len(), out.len());
2169    let max_log = experts
2170        .iter()
2171        .map(|expert| expert.log_prior)
2172        .fold(f64::NEG_INFINITY, f64::max);
2173    for (slot, expert) in out.iter_mut().zip(experts.iter()) {
2174        *slot = if max_log.is_finite() {
2175            (expert.log_prior - max_log).exp()
2176        } else {
2177            0.0
2178        };
2179    }
2180    normalize_simplex_weights(out);
2181}
2182
2183fn set_mix_expert_log_weights_from_linear(experts: &mut [MixExpert], weights: &[f64]) {
2184    for (expert, &weight) in experts.iter_mut().zip(weights.iter()) {
2185        expert.log_weight = if weight > 0.0 {
2186            weight.ln()
2187        } else {
2188            f64::NEG_INFINITY
2189        };
2190    }
2191}
2192
2193fn apply_switching_weights(
2194    experts: &mut [MixExpert],
2195    prior_weights: &[f64],
2196    alpha: f64,
2197    posterior: &mut [f64],
2198    scratch: &mut [f64],
2199) {
2200    if experts.is_empty() {
2201        return;
2202    }
2203    debug_assert_eq!(experts.len(), prior_weights.len());
2204
2205    normalize_simplex_weights(posterior);
2206    if experts.len() == 1 || alpha <= 0.0 {
2207        set_mix_expert_log_weights_from_linear(experts, posterior);
2208        return;
2209    }
2210
2211    let num_switch_targets = prior_weights.iter().filter(|&&prior| prior < 1.0).count();
2212    if num_switch_targets <= 1 {
2213        set_mix_expert_log_weights_from_linear(experts, posterior);
2214        return;
2215    }
2216
2217    let mut switch_out_sum = 0.0;
2218    for i in 0..experts.len() {
2219        let denom = 1.0 - prior_weights[i];
2220        if denom > 0.0 {
2221            switch_out_sum += posterior[i] / denom;
2222        }
2223    }
2224
2225    for i in 0..experts.len() {
2226        let prior = prior_weights[i];
2227        let stay = (1.0 - alpha) * posterior[i];
2228        let switch_in = if prior > 0.0 {
2229            let denom = 1.0 - prior;
2230            let switchable_mass = if denom > 0.0 {
2231                switch_out_sum - posterior[i] / denom
2232            } else {
2233                0.0
2234            };
2235            alpha * prior * switchable_mass
2236        } else {
2237            0.0
2238        };
2239        scratch[i] = stay + switch_in;
2240    }
2241
2242    normalize_simplex_weights(scratch);
2243    set_mix_expert_log_weights_from_linear(experts, scratch);
2244}
2245
2246#[allow(dead_code)]
2247fn _zpaq_marker(_: &ZpaqRateModel) {}
2248
2249#[cfg(test)]
2250mod tests {
2251    use super::*;
2252    use std::sync::Arc;
2253
2254    fn assert_pdf_close(lhs: &[f64], rhs: &[f64], tol: f64) {
2255        assert_eq!(lhs.len(), rhs.len());
2256        for (idx, (&a, &b)) in lhs.iter().zip(rhs.iter()).enumerate() {
2257            let delta = (a - b).abs();
2258            assert!(
2259                delta <= tol,
2260                "pdf mismatch at symbol {idx}: lhs={a} rhs={b} delta={delta}"
2261            );
2262        }
2263    }
2264
2265    fn brute_force_pdf(predictor: &mut CtwPredictor) -> Vec<f64> {
2266        let bits = predictor.bits_per_symbol.clamp(1, 8);
2267        let mut out = vec![0.0; 256];
2268
2269        if bits == 8 {
2270            for (sym, slot) in out.iter_mut().enumerate().take(256usize) {
2271                *slot = predictor.log_prob_symbol_bruteforce(sym as u8).exp();
2272            }
2273        } else {
2274            let patterns = 1usize << bits;
2275            let aliases = 1usize << (8 - bits);
2276            let mut pat_prob = vec![0.0; patterns];
2277            for (pat, value) in pat_prob.iter_mut().enumerate() {
2278                let symbol = if predictor.msb_first {
2279                    (pat as u8) << (8 - bits)
2280                } else {
2281                    pat as u8
2282                };
2283                *value = predictor.log_prob_symbol_bruteforce(symbol).exp();
2284            }
2285            for (byte, slot) in out.iter_mut().enumerate().take(256usize) {
2286                let pat = if predictor.msb_first {
2287                    byte >> (8 - bits)
2288                } else {
2289                    byte & (patterns - 1)
2290                };
2291                *slot = pat_prob[pat] / (aliases as f64);
2292            }
2293        }
2294
2295        CtwPredictor::normalize_pdf(&mut out);
2296        out
2297    }
2298
2299    #[test]
2300    fn ctw_pdf_fast_matches_bruteforce() {
2301        let mut predictor = CtwPredictor::new_ctw(6);
2302        for &b in b"ctw fast-path regression corpus 1234567890" {
2303            predictor.update(b);
2304        }
2305
2306        let fast = predictor.pdf_next().to_vec();
2307        predictor.valid = false;
2308        let brute = brute_force_pdf(&mut predictor);
2309
2310        for i in 0..256usize {
2311            let delta = (fast[i] - brute[i]).abs();
2312            assert!(
2313                delta < 1e-12,
2314                "symbol={i} fast={} brute={} delta={delta}",
2315                fast[i],
2316                brute[i]
2317            );
2318        }
2319    }
2320
2321    #[test]
2322    fn fac_pdf_fast_matches_bruteforce_subbyte() {
2323        let mut predictor = CtwPredictor::new_fac(5, 5);
2324        for &b in b"fac ctw subbyte regression corpus abcdefghijklmnopqrstuvwxyz" {
2325            predictor.update(b);
2326        }
2327
2328        let fast = predictor.pdf_next().to_vec();
2329        predictor.valid = false;
2330        let brute = brute_force_pdf(&mut predictor);
2331
2332        for i in 0..256usize {
2333            let delta = (fast[i] - brute[i]).abs();
2334            assert!(
2335                delta < 1e-12,
2336                "symbol={i} fast={} brute={} delta={delta}",
2337                fast[i],
2338                brute[i]
2339            );
2340        }
2341    }
2342
2343    fn assert_ctw_pdf_next_preserves_state(mut predictor: CtwPredictor) {
2344        for &b in b"ctw predictor state preservation payload" {
2345            predictor.update(b);
2346        }
2347        let mut before_p0 = [0.0f64; 8];
2348        let mut before_p1 = [0.0f64; 8];
2349        for bit_idx in 0..8usize {
2350            before_p0[bit_idx] = predictor.tree.predict(false, bit_idx);
2351            before_p1[bit_idx] = predictor.tree.predict(true, bit_idx);
2352        }
2353        let log_before = predictor.tree.get_log_block_probability();
2354        let _ = predictor.pdf_next();
2355        let log_after = predictor.tree.get_log_block_probability();
2356        assert!(
2357            (log_before - log_after).abs() < 1e-12,
2358            "log drift: before={log_before} after={log_after}"
2359        );
2360        for bit_idx in 0..8usize {
2361            let after_p0 = predictor.tree.predict(false, bit_idx);
2362            let after_p1 = predictor.tree.predict(true, bit_idx);
2363            assert!(
2364                (before_p0[bit_idx] - after_p0).abs() < 1e-12,
2365                "bit {bit_idx} p0 drift: {} vs {}",
2366                before_p0[bit_idx],
2367                after_p0
2368            );
2369            assert!(
2370                (before_p1[bit_idx] - after_p1).abs() < 1e-12,
2371                "bit {bit_idx} p1 drift: {} vs {}",
2372                before_p1[bit_idx],
2373                after_p1
2374            );
2375        }
2376    }
2377
2378    #[test]
2379    fn ctw_pdf_next_preserves_state() {
2380        assert_ctw_pdf_next_preserves_state(CtwPredictor::new_ctw(7));
2381    }
2382
2383    #[test]
2384    fn fac_pdf_next_preserves_state() {
2385        assert_ctw_pdf_next_preserves_state(CtwPredictor::new_fac(7, 8));
2386    }
2387
2388    fn assert_fill_pattern_preserves_symbol_log_probs(mut predictor: CtwPredictor) {
2389        for &b in b"fill-pattern preservation regression payload" {
2390            predictor.update(b);
2391        }
2392        let mut baseline = [0.0f64; 256];
2393        for (sym, slot) in baseline.iter_mut().enumerate() {
2394            *slot = predictor.log_prob_symbol_bruteforce(sym as u8);
2395        }
2396        let _ = predictor.fill_pattern_log_probs();
2397        for (sym, &expected) in baseline.iter().enumerate() {
2398            let got = predictor.log_prob_symbol_bruteforce(sym as u8);
2399            let diff = (expected - got).abs();
2400            assert!(
2401                diff < 1e-12,
2402                "symbol={sym} expected={expected} got={got} diff={diff}"
2403            );
2404        }
2405    }
2406
2407    #[test]
2408    fn ctw_fill_pattern_preserves_symbol_log_probs() {
2409        assert_fill_pattern_preserves_symbol_log_probs(CtwPredictor::new_ctw(7));
2410    }
2411
2412    #[test]
2413    fn fac_fill_pattern_preserves_symbol_log_probs() {
2414        assert_fill_pattern_preserves_symbol_log_probs(CtwPredictor::new_fac(7, 8));
2415    }
2416
2417    fn assert_pdf_then_update_matches_plain_update(mut base: CtwPredictor) {
2418        for &b in b"pdf then update parity payload" {
2419            base.update(b);
2420        }
2421        let observed = b'n';
2422        let mut with_pdf = base.clone();
2423        let mut plain = base;
2424
2425        let _ = with_pdf.pdf_next();
2426        with_pdf.update(observed);
2427        plain.update(observed);
2428
2429        for sym in 0u8..=255u8 {
2430            let lp_with_pdf = with_pdf.log_prob_symbol_bruteforce(sym);
2431            let lp_plain = plain.log_prob_symbol_bruteforce(sym);
2432            let diff = (lp_with_pdf - lp_plain).abs();
2433            assert!(
2434                diff < 1e-12,
2435                "symbol={sym} with_pdf={lp_with_pdf} plain={lp_plain} diff={diff}"
2436            );
2437        }
2438    }
2439
2440    #[test]
2441    fn ctw_pdf_then_update_matches_plain_update() {
2442        assert_pdf_then_update_matches_plain_update(CtwPredictor::new_ctw(7));
2443    }
2444
2445    #[test]
2446    fn fac_pdf_then_update_matches_plain_update() {
2447        assert_pdf_then_update_matches_plain_update(CtwPredictor::new_fac(7, 8));
2448    }
2449
2450    #[test]
2451    fn roundtrip_rate_ac_ctw() {
2452        let data = b"ctw backend roundtrip payload";
2453        let backend = RateBackend::Ctw { depth: 8 };
2454        let enc =
2455            compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2456        let dec =
2457            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2458        assert_eq!(dec, data);
2459    }
2460
2461    #[test]
2462    fn roundtrip_rate_ac_match_family_and_ppmd() {
2463        let data = b"repeat repeat repeat sparse sparse repeat payload";
2464        for backend in [
2465            RateBackend::Match {
2466                hash_bits: 20,
2467                min_len: 4,
2468                max_len: 255,
2469                base_mix: 0.02,
2470                confidence_scale: 1.0,
2471            },
2472            RateBackend::SparseMatch {
2473                hash_bits: 19,
2474                min_len: 3,
2475                max_len: 64,
2476                gap_min: 1,
2477                gap_max: 2,
2478                base_mix: 0.05,
2479                confidence_scale: 1.0,
2480            },
2481            RateBackend::Ppmd {
2482                order: 8,
2483                memory_mb: 8,
2484            },
2485        ] {
2486            let enc = compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed)
2487                .unwrap();
2488            let dec = decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed)
2489                .unwrap();
2490            assert_eq!(dec, data);
2491        }
2492    }
2493
2494    #[test]
2495    fn roundtrip_rate_ac_ppmd_high_order_text_payload() {
2496        let seed = include_bytes!("../../README.md");
2497        let mut data = Vec::with_capacity(4096);
2498        while data.len() < 4096 {
2499            data.extend_from_slice(seed);
2500        }
2501        data.truncate(4096);
2502
2503        let backend = RateBackend::Ppmd {
2504            order: 12,
2505            memory_mb: 256,
2506        };
2507        let enc = compress_rate_bytes(&data, &backend, -1, CoderType::AC, FramingMode::Framed)
2508            .expect("ppmd high-order compression");
2509        let dec = decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed)
2510            .expect("ppmd high-order decompression");
2511        assert_eq!(dec, data);
2512    }
2513
2514    #[test]
2515    fn roundtrip_rate_ac_calibrated_backend() {
2516        let data = b"calibration wrapper payload calibration wrapper payload";
2517        let backend = RateBackend::Calibrated {
2518            spec: Arc::new(crate::CalibratedSpec {
2519                base: RateBackend::Ctw { depth: 8 },
2520                context: crate::CalibrationContextKind::Text,
2521                bins: 33,
2522                learning_rate: 0.02,
2523                bias_clip: 4.0,
2524            }),
2525        };
2526        let enc =
2527            compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2528        let dec =
2529            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2530        assert_eq!(dec, data);
2531    }
2532
2533    #[test]
2534    fn roundtrip_rate_ac_single_expert_ctw_neural_mixture() {
2535        let data = b"single expert neural ctw fast path payload";
2536        let spec = MixtureSpec::new(
2537            MixtureKind::Neural,
2538            vec![crate::MixtureExpertSpec {
2539                name: Some("ctw".to_string()),
2540                log_prior: 0.0,
2541                max_order: -1,
2542                backend: RateBackend::Ctw { depth: 8 },
2543            }],
2544        )
2545        .with_alpha(0.03);
2546        let backend = RateBackend::Mixture {
2547            spec: Arc::new(spec),
2548        };
2549        let enc =
2550            compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2551        let dec =
2552            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2553        assert_eq!(dec, data);
2554    }
2555
2556    #[test]
2557    fn roundtrip_rate_ac_single_expert_ctw_bayes_mixture() {
2558        let data = b"single expert bayes ctw fast path payload";
2559        let spec = MixtureSpec::new(
2560            MixtureKind::Bayes,
2561            vec![crate::MixtureExpertSpec {
2562                name: Some("ctw".to_string()),
2563                log_prior: 0.0,
2564                max_order: -1,
2565                backend: RateBackend::Ctw { depth: 8 },
2566            }],
2567        )
2568        .with_alpha(0.03);
2569        let backend = RateBackend::Mixture {
2570            spec: Arc::new(spec),
2571        };
2572        let enc =
2573            compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2574        let dec =
2575            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2576        assert_eq!(dec, data);
2577    }
2578
2579    #[test]
2580    fn roundtrip_rate_rans_recursive_mixture() {
2581        let data = b"recursive mixture payload";
2582        let nested = MixtureSpec::new(
2583            MixtureKind::Bayes,
2584            vec![
2585                crate::MixtureExpertSpec {
2586                    name: Some("ctw".to_string()),
2587                    log_prior: 0.0,
2588                    max_order: -1,
2589                    backend: RateBackend::Ctw { depth: 6 },
2590                },
2591                crate::MixtureExpertSpec {
2592                    name: Some("fac".to_string()),
2593                    log_prior: 0.0,
2594                    max_order: -1,
2595                    backend: RateBackend::FacCtw {
2596                        base_depth: 6,
2597                        num_percept_bits: 8,
2598                        encoding_bits: 8,
2599                    },
2600                },
2601            ],
2602        );
2603        let root = MixtureSpec::new(
2604            MixtureKind::Switching,
2605            vec![
2606                crate::MixtureExpertSpec {
2607                    name: Some("nested".to_string()),
2608                    log_prior: 0.0,
2609                    max_order: -1,
2610                    backend: RateBackend::Mixture {
2611                        spec: Arc::new(nested),
2612                    },
2613                },
2614                crate::MixtureExpertSpec {
2615                    name: Some("zpaq".to_string()),
2616                    log_prior: 0.0,
2617                    max_order: -1,
2618                    backend: RateBackend::Zpaq {
2619                        method: "1".to_string(),
2620                    },
2621                },
2622            ],
2623        )
2624        .with_alpha(0.05);
2625
2626        let backend = RateBackend::Mixture {
2627            spec: Arc::new(root),
2628        };
2629        let enc =
2630            compress_rate_bytes(data, &backend, -1, CoderType::RANS, FramingMode::Framed).unwrap();
2631        let dec = decompress_rate_bytes(&enc, &backend, -1, CoderType::RANS, FramingMode::Framed)
2632            .unwrap();
2633        assert_eq!(dec, data);
2634    }
2635
2636    #[test]
2637    fn roundtrip_rate_ac_recursive_neural_mixture() {
2638        let data = b"neural recursive mixture payload for ac coder";
2639        let inner = MixtureSpec::new(
2640            MixtureKind::Bayes,
2641            vec![
2642                crate::MixtureExpertSpec {
2643                    name: Some("ctw".to_string()),
2644                    log_prior: 0.0,
2645                    max_order: -1,
2646                    backend: RateBackend::Ctw { depth: 6 },
2647                },
2648                crate::MixtureExpertSpec {
2649                    name: Some("fac".to_string()),
2650                    log_prior: 0.0,
2651                    max_order: -1,
2652                    backend: RateBackend::FacCtw {
2653                        base_depth: 6,
2654                        num_percept_bits: 8,
2655                        encoding_bits: 8,
2656                    },
2657                },
2658            ],
2659        );
2660        let root = MixtureSpec::new(
2661            MixtureKind::Neural,
2662            vec![
2663                crate::MixtureExpertSpec {
2664                    name: Some("nested".to_string()),
2665                    log_prior: 0.0,
2666                    max_order: -1,
2667                    backend: RateBackend::Mixture {
2668                        spec: Arc::new(inner),
2669                    },
2670                },
2671                crate::MixtureExpertSpec {
2672                    name: Some("zpaq".to_string()),
2673                    log_prior: 0.0,
2674                    max_order: -1,
2675                    backend: RateBackend::Zpaq {
2676                        method: "1".to_string(),
2677                    },
2678                },
2679            ],
2680        )
2681        .with_alpha(0.03);
2682
2683        let backend = RateBackend::Mixture {
2684            spec: Arc::new(root),
2685        };
2686        let enc =
2687            compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2688        let dec =
2689            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2690        assert_eq!(dec, data);
2691    }
2692
2693    fn assert_runtime_and_compression_predictor_align(spec: MixtureSpec, data: &[u8], tol: f64) {
2694        let backend = RateBackend::Mixture {
2695            spec: Arc::new(spec.clone()),
2696        };
2697        let mut predictor = RatePdfPredictor::from_rate_backend(backend, -1).unwrap();
2698        let experts = spec.build_experts();
2699        let mut runtime = crate::mixture::build_mixture_runtime(&spec, &experts).unwrap();
2700
2701        for (t, &symbol) in data.iter().enumerate() {
2702            let pdf = predictor.pdf_next().unwrap();
2703            let p_comp = pdf[symbol as usize];
2704            let p_runtime = runtime.peek_log_prob(symbol).exp();
2705            assert!(
2706                (p_comp - p_runtime).abs() < tol,
2707                "t={t} p_comp={p_comp} p_runtime={p_runtime} symbol={symbol}"
2708            );
2709            predictor.update(symbol).unwrap();
2710            runtime.step(symbol);
2711        }
2712    }
2713
2714    fn alignment_experts() -> Vec<crate::MixtureExpertSpec> {
2715        vec![
2716            crate::MixtureExpertSpec {
2717                name: Some("ctw".to_string()),
2718                log_prior: 0.0,
2719                max_order: -1,
2720                backend: RateBackend::Ctw { depth: 7 },
2721            },
2722            crate::MixtureExpertSpec {
2723                name: Some("fac".to_string()),
2724                log_prior: -0.7,
2725                max_order: -1,
2726                backend: RateBackend::FacCtw {
2727                    base_depth: 7,
2728                    num_percept_bits: 8,
2729                    encoding_bits: 8,
2730                },
2731            },
2732        ]
2733    }
2734
2735    #[test]
2736    fn bayes_runtime_and_compression_predictor_align() {
2737        let spec = MixtureSpec::new(MixtureKind::Bayes, alignment_experts());
2738        assert_runtime_and_compression_predictor_align(
2739            spec,
2740            b"bayes predictor alignment check sequence",
2741            1e-8,
2742        );
2743    }
2744
2745    #[test]
2746    fn fading_runtime_and_compression_predictor_align() {
2747        let spec = MixtureSpec::new(MixtureKind::FadingBayes, alignment_experts()).with_decay(0.97);
2748        assert_runtime_and_compression_predictor_align(
2749            spec,
2750            b"fading predictor alignment check sequence",
2751            1e-8,
2752        );
2753    }
2754
2755    #[test]
2756    fn switching_runtime_and_compression_predictor_align() {
2757        let spec = MixtureSpec::new(MixtureKind::Switching, alignment_experts()).with_alpha(0.17);
2758        assert_runtime_and_compression_predictor_align(
2759            spec,
2760            b"switching predictor alignment check sequence",
2761            1e-8,
2762        );
2763    }
2764
2765    #[test]
2766    fn switching_theorem_runtime_and_compression_predictor_align() {
2767        let spec = MixtureSpec::new(MixtureKind::Switching, alignment_experts())
2768            .with_schedule(MixtureScheduleMode::Theorem)
2769            .with_alpha(0.91);
2770        assert_runtime_and_compression_predictor_align(
2771            spec,
2772            b"switching theorem predictor alignment check sequence",
2773            1e-8,
2774        );
2775    }
2776
2777    #[test]
2778    fn convex_runtime_and_compression_predictor_align_for_alpha_above_one() {
2779        let spec = MixtureSpec::new(MixtureKind::Convex, alignment_experts()).with_alpha(1.25);
2780        assert_runtime_and_compression_predictor_align(
2781            spec,
2782            b"convex predictor alignment check sequence",
2783            1e-8,
2784        );
2785    }
2786
2787    #[test]
2788    fn convex_theorem_runtime_and_compression_predictor_align() {
2789        let spec = MixtureSpec::new(MixtureKind::Convex, alignment_experts())
2790            .with_schedule(MixtureScheduleMode::Theorem)
2791            .with_alpha(7.5);
2792        assert_runtime_and_compression_predictor_align(
2793            spec,
2794            b"convex theorem predictor alignment check sequence",
2795            1e-8,
2796        );
2797    }
2798
2799    #[test]
2800    fn neural_runtime_and_compression_predictor_align() {
2801        let spec = MixtureSpec::new(MixtureKind::Neural, alignment_experts()).with_alpha(0.03);
2802        assert_runtime_and_compression_predictor_align(
2803            spec,
2804            b"neural alignment check sequence",
2805            1e-8,
2806        );
2807    }
2808
2809    #[test]
2810    fn mdl_runtime_and_compression_predictor_align() {
2811        let spec = MixtureSpec::new(MixtureKind::Mdl, alignment_experts());
2812        assert_runtime_and_compression_predictor_align(spec, b"mdl alignment check sequence", 1e-8);
2813    }
2814
2815    #[test]
2816    fn nested_runtime_and_compression_predictor_align() {
2817        let nested = MixtureSpec::new(MixtureKind::Bayes, alignment_experts());
2818        let spec = MixtureSpec::new(
2819            MixtureKind::Switching,
2820            vec![
2821                crate::MixtureExpertSpec {
2822                    name: Some("nested".to_string()),
2823                    log_prior: 0.0,
2824                    max_order: -1,
2825                    backend: RateBackend::Mixture {
2826                        spec: Arc::new(nested),
2827                    },
2828                },
2829                crate::MixtureExpertSpec {
2830                    name: Some("ppmd".to_string()),
2831                    log_prior: -0.2,
2832                    max_order: -1,
2833                    backend: RateBackend::Ppmd {
2834                        order: 5,
2835                        memory_mb: 8,
2836                    },
2837                },
2838            ],
2839        )
2840        .with_alpha(0.13);
2841        assert_runtime_and_compression_predictor_align(
2842            spec,
2843            b"nested mixture predictor alignment check sequence",
2844            1e-8,
2845        );
2846    }
2847
2848    fn assert_cached_cdf_fast_bitwise_matches_pdf_rows(mut predictor: RatePdfPredictor) {
2849        let data = b"cached cdf parity check payload";
2850        for &symbol in data {
2851            let pdf = predictor.pdf_next().unwrap().to_vec();
2852            assert!(predictor.prepare_cached_cdf_fast_bitwise().unwrap());
2853
2854            let mut row = [0.0; 257];
2855            row[0] = 0.0;
2856            for i in 0..256 {
2857                row[i + 1] = row[i] + pdf[i].max(PDF_MIN);
2858            }
2859
2860            let mut stack = vec![(0usize, 256usize)];
2861            while let Some((lo, hi)) = stack.pop() {
2862                if hi - lo <= 1 {
2863                    continue;
2864                }
2865                let expected = cdf_bit_prob_one_msb(&row, lo, hi);
2866                let got = predictor
2867                    .cached_cdf_bit_prob_one_msb(lo, hi)
2868                    .expect("cached cdf branch probability");
2869                let diff = (expected - got).abs();
2870                assert!(
2871                    diff <= 1e-12,
2872                    "lo={lo} hi={hi} expected={expected} got={got} diff={diff}"
2873                );
2874                let mid = (lo + hi) >> 1;
2875                stack.push((lo, mid));
2876                stack.push((mid, hi));
2877            }
2878
2879            predictor.update(symbol).unwrap();
2880        }
2881    }
2882
2883    #[test]
2884    fn cached_cdf_fast_bitwise_matches_pdf_rows_for_specialized_predictors() {
2885        assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2886            RatePdfPredictor::from_rate_backend(RateBackend::RosaPlus, -1).unwrap(),
2887        );
2888        assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2889            RatePdfPredictor::from_rate_backend(
2890                RateBackend::Ppmd {
2891                    order: 6,
2892                    memory_mb: 8,
2893                },
2894                -1,
2895            )
2896            .unwrap(),
2897        );
2898        assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2899            RatePdfPredictor::from_rate_backend(
2900                RateBackend::Match {
2901                    hash_bits: 20,
2902                    min_len: 4,
2903                    max_len: 255,
2904                    base_mix: 0.02,
2905                    confidence_scale: 1.0,
2906                },
2907                -1,
2908            )
2909            .unwrap(),
2910        );
2911        #[cfg(feature = "backend-rwkv")]
2912        assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2913            RatePdfPredictor::from_rate_backend(
2914                RateBackend::Rwkv7Method {
2915                    method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=11,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
2916                },
2917                -1,
2918            )
2919            .unwrap(),
2920        );
2921        #[cfg(feature = "backend-mamba")]
2922        assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2923            RatePdfPredictor::from_rate_backend(
2924                RateBackend::MambaMethod {
2925                    method: "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=7,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
2926                },
2927                -1,
2928            )
2929            .unwrap(),
2930        );
2931    }
2932
2933    #[test]
2934    fn raw_size_not_larger_than_framed_size() {
2935        let data = b"raw/framed size check payload";
2936        let backend = RateBackend::RosaPlus;
2937        let raw = compress_rate_size(data, &backend, 8, CoderType::AC, FramingMode::Raw).unwrap();
2938        let framed =
2939            compress_rate_size(data, &backend, 8, CoderType::AC, FramingMode::Framed).unwrap();
2940        assert!(framed >= raw);
2941    }
2942
2943    #[cfg(feature = "backend-rwkv")]
2944    #[test]
2945    fn roundtrip_rate_rwkv_method_cfg() {
2946        let data = b"rwkv cfg method backend";
2947        let backend = RateBackend::Rwkv7Method {
2948            method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=11,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
2949        };
2950        let enc =
2951            compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2952        let dec =
2953            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2954        assert_eq!(dec, data);
2955    }
2956
2957    #[cfg(feature = "backend-rwkv")]
2958    #[test]
2959    fn rwkv_rate_predictor_preserves_backend_pdf_exactly() {
2960        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=11,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer";
2961        let mut predictor = RwkvPredictor::from_method(method).expect("rwkv predictor");
2962        let mut backend = rwkvzip::Compressor::new_from_method(method).expect("rwkv backend");
2963        let mut direct = vec![0.0; backend.vocab_size()];
2964
2965        let predicted = predictor.pdf_next().to_vec();
2966        backend.forward_to_pdf(0, &mut direct);
2967        assert_pdf_close(&predicted, &direct, 1e-18);
2968
2969        predictor.update(b'x').expect("predictor update");
2970        backend
2971            .online_update_from_pdf(b'x', &direct)
2972            .expect("backend update");
2973        backend.forward_to_pdf(u32::from(b'x'), &mut direct);
2974        assert_pdf_close(predictor.pdf_next(), &direct, 1e-18);
2975    }
2976
2977    #[cfg(feature = "backend-rwkv")]
2978    #[test]
2979    fn rwkv_rate_predictor_matches_backend_after_partial_tbptt_stream() {
2980        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=29,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=8,clip=0,momentum=0.9)";
2981        let data = b"abcdefghij";
2982        let mut predictor = RwkvPredictor::from_method(method).expect("rwkv predictor");
2983        let mut backend = rwkvzip::Compressor::new_from_method(method).expect("rwkv backend");
2984        let mut direct = vec![0.0; backend.vocab_size()];
2985
2986        predictor
2987            .begin_stream(data.len())
2988            .expect("begin predictor stream");
2989        backend
2990            .begin_online_policy_stream(Some(data.len() as u64))
2991            .expect("begin backend stream");
2992        backend.reset_and_prime();
2993
2994        for &byte in data {
2995            let predicted = predictor.pdf_next().to_vec();
2996            backend.copy_current_pdf_to(&mut direct);
2997            assert_pdf_close(&predicted, &direct, 1e-18);
2998
2999            predictor.update(byte).expect("predictor update");
3000            backend
3001                .observe_symbol_from_current_pdf(byte)
3002                .expect("backend update");
3003        }
3004
3005        predictor.finish_stream().expect("finish predictor stream");
3006        backend
3007            .finish_online_policy_stream()
3008            .expect("finish backend stream");
3009        backend.copy_current_pdf_to(&mut direct);
3010        assert_pdf_close(predictor.pdf_next(), &direct, 1e-18);
3011    }
3012
3013    #[cfg(feature = "backend-rwkv")]
3014    #[test]
3015    fn roundtrip_rate_rwkv_two_json_method_2m() {
3016        let two_json: serde_json::Value =
3017            serde_json::from_str(include_str!("../../examples/two.json")).unwrap();
3018        let method = two_json["experts"]
3019            .as_array()
3020            .unwrap()
3021            .iter()
3022            .find(|expert| expert["name"].as_str() == Some("rwkv"))
3023            .and_then(|expert| expert["method"].as_str())
3024            .unwrap()
3025            .to_string();
3026
3027        let backend = RateBackend::Rwkv7Method { method };
3028        let seed = include_bytes!("../../README.md");
3029        let target_len = 2_097_152usize;
3030        let mut data = Vec::with_capacity(target_len);
3031        while data.len() < target_len {
3032            let remaining = target_len - data.len();
3033            data.extend_from_slice(&seed[..seed.len().min(remaining)]);
3034        }
3035
3036        let enc =
3037            compress_rate_bytes(&data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
3038        let dec =
3039            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
3040        assert_eq!(dec, data);
3041    }
3042
3043    #[cfg(feature = "backend-mamba")]
3044    #[test]
3045    fn mamba_rate_predictor_preserves_backend_pdf_exactly() {
3046        let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=7,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer";
3047        let mut predictor = MambaPredictor::from_method(method).expect("mamba predictor");
3048        let mut backend = mambazip::Compressor::new_from_method(method).expect("mamba backend");
3049        let mut direct = vec![0.0; backend.vocab_size()];
3050
3051        let predicted = predictor.pdf_next().to_vec();
3052        backend.forward_to_pdf(0, &mut direct);
3053        assert_pdf_close(&predicted, &direct, 1e-18);
3054
3055        predictor.update(b'x').expect("predictor update");
3056        backend
3057            .online_update_from_pdf(b'x', &direct)
3058            .expect("backend update");
3059        backend.forward_to_pdf(u32::from(b'x'), &mut direct);
3060        assert_pdf_close(predictor.pdf_next(), &direct, 1e-18);
3061    }
3062
3063    #[test]
3064    fn roundtrip_rate_ac_particle() {
3065        let spec = crate::ParticleSpec {
3066            num_particles: 4,
3067            num_cells: 4,
3068            cell_dim: 8,
3069            num_rules: 2,
3070            selector_hidden: 16,
3071            rule_hidden: 16,
3072            context_window: 8,
3073            unroll_steps: 1,
3074            ..crate::ParticleSpec::default()
3075        };
3076        let data = b"particle ac roundtrip payload";
3077        let backend = RateBackend::Particle {
3078            spec: Arc::new(spec),
3079        };
3080        let enc =
3081            compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
3082        let dec =
3083            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
3084        assert_eq!(dec, data);
3085    }
3086
3087    #[test]
3088    fn roundtrip_rate_rans_particle() {
3089        let spec = crate::ParticleSpec {
3090            num_particles: 4,
3091            num_cells: 4,
3092            cell_dim: 8,
3093            num_rules: 2,
3094            selector_hidden: 16,
3095            rule_hidden: 16,
3096            context_window: 8,
3097            unroll_steps: 1,
3098            ..crate::ParticleSpec::default()
3099        };
3100        let data = b"particle rans roundtrip payload";
3101        let backend = RateBackend::Particle {
3102            spec: Arc::new(spec),
3103        };
3104        let enc =
3105            compress_rate_bytes(data, &backend, -1, CoderType::RANS, FramingMode::Framed).unwrap();
3106        let dec = decompress_rate_bytes(&enc, &backend, -1, CoderType::RANS, FramingMode::Framed)
3107            .unwrap();
3108        assert_eq!(dec, data);
3109    }
3110
3111    #[test]
3112    fn mixture_with_particle_expert_roundtrip() {
3113        let particle_spec = crate::ParticleSpec {
3114            num_particles: 4,
3115            num_cells: 4,
3116            cell_dim: 8,
3117            num_rules: 2,
3118            selector_hidden: 16,
3119            rule_hidden: 16,
3120            context_window: 8,
3121            unroll_steps: 1,
3122            ..crate::ParticleSpec::default()
3123        };
3124        let spec = MixtureSpec::new(
3125            MixtureKind::Bayes,
3126            vec![
3127                crate::MixtureExpertSpec {
3128                    name: Some("particle".to_string()),
3129                    log_prior: 0.0,
3130                    max_order: -1,
3131                    backend: RateBackend::Particle {
3132                        spec: Arc::new(particle_spec),
3133                    },
3134                },
3135                crate::MixtureExpertSpec {
3136                    name: Some("ctw".to_string()),
3137                    log_prior: 0.0,
3138                    max_order: -1,
3139                    backend: RateBackend::Ctw { depth: 6 },
3140                },
3141            ],
3142        );
3143        let backend = RateBackend::Mixture {
3144            spec: Arc::new(spec),
3145        };
3146        let data = b"mixture with particle expert roundtrip";
3147        let enc =
3148            compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
3149        let dec =
3150            decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
3151        assert_eq!(dec, data);
3152    }
3153}