infotheory/
mixture.rs

1//! Online mixtures of probabilistic predictors (log-loss Hedge / Bayes, switching, MDL).
2//!
3//! This module provides a small, rigorously correct toolkit for sequential model mixing.
4//! Predictors expose per-symbol log-probabilities, which allows principled Bayesian
5//! mixture updates and clean information-theoretic accounting.
6//!
7//! ## Rate-Backend Mixtures
8//!
9//! The mixture primitives here power `RateBackend::Mixture`, enabling Bayes, fading Bayes,
10//! switching, and MDL-style selectors to be used anywhere a rate backend is accepted.
11
12use crate::ctw::FacContextTree;
13use crate::zpaq_rate::ZpaqRateModel;
14use crate::{MixtureKind, MixtureSpec, RateBackend};
15use rosaplus::RosaPlus;
16use rwkvzip::coders::softmax_pdf_floor_inplace;
17use std::sync::Arc;
18
19/// Default minimum probability floor to avoid log(0).
20pub const DEFAULT_MIN_PROB: f64 = 5.960_464_477_539_063e-8;
21
22#[inline]
23fn clamp_prob(p: f64, min_prob: f64) -> f64 {
24    if p.is_finite() {
25        p.max(min_prob)
26    } else {
27        min_prob
28    }
29}
30
31#[inline]
32fn logsumexp(xs: &[f64]) -> f64 {
33    let mut max_v = f64::NEG_INFINITY;
34    for &v in xs {
35        if v > max_v {
36            max_v = v;
37        }
38    }
39    if !max_v.is_finite() {
40        return max_v;
41    }
42    let mut sum = 0.0;
43    for &v in xs {
44        sum += (v - max_v).exp();
45    }
46    max_v + sum.ln()
47}
48
49#[inline]
50fn logsumexp2(a: f64, b: f64) -> f64 {
51    let m = if a > b { a } else { b };
52    if !m.is_finite() {
53        return m;
54    }
55    m + ((a - m).exp() + (b - m).exp()).ln()
56}
57
58#[inline]
59fn logsumexp_weights(experts: &[ExpertState]) -> f64 {
60    let mut max_v = f64::NEG_INFINITY;
61    for e in experts {
62        if e.log_weight > max_v {
63            max_v = e.log_weight;
64        }
65    }
66    if !max_v.is_finite() {
67        return max_v;
68    }
69    let mut sum = 0.0;
70    for e in experts {
71        sum += (e.log_weight - max_v).exp();
72    }
73    max_v + sum.ln()
74}
75
76/// Trait for online byte-level predictors that expose per-symbol log-probabilities.
77pub trait OnlineBytePredictor: Send {
78    /// Log-probability (natural log) of `symbol` given the current history.
79    fn log_prob(&mut self, symbol: u8) -> f64;
80
81    /// Update the predictor with the observed `symbol`.
82    fn update(&mut self, symbol: u8);
83}
84
85/// A concrete online predictor backed by a `RateBackend` configuration.
86pub enum RateBackendPredictor {
87    /// ROSA-Plus online suffix automaton.
88    Rosa {
89        model: RosaPlus,
90        min_prob: f64,
91    },
92    /// Byte-wise CTW implemented as 8 factorized bit trees (MSB-first).
93    Ctw {
94        tree: FacContextTree,
95        min_prob: f64,
96    },
97    /// Factorized CTW with configurable bit-encoding (LSB-first).
98    FacCtw {
99        tree: FacContextTree,
100        bits_per_symbol: usize,
101        min_prob: f64,
102    },
103    /// RWKV-7 neural predictor.
104    Rwkv7 {
105        compressor: rwkvzip::Compressor,
106        primed: bool,
107        min_prob: f64,
108    },
109    /// ZPAQ streaming rate model.
110    Zpaq {
111        model: ZpaqRateModel,
112    },
113    /// Online mixture over experts (Bayes, fading Bayes, switching, MDL).
114    Mixture {
115        runtime: MixtureRuntime,
116        pending_symbol: Option<u8>,
117        pending_logp: f64,
118    },
119}
120
121impl RateBackendPredictor {
122    /// Create a new online predictor from a rate backend configuration.
123    pub fn from_backend(backend: RateBackend, max_order: i64, min_prob: f64) -> Self {
124        match backend {
125            RateBackend::RosaPlus => {
126                let mut model = RosaPlus::new(max_order, false, 0, 42);
127                model.build_lm_full_bytes_no_finalize_endpos();
128                Self::Rosa { model, min_prob }
129            }
130            RateBackend::Ctw { depth } => {
131                let tree = FacContextTree::new(depth, 8);
132                Self::Ctw { tree, min_prob }
133            }
134            RateBackend::FacCtw {
135                base_depth,
136                num_percept_bits: _,
137                encoding_bits,
138            } => {
139                let bits_per_symbol = encoding_bits.min(8).max(1);
140                let tree = FacContextTree::new(base_depth, bits_per_symbol);
141                Self::FacCtw {
142                    tree,
143                    bits_per_symbol,
144                    min_prob,
145                }
146            }
147            RateBackend::Rwkv7 { model } => {
148                let mut compressor = rwkvzip::Compressor::new_from_model(model);
149                let vocab_size = compressor.vocab_size();
150                let logits = compressor
151                    .model
152                    .forward(&mut compressor.scratch, 0, &mut compressor.state);
153                softmax_pdf_floor_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
154                Self::Rwkv7 {
155                    compressor,
156                    primed: true,
157                    min_prob,
158                }
159            }
160            RateBackend::Zpaq { method } => {
161                let model = ZpaqRateModel::new(method, min_prob);
162                Self::Zpaq { model }
163            }
164            RateBackend::Mixture { spec } => {
165                let experts = spec.build_experts();
166                let runtime = build_mixture_runtime(spec.as_ref(), &experts)
167                    .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
168                Self::Mixture {
169                    runtime,
170                    pending_symbol: None,
171                    pending_logp: 0.0,
172                }
173            }
174        }
175    }
176
177    /// Human-readable default name for a backend + config.
178    pub fn default_name(backend: &RateBackend, max_order: i64) -> String {
179        match backend {
180            RateBackend::RosaPlus => format!("rosa(mo={})", max_order),
181            RateBackend::Ctw { depth } => format!("ctw(d={})", depth),
182            RateBackend::FacCtw {
183                base_depth,
184                encoding_bits,
185                ..
186            } => format!("fac-ctw(d={},b={})", base_depth, encoding_bits),
187            RateBackend::Rwkv7 { .. } => "rwkv7".to_string(),
188            RateBackend::Zpaq { method } => format!("zpaq(m={})", method),
189            RateBackend::Mixture { spec } => {
190                let kind = match spec.kind {
191                    MixtureKind::Bayes => "bayes",
192                    MixtureKind::FadingBayes => "fading",
193                    MixtureKind::Switching => "switch",
194                    MixtureKind::Mdl => "mdl",
195                };
196                format!("mix({})", kind)
197            }
198        }
199    }
200}
201
202impl OnlineBytePredictor for RateBackendPredictor {
203    fn log_prob(&mut self, symbol: u8) -> f64 {
204        match self {
205            RateBackendPredictor::Rosa { model, min_prob } => {
206                let p = clamp_prob(model.prob_for_last(symbol as u32), *min_prob);
207                p.ln()
208            }
209            RateBackendPredictor::Ctw { tree, min_prob } => {
210                let log_before = tree.get_log_block_probability();
211                for bit_idx in 0..8 {
212                    let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
213                    tree.update(bit, bit_idx);
214                }
215                let log_after = tree.get_log_block_probability();
216                for bit_idx in (0..8).rev() {
217                    tree.revert(bit_idx);
218                }
219                let logp = log_after - log_before;
220                if logp.is_finite() {
221                    logp.max(min_prob.ln())
222                } else {
223                    min_prob.ln()
224                }
225            }
226            RateBackendPredictor::FacCtw {
227                tree,
228                bits_per_symbol,
229                min_prob,
230            } => {
231                let log_before = tree.get_log_block_probability();
232                for i in 0..*bits_per_symbol {
233                    let bit = ((symbol >> i) & 1) == 1;
234                    tree.update(bit, i);
235                }
236                let log_after = tree.get_log_block_probability();
237                for i in (0..*bits_per_symbol).rev() {
238                    tree.revert(i);
239                }
240                let logp = log_after - log_before;
241                if logp.is_finite() {
242                    logp.max(min_prob.ln())
243                } else {
244                    min_prob.ln()
245                }
246            }
247            RateBackendPredictor::Rwkv7 {
248                compressor,
249                primed,
250                min_prob,
251            } => {
252                if !*primed {
253                    let vocab_size = compressor.vocab_size();
254                    let logits = compressor
255                        .model
256                        .forward(&mut compressor.scratch, 0, &mut compressor.state);
257                    softmax_pdf_floor_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
258                    *primed = true;
259                }
260                let p = clamp_prob(compressor.pdf_buffer[symbol as usize], *min_prob);
261                p.ln()
262            }
263            RateBackendPredictor::Zpaq { model } => model.log_prob(symbol),
264            RateBackendPredictor::Mixture {
265                runtime,
266                pending_symbol,
267                pending_logp,
268            } => {
269                if let Some(pending) = *pending_symbol {
270                    if pending == symbol {
271                        return *pending_logp;
272                    }
273                    *pending_symbol = None;
274                }
275                let logp = runtime.step(symbol);
276                *pending_symbol = Some(symbol);
277                *pending_logp = logp;
278                logp
279            }
280        }
281    }
282
283    fn update(&mut self, symbol: u8) {
284        match self {
285            RateBackendPredictor::Rosa { model, .. } => {
286                let mut tx = model.begin_tx();
287                model.train_sequence_tx(&mut tx, &[symbol]);
288            }
289            RateBackendPredictor::Ctw { tree, .. } => {
290                for bit_idx in 0..8 {
291                    let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
292                    tree.update(bit, bit_idx);
293                }
294            }
295            RateBackendPredictor::FacCtw {
296                tree,
297                bits_per_symbol,
298                ..
299            } => {
300                for i in 0..*bits_per_symbol {
301                    let bit = ((symbol >> i) & 1) == 1;
302                    tree.update(bit, i);
303                }
304            }
305            RateBackendPredictor::Rwkv7 {
306                compressor,
307                primed,
308                ..
309            } => {
310                if !*primed {
311                    let vocab_size = compressor.vocab_size();
312                    let logits = compressor
313                        .model
314                        .forward(&mut compressor.scratch, 0, &mut compressor.state);
315                    softmax_pdf_floor_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
316                    *primed = true;
317                }
318                let vocab_size = compressor.vocab_size();
319                let logits = compressor.model.forward(
320                    &mut compressor.scratch,
321                    symbol as u32,
322                    &mut compressor.state,
323                );
324                softmax_pdf_floor_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
325            }
326            RateBackendPredictor::Zpaq { model } => {
327                model.update(symbol);
328            }
329            RateBackendPredictor::Mixture {
330                runtime,
331                pending_symbol,
332                ..
333            } => {
334                if let Some(pending) = *pending_symbol {
335                    if pending == symbol {
336                        *pending_symbol = None;
337                        return;
338                    }
339                }
340                *pending_symbol = None;
341                let _ = runtime.step(symbol);
342            }
343        }
344    }
345}
346
347/// Configuration for a mixture expert.
348#[derive(Clone)]
349pub struct ExpertConfig {
350    pub name: String,
351    /// Log prior weight (natural log). Uniform priors can be `0.0`.
352    pub log_prior: f64,
353    builder: Arc<dyn Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync>,
354}
355
356impl ExpertConfig {
357    /// Create a new expert config from a builder closure.
358    pub fn new(
359        name: impl Into<String>,
360        log_prior: f64,
361        builder: impl Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync + 'static,
362    ) -> Self {
363        Self {
364            name: name.into(),
365            log_prior,
366            builder: Arc::new(builder),
367        }
368    }
369
370    /// Uniform prior helper.
371    pub fn uniform(
372        name: impl Into<String>,
373        builder: impl Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync + 'static,
374    ) -> Self {
375        Self::new(name, 0.0, builder)
376    }
377
378    /// Expert from a `RateBackend` configuration. `max_order` applies to ROSA.
379    pub fn from_rate_backend(
380        name: Option<String>,
381        log_prior: f64,
382        backend: RateBackend,
383        max_order: i64,
384    ) -> Self {
385        let name = name.unwrap_or_else(|| RateBackendPredictor::default_name(&backend, max_order));
386        Self::new(name, log_prior, move || {
387            Box::new(RateBackendPredictor::from_backend(
388                backend.clone(),
389                max_order,
390                DEFAULT_MIN_PROB,
391            ))
392        })
393    }
394
395    /// ROSA expert (uniform prior).
396    pub fn rosa(name: impl Into<String>, max_order: i64) -> Self {
397        let name = name.into();
398        Self::uniform(name, move || {
399            Box::new(RateBackendPredictor::from_backend(
400                RateBackend::RosaPlus,
401                max_order,
402                DEFAULT_MIN_PROB,
403            ))
404        })
405    }
406
407    /// CTW expert (uniform prior).
408    pub fn ctw(name: impl Into<String>, depth: usize) -> Self {
409        let name = name.into();
410        Self::uniform(name, move || {
411            Box::new(RateBackendPredictor::from_backend(
412                RateBackend::Ctw { depth },
413                -1,
414                DEFAULT_MIN_PROB,
415            ))
416        })
417    }
418
419    /// FAC-CTW expert (uniform prior).
420    pub fn fac_ctw(name: impl Into<String>, base_depth: usize, encoding_bits: usize) -> Self {
421        let name = name.into();
422        Self::uniform(name, move || {
423            Box::new(RateBackendPredictor::from_backend(
424                RateBackend::FacCtw {
425                    base_depth,
426                    num_percept_bits: encoding_bits,
427                    encoding_bits,
428                },
429                -1,
430                DEFAULT_MIN_PROB,
431            ))
432        })
433    }
434
435    /// RWKV-7 expert (uniform prior).
436    pub fn rwkv(name: impl Into<String>, model: Arc<rwkvzip::Model>) -> Self {
437        let name = name.into();
438        Self::uniform(name, move || {
439            Box::new(RateBackendPredictor::from_backend(
440                RateBackend::Rwkv7 { model: model.clone() },
441                -1,
442                DEFAULT_MIN_PROB,
443            ))
444        })
445    }
446
447    /// ZPAQ expert (uniform prior).
448    pub fn zpaq(name: impl Into<String>, method: impl Into<String>) -> Self {
449        let name = name.into();
450        let method = method.into();
451        Self::uniform(name, move || {
452            Box::new(RateBackendPredictor::from_backend(
453                RateBackend::Zpaq {
454                    method: method.clone(),
455                },
456                -1,
457                DEFAULT_MIN_PROB,
458            ))
459        })
460    }
461
462    /// Expert name.
463    pub fn name(&self) -> &str {
464        &self.name
465    }
466
467    /// Log prior weight (unnormalized).
468    pub fn log_prior(&self) -> f64 {
469        self.log_prior
470    }
471
472    /// Build a fresh predictor instance for evaluation or analysis.
473    pub fn build_predictor(&self) -> Box<dyn OnlineBytePredictor> {
474        (self.builder)()
475    }
476
477    fn build(&self) -> ExpertState {
478        ExpertState {
479            name: self.name.clone(),
480            log_weight: self.log_prior,
481            log_prior: self.log_prior,
482            predictor: (self.builder)(),
483            cum_log_loss: 0.0,
484        }
485    }
486}
487
488struct ExpertState {
489    name: String,
490    log_weight: f64,
491    log_prior: f64,
492    predictor: Box<dyn OnlineBytePredictor>,
493    cum_log_loss: f64,
494}
495
496impl ExpertState {
497    #[inline]
498    fn log_prob(&mut self, symbol: u8) -> f64 {
499        self.predictor.log_prob(symbol)
500    }
501
502    #[inline]
503    fn update(&mut self, symbol: u8) {
504        self.predictor.update(symbol);
505    }
506}
507
508/// Exponential-weights Bayes mixture (log-loss Hedge).
509pub struct BayesMixture {
510    experts: Vec<ExpertState>,
511    scratch_logps: Vec<f64>,
512    scratch_mix: Vec<f64>,
513    total_log_loss: f64,
514}
515
516impl BayesMixture {
517    pub fn new(configs: &[ExpertConfig]) -> Self {
518        let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
519        let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
520        let norm = logsumexp(&log_priors);
521        for e in &mut experts {
522            e.log_weight -= norm;
523        }
524        Self {
525            experts,
526            scratch_logps: vec![0.0; configs.len()],
527            scratch_mix: vec![0.0; configs.len()],
528            total_log_loss: 0.0,
529        }
530    }
531
532    /// Log-probability (natural log) of the mixture for `symbol`, then update.
533    pub fn step(&mut self, symbol: u8) -> f64 {
534        if self.experts.is_empty() {
535            return f64::NEG_INFINITY;
536        }
537        for (i, expert) in self.experts.iter_mut().enumerate() {
538            self.scratch_logps[i] = expert.log_prob(symbol);
539            self.scratch_mix[i] = expert.log_weight + self.scratch_logps[i];
540        }
541        let log_mix = logsumexp(&self.scratch_mix);
542        for (i, expert) in self.experts.iter_mut().enumerate() {
543            expert.log_weight = expert.log_weight + self.scratch_logps[i] - log_mix;
544            expert.cum_log_loss -= self.scratch_logps[i];
545            expert.update(symbol);
546        }
547        self.total_log_loss -= log_mix;
548        log_mix
549    }
550
551    /// Posterior weights (normalized) over experts.
552    pub fn posterior(&self) -> Vec<f64> {
553        let norm = logsumexp_weights(&self.experts);
554        self.experts
555            .iter()
556            .map(|e| (e.log_weight - norm).exp())
557            .collect()
558    }
559
560    /// Index and log-loss (nats) of the current best expert.
561    pub fn min_expert_log_loss(&self) -> (usize, f64) {
562        let mut best_idx = 0usize;
563        let mut best_loss = f64::INFINITY;
564        for (i, e) in self.experts.iter().enumerate() {
565            if e.cum_log_loss < best_loss {
566                best_loss = e.cum_log_loss;
567                best_idx = i;
568            }
569        }
570        (best_idx, best_loss)
571    }
572
573    /// Index and posterior mass of the most likely expert.
574    pub fn max_posterior(&self) -> (usize, f64) {
575        let norm = logsumexp_weights(&self.experts);
576        let mut best_idx = 0usize;
577        let mut best_p = 0.0;
578        for (i, e) in self.experts.iter().enumerate() {
579            let p = (e.log_weight - norm).exp();
580            if p > best_p {
581                best_p = p;
582                best_idx = i;
583            }
584        }
585        (best_idx, best_p)
586    }
587
588    /// Total log-loss of the mixture so far (nats).
589    pub fn total_log_loss(&self) -> f64 {
590        self.total_log_loss
591    }
592
593    /// Expert cumulative log-losses (nats) and names.
594    pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
595        self.experts
596            .iter()
597            .map(|e| (e.name.clone(), e.cum_log_loss))
598            .collect()
599    }
600
601    /// Expert names in order.
602    pub fn expert_names(&self) -> Vec<String> {
603        self.experts.iter().map(|e| e.name.clone()).collect()
604    }
605}
606
607/// Exponential-weights Bayes mixture with exponential forgetting on weights.
608///
609/// This is a non-stationary control: weights are discounted each step by `decay`.
610pub struct FadingBayesMixture {
611    experts: Vec<ExpertState>,
612    decay: f64,
613    scratch_logps: Vec<f64>,
614    scratch_mix: Vec<f64>,
615    total_log_loss: f64,
616}
617
618impl FadingBayesMixture {
619    pub fn new(configs: &[ExpertConfig], decay: f64) -> Self {
620        let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
621        let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
622        let norm = logsumexp(&log_priors);
623        for e in &mut experts {
624            e.log_weight -= norm;
625        }
626        let decay = decay.clamp(0.0, 1.0);
627        Self {
628            experts,
629            decay,
630            scratch_logps: vec![0.0; configs.len()],
631            scratch_mix: vec![0.0; configs.len()],
632            total_log_loss: 0.0,
633        }
634    }
635
636    /// Log-probability (natural log) of the fading mixture for `symbol`, then update.
637    pub fn step(&mut self, symbol: u8) -> f64 {
638        if self.experts.is_empty() {
639            return f64::NEG_INFINITY;
640        }
641        for (i, expert) in self.experts.iter_mut().enumerate() {
642            self.scratch_logps[i] = expert.log_prob(symbol);
643            let decayed = self.decay * expert.log_weight;
644            self.scratch_mix[i] = decayed + self.scratch_logps[i];
645        }
646        let log_mix = logsumexp(&self.scratch_mix);
647        for (i, expert) in self.experts.iter_mut().enumerate() {
648            let decayed = self.decay * expert.log_weight;
649            expert.log_weight = decayed + self.scratch_logps[i] - log_mix;
650            expert.cum_log_loss -= self.scratch_logps[i];
651            expert.update(symbol);
652        }
653        self.total_log_loss -= log_mix;
654        log_mix
655    }
656
657    /// Posterior weights (normalized) over experts.
658    pub fn posterior(&self) -> Vec<f64> {
659        let norm = logsumexp_weights(&self.experts);
660        self.experts
661            .iter()
662            .map(|e| (e.log_weight - norm).exp())
663            .collect()
664    }
665
666    /// Index and log-loss (nats) of the current best expert (non-discounted loss).
667    pub fn min_expert_log_loss(&self) -> (usize, f64) {
668        let mut best_idx = 0usize;
669        let mut best_loss = f64::INFINITY;
670        for (i, e) in self.experts.iter().enumerate() {
671            if e.cum_log_loss < best_loss {
672                best_loss = e.cum_log_loss;
673                best_idx = i;
674            }
675        }
676        (best_idx, best_loss)
677    }
678
679    /// Total log-loss of the mixture so far (nats).
680    pub fn total_log_loss(&self) -> f64 {
681        self.total_log_loss
682    }
683
684    /// Expert names in order.
685    pub fn expert_names(&self) -> Vec<String> {
686        self.experts.iter().map(|e| e.name.clone()).collect()
687    }
688}
689
690/// Switching mixture: allows occasional switches between experts.
691pub struct SwitchingMixture {
692    experts: Vec<ExpertState>,
693    log_prior: Vec<f64>,
694    log_alpha: f64,
695    log_1m_alpha: f64,
696    scratch_logps: Vec<f64>,
697    scratch_switch: Vec<f64>,
698    total_log_loss: f64,
699}
700
701impl SwitchingMixture {
702    pub fn new(configs: &[ExpertConfig], alpha: f64) -> Self {
703        let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
704        let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
705        let norm = logsumexp(&log_priors);
706        for e in &mut experts {
707            e.log_weight -= norm;
708        }
709        let log_prior: Vec<f64> = experts.iter().map(|e| e.log_prior - norm).collect();
710        let alpha = alpha.clamp(1e-12, 1.0 - 1e-12);
711        Self {
712            experts,
713            log_prior,
714            log_alpha: alpha.ln(),
715            log_1m_alpha: (1.0 - alpha).ln(),
716            scratch_logps: vec![0.0; configs.len()],
717            scratch_switch: vec![0.0; configs.len()],
718            total_log_loss: 0.0,
719        }
720    }
721
722    /// Log-probability (natural log) of the switching mixture for `symbol`, then update.
723    pub fn step(&mut self, symbol: u8) -> f64 {
724        if self.experts.is_empty() {
725            return f64::NEG_INFINITY;
726        }
727        for (i, expert) in self.experts.iter_mut().enumerate() {
728            self.scratch_logps[i] = expert.log_prob(symbol);
729        }
730
731        for i in 0..self.experts.len() {
732            let log_switch = logsumexp2(
733                self.log_1m_alpha + self.experts[i].log_weight,
734                self.log_alpha + self.log_prior[i],
735            );
736            self.scratch_switch[i] = self.scratch_logps[i] + log_switch;
737        }
738        let log_mix = logsumexp(&self.scratch_switch);
739        for i in 0..self.experts.len() {
740            let expert = &mut self.experts[i];
741            expert.log_weight = self.scratch_switch[i] - log_mix;
742            expert.cum_log_loss -= self.scratch_logps[i];
743            expert.update(symbol);
744        }
745        self.total_log_loss -= log_mix;
746        log_mix
747    }
748
749    /// Posterior weights (normalized) over experts.
750    pub fn posterior(&self) -> Vec<f64> {
751        let norm = logsumexp_weights(&self.experts);
752        self.experts
753            .iter()
754            .map(|e| (e.log_weight - norm).exp())
755            .collect()
756    }
757
758    /// Index and log-loss (nats) of the current best expert.
759    pub fn min_expert_log_loss(&self) -> (usize, f64) {
760        let mut best_idx = 0usize;
761        let mut best_loss = f64::INFINITY;
762        for (i, e) in self.experts.iter().enumerate() {
763            if e.cum_log_loss < best_loss {
764                best_loss = e.cum_log_loss;
765                best_idx = i;
766            }
767        }
768        (best_idx, best_loss)
769    }
770
771    /// Index and posterior mass of the most likely expert.
772    pub fn max_posterior(&self) -> (usize, f64) {
773        let norm = logsumexp_weights(&self.experts);
774        let mut best_idx = 0usize;
775        let mut best_p = 0.0;
776        for (i, e) in self.experts.iter().enumerate() {
777            let p = (e.log_weight - norm).exp();
778            if p > best_p {
779                best_p = p;
780                best_idx = i;
781            }
782        }
783        (best_idx, best_p)
784    }
785
786    /// Total log-loss of the mixture so far (nats).
787    pub fn total_log_loss(&self) -> f64 {
788        self.total_log_loss
789    }
790
791    /// Expert cumulative log-losses (nats) and names.
792    pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
793        self.experts
794            .iter()
795            .map(|e| (e.name.clone(), e.cum_log_loss))
796            .collect()
797    }
798
799    /// Expert names in order.
800    pub fn expert_names(&self) -> Vec<String> {
801        self.experts.iter().map(|e| e.name.clone()).collect()
802    }
803}
804
805/// MDL-style selector: predicts with the current best expert (by cumulative loss).
806pub struct MdlSelector {
807    experts: Vec<ExpertState>,
808    scratch_logps: Vec<f64>,
809    total_log_loss: f64,
810    last_best: usize,
811}
812
813impl MdlSelector {
814    pub fn new(configs: &[ExpertConfig]) -> Self {
815        let experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
816        let last_best = 0usize;
817        Self {
818            experts,
819            scratch_logps: vec![0.0; configs.len()],
820            total_log_loss: 0.0,
821            last_best,
822        }
823    }
824
825    /// Log-probability (natural log) of the MDL selector for `symbol`, then update.
826    pub fn step(&mut self, symbol: u8) -> f64 {
827        if self.experts.is_empty() {
828            return f64::NEG_INFINITY;
829        }
830        for (i, expert) in self.experts.iter_mut().enumerate() {
831            self.scratch_logps[i] = expert.log_prob(symbol);
832        }
833        let mut best_idx = 0usize;
834        let mut best_loss = f64::INFINITY;
835        for (i, expert) in self.experts.iter().enumerate() {
836            if expert.cum_log_loss < best_loss {
837                best_loss = expert.cum_log_loss;
838                best_idx = i;
839            }
840        }
841        let logp = self.scratch_logps[best_idx];
842        for (i, expert) in self.experts.iter_mut().enumerate() {
843            expert.cum_log_loss -= self.scratch_logps[i];
844            expert.update(symbol);
845        }
846        self.total_log_loss -= logp;
847        self.last_best = best_idx;
848        logp
849    }
850
851    /// Index of the current best expert.
852    pub fn best_index(&self) -> usize {
853        self.last_best
854    }
855
856    /// Index and log-loss (nats) of the current best expert.
857    pub fn min_expert_log_loss(&self) -> (usize, f64) {
858        let mut best_idx = 0usize;
859        let mut best_loss = f64::INFINITY;
860        for (i, e) in self.experts.iter().enumerate() {
861            if e.cum_log_loss < best_loss {
862                best_loss = e.cum_log_loss;
863                best_idx = i;
864            }
865        }
866        (best_idx, best_loss)
867    }
868
869    /// Total log-loss of the selector so far (nats).
870    pub fn total_log_loss(&self) -> f64 {
871        self.total_log_loss
872    }
873
874    /// Expert cumulative log-losses (nats) and names.
875    pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
876        self.experts
877            .iter()
878            .map(|e| (e.name.clone(), e.cum_log_loss))
879            .collect()
880    }
881
882    /// Expert names in order.
883    pub fn expert_names(&self) -> Vec<String> {
884        self.experts.iter().map(|e| e.name.clone()).collect()
885    }
886}
887
888// =============================================================================
889// Mixture Runtime Helper (for RateBackend::Mixture)
890// =============================================================================
891
892pub enum MixtureRuntime {
893    Bayes(BayesMixture),
894    Fading(FadingBayesMixture),
895    Switching(SwitchingMixture),
896    Mdl(MdlSelector),
897}
898
899impl MixtureRuntime {
900    /// Step the mixture and return log-probability (nats).
901    pub(crate) fn step(&mut self, symbol: u8) -> f64 {
902        match self {
903            MixtureRuntime::Bayes(m) => m.step(symbol),
904            MixtureRuntime::Fading(m) => m.step(symbol),
905            MixtureRuntime::Switching(m) => m.step(symbol),
906            MixtureRuntime::Mdl(m) => m.step(symbol),
907        }
908    }
909}
910
911pub(crate) fn build_mixture_runtime(
912    spec: &MixtureSpec,
913    experts: &[ExpertConfig],
914) -> Result<MixtureRuntime, String> {
915    if experts.is_empty() {
916        return Err("mixture spec must include at least one expert".to_string());
917    }
918    match spec.kind {
919        MixtureKind::Bayes => Ok(MixtureRuntime::Bayes(BayesMixture::new(experts))),
920        MixtureKind::FadingBayes => {
921            let decay = spec
922                .decay
923                .ok_or_else(|| "fading Bayes mixture requires decay".to_string())?;
924            Ok(MixtureRuntime::Fading(FadingBayesMixture::new(
925                experts, decay,
926            )))
927        }
928        MixtureKind::Switching => Ok(MixtureRuntime::Switching(SwitchingMixture::new(
929            experts, spec.alpha,
930        ))),
931        MixtureKind::Mdl => Ok(MixtureRuntime::Mdl(MdlSelector::new(experts))),
932    }
933}
934
935#[cfg(test)]
936mod tests {
937    use super::*;
938
939    struct AlwaysPredict {
940        byte: u8,
941    }
942
943    impl OnlineBytePredictor for AlwaysPredict {
944        fn log_prob(&mut self, symbol: u8) -> f64 {
945            if symbol == self.byte {
946                0.0
947            } else {
948                f64::NEG_INFINITY
949            }
950        }
951
952        fn update(&mut self, _symbol: u8) {}
953    }
954
955    #[test]
956    fn bayes_mixture_prefers_correct_expert() {
957        let configs = vec![
958            ExpertConfig::uniform("zero", || Box::new(AlwaysPredict { byte: 0 })),
959            ExpertConfig::uniform("one", || Box::new(AlwaysPredict { byte: 1 })),
960        ];
961        let mut mix = BayesMixture::new(&configs);
962        for _ in 0..10 {
963            mix.step(0);
964        }
965        let post = mix.posterior();
966        assert!(post[0] > 0.999);
967        assert!(post[1] < 1e-6);
968    }
969}