Skip to main content

infotheory/
mixture.rs

1//! Online mixtures of probabilistic predictors (log-loss Hedge / Bayes, switching, MDL).
2//!
3//! This module provides a small, rigorously correct toolkit for sequential model mixing.
4//! Predictors expose per-symbol log-probabilities, which allows principled Bayesian
5//! mixture updates and clean information-theoretic accounting.
6//!
7//! ## Rate-Backend Mixtures
8//!
9//! The mixture primitives here power `RateBackend::Mixture`, enabling Bayes, fading Bayes,
10//! switching, and MDL-style selectors to be used anywhere a rate backend is accepted.
11
12use crate::backends::calibration::CalibratorCore;
13use crate::backends::match_model::MatchModel;
14use crate::backends::ppmd::PpmdModel;
15use crate::backends::sequitur::{SequiturCheckpoint, SequiturModel};
16use crate::backends::sparse_match::SparseMatchModel;
17use crate::backends::text_context::TextContextAnalyzer;
18use crate::ctw::FacContextTree;
19#[cfg(feature = "backend-mamba")]
20use crate::mambazip;
21use crate::neural_mix::{NeuralHistoryState, NeuralMixCore};
22use crate::rosaplus::RosaPlus;
23#[cfg(feature = "backend-rwkv")]
24use crate::rwkvzip;
25use crate::zpaq_rate::ZpaqRateModel;
26use crate::{CalibratedSpec, MixtureKind, MixtureScheduleMode, MixtureSpec, RateBackend};
27use std::sync::Arc;
28
29/// Default minimum probability floor to avoid log(0).
30pub const DEFAULT_MIN_PROB: f64 = 5.960_464_477_539_063e-8;
31
32#[inline]
33fn clamp_prob(p: f64, min_prob: f64) -> f64 {
34    if p.is_finite() {
35        p.max(min_prob)
36    } else {
37        min_prob
38    }
39}
40
41#[inline]
42fn clamp_unit_prob(p: f64, min_prob: f64) -> f64 {
43    clamp_prob(p, min_prob).min(1.0 - min_prob)
44}
45
46#[inline]
47fn build_calibrator(spec: &CalibratedSpec) -> CalibratorCore {
48    CalibratorCore::new(spec.context, spec.bins, spec.learning_rate, spec.bias_clip)
49}
50
51#[inline]
52fn logsumexp(xs: &[f64]) -> f64 {
53    let mut max_v = f64::NEG_INFINITY;
54    for &v in xs {
55        if v > max_v {
56            max_v = v;
57        }
58    }
59    if !max_v.is_finite() {
60        return max_v;
61    }
62    let mut sum = 0.0;
63    for &v in xs {
64        sum += (v - max_v).exp();
65    }
66    max_v + sum.ln()
67}
68
69#[inline]
70fn logsumexp2(a: f64, b: f64) -> f64 {
71    let m = if a > b { a } else { b };
72    if !m.is_finite() {
73        return m;
74    }
75    m + ((a - m).exp() + (b - m).exp()).ln()
76}
77
78#[inline]
79fn logsumexp_weights(experts: &[ExpertState]) -> f64 {
80    let mut max_v = f64::NEG_INFINITY;
81    for e in experts {
82        if e.log_weight > max_v {
83            max_v = e.log_weight;
84        }
85    }
86    if !max_v.is_finite() {
87        return max_v;
88    }
89    let mut sum = 0.0;
90    for e in experts {
91        sum += (e.log_weight - max_v).exp();
92    }
93    max_v + sum.ln()
94}
95
96fn normalize_simplex_weights(weights: &mut [f64]) {
97    if weights.is_empty() {
98        return;
99    }
100    let mut sum = 0.0;
101    for weight in weights.iter_mut() {
102        if !weight.is_finite() || *weight < 0.0 {
103            *weight = 0.0;
104        }
105        sum += *weight;
106    }
107    if !sum.is_finite() || sum <= 0.0 {
108        let uniform = 1.0 / (weights.len() as f64);
109        weights.fill(uniform);
110        return;
111    }
112    for weight in weights.iter_mut() {
113        *weight /= sum;
114    }
115}
116
117pub(crate) fn project_simplex_with_scratch(weights: &mut [f64], scratch: &mut Vec<f64>) {
118    if weights.is_empty() {
119        return;
120    }
121
122    scratch.clear();
123    scratch.extend(
124        weights
125            .iter()
126            .map(|&weight| if weight.is_finite() { weight } else { 0.0 }),
127    );
128    let sorted = scratch.as_mut_slice();
129    sorted.sort_by(|a, b| b.total_cmp(a));
130
131    let mut cumulative = 0.0;
132    let mut rho = None;
133    for (index, value) in sorted.iter().enumerate() {
134        cumulative += *value;
135        let theta = (cumulative - 1.0) / ((index + 1) as f64);
136        if *value > theta {
137            rho = Some(index);
138        }
139    }
140
141    let Some(rho_index) = rho else {
142        let uniform = 1.0 / (weights.len() as f64);
143        weights.fill(uniform);
144        return;
145    };
146
147    let theta = (sorted.iter().take(rho_index + 1).sum::<f64>() - 1.0) / ((rho_index + 1) as f64);
148    for weight in weights.iter_mut() {
149        *weight = (*weight - theta).max(0.0);
150    }
151    normalize_simplex_weights(weights);
152}
153
154#[inline]
155pub(crate) fn switching_alpha_for_update(
156    schedule: MixtureScheduleMode,
157    alpha: f64,
158    processed_symbols: u64,
159) -> f64 {
160    match schedule {
161        MixtureScheduleMode::Default => alpha.clamp(0.0, 1.0),
162        MixtureScheduleMode::Theorem => 1.0 / ((processed_symbols + 2) as f64),
163    }
164}
165
166#[inline]
167pub(crate) fn convex_step_size_for_update(
168    schedule: MixtureScheduleMode,
169    alpha: f64,
170    update_index: u64,
171) -> f64 {
172    let t = update_index.max(1) as f64;
173    match schedule {
174        MixtureScheduleMode::Default => alpha.max(1e-12) / t.sqrt(),
175        MixtureScheduleMode::Theorem => DEFAULT_MIN_PROB / t.sqrt(),
176    }
177}
178
179fn normalized_prior_weights(configs: &[ExpertConfig]) -> Vec<f64> {
180    if configs.is_empty() {
181        return Vec::new();
182    }
183    let max_log = configs
184        .iter()
185        .map(|cfg| cfg.log_prior)
186        .fold(f64::NEG_INFINITY, f64::max);
187    let mut weights = configs
188        .iter()
189        .map(|cfg| {
190            if max_log.is_finite() {
191                (cfg.log_prior - max_log).exp()
192            } else {
193                0.0
194            }
195        })
196        .collect::<Vec<_>>();
197    normalize_simplex_weights(&mut weights);
198    weights
199}
200
201fn set_log_weights_from_linear(experts: &mut [ExpertState], weights: &[f64]) {
202    for (expert, &weight) in experts.iter_mut().zip(weights.iter()) {
203        expert.log_weight = if weight > 0.0 {
204            weight.ln()
205        } else {
206            f64::NEG_INFINITY
207        };
208    }
209}
210
211/// Trait for online byte-level predictors that expose per-symbol log-probabilities.
212pub trait OnlineBytePredictorClone {
213    /// Clone this predictor as a trait object.
214    ///
215    /// This supports `Clone` for `Box<dyn OnlineBytePredictor>` via type erasure,
216    /// so mixture experts can be duplicated without knowing their concrete type.
217    fn clone_box(&self) -> Box<dyn OnlineBytePredictor>;
218}
219
220impl<T> OnlineBytePredictorClone for T
221where
222    T: 'static + OnlineBytePredictor + Clone,
223{
224    fn clone_box(&self) -> Box<dyn OnlineBytePredictor> {
225        Box::new(self.clone())
226    }
227}
228
229impl Clone for Box<dyn OnlineBytePredictor> {
230    fn clone(&self) -> Self {
231        self.clone_box()
232    }
233}
234
235/// Trait for online byte-level predictors that expose per-symbol log-probabilities.
236pub trait OnlineBytePredictor: Send + OnlineBytePredictorClone {
237    /// Optional stream-start hook.
238    ///
239    /// Predictors that require total symbol count (for example percent-based
240    /// policy schedules) can initialize runtime state here.
241    fn begin_stream(&mut self, _total_symbols: Option<u64>) -> Result<(), String> {
242        Ok(())
243    }
244
245    /// Optional stream-finalization hook.
246    fn finish_stream(&mut self) -> Result<(), String> {
247        Ok(())
248    }
249
250    /// Log-probability (natural log) of `symbol` given the current history.
251    fn log_prob(&mut self, symbol: u8) -> f64;
252
253    /// Bulk 256-way log-probabilities for the next byte.
254    fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
255        for (sym, slot) in out.iter_mut().enumerate() {
256            *slot = self.log_prob(sym as u8);
257        }
258    }
259
260    /// Log-probability (natural log) of `symbol`, then update the predictor.
261    fn log_prob_update(&mut self, symbol: u8) -> f64 {
262        let logp = self.log_prob(symbol);
263        self.update(symbol);
264        logp
265    }
266
267    /// Update the predictor with the observed `symbol`.
268    fn update(&mut self, symbol: u8);
269
270    /// Reset only dynamic conditioning state while preserving fitted parameters/statistics.
271    ///
272    /// Predictors with latent/posterior state may also preserve their learned
273    /// parameter posterior here; "frozen" means no new parameter fitting during
274    /// the score pass, not necessarily a static hidden-state belief.
275    fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
276        self.finish_stream()?;
277        self.begin_stream(total_symbols)
278    }
279
280    /// Advance conditioning state without fitting or adapting parameters.
281    ///
282    /// For state-space or latent-variable models this may still update internal
283    /// filtering/posterior state needed for correct sequential predictions.
284    fn update_frozen(&mut self, symbol: u8) {
285        self.update(symbol);
286    }
287}
288
289#[cfg(feature = "backend-rwkv")]
290#[inline]
291fn ensure_rwkv_primed(compressor: &mut rwkvzip::Compressor, primed: &mut bool) {
292    if !*primed {
293        compressor.reset_and_prime();
294        *primed = true;
295    }
296}
297
298#[inline]
299fn ctw_log_prob_update_msb(tree: &mut FacContextTree, symbol: u8, min_prob: f64) -> f64 {
300    let mut logp = 0.0;
301    for bit_idx in 0..8 {
302        let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
303        let p = tree.predict(bit, bit_idx);
304        if p.is_finite() && p > 0.0 {
305            logp += p.ln();
306        } else {
307            logp = f64::NEG_INFINITY;
308        }
309        tree.update_predicted(bit, bit_idx);
310    }
311    if logp.is_finite() {
312        logp.max(min_prob.ln())
313    } else {
314        min_prob.ln()
315    }
316}
317
318#[inline]
319fn ctw_log_prob_update_lsb(
320    tree: &mut FacContextTree,
321    symbol: u8,
322    bits_per_symbol: usize,
323    min_prob: f64,
324) -> f64 {
325    let mut logp = 0.0;
326    for bit_idx in 0..bits_per_symbol {
327        let bit = ((symbol >> bit_idx) & 1) == 1;
328        let p = tree.predict(bit, bit_idx);
329        if p.is_finite() && p > 0.0 {
330            logp += p.ln();
331        } else {
332            logp = f64::NEG_INFINITY;
333        }
334        tree.update_predicted(bit, bit_idx);
335    }
336    if logp.is_finite() {
337        logp.max(min_prob.ln())
338    } else {
339        min_prob.ln()
340    }
341}
342
343fn fill_fac_tree_log_probs(
344    tree: &mut FacContextTree,
345    bits_per_symbol: usize,
346    msb_first: bool,
347    min_logp: f64,
348    out: &mut [f64; 256],
349) {
350    struct RecParams {
351        bits: usize,
352        msb_first: bool,
353        log_before: f64,
354        min_logp: f64,
355    }
356
357    let bits = bits_per_symbol.clamp(1, 8);
358    let patterns = 1usize << bits;
359    let mut pattern_logps = [f64::NEG_INFINITY; 256];
360    let params = RecParams {
361        bits,
362        msb_first,
363        log_before: tree.get_log_block_probability(),
364        min_logp,
365    };
366
367    fn rec(
368        tree: &mut FacContextTree,
369        depth: usize,
370        params: &RecParams,
371        symbol_acc: u8,
372        pattern_logps: &mut [f64; 256],
373    ) {
374        if depth == params.bits {
375            let pat = symbol_acc as usize;
376            let logp = (tree.get_log_block_probability() - params.log_before).max(params.min_logp);
377            pattern_logps[pat] = logp;
378            return;
379        }
380
381        for bit in [false, true] {
382            tree.update(bit, depth);
383            let mut next_symbol = symbol_acc;
384            if params.msb_first {
385                let shift = 7usize.saturating_sub(depth);
386                if bit {
387                    next_symbol |= 1u8 << shift;
388                }
389            } else if bit {
390                next_symbol |= 1u8 << depth;
391            }
392            rec(tree, depth + 1, params, next_symbol, pattern_logps);
393            tree.revert(depth);
394        }
395    }
396
397    rec(tree, 0, &params, 0, &mut pattern_logps);
398
399    if bits == 8 {
400        out.copy_from_slice(&pattern_logps);
401    } else {
402        let aliases = 1usize << (8 - bits);
403        let alias_ln = (aliases as f64).ln();
404        let mask = patterns - 1;
405        for byte in 0..256usize {
406            out[byte] = pattern_logps[byte & mask] - alias_ln;
407        }
408    }
409}
410
411/// A concrete online predictor backed by a `RateBackend` configuration.
412#[allow(clippy::large_enum_variant)]
413#[derive(Clone)]
414pub enum RateBackendPredictor {
415    /// ROSA-Plus online suffix automaton.
416    Rosa {
417        /// ROSA model state.
418        model: RosaPlus,
419        /// Probability floor for numeric stability.
420        min_prob: f64,
421    },
422    /// Local contiguous match predictor.
423    Match {
424        /// Match model state.
425        model: MatchModel,
426        /// Probability floor for numeric stability.
427        min_prob: f64,
428    },
429    /// Sparse/gapped local match predictor.
430    SparseMatch {
431        /// Sparse-match model state.
432        model: SparseMatchModel,
433        /// Probability floor for numeric stability.
434        min_prob: f64,
435    },
436    /// Bounded-memory PPMD-style predictor.
437    Ppmd {
438        /// PPMD model state.
439        model: PpmdModel,
440        /// Probability floor for numeric stability.
441        min_prob: f64,
442    },
443    /// Exact online Sequitur grammar backend with predictive suffix contexts.
444    Sequitur {
445        /// Sequitur model state.
446        model: SequiturModel,
447        /// Probability floor for numeric stability.
448        min_prob: f64,
449    },
450    /// Byte-wise CTW implemented as 8 factorized bit trees (MSB-first).
451    Ctw {
452        /// FAC-CTW tree stack (8 bits per byte).
453        tree: FacContextTree,
454        /// Probability floor for numeric stability.
455        min_prob: f64,
456    },
457    /// Factorized CTW with configurable bit-encoding (LSB-first).
458    FacCtw {
459        /// FAC-CTW tree stack for configured bit width.
460        tree: FacContextTree,
461        /// Active bit-width per symbol.
462        bits_per_symbol: usize,
463        /// Probability floor for numeric stability.
464        min_prob: f64,
465    },
466    /// RWKV-7 neural predictor.
467    #[cfg(feature = "backend-rwkv")]
468    Rwkv7 {
469        /// RWKV compressor/runtime state.
470        compressor: rwkvzip::Compressor,
471        /// Whether the first-token distribution has been primed.
472        primed: bool,
473        /// Scratch copy used for update API that borrows immutable PDF.
474        pdf_scratch: Vec<f64>,
475        /// Probability floor for numeric stability.
476        min_prob: f64,
477    },
478    /// Mamba-1 neural predictor.
479    #[cfg(feature = "backend-mamba")]
480    Mamba {
481        /// Mamba compressor/runtime state.
482        compressor: mambazip::Compressor,
483        /// Whether the first-token distribution has been primed.
484        primed: bool,
485        /// Scratch copy used for update API that borrows immutable PDF.
486        pdf_scratch: Vec<f64>,
487        /// Probability floor for numeric stability.
488        min_prob: f64,
489    },
490    /// ZPAQ streaming rate model.
491    Zpaq {
492        /// ZPAQ rate model state.
493        model: ZpaqRateModel,
494    },
495    /// Online mixture over experts (Bayes, fading Bayes, switching, MDL).
496    Mixture {
497        /// Active mixture runtime.
498        runtime: MixtureRuntime,
499    },
500    /// Particle-latent filter ensemble.
501    Particle {
502        /// Particle runtime.
503        runtime: crate::particle::ParticleRuntime,
504    },
505    /// Calibrated wrapper around another predictor.
506    Calibrated {
507        /// Wrapped predictor whose PDF is calibrated.
508        base: Box<RateBackendPredictor>,
509        /// Online calibrator state and context features.
510        core: CalibratorCore,
511        /// Cached calibrated PDF.
512        pdf: [f64; 256],
513        /// Whether `pdf` currently matches wrapped state.
514        valid: bool,
515        /// Probability floor used for numerical stability.
516        min_prob: f64,
517    },
518}
519
520#[derive(Clone)]
521/// Checkpoint snapshot used for temporary predictor rollback.
522///
523/// Most backends use a full cloned predictor snapshot. Sequitur uses a compact
524/// internal checkpoint to avoid cloning its full state.
525pub enum RateBackendPredictorCheckpoint {
526    /// Full predictor clone for backends without specialized checkpointing.
527    Full(RateBackendPredictor),
528    /// Compact Sequitur undo marker for [`RateBackendPredictor::Sequitur`].
529    Sequitur(SequiturCheckpoint),
530}
531
532impl RateBackendPredictor {
533    /// Create a new online predictor from a rate backend configuration.
534    pub fn from_backend(backend: RateBackend, max_order: i64, min_prob: f64) -> Self {
535        match backend {
536            RateBackend::RosaPlus => {
537                let mut model = RosaPlus::new(max_order, false, 0, 42);
538                model.build_lm_full_bytes_no_finalize_endpos();
539                Self::Rosa { model, min_prob }
540            }
541            RateBackend::Match {
542                hash_bits,
543                min_len,
544                max_len,
545                base_mix,
546                confidence_scale,
547            } => Self::Match {
548                model: MatchModel::new_contiguous(
549                    hash_bits,
550                    min_len,
551                    max_len,
552                    base_mix,
553                    confidence_scale,
554                ),
555                min_prob,
556            },
557            RateBackend::SparseMatch {
558                hash_bits,
559                min_len,
560                max_len,
561                gap_min,
562                gap_max,
563                base_mix,
564                confidence_scale,
565            } => Self::SparseMatch {
566                model: SparseMatchModel::new(
567                    hash_bits,
568                    min_len,
569                    max_len,
570                    gap_min,
571                    gap_max,
572                    base_mix,
573                    confidence_scale,
574                ),
575                min_prob,
576            },
577            RateBackend::Ppmd { order, memory_mb } => Self::Ppmd {
578                model: PpmdModel::new(order, memory_mb),
579                min_prob,
580            },
581            RateBackend::Sequitur { context_bytes } => Self::Sequitur {
582                model: SequiturModel::new(context_bytes),
583                min_prob,
584            },
585            RateBackend::Ctw { depth } => {
586                let tree = FacContextTree::new(depth, 8);
587                Self::Ctw { tree, min_prob }
588            }
589            RateBackend::FacCtw {
590                base_depth,
591                num_percept_bits: _,
592                encoding_bits,
593            } => {
594                let bits_per_symbol = encoding_bits.clamp(1, 8);
595                let tree = FacContextTree::new(base_depth, bits_per_symbol);
596                Self::FacCtw {
597                    tree,
598                    bits_per_symbol,
599                    min_prob,
600                }
601            }
602            #[cfg(feature = "backend-rwkv")]
603            RateBackend::Rwkv7 { model } => {
604                let mut compressor = rwkvzip::Compressor::new_from_model(model);
605                compressor.reset_and_prime();
606                Self::Rwkv7 {
607                    pdf_scratch: vec![0.0; compressor.pdf_buffer.len()],
608                    compressor,
609                    primed: true,
610                    min_prob,
611                }
612            }
613            #[cfg(feature = "backend-rwkv")]
614            RateBackend::Rwkv7Method { method } => {
615                let mut compressor = rwkvzip::Compressor::new_from_method(&method)
616                    .unwrap_or_else(|e| panic!("invalid rwkv method '{method}': {e}"));
617                compressor.reset_and_prime();
618                Self::Rwkv7 {
619                    pdf_scratch: vec![0.0; compressor.pdf_buffer.len()],
620                    compressor,
621                    primed: true,
622                    min_prob,
623                }
624            }
625            #[cfg(feature = "backend-mamba")]
626            RateBackend::Mamba { model } => {
627                let mut compressor = mambazip::Compressor::new_from_model(model);
628                let bias = compressor.online_bias_snapshot();
629                let logits =
630                    compressor
631                        .model
632                        .forward(&mut compressor.scratch, 0, &mut compressor.state);
633                mambazip::Compressor::logits_to_pdf(
634                    logits,
635                    bias.as_deref(),
636                    &mut compressor.pdf_buffer,
637                );
638                Self::Mamba {
639                    pdf_scratch: vec![0.0; compressor.pdf_buffer.len()],
640                    compressor,
641                    primed: true,
642                    min_prob,
643                }
644            }
645            #[cfg(feature = "backend-mamba")]
646            RateBackend::MambaMethod { method } => {
647                let mut compressor = mambazip::Compressor::new_from_method(&method)
648                    .unwrap_or_else(|e| panic!("invalid mamba method '{method}': {e}"));
649                let bias = compressor.online_bias_snapshot();
650                let logits =
651                    compressor
652                        .model
653                        .forward(&mut compressor.scratch, 0, &mut compressor.state);
654                mambazip::Compressor::logits_to_pdf(
655                    logits,
656                    bias.as_deref(),
657                    &mut compressor.pdf_buffer,
658                );
659                Self::Mamba {
660                    pdf_scratch: vec![0.0; compressor.pdf_buffer.len()],
661                    compressor,
662                    primed: true,
663                    min_prob,
664                }
665            }
666            RateBackend::Zpaq { method } => {
667                let model = ZpaqRateModel::new(method, min_prob);
668                Self::Zpaq { model }
669            }
670            RateBackend::Mixture { spec } => {
671                let experts = spec.build_experts();
672                let runtime = build_mixture_runtime(spec.as_ref(), &experts)
673                    .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
674                Self::Mixture { runtime }
675            }
676            RateBackend::Particle { spec } => {
677                let runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
678                Self::Particle { runtime }
679            }
680            RateBackend::Calibrated { spec } => Self::Calibrated {
681                base: Box::new(Self::from_backend(spec.base.clone(), max_order, min_prob)),
682                core: build_calibrator(spec.as_ref()),
683                pdf: [1.0 / 256.0; 256],
684                valid: false,
685                min_prob,
686            },
687        }
688    }
689
690    /// Human-readable default name for a backend + config.
691    pub fn default_name(backend: &RateBackend, max_order: i64) -> String {
692        match backend {
693            RateBackend::RosaPlus => format!("rosa(mo={})", max_order),
694            RateBackend::Match { .. } => "match".to_string(),
695            RateBackend::SparseMatch { .. } => "sparse-match".to_string(),
696            RateBackend::Ppmd { order, memory_mb } => {
697                format!("ppmd(o={},m={}MiB)", order, memory_mb)
698            }
699            RateBackend::Sequitur { context_bytes } => {
700                format!("sequitur(ctx={context_bytes})")
701            }
702            RateBackend::Ctw { depth } => format!("ctw(d={})", depth),
703            RateBackend::FacCtw {
704                base_depth,
705                encoding_bits,
706                ..
707            } => format!("fac-ctw(d={},b={})", base_depth, encoding_bits),
708            #[cfg(feature = "backend-rwkv")]
709            RateBackend::Rwkv7 { .. } => "rwkv7".to_string(),
710            #[cfg(feature = "backend-rwkv")]
711            RateBackend::Rwkv7Method { method } => format!("rwkv7({method})"),
712            #[cfg(feature = "backend-mamba")]
713            RateBackend::Mamba { .. } => "mamba".to_string(),
714            #[cfg(feature = "backend-mamba")]
715            RateBackend::MambaMethod { method } => format!("mamba({method})"),
716            RateBackend::Zpaq { method } => format!("zpaq(m={})", method),
717            RateBackend::Mixture { spec } => {
718                let kind = match spec.kind {
719                    MixtureKind::Bayes => "bayes",
720                    MixtureKind::FadingBayes => "fading",
721                    MixtureKind::Switching => "switch",
722                    MixtureKind::Convex => "convex",
723                    MixtureKind::Mdl => "mdl",
724                    MixtureKind::Neural => "neural",
725                };
726                format!("mix({})", kind)
727            }
728            RateBackend::Particle { spec } => {
729                format!("particle(n={},c={})", spec.num_particles, spec.num_cells)
730            }
731            RateBackend::Calibrated { spec } => {
732                format!("calibrated({})", Self::default_name(&spec.base, max_order))
733            }
734        }
735    }
736
737    pub(crate) fn checkpoint(&mut self) -> RateBackendPredictorCheckpoint {
738        match self {
739            RateBackendPredictor::Sequitur { model, .. } => {
740                RateBackendPredictorCheckpoint::Sequitur(model.checkpoint())
741            }
742            _ => RateBackendPredictorCheckpoint::Full(self.clone()),
743        }
744    }
745
746    pub(crate) fn restore_checkpoint(&mut self, checkpoint: &RateBackendPredictorCheckpoint) {
747        match (self, checkpoint) {
748            (
749                RateBackendPredictor::Sequitur { model, .. },
750                RateBackendPredictorCheckpoint::Sequitur(ck),
751            ) => {
752                model.restore(ck);
753            }
754            (slot, RateBackendPredictorCheckpoint::Full(state)) => {
755                *slot = state.clone();
756            }
757            (_, RateBackendPredictorCheckpoint::Sequitur(_)) => {
758                panic!("mismatched RateBackendPredictor checkpoint variant")
759            }
760        }
761    }
762
763    pub(crate) fn clear_checkpoints_if_supported(&mut self) {
764        if let RateBackendPredictor::Sequitur { model, .. } = self {
765            model.clear_checkpoints();
766        }
767    }
768}
769
770impl OnlineBytePredictor for RateBackendPredictor {
771    fn begin_stream(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
772        self.finish_stream()?;
773        match self {
774            RateBackendPredictor::Rosa { model, .. } => {
775                if let Some(total) = total_symbols {
776                    let reserve = usize::try_from(total).unwrap_or(usize::MAX / 4);
777                    model.reserve_for_stream(reserve);
778                }
779                Ok(())
780            }
781            RateBackendPredictor::Match { .. }
782            | RateBackendPredictor::SparseMatch { .. }
783            | RateBackendPredictor::Ppmd { .. } => Ok(()),
784            RateBackendPredictor::Sequitur { model, .. } => {
785                model.begin_stream(total_symbols);
786                Ok(())
787            }
788            RateBackendPredictor::Ctw { .. }
789            | RateBackendPredictor::FacCtw { .. }
790            | RateBackendPredictor::Zpaq { .. }
791            | RateBackendPredictor::Particle { .. } => Ok(()),
792            #[cfg(feature = "backend-rwkv")]
793            RateBackendPredictor::Rwkv7 { compressor, .. } => compressor
794                .begin_online_policy_stream(total_symbols)
795                .map_err(|e| e.to_string()),
796            #[cfg(feature = "backend-mamba")]
797            RateBackendPredictor::Mamba { compressor, .. } => compressor
798                .begin_online_policy_stream(total_symbols)
799                .map_err(|e| e.to_string()),
800            RateBackendPredictor::Mixture { runtime } => runtime.begin_stream(total_symbols),
801            RateBackendPredictor::Calibrated { base, .. } => base.begin_stream(total_symbols),
802        }
803    }
804
805    fn finish_stream(&mut self) -> Result<(), String> {
806        match self {
807            RateBackendPredictor::Rosa { .. }
808            | RateBackendPredictor::Match { .. }
809            | RateBackendPredictor::SparseMatch { .. }
810            | RateBackendPredictor::Ppmd { .. }
811            | RateBackendPredictor::Ctw { .. }
812            | RateBackendPredictor::FacCtw { .. }
813            | RateBackendPredictor::Zpaq { .. }
814            | RateBackendPredictor::Particle { .. } => Ok(()),
815            RateBackendPredictor::Sequitur { model, .. } => {
816                model.finish_stream();
817                Ok(())
818            }
819            #[cfg(feature = "backend-rwkv")]
820            RateBackendPredictor::Rwkv7 { compressor, .. } => compressor
821                .finish_online_policy_stream()
822                .map_err(|e| e.to_string()),
823            #[cfg(feature = "backend-mamba")]
824            RateBackendPredictor::Mamba { .. } => Ok(()),
825            RateBackendPredictor::Mixture { runtime } => runtime.finish_stream(),
826            RateBackendPredictor::Calibrated { base, .. } => base.finish_stream(),
827        }
828    }
829
830    fn log_prob(&mut self, symbol: u8) -> f64 {
831        match self {
832            RateBackendPredictor::Rosa { model, min_prob } => {
833                let p = clamp_prob(model.prob_for_last(symbol as u32), *min_prob);
834                p.ln()
835            }
836            RateBackendPredictor::Match { model, min_prob } => model.log_prob(symbol, *min_prob),
837            RateBackendPredictor::SparseMatch { model, min_prob } => {
838                model.log_prob(symbol, *min_prob)
839            }
840            RateBackendPredictor::Ppmd { model, min_prob } => model.log_prob(symbol, *min_prob),
841            RateBackendPredictor::Sequitur { model, min_prob } => model.log_prob(symbol, *min_prob),
842            RateBackendPredictor::Ctw { tree, min_prob } => {
843                let log_before = tree.get_log_block_probability();
844                for bit_idx in 0..8 {
845                    let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
846                    tree.update(bit, bit_idx);
847                }
848                let log_after = tree.get_log_block_probability();
849                for bit_idx in (0..8).rev() {
850                    tree.revert(bit_idx);
851                }
852                let logp = log_after - log_before;
853                if logp.is_finite() {
854                    logp.max(min_prob.ln())
855                } else {
856                    min_prob.ln()
857                }
858            }
859            RateBackendPredictor::FacCtw {
860                tree,
861                bits_per_symbol,
862                min_prob,
863            } => {
864                let log_before = tree.get_log_block_probability();
865                for i in 0..*bits_per_symbol {
866                    let bit = ((symbol >> i) & 1) == 1;
867                    tree.update(bit, i);
868                }
869                let log_after = tree.get_log_block_probability();
870                for i in (0..*bits_per_symbol).rev() {
871                    tree.revert(i);
872                }
873                let logp = log_after - log_before;
874                if logp.is_finite() {
875                    logp.max(min_prob.ln())
876                } else {
877                    min_prob.ln()
878                }
879            }
880            #[cfg(feature = "backend-rwkv")]
881            RateBackendPredictor::Rwkv7 {
882                compressor,
883                primed,
884                min_prob,
885                ..
886            } => {
887                ensure_rwkv_primed(compressor, primed);
888                let p = clamp_prob(compressor.pdf_buffer[symbol as usize], *min_prob);
889                p.ln()
890            }
891            #[cfg(feature = "backend-mamba")]
892            RateBackendPredictor::Mamba {
893                compressor,
894                primed,
895                min_prob,
896                ..
897            } => {
898                if !*primed {
899                    let bias = compressor.online_bias_snapshot();
900                    let logits =
901                        compressor
902                            .model
903                            .forward(&mut compressor.scratch, 0, &mut compressor.state);
904                    mambazip::Compressor::logits_to_pdf(
905                        logits,
906                        bias.as_deref(),
907                        &mut compressor.pdf_buffer,
908                    );
909                    *primed = true;
910                }
911                let p = clamp_prob(compressor.pdf_buffer[symbol as usize], *min_prob);
912                p.ln()
913            }
914            RateBackendPredictor::Zpaq { model } => model.log_prob(symbol),
915            RateBackendPredictor::Mixture { runtime } => runtime.peek_log_prob(symbol),
916            RateBackendPredictor::Particle { runtime } => runtime.peek_log_prob(symbol),
917            RateBackendPredictor::Calibrated {
918                base,
919                core,
920                pdf,
921                valid,
922                min_prob,
923            } => {
924                if !*valid {
925                    let mut base_logps = [0.0; 256];
926                    base.fill_log_probs(&mut base_logps);
927                    let mut base_pdf = [0.0; 256];
928                    for (dst, &lp) in base_pdf.iter_mut().zip(base_logps.iter()) {
929                        *dst = clamp_prob(lp.exp(), *min_prob);
930                    }
931                    core.apply_pdf(&base_pdf, pdf);
932                    *valid = true;
933                }
934                pdf[symbol as usize].max(*min_prob).ln()
935            }
936        }
937    }
938
939    fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
940        match self {
941            RateBackendPredictor::Rosa { model, min_prob } => {
942                model.fill_probs_for_last_bytes(out);
943                for slot in out.iter_mut() {
944                    *slot = clamp_prob(*slot, *min_prob).ln();
945                }
946            }
947            RateBackendPredictor::Match { model, min_prob } => {
948                let mut pdf = [0.0; 256];
949                model.fill_pdf(&mut pdf);
950                for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
951                    *slot = clamp_prob(p, *min_prob).ln();
952                }
953            }
954            RateBackendPredictor::SparseMatch { model, min_prob } => {
955                let mut pdf = [0.0; 256];
956                model.fill_pdf(&mut pdf);
957                for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
958                    *slot = clamp_prob(p, *min_prob).ln();
959                }
960            }
961            RateBackendPredictor::Ppmd { model, min_prob } => {
962                let mut pdf = [0.0; 256];
963                model.fill_pdf(&mut pdf);
964                for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
965                    *slot = clamp_prob(p, *min_prob).ln();
966                }
967            }
968            RateBackendPredictor::Sequitur { model, min_prob } => {
969                let mut pdf = [0.0; 256];
970                model.fill_pdf(&mut pdf);
971                for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
972                    *slot = clamp_prob(p, *min_prob).ln();
973                }
974            }
975            RateBackendPredictor::Ctw { tree, min_prob } => {
976                fill_fac_tree_log_probs(tree, 8, true, min_prob.ln(), out);
977            }
978            RateBackendPredictor::FacCtw {
979                tree,
980                bits_per_symbol,
981                min_prob,
982            } => {
983                fill_fac_tree_log_probs(tree, *bits_per_symbol, false, min_prob.ln(), out);
984            }
985            #[cfg(feature = "backend-rwkv")]
986            RateBackendPredictor::Rwkv7 {
987                compressor,
988                primed,
989                min_prob,
990                ..
991            } => {
992                ensure_rwkv_primed(compressor, primed);
993                for (slot, &p_raw) in out
994                    .iter_mut()
995                    .take(256)
996                    .zip(compressor.pdf_buffer.iter().take(256))
997                {
998                    let p = clamp_prob(p_raw, *min_prob);
999                    *slot = p.ln();
1000                }
1001            }
1002            #[cfg(feature = "backend-mamba")]
1003            RateBackendPredictor::Mamba {
1004                compressor,
1005                primed,
1006                min_prob,
1007                ..
1008            } => {
1009                if !*primed {
1010                    let bias = compressor.online_bias_snapshot();
1011                    let logits =
1012                        compressor
1013                            .model
1014                            .forward(&mut compressor.scratch, 0, &mut compressor.state);
1015                    mambazip::Compressor::logits_to_pdf(
1016                        logits,
1017                        bias.as_deref(),
1018                        &mut compressor.pdf_buffer,
1019                    );
1020                    *primed = true;
1021                }
1022                for (slot, &p_raw) in out
1023                    .iter_mut()
1024                    .take(256)
1025                    .zip(compressor.pdf_buffer.iter().take(256))
1026                {
1027                    let p = clamp_prob(p_raw, *min_prob);
1028                    *slot = p.ln();
1029                }
1030            }
1031            RateBackendPredictor::Zpaq { model } => {
1032                model.fill_log_probs(out);
1033            }
1034            RateBackendPredictor::Mixture { runtime } => {
1035                runtime.fill_log_probs(out);
1036            }
1037            RateBackendPredictor::Particle { runtime } => {
1038                runtime.fill_log_probs_cached(out);
1039            }
1040            RateBackendPredictor::Calibrated {
1041                base,
1042                core,
1043                pdf,
1044                valid,
1045                min_prob,
1046            } => {
1047                if !*valid {
1048                    let mut base_logps = [0.0; 256];
1049                    base.fill_log_probs(&mut base_logps);
1050                    let mut base_pdf = [0.0; 256];
1051                    for (dst, &lp) in base_pdf.iter_mut().zip(base_logps.iter()) {
1052                        *dst = clamp_prob(lp.exp(), *min_prob);
1053                    }
1054                    core.apply_pdf(&base_pdf, pdf);
1055                    *valid = true;
1056                }
1057                for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
1058                    *slot = clamp_prob(p, *min_prob).ln();
1059                }
1060            }
1061        }
1062    }
1063
1064    fn update(&mut self, symbol: u8) {
1065        match self {
1066            RateBackendPredictor::Rosa { model, .. } => {
1067                model.train_byte(symbol);
1068            }
1069            RateBackendPredictor::Match { model, .. } => {
1070                model.update(symbol);
1071            }
1072            RateBackendPredictor::SparseMatch { model, .. } => {
1073                model.update(symbol);
1074            }
1075            RateBackendPredictor::Ppmd { model, .. } => {
1076                model.update(symbol);
1077            }
1078            RateBackendPredictor::Sequitur { model, .. } => {
1079                model.update(symbol);
1080            }
1081            RateBackendPredictor::Ctw { tree, .. } => {
1082                for bit_idx in 0..8 {
1083                    let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
1084                    tree.update(bit, bit_idx);
1085                }
1086            }
1087            RateBackendPredictor::FacCtw {
1088                tree,
1089                bits_per_symbol,
1090                ..
1091            } => {
1092                for i in 0..*bits_per_symbol {
1093                    let bit = ((symbol >> i) & 1) == 1;
1094                    tree.update(bit, i);
1095                }
1096            }
1097            #[cfg(feature = "backend-rwkv")]
1098            RateBackendPredictor::Rwkv7 {
1099                compressor, primed, ..
1100            } => {
1101                ensure_rwkv_primed(compressor, primed);
1102                compressor
1103                    .observe_symbol_from_current_pdf(symbol)
1104                    .unwrap_or_else(|e| panic!("rwkv online update failed: {e}"));
1105            }
1106            #[cfg(feature = "backend-mamba")]
1107            RateBackendPredictor::Mamba {
1108                compressor,
1109                primed,
1110                pdf_scratch,
1111                ..
1112            } => {
1113                if !*primed {
1114                    let bias = compressor.online_bias_snapshot();
1115                    let logits =
1116                        compressor
1117                            .model
1118                            .forward(&mut compressor.scratch, 0, &mut compressor.state);
1119                    mambazip::Compressor::logits_to_pdf(
1120                        logits,
1121                        bias.as_deref(),
1122                        &mut compressor.pdf_buffer,
1123                    );
1124                    *primed = true;
1125                }
1126                if pdf_scratch.len() != compressor.pdf_buffer.len() {
1127                    pdf_scratch.resize(compressor.pdf_buffer.len(), 0.0);
1128                }
1129                pdf_scratch.copy_from_slice(&compressor.pdf_buffer);
1130                compressor
1131                    .online_update_from_pdf(symbol, pdf_scratch)
1132                    .unwrap_or_else(|e| panic!("mamba online update failed: {e}"));
1133                let bias = compressor.online_bias_snapshot();
1134                let logits = compressor.model.forward(
1135                    &mut compressor.scratch,
1136                    symbol as u32,
1137                    &mut compressor.state,
1138                );
1139                mambazip::Compressor::logits_to_pdf(
1140                    logits,
1141                    bias.as_deref(),
1142                    &mut compressor.pdf_buffer,
1143                );
1144            }
1145            RateBackendPredictor::Zpaq { model } => {
1146                model.update(symbol);
1147            }
1148            RateBackendPredictor::Mixture { runtime } => {
1149                let _ = runtime.step(symbol);
1150            }
1151            RateBackendPredictor::Particle { runtime } => {
1152                runtime.step(symbol);
1153            }
1154            RateBackendPredictor::Calibrated {
1155                base,
1156                core,
1157                pdf,
1158                valid,
1159                ..
1160            } => {
1161                if !*valid {
1162                    let mut base_logps = [0.0; 256];
1163                    base.fill_log_probs(&mut base_logps);
1164                    let mut base_pdf = [0.0; 256];
1165                    for (dst, &lp) in base_pdf.iter_mut().zip(base_logps.iter()) {
1166                        *dst = clamp_prob(lp.exp(), DEFAULT_MIN_PROB);
1167                    }
1168                    core.apply_pdf(&base_pdf, pdf);
1169                }
1170                core.update(symbol, pdf);
1171                base.update(symbol);
1172                *valid = false;
1173            }
1174        }
1175    }
1176
1177    fn log_prob_update(&mut self, symbol: u8) -> f64 {
1178        match self {
1179            RateBackendPredictor::Rosa { model, min_prob } => {
1180                let p = clamp_prob(model.prob_for_last(symbol as u32), *min_prob);
1181                model.train_byte(symbol);
1182                p.ln()
1183            }
1184            RateBackendPredictor::Ctw { tree, min_prob } => {
1185                ctw_log_prob_update_msb(tree, symbol, *min_prob)
1186            }
1187            RateBackendPredictor::FacCtw {
1188                tree,
1189                bits_per_symbol,
1190                min_prob,
1191            } => ctw_log_prob_update_lsb(tree, symbol, *bits_per_symbol, *min_prob),
1192            _ => {
1193                let logp = self.log_prob(symbol);
1194                self.update(symbol);
1195                logp
1196            }
1197        }
1198    }
1199
1200    fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1201        self.finish_stream()?;
1202        match self {
1203            RateBackendPredictor::Rosa { model, .. } => {
1204                if let Some(total) = total_symbols {
1205                    let reserve = usize::try_from(total).unwrap_or(usize::MAX / 4);
1206                    model.reserve_for_stream(reserve);
1207                }
1208                model.build_lm_full_bytes_no_finalize_endpos();
1209                model.reset_conditioning_cursor();
1210                Ok(())
1211            }
1212            RateBackendPredictor::Match { model, .. } => {
1213                model.reset_history();
1214                Ok(())
1215            }
1216            RateBackendPredictor::SparseMatch { model, .. } => {
1217                model.reset_history();
1218                Ok(())
1219            }
1220            RateBackendPredictor::Ppmd { model, .. } => {
1221                model.reset_history();
1222                Ok(())
1223            }
1224            RateBackendPredictor::Sequitur { model, .. } => {
1225                model.reset_frozen();
1226                Ok(())
1227            }
1228            RateBackendPredictor::Ctw { tree, .. } => {
1229                tree.reset_history_only();
1230                Ok(())
1231            }
1232            RateBackendPredictor::FacCtw { tree, .. } => {
1233                tree.reset_history_only();
1234                Ok(())
1235            }
1236            #[cfg(feature = "backend-rwkv")]
1237            RateBackendPredictor::Rwkv7 {
1238                compressor, primed, ..
1239            } => {
1240                compressor.reset_and_prime();
1241                *primed = true;
1242                Ok(())
1243            }
1244            #[cfg(feature = "backend-mamba")]
1245            RateBackendPredictor::Mamba {
1246                compressor, primed, ..
1247            } => {
1248                compressor.reset_and_prime();
1249                *primed = true;
1250                Ok(())
1251            }
1252            RateBackendPredictor::Zpaq { .. } => {
1253                Err("plugin entropy is not supported for zpaq rate backends in 1.1.1".to_string())
1254            }
1255            RateBackendPredictor::Mixture { runtime } => runtime.reset_frozen(total_symbols),
1256            RateBackendPredictor::Particle { runtime } => {
1257                runtime.reset_frozen_state();
1258                Ok(())
1259            }
1260            RateBackendPredictor::Calibrated {
1261                base,
1262                core,
1263                pdf,
1264                valid,
1265                ..
1266            } => {
1267                base.reset_frozen(total_symbols)?;
1268                core.reset_context();
1269                pdf.fill(1.0 / 256.0);
1270                *valid = false;
1271                Ok(())
1272            }
1273        }
1274    }
1275
1276    fn update_frozen(&mut self, symbol: u8) {
1277        match self {
1278            RateBackendPredictor::Rosa { model, .. } => {
1279                model.advance_conditioning_byte(symbol);
1280            }
1281            RateBackendPredictor::Match { model, .. } => {
1282                model.update_history_only(symbol);
1283            }
1284            RateBackendPredictor::SparseMatch { model, .. } => {
1285                model.update_history_only(symbol);
1286            }
1287            RateBackendPredictor::Ppmd { model, .. } => {
1288                model.update_history_only(symbol);
1289            }
1290            RateBackendPredictor::Sequitur { model, .. } => {
1291                model.update_frozen(symbol);
1292            }
1293            RateBackendPredictor::Ctw { tree, .. } => {
1294                let mut bits = [false; 8];
1295                for (bit_idx, slot) in bits.iter_mut().enumerate() {
1296                    *slot = ((symbol >> (7 - bit_idx)) & 1) == 1;
1297                }
1298                tree.update_history(&bits);
1299            }
1300            RateBackendPredictor::FacCtw {
1301                tree,
1302                bits_per_symbol,
1303                ..
1304            } => {
1305                let bits = (*bits_per_symbol).clamp(1, 8);
1306                let mut history_bits = [false; 8];
1307                for (idx, slot) in history_bits.iter_mut().enumerate().take(bits) {
1308                    *slot = ((symbol >> idx) & 1) == 1;
1309                }
1310                tree.update_history(&history_bits[..bits]);
1311            }
1312            #[cfg(feature = "backend-rwkv")]
1313            RateBackendPredictor::Rwkv7 {
1314                compressor, primed, ..
1315            } => {
1316                if !*primed {
1317                    compressor.reset_and_prime();
1318                    *primed = true;
1319                }
1320                compressor.forward_to_internal_pdf(symbol as u32);
1321            }
1322            #[cfg(feature = "backend-mamba")]
1323            RateBackendPredictor::Mamba {
1324                compressor, primed, ..
1325            } => {
1326                if !*primed {
1327                    compressor.reset_and_prime();
1328                    *primed = true;
1329                }
1330                let bias = compressor.online_bias_snapshot();
1331                let logits = compressor.model.forward(
1332                    &mut compressor.scratch,
1333                    symbol as u32,
1334                    &mut compressor.state,
1335                );
1336                mambazip::Compressor::logits_to_pdf(
1337                    logits,
1338                    bias.as_deref(),
1339                    &mut compressor.pdf_buffer,
1340                );
1341            }
1342            RateBackendPredictor::Zpaq { model } => {
1343                model.update(symbol);
1344            }
1345            RateBackendPredictor::Mixture { runtime } => {
1346                runtime.update_frozen(symbol);
1347            }
1348            RateBackendPredictor::Particle { runtime } => {
1349                runtime.update_frozen(symbol);
1350            }
1351            RateBackendPredictor::Calibrated {
1352                base,
1353                core,
1354                pdf,
1355                valid,
1356                ..
1357            } => {
1358                if !*valid {
1359                    let mut base_logps = [0.0; 256];
1360                    base.fill_log_probs(&mut base_logps);
1361                    let mut base_pdf = [0.0; 256];
1362                    for (dst, &lp) in base_pdf.iter_mut().zip(base_logps.iter()) {
1363                        *dst = clamp_prob(lp.exp(), DEFAULT_MIN_PROB);
1364                    }
1365                    core.apply_pdf(&base_pdf, pdf);
1366                    *valid = true;
1367                }
1368                base.update_frozen(symbol);
1369                core.update_context_only(symbol);
1370                *valid = false;
1371            }
1372        }
1373    }
1374}
1375
1376/// Configuration for a mixture expert.
1377#[derive(Clone)]
1378pub struct ExpertConfig {
1379    /// Human-readable expert identifier.
1380    pub name: String,
1381    /// Log prior weight (natural log). Uniform priors can be `0.0`.
1382    pub log_prior: f64,
1383    builder: Arc<dyn Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync>,
1384}
1385
1386impl ExpertConfig {
1387    /// Create a new expert config from a builder closure.
1388    pub fn new(
1389        name: impl Into<String>,
1390        log_prior: f64,
1391        builder: impl Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync + 'static,
1392    ) -> Self {
1393        Self {
1394            name: name.into(),
1395            log_prior,
1396            builder: Arc::new(builder),
1397        }
1398    }
1399
1400    /// Uniform prior helper.
1401    pub fn uniform(
1402        name: impl Into<String>,
1403        builder: impl Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync + 'static,
1404    ) -> Self {
1405        Self::new(name, 0.0, builder)
1406    }
1407
1408    /// Expert from a `RateBackend` configuration. `max_order` applies to ROSA.
1409    pub fn from_rate_backend(
1410        name: Option<String>,
1411        log_prior: f64,
1412        backend: RateBackend,
1413        max_order: i64,
1414    ) -> Self {
1415        let name = name.unwrap_or_else(|| RateBackendPredictor::default_name(&backend, max_order));
1416        Self::new(name, log_prior, move || {
1417            Box::new(RateBackendPredictor::from_backend(
1418                backend.clone(),
1419                max_order,
1420                DEFAULT_MIN_PROB,
1421            ))
1422        })
1423    }
1424
1425    /// ROSA expert (uniform prior).
1426    pub fn rosa(name: impl Into<String>, max_order: i64) -> Self {
1427        let name = name.into();
1428        Self::uniform(name, move || {
1429            Box::new(RateBackendPredictor::from_backend(
1430                RateBackend::RosaPlus,
1431                max_order,
1432                DEFAULT_MIN_PROB,
1433            ))
1434        })
1435    }
1436
1437    /// CTW expert (uniform prior).
1438    pub fn ctw(name: impl Into<String>, depth: usize) -> Self {
1439        let name = name.into();
1440        Self::uniform(name, move || {
1441            Box::new(RateBackendPredictor::from_backend(
1442                RateBackend::Ctw { depth },
1443                -1,
1444                DEFAULT_MIN_PROB,
1445            ))
1446        })
1447    }
1448
1449    /// FAC-CTW expert (uniform prior).
1450    pub fn fac_ctw(name: impl Into<String>, base_depth: usize, encoding_bits: usize) -> Self {
1451        let name = name.into();
1452        Self::uniform(name, move || {
1453            Box::new(RateBackendPredictor::from_backend(
1454                RateBackend::FacCtw {
1455                    base_depth,
1456                    num_percept_bits: encoding_bits,
1457                    encoding_bits,
1458                },
1459                -1,
1460                DEFAULT_MIN_PROB,
1461            ))
1462        })
1463    }
1464
1465    /// RWKV-7 expert (uniform prior).
1466    #[cfg(feature = "backend-rwkv")]
1467    pub fn rwkv(name: impl Into<String>, model: Arc<rwkvzip::Model>) -> Self {
1468        let name = name.into();
1469        Self::uniform(name, move || {
1470            Box::new(RateBackendPredictor::from_backend(
1471                RateBackend::Rwkv7 {
1472                    model: model.clone(),
1473                },
1474                -1,
1475                DEFAULT_MIN_PROB,
1476            ))
1477        })
1478    }
1479
1480    /// Mamba expert (uniform prior).
1481    #[cfg(feature = "backend-mamba")]
1482    pub fn mamba(name: impl Into<String>, model: Arc<mambazip::Model>) -> Self {
1483        let name = name.into();
1484        Self::uniform(name, move || {
1485            Box::new(RateBackendPredictor::from_backend(
1486                RateBackend::Mamba {
1487                    model: model.clone(),
1488                },
1489                -1,
1490                DEFAULT_MIN_PROB,
1491            ))
1492        })
1493    }
1494
1495    /// ZPAQ expert (uniform prior).
1496    pub fn zpaq(name: impl Into<String>, method: impl Into<String>) -> Self {
1497        let name = name.into();
1498        let method = method.into();
1499        Self::uniform(name, move || {
1500            Box::new(RateBackendPredictor::from_backend(
1501                RateBackend::Zpaq {
1502                    method: method.clone(),
1503                },
1504                -1,
1505                DEFAULT_MIN_PROB,
1506            ))
1507        })
1508    }
1509
1510    /// Expert name.
1511    pub fn name(&self) -> &str {
1512        &self.name
1513    }
1514
1515    /// Log prior weight (unnormalized).
1516    pub fn log_prior(&self) -> f64 {
1517        self.log_prior
1518    }
1519
1520    /// Build a fresh predictor instance for evaluation or analysis.
1521    pub fn build_predictor(&self) -> Box<dyn OnlineBytePredictor> {
1522        (self.builder)()
1523    }
1524
1525    fn build(&self) -> ExpertState {
1526        ExpertState {
1527            name: self.name.clone(),
1528            log_weight: self.log_prior,
1529            log_prior: self.log_prior,
1530            predictor: (self.builder)(),
1531            cum_log_loss: 0.0,
1532        }
1533    }
1534}
1535
1536#[derive(Clone)]
1537struct ExpertState {
1538    name: String,
1539    log_weight: f64,
1540    log_prior: f64,
1541    predictor: Box<dyn OnlineBytePredictor>,
1542    cum_log_loss: f64,
1543}
1544
1545impl ExpertState {
1546    #[inline]
1547    fn begin_stream(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1548        self.predictor.begin_stream(total_symbols)
1549    }
1550
1551    #[inline]
1552    fn finish_stream(&mut self) -> Result<(), String> {
1553        self.predictor.finish_stream()
1554    }
1555
1556    #[inline]
1557    fn log_prob(&mut self, symbol: u8) -> f64 {
1558        self.predictor.log_prob(symbol)
1559    }
1560
1561    #[inline]
1562    fn log_prob_update(&mut self, symbol: u8) -> f64 {
1563        self.predictor.log_prob_update(symbol)
1564    }
1565
1566    #[inline]
1567    fn update(&mut self, symbol: u8) {
1568        self.predictor.update(symbol);
1569    }
1570
1571    #[inline]
1572    fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1573        self.predictor.reset_frozen(total_symbols)
1574    }
1575
1576    #[inline]
1577    fn update_frozen(&mut self, symbol: u8) {
1578        self.predictor.update_frozen(symbol);
1579    }
1580}
1581
1582/// Exponential-weights Bayes mixture (log-loss Hedge).
1583#[derive(Clone)]
1584pub struct BayesMixture {
1585    experts: Vec<ExpertState>,
1586    scratch_logps: Vec<f64>,
1587    scratch_mix: Vec<f64>,
1588    cached_symbol: u8,
1589    cached_log_mix: f64,
1590    cache_valid: bool,
1591    total_log_loss: f64,
1592}
1593
1594impl BayesMixture {
1595    /// Construct a normalized Bayes mixture from expert configs.
1596    pub fn new(configs: &[ExpertConfig]) -> Self {
1597        let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
1598        let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
1599        let norm = logsumexp(&log_priors);
1600        for e in &mut experts {
1601            e.log_weight -= norm;
1602        }
1603        Self {
1604            experts,
1605            scratch_logps: vec![0.0; configs.len()],
1606            scratch_mix: vec![0.0; configs.len()],
1607            cached_symbol: 0,
1608            cached_log_mix: f64::NEG_INFINITY,
1609            cache_valid: false,
1610            total_log_loss: 0.0,
1611        }
1612    }
1613
1614    /// Log-probability (natural log) of the mixture for `symbol`, then update.
1615    pub fn step(&mut self, symbol: u8) -> f64 {
1616        if self.experts.is_empty() {
1617            return f64::NEG_INFINITY;
1618        }
1619        let log_mix = if self.cache_valid && self.cached_symbol == symbol {
1620            for (i, expert) in self.experts.iter_mut().enumerate() {
1621                expert.cum_log_loss -= self.scratch_logps[i];
1622                expert.update(symbol);
1623            }
1624            self.cached_log_mix
1625        } else {
1626            for (i, expert) in self.experts.iter_mut().enumerate() {
1627                self.scratch_logps[i] = expert.log_prob_update(symbol);
1628                self.scratch_mix[i] = expert.log_weight + self.scratch_logps[i];
1629                expert.cum_log_loss -= self.scratch_logps[i];
1630            }
1631            logsumexp(&self.scratch_mix)
1632        };
1633        for (i, expert) in self.experts.iter_mut().enumerate() {
1634            expert.log_weight = expert.log_weight + self.scratch_logps[i] - log_mix;
1635        }
1636        self.cache_valid = false;
1637        self.total_log_loss -= log_mix;
1638        log_mix
1639    }
1640
1641    fn predict_log_prob(&mut self, symbol: u8) -> f64 {
1642        if self.experts.is_empty() {
1643            return f64::NEG_INFINITY;
1644        }
1645        for (i, expert) in self.experts.iter_mut().enumerate() {
1646            self.scratch_logps[i] = expert.log_prob(symbol);
1647            self.scratch_mix[i] = expert.log_weight + self.scratch_logps[i];
1648        }
1649        let log_mix = logsumexp(&self.scratch_mix);
1650        self.cached_symbol = symbol;
1651        self.cached_log_mix = log_mix;
1652        self.cache_valid = true;
1653        log_mix
1654    }
1655
1656    fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
1657        if self.experts.is_empty() {
1658            out.fill(f64::NEG_INFINITY);
1659            return;
1660        }
1661        out.fill(f64::NEG_INFINITY);
1662        let norm = logsumexp_weights(&self.experts);
1663        let mut row = [0.0f64; 256];
1664        for expert in &mut self.experts {
1665            expert.predictor.fill_log_probs(&mut row);
1666            let lw = expert.log_weight - norm;
1667            for b in 0..256 {
1668                out[b] = logsumexp2(out[b], lw + row[b]);
1669            }
1670        }
1671    }
1672
1673    /// Posterior weights (normalized) over experts.
1674    pub fn posterior(&self) -> Vec<f64> {
1675        let norm = logsumexp_weights(&self.experts);
1676        self.experts
1677            .iter()
1678            .map(|e| (e.log_weight - norm).exp())
1679            .collect()
1680    }
1681
1682    /// Index and log-loss (nats) of the current best expert.
1683    pub fn min_expert_log_loss(&self) -> (usize, f64) {
1684        let mut best_idx = 0usize;
1685        let mut best_loss = f64::INFINITY;
1686        for (i, e) in self.experts.iter().enumerate() {
1687            if e.cum_log_loss < best_loss {
1688                best_loss = e.cum_log_loss;
1689                best_idx = i;
1690            }
1691        }
1692        (best_idx, best_loss)
1693    }
1694
1695    /// Index and posterior mass of the most likely expert.
1696    pub fn max_posterior(&self) -> (usize, f64) {
1697        let norm = logsumexp_weights(&self.experts);
1698        let mut best_idx = 0usize;
1699        let mut best_p = 0.0;
1700        for (i, e) in self.experts.iter().enumerate() {
1701            let p = (e.log_weight - norm).exp();
1702            if p > best_p {
1703                best_p = p;
1704                best_idx = i;
1705            }
1706        }
1707        (best_idx, best_p)
1708    }
1709
1710    /// Total log-loss of the mixture so far (nats).
1711    pub fn total_log_loss(&self) -> f64 {
1712        self.total_log_loss
1713    }
1714
1715    /// Expert cumulative log-losses (nats) and names.
1716    pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
1717        self.experts
1718            .iter()
1719            .map(|e| (e.name.clone(), e.cum_log_loss))
1720            .collect()
1721    }
1722
1723    /// Expert names in order.
1724    pub fn expert_names(&self) -> Vec<String> {
1725        self.experts.iter().map(|e| e.name.clone()).collect()
1726    }
1727
1728    fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1729        for expert in &mut self.experts {
1730            expert.reset_frozen(total_symbols)?;
1731        }
1732        self.cache_valid = false;
1733        self.total_log_loss = 0.0;
1734        Ok(())
1735    }
1736
1737    fn update_frozen(&mut self, symbol: u8) {
1738        for expert in &mut self.experts {
1739            expert.update_frozen(symbol);
1740        }
1741        self.cache_valid = false;
1742    }
1743}
1744
1745/// Exponential-weights Bayes mixture with exponential forgetting on weights.
1746///
1747/// This is a non-stationary control: weights are discounted each step by `decay`.
1748#[derive(Clone)]
1749pub struct FadingBayesMixture {
1750    experts: Vec<ExpertState>,
1751    decay: f64,
1752    scratch_logps: Vec<f64>,
1753    scratch_mix: Vec<f64>,
1754    cached_symbol: u8,
1755    cached_log_predictive: f64,
1756    cached_log_evidence: f64,
1757    cache_valid: bool,
1758    total_log_loss: f64,
1759}
1760
1761impl FadingBayesMixture {
1762    /// Construct a fading Bayes mixture with decay in `[0, 1]`.
1763    pub fn new(configs: &[ExpertConfig], decay: f64) -> Self {
1764        let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
1765        let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
1766        let norm = logsumexp(&log_priors);
1767        for e in &mut experts {
1768            e.log_weight -= norm;
1769        }
1770        let decay = decay.clamp(0.0, 1.0);
1771        Self {
1772            experts,
1773            decay,
1774            scratch_logps: vec![0.0; configs.len()],
1775            scratch_mix: vec![0.0; configs.len()],
1776            cached_symbol: 0,
1777            cached_log_predictive: f64::NEG_INFINITY,
1778            cached_log_evidence: f64::NEG_INFINITY,
1779            cache_valid: false,
1780            total_log_loss: 0.0,
1781        }
1782    }
1783
1784    /// Log-probability (natural log) of the fading mixture for `symbol`, then update.
1785    pub fn step(&mut self, symbol: u8) -> f64 {
1786        if self.experts.is_empty() {
1787            return f64::NEG_INFINITY;
1788        }
1789        let (log_predictive, log_evidence) = if self.cache_valid && self.cached_symbol == symbol {
1790            for (i, expert) in self.experts.iter_mut().enumerate() {
1791                expert.cum_log_loss -= self.scratch_logps[i];
1792                expert.update(symbol);
1793            }
1794            (self.cached_log_predictive, self.cached_log_evidence)
1795        } else {
1796            for (i, expert) in self.experts.iter_mut().enumerate() {
1797                self.scratch_logps[i] = expert.log_prob_update(symbol);
1798                self.scratch_mix[i] = self.decay * expert.log_weight;
1799            }
1800            let log_prior_norm = logsumexp(&self.scratch_mix);
1801            for (i, expert) in self.experts.iter_mut().enumerate() {
1802                self.scratch_mix[i] += self.scratch_logps[i];
1803                expert.cum_log_loss -= self.scratch_logps[i];
1804            }
1805            let log_evidence = logsumexp(&self.scratch_mix);
1806            (log_evidence - log_prior_norm, log_evidence)
1807        };
1808        for (i, expert) in self.experts.iter_mut().enumerate() {
1809            let decayed = self.decay * expert.log_weight;
1810            expert.log_weight = decayed + self.scratch_logps[i] - log_evidence;
1811        }
1812        self.cache_valid = false;
1813        self.total_log_loss -= log_predictive;
1814        log_predictive
1815    }
1816
1817    fn predict_log_prob(&mut self, symbol: u8) -> f64 {
1818        if self.experts.is_empty() {
1819            return f64::NEG_INFINITY;
1820        }
1821        for (i, expert) in self.experts.iter_mut().enumerate() {
1822            self.scratch_logps[i] = expert.log_prob(symbol);
1823            self.scratch_mix[i] = self.decay * expert.log_weight;
1824        }
1825        let log_prior_norm = logsumexp(&self.scratch_mix);
1826        for i in 0..self.experts.len() {
1827            self.scratch_mix[i] += self.scratch_logps[i];
1828        }
1829        let log_evidence = logsumexp(&self.scratch_mix);
1830        let log_predictive = log_evidence - log_prior_norm;
1831        self.cached_symbol = symbol;
1832        self.cached_log_predictive = log_predictive;
1833        self.cached_log_evidence = log_evidence;
1834        self.cache_valid = true;
1835        log_predictive
1836    }
1837
1838    fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
1839        if self.experts.is_empty() {
1840            out.fill(f64::NEG_INFINITY);
1841            return;
1842        }
1843        out.fill(f64::NEG_INFINITY);
1844        let mut decayed = Vec::with_capacity(self.experts.len());
1845        for expert in &self.experts {
1846            decayed.push(self.decay * expert.log_weight);
1847        }
1848        let norm = logsumexp(&decayed);
1849        let mut row = [0.0f64; 256];
1850        for (i, expert) in self.experts.iter_mut().enumerate() {
1851            expert.predictor.fill_log_probs(&mut row);
1852            let lw = decayed[i] - norm;
1853            for b in 0..256 {
1854                out[b] = logsumexp2(out[b], lw + row[b]);
1855            }
1856        }
1857    }
1858
1859    /// Posterior weights (normalized) over experts.
1860    pub fn posterior(&self) -> Vec<f64> {
1861        let norm = logsumexp_weights(&self.experts);
1862        self.experts
1863            .iter()
1864            .map(|e| (e.log_weight - norm).exp())
1865            .collect()
1866    }
1867
1868    /// Index and log-loss (nats) of the current best expert (non-discounted loss).
1869    pub fn min_expert_log_loss(&self) -> (usize, f64) {
1870        let mut best_idx = 0usize;
1871        let mut best_loss = f64::INFINITY;
1872        for (i, e) in self.experts.iter().enumerate() {
1873            if e.cum_log_loss < best_loss {
1874                best_loss = e.cum_log_loss;
1875                best_idx = i;
1876            }
1877        }
1878        (best_idx, best_loss)
1879    }
1880
1881    /// Total log-loss of the mixture so far (nats).
1882    pub fn total_log_loss(&self) -> f64 {
1883        self.total_log_loss
1884    }
1885
1886    /// Expert names in order.
1887    pub fn expert_names(&self) -> Vec<String> {
1888        self.experts.iter().map(|e| e.name.clone()).collect()
1889    }
1890
1891    fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1892        for expert in &mut self.experts {
1893            expert.reset_frozen(total_symbols)?;
1894        }
1895        self.cache_valid = false;
1896        self.total_log_loss = 0.0;
1897        Ok(())
1898    }
1899
1900    fn update_frozen(&mut self, symbol: u8) {
1901        for expert in &mut self.experts {
1902            expert.update_frozen(symbol);
1903        }
1904        self.cache_valid = false;
1905    }
1906}
1907
1908/// Switching mixture: allows occasional switches between experts.
1909#[derive(Clone)]
1910pub struct SwitchingMixture {
1911    experts: Vec<ExpertState>,
1912    prior: Vec<f64>,
1913    alpha: f64,
1914    schedule: MixtureScheduleMode,
1915    scratch_logps: Vec<f64>,
1916    scratch_joint: Vec<f64>,
1917    scratch_weights: Vec<f64>,
1918    cached_symbol: u8,
1919    cached_log_mix: f64,
1920    cache_valid: bool,
1921    total_log_loss: f64,
1922    update_count: u64,
1923}
1924
1925impl SwitchingMixture {
1926    /// Construct a switching mixture.
1927    pub fn new(configs: &[ExpertConfig], alpha: f64, schedule: MixtureScheduleMode) -> Self {
1928        let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
1929        let prior = normalized_prior_weights(configs);
1930        set_log_weights_from_linear(&mut experts, &prior);
1931        Self {
1932            experts,
1933            prior,
1934            alpha,
1935            schedule,
1936            scratch_logps: vec![0.0; configs.len()],
1937            scratch_joint: vec![0.0; configs.len()],
1938            scratch_weights: vec![0.0; configs.len()],
1939            cached_symbol: 0,
1940            cached_log_mix: f64::NEG_INFINITY,
1941            cache_valid: false,
1942            total_log_loss: 0.0,
1943            update_count: 0,
1944        }
1945    }
1946
1947    /// Log-probability (natural log) of the switching mixture for `symbol`, then update.
1948    pub fn step(&mut self, symbol: u8) -> f64 {
1949        if self.experts.is_empty() {
1950            return f64::NEG_INFINITY;
1951        }
1952        let log_mix = if self.cache_valid && self.cached_symbol == symbol {
1953            for (i, expert) in self.experts.iter_mut().enumerate() {
1954                expert.cum_log_loss -= self.scratch_logps[i];
1955                expert.update(symbol);
1956            }
1957            self.cached_log_mix
1958        } else {
1959            for (i, expert) in self.experts.iter_mut().enumerate() {
1960                self.scratch_logps[i] = expert.log_prob_update(symbol);
1961                expert.cum_log_loss -= self.scratch_logps[i];
1962                self.scratch_joint[i] = expert.log_weight + self.scratch_logps[i];
1963            }
1964            logsumexp(&self.scratch_joint)
1965        };
1966
1967        for i in 0..self.experts.len() {
1968            self.scratch_weights[i] = (self.scratch_joint[i] - log_mix).exp();
1969        }
1970
1971        let alpha = switching_alpha_for_update(self.schedule, self.alpha, self.update_count);
1972        self.update_count = self.update_count.saturating_add(1);
1973
1974        if self.experts.len() == 1 || alpha <= 0.0 {
1975            set_log_weights_from_linear(&mut self.experts, &self.scratch_weights);
1976        } else {
1977            let mut switch_out_sum = 0.0;
1978            let mut num_switch_targets = 0usize;
1979            for &prior in &self.prior {
1980                if prior < 1.0 {
1981                    num_switch_targets += 1;
1982                }
1983            }
1984
1985            if num_switch_targets <= 1 {
1986                set_log_weights_from_linear(&mut self.experts, &self.scratch_weights);
1987            } else {
1988                for i in 0..self.experts.len() {
1989                    let denom = 1.0 - self.prior[i];
1990                    if denom > 0.0 {
1991                        switch_out_sum += self.scratch_weights[i] / denom;
1992                    }
1993                }
1994
1995                for i in 0..self.experts.len() {
1996                    let stay = (1.0 - alpha) * self.scratch_weights[i];
1997                    let switch_in = if self.prior[i] > 0.0 {
1998                        let denom = 1.0 - self.prior[i];
1999                        let switchable_mass = if denom > 0.0 {
2000                            switch_out_sum - self.scratch_weights[i] / denom
2001                        } else {
2002                            0.0
2003                        };
2004                        alpha * self.prior[i] * switchable_mass
2005                    } else {
2006                        0.0
2007                    };
2008                    self.scratch_joint[i] = stay + switch_in;
2009                }
2010                normalize_simplex_weights(&mut self.scratch_joint);
2011                set_log_weights_from_linear(&mut self.experts, &self.scratch_joint);
2012            }
2013        }
2014        self.cache_valid = false;
2015        self.total_log_loss -= log_mix;
2016        log_mix
2017    }
2018
2019    fn predict_log_prob(&mut self, symbol: u8) -> f64 {
2020        if self.experts.is_empty() {
2021            return f64::NEG_INFINITY;
2022        }
2023        for i in 0..self.experts.len() {
2024            let lp = self.experts[i].log_prob(symbol);
2025            self.scratch_logps[i] = lp;
2026            self.scratch_joint[i] = self.experts[i].log_weight + lp;
2027        }
2028        let log_mix = logsumexp(&self.scratch_joint);
2029        self.cached_symbol = symbol;
2030        self.cached_log_mix = log_mix;
2031        self.cache_valid = true;
2032        log_mix
2033    }
2034
2035    fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2036        if self.experts.is_empty() {
2037            out.fill(f64::NEG_INFINITY);
2038            return;
2039        }
2040        out.fill(f64::NEG_INFINITY);
2041        let norm = logsumexp_weights(&self.experts);
2042        let mut row = [0.0f64; 256];
2043        for expert in &mut self.experts {
2044            expert.predictor.fill_log_probs(&mut row);
2045            let lw = expert.log_weight - norm;
2046            for b in 0..256 {
2047                out[b] = logsumexp2(out[b], lw + row[b]);
2048            }
2049        }
2050    }
2051
2052    /// Posterior weights (normalized) over experts.
2053    pub fn posterior(&self) -> Vec<f64> {
2054        let norm = logsumexp_weights(&self.experts);
2055        self.experts
2056            .iter()
2057            .map(|e| (e.log_weight - norm).exp())
2058            .collect()
2059    }
2060
2061    /// Index and log-loss (nats) of the current best expert.
2062    pub fn min_expert_log_loss(&self) -> (usize, f64) {
2063        let mut best_idx = 0usize;
2064        let mut best_loss = f64::INFINITY;
2065        for (i, e) in self.experts.iter().enumerate() {
2066            if e.cum_log_loss < best_loss {
2067                best_loss = e.cum_log_loss;
2068                best_idx = i;
2069            }
2070        }
2071        (best_idx, best_loss)
2072    }
2073
2074    /// Index and posterior mass of the most likely expert.
2075    pub fn max_posterior(&self) -> (usize, f64) {
2076        let norm = logsumexp_weights(&self.experts);
2077        let mut best_idx = 0usize;
2078        let mut best_p = 0.0;
2079        for (i, e) in self.experts.iter().enumerate() {
2080            let p = (e.log_weight - norm).exp();
2081            if p > best_p {
2082                best_p = p;
2083                best_idx = i;
2084            }
2085        }
2086        (best_idx, best_p)
2087    }
2088
2089    /// Total log-loss of the mixture so far (nats).
2090    pub fn total_log_loss(&self) -> f64 {
2091        self.total_log_loss
2092    }
2093
2094    /// Expert cumulative log-losses (nats) and names.
2095    pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
2096        self.experts
2097            .iter()
2098            .map(|e| (e.name.clone(), e.cum_log_loss))
2099            .collect()
2100    }
2101
2102    /// Expert names in order.
2103    pub fn expert_names(&self) -> Vec<String> {
2104        self.experts.iter().map(|e| e.name.clone()).collect()
2105    }
2106
2107    fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2108        for expert in &mut self.experts {
2109            expert.reset_frozen(total_symbols)?;
2110        }
2111        self.cache_valid = false;
2112        self.total_log_loss = 0.0;
2113        self.update_count = 0;
2114        Ok(())
2115    }
2116
2117    fn update_frozen(&mut self, symbol: u8) {
2118        for expert in &mut self.experts {
2119            expert.update_frozen(symbol);
2120        }
2121        self.cache_valid = false;
2122    }
2123}
2124
2125/// Convex mixture with projected-simplex online updates.
2126#[derive(Clone)]
2127pub struct ConvexMixture {
2128    experts: Vec<ExpertState>,
2129    alpha: f64,
2130    schedule: MixtureScheduleMode,
2131    lambda: Vec<f64>,
2132    scratch_logps: Vec<f64>,
2133    projection_scratch: Vec<f64>,
2134    cached_symbol: u8,
2135    cached_log_mix: f64,
2136    cache_valid: bool,
2137    total_log_loss: f64,
2138    update_count: u64,
2139}
2140
2141impl ConvexMixture {
2142    /// Construct a convex mixture with prior-derived initial weights.
2143    pub fn new(configs: &[ExpertConfig], alpha: f64, schedule: MixtureScheduleMode) -> Self {
2144        Self {
2145            experts: configs.iter().map(|c| c.build()).collect(),
2146            alpha,
2147            schedule,
2148            lambda: normalized_prior_weights(configs),
2149            scratch_logps: vec![0.0; configs.len()],
2150            projection_scratch: Vec::with_capacity(configs.len()),
2151            cached_symbol: 0,
2152            cached_log_mix: f64::NEG_INFINITY,
2153            cache_valid: false,
2154            total_log_loss: 0.0,
2155            update_count: 0,
2156        }
2157    }
2158
2159    fn mix_log_prob(&self, logps: &[f64]) -> f64 {
2160        let mut mix = 0.0;
2161        for (weight, &logp) in self.lambda.iter().zip(logps.iter()) {
2162            if *weight > 0.0 {
2163                mix += *weight * logp.exp();
2164            }
2165        }
2166        clamp_prob(mix, DEFAULT_MIN_PROB).ln()
2167    }
2168
2169    /// Log-probability (natural log) of the convex mixture for `symbol`, then update.
2170    pub fn step(&mut self, symbol: u8) -> f64 {
2171        if self.experts.is_empty() {
2172            return f64::NEG_INFINITY;
2173        }
2174
2175        let log_mix = if self.cache_valid && self.cached_symbol == symbol {
2176            for (i, expert) in self.experts.iter_mut().enumerate() {
2177                expert.cum_log_loss -= self.scratch_logps[i];
2178                expert.update(symbol);
2179            }
2180            self.cached_log_mix
2181        } else {
2182            for (i, expert) in self.experts.iter_mut().enumerate() {
2183                self.scratch_logps[i] = expert.log_prob_update(symbol);
2184                expert.cum_log_loss -= self.scratch_logps[i];
2185            }
2186            self.mix_log_prob(&self.scratch_logps)
2187        };
2188
2189        self.update_count = self.update_count.saturating_add(1);
2190        let step_size = convex_step_size_for_update(self.schedule, self.alpha, self.update_count);
2191        for (weight, &logp) in self.lambda.iter_mut().zip(self.scratch_logps.iter()) {
2192            let grad = -(logp - log_mix).exp();
2193            *weight -= step_size * grad;
2194        }
2195        project_simplex_with_scratch(&mut self.lambda, &mut self.projection_scratch);
2196        self.cache_valid = false;
2197        self.total_log_loss -= log_mix;
2198        log_mix
2199    }
2200
2201    fn predict_log_prob(&mut self, symbol: u8) -> f64 {
2202        if self.experts.is_empty() {
2203            return f64::NEG_INFINITY;
2204        }
2205        for (i, expert) in self.experts.iter_mut().enumerate() {
2206            self.scratch_logps[i] = expert.log_prob(symbol);
2207        }
2208        let log_mix = self.mix_log_prob(&self.scratch_logps);
2209        self.cached_symbol = symbol;
2210        self.cached_log_mix = log_mix;
2211        self.cache_valid = true;
2212        log_mix
2213    }
2214
2215    fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2216        if self.experts.is_empty() {
2217            out.fill(f64::NEG_INFINITY);
2218            return;
2219        }
2220        out.fill(f64::NEG_INFINITY);
2221        let mut row = [0.0f64; 256];
2222        for (index, expert) in self.experts.iter_mut().enumerate() {
2223            expert.predictor.fill_log_probs(&mut row);
2224            let weight = self.lambda.get(index).copied().unwrap_or(0.0);
2225            if weight <= 0.0 {
2226                continue;
2227            }
2228            let log_weight = weight.ln();
2229            for byte in 0..256 {
2230                out[byte] = logsumexp2(out[byte], log_weight + row[byte]);
2231            }
2232        }
2233    }
2234
2235    fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2236        for expert in &mut self.experts {
2237            expert.reset_frozen(total_symbols)?;
2238        }
2239        self.cache_valid = false;
2240        self.total_log_loss = 0.0;
2241        self.update_count = 0;
2242        Ok(())
2243    }
2244
2245    fn update_frozen(&mut self, symbol: u8) {
2246        for expert in &mut self.experts {
2247            expert.update_frozen(symbol);
2248        }
2249        self.cache_valid = false;
2250    }
2251}
2252
2253/// MDL-style selector: predicts with the current best expert (by cumulative loss).
2254#[derive(Clone)]
2255pub struct MdlSelector {
2256    experts: Vec<ExpertState>,
2257    scratch_logps: Vec<f64>,
2258    total_log_loss: f64,
2259    last_best: usize,
2260    cached_symbol: u8,
2261    cached_best_idx: usize,
2262    cached_best_logp: f64,
2263    cache_valid: bool,
2264}
2265
2266/// Bytewise neural mixer inspired by fx2-cmix online adaptation.
2267///
2268/// This model is a context-conditioned two-stage gating network trained online
2269/// from per-symbol expert likelihoods:
2270/// 1) context-local first-stage expert gates,
2271/// 2) context-local second-stage meta-gate over stage-1 outputs,
2272/// 3) per-symbol SGD updates with optional tiny-error skip.
2273#[derive(Clone)]
2274pub struct NeuralMixture {
2275    experts: Vec<ExpertState>,
2276    neural: NeuralMixCore,
2277    analyzer: TextContextAnalyzer,
2278    min_prob: f64,
2279    scratch_expert_logps: Vec<f64>,
2280    scratch_mix_weights: Vec<f64>,
2281    eval_cache_valid: bool,
2282    eval_cache_full_valid: bool,
2283    eval_cache_history: NeuralHistoryState,
2284    eval_cache_symbol: u8,
2285    eval_cache_logp: f64,
2286    eval_cache_mix_logps: [f64; 256],
2287    eval_cache_expert_logps: Vec<[f64; 256]>,
2288    total_log_loss: f64,
2289}
2290
2291impl NeuralMixture {
2292    /// Construct a neural mixture. `learning_rate` is taken from `MixtureSpec.alpha`.
2293    pub fn new(configs: &[ExpertConfig], learning_rate: f64) -> Self {
2294        let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
2295        let n = experts.len();
2296
2297        let mut prior_weights = vec![0.0; n];
2298        if n > 0 {
2299            let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
2300            let norm = logsumexp(&log_priors);
2301            for (i, e) in experts.iter_mut().enumerate() {
2302                let p = (e.log_prior - norm).exp();
2303                prior_weights[i] = p;
2304            }
2305        }
2306
2307        let base_lr = if learning_rate.is_finite() {
2308            learning_rate.abs().clamp(1e-6, 1.0)
2309        } else {
2310            0.03
2311        };
2312        let effective_lr = (base_lr * 25.0).clamp(1e-6, 1.0);
2313        let analyzer = TextContextAnalyzer::new();
2314        let mut neural =
2315            NeuralMixCore::new(n, &prior_weights, effective_lr * 0.5, effective_lr, 1e-5);
2316        neural.set_context_state(analyzer.state());
2317        let eval_cache_history = neural.history_state();
2318
2319        Self {
2320            experts,
2321            neural,
2322            analyzer,
2323            min_prob: DEFAULT_MIN_PROB,
2324            scratch_expert_logps: vec![0.0; n],
2325            scratch_mix_weights: vec![0.0; n],
2326            eval_cache_valid: false,
2327            eval_cache_full_valid: false,
2328            eval_cache_history,
2329            eval_cache_symbol: 0,
2330            eval_cache_logp: f64::NEG_INFINITY,
2331            eval_cache_mix_logps: [f64::NEG_INFINITY; 256],
2332            eval_cache_expert_logps: vec![[f64::NEG_INFINITY; 256]; n],
2333            total_log_loss: 0.0,
2334        }
2335    }
2336
2337    #[inline]
2338    fn invalidate_eval_cache(&mut self) {
2339        self.eval_cache_valid = false;
2340        self.eval_cache_full_valid = false;
2341    }
2342
2343    fn sync_history_state(&mut self) -> NeuralHistoryState {
2344        let history = self.analyzer.state();
2345        if self.neural.history_state() != history {
2346            self.neural.set_context_state(history);
2347        }
2348        if self.eval_cache_history != history {
2349            self.invalidate_eval_cache();
2350            self.eval_cache_history = history;
2351        }
2352        history
2353    }
2354
2355    fn ensure_full_evaluation(&mut self) {
2356        self.sync_history_state();
2357        if self.eval_cache_full_valid {
2358            return;
2359        }
2360
2361        self.neural.evaluate_expert_weights();
2362        self.scratch_mix_weights
2363            .copy_from_slice(self.neural.expert_weights());
2364        let mut mix_pdf = [0.0f64; 256];
2365        for i in 0..self.experts.len() {
2366            let row = &mut self.eval_cache_expert_logps[i];
2367            self.experts[i].predictor.fill_log_probs(row);
2368            let w = self.scratch_mix_weights[i];
2369            for (dst, &lp) in mix_pdf.iter_mut().zip(row.iter()) {
2370                *dst += w * clamp_prob(lp.exp(), self.min_prob);
2371            }
2372        }
2373
2374        let sum: f64 = mix_pdf.iter().sum();
2375        if !sum.is_finite() || sum <= 0.0 {
2376            let uniform = (1.0f64 / 256.0).ln();
2377            self.eval_cache_mix_logps.fill(uniform);
2378        } else {
2379            let inv = 1.0 / sum;
2380            for (dst, &p_raw) in self.eval_cache_mix_logps.iter_mut().zip(mix_pdf.iter()) {
2381                let p = clamp_unit_prob(p_raw * inv, self.min_prob);
2382                *dst = p.ln();
2383            }
2384        }
2385
2386        self.eval_cache_full_valid = true;
2387    }
2388
2389    fn evaluate_symbol(&mut self, symbol: u8) -> f64 {
2390        let history = self.sync_history_state();
2391        if self.eval_cache_valid
2392            && self.eval_cache_history == history
2393            && self.eval_cache_symbol == symbol
2394        {
2395            return self.eval_cache_logp;
2396        }
2397
2398        if self.eval_cache_full_valid && self.eval_cache_history == history {
2399            for (dst, row) in self
2400                .scratch_expert_logps
2401                .iter_mut()
2402                .zip(self.eval_cache_expert_logps.iter())
2403            {
2404                *dst = row[symbol as usize];
2405            }
2406            let logp = self.eval_cache_mix_logps[symbol as usize];
2407            self.eval_cache_valid = true;
2408            self.eval_cache_symbol = symbol;
2409            self.eval_cache_logp = logp;
2410            return logp;
2411        }
2412
2413        let expert_count = self.experts.len();
2414        for i in 0..expert_count {
2415            self.scratch_expert_logps[i] = self.experts[i].log_prob(symbol);
2416        }
2417        let p = self
2418            .neural
2419            .evaluate_symbol(&self.scratch_expert_logps, self.min_prob);
2420        let logp = clamp_unit_prob(p, self.min_prob).ln();
2421        self.eval_cache_valid = true;
2422        self.eval_cache_history = history;
2423        self.eval_cache_symbol = symbol;
2424        self.eval_cache_logp = logp;
2425        logp
2426    }
2427
2428    fn predict_log_prob(&mut self, symbol: u8) -> f64 {
2429        if self.experts.is_empty() {
2430            return f64::NEG_INFINITY;
2431        }
2432        if self.experts.len() == 1 {
2433            return self.experts[0].log_prob(symbol);
2434        }
2435        self.evaluate_symbol(symbol)
2436    }
2437
2438    fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2439        if self.experts.is_empty() {
2440            out.fill(f64::NEG_INFINITY);
2441            return;
2442        }
2443        if self.experts.len() == 1 {
2444            self.experts[0].predictor.fill_log_probs(out);
2445            return;
2446        }
2447        self.ensure_full_evaluation();
2448        out.copy_from_slice(&self.eval_cache_mix_logps);
2449    }
2450
2451    /// Log-probability (natural log) of the neural mixture for `symbol`, then update.
2452    pub fn step(&mut self, symbol: u8) -> f64 {
2453        if self.experts.is_empty() {
2454            return f64::NEG_INFINITY;
2455        }
2456
2457        if self.experts.len() == 1 {
2458            let expert = &mut self.experts[0];
2459            let logp = expert.log_prob_update(symbol);
2460            expert.cum_log_loss -= logp;
2461            self.total_log_loss -= logp;
2462            self.analyzer.update(symbol);
2463            self.neural.set_context_state(self.analyzer.state());
2464            self.invalidate_eval_cache();
2465            return logp;
2466        }
2467
2468        let history = self.sync_history_state();
2469        let logp = if self.eval_cache_valid
2470            && self.eval_cache_history == history
2471            && self.eval_cache_symbol == symbol
2472        {
2473            let logp = self.eval_cache_logp;
2474            for i in 0..self.experts.len() {
2475                let expert = &mut self.experts[i];
2476                expert.cum_log_loss -= self.scratch_expert_logps[i];
2477                expert.update(symbol);
2478            }
2479            logp
2480        } else if self.eval_cache_full_valid && self.eval_cache_history == history {
2481            for i in 0..self.experts.len() {
2482                self.scratch_expert_logps[i] = self.eval_cache_expert_logps[i][symbol as usize];
2483            }
2484            let logp = self.eval_cache_mix_logps[symbol as usize];
2485            for i in 0..self.experts.len() {
2486                let expert = &mut self.experts[i];
2487                expert.cum_log_loss -= self.scratch_expert_logps[i];
2488                expert.update(symbol);
2489            }
2490            logp
2491        } else {
2492            for i in 0..self.experts.len() {
2493                let expert = &mut self.experts[i];
2494                self.scratch_expert_logps[i] = expert.log_prob_update(symbol);
2495                expert.cum_log_loss -= self.scratch_expert_logps[i];
2496            }
2497            let p = self
2498                .neural
2499                .evaluate_symbol(&self.scratch_expert_logps, self.min_prob);
2500            clamp_unit_prob(p, self.min_prob).ln()
2501        };
2502        self.neural
2503            .update_weights_symbol(&self.scratch_expert_logps, self.min_prob);
2504        self.total_log_loss -= logp;
2505        self.analyzer.update(symbol);
2506        self.neural.set_context_state(self.analyzer.state());
2507        self.invalidate_eval_cache();
2508        logp
2509    }
2510
2511    /// Total log-loss of the mixture so far (nats).
2512    pub fn total_log_loss(&self) -> f64 {
2513        self.total_log_loss
2514    }
2515
2516    fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2517        for expert in &mut self.experts {
2518            expert.reset_frozen(total_symbols)?;
2519        }
2520        self.analyzer = TextContextAnalyzer::new();
2521        self.neural.set_context_state(self.analyzer.state());
2522        self.invalidate_eval_cache();
2523        self.eval_cache_history = self.neural.history_state();
2524        self.total_log_loss = 0.0;
2525        Ok(())
2526    }
2527
2528    fn update_frozen(&mut self, symbol: u8) {
2529        for expert in &mut self.experts {
2530            expert.update_frozen(symbol);
2531        }
2532        self.analyzer.update(symbol);
2533        self.neural.set_context_state(self.analyzer.state());
2534        self.invalidate_eval_cache();
2535        self.eval_cache_history = self.neural.history_state();
2536    }
2537}
2538
2539impl MdlSelector {
2540    /// Construct an MDL-style expert selector.
2541    pub fn new(configs: &[ExpertConfig]) -> Self {
2542        let experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
2543        let last_best = 0usize;
2544        Self {
2545            experts,
2546            scratch_logps: vec![0.0; configs.len()],
2547            total_log_loss: 0.0,
2548            last_best,
2549            cached_symbol: 0,
2550            cached_best_idx: 0,
2551            cached_best_logp: f64::NEG_INFINITY,
2552            cache_valid: false,
2553        }
2554    }
2555
2556    /// Log-probability (natural log) of the MDL selector for `symbol`, then update.
2557    pub fn step(&mut self, symbol: u8) -> f64 {
2558        if self.experts.is_empty() {
2559            return f64::NEG_INFINITY;
2560        }
2561        let used_cache = self.cache_valid && self.cached_symbol == symbol;
2562        let best_idx = if used_cache {
2563            self.scratch_logps[self.cached_best_idx] = self.cached_best_logp;
2564            for (i, expert) in self.experts.iter_mut().enumerate() {
2565                if i == self.cached_best_idx {
2566                    continue;
2567                }
2568                self.scratch_logps[i] = expert.log_prob(symbol);
2569            }
2570            self.cached_best_idx
2571        } else {
2572            for (i, expert) in self.experts.iter_mut().enumerate() {
2573                self.scratch_logps[i] = expert.log_prob_update(symbol);
2574            }
2575            let mut best_idx = 0usize;
2576            let mut best_loss = f64::INFINITY;
2577            for (i, expert) in self.experts.iter().enumerate() {
2578                if expert.cum_log_loss < best_loss {
2579                    best_loss = expert.cum_log_loss;
2580                    best_idx = i;
2581                }
2582            }
2583            best_idx
2584        };
2585        let logp = self.scratch_logps[best_idx];
2586        self.cache_valid = false;
2587        for (i, expert) in self.experts.iter_mut().enumerate() {
2588            expert.cum_log_loss -= self.scratch_logps[i];
2589            if used_cache {
2590                expert.update(symbol);
2591            }
2592        }
2593        self.total_log_loss -= logp;
2594        self.last_best = best_idx;
2595        logp
2596    }
2597
2598    fn predict_log_prob(&mut self, symbol: u8) -> f64 {
2599        if self.experts.is_empty() {
2600            return f64::NEG_INFINITY;
2601        }
2602        let mut best_idx = 0usize;
2603        let mut best_loss = f64::INFINITY;
2604        for (i, expert) in self.experts.iter().enumerate() {
2605            if expert.cum_log_loss < best_loss {
2606                best_loss = expert.cum_log_loss;
2607                best_idx = i;
2608            }
2609        }
2610        let logp = self.experts[best_idx].log_prob(symbol);
2611        self.cached_symbol = symbol;
2612        self.cached_best_idx = best_idx;
2613        self.cached_best_logp = logp;
2614        self.cache_valid = true;
2615        logp
2616    }
2617
2618    fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2619        if self.experts.is_empty() {
2620            out.fill(f64::NEG_INFINITY);
2621            return;
2622        }
2623        let mut best_idx = 0usize;
2624        let mut best_loss = f64::INFINITY;
2625        for (i, expert) in self.experts.iter().enumerate() {
2626            if expert.cum_log_loss < best_loss {
2627                best_loss = expert.cum_log_loss;
2628                best_idx = i;
2629            }
2630        }
2631        self.experts[best_idx].predictor.fill_log_probs(out);
2632    }
2633
2634    /// Index of the current best expert.
2635    pub fn best_index(&self) -> usize {
2636        self.last_best
2637    }
2638
2639    /// Index and log-loss (nats) of the current best expert.
2640    pub fn min_expert_log_loss(&self) -> (usize, f64) {
2641        let mut best_idx = 0usize;
2642        let mut best_loss = f64::INFINITY;
2643        for (i, e) in self.experts.iter().enumerate() {
2644            if e.cum_log_loss < best_loss {
2645                best_loss = e.cum_log_loss;
2646                best_idx = i;
2647            }
2648        }
2649        (best_idx, best_loss)
2650    }
2651
2652    /// Total log-loss of the selector so far (nats).
2653    pub fn total_log_loss(&self) -> f64 {
2654        self.total_log_loss
2655    }
2656
2657    /// Expert cumulative log-losses (nats) and names.
2658    pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
2659        self.experts
2660            .iter()
2661            .map(|e| (e.name.clone(), e.cum_log_loss))
2662            .collect()
2663    }
2664
2665    /// Expert names in order.
2666    pub fn expert_names(&self) -> Vec<String> {
2667        self.experts.iter().map(|e| e.name.clone()).collect()
2668    }
2669
2670    fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2671        for expert in &mut self.experts {
2672            expert.reset_frozen(total_symbols)?;
2673        }
2674        self.cache_valid = false;
2675        self.total_log_loss = 0.0;
2676        Ok(())
2677    }
2678
2679    fn update_frozen(&mut self, symbol: u8) {
2680        for expert in &mut self.experts {
2681            expert.update_frozen(symbol);
2682        }
2683        self.cache_valid = false;
2684    }
2685}
2686
2687// =============================================================================
2688// Mixture Runtime Helper (for RateBackend::Mixture)
2689// =============================================================================
2690
2691/// Runtime wrapper over concrete mixture strategies.
2692#[allow(clippy::large_enum_variant)]
2693#[derive(Clone)]
2694pub enum MixtureRuntime {
2695    /// Bayes mixture.
2696    Bayes(BayesMixture),
2697    /// Fading Bayes mixture.
2698    Fading(FadingBayesMixture),
2699    /// Switching mixture.
2700    Switching(SwitchingMixture),
2701    /// Convex mixture.
2702    Convex(ConvexMixture),
2703    /// MDL selector.
2704    Mdl(MdlSelector),
2705    /// Bytewise neural logistic mixer.
2706    Neural(NeuralMixture),
2707}
2708
2709impl MixtureRuntime {
2710    pub(crate) fn begin_stream(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2711        match self {
2712            MixtureRuntime::Bayes(m) => begin_expert_stream(&mut m.experts, total_symbols),
2713            MixtureRuntime::Fading(m) => begin_expert_stream(&mut m.experts, total_symbols),
2714            MixtureRuntime::Switching(m) => begin_expert_stream(&mut m.experts, total_symbols),
2715            MixtureRuntime::Convex(m) => begin_expert_stream(&mut m.experts, total_symbols),
2716            MixtureRuntime::Mdl(m) => begin_expert_stream(&mut m.experts, total_symbols),
2717            MixtureRuntime::Neural(m) => begin_expert_stream(&mut m.experts, total_symbols),
2718        }
2719    }
2720
2721    pub(crate) fn finish_stream(&mut self) -> Result<(), String> {
2722        match self {
2723            MixtureRuntime::Bayes(m) => finish_expert_stream(&mut m.experts),
2724            MixtureRuntime::Fading(m) => finish_expert_stream(&mut m.experts),
2725            MixtureRuntime::Switching(m) => finish_expert_stream(&mut m.experts),
2726            MixtureRuntime::Convex(m) => finish_expert_stream(&mut m.experts),
2727            MixtureRuntime::Mdl(m) => finish_expert_stream(&mut m.experts),
2728            MixtureRuntime::Neural(m) => finish_expert_stream(&mut m.experts),
2729        }
2730    }
2731
2732    pub(crate) fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2733        match self {
2734            MixtureRuntime::Bayes(m) => m.reset_frozen(total_symbols),
2735            MixtureRuntime::Fading(m) => m.reset_frozen(total_symbols),
2736            MixtureRuntime::Switching(m) => m.reset_frozen(total_symbols),
2737            MixtureRuntime::Convex(m) => m.reset_frozen(total_symbols),
2738            MixtureRuntime::Mdl(m) => m.reset_frozen(total_symbols),
2739            MixtureRuntime::Neural(m) => m.reset_frozen(total_symbols),
2740        }
2741    }
2742
2743    /// Non-mutating log-probability (nats) for `symbol` at current state.
2744    pub(crate) fn peek_log_prob(&mut self, symbol: u8) -> f64 {
2745        match self {
2746            MixtureRuntime::Bayes(m) => m.predict_log_prob(symbol),
2747            MixtureRuntime::Fading(m) => m.predict_log_prob(symbol),
2748            MixtureRuntime::Switching(m) => m.predict_log_prob(symbol),
2749            MixtureRuntime::Convex(m) => m.predict_log_prob(symbol),
2750            MixtureRuntime::Mdl(m) => m.predict_log_prob(symbol),
2751            MixtureRuntime::Neural(m) => m.predict_log_prob(symbol),
2752        }
2753    }
2754
2755    /// Step the mixture and return log-probability (nats).
2756    pub(crate) fn step(&mut self, symbol: u8) -> f64 {
2757        match self {
2758            MixtureRuntime::Bayes(m) => m.step(symbol),
2759            MixtureRuntime::Fading(m) => m.step(symbol),
2760            MixtureRuntime::Switching(m) => m.step(symbol),
2761            MixtureRuntime::Convex(m) => m.step(symbol),
2762            MixtureRuntime::Mdl(m) => m.step(symbol),
2763            MixtureRuntime::Neural(m) => m.step(symbol),
2764        }
2765    }
2766
2767    pub(crate) fn update_frozen(&mut self, symbol: u8) {
2768        match self {
2769            MixtureRuntime::Bayes(m) => m.update_frozen(symbol),
2770            MixtureRuntime::Fading(m) => m.update_frozen(symbol),
2771            MixtureRuntime::Switching(m) => m.update_frozen(symbol),
2772            MixtureRuntime::Convex(m) => m.update_frozen(symbol),
2773            MixtureRuntime::Mdl(m) => m.update_frozen(symbol),
2774            MixtureRuntime::Neural(m) => m.update_frozen(symbol),
2775        }
2776    }
2777
2778    pub(crate) fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2779        match self {
2780            MixtureRuntime::Bayes(m) => m.fill_log_probs(out),
2781            MixtureRuntime::Fading(m) => m.fill_log_probs(out),
2782            MixtureRuntime::Switching(m) => m.fill_log_probs(out),
2783            MixtureRuntime::Convex(m) => m.fill_log_probs(out),
2784            MixtureRuntime::Mdl(m) => m.fill_log_probs(out),
2785            MixtureRuntime::Neural(m) => m.fill_log_probs(out),
2786        }
2787    }
2788}
2789
2790fn begin_expert_stream(
2791    experts: &mut [ExpertState],
2792    total_symbols: Option<u64>,
2793) -> Result<(), String> {
2794    for expert in experts {
2795        expert.begin_stream(total_symbols)?;
2796    }
2797    Ok(())
2798}
2799
2800fn finish_expert_stream(experts: &mut [ExpertState]) -> Result<(), String> {
2801    for expert in experts {
2802        expert.finish_stream()?;
2803    }
2804    Ok(())
2805}
2806
2807pub(crate) fn build_mixture_runtime(
2808    spec: &MixtureSpec,
2809    experts: &[ExpertConfig],
2810) -> Result<MixtureRuntime, String> {
2811    spec.validate()?;
2812    match spec.kind {
2813        MixtureKind::Bayes => Ok(MixtureRuntime::Bayes(BayesMixture::new(experts))),
2814        MixtureKind::FadingBayes => {
2815            let decay = spec
2816                .decay
2817                .ok_or_else(|| "fading Bayes mixture requires decay".to_string())?;
2818            Ok(MixtureRuntime::Fading(FadingBayesMixture::new(
2819                experts, decay,
2820            )))
2821        }
2822        MixtureKind::Switching => Ok(MixtureRuntime::Switching(SwitchingMixture::new(
2823            experts,
2824            spec.alpha,
2825            spec.schedule,
2826        ))),
2827        MixtureKind::Convex => Ok(MixtureRuntime::Convex(ConvexMixture::new(
2828            experts,
2829            spec.alpha,
2830            spec.schedule,
2831        ))),
2832        MixtureKind::Mdl => Ok(MixtureRuntime::Mdl(MdlSelector::new(experts))),
2833        MixtureKind::Neural => Ok(MixtureRuntime::Neural(NeuralMixture::new(
2834            experts, spec.alpha,
2835        ))),
2836    }
2837}
2838
2839#[cfg(test)]
2840mod tests {
2841    use super::*;
2842    use std::sync::{
2843        Arc,
2844        atomic::{AtomicU64, AtomicUsize, Ordering},
2845    };
2846
2847    #[derive(Clone)]
2848    struct AlwaysPredict {
2849        byte: u8,
2850    }
2851
2852    impl OnlineBytePredictor for AlwaysPredict {
2853        fn log_prob(&mut self, symbol: u8) -> f64 {
2854            if symbol == self.byte {
2855                0.0
2856            } else {
2857                f64::NEG_INFINITY
2858            }
2859        }
2860
2861        fn update(&mut self, _symbol: u8) {}
2862    }
2863
2864    #[derive(Clone)]
2865    struct FixedProbPredict {
2866        prob_zero: f64,
2867    }
2868
2869    impl OnlineBytePredictor for FixedProbPredict {
2870        fn log_prob(&mut self, symbol: u8) -> f64 {
2871            let p = if symbol == 0 {
2872                self.prob_zero
2873            } else {
2874                1.0 - self.prob_zero
2875            };
2876            p.ln()
2877        }
2878
2879        fn update(&mut self, _symbol: u8) {}
2880    }
2881
2882    fn weighted_cfg(name: &'static str, weight: f64, prob_zero: f64) -> ExpertConfig {
2883        ExpertConfig::new(name, weight.ln(), move || {
2884            Box::new(FixedProbPredict { prob_zero })
2885        })
2886    }
2887
2888    #[test]
2889    fn bayes_mixture_prefers_correct_expert() {
2890        let configs = vec![
2891            ExpertConfig::uniform("zero", || Box::new(AlwaysPredict { byte: 0 })),
2892            ExpertConfig::uniform("one", || Box::new(AlwaysPredict { byte: 1 })),
2893        ];
2894        let mut mix = BayesMixture::new(&configs);
2895        for _ in 0..10 {
2896            mix.step(0);
2897        }
2898        let post = mix.posterior();
2899        assert!(post[0] > 0.999);
2900        assert!(post[1] < 1e-6);
2901    }
2902
2903    fn counting_cfg(name: &'static str, calls: Arc<AtomicUsize>) -> ExpertConfig {
2904        ExpertConfig::uniform(name, move || {
2905            Box::new(CountingPredict {
2906                calls: calls.clone(),
2907            })
2908        })
2909    }
2910
2911    #[test]
2912    fn bayes_predict_then_step_reuses_cached_log_probs() {
2913        let c0 = Arc::new(AtomicUsize::new(0));
2914        let c1 = Arc::new(AtomicUsize::new(0));
2915        let mut mix = BayesMixture::new(&[
2916            counting_cfg("c0", c0.clone()),
2917            counting_cfg("c1", c1.clone()),
2918        ]);
2919        let _ = mix.predict_log_prob(0);
2920        let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2921        assert_eq!(after_predict, 2);
2922        let _ = mix.step(0);
2923        let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2924        assert_eq!(after_step, after_predict);
2925    }
2926
2927    #[test]
2928    fn fading_predict_then_step_reuses_cached_log_probs() {
2929        let c0 = Arc::new(AtomicUsize::new(0));
2930        let c1 = Arc::new(AtomicUsize::new(0));
2931        let mut mix = FadingBayesMixture::new(
2932            &[
2933                counting_cfg("c0", c0.clone()),
2934                counting_cfg("c1", c1.clone()),
2935            ],
2936            0.95,
2937        );
2938        let _ = mix.predict_log_prob(0);
2939        let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2940        assert_eq!(after_predict, 2);
2941        let _ = mix.step(0);
2942        let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2943        assert_eq!(after_step, after_predict);
2944    }
2945
2946    #[test]
2947    fn switching_predict_then_step_reuses_cached_log_probs() {
2948        let c0 = Arc::new(AtomicUsize::new(0));
2949        let c1 = Arc::new(AtomicUsize::new(0));
2950        let mut mix = SwitchingMixture::new(
2951            &[
2952                counting_cfg("c0", c0.clone()),
2953                counting_cfg("c1", c1.clone()),
2954            ],
2955            0.05,
2956            MixtureScheduleMode::Default,
2957        );
2958        let _ = mix.predict_log_prob(0);
2959        let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2960        assert_eq!(after_predict, 2);
2961        let _ = mix.step(0);
2962        let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2963        assert_eq!(after_step, after_predict);
2964    }
2965
2966    #[test]
2967    fn switching_mixture_matches_fixed_share_update_for_uniform_prior() {
2968        let configs = vec![weighted_cfg("a", 0.5, 0.8), weighted_cfg("b", 0.5, 0.3)];
2969        let alpha = 0.2;
2970        let mut mix = SwitchingMixture::new(&configs, alpha, MixtureScheduleMode::Default);
2971
2972        let predicted = mix.predict_log_prob(0).exp();
2973        assert!((predicted - 0.55).abs() < 1e-12, "predicted={predicted}");
2974
2975        let observed = mix.step(0).exp();
2976        assert!((observed - 0.55).abs() < 1e-12, "observed={observed}");
2977
2978        let post = mix.posterior();
2979        let posterior_a = 0.5 * 0.8 / 0.55;
2980        let posterior_b = 0.5 * 0.3 / 0.55;
2981        let expected_a = (1.0 - alpha) * posterior_a + alpha * posterior_b;
2982        let expected_b = (1.0 - alpha) * posterior_b + alpha * posterior_a;
2983        assert!(
2984            (post[0] - expected_a).abs() < 1e-12 && (post[1] - expected_b).abs() < 1e-12,
2985            "expected [{expected_a}, {expected_b}], got {:?}",
2986            post
2987        );
2988    }
2989
2990    #[test]
2991    fn switching_mixture_switches_according_to_prior_over_other_experts() {
2992        let configs = vec![
2993            weighted_cfg("a", 0.5, 0.75),
2994            weighted_cfg("b", 0.3, 0.25),
2995            weighted_cfg("c", 0.2, 0.60),
2996        ];
2997        let alpha = 0.15;
2998        let mut mix = SwitchingMixture::new(&configs, alpha, MixtureScheduleMode::Default);
2999
3000        let _ = mix.step(0);
3001        let post = mix.posterior();
3002
3003        let current = [0.5_f64, 0.3, 0.2];
3004        let likelihood = [0.75_f64, 0.25, 0.60];
3005        let mix_prob = current
3006            .iter()
3007            .zip(likelihood.iter())
3008            .map(|(w, p)| w * p)
3009            .sum::<f64>();
3010        let posterior = [
3011            current[0] * likelihood[0] / mix_prob,
3012            current[1] * likelihood[1] / mix_prob,
3013            current[2] * likelihood[2] / mix_prob,
3014        ];
3015        let prior = [0.5_f64, 0.3, 0.2];
3016        let mut expected = [0.0_f64; 3];
3017        for j in 0..3 {
3018            let stay = (1.0 - alpha) * posterior[j];
3019            let switch_in = alpha
3020                * prior[j]
3021                * (0..3)
3022                    .filter(|&k| k != j)
3023                    .map(|k| posterior[k] / (1.0 - prior[k]))
3024                    .sum::<f64>();
3025            expected[j] = stay + switch_in;
3026        }
3027
3028        for i in 0..3 {
3029            assert!(
3030                (post[i] - expected[i]).abs() < 1e-12,
3031                "expert {i}: expected {} got {}",
3032                expected[i],
3033                post[i]
3034            );
3035        }
3036    }
3037
3038    #[test]
3039    fn switching_theorem_schedule_uses_one_over_t() {
3040        assert!(
3041            (switching_alpha_for_update(MixtureScheduleMode::Theorem, 0.99, 0) - 0.5).abs() < 1e-12
3042        );
3043        assert!(
3044            (switching_alpha_for_update(MixtureScheduleMode::Theorem, 0.99, 1) - (1.0 / 3.0)).abs()
3045                < 1e-12
3046        );
3047
3048        let configs = vec![weighted_cfg("a", 0.5, 0.8), weighted_cfg("b", 0.5, 0.3)];
3049        let mut mix = SwitchingMixture::new(&configs, 0.99, MixtureScheduleMode::Theorem);
3050        let _ = mix.step(0);
3051        let post = mix.posterior();
3052        let posterior_a = 0.5 * 0.8 / 0.55;
3053        let posterior_b = 0.5 * 0.3 / 0.55;
3054        let expected_a = 0.5 * posterior_a + 0.5 * posterior_b;
3055        let expected_b = expected_a;
3056        assert!((post[0] - expected_a).abs() < 1e-12);
3057        assert!((post[1] - expected_b).abs() < 1e-12);
3058    }
3059
3060    #[test]
3061    fn convex_theorem_schedule_uses_paper_step_size() {
3062        let eta = convex_step_size_for_update(MixtureScheduleMode::Theorem, 9.0, 1);
3063        assert!((eta - DEFAULT_MIN_PROB).abs() < 1e-18);
3064
3065        let configs = vec![weighted_cfg("a", 0.5, 0.8), weighted_cfg("b", 0.5, 0.3)];
3066        let mut mix = ConvexMixture::new(&configs, 9.0, MixtureScheduleMode::Theorem);
3067        let observed = mix.step(0).exp();
3068        assert!((observed - 0.55).abs() < 1e-12, "observed={observed}");
3069
3070        let expected = [
3071            0.5 + eta * ((0.8 / 0.55) - 1.0),
3072            0.5 + eta * ((0.3 / 0.55) - 1.0),
3073        ];
3074        assert!((mix.lambda[0] - expected[0]).abs() < 1e-12);
3075        assert!((mix.lambda[1] - expected[1]).abs() < 1e-12);
3076    }
3077
3078    #[test]
3079    fn mdl_predict_then_step_reuses_best_expert_log_prob() {
3080        let c0 = Arc::new(AtomicUsize::new(0));
3081        let c1 = Arc::new(AtomicUsize::new(0));
3082        let mut mdl = MdlSelector::new(&[
3083            counting_cfg("c0", c0.clone()),
3084            counting_cfg("c1", c1.clone()),
3085        ]);
3086        let _ = mdl.predict_log_prob(0);
3087        let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
3088        assert_eq!(after_predict, 1);
3089        let _ = mdl.step(0);
3090        let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
3091        assert_eq!(after_step, 2);
3092    }
3093
3094    #[test]
3095    fn neural_mixture_adapts_to_correct_symbol() {
3096        let configs = vec![
3097            ExpertConfig::uniform("zero", || Box::new(AlwaysPredict { byte: 0 })),
3098            ExpertConfig::uniform("one", || Box::new(AlwaysPredict { byte: 1 })),
3099        ];
3100        let mut mix = NeuralMixture::new(&configs, 0.05);
3101
3102        let mut early = 0.0;
3103        let mut late = 0.0;
3104        for t in 0..200 {
3105            let lp = mix.step(0);
3106            if t < 20 {
3107                early -= lp;
3108            }
3109            if t >= 180 {
3110                late -= lp;
3111            }
3112        }
3113
3114        let early_avg = early / 20.0;
3115        let late_avg = late / 20.0;
3116        assert!(
3117            late_avg < early_avg,
3118            "late_avg={late_avg} early_avg={early_avg}"
3119        );
3120        assert!(late_avg < 0.35, "late_avg={late_avg}");
3121    }
3122
3123    #[derive(Clone)]
3124    struct CountingPredict {
3125        calls: Arc<AtomicUsize>,
3126    }
3127
3128    impl OnlineBytePredictor for CountingPredict {
3129        fn log_prob(&mut self, symbol: u8) -> f64 {
3130            self.calls.fetch_add(1, Ordering::Relaxed);
3131            if symbol == 0 { 0.0 } else { -20.0 }
3132        }
3133
3134        fn update(&mut self, _symbol: u8) {}
3135    }
3136
3137    #[derive(Clone)]
3138    struct CountingFillPredict {
3139        log_calls: Arc<AtomicUsize>,
3140        fill_calls: Arc<AtomicUsize>,
3141    }
3142
3143    impl OnlineBytePredictor for CountingFillPredict {
3144        fn log_prob(&mut self, symbol: u8) -> f64 {
3145            self.log_calls.fetch_add(1, Ordering::Relaxed);
3146            if symbol == 0 { 0.0 } else { -20.0 }
3147        }
3148
3149        fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
3150            self.fill_calls.fetch_add(1, Ordering::Relaxed);
3151            out.fill(-20.0);
3152            out[0] = 0.0;
3153        }
3154
3155        fn update(&mut self, _symbol: u8) {}
3156    }
3157
3158    #[derive(Clone)]
3159    struct BeginAwarePredict {
3160        seen_total: Arc<AtomicU64>,
3161        began: bool,
3162    }
3163
3164    impl OnlineBytePredictor for BeginAwarePredict {
3165        fn begin_stream(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
3166            let total = total_symbols.ok_or_else(|| "missing total symbols".to_string())?;
3167            self.seen_total.store(total, Ordering::Relaxed);
3168            self.began = true;
3169            Ok(())
3170        }
3171
3172        fn log_prob(&mut self, _symbol: u8) -> f64 {
3173            if self.began { 0.0 } else { f64::NEG_INFINITY }
3174        }
3175
3176        fn update(&mut self, _symbol: u8) {}
3177    }
3178
3179    fn assert_log_prob_update_matches_separate(label: &str, backend: RateBackend) {
3180        let mut separate =
3181            RateBackendPredictor::from_backend(backend.clone(), -1, DEFAULT_MIN_PROB);
3182        let mut combined = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
3183        let data = b"combined step check data";
3184
3185        for &b in data {
3186            let logp_separate = separate.log_prob(b);
3187            separate.update(b);
3188            let logp_combined = combined.log_prob_update(b);
3189            let diff = (logp_separate - logp_combined).abs();
3190            assert!(
3191                diff <= 1e-12,
3192                "[{label}] symbol={b} separate={logp_separate} combined={logp_combined} diff={diff}"
3193            );
3194
3195            let mut sep_row = [0.0; 256];
3196            let mut combo_row = [0.0; 256];
3197            separate.fill_log_probs(&mut sep_row);
3198            combined.fill_log_probs(&mut combo_row);
3199            for i in 0..256 {
3200                let diff = (sep_row[i] - combo_row[i]).abs();
3201                assert!(
3202                    diff <= 1e-12,
3203                    "row mismatch at {i}: {} vs {}",
3204                    sep_row[i],
3205                    combo_row[i]
3206                );
3207            }
3208        }
3209    }
3210
3211    fn assert_fill_matches_symbol_queries(label: &str, backend: RateBackend) {
3212        let mut bulk = RateBackendPredictor::from_backend(backend.clone(), -1, DEFAULT_MIN_PROB);
3213        let mut queried = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
3214        let data = b"continuation consistency prompt";
3215
3216        bulk.begin_stream(Some(data.len() as u64))
3217            .expect("bulk begin");
3218        queried
3219            .begin_stream(Some(data.len() as u64))
3220            .expect("query begin");
3221        for &b in data {
3222            bulk.update(b);
3223            queried.update(b);
3224        }
3225
3226        let mut bulk_row = [0.0; 256];
3227        bulk.fill_log_probs(&mut bulk_row);
3228        for (sym, &bulk_logp) in bulk_row.iter().enumerate() {
3229            let queried_logp = queried.log_prob(sym as u8);
3230            let diff = (bulk_logp - queried_logp).abs();
3231            assert!(
3232                diff <= 1e-12,
3233                "[{label}] sym={sym} bulk={bulk_logp} queried={queried_logp} diff={diff}"
3234            );
3235        }
3236    }
3237
3238    fn assert_fill_matches_symbol_queries_after_frozen_conditioning(
3239        label: &str,
3240        backend: RateBackend,
3241    ) {
3242        let fit = b"If a frog is green, dogs are red.\nIf a toad is green, cats are red.\n";
3243        let condition = b"If a cat is red, toads are \n";
3244        let total = (fit.len() + condition.len()) as u64;
3245
3246        let mut bulk = RateBackendPredictor::from_backend(backend.clone(), -1, DEFAULT_MIN_PROB);
3247        let mut queried = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
3248
3249        bulk.begin_stream(Some(total)).expect("bulk begin");
3250        queried.begin_stream(Some(total)).expect("query begin");
3251        for &b in fit {
3252            bulk.update(b);
3253            queried.update(b);
3254        }
3255        bulk.reset_frozen(Some(condition.len() as u64))
3256            .expect("bulk reset frozen");
3257        queried
3258            .reset_frozen(Some(condition.len() as u64))
3259            .expect("query reset frozen");
3260        for &b in condition {
3261            bulk.update_frozen(b);
3262            queried.update_frozen(b);
3263        }
3264
3265        let mut bulk_row = [0.0; 256];
3266        bulk.fill_log_probs(&mut bulk_row);
3267        for (sym, &bulk_logp) in bulk_row.iter().enumerate() {
3268            let queried_logp = queried.log_prob(sym as u8);
3269            let diff = (bulk_logp - queried_logp).abs();
3270            assert!(
3271                diff <= 1e-12,
3272                "[{label}] frozen sym={sym} bulk={bulk_logp} queried={queried_logp} diff={diff}"
3273            );
3274        }
3275    }
3276
3277    #[test]
3278    fn predictor_log_prob_update_matches_separate_update_for_rosa_backend() {
3279        assert_log_prob_update_matches_separate("rosa", RateBackend::RosaPlus);
3280    }
3281
3282    #[test]
3283    fn predictor_log_prob_update_matches_separate_update_for_ctw_backend() {
3284        assert_log_prob_update_matches_separate("ctw", RateBackend::Ctw { depth: 6 });
3285    }
3286
3287    #[test]
3288    fn predictor_log_prob_update_matches_separate_update_for_fac_ctw_backend() {
3289        assert_log_prob_update_matches_separate(
3290            "fac-ctw",
3291            RateBackend::FacCtw {
3292                base_depth: 6,
3293                num_percept_bits: 8,
3294                encoding_bits: 8,
3295            },
3296        );
3297    }
3298
3299    #[test]
3300    fn predictor_fill_matches_symbol_queries_for_rosa_backend() {
3301        assert_fill_matches_symbol_queries("rosa", RateBackend::RosaPlus);
3302    }
3303
3304    #[test]
3305    fn predictor_fill_matches_symbol_queries_for_ctw_backend() {
3306        assert_fill_matches_symbol_queries("ctw", RateBackend::Ctw { depth: 6 });
3307    }
3308
3309    #[test]
3310    fn predictor_fill_matches_symbol_queries_for_match_backend() {
3311        assert_fill_matches_symbol_queries(
3312            "match",
3313            RateBackend::Match {
3314                hash_bits: 18,
3315                min_len: 4,
3316                max_len: 64,
3317                base_mix: 0.02,
3318                confidence_scale: 1.0,
3319            },
3320        );
3321    }
3322
3323    #[test]
3324    fn predictor_fill_matches_symbol_queries_for_ppmd_backend() {
3325        assert_fill_matches_symbol_queries(
3326            "ppmd",
3327            RateBackend::Ppmd {
3328                order: 8,
3329                memory_mb: 8,
3330            },
3331        );
3332    }
3333
3334    #[cfg(feature = "backend-rwkv")]
3335    #[test]
3336    fn predictor_fill_matches_symbol_queries_for_rwkv_backend() {
3337        assert_fill_matches_symbol_queries(
3338            "rwkv7",
3339            RateBackend::Rwkv7Method {
3340                method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3341            },
3342        );
3343    }
3344
3345    #[test]
3346    fn predictor_fill_matches_symbol_queries_for_rosa_backend_after_frozen_conditioning() {
3347        assert_fill_matches_symbol_queries_after_frozen_conditioning("rosa", RateBackend::RosaPlus);
3348    }
3349
3350    #[test]
3351    fn predictor_frozen_conditioning_reuses_match_fit_corpus() {
3352        let mut predictor = RateBackendPredictor::from_backend(
3353            RateBackend::Match {
3354                hash_bits: 20,
3355                min_len: 3,
3356                max_len: 32,
3357                base_mix: 0.02,
3358                confidence_scale: 1.0,
3359            },
3360            -1,
3361            DEFAULT_MIN_PROB,
3362        );
3363
3364        for &b in b"abcabcX" {
3365            predictor.update(b);
3366        }
3367        predictor
3368            .reset_frozen(Some(6))
3369            .expect("reset frozen for match backend");
3370        for &b in b"abcabc" {
3371            predictor.update_frozen(b);
3372        }
3373        let p_x = predictor.log_prob(b'X').exp();
3374        assert!(
3375            p_x > 0.01,
3376            "frozen conditioning should preserve fit corpus for match backend; p_x={p_x}"
3377        );
3378    }
3379
3380    #[test]
3381    fn predictor_frozen_conditioning_reuses_sparse_match_fit_corpus() {
3382        let mut predictor = RateBackendPredictor::from_backend(
3383            RateBackend::SparseMatch {
3384                hash_bits: 20,
3385                min_len: 3,
3386                max_len: 32,
3387                gap_min: 0,
3388                gap_max: 2,
3389                base_mix: 0.02,
3390                confidence_scale: 1.0,
3391            },
3392            -1,
3393            DEFAULT_MIN_PROB,
3394        );
3395
3396        for &b in b"abcabcX" {
3397            predictor.update(b);
3398        }
3399        predictor
3400            .reset_frozen(Some(6))
3401            .expect("reset frozen for sparse-match backend");
3402        for &b in b"abcabc" {
3403            predictor.update_frozen(b);
3404        }
3405        let p_x = predictor.log_prob(b'X').exp();
3406        assert!(
3407            p_x > 0.01,
3408            "frozen conditioning should preserve fit corpus for sparse-match backend; p_x={p_x}"
3409        );
3410    }
3411
3412    #[test]
3413    fn neural_predict_then_step_reuses_evaluation_cache() {
3414        let c0 = Arc::new(AtomicUsize::new(0));
3415        let c1 = Arc::new(AtomicUsize::new(0));
3416        let cfg0 = {
3417            let c = c0.clone();
3418            ExpertConfig::uniform("c0", move || Box::new(CountingPredict { calls: c.clone() }))
3419        };
3420        let cfg1 = {
3421            let c = c1.clone();
3422            ExpertConfig::uniform("c1", move || Box::new(CountingPredict { calls: c.clone() }))
3423        };
3424        let mut mix = NeuralMixture::new(&[cfg0, cfg1], 0.03);
3425
3426        let _ = mix.predict_log_prob(0);
3427        let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
3428        assert_eq!(after_predict, 2);
3429
3430        let _ = mix.step(0);
3431        let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
3432        assert_eq!(after_step, after_predict);
3433    }
3434
3435    #[test]
3436    fn neural_predict_multiple_symbols_reuses_single_evaluation() {
3437        let c0 = Arc::new(AtomicUsize::new(0));
3438        let c1 = Arc::new(AtomicUsize::new(0));
3439        let cfg0 = {
3440            let c = c0.clone();
3441            ExpertConfig::uniform("c0", move || Box::new(CountingPredict { calls: c.clone() }))
3442        };
3443        let cfg1 = {
3444            let c = c1.clone();
3445            ExpertConfig::uniform("c1", move || Box::new(CountingPredict { calls: c.clone() }))
3446        };
3447        let mut mix = NeuralMixture::new(&[cfg0, cfg1], 0.03);
3448
3449        let _ = mix.predict_log_prob(0);
3450        let after_first = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
3451        assert_eq!(after_first, 2);
3452
3453        let _ = mix.predict_log_prob(1);
3454        let after_second = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
3455        assert_eq!(after_second, after_first + 2);
3456    }
3457
3458    #[test]
3459    fn neural_fill_then_step_reuses_cached_full_rows() {
3460        let log0 = Arc::new(AtomicUsize::new(0));
3461        let log1 = Arc::new(AtomicUsize::new(0));
3462        let fill0 = Arc::new(AtomicUsize::new(0));
3463        let fill1 = Arc::new(AtomicUsize::new(0));
3464        let cfg0 = {
3465            let log_calls = log0.clone();
3466            let fill_calls = fill0.clone();
3467            ExpertConfig::uniform("c0", move || {
3468                Box::new(CountingFillPredict {
3469                    log_calls: log_calls.clone(),
3470                    fill_calls: fill_calls.clone(),
3471                })
3472            })
3473        };
3474        let cfg1 = {
3475            let log_calls = log1.clone();
3476            let fill_calls = fill1.clone();
3477            ExpertConfig::uniform("c1", move || {
3478                Box::new(CountingFillPredict {
3479                    log_calls: log_calls.clone(),
3480                    fill_calls: fill_calls.clone(),
3481                })
3482            })
3483        };
3484        let mut mix = NeuralMixture::new(&[cfg0, cfg1], 0.03);
3485
3486        let mut row = [0.0; 256];
3487        mix.fill_log_probs(&mut row);
3488        assert_eq!(fill0.load(Ordering::Relaxed), 1);
3489        assert_eq!(fill1.load(Ordering::Relaxed), 1);
3490        assert_eq!(log0.load(Ordering::Relaxed), 0);
3491        assert_eq!(log1.load(Ordering::Relaxed), 0);
3492
3493        let _ = mix.step(0);
3494        assert_eq!(fill0.load(Ordering::Relaxed), 1);
3495        assert_eq!(fill1.load(Ordering::Relaxed), 1);
3496        assert_eq!(log0.load(Ordering::Relaxed), 0);
3497        assert_eq!(log1.load(Ordering::Relaxed), 0);
3498    }
3499
3500    #[test]
3501    fn runtime_begin_stream_propagates_to_experts() {
3502        let seen_total = Arc::new(AtomicU64::new(0));
3503        let cfg = {
3504            let seen_total = seen_total.clone();
3505            ExpertConfig::uniform("begin-aware", move || {
3506                Box::new(BeginAwarePredict {
3507                    seen_total: seen_total.clone(),
3508                    began: false,
3509                })
3510            })
3511        };
3512
3513        let spec = MixtureSpec::new(
3514            MixtureKind::Bayes,
3515            vec![crate::MixtureExpertSpec {
3516                name: Some("begin-aware".to_string()),
3517                log_prior: 0.0,
3518                max_order: -1,
3519                backend: RateBackend::Ctw { depth: 1 },
3520            }],
3521        );
3522        let mut runtime = build_mixture_runtime(&spec, &[cfg]).expect("runtime");
3523        runtime.begin_stream(Some(123)).expect("begin stream");
3524        let _ = runtime.step(0);
3525        assert_eq!(seen_total.load(Ordering::Relaxed), 123);
3526    }
3527
3528    #[test]
3529    fn zpaq_fill_log_probs_does_not_drift_history() {
3530        let backend = RateBackend::Zpaq {
3531            method: "1".to_string(),
3532        };
3533        let mut baseline =
3534            RateBackendPredictor::from_backend(backend.clone(), -1, DEFAULT_MIN_PROB);
3535        let mut probe = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
3536
3537        let history = b"history for zpaq predictor";
3538        for &b in history {
3539            baseline.update(b);
3540            probe.update(b);
3541        }
3542
3543        let mut row = [0.0f64; 256];
3544        probe.fill_log_probs(&mut row);
3545
3546        let sym = b'k';
3547        let lp_base = baseline.log_prob(sym);
3548        let lp_probe = probe.log_prob(sym);
3549        assert!((lp_base - lp_probe).abs() < 1e-9);
3550        assert!((row[sym as usize] - lp_base).abs() < 1e-9);
3551
3552        baseline.update(sym);
3553        probe.update(sym);
3554        let next = b'q';
3555        let next_base = baseline.log_prob(next);
3556        let next_probe = probe.log_prob(next);
3557        assert!((next_base - next_probe).abs() < 1e-9);
3558    }
3559
3560    fn assert_predictor_log_probs_normalize_to_one(backend: RateBackend) {
3561        let mut predictor = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
3562        for &b in b"normalization corpus for ctw/fac predictor checks" {
3563            predictor.update(b);
3564        }
3565        let mut sum = 0.0f64;
3566        for sym in 0u8..=255u8 {
3567            sum += predictor.log_prob(sym).exp();
3568        }
3569        assert!(
3570            (sum - 1.0).abs() <= 1e-10,
3571            "probability mass drift: sum={sum}"
3572        );
3573    }
3574
3575    #[test]
3576    fn ctw_predictor_symbol_probs_normalize() {
3577        assert_predictor_log_probs_normalize_to_one(RateBackend::Ctw { depth: 7 });
3578    }
3579
3580    #[test]
3581    fn fac_ctw_predictor_symbol_probs_normalize() {
3582        assert_predictor_log_probs_normalize_to_one(RateBackend::FacCtw {
3583            base_depth: 7,
3584            num_percept_bits: 8,
3585            encoding_bits: 8,
3586        });
3587    }
3588}