1use crate::ctw::FacContextTree;
13use crate::zpaq_rate::ZpaqRateModel;
14use crate::{MixtureKind, MixtureSpec, RateBackend};
15use rosaplus::RosaPlus;
16use rwkvzip::coders::softmax_pdf_floor_inplace;
17use std::sync::Arc;
18
19pub const DEFAULT_MIN_PROB: f64 = 5.960_464_477_539_063e-8;
21
22#[inline]
23fn clamp_prob(p: f64, min_prob: f64) -> f64 {
24 if p.is_finite() {
25 p.max(min_prob)
26 } else {
27 min_prob
28 }
29}
30
31#[inline]
32fn logsumexp(xs: &[f64]) -> f64 {
33 let mut max_v = f64::NEG_INFINITY;
34 for &v in xs {
35 if v > max_v {
36 max_v = v;
37 }
38 }
39 if !max_v.is_finite() {
40 return max_v;
41 }
42 let mut sum = 0.0;
43 for &v in xs {
44 sum += (v - max_v).exp();
45 }
46 max_v + sum.ln()
47}
48
49#[inline]
50fn logsumexp2(a: f64, b: f64) -> f64 {
51 let m = if a > b { a } else { b };
52 if !m.is_finite() {
53 return m;
54 }
55 m + ((a - m).exp() + (b - m).exp()).ln()
56}
57
58#[inline]
59fn logsumexp_weights(experts: &[ExpertState]) -> f64 {
60 let mut max_v = f64::NEG_INFINITY;
61 for e in experts {
62 if e.log_weight > max_v {
63 max_v = e.log_weight;
64 }
65 }
66 if !max_v.is_finite() {
67 return max_v;
68 }
69 let mut sum = 0.0;
70 for e in experts {
71 sum += (e.log_weight - max_v).exp();
72 }
73 max_v + sum.ln()
74}
75
76pub trait OnlineBytePredictor: Send {
78 fn log_prob(&mut self, symbol: u8) -> f64;
80
81 fn update(&mut self, symbol: u8);
83}
84
85pub enum RateBackendPredictor {
87 Rosa {
89 model: RosaPlus,
90 min_prob: f64,
91 },
92 Ctw {
94 tree: FacContextTree,
95 min_prob: f64,
96 },
97 FacCtw {
99 tree: FacContextTree,
100 bits_per_symbol: usize,
101 min_prob: f64,
102 },
103 Rwkv7 {
105 compressor: rwkvzip::Compressor,
106 primed: bool,
107 min_prob: f64,
108 },
109 Zpaq {
111 model: ZpaqRateModel,
112 },
113 Mixture {
115 runtime: MixtureRuntime,
116 pending_symbol: Option<u8>,
117 pending_logp: f64,
118 },
119}
120
121impl RateBackendPredictor {
122 pub fn from_backend(backend: RateBackend, max_order: i64, min_prob: f64) -> Self {
124 match backend {
125 RateBackend::RosaPlus => {
126 let mut model = RosaPlus::new(max_order, false, 0, 42);
127 model.build_lm_full_bytes_no_finalize_endpos();
128 Self::Rosa { model, min_prob }
129 }
130 RateBackend::Ctw { depth } => {
131 let tree = FacContextTree::new(depth, 8);
132 Self::Ctw { tree, min_prob }
133 }
134 RateBackend::FacCtw {
135 base_depth,
136 num_percept_bits: _,
137 encoding_bits,
138 } => {
139 let bits_per_symbol = encoding_bits.min(8).max(1);
140 let tree = FacContextTree::new(base_depth, bits_per_symbol);
141 Self::FacCtw {
142 tree,
143 bits_per_symbol,
144 min_prob,
145 }
146 }
147 RateBackend::Rwkv7 { model } => {
148 let mut compressor = rwkvzip::Compressor::new_from_model(model);
149 let vocab_size = compressor.vocab_size();
150 let logits = compressor
151 .model
152 .forward(&mut compressor.scratch, 0, &mut compressor.state);
153 softmax_pdf_floor_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
154 Self::Rwkv7 {
155 compressor,
156 primed: true,
157 min_prob,
158 }
159 }
160 RateBackend::Zpaq { method } => {
161 let model = ZpaqRateModel::new(method, min_prob);
162 Self::Zpaq { model }
163 }
164 RateBackend::Mixture { spec } => {
165 let experts = spec.build_experts();
166 let runtime = build_mixture_runtime(spec.as_ref(), &experts)
167 .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
168 Self::Mixture {
169 runtime,
170 pending_symbol: None,
171 pending_logp: 0.0,
172 }
173 }
174 }
175 }
176
177 pub fn default_name(backend: &RateBackend, max_order: i64) -> String {
179 match backend {
180 RateBackend::RosaPlus => format!("rosa(mo={})", max_order),
181 RateBackend::Ctw { depth } => format!("ctw(d={})", depth),
182 RateBackend::FacCtw {
183 base_depth,
184 encoding_bits,
185 ..
186 } => format!("fac-ctw(d={},b={})", base_depth, encoding_bits),
187 RateBackend::Rwkv7 { .. } => "rwkv7".to_string(),
188 RateBackend::Zpaq { method } => format!("zpaq(m={})", method),
189 RateBackend::Mixture { spec } => {
190 let kind = match spec.kind {
191 MixtureKind::Bayes => "bayes",
192 MixtureKind::FadingBayes => "fading",
193 MixtureKind::Switching => "switch",
194 MixtureKind::Mdl => "mdl",
195 };
196 format!("mix({})", kind)
197 }
198 }
199 }
200}
201
202impl OnlineBytePredictor for RateBackendPredictor {
203 fn log_prob(&mut self, symbol: u8) -> f64 {
204 match self {
205 RateBackendPredictor::Rosa { model, min_prob } => {
206 let p = clamp_prob(model.prob_for_last(symbol as u32), *min_prob);
207 p.ln()
208 }
209 RateBackendPredictor::Ctw { tree, min_prob } => {
210 let log_before = tree.get_log_block_probability();
211 for bit_idx in 0..8 {
212 let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
213 tree.update(bit, bit_idx);
214 }
215 let log_after = tree.get_log_block_probability();
216 for bit_idx in (0..8).rev() {
217 tree.revert(bit_idx);
218 }
219 let logp = log_after - log_before;
220 if logp.is_finite() {
221 logp.max(min_prob.ln())
222 } else {
223 min_prob.ln()
224 }
225 }
226 RateBackendPredictor::FacCtw {
227 tree,
228 bits_per_symbol,
229 min_prob,
230 } => {
231 let log_before = tree.get_log_block_probability();
232 for i in 0..*bits_per_symbol {
233 let bit = ((symbol >> i) & 1) == 1;
234 tree.update(bit, i);
235 }
236 let log_after = tree.get_log_block_probability();
237 for i in (0..*bits_per_symbol).rev() {
238 tree.revert(i);
239 }
240 let logp = log_after - log_before;
241 if logp.is_finite() {
242 logp.max(min_prob.ln())
243 } else {
244 min_prob.ln()
245 }
246 }
247 RateBackendPredictor::Rwkv7 {
248 compressor,
249 primed,
250 min_prob,
251 } => {
252 if !*primed {
253 let vocab_size = compressor.vocab_size();
254 let logits = compressor
255 .model
256 .forward(&mut compressor.scratch, 0, &mut compressor.state);
257 softmax_pdf_floor_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
258 *primed = true;
259 }
260 let p = clamp_prob(compressor.pdf_buffer[symbol as usize], *min_prob);
261 p.ln()
262 }
263 RateBackendPredictor::Zpaq { model } => model.log_prob(symbol),
264 RateBackendPredictor::Mixture {
265 runtime,
266 pending_symbol,
267 pending_logp,
268 } => {
269 if let Some(pending) = *pending_symbol {
270 if pending == symbol {
271 return *pending_logp;
272 }
273 *pending_symbol = None;
274 }
275 let logp = runtime.step(symbol);
276 *pending_symbol = Some(symbol);
277 *pending_logp = logp;
278 logp
279 }
280 }
281 }
282
283 fn update(&mut self, symbol: u8) {
284 match self {
285 RateBackendPredictor::Rosa { model, .. } => {
286 let mut tx = model.begin_tx();
287 model.train_sequence_tx(&mut tx, &[symbol]);
288 }
289 RateBackendPredictor::Ctw { tree, .. } => {
290 for bit_idx in 0..8 {
291 let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
292 tree.update(bit, bit_idx);
293 }
294 }
295 RateBackendPredictor::FacCtw {
296 tree,
297 bits_per_symbol,
298 ..
299 } => {
300 for i in 0..*bits_per_symbol {
301 let bit = ((symbol >> i) & 1) == 1;
302 tree.update(bit, i);
303 }
304 }
305 RateBackendPredictor::Rwkv7 {
306 compressor,
307 primed,
308 ..
309 } => {
310 if !*primed {
311 let vocab_size = compressor.vocab_size();
312 let logits = compressor
313 .model
314 .forward(&mut compressor.scratch, 0, &mut compressor.state);
315 softmax_pdf_floor_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
316 *primed = true;
317 }
318 let vocab_size = compressor.vocab_size();
319 let logits = compressor.model.forward(
320 &mut compressor.scratch,
321 symbol as u32,
322 &mut compressor.state,
323 );
324 softmax_pdf_floor_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
325 }
326 RateBackendPredictor::Zpaq { model } => {
327 model.update(symbol);
328 }
329 RateBackendPredictor::Mixture {
330 runtime,
331 pending_symbol,
332 ..
333 } => {
334 if let Some(pending) = *pending_symbol {
335 if pending == symbol {
336 *pending_symbol = None;
337 return;
338 }
339 }
340 *pending_symbol = None;
341 let _ = runtime.step(symbol);
342 }
343 }
344 }
345}
346
347#[derive(Clone)]
349pub struct ExpertConfig {
350 pub name: String,
351 pub log_prior: f64,
353 builder: Arc<dyn Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync>,
354}
355
356impl ExpertConfig {
357 pub fn new(
359 name: impl Into<String>,
360 log_prior: f64,
361 builder: impl Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync + 'static,
362 ) -> Self {
363 Self {
364 name: name.into(),
365 log_prior,
366 builder: Arc::new(builder),
367 }
368 }
369
370 pub fn uniform(
372 name: impl Into<String>,
373 builder: impl Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync + 'static,
374 ) -> Self {
375 Self::new(name, 0.0, builder)
376 }
377
378 pub fn from_rate_backend(
380 name: Option<String>,
381 log_prior: f64,
382 backend: RateBackend,
383 max_order: i64,
384 ) -> Self {
385 let name = name.unwrap_or_else(|| RateBackendPredictor::default_name(&backend, max_order));
386 Self::new(name, log_prior, move || {
387 Box::new(RateBackendPredictor::from_backend(
388 backend.clone(),
389 max_order,
390 DEFAULT_MIN_PROB,
391 ))
392 })
393 }
394
395 pub fn rosa(name: impl Into<String>, max_order: i64) -> Self {
397 let name = name.into();
398 Self::uniform(name, move || {
399 Box::new(RateBackendPredictor::from_backend(
400 RateBackend::RosaPlus,
401 max_order,
402 DEFAULT_MIN_PROB,
403 ))
404 })
405 }
406
407 pub fn ctw(name: impl Into<String>, depth: usize) -> Self {
409 let name = name.into();
410 Self::uniform(name, move || {
411 Box::new(RateBackendPredictor::from_backend(
412 RateBackend::Ctw { depth },
413 -1,
414 DEFAULT_MIN_PROB,
415 ))
416 })
417 }
418
419 pub fn fac_ctw(name: impl Into<String>, base_depth: usize, encoding_bits: usize) -> Self {
421 let name = name.into();
422 Self::uniform(name, move || {
423 Box::new(RateBackendPredictor::from_backend(
424 RateBackend::FacCtw {
425 base_depth,
426 num_percept_bits: encoding_bits,
427 encoding_bits,
428 },
429 -1,
430 DEFAULT_MIN_PROB,
431 ))
432 })
433 }
434
435 pub fn rwkv(name: impl Into<String>, model: Arc<rwkvzip::Model>) -> Self {
437 let name = name.into();
438 Self::uniform(name, move || {
439 Box::new(RateBackendPredictor::from_backend(
440 RateBackend::Rwkv7 { model: model.clone() },
441 -1,
442 DEFAULT_MIN_PROB,
443 ))
444 })
445 }
446
447 pub fn zpaq(name: impl Into<String>, method: impl Into<String>) -> Self {
449 let name = name.into();
450 let method = method.into();
451 Self::uniform(name, move || {
452 Box::new(RateBackendPredictor::from_backend(
453 RateBackend::Zpaq {
454 method: method.clone(),
455 },
456 -1,
457 DEFAULT_MIN_PROB,
458 ))
459 })
460 }
461
462 pub fn name(&self) -> &str {
464 &self.name
465 }
466
467 pub fn log_prior(&self) -> f64 {
469 self.log_prior
470 }
471
472 pub fn build_predictor(&self) -> Box<dyn OnlineBytePredictor> {
474 (self.builder)()
475 }
476
477 fn build(&self) -> ExpertState {
478 ExpertState {
479 name: self.name.clone(),
480 log_weight: self.log_prior,
481 log_prior: self.log_prior,
482 predictor: (self.builder)(),
483 cum_log_loss: 0.0,
484 }
485 }
486}
487
488struct ExpertState {
489 name: String,
490 log_weight: f64,
491 log_prior: f64,
492 predictor: Box<dyn OnlineBytePredictor>,
493 cum_log_loss: f64,
494}
495
496impl ExpertState {
497 #[inline]
498 fn log_prob(&mut self, symbol: u8) -> f64 {
499 self.predictor.log_prob(symbol)
500 }
501
502 #[inline]
503 fn update(&mut self, symbol: u8) {
504 self.predictor.update(symbol);
505 }
506}
507
508pub struct BayesMixture {
510 experts: Vec<ExpertState>,
511 scratch_logps: Vec<f64>,
512 scratch_mix: Vec<f64>,
513 total_log_loss: f64,
514}
515
516impl BayesMixture {
517 pub fn new(configs: &[ExpertConfig]) -> Self {
518 let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
519 let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
520 let norm = logsumexp(&log_priors);
521 for e in &mut experts {
522 e.log_weight -= norm;
523 }
524 Self {
525 experts,
526 scratch_logps: vec![0.0; configs.len()],
527 scratch_mix: vec![0.0; configs.len()],
528 total_log_loss: 0.0,
529 }
530 }
531
532 pub fn step(&mut self, symbol: u8) -> f64 {
534 if self.experts.is_empty() {
535 return f64::NEG_INFINITY;
536 }
537 for (i, expert) in self.experts.iter_mut().enumerate() {
538 self.scratch_logps[i] = expert.log_prob(symbol);
539 self.scratch_mix[i] = expert.log_weight + self.scratch_logps[i];
540 }
541 let log_mix = logsumexp(&self.scratch_mix);
542 for (i, expert) in self.experts.iter_mut().enumerate() {
543 expert.log_weight = expert.log_weight + self.scratch_logps[i] - log_mix;
544 expert.cum_log_loss -= self.scratch_logps[i];
545 expert.update(symbol);
546 }
547 self.total_log_loss -= log_mix;
548 log_mix
549 }
550
551 pub fn posterior(&self) -> Vec<f64> {
553 let norm = logsumexp_weights(&self.experts);
554 self.experts
555 .iter()
556 .map(|e| (e.log_weight - norm).exp())
557 .collect()
558 }
559
560 pub fn min_expert_log_loss(&self) -> (usize, f64) {
562 let mut best_idx = 0usize;
563 let mut best_loss = f64::INFINITY;
564 for (i, e) in self.experts.iter().enumerate() {
565 if e.cum_log_loss < best_loss {
566 best_loss = e.cum_log_loss;
567 best_idx = i;
568 }
569 }
570 (best_idx, best_loss)
571 }
572
573 pub fn max_posterior(&self) -> (usize, f64) {
575 let norm = logsumexp_weights(&self.experts);
576 let mut best_idx = 0usize;
577 let mut best_p = 0.0;
578 for (i, e) in self.experts.iter().enumerate() {
579 let p = (e.log_weight - norm).exp();
580 if p > best_p {
581 best_p = p;
582 best_idx = i;
583 }
584 }
585 (best_idx, best_p)
586 }
587
588 pub fn total_log_loss(&self) -> f64 {
590 self.total_log_loss
591 }
592
593 pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
595 self.experts
596 .iter()
597 .map(|e| (e.name.clone(), e.cum_log_loss))
598 .collect()
599 }
600
601 pub fn expert_names(&self) -> Vec<String> {
603 self.experts.iter().map(|e| e.name.clone()).collect()
604 }
605}
606
607pub struct FadingBayesMixture {
611 experts: Vec<ExpertState>,
612 decay: f64,
613 scratch_logps: Vec<f64>,
614 scratch_mix: Vec<f64>,
615 total_log_loss: f64,
616}
617
618impl FadingBayesMixture {
619 pub fn new(configs: &[ExpertConfig], decay: f64) -> Self {
620 let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
621 let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
622 let norm = logsumexp(&log_priors);
623 for e in &mut experts {
624 e.log_weight -= norm;
625 }
626 let decay = decay.clamp(0.0, 1.0);
627 Self {
628 experts,
629 decay,
630 scratch_logps: vec![0.0; configs.len()],
631 scratch_mix: vec![0.0; configs.len()],
632 total_log_loss: 0.0,
633 }
634 }
635
636 pub fn step(&mut self, symbol: u8) -> f64 {
638 if self.experts.is_empty() {
639 return f64::NEG_INFINITY;
640 }
641 for (i, expert) in self.experts.iter_mut().enumerate() {
642 self.scratch_logps[i] = expert.log_prob(symbol);
643 let decayed = self.decay * expert.log_weight;
644 self.scratch_mix[i] = decayed + self.scratch_logps[i];
645 }
646 let log_mix = logsumexp(&self.scratch_mix);
647 for (i, expert) in self.experts.iter_mut().enumerate() {
648 let decayed = self.decay * expert.log_weight;
649 expert.log_weight = decayed + self.scratch_logps[i] - log_mix;
650 expert.cum_log_loss -= self.scratch_logps[i];
651 expert.update(symbol);
652 }
653 self.total_log_loss -= log_mix;
654 log_mix
655 }
656
657 pub fn posterior(&self) -> Vec<f64> {
659 let norm = logsumexp_weights(&self.experts);
660 self.experts
661 .iter()
662 .map(|e| (e.log_weight - norm).exp())
663 .collect()
664 }
665
666 pub fn min_expert_log_loss(&self) -> (usize, f64) {
668 let mut best_idx = 0usize;
669 let mut best_loss = f64::INFINITY;
670 for (i, e) in self.experts.iter().enumerate() {
671 if e.cum_log_loss < best_loss {
672 best_loss = e.cum_log_loss;
673 best_idx = i;
674 }
675 }
676 (best_idx, best_loss)
677 }
678
679 pub fn total_log_loss(&self) -> f64 {
681 self.total_log_loss
682 }
683
684 pub fn expert_names(&self) -> Vec<String> {
686 self.experts.iter().map(|e| e.name.clone()).collect()
687 }
688}
689
690pub struct SwitchingMixture {
692 experts: Vec<ExpertState>,
693 log_prior: Vec<f64>,
694 log_alpha: f64,
695 log_1m_alpha: f64,
696 scratch_logps: Vec<f64>,
697 scratch_switch: Vec<f64>,
698 total_log_loss: f64,
699}
700
701impl SwitchingMixture {
702 pub fn new(configs: &[ExpertConfig], alpha: f64) -> Self {
703 let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
704 let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
705 let norm = logsumexp(&log_priors);
706 for e in &mut experts {
707 e.log_weight -= norm;
708 }
709 let log_prior: Vec<f64> = experts.iter().map(|e| e.log_prior - norm).collect();
710 let alpha = alpha.clamp(1e-12, 1.0 - 1e-12);
711 Self {
712 experts,
713 log_prior,
714 log_alpha: alpha.ln(),
715 log_1m_alpha: (1.0 - alpha).ln(),
716 scratch_logps: vec![0.0; configs.len()],
717 scratch_switch: vec![0.0; configs.len()],
718 total_log_loss: 0.0,
719 }
720 }
721
722 pub fn step(&mut self, symbol: u8) -> f64 {
724 if self.experts.is_empty() {
725 return f64::NEG_INFINITY;
726 }
727 for (i, expert) in self.experts.iter_mut().enumerate() {
728 self.scratch_logps[i] = expert.log_prob(symbol);
729 }
730
731 for i in 0..self.experts.len() {
732 let log_switch = logsumexp2(
733 self.log_1m_alpha + self.experts[i].log_weight,
734 self.log_alpha + self.log_prior[i],
735 );
736 self.scratch_switch[i] = self.scratch_logps[i] + log_switch;
737 }
738 let log_mix = logsumexp(&self.scratch_switch);
739 for i in 0..self.experts.len() {
740 let expert = &mut self.experts[i];
741 expert.log_weight = self.scratch_switch[i] - log_mix;
742 expert.cum_log_loss -= self.scratch_logps[i];
743 expert.update(symbol);
744 }
745 self.total_log_loss -= log_mix;
746 log_mix
747 }
748
749 pub fn posterior(&self) -> Vec<f64> {
751 let norm = logsumexp_weights(&self.experts);
752 self.experts
753 .iter()
754 .map(|e| (e.log_weight - norm).exp())
755 .collect()
756 }
757
758 pub fn min_expert_log_loss(&self) -> (usize, f64) {
760 let mut best_idx = 0usize;
761 let mut best_loss = f64::INFINITY;
762 for (i, e) in self.experts.iter().enumerate() {
763 if e.cum_log_loss < best_loss {
764 best_loss = e.cum_log_loss;
765 best_idx = i;
766 }
767 }
768 (best_idx, best_loss)
769 }
770
771 pub fn max_posterior(&self) -> (usize, f64) {
773 let norm = logsumexp_weights(&self.experts);
774 let mut best_idx = 0usize;
775 let mut best_p = 0.0;
776 for (i, e) in self.experts.iter().enumerate() {
777 let p = (e.log_weight - norm).exp();
778 if p > best_p {
779 best_p = p;
780 best_idx = i;
781 }
782 }
783 (best_idx, best_p)
784 }
785
786 pub fn total_log_loss(&self) -> f64 {
788 self.total_log_loss
789 }
790
791 pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
793 self.experts
794 .iter()
795 .map(|e| (e.name.clone(), e.cum_log_loss))
796 .collect()
797 }
798
799 pub fn expert_names(&self) -> Vec<String> {
801 self.experts.iter().map(|e| e.name.clone()).collect()
802 }
803}
804
805pub struct MdlSelector {
807 experts: Vec<ExpertState>,
808 scratch_logps: Vec<f64>,
809 total_log_loss: f64,
810 last_best: usize,
811}
812
813impl MdlSelector {
814 pub fn new(configs: &[ExpertConfig]) -> Self {
815 let experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
816 let last_best = 0usize;
817 Self {
818 experts,
819 scratch_logps: vec![0.0; configs.len()],
820 total_log_loss: 0.0,
821 last_best,
822 }
823 }
824
825 pub fn step(&mut self, symbol: u8) -> f64 {
827 if self.experts.is_empty() {
828 return f64::NEG_INFINITY;
829 }
830 for (i, expert) in self.experts.iter_mut().enumerate() {
831 self.scratch_logps[i] = expert.log_prob(symbol);
832 }
833 let mut best_idx = 0usize;
834 let mut best_loss = f64::INFINITY;
835 for (i, expert) in self.experts.iter().enumerate() {
836 if expert.cum_log_loss < best_loss {
837 best_loss = expert.cum_log_loss;
838 best_idx = i;
839 }
840 }
841 let logp = self.scratch_logps[best_idx];
842 for (i, expert) in self.experts.iter_mut().enumerate() {
843 expert.cum_log_loss -= self.scratch_logps[i];
844 expert.update(symbol);
845 }
846 self.total_log_loss -= logp;
847 self.last_best = best_idx;
848 logp
849 }
850
851 pub fn best_index(&self) -> usize {
853 self.last_best
854 }
855
856 pub fn min_expert_log_loss(&self) -> (usize, f64) {
858 let mut best_idx = 0usize;
859 let mut best_loss = f64::INFINITY;
860 for (i, e) in self.experts.iter().enumerate() {
861 if e.cum_log_loss < best_loss {
862 best_loss = e.cum_log_loss;
863 best_idx = i;
864 }
865 }
866 (best_idx, best_loss)
867 }
868
869 pub fn total_log_loss(&self) -> f64 {
871 self.total_log_loss
872 }
873
874 pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
876 self.experts
877 .iter()
878 .map(|e| (e.name.clone(), e.cum_log_loss))
879 .collect()
880 }
881
882 pub fn expert_names(&self) -> Vec<String> {
884 self.experts.iter().map(|e| e.name.clone()).collect()
885 }
886}
887
888pub enum MixtureRuntime {
893 Bayes(BayesMixture),
894 Fading(FadingBayesMixture),
895 Switching(SwitchingMixture),
896 Mdl(MdlSelector),
897}
898
899impl MixtureRuntime {
900 pub(crate) fn step(&mut self, symbol: u8) -> f64 {
902 match self {
903 MixtureRuntime::Bayes(m) => m.step(symbol),
904 MixtureRuntime::Fading(m) => m.step(symbol),
905 MixtureRuntime::Switching(m) => m.step(symbol),
906 MixtureRuntime::Mdl(m) => m.step(symbol),
907 }
908 }
909}
910
911pub(crate) fn build_mixture_runtime(
912 spec: &MixtureSpec,
913 experts: &[ExpertConfig],
914) -> Result<MixtureRuntime, String> {
915 if experts.is_empty() {
916 return Err("mixture spec must include at least one expert".to_string());
917 }
918 match spec.kind {
919 MixtureKind::Bayes => Ok(MixtureRuntime::Bayes(BayesMixture::new(experts))),
920 MixtureKind::FadingBayes => {
921 let decay = spec
922 .decay
923 .ok_or_else(|| "fading Bayes mixture requires decay".to_string())?;
924 Ok(MixtureRuntime::Fading(FadingBayesMixture::new(
925 experts, decay,
926 )))
927 }
928 MixtureKind::Switching => Ok(MixtureRuntime::Switching(SwitchingMixture::new(
929 experts, spec.alpha,
930 ))),
931 MixtureKind::Mdl => Ok(MixtureRuntime::Mdl(MdlSelector::new(experts))),
932 }
933}
934
935#[cfg(test)]
936mod tests {
937 use super::*;
938
939 struct AlwaysPredict {
940 byte: u8,
941 }
942
943 impl OnlineBytePredictor for AlwaysPredict {
944 fn log_prob(&mut self, symbol: u8) -> f64 {
945 if symbol == self.byte {
946 0.0
947 } else {
948 f64::NEG_INFINITY
949 }
950 }
951
952 fn update(&mut self, _symbol: u8) {}
953 }
954
955 #[test]
956 fn bayes_mixture_prefers_correct_expert() {
957 let configs = vec![
958 ExpertConfig::uniform("zero", || Box::new(AlwaysPredict { byte: 0 })),
959 ExpertConfig::uniform("one", || Box::new(AlwaysPredict { byte: 1 })),
960 ];
961 let mut mix = BayesMixture::new(&configs);
962 for _ in 0..10 {
963 mix.step(0);
964 }
965 let post = mix.posterior();
966 assert!(post[0] > 0.999);
967 assert!(post[1] < 1e-6);
968 }
969}