Skip to main content

infotheory/aixi/
model.rs

1//! Predictive models for AIXI.
2//!
3//! This module defines the `Predictor` trait, which encapsulates the logic
4//! for learning from history and predicting future symbols. Different implementations
5//! provide different complexity vs performance trade-offs.
6
7use crate::RateBackend;
8use crate::aixi::rate_backend::rate_backend_contains_zpaq;
9use crate::ctw::{ContextTree, FacContextTree};
10#[cfg(feature = "backend-mamba")]
11use crate::mambazip::{Compressor as MambaCompressor, Model as MambaModel, State as MambaState};
12use crate::mixture::{
13    DEFAULT_MIN_PROB, OnlineBytePredictor, RateBackendPredictor, RateBackendPredictorCheckpoint,
14};
15use crate::rosaplus::{RosaPlus, RosaTx};
16#[cfg(feature = "backend-rwkv")]
17use crate::rwkvzip::{Compressor as RwkvCompressor, Model as RwkvModel, State as RwkvState};
18use crate::zpaq_rate::ZpaqRateModel;
19#[cfg(any(feature = "backend-mamba", feature = "backend-rwkv"))]
20use std::sync::Arc;
21
22/// Interface for an AIXI world model.
23///
24/// Predictors are mutated behind `&mut self` and cloned per worker during
25/// parallel MCTS. They only need `Send`, not `Sync`, which avoids unsound
26/// thread-sharing requirements for backends with thread-confined internals.
27pub trait Predictor: Send {
28    /// Incorporates a new symbol into the model's training history.
29    fn update(&mut self, sym: bool);
30
31    /// Incorporates a new symbol as committed training history without
32    /// retaining rollback state when the predictor supports that optimization.
33    fn commit_update(&mut self, sym: bool) {
34        self.update(sym);
35    }
36
37    /// Appends a symbol to the model's interaction history without necessarily
38    /// updating the training counts immediately (backend dependent).
39    fn update_history(&mut self, sym: bool) {
40        self.update(sym);
41    }
42
43    /// Appends a symbol as committed interaction history without retaining
44    /// rollback state when the predictor supports that optimization.
45    fn commit_update_history(&mut self, sym: bool) {
46        self.update_history(sym);
47    }
48
49    /// Reverts the model to its state before the last `update`.
50    fn revert(&mut self);
51
52    /// Reverts the model to its state before the last `update_history`.
53    fn pop_history(&mut self) {
54        self.revert();
55    }
56
57    /// Begins an optional coarse rollback scope for simulation-heavy callers.
58    ///
59    /// Predictors that support this can avoid retaining per-symbol rollback state
60    /// until the matching `rollback_scope` call.
61    fn begin_rollback_scope(&mut self) {}
62
63    /// Rolls back to the last scope opened with `begin_rollback_scope`.
64    ///
65    /// Returns `true` when a scope rollback was performed, allowing callers to skip
66    /// per-symbol revert loops.
67    fn rollback_scope(&mut self) -> bool {
68        false
69    }
70
71    /// Predicts the probability of the next symbol being `sym`.
72    fn predict_prob(&mut self, sym: bool) -> f64;
73
74    /// Shorthand for `predict_prob(true)`.
75    fn predict_one(&mut self) -> f64 {
76        self.predict_prob(true)
77    }
78
79    /// Returns a human-readable name of the predictive model.
80    fn model_name(&self) -> String;
81
82    /// Creates a boxed clone of this predictor.
83    fn boxed_clone(&self) -> Box<dyn Predictor>;
84}
85
86#[inline]
87fn binary_prob_floor(min_prob: f64) -> f64 {
88    if min_prob.is_finite() {
89        min_prob.clamp(1e-12, 0.499_999_999_999)
90    } else {
91        1e-12
92    }
93}
94
95#[inline]
96fn normalized_binary_prob_pair_from_probs(p0: f64, p1: f64, min_prob: f64) -> (f64, f64) {
97    let p0 = if p0.is_finite() && p0 > 0.0 { p0 } else { 0.0 };
98    let p1 = if p1.is_finite() && p1 > 0.0 { p1 } else { 0.0 };
99    let sum = p0 + p1;
100    if !sum.is_finite() || sum <= 0.0 {
101        return (0.5, 0.5);
102    }
103    let floor = binary_prob_floor(min_prob);
104    let q1 = (p1 / sum).clamp(floor, 1.0 - floor);
105    (1.0 - q1, q1)
106}
107
108#[inline]
109fn normalized_binary_prob_pair_from_log_probs(logp0: f64, logp1: f64, min_prob: f64) -> (f64, f64) {
110    let max_log = logp0.max(logp1);
111    if !max_log.is_finite() {
112        return (0.5, 0.5);
113    }
114    let p0 = if logp0.is_finite() {
115        (logp0 - max_log).exp()
116    } else {
117        0.0
118    };
119    let p1 = if logp1.is_finite() {
120        (logp1 - max_log).exp()
121    } else {
122        0.0
123    };
124    normalized_binary_prob_pair_from_probs(p0, p1, min_prob)
125}
126
127/// A predictor using the Action-Conditional CTW algorithm.
128///
129/// AC-CTW uses a single context tree for all bits in sequence.
130/// For better type information exploitation, use `FacCtwPredictor`.
131pub struct CtwPredictor {
132    tree: ContextTree,
133}
134
135impl CtwPredictor {
136    /// Creates a new `CtwPredictor` with the specified context depth.
137    pub fn new(depth: usize) -> Self {
138        Self {
139            tree: ContextTree::new(depth),
140        }
141    }
142}
143
144impl Predictor for CtwPredictor {
145    fn update(&mut self, sym: bool) {
146        self.tree.update(sym);
147    }
148    fn update_history(&mut self, sym: bool) {
149        self.tree.update_history(&[sym]);
150    }
151
152    fn revert(&mut self) {
153        self.tree.revert();
154    }
155    fn pop_history(&mut self) {
156        self.tree.revert_history();
157    }
158
159    fn predict_prob(&mut self, sym: bool) -> f64 {
160        self.tree.predict(sym)
161    }
162
163    fn model_name(&self) -> String {
164        format!("AC-CTW(d={})", self.tree.depth())
165    }
166
167    fn boxed_clone(&self) -> Box<dyn Predictor> {
168        Box::new(Self {
169            tree: self.tree.clone(),
170        })
171    }
172}
173
174/// A predictor using the Factorized Action-Conditional CTW (FAC-CTW) algorithm.
175///
176/// FAC-CTW uses k separate context trees (one per percept bit) with overlapping
177/// context depths D+i-1 for bit i. This enables better exploitation of type
178/// information within percepts, as described in Veness et al. (2011) Section 5.
179///
180/// This is the recommended CTW variant for MC-AIXI agents.
181pub struct FacCtwPredictor {
182    tree: FacContextTree,
183    /// Current bit index within a percept (cycles 0..num_bits).
184    current_bit: usize,
185    /// Total number of percept bits (k).
186    num_bits: usize,
187}
188
189impl FacCtwPredictor {
190    /// Creates a new `FacCtwPredictor`.
191    ///
192    /// - `base_depth`: Context depth D for the first bit's tree
193    /// - `num_percept_bits`: Total bits per percept (observation_bits + reward_bits)
194    pub fn new(base_depth: usize, num_percept_bits: usize) -> Self {
195        Self {
196            tree: FacContextTree::new(base_depth, num_percept_bits),
197            current_bit: 0,
198            num_bits: num_percept_bits,
199        }
200    }
201}
202
203impl Predictor for FacCtwPredictor {
204    fn update(&mut self, sym: bool) {
205        self.tree.update(sym, self.current_bit);
206        self.current_bit = (self.current_bit + 1) % self.num_bits;
207    }
208
209    fn update_history(&mut self, sym: bool) {
210        self.tree.update_history(&[sym]);
211    }
212
213    fn revert(&mut self) {
214        // Revert to previous bit index
215        self.current_bit = if self.current_bit == 0 {
216            self.num_bits - 1
217        } else {
218            self.current_bit - 1
219        };
220        self.tree.revert(self.current_bit);
221    }
222
223    fn pop_history(&mut self) {
224        self.tree.revert_history(1);
225    }
226
227    fn predict_prob(&mut self, sym: bool) -> f64 {
228        self.tree.predict(sym, self.current_bit)
229    }
230
231    fn model_name(&self) -> String {
232        format!("FAC-CTW(D={}, k={})", self.tree.base_depth(), self.num_bits)
233    }
234
235    fn boxed_clone(&self) -> Box<dyn Predictor> {
236        Box::new(Self {
237            tree: self.tree.clone(),
238            current_bit: self.current_bit,
239            num_bits: self.num_bits,
240        })
241    }
242}
243
244/// A predictor using the ROSA-Plus (Rapid Online Suffix Automaton + Witten-Bell Smoother) algorithm.
245///
246/// ROSA is a (practically) sub-quadratic suffix automaton based language model that
247/// can handle very long contexts efficiently.
248pub struct RosaPredictor {
249    model: RosaPlus,
250    history: Vec<RosaTx>,
251}
252
253impl RosaPredictor {
254    /// Creates a new `RosaPredictor` with a maximum context length for the fallback LM.
255    /// Note: deterministic ROSA uses the full SAM and is not capped by `max_order`.
256    pub fn new(max_order: i64) -> Self {
257        // Seed and EOT
258        let mut model = RosaPlus::new(max_order, false, 0, 42);
259        // Important: Pre-build the LM with full byte alphabet so we can incrementally update it.
260        model.build_lm_full_bytes_no_finalize_endpos();
261        Self {
262            model,
263            history: Vec::new(),
264        }
265    }
266}
267
268impl Predictor for RosaPredictor {
269    fn update(&mut self, sym: bool) {
270        let mut tx = self.model.begin_tx();
271        // Using 0u8 and 1u8 for bits
272        let byte = if sym { 1u8 } else { 0u8 };
273
274        // train_example_tx updates the tx object and the model
275        self.model.train_sequence_tx(&mut tx, &[byte]);
276        self.history.push(tx);
277    }
278
279    fn revert(&mut self) {
280        if let Some(tx) = self.history.pop() {
281            self.model.rollback_tx(tx);
282        }
283    }
284
285    fn predict_prob(&mut self, sym: bool) -> f64 {
286        let (p0, p1) = normalized_binary_prob_pair_from_probs(
287            self.model.prob_for_last(0),
288            self.model.prob_for_last(1),
289            DEFAULT_MIN_PROB,
290        );
291        if sym { p1 } else { p0 }
292    }
293
294    fn model_name(&self) -> String {
295        "ROSA".to_string()
296    }
297
298    fn boxed_clone(&self) -> Box<dyn Predictor> {
299        Box::new(Self {
300            model: self.model.clone(),
301            history: self.history.clone(),
302        })
303    }
304}
305
306/// A predictor using ZPAQ as a streaming rate model.
307///
308/// This maintains a full history so it can rebuild state on revert and handle
309/// any misuse where `predict_prob` is called without a matching `update`.
310pub struct ZpaqPredictor {
311    method: String,
312    min_prob: f64,
313    model: ZpaqRateModel,
314    history: Vec<u8>,
315    pending: Option<(u8, f64)>,
316}
317
318impl ZpaqPredictor {
319    /// Create a ZPAQ-backed predictor from a `method` and probability floor.
320    pub fn new(method: String, min_prob: f64) -> Self {
321        let model = ZpaqRateModel::new(method.clone(), min_prob);
322        Self {
323            method,
324            min_prob,
325            model,
326            history: Vec::new(),
327            pending: None,
328        }
329    }
330
331    fn rebuild_from_history(&mut self) {
332        self.model.reset();
333        if !self.history.is_empty() {
334            self.model.update_and_score(&self.history);
335        }
336    }
337
338    fn log_prob_from_history(&self, symbol: u8) -> f64 {
339        let mut tmp = ZpaqRateModel::new(self.method.clone(), self.min_prob);
340        if !self.history.is_empty() {
341            tmp.update_and_score(&self.history);
342        }
343        tmp.log_prob(symbol)
344    }
345
346    fn binary_log_prob_pair(&mut self, preferred_symbol: u8) -> (f64, f64) {
347        let other_symbol = preferred_symbol ^ 1;
348        let preferred_logp = match self.pending {
349            Some((pending, logp)) if pending == preferred_symbol => logp,
350            Some(_) => self.log_prob_from_history(preferred_symbol),
351            None => {
352                let logp = self.model.log_prob(preferred_symbol);
353                self.pending = Some((preferred_symbol, logp));
354                logp
355            }
356        };
357        let other_logp = match self.pending {
358            Some((pending, logp)) if pending == other_symbol => logp,
359            _ => self.log_prob_from_history(other_symbol),
360        };
361        if preferred_symbol == 0 {
362            (preferred_logp, other_logp)
363        } else {
364            (other_logp, preferred_logp)
365        }
366    }
367}
368
369impl Predictor for ZpaqPredictor {
370    fn update(&mut self, sym: bool) {
371        let byte = if sym { 1u8 } else { 0u8 };
372        if let Some((pending, _)) = self.pending {
373            if pending == byte {
374                self.model.update(byte);
375                self.pending = None;
376                self.history.push(byte);
377                return;
378            }
379            self.pending = None;
380            self.rebuild_from_history();
381        }
382        self.model.update(byte);
383        self.history.push(byte);
384    }
385
386    fn revert(&mut self) {
387        if self.history.pop().is_some() {
388            self.pending = None;
389            self.rebuild_from_history();
390        }
391    }
392
393    fn predict_prob(&mut self, sym: bool) -> f64 {
394        let preferred_symbol = if sym { 1u8 } else { 0u8 };
395        let (logp0, logp1) = self.binary_log_prob_pair(preferred_symbol);
396        let (p0, p1) = normalized_binary_prob_pair_from_log_probs(logp0, logp1, self.min_prob);
397        if sym { p1 } else { p0 }
398    }
399
400    fn model_name(&self) -> String {
401        format!("ZPAQ({})", self.method)
402    }
403
404    fn boxed_clone(&self) -> Box<dyn Predictor> {
405        Box::new(Self {
406            method: self.method.clone(),
407            min_prob: self.min_prob,
408            model: self.model.clone(),
409            history: self.history.clone(),
410            pending: self.pending,
411        })
412    }
413}
414
415/// A generic bit-level predictor backed by any [`RateBackend`].
416///
417/// This adapter maps boolean symbols to bytes `{0,1}` and forwards them to the
418/// workspace-wide rate backend abstraction. It prioritizes correctness and
419/// backend coverage over rollback efficiency.
420pub struct RateBackendBitPredictor {
421    backend: RateBackend,
422    max_order: i64,
423    min_prob: f64,
424    predictor: RateBackendPredictor,
425    journal: Vec<RateBackendJournalEntry>,
426    rollback_scopes: Vec<RateBackendRollbackScope>,
427}
428
429#[derive(Clone, Copy, Debug, Eq, PartialEq)]
430enum RateBackendJournalKind {
431    Update,
432    FrozenUpdate,
433}
434
435#[derive(Clone)]
436struct RateBackendJournalEntry {
437    kind: RateBackendJournalKind,
438    checkpoint: RateBackendPredictorCheckpoint,
439}
440
441#[derive(Clone)]
442struct RateBackendRollbackScope {
443    checkpoint: RateBackendPredictorCheckpoint,
444    journal_len: usize,
445}
446
447impl RateBackendBitPredictor {
448    /// Create a new bit-level adapter from a rate backend.
449    pub fn new(backend: RateBackend, max_order: i64) -> Result<Self, String> {
450        Self::new_with_min_prob(backend, max_order, DEFAULT_MIN_PROB)
451    }
452
453    /// Create a new bit-level adapter with an explicit probability floor.
454    pub fn new_with_min_prob(
455        backend: RateBackend,
456        max_order: i64,
457        min_prob: f64,
458    ) -> Result<Self, String> {
459        if rate_backend_contains_zpaq(&backend) {
460            return Err(
461                "RateBackendBitPredictor does not support zpaq backends; use a non-zpaq rate_backend"
462                    .to_string(),
463            );
464        }
465        let mut predictor =
466            RateBackendPredictor::from_backend(backend.clone(), max_order, min_prob);
467        predictor
468            .begin_stream(None)
469            .map_err(|err| format!("failed to start RateBackend predictor stream: {err}"))?;
470        Ok(Self {
471            backend,
472            max_order,
473            min_prob,
474            predictor,
475            journal: Vec::new(),
476            rollback_scopes: Vec::new(),
477        })
478    }
479
480    #[inline(always)]
481    fn bit_to_byte(sym: bool) -> u8 {
482        if sym { 1u8 } else { 0u8 }
483    }
484
485    fn clone_state(&self) -> Self {
486        Self {
487            backend: self.backend.clone(),
488            max_order: self.max_order,
489            min_prob: self.min_prob,
490            predictor: self.predictor.clone(),
491            journal: self.journal.clone(),
492            rollback_scopes: self.rollback_scopes.clone(),
493        }
494    }
495
496    fn checkpoint(&mut self, kind: RateBackendJournalKind) -> RateBackendJournalEntry {
497        RateBackendJournalEntry {
498            kind,
499            checkpoint: self.predictor.checkpoint(),
500        }
501    }
502
503    fn restore_last(&mut self, expected_kind: RateBackendJournalKind) {
504        assert!(
505            self.rollback_scopes.is_empty(),
506            "RateBackendBitPredictor per-symbol rollback inside active scope is unsupported"
507        );
508        let entry = self
509            .journal
510            .pop()
511            .expect("RateBackendBitPredictor rollback underflow");
512        assert_eq!(
513            entry.kind, expected_kind,
514            "RateBackendBitPredictor rollback kind mismatch: expected {expected_kind:?}, got {:?}",
515            entry.kind
516        );
517        self.predictor.restore_checkpoint(&entry.checkpoint);
518        if self.rollback_scopes.is_empty() && self.journal.is_empty() {
519            self.predictor.clear_checkpoints_if_supported();
520        }
521    }
522}
523
524impl Predictor for RateBackendBitPredictor {
525    fn update(&mut self, sym: bool) {
526        if self.rollback_scopes.is_empty() {
527            let checkpoint = self.checkpoint(RateBackendJournalKind::Update);
528            self.journal.push(checkpoint);
529        }
530        self.predictor.update(Self::bit_to_byte(sym));
531    }
532
533    fn commit_update(&mut self, sym: bool) {
534        self.predictor.update(Self::bit_to_byte(sym));
535    }
536
537    fn update_history(&mut self, sym: bool) {
538        if self.rollback_scopes.is_empty() {
539            let checkpoint = self.checkpoint(RateBackendJournalKind::FrozenUpdate);
540            self.journal.push(checkpoint);
541        }
542        self.predictor.update_frozen(Self::bit_to_byte(sym));
543    }
544
545    fn commit_update_history(&mut self, sym: bool) {
546        self.predictor.update_frozen(Self::bit_to_byte(sym));
547    }
548
549    fn revert(&mut self) {
550        self.restore_last(RateBackendJournalKind::Update);
551    }
552
553    fn pop_history(&mut self) {
554        self.restore_last(RateBackendJournalKind::FrozenUpdate);
555    }
556
557    fn begin_rollback_scope(&mut self) {
558        let checkpoint = self.predictor.checkpoint();
559        self.rollback_scopes.push(RateBackendRollbackScope {
560            checkpoint,
561            journal_len: self.journal.len(),
562        });
563    }
564
565    fn rollback_scope(&mut self) -> bool {
566        let Some(scope) = self.rollback_scopes.pop() else {
567            return false;
568        };
569        self.predictor.restore_checkpoint(&scope.checkpoint);
570        self.journal.truncate(scope.journal_len);
571        if self.rollback_scopes.is_empty() && self.journal.is_empty() {
572            self.predictor.clear_checkpoints_if_supported();
573        }
574        true
575    }
576
577    fn predict_prob(&mut self, sym: bool) -> f64 {
578        let (p0, p1) = normalized_binary_prob_pair_from_log_probs(
579            self.predictor.log_prob(0),
580            self.predictor.log_prob(1),
581            self.min_prob,
582        );
583        if sym { p1 } else { p0 }
584    }
585
586    fn model_name(&self) -> String {
587        format!(
588            "RateBackendBits({})",
589            RateBackendPredictor::default_name(&self.backend, self.max_order)
590        )
591    }
592
593    fn boxed_clone(&self) -> Box<dyn Predictor> {
594        Box::new(self.clone_state())
595    }
596}
597
598#[cfg(feature = "backend-rwkv")]
599use crate::coders::softmax_pdf_floor_inplace;
600
601/// A predictor using the RWKV neural network architecture.
602///
603/// This provides a deep learning based world model for AIXI, allowing
604/// the agent to leverage large pre-trained models for sequence prediction.
605#[cfg(feature = "backend-rwkv")]
606pub struct RwkvPredictor {
607    compressor: RwkvCompressor,
608    history: Vec<(RwkvState, Vec<f64>)>,
609}
610
611#[cfg(feature = "backend-rwkv")]
612impl RwkvPredictor {
613    /// Creates a new `RwkvPredictor` from an initialized `Model`.
614    pub fn new(model: Arc<RwkvModel>) -> Self {
615        let mut compressor = RwkvCompressor::new_from_model(model);
616        let vocab_size = compressor.vocab_size();
617        let logits = compressor
618            .model
619            .forward(&mut compressor.scratch, 0, &mut compressor.state);
620        softmax_pdf_floor_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
621
622        Self {
623            compressor,
624            history: Vec::new(),
625        }
626    }
627
628    /// Creates a new `RwkvPredictor` from a method string.
629    pub fn from_method(method: &str) -> Result<Self, String> {
630        let mut compressor =
631            RwkvCompressor::new_from_method(method).map_err(|err| err.to_string())?;
632        compressor.forward_to_internal_pdf(0);
633        Ok(Self {
634            compressor,
635            history: Vec::new(),
636        })
637    }
638}
639
640#[cfg(feature = "backend-rwkv")]
641impl Predictor for RwkvPredictor {
642    fn update(&mut self, sym: bool) {
643        // Save current state and pdf
644        self.history.push((
645            self.compressor.state.clone(),
646            self.compressor.pdf_buffer.clone(),
647        ));
648
649        let byte = if sym { 1u32 } else { 0u32 };
650        let vocab_size = self.compressor.vocab_size();
651
652        let logits = self.compressor.model.forward(
653            &mut self.compressor.scratch,
654            byte,
655            &mut self.compressor.state,
656        );
657        softmax_pdf_floor_inplace(logits, vocab_size, &mut self.compressor.pdf_buffer);
658    }
659
660    fn revert(&mut self) {
661        if let Some((state, pdf)) = self.history.pop() {
662            self.compressor.state = state;
663            self.compressor.pdf_buffer = pdf;
664        }
665    }
666
667    fn predict_prob(&mut self, sym: bool) -> f64 {
668        let (p0, p1) = normalized_binary_prob_pair_from_probs(
669            self.compressor.pdf_buffer[0],
670            self.compressor.pdf_buffer[1],
671            DEFAULT_MIN_PROB,
672        );
673        if sym { p1 } else { p0 }
674    }
675
676    fn model_name(&self) -> String {
677        "RWKV".to_string()
678    }
679
680    fn boxed_clone(&self) -> Box<dyn Predictor> {
681        Box::new(Self {
682            compressor: self.compressor.clone(),
683            history: self.history.clone(),
684        })
685    }
686}
687
688/// A predictor using the Mamba neural network architecture.
689#[cfg(feature = "backend-mamba")]
690pub struct MambaPredictor {
691    compressor: MambaCompressor,
692    history: Vec<(MambaState, Vec<f64>)>,
693}
694
695#[cfg(feature = "backend-mamba")]
696impl MambaPredictor {
697    /// Creates a new `MambaPredictor` from an initialized `Model`.
698    pub fn new(model: Arc<MambaModel>) -> Self {
699        let mut compressor = MambaCompressor::new_from_model(model);
700        let logits = compressor
701            .model
702            .forward(&mut compressor.scratch, 0, &mut compressor.state)
703            .to_vec();
704        let bias = compressor.online_bias_snapshot();
705        MambaCompressor::logits_to_pdf(&logits, bias.as_deref(), &mut compressor.pdf_buffer);
706
707        Self {
708            compressor,
709            history: Vec::new(),
710        }
711    }
712
713    /// Creates a new `MambaPredictor` from a method string.
714    pub fn from_method(method: &str) -> Result<Self, String> {
715        let mut compressor =
716            MambaCompressor::new_from_method(method).map_err(|err| err.to_string())?;
717        let mut pdf = vec![0.0f64; compressor.vocab_size()];
718        compressor.forward_to_pdf(0, &mut pdf);
719        compressor.pdf_buffer.clone_from(&pdf);
720        Ok(Self {
721            compressor,
722            history: Vec::new(),
723        })
724    }
725}
726
727#[cfg(feature = "backend-mamba")]
728impl Predictor for MambaPredictor {
729    fn update(&mut self, sym: bool) {
730        self.history.push((
731            self.compressor.state.clone(),
732            self.compressor.pdf_buffer.clone(),
733        ));
734
735        let byte = if sym { 1u32 } else { 0u32 };
736        let logits = self
737            .compressor
738            .model
739            .forward(
740                &mut self.compressor.scratch,
741                byte,
742                &mut self.compressor.state,
743            )
744            .to_vec();
745        let bias = self.compressor.online_bias_snapshot();
746        MambaCompressor::logits_to_pdf(&logits, bias.as_deref(), &mut self.compressor.pdf_buffer);
747    }
748
749    fn revert(&mut self) {
750        if let Some((state, pdf)) = self.history.pop() {
751            self.compressor.state = state;
752            self.compressor.pdf_buffer = pdf;
753        }
754    }
755
756    fn predict_prob(&mut self, sym: bool) -> f64 {
757        let (p0, p1) = normalized_binary_prob_pair_from_probs(
758            self.compressor.pdf_buffer[0],
759            self.compressor.pdf_buffer[1],
760            DEFAULT_MIN_PROB,
761        );
762        if sym { p1 } else { p0 }
763    }
764
765    fn model_name(&self) -> String {
766        "Mamba".to_string()
767    }
768
769    fn boxed_clone(&self) -> Box<dyn Predictor> {
770        Box::new(Self {
771            compressor: self.compressor.clone(),
772            history: self.history.clone(),
773        })
774    }
775}
776
777#[cfg(test)]
778mod tests {
779    use super::*;
780
781    fn approx_eq(a: f64, b: f64) {
782        let diff = (a - b).abs();
783        assert!(
784            diff <= 1e-12,
785            "expected probabilities to match exactly enough: left={a} right={b} diff={diff}"
786        );
787    }
788
789    fn assert_binary_predictor_normalizes(mut predictor: Box<dyn Predictor>, label: &str) {
790        for (step, &bit) in [false, true, true, false, true, false].iter().enumerate() {
791            let p0 = predictor.predict_prob(false);
792            let p1 = predictor.predict_prob(true);
793            let sum = p0 + p1;
794            assert!(
795                (sum - 1.0).abs() < 1e-12,
796                "{label}: probabilities must sum to 1 at step {step}, got p0={p0}, p1={p1}, sum={sum}",
797            );
798            assert!(
799                (0.0..=1.0).contains(&p0) && (0.0..=1.0).contains(&p1),
800                "{label}: probabilities must stay in [0,1] at step {step}, got p0={p0}, p1={p1}",
801            );
802            predictor.commit_update(bit);
803        }
804    }
805
806    fn predictor_signature(
807        mut predictor: RateBackendBitPredictor,
808        probe: &[bool],
809    ) -> Vec<(f64, f64)> {
810        let mut signature = Vec::with_capacity(probe.len());
811        for &bit in probe {
812            signature.push((predictor.predict_prob(false), predictor.predict_prob(true)));
813            predictor.commit_update(bit);
814        }
815        signature
816    }
817
818    #[test]
819    fn committed_rate_backend_updates_do_not_grow_journal() {
820        let mut predictor = RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
821            .expect("rate backend predictor should initialize");
822
823        for idx in 0..512usize {
824            predictor.commit_update((idx & 1) == 0);
825            predictor.commit_update_history((idx % 3) == 0);
826        }
827
828        assert!(
829            predictor.journal.is_empty(),
830            "committed history should not retain rollback snapshots"
831        );
832    }
833
834    #[test]
835    fn reversible_rate_backend_update_paths_round_trip_exactly() {
836        let mut predictor = RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
837            .expect("rate backend predictor should initialize");
838        for &bit in &[true, false, true, true, false, false, true] {
839            predictor.commit_update(bit);
840        }
841
842        let baseline_after_train = predictor.clone_state();
843        predictor.update(true);
844        predictor.update(false);
845        predictor.revert();
846        predictor.revert();
847        assert_eq!(predictor.journal.len(), baseline_after_train.journal.len());
848
849        let train_probe = [true, false, false, true, true, false];
850        let got = predictor_signature(predictor.clone_state(), &train_probe);
851        let want = predictor_signature(baseline_after_train.clone_state(), &train_probe);
852        for ((got0, got1), (want0, want1)) in got.into_iter().zip(want.into_iter()) {
853            approx_eq(got0, want0);
854            approx_eq(got1, want1);
855        }
856
857        let baseline_after_history = baseline_after_train.clone_state();
858        predictor.update_history(false);
859        predictor.update_history(true);
860        predictor.pop_history();
861        predictor.pop_history();
862        assert_eq!(
863            predictor.journal.len(),
864            baseline_after_history.journal.len()
865        );
866
867        let history_probe = [false, true, true, false, false, true];
868        let got = predictor_signature(predictor.clone_state(), &history_probe);
869        let want = predictor_signature(baseline_after_history, &history_probe);
870        for ((got0, got1), (want0, want1)) in got.into_iter().zip(want.into_iter()) {
871            approx_eq(got0, want0);
872            approx_eq(got1, want1);
873        }
874    }
875
876    #[test]
877    fn long_committed_history_does_not_contaminate_clone_rollback_state() {
878        let mut predictor = RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
879            .expect("rate backend predictor should initialize");
880
881        for idx in 0..2048usize {
882            predictor.commit_update((idx & 7) < 3);
883            predictor.commit_update_history((idx % 5) < 2);
884        }
885        assert!(predictor.journal.is_empty());
886
887        let mut cloned = predictor.clone_state();
888        assert!(
889            cloned.journal.is_empty(),
890            "clone state should only carry active reversible rollback depth"
891        );
892
893        let baseline = predictor_signature(predictor.clone_state(), &[true, false, true, false]);
894        cloned.update(true);
895        cloned.revert();
896        cloned.update_history(false);
897        cloned.pop_history();
898        assert!(cloned.journal.is_empty());
899
900        let after_round_trip = predictor_signature(cloned, &[true, false, true, false]);
901        for ((got0, got1), (want0, want1)) in after_round_trip.into_iter().zip(baseline.into_iter())
902        {
903            approx_eq(got0, want0);
904            approx_eq(got1, want1);
905        }
906    }
907
908    #[test]
909    fn rollback_scope_restores_simulation_state_without_growing_journal() {
910        let mut predictor = RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
911            .expect("rate backend predictor should initialize");
912        for &bit in &[true, false, true, false, true] {
913            predictor.commit_update(bit);
914        }
915
916        let baseline = predictor_signature(predictor.clone_state(), &[true, true, false, false]);
917        predictor.begin_rollback_scope();
918        for idx in 0..512usize {
919            predictor.update((idx & 1) == 0);
920            predictor.update_history((idx % 3) == 0);
921        }
922        assert!(
923            predictor.journal.is_empty(),
924            "scoped reversible updates should not retain per-bit snapshots"
925        );
926        assert!(predictor.rollback_scope(), "scope rollback should succeed");
927        assert!(predictor.journal.is_empty());
928
929        let after = predictor_signature(predictor, &[true, true, false, false]);
930        for ((got0, got1), (want0, want1)) in after.into_iter().zip(baseline.into_iter()) {
931            approx_eq(got0, want0);
932            approx_eq(got1, want1);
933        }
934    }
935
936    #[test]
937    fn cloned_predictor_carries_only_active_scope_snapshots() {
938        let mut predictor = RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
939            .expect("rate backend predictor should initialize");
940        for idx in 0..1024usize {
941            predictor.commit_update((idx & 3) == 0);
942        }
943
944        predictor.begin_rollback_scope();
945        for idx in 0..256usize {
946            predictor.update((idx & 1) == 0);
947        }
948        let cloned = predictor.clone_state();
949        assert!(
950            cloned.journal.is_empty(),
951            "scoped reversible updates should not leak per-bit journal state into clones"
952        );
953        assert_eq!(cloned.rollback_scopes.len(), 1);
954    }
955
956    #[test]
957    fn generic_rate_backend_bit_predictors_normalize_binary_mass() {
958        assert_binary_predictor_normalizes(
959            Box::new(
960                RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
961                    .expect("generic rosa predictor"),
962            ),
963            "generic-rosa",
964        );
965        assert_binary_predictor_normalizes(
966            Box::new(
967                RateBackendBitPredictor::new(
968                    RateBackend::Ppmd {
969                        order: 4,
970                        memory_mb: 8,
971                    },
972                    8,
973                )
974                .expect("generic ppmd predictor"),
975            ),
976            "generic-ppmd",
977        );
978        assert_binary_predictor_normalizes(
979            Box::new(
980                RateBackendBitPredictor::new(
981                    RateBackend::Match {
982                        hash_bits: 16,
983                        min_len: 2,
984                        max_len: 32,
985                        base_mix: 0.05,
986                        confidence_scale: 1.0,
987                    },
988                    8,
989                )
990                .expect("generic match predictor"),
991            ),
992            "generic-match",
993        );
994    }
995
996    #[cfg(feature = "backend-zpaq")]
997    #[test]
998    fn zpaq_predictor_normalizes_binary_mass() {
999        assert_binary_predictor_normalizes(
1000            Box::new(ZpaqPredictor::new("1".to_string(), DEFAULT_MIN_PROB)),
1001            "zpaq",
1002        );
1003    }
1004
1005    #[cfg(feature = "backend-rwkv")]
1006    #[test]
1007    fn rwkv_predictor_normalizes_binary_mass() {
1008        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer";
1009        let predictor = RwkvPredictor::from_method(method).expect("rwkv predictor");
1010        assert_binary_predictor_normalizes(Box::new(predictor), "rwkv");
1011    }
1012
1013    #[cfg(feature = "backend-mamba")]
1014    #[test]
1015    fn mamba_predictor_normalizes_binary_mass() {
1016        let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=7,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer";
1017        let predictor = MambaPredictor::from_method(method).expect("mamba predictor");
1018        assert_binary_predictor_normalizes(Box::new(predictor), "mamba");
1019    }
1020}