1use crate::RateBackend;
8use crate::aixi::rate_backend::rate_backend_contains_zpaq;
9use crate::ctw::{ContextTree, FacContextTree};
10#[cfg(feature = "backend-mamba")]
11use crate::mambazip::{Compressor as MambaCompressor, Model as MambaModel, State as MambaState};
12use crate::mixture::{
13 DEFAULT_MIN_PROB, OnlineBytePredictor, RateBackendPredictor, RateBackendPredictorCheckpoint,
14};
15use crate::rosaplus::{RosaPlus, RosaTx};
16#[cfg(feature = "backend-rwkv")]
17use crate::rwkvzip::{Compressor as RwkvCompressor, Model as RwkvModel, State as RwkvState};
18use crate::zpaq_rate::ZpaqRateModel;
19#[cfg(any(feature = "backend-mamba", feature = "backend-rwkv"))]
20use std::sync::Arc;
21
22pub trait Predictor: Send {
28 fn update(&mut self, sym: bool);
30
31 fn commit_update(&mut self, sym: bool) {
34 self.update(sym);
35 }
36
37 fn update_history(&mut self, sym: bool) {
40 self.update(sym);
41 }
42
43 fn commit_update_history(&mut self, sym: bool) {
46 self.update_history(sym);
47 }
48
49 fn revert(&mut self);
51
52 fn pop_history(&mut self) {
54 self.revert();
55 }
56
57 fn begin_rollback_scope(&mut self) {}
62
63 fn rollback_scope(&mut self) -> bool {
68 false
69 }
70
71 fn predict_prob(&mut self, sym: bool) -> f64;
73
74 fn predict_one(&mut self) -> f64 {
76 self.predict_prob(true)
77 }
78
79 fn model_name(&self) -> String;
81
82 fn boxed_clone(&self) -> Box<dyn Predictor>;
84}
85
86#[inline]
87fn binary_prob_floor(min_prob: f64) -> f64 {
88 if min_prob.is_finite() {
89 min_prob.clamp(1e-12, 0.499_999_999_999)
90 } else {
91 1e-12
92 }
93}
94
95#[inline]
96fn normalized_binary_prob_pair_from_probs(p0: f64, p1: f64, min_prob: f64) -> (f64, f64) {
97 let p0 = if p0.is_finite() && p0 > 0.0 { p0 } else { 0.0 };
98 let p1 = if p1.is_finite() && p1 > 0.0 { p1 } else { 0.0 };
99 let sum = p0 + p1;
100 if !sum.is_finite() || sum <= 0.0 {
101 return (0.5, 0.5);
102 }
103 let floor = binary_prob_floor(min_prob);
104 let q1 = (p1 / sum).clamp(floor, 1.0 - floor);
105 (1.0 - q1, q1)
106}
107
108#[inline]
109fn normalized_binary_prob_pair_from_log_probs(logp0: f64, logp1: f64, min_prob: f64) -> (f64, f64) {
110 let max_log = logp0.max(logp1);
111 if !max_log.is_finite() {
112 return (0.5, 0.5);
113 }
114 let p0 = if logp0.is_finite() {
115 (logp0 - max_log).exp()
116 } else {
117 0.0
118 };
119 let p1 = if logp1.is_finite() {
120 (logp1 - max_log).exp()
121 } else {
122 0.0
123 };
124 normalized_binary_prob_pair_from_probs(p0, p1, min_prob)
125}
126
127pub struct CtwPredictor {
132 tree: ContextTree,
133}
134
135impl CtwPredictor {
136 pub fn new(depth: usize) -> Self {
138 Self {
139 tree: ContextTree::new(depth),
140 }
141 }
142}
143
144impl Predictor for CtwPredictor {
145 fn update(&mut self, sym: bool) {
146 self.tree.update(sym);
147 }
148 fn update_history(&mut self, sym: bool) {
149 self.tree.update_history(&[sym]);
150 }
151
152 fn revert(&mut self) {
153 self.tree.revert();
154 }
155 fn pop_history(&mut self) {
156 self.tree.revert_history();
157 }
158
159 fn predict_prob(&mut self, sym: bool) -> f64 {
160 self.tree.predict(sym)
161 }
162
163 fn model_name(&self) -> String {
164 format!("AC-CTW(d={})", self.tree.depth())
165 }
166
167 fn boxed_clone(&self) -> Box<dyn Predictor> {
168 Box::new(Self {
169 tree: self.tree.clone(),
170 })
171 }
172}
173
174pub struct FacCtwPredictor {
182 tree: FacContextTree,
183 current_bit: usize,
185 num_bits: usize,
187}
188
189impl FacCtwPredictor {
190 pub fn new(base_depth: usize, num_percept_bits: usize) -> Self {
195 Self {
196 tree: FacContextTree::new(base_depth, num_percept_bits),
197 current_bit: 0,
198 num_bits: num_percept_bits,
199 }
200 }
201}
202
203impl Predictor for FacCtwPredictor {
204 fn update(&mut self, sym: bool) {
205 self.tree.update(sym, self.current_bit);
206 self.current_bit = (self.current_bit + 1) % self.num_bits;
207 }
208
209 fn update_history(&mut self, sym: bool) {
210 self.tree.update_history(&[sym]);
211 }
212
213 fn revert(&mut self) {
214 self.current_bit = if self.current_bit == 0 {
216 self.num_bits - 1
217 } else {
218 self.current_bit - 1
219 };
220 self.tree.revert(self.current_bit);
221 }
222
223 fn pop_history(&mut self) {
224 self.tree.revert_history(1);
225 }
226
227 fn predict_prob(&mut self, sym: bool) -> f64 {
228 self.tree.predict(sym, self.current_bit)
229 }
230
231 fn model_name(&self) -> String {
232 format!("FAC-CTW(D={}, k={})", self.tree.base_depth(), self.num_bits)
233 }
234
235 fn boxed_clone(&self) -> Box<dyn Predictor> {
236 Box::new(Self {
237 tree: self.tree.clone(),
238 current_bit: self.current_bit,
239 num_bits: self.num_bits,
240 })
241 }
242}
243
244pub struct RosaPredictor {
249 model: RosaPlus,
250 history: Vec<RosaTx>,
251}
252
253impl RosaPredictor {
254 pub fn new(max_order: i64) -> Self {
257 let mut model = RosaPlus::new(max_order, false, 0, 42);
259 model.build_lm_full_bytes_no_finalize_endpos();
261 Self {
262 model,
263 history: Vec::new(),
264 }
265 }
266}
267
268impl Predictor for RosaPredictor {
269 fn update(&mut self, sym: bool) {
270 let mut tx = self.model.begin_tx();
271 let byte = if sym { 1u8 } else { 0u8 };
273
274 self.model.train_sequence_tx(&mut tx, &[byte]);
276 self.history.push(tx);
277 }
278
279 fn revert(&mut self) {
280 if let Some(tx) = self.history.pop() {
281 self.model.rollback_tx(tx);
282 }
283 }
284
285 fn predict_prob(&mut self, sym: bool) -> f64 {
286 let (p0, p1) = normalized_binary_prob_pair_from_probs(
287 self.model.prob_for_last(0),
288 self.model.prob_for_last(1),
289 DEFAULT_MIN_PROB,
290 );
291 if sym { p1 } else { p0 }
292 }
293
294 fn model_name(&self) -> String {
295 "ROSA".to_string()
296 }
297
298 fn boxed_clone(&self) -> Box<dyn Predictor> {
299 Box::new(Self {
300 model: self.model.clone(),
301 history: self.history.clone(),
302 })
303 }
304}
305
306pub struct ZpaqPredictor {
311 method: String,
312 min_prob: f64,
313 model: ZpaqRateModel,
314 history: Vec<u8>,
315 pending: Option<(u8, f64)>,
316}
317
318impl ZpaqPredictor {
319 pub fn new(method: String, min_prob: f64) -> Self {
321 let model = ZpaqRateModel::new(method.clone(), min_prob);
322 Self {
323 method,
324 min_prob,
325 model,
326 history: Vec::new(),
327 pending: None,
328 }
329 }
330
331 fn rebuild_from_history(&mut self) {
332 self.model.reset();
333 if !self.history.is_empty() {
334 self.model.update_and_score(&self.history);
335 }
336 }
337
338 fn log_prob_from_history(&self, symbol: u8) -> f64 {
339 let mut tmp = ZpaqRateModel::new(self.method.clone(), self.min_prob);
340 if !self.history.is_empty() {
341 tmp.update_and_score(&self.history);
342 }
343 tmp.log_prob(symbol)
344 }
345
346 fn binary_log_prob_pair(&mut self, preferred_symbol: u8) -> (f64, f64) {
347 let other_symbol = preferred_symbol ^ 1;
348 let preferred_logp = match self.pending {
349 Some((pending, logp)) if pending == preferred_symbol => logp,
350 Some(_) => self.log_prob_from_history(preferred_symbol),
351 None => {
352 let logp = self.model.log_prob(preferred_symbol);
353 self.pending = Some((preferred_symbol, logp));
354 logp
355 }
356 };
357 let other_logp = match self.pending {
358 Some((pending, logp)) if pending == other_symbol => logp,
359 _ => self.log_prob_from_history(other_symbol),
360 };
361 if preferred_symbol == 0 {
362 (preferred_logp, other_logp)
363 } else {
364 (other_logp, preferred_logp)
365 }
366 }
367}
368
369impl Predictor for ZpaqPredictor {
370 fn update(&mut self, sym: bool) {
371 let byte = if sym { 1u8 } else { 0u8 };
372 if let Some((pending, _)) = self.pending {
373 if pending == byte {
374 self.model.update(byte);
375 self.pending = None;
376 self.history.push(byte);
377 return;
378 }
379 self.pending = None;
380 self.rebuild_from_history();
381 }
382 self.model.update(byte);
383 self.history.push(byte);
384 }
385
386 fn revert(&mut self) {
387 if self.history.pop().is_some() {
388 self.pending = None;
389 self.rebuild_from_history();
390 }
391 }
392
393 fn predict_prob(&mut self, sym: bool) -> f64 {
394 let preferred_symbol = if sym { 1u8 } else { 0u8 };
395 let (logp0, logp1) = self.binary_log_prob_pair(preferred_symbol);
396 let (p0, p1) = normalized_binary_prob_pair_from_log_probs(logp0, logp1, self.min_prob);
397 if sym { p1 } else { p0 }
398 }
399
400 fn model_name(&self) -> String {
401 format!("ZPAQ({})", self.method)
402 }
403
404 fn boxed_clone(&self) -> Box<dyn Predictor> {
405 Box::new(Self {
406 method: self.method.clone(),
407 min_prob: self.min_prob,
408 model: self.model.clone(),
409 history: self.history.clone(),
410 pending: self.pending,
411 })
412 }
413}
414
415pub struct RateBackendBitPredictor {
421 backend: RateBackend,
422 max_order: i64,
423 min_prob: f64,
424 predictor: RateBackendPredictor,
425 journal: Vec<RateBackendJournalEntry>,
426 rollback_scopes: Vec<RateBackendRollbackScope>,
427}
428
429#[derive(Clone, Copy, Debug, Eq, PartialEq)]
430enum RateBackendJournalKind {
431 Update,
432 FrozenUpdate,
433}
434
435#[derive(Clone)]
436struct RateBackendJournalEntry {
437 kind: RateBackendJournalKind,
438 checkpoint: RateBackendPredictorCheckpoint,
439}
440
441#[derive(Clone)]
442struct RateBackendRollbackScope {
443 checkpoint: RateBackendPredictorCheckpoint,
444 journal_len: usize,
445}
446
447impl RateBackendBitPredictor {
448 pub fn new(backend: RateBackend, max_order: i64) -> Result<Self, String> {
450 Self::new_with_min_prob(backend, max_order, DEFAULT_MIN_PROB)
451 }
452
453 pub fn new_with_min_prob(
455 backend: RateBackend,
456 max_order: i64,
457 min_prob: f64,
458 ) -> Result<Self, String> {
459 if rate_backend_contains_zpaq(&backend) {
460 return Err(
461 "RateBackendBitPredictor does not support zpaq backends; use a non-zpaq rate_backend"
462 .to_string(),
463 );
464 }
465 let mut predictor =
466 RateBackendPredictor::from_backend(backend.clone(), max_order, min_prob);
467 predictor
468 .begin_stream(None)
469 .map_err(|err| format!("failed to start RateBackend predictor stream: {err}"))?;
470 Ok(Self {
471 backend,
472 max_order,
473 min_prob,
474 predictor,
475 journal: Vec::new(),
476 rollback_scopes: Vec::new(),
477 })
478 }
479
480 #[inline(always)]
481 fn bit_to_byte(sym: bool) -> u8 {
482 if sym { 1u8 } else { 0u8 }
483 }
484
485 fn clone_state(&self) -> Self {
486 Self {
487 backend: self.backend.clone(),
488 max_order: self.max_order,
489 min_prob: self.min_prob,
490 predictor: self.predictor.clone(),
491 journal: self.journal.clone(),
492 rollback_scopes: self.rollback_scopes.clone(),
493 }
494 }
495
496 fn checkpoint(&mut self, kind: RateBackendJournalKind) -> RateBackendJournalEntry {
497 RateBackendJournalEntry {
498 kind,
499 checkpoint: self.predictor.checkpoint(),
500 }
501 }
502
503 fn restore_last(&mut self, expected_kind: RateBackendJournalKind) {
504 assert!(
505 self.rollback_scopes.is_empty(),
506 "RateBackendBitPredictor per-symbol rollback inside active scope is unsupported"
507 );
508 let entry = self
509 .journal
510 .pop()
511 .expect("RateBackendBitPredictor rollback underflow");
512 assert_eq!(
513 entry.kind, expected_kind,
514 "RateBackendBitPredictor rollback kind mismatch: expected {expected_kind:?}, got {:?}",
515 entry.kind
516 );
517 self.predictor.restore_checkpoint(&entry.checkpoint);
518 if self.rollback_scopes.is_empty() && self.journal.is_empty() {
519 self.predictor.clear_checkpoints_if_supported();
520 }
521 }
522}
523
524impl Predictor for RateBackendBitPredictor {
525 fn update(&mut self, sym: bool) {
526 if self.rollback_scopes.is_empty() {
527 let checkpoint = self.checkpoint(RateBackendJournalKind::Update);
528 self.journal.push(checkpoint);
529 }
530 self.predictor.update(Self::bit_to_byte(sym));
531 }
532
533 fn commit_update(&mut self, sym: bool) {
534 self.predictor.update(Self::bit_to_byte(sym));
535 }
536
537 fn update_history(&mut self, sym: bool) {
538 if self.rollback_scopes.is_empty() {
539 let checkpoint = self.checkpoint(RateBackendJournalKind::FrozenUpdate);
540 self.journal.push(checkpoint);
541 }
542 self.predictor.update_frozen(Self::bit_to_byte(sym));
543 }
544
545 fn commit_update_history(&mut self, sym: bool) {
546 self.predictor.update_frozen(Self::bit_to_byte(sym));
547 }
548
549 fn revert(&mut self) {
550 self.restore_last(RateBackendJournalKind::Update);
551 }
552
553 fn pop_history(&mut self) {
554 self.restore_last(RateBackendJournalKind::FrozenUpdate);
555 }
556
557 fn begin_rollback_scope(&mut self) {
558 let checkpoint = self.predictor.checkpoint();
559 self.rollback_scopes.push(RateBackendRollbackScope {
560 checkpoint,
561 journal_len: self.journal.len(),
562 });
563 }
564
565 fn rollback_scope(&mut self) -> bool {
566 let Some(scope) = self.rollback_scopes.pop() else {
567 return false;
568 };
569 self.predictor.restore_checkpoint(&scope.checkpoint);
570 self.journal.truncate(scope.journal_len);
571 if self.rollback_scopes.is_empty() && self.journal.is_empty() {
572 self.predictor.clear_checkpoints_if_supported();
573 }
574 true
575 }
576
577 fn predict_prob(&mut self, sym: bool) -> f64 {
578 let (p0, p1) = normalized_binary_prob_pair_from_log_probs(
579 self.predictor.log_prob(0),
580 self.predictor.log_prob(1),
581 self.min_prob,
582 );
583 if sym { p1 } else { p0 }
584 }
585
586 fn model_name(&self) -> String {
587 format!(
588 "RateBackendBits({})",
589 RateBackendPredictor::default_name(&self.backend, self.max_order)
590 )
591 }
592
593 fn boxed_clone(&self) -> Box<dyn Predictor> {
594 Box::new(self.clone_state())
595 }
596}
597
598#[cfg(feature = "backend-rwkv")]
599use crate::coders::softmax_pdf_floor_inplace;
600
601#[cfg(feature = "backend-rwkv")]
606pub struct RwkvPredictor {
607 compressor: RwkvCompressor,
608 history: Vec<(RwkvState, Vec<f64>)>,
609}
610
611#[cfg(feature = "backend-rwkv")]
612impl RwkvPredictor {
613 pub fn new(model: Arc<RwkvModel>) -> Self {
615 let mut compressor = RwkvCompressor::new_from_model(model);
616 let vocab_size = compressor.vocab_size();
617 let logits = compressor
618 .model
619 .forward(&mut compressor.scratch, 0, &mut compressor.state);
620 softmax_pdf_floor_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
621
622 Self {
623 compressor,
624 history: Vec::new(),
625 }
626 }
627
628 pub fn from_method(method: &str) -> Result<Self, String> {
630 let mut compressor =
631 RwkvCompressor::new_from_method(method).map_err(|err| err.to_string())?;
632 compressor.forward_to_internal_pdf(0);
633 Ok(Self {
634 compressor,
635 history: Vec::new(),
636 })
637 }
638}
639
640#[cfg(feature = "backend-rwkv")]
641impl Predictor for RwkvPredictor {
642 fn update(&mut self, sym: bool) {
643 self.history.push((
645 self.compressor.state.clone(),
646 self.compressor.pdf_buffer.clone(),
647 ));
648
649 let byte = if sym { 1u32 } else { 0u32 };
650 let vocab_size = self.compressor.vocab_size();
651
652 let logits = self.compressor.model.forward(
653 &mut self.compressor.scratch,
654 byte,
655 &mut self.compressor.state,
656 );
657 softmax_pdf_floor_inplace(logits, vocab_size, &mut self.compressor.pdf_buffer);
658 }
659
660 fn revert(&mut self) {
661 if let Some((state, pdf)) = self.history.pop() {
662 self.compressor.state = state;
663 self.compressor.pdf_buffer = pdf;
664 }
665 }
666
667 fn predict_prob(&mut self, sym: bool) -> f64 {
668 let (p0, p1) = normalized_binary_prob_pair_from_probs(
669 self.compressor.pdf_buffer[0],
670 self.compressor.pdf_buffer[1],
671 DEFAULT_MIN_PROB,
672 );
673 if sym { p1 } else { p0 }
674 }
675
676 fn model_name(&self) -> String {
677 "RWKV".to_string()
678 }
679
680 fn boxed_clone(&self) -> Box<dyn Predictor> {
681 Box::new(Self {
682 compressor: self.compressor.clone(),
683 history: self.history.clone(),
684 })
685 }
686}
687
688#[cfg(feature = "backend-mamba")]
690pub struct MambaPredictor {
691 compressor: MambaCompressor,
692 history: Vec<(MambaState, Vec<f64>)>,
693}
694
695#[cfg(feature = "backend-mamba")]
696impl MambaPredictor {
697 pub fn new(model: Arc<MambaModel>) -> Self {
699 let mut compressor = MambaCompressor::new_from_model(model);
700 let logits = compressor
701 .model
702 .forward(&mut compressor.scratch, 0, &mut compressor.state)
703 .to_vec();
704 let bias = compressor.online_bias_snapshot();
705 MambaCompressor::logits_to_pdf(&logits, bias.as_deref(), &mut compressor.pdf_buffer);
706
707 Self {
708 compressor,
709 history: Vec::new(),
710 }
711 }
712
713 pub fn from_method(method: &str) -> Result<Self, String> {
715 let mut compressor =
716 MambaCompressor::new_from_method(method).map_err(|err| err.to_string())?;
717 let mut pdf = vec![0.0f64; compressor.vocab_size()];
718 compressor.forward_to_pdf(0, &mut pdf);
719 compressor.pdf_buffer.clone_from(&pdf);
720 Ok(Self {
721 compressor,
722 history: Vec::new(),
723 })
724 }
725}
726
727#[cfg(feature = "backend-mamba")]
728impl Predictor for MambaPredictor {
729 fn update(&mut self, sym: bool) {
730 self.history.push((
731 self.compressor.state.clone(),
732 self.compressor.pdf_buffer.clone(),
733 ));
734
735 let byte = if sym { 1u32 } else { 0u32 };
736 let logits = self
737 .compressor
738 .model
739 .forward(
740 &mut self.compressor.scratch,
741 byte,
742 &mut self.compressor.state,
743 )
744 .to_vec();
745 let bias = self.compressor.online_bias_snapshot();
746 MambaCompressor::logits_to_pdf(&logits, bias.as_deref(), &mut self.compressor.pdf_buffer);
747 }
748
749 fn revert(&mut self) {
750 if let Some((state, pdf)) = self.history.pop() {
751 self.compressor.state = state;
752 self.compressor.pdf_buffer = pdf;
753 }
754 }
755
756 fn predict_prob(&mut self, sym: bool) -> f64 {
757 let (p0, p1) = normalized_binary_prob_pair_from_probs(
758 self.compressor.pdf_buffer[0],
759 self.compressor.pdf_buffer[1],
760 DEFAULT_MIN_PROB,
761 );
762 if sym { p1 } else { p0 }
763 }
764
765 fn model_name(&self) -> String {
766 "Mamba".to_string()
767 }
768
769 fn boxed_clone(&self) -> Box<dyn Predictor> {
770 Box::new(Self {
771 compressor: self.compressor.clone(),
772 history: self.history.clone(),
773 })
774 }
775}
776
777#[cfg(test)]
778mod tests {
779 use super::*;
780
781 fn approx_eq(a: f64, b: f64) {
782 let diff = (a - b).abs();
783 assert!(
784 diff <= 1e-12,
785 "expected probabilities to match exactly enough: left={a} right={b} diff={diff}"
786 );
787 }
788
789 fn assert_binary_predictor_normalizes(mut predictor: Box<dyn Predictor>, label: &str) {
790 for (step, &bit) in [false, true, true, false, true, false].iter().enumerate() {
791 let p0 = predictor.predict_prob(false);
792 let p1 = predictor.predict_prob(true);
793 let sum = p0 + p1;
794 assert!(
795 (sum - 1.0).abs() < 1e-12,
796 "{label}: probabilities must sum to 1 at step {step}, got p0={p0}, p1={p1}, sum={sum}",
797 );
798 assert!(
799 (0.0..=1.0).contains(&p0) && (0.0..=1.0).contains(&p1),
800 "{label}: probabilities must stay in [0,1] at step {step}, got p0={p0}, p1={p1}",
801 );
802 predictor.commit_update(bit);
803 }
804 }
805
806 fn predictor_signature(
807 mut predictor: RateBackendBitPredictor,
808 probe: &[bool],
809 ) -> Vec<(f64, f64)> {
810 let mut signature = Vec::with_capacity(probe.len());
811 for &bit in probe {
812 signature.push((predictor.predict_prob(false), predictor.predict_prob(true)));
813 predictor.commit_update(bit);
814 }
815 signature
816 }
817
818 #[test]
819 fn committed_rate_backend_updates_do_not_grow_journal() {
820 let mut predictor = RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
821 .expect("rate backend predictor should initialize");
822
823 for idx in 0..512usize {
824 predictor.commit_update((idx & 1) == 0);
825 predictor.commit_update_history((idx % 3) == 0);
826 }
827
828 assert!(
829 predictor.journal.is_empty(),
830 "committed history should not retain rollback snapshots"
831 );
832 }
833
834 #[test]
835 fn reversible_rate_backend_update_paths_round_trip_exactly() {
836 let mut predictor = RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
837 .expect("rate backend predictor should initialize");
838 for &bit in &[true, false, true, true, false, false, true] {
839 predictor.commit_update(bit);
840 }
841
842 let baseline_after_train = predictor.clone_state();
843 predictor.update(true);
844 predictor.update(false);
845 predictor.revert();
846 predictor.revert();
847 assert_eq!(predictor.journal.len(), baseline_after_train.journal.len());
848
849 let train_probe = [true, false, false, true, true, false];
850 let got = predictor_signature(predictor.clone_state(), &train_probe);
851 let want = predictor_signature(baseline_after_train.clone_state(), &train_probe);
852 for ((got0, got1), (want0, want1)) in got.into_iter().zip(want.into_iter()) {
853 approx_eq(got0, want0);
854 approx_eq(got1, want1);
855 }
856
857 let baseline_after_history = baseline_after_train.clone_state();
858 predictor.update_history(false);
859 predictor.update_history(true);
860 predictor.pop_history();
861 predictor.pop_history();
862 assert_eq!(
863 predictor.journal.len(),
864 baseline_after_history.journal.len()
865 );
866
867 let history_probe = [false, true, true, false, false, true];
868 let got = predictor_signature(predictor.clone_state(), &history_probe);
869 let want = predictor_signature(baseline_after_history, &history_probe);
870 for ((got0, got1), (want0, want1)) in got.into_iter().zip(want.into_iter()) {
871 approx_eq(got0, want0);
872 approx_eq(got1, want1);
873 }
874 }
875
876 #[test]
877 fn long_committed_history_does_not_contaminate_clone_rollback_state() {
878 let mut predictor = RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
879 .expect("rate backend predictor should initialize");
880
881 for idx in 0..2048usize {
882 predictor.commit_update((idx & 7) < 3);
883 predictor.commit_update_history((idx % 5) < 2);
884 }
885 assert!(predictor.journal.is_empty());
886
887 let mut cloned = predictor.clone_state();
888 assert!(
889 cloned.journal.is_empty(),
890 "clone state should only carry active reversible rollback depth"
891 );
892
893 let baseline = predictor_signature(predictor.clone_state(), &[true, false, true, false]);
894 cloned.update(true);
895 cloned.revert();
896 cloned.update_history(false);
897 cloned.pop_history();
898 assert!(cloned.journal.is_empty());
899
900 let after_round_trip = predictor_signature(cloned, &[true, false, true, false]);
901 for ((got0, got1), (want0, want1)) in after_round_trip.into_iter().zip(baseline.into_iter())
902 {
903 approx_eq(got0, want0);
904 approx_eq(got1, want1);
905 }
906 }
907
908 #[test]
909 fn rollback_scope_restores_simulation_state_without_growing_journal() {
910 let mut predictor = RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
911 .expect("rate backend predictor should initialize");
912 for &bit in &[true, false, true, false, true] {
913 predictor.commit_update(bit);
914 }
915
916 let baseline = predictor_signature(predictor.clone_state(), &[true, true, false, false]);
917 predictor.begin_rollback_scope();
918 for idx in 0..512usize {
919 predictor.update((idx & 1) == 0);
920 predictor.update_history((idx % 3) == 0);
921 }
922 assert!(
923 predictor.journal.is_empty(),
924 "scoped reversible updates should not retain per-bit snapshots"
925 );
926 assert!(predictor.rollback_scope(), "scope rollback should succeed");
927 assert!(predictor.journal.is_empty());
928
929 let after = predictor_signature(predictor, &[true, true, false, false]);
930 for ((got0, got1), (want0, want1)) in after.into_iter().zip(baseline.into_iter()) {
931 approx_eq(got0, want0);
932 approx_eq(got1, want1);
933 }
934 }
935
936 #[test]
937 fn cloned_predictor_carries_only_active_scope_snapshots() {
938 let mut predictor = RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
939 .expect("rate backend predictor should initialize");
940 for idx in 0..1024usize {
941 predictor.commit_update((idx & 3) == 0);
942 }
943
944 predictor.begin_rollback_scope();
945 for idx in 0..256usize {
946 predictor.update((idx & 1) == 0);
947 }
948 let cloned = predictor.clone_state();
949 assert!(
950 cloned.journal.is_empty(),
951 "scoped reversible updates should not leak per-bit journal state into clones"
952 );
953 assert_eq!(cloned.rollback_scopes.len(), 1);
954 }
955
956 #[test]
957 fn generic_rate_backend_bit_predictors_normalize_binary_mass() {
958 assert_binary_predictor_normalizes(
959 Box::new(
960 RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
961 .expect("generic rosa predictor"),
962 ),
963 "generic-rosa",
964 );
965 assert_binary_predictor_normalizes(
966 Box::new(
967 RateBackendBitPredictor::new(
968 RateBackend::Ppmd {
969 order: 4,
970 memory_mb: 8,
971 },
972 8,
973 )
974 .expect("generic ppmd predictor"),
975 ),
976 "generic-ppmd",
977 );
978 assert_binary_predictor_normalizes(
979 Box::new(
980 RateBackendBitPredictor::new(
981 RateBackend::Match {
982 hash_bits: 16,
983 min_len: 2,
984 max_len: 32,
985 base_mix: 0.05,
986 confidence_scale: 1.0,
987 },
988 8,
989 )
990 .expect("generic match predictor"),
991 ),
992 "generic-match",
993 );
994 }
995
996 #[cfg(feature = "backend-zpaq")]
997 #[test]
998 fn zpaq_predictor_normalizes_binary_mass() {
999 assert_binary_predictor_normalizes(
1000 Box::new(ZpaqPredictor::new("1".to_string(), DEFAULT_MIN_PROB)),
1001 "zpaq",
1002 );
1003 }
1004
1005 #[cfg(feature = "backend-rwkv")]
1006 #[test]
1007 fn rwkv_predictor_normalizes_binary_mass() {
1008 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer";
1009 let predictor = RwkvPredictor::from_method(method).expect("rwkv predictor");
1010 assert_binary_predictor_normalizes(Box::new(predictor), "rwkv");
1011 }
1012
1013 #[cfg(feature = "backend-mamba")]
1014 #[test]
1015 fn mamba_predictor_normalizes_binary_mass() {
1016 let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=7,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer";
1017 let predictor = MambaPredictor::from_method(method).expect("mamba predictor");
1018 assert_binary_predictor_normalizes(Box::new(predictor), "mamba");
1019 }
1020}