1use crate::backends::calibration::CalibratorCore;
13use crate::backends::match_model::MatchModel;
14use crate::backends::ppmd::PpmdModel;
15use crate::backends::sequitur::{SequiturCheckpoint, SequiturModel};
16use crate::backends::sparse_match::SparseMatchModel;
17use crate::backends::text_context::TextContextAnalyzer;
18use crate::ctw::FacContextTree;
19#[cfg(feature = "backend-mamba")]
20use crate::mambazip;
21use crate::neural_mix::{NeuralHistoryState, NeuralMixCore};
22use crate::rosaplus::RosaPlus;
23#[cfg(feature = "backend-rwkv")]
24use crate::rwkvzip;
25use crate::zpaq_rate::ZpaqRateModel;
26use crate::{CalibratedSpec, MixtureKind, MixtureScheduleMode, MixtureSpec, RateBackend};
27use std::sync::Arc;
28
29pub const DEFAULT_MIN_PROB: f64 = 5.960_464_477_539_063e-8;
31
32#[inline]
33fn clamp_prob(p: f64, min_prob: f64) -> f64 {
34 if p.is_finite() {
35 p.max(min_prob)
36 } else {
37 min_prob
38 }
39}
40
41#[inline]
42fn clamp_unit_prob(p: f64, min_prob: f64) -> f64 {
43 clamp_prob(p, min_prob).min(1.0 - min_prob)
44}
45
46#[inline]
47fn build_calibrator(spec: &CalibratedSpec) -> CalibratorCore {
48 CalibratorCore::new(spec.context, spec.bins, spec.learning_rate, spec.bias_clip)
49}
50
51#[inline]
52fn logsumexp(xs: &[f64]) -> f64 {
53 let mut max_v = f64::NEG_INFINITY;
54 for &v in xs {
55 if v > max_v {
56 max_v = v;
57 }
58 }
59 if !max_v.is_finite() {
60 return max_v;
61 }
62 let mut sum = 0.0;
63 for &v in xs {
64 sum += (v - max_v).exp();
65 }
66 max_v + sum.ln()
67}
68
69#[inline]
70fn logsumexp2(a: f64, b: f64) -> f64 {
71 let m = if a > b { a } else { b };
72 if !m.is_finite() {
73 return m;
74 }
75 m + ((a - m).exp() + (b - m).exp()).ln()
76}
77
78#[inline]
79fn logsumexp_weights(experts: &[ExpertState]) -> f64 {
80 let mut max_v = f64::NEG_INFINITY;
81 for e in experts {
82 if e.log_weight > max_v {
83 max_v = e.log_weight;
84 }
85 }
86 if !max_v.is_finite() {
87 return max_v;
88 }
89 let mut sum = 0.0;
90 for e in experts {
91 sum += (e.log_weight - max_v).exp();
92 }
93 max_v + sum.ln()
94}
95
96fn normalize_simplex_weights(weights: &mut [f64]) {
97 if weights.is_empty() {
98 return;
99 }
100 let mut sum = 0.0;
101 for weight in weights.iter_mut() {
102 if !weight.is_finite() || *weight < 0.0 {
103 *weight = 0.0;
104 }
105 sum += *weight;
106 }
107 if !sum.is_finite() || sum <= 0.0 {
108 let uniform = 1.0 / (weights.len() as f64);
109 weights.fill(uniform);
110 return;
111 }
112 for weight in weights.iter_mut() {
113 *weight /= sum;
114 }
115}
116
117pub(crate) fn project_simplex_with_scratch(weights: &mut [f64], scratch: &mut Vec<f64>) {
118 if weights.is_empty() {
119 return;
120 }
121
122 scratch.clear();
123 scratch.extend(
124 weights
125 .iter()
126 .map(|&weight| if weight.is_finite() { weight } else { 0.0 }),
127 );
128 let sorted = scratch.as_mut_slice();
129 sorted.sort_by(|a, b| b.total_cmp(a));
130
131 let mut cumulative = 0.0;
132 let mut rho = None;
133 for (index, value) in sorted.iter().enumerate() {
134 cumulative += *value;
135 let theta = (cumulative - 1.0) / ((index + 1) as f64);
136 if *value > theta {
137 rho = Some(index);
138 }
139 }
140
141 let Some(rho_index) = rho else {
142 let uniform = 1.0 / (weights.len() as f64);
143 weights.fill(uniform);
144 return;
145 };
146
147 let theta = (sorted.iter().take(rho_index + 1).sum::<f64>() - 1.0) / ((rho_index + 1) as f64);
148 for weight in weights.iter_mut() {
149 *weight = (*weight - theta).max(0.0);
150 }
151 normalize_simplex_weights(weights);
152}
153
154#[inline]
155pub(crate) fn switching_alpha_for_update(
156 schedule: MixtureScheduleMode,
157 alpha: f64,
158 processed_symbols: u64,
159) -> f64 {
160 match schedule {
161 MixtureScheduleMode::Default => alpha.clamp(0.0, 1.0),
162 MixtureScheduleMode::Theorem => 1.0 / ((processed_symbols + 2) as f64),
163 }
164}
165
166#[inline]
167pub(crate) fn convex_step_size_for_update(
168 schedule: MixtureScheduleMode,
169 alpha: f64,
170 update_index: u64,
171) -> f64 {
172 let t = update_index.max(1) as f64;
173 match schedule {
174 MixtureScheduleMode::Default => alpha.max(1e-12) / t.sqrt(),
175 MixtureScheduleMode::Theorem => DEFAULT_MIN_PROB / t.sqrt(),
176 }
177}
178
179fn normalized_prior_weights(configs: &[ExpertConfig]) -> Vec<f64> {
180 if configs.is_empty() {
181 return Vec::new();
182 }
183 let max_log = configs
184 .iter()
185 .map(|cfg| cfg.log_prior)
186 .fold(f64::NEG_INFINITY, f64::max);
187 let mut weights = configs
188 .iter()
189 .map(|cfg| {
190 if max_log.is_finite() {
191 (cfg.log_prior - max_log).exp()
192 } else {
193 0.0
194 }
195 })
196 .collect::<Vec<_>>();
197 normalize_simplex_weights(&mut weights);
198 weights
199}
200
201fn set_log_weights_from_linear(experts: &mut [ExpertState], weights: &[f64]) {
202 for (expert, &weight) in experts.iter_mut().zip(weights.iter()) {
203 expert.log_weight = if weight > 0.0 {
204 weight.ln()
205 } else {
206 f64::NEG_INFINITY
207 };
208 }
209}
210
211pub trait OnlineBytePredictorClone {
213 fn clone_box(&self) -> Box<dyn OnlineBytePredictor>;
218}
219
220impl<T> OnlineBytePredictorClone for T
221where
222 T: 'static + OnlineBytePredictor + Clone,
223{
224 fn clone_box(&self) -> Box<dyn OnlineBytePredictor> {
225 Box::new(self.clone())
226 }
227}
228
229impl Clone for Box<dyn OnlineBytePredictor> {
230 fn clone(&self) -> Self {
231 self.clone_box()
232 }
233}
234
235pub trait OnlineBytePredictor: Send + OnlineBytePredictorClone {
237 fn begin_stream(&mut self, _total_symbols: Option<u64>) -> Result<(), String> {
242 Ok(())
243 }
244
245 fn finish_stream(&mut self) -> Result<(), String> {
247 Ok(())
248 }
249
250 fn log_prob(&mut self, symbol: u8) -> f64;
252
253 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
255 for (sym, slot) in out.iter_mut().enumerate() {
256 *slot = self.log_prob(sym as u8);
257 }
258 }
259
260 fn log_prob_update(&mut self, symbol: u8) -> f64 {
262 let logp = self.log_prob(symbol);
263 self.update(symbol);
264 logp
265 }
266
267 fn update(&mut self, symbol: u8);
269
270 fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
276 self.finish_stream()?;
277 self.begin_stream(total_symbols)
278 }
279
280 fn update_frozen(&mut self, symbol: u8) {
285 self.update(symbol);
286 }
287}
288
289#[cfg(feature = "backend-rwkv")]
290#[inline]
291fn ensure_rwkv_primed(compressor: &mut rwkvzip::Compressor, primed: &mut bool) {
292 if !*primed {
293 compressor.reset_and_prime();
294 *primed = true;
295 }
296}
297
298#[inline]
299fn ctw_log_prob_update_msb(tree: &mut FacContextTree, symbol: u8, min_prob: f64) -> f64 {
300 let mut logp = 0.0;
301 for bit_idx in 0..8 {
302 let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
303 let p = tree.predict(bit, bit_idx);
304 if p.is_finite() && p > 0.0 {
305 logp += p.ln();
306 } else {
307 logp = f64::NEG_INFINITY;
308 }
309 tree.update_predicted(bit, bit_idx);
310 }
311 if logp.is_finite() {
312 logp.max(min_prob.ln())
313 } else {
314 min_prob.ln()
315 }
316}
317
318#[inline]
319fn ctw_log_prob_update_lsb(
320 tree: &mut FacContextTree,
321 symbol: u8,
322 bits_per_symbol: usize,
323 min_prob: f64,
324) -> f64 {
325 let mut logp = 0.0;
326 for bit_idx in 0..bits_per_symbol {
327 let bit = ((symbol >> bit_idx) & 1) == 1;
328 let p = tree.predict(bit, bit_idx);
329 if p.is_finite() && p > 0.0 {
330 logp += p.ln();
331 } else {
332 logp = f64::NEG_INFINITY;
333 }
334 tree.update_predicted(bit, bit_idx);
335 }
336 if logp.is_finite() {
337 logp.max(min_prob.ln())
338 } else {
339 min_prob.ln()
340 }
341}
342
343fn fill_fac_tree_log_probs(
344 tree: &mut FacContextTree,
345 bits_per_symbol: usize,
346 msb_first: bool,
347 min_logp: f64,
348 out: &mut [f64; 256],
349) {
350 struct RecParams {
351 bits: usize,
352 msb_first: bool,
353 log_before: f64,
354 min_logp: f64,
355 }
356
357 let bits = bits_per_symbol.clamp(1, 8);
358 let patterns = 1usize << bits;
359 let mut pattern_logps = [f64::NEG_INFINITY; 256];
360 let params = RecParams {
361 bits,
362 msb_first,
363 log_before: tree.get_log_block_probability(),
364 min_logp,
365 };
366
367 fn rec(
368 tree: &mut FacContextTree,
369 depth: usize,
370 params: &RecParams,
371 symbol_acc: u8,
372 pattern_logps: &mut [f64; 256],
373 ) {
374 if depth == params.bits {
375 let pat = symbol_acc as usize;
376 let logp = (tree.get_log_block_probability() - params.log_before).max(params.min_logp);
377 pattern_logps[pat] = logp;
378 return;
379 }
380
381 for bit in [false, true] {
382 tree.update(bit, depth);
383 let mut next_symbol = symbol_acc;
384 if params.msb_first {
385 let shift = 7usize.saturating_sub(depth);
386 if bit {
387 next_symbol |= 1u8 << shift;
388 }
389 } else if bit {
390 next_symbol |= 1u8 << depth;
391 }
392 rec(tree, depth + 1, params, next_symbol, pattern_logps);
393 tree.revert(depth);
394 }
395 }
396
397 rec(tree, 0, ¶ms, 0, &mut pattern_logps);
398
399 if bits == 8 {
400 out.copy_from_slice(&pattern_logps);
401 } else {
402 let aliases = 1usize << (8 - bits);
403 let alias_ln = (aliases as f64).ln();
404 let mask = patterns - 1;
405 for byte in 0..256usize {
406 out[byte] = pattern_logps[byte & mask] - alias_ln;
407 }
408 }
409}
410
411#[allow(clippy::large_enum_variant)]
413#[derive(Clone)]
414pub enum RateBackendPredictor {
415 Rosa {
417 model: RosaPlus,
419 min_prob: f64,
421 },
422 Match {
424 model: MatchModel,
426 min_prob: f64,
428 },
429 SparseMatch {
431 model: SparseMatchModel,
433 min_prob: f64,
435 },
436 Ppmd {
438 model: PpmdModel,
440 min_prob: f64,
442 },
443 Sequitur {
445 model: SequiturModel,
447 min_prob: f64,
449 },
450 Ctw {
452 tree: FacContextTree,
454 min_prob: f64,
456 },
457 FacCtw {
459 tree: FacContextTree,
461 bits_per_symbol: usize,
463 min_prob: f64,
465 },
466 #[cfg(feature = "backend-rwkv")]
468 Rwkv7 {
469 compressor: rwkvzip::Compressor,
471 primed: bool,
473 pdf_scratch: Vec<f64>,
475 min_prob: f64,
477 },
478 #[cfg(feature = "backend-mamba")]
480 Mamba {
481 compressor: mambazip::Compressor,
483 primed: bool,
485 pdf_scratch: Vec<f64>,
487 min_prob: f64,
489 },
490 Zpaq {
492 model: ZpaqRateModel,
494 },
495 Mixture {
497 runtime: MixtureRuntime,
499 },
500 Particle {
502 runtime: crate::particle::ParticleRuntime,
504 },
505 Calibrated {
507 base: Box<RateBackendPredictor>,
509 core: CalibratorCore,
511 pdf: [f64; 256],
513 valid: bool,
515 min_prob: f64,
517 },
518}
519
520#[derive(Clone)]
521pub enum RateBackendPredictorCheckpoint {
526 Full(RateBackendPredictor),
528 Sequitur(SequiturCheckpoint),
530}
531
532impl RateBackendPredictor {
533 pub fn from_backend(backend: RateBackend, max_order: i64, min_prob: f64) -> Self {
535 match backend {
536 RateBackend::RosaPlus => {
537 let mut model = RosaPlus::new(max_order, false, 0, 42);
538 model.build_lm_full_bytes_no_finalize_endpos();
539 Self::Rosa { model, min_prob }
540 }
541 RateBackend::Match {
542 hash_bits,
543 min_len,
544 max_len,
545 base_mix,
546 confidence_scale,
547 } => Self::Match {
548 model: MatchModel::new_contiguous(
549 hash_bits,
550 min_len,
551 max_len,
552 base_mix,
553 confidence_scale,
554 ),
555 min_prob,
556 },
557 RateBackend::SparseMatch {
558 hash_bits,
559 min_len,
560 max_len,
561 gap_min,
562 gap_max,
563 base_mix,
564 confidence_scale,
565 } => Self::SparseMatch {
566 model: SparseMatchModel::new(
567 hash_bits,
568 min_len,
569 max_len,
570 gap_min,
571 gap_max,
572 base_mix,
573 confidence_scale,
574 ),
575 min_prob,
576 },
577 RateBackend::Ppmd { order, memory_mb } => Self::Ppmd {
578 model: PpmdModel::new(order, memory_mb),
579 min_prob,
580 },
581 RateBackend::Sequitur { context_bytes } => Self::Sequitur {
582 model: SequiturModel::new(context_bytes),
583 min_prob,
584 },
585 RateBackend::Ctw { depth } => {
586 let tree = FacContextTree::new(depth, 8);
587 Self::Ctw { tree, min_prob }
588 }
589 RateBackend::FacCtw {
590 base_depth,
591 num_percept_bits: _,
592 encoding_bits,
593 } => {
594 let bits_per_symbol = encoding_bits.clamp(1, 8);
595 let tree = FacContextTree::new(base_depth, bits_per_symbol);
596 Self::FacCtw {
597 tree,
598 bits_per_symbol,
599 min_prob,
600 }
601 }
602 #[cfg(feature = "backend-rwkv")]
603 RateBackend::Rwkv7 { model } => {
604 let mut compressor = rwkvzip::Compressor::new_from_model(model);
605 compressor.reset_and_prime();
606 Self::Rwkv7 {
607 pdf_scratch: vec![0.0; compressor.pdf_buffer.len()],
608 compressor,
609 primed: true,
610 min_prob,
611 }
612 }
613 #[cfg(feature = "backend-rwkv")]
614 RateBackend::Rwkv7Method { method } => {
615 let mut compressor = rwkvzip::Compressor::new_from_method(&method)
616 .unwrap_or_else(|e| panic!("invalid rwkv method '{method}': {e}"));
617 compressor.reset_and_prime();
618 Self::Rwkv7 {
619 pdf_scratch: vec![0.0; compressor.pdf_buffer.len()],
620 compressor,
621 primed: true,
622 min_prob,
623 }
624 }
625 #[cfg(feature = "backend-mamba")]
626 RateBackend::Mamba { model } => {
627 let mut compressor = mambazip::Compressor::new_from_model(model);
628 let bias = compressor.online_bias_snapshot();
629 let logits =
630 compressor
631 .model
632 .forward(&mut compressor.scratch, 0, &mut compressor.state);
633 mambazip::Compressor::logits_to_pdf(
634 logits,
635 bias.as_deref(),
636 &mut compressor.pdf_buffer,
637 );
638 Self::Mamba {
639 pdf_scratch: vec![0.0; compressor.pdf_buffer.len()],
640 compressor,
641 primed: true,
642 min_prob,
643 }
644 }
645 #[cfg(feature = "backend-mamba")]
646 RateBackend::MambaMethod { method } => {
647 let mut compressor = mambazip::Compressor::new_from_method(&method)
648 .unwrap_or_else(|e| panic!("invalid mamba method '{method}': {e}"));
649 let bias = compressor.online_bias_snapshot();
650 let logits =
651 compressor
652 .model
653 .forward(&mut compressor.scratch, 0, &mut compressor.state);
654 mambazip::Compressor::logits_to_pdf(
655 logits,
656 bias.as_deref(),
657 &mut compressor.pdf_buffer,
658 );
659 Self::Mamba {
660 pdf_scratch: vec![0.0; compressor.pdf_buffer.len()],
661 compressor,
662 primed: true,
663 min_prob,
664 }
665 }
666 RateBackend::Zpaq { method } => {
667 let model = ZpaqRateModel::new(method, min_prob);
668 Self::Zpaq { model }
669 }
670 RateBackend::Mixture { spec } => {
671 let experts = spec.build_experts();
672 let runtime = build_mixture_runtime(spec.as_ref(), &experts)
673 .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
674 Self::Mixture { runtime }
675 }
676 RateBackend::Particle { spec } => {
677 let runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
678 Self::Particle { runtime }
679 }
680 RateBackend::Calibrated { spec } => Self::Calibrated {
681 base: Box::new(Self::from_backend(spec.base.clone(), max_order, min_prob)),
682 core: build_calibrator(spec.as_ref()),
683 pdf: [1.0 / 256.0; 256],
684 valid: false,
685 min_prob,
686 },
687 }
688 }
689
690 pub fn default_name(backend: &RateBackend, max_order: i64) -> String {
692 match backend {
693 RateBackend::RosaPlus => format!("rosa(mo={})", max_order),
694 RateBackend::Match { .. } => "match".to_string(),
695 RateBackend::SparseMatch { .. } => "sparse-match".to_string(),
696 RateBackend::Ppmd { order, memory_mb } => {
697 format!("ppmd(o={},m={}MiB)", order, memory_mb)
698 }
699 RateBackend::Sequitur { context_bytes } => {
700 format!("sequitur(ctx={context_bytes})")
701 }
702 RateBackend::Ctw { depth } => format!("ctw(d={})", depth),
703 RateBackend::FacCtw {
704 base_depth,
705 encoding_bits,
706 ..
707 } => format!("fac-ctw(d={},b={})", base_depth, encoding_bits),
708 #[cfg(feature = "backend-rwkv")]
709 RateBackend::Rwkv7 { .. } => "rwkv7".to_string(),
710 #[cfg(feature = "backend-rwkv")]
711 RateBackend::Rwkv7Method { method } => format!("rwkv7({method})"),
712 #[cfg(feature = "backend-mamba")]
713 RateBackend::Mamba { .. } => "mamba".to_string(),
714 #[cfg(feature = "backend-mamba")]
715 RateBackend::MambaMethod { method } => format!("mamba({method})"),
716 RateBackend::Zpaq { method } => format!("zpaq(m={})", method),
717 RateBackend::Mixture { spec } => {
718 let kind = match spec.kind {
719 MixtureKind::Bayes => "bayes",
720 MixtureKind::FadingBayes => "fading",
721 MixtureKind::Switching => "switch",
722 MixtureKind::Convex => "convex",
723 MixtureKind::Mdl => "mdl",
724 MixtureKind::Neural => "neural",
725 };
726 format!("mix({})", kind)
727 }
728 RateBackend::Particle { spec } => {
729 format!("particle(n={},c={})", spec.num_particles, spec.num_cells)
730 }
731 RateBackend::Calibrated { spec } => {
732 format!("calibrated({})", Self::default_name(&spec.base, max_order))
733 }
734 }
735 }
736
737 pub(crate) fn checkpoint(&mut self) -> RateBackendPredictorCheckpoint {
738 match self {
739 RateBackendPredictor::Sequitur { model, .. } => {
740 RateBackendPredictorCheckpoint::Sequitur(model.checkpoint())
741 }
742 _ => RateBackendPredictorCheckpoint::Full(self.clone()),
743 }
744 }
745
746 pub(crate) fn restore_checkpoint(&mut self, checkpoint: &RateBackendPredictorCheckpoint) {
747 match (self, checkpoint) {
748 (
749 RateBackendPredictor::Sequitur { model, .. },
750 RateBackendPredictorCheckpoint::Sequitur(ck),
751 ) => {
752 model.restore(ck);
753 }
754 (slot, RateBackendPredictorCheckpoint::Full(state)) => {
755 *slot = state.clone();
756 }
757 (_, RateBackendPredictorCheckpoint::Sequitur(_)) => {
758 panic!("mismatched RateBackendPredictor checkpoint variant")
759 }
760 }
761 }
762
763 pub(crate) fn clear_checkpoints_if_supported(&mut self) {
764 if let RateBackendPredictor::Sequitur { model, .. } = self {
765 model.clear_checkpoints();
766 }
767 }
768}
769
770impl OnlineBytePredictor for RateBackendPredictor {
771 fn begin_stream(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
772 self.finish_stream()?;
773 match self {
774 RateBackendPredictor::Rosa { model, .. } => {
775 if let Some(total) = total_symbols {
776 let reserve = usize::try_from(total).unwrap_or(usize::MAX / 4);
777 model.reserve_for_stream(reserve);
778 }
779 Ok(())
780 }
781 RateBackendPredictor::Match { .. }
782 | RateBackendPredictor::SparseMatch { .. }
783 | RateBackendPredictor::Ppmd { .. } => Ok(()),
784 RateBackendPredictor::Sequitur { model, .. } => {
785 model.begin_stream(total_symbols);
786 Ok(())
787 }
788 RateBackendPredictor::Ctw { .. }
789 | RateBackendPredictor::FacCtw { .. }
790 | RateBackendPredictor::Zpaq { .. }
791 | RateBackendPredictor::Particle { .. } => Ok(()),
792 #[cfg(feature = "backend-rwkv")]
793 RateBackendPredictor::Rwkv7 { compressor, .. } => compressor
794 .begin_online_policy_stream(total_symbols)
795 .map_err(|e| e.to_string()),
796 #[cfg(feature = "backend-mamba")]
797 RateBackendPredictor::Mamba { compressor, .. } => compressor
798 .begin_online_policy_stream(total_symbols)
799 .map_err(|e| e.to_string()),
800 RateBackendPredictor::Mixture { runtime } => runtime.begin_stream(total_symbols),
801 RateBackendPredictor::Calibrated { base, .. } => base.begin_stream(total_symbols),
802 }
803 }
804
805 fn finish_stream(&mut self) -> Result<(), String> {
806 match self {
807 RateBackendPredictor::Rosa { .. }
808 | RateBackendPredictor::Match { .. }
809 | RateBackendPredictor::SparseMatch { .. }
810 | RateBackendPredictor::Ppmd { .. }
811 | RateBackendPredictor::Ctw { .. }
812 | RateBackendPredictor::FacCtw { .. }
813 | RateBackendPredictor::Zpaq { .. }
814 | RateBackendPredictor::Particle { .. } => Ok(()),
815 RateBackendPredictor::Sequitur { model, .. } => {
816 model.finish_stream();
817 Ok(())
818 }
819 #[cfg(feature = "backend-rwkv")]
820 RateBackendPredictor::Rwkv7 { compressor, .. } => compressor
821 .finish_online_policy_stream()
822 .map_err(|e| e.to_string()),
823 #[cfg(feature = "backend-mamba")]
824 RateBackendPredictor::Mamba { .. } => Ok(()),
825 RateBackendPredictor::Mixture { runtime } => runtime.finish_stream(),
826 RateBackendPredictor::Calibrated { base, .. } => base.finish_stream(),
827 }
828 }
829
830 fn log_prob(&mut self, symbol: u8) -> f64 {
831 match self {
832 RateBackendPredictor::Rosa { model, min_prob } => {
833 let p = clamp_prob(model.prob_for_last(symbol as u32), *min_prob);
834 p.ln()
835 }
836 RateBackendPredictor::Match { model, min_prob } => model.log_prob(symbol, *min_prob),
837 RateBackendPredictor::SparseMatch { model, min_prob } => {
838 model.log_prob(symbol, *min_prob)
839 }
840 RateBackendPredictor::Ppmd { model, min_prob } => model.log_prob(symbol, *min_prob),
841 RateBackendPredictor::Sequitur { model, min_prob } => model.log_prob(symbol, *min_prob),
842 RateBackendPredictor::Ctw { tree, min_prob } => {
843 let log_before = tree.get_log_block_probability();
844 for bit_idx in 0..8 {
845 let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
846 tree.update(bit, bit_idx);
847 }
848 let log_after = tree.get_log_block_probability();
849 for bit_idx in (0..8).rev() {
850 tree.revert(bit_idx);
851 }
852 let logp = log_after - log_before;
853 if logp.is_finite() {
854 logp.max(min_prob.ln())
855 } else {
856 min_prob.ln()
857 }
858 }
859 RateBackendPredictor::FacCtw {
860 tree,
861 bits_per_symbol,
862 min_prob,
863 } => {
864 let log_before = tree.get_log_block_probability();
865 for i in 0..*bits_per_symbol {
866 let bit = ((symbol >> i) & 1) == 1;
867 tree.update(bit, i);
868 }
869 let log_after = tree.get_log_block_probability();
870 for i in (0..*bits_per_symbol).rev() {
871 tree.revert(i);
872 }
873 let logp = log_after - log_before;
874 if logp.is_finite() {
875 logp.max(min_prob.ln())
876 } else {
877 min_prob.ln()
878 }
879 }
880 #[cfg(feature = "backend-rwkv")]
881 RateBackendPredictor::Rwkv7 {
882 compressor,
883 primed,
884 min_prob,
885 ..
886 } => {
887 ensure_rwkv_primed(compressor, primed);
888 let p = clamp_prob(compressor.pdf_buffer[symbol as usize], *min_prob);
889 p.ln()
890 }
891 #[cfg(feature = "backend-mamba")]
892 RateBackendPredictor::Mamba {
893 compressor,
894 primed,
895 min_prob,
896 ..
897 } => {
898 if !*primed {
899 let bias = compressor.online_bias_snapshot();
900 let logits =
901 compressor
902 .model
903 .forward(&mut compressor.scratch, 0, &mut compressor.state);
904 mambazip::Compressor::logits_to_pdf(
905 logits,
906 bias.as_deref(),
907 &mut compressor.pdf_buffer,
908 );
909 *primed = true;
910 }
911 let p = clamp_prob(compressor.pdf_buffer[symbol as usize], *min_prob);
912 p.ln()
913 }
914 RateBackendPredictor::Zpaq { model } => model.log_prob(symbol),
915 RateBackendPredictor::Mixture { runtime } => runtime.peek_log_prob(symbol),
916 RateBackendPredictor::Particle { runtime } => runtime.peek_log_prob(symbol),
917 RateBackendPredictor::Calibrated {
918 base,
919 core,
920 pdf,
921 valid,
922 min_prob,
923 } => {
924 if !*valid {
925 let mut base_logps = [0.0; 256];
926 base.fill_log_probs(&mut base_logps);
927 let mut base_pdf = [0.0; 256];
928 for (dst, &lp) in base_pdf.iter_mut().zip(base_logps.iter()) {
929 *dst = clamp_prob(lp.exp(), *min_prob);
930 }
931 core.apply_pdf(&base_pdf, pdf);
932 *valid = true;
933 }
934 pdf[symbol as usize].max(*min_prob).ln()
935 }
936 }
937 }
938
939 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
940 match self {
941 RateBackendPredictor::Rosa { model, min_prob } => {
942 model.fill_probs_for_last_bytes(out);
943 for slot in out.iter_mut() {
944 *slot = clamp_prob(*slot, *min_prob).ln();
945 }
946 }
947 RateBackendPredictor::Match { model, min_prob } => {
948 let mut pdf = [0.0; 256];
949 model.fill_pdf(&mut pdf);
950 for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
951 *slot = clamp_prob(p, *min_prob).ln();
952 }
953 }
954 RateBackendPredictor::SparseMatch { model, min_prob } => {
955 let mut pdf = [0.0; 256];
956 model.fill_pdf(&mut pdf);
957 for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
958 *slot = clamp_prob(p, *min_prob).ln();
959 }
960 }
961 RateBackendPredictor::Ppmd { model, min_prob } => {
962 let mut pdf = [0.0; 256];
963 model.fill_pdf(&mut pdf);
964 for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
965 *slot = clamp_prob(p, *min_prob).ln();
966 }
967 }
968 RateBackendPredictor::Sequitur { model, min_prob } => {
969 let mut pdf = [0.0; 256];
970 model.fill_pdf(&mut pdf);
971 for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
972 *slot = clamp_prob(p, *min_prob).ln();
973 }
974 }
975 RateBackendPredictor::Ctw { tree, min_prob } => {
976 fill_fac_tree_log_probs(tree, 8, true, min_prob.ln(), out);
977 }
978 RateBackendPredictor::FacCtw {
979 tree,
980 bits_per_symbol,
981 min_prob,
982 } => {
983 fill_fac_tree_log_probs(tree, *bits_per_symbol, false, min_prob.ln(), out);
984 }
985 #[cfg(feature = "backend-rwkv")]
986 RateBackendPredictor::Rwkv7 {
987 compressor,
988 primed,
989 min_prob,
990 ..
991 } => {
992 ensure_rwkv_primed(compressor, primed);
993 for (slot, &p_raw) in out
994 .iter_mut()
995 .take(256)
996 .zip(compressor.pdf_buffer.iter().take(256))
997 {
998 let p = clamp_prob(p_raw, *min_prob);
999 *slot = p.ln();
1000 }
1001 }
1002 #[cfg(feature = "backend-mamba")]
1003 RateBackendPredictor::Mamba {
1004 compressor,
1005 primed,
1006 min_prob,
1007 ..
1008 } => {
1009 if !*primed {
1010 let bias = compressor.online_bias_snapshot();
1011 let logits =
1012 compressor
1013 .model
1014 .forward(&mut compressor.scratch, 0, &mut compressor.state);
1015 mambazip::Compressor::logits_to_pdf(
1016 logits,
1017 bias.as_deref(),
1018 &mut compressor.pdf_buffer,
1019 );
1020 *primed = true;
1021 }
1022 for (slot, &p_raw) in out
1023 .iter_mut()
1024 .take(256)
1025 .zip(compressor.pdf_buffer.iter().take(256))
1026 {
1027 let p = clamp_prob(p_raw, *min_prob);
1028 *slot = p.ln();
1029 }
1030 }
1031 RateBackendPredictor::Zpaq { model } => {
1032 model.fill_log_probs(out);
1033 }
1034 RateBackendPredictor::Mixture { runtime } => {
1035 runtime.fill_log_probs(out);
1036 }
1037 RateBackendPredictor::Particle { runtime } => {
1038 runtime.fill_log_probs_cached(out);
1039 }
1040 RateBackendPredictor::Calibrated {
1041 base,
1042 core,
1043 pdf,
1044 valid,
1045 min_prob,
1046 } => {
1047 if !*valid {
1048 let mut base_logps = [0.0; 256];
1049 base.fill_log_probs(&mut base_logps);
1050 let mut base_pdf = [0.0; 256];
1051 for (dst, &lp) in base_pdf.iter_mut().zip(base_logps.iter()) {
1052 *dst = clamp_prob(lp.exp(), *min_prob);
1053 }
1054 core.apply_pdf(&base_pdf, pdf);
1055 *valid = true;
1056 }
1057 for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
1058 *slot = clamp_prob(p, *min_prob).ln();
1059 }
1060 }
1061 }
1062 }
1063
1064 fn update(&mut self, symbol: u8) {
1065 match self {
1066 RateBackendPredictor::Rosa { model, .. } => {
1067 model.train_byte(symbol);
1068 }
1069 RateBackendPredictor::Match { model, .. } => {
1070 model.update(symbol);
1071 }
1072 RateBackendPredictor::SparseMatch { model, .. } => {
1073 model.update(symbol);
1074 }
1075 RateBackendPredictor::Ppmd { model, .. } => {
1076 model.update(symbol);
1077 }
1078 RateBackendPredictor::Sequitur { model, .. } => {
1079 model.update(symbol);
1080 }
1081 RateBackendPredictor::Ctw { tree, .. } => {
1082 for bit_idx in 0..8 {
1083 let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
1084 tree.update(bit, bit_idx);
1085 }
1086 }
1087 RateBackendPredictor::FacCtw {
1088 tree,
1089 bits_per_symbol,
1090 ..
1091 } => {
1092 for i in 0..*bits_per_symbol {
1093 let bit = ((symbol >> i) & 1) == 1;
1094 tree.update(bit, i);
1095 }
1096 }
1097 #[cfg(feature = "backend-rwkv")]
1098 RateBackendPredictor::Rwkv7 {
1099 compressor, primed, ..
1100 } => {
1101 ensure_rwkv_primed(compressor, primed);
1102 compressor
1103 .observe_symbol_from_current_pdf(symbol)
1104 .unwrap_or_else(|e| panic!("rwkv online update failed: {e}"));
1105 }
1106 #[cfg(feature = "backend-mamba")]
1107 RateBackendPredictor::Mamba {
1108 compressor,
1109 primed,
1110 pdf_scratch,
1111 ..
1112 } => {
1113 if !*primed {
1114 let bias = compressor.online_bias_snapshot();
1115 let logits =
1116 compressor
1117 .model
1118 .forward(&mut compressor.scratch, 0, &mut compressor.state);
1119 mambazip::Compressor::logits_to_pdf(
1120 logits,
1121 bias.as_deref(),
1122 &mut compressor.pdf_buffer,
1123 );
1124 *primed = true;
1125 }
1126 if pdf_scratch.len() != compressor.pdf_buffer.len() {
1127 pdf_scratch.resize(compressor.pdf_buffer.len(), 0.0);
1128 }
1129 pdf_scratch.copy_from_slice(&compressor.pdf_buffer);
1130 compressor
1131 .online_update_from_pdf(symbol, pdf_scratch)
1132 .unwrap_or_else(|e| panic!("mamba online update failed: {e}"));
1133 let bias = compressor.online_bias_snapshot();
1134 let logits = compressor.model.forward(
1135 &mut compressor.scratch,
1136 symbol as u32,
1137 &mut compressor.state,
1138 );
1139 mambazip::Compressor::logits_to_pdf(
1140 logits,
1141 bias.as_deref(),
1142 &mut compressor.pdf_buffer,
1143 );
1144 }
1145 RateBackendPredictor::Zpaq { model } => {
1146 model.update(symbol);
1147 }
1148 RateBackendPredictor::Mixture { runtime } => {
1149 let _ = runtime.step(symbol);
1150 }
1151 RateBackendPredictor::Particle { runtime } => {
1152 runtime.step(symbol);
1153 }
1154 RateBackendPredictor::Calibrated {
1155 base,
1156 core,
1157 pdf,
1158 valid,
1159 ..
1160 } => {
1161 if !*valid {
1162 let mut base_logps = [0.0; 256];
1163 base.fill_log_probs(&mut base_logps);
1164 let mut base_pdf = [0.0; 256];
1165 for (dst, &lp) in base_pdf.iter_mut().zip(base_logps.iter()) {
1166 *dst = clamp_prob(lp.exp(), DEFAULT_MIN_PROB);
1167 }
1168 core.apply_pdf(&base_pdf, pdf);
1169 }
1170 core.update(symbol, pdf);
1171 base.update(symbol);
1172 *valid = false;
1173 }
1174 }
1175 }
1176
1177 fn log_prob_update(&mut self, symbol: u8) -> f64 {
1178 match self {
1179 RateBackendPredictor::Rosa { model, min_prob } => {
1180 let p = clamp_prob(model.prob_for_last(symbol as u32), *min_prob);
1181 model.train_byte(symbol);
1182 p.ln()
1183 }
1184 RateBackendPredictor::Ctw { tree, min_prob } => {
1185 ctw_log_prob_update_msb(tree, symbol, *min_prob)
1186 }
1187 RateBackendPredictor::FacCtw {
1188 tree,
1189 bits_per_symbol,
1190 min_prob,
1191 } => ctw_log_prob_update_lsb(tree, symbol, *bits_per_symbol, *min_prob),
1192 _ => {
1193 let logp = self.log_prob(symbol);
1194 self.update(symbol);
1195 logp
1196 }
1197 }
1198 }
1199
1200 fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1201 self.finish_stream()?;
1202 match self {
1203 RateBackendPredictor::Rosa { model, .. } => {
1204 if let Some(total) = total_symbols {
1205 let reserve = usize::try_from(total).unwrap_or(usize::MAX / 4);
1206 model.reserve_for_stream(reserve);
1207 }
1208 model.build_lm_full_bytes_no_finalize_endpos();
1209 model.reset_conditioning_cursor();
1210 Ok(())
1211 }
1212 RateBackendPredictor::Match { model, .. } => {
1213 model.reset_history();
1214 Ok(())
1215 }
1216 RateBackendPredictor::SparseMatch { model, .. } => {
1217 model.reset_history();
1218 Ok(())
1219 }
1220 RateBackendPredictor::Ppmd { model, .. } => {
1221 model.reset_history();
1222 Ok(())
1223 }
1224 RateBackendPredictor::Sequitur { model, .. } => {
1225 model.reset_frozen();
1226 Ok(())
1227 }
1228 RateBackendPredictor::Ctw { tree, .. } => {
1229 tree.reset_history_only();
1230 Ok(())
1231 }
1232 RateBackendPredictor::FacCtw { tree, .. } => {
1233 tree.reset_history_only();
1234 Ok(())
1235 }
1236 #[cfg(feature = "backend-rwkv")]
1237 RateBackendPredictor::Rwkv7 {
1238 compressor, primed, ..
1239 } => {
1240 compressor.reset_and_prime();
1241 *primed = true;
1242 Ok(())
1243 }
1244 #[cfg(feature = "backend-mamba")]
1245 RateBackendPredictor::Mamba {
1246 compressor, primed, ..
1247 } => {
1248 compressor.reset_and_prime();
1249 *primed = true;
1250 Ok(())
1251 }
1252 RateBackendPredictor::Zpaq { .. } => {
1253 Err("plugin entropy is not supported for zpaq rate backends in 1.1.1".to_string())
1254 }
1255 RateBackendPredictor::Mixture { runtime } => runtime.reset_frozen(total_symbols),
1256 RateBackendPredictor::Particle { runtime } => {
1257 runtime.reset_frozen_state();
1258 Ok(())
1259 }
1260 RateBackendPredictor::Calibrated {
1261 base,
1262 core,
1263 pdf,
1264 valid,
1265 ..
1266 } => {
1267 base.reset_frozen(total_symbols)?;
1268 core.reset_context();
1269 pdf.fill(1.0 / 256.0);
1270 *valid = false;
1271 Ok(())
1272 }
1273 }
1274 }
1275
1276 fn update_frozen(&mut self, symbol: u8) {
1277 match self {
1278 RateBackendPredictor::Rosa { model, .. } => {
1279 model.advance_conditioning_byte(symbol);
1280 }
1281 RateBackendPredictor::Match { model, .. } => {
1282 model.update_history_only(symbol);
1283 }
1284 RateBackendPredictor::SparseMatch { model, .. } => {
1285 model.update_history_only(symbol);
1286 }
1287 RateBackendPredictor::Ppmd { model, .. } => {
1288 model.update_history_only(symbol);
1289 }
1290 RateBackendPredictor::Sequitur { model, .. } => {
1291 model.update_frozen(symbol);
1292 }
1293 RateBackendPredictor::Ctw { tree, .. } => {
1294 let mut bits = [false; 8];
1295 for (bit_idx, slot) in bits.iter_mut().enumerate() {
1296 *slot = ((symbol >> (7 - bit_idx)) & 1) == 1;
1297 }
1298 tree.update_history(&bits);
1299 }
1300 RateBackendPredictor::FacCtw {
1301 tree,
1302 bits_per_symbol,
1303 ..
1304 } => {
1305 let bits = (*bits_per_symbol).clamp(1, 8);
1306 let mut history_bits = [false; 8];
1307 for (idx, slot) in history_bits.iter_mut().enumerate().take(bits) {
1308 *slot = ((symbol >> idx) & 1) == 1;
1309 }
1310 tree.update_history(&history_bits[..bits]);
1311 }
1312 #[cfg(feature = "backend-rwkv")]
1313 RateBackendPredictor::Rwkv7 {
1314 compressor, primed, ..
1315 } => {
1316 if !*primed {
1317 compressor.reset_and_prime();
1318 *primed = true;
1319 }
1320 compressor.forward_to_internal_pdf(symbol as u32);
1321 }
1322 #[cfg(feature = "backend-mamba")]
1323 RateBackendPredictor::Mamba {
1324 compressor, primed, ..
1325 } => {
1326 if !*primed {
1327 compressor.reset_and_prime();
1328 *primed = true;
1329 }
1330 let bias = compressor.online_bias_snapshot();
1331 let logits = compressor.model.forward(
1332 &mut compressor.scratch,
1333 symbol as u32,
1334 &mut compressor.state,
1335 );
1336 mambazip::Compressor::logits_to_pdf(
1337 logits,
1338 bias.as_deref(),
1339 &mut compressor.pdf_buffer,
1340 );
1341 }
1342 RateBackendPredictor::Zpaq { model } => {
1343 model.update(symbol);
1344 }
1345 RateBackendPredictor::Mixture { runtime } => {
1346 runtime.update_frozen(symbol);
1347 }
1348 RateBackendPredictor::Particle { runtime } => {
1349 runtime.update_frozen(symbol);
1350 }
1351 RateBackendPredictor::Calibrated {
1352 base,
1353 core,
1354 pdf,
1355 valid,
1356 ..
1357 } => {
1358 if !*valid {
1359 let mut base_logps = [0.0; 256];
1360 base.fill_log_probs(&mut base_logps);
1361 let mut base_pdf = [0.0; 256];
1362 for (dst, &lp) in base_pdf.iter_mut().zip(base_logps.iter()) {
1363 *dst = clamp_prob(lp.exp(), DEFAULT_MIN_PROB);
1364 }
1365 core.apply_pdf(&base_pdf, pdf);
1366 *valid = true;
1367 }
1368 base.update_frozen(symbol);
1369 core.update_context_only(symbol);
1370 *valid = false;
1371 }
1372 }
1373 }
1374}
1375
1376#[derive(Clone)]
1378pub struct ExpertConfig {
1379 pub name: String,
1381 pub log_prior: f64,
1383 builder: Arc<dyn Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync>,
1384}
1385
1386impl ExpertConfig {
1387 pub fn new(
1389 name: impl Into<String>,
1390 log_prior: f64,
1391 builder: impl Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync + 'static,
1392 ) -> Self {
1393 Self {
1394 name: name.into(),
1395 log_prior,
1396 builder: Arc::new(builder),
1397 }
1398 }
1399
1400 pub fn uniform(
1402 name: impl Into<String>,
1403 builder: impl Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync + 'static,
1404 ) -> Self {
1405 Self::new(name, 0.0, builder)
1406 }
1407
1408 pub fn from_rate_backend(
1410 name: Option<String>,
1411 log_prior: f64,
1412 backend: RateBackend,
1413 max_order: i64,
1414 ) -> Self {
1415 let name = name.unwrap_or_else(|| RateBackendPredictor::default_name(&backend, max_order));
1416 Self::new(name, log_prior, move || {
1417 Box::new(RateBackendPredictor::from_backend(
1418 backend.clone(),
1419 max_order,
1420 DEFAULT_MIN_PROB,
1421 ))
1422 })
1423 }
1424
1425 pub fn rosa(name: impl Into<String>, max_order: i64) -> Self {
1427 let name = name.into();
1428 Self::uniform(name, move || {
1429 Box::new(RateBackendPredictor::from_backend(
1430 RateBackend::RosaPlus,
1431 max_order,
1432 DEFAULT_MIN_PROB,
1433 ))
1434 })
1435 }
1436
1437 pub fn ctw(name: impl Into<String>, depth: usize) -> Self {
1439 let name = name.into();
1440 Self::uniform(name, move || {
1441 Box::new(RateBackendPredictor::from_backend(
1442 RateBackend::Ctw { depth },
1443 -1,
1444 DEFAULT_MIN_PROB,
1445 ))
1446 })
1447 }
1448
1449 pub fn fac_ctw(name: impl Into<String>, base_depth: usize, encoding_bits: usize) -> Self {
1451 let name = name.into();
1452 Self::uniform(name, move || {
1453 Box::new(RateBackendPredictor::from_backend(
1454 RateBackend::FacCtw {
1455 base_depth,
1456 num_percept_bits: encoding_bits,
1457 encoding_bits,
1458 },
1459 -1,
1460 DEFAULT_MIN_PROB,
1461 ))
1462 })
1463 }
1464
1465 #[cfg(feature = "backend-rwkv")]
1467 pub fn rwkv(name: impl Into<String>, model: Arc<rwkvzip::Model>) -> Self {
1468 let name = name.into();
1469 Self::uniform(name, move || {
1470 Box::new(RateBackendPredictor::from_backend(
1471 RateBackend::Rwkv7 {
1472 model: model.clone(),
1473 },
1474 -1,
1475 DEFAULT_MIN_PROB,
1476 ))
1477 })
1478 }
1479
1480 #[cfg(feature = "backend-mamba")]
1482 pub fn mamba(name: impl Into<String>, model: Arc<mambazip::Model>) -> Self {
1483 let name = name.into();
1484 Self::uniform(name, move || {
1485 Box::new(RateBackendPredictor::from_backend(
1486 RateBackend::Mamba {
1487 model: model.clone(),
1488 },
1489 -1,
1490 DEFAULT_MIN_PROB,
1491 ))
1492 })
1493 }
1494
1495 pub fn zpaq(name: impl Into<String>, method: impl Into<String>) -> Self {
1497 let name = name.into();
1498 let method = method.into();
1499 Self::uniform(name, move || {
1500 Box::new(RateBackendPredictor::from_backend(
1501 RateBackend::Zpaq {
1502 method: method.clone(),
1503 },
1504 -1,
1505 DEFAULT_MIN_PROB,
1506 ))
1507 })
1508 }
1509
1510 pub fn name(&self) -> &str {
1512 &self.name
1513 }
1514
1515 pub fn log_prior(&self) -> f64 {
1517 self.log_prior
1518 }
1519
1520 pub fn build_predictor(&self) -> Box<dyn OnlineBytePredictor> {
1522 (self.builder)()
1523 }
1524
1525 fn build(&self) -> ExpertState {
1526 ExpertState {
1527 name: self.name.clone(),
1528 log_weight: self.log_prior,
1529 log_prior: self.log_prior,
1530 predictor: (self.builder)(),
1531 cum_log_loss: 0.0,
1532 }
1533 }
1534}
1535
1536#[derive(Clone)]
1537struct ExpertState {
1538 name: String,
1539 log_weight: f64,
1540 log_prior: f64,
1541 predictor: Box<dyn OnlineBytePredictor>,
1542 cum_log_loss: f64,
1543}
1544
1545impl ExpertState {
1546 #[inline]
1547 fn begin_stream(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1548 self.predictor.begin_stream(total_symbols)
1549 }
1550
1551 #[inline]
1552 fn finish_stream(&mut self) -> Result<(), String> {
1553 self.predictor.finish_stream()
1554 }
1555
1556 #[inline]
1557 fn log_prob(&mut self, symbol: u8) -> f64 {
1558 self.predictor.log_prob(symbol)
1559 }
1560
1561 #[inline]
1562 fn log_prob_update(&mut self, symbol: u8) -> f64 {
1563 self.predictor.log_prob_update(symbol)
1564 }
1565
1566 #[inline]
1567 fn update(&mut self, symbol: u8) {
1568 self.predictor.update(symbol);
1569 }
1570
1571 #[inline]
1572 fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1573 self.predictor.reset_frozen(total_symbols)
1574 }
1575
1576 #[inline]
1577 fn update_frozen(&mut self, symbol: u8) {
1578 self.predictor.update_frozen(symbol);
1579 }
1580}
1581
1582#[derive(Clone)]
1584pub struct BayesMixture {
1585 experts: Vec<ExpertState>,
1586 scratch_logps: Vec<f64>,
1587 scratch_mix: Vec<f64>,
1588 cached_symbol: u8,
1589 cached_log_mix: f64,
1590 cache_valid: bool,
1591 total_log_loss: f64,
1592}
1593
1594impl BayesMixture {
1595 pub fn new(configs: &[ExpertConfig]) -> Self {
1597 let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
1598 let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
1599 let norm = logsumexp(&log_priors);
1600 for e in &mut experts {
1601 e.log_weight -= norm;
1602 }
1603 Self {
1604 experts,
1605 scratch_logps: vec![0.0; configs.len()],
1606 scratch_mix: vec![0.0; configs.len()],
1607 cached_symbol: 0,
1608 cached_log_mix: f64::NEG_INFINITY,
1609 cache_valid: false,
1610 total_log_loss: 0.0,
1611 }
1612 }
1613
1614 pub fn step(&mut self, symbol: u8) -> f64 {
1616 if self.experts.is_empty() {
1617 return f64::NEG_INFINITY;
1618 }
1619 let log_mix = if self.cache_valid && self.cached_symbol == symbol {
1620 for (i, expert) in self.experts.iter_mut().enumerate() {
1621 expert.cum_log_loss -= self.scratch_logps[i];
1622 expert.update(symbol);
1623 }
1624 self.cached_log_mix
1625 } else {
1626 for (i, expert) in self.experts.iter_mut().enumerate() {
1627 self.scratch_logps[i] = expert.log_prob_update(symbol);
1628 self.scratch_mix[i] = expert.log_weight + self.scratch_logps[i];
1629 expert.cum_log_loss -= self.scratch_logps[i];
1630 }
1631 logsumexp(&self.scratch_mix)
1632 };
1633 for (i, expert) in self.experts.iter_mut().enumerate() {
1634 expert.log_weight = expert.log_weight + self.scratch_logps[i] - log_mix;
1635 }
1636 self.cache_valid = false;
1637 self.total_log_loss -= log_mix;
1638 log_mix
1639 }
1640
1641 fn predict_log_prob(&mut self, symbol: u8) -> f64 {
1642 if self.experts.is_empty() {
1643 return f64::NEG_INFINITY;
1644 }
1645 for (i, expert) in self.experts.iter_mut().enumerate() {
1646 self.scratch_logps[i] = expert.log_prob(symbol);
1647 self.scratch_mix[i] = expert.log_weight + self.scratch_logps[i];
1648 }
1649 let log_mix = logsumexp(&self.scratch_mix);
1650 self.cached_symbol = symbol;
1651 self.cached_log_mix = log_mix;
1652 self.cache_valid = true;
1653 log_mix
1654 }
1655
1656 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
1657 if self.experts.is_empty() {
1658 out.fill(f64::NEG_INFINITY);
1659 return;
1660 }
1661 out.fill(f64::NEG_INFINITY);
1662 let norm = logsumexp_weights(&self.experts);
1663 let mut row = [0.0f64; 256];
1664 for expert in &mut self.experts {
1665 expert.predictor.fill_log_probs(&mut row);
1666 let lw = expert.log_weight - norm;
1667 for b in 0..256 {
1668 out[b] = logsumexp2(out[b], lw + row[b]);
1669 }
1670 }
1671 }
1672
1673 pub fn posterior(&self) -> Vec<f64> {
1675 let norm = logsumexp_weights(&self.experts);
1676 self.experts
1677 .iter()
1678 .map(|e| (e.log_weight - norm).exp())
1679 .collect()
1680 }
1681
1682 pub fn min_expert_log_loss(&self) -> (usize, f64) {
1684 let mut best_idx = 0usize;
1685 let mut best_loss = f64::INFINITY;
1686 for (i, e) in self.experts.iter().enumerate() {
1687 if e.cum_log_loss < best_loss {
1688 best_loss = e.cum_log_loss;
1689 best_idx = i;
1690 }
1691 }
1692 (best_idx, best_loss)
1693 }
1694
1695 pub fn max_posterior(&self) -> (usize, f64) {
1697 let norm = logsumexp_weights(&self.experts);
1698 let mut best_idx = 0usize;
1699 let mut best_p = 0.0;
1700 for (i, e) in self.experts.iter().enumerate() {
1701 let p = (e.log_weight - norm).exp();
1702 if p > best_p {
1703 best_p = p;
1704 best_idx = i;
1705 }
1706 }
1707 (best_idx, best_p)
1708 }
1709
1710 pub fn total_log_loss(&self) -> f64 {
1712 self.total_log_loss
1713 }
1714
1715 pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
1717 self.experts
1718 .iter()
1719 .map(|e| (e.name.clone(), e.cum_log_loss))
1720 .collect()
1721 }
1722
1723 pub fn expert_names(&self) -> Vec<String> {
1725 self.experts.iter().map(|e| e.name.clone()).collect()
1726 }
1727
1728 fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1729 for expert in &mut self.experts {
1730 expert.reset_frozen(total_symbols)?;
1731 }
1732 self.cache_valid = false;
1733 self.total_log_loss = 0.0;
1734 Ok(())
1735 }
1736
1737 fn update_frozen(&mut self, symbol: u8) {
1738 for expert in &mut self.experts {
1739 expert.update_frozen(symbol);
1740 }
1741 self.cache_valid = false;
1742 }
1743}
1744
1745#[derive(Clone)]
1749pub struct FadingBayesMixture {
1750 experts: Vec<ExpertState>,
1751 decay: f64,
1752 scratch_logps: Vec<f64>,
1753 scratch_mix: Vec<f64>,
1754 cached_symbol: u8,
1755 cached_log_predictive: f64,
1756 cached_log_evidence: f64,
1757 cache_valid: bool,
1758 total_log_loss: f64,
1759}
1760
1761impl FadingBayesMixture {
1762 pub fn new(configs: &[ExpertConfig], decay: f64) -> Self {
1764 let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
1765 let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
1766 let norm = logsumexp(&log_priors);
1767 for e in &mut experts {
1768 e.log_weight -= norm;
1769 }
1770 let decay = decay.clamp(0.0, 1.0);
1771 Self {
1772 experts,
1773 decay,
1774 scratch_logps: vec![0.0; configs.len()],
1775 scratch_mix: vec![0.0; configs.len()],
1776 cached_symbol: 0,
1777 cached_log_predictive: f64::NEG_INFINITY,
1778 cached_log_evidence: f64::NEG_INFINITY,
1779 cache_valid: false,
1780 total_log_loss: 0.0,
1781 }
1782 }
1783
1784 pub fn step(&mut self, symbol: u8) -> f64 {
1786 if self.experts.is_empty() {
1787 return f64::NEG_INFINITY;
1788 }
1789 let (log_predictive, log_evidence) = if self.cache_valid && self.cached_symbol == symbol {
1790 for (i, expert) in self.experts.iter_mut().enumerate() {
1791 expert.cum_log_loss -= self.scratch_logps[i];
1792 expert.update(symbol);
1793 }
1794 (self.cached_log_predictive, self.cached_log_evidence)
1795 } else {
1796 for (i, expert) in self.experts.iter_mut().enumerate() {
1797 self.scratch_logps[i] = expert.log_prob_update(symbol);
1798 self.scratch_mix[i] = self.decay * expert.log_weight;
1799 }
1800 let log_prior_norm = logsumexp(&self.scratch_mix);
1801 for (i, expert) in self.experts.iter_mut().enumerate() {
1802 self.scratch_mix[i] += self.scratch_logps[i];
1803 expert.cum_log_loss -= self.scratch_logps[i];
1804 }
1805 let log_evidence = logsumexp(&self.scratch_mix);
1806 (log_evidence - log_prior_norm, log_evidence)
1807 };
1808 for (i, expert) in self.experts.iter_mut().enumerate() {
1809 let decayed = self.decay * expert.log_weight;
1810 expert.log_weight = decayed + self.scratch_logps[i] - log_evidence;
1811 }
1812 self.cache_valid = false;
1813 self.total_log_loss -= log_predictive;
1814 log_predictive
1815 }
1816
1817 fn predict_log_prob(&mut self, symbol: u8) -> f64 {
1818 if self.experts.is_empty() {
1819 return f64::NEG_INFINITY;
1820 }
1821 for (i, expert) in self.experts.iter_mut().enumerate() {
1822 self.scratch_logps[i] = expert.log_prob(symbol);
1823 self.scratch_mix[i] = self.decay * expert.log_weight;
1824 }
1825 let log_prior_norm = logsumexp(&self.scratch_mix);
1826 for i in 0..self.experts.len() {
1827 self.scratch_mix[i] += self.scratch_logps[i];
1828 }
1829 let log_evidence = logsumexp(&self.scratch_mix);
1830 let log_predictive = log_evidence - log_prior_norm;
1831 self.cached_symbol = symbol;
1832 self.cached_log_predictive = log_predictive;
1833 self.cached_log_evidence = log_evidence;
1834 self.cache_valid = true;
1835 log_predictive
1836 }
1837
1838 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
1839 if self.experts.is_empty() {
1840 out.fill(f64::NEG_INFINITY);
1841 return;
1842 }
1843 out.fill(f64::NEG_INFINITY);
1844 let mut decayed = Vec::with_capacity(self.experts.len());
1845 for expert in &self.experts {
1846 decayed.push(self.decay * expert.log_weight);
1847 }
1848 let norm = logsumexp(&decayed);
1849 let mut row = [0.0f64; 256];
1850 for (i, expert) in self.experts.iter_mut().enumerate() {
1851 expert.predictor.fill_log_probs(&mut row);
1852 let lw = decayed[i] - norm;
1853 for b in 0..256 {
1854 out[b] = logsumexp2(out[b], lw + row[b]);
1855 }
1856 }
1857 }
1858
1859 pub fn posterior(&self) -> Vec<f64> {
1861 let norm = logsumexp_weights(&self.experts);
1862 self.experts
1863 .iter()
1864 .map(|e| (e.log_weight - norm).exp())
1865 .collect()
1866 }
1867
1868 pub fn min_expert_log_loss(&self) -> (usize, f64) {
1870 let mut best_idx = 0usize;
1871 let mut best_loss = f64::INFINITY;
1872 for (i, e) in self.experts.iter().enumerate() {
1873 if e.cum_log_loss < best_loss {
1874 best_loss = e.cum_log_loss;
1875 best_idx = i;
1876 }
1877 }
1878 (best_idx, best_loss)
1879 }
1880
1881 pub fn total_log_loss(&self) -> f64 {
1883 self.total_log_loss
1884 }
1885
1886 pub fn expert_names(&self) -> Vec<String> {
1888 self.experts.iter().map(|e| e.name.clone()).collect()
1889 }
1890
1891 fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1892 for expert in &mut self.experts {
1893 expert.reset_frozen(total_symbols)?;
1894 }
1895 self.cache_valid = false;
1896 self.total_log_loss = 0.0;
1897 Ok(())
1898 }
1899
1900 fn update_frozen(&mut self, symbol: u8) {
1901 for expert in &mut self.experts {
1902 expert.update_frozen(symbol);
1903 }
1904 self.cache_valid = false;
1905 }
1906}
1907
1908#[derive(Clone)]
1910pub struct SwitchingMixture {
1911 experts: Vec<ExpertState>,
1912 prior: Vec<f64>,
1913 alpha: f64,
1914 schedule: MixtureScheduleMode,
1915 scratch_logps: Vec<f64>,
1916 scratch_joint: Vec<f64>,
1917 scratch_weights: Vec<f64>,
1918 cached_symbol: u8,
1919 cached_log_mix: f64,
1920 cache_valid: bool,
1921 total_log_loss: f64,
1922 update_count: u64,
1923}
1924
1925impl SwitchingMixture {
1926 pub fn new(configs: &[ExpertConfig], alpha: f64, schedule: MixtureScheduleMode) -> Self {
1928 let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
1929 let prior = normalized_prior_weights(configs);
1930 set_log_weights_from_linear(&mut experts, &prior);
1931 Self {
1932 experts,
1933 prior,
1934 alpha,
1935 schedule,
1936 scratch_logps: vec![0.0; configs.len()],
1937 scratch_joint: vec![0.0; configs.len()],
1938 scratch_weights: vec![0.0; configs.len()],
1939 cached_symbol: 0,
1940 cached_log_mix: f64::NEG_INFINITY,
1941 cache_valid: false,
1942 total_log_loss: 0.0,
1943 update_count: 0,
1944 }
1945 }
1946
1947 pub fn step(&mut self, symbol: u8) -> f64 {
1949 if self.experts.is_empty() {
1950 return f64::NEG_INFINITY;
1951 }
1952 let log_mix = if self.cache_valid && self.cached_symbol == symbol {
1953 for (i, expert) in self.experts.iter_mut().enumerate() {
1954 expert.cum_log_loss -= self.scratch_logps[i];
1955 expert.update(symbol);
1956 }
1957 self.cached_log_mix
1958 } else {
1959 for (i, expert) in self.experts.iter_mut().enumerate() {
1960 self.scratch_logps[i] = expert.log_prob_update(symbol);
1961 expert.cum_log_loss -= self.scratch_logps[i];
1962 self.scratch_joint[i] = expert.log_weight + self.scratch_logps[i];
1963 }
1964 logsumexp(&self.scratch_joint)
1965 };
1966
1967 for i in 0..self.experts.len() {
1968 self.scratch_weights[i] = (self.scratch_joint[i] - log_mix).exp();
1969 }
1970
1971 let alpha = switching_alpha_for_update(self.schedule, self.alpha, self.update_count);
1972 self.update_count = self.update_count.saturating_add(1);
1973
1974 if self.experts.len() == 1 || alpha <= 0.0 {
1975 set_log_weights_from_linear(&mut self.experts, &self.scratch_weights);
1976 } else {
1977 let mut switch_out_sum = 0.0;
1978 let mut num_switch_targets = 0usize;
1979 for &prior in &self.prior {
1980 if prior < 1.0 {
1981 num_switch_targets += 1;
1982 }
1983 }
1984
1985 if num_switch_targets <= 1 {
1986 set_log_weights_from_linear(&mut self.experts, &self.scratch_weights);
1987 } else {
1988 for i in 0..self.experts.len() {
1989 let denom = 1.0 - self.prior[i];
1990 if denom > 0.0 {
1991 switch_out_sum += self.scratch_weights[i] / denom;
1992 }
1993 }
1994
1995 for i in 0..self.experts.len() {
1996 let stay = (1.0 - alpha) * self.scratch_weights[i];
1997 let switch_in = if self.prior[i] > 0.0 {
1998 let denom = 1.0 - self.prior[i];
1999 let switchable_mass = if denom > 0.0 {
2000 switch_out_sum - self.scratch_weights[i] / denom
2001 } else {
2002 0.0
2003 };
2004 alpha * self.prior[i] * switchable_mass
2005 } else {
2006 0.0
2007 };
2008 self.scratch_joint[i] = stay + switch_in;
2009 }
2010 normalize_simplex_weights(&mut self.scratch_joint);
2011 set_log_weights_from_linear(&mut self.experts, &self.scratch_joint);
2012 }
2013 }
2014 self.cache_valid = false;
2015 self.total_log_loss -= log_mix;
2016 log_mix
2017 }
2018
2019 fn predict_log_prob(&mut self, symbol: u8) -> f64 {
2020 if self.experts.is_empty() {
2021 return f64::NEG_INFINITY;
2022 }
2023 for i in 0..self.experts.len() {
2024 let lp = self.experts[i].log_prob(symbol);
2025 self.scratch_logps[i] = lp;
2026 self.scratch_joint[i] = self.experts[i].log_weight + lp;
2027 }
2028 let log_mix = logsumexp(&self.scratch_joint);
2029 self.cached_symbol = symbol;
2030 self.cached_log_mix = log_mix;
2031 self.cache_valid = true;
2032 log_mix
2033 }
2034
2035 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2036 if self.experts.is_empty() {
2037 out.fill(f64::NEG_INFINITY);
2038 return;
2039 }
2040 out.fill(f64::NEG_INFINITY);
2041 let norm = logsumexp_weights(&self.experts);
2042 let mut row = [0.0f64; 256];
2043 for expert in &mut self.experts {
2044 expert.predictor.fill_log_probs(&mut row);
2045 let lw = expert.log_weight - norm;
2046 for b in 0..256 {
2047 out[b] = logsumexp2(out[b], lw + row[b]);
2048 }
2049 }
2050 }
2051
2052 pub fn posterior(&self) -> Vec<f64> {
2054 let norm = logsumexp_weights(&self.experts);
2055 self.experts
2056 .iter()
2057 .map(|e| (e.log_weight - norm).exp())
2058 .collect()
2059 }
2060
2061 pub fn min_expert_log_loss(&self) -> (usize, f64) {
2063 let mut best_idx = 0usize;
2064 let mut best_loss = f64::INFINITY;
2065 for (i, e) in self.experts.iter().enumerate() {
2066 if e.cum_log_loss < best_loss {
2067 best_loss = e.cum_log_loss;
2068 best_idx = i;
2069 }
2070 }
2071 (best_idx, best_loss)
2072 }
2073
2074 pub fn max_posterior(&self) -> (usize, f64) {
2076 let norm = logsumexp_weights(&self.experts);
2077 let mut best_idx = 0usize;
2078 let mut best_p = 0.0;
2079 for (i, e) in self.experts.iter().enumerate() {
2080 let p = (e.log_weight - norm).exp();
2081 if p > best_p {
2082 best_p = p;
2083 best_idx = i;
2084 }
2085 }
2086 (best_idx, best_p)
2087 }
2088
2089 pub fn total_log_loss(&self) -> f64 {
2091 self.total_log_loss
2092 }
2093
2094 pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
2096 self.experts
2097 .iter()
2098 .map(|e| (e.name.clone(), e.cum_log_loss))
2099 .collect()
2100 }
2101
2102 pub fn expert_names(&self) -> Vec<String> {
2104 self.experts.iter().map(|e| e.name.clone()).collect()
2105 }
2106
2107 fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2108 for expert in &mut self.experts {
2109 expert.reset_frozen(total_symbols)?;
2110 }
2111 self.cache_valid = false;
2112 self.total_log_loss = 0.0;
2113 self.update_count = 0;
2114 Ok(())
2115 }
2116
2117 fn update_frozen(&mut self, symbol: u8) {
2118 for expert in &mut self.experts {
2119 expert.update_frozen(symbol);
2120 }
2121 self.cache_valid = false;
2122 }
2123}
2124
2125#[derive(Clone)]
2127pub struct ConvexMixture {
2128 experts: Vec<ExpertState>,
2129 alpha: f64,
2130 schedule: MixtureScheduleMode,
2131 lambda: Vec<f64>,
2132 scratch_logps: Vec<f64>,
2133 projection_scratch: Vec<f64>,
2134 cached_symbol: u8,
2135 cached_log_mix: f64,
2136 cache_valid: bool,
2137 total_log_loss: f64,
2138 update_count: u64,
2139}
2140
2141impl ConvexMixture {
2142 pub fn new(configs: &[ExpertConfig], alpha: f64, schedule: MixtureScheduleMode) -> Self {
2144 Self {
2145 experts: configs.iter().map(|c| c.build()).collect(),
2146 alpha,
2147 schedule,
2148 lambda: normalized_prior_weights(configs),
2149 scratch_logps: vec![0.0; configs.len()],
2150 projection_scratch: Vec::with_capacity(configs.len()),
2151 cached_symbol: 0,
2152 cached_log_mix: f64::NEG_INFINITY,
2153 cache_valid: false,
2154 total_log_loss: 0.0,
2155 update_count: 0,
2156 }
2157 }
2158
2159 fn mix_log_prob(&self, logps: &[f64]) -> f64 {
2160 let mut mix = 0.0;
2161 for (weight, &logp) in self.lambda.iter().zip(logps.iter()) {
2162 if *weight > 0.0 {
2163 mix += *weight * logp.exp();
2164 }
2165 }
2166 clamp_prob(mix, DEFAULT_MIN_PROB).ln()
2167 }
2168
2169 pub fn step(&mut self, symbol: u8) -> f64 {
2171 if self.experts.is_empty() {
2172 return f64::NEG_INFINITY;
2173 }
2174
2175 let log_mix = if self.cache_valid && self.cached_symbol == symbol {
2176 for (i, expert) in self.experts.iter_mut().enumerate() {
2177 expert.cum_log_loss -= self.scratch_logps[i];
2178 expert.update(symbol);
2179 }
2180 self.cached_log_mix
2181 } else {
2182 for (i, expert) in self.experts.iter_mut().enumerate() {
2183 self.scratch_logps[i] = expert.log_prob_update(symbol);
2184 expert.cum_log_loss -= self.scratch_logps[i];
2185 }
2186 self.mix_log_prob(&self.scratch_logps)
2187 };
2188
2189 self.update_count = self.update_count.saturating_add(1);
2190 let step_size = convex_step_size_for_update(self.schedule, self.alpha, self.update_count);
2191 for (weight, &logp) in self.lambda.iter_mut().zip(self.scratch_logps.iter()) {
2192 let grad = -(logp - log_mix).exp();
2193 *weight -= step_size * grad;
2194 }
2195 project_simplex_with_scratch(&mut self.lambda, &mut self.projection_scratch);
2196 self.cache_valid = false;
2197 self.total_log_loss -= log_mix;
2198 log_mix
2199 }
2200
2201 fn predict_log_prob(&mut self, symbol: u8) -> f64 {
2202 if self.experts.is_empty() {
2203 return f64::NEG_INFINITY;
2204 }
2205 for (i, expert) in self.experts.iter_mut().enumerate() {
2206 self.scratch_logps[i] = expert.log_prob(symbol);
2207 }
2208 let log_mix = self.mix_log_prob(&self.scratch_logps);
2209 self.cached_symbol = symbol;
2210 self.cached_log_mix = log_mix;
2211 self.cache_valid = true;
2212 log_mix
2213 }
2214
2215 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2216 if self.experts.is_empty() {
2217 out.fill(f64::NEG_INFINITY);
2218 return;
2219 }
2220 out.fill(f64::NEG_INFINITY);
2221 let mut row = [0.0f64; 256];
2222 for (index, expert) in self.experts.iter_mut().enumerate() {
2223 expert.predictor.fill_log_probs(&mut row);
2224 let weight = self.lambda.get(index).copied().unwrap_or(0.0);
2225 if weight <= 0.0 {
2226 continue;
2227 }
2228 let log_weight = weight.ln();
2229 for byte in 0..256 {
2230 out[byte] = logsumexp2(out[byte], log_weight + row[byte]);
2231 }
2232 }
2233 }
2234
2235 fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2236 for expert in &mut self.experts {
2237 expert.reset_frozen(total_symbols)?;
2238 }
2239 self.cache_valid = false;
2240 self.total_log_loss = 0.0;
2241 self.update_count = 0;
2242 Ok(())
2243 }
2244
2245 fn update_frozen(&mut self, symbol: u8) {
2246 for expert in &mut self.experts {
2247 expert.update_frozen(symbol);
2248 }
2249 self.cache_valid = false;
2250 }
2251}
2252
2253#[derive(Clone)]
2255pub struct MdlSelector {
2256 experts: Vec<ExpertState>,
2257 scratch_logps: Vec<f64>,
2258 total_log_loss: f64,
2259 last_best: usize,
2260 cached_symbol: u8,
2261 cached_best_idx: usize,
2262 cached_best_logp: f64,
2263 cache_valid: bool,
2264}
2265
2266#[derive(Clone)]
2274pub struct NeuralMixture {
2275 experts: Vec<ExpertState>,
2276 neural: NeuralMixCore,
2277 analyzer: TextContextAnalyzer,
2278 min_prob: f64,
2279 scratch_expert_logps: Vec<f64>,
2280 scratch_mix_weights: Vec<f64>,
2281 eval_cache_valid: bool,
2282 eval_cache_full_valid: bool,
2283 eval_cache_history: NeuralHistoryState,
2284 eval_cache_symbol: u8,
2285 eval_cache_logp: f64,
2286 eval_cache_mix_logps: [f64; 256],
2287 eval_cache_expert_logps: Vec<[f64; 256]>,
2288 total_log_loss: f64,
2289}
2290
2291impl NeuralMixture {
2292 pub fn new(configs: &[ExpertConfig], learning_rate: f64) -> Self {
2294 let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
2295 let n = experts.len();
2296
2297 let mut prior_weights = vec![0.0; n];
2298 if n > 0 {
2299 let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
2300 let norm = logsumexp(&log_priors);
2301 for (i, e) in experts.iter_mut().enumerate() {
2302 let p = (e.log_prior - norm).exp();
2303 prior_weights[i] = p;
2304 }
2305 }
2306
2307 let base_lr = if learning_rate.is_finite() {
2308 learning_rate.abs().clamp(1e-6, 1.0)
2309 } else {
2310 0.03
2311 };
2312 let effective_lr = (base_lr * 25.0).clamp(1e-6, 1.0);
2313 let analyzer = TextContextAnalyzer::new();
2314 let mut neural =
2315 NeuralMixCore::new(n, &prior_weights, effective_lr * 0.5, effective_lr, 1e-5);
2316 neural.set_context_state(analyzer.state());
2317 let eval_cache_history = neural.history_state();
2318
2319 Self {
2320 experts,
2321 neural,
2322 analyzer,
2323 min_prob: DEFAULT_MIN_PROB,
2324 scratch_expert_logps: vec![0.0; n],
2325 scratch_mix_weights: vec![0.0; n],
2326 eval_cache_valid: false,
2327 eval_cache_full_valid: false,
2328 eval_cache_history,
2329 eval_cache_symbol: 0,
2330 eval_cache_logp: f64::NEG_INFINITY,
2331 eval_cache_mix_logps: [f64::NEG_INFINITY; 256],
2332 eval_cache_expert_logps: vec![[f64::NEG_INFINITY; 256]; n],
2333 total_log_loss: 0.0,
2334 }
2335 }
2336
2337 #[inline]
2338 fn invalidate_eval_cache(&mut self) {
2339 self.eval_cache_valid = false;
2340 self.eval_cache_full_valid = false;
2341 }
2342
2343 fn sync_history_state(&mut self) -> NeuralHistoryState {
2344 let history = self.analyzer.state();
2345 if self.neural.history_state() != history {
2346 self.neural.set_context_state(history);
2347 }
2348 if self.eval_cache_history != history {
2349 self.invalidate_eval_cache();
2350 self.eval_cache_history = history;
2351 }
2352 history
2353 }
2354
2355 fn ensure_full_evaluation(&mut self) {
2356 self.sync_history_state();
2357 if self.eval_cache_full_valid {
2358 return;
2359 }
2360
2361 self.neural.evaluate_expert_weights();
2362 self.scratch_mix_weights
2363 .copy_from_slice(self.neural.expert_weights());
2364 let mut mix_pdf = [0.0f64; 256];
2365 for i in 0..self.experts.len() {
2366 let row = &mut self.eval_cache_expert_logps[i];
2367 self.experts[i].predictor.fill_log_probs(row);
2368 let w = self.scratch_mix_weights[i];
2369 for (dst, &lp) in mix_pdf.iter_mut().zip(row.iter()) {
2370 *dst += w * clamp_prob(lp.exp(), self.min_prob);
2371 }
2372 }
2373
2374 let sum: f64 = mix_pdf.iter().sum();
2375 if !sum.is_finite() || sum <= 0.0 {
2376 let uniform = (1.0f64 / 256.0).ln();
2377 self.eval_cache_mix_logps.fill(uniform);
2378 } else {
2379 let inv = 1.0 / sum;
2380 for (dst, &p_raw) in self.eval_cache_mix_logps.iter_mut().zip(mix_pdf.iter()) {
2381 let p = clamp_unit_prob(p_raw * inv, self.min_prob);
2382 *dst = p.ln();
2383 }
2384 }
2385
2386 self.eval_cache_full_valid = true;
2387 }
2388
2389 fn evaluate_symbol(&mut self, symbol: u8) -> f64 {
2390 let history = self.sync_history_state();
2391 if self.eval_cache_valid
2392 && self.eval_cache_history == history
2393 && self.eval_cache_symbol == symbol
2394 {
2395 return self.eval_cache_logp;
2396 }
2397
2398 if self.eval_cache_full_valid && self.eval_cache_history == history {
2399 for (dst, row) in self
2400 .scratch_expert_logps
2401 .iter_mut()
2402 .zip(self.eval_cache_expert_logps.iter())
2403 {
2404 *dst = row[symbol as usize];
2405 }
2406 let logp = self.eval_cache_mix_logps[symbol as usize];
2407 self.eval_cache_valid = true;
2408 self.eval_cache_symbol = symbol;
2409 self.eval_cache_logp = logp;
2410 return logp;
2411 }
2412
2413 let expert_count = self.experts.len();
2414 for i in 0..expert_count {
2415 self.scratch_expert_logps[i] = self.experts[i].log_prob(symbol);
2416 }
2417 let p = self
2418 .neural
2419 .evaluate_symbol(&self.scratch_expert_logps, self.min_prob);
2420 let logp = clamp_unit_prob(p, self.min_prob).ln();
2421 self.eval_cache_valid = true;
2422 self.eval_cache_history = history;
2423 self.eval_cache_symbol = symbol;
2424 self.eval_cache_logp = logp;
2425 logp
2426 }
2427
2428 fn predict_log_prob(&mut self, symbol: u8) -> f64 {
2429 if self.experts.is_empty() {
2430 return f64::NEG_INFINITY;
2431 }
2432 if self.experts.len() == 1 {
2433 return self.experts[0].log_prob(symbol);
2434 }
2435 self.evaluate_symbol(symbol)
2436 }
2437
2438 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2439 if self.experts.is_empty() {
2440 out.fill(f64::NEG_INFINITY);
2441 return;
2442 }
2443 if self.experts.len() == 1 {
2444 self.experts[0].predictor.fill_log_probs(out);
2445 return;
2446 }
2447 self.ensure_full_evaluation();
2448 out.copy_from_slice(&self.eval_cache_mix_logps);
2449 }
2450
2451 pub fn step(&mut self, symbol: u8) -> f64 {
2453 if self.experts.is_empty() {
2454 return f64::NEG_INFINITY;
2455 }
2456
2457 if self.experts.len() == 1 {
2458 let expert = &mut self.experts[0];
2459 let logp = expert.log_prob_update(symbol);
2460 expert.cum_log_loss -= logp;
2461 self.total_log_loss -= logp;
2462 self.analyzer.update(symbol);
2463 self.neural.set_context_state(self.analyzer.state());
2464 self.invalidate_eval_cache();
2465 return logp;
2466 }
2467
2468 let history = self.sync_history_state();
2469 let logp = if self.eval_cache_valid
2470 && self.eval_cache_history == history
2471 && self.eval_cache_symbol == symbol
2472 {
2473 let logp = self.eval_cache_logp;
2474 for i in 0..self.experts.len() {
2475 let expert = &mut self.experts[i];
2476 expert.cum_log_loss -= self.scratch_expert_logps[i];
2477 expert.update(symbol);
2478 }
2479 logp
2480 } else if self.eval_cache_full_valid && self.eval_cache_history == history {
2481 for i in 0..self.experts.len() {
2482 self.scratch_expert_logps[i] = self.eval_cache_expert_logps[i][symbol as usize];
2483 }
2484 let logp = self.eval_cache_mix_logps[symbol as usize];
2485 for i in 0..self.experts.len() {
2486 let expert = &mut self.experts[i];
2487 expert.cum_log_loss -= self.scratch_expert_logps[i];
2488 expert.update(symbol);
2489 }
2490 logp
2491 } else {
2492 for i in 0..self.experts.len() {
2493 let expert = &mut self.experts[i];
2494 self.scratch_expert_logps[i] = expert.log_prob_update(symbol);
2495 expert.cum_log_loss -= self.scratch_expert_logps[i];
2496 }
2497 let p = self
2498 .neural
2499 .evaluate_symbol(&self.scratch_expert_logps, self.min_prob);
2500 clamp_unit_prob(p, self.min_prob).ln()
2501 };
2502 self.neural
2503 .update_weights_symbol(&self.scratch_expert_logps, self.min_prob);
2504 self.total_log_loss -= logp;
2505 self.analyzer.update(symbol);
2506 self.neural.set_context_state(self.analyzer.state());
2507 self.invalidate_eval_cache();
2508 logp
2509 }
2510
2511 pub fn total_log_loss(&self) -> f64 {
2513 self.total_log_loss
2514 }
2515
2516 fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2517 for expert in &mut self.experts {
2518 expert.reset_frozen(total_symbols)?;
2519 }
2520 self.analyzer = TextContextAnalyzer::new();
2521 self.neural.set_context_state(self.analyzer.state());
2522 self.invalidate_eval_cache();
2523 self.eval_cache_history = self.neural.history_state();
2524 self.total_log_loss = 0.0;
2525 Ok(())
2526 }
2527
2528 fn update_frozen(&mut self, symbol: u8) {
2529 for expert in &mut self.experts {
2530 expert.update_frozen(symbol);
2531 }
2532 self.analyzer.update(symbol);
2533 self.neural.set_context_state(self.analyzer.state());
2534 self.invalidate_eval_cache();
2535 self.eval_cache_history = self.neural.history_state();
2536 }
2537}
2538
2539impl MdlSelector {
2540 pub fn new(configs: &[ExpertConfig]) -> Self {
2542 let experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
2543 let last_best = 0usize;
2544 Self {
2545 experts,
2546 scratch_logps: vec![0.0; configs.len()],
2547 total_log_loss: 0.0,
2548 last_best,
2549 cached_symbol: 0,
2550 cached_best_idx: 0,
2551 cached_best_logp: f64::NEG_INFINITY,
2552 cache_valid: false,
2553 }
2554 }
2555
2556 pub fn step(&mut self, symbol: u8) -> f64 {
2558 if self.experts.is_empty() {
2559 return f64::NEG_INFINITY;
2560 }
2561 let used_cache = self.cache_valid && self.cached_symbol == symbol;
2562 let best_idx = if used_cache {
2563 self.scratch_logps[self.cached_best_idx] = self.cached_best_logp;
2564 for (i, expert) in self.experts.iter_mut().enumerate() {
2565 if i == self.cached_best_idx {
2566 continue;
2567 }
2568 self.scratch_logps[i] = expert.log_prob(symbol);
2569 }
2570 self.cached_best_idx
2571 } else {
2572 for (i, expert) in self.experts.iter_mut().enumerate() {
2573 self.scratch_logps[i] = expert.log_prob_update(symbol);
2574 }
2575 let mut best_idx = 0usize;
2576 let mut best_loss = f64::INFINITY;
2577 for (i, expert) in self.experts.iter().enumerate() {
2578 if expert.cum_log_loss < best_loss {
2579 best_loss = expert.cum_log_loss;
2580 best_idx = i;
2581 }
2582 }
2583 best_idx
2584 };
2585 let logp = self.scratch_logps[best_idx];
2586 self.cache_valid = false;
2587 for (i, expert) in self.experts.iter_mut().enumerate() {
2588 expert.cum_log_loss -= self.scratch_logps[i];
2589 if used_cache {
2590 expert.update(symbol);
2591 }
2592 }
2593 self.total_log_loss -= logp;
2594 self.last_best = best_idx;
2595 logp
2596 }
2597
2598 fn predict_log_prob(&mut self, symbol: u8) -> f64 {
2599 if self.experts.is_empty() {
2600 return f64::NEG_INFINITY;
2601 }
2602 let mut best_idx = 0usize;
2603 let mut best_loss = f64::INFINITY;
2604 for (i, expert) in self.experts.iter().enumerate() {
2605 if expert.cum_log_loss < best_loss {
2606 best_loss = expert.cum_log_loss;
2607 best_idx = i;
2608 }
2609 }
2610 let logp = self.experts[best_idx].log_prob(symbol);
2611 self.cached_symbol = symbol;
2612 self.cached_best_idx = best_idx;
2613 self.cached_best_logp = logp;
2614 self.cache_valid = true;
2615 logp
2616 }
2617
2618 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2619 if self.experts.is_empty() {
2620 out.fill(f64::NEG_INFINITY);
2621 return;
2622 }
2623 let mut best_idx = 0usize;
2624 let mut best_loss = f64::INFINITY;
2625 for (i, expert) in self.experts.iter().enumerate() {
2626 if expert.cum_log_loss < best_loss {
2627 best_loss = expert.cum_log_loss;
2628 best_idx = i;
2629 }
2630 }
2631 self.experts[best_idx].predictor.fill_log_probs(out);
2632 }
2633
2634 pub fn best_index(&self) -> usize {
2636 self.last_best
2637 }
2638
2639 pub fn min_expert_log_loss(&self) -> (usize, f64) {
2641 let mut best_idx = 0usize;
2642 let mut best_loss = f64::INFINITY;
2643 for (i, e) in self.experts.iter().enumerate() {
2644 if e.cum_log_loss < best_loss {
2645 best_loss = e.cum_log_loss;
2646 best_idx = i;
2647 }
2648 }
2649 (best_idx, best_loss)
2650 }
2651
2652 pub fn total_log_loss(&self) -> f64 {
2654 self.total_log_loss
2655 }
2656
2657 pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
2659 self.experts
2660 .iter()
2661 .map(|e| (e.name.clone(), e.cum_log_loss))
2662 .collect()
2663 }
2664
2665 pub fn expert_names(&self) -> Vec<String> {
2667 self.experts.iter().map(|e| e.name.clone()).collect()
2668 }
2669
2670 fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2671 for expert in &mut self.experts {
2672 expert.reset_frozen(total_symbols)?;
2673 }
2674 self.cache_valid = false;
2675 self.total_log_loss = 0.0;
2676 Ok(())
2677 }
2678
2679 fn update_frozen(&mut self, symbol: u8) {
2680 for expert in &mut self.experts {
2681 expert.update_frozen(symbol);
2682 }
2683 self.cache_valid = false;
2684 }
2685}
2686
2687#[allow(clippy::large_enum_variant)]
2693#[derive(Clone)]
2694pub enum MixtureRuntime {
2695 Bayes(BayesMixture),
2697 Fading(FadingBayesMixture),
2699 Switching(SwitchingMixture),
2701 Convex(ConvexMixture),
2703 Mdl(MdlSelector),
2705 Neural(NeuralMixture),
2707}
2708
2709impl MixtureRuntime {
2710 pub(crate) fn begin_stream(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2711 match self {
2712 MixtureRuntime::Bayes(m) => begin_expert_stream(&mut m.experts, total_symbols),
2713 MixtureRuntime::Fading(m) => begin_expert_stream(&mut m.experts, total_symbols),
2714 MixtureRuntime::Switching(m) => begin_expert_stream(&mut m.experts, total_symbols),
2715 MixtureRuntime::Convex(m) => begin_expert_stream(&mut m.experts, total_symbols),
2716 MixtureRuntime::Mdl(m) => begin_expert_stream(&mut m.experts, total_symbols),
2717 MixtureRuntime::Neural(m) => begin_expert_stream(&mut m.experts, total_symbols),
2718 }
2719 }
2720
2721 pub(crate) fn finish_stream(&mut self) -> Result<(), String> {
2722 match self {
2723 MixtureRuntime::Bayes(m) => finish_expert_stream(&mut m.experts),
2724 MixtureRuntime::Fading(m) => finish_expert_stream(&mut m.experts),
2725 MixtureRuntime::Switching(m) => finish_expert_stream(&mut m.experts),
2726 MixtureRuntime::Convex(m) => finish_expert_stream(&mut m.experts),
2727 MixtureRuntime::Mdl(m) => finish_expert_stream(&mut m.experts),
2728 MixtureRuntime::Neural(m) => finish_expert_stream(&mut m.experts),
2729 }
2730 }
2731
2732 pub(crate) fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2733 match self {
2734 MixtureRuntime::Bayes(m) => m.reset_frozen(total_symbols),
2735 MixtureRuntime::Fading(m) => m.reset_frozen(total_symbols),
2736 MixtureRuntime::Switching(m) => m.reset_frozen(total_symbols),
2737 MixtureRuntime::Convex(m) => m.reset_frozen(total_symbols),
2738 MixtureRuntime::Mdl(m) => m.reset_frozen(total_symbols),
2739 MixtureRuntime::Neural(m) => m.reset_frozen(total_symbols),
2740 }
2741 }
2742
2743 pub(crate) fn peek_log_prob(&mut self, symbol: u8) -> f64 {
2745 match self {
2746 MixtureRuntime::Bayes(m) => m.predict_log_prob(symbol),
2747 MixtureRuntime::Fading(m) => m.predict_log_prob(symbol),
2748 MixtureRuntime::Switching(m) => m.predict_log_prob(symbol),
2749 MixtureRuntime::Convex(m) => m.predict_log_prob(symbol),
2750 MixtureRuntime::Mdl(m) => m.predict_log_prob(symbol),
2751 MixtureRuntime::Neural(m) => m.predict_log_prob(symbol),
2752 }
2753 }
2754
2755 pub(crate) fn step(&mut self, symbol: u8) -> f64 {
2757 match self {
2758 MixtureRuntime::Bayes(m) => m.step(symbol),
2759 MixtureRuntime::Fading(m) => m.step(symbol),
2760 MixtureRuntime::Switching(m) => m.step(symbol),
2761 MixtureRuntime::Convex(m) => m.step(symbol),
2762 MixtureRuntime::Mdl(m) => m.step(symbol),
2763 MixtureRuntime::Neural(m) => m.step(symbol),
2764 }
2765 }
2766
2767 pub(crate) fn update_frozen(&mut self, symbol: u8) {
2768 match self {
2769 MixtureRuntime::Bayes(m) => m.update_frozen(symbol),
2770 MixtureRuntime::Fading(m) => m.update_frozen(symbol),
2771 MixtureRuntime::Switching(m) => m.update_frozen(symbol),
2772 MixtureRuntime::Convex(m) => m.update_frozen(symbol),
2773 MixtureRuntime::Mdl(m) => m.update_frozen(symbol),
2774 MixtureRuntime::Neural(m) => m.update_frozen(symbol),
2775 }
2776 }
2777
2778 pub(crate) fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2779 match self {
2780 MixtureRuntime::Bayes(m) => m.fill_log_probs(out),
2781 MixtureRuntime::Fading(m) => m.fill_log_probs(out),
2782 MixtureRuntime::Switching(m) => m.fill_log_probs(out),
2783 MixtureRuntime::Convex(m) => m.fill_log_probs(out),
2784 MixtureRuntime::Mdl(m) => m.fill_log_probs(out),
2785 MixtureRuntime::Neural(m) => m.fill_log_probs(out),
2786 }
2787 }
2788}
2789
2790fn begin_expert_stream(
2791 experts: &mut [ExpertState],
2792 total_symbols: Option<u64>,
2793) -> Result<(), String> {
2794 for expert in experts {
2795 expert.begin_stream(total_symbols)?;
2796 }
2797 Ok(())
2798}
2799
2800fn finish_expert_stream(experts: &mut [ExpertState]) -> Result<(), String> {
2801 for expert in experts {
2802 expert.finish_stream()?;
2803 }
2804 Ok(())
2805}
2806
2807pub(crate) fn build_mixture_runtime(
2808 spec: &MixtureSpec,
2809 experts: &[ExpertConfig],
2810) -> Result<MixtureRuntime, String> {
2811 spec.validate()?;
2812 match spec.kind {
2813 MixtureKind::Bayes => Ok(MixtureRuntime::Bayes(BayesMixture::new(experts))),
2814 MixtureKind::FadingBayes => {
2815 let decay = spec
2816 .decay
2817 .ok_or_else(|| "fading Bayes mixture requires decay".to_string())?;
2818 Ok(MixtureRuntime::Fading(FadingBayesMixture::new(
2819 experts, decay,
2820 )))
2821 }
2822 MixtureKind::Switching => Ok(MixtureRuntime::Switching(SwitchingMixture::new(
2823 experts,
2824 spec.alpha,
2825 spec.schedule,
2826 ))),
2827 MixtureKind::Convex => Ok(MixtureRuntime::Convex(ConvexMixture::new(
2828 experts,
2829 spec.alpha,
2830 spec.schedule,
2831 ))),
2832 MixtureKind::Mdl => Ok(MixtureRuntime::Mdl(MdlSelector::new(experts))),
2833 MixtureKind::Neural => Ok(MixtureRuntime::Neural(NeuralMixture::new(
2834 experts, spec.alpha,
2835 ))),
2836 }
2837}
2838
2839#[cfg(test)]
2840mod tests {
2841 use super::*;
2842 use std::sync::{
2843 Arc,
2844 atomic::{AtomicU64, AtomicUsize, Ordering},
2845 };
2846
2847 #[derive(Clone)]
2848 struct AlwaysPredict {
2849 byte: u8,
2850 }
2851
2852 impl OnlineBytePredictor for AlwaysPredict {
2853 fn log_prob(&mut self, symbol: u8) -> f64 {
2854 if symbol == self.byte {
2855 0.0
2856 } else {
2857 f64::NEG_INFINITY
2858 }
2859 }
2860
2861 fn update(&mut self, _symbol: u8) {}
2862 }
2863
2864 #[derive(Clone)]
2865 struct FixedProbPredict {
2866 prob_zero: f64,
2867 }
2868
2869 impl OnlineBytePredictor for FixedProbPredict {
2870 fn log_prob(&mut self, symbol: u8) -> f64 {
2871 let p = if symbol == 0 {
2872 self.prob_zero
2873 } else {
2874 1.0 - self.prob_zero
2875 };
2876 p.ln()
2877 }
2878
2879 fn update(&mut self, _symbol: u8) {}
2880 }
2881
2882 fn weighted_cfg(name: &'static str, weight: f64, prob_zero: f64) -> ExpertConfig {
2883 ExpertConfig::new(name, weight.ln(), move || {
2884 Box::new(FixedProbPredict { prob_zero })
2885 })
2886 }
2887
2888 #[test]
2889 fn bayes_mixture_prefers_correct_expert() {
2890 let configs = vec![
2891 ExpertConfig::uniform("zero", || Box::new(AlwaysPredict { byte: 0 })),
2892 ExpertConfig::uniform("one", || Box::new(AlwaysPredict { byte: 1 })),
2893 ];
2894 let mut mix = BayesMixture::new(&configs);
2895 for _ in 0..10 {
2896 mix.step(0);
2897 }
2898 let post = mix.posterior();
2899 assert!(post[0] > 0.999);
2900 assert!(post[1] < 1e-6);
2901 }
2902
2903 fn counting_cfg(name: &'static str, calls: Arc<AtomicUsize>) -> ExpertConfig {
2904 ExpertConfig::uniform(name, move || {
2905 Box::new(CountingPredict {
2906 calls: calls.clone(),
2907 })
2908 })
2909 }
2910
2911 #[test]
2912 fn bayes_predict_then_step_reuses_cached_log_probs() {
2913 let c0 = Arc::new(AtomicUsize::new(0));
2914 let c1 = Arc::new(AtomicUsize::new(0));
2915 let mut mix = BayesMixture::new(&[
2916 counting_cfg("c0", c0.clone()),
2917 counting_cfg("c1", c1.clone()),
2918 ]);
2919 let _ = mix.predict_log_prob(0);
2920 let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2921 assert_eq!(after_predict, 2);
2922 let _ = mix.step(0);
2923 let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2924 assert_eq!(after_step, after_predict);
2925 }
2926
2927 #[test]
2928 fn fading_predict_then_step_reuses_cached_log_probs() {
2929 let c0 = Arc::new(AtomicUsize::new(0));
2930 let c1 = Arc::new(AtomicUsize::new(0));
2931 let mut mix = FadingBayesMixture::new(
2932 &[
2933 counting_cfg("c0", c0.clone()),
2934 counting_cfg("c1", c1.clone()),
2935 ],
2936 0.95,
2937 );
2938 let _ = mix.predict_log_prob(0);
2939 let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2940 assert_eq!(after_predict, 2);
2941 let _ = mix.step(0);
2942 let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2943 assert_eq!(after_step, after_predict);
2944 }
2945
2946 #[test]
2947 fn switching_predict_then_step_reuses_cached_log_probs() {
2948 let c0 = Arc::new(AtomicUsize::new(0));
2949 let c1 = Arc::new(AtomicUsize::new(0));
2950 let mut mix = SwitchingMixture::new(
2951 &[
2952 counting_cfg("c0", c0.clone()),
2953 counting_cfg("c1", c1.clone()),
2954 ],
2955 0.05,
2956 MixtureScheduleMode::Default,
2957 );
2958 let _ = mix.predict_log_prob(0);
2959 let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2960 assert_eq!(after_predict, 2);
2961 let _ = mix.step(0);
2962 let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2963 assert_eq!(after_step, after_predict);
2964 }
2965
2966 #[test]
2967 fn switching_mixture_matches_fixed_share_update_for_uniform_prior() {
2968 let configs = vec![weighted_cfg("a", 0.5, 0.8), weighted_cfg("b", 0.5, 0.3)];
2969 let alpha = 0.2;
2970 let mut mix = SwitchingMixture::new(&configs, alpha, MixtureScheduleMode::Default);
2971
2972 let predicted = mix.predict_log_prob(0).exp();
2973 assert!((predicted - 0.55).abs() < 1e-12, "predicted={predicted}");
2974
2975 let observed = mix.step(0).exp();
2976 assert!((observed - 0.55).abs() < 1e-12, "observed={observed}");
2977
2978 let post = mix.posterior();
2979 let posterior_a = 0.5 * 0.8 / 0.55;
2980 let posterior_b = 0.5 * 0.3 / 0.55;
2981 let expected_a = (1.0 - alpha) * posterior_a + alpha * posterior_b;
2982 let expected_b = (1.0 - alpha) * posterior_b + alpha * posterior_a;
2983 assert!(
2984 (post[0] - expected_a).abs() < 1e-12 && (post[1] - expected_b).abs() < 1e-12,
2985 "expected [{expected_a}, {expected_b}], got {:?}",
2986 post
2987 );
2988 }
2989
2990 #[test]
2991 fn switching_mixture_switches_according_to_prior_over_other_experts() {
2992 let configs = vec![
2993 weighted_cfg("a", 0.5, 0.75),
2994 weighted_cfg("b", 0.3, 0.25),
2995 weighted_cfg("c", 0.2, 0.60),
2996 ];
2997 let alpha = 0.15;
2998 let mut mix = SwitchingMixture::new(&configs, alpha, MixtureScheduleMode::Default);
2999
3000 let _ = mix.step(0);
3001 let post = mix.posterior();
3002
3003 let current = [0.5_f64, 0.3, 0.2];
3004 let likelihood = [0.75_f64, 0.25, 0.60];
3005 let mix_prob = current
3006 .iter()
3007 .zip(likelihood.iter())
3008 .map(|(w, p)| w * p)
3009 .sum::<f64>();
3010 let posterior = [
3011 current[0] * likelihood[0] / mix_prob,
3012 current[1] * likelihood[1] / mix_prob,
3013 current[2] * likelihood[2] / mix_prob,
3014 ];
3015 let prior = [0.5_f64, 0.3, 0.2];
3016 let mut expected = [0.0_f64; 3];
3017 for j in 0..3 {
3018 let stay = (1.0 - alpha) * posterior[j];
3019 let switch_in = alpha
3020 * prior[j]
3021 * (0..3)
3022 .filter(|&k| k != j)
3023 .map(|k| posterior[k] / (1.0 - prior[k]))
3024 .sum::<f64>();
3025 expected[j] = stay + switch_in;
3026 }
3027
3028 for i in 0..3 {
3029 assert!(
3030 (post[i] - expected[i]).abs() < 1e-12,
3031 "expert {i}: expected {} got {}",
3032 expected[i],
3033 post[i]
3034 );
3035 }
3036 }
3037
3038 #[test]
3039 fn switching_theorem_schedule_uses_one_over_t() {
3040 assert!(
3041 (switching_alpha_for_update(MixtureScheduleMode::Theorem, 0.99, 0) - 0.5).abs() < 1e-12
3042 );
3043 assert!(
3044 (switching_alpha_for_update(MixtureScheduleMode::Theorem, 0.99, 1) - (1.0 / 3.0)).abs()
3045 < 1e-12
3046 );
3047
3048 let configs = vec![weighted_cfg("a", 0.5, 0.8), weighted_cfg("b", 0.5, 0.3)];
3049 let mut mix = SwitchingMixture::new(&configs, 0.99, MixtureScheduleMode::Theorem);
3050 let _ = mix.step(0);
3051 let post = mix.posterior();
3052 let posterior_a = 0.5 * 0.8 / 0.55;
3053 let posterior_b = 0.5 * 0.3 / 0.55;
3054 let expected_a = 0.5 * posterior_a + 0.5 * posterior_b;
3055 let expected_b = expected_a;
3056 assert!((post[0] - expected_a).abs() < 1e-12);
3057 assert!((post[1] - expected_b).abs() < 1e-12);
3058 }
3059
3060 #[test]
3061 fn convex_theorem_schedule_uses_paper_step_size() {
3062 let eta = convex_step_size_for_update(MixtureScheduleMode::Theorem, 9.0, 1);
3063 assert!((eta - DEFAULT_MIN_PROB).abs() < 1e-18);
3064
3065 let configs = vec![weighted_cfg("a", 0.5, 0.8), weighted_cfg("b", 0.5, 0.3)];
3066 let mut mix = ConvexMixture::new(&configs, 9.0, MixtureScheduleMode::Theorem);
3067 let observed = mix.step(0).exp();
3068 assert!((observed - 0.55).abs() < 1e-12, "observed={observed}");
3069
3070 let expected = [
3071 0.5 + eta * ((0.8 / 0.55) - 1.0),
3072 0.5 + eta * ((0.3 / 0.55) - 1.0),
3073 ];
3074 assert!((mix.lambda[0] - expected[0]).abs() < 1e-12);
3075 assert!((mix.lambda[1] - expected[1]).abs() < 1e-12);
3076 }
3077
3078 #[test]
3079 fn mdl_predict_then_step_reuses_best_expert_log_prob() {
3080 let c0 = Arc::new(AtomicUsize::new(0));
3081 let c1 = Arc::new(AtomicUsize::new(0));
3082 let mut mdl = MdlSelector::new(&[
3083 counting_cfg("c0", c0.clone()),
3084 counting_cfg("c1", c1.clone()),
3085 ]);
3086 let _ = mdl.predict_log_prob(0);
3087 let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
3088 assert_eq!(after_predict, 1);
3089 let _ = mdl.step(0);
3090 let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
3091 assert_eq!(after_step, 2);
3092 }
3093
3094 #[test]
3095 fn neural_mixture_adapts_to_correct_symbol() {
3096 let configs = vec![
3097 ExpertConfig::uniform("zero", || Box::new(AlwaysPredict { byte: 0 })),
3098 ExpertConfig::uniform("one", || Box::new(AlwaysPredict { byte: 1 })),
3099 ];
3100 let mut mix = NeuralMixture::new(&configs, 0.05);
3101
3102 let mut early = 0.0;
3103 let mut late = 0.0;
3104 for t in 0..200 {
3105 let lp = mix.step(0);
3106 if t < 20 {
3107 early -= lp;
3108 }
3109 if t >= 180 {
3110 late -= lp;
3111 }
3112 }
3113
3114 let early_avg = early / 20.0;
3115 let late_avg = late / 20.0;
3116 assert!(
3117 late_avg < early_avg,
3118 "late_avg={late_avg} early_avg={early_avg}"
3119 );
3120 assert!(late_avg < 0.35, "late_avg={late_avg}");
3121 }
3122
3123 #[derive(Clone)]
3124 struct CountingPredict {
3125 calls: Arc<AtomicUsize>,
3126 }
3127
3128 impl OnlineBytePredictor for CountingPredict {
3129 fn log_prob(&mut self, symbol: u8) -> f64 {
3130 self.calls.fetch_add(1, Ordering::Relaxed);
3131 if symbol == 0 { 0.0 } else { -20.0 }
3132 }
3133
3134 fn update(&mut self, _symbol: u8) {}
3135 }
3136
3137 #[derive(Clone)]
3138 struct CountingFillPredict {
3139 log_calls: Arc<AtomicUsize>,
3140 fill_calls: Arc<AtomicUsize>,
3141 }
3142
3143 impl OnlineBytePredictor for CountingFillPredict {
3144 fn log_prob(&mut self, symbol: u8) -> f64 {
3145 self.log_calls.fetch_add(1, Ordering::Relaxed);
3146 if symbol == 0 { 0.0 } else { -20.0 }
3147 }
3148
3149 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
3150 self.fill_calls.fetch_add(1, Ordering::Relaxed);
3151 out.fill(-20.0);
3152 out[0] = 0.0;
3153 }
3154
3155 fn update(&mut self, _symbol: u8) {}
3156 }
3157
3158 #[derive(Clone)]
3159 struct BeginAwarePredict {
3160 seen_total: Arc<AtomicU64>,
3161 began: bool,
3162 }
3163
3164 impl OnlineBytePredictor for BeginAwarePredict {
3165 fn begin_stream(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
3166 let total = total_symbols.ok_or_else(|| "missing total symbols".to_string())?;
3167 self.seen_total.store(total, Ordering::Relaxed);
3168 self.began = true;
3169 Ok(())
3170 }
3171
3172 fn log_prob(&mut self, _symbol: u8) -> f64 {
3173 if self.began { 0.0 } else { f64::NEG_INFINITY }
3174 }
3175
3176 fn update(&mut self, _symbol: u8) {}
3177 }
3178
3179 fn assert_log_prob_update_matches_separate(label: &str, backend: RateBackend) {
3180 let mut separate =
3181 RateBackendPredictor::from_backend(backend.clone(), -1, DEFAULT_MIN_PROB);
3182 let mut combined = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
3183 let data = b"combined step check data";
3184
3185 for &b in data {
3186 let logp_separate = separate.log_prob(b);
3187 separate.update(b);
3188 let logp_combined = combined.log_prob_update(b);
3189 let diff = (logp_separate - logp_combined).abs();
3190 assert!(
3191 diff <= 1e-12,
3192 "[{label}] symbol={b} separate={logp_separate} combined={logp_combined} diff={diff}"
3193 );
3194
3195 let mut sep_row = [0.0; 256];
3196 let mut combo_row = [0.0; 256];
3197 separate.fill_log_probs(&mut sep_row);
3198 combined.fill_log_probs(&mut combo_row);
3199 for i in 0..256 {
3200 let diff = (sep_row[i] - combo_row[i]).abs();
3201 assert!(
3202 diff <= 1e-12,
3203 "row mismatch at {i}: {} vs {}",
3204 sep_row[i],
3205 combo_row[i]
3206 );
3207 }
3208 }
3209 }
3210
3211 fn assert_fill_matches_symbol_queries(label: &str, backend: RateBackend) {
3212 let mut bulk = RateBackendPredictor::from_backend(backend.clone(), -1, DEFAULT_MIN_PROB);
3213 let mut queried = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
3214 let data = b"continuation consistency prompt";
3215
3216 bulk.begin_stream(Some(data.len() as u64))
3217 .expect("bulk begin");
3218 queried
3219 .begin_stream(Some(data.len() as u64))
3220 .expect("query begin");
3221 for &b in data {
3222 bulk.update(b);
3223 queried.update(b);
3224 }
3225
3226 let mut bulk_row = [0.0; 256];
3227 bulk.fill_log_probs(&mut bulk_row);
3228 for (sym, &bulk_logp) in bulk_row.iter().enumerate() {
3229 let queried_logp = queried.log_prob(sym as u8);
3230 let diff = (bulk_logp - queried_logp).abs();
3231 assert!(
3232 diff <= 1e-12,
3233 "[{label}] sym={sym} bulk={bulk_logp} queried={queried_logp} diff={diff}"
3234 );
3235 }
3236 }
3237
3238 fn assert_fill_matches_symbol_queries_after_frozen_conditioning(
3239 label: &str,
3240 backend: RateBackend,
3241 ) {
3242 let fit = b"If a frog is green, dogs are red.\nIf a toad is green, cats are red.\n";
3243 let condition = b"If a cat is red, toads are \n";
3244 let total = (fit.len() + condition.len()) as u64;
3245
3246 let mut bulk = RateBackendPredictor::from_backend(backend.clone(), -1, DEFAULT_MIN_PROB);
3247 let mut queried = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
3248
3249 bulk.begin_stream(Some(total)).expect("bulk begin");
3250 queried.begin_stream(Some(total)).expect("query begin");
3251 for &b in fit {
3252 bulk.update(b);
3253 queried.update(b);
3254 }
3255 bulk.reset_frozen(Some(condition.len() as u64))
3256 .expect("bulk reset frozen");
3257 queried
3258 .reset_frozen(Some(condition.len() as u64))
3259 .expect("query reset frozen");
3260 for &b in condition {
3261 bulk.update_frozen(b);
3262 queried.update_frozen(b);
3263 }
3264
3265 let mut bulk_row = [0.0; 256];
3266 bulk.fill_log_probs(&mut bulk_row);
3267 for (sym, &bulk_logp) in bulk_row.iter().enumerate() {
3268 let queried_logp = queried.log_prob(sym as u8);
3269 let diff = (bulk_logp - queried_logp).abs();
3270 assert!(
3271 diff <= 1e-12,
3272 "[{label}] frozen sym={sym} bulk={bulk_logp} queried={queried_logp} diff={diff}"
3273 );
3274 }
3275 }
3276
3277 #[test]
3278 fn predictor_log_prob_update_matches_separate_update_for_rosa_backend() {
3279 assert_log_prob_update_matches_separate("rosa", RateBackend::RosaPlus);
3280 }
3281
3282 #[test]
3283 fn predictor_log_prob_update_matches_separate_update_for_ctw_backend() {
3284 assert_log_prob_update_matches_separate("ctw", RateBackend::Ctw { depth: 6 });
3285 }
3286
3287 #[test]
3288 fn predictor_log_prob_update_matches_separate_update_for_fac_ctw_backend() {
3289 assert_log_prob_update_matches_separate(
3290 "fac-ctw",
3291 RateBackend::FacCtw {
3292 base_depth: 6,
3293 num_percept_bits: 8,
3294 encoding_bits: 8,
3295 },
3296 );
3297 }
3298
3299 #[test]
3300 fn predictor_fill_matches_symbol_queries_for_rosa_backend() {
3301 assert_fill_matches_symbol_queries("rosa", RateBackend::RosaPlus);
3302 }
3303
3304 #[test]
3305 fn predictor_fill_matches_symbol_queries_for_ctw_backend() {
3306 assert_fill_matches_symbol_queries("ctw", RateBackend::Ctw { depth: 6 });
3307 }
3308
3309 #[test]
3310 fn predictor_fill_matches_symbol_queries_for_match_backend() {
3311 assert_fill_matches_symbol_queries(
3312 "match",
3313 RateBackend::Match {
3314 hash_bits: 18,
3315 min_len: 4,
3316 max_len: 64,
3317 base_mix: 0.02,
3318 confidence_scale: 1.0,
3319 },
3320 );
3321 }
3322
3323 #[test]
3324 fn predictor_fill_matches_symbol_queries_for_ppmd_backend() {
3325 assert_fill_matches_symbol_queries(
3326 "ppmd",
3327 RateBackend::Ppmd {
3328 order: 8,
3329 memory_mb: 8,
3330 },
3331 );
3332 }
3333
3334 #[cfg(feature = "backend-rwkv")]
3335 #[test]
3336 fn predictor_fill_matches_symbol_queries_for_rwkv_backend() {
3337 assert_fill_matches_symbol_queries(
3338 "rwkv7",
3339 RateBackend::Rwkv7Method {
3340 method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3341 },
3342 );
3343 }
3344
3345 #[test]
3346 fn predictor_fill_matches_symbol_queries_for_rosa_backend_after_frozen_conditioning() {
3347 assert_fill_matches_symbol_queries_after_frozen_conditioning("rosa", RateBackend::RosaPlus);
3348 }
3349
3350 #[test]
3351 fn predictor_frozen_conditioning_reuses_match_fit_corpus() {
3352 let mut predictor = RateBackendPredictor::from_backend(
3353 RateBackend::Match {
3354 hash_bits: 20,
3355 min_len: 3,
3356 max_len: 32,
3357 base_mix: 0.02,
3358 confidence_scale: 1.0,
3359 },
3360 -1,
3361 DEFAULT_MIN_PROB,
3362 );
3363
3364 for &b in b"abcabcX" {
3365 predictor.update(b);
3366 }
3367 predictor
3368 .reset_frozen(Some(6))
3369 .expect("reset frozen for match backend");
3370 for &b in b"abcabc" {
3371 predictor.update_frozen(b);
3372 }
3373 let p_x = predictor.log_prob(b'X').exp();
3374 assert!(
3375 p_x > 0.01,
3376 "frozen conditioning should preserve fit corpus for match backend; p_x={p_x}"
3377 );
3378 }
3379
3380 #[test]
3381 fn predictor_frozen_conditioning_reuses_sparse_match_fit_corpus() {
3382 let mut predictor = RateBackendPredictor::from_backend(
3383 RateBackend::SparseMatch {
3384 hash_bits: 20,
3385 min_len: 3,
3386 max_len: 32,
3387 gap_min: 0,
3388 gap_max: 2,
3389 base_mix: 0.02,
3390 confidence_scale: 1.0,
3391 },
3392 -1,
3393 DEFAULT_MIN_PROB,
3394 );
3395
3396 for &b in b"abcabcX" {
3397 predictor.update(b);
3398 }
3399 predictor
3400 .reset_frozen(Some(6))
3401 .expect("reset frozen for sparse-match backend");
3402 for &b in b"abcabc" {
3403 predictor.update_frozen(b);
3404 }
3405 let p_x = predictor.log_prob(b'X').exp();
3406 assert!(
3407 p_x > 0.01,
3408 "frozen conditioning should preserve fit corpus for sparse-match backend; p_x={p_x}"
3409 );
3410 }
3411
3412 #[test]
3413 fn neural_predict_then_step_reuses_evaluation_cache() {
3414 let c0 = Arc::new(AtomicUsize::new(0));
3415 let c1 = Arc::new(AtomicUsize::new(0));
3416 let cfg0 = {
3417 let c = c0.clone();
3418 ExpertConfig::uniform("c0", move || Box::new(CountingPredict { calls: c.clone() }))
3419 };
3420 let cfg1 = {
3421 let c = c1.clone();
3422 ExpertConfig::uniform("c1", move || Box::new(CountingPredict { calls: c.clone() }))
3423 };
3424 let mut mix = NeuralMixture::new(&[cfg0, cfg1], 0.03);
3425
3426 let _ = mix.predict_log_prob(0);
3427 let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
3428 assert_eq!(after_predict, 2);
3429
3430 let _ = mix.step(0);
3431 let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
3432 assert_eq!(after_step, after_predict);
3433 }
3434
3435 #[test]
3436 fn neural_predict_multiple_symbols_reuses_single_evaluation() {
3437 let c0 = Arc::new(AtomicUsize::new(0));
3438 let c1 = Arc::new(AtomicUsize::new(0));
3439 let cfg0 = {
3440 let c = c0.clone();
3441 ExpertConfig::uniform("c0", move || Box::new(CountingPredict { calls: c.clone() }))
3442 };
3443 let cfg1 = {
3444 let c = c1.clone();
3445 ExpertConfig::uniform("c1", move || Box::new(CountingPredict { calls: c.clone() }))
3446 };
3447 let mut mix = NeuralMixture::new(&[cfg0, cfg1], 0.03);
3448
3449 let _ = mix.predict_log_prob(0);
3450 let after_first = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
3451 assert_eq!(after_first, 2);
3452
3453 let _ = mix.predict_log_prob(1);
3454 let after_second = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
3455 assert_eq!(after_second, after_first + 2);
3456 }
3457
3458 #[test]
3459 fn neural_fill_then_step_reuses_cached_full_rows() {
3460 let log0 = Arc::new(AtomicUsize::new(0));
3461 let log1 = Arc::new(AtomicUsize::new(0));
3462 let fill0 = Arc::new(AtomicUsize::new(0));
3463 let fill1 = Arc::new(AtomicUsize::new(0));
3464 let cfg0 = {
3465 let log_calls = log0.clone();
3466 let fill_calls = fill0.clone();
3467 ExpertConfig::uniform("c0", move || {
3468 Box::new(CountingFillPredict {
3469 log_calls: log_calls.clone(),
3470 fill_calls: fill_calls.clone(),
3471 })
3472 })
3473 };
3474 let cfg1 = {
3475 let log_calls = log1.clone();
3476 let fill_calls = fill1.clone();
3477 ExpertConfig::uniform("c1", move || {
3478 Box::new(CountingFillPredict {
3479 log_calls: log_calls.clone(),
3480 fill_calls: fill_calls.clone(),
3481 })
3482 })
3483 };
3484 let mut mix = NeuralMixture::new(&[cfg0, cfg1], 0.03);
3485
3486 let mut row = [0.0; 256];
3487 mix.fill_log_probs(&mut row);
3488 assert_eq!(fill0.load(Ordering::Relaxed), 1);
3489 assert_eq!(fill1.load(Ordering::Relaxed), 1);
3490 assert_eq!(log0.load(Ordering::Relaxed), 0);
3491 assert_eq!(log1.load(Ordering::Relaxed), 0);
3492
3493 let _ = mix.step(0);
3494 assert_eq!(fill0.load(Ordering::Relaxed), 1);
3495 assert_eq!(fill1.load(Ordering::Relaxed), 1);
3496 assert_eq!(log0.load(Ordering::Relaxed), 0);
3497 assert_eq!(log1.load(Ordering::Relaxed), 0);
3498 }
3499
3500 #[test]
3501 fn runtime_begin_stream_propagates_to_experts() {
3502 let seen_total = Arc::new(AtomicU64::new(0));
3503 let cfg = {
3504 let seen_total = seen_total.clone();
3505 ExpertConfig::uniform("begin-aware", move || {
3506 Box::new(BeginAwarePredict {
3507 seen_total: seen_total.clone(),
3508 began: false,
3509 })
3510 })
3511 };
3512
3513 let spec = MixtureSpec::new(
3514 MixtureKind::Bayes,
3515 vec![crate::MixtureExpertSpec {
3516 name: Some("begin-aware".to_string()),
3517 log_prior: 0.0,
3518 max_order: -1,
3519 backend: RateBackend::Ctw { depth: 1 },
3520 }],
3521 );
3522 let mut runtime = build_mixture_runtime(&spec, &[cfg]).expect("runtime");
3523 runtime.begin_stream(Some(123)).expect("begin stream");
3524 let _ = runtime.step(0);
3525 assert_eq!(seen_total.load(Ordering::Relaxed), 123);
3526 }
3527
3528 #[test]
3529 fn zpaq_fill_log_probs_does_not_drift_history() {
3530 let backend = RateBackend::Zpaq {
3531 method: "1".to_string(),
3532 };
3533 let mut baseline =
3534 RateBackendPredictor::from_backend(backend.clone(), -1, DEFAULT_MIN_PROB);
3535 let mut probe = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
3536
3537 let history = b"history for zpaq predictor";
3538 for &b in history {
3539 baseline.update(b);
3540 probe.update(b);
3541 }
3542
3543 let mut row = [0.0f64; 256];
3544 probe.fill_log_probs(&mut row);
3545
3546 let sym = b'k';
3547 let lp_base = baseline.log_prob(sym);
3548 let lp_probe = probe.log_prob(sym);
3549 assert!((lp_base - lp_probe).abs() < 1e-9);
3550 assert!((row[sym as usize] - lp_base).abs() < 1e-9);
3551
3552 baseline.update(sym);
3553 probe.update(sym);
3554 let next = b'q';
3555 let next_base = baseline.log_prob(next);
3556 let next_probe = probe.log_prob(next);
3557 assert!((next_base - next_probe).abs() < 1e-9);
3558 }
3559
3560 fn assert_predictor_log_probs_normalize_to_one(backend: RateBackend) {
3561 let mut predictor = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
3562 for &b in b"normalization corpus for ctw/fac predictor checks" {
3563 predictor.update(b);
3564 }
3565 let mut sum = 0.0f64;
3566 for sym in 0u8..=255u8 {
3567 sum += predictor.log_prob(sym).exp();
3568 }
3569 assert!(
3570 (sum - 1.0).abs() <= 1e-10,
3571 "probability mass drift: sum={sum}"
3572 );
3573 }
3574
3575 #[test]
3576 fn ctw_predictor_symbol_probs_normalize() {
3577 assert_predictor_log_probs_normalize_to_one(RateBackend::Ctw { depth: 7 });
3578 }
3579
3580 #[test]
3581 fn fac_ctw_predictor_symbol_probs_normalize() {
3582 assert_predictor_log_probs_normalize_to_one(RateBackend::FacCtw {
3583 base_depth: 7,
3584 num_percept_bits: 8,
3585 encoding_bits: 8,
3586 });
3587 }
3588}