infotheory/
mixture.rs

1//! Online mixtures of probabilistic predictors (log-loss Hedge / Bayes, switching, MDL).
2//!
3//! This module provides a small, rigorously correct toolkit for sequential model mixing.
4//! Predictors expose per-symbol log-probabilities, which allows principled Bayesian
5//! mixture updates and clean information-theoretic accounting.
6//!
7//! ## Rate-Backend Mixtures
8//!
9//! The mixture primitives here power `RateBackend::Mixture`, enabling Bayes, fading Bayes,
10//! switching, and MDL-style selectors to be used anywhere a rate backend is accepted.
11
12use crate::backends::calibration::CalibratorCore;
13use crate::backends::match_model::MatchModel;
14use crate::backends::ppmd::PpmdModel;
15use crate::backends::sparse_match::SparseMatchModel;
16use crate::backends::text_context::TextContextAnalyzer;
17use crate::ctw::FacContextTree;
18#[cfg(feature = "backend-mamba")]
19use crate::mambazip;
20use crate::neural_mix::{NeuralHistoryState, NeuralMixCore};
21use crate::rosaplus::RosaPlus;
22#[cfg(feature = "backend-rwkv")]
23use crate::rwkvzip;
24use crate::zpaq_rate::ZpaqRateModel;
25use crate::{CalibratedSpec, MixtureKind, MixtureSpec, RateBackend};
26use std::sync::Arc;
27
28/// Default minimum probability floor to avoid log(0).
29pub const DEFAULT_MIN_PROB: f64 = 5.960_464_477_539_063e-8;
30
31#[inline]
32fn clamp_prob(p: f64, min_prob: f64) -> f64 {
33    if p.is_finite() {
34        p.max(min_prob)
35    } else {
36        min_prob
37    }
38}
39
40#[inline]
41fn clamp_unit_prob(p: f64, min_prob: f64) -> f64 {
42    clamp_prob(p, min_prob).min(1.0 - min_prob)
43}
44
45#[inline]
46fn build_calibrator(spec: &CalibratedSpec) -> CalibratorCore {
47    CalibratorCore::new(spec.context, spec.bins, spec.learning_rate, spec.bias_clip)
48}
49
50#[inline]
51fn logsumexp(xs: &[f64]) -> f64 {
52    let mut max_v = f64::NEG_INFINITY;
53    for &v in xs {
54        if v > max_v {
55            max_v = v;
56        }
57    }
58    if !max_v.is_finite() {
59        return max_v;
60    }
61    let mut sum = 0.0;
62    for &v in xs {
63        sum += (v - max_v).exp();
64    }
65    max_v + sum.ln()
66}
67
68#[inline]
69fn logsumexp2(a: f64, b: f64) -> f64 {
70    let m = if a > b { a } else { b };
71    if !m.is_finite() {
72        return m;
73    }
74    m + ((a - m).exp() + (b - m).exp()).ln()
75}
76
77#[inline]
78fn logsumexp_weights(experts: &[ExpertState]) -> f64 {
79    let mut max_v = f64::NEG_INFINITY;
80    for e in experts {
81        if e.log_weight > max_v {
82            max_v = e.log_weight;
83        }
84    }
85    if !max_v.is_finite() {
86        return max_v;
87    }
88    let mut sum = 0.0;
89    for e in experts {
90        sum += (e.log_weight - max_v).exp();
91    }
92    max_v + sum.ln()
93}
94
95/// Trait for online byte-level predictors that expose per-symbol log-probabilities.
96pub trait OnlineBytePredictorClone {
97    /// Clone this predictor as a trait object.
98    ///
99    /// This supports `Clone` for `Box<dyn OnlineBytePredictor>` via type erasure,
100    /// so mixture experts can be duplicated without knowing their concrete type.
101    fn clone_box(&self) -> Box<dyn OnlineBytePredictor>;
102}
103
104impl<T> OnlineBytePredictorClone for T
105where
106    T: 'static + OnlineBytePredictor + Clone,
107{
108    fn clone_box(&self) -> Box<dyn OnlineBytePredictor> {
109        Box::new(self.clone())
110    }
111}
112
113impl Clone for Box<dyn OnlineBytePredictor> {
114    fn clone(&self) -> Self {
115        self.clone_box()
116    }
117}
118
119/// Trait for online byte-level predictors that expose per-symbol log-probabilities.
120pub trait OnlineBytePredictor: Send + OnlineBytePredictorClone {
121    /// Optional stream-start hook.
122    ///
123    /// Predictors that require total symbol count (for example percent-based
124    /// policy schedules) can initialize runtime state here.
125    fn begin_stream(&mut self, _total_symbols: Option<u64>) -> Result<(), String> {
126        Ok(())
127    }
128
129    /// Optional stream-finalization hook.
130    fn finish_stream(&mut self) -> Result<(), String> {
131        Ok(())
132    }
133
134    /// Log-probability (natural log) of `symbol` given the current history.
135    fn log_prob(&mut self, symbol: u8) -> f64;
136
137    /// Bulk 256-way log-probabilities for the next byte.
138    fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
139        for (sym, slot) in out.iter_mut().enumerate() {
140            *slot = self.log_prob(sym as u8);
141        }
142    }
143
144    /// Log-probability (natural log) of `symbol`, then update the predictor.
145    fn log_prob_update(&mut self, symbol: u8) -> f64 {
146        let logp = self.log_prob(symbol);
147        self.update(symbol);
148        logp
149    }
150
151    /// Update the predictor with the observed `symbol`.
152    fn update(&mut self, symbol: u8);
153
154    /// Reset only dynamic conditioning state while preserving fitted parameters/statistics.
155    ///
156    /// Predictors with latent/posterior state may also preserve their learned
157    /// parameter posterior here; "frozen" means no new parameter fitting during
158    /// the score pass, not necessarily a static hidden-state belief.
159    fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
160        self.finish_stream()?;
161        self.begin_stream(total_symbols)
162    }
163
164    /// Advance conditioning state without fitting or adapting parameters.
165    ///
166    /// For state-space or latent-variable models this may still update internal
167    /// filtering/posterior state needed for correct sequential predictions.
168    fn update_frozen(&mut self, symbol: u8) {
169        self.update(symbol);
170    }
171}
172
173#[cfg(feature = "backend-rwkv")]
174#[inline]
175fn ensure_rwkv_primed(compressor: &mut rwkvzip::Compressor, primed: &mut bool) {
176    if !*primed {
177        compressor.reset_and_prime();
178        *primed = true;
179    }
180}
181
182#[inline]
183fn ctw_log_prob_update_msb(tree: &mut FacContextTree, symbol: u8, min_prob: f64) -> f64 {
184    let mut logp = 0.0;
185    for bit_idx in 0..8 {
186        let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
187        let p = tree.predict(bit, bit_idx);
188        if p.is_finite() && p > 0.0 {
189            logp += p.ln();
190        } else {
191            logp = f64::NEG_INFINITY;
192        }
193        tree.update_predicted(bit, bit_idx);
194    }
195    if logp.is_finite() {
196        logp.max(min_prob.ln())
197    } else {
198        min_prob.ln()
199    }
200}
201
202#[inline]
203fn ctw_log_prob_update_lsb(
204    tree: &mut FacContextTree,
205    symbol: u8,
206    bits_per_symbol: usize,
207    min_prob: f64,
208) -> f64 {
209    let mut logp = 0.0;
210    for bit_idx in 0..bits_per_symbol {
211        let bit = ((symbol >> bit_idx) & 1) == 1;
212        let p = tree.predict(bit, bit_idx);
213        if p.is_finite() && p > 0.0 {
214            logp += p.ln();
215        } else {
216            logp = f64::NEG_INFINITY;
217        }
218        tree.update_predicted(bit, bit_idx);
219    }
220    if logp.is_finite() {
221        logp.max(min_prob.ln())
222    } else {
223        min_prob.ln()
224    }
225}
226
227fn fill_fac_tree_log_probs(
228    tree: &mut FacContextTree,
229    bits_per_symbol: usize,
230    msb_first: bool,
231    min_logp: f64,
232    out: &mut [f64; 256],
233) {
234    struct RecParams {
235        bits: usize,
236        msb_first: bool,
237        log_before: f64,
238        min_logp: f64,
239    }
240
241    let bits = bits_per_symbol.clamp(1, 8);
242    let patterns = 1usize << bits;
243    let mut pattern_logps = [f64::NEG_INFINITY; 256];
244    let params = RecParams {
245        bits,
246        msb_first,
247        log_before: tree.get_log_block_probability(),
248        min_logp,
249    };
250
251    fn rec(
252        tree: &mut FacContextTree,
253        depth: usize,
254        params: &RecParams,
255        symbol_acc: u8,
256        pattern_logps: &mut [f64; 256],
257    ) {
258        if depth == params.bits {
259            let pat = symbol_acc as usize;
260            let logp = (tree.get_log_block_probability() - params.log_before).max(params.min_logp);
261            pattern_logps[pat] = logp;
262            return;
263        }
264
265        for bit in [false, true] {
266            tree.update(bit, depth);
267            let mut next_symbol = symbol_acc;
268            if params.msb_first {
269                let shift = 7usize.saturating_sub(depth);
270                if bit {
271                    next_symbol |= 1u8 << shift;
272                }
273            } else if bit {
274                next_symbol |= 1u8 << depth;
275            }
276            rec(tree, depth + 1, params, next_symbol, pattern_logps);
277            tree.revert(depth);
278        }
279    }
280
281    rec(tree, 0, &params, 0, &mut pattern_logps);
282
283    if bits == 8 {
284        out.copy_from_slice(&pattern_logps);
285    } else {
286        let aliases = 1usize << (8 - bits);
287        let alias_ln = (aliases as f64).ln();
288        let mask = patterns - 1;
289        for byte in 0..256usize {
290            out[byte] = pattern_logps[byte & mask] - alias_ln;
291        }
292    }
293}
294
295/// A concrete online predictor backed by a `RateBackend` configuration.
296#[allow(clippy::large_enum_variant)]
297#[derive(Clone)]
298pub enum RateBackendPredictor {
299    /// ROSA-Plus online suffix automaton.
300    Rosa {
301        /// ROSA model state.
302        model: RosaPlus,
303        /// Probability floor for numeric stability.
304        min_prob: f64,
305    },
306    /// Local contiguous match predictor.
307    Match {
308        /// Match model state.
309        model: MatchModel,
310        /// Probability floor for numeric stability.
311        min_prob: f64,
312    },
313    /// Sparse/gapped local match predictor.
314    SparseMatch {
315        /// Sparse-match model state.
316        model: SparseMatchModel,
317        /// Probability floor for numeric stability.
318        min_prob: f64,
319    },
320    /// Bounded-memory PPMD-style predictor.
321    Ppmd {
322        /// PPMD model state.
323        model: PpmdModel,
324        /// Probability floor for numeric stability.
325        min_prob: f64,
326    },
327    /// Byte-wise CTW implemented as 8 factorized bit trees (MSB-first).
328    Ctw {
329        /// FAC-CTW tree stack (8 bits per byte).
330        tree: FacContextTree,
331        /// Probability floor for numeric stability.
332        min_prob: f64,
333    },
334    /// Factorized CTW with configurable bit-encoding (LSB-first).
335    FacCtw {
336        /// FAC-CTW tree stack for configured bit width.
337        tree: FacContextTree,
338        /// Active bit-width per symbol.
339        bits_per_symbol: usize,
340        /// Probability floor for numeric stability.
341        min_prob: f64,
342    },
343    /// RWKV-7 neural predictor.
344    #[cfg(feature = "backend-rwkv")]
345    Rwkv7 {
346        /// RWKV compressor/runtime state.
347        compressor: rwkvzip::Compressor,
348        /// Whether the first-token distribution has been primed.
349        primed: bool,
350        /// Scratch copy used for update API that borrows immutable PDF.
351        pdf_scratch: Vec<f64>,
352        /// Probability floor for numeric stability.
353        min_prob: f64,
354    },
355    /// Mamba-1 neural predictor.
356    #[cfg(feature = "backend-mamba")]
357    Mamba {
358        /// Mamba compressor/runtime state.
359        compressor: mambazip::Compressor,
360        /// Whether the first-token distribution has been primed.
361        primed: bool,
362        /// Scratch copy used for update API that borrows immutable PDF.
363        pdf_scratch: Vec<f64>,
364        /// Probability floor for numeric stability.
365        min_prob: f64,
366    },
367    /// ZPAQ streaming rate model.
368    Zpaq {
369        /// ZPAQ rate model state.
370        model: ZpaqRateModel,
371    },
372    /// Online mixture over experts (Bayes, fading Bayes, switching, MDL).
373    Mixture {
374        /// Active mixture runtime.
375        runtime: MixtureRuntime,
376    },
377    /// Particle-latent filter ensemble.
378    Particle {
379        /// Particle runtime.
380        runtime: crate::particle::ParticleRuntime,
381    },
382    /// Calibrated wrapper around another predictor.
383    Calibrated {
384        /// Wrapped predictor whose PDF is calibrated.
385        base: Box<RateBackendPredictor>,
386        /// Online calibrator state and context features.
387        core: CalibratorCore,
388        /// Cached calibrated PDF.
389        pdf: [f64; 256],
390        /// Whether `pdf` currently matches wrapped state.
391        valid: bool,
392        /// Probability floor used for numerical stability.
393        min_prob: f64,
394    },
395}
396
397impl RateBackendPredictor {
398    /// Create a new online predictor from a rate backend configuration.
399    pub fn from_backend(backend: RateBackend, max_order: i64, min_prob: f64) -> Self {
400        match backend {
401            RateBackend::RosaPlus => {
402                let mut model = RosaPlus::new(max_order, false, 0, 42);
403                model.build_lm_full_bytes_no_finalize_endpos();
404                Self::Rosa { model, min_prob }
405            }
406            RateBackend::Match {
407                hash_bits,
408                min_len,
409                max_len,
410                base_mix,
411                confidence_scale,
412            } => Self::Match {
413                model: MatchModel::new_contiguous(
414                    hash_bits,
415                    min_len,
416                    max_len,
417                    base_mix,
418                    confidence_scale,
419                ),
420                min_prob,
421            },
422            RateBackend::SparseMatch {
423                hash_bits,
424                min_len,
425                max_len,
426                gap_min,
427                gap_max,
428                base_mix,
429                confidence_scale,
430            } => Self::SparseMatch {
431                model: SparseMatchModel::new(
432                    hash_bits,
433                    min_len,
434                    max_len,
435                    gap_min,
436                    gap_max,
437                    base_mix,
438                    confidence_scale,
439                ),
440                min_prob,
441            },
442            RateBackend::Ppmd { order, memory_mb } => Self::Ppmd {
443                model: PpmdModel::new(order, memory_mb),
444                min_prob,
445            },
446            RateBackend::Ctw { depth } => {
447                let tree = FacContextTree::new(depth, 8);
448                Self::Ctw { tree, min_prob }
449            }
450            RateBackend::FacCtw {
451                base_depth,
452                num_percept_bits: _,
453                encoding_bits,
454            } => {
455                let bits_per_symbol = encoding_bits.clamp(1, 8);
456                let tree = FacContextTree::new(base_depth, bits_per_symbol);
457                Self::FacCtw {
458                    tree,
459                    bits_per_symbol,
460                    min_prob,
461                }
462            }
463            #[cfg(feature = "backend-rwkv")]
464            RateBackend::Rwkv7 { model } => {
465                let mut compressor = rwkvzip::Compressor::new_from_model(model);
466                compressor.reset_and_prime();
467                Self::Rwkv7 {
468                    pdf_scratch: vec![0.0; compressor.pdf_buffer.len()],
469                    compressor,
470                    primed: true,
471                    min_prob,
472                }
473            }
474            #[cfg(feature = "backend-rwkv")]
475            RateBackend::Rwkv7Method { method } => {
476                let mut compressor = rwkvzip::Compressor::new_from_method(&method)
477                    .unwrap_or_else(|e| panic!("invalid rwkv method '{method}': {e}"));
478                compressor.reset_and_prime();
479                Self::Rwkv7 {
480                    pdf_scratch: vec![0.0; compressor.pdf_buffer.len()],
481                    compressor,
482                    primed: true,
483                    min_prob,
484                }
485            }
486            #[cfg(feature = "backend-mamba")]
487            RateBackend::Mamba { model } => {
488                let mut compressor = mambazip::Compressor::new_from_model(model);
489                let bias = compressor.online_bias_snapshot();
490                let logits =
491                    compressor
492                        .model
493                        .forward(&mut compressor.scratch, 0, &mut compressor.state);
494                mambazip::Compressor::logits_to_pdf(
495                    logits,
496                    bias.as_deref(),
497                    &mut compressor.pdf_buffer,
498                );
499                Self::Mamba {
500                    pdf_scratch: vec![0.0; compressor.pdf_buffer.len()],
501                    compressor,
502                    primed: true,
503                    min_prob,
504                }
505            }
506            #[cfg(feature = "backend-mamba")]
507            RateBackend::MambaMethod { method } => {
508                let mut compressor = mambazip::Compressor::new_from_method(&method)
509                    .unwrap_or_else(|e| panic!("invalid mamba method '{method}': {e}"));
510                let bias = compressor.online_bias_snapshot();
511                let logits =
512                    compressor
513                        .model
514                        .forward(&mut compressor.scratch, 0, &mut compressor.state);
515                mambazip::Compressor::logits_to_pdf(
516                    logits,
517                    bias.as_deref(),
518                    &mut compressor.pdf_buffer,
519                );
520                Self::Mamba {
521                    pdf_scratch: vec![0.0; compressor.pdf_buffer.len()],
522                    compressor,
523                    primed: true,
524                    min_prob,
525                }
526            }
527            RateBackend::Zpaq { method } => {
528                let model = ZpaqRateModel::new(method, min_prob);
529                Self::Zpaq { model }
530            }
531            RateBackend::Mixture { spec } => {
532                let experts = spec.build_experts();
533                let runtime = build_mixture_runtime(spec.as_ref(), &experts)
534                    .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
535                Self::Mixture { runtime }
536            }
537            RateBackend::Particle { spec } => {
538                let runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
539                Self::Particle { runtime }
540            }
541            RateBackend::Calibrated { spec } => Self::Calibrated {
542                base: Box::new(Self::from_backend(spec.base.clone(), max_order, min_prob)),
543                core: build_calibrator(spec.as_ref()),
544                pdf: [1.0 / 256.0; 256],
545                valid: false,
546                min_prob,
547            },
548        }
549    }
550
551    /// Human-readable default name for a backend + config.
552    pub fn default_name(backend: &RateBackend, max_order: i64) -> String {
553        match backend {
554            RateBackend::RosaPlus => format!("rosa(mo={})", max_order),
555            RateBackend::Match { .. } => "match".to_string(),
556            RateBackend::SparseMatch { .. } => "sparse-match".to_string(),
557            RateBackend::Ppmd { order, memory_mb } => {
558                format!("ppmd(o={},m={}MiB)", order, memory_mb)
559            }
560            RateBackend::Ctw { depth } => format!("ctw(d={})", depth),
561            RateBackend::FacCtw {
562                base_depth,
563                encoding_bits,
564                ..
565            } => format!("fac-ctw(d={},b={})", base_depth, encoding_bits),
566            #[cfg(feature = "backend-rwkv")]
567            RateBackend::Rwkv7 { .. } => "rwkv7".to_string(),
568            #[cfg(feature = "backend-rwkv")]
569            RateBackend::Rwkv7Method { method } => format!("rwkv7({method})"),
570            #[cfg(feature = "backend-mamba")]
571            RateBackend::Mamba { .. } => "mamba".to_string(),
572            #[cfg(feature = "backend-mamba")]
573            RateBackend::MambaMethod { method } => format!("mamba({method})"),
574            RateBackend::Zpaq { method } => format!("zpaq(m={})", method),
575            RateBackend::Mixture { spec } => {
576                let kind = match spec.kind {
577                    MixtureKind::Bayes => "bayes",
578                    MixtureKind::FadingBayes => "fading",
579                    MixtureKind::Switching => "switch",
580                    MixtureKind::Mdl => "mdl",
581                    MixtureKind::Neural => "neural",
582                };
583                format!("mix({})", kind)
584            }
585            RateBackend::Particle { spec } => {
586                format!("particle(n={},c={})", spec.num_particles, spec.num_cells)
587            }
588            RateBackend::Calibrated { spec } => {
589                format!("calibrated({})", Self::default_name(&spec.base, max_order))
590            }
591        }
592    }
593}
594
595impl OnlineBytePredictor for RateBackendPredictor {
596    fn begin_stream(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
597        self.finish_stream()?;
598        match self {
599            RateBackendPredictor::Rosa { model, .. } => {
600                if let Some(total) = total_symbols {
601                    let reserve = usize::try_from(total).unwrap_or(usize::MAX / 4);
602                    model.reserve_for_stream(reserve);
603                }
604                Ok(())
605            }
606            RateBackendPredictor::Match { .. }
607            | RateBackendPredictor::SparseMatch { .. }
608            | RateBackendPredictor::Ppmd { .. } => Ok(()),
609            RateBackendPredictor::Ctw { .. }
610            | RateBackendPredictor::FacCtw { .. }
611            | RateBackendPredictor::Zpaq { .. }
612            | RateBackendPredictor::Particle { .. } => Ok(()),
613            #[cfg(feature = "backend-rwkv")]
614            RateBackendPredictor::Rwkv7 { compressor, .. } => compressor
615                .begin_online_policy_stream(total_symbols)
616                .map_err(|e| e.to_string()),
617            #[cfg(feature = "backend-mamba")]
618            RateBackendPredictor::Mamba { compressor, .. } => compressor
619                .begin_online_policy_stream(total_symbols)
620                .map_err(|e| e.to_string()),
621            RateBackendPredictor::Mixture { runtime } => runtime.begin_stream(total_symbols),
622            RateBackendPredictor::Calibrated { base, .. } => base.begin_stream(total_symbols),
623        }
624    }
625
626    fn finish_stream(&mut self) -> Result<(), String> {
627        match self {
628            RateBackendPredictor::Rosa { .. }
629            | RateBackendPredictor::Match { .. }
630            | RateBackendPredictor::SparseMatch { .. }
631            | RateBackendPredictor::Ppmd { .. }
632            | RateBackendPredictor::Ctw { .. }
633            | RateBackendPredictor::FacCtw { .. }
634            | RateBackendPredictor::Zpaq { .. }
635            | RateBackendPredictor::Particle { .. } => Ok(()),
636            #[cfg(feature = "backend-rwkv")]
637            RateBackendPredictor::Rwkv7 { compressor, .. } => compressor
638                .finish_online_policy_stream()
639                .map_err(|e| e.to_string()),
640            #[cfg(feature = "backend-mamba")]
641            RateBackendPredictor::Mamba { .. } => Ok(()),
642            RateBackendPredictor::Mixture { runtime } => runtime.finish_stream(),
643            RateBackendPredictor::Calibrated { base, .. } => base.finish_stream(),
644        }
645    }
646
647    fn log_prob(&mut self, symbol: u8) -> f64 {
648        match self {
649            RateBackendPredictor::Rosa { model, min_prob } => {
650                let p = clamp_prob(model.prob_for_last(symbol as u32), *min_prob);
651                p.ln()
652            }
653            RateBackendPredictor::Match { model, min_prob } => model.log_prob(symbol, *min_prob),
654            RateBackendPredictor::SparseMatch { model, min_prob } => {
655                model.log_prob(symbol, *min_prob)
656            }
657            RateBackendPredictor::Ppmd { model, min_prob } => model.log_prob(symbol, *min_prob),
658            RateBackendPredictor::Ctw { tree, min_prob } => {
659                let log_before = tree.get_log_block_probability();
660                for bit_idx in 0..8 {
661                    let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
662                    tree.update(bit, bit_idx);
663                }
664                let log_after = tree.get_log_block_probability();
665                for bit_idx in (0..8).rev() {
666                    tree.revert(bit_idx);
667                }
668                let logp = log_after - log_before;
669                if logp.is_finite() {
670                    logp.max(min_prob.ln())
671                } else {
672                    min_prob.ln()
673                }
674            }
675            RateBackendPredictor::FacCtw {
676                tree,
677                bits_per_symbol,
678                min_prob,
679            } => {
680                let log_before = tree.get_log_block_probability();
681                for i in 0..*bits_per_symbol {
682                    let bit = ((symbol >> i) & 1) == 1;
683                    tree.update(bit, i);
684                }
685                let log_after = tree.get_log_block_probability();
686                for i in (0..*bits_per_symbol).rev() {
687                    tree.revert(i);
688                }
689                let logp = log_after - log_before;
690                if logp.is_finite() {
691                    logp.max(min_prob.ln())
692                } else {
693                    min_prob.ln()
694                }
695            }
696            #[cfg(feature = "backend-rwkv")]
697            RateBackendPredictor::Rwkv7 {
698                compressor,
699                primed,
700                min_prob,
701                ..
702            } => {
703                ensure_rwkv_primed(compressor, primed);
704                let p = clamp_prob(compressor.pdf_buffer[symbol as usize], *min_prob);
705                p.ln()
706            }
707            #[cfg(feature = "backend-mamba")]
708            RateBackendPredictor::Mamba {
709                compressor,
710                primed,
711                min_prob,
712                ..
713            } => {
714                if !*primed {
715                    let bias = compressor.online_bias_snapshot();
716                    let logits =
717                        compressor
718                            .model
719                            .forward(&mut compressor.scratch, 0, &mut compressor.state);
720                    mambazip::Compressor::logits_to_pdf(
721                        logits,
722                        bias.as_deref(),
723                        &mut compressor.pdf_buffer,
724                    );
725                    *primed = true;
726                }
727                let p = clamp_prob(compressor.pdf_buffer[symbol as usize], *min_prob);
728                p.ln()
729            }
730            RateBackendPredictor::Zpaq { model } => model.log_prob(symbol),
731            RateBackendPredictor::Mixture { runtime } => runtime.peek_log_prob(symbol),
732            RateBackendPredictor::Particle { runtime } => runtime.peek_log_prob(symbol),
733            RateBackendPredictor::Calibrated {
734                base,
735                core,
736                pdf,
737                valid,
738                min_prob,
739            } => {
740                if !*valid {
741                    let mut base_logps = [0.0; 256];
742                    base.fill_log_probs(&mut base_logps);
743                    let mut base_pdf = [0.0; 256];
744                    for (dst, &lp) in base_pdf.iter_mut().zip(base_logps.iter()) {
745                        *dst = clamp_prob(lp.exp(), *min_prob);
746                    }
747                    core.apply_pdf(&base_pdf, pdf);
748                    *valid = true;
749                }
750                pdf[symbol as usize].max(*min_prob).ln()
751            }
752        }
753    }
754
755    fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
756        match self {
757            RateBackendPredictor::Rosa { model, min_prob } => {
758                model.fill_probs_for_last_bytes(out);
759                for slot in out.iter_mut() {
760                    *slot = clamp_prob(*slot, *min_prob).ln();
761                }
762            }
763            RateBackendPredictor::Match { model, min_prob } => {
764                let mut pdf = [0.0; 256];
765                model.fill_pdf(&mut pdf);
766                for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
767                    *slot = clamp_prob(p, *min_prob).ln();
768                }
769            }
770            RateBackendPredictor::SparseMatch { model, min_prob } => {
771                let mut pdf = [0.0; 256];
772                model.fill_pdf(&mut pdf);
773                for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
774                    *slot = clamp_prob(p, *min_prob).ln();
775                }
776            }
777            RateBackendPredictor::Ppmd { model, min_prob } => {
778                let mut pdf = [0.0; 256];
779                model.fill_pdf(&mut pdf);
780                for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
781                    *slot = clamp_prob(p, *min_prob).ln();
782                }
783            }
784            RateBackendPredictor::Ctw { tree, min_prob } => {
785                fill_fac_tree_log_probs(tree, 8, true, min_prob.ln(), out);
786            }
787            RateBackendPredictor::FacCtw {
788                tree,
789                bits_per_symbol,
790                min_prob,
791            } => {
792                fill_fac_tree_log_probs(tree, *bits_per_symbol, false, min_prob.ln(), out);
793            }
794            #[cfg(feature = "backend-rwkv")]
795            RateBackendPredictor::Rwkv7 {
796                compressor,
797                primed,
798                min_prob,
799                ..
800            } => {
801                ensure_rwkv_primed(compressor, primed);
802                for (slot, &p_raw) in out
803                    .iter_mut()
804                    .take(256)
805                    .zip(compressor.pdf_buffer.iter().take(256))
806                {
807                    let p = clamp_prob(p_raw, *min_prob);
808                    *slot = p.ln();
809                }
810            }
811            #[cfg(feature = "backend-mamba")]
812            RateBackendPredictor::Mamba {
813                compressor,
814                primed,
815                min_prob,
816                ..
817            } => {
818                if !*primed {
819                    let bias = compressor.online_bias_snapshot();
820                    let logits =
821                        compressor
822                            .model
823                            .forward(&mut compressor.scratch, 0, &mut compressor.state);
824                    mambazip::Compressor::logits_to_pdf(
825                        logits,
826                        bias.as_deref(),
827                        &mut compressor.pdf_buffer,
828                    );
829                    *primed = true;
830                }
831                for (slot, &p_raw) in out
832                    .iter_mut()
833                    .take(256)
834                    .zip(compressor.pdf_buffer.iter().take(256))
835                {
836                    let p = clamp_prob(p_raw, *min_prob);
837                    *slot = p.ln();
838                }
839            }
840            RateBackendPredictor::Zpaq { model } => {
841                model.fill_log_probs(out);
842            }
843            RateBackendPredictor::Mixture { runtime } => {
844                runtime.fill_log_probs(out);
845            }
846            RateBackendPredictor::Particle { runtime } => {
847                runtime.fill_log_probs_cached(out);
848            }
849            RateBackendPredictor::Calibrated {
850                base,
851                core,
852                pdf,
853                valid,
854                min_prob,
855            } => {
856                if !*valid {
857                    let mut base_logps = [0.0; 256];
858                    base.fill_log_probs(&mut base_logps);
859                    let mut base_pdf = [0.0; 256];
860                    for (dst, &lp) in base_pdf.iter_mut().zip(base_logps.iter()) {
861                        *dst = clamp_prob(lp.exp(), *min_prob);
862                    }
863                    core.apply_pdf(&base_pdf, pdf);
864                    *valid = true;
865                }
866                for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
867                    *slot = clamp_prob(p, *min_prob).ln();
868                }
869            }
870        }
871    }
872
873    fn update(&mut self, symbol: u8) {
874        match self {
875            RateBackendPredictor::Rosa { model, .. } => {
876                model.train_byte(symbol);
877            }
878            RateBackendPredictor::Match { model, .. } => {
879                model.update(symbol);
880            }
881            RateBackendPredictor::SparseMatch { model, .. } => {
882                model.update(symbol);
883            }
884            RateBackendPredictor::Ppmd { model, .. } => {
885                model.update(symbol);
886            }
887            RateBackendPredictor::Ctw { tree, .. } => {
888                for bit_idx in 0..8 {
889                    let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
890                    tree.update(bit, bit_idx);
891                }
892            }
893            RateBackendPredictor::FacCtw {
894                tree,
895                bits_per_symbol,
896                ..
897            } => {
898                for i in 0..*bits_per_symbol {
899                    let bit = ((symbol >> i) & 1) == 1;
900                    tree.update(bit, i);
901                }
902            }
903            #[cfg(feature = "backend-rwkv")]
904            RateBackendPredictor::Rwkv7 {
905                compressor, primed, ..
906            } => {
907                ensure_rwkv_primed(compressor, primed);
908                compressor
909                    .observe_symbol_from_current_pdf(symbol)
910                    .unwrap_or_else(|e| panic!("rwkv online update failed: {e}"));
911            }
912            #[cfg(feature = "backend-mamba")]
913            RateBackendPredictor::Mamba {
914                compressor,
915                primed,
916                pdf_scratch,
917                ..
918            } => {
919                if !*primed {
920                    let bias = compressor.online_bias_snapshot();
921                    let logits =
922                        compressor
923                            .model
924                            .forward(&mut compressor.scratch, 0, &mut compressor.state);
925                    mambazip::Compressor::logits_to_pdf(
926                        logits,
927                        bias.as_deref(),
928                        &mut compressor.pdf_buffer,
929                    );
930                    *primed = true;
931                }
932                if pdf_scratch.len() != compressor.pdf_buffer.len() {
933                    pdf_scratch.resize(compressor.pdf_buffer.len(), 0.0);
934                }
935                pdf_scratch.copy_from_slice(&compressor.pdf_buffer);
936                compressor
937                    .online_update_from_pdf(symbol, pdf_scratch)
938                    .unwrap_or_else(|e| panic!("mamba online update failed: {e}"));
939                let bias = compressor.online_bias_snapshot();
940                let logits = compressor.model.forward(
941                    &mut compressor.scratch,
942                    symbol as u32,
943                    &mut compressor.state,
944                );
945                mambazip::Compressor::logits_to_pdf(
946                    logits,
947                    bias.as_deref(),
948                    &mut compressor.pdf_buffer,
949                );
950            }
951            RateBackendPredictor::Zpaq { model } => {
952                model.update(symbol);
953            }
954            RateBackendPredictor::Mixture { runtime } => {
955                let _ = runtime.step(symbol);
956            }
957            RateBackendPredictor::Particle { runtime } => {
958                runtime.step(symbol);
959            }
960            RateBackendPredictor::Calibrated {
961                base,
962                core,
963                pdf,
964                valid,
965                ..
966            } => {
967                if !*valid {
968                    let mut base_logps = [0.0; 256];
969                    base.fill_log_probs(&mut base_logps);
970                    let mut base_pdf = [0.0; 256];
971                    for (dst, &lp) in base_pdf.iter_mut().zip(base_logps.iter()) {
972                        *dst = clamp_prob(lp.exp(), DEFAULT_MIN_PROB);
973                    }
974                    core.apply_pdf(&base_pdf, pdf);
975                }
976                core.update(symbol, pdf);
977                base.update(symbol);
978                *valid = false;
979            }
980        }
981    }
982
983    fn log_prob_update(&mut self, symbol: u8) -> f64 {
984        match self {
985            RateBackendPredictor::Rosa { model, min_prob } => {
986                let p = clamp_prob(model.prob_for_last(symbol as u32), *min_prob);
987                model.train_byte(symbol);
988                p.ln()
989            }
990            RateBackendPredictor::Ctw { tree, min_prob } => {
991                ctw_log_prob_update_msb(tree, symbol, *min_prob)
992            }
993            RateBackendPredictor::FacCtw {
994                tree,
995                bits_per_symbol,
996                min_prob,
997            } => ctw_log_prob_update_lsb(tree, symbol, *bits_per_symbol, *min_prob),
998            _ => {
999                let logp = self.log_prob(symbol);
1000                self.update(symbol);
1001                logp
1002            }
1003        }
1004    }
1005
1006    fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1007        self.finish_stream()?;
1008        match self {
1009            RateBackendPredictor::Rosa { model, .. } => {
1010                if let Some(total) = total_symbols {
1011                    let reserve = usize::try_from(total).unwrap_or(usize::MAX / 4);
1012                    model.reserve_for_stream(reserve);
1013                }
1014                model.build_lm_full_bytes_no_finalize_endpos();
1015                model.reset_conditioning_cursor();
1016                Ok(())
1017            }
1018            RateBackendPredictor::Match { model, .. } => {
1019                model.reset_history();
1020                Ok(())
1021            }
1022            RateBackendPredictor::SparseMatch { model, .. } => {
1023                model.reset_history();
1024                Ok(())
1025            }
1026            RateBackendPredictor::Ppmd { model, .. } => {
1027                model.reset_history();
1028                Ok(())
1029            }
1030            RateBackendPredictor::Ctw { tree, .. } => {
1031                tree.reset_history_only();
1032                Ok(())
1033            }
1034            RateBackendPredictor::FacCtw { tree, .. } => {
1035                tree.reset_history_only();
1036                Ok(())
1037            }
1038            #[cfg(feature = "backend-rwkv")]
1039            RateBackendPredictor::Rwkv7 {
1040                compressor, primed, ..
1041            } => {
1042                compressor.reset_and_prime();
1043                *primed = true;
1044                Ok(())
1045            }
1046            #[cfg(feature = "backend-mamba")]
1047            RateBackendPredictor::Mamba {
1048                compressor, primed, ..
1049            } => {
1050                compressor.reset_and_prime();
1051                *primed = true;
1052                Ok(())
1053            }
1054            RateBackendPredictor::Zpaq { .. } => {
1055                Err("plugin entropy is not supported for zpaq rate backends in 1.1.0".to_string())
1056            }
1057            RateBackendPredictor::Mixture { runtime } => runtime.reset_frozen(total_symbols),
1058            RateBackendPredictor::Particle { runtime } => {
1059                runtime.reset_frozen_state();
1060                Ok(())
1061            }
1062            RateBackendPredictor::Calibrated {
1063                base,
1064                core,
1065                pdf,
1066                valid,
1067                ..
1068            } => {
1069                base.reset_frozen(total_symbols)?;
1070                core.reset_context();
1071                pdf.fill(1.0 / 256.0);
1072                *valid = false;
1073                Ok(())
1074            }
1075        }
1076    }
1077
1078    fn update_frozen(&mut self, symbol: u8) {
1079        match self {
1080            RateBackendPredictor::Rosa { model, .. } => {
1081                model.advance_conditioning_byte(symbol);
1082            }
1083            RateBackendPredictor::Match { model, .. } => {
1084                model.update_history_only(symbol);
1085            }
1086            RateBackendPredictor::SparseMatch { model, .. } => {
1087                model.update_history_only(symbol);
1088            }
1089            RateBackendPredictor::Ppmd { model, .. } => {
1090                model.update_history_only(symbol);
1091            }
1092            RateBackendPredictor::Ctw { tree, .. } => {
1093                let mut bits = [false; 8];
1094                for (bit_idx, slot) in bits.iter_mut().enumerate() {
1095                    *slot = ((symbol >> (7 - bit_idx)) & 1) == 1;
1096                }
1097                tree.update_history(&bits);
1098            }
1099            RateBackendPredictor::FacCtw {
1100                tree,
1101                bits_per_symbol,
1102                ..
1103            } => {
1104                let bits = (*bits_per_symbol).clamp(1, 8);
1105                let mut history_bits = [false; 8];
1106                for (idx, slot) in history_bits.iter_mut().enumerate().take(bits) {
1107                    *slot = ((symbol >> idx) & 1) == 1;
1108                }
1109                tree.update_history(&history_bits[..bits]);
1110            }
1111            #[cfg(feature = "backend-rwkv")]
1112            RateBackendPredictor::Rwkv7 {
1113                compressor, primed, ..
1114            } => {
1115                if !*primed {
1116                    compressor.reset_and_prime();
1117                    *primed = true;
1118                }
1119                compressor.forward_to_internal_pdf(symbol as u32);
1120            }
1121            #[cfg(feature = "backend-mamba")]
1122            RateBackendPredictor::Mamba {
1123                compressor, primed, ..
1124            } => {
1125                if !*primed {
1126                    compressor.reset_and_prime();
1127                    *primed = true;
1128                }
1129                let bias = compressor.online_bias_snapshot();
1130                let logits = compressor.model.forward(
1131                    &mut compressor.scratch,
1132                    symbol as u32,
1133                    &mut compressor.state,
1134                );
1135                mambazip::Compressor::logits_to_pdf(
1136                    logits,
1137                    bias.as_deref(),
1138                    &mut compressor.pdf_buffer,
1139                );
1140            }
1141            RateBackendPredictor::Zpaq { model } => {
1142                model.update(symbol);
1143            }
1144            RateBackendPredictor::Mixture { runtime } => {
1145                runtime.update_frozen(symbol);
1146            }
1147            RateBackendPredictor::Particle { runtime } => {
1148                runtime.update_frozen(symbol);
1149            }
1150            RateBackendPredictor::Calibrated {
1151                base,
1152                core,
1153                pdf,
1154                valid,
1155                ..
1156            } => {
1157                if !*valid {
1158                    let mut base_logps = [0.0; 256];
1159                    base.fill_log_probs(&mut base_logps);
1160                    let mut base_pdf = [0.0; 256];
1161                    for (dst, &lp) in base_pdf.iter_mut().zip(base_logps.iter()) {
1162                        *dst = clamp_prob(lp.exp(), DEFAULT_MIN_PROB);
1163                    }
1164                    core.apply_pdf(&base_pdf, pdf);
1165                    *valid = true;
1166                }
1167                base.update_frozen(symbol);
1168                core.update_context_only(symbol);
1169                *valid = false;
1170            }
1171        }
1172    }
1173}
1174
1175/// Configuration for a mixture expert.
1176#[derive(Clone)]
1177pub struct ExpertConfig {
1178    /// Human-readable expert identifier.
1179    pub name: String,
1180    /// Log prior weight (natural log). Uniform priors can be `0.0`.
1181    pub log_prior: f64,
1182    builder: Arc<dyn Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync>,
1183}
1184
1185impl ExpertConfig {
1186    /// Create a new expert config from a builder closure.
1187    pub fn new(
1188        name: impl Into<String>,
1189        log_prior: f64,
1190        builder: impl Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync + 'static,
1191    ) -> Self {
1192        Self {
1193            name: name.into(),
1194            log_prior,
1195            builder: Arc::new(builder),
1196        }
1197    }
1198
1199    /// Uniform prior helper.
1200    pub fn uniform(
1201        name: impl Into<String>,
1202        builder: impl Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync + 'static,
1203    ) -> Self {
1204        Self::new(name, 0.0, builder)
1205    }
1206
1207    /// Expert from a `RateBackend` configuration. `max_order` applies to ROSA.
1208    pub fn from_rate_backend(
1209        name: Option<String>,
1210        log_prior: f64,
1211        backend: RateBackend,
1212        max_order: i64,
1213    ) -> Self {
1214        let name = name.unwrap_or_else(|| RateBackendPredictor::default_name(&backend, max_order));
1215        Self::new(name, log_prior, move || {
1216            Box::new(RateBackendPredictor::from_backend(
1217                backend.clone(),
1218                max_order,
1219                DEFAULT_MIN_PROB,
1220            ))
1221        })
1222    }
1223
1224    /// ROSA expert (uniform prior).
1225    pub fn rosa(name: impl Into<String>, max_order: i64) -> Self {
1226        let name = name.into();
1227        Self::uniform(name, move || {
1228            Box::new(RateBackendPredictor::from_backend(
1229                RateBackend::RosaPlus,
1230                max_order,
1231                DEFAULT_MIN_PROB,
1232            ))
1233        })
1234    }
1235
1236    /// CTW expert (uniform prior).
1237    pub fn ctw(name: impl Into<String>, depth: usize) -> Self {
1238        let name = name.into();
1239        Self::uniform(name, move || {
1240            Box::new(RateBackendPredictor::from_backend(
1241                RateBackend::Ctw { depth },
1242                -1,
1243                DEFAULT_MIN_PROB,
1244            ))
1245        })
1246    }
1247
1248    /// FAC-CTW expert (uniform prior).
1249    pub fn fac_ctw(name: impl Into<String>, base_depth: usize, encoding_bits: usize) -> Self {
1250        let name = name.into();
1251        Self::uniform(name, move || {
1252            Box::new(RateBackendPredictor::from_backend(
1253                RateBackend::FacCtw {
1254                    base_depth,
1255                    num_percept_bits: encoding_bits,
1256                    encoding_bits,
1257                },
1258                -1,
1259                DEFAULT_MIN_PROB,
1260            ))
1261        })
1262    }
1263
1264    /// RWKV-7 expert (uniform prior).
1265    #[cfg(feature = "backend-rwkv")]
1266    pub fn rwkv(name: impl Into<String>, model: Arc<rwkvzip::Model>) -> Self {
1267        let name = name.into();
1268        Self::uniform(name, move || {
1269            Box::new(RateBackendPredictor::from_backend(
1270                RateBackend::Rwkv7 {
1271                    model: model.clone(),
1272                },
1273                -1,
1274                DEFAULT_MIN_PROB,
1275            ))
1276        })
1277    }
1278
1279    /// Mamba expert (uniform prior).
1280    #[cfg(feature = "backend-mamba")]
1281    pub fn mamba(name: impl Into<String>, model: Arc<mambazip::Model>) -> Self {
1282        let name = name.into();
1283        Self::uniform(name, move || {
1284            Box::new(RateBackendPredictor::from_backend(
1285                RateBackend::Mamba {
1286                    model: model.clone(),
1287                },
1288                -1,
1289                DEFAULT_MIN_PROB,
1290            ))
1291        })
1292    }
1293
1294    /// ZPAQ expert (uniform prior).
1295    pub fn zpaq(name: impl Into<String>, method: impl Into<String>) -> Self {
1296        let name = name.into();
1297        let method = method.into();
1298        Self::uniform(name, move || {
1299            Box::new(RateBackendPredictor::from_backend(
1300                RateBackend::Zpaq {
1301                    method: method.clone(),
1302                },
1303                -1,
1304                DEFAULT_MIN_PROB,
1305            ))
1306        })
1307    }
1308
1309    /// Expert name.
1310    pub fn name(&self) -> &str {
1311        &self.name
1312    }
1313
1314    /// Log prior weight (unnormalized).
1315    pub fn log_prior(&self) -> f64 {
1316        self.log_prior
1317    }
1318
1319    /// Build a fresh predictor instance for evaluation or analysis.
1320    pub fn build_predictor(&self) -> Box<dyn OnlineBytePredictor> {
1321        (self.builder)()
1322    }
1323
1324    fn build(&self) -> ExpertState {
1325        ExpertState {
1326            name: self.name.clone(),
1327            log_weight: self.log_prior,
1328            log_prior: self.log_prior,
1329            predictor: (self.builder)(),
1330            cum_log_loss: 0.0,
1331        }
1332    }
1333}
1334
1335#[derive(Clone)]
1336struct ExpertState {
1337    name: String,
1338    log_weight: f64,
1339    log_prior: f64,
1340    predictor: Box<dyn OnlineBytePredictor>,
1341    cum_log_loss: f64,
1342}
1343
1344impl ExpertState {
1345    #[inline]
1346    fn begin_stream(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1347        self.predictor.begin_stream(total_symbols)
1348    }
1349
1350    #[inline]
1351    fn finish_stream(&mut self) -> Result<(), String> {
1352        self.predictor.finish_stream()
1353    }
1354
1355    #[inline]
1356    fn log_prob(&mut self, symbol: u8) -> f64 {
1357        self.predictor.log_prob(symbol)
1358    }
1359
1360    #[inline]
1361    fn log_prob_update(&mut self, symbol: u8) -> f64 {
1362        self.predictor.log_prob_update(symbol)
1363    }
1364
1365    #[inline]
1366    fn update(&mut self, symbol: u8) {
1367        self.predictor.update(symbol);
1368    }
1369
1370    #[inline]
1371    fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1372        self.predictor.reset_frozen(total_symbols)
1373    }
1374
1375    #[inline]
1376    fn update_frozen(&mut self, symbol: u8) {
1377        self.predictor.update_frozen(symbol);
1378    }
1379}
1380
1381/// Exponential-weights Bayes mixture (log-loss Hedge).
1382#[derive(Clone)]
1383pub struct BayesMixture {
1384    experts: Vec<ExpertState>,
1385    scratch_logps: Vec<f64>,
1386    scratch_mix: Vec<f64>,
1387    cached_symbol: u8,
1388    cached_log_mix: f64,
1389    cache_valid: bool,
1390    total_log_loss: f64,
1391}
1392
1393impl BayesMixture {
1394    /// Construct a normalized Bayes mixture from expert configs.
1395    pub fn new(configs: &[ExpertConfig]) -> Self {
1396        let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
1397        let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
1398        let norm = logsumexp(&log_priors);
1399        for e in &mut experts {
1400            e.log_weight -= norm;
1401        }
1402        Self {
1403            experts,
1404            scratch_logps: vec![0.0; configs.len()],
1405            scratch_mix: vec![0.0; configs.len()],
1406            cached_symbol: 0,
1407            cached_log_mix: f64::NEG_INFINITY,
1408            cache_valid: false,
1409            total_log_loss: 0.0,
1410        }
1411    }
1412
1413    /// Log-probability (natural log) of the mixture for `symbol`, then update.
1414    pub fn step(&mut self, symbol: u8) -> f64 {
1415        if self.experts.is_empty() {
1416            return f64::NEG_INFINITY;
1417        }
1418        let log_mix = if self.cache_valid && self.cached_symbol == symbol {
1419            for (i, expert) in self.experts.iter_mut().enumerate() {
1420                expert.cum_log_loss -= self.scratch_logps[i];
1421                expert.update(symbol);
1422            }
1423            self.cached_log_mix
1424        } else {
1425            for (i, expert) in self.experts.iter_mut().enumerate() {
1426                self.scratch_logps[i] = expert.log_prob_update(symbol);
1427                self.scratch_mix[i] = expert.log_weight + self.scratch_logps[i];
1428                expert.cum_log_loss -= self.scratch_logps[i];
1429            }
1430            logsumexp(&self.scratch_mix)
1431        };
1432        for (i, expert) in self.experts.iter_mut().enumerate() {
1433            expert.log_weight = expert.log_weight + self.scratch_logps[i] - log_mix;
1434        }
1435        self.cache_valid = false;
1436        self.total_log_loss -= log_mix;
1437        log_mix
1438    }
1439
1440    fn predict_log_prob(&mut self, symbol: u8) -> f64 {
1441        if self.experts.is_empty() {
1442            return f64::NEG_INFINITY;
1443        }
1444        for (i, expert) in self.experts.iter_mut().enumerate() {
1445            self.scratch_logps[i] = expert.log_prob(symbol);
1446            self.scratch_mix[i] = expert.log_weight + self.scratch_logps[i];
1447        }
1448        let log_mix = logsumexp(&self.scratch_mix);
1449        self.cached_symbol = symbol;
1450        self.cached_log_mix = log_mix;
1451        self.cache_valid = true;
1452        log_mix
1453    }
1454
1455    fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
1456        if self.experts.is_empty() {
1457            out.fill(f64::NEG_INFINITY);
1458            return;
1459        }
1460        out.fill(f64::NEG_INFINITY);
1461        let norm = logsumexp_weights(&self.experts);
1462        let mut row = [0.0f64; 256];
1463        for expert in &mut self.experts {
1464            expert.predictor.fill_log_probs(&mut row);
1465            let lw = expert.log_weight - norm;
1466            for b in 0..256 {
1467                out[b] = logsumexp2(out[b], lw + row[b]);
1468            }
1469        }
1470    }
1471
1472    /// Posterior weights (normalized) over experts.
1473    pub fn posterior(&self) -> Vec<f64> {
1474        let norm = logsumexp_weights(&self.experts);
1475        self.experts
1476            .iter()
1477            .map(|e| (e.log_weight - norm).exp())
1478            .collect()
1479    }
1480
1481    /// Index and log-loss (nats) of the current best expert.
1482    pub fn min_expert_log_loss(&self) -> (usize, f64) {
1483        let mut best_idx = 0usize;
1484        let mut best_loss = f64::INFINITY;
1485        for (i, e) in self.experts.iter().enumerate() {
1486            if e.cum_log_loss < best_loss {
1487                best_loss = e.cum_log_loss;
1488                best_idx = i;
1489            }
1490        }
1491        (best_idx, best_loss)
1492    }
1493
1494    /// Index and posterior mass of the most likely expert.
1495    pub fn max_posterior(&self) -> (usize, f64) {
1496        let norm = logsumexp_weights(&self.experts);
1497        let mut best_idx = 0usize;
1498        let mut best_p = 0.0;
1499        for (i, e) in self.experts.iter().enumerate() {
1500            let p = (e.log_weight - norm).exp();
1501            if p > best_p {
1502                best_p = p;
1503                best_idx = i;
1504            }
1505        }
1506        (best_idx, best_p)
1507    }
1508
1509    /// Total log-loss of the mixture so far (nats).
1510    pub fn total_log_loss(&self) -> f64 {
1511        self.total_log_loss
1512    }
1513
1514    /// Expert cumulative log-losses (nats) and names.
1515    pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
1516        self.experts
1517            .iter()
1518            .map(|e| (e.name.clone(), e.cum_log_loss))
1519            .collect()
1520    }
1521
1522    /// Expert names in order.
1523    pub fn expert_names(&self) -> Vec<String> {
1524        self.experts.iter().map(|e| e.name.clone()).collect()
1525    }
1526
1527    fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1528        for expert in &mut self.experts {
1529            expert.reset_frozen(total_symbols)?;
1530        }
1531        self.cache_valid = false;
1532        self.total_log_loss = 0.0;
1533        Ok(())
1534    }
1535
1536    fn update_frozen(&mut self, symbol: u8) {
1537        for expert in &mut self.experts {
1538            expert.update_frozen(symbol);
1539        }
1540        self.cache_valid = false;
1541    }
1542}
1543
1544/// Exponential-weights Bayes mixture with exponential forgetting on weights.
1545///
1546/// This is a non-stationary control: weights are discounted each step by `decay`.
1547#[derive(Clone)]
1548pub struct FadingBayesMixture {
1549    experts: Vec<ExpertState>,
1550    decay: f64,
1551    scratch_logps: Vec<f64>,
1552    scratch_mix: Vec<f64>,
1553    cached_symbol: u8,
1554    cached_log_mix: f64,
1555    cache_valid: bool,
1556    total_log_loss: f64,
1557}
1558
1559impl FadingBayesMixture {
1560    /// Construct a fading Bayes mixture with decay in `[0, 1]`.
1561    pub fn new(configs: &[ExpertConfig], decay: f64) -> Self {
1562        let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
1563        let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
1564        let norm = logsumexp(&log_priors);
1565        for e in &mut experts {
1566            e.log_weight -= norm;
1567        }
1568        let decay = decay.clamp(0.0, 1.0);
1569        Self {
1570            experts,
1571            decay,
1572            scratch_logps: vec![0.0; configs.len()],
1573            scratch_mix: vec![0.0; configs.len()],
1574            cached_symbol: 0,
1575            cached_log_mix: f64::NEG_INFINITY,
1576            cache_valid: false,
1577            total_log_loss: 0.0,
1578        }
1579    }
1580
1581    /// Log-probability (natural log) of the fading mixture for `symbol`, then update.
1582    pub fn step(&mut self, symbol: u8) -> f64 {
1583        if self.experts.is_empty() {
1584            return f64::NEG_INFINITY;
1585        }
1586        let log_mix = if self.cache_valid && self.cached_symbol == symbol {
1587            for (i, expert) in self.experts.iter_mut().enumerate() {
1588                expert.cum_log_loss -= self.scratch_logps[i];
1589                expert.update(symbol);
1590            }
1591            self.cached_log_mix
1592        } else {
1593            for (i, expert) in self.experts.iter_mut().enumerate() {
1594                self.scratch_logps[i] = expert.log_prob_update(symbol);
1595                let decayed = self.decay * expert.log_weight;
1596                self.scratch_mix[i] = decayed + self.scratch_logps[i];
1597                expert.cum_log_loss -= self.scratch_logps[i];
1598            }
1599            logsumexp(&self.scratch_mix)
1600        };
1601        for (i, expert) in self.experts.iter_mut().enumerate() {
1602            let decayed = self.decay * expert.log_weight;
1603            expert.log_weight = decayed + self.scratch_logps[i] - log_mix;
1604        }
1605        self.cache_valid = false;
1606        self.total_log_loss -= log_mix;
1607        log_mix
1608    }
1609
1610    fn predict_log_prob(&mut self, symbol: u8) -> f64 {
1611        if self.experts.is_empty() {
1612            return f64::NEG_INFINITY;
1613        }
1614        for (i, expert) in self.experts.iter_mut().enumerate() {
1615            self.scratch_logps[i] = expert.log_prob(symbol);
1616            self.scratch_mix[i] = self.decay * expert.log_weight + self.scratch_logps[i];
1617        }
1618        let log_mix = logsumexp(&self.scratch_mix);
1619        self.cached_symbol = symbol;
1620        self.cached_log_mix = log_mix;
1621        self.cache_valid = true;
1622        log_mix
1623    }
1624
1625    fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
1626        if self.experts.is_empty() {
1627            out.fill(f64::NEG_INFINITY);
1628            return;
1629        }
1630        out.fill(f64::NEG_INFINITY);
1631        let mut decayed = Vec::with_capacity(self.experts.len());
1632        for expert in &self.experts {
1633            decayed.push(self.decay * expert.log_weight);
1634        }
1635        let norm = logsumexp(&decayed);
1636        let mut row = [0.0f64; 256];
1637        for (i, expert) in self.experts.iter_mut().enumerate() {
1638            expert.predictor.fill_log_probs(&mut row);
1639            let lw = decayed[i] - norm;
1640            for b in 0..256 {
1641                out[b] = logsumexp2(out[b], lw + row[b]);
1642            }
1643        }
1644    }
1645
1646    /// Posterior weights (normalized) over experts.
1647    pub fn posterior(&self) -> Vec<f64> {
1648        let norm = logsumexp_weights(&self.experts);
1649        self.experts
1650            .iter()
1651            .map(|e| (e.log_weight - norm).exp())
1652            .collect()
1653    }
1654
1655    /// Index and log-loss (nats) of the current best expert (non-discounted loss).
1656    pub fn min_expert_log_loss(&self) -> (usize, f64) {
1657        let mut best_idx = 0usize;
1658        let mut best_loss = f64::INFINITY;
1659        for (i, e) in self.experts.iter().enumerate() {
1660            if e.cum_log_loss < best_loss {
1661                best_loss = e.cum_log_loss;
1662                best_idx = i;
1663            }
1664        }
1665        (best_idx, best_loss)
1666    }
1667
1668    /// Total log-loss of the mixture so far (nats).
1669    pub fn total_log_loss(&self) -> f64 {
1670        self.total_log_loss
1671    }
1672
1673    /// Expert names in order.
1674    pub fn expert_names(&self) -> Vec<String> {
1675        self.experts.iter().map(|e| e.name.clone()).collect()
1676    }
1677
1678    fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1679        for expert in &mut self.experts {
1680            expert.reset_frozen(total_symbols)?;
1681        }
1682        self.cache_valid = false;
1683        self.total_log_loss = 0.0;
1684        Ok(())
1685    }
1686
1687    fn update_frozen(&mut self, symbol: u8) {
1688        for expert in &mut self.experts {
1689            expert.update_frozen(symbol);
1690        }
1691        self.cache_valid = false;
1692    }
1693}
1694
1695/// Switching mixture: allows occasional switches between experts.
1696#[derive(Clone)]
1697pub struct SwitchingMixture {
1698    experts: Vec<ExpertState>,
1699    log_prior: Vec<f64>,
1700    log_alpha: f64,
1701    log_1m_alpha: f64,
1702    scratch_logps: Vec<f64>,
1703    scratch_switch: Vec<f64>,
1704    cached_symbol: u8,
1705    cached_log_mix: f64,
1706    cache_valid: bool,
1707    total_log_loss: f64,
1708}
1709
1710impl SwitchingMixture {
1711    /// Construct a switching mixture with switch probability `alpha`.
1712    pub fn new(configs: &[ExpertConfig], alpha: f64) -> Self {
1713        let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
1714        let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
1715        let norm = logsumexp(&log_priors);
1716        for e in &mut experts {
1717            e.log_weight -= norm;
1718        }
1719        let log_prior: Vec<f64> = experts.iter().map(|e| e.log_prior - norm).collect();
1720        let alpha = alpha.clamp(1e-12, 1.0 - 1e-12);
1721        Self {
1722            experts,
1723            log_prior,
1724            log_alpha: alpha.ln(),
1725            log_1m_alpha: (1.0 - alpha).ln(),
1726            scratch_logps: vec![0.0; configs.len()],
1727            scratch_switch: vec![0.0; configs.len()],
1728            cached_symbol: 0,
1729            cached_log_mix: f64::NEG_INFINITY,
1730            cache_valid: false,
1731            total_log_loss: 0.0,
1732        }
1733    }
1734
1735    /// Log-probability (natural log) of the switching mixture for `symbol`, then update.
1736    pub fn step(&mut self, symbol: u8) -> f64 {
1737        if self.experts.is_empty() {
1738            return f64::NEG_INFINITY;
1739        }
1740        let log_mix = if self.cache_valid && self.cached_symbol == symbol {
1741            for (i, expert) in self.experts.iter_mut().enumerate() {
1742                expert.cum_log_loss -= self.scratch_logps[i];
1743                expert.update(symbol);
1744            }
1745            self.cached_log_mix
1746        } else {
1747            for (i, expert) in self.experts.iter_mut().enumerate() {
1748                self.scratch_logps[i] = expert.log_prob_update(symbol);
1749                expert.cum_log_loss -= self.scratch_logps[i];
1750            }
1751            for i in 0..self.experts.len() {
1752                let log_switch = logsumexp2(
1753                    self.log_1m_alpha + self.experts[i].log_weight,
1754                    self.log_alpha + self.log_prior[i],
1755                );
1756                self.scratch_switch[i] = self.scratch_logps[i] + log_switch;
1757            }
1758            logsumexp(&self.scratch_switch)
1759        };
1760        for i in 0..self.experts.len() {
1761            let expert = &mut self.experts[i];
1762            expert.log_weight = self.scratch_switch[i] - log_mix;
1763        }
1764        self.cache_valid = false;
1765        self.total_log_loss -= log_mix;
1766        log_mix
1767    }
1768
1769    fn predict_log_prob(&mut self, symbol: u8) -> f64 {
1770        if self.experts.is_empty() {
1771            return f64::NEG_INFINITY;
1772        }
1773        for i in 0..self.experts.len() {
1774            let lp = self.experts[i].log_prob(symbol);
1775            self.scratch_logps[i] = lp;
1776            let log_switch = logsumexp2(
1777                self.log_1m_alpha + self.experts[i].log_weight,
1778                self.log_alpha + self.log_prior[i],
1779            );
1780            self.scratch_switch[i] = lp + log_switch;
1781        }
1782        let log_mix = logsumexp(&self.scratch_switch);
1783        self.cached_symbol = symbol;
1784        self.cached_log_mix = log_mix;
1785        self.cache_valid = true;
1786        log_mix
1787    }
1788
1789    fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
1790        if self.experts.is_empty() {
1791            out.fill(f64::NEG_INFINITY);
1792            return;
1793        }
1794        out.fill(f64::NEG_INFINITY);
1795        let mut log_switch = vec![0.0f64; self.experts.len()];
1796        for (i, expert) in self.experts.iter().enumerate() {
1797            log_switch[i] = logsumexp2(
1798                self.log_1m_alpha + expert.log_weight,
1799                self.log_alpha + self.log_prior[i],
1800            );
1801        }
1802        let norm = logsumexp(&log_switch);
1803        let mut row = [0.0f64; 256];
1804        for (i, expert) in self.experts.iter_mut().enumerate() {
1805            expert.predictor.fill_log_probs(&mut row);
1806            let lw = log_switch[i] - norm;
1807            for b in 0..256 {
1808                out[b] = logsumexp2(out[b], lw + row[b]);
1809            }
1810        }
1811    }
1812
1813    /// Posterior weights (normalized) over experts.
1814    pub fn posterior(&self) -> Vec<f64> {
1815        let norm = logsumexp_weights(&self.experts);
1816        self.experts
1817            .iter()
1818            .map(|e| (e.log_weight - norm).exp())
1819            .collect()
1820    }
1821
1822    /// Index and log-loss (nats) of the current best expert.
1823    pub fn min_expert_log_loss(&self) -> (usize, f64) {
1824        let mut best_idx = 0usize;
1825        let mut best_loss = f64::INFINITY;
1826        for (i, e) in self.experts.iter().enumerate() {
1827            if e.cum_log_loss < best_loss {
1828                best_loss = e.cum_log_loss;
1829                best_idx = i;
1830            }
1831        }
1832        (best_idx, best_loss)
1833    }
1834
1835    /// Index and posterior mass of the most likely expert.
1836    pub fn max_posterior(&self) -> (usize, f64) {
1837        let norm = logsumexp_weights(&self.experts);
1838        let mut best_idx = 0usize;
1839        let mut best_p = 0.0;
1840        for (i, e) in self.experts.iter().enumerate() {
1841            let p = (e.log_weight - norm).exp();
1842            if p > best_p {
1843                best_p = p;
1844                best_idx = i;
1845            }
1846        }
1847        (best_idx, best_p)
1848    }
1849
1850    /// Total log-loss of the mixture so far (nats).
1851    pub fn total_log_loss(&self) -> f64 {
1852        self.total_log_loss
1853    }
1854
1855    /// Expert cumulative log-losses (nats) and names.
1856    pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
1857        self.experts
1858            .iter()
1859            .map(|e| (e.name.clone(), e.cum_log_loss))
1860            .collect()
1861    }
1862
1863    /// Expert names in order.
1864    pub fn expert_names(&self) -> Vec<String> {
1865        self.experts.iter().map(|e| e.name.clone()).collect()
1866    }
1867
1868    fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1869        for expert in &mut self.experts {
1870            expert.reset_frozen(total_symbols)?;
1871        }
1872        self.cache_valid = false;
1873        self.total_log_loss = 0.0;
1874        Ok(())
1875    }
1876
1877    fn update_frozen(&mut self, symbol: u8) {
1878        for expert in &mut self.experts {
1879            expert.update_frozen(symbol);
1880        }
1881        self.cache_valid = false;
1882    }
1883}
1884
1885/// MDL-style selector: predicts with the current best expert (by cumulative loss).
1886#[derive(Clone)]
1887pub struct MdlSelector {
1888    experts: Vec<ExpertState>,
1889    scratch_logps: Vec<f64>,
1890    total_log_loss: f64,
1891    last_best: usize,
1892    cached_symbol: u8,
1893    cached_best_idx: usize,
1894    cached_best_logp: f64,
1895    cache_valid: bool,
1896}
1897
1898/// Bytewise neural mixer inspired by fx2-cmix online adaptation.
1899///
1900/// This model is a context-conditioned two-stage gating network trained online
1901/// from per-symbol expert likelihoods:
1902/// 1) context-local first-stage expert gates,
1903/// 2) context-local second-stage meta-gate over stage-1 outputs,
1904/// 3) per-symbol SGD updates with optional tiny-error skip.
1905#[derive(Clone)]
1906pub struct NeuralMixture {
1907    experts: Vec<ExpertState>,
1908    neural: NeuralMixCore,
1909    analyzer: TextContextAnalyzer,
1910    min_prob: f64,
1911    scratch_expert_logps: Vec<f64>,
1912    scratch_mix_weights: Vec<f64>,
1913    eval_cache_valid: bool,
1914    eval_cache_full_valid: bool,
1915    eval_cache_history: NeuralHistoryState,
1916    eval_cache_symbol: u8,
1917    eval_cache_logp: f64,
1918    eval_cache_mix_logps: [f64; 256],
1919    eval_cache_expert_logps: Vec<[f64; 256]>,
1920    total_log_loss: f64,
1921}
1922
1923impl NeuralMixture {
1924    /// Construct a neural mixture. `learning_rate` is taken from `MixtureSpec.alpha`.
1925    pub fn new(configs: &[ExpertConfig], learning_rate: f64) -> Self {
1926        let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
1927        let n = experts.len();
1928
1929        let mut prior_weights = vec![0.0; n];
1930        if n > 0 {
1931            let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
1932            let norm = logsumexp(&log_priors);
1933            for (i, e) in experts.iter_mut().enumerate() {
1934                let p = (e.log_prior - norm).exp();
1935                prior_weights[i] = p;
1936            }
1937        }
1938
1939        let base_lr = if learning_rate.is_finite() {
1940            learning_rate.abs().clamp(1e-6, 1.0)
1941        } else {
1942            0.03
1943        };
1944        let effective_lr = (base_lr * 25.0).clamp(1e-6, 1.0);
1945        let analyzer = TextContextAnalyzer::new();
1946        let mut neural =
1947            NeuralMixCore::new(n, &prior_weights, effective_lr * 0.5, effective_lr, 1e-5);
1948        neural.set_context_state(analyzer.state());
1949        let eval_cache_history = neural.history_state();
1950
1951        Self {
1952            experts,
1953            neural,
1954            analyzer,
1955            min_prob: DEFAULT_MIN_PROB,
1956            scratch_expert_logps: vec![0.0; n],
1957            scratch_mix_weights: vec![0.0; n],
1958            eval_cache_valid: false,
1959            eval_cache_full_valid: false,
1960            eval_cache_history,
1961            eval_cache_symbol: 0,
1962            eval_cache_logp: f64::NEG_INFINITY,
1963            eval_cache_mix_logps: [f64::NEG_INFINITY; 256],
1964            eval_cache_expert_logps: vec![[f64::NEG_INFINITY; 256]; n],
1965            total_log_loss: 0.0,
1966        }
1967    }
1968
1969    #[inline]
1970    fn invalidate_eval_cache(&mut self) {
1971        self.eval_cache_valid = false;
1972        self.eval_cache_full_valid = false;
1973    }
1974
1975    fn sync_history_state(&mut self) -> NeuralHistoryState {
1976        let history = self.analyzer.state();
1977        if self.neural.history_state() != history {
1978            self.neural.set_context_state(history);
1979        }
1980        if self.eval_cache_history != history {
1981            self.invalidate_eval_cache();
1982            self.eval_cache_history = history;
1983        }
1984        history
1985    }
1986
1987    fn ensure_full_evaluation(&mut self) {
1988        self.sync_history_state();
1989        if self.eval_cache_full_valid {
1990            return;
1991        }
1992
1993        self.neural.evaluate_expert_weights();
1994        self.scratch_mix_weights
1995            .copy_from_slice(self.neural.expert_weights());
1996        let mut mix_pdf = [0.0f64; 256];
1997        for i in 0..self.experts.len() {
1998            let row = &mut self.eval_cache_expert_logps[i];
1999            self.experts[i].predictor.fill_log_probs(row);
2000            let w = self.scratch_mix_weights[i];
2001            for (dst, &lp) in mix_pdf.iter_mut().zip(row.iter()) {
2002                *dst += w * clamp_prob(lp.exp(), self.min_prob);
2003            }
2004        }
2005
2006        let sum: f64 = mix_pdf.iter().sum();
2007        if !sum.is_finite() || sum <= 0.0 {
2008            let uniform = (1.0f64 / 256.0).ln();
2009            self.eval_cache_mix_logps.fill(uniform);
2010        } else {
2011            let inv = 1.0 / sum;
2012            for (dst, &p_raw) in self.eval_cache_mix_logps.iter_mut().zip(mix_pdf.iter()) {
2013                let p = clamp_unit_prob(p_raw * inv, self.min_prob);
2014                *dst = p.ln();
2015            }
2016        }
2017
2018        self.eval_cache_full_valid = true;
2019    }
2020
2021    fn evaluate_symbol(&mut self, symbol: u8) -> f64 {
2022        let history = self.sync_history_state();
2023        if self.eval_cache_valid
2024            && self.eval_cache_history == history
2025            && self.eval_cache_symbol == symbol
2026        {
2027            return self.eval_cache_logp;
2028        }
2029
2030        if self.eval_cache_full_valid && self.eval_cache_history == history {
2031            for (dst, row) in self
2032                .scratch_expert_logps
2033                .iter_mut()
2034                .zip(self.eval_cache_expert_logps.iter())
2035            {
2036                *dst = row[symbol as usize];
2037            }
2038            let logp = self.eval_cache_mix_logps[symbol as usize];
2039            self.eval_cache_valid = true;
2040            self.eval_cache_symbol = symbol;
2041            self.eval_cache_logp = logp;
2042            return logp;
2043        }
2044
2045        let expert_count = self.experts.len();
2046        for i in 0..expert_count {
2047            self.scratch_expert_logps[i] = self.experts[i].log_prob(symbol);
2048        }
2049        let p = self
2050            .neural
2051            .evaluate_symbol(&self.scratch_expert_logps, self.min_prob);
2052        let logp = clamp_unit_prob(p, self.min_prob).ln();
2053        self.eval_cache_valid = true;
2054        self.eval_cache_history = history;
2055        self.eval_cache_symbol = symbol;
2056        self.eval_cache_logp = logp;
2057        logp
2058    }
2059
2060    fn predict_log_prob(&mut self, symbol: u8) -> f64 {
2061        if self.experts.is_empty() {
2062            return f64::NEG_INFINITY;
2063        }
2064        if self.experts.len() == 1 {
2065            return self.experts[0].log_prob(symbol);
2066        }
2067        self.evaluate_symbol(symbol)
2068    }
2069
2070    fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2071        if self.experts.is_empty() {
2072            out.fill(f64::NEG_INFINITY);
2073            return;
2074        }
2075        if self.experts.len() == 1 {
2076            self.experts[0].predictor.fill_log_probs(out);
2077            return;
2078        }
2079        self.ensure_full_evaluation();
2080        out.copy_from_slice(&self.eval_cache_mix_logps);
2081    }
2082
2083    /// Log-probability (natural log) of the neural mixture for `symbol`, then update.
2084    pub fn step(&mut self, symbol: u8) -> f64 {
2085        if self.experts.is_empty() {
2086            return f64::NEG_INFINITY;
2087        }
2088
2089        if self.experts.len() == 1 {
2090            let expert = &mut self.experts[0];
2091            let logp = expert.log_prob_update(symbol);
2092            expert.cum_log_loss -= logp;
2093            self.total_log_loss -= logp;
2094            self.analyzer.update(symbol);
2095            self.neural.set_context_state(self.analyzer.state());
2096            self.invalidate_eval_cache();
2097            return logp;
2098        }
2099
2100        let history = self.sync_history_state();
2101        let logp = if self.eval_cache_valid
2102            && self.eval_cache_history == history
2103            && self.eval_cache_symbol == symbol
2104        {
2105            let logp = self.eval_cache_logp;
2106            for i in 0..self.experts.len() {
2107                let expert = &mut self.experts[i];
2108                expert.cum_log_loss -= self.scratch_expert_logps[i];
2109                expert.update(symbol);
2110            }
2111            logp
2112        } else if self.eval_cache_full_valid && self.eval_cache_history == history {
2113            for i in 0..self.experts.len() {
2114                self.scratch_expert_logps[i] = self.eval_cache_expert_logps[i][symbol as usize];
2115            }
2116            let logp = self.eval_cache_mix_logps[symbol as usize];
2117            for i in 0..self.experts.len() {
2118                let expert = &mut self.experts[i];
2119                expert.cum_log_loss -= self.scratch_expert_logps[i];
2120                expert.update(symbol);
2121            }
2122            logp
2123        } else {
2124            for i in 0..self.experts.len() {
2125                let expert = &mut self.experts[i];
2126                self.scratch_expert_logps[i] = expert.log_prob_update(symbol);
2127                expert.cum_log_loss -= self.scratch_expert_logps[i];
2128            }
2129            let p = self
2130                .neural
2131                .evaluate_symbol(&self.scratch_expert_logps, self.min_prob);
2132            clamp_unit_prob(p, self.min_prob).ln()
2133        };
2134        self.neural
2135            .update_weights_symbol(&self.scratch_expert_logps, self.min_prob);
2136        self.total_log_loss -= logp;
2137        self.analyzer.update(symbol);
2138        self.neural.set_context_state(self.analyzer.state());
2139        self.invalidate_eval_cache();
2140        logp
2141    }
2142
2143    /// Total log-loss of the mixture so far (nats).
2144    pub fn total_log_loss(&self) -> f64 {
2145        self.total_log_loss
2146    }
2147
2148    fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2149        for expert in &mut self.experts {
2150            expert.reset_frozen(total_symbols)?;
2151        }
2152        self.analyzer = TextContextAnalyzer::new();
2153        self.neural.set_context_state(self.analyzer.state());
2154        self.invalidate_eval_cache();
2155        self.eval_cache_history = self.neural.history_state();
2156        self.total_log_loss = 0.0;
2157        Ok(())
2158    }
2159
2160    fn update_frozen(&mut self, symbol: u8) {
2161        for expert in &mut self.experts {
2162            expert.update_frozen(symbol);
2163        }
2164        self.analyzer.update(symbol);
2165        self.neural.set_context_state(self.analyzer.state());
2166        self.invalidate_eval_cache();
2167        self.eval_cache_history = self.neural.history_state();
2168    }
2169}
2170
2171impl MdlSelector {
2172    /// Construct an MDL-style expert selector.
2173    pub fn new(configs: &[ExpertConfig]) -> Self {
2174        let experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
2175        let last_best = 0usize;
2176        Self {
2177            experts,
2178            scratch_logps: vec![0.0; configs.len()],
2179            total_log_loss: 0.0,
2180            last_best,
2181            cached_symbol: 0,
2182            cached_best_idx: 0,
2183            cached_best_logp: f64::NEG_INFINITY,
2184            cache_valid: false,
2185        }
2186    }
2187
2188    /// Log-probability (natural log) of the MDL selector for `symbol`, then update.
2189    pub fn step(&mut self, symbol: u8) -> f64 {
2190        if self.experts.is_empty() {
2191            return f64::NEG_INFINITY;
2192        }
2193        let used_cache = self.cache_valid && self.cached_symbol == symbol;
2194        let best_idx = if used_cache {
2195            self.scratch_logps[self.cached_best_idx] = self.cached_best_logp;
2196            for (i, expert) in self.experts.iter_mut().enumerate() {
2197                if i == self.cached_best_idx {
2198                    continue;
2199                }
2200                self.scratch_logps[i] = expert.log_prob(symbol);
2201            }
2202            self.cached_best_idx
2203        } else {
2204            for (i, expert) in self.experts.iter_mut().enumerate() {
2205                self.scratch_logps[i] = expert.log_prob_update(symbol);
2206            }
2207            let mut best_idx = 0usize;
2208            let mut best_loss = f64::INFINITY;
2209            for (i, expert) in self.experts.iter().enumerate() {
2210                if expert.cum_log_loss < best_loss {
2211                    best_loss = expert.cum_log_loss;
2212                    best_idx = i;
2213                }
2214            }
2215            best_idx
2216        };
2217        let logp = self.scratch_logps[best_idx];
2218        self.cache_valid = false;
2219        for (i, expert) in self.experts.iter_mut().enumerate() {
2220            expert.cum_log_loss -= self.scratch_logps[i];
2221            if used_cache {
2222                expert.update(symbol);
2223            }
2224        }
2225        self.total_log_loss -= logp;
2226        self.last_best = best_idx;
2227        logp
2228    }
2229
2230    fn predict_log_prob(&mut self, symbol: u8) -> f64 {
2231        if self.experts.is_empty() {
2232            return f64::NEG_INFINITY;
2233        }
2234        let mut best_idx = 0usize;
2235        let mut best_loss = f64::INFINITY;
2236        for (i, expert) in self.experts.iter().enumerate() {
2237            if expert.cum_log_loss < best_loss {
2238                best_loss = expert.cum_log_loss;
2239                best_idx = i;
2240            }
2241        }
2242        let logp = self.experts[best_idx].log_prob(symbol);
2243        self.cached_symbol = symbol;
2244        self.cached_best_idx = best_idx;
2245        self.cached_best_logp = logp;
2246        self.cache_valid = true;
2247        logp
2248    }
2249
2250    fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2251        if self.experts.is_empty() {
2252            out.fill(f64::NEG_INFINITY);
2253            return;
2254        }
2255        let mut best_idx = 0usize;
2256        let mut best_loss = f64::INFINITY;
2257        for (i, expert) in self.experts.iter().enumerate() {
2258            if expert.cum_log_loss < best_loss {
2259                best_loss = expert.cum_log_loss;
2260                best_idx = i;
2261            }
2262        }
2263        self.experts[best_idx].predictor.fill_log_probs(out);
2264    }
2265
2266    /// Index of the current best expert.
2267    pub fn best_index(&self) -> usize {
2268        self.last_best
2269    }
2270
2271    /// Index and log-loss (nats) of the current best expert.
2272    pub fn min_expert_log_loss(&self) -> (usize, f64) {
2273        let mut best_idx = 0usize;
2274        let mut best_loss = f64::INFINITY;
2275        for (i, e) in self.experts.iter().enumerate() {
2276            if e.cum_log_loss < best_loss {
2277                best_loss = e.cum_log_loss;
2278                best_idx = i;
2279            }
2280        }
2281        (best_idx, best_loss)
2282    }
2283
2284    /// Total log-loss of the selector so far (nats).
2285    pub fn total_log_loss(&self) -> f64 {
2286        self.total_log_loss
2287    }
2288
2289    /// Expert cumulative log-losses (nats) and names.
2290    pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
2291        self.experts
2292            .iter()
2293            .map(|e| (e.name.clone(), e.cum_log_loss))
2294            .collect()
2295    }
2296
2297    /// Expert names in order.
2298    pub fn expert_names(&self) -> Vec<String> {
2299        self.experts.iter().map(|e| e.name.clone()).collect()
2300    }
2301
2302    fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2303        for expert in &mut self.experts {
2304            expert.reset_frozen(total_symbols)?;
2305        }
2306        self.cache_valid = false;
2307        self.total_log_loss = 0.0;
2308        Ok(())
2309    }
2310
2311    fn update_frozen(&mut self, symbol: u8) {
2312        for expert in &mut self.experts {
2313            expert.update_frozen(symbol);
2314        }
2315        self.cache_valid = false;
2316    }
2317}
2318
2319// =============================================================================
2320// Mixture Runtime Helper (for RateBackend::Mixture)
2321// =============================================================================
2322
2323/// Runtime wrapper over concrete mixture strategies.
2324#[allow(clippy::large_enum_variant)]
2325#[derive(Clone)]
2326pub enum MixtureRuntime {
2327    /// Bayes mixture.
2328    Bayes(BayesMixture),
2329    /// Fading Bayes mixture.
2330    Fading(FadingBayesMixture),
2331    /// Switching mixture.
2332    Switching(SwitchingMixture),
2333    /// MDL selector.
2334    Mdl(MdlSelector),
2335    /// Bytewise neural logistic mixer.
2336    Neural(NeuralMixture),
2337}
2338
2339impl MixtureRuntime {
2340    pub(crate) fn begin_stream(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2341        match self {
2342            MixtureRuntime::Bayes(m) => begin_expert_stream(&mut m.experts, total_symbols),
2343            MixtureRuntime::Fading(m) => begin_expert_stream(&mut m.experts, total_symbols),
2344            MixtureRuntime::Switching(m) => begin_expert_stream(&mut m.experts, total_symbols),
2345            MixtureRuntime::Mdl(m) => begin_expert_stream(&mut m.experts, total_symbols),
2346            MixtureRuntime::Neural(m) => begin_expert_stream(&mut m.experts, total_symbols),
2347        }
2348    }
2349
2350    pub(crate) fn finish_stream(&mut self) -> Result<(), String> {
2351        match self {
2352            MixtureRuntime::Bayes(m) => finish_expert_stream(&mut m.experts),
2353            MixtureRuntime::Fading(m) => finish_expert_stream(&mut m.experts),
2354            MixtureRuntime::Switching(m) => finish_expert_stream(&mut m.experts),
2355            MixtureRuntime::Mdl(m) => finish_expert_stream(&mut m.experts),
2356            MixtureRuntime::Neural(m) => finish_expert_stream(&mut m.experts),
2357        }
2358    }
2359
2360    pub(crate) fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2361        match self {
2362            MixtureRuntime::Bayes(m) => m.reset_frozen(total_symbols),
2363            MixtureRuntime::Fading(m) => m.reset_frozen(total_symbols),
2364            MixtureRuntime::Switching(m) => m.reset_frozen(total_symbols),
2365            MixtureRuntime::Mdl(m) => m.reset_frozen(total_symbols),
2366            MixtureRuntime::Neural(m) => m.reset_frozen(total_symbols),
2367        }
2368    }
2369
2370    /// Non-mutating log-probability (nats) for `symbol` at current state.
2371    pub(crate) fn peek_log_prob(&mut self, symbol: u8) -> f64 {
2372        match self {
2373            MixtureRuntime::Bayes(m) => m.predict_log_prob(symbol),
2374            MixtureRuntime::Fading(m) => m.predict_log_prob(symbol),
2375            MixtureRuntime::Switching(m) => m.predict_log_prob(symbol),
2376            MixtureRuntime::Mdl(m) => m.predict_log_prob(symbol),
2377            MixtureRuntime::Neural(m) => m.predict_log_prob(symbol),
2378        }
2379    }
2380
2381    /// Step the mixture and return log-probability (nats).
2382    pub(crate) fn step(&mut self, symbol: u8) -> f64 {
2383        match self {
2384            MixtureRuntime::Bayes(m) => m.step(symbol),
2385            MixtureRuntime::Fading(m) => m.step(symbol),
2386            MixtureRuntime::Switching(m) => m.step(symbol),
2387            MixtureRuntime::Mdl(m) => m.step(symbol),
2388            MixtureRuntime::Neural(m) => m.step(symbol),
2389        }
2390    }
2391
2392    pub(crate) fn update_frozen(&mut self, symbol: u8) {
2393        match self {
2394            MixtureRuntime::Bayes(m) => m.update_frozen(symbol),
2395            MixtureRuntime::Fading(m) => m.update_frozen(symbol),
2396            MixtureRuntime::Switching(m) => m.update_frozen(symbol),
2397            MixtureRuntime::Mdl(m) => m.update_frozen(symbol),
2398            MixtureRuntime::Neural(m) => m.update_frozen(symbol),
2399        }
2400    }
2401
2402    pub(crate) fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2403        match self {
2404            MixtureRuntime::Bayes(m) => m.fill_log_probs(out),
2405            MixtureRuntime::Fading(m) => m.fill_log_probs(out),
2406            MixtureRuntime::Switching(m) => m.fill_log_probs(out),
2407            MixtureRuntime::Mdl(m) => m.fill_log_probs(out),
2408            MixtureRuntime::Neural(m) => m.fill_log_probs(out),
2409        }
2410    }
2411}
2412
2413fn begin_expert_stream(
2414    experts: &mut [ExpertState],
2415    total_symbols: Option<u64>,
2416) -> Result<(), String> {
2417    for expert in experts {
2418        expert.begin_stream(total_symbols)?;
2419    }
2420    Ok(())
2421}
2422
2423fn finish_expert_stream(experts: &mut [ExpertState]) -> Result<(), String> {
2424    for expert in experts {
2425        expert.finish_stream()?;
2426    }
2427    Ok(())
2428}
2429
2430pub(crate) fn build_mixture_runtime(
2431    spec: &MixtureSpec,
2432    experts: &[ExpertConfig],
2433) -> Result<MixtureRuntime, String> {
2434    if experts.is_empty() {
2435        return Err("mixture spec must include at least one expert".to_string());
2436    }
2437    match spec.kind {
2438        MixtureKind::Bayes => Ok(MixtureRuntime::Bayes(BayesMixture::new(experts))),
2439        MixtureKind::FadingBayes => {
2440            let decay = spec
2441                .decay
2442                .ok_or_else(|| "fading Bayes mixture requires decay".to_string())?;
2443            Ok(MixtureRuntime::Fading(FadingBayesMixture::new(
2444                experts, decay,
2445            )))
2446        }
2447        MixtureKind::Switching => Ok(MixtureRuntime::Switching(SwitchingMixture::new(
2448            experts, spec.alpha,
2449        ))),
2450        MixtureKind::Mdl => Ok(MixtureRuntime::Mdl(MdlSelector::new(experts))),
2451        MixtureKind::Neural => Ok(MixtureRuntime::Neural(NeuralMixture::new(
2452            experts, spec.alpha,
2453        ))),
2454    }
2455}
2456
2457#[cfg(test)]
2458mod tests {
2459    use super::*;
2460    use std::sync::{
2461        Arc,
2462        atomic::{AtomicU64, AtomicUsize, Ordering},
2463    };
2464
2465    #[derive(Clone)]
2466    struct AlwaysPredict {
2467        byte: u8,
2468    }
2469
2470    impl OnlineBytePredictor for AlwaysPredict {
2471        fn log_prob(&mut self, symbol: u8) -> f64 {
2472            if symbol == self.byte {
2473                0.0
2474            } else {
2475                f64::NEG_INFINITY
2476            }
2477        }
2478
2479        fn update(&mut self, _symbol: u8) {}
2480    }
2481
2482    #[test]
2483    fn bayes_mixture_prefers_correct_expert() {
2484        let configs = vec![
2485            ExpertConfig::uniform("zero", || Box::new(AlwaysPredict { byte: 0 })),
2486            ExpertConfig::uniform("one", || Box::new(AlwaysPredict { byte: 1 })),
2487        ];
2488        let mut mix = BayesMixture::new(&configs);
2489        for _ in 0..10 {
2490            mix.step(0);
2491        }
2492        let post = mix.posterior();
2493        assert!(post[0] > 0.999);
2494        assert!(post[1] < 1e-6);
2495    }
2496
2497    fn counting_cfg(name: &'static str, calls: Arc<AtomicUsize>) -> ExpertConfig {
2498        ExpertConfig::uniform(name, move || {
2499            Box::new(CountingPredict {
2500                calls: calls.clone(),
2501            })
2502        })
2503    }
2504
2505    #[test]
2506    fn bayes_predict_then_step_reuses_cached_log_probs() {
2507        let c0 = Arc::new(AtomicUsize::new(0));
2508        let c1 = Arc::new(AtomicUsize::new(0));
2509        let mut mix = BayesMixture::new(&[
2510            counting_cfg("c0", c0.clone()),
2511            counting_cfg("c1", c1.clone()),
2512        ]);
2513        let _ = mix.predict_log_prob(0);
2514        let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2515        assert_eq!(after_predict, 2);
2516        let _ = mix.step(0);
2517        let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2518        assert_eq!(after_step, after_predict);
2519    }
2520
2521    #[test]
2522    fn fading_predict_then_step_reuses_cached_log_probs() {
2523        let c0 = Arc::new(AtomicUsize::new(0));
2524        let c1 = Arc::new(AtomicUsize::new(0));
2525        let mut mix = FadingBayesMixture::new(
2526            &[
2527                counting_cfg("c0", c0.clone()),
2528                counting_cfg("c1", c1.clone()),
2529            ],
2530            0.95,
2531        );
2532        let _ = mix.predict_log_prob(0);
2533        let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2534        assert_eq!(after_predict, 2);
2535        let _ = mix.step(0);
2536        let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2537        assert_eq!(after_step, after_predict);
2538    }
2539
2540    #[test]
2541    fn switching_predict_then_step_reuses_cached_log_probs() {
2542        let c0 = Arc::new(AtomicUsize::new(0));
2543        let c1 = Arc::new(AtomicUsize::new(0));
2544        let mut mix = SwitchingMixture::new(
2545            &[
2546                counting_cfg("c0", c0.clone()),
2547                counting_cfg("c1", c1.clone()),
2548            ],
2549            0.05,
2550        );
2551        let _ = mix.predict_log_prob(0);
2552        let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2553        assert_eq!(after_predict, 2);
2554        let _ = mix.step(0);
2555        let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2556        assert_eq!(after_step, after_predict);
2557    }
2558
2559    #[test]
2560    fn mdl_predict_then_step_reuses_best_expert_log_prob() {
2561        let c0 = Arc::new(AtomicUsize::new(0));
2562        let c1 = Arc::new(AtomicUsize::new(0));
2563        let mut mdl = MdlSelector::new(&[
2564            counting_cfg("c0", c0.clone()),
2565            counting_cfg("c1", c1.clone()),
2566        ]);
2567        let _ = mdl.predict_log_prob(0);
2568        let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2569        assert_eq!(after_predict, 1);
2570        let _ = mdl.step(0);
2571        let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2572        assert_eq!(after_step, 2);
2573    }
2574
2575    #[test]
2576    fn neural_mixture_adapts_to_correct_symbol() {
2577        let configs = vec![
2578            ExpertConfig::uniform("zero", || Box::new(AlwaysPredict { byte: 0 })),
2579            ExpertConfig::uniform("one", || Box::new(AlwaysPredict { byte: 1 })),
2580        ];
2581        let mut mix = NeuralMixture::new(&configs, 0.05);
2582
2583        let mut early = 0.0;
2584        let mut late = 0.0;
2585        for t in 0..200 {
2586            let lp = mix.step(0);
2587            if t < 20 {
2588                early -= lp;
2589            }
2590            if t >= 180 {
2591                late -= lp;
2592            }
2593        }
2594
2595        let early_avg = early / 20.0;
2596        let late_avg = late / 20.0;
2597        assert!(
2598            late_avg < early_avg,
2599            "late_avg={late_avg} early_avg={early_avg}"
2600        );
2601        assert!(late_avg < 0.35, "late_avg={late_avg}");
2602    }
2603
2604    #[derive(Clone)]
2605    struct CountingPredict {
2606        calls: Arc<AtomicUsize>,
2607    }
2608
2609    impl OnlineBytePredictor for CountingPredict {
2610        fn log_prob(&mut self, symbol: u8) -> f64 {
2611            self.calls.fetch_add(1, Ordering::Relaxed);
2612            if symbol == 0 { 0.0 } else { -20.0 }
2613        }
2614
2615        fn update(&mut self, _symbol: u8) {}
2616    }
2617
2618    #[derive(Clone)]
2619    struct CountingFillPredict {
2620        log_calls: Arc<AtomicUsize>,
2621        fill_calls: Arc<AtomicUsize>,
2622    }
2623
2624    impl OnlineBytePredictor for CountingFillPredict {
2625        fn log_prob(&mut self, symbol: u8) -> f64 {
2626            self.log_calls.fetch_add(1, Ordering::Relaxed);
2627            if symbol == 0 { 0.0 } else { -20.0 }
2628        }
2629
2630        fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2631            self.fill_calls.fetch_add(1, Ordering::Relaxed);
2632            out.fill(-20.0);
2633            out[0] = 0.0;
2634        }
2635
2636        fn update(&mut self, _symbol: u8) {}
2637    }
2638
2639    #[derive(Clone)]
2640    struct BeginAwarePredict {
2641        seen_total: Arc<AtomicU64>,
2642        began: bool,
2643    }
2644
2645    impl OnlineBytePredictor for BeginAwarePredict {
2646        fn begin_stream(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2647            let total = total_symbols.ok_or_else(|| "missing total symbols".to_string())?;
2648            self.seen_total.store(total, Ordering::Relaxed);
2649            self.began = true;
2650            Ok(())
2651        }
2652
2653        fn log_prob(&mut self, _symbol: u8) -> f64 {
2654            if self.began { 0.0 } else { f64::NEG_INFINITY }
2655        }
2656
2657        fn update(&mut self, _symbol: u8) {}
2658    }
2659
2660    fn assert_log_prob_update_matches_separate(label: &str, backend: RateBackend) {
2661        let mut separate =
2662            RateBackendPredictor::from_backend(backend.clone(), -1, DEFAULT_MIN_PROB);
2663        let mut combined = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
2664        let data = b"combined step check data";
2665
2666        for &b in data {
2667            let logp_separate = separate.log_prob(b);
2668            separate.update(b);
2669            let logp_combined = combined.log_prob_update(b);
2670            let diff = (logp_separate - logp_combined).abs();
2671            assert!(
2672                diff <= 1e-12,
2673                "[{label}] symbol={b} separate={logp_separate} combined={logp_combined} diff={diff}"
2674            );
2675
2676            let mut sep_row = [0.0; 256];
2677            let mut combo_row = [0.0; 256];
2678            separate.fill_log_probs(&mut sep_row);
2679            combined.fill_log_probs(&mut combo_row);
2680            for i in 0..256 {
2681                let diff = (sep_row[i] - combo_row[i]).abs();
2682                assert!(
2683                    diff <= 1e-12,
2684                    "row mismatch at {i}: {} vs {}",
2685                    sep_row[i],
2686                    combo_row[i]
2687                );
2688            }
2689        }
2690    }
2691
2692    fn assert_fill_matches_symbol_queries(label: &str, backend: RateBackend) {
2693        let mut bulk = RateBackendPredictor::from_backend(backend.clone(), -1, DEFAULT_MIN_PROB);
2694        let mut queried = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
2695        let data = b"continuation consistency prompt";
2696
2697        bulk.begin_stream(Some(data.len() as u64))
2698            .expect("bulk begin");
2699        queried
2700            .begin_stream(Some(data.len() as u64))
2701            .expect("query begin");
2702        for &b in data {
2703            bulk.update(b);
2704            queried.update(b);
2705        }
2706
2707        let mut bulk_row = [0.0; 256];
2708        bulk.fill_log_probs(&mut bulk_row);
2709        for (sym, &bulk_logp) in bulk_row.iter().enumerate() {
2710            let queried_logp = queried.log_prob(sym as u8);
2711            let diff = (bulk_logp - queried_logp).abs();
2712            assert!(
2713                diff <= 1e-12,
2714                "[{label}] sym={sym} bulk={bulk_logp} queried={queried_logp} diff={diff}"
2715            );
2716        }
2717    }
2718
2719    fn assert_fill_matches_symbol_queries_after_frozen_conditioning(
2720        label: &str,
2721        backend: RateBackend,
2722    ) {
2723        let fit = b"If a frog is green, dogs are red.\nIf a toad is green, cats are red.\n";
2724        let condition = b"If a cat is red, toads are \n";
2725        let total = (fit.len() + condition.len()) as u64;
2726
2727        let mut bulk = RateBackendPredictor::from_backend(backend.clone(), -1, DEFAULT_MIN_PROB);
2728        let mut queried = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
2729
2730        bulk.begin_stream(Some(total)).expect("bulk begin");
2731        queried.begin_stream(Some(total)).expect("query begin");
2732        for &b in fit {
2733            bulk.update(b);
2734            queried.update(b);
2735        }
2736        bulk.reset_frozen(Some(condition.len() as u64))
2737            .expect("bulk reset frozen");
2738        queried
2739            .reset_frozen(Some(condition.len() as u64))
2740            .expect("query reset frozen");
2741        for &b in condition {
2742            bulk.update_frozen(b);
2743            queried.update_frozen(b);
2744        }
2745
2746        let mut bulk_row = [0.0; 256];
2747        bulk.fill_log_probs(&mut bulk_row);
2748        for (sym, &bulk_logp) in bulk_row.iter().enumerate() {
2749            let queried_logp = queried.log_prob(sym as u8);
2750            let diff = (bulk_logp - queried_logp).abs();
2751            assert!(
2752                diff <= 1e-12,
2753                "[{label}] frozen sym={sym} bulk={bulk_logp} queried={queried_logp} diff={diff}"
2754            );
2755        }
2756    }
2757
2758    #[test]
2759    fn predictor_log_prob_update_matches_separate_update_for_rosa_backend() {
2760        assert_log_prob_update_matches_separate("rosa", RateBackend::RosaPlus);
2761    }
2762
2763    #[test]
2764    fn predictor_log_prob_update_matches_separate_update_for_ctw_backend() {
2765        assert_log_prob_update_matches_separate("ctw", RateBackend::Ctw { depth: 6 });
2766    }
2767
2768    #[test]
2769    fn predictor_log_prob_update_matches_separate_update_for_fac_ctw_backend() {
2770        assert_log_prob_update_matches_separate(
2771            "fac-ctw",
2772            RateBackend::FacCtw {
2773                base_depth: 6,
2774                num_percept_bits: 8,
2775                encoding_bits: 8,
2776            },
2777        );
2778    }
2779
2780    #[test]
2781    fn predictor_fill_matches_symbol_queries_for_rosa_backend() {
2782        assert_fill_matches_symbol_queries("rosa", RateBackend::RosaPlus);
2783    }
2784
2785    #[test]
2786    fn predictor_fill_matches_symbol_queries_for_ctw_backend() {
2787        assert_fill_matches_symbol_queries("ctw", RateBackend::Ctw { depth: 6 });
2788    }
2789
2790    #[test]
2791    fn predictor_fill_matches_symbol_queries_for_match_backend() {
2792        assert_fill_matches_symbol_queries(
2793            "match",
2794            RateBackend::Match {
2795                hash_bits: 18,
2796                min_len: 4,
2797                max_len: 64,
2798                base_mix: 0.02,
2799                confidence_scale: 1.0,
2800            },
2801        );
2802    }
2803
2804    #[test]
2805    fn predictor_fill_matches_symbol_queries_for_ppmd_backend() {
2806        assert_fill_matches_symbol_queries(
2807            "ppmd",
2808            RateBackend::Ppmd {
2809                order: 8,
2810                memory_mb: 8,
2811            },
2812        );
2813    }
2814
2815    #[cfg(feature = "backend-rwkv")]
2816    #[test]
2817    fn predictor_fill_matches_symbol_queries_for_rwkv_backend() {
2818        assert_fill_matches_symbol_queries(
2819            "rwkv7",
2820            RateBackend::Rwkv7Method {
2821                method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
2822            },
2823        );
2824    }
2825
2826    #[test]
2827    fn predictor_fill_matches_symbol_queries_for_rosa_backend_after_frozen_conditioning() {
2828        assert_fill_matches_symbol_queries_after_frozen_conditioning("rosa", RateBackend::RosaPlus);
2829    }
2830
2831    #[test]
2832    fn predictor_frozen_conditioning_reuses_match_fit_corpus() {
2833        let mut predictor = RateBackendPredictor::from_backend(
2834            RateBackend::Match {
2835                hash_bits: 20,
2836                min_len: 3,
2837                max_len: 32,
2838                base_mix: 0.02,
2839                confidence_scale: 1.0,
2840            },
2841            -1,
2842            DEFAULT_MIN_PROB,
2843        );
2844
2845        for &b in b"abcabcX" {
2846            predictor.update(b);
2847        }
2848        predictor
2849            .reset_frozen(Some(6))
2850            .expect("reset frozen for match backend");
2851        for &b in b"abcabc" {
2852            predictor.update_frozen(b);
2853        }
2854        let p_x = predictor.log_prob(b'X').exp();
2855        assert!(
2856            p_x > 0.01,
2857            "frozen conditioning should preserve fit corpus for match backend; p_x={p_x}"
2858        );
2859    }
2860
2861    #[test]
2862    fn predictor_frozen_conditioning_reuses_sparse_match_fit_corpus() {
2863        let mut predictor = RateBackendPredictor::from_backend(
2864            RateBackend::SparseMatch {
2865                hash_bits: 20,
2866                min_len: 3,
2867                max_len: 32,
2868                gap_min: 0,
2869                gap_max: 2,
2870                base_mix: 0.02,
2871                confidence_scale: 1.0,
2872            },
2873            -1,
2874            DEFAULT_MIN_PROB,
2875        );
2876
2877        for &b in b"abcabcX" {
2878            predictor.update(b);
2879        }
2880        predictor
2881            .reset_frozen(Some(6))
2882            .expect("reset frozen for sparse-match backend");
2883        for &b in b"abcabc" {
2884            predictor.update_frozen(b);
2885        }
2886        let p_x = predictor.log_prob(b'X').exp();
2887        assert!(
2888            p_x > 0.01,
2889            "frozen conditioning should preserve fit corpus for sparse-match backend; p_x={p_x}"
2890        );
2891    }
2892
2893    #[test]
2894    fn neural_predict_then_step_reuses_evaluation_cache() {
2895        let c0 = Arc::new(AtomicUsize::new(0));
2896        let c1 = Arc::new(AtomicUsize::new(0));
2897        let cfg0 = {
2898            let c = c0.clone();
2899            ExpertConfig::uniform("c0", move || Box::new(CountingPredict { calls: c.clone() }))
2900        };
2901        let cfg1 = {
2902            let c = c1.clone();
2903            ExpertConfig::uniform("c1", move || Box::new(CountingPredict { calls: c.clone() }))
2904        };
2905        let mut mix = NeuralMixture::new(&[cfg0, cfg1], 0.03);
2906
2907        let _ = mix.predict_log_prob(0);
2908        let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2909        assert_eq!(after_predict, 2);
2910
2911        let _ = mix.step(0);
2912        let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2913        assert_eq!(after_step, after_predict);
2914    }
2915
2916    #[test]
2917    fn neural_predict_multiple_symbols_reuses_single_evaluation() {
2918        let c0 = Arc::new(AtomicUsize::new(0));
2919        let c1 = Arc::new(AtomicUsize::new(0));
2920        let cfg0 = {
2921            let c = c0.clone();
2922            ExpertConfig::uniform("c0", move || Box::new(CountingPredict { calls: c.clone() }))
2923        };
2924        let cfg1 = {
2925            let c = c1.clone();
2926            ExpertConfig::uniform("c1", move || Box::new(CountingPredict { calls: c.clone() }))
2927        };
2928        let mut mix = NeuralMixture::new(&[cfg0, cfg1], 0.03);
2929
2930        let _ = mix.predict_log_prob(0);
2931        let after_first = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2932        assert_eq!(after_first, 2);
2933
2934        let _ = mix.predict_log_prob(1);
2935        let after_second = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2936        assert_eq!(after_second, after_first + 2);
2937    }
2938
2939    #[test]
2940    fn neural_fill_then_step_reuses_cached_full_rows() {
2941        let log0 = Arc::new(AtomicUsize::new(0));
2942        let log1 = Arc::new(AtomicUsize::new(0));
2943        let fill0 = Arc::new(AtomicUsize::new(0));
2944        let fill1 = Arc::new(AtomicUsize::new(0));
2945        let cfg0 = {
2946            let log_calls = log0.clone();
2947            let fill_calls = fill0.clone();
2948            ExpertConfig::uniform("c0", move || {
2949                Box::new(CountingFillPredict {
2950                    log_calls: log_calls.clone(),
2951                    fill_calls: fill_calls.clone(),
2952                })
2953            })
2954        };
2955        let cfg1 = {
2956            let log_calls = log1.clone();
2957            let fill_calls = fill1.clone();
2958            ExpertConfig::uniform("c1", move || {
2959                Box::new(CountingFillPredict {
2960                    log_calls: log_calls.clone(),
2961                    fill_calls: fill_calls.clone(),
2962                })
2963            })
2964        };
2965        let mut mix = NeuralMixture::new(&[cfg0, cfg1], 0.03);
2966
2967        let mut row = [0.0; 256];
2968        mix.fill_log_probs(&mut row);
2969        assert_eq!(fill0.load(Ordering::Relaxed), 1);
2970        assert_eq!(fill1.load(Ordering::Relaxed), 1);
2971        assert_eq!(log0.load(Ordering::Relaxed), 0);
2972        assert_eq!(log1.load(Ordering::Relaxed), 0);
2973
2974        let _ = mix.step(0);
2975        assert_eq!(fill0.load(Ordering::Relaxed), 1);
2976        assert_eq!(fill1.load(Ordering::Relaxed), 1);
2977        assert_eq!(log0.load(Ordering::Relaxed), 0);
2978        assert_eq!(log1.load(Ordering::Relaxed), 0);
2979    }
2980
2981    #[test]
2982    fn runtime_begin_stream_propagates_to_experts() {
2983        let seen_total = Arc::new(AtomicU64::new(0));
2984        let cfg = {
2985            let seen_total = seen_total.clone();
2986            ExpertConfig::uniform("begin-aware", move || {
2987                Box::new(BeginAwarePredict {
2988                    seen_total: seen_total.clone(),
2989                    began: false,
2990                })
2991            })
2992        };
2993
2994        let spec = MixtureSpec::new(MixtureKind::Bayes, vec![]);
2995        let mut runtime = build_mixture_runtime(&spec, &[cfg]).expect("runtime");
2996        runtime.begin_stream(Some(123)).expect("begin stream");
2997        let _ = runtime.step(0);
2998        assert_eq!(seen_total.load(Ordering::Relaxed), 123);
2999    }
3000
3001    #[test]
3002    fn zpaq_fill_log_probs_does_not_drift_history() {
3003        let backend = RateBackend::Zpaq {
3004            method: "1".to_string(),
3005        };
3006        let mut baseline =
3007            RateBackendPredictor::from_backend(backend.clone(), -1, DEFAULT_MIN_PROB);
3008        let mut probe = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
3009
3010        let history = b"history for zpaq predictor";
3011        for &b in history {
3012            baseline.update(b);
3013            probe.update(b);
3014        }
3015
3016        let mut row = [0.0f64; 256];
3017        probe.fill_log_probs(&mut row);
3018
3019        let sym = b'k';
3020        let lp_base = baseline.log_prob(sym);
3021        let lp_probe = probe.log_prob(sym);
3022        assert!((lp_base - lp_probe).abs() < 1e-9);
3023        assert!((row[sym as usize] - lp_base).abs() < 1e-9);
3024
3025        baseline.update(sym);
3026        probe.update(sym);
3027        let next = b'q';
3028        let next_base = baseline.log_prob(next);
3029        let next_probe = probe.log_prob(next);
3030        assert!((next_base - next_probe).abs() < 1e-9);
3031    }
3032
3033    fn assert_predictor_log_probs_normalize_to_one(backend: RateBackend) {
3034        let mut predictor = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
3035        for &b in b"normalization corpus for ctw/fac predictor checks" {
3036            predictor.update(b);
3037        }
3038        let mut sum = 0.0f64;
3039        for sym in 0u8..=255u8 {
3040            sum += predictor.log_prob(sym).exp();
3041        }
3042        assert!(
3043            (sum - 1.0).abs() <= 1e-10,
3044            "probability mass drift: sum={sum}"
3045        );
3046    }
3047
3048    #[test]
3049    fn ctw_predictor_symbol_probs_normalize() {
3050        assert_predictor_log_probs_normalize_to_one(RateBackend::Ctw { depth: 7 });
3051    }
3052
3053    #[test]
3054    fn fac_ctw_predictor_symbol_probs_normalize() {
3055        assert_predictor_log_probs_normalize_to_one(RateBackend::FacCtw {
3056            base_depth: 7,
3057            num_percept_bits: 8,
3058            encoding_bits: 8,
3059        });
3060    }
3061}