1use crate::backends::calibration::CalibratorCore;
13use crate::backends::match_model::MatchModel;
14use crate::backends::ppmd::PpmdModel;
15use crate::backends::sparse_match::SparseMatchModel;
16use crate::backends::text_context::TextContextAnalyzer;
17use crate::ctw::FacContextTree;
18#[cfg(feature = "backend-mamba")]
19use crate::mambazip;
20use crate::neural_mix::{NeuralHistoryState, NeuralMixCore};
21use crate::rosaplus::RosaPlus;
22#[cfg(feature = "backend-rwkv")]
23use crate::rwkvzip;
24use crate::zpaq_rate::ZpaqRateModel;
25use crate::{CalibratedSpec, MixtureKind, MixtureSpec, RateBackend};
26use std::sync::Arc;
27
28pub const DEFAULT_MIN_PROB: f64 = 5.960_464_477_539_063e-8;
30
31#[inline]
32fn clamp_prob(p: f64, min_prob: f64) -> f64 {
33 if p.is_finite() {
34 p.max(min_prob)
35 } else {
36 min_prob
37 }
38}
39
40#[inline]
41fn clamp_unit_prob(p: f64, min_prob: f64) -> f64 {
42 clamp_prob(p, min_prob).min(1.0 - min_prob)
43}
44
45#[inline]
46fn build_calibrator(spec: &CalibratedSpec) -> CalibratorCore {
47 CalibratorCore::new(spec.context, spec.bins, spec.learning_rate, spec.bias_clip)
48}
49
50#[inline]
51fn logsumexp(xs: &[f64]) -> f64 {
52 let mut max_v = f64::NEG_INFINITY;
53 for &v in xs {
54 if v > max_v {
55 max_v = v;
56 }
57 }
58 if !max_v.is_finite() {
59 return max_v;
60 }
61 let mut sum = 0.0;
62 for &v in xs {
63 sum += (v - max_v).exp();
64 }
65 max_v + sum.ln()
66}
67
68#[inline]
69fn logsumexp2(a: f64, b: f64) -> f64 {
70 let m = if a > b { a } else { b };
71 if !m.is_finite() {
72 return m;
73 }
74 m + ((a - m).exp() + (b - m).exp()).ln()
75}
76
77#[inline]
78fn logsumexp_weights(experts: &[ExpertState]) -> f64 {
79 let mut max_v = f64::NEG_INFINITY;
80 for e in experts {
81 if e.log_weight > max_v {
82 max_v = e.log_weight;
83 }
84 }
85 if !max_v.is_finite() {
86 return max_v;
87 }
88 let mut sum = 0.0;
89 for e in experts {
90 sum += (e.log_weight - max_v).exp();
91 }
92 max_v + sum.ln()
93}
94
95pub trait OnlineBytePredictorClone {
97 fn clone_box(&self) -> Box<dyn OnlineBytePredictor>;
102}
103
104impl<T> OnlineBytePredictorClone for T
105where
106 T: 'static + OnlineBytePredictor + Clone,
107{
108 fn clone_box(&self) -> Box<dyn OnlineBytePredictor> {
109 Box::new(self.clone())
110 }
111}
112
113impl Clone for Box<dyn OnlineBytePredictor> {
114 fn clone(&self) -> Self {
115 self.clone_box()
116 }
117}
118
119pub trait OnlineBytePredictor: Send + OnlineBytePredictorClone {
121 fn begin_stream(&mut self, _total_symbols: Option<u64>) -> Result<(), String> {
126 Ok(())
127 }
128
129 fn finish_stream(&mut self) -> Result<(), String> {
131 Ok(())
132 }
133
134 fn log_prob(&mut self, symbol: u8) -> f64;
136
137 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
139 for (sym, slot) in out.iter_mut().enumerate() {
140 *slot = self.log_prob(sym as u8);
141 }
142 }
143
144 fn log_prob_update(&mut self, symbol: u8) -> f64 {
146 let logp = self.log_prob(symbol);
147 self.update(symbol);
148 logp
149 }
150
151 fn update(&mut self, symbol: u8);
153
154 fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
160 self.finish_stream()?;
161 self.begin_stream(total_symbols)
162 }
163
164 fn update_frozen(&mut self, symbol: u8) {
169 self.update(symbol);
170 }
171}
172
173#[cfg(feature = "backend-rwkv")]
174#[inline]
175fn ensure_rwkv_primed(compressor: &mut rwkvzip::Compressor, primed: &mut bool) {
176 if !*primed {
177 compressor.reset_and_prime();
178 *primed = true;
179 }
180}
181
182#[inline]
183fn ctw_log_prob_update_msb(tree: &mut FacContextTree, symbol: u8, min_prob: f64) -> f64 {
184 let mut logp = 0.0;
185 for bit_idx in 0..8 {
186 let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
187 let p = tree.predict(bit, bit_idx);
188 if p.is_finite() && p > 0.0 {
189 logp += p.ln();
190 } else {
191 logp = f64::NEG_INFINITY;
192 }
193 tree.update_predicted(bit, bit_idx);
194 }
195 if logp.is_finite() {
196 logp.max(min_prob.ln())
197 } else {
198 min_prob.ln()
199 }
200}
201
202#[inline]
203fn ctw_log_prob_update_lsb(
204 tree: &mut FacContextTree,
205 symbol: u8,
206 bits_per_symbol: usize,
207 min_prob: f64,
208) -> f64 {
209 let mut logp = 0.0;
210 for bit_idx in 0..bits_per_symbol {
211 let bit = ((symbol >> bit_idx) & 1) == 1;
212 let p = tree.predict(bit, bit_idx);
213 if p.is_finite() && p > 0.0 {
214 logp += p.ln();
215 } else {
216 logp = f64::NEG_INFINITY;
217 }
218 tree.update_predicted(bit, bit_idx);
219 }
220 if logp.is_finite() {
221 logp.max(min_prob.ln())
222 } else {
223 min_prob.ln()
224 }
225}
226
227fn fill_fac_tree_log_probs(
228 tree: &mut FacContextTree,
229 bits_per_symbol: usize,
230 msb_first: bool,
231 min_logp: f64,
232 out: &mut [f64; 256],
233) {
234 struct RecParams {
235 bits: usize,
236 msb_first: bool,
237 log_before: f64,
238 min_logp: f64,
239 }
240
241 let bits = bits_per_symbol.clamp(1, 8);
242 let patterns = 1usize << bits;
243 let mut pattern_logps = [f64::NEG_INFINITY; 256];
244 let params = RecParams {
245 bits,
246 msb_first,
247 log_before: tree.get_log_block_probability(),
248 min_logp,
249 };
250
251 fn rec(
252 tree: &mut FacContextTree,
253 depth: usize,
254 params: &RecParams,
255 symbol_acc: u8,
256 pattern_logps: &mut [f64; 256],
257 ) {
258 if depth == params.bits {
259 let pat = symbol_acc as usize;
260 let logp = (tree.get_log_block_probability() - params.log_before).max(params.min_logp);
261 pattern_logps[pat] = logp;
262 return;
263 }
264
265 for bit in [false, true] {
266 tree.update(bit, depth);
267 let mut next_symbol = symbol_acc;
268 if params.msb_first {
269 let shift = 7usize.saturating_sub(depth);
270 if bit {
271 next_symbol |= 1u8 << shift;
272 }
273 } else if bit {
274 next_symbol |= 1u8 << depth;
275 }
276 rec(tree, depth + 1, params, next_symbol, pattern_logps);
277 tree.revert(depth);
278 }
279 }
280
281 rec(tree, 0, ¶ms, 0, &mut pattern_logps);
282
283 if bits == 8 {
284 out.copy_from_slice(&pattern_logps);
285 } else {
286 let aliases = 1usize << (8 - bits);
287 let alias_ln = (aliases as f64).ln();
288 let mask = patterns - 1;
289 for byte in 0..256usize {
290 out[byte] = pattern_logps[byte & mask] - alias_ln;
291 }
292 }
293}
294
295#[allow(clippy::large_enum_variant)]
297#[derive(Clone)]
298pub enum RateBackendPredictor {
299 Rosa {
301 model: RosaPlus,
303 min_prob: f64,
305 },
306 Match {
308 model: MatchModel,
310 min_prob: f64,
312 },
313 SparseMatch {
315 model: SparseMatchModel,
317 min_prob: f64,
319 },
320 Ppmd {
322 model: PpmdModel,
324 min_prob: f64,
326 },
327 Ctw {
329 tree: FacContextTree,
331 min_prob: f64,
333 },
334 FacCtw {
336 tree: FacContextTree,
338 bits_per_symbol: usize,
340 min_prob: f64,
342 },
343 #[cfg(feature = "backend-rwkv")]
345 Rwkv7 {
346 compressor: rwkvzip::Compressor,
348 primed: bool,
350 pdf_scratch: Vec<f64>,
352 min_prob: f64,
354 },
355 #[cfg(feature = "backend-mamba")]
357 Mamba {
358 compressor: mambazip::Compressor,
360 primed: bool,
362 pdf_scratch: Vec<f64>,
364 min_prob: f64,
366 },
367 Zpaq {
369 model: ZpaqRateModel,
371 },
372 Mixture {
374 runtime: MixtureRuntime,
376 },
377 Particle {
379 runtime: crate::particle::ParticleRuntime,
381 },
382 Calibrated {
384 base: Box<RateBackendPredictor>,
386 core: CalibratorCore,
388 pdf: [f64; 256],
390 valid: bool,
392 min_prob: f64,
394 },
395}
396
397impl RateBackendPredictor {
398 pub fn from_backend(backend: RateBackend, max_order: i64, min_prob: f64) -> Self {
400 match backend {
401 RateBackend::RosaPlus => {
402 let mut model = RosaPlus::new(max_order, false, 0, 42);
403 model.build_lm_full_bytes_no_finalize_endpos();
404 Self::Rosa { model, min_prob }
405 }
406 RateBackend::Match {
407 hash_bits,
408 min_len,
409 max_len,
410 base_mix,
411 confidence_scale,
412 } => Self::Match {
413 model: MatchModel::new_contiguous(
414 hash_bits,
415 min_len,
416 max_len,
417 base_mix,
418 confidence_scale,
419 ),
420 min_prob,
421 },
422 RateBackend::SparseMatch {
423 hash_bits,
424 min_len,
425 max_len,
426 gap_min,
427 gap_max,
428 base_mix,
429 confidence_scale,
430 } => Self::SparseMatch {
431 model: SparseMatchModel::new(
432 hash_bits,
433 min_len,
434 max_len,
435 gap_min,
436 gap_max,
437 base_mix,
438 confidence_scale,
439 ),
440 min_prob,
441 },
442 RateBackend::Ppmd { order, memory_mb } => Self::Ppmd {
443 model: PpmdModel::new(order, memory_mb),
444 min_prob,
445 },
446 RateBackend::Ctw { depth } => {
447 let tree = FacContextTree::new(depth, 8);
448 Self::Ctw { tree, min_prob }
449 }
450 RateBackend::FacCtw {
451 base_depth,
452 num_percept_bits: _,
453 encoding_bits,
454 } => {
455 let bits_per_symbol = encoding_bits.clamp(1, 8);
456 let tree = FacContextTree::new(base_depth, bits_per_symbol);
457 Self::FacCtw {
458 tree,
459 bits_per_symbol,
460 min_prob,
461 }
462 }
463 #[cfg(feature = "backend-rwkv")]
464 RateBackend::Rwkv7 { model } => {
465 let mut compressor = rwkvzip::Compressor::new_from_model(model);
466 compressor.reset_and_prime();
467 Self::Rwkv7 {
468 pdf_scratch: vec![0.0; compressor.pdf_buffer.len()],
469 compressor,
470 primed: true,
471 min_prob,
472 }
473 }
474 #[cfg(feature = "backend-rwkv")]
475 RateBackend::Rwkv7Method { method } => {
476 let mut compressor = rwkvzip::Compressor::new_from_method(&method)
477 .unwrap_or_else(|e| panic!("invalid rwkv method '{method}': {e}"));
478 compressor.reset_and_prime();
479 Self::Rwkv7 {
480 pdf_scratch: vec![0.0; compressor.pdf_buffer.len()],
481 compressor,
482 primed: true,
483 min_prob,
484 }
485 }
486 #[cfg(feature = "backend-mamba")]
487 RateBackend::Mamba { model } => {
488 let mut compressor = mambazip::Compressor::new_from_model(model);
489 let bias = compressor.online_bias_snapshot();
490 let logits =
491 compressor
492 .model
493 .forward(&mut compressor.scratch, 0, &mut compressor.state);
494 mambazip::Compressor::logits_to_pdf(
495 logits,
496 bias.as_deref(),
497 &mut compressor.pdf_buffer,
498 );
499 Self::Mamba {
500 pdf_scratch: vec![0.0; compressor.pdf_buffer.len()],
501 compressor,
502 primed: true,
503 min_prob,
504 }
505 }
506 #[cfg(feature = "backend-mamba")]
507 RateBackend::MambaMethod { method } => {
508 let mut compressor = mambazip::Compressor::new_from_method(&method)
509 .unwrap_or_else(|e| panic!("invalid mamba method '{method}': {e}"));
510 let bias = compressor.online_bias_snapshot();
511 let logits =
512 compressor
513 .model
514 .forward(&mut compressor.scratch, 0, &mut compressor.state);
515 mambazip::Compressor::logits_to_pdf(
516 logits,
517 bias.as_deref(),
518 &mut compressor.pdf_buffer,
519 );
520 Self::Mamba {
521 pdf_scratch: vec![0.0; compressor.pdf_buffer.len()],
522 compressor,
523 primed: true,
524 min_prob,
525 }
526 }
527 RateBackend::Zpaq { method } => {
528 let model = ZpaqRateModel::new(method, min_prob);
529 Self::Zpaq { model }
530 }
531 RateBackend::Mixture { spec } => {
532 let experts = spec.build_experts();
533 let runtime = build_mixture_runtime(spec.as_ref(), &experts)
534 .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
535 Self::Mixture { runtime }
536 }
537 RateBackend::Particle { spec } => {
538 let runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
539 Self::Particle { runtime }
540 }
541 RateBackend::Calibrated { spec } => Self::Calibrated {
542 base: Box::new(Self::from_backend(spec.base.clone(), max_order, min_prob)),
543 core: build_calibrator(spec.as_ref()),
544 pdf: [1.0 / 256.0; 256],
545 valid: false,
546 min_prob,
547 },
548 }
549 }
550
551 pub fn default_name(backend: &RateBackend, max_order: i64) -> String {
553 match backend {
554 RateBackend::RosaPlus => format!("rosa(mo={})", max_order),
555 RateBackend::Match { .. } => "match".to_string(),
556 RateBackend::SparseMatch { .. } => "sparse-match".to_string(),
557 RateBackend::Ppmd { order, memory_mb } => {
558 format!("ppmd(o={},m={}MiB)", order, memory_mb)
559 }
560 RateBackend::Ctw { depth } => format!("ctw(d={})", depth),
561 RateBackend::FacCtw {
562 base_depth,
563 encoding_bits,
564 ..
565 } => format!("fac-ctw(d={},b={})", base_depth, encoding_bits),
566 #[cfg(feature = "backend-rwkv")]
567 RateBackend::Rwkv7 { .. } => "rwkv7".to_string(),
568 #[cfg(feature = "backend-rwkv")]
569 RateBackend::Rwkv7Method { method } => format!("rwkv7({method})"),
570 #[cfg(feature = "backend-mamba")]
571 RateBackend::Mamba { .. } => "mamba".to_string(),
572 #[cfg(feature = "backend-mamba")]
573 RateBackend::MambaMethod { method } => format!("mamba({method})"),
574 RateBackend::Zpaq { method } => format!("zpaq(m={})", method),
575 RateBackend::Mixture { spec } => {
576 let kind = match spec.kind {
577 MixtureKind::Bayes => "bayes",
578 MixtureKind::FadingBayes => "fading",
579 MixtureKind::Switching => "switch",
580 MixtureKind::Mdl => "mdl",
581 MixtureKind::Neural => "neural",
582 };
583 format!("mix({})", kind)
584 }
585 RateBackend::Particle { spec } => {
586 format!("particle(n={},c={})", spec.num_particles, spec.num_cells)
587 }
588 RateBackend::Calibrated { spec } => {
589 format!("calibrated({})", Self::default_name(&spec.base, max_order))
590 }
591 }
592 }
593}
594
595impl OnlineBytePredictor for RateBackendPredictor {
596 fn begin_stream(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
597 self.finish_stream()?;
598 match self {
599 RateBackendPredictor::Rosa { model, .. } => {
600 if let Some(total) = total_symbols {
601 let reserve = usize::try_from(total).unwrap_or(usize::MAX / 4);
602 model.reserve_for_stream(reserve);
603 }
604 Ok(())
605 }
606 RateBackendPredictor::Match { .. }
607 | RateBackendPredictor::SparseMatch { .. }
608 | RateBackendPredictor::Ppmd { .. } => Ok(()),
609 RateBackendPredictor::Ctw { .. }
610 | RateBackendPredictor::FacCtw { .. }
611 | RateBackendPredictor::Zpaq { .. }
612 | RateBackendPredictor::Particle { .. } => Ok(()),
613 #[cfg(feature = "backend-rwkv")]
614 RateBackendPredictor::Rwkv7 { compressor, .. } => compressor
615 .begin_online_policy_stream(total_symbols)
616 .map_err(|e| e.to_string()),
617 #[cfg(feature = "backend-mamba")]
618 RateBackendPredictor::Mamba { compressor, .. } => compressor
619 .begin_online_policy_stream(total_symbols)
620 .map_err(|e| e.to_string()),
621 RateBackendPredictor::Mixture { runtime } => runtime.begin_stream(total_symbols),
622 RateBackendPredictor::Calibrated { base, .. } => base.begin_stream(total_symbols),
623 }
624 }
625
626 fn finish_stream(&mut self) -> Result<(), String> {
627 match self {
628 RateBackendPredictor::Rosa { .. }
629 | RateBackendPredictor::Match { .. }
630 | RateBackendPredictor::SparseMatch { .. }
631 | RateBackendPredictor::Ppmd { .. }
632 | RateBackendPredictor::Ctw { .. }
633 | RateBackendPredictor::FacCtw { .. }
634 | RateBackendPredictor::Zpaq { .. }
635 | RateBackendPredictor::Particle { .. } => Ok(()),
636 #[cfg(feature = "backend-rwkv")]
637 RateBackendPredictor::Rwkv7 { compressor, .. } => compressor
638 .finish_online_policy_stream()
639 .map_err(|e| e.to_string()),
640 #[cfg(feature = "backend-mamba")]
641 RateBackendPredictor::Mamba { .. } => Ok(()),
642 RateBackendPredictor::Mixture { runtime } => runtime.finish_stream(),
643 RateBackendPredictor::Calibrated { base, .. } => base.finish_stream(),
644 }
645 }
646
647 fn log_prob(&mut self, symbol: u8) -> f64 {
648 match self {
649 RateBackendPredictor::Rosa { model, min_prob } => {
650 let p = clamp_prob(model.prob_for_last(symbol as u32), *min_prob);
651 p.ln()
652 }
653 RateBackendPredictor::Match { model, min_prob } => model.log_prob(symbol, *min_prob),
654 RateBackendPredictor::SparseMatch { model, min_prob } => {
655 model.log_prob(symbol, *min_prob)
656 }
657 RateBackendPredictor::Ppmd { model, min_prob } => model.log_prob(symbol, *min_prob),
658 RateBackendPredictor::Ctw { tree, min_prob } => {
659 let log_before = tree.get_log_block_probability();
660 for bit_idx in 0..8 {
661 let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
662 tree.update(bit, bit_idx);
663 }
664 let log_after = tree.get_log_block_probability();
665 for bit_idx in (0..8).rev() {
666 tree.revert(bit_idx);
667 }
668 let logp = log_after - log_before;
669 if logp.is_finite() {
670 logp.max(min_prob.ln())
671 } else {
672 min_prob.ln()
673 }
674 }
675 RateBackendPredictor::FacCtw {
676 tree,
677 bits_per_symbol,
678 min_prob,
679 } => {
680 let log_before = tree.get_log_block_probability();
681 for i in 0..*bits_per_symbol {
682 let bit = ((symbol >> i) & 1) == 1;
683 tree.update(bit, i);
684 }
685 let log_after = tree.get_log_block_probability();
686 for i in (0..*bits_per_symbol).rev() {
687 tree.revert(i);
688 }
689 let logp = log_after - log_before;
690 if logp.is_finite() {
691 logp.max(min_prob.ln())
692 } else {
693 min_prob.ln()
694 }
695 }
696 #[cfg(feature = "backend-rwkv")]
697 RateBackendPredictor::Rwkv7 {
698 compressor,
699 primed,
700 min_prob,
701 ..
702 } => {
703 ensure_rwkv_primed(compressor, primed);
704 let p = clamp_prob(compressor.pdf_buffer[symbol as usize], *min_prob);
705 p.ln()
706 }
707 #[cfg(feature = "backend-mamba")]
708 RateBackendPredictor::Mamba {
709 compressor,
710 primed,
711 min_prob,
712 ..
713 } => {
714 if !*primed {
715 let bias = compressor.online_bias_snapshot();
716 let logits =
717 compressor
718 .model
719 .forward(&mut compressor.scratch, 0, &mut compressor.state);
720 mambazip::Compressor::logits_to_pdf(
721 logits,
722 bias.as_deref(),
723 &mut compressor.pdf_buffer,
724 );
725 *primed = true;
726 }
727 let p = clamp_prob(compressor.pdf_buffer[symbol as usize], *min_prob);
728 p.ln()
729 }
730 RateBackendPredictor::Zpaq { model } => model.log_prob(symbol),
731 RateBackendPredictor::Mixture { runtime } => runtime.peek_log_prob(symbol),
732 RateBackendPredictor::Particle { runtime } => runtime.peek_log_prob(symbol),
733 RateBackendPredictor::Calibrated {
734 base,
735 core,
736 pdf,
737 valid,
738 min_prob,
739 } => {
740 if !*valid {
741 let mut base_logps = [0.0; 256];
742 base.fill_log_probs(&mut base_logps);
743 let mut base_pdf = [0.0; 256];
744 for (dst, &lp) in base_pdf.iter_mut().zip(base_logps.iter()) {
745 *dst = clamp_prob(lp.exp(), *min_prob);
746 }
747 core.apply_pdf(&base_pdf, pdf);
748 *valid = true;
749 }
750 pdf[symbol as usize].max(*min_prob).ln()
751 }
752 }
753 }
754
755 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
756 match self {
757 RateBackendPredictor::Rosa { model, min_prob } => {
758 model.fill_probs_for_last_bytes(out);
759 for slot in out.iter_mut() {
760 *slot = clamp_prob(*slot, *min_prob).ln();
761 }
762 }
763 RateBackendPredictor::Match { model, min_prob } => {
764 let mut pdf = [0.0; 256];
765 model.fill_pdf(&mut pdf);
766 for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
767 *slot = clamp_prob(p, *min_prob).ln();
768 }
769 }
770 RateBackendPredictor::SparseMatch { model, min_prob } => {
771 let mut pdf = [0.0; 256];
772 model.fill_pdf(&mut pdf);
773 for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
774 *slot = clamp_prob(p, *min_prob).ln();
775 }
776 }
777 RateBackendPredictor::Ppmd { model, min_prob } => {
778 let mut pdf = [0.0; 256];
779 model.fill_pdf(&mut pdf);
780 for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
781 *slot = clamp_prob(p, *min_prob).ln();
782 }
783 }
784 RateBackendPredictor::Ctw { tree, min_prob } => {
785 fill_fac_tree_log_probs(tree, 8, true, min_prob.ln(), out);
786 }
787 RateBackendPredictor::FacCtw {
788 tree,
789 bits_per_symbol,
790 min_prob,
791 } => {
792 fill_fac_tree_log_probs(tree, *bits_per_symbol, false, min_prob.ln(), out);
793 }
794 #[cfg(feature = "backend-rwkv")]
795 RateBackendPredictor::Rwkv7 {
796 compressor,
797 primed,
798 min_prob,
799 ..
800 } => {
801 ensure_rwkv_primed(compressor, primed);
802 for (slot, &p_raw) in out
803 .iter_mut()
804 .take(256)
805 .zip(compressor.pdf_buffer.iter().take(256))
806 {
807 let p = clamp_prob(p_raw, *min_prob);
808 *slot = p.ln();
809 }
810 }
811 #[cfg(feature = "backend-mamba")]
812 RateBackendPredictor::Mamba {
813 compressor,
814 primed,
815 min_prob,
816 ..
817 } => {
818 if !*primed {
819 let bias = compressor.online_bias_snapshot();
820 let logits =
821 compressor
822 .model
823 .forward(&mut compressor.scratch, 0, &mut compressor.state);
824 mambazip::Compressor::logits_to_pdf(
825 logits,
826 bias.as_deref(),
827 &mut compressor.pdf_buffer,
828 );
829 *primed = true;
830 }
831 for (slot, &p_raw) in out
832 .iter_mut()
833 .take(256)
834 .zip(compressor.pdf_buffer.iter().take(256))
835 {
836 let p = clamp_prob(p_raw, *min_prob);
837 *slot = p.ln();
838 }
839 }
840 RateBackendPredictor::Zpaq { model } => {
841 model.fill_log_probs(out);
842 }
843 RateBackendPredictor::Mixture { runtime } => {
844 runtime.fill_log_probs(out);
845 }
846 RateBackendPredictor::Particle { runtime } => {
847 runtime.fill_log_probs_cached(out);
848 }
849 RateBackendPredictor::Calibrated {
850 base,
851 core,
852 pdf,
853 valid,
854 min_prob,
855 } => {
856 if !*valid {
857 let mut base_logps = [0.0; 256];
858 base.fill_log_probs(&mut base_logps);
859 let mut base_pdf = [0.0; 256];
860 for (dst, &lp) in base_pdf.iter_mut().zip(base_logps.iter()) {
861 *dst = clamp_prob(lp.exp(), *min_prob);
862 }
863 core.apply_pdf(&base_pdf, pdf);
864 *valid = true;
865 }
866 for (slot, &p) in out.iter_mut().zip(pdf.iter()) {
867 *slot = clamp_prob(p, *min_prob).ln();
868 }
869 }
870 }
871 }
872
873 fn update(&mut self, symbol: u8) {
874 match self {
875 RateBackendPredictor::Rosa { model, .. } => {
876 model.train_byte(symbol);
877 }
878 RateBackendPredictor::Match { model, .. } => {
879 model.update(symbol);
880 }
881 RateBackendPredictor::SparseMatch { model, .. } => {
882 model.update(symbol);
883 }
884 RateBackendPredictor::Ppmd { model, .. } => {
885 model.update(symbol);
886 }
887 RateBackendPredictor::Ctw { tree, .. } => {
888 for bit_idx in 0..8 {
889 let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
890 tree.update(bit, bit_idx);
891 }
892 }
893 RateBackendPredictor::FacCtw {
894 tree,
895 bits_per_symbol,
896 ..
897 } => {
898 for i in 0..*bits_per_symbol {
899 let bit = ((symbol >> i) & 1) == 1;
900 tree.update(bit, i);
901 }
902 }
903 #[cfg(feature = "backend-rwkv")]
904 RateBackendPredictor::Rwkv7 {
905 compressor, primed, ..
906 } => {
907 ensure_rwkv_primed(compressor, primed);
908 compressor
909 .observe_symbol_from_current_pdf(symbol)
910 .unwrap_or_else(|e| panic!("rwkv online update failed: {e}"));
911 }
912 #[cfg(feature = "backend-mamba")]
913 RateBackendPredictor::Mamba {
914 compressor,
915 primed,
916 pdf_scratch,
917 ..
918 } => {
919 if !*primed {
920 let bias = compressor.online_bias_snapshot();
921 let logits =
922 compressor
923 .model
924 .forward(&mut compressor.scratch, 0, &mut compressor.state);
925 mambazip::Compressor::logits_to_pdf(
926 logits,
927 bias.as_deref(),
928 &mut compressor.pdf_buffer,
929 );
930 *primed = true;
931 }
932 if pdf_scratch.len() != compressor.pdf_buffer.len() {
933 pdf_scratch.resize(compressor.pdf_buffer.len(), 0.0);
934 }
935 pdf_scratch.copy_from_slice(&compressor.pdf_buffer);
936 compressor
937 .online_update_from_pdf(symbol, pdf_scratch)
938 .unwrap_or_else(|e| panic!("mamba online update failed: {e}"));
939 let bias = compressor.online_bias_snapshot();
940 let logits = compressor.model.forward(
941 &mut compressor.scratch,
942 symbol as u32,
943 &mut compressor.state,
944 );
945 mambazip::Compressor::logits_to_pdf(
946 logits,
947 bias.as_deref(),
948 &mut compressor.pdf_buffer,
949 );
950 }
951 RateBackendPredictor::Zpaq { model } => {
952 model.update(symbol);
953 }
954 RateBackendPredictor::Mixture { runtime } => {
955 let _ = runtime.step(symbol);
956 }
957 RateBackendPredictor::Particle { runtime } => {
958 runtime.step(symbol);
959 }
960 RateBackendPredictor::Calibrated {
961 base,
962 core,
963 pdf,
964 valid,
965 ..
966 } => {
967 if !*valid {
968 let mut base_logps = [0.0; 256];
969 base.fill_log_probs(&mut base_logps);
970 let mut base_pdf = [0.0; 256];
971 for (dst, &lp) in base_pdf.iter_mut().zip(base_logps.iter()) {
972 *dst = clamp_prob(lp.exp(), DEFAULT_MIN_PROB);
973 }
974 core.apply_pdf(&base_pdf, pdf);
975 }
976 core.update(symbol, pdf);
977 base.update(symbol);
978 *valid = false;
979 }
980 }
981 }
982
983 fn log_prob_update(&mut self, symbol: u8) -> f64 {
984 match self {
985 RateBackendPredictor::Rosa { model, min_prob } => {
986 let p = clamp_prob(model.prob_for_last(symbol as u32), *min_prob);
987 model.train_byte(symbol);
988 p.ln()
989 }
990 RateBackendPredictor::Ctw { tree, min_prob } => {
991 ctw_log_prob_update_msb(tree, symbol, *min_prob)
992 }
993 RateBackendPredictor::FacCtw {
994 tree,
995 bits_per_symbol,
996 min_prob,
997 } => ctw_log_prob_update_lsb(tree, symbol, *bits_per_symbol, *min_prob),
998 _ => {
999 let logp = self.log_prob(symbol);
1000 self.update(symbol);
1001 logp
1002 }
1003 }
1004 }
1005
1006 fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1007 self.finish_stream()?;
1008 match self {
1009 RateBackendPredictor::Rosa { model, .. } => {
1010 if let Some(total) = total_symbols {
1011 let reserve = usize::try_from(total).unwrap_or(usize::MAX / 4);
1012 model.reserve_for_stream(reserve);
1013 }
1014 model.build_lm_full_bytes_no_finalize_endpos();
1015 model.reset_conditioning_cursor();
1016 Ok(())
1017 }
1018 RateBackendPredictor::Match { model, .. } => {
1019 model.reset_history();
1020 Ok(())
1021 }
1022 RateBackendPredictor::SparseMatch { model, .. } => {
1023 model.reset_history();
1024 Ok(())
1025 }
1026 RateBackendPredictor::Ppmd { model, .. } => {
1027 model.reset_history();
1028 Ok(())
1029 }
1030 RateBackendPredictor::Ctw { tree, .. } => {
1031 tree.reset_history_only();
1032 Ok(())
1033 }
1034 RateBackendPredictor::FacCtw { tree, .. } => {
1035 tree.reset_history_only();
1036 Ok(())
1037 }
1038 #[cfg(feature = "backend-rwkv")]
1039 RateBackendPredictor::Rwkv7 {
1040 compressor, primed, ..
1041 } => {
1042 compressor.reset_and_prime();
1043 *primed = true;
1044 Ok(())
1045 }
1046 #[cfg(feature = "backend-mamba")]
1047 RateBackendPredictor::Mamba {
1048 compressor, primed, ..
1049 } => {
1050 compressor.reset_and_prime();
1051 *primed = true;
1052 Ok(())
1053 }
1054 RateBackendPredictor::Zpaq { .. } => {
1055 Err("plugin entropy is not supported for zpaq rate backends in 1.1.0".to_string())
1056 }
1057 RateBackendPredictor::Mixture { runtime } => runtime.reset_frozen(total_symbols),
1058 RateBackendPredictor::Particle { runtime } => {
1059 runtime.reset_frozen_state();
1060 Ok(())
1061 }
1062 RateBackendPredictor::Calibrated {
1063 base,
1064 core,
1065 pdf,
1066 valid,
1067 ..
1068 } => {
1069 base.reset_frozen(total_symbols)?;
1070 core.reset_context();
1071 pdf.fill(1.0 / 256.0);
1072 *valid = false;
1073 Ok(())
1074 }
1075 }
1076 }
1077
1078 fn update_frozen(&mut self, symbol: u8) {
1079 match self {
1080 RateBackendPredictor::Rosa { model, .. } => {
1081 model.advance_conditioning_byte(symbol);
1082 }
1083 RateBackendPredictor::Match { model, .. } => {
1084 model.update_history_only(symbol);
1085 }
1086 RateBackendPredictor::SparseMatch { model, .. } => {
1087 model.update_history_only(symbol);
1088 }
1089 RateBackendPredictor::Ppmd { model, .. } => {
1090 model.update_history_only(symbol);
1091 }
1092 RateBackendPredictor::Ctw { tree, .. } => {
1093 let mut bits = [false; 8];
1094 for (bit_idx, slot) in bits.iter_mut().enumerate() {
1095 *slot = ((symbol >> (7 - bit_idx)) & 1) == 1;
1096 }
1097 tree.update_history(&bits);
1098 }
1099 RateBackendPredictor::FacCtw {
1100 tree,
1101 bits_per_symbol,
1102 ..
1103 } => {
1104 let bits = (*bits_per_symbol).clamp(1, 8);
1105 let mut history_bits = [false; 8];
1106 for (idx, slot) in history_bits.iter_mut().enumerate().take(bits) {
1107 *slot = ((symbol >> idx) & 1) == 1;
1108 }
1109 tree.update_history(&history_bits[..bits]);
1110 }
1111 #[cfg(feature = "backend-rwkv")]
1112 RateBackendPredictor::Rwkv7 {
1113 compressor, primed, ..
1114 } => {
1115 if !*primed {
1116 compressor.reset_and_prime();
1117 *primed = true;
1118 }
1119 compressor.forward_to_internal_pdf(symbol as u32);
1120 }
1121 #[cfg(feature = "backend-mamba")]
1122 RateBackendPredictor::Mamba {
1123 compressor, primed, ..
1124 } => {
1125 if !*primed {
1126 compressor.reset_and_prime();
1127 *primed = true;
1128 }
1129 let bias = compressor.online_bias_snapshot();
1130 let logits = compressor.model.forward(
1131 &mut compressor.scratch,
1132 symbol as u32,
1133 &mut compressor.state,
1134 );
1135 mambazip::Compressor::logits_to_pdf(
1136 logits,
1137 bias.as_deref(),
1138 &mut compressor.pdf_buffer,
1139 );
1140 }
1141 RateBackendPredictor::Zpaq { model } => {
1142 model.update(symbol);
1143 }
1144 RateBackendPredictor::Mixture { runtime } => {
1145 runtime.update_frozen(symbol);
1146 }
1147 RateBackendPredictor::Particle { runtime } => {
1148 runtime.update_frozen(symbol);
1149 }
1150 RateBackendPredictor::Calibrated {
1151 base,
1152 core,
1153 pdf,
1154 valid,
1155 ..
1156 } => {
1157 if !*valid {
1158 let mut base_logps = [0.0; 256];
1159 base.fill_log_probs(&mut base_logps);
1160 let mut base_pdf = [0.0; 256];
1161 for (dst, &lp) in base_pdf.iter_mut().zip(base_logps.iter()) {
1162 *dst = clamp_prob(lp.exp(), DEFAULT_MIN_PROB);
1163 }
1164 core.apply_pdf(&base_pdf, pdf);
1165 *valid = true;
1166 }
1167 base.update_frozen(symbol);
1168 core.update_context_only(symbol);
1169 *valid = false;
1170 }
1171 }
1172 }
1173}
1174
1175#[derive(Clone)]
1177pub struct ExpertConfig {
1178 pub name: String,
1180 pub log_prior: f64,
1182 builder: Arc<dyn Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync>,
1183}
1184
1185impl ExpertConfig {
1186 pub fn new(
1188 name: impl Into<String>,
1189 log_prior: f64,
1190 builder: impl Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync + 'static,
1191 ) -> Self {
1192 Self {
1193 name: name.into(),
1194 log_prior,
1195 builder: Arc::new(builder),
1196 }
1197 }
1198
1199 pub fn uniform(
1201 name: impl Into<String>,
1202 builder: impl Fn() -> Box<dyn OnlineBytePredictor> + Send + Sync + 'static,
1203 ) -> Self {
1204 Self::new(name, 0.0, builder)
1205 }
1206
1207 pub fn from_rate_backend(
1209 name: Option<String>,
1210 log_prior: f64,
1211 backend: RateBackend,
1212 max_order: i64,
1213 ) -> Self {
1214 let name = name.unwrap_or_else(|| RateBackendPredictor::default_name(&backend, max_order));
1215 Self::new(name, log_prior, move || {
1216 Box::new(RateBackendPredictor::from_backend(
1217 backend.clone(),
1218 max_order,
1219 DEFAULT_MIN_PROB,
1220 ))
1221 })
1222 }
1223
1224 pub fn rosa(name: impl Into<String>, max_order: i64) -> Self {
1226 let name = name.into();
1227 Self::uniform(name, move || {
1228 Box::new(RateBackendPredictor::from_backend(
1229 RateBackend::RosaPlus,
1230 max_order,
1231 DEFAULT_MIN_PROB,
1232 ))
1233 })
1234 }
1235
1236 pub fn ctw(name: impl Into<String>, depth: usize) -> Self {
1238 let name = name.into();
1239 Self::uniform(name, move || {
1240 Box::new(RateBackendPredictor::from_backend(
1241 RateBackend::Ctw { depth },
1242 -1,
1243 DEFAULT_MIN_PROB,
1244 ))
1245 })
1246 }
1247
1248 pub fn fac_ctw(name: impl Into<String>, base_depth: usize, encoding_bits: usize) -> Self {
1250 let name = name.into();
1251 Self::uniform(name, move || {
1252 Box::new(RateBackendPredictor::from_backend(
1253 RateBackend::FacCtw {
1254 base_depth,
1255 num_percept_bits: encoding_bits,
1256 encoding_bits,
1257 },
1258 -1,
1259 DEFAULT_MIN_PROB,
1260 ))
1261 })
1262 }
1263
1264 #[cfg(feature = "backend-rwkv")]
1266 pub fn rwkv(name: impl Into<String>, model: Arc<rwkvzip::Model>) -> Self {
1267 let name = name.into();
1268 Self::uniform(name, move || {
1269 Box::new(RateBackendPredictor::from_backend(
1270 RateBackend::Rwkv7 {
1271 model: model.clone(),
1272 },
1273 -1,
1274 DEFAULT_MIN_PROB,
1275 ))
1276 })
1277 }
1278
1279 #[cfg(feature = "backend-mamba")]
1281 pub fn mamba(name: impl Into<String>, model: Arc<mambazip::Model>) -> Self {
1282 let name = name.into();
1283 Self::uniform(name, move || {
1284 Box::new(RateBackendPredictor::from_backend(
1285 RateBackend::Mamba {
1286 model: model.clone(),
1287 },
1288 -1,
1289 DEFAULT_MIN_PROB,
1290 ))
1291 })
1292 }
1293
1294 pub fn zpaq(name: impl Into<String>, method: impl Into<String>) -> Self {
1296 let name = name.into();
1297 let method = method.into();
1298 Self::uniform(name, move || {
1299 Box::new(RateBackendPredictor::from_backend(
1300 RateBackend::Zpaq {
1301 method: method.clone(),
1302 },
1303 -1,
1304 DEFAULT_MIN_PROB,
1305 ))
1306 })
1307 }
1308
1309 pub fn name(&self) -> &str {
1311 &self.name
1312 }
1313
1314 pub fn log_prior(&self) -> f64 {
1316 self.log_prior
1317 }
1318
1319 pub fn build_predictor(&self) -> Box<dyn OnlineBytePredictor> {
1321 (self.builder)()
1322 }
1323
1324 fn build(&self) -> ExpertState {
1325 ExpertState {
1326 name: self.name.clone(),
1327 log_weight: self.log_prior,
1328 log_prior: self.log_prior,
1329 predictor: (self.builder)(),
1330 cum_log_loss: 0.0,
1331 }
1332 }
1333}
1334
1335#[derive(Clone)]
1336struct ExpertState {
1337 name: String,
1338 log_weight: f64,
1339 log_prior: f64,
1340 predictor: Box<dyn OnlineBytePredictor>,
1341 cum_log_loss: f64,
1342}
1343
1344impl ExpertState {
1345 #[inline]
1346 fn begin_stream(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1347 self.predictor.begin_stream(total_symbols)
1348 }
1349
1350 #[inline]
1351 fn finish_stream(&mut self) -> Result<(), String> {
1352 self.predictor.finish_stream()
1353 }
1354
1355 #[inline]
1356 fn log_prob(&mut self, symbol: u8) -> f64 {
1357 self.predictor.log_prob(symbol)
1358 }
1359
1360 #[inline]
1361 fn log_prob_update(&mut self, symbol: u8) -> f64 {
1362 self.predictor.log_prob_update(symbol)
1363 }
1364
1365 #[inline]
1366 fn update(&mut self, symbol: u8) {
1367 self.predictor.update(symbol);
1368 }
1369
1370 #[inline]
1371 fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1372 self.predictor.reset_frozen(total_symbols)
1373 }
1374
1375 #[inline]
1376 fn update_frozen(&mut self, symbol: u8) {
1377 self.predictor.update_frozen(symbol);
1378 }
1379}
1380
1381#[derive(Clone)]
1383pub struct BayesMixture {
1384 experts: Vec<ExpertState>,
1385 scratch_logps: Vec<f64>,
1386 scratch_mix: Vec<f64>,
1387 cached_symbol: u8,
1388 cached_log_mix: f64,
1389 cache_valid: bool,
1390 total_log_loss: f64,
1391}
1392
1393impl BayesMixture {
1394 pub fn new(configs: &[ExpertConfig]) -> Self {
1396 let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
1397 let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
1398 let norm = logsumexp(&log_priors);
1399 for e in &mut experts {
1400 e.log_weight -= norm;
1401 }
1402 Self {
1403 experts,
1404 scratch_logps: vec![0.0; configs.len()],
1405 scratch_mix: vec![0.0; configs.len()],
1406 cached_symbol: 0,
1407 cached_log_mix: f64::NEG_INFINITY,
1408 cache_valid: false,
1409 total_log_loss: 0.0,
1410 }
1411 }
1412
1413 pub fn step(&mut self, symbol: u8) -> f64 {
1415 if self.experts.is_empty() {
1416 return f64::NEG_INFINITY;
1417 }
1418 let log_mix = if self.cache_valid && self.cached_symbol == symbol {
1419 for (i, expert) in self.experts.iter_mut().enumerate() {
1420 expert.cum_log_loss -= self.scratch_logps[i];
1421 expert.update(symbol);
1422 }
1423 self.cached_log_mix
1424 } else {
1425 for (i, expert) in self.experts.iter_mut().enumerate() {
1426 self.scratch_logps[i] = expert.log_prob_update(symbol);
1427 self.scratch_mix[i] = expert.log_weight + self.scratch_logps[i];
1428 expert.cum_log_loss -= self.scratch_logps[i];
1429 }
1430 logsumexp(&self.scratch_mix)
1431 };
1432 for (i, expert) in self.experts.iter_mut().enumerate() {
1433 expert.log_weight = expert.log_weight + self.scratch_logps[i] - log_mix;
1434 }
1435 self.cache_valid = false;
1436 self.total_log_loss -= log_mix;
1437 log_mix
1438 }
1439
1440 fn predict_log_prob(&mut self, symbol: u8) -> f64 {
1441 if self.experts.is_empty() {
1442 return f64::NEG_INFINITY;
1443 }
1444 for (i, expert) in self.experts.iter_mut().enumerate() {
1445 self.scratch_logps[i] = expert.log_prob(symbol);
1446 self.scratch_mix[i] = expert.log_weight + self.scratch_logps[i];
1447 }
1448 let log_mix = logsumexp(&self.scratch_mix);
1449 self.cached_symbol = symbol;
1450 self.cached_log_mix = log_mix;
1451 self.cache_valid = true;
1452 log_mix
1453 }
1454
1455 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
1456 if self.experts.is_empty() {
1457 out.fill(f64::NEG_INFINITY);
1458 return;
1459 }
1460 out.fill(f64::NEG_INFINITY);
1461 let norm = logsumexp_weights(&self.experts);
1462 let mut row = [0.0f64; 256];
1463 for expert in &mut self.experts {
1464 expert.predictor.fill_log_probs(&mut row);
1465 let lw = expert.log_weight - norm;
1466 for b in 0..256 {
1467 out[b] = logsumexp2(out[b], lw + row[b]);
1468 }
1469 }
1470 }
1471
1472 pub fn posterior(&self) -> Vec<f64> {
1474 let norm = logsumexp_weights(&self.experts);
1475 self.experts
1476 .iter()
1477 .map(|e| (e.log_weight - norm).exp())
1478 .collect()
1479 }
1480
1481 pub fn min_expert_log_loss(&self) -> (usize, f64) {
1483 let mut best_idx = 0usize;
1484 let mut best_loss = f64::INFINITY;
1485 for (i, e) in self.experts.iter().enumerate() {
1486 if e.cum_log_loss < best_loss {
1487 best_loss = e.cum_log_loss;
1488 best_idx = i;
1489 }
1490 }
1491 (best_idx, best_loss)
1492 }
1493
1494 pub fn max_posterior(&self) -> (usize, f64) {
1496 let norm = logsumexp_weights(&self.experts);
1497 let mut best_idx = 0usize;
1498 let mut best_p = 0.0;
1499 for (i, e) in self.experts.iter().enumerate() {
1500 let p = (e.log_weight - norm).exp();
1501 if p > best_p {
1502 best_p = p;
1503 best_idx = i;
1504 }
1505 }
1506 (best_idx, best_p)
1507 }
1508
1509 pub fn total_log_loss(&self) -> f64 {
1511 self.total_log_loss
1512 }
1513
1514 pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
1516 self.experts
1517 .iter()
1518 .map(|e| (e.name.clone(), e.cum_log_loss))
1519 .collect()
1520 }
1521
1522 pub fn expert_names(&self) -> Vec<String> {
1524 self.experts.iter().map(|e| e.name.clone()).collect()
1525 }
1526
1527 fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1528 for expert in &mut self.experts {
1529 expert.reset_frozen(total_symbols)?;
1530 }
1531 self.cache_valid = false;
1532 self.total_log_loss = 0.0;
1533 Ok(())
1534 }
1535
1536 fn update_frozen(&mut self, symbol: u8) {
1537 for expert in &mut self.experts {
1538 expert.update_frozen(symbol);
1539 }
1540 self.cache_valid = false;
1541 }
1542}
1543
1544#[derive(Clone)]
1548pub struct FadingBayesMixture {
1549 experts: Vec<ExpertState>,
1550 decay: f64,
1551 scratch_logps: Vec<f64>,
1552 scratch_mix: Vec<f64>,
1553 cached_symbol: u8,
1554 cached_log_mix: f64,
1555 cache_valid: bool,
1556 total_log_loss: f64,
1557}
1558
1559impl FadingBayesMixture {
1560 pub fn new(configs: &[ExpertConfig], decay: f64) -> Self {
1562 let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
1563 let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
1564 let norm = logsumexp(&log_priors);
1565 for e in &mut experts {
1566 e.log_weight -= norm;
1567 }
1568 let decay = decay.clamp(0.0, 1.0);
1569 Self {
1570 experts,
1571 decay,
1572 scratch_logps: vec![0.0; configs.len()],
1573 scratch_mix: vec![0.0; configs.len()],
1574 cached_symbol: 0,
1575 cached_log_mix: f64::NEG_INFINITY,
1576 cache_valid: false,
1577 total_log_loss: 0.0,
1578 }
1579 }
1580
1581 pub fn step(&mut self, symbol: u8) -> f64 {
1583 if self.experts.is_empty() {
1584 return f64::NEG_INFINITY;
1585 }
1586 let log_mix = if self.cache_valid && self.cached_symbol == symbol {
1587 for (i, expert) in self.experts.iter_mut().enumerate() {
1588 expert.cum_log_loss -= self.scratch_logps[i];
1589 expert.update(symbol);
1590 }
1591 self.cached_log_mix
1592 } else {
1593 for (i, expert) in self.experts.iter_mut().enumerate() {
1594 self.scratch_logps[i] = expert.log_prob_update(symbol);
1595 let decayed = self.decay * expert.log_weight;
1596 self.scratch_mix[i] = decayed + self.scratch_logps[i];
1597 expert.cum_log_loss -= self.scratch_logps[i];
1598 }
1599 logsumexp(&self.scratch_mix)
1600 };
1601 for (i, expert) in self.experts.iter_mut().enumerate() {
1602 let decayed = self.decay * expert.log_weight;
1603 expert.log_weight = decayed + self.scratch_logps[i] - log_mix;
1604 }
1605 self.cache_valid = false;
1606 self.total_log_loss -= log_mix;
1607 log_mix
1608 }
1609
1610 fn predict_log_prob(&mut self, symbol: u8) -> f64 {
1611 if self.experts.is_empty() {
1612 return f64::NEG_INFINITY;
1613 }
1614 for (i, expert) in self.experts.iter_mut().enumerate() {
1615 self.scratch_logps[i] = expert.log_prob(symbol);
1616 self.scratch_mix[i] = self.decay * expert.log_weight + self.scratch_logps[i];
1617 }
1618 let log_mix = logsumexp(&self.scratch_mix);
1619 self.cached_symbol = symbol;
1620 self.cached_log_mix = log_mix;
1621 self.cache_valid = true;
1622 log_mix
1623 }
1624
1625 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
1626 if self.experts.is_empty() {
1627 out.fill(f64::NEG_INFINITY);
1628 return;
1629 }
1630 out.fill(f64::NEG_INFINITY);
1631 let mut decayed = Vec::with_capacity(self.experts.len());
1632 for expert in &self.experts {
1633 decayed.push(self.decay * expert.log_weight);
1634 }
1635 let norm = logsumexp(&decayed);
1636 let mut row = [0.0f64; 256];
1637 for (i, expert) in self.experts.iter_mut().enumerate() {
1638 expert.predictor.fill_log_probs(&mut row);
1639 let lw = decayed[i] - norm;
1640 for b in 0..256 {
1641 out[b] = logsumexp2(out[b], lw + row[b]);
1642 }
1643 }
1644 }
1645
1646 pub fn posterior(&self) -> Vec<f64> {
1648 let norm = logsumexp_weights(&self.experts);
1649 self.experts
1650 .iter()
1651 .map(|e| (e.log_weight - norm).exp())
1652 .collect()
1653 }
1654
1655 pub fn min_expert_log_loss(&self) -> (usize, f64) {
1657 let mut best_idx = 0usize;
1658 let mut best_loss = f64::INFINITY;
1659 for (i, e) in self.experts.iter().enumerate() {
1660 if e.cum_log_loss < best_loss {
1661 best_loss = e.cum_log_loss;
1662 best_idx = i;
1663 }
1664 }
1665 (best_idx, best_loss)
1666 }
1667
1668 pub fn total_log_loss(&self) -> f64 {
1670 self.total_log_loss
1671 }
1672
1673 pub fn expert_names(&self) -> Vec<String> {
1675 self.experts.iter().map(|e| e.name.clone()).collect()
1676 }
1677
1678 fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1679 for expert in &mut self.experts {
1680 expert.reset_frozen(total_symbols)?;
1681 }
1682 self.cache_valid = false;
1683 self.total_log_loss = 0.0;
1684 Ok(())
1685 }
1686
1687 fn update_frozen(&mut self, symbol: u8) {
1688 for expert in &mut self.experts {
1689 expert.update_frozen(symbol);
1690 }
1691 self.cache_valid = false;
1692 }
1693}
1694
1695#[derive(Clone)]
1697pub struct SwitchingMixture {
1698 experts: Vec<ExpertState>,
1699 log_prior: Vec<f64>,
1700 log_alpha: f64,
1701 log_1m_alpha: f64,
1702 scratch_logps: Vec<f64>,
1703 scratch_switch: Vec<f64>,
1704 cached_symbol: u8,
1705 cached_log_mix: f64,
1706 cache_valid: bool,
1707 total_log_loss: f64,
1708}
1709
1710impl SwitchingMixture {
1711 pub fn new(configs: &[ExpertConfig], alpha: f64) -> Self {
1713 let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
1714 let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
1715 let norm = logsumexp(&log_priors);
1716 for e in &mut experts {
1717 e.log_weight -= norm;
1718 }
1719 let log_prior: Vec<f64> = experts.iter().map(|e| e.log_prior - norm).collect();
1720 let alpha = alpha.clamp(1e-12, 1.0 - 1e-12);
1721 Self {
1722 experts,
1723 log_prior,
1724 log_alpha: alpha.ln(),
1725 log_1m_alpha: (1.0 - alpha).ln(),
1726 scratch_logps: vec![0.0; configs.len()],
1727 scratch_switch: vec![0.0; configs.len()],
1728 cached_symbol: 0,
1729 cached_log_mix: f64::NEG_INFINITY,
1730 cache_valid: false,
1731 total_log_loss: 0.0,
1732 }
1733 }
1734
1735 pub fn step(&mut self, symbol: u8) -> f64 {
1737 if self.experts.is_empty() {
1738 return f64::NEG_INFINITY;
1739 }
1740 let log_mix = if self.cache_valid && self.cached_symbol == symbol {
1741 for (i, expert) in self.experts.iter_mut().enumerate() {
1742 expert.cum_log_loss -= self.scratch_logps[i];
1743 expert.update(symbol);
1744 }
1745 self.cached_log_mix
1746 } else {
1747 for (i, expert) in self.experts.iter_mut().enumerate() {
1748 self.scratch_logps[i] = expert.log_prob_update(symbol);
1749 expert.cum_log_loss -= self.scratch_logps[i];
1750 }
1751 for i in 0..self.experts.len() {
1752 let log_switch = logsumexp2(
1753 self.log_1m_alpha + self.experts[i].log_weight,
1754 self.log_alpha + self.log_prior[i],
1755 );
1756 self.scratch_switch[i] = self.scratch_logps[i] + log_switch;
1757 }
1758 logsumexp(&self.scratch_switch)
1759 };
1760 for i in 0..self.experts.len() {
1761 let expert = &mut self.experts[i];
1762 expert.log_weight = self.scratch_switch[i] - log_mix;
1763 }
1764 self.cache_valid = false;
1765 self.total_log_loss -= log_mix;
1766 log_mix
1767 }
1768
1769 fn predict_log_prob(&mut self, symbol: u8) -> f64 {
1770 if self.experts.is_empty() {
1771 return f64::NEG_INFINITY;
1772 }
1773 for i in 0..self.experts.len() {
1774 let lp = self.experts[i].log_prob(symbol);
1775 self.scratch_logps[i] = lp;
1776 let log_switch = logsumexp2(
1777 self.log_1m_alpha + self.experts[i].log_weight,
1778 self.log_alpha + self.log_prior[i],
1779 );
1780 self.scratch_switch[i] = lp + log_switch;
1781 }
1782 let log_mix = logsumexp(&self.scratch_switch);
1783 self.cached_symbol = symbol;
1784 self.cached_log_mix = log_mix;
1785 self.cache_valid = true;
1786 log_mix
1787 }
1788
1789 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
1790 if self.experts.is_empty() {
1791 out.fill(f64::NEG_INFINITY);
1792 return;
1793 }
1794 out.fill(f64::NEG_INFINITY);
1795 let mut log_switch = vec![0.0f64; self.experts.len()];
1796 for (i, expert) in self.experts.iter().enumerate() {
1797 log_switch[i] = logsumexp2(
1798 self.log_1m_alpha + expert.log_weight,
1799 self.log_alpha + self.log_prior[i],
1800 );
1801 }
1802 let norm = logsumexp(&log_switch);
1803 let mut row = [0.0f64; 256];
1804 for (i, expert) in self.experts.iter_mut().enumerate() {
1805 expert.predictor.fill_log_probs(&mut row);
1806 let lw = log_switch[i] - norm;
1807 for b in 0..256 {
1808 out[b] = logsumexp2(out[b], lw + row[b]);
1809 }
1810 }
1811 }
1812
1813 pub fn posterior(&self) -> Vec<f64> {
1815 let norm = logsumexp_weights(&self.experts);
1816 self.experts
1817 .iter()
1818 .map(|e| (e.log_weight - norm).exp())
1819 .collect()
1820 }
1821
1822 pub fn min_expert_log_loss(&self) -> (usize, f64) {
1824 let mut best_idx = 0usize;
1825 let mut best_loss = f64::INFINITY;
1826 for (i, e) in self.experts.iter().enumerate() {
1827 if e.cum_log_loss < best_loss {
1828 best_loss = e.cum_log_loss;
1829 best_idx = i;
1830 }
1831 }
1832 (best_idx, best_loss)
1833 }
1834
1835 pub fn max_posterior(&self) -> (usize, f64) {
1837 let norm = logsumexp_weights(&self.experts);
1838 let mut best_idx = 0usize;
1839 let mut best_p = 0.0;
1840 for (i, e) in self.experts.iter().enumerate() {
1841 let p = (e.log_weight - norm).exp();
1842 if p > best_p {
1843 best_p = p;
1844 best_idx = i;
1845 }
1846 }
1847 (best_idx, best_p)
1848 }
1849
1850 pub fn total_log_loss(&self) -> f64 {
1852 self.total_log_loss
1853 }
1854
1855 pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
1857 self.experts
1858 .iter()
1859 .map(|e| (e.name.clone(), e.cum_log_loss))
1860 .collect()
1861 }
1862
1863 pub fn expert_names(&self) -> Vec<String> {
1865 self.experts.iter().map(|e| e.name.clone()).collect()
1866 }
1867
1868 fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
1869 for expert in &mut self.experts {
1870 expert.reset_frozen(total_symbols)?;
1871 }
1872 self.cache_valid = false;
1873 self.total_log_loss = 0.0;
1874 Ok(())
1875 }
1876
1877 fn update_frozen(&mut self, symbol: u8) {
1878 for expert in &mut self.experts {
1879 expert.update_frozen(symbol);
1880 }
1881 self.cache_valid = false;
1882 }
1883}
1884
1885#[derive(Clone)]
1887pub struct MdlSelector {
1888 experts: Vec<ExpertState>,
1889 scratch_logps: Vec<f64>,
1890 total_log_loss: f64,
1891 last_best: usize,
1892 cached_symbol: u8,
1893 cached_best_idx: usize,
1894 cached_best_logp: f64,
1895 cache_valid: bool,
1896}
1897
1898#[derive(Clone)]
1906pub struct NeuralMixture {
1907 experts: Vec<ExpertState>,
1908 neural: NeuralMixCore,
1909 analyzer: TextContextAnalyzer,
1910 min_prob: f64,
1911 scratch_expert_logps: Vec<f64>,
1912 scratch_mix_weights: Vec<f64>,
1913 eval_cache_valid: bool,
1914 eval_cache_full_valid: bool,
1915 eval_cache_history: NeuralHistoryState,
1916 eval_cache_symbol: u8,
1917 eval_cache_logp: f64,
1918 eval_cache_mix_logps: [f64; 256],
1919 eval_cache_expert_logps: Vec<[f64; 256]>,
1920 total_log_loss: f64,
1921}
1922
1923impl NeuralMixture {
1924 pub fn new(configs: &[ExpertConfig], learning_rate: f64) -> Self {
1926 let mut experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
1927 let n = experts.len();
1928
1929 let mut prior_weights = vec![0.0; n];
1930 if n > 0 {
1931 let log_priors: Vec<f64> = experts.iter().map(|e| e.log_prior).collect();
1932 let norm = logsumexp(&log_priors);
1933 for (i, e) in experts.iter_mut().enumerate() {
1934 let p = (e.log_prior - norm).exp();
1935 prior_weights[i] = p;
1936 }
1937 }
1938
1939 let base_lr = if learning_rate.is_finite() {
1940 learning_rate.abs().clamp(1e-6, 1.0)
1941 } else {
1942 0.03
1943 };
1944 let effective_lr = (base_lr * 25.0).clamp(1e-6, 1.0);
1945 let analyzer = TextContextAnalyzer::new();
1946 let mut neural =
1947 NeuralMixCore::new(n, &prior_weights, effective_lr * 0.5, effective_lr, 1e-5);
1948 neural.set_context_state(analyzer.state());
1949 let eval_cache_history = neural.history_state();
1950
1951 Self {
1952 experts,
1953 neural,
1954 analyzer,
1955 min_prob: DEFAULT_MIN_PROB,
1956 scratch_expert_logps: vec![0.0; n],
1957 scratch_mix_weights: vec![0.0; n],
1958 eval_cache_valid: false,
1959 eval_cache_full_valid: false,
1960 eval_cache_history,
1961 eval_cache_symbol: 0,
1962 eval_cache_logp: f64::NEG_INFINITY,
1963 eval_cache_mix_logps: [f64::NEG_INFINITY; 256],
1964 eval_cache_expert_logps: vec![[f64::NEG_INFINITY; 256]; n],
1965 total_log_loss: 0.0,
1966 }
1967 }
1968
1969 #[inline]
1970 fn invalidate_eval_cache(&mut self) {
1971 self.eval_cache_valid = false;
1972 self.eval_cache_full_valid = false;
1973 }
1974
1975 fn sync_history_state(&mut self) -> NeuralHistoryState {
1976 let history = self.analyzer.state();
1977 if self.neural.history_state() != history {
1978 self.neural.set_context_state(history);
1979 }
1980 if self.eval_cache_history != history {
1981 self.invalidate_eval_cache();
1982 self.eval_cache_history = history;
1983 }
1984 history
1985 }
1986
1987 fn ensure_full_evaluation(&mut self) {
1988 self.sync_history_state();
1989 if self.eval_cache_full_valid {
1990 return;
1991 }
1992
1993 self.neural.evaluate_expert_weights();
1994 self.scratch_mix_weights
1995 .copy_from_slice(self.neural.expert_weights());
1996 let mut mix_pdf = [0.0f64; 256];
1997 for i in 0..self.experts.len() {
1998 let row = &mut self.eval_cache_expert_logps[i];
1999 self.experts[i].predictor.fill_log_probs(row);
2000 let w = self.scratch_mix_weights[i];
2001 for (dst, &lp) in mix_pdf.iter_mut().zip(row.iter()) {
2002 *dst += w * clamp_prob(lp.exp(), self.min_prob);
2003 }
2004 }
2005
2006 let sum: f64 = mix_pdf.iter().sum();
2007 if !sum.is_finite() || sum <= 0.0 {
2008 let uniform = (1.0f64 / 256.0).ln();
2009 self.eval_cache_mix_logps.fill(uniform);
2010 } else {
2011 let inv = 1.0 / sum;
2012 for (dst, &p_raw) in self.eval_cache_mix_logps.iter_mut().zip(mix_pdf.iter()) {
2013 let p = clamp_unit_prob(p_raw * inv, self.min_prob);
2014 *dst = p.ln();
2015 }
2016 }
2017
2018 self.eval_cache_full_valid = true;
2019 }
2020
2021 fn evaluate_symbol(&mut self, symbol: u8) -> f64 {
2022 let history = self.sync_history_state();
2023 if self.eval_cache_valid
2024 && self.eval_cache_history == history
2025 && self.eval_cache_symbol == symbol
2026 {
2027 return self.eval_cache_logp;
2028 }
2029
2030 if self.eval_cache_full_valid && self.eval_cache_history == history {
2031 for (dst, row) in self
2032 .scratch_expert_logps
2033 .iter_mut()
2034 .zip(self.eval_cache_expert_logps.iter())
2035 {
2036 *dst = row[symbol as usize];
2037 }
2038 let logp = self.eval_cache_mix_logps[symbol as usize];
2039 self.eval_cache_valid = true;
2040 self.eval_cache_symbol = symbol;
2041 self.eval_cache_logp = logp;
2042 return logp;
2043 }
2044
2045 let expert_count = self.experts.len();
2046 for i in 0..expert_count {
2047 self.scratch_expert_logps[i] = self.experts[i].log_prob(symbol);
2048 }
2049 let p = self
2050 .neural
2051 .evaluate_symbol(&self.scratch_expert_logps, self.min_prob);
2052 let logp = clamp_unit_prob(p, self.min_prob).ln();
2053 self.eval_cache_valid = true;
2054 self.eval_cache_history = history;
2055 self.eval_cache_symbol = symbol;
2056 self.eval_cache_logp = logp;
2057 logp
2058 }
2059
2060 fn predict_log_prob(&mut self, symbol: u8) -> f64 {
2061 if self.experts.is_empty() {
2062 return f64::NEG_INFINITY;
2063 }
2064 if self.experts.len() == 1 {
2065 return self.experts[0].log_prob(symbol);
2066 }
2067 self.evaluate_symbol(symbol)
2068 }
2069
2070 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2071 if self.experts.is_empty() {
2072 out.fill(f64::NEG_INFINITY);
2073 return;
2074 }
2075 if self.experts.len() == 1 {
2076 self.experts[0].predictor.fill_log_probs(out);
2077 return;
2078 }
2079 self.ensure_full_evaluation();
2080 out.copy_from_slice(&self.eval_cache_mix_logps);
2081 }
2082
2083 pub fn step(&mut self, symbol: u8) -> f64 {
2085 if self.experts.is_empty() {
2086 return f64::NEG_INFINITY;
2087 }
2088
2089 if self.experts.len() == 1 {
2090 let expert = &mut self.experts[0];
2091 let logp = expert.log_prob_update(symbol);
2092 expert.cum_log_loss -= logp;
2093 self.total_log_loss -= logp;
2094 self.analyzer.update(symbol);
2095 self.neural.set_context_state(self.analyzer.state());
2096 self.invalidate_eval_cache();
2097 return logp;
2098 }
2099
2100 let history = self.sync_history_state();
2101 let logp = if self.eval_cache_valid
2102 && self.eval_cache_history == history
2103 && self.eval_cache_symbol == symbol
2104 {
2105 let logp = self.eval_cache_logp;
2106 for i in 0..self.experts.len() {
2107 let expert = &mut self.experts[i];
2108 expert.cum_log_loss -= self.scratch_expert_logps[i];
2109 expert.update(symbol);
2110 }
2111 logp
2112 } else if self.eval_cache_full_valid && self.eval_cache_history == history {
2113 for i in 0..self.experts.len() {
2114 self.scratch_expert_logps[i] = self.eval_cache_expert_logps[i][symbol as usize];
2115 }
2116 let logp = self.eval_cache_mix_logps[symbol as usize];
2117 for i in 0..self.experts.len() {
2118 let expert = &mut self.experts[i];
2119 expert.cum_log_loss -= self.scratch_expert_logps[i];
2120 expert.update(symbol);
2121 }
2122 logp
2123 } else {
2124 for i in 0..self.experts.len() {
2125 let expert = &mut self.experts[i];
2126 self.scratch_expert_logps[i] = expert.log_prob_update(symbol);
2127 expert.cum_log_loss -= self.scratch_expert_logps[i];
2128 }
2129 let p = self
2130 .neural
2131 .evaluate_symbol(&self.scratch_expert_logps, self.min_prob);
2132 clamp_unit_prob(p, self.min_prob).ln()
2133 };
2134 self.neural
2135 .update_weights_symbol(&self.scratch_expert_logps, self.min_prob);
2136 self.total_log_loss -= logp;
2137 self.analyzer.update(symbol);
2138 self.neural.set_context_state(self.analyzer.state());
2139 self.invalidate_eval_cache();
2140 logp
2141 }
2142
2143 pub fn total_log_loss(&self) -> f64 {
2145 self.total_log_loss
2146 }
2147
2148 fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2149 for expert in &mut self.experts {
2150 expert.reset_frozen(total_symbols)?;
2151 }
2152 self.analyzer = TextContextAnalyzer::new();
2153 self.neural.set_context_state(self.analyzer.state());
2154 self.invalidate_eval_cache();
2155 self.eval_cache_history = self.neural.history_state();
2156 self.total_log_loss = 0.0;
2157 Ok(())
2158 }
2159
2160 fn update_frozen(&mut self, symbol: u8) {
2161 for expert in &mut self.experts {
2162 expert.update_frozen(symbol);
2163 }
2164 self.analyzer.update(symbol);
2165 self.neural.set_context_state(self.analyzer.state());
2166 self.invalidate_eval_cache();
2167 self.eval_cache_history = self.neural.history_state();
2168 }
2169}
2170
2171impl MdlSelector {
2172 pub fn new(configs: &[ExpertConfig]) -> Self {
2174 let experts: Vec<ExpertState> = configs.iter().map(|c| c.build()).collect();
2175 let last_best = 0usize;
2176 Self {
2177 experts,
2178 scratch_logps: vec![0.0; configs.len()],
2179 total_log_loss: 0.0,
2180 last_best,
2181 cached_symbol: 0,
2182 cached_best_idx: 0,
2183 cached_best_logp: f64::NEG_INFINITY,
2184 cache_valid: false,
2185 }
2186 }
2187
2188 pub fn step(&mut self, symbol: u8) -> f64 {
2190 if self.experts.is_empty() {
2191 return f64::NEG_INFINITY;
2192 }
2193 let used_cache = self.cache_valid && self.cached_symbol == symbol;
2194 let best_idx = if used_cache {
2195 self.scratch_logps[self.cached_best_idx] = self.cached_best_logp;
2196 for (i, expert) in self.experts.iter_mut().enumerate() {
2197 if i == self.cached_best_idx {
2198 continue;
2199 }
2200 self.scratch_logps[i] = expert.log_prob(symbol);
2201 }
2202 self.cached_best_idx
2203 } else {
2204 for (i, expert) in self.experts.iter_mut().enumerate() {
2205 self.scratch_logps[i] = expert.log_prob_update(symbol);
2206 }
2207 let mut best_idx = 0usize;
2208 let mut best_loss = f64::INFINITY;
2209 for (i, expert) in self.experts.iter().enumerate() {
2210 if expert.cum_log_loss < best_loss {
2211 best_loss = expert.cum_log_loss;
2212 best_idx = i;
2213 }
2214 }
2215 best_idx
2216 };
2217 let logp = self.scratch_logps[best_idx];
2218 self.cache_valid = false;
2219 for (i, expert) in self.experts.iter_mut().enumerate() {
2220 expert.cum_log_loss -= self.scratch_logps[i];
2221 if used_cache {
2222 expert.update(symbol);
2223 }
2224 }
2225 self.total_log_loss -= logp;
2226 self.last_best = best_idx;
2227 logp
2228 }
2229
2230 fn predict_log_prob(&mut self, symbol: u8) -> f64 {
2231 if self.experts.is_empty() {
2232 return f64::NEG_INFINITY;
2233 }
2234 let mut best_idx = 0usize;
2235 let mut best_loss = f64::INFINITY;
2236 for (i, expert) in self.experts.iter().enumerate() {
2237 if expert.cum_log_loss < best_loss {
2238 best_loss = expert.cum_log_loss;
2239 best_idx = i;
2240 }
2241 }
2242 let logp = self.experts[best_idx].log_prob(symbol);
2243 self.cached_symbol = symbol;
2244 self.cached_best_idx = best_idx;
2245 self.cached_best_logp = logp;
2246 self.cache_valid = true;
2247 logp
2248 }
2249
2250 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2251 if self.experts.is_empty() {
2252 out.fill(f64::NEG_INFINITY);
2253 return;
2254 }
2255 let mut best_idx = 0usize;
2256 let mut best_loss = f64::INFINITY;
2257 for (i, expert) in self.experts.iter().enumerate() {
2258 if expert.cum_log_loss < best_loss {
2259 best_loss = expert.cum_log_loss;
2260 best_idx = i;
2261 }
2262 }
2263 self.experts[best_idx].predictor.fill_log_probs(out);
2264 }
2265
2266 pub fn best_index(&self) -> usize {
2268 self.last_best
2269 }
2270
2271 pub fn min_expert_log_loss(&self) -> (usize, f64) {
2273 let mut best_idx = 0usize;
2274 let mut best_loss = f64::INFINITY;
2275 for (i, e) in self.experts.iter().enumerate() {
2276 if e.cum_log_loss < best_loss {
2277 best_loss = e.cum_log_loss;
2278 best_idx = i;
2279 }
2280 }
2281 (best_idx, best_loss)
2282 }
2283
2284 pub fn total_log_loss(&self) -> f64 {
2286 self.total_log_loss
2287 }
2288
2289 pub fn expert_log_losses(&self) -> Vec<(String, f64)> {
2291 self.experts
2292 .iter()
2293 .map(|e| (e.name.clone(), e.cum_log_loss))
2294 .collect()
2295 }
2296
2297 pub fn expert_names(&self) -> Vec<String> {
2299 self.experts.iter().map(|e| e.name.clone()).collect()
2300 }
2301
2302 fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2303 for expert in &mut self.experts {
2304 expert.reset_frozen(total_symbols)?;
2305 }
2306 self.cache_valid = false;
2307 self.total_log_loss = 0.0;
2308 Ok(())
2309 }
2310
2311 fn update_frozen(&mut self, symbol: u8) {
2312 for expert in &mut self.experts {
2313 expert.update_frozen(symbol);
2314 }
2315 self.cache_valid = false;
2316 }
2317}
2318
2319#[allow(clippy::large_enum_variant)]
2325#[derive(Clone)]
2326pub enum MixtureRuntime {
2327 Bayes(BayesMixture),
2329 Fading(FadingBayesMixture),
2331 Switching(SwitchingMixture),
2333 Mdl(MdlSelector),
2335 Neural(NeuralMixture),
2337}
2338
2339impl MixtureRuntime {
2340 pub(crate) fn begin_stream(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2341 match self {
2342 MixtureRuntime::Bayes(m) => begin_expert_stream(&mut m.experts, total_symbols),
2343 MixtureRuntime::Fading(m) => begin_expert_stream(&mut m.experts, total_symbols),
2344 MixtureRuntime::Switching(m) => begin_expert_stream(&mut m.experts, total_symbols),
2345 MixtureRuntime::Mdl(m) => begin_expert_stream(&mut m.experts, total_symbols),
2346 MixtureRuntime::Neural(m) => begin_expert_stream(&mut m.experts, total_symbols),
2347 }
2348 }
2349
2350 pub(crate) fn finish_stream(&mut self) -> Result<(), String> {
2351 match self {
2352 MixtureRuntime::Bayes(m) => finish_expert_stream(&mut m.experts),
2353 MixtureRuntime::Fading(m) => finish_expert_stream(&mut m.experts),
2354 MixtureRuntime::Switching(m) => finish_expert_stream(&mut m.experts),
2355 MixtureRuntime::Mdl(m) => finish_expert_stream(&mut m.experts),
2356 MixtureRuntime::Neural(m) => finish_expert_stream(&mut m.experts),
2357 }
2358 }
2359
2360 pub(crate) fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2361 match self {
2362 MixtureRuntime::Bayes(m) => m.reset_frozen(total_symbols),
2363 MixtureRuntime::Fading(m) => m.reset_frozen(total_symbols),
2364 MixtureRuntime::Switching(m) => m.reset_frozen(total_symbols),
2365 MixtureRuntime::Mdl(m) => m.reset_frozen(total_symbols),
2366 MixtureRuntime::Neural(m) => m.reset_frozen(total_symbols),
2367 }
2368 }
2369
2370 pub(crate) fn peek_log_prob(&mut self, symbol: u8) -> f64 {
2372 match self {
2373 MixtureRuntime::Bayes(m) => m.predict_log_prob(symbol),
2374 MixtureRuntime::Fading(m) => m.predict_log_prob(symbol),
2375 MixtureRuntime::Switching(m) => m.predict_log_prob(symbol),
2376 MixtureRuntime::Mdl(m) => m.predict_log_prob(symbol),
2377 MixtureRuntime::Neural(m) => m.predict_log_prob(symbol),
2378 }
2379 }
2380
2381 pub(crate) fn step(&mut self, symbol: u8) -> f64 {
2383 match self {
2384 MixtureRuntime::Bayes(m) => m.step(symbol),
2385 MixtureRuntime::Fading(m) => m.step(symbol),
2386 MixtureRuntime::Switching(m) => m.step(symbol),
2387 MixtureRuntime::Mdl(m) => m.step(symbol),
2388 MixtureRuntime::Neural(m) => m.step(symbol),
2389 }
2390 }
2391
2392 pub(crate) fn update_frozen(&mut self, symbol: u8) {
2393 match self {
2394 MixtureRuntime::Bayes(m) => m.update_frozen(symbol),
2395 MixtureRuntime::Fading(m) => m.update_frozen(symbol),
2396 MixtureRuntime::Switching(m) => m.update_frozen(symbol),
2397 MixtureRuntime::Mdl(m) => m.update_frozen(symbol),
2398 MixtureRuntime::Neural(m) => m.update_frozen(symbol),
2399 }
2400 }
2401
2402 pub(crate) fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2403 match self {
2404 MixtureRuntime::Bayes(m) => m.fill_log_probs(out),
2405 MixtureRuntime::Fading(m) => m.fill_log_probs(out),
2406 MixtureRuntime::Switching(m) => m.fill_log_probs(out),
2407 MixtureRuntime::Mdl(m) => m.fill_log_probs(out),
2408 MixtureRuntime::Neural(m) => m.fill_log_probs(out),
2409 }
2410 }
2411}
2412
2413fn begin_expert_stream(
2414 experts: &mut [ExpertState],
2415 total_symbols: Option<u64>,
2416) -> Result<(), String> {
2417 for expert in experts {
2418 expert.begin_stream(total_symbols)?;
2419 }
2420 Ok(())
2421}
2422
2423fn finish_expert_stream(experts: &mut [ExpertState]) -> Result<(), String> {
2424 for expert in experts {
2425 expert.finish_stream()?;
2426 }
2427 Ok(())
2428}
2429
2430pub(crate) fn build_mixture_runtime(
2431 spec: &MixtureSpec,
2432 experts: &[ExpertConfig],
2433) -> Result<MixtureRuntime, String> {
2434 if experts.is_empty() {
2435 return Err("mixture spec must include at least one expert".to_string());
2436 }
2437 match spec.kind {
2438 MixtureKind::Bayes => Ok(MixtureRuntime::Bayes(BayesMixture::new(experts))),
2439 MixtureKind::FadingBayes => {
2440 let decay = spec
2441 .decay
2442 .ok_or_else(|| "fading Bayes mixture requires decay".to_string())?;
2443 Ok(MixtureRuntime::Fading(FadingBayesMixture::new(
2444 experts, decay,
2445 )))
2446 }
2447 MixtureKind::Switching => Ok(MixtureRuntime::Switching(SwitchingMixture::new(
2448 experts, spec.alpha,
2449 ))),
2450 MixtureKind::Mdl => Ok(MixtureRuntime::Mdl(MdlSelector::new(experts))),
2451 MixtureKind::Neural => Ok(MixtureRuntime::Neural(NeuralMixture::new(
2452 experts, spec.alpha,
2453 ))),
2454 }
2455}
2456
2457#[cfg(test)]
2458mod tests {
2459 use super::*;
2460 use std::sync::{
2461 Arc,
2462 atomic::{AtomicU64, AtomicUsize, Ordering},
2463 };
2464
2465 #[derive(Clone)]
2466 struct AlwaysPredict {
2467 byte: u8,
2468 }
2469
2470 impl OnlineBytePredictor for AlwaysPredict {
2471 fn log_prob(&mut self, symbol: u8) -> f64 {
2472 if symbol == self.byte {
2473 0.0
2474 } else {
2475 f64::NEG_INFINITY
2476 }
2477 }
2478
2479 fn update(&mut self, _symbol: u8) {}
2480 }
2481
2482 #[test]
2483 fn bayes_mixture_prefers_correct_expert() {
2484 let configs = vec![
2485 ExpertConfig::uniform("zero", || Box::new(AlwaysPredict { byte: 0 })),
2486 ExpertConfig::uniform("one", || Box::new(AlwaysPredict { byte: 1 })),
2487 ];
2488 let mut mix = BayesMixture::new(&configs);
2489 for _ in 0..10 {
2490 mix.step(0);
2491 }
2492 let post = mix.posterior();
2493 assert!(post[0] > 0.999);
2494 assert!(post[1] < 1e-6);
2495 }
2496
2497 fn counting_cfg(name: &'static str, calls: Arc<AtomicUsize>) -> ExpertConfig {
2498 ExpertConfig::uniform(name, move || {
2499 Box::new(CountingPredict {
2500 calls: calls.clone(),
2501 })
2502 })
2503 }
2504
2505 #[test]
2506 fn bayes_predict_then_step_reuses_cached_log_probs() {
2507 let c0 = Arc::new(AtomicUsize::new(0));
2508 let c1 = Arc::new(AtomicUsize::new(0));
2509 let mut mix = BayesMixture::new(&[
2510 counting_cfg("c0", c0.clone()),
2511 counting_cfg("c1", c1.clone()),
2512 ]);
2513 let _ = mix.predict_log_prob(0);
2514 let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2515 assert_eq!(after_predict, 2);
2516 let _ = mix.step(0);
2517 let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2518 assert_eq!(after_step, after_predict);
2519 }
2520
2521 #[test]
2522 fn fading_predict_then_step_reuses_cached_log_probs() {
2523 let c0 = Arc::new(AtomicUsize::new(0));
2524 let c1 = Arc::new(AtomicUsize::new(0));
2525 let mut mix = FadingBayesMixture::new(
2526 &[
2527 counting_cfg("c0", c0.clone()),
2528 counting_cfg("c1", c1.clone()),
2529 ],
2530 0.95,
2531 );
2532 let _ = mix.predict_log_prob(0);
2533 let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2534 assert_eq!(after_predict, 2);
2535 let _ = mix.step(0);
2536 let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2537 assert_eq!(after_step, after_predict);
2538 }
2539
2540 #[test]
2541 fn switching_predict_then_step_reuses_cached_log_probs() {
2542 let c0 = Arc::new(AtomicUsize::new(0));
2543 let c1 = Arc::new(AtomicUsize::new(0));
2544 let mut mix = SwitchingMixture::new(
2545 &[
2546 counting_cfg("c0", c0.clone()),
2547 counting_cfg("c1", c1.clone()),
2548 ],
2549 0.05,
2550 );
2551 let _ = mix.predict_log_prob(0);
2552 let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2553 assert_eq!(after_predict, 2);
2554 let _ = mix.step(0);
2555 let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2556 assert_eq!(after_step, after_predict);
2557 }
2558
2559 #[test]
2560 fn mdl_predict_then_step_reuses_best_expert_log_prob() {
2561 let c0 = Arc::new(AtomicUsize::new(0));
2562 let c1 = Arc::new(AtomicUsize::new(0));
2563 let mut mdl = MdlSelector::new(&[
2564 counting_cfg("c0", c0.clone()),
2565 counting_cfg("c1", c1.clone()),
2566 ]);
2567 let _ = mdl.predict_log_prob(0);
2568 let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2569 assert_eq!(after_predict, 1);
2570 let _ = mdl.step(0);
2571 let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2572 assert_eq!(after_step, 2);
2573 }
2574
2575 #[test]
2576 fn neural_mixture_adapts_to_correct_symbol() {
2577 let configs = vec![
2578 ExpertConfig::uniform("zero", || Box::new(AlwaysPredict { byte: 0 })),
2579 ExpertConfig::uniform("one", || Box::new(AlwaysPredict { byte: 1 })),
2580 ];
2581 let mut mix = NeuralMixture::new(&configs, 0.05);
2582
2583 let mut early = 0.0;
2584 let mut late = 0.0;
2585 for t in 0..200 {
2586 let lp = mix.step(0);
2587 if t < 20 {
2588 early -= lp;
2589 }
2590 if t >= 180 {
2591 late -= lp;
2592 }
2593 }
2594
2595 let early_avg = early / 20.0;
2596 let late_avg = late / 20.0;
2597 assert!(
2598 late_avg < early_avg,
2599 "late_avg={late_avg} early_avg={early_avg}"
2600 );
2601 assert!(late_avg < 0.35, "late_avg={late_avg}");
2602 }
2603
2604 #[derive(Clone)]
2605 struct CountingPredict {
2606 calls: Arc<AtomicUsize>,
2607 }
2608
2609 impl OnlineBytePredictor for CountingPredict {
2610 fn log_prob(&mut self, symbol: u8) -> f64 {
2611 self.calls.fetch_add(1, Ordering::Relaxed);
2612 if symbol == 0 { 0.0 } else { -20.0 }
2613 }
2614
2615 fn update(&mut self, _symbol: u8) {}
2616 }
2617
2618 #[derive(Clone)]
2619 struct CountingFillPredict {
2620 log_calls: Arc<AtomicUsize>,
2621 fill_calls: Arc<AtomicUsize>,
2622 }
2623
2624 impl OnlineBytePredictor for CountingFillPredict {
2625 fn log_prob(&mut self, symbol: u8) -> f64 {
2626 self.log_calls.fetch_add(1, Ordering::Relaxed);
2627 if symbol == 0 { 0.0 } else { -20.0 }
2628 }
2629
2630 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
2631 self.fill_calls.fetch_add(1, Ordering::Relaxed);
2632 out.fill(-20.0);
2633 out[0] = 0.0;
2634 }
2635
2636 fn update(&mut self, _symbol: u8) {}
2637 }
2638
2639 #[derive(Clone)]
2640 struct BeginAwarePredict {
2641 seen_total: Arc<AtomicU64>,
2642 began: bool,
2643 }
2644
2645 impl OnlineBytePredictor for BeginAwarePredict {
2646 fn begin_stream(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
2647 let total = total_symbols.ok_or_else(|| "missing total symbols".to_string())?;
2648 self.seen_total.store(total, Ordering::Relaxed);
2649 self.began = true;
2650 Ok(())
2651 }
2652
2653 fn log_prob(&mut self, _symbol: u8) -> f64 {
2654 if self.began { 0.0 } else { f64::NEG_INFINITY }
2655 }
2656
2657 fn update(&mut self, _symbol: u8) {}
2658 }
2659
2660 fn assert_log_prob_update_matches_separate(label: &str, backend: RateBackend) {
2661 let mut separate =
2662 RateBackendPredictor::from_backend(backend.clone(), -1, DEFAULT_MIN_PROB);
2663 let mut combined = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
2664 let data = b"combined step check data";
2665
2666 for &b in data {
2667 let logp_separate = separate.log_prob(b);
2668 separate.update(b);
2669 let logp_combined = combined.log_prob_update(b);
2670 let diff = (logp_separate - logp_combined).abs();
2671 assert!(
2672 diff <= 1e-12,
2673 "[{label}] symbol={b} separate={logp_separate} combined={logp_combined} diff={diff}"
2674 );
2675
2676 let mut sep_row = [0.0; 256];
2677 let mut combo_row = [0.0; 256];
2678 separate.fill_log_probs(&mut sep_row);
2679 combined.fill_log_probs(&mut combo_row);
2680 for i in 0..256 {
2681 let diff = (sep_row[i] - combo_row[i]).abs();
2682 assert!(
2683 diff <= 1e-12,
2684 "row mismatch at {i}: {} vs {}",
2685 sep_row[i],
2686 combo_row[i]
2687 );
2688 }
2689 }
2690 }
2691
2692 fn assert_fill_matches_symbol_queries(label: &str, backend: RateBackend) {
2693 let mut bulk = RateBackendPredictor::from_backend(backend.clone(), -1, DEFAULT_MIN_PROB);
2694 let mut queried = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
2695 let data = b"continuation consistency prompt";
2696
2697 bulk.begin_stream(Some(data.len() as u64))
2698 .expect("bulk begin");
2699 queried
2700 .begin_stream(Some(data.len() as u64))
2701 .expect("query begin");
2702 for &b in data {
2703 bulk.update(b);
2704 queried.update(b);
2705 }
2706
2707 let mut bulk_row = [0.0; 256];
2708 bulk.fill_log_probs(&mut bulk_row);
2709 for (sym, &bulk_logp) in bulk_row.iter().enumerate() {
2710 let queried_logp = queried.log_prob(sym as u8);
2711 let diff = (bulk_logp - queried_logp).abs();
2712 assert!(
2713 diff <= 1e-12,
2714 "[{label}] sym={sym} bulk={bulk_logp} queried={queried_logp} diff={diff}"
2715 );
2716 }
2717 }
2718
2719 fn assert_fill_matches_symbol_queries_after_frozen_conditioning(
2720 label: &str,
2721 backend: RateBackend,
2722 ) {
2723 let fit = b"If a frog is green, dogs are red.\nIf a toad is green, cats are red.\n";
2724 let condition = b"If a cat is red, toads are \n";
2725 let total = (fit.len() + condition.len()) as u64;
2726
2727 let mut bulk = RateBackendPredictor::from_backend(backend.clone(), -1, DEFAULT_MIN_PROB);
2728 let mut queried = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
2729
2730 bulk.begin_stream(Some(total)).expect("bulk begin");
2731 queried.begin_stream(Some(total)).expect("query begin");
2732 for &b in fit {
2733 bulk.update(b);
2734 queried.update(b);
2735 }
2736 bulk.reset_frozen(Some(condition.len() as u64))
2737 .expect("bulk reset frozen");
2738 queried
2739 .reset_frozen(Some(condition.len() as u64))
2740 .expect("query reset frozen");
2741 for &b in condition {
2742 bulk.update_frozen(b);
2743 queried.update_frozen(b);
2744 }
2745
2746 let mut bulk_row = [0.0; 256];
2747 bulk.fill_log_probs(&mut bulk_row);
2748 for (sym, &bulk_logp) in bulk_row.iter().enumerate() {
2749 let queried_logp = queried.log_prob(sym as u8);
2750 let diff = (bulk_logp - queried_logp).abs();
2751 assert!(
2752 diff <= 1e-12,
2753 "[{label}] frozen sym={sym} bulk={bulk_logp} queried={queried_logp} diff={diff}"
2754 );
2755 }
2756 }
2757
2758 #[test]
2759 fn predictor_log_prob_update_matches_separate_update_for_rosa_backend() {
2760 assert_log_prob_update_matches_separate("rosa", RateBackend::RosaPlus);
2761 }
2762
2763 #[test]
2764 fn predictor_log_prob_update_matches_separate_update_for_ctw_backend() {
2765 assert_log_prob_update_matches_separate("ctw", RateBackend::Ctw { depth: 6 });
2766 }
2767
2768 #[test]
2769 fn predictor_log_prob_update_matches_separate_update_for_fac_ctw_backend() {
2770 assert_log_prob_update_matches_separate(
2771 "fac-ctw",
2772 RateBackend::FacCtw {
2773 base_depth: 6,
2774 num_percept_bits: 8,
2775 encoding_bits: 8,
2776 },
2777 );
2778 }
2779
2780 #[test]
2781 fn predictor_fill_matches_symbol_queries_for_rosa_backend() {
2782 assert_fill_matches_symbol_queries("rosa", RateBackend::RosaPlus);
2783 }
2784
2785 #[test]
2786 fn predictor_fill_matches_symbol_queries_for_ctw_backend() {
2787 assert_fill_matches_symbol_queries("ctw", RateBackend::Ctw { depth: 6 });
2788 }
2789
2790 #[test]
2791 fn predictor_fill_matches_symbol_queries_for_match_backend() {
2792 assert_fill_matches_symbol_queries(
2793 "match",
2794 RateBackend::Match {
2795 hash_bits: 18,
2796 min_len: 4,
2797 max_len: 64,
2798 base_mix: 0.02,
2799 confidence_scale: 1.0,
2800 },
2801 );
2802 }
2803
2804 #[test]
2805 fn predictor_fill_matches_symbol_queries_for_ppmd_backend() {
2806 assert_fill_matches_symbol_queries(
2807 "ppmd",
2808 RateBackend::Ppmd {
2809 order: 8,
2810 memory_mb: 8,
2811 },
2812 );
2813 }
2814
2815 #[cfg(feature = "backend-rwkv")]
2816 #[test]
2817 fn predictor_fill_matches_symbol_queries_for_rwkv_backend() {
2818 assert_fill_matches_symbol_queries(
2819 "rwkv7",
2820 RateBackend::Rwkv7Method {
2821 method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
2822 },
2823 );
2824 }
2825
2826 #[test]
2827 fn predictor_fill_matches_symbol_queries_for_rosa_backend_after_frozen_conditioning() {
2828 assert_fill_matches_symbol_queries_after_frozen_conditioning("rosa", RateBackend::RosaPlus);
2829 }
2830
2831 #[test]
2832 fn predictor_frozen_conditioning_reuses_match_fit_corpus() {
2833 let mut predictor = RateBackendPredictor::from_backend(
2834 RateBackend::Match {
2835 hash_bits: 20,
2836 min_len: 3,
2837 max_len: 32,
2838 base_mix: 0.02,
2839 confidence_scale: 1.0,
2840 },
2841 -1,
2842 DEFAULT_MIN_PROB,
2843 );
2844
2845 for &b in b"abcabcX" {
2846 predictor.update(b);
2847 }
2848 predictor
2849 .reset_frozen(Some(6))
2850 .expect("reset frozen for match backend");
2851 for &b in b"abcabc" {
2852 predictor.update_frozen(b);
2853 }
2854 let p_x = predictor.log_prob(b'X').exp();
2855 assert!(
2856 p_x > 0.01,
2857 "frozen conditioning should preserve fit corpus for match backend; p_x={p_x}"
2858 );
2859 }
2860
2861 #[test]
2862 fn predictor_frozen_conditioning_reuses_sparse_match_fit_corpus() {
2863 let mut predictor = RateBackendPredictor::from_backend(
2864 RateBackend::SparseMatch {
2865 hash_bits: 20,
2866 min_len: 3,
2867 max_len: 32,
2868 gap_min: 0,
2869 gap_max: 2,
2870 base_mix: 0.02,
2871 confidence_scale: 1.0,
2872 },
2873 -1,
2874 DEFAULT_MIN_PROB,
2875 );
2876
2877 for &b in b"abcabcX" {
2878 predictor.update(b);
2879 }
2880 predictor
2881 .reset_frozen(Some(6))
2882 .expect("reset frozen for sparse-match backend");
2883 for &b in b"abcabc" {
2884 predictor.update_frozen(b);
2885 }
2886 let p_x = predictor.log_prob(b'X').exp();
2887 assert!(
2888 p_x > 0.01,
2889 "frozen conditioning should preserve fit corpus for sparse-match backend; p_x={p_x}"
2890 );
2891 }
2892
2893 #[test]
2894 fn neural_predict_then_step_reuses_evaluation_cache() {
2895 let c0 = Arc::new(AtomicUsize::new(0));
2896 let c1 = Arc::new(AtomicUsize::new(0));
2897 let cfg0 = {
2898 let c = c0.clone();
2899 ExpertConfig::uniform("c0", move || Box::new(CountingPredict { calls: c.clone() }))
2900 };
2901 let cfg1 = {
2902 let c = c1.clone();
2903 ExpertConfig::uniform("c1", move || Box::new(CountingPredict { calls: c.clone() }))
2904 };
2905 let mut mix = NeuralMixture::new(&[cfg0, cfg1], 0.03);
2906
2907 let _ = mix.predict_log_prob(0);
2908 let after_predict = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2909 assert_eq!(after_predict, 2);
2910
2911 let _ = mix.step(0);
2912 let after_step = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2913 assert_eq!(after_step, after_predict);
2914 }
2915
2916 #[test]
2917 fn neural_predict_multiple_symbols_reuses_single_evaluation() {
2918 let c0 = Arc::new(AtomicUsize::new(0));
2919 let c1 = Arc::new(AtomicUsize::new(0));
2920 let cfg0 = {
2921 let c = c0.clone();
2922 ExpertConfig::uniform("c0", move || Box::new(CountingPredict { calls: c.clone() }))
2923 };
2924 let cfg1 = {
2925 let c = c1.clone();
2926 ExpertConfig::uniform("c1", move || Box::new(CountingPredict { calls: c.clone() }))
2927 };
2928 let mut mix = NeuralMixture::new(&[cfg0, cfg1], 0.03);
2929
2930 let _ = mix.predict_log_prob(0);
2931 let after_first = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2932 assert_eq!(after_first, 2);
2933
2934 let _ = mix.predict_log_prob(1);
2935 let after_second = c0.load(Ordering::Relaxed) + c1.load(Ordering::Relaxed);
2936 assert_eq!(after_second, after_first + 2);
2937 }
2938
2939 #[test]
2940 fn neural_fill_then_step_reuses_cached_full_rows() {
2941 let log0 = Arc::new(AtomicUsize::new(0));
2942 let log1 = Arc::new(AtomicUsize::new(0));
2943 let fill0 = Arc::new(AtomicUsize::new(0));
2944 let fill1 = Arc::new(AtomicUsize::new(0));
2945 let cfg0 = {
2946 let log_calls = log0.clone();
2947 let fill_calls = fill0.clone();
2948 ExpertConfig::uniform("c0", move || {
2949 Box::new(CountingFillPredict {
2950 log_calls: log_calls.clone(),
2951 fill_calls: fill_calls.clone(),
2952 })
2953 })
2954 };
2955 let cfg1 = {
2956 let log_calls = log1.clone();
2957 let fill_calls = fill1.clone();
2958 ExpertConfig::uniform("c1", move || {
2959 Box::new(CountingFillPredict {
2960 log_calls: log_calls.clone(),
2961 fill_calls: fill_calls.clone(),
2962 })
2963 })
2964 };
2965 let mut mix = NeuralMixture::new(&[cfg0, cfg1], 0.03);
2966
2967 let mut row = [0.0; 256];
2968 mix.fill_log_probs(&mut row);
2969 assert_eq!(fill0.load(Ordering::Relaxed), 1);
2970 assert_eq!(fill1.load(Ordering::Relaxed), 1);
2971 assert_eq!(log0.load(Ordering::Relaxed), 0);
2972 assert_eq!(log1.load(Ordering::Relaxed), 0);
2973
2974 let _ = mix.step(0);
2975 assert_eq!(fill0.load(Ordering::Relaxed), 1);
2976 assert_eq!(fill1.load(Ordering::Relaxed), 1);
2977 assert_eq!(log0.load(Ordering::Relaxed), 0);
2978 assert_eq!(log1.load(Ordering::Relaxed), 0);
2979 }
2980
2981 #[test]
2982 fn runtime_begin_stream_propagates_to_experts() {
2983 let seen_total = Arc::new(AtomicU64::new(0));
2984 let cfg = {
2985 let seen_total = seen_total.clone();
2986 ExpertConfig::uniform("begin-aware", move || {
2987 Box::new(BeginAwarePredict {
2988 seen_total: seen_total.clone(),
2989 began: false,
2990 })
2991 })
2992 };
2993
2994 let spec = MixtureSpec::new(MixtureKind::Bayes, vec![]);
2995 let mut runtime = build_mixture_runtime(&spec, &[cfg]).expect("runtime");
2996 runtime.begin_stream(Some(123)).expect("begin stream");
2997 let _ = runtime.step(0);
2998 assert_eq!(seen_total.load(Ordering::Relaxed), 123);
2999 }
3000
3001 #[test]
3002 fn zpaq_fill_log_probs_does_not_drift_history() {
3003 let backend = RateBackend::Zpaq {
3004 method: "1".to_string(),
3005 };
3006 let mut baseline =
3007 RateBackendPredictor::from_backend(backend.clone(), -1, DEFAULT_MIN_PROB);
3008 let mut probe = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
3009
3010 let history = b"history for zpaq predictor";
3011 for &b in history {
3012 baseline.update(b);
3013 probe.update(b);
3014 }
3015
3016 let mut row = [0.0f64; 256];
3017 probe.fill_log_probs(&mut row);
3018
3019 let sym = b'k';
3020 let lp_base = baseline.log_prob(sym);
3021 let lp_probe = probe.log_prob(sym);
3022 assert!((lp_base - lp_probe).abs() < 1e-9);
3023 assert!((row[sym as usize] - lp_base).abs() < 1e-9);
3024
3025 baseline.update(sym);
3026 probe.update(sym);
3027 let next = b'q';
3028 let next_base = baseline.log_prob(next);
3029 let next_probe = probe.log_prob(next);
3030 assert!((next_base - next_probe).abs() < 1e-9);
3031 }
3032
3033 fn assert_predictor_log_probs_normalize_to_one(backend: RateBackend) {
3034 let mut predictor = RateBackendPredictor::from_backend(backend, -1, DEFAULT_MIN_PROB);
3035 for &b in b"normalization corpus for ctw/fac predictor checks" {
3036 predictor.update(b);
3037 }
3038 let mut sum = 0.0f64;
3039 for sym in 0u8..=255u8 {
3040 sum += predictor.log_prob(sym).exp();
3041 }
3042 assert!(
3043 (sum - 1.0).abs() <= 1e-10,
3044 "probability mass drift: sum={sum}"
3045 );
3046 }
3047
3048 #[test]
3049 fn ctw_predictor_symbol_probs_normalize() {
3050 assert_predictor_log_probs_normalize_to_one(RateBackend::Ctw { depth: 7 });
3051 }
3052
3053 #[test]
3054 fn fac_ctw_predictor_symbol_probs_normalize() {
3055 assert_predictor_log_probs_normalize_to_one(RateBackend::FacCtw {
3056 base_depth: 7,
3057 num_percept_bits: 8,
3058 encoding_bits: 8,
3059 });
3060 }
3061}