infotheory/aixi/
model.rs

1//! Predictive models for AIXI.
2//!
3//! This module defines the `Predictor` trait, which encapsulates the logic
4//! for learning from history and predicting future symbols. Different implementations
5//! provide different complexity vs performance trade-offs.
6
7use crate::RateBackend;
8use crate::ctw::{ContextTree, FacContextTree};
9#[cfg(feature = "backend-mamba")]
10use crate::mambazip::{Compressor as MambaCompressor, Model as MambaModel, State as MambaState};
11use crate::mixture::{DEFAULT_MIN_PROB, OnlineBytePredictor, RateBackendPredictor};
12use crate::rosaplus::{RosaPlus, RosaTx};
13#[cfg(feature = "backend-rwkv")]
14use crate::rwkvzip::{Compressor as RwkvCompressor, Model as RwkvModel, State as RwkvState};
15use crate::zpaq_rate::ZpaqRateModel;
16#[cfg(any(feature = "backend-mamba", feature = "backend-rwkv"))]
17use std::sync::Arc;
18
19/// Interface for an AIXI world model.
20///
21/// Predictors are mutated behind `&mut self` and cloned per worker during
22/// parallel MCTS. They only need `Send`, not `Sync`, which avoids unsound
23/// thread-sharing requirements for backends with thread-confined internals.
24pub trait Predictor: Send {
25    /// Incorporates a new symbol into the model's training history.
26    fn update(&mut self, sym: bool);
27
28    /// Appends a symbol to the model's interaction history without necessarily
29    /// updating the training counts immediately (backend dependent).
30    fn update_history(&mut self, sym: bool) {
31        self.update(sym);
32    }
33
34    /// Reverts the model to its state before the last `update`.
35    fn revert(&mut self);
36
37    /// Reverts the model to its state before the last `update_history`.
38    fn pop_history(&mut self) {
39        self.revert();
40    }
41
42    /// Predicts the probability of the next symbol being `sym`.
43    fn predict_prob(&mut self, sym: bool) -> f64;
44
45    /// Shorthand for `predict_prob(true)`.
46    fn predict_one(&mut self) -> f64 {
47        self.predict_prob(true)
48    }
49
50    /// Returns a human-readable name of the predictive model.
51    fn model_name(&self) -> String;
52
53    /// Creates a boxed clone of this predictor.
54    fn boxed_clone(&self) -> Box<dyn Predictor>;
55}
56
57/// A predictor using the Action-Conditional CTW algorithm.
58///
59/// AC-CTW uses a single context tree for all bits in sequence.
60/// For better type information exploitation, use `FacCtwPredictor`.
61pub struct CtwPredictor {
62    tree: ContextTree,
63}
64
65impl CtwPredictor {
66    /// Creates a new `CtwPredictor` with the specified context depth.
67    pub fn new(depth: usize) -> Self {
68        Self {
69            tree: ContextTree::new(depth),
70        }
71    }
72}
73
74impl Predictor for CtwPredictor {
75    fn update(&mut self, sym: bool) {
76        self.tree.update(sym);
77    }
78    fn update_history(&mut self, sym: bool) {
79        self.tree.update_history(&[sym]);
80    }
81
82    fn revert(&mut self) {
83        self.tree.revert();
84    }
85    fn pop_history(&mut self) {
86        self.tree.revert_history();
87    }
88
89    fn predict_prob(&mut self, sym: bool) -> f64 {
90        self.tree.predict(sym)
91    }
92
93    fn model_name(&self) -> String {
94        format!("AC-CTW(d={})", self.tree.depth())
95    }
96
97    fn boxed_clone(&self) -> Box<dyn Predictor> {
98        Box::new(Self {
99            tree: self.tree.clone(),
100        })
101    }
102}
103
104/// A predictor using the Factorized Action-Conditional CTW (FAC-CTW) algorithm.
105///
106/// FAC-CTW uses k separate context trees (one per percept bit) with overlapping
107/// context depths D+i-1 for bit i. This enables better exploitation of type
108/// information within percepts, as described in Veness et al. (2011) Section 5.
109///
110/// This is the recommended CTW variant for MC-AIXI agents.
111pub struct FacCtwPredictor {
112    tree: FacContextTree,
113    /// Current bit index within a percept (cycles 0..num_bits).
114    current_bit: usize,
115    /// Total number of percept bits (k).
116    num_bits: usize,
117}
118
119impl FacCtwPredictor {
120    /// Creates a new `FacCtwPredictor`.
121    ///
122    /// - `base_depth`: Context depth D for the first bit's tree
123    /// - `num_percept_bits`: Total bits per percept (observation_bits + reward_bits)
124    pub fn new(base_depth: usize, num_percept_bits: usize) -> Self {
125        Self {
126            tree: FacContextTree::new(base_depth, num_percept_bits),
127            current_bit: 0,
128            num_bits: num_percept_bits,
129        }
130    }
131}
132
133impl Predictor for FacCtwPredictor {
134    fn update(&mut self, sym: bool) {
135        self.tree.update(sym, self.current_bit);
136        self.current_bit = (self.current_bit + 1) % self.num_bits;
137    }
138
139    fn update_history(&mut self, sym: bool) {
140        self.tree.update_history(&[sym]);
141    }
142
143    fn revert(&mut self) {
144        // Revert to previous bit index
145        self.current_bit = if self.current_bit == 0 {
146            self.num_bits - 1
147        } else {
148            self.current_bit - 1
149        };
150        self.tree.revert(self.current_bit);
151    }
152
153    fn pop_history(&mut self) {
154        self.tree.revert_history(1);
155    }
156
157    fn predict_prob(&mut self, sym: bool) -> f64 {
158        self.tree.predict(sym, self.current_bit)
159    }
160
161    fn model_name(&self) -> String {
162        format!("FAC-CTW(D={}, k={})", self.tree.base_depth(), self.num_bits)
163    }
164
165    fn boxed_clone(&self) -> Box<dyn Predictor> {
166        Box::new(Self {
167            tree: self.tree.clone(),
168            current_bit: self.current_bit,
169            num_bits: self.num_bits,
170        })
171    }
172}
173
174/// A predictor using the ROSA-Plus (Rapid Online Suffix Automaton + Witten-Bell Smoother) algorithm.
175///
176/// ROSA is a (practically) sub-quadratic suffix automaton based language model that
177/// can handle very long contexts efficiently.
178pub struct RosaPredictor {
179    model: RosaPlus,
180    history: Vec<RosaTx>,
181}
182
183impl RosaPredictor {
184    /// Creates a new `RosaPredictor` with a maximum context length for the fallback LM.
185    /// Note: deterministic ROSA uses the full SAM and is not capped by `max_order`.
186    pub fn new(max_order: i64) -> Self {
187        // Seed and EOT
188        let mut model = RosaPlus::new(max_order, false, 0, 42);
189        // Important: Pre-build the LM with full byte alphabet so we can incrementally update it.
190        model.build_lm_full_bytes_no_finalize_endpos();
191        Self {
192            model,
193            history: Vec::new(),
194        }
195    }
196}
197
198impl Predictor for RosaPredictor {
199    fn update(&mut self, sym: bool) {
200        let mut tx = self.model.begin_tx();
201        // Using 0u8 and 1u8 for bits
202        let byte = if sym { 1u8 } else { 0u8 };
203
204        // train_example_tx updates the tx object and the model
205        self.model.train_sequence_tx(&mut tx, &[byte]);
206        self.history.push(tx);
207    }
208
209    fn revert(&mut self) {
210        if let Some(tx) = self.history.pop() {
211            self.model.rollback_tx(tx);
212        }
213    }
214
215    fn predict_prob(&mut self, sym: bool) -> f64 {
216        let p0 = self.model.prob_for_last(0);
217        let p1 = self.model.prob_for_last(1);
218        let denom = (p0 + p1).max(1e-12);
219        if sym { p1 / denom } else { p0 / denom }
220    }
221
222    fn model_name(&self) -> String {
223        "ROSA".to_string()
224    }
225
226    fn boxed_clone(&self) -> Box<dyn Predictor> {
227        Box::new(Self {
228            model: self.model.clone(),
229            history: self.history.clone(),
230        })
231    }
232}
233
234/// A predictor using ZPAQ as a streaming rate model.
235///
236/// This maintains a full history so it can rebuild state on revert and handle
237/// any misuse where `predict_prob` is called without a matching `update`.
238pub struct ZpaqPredictor {
239    method: String,
240    min_prob: f64,
241    model: ZpaqRateModel,
242    history: Vec<u8>,
243    pending: Option<(u8, f64)>,
244}
245
246impl ZpaqPredictor {
247    /// Create a ZPAQ-backed predictor from a `method` and probability floor.
248    pub fn new(method: String, min_prob: f64) -> Self {
249        let model = ZpaqRateModel::new(method.clone(), min_prob);
250        Self {
251            method,
252            min_prob,
253            model,
254            history: Vec::new(),
255            pending: None,
256        }
257    }
258
259    fn rebuild_from_history(&mut self) {
260        self.model.reset();
261        if !self.history.is_empty() {
262            self.model.update_and_score(&self.history);
263        }
264    }
265
266    fn log_prob_from_history(&self, symbol: u8) -> f64 {
267        let mut tmp = ZpaqRateModel::new(self.method.clone(), self.min_prob);
268        if !self.history.is_empty() {
269            tmp.update_and_score(&self.history);
270        }
271        tmp.log_prob(symbol)
272    }
273}
274
275impl Predictor for ZpaqPredictor {
276    fn update(&mut self, sym: bool) {
277        let byte = if sym { 1u8 } else { 0u8 };
278        if let Some((pending, _)) = self.pending {
279            if pending == byte {
280                self.model.update(byte);
281                self.pending = None;
282                self.history.push(byte);
283                return;
284            }
285            self.pending = None;
286            self.rebuild_from_history();
287        }
288        self.model.update(byte);
289        self.history.push(byte);
290    }
291
292    fn revert(&mut self) {
293        if self.history.pop().is_some() {
294            self.pending = None;
295            self.rebuild_from_history();
296        }
297    }
298
299    fn predict_prob(&mut self, sym: bool) -> f64 {
300        let byte = if sym { 1u8 } else { 0u8 };
301        if let Some((pending, logp)) = self.pending {
302            if pending == byte {
303                return logp.exp();
304            }
305            return self.log_prob_from_history(byte).exp();
306        }
307        let logp = self.model.log_prob(byte);
308        self.pending = Some((byte, logp));
309        logp.exp()
310    }
311
312    fn model_name(&self) -> String {
313        format!("ZPAQ({})", self.method)
314    }
315
316    fn boxed_clone(&self) -> Box<dyn Predictor> {
317        Box::new(Self {
318            method: self.method.clone(),
319            min_prob: self.min_prob,
320            model: self.model.clone(),
321            history: self.history.clone(),
322            pending: self.pending,
323        })
324    }
325}
326
327/// A generic bit-level predictor backed by any [`RateBackend`].
328///
329/// This adapter maps boolean symbols to bytes `{0,1}` and forwards them to the
330/// workspace-wide rate backend abstraction. It prioritizes correctness and
331/// backend coverage over rollback efficiency.
332pub struct RateBackendBitPredictor {
333    backend: RateBackend,
334    max_order: i64,
335    min_prob: f64,
336    predictor: RateBackendPredictor,
337}
338
339impl RateBackendBitPredictor {
340    /// Create a new bit-level adapter from a rate backend.
341    pub fn new(backend: RateBackend, max_order: i64) -> Result<Self, String> {
342        Self::new_with_min_prob(backend, max_order, DEFAULT_MIN_PROB)
343    }
344
345    /// Create a new bit-level adapter with an explicit probability floor.
346    pub fn new_with_min_prob(
347        backend: RateBackend,
348        max_order: i64,
349        min_prob: f64,
350    ) -> Result<Self, String> {
351        if rate_backend_contains_zpaq(&backend) {
352            return Err(
353                "RateBackendBitPredictor does not support zpaq backends; use a non-zpaq rate_backend"
354                    .to_string(),
355            );
356        }
357        let mut predictor =
358            RateBackendPredictor::from_backend(backend.clone(), max_order, min_prob);
359        predictor
360            .begin_stream(None)
361            .map_err(|err| format!("failed to start RateBackend predictor stream: {err}"))?;
362        Ok(Self {
363            backend,
364            max_order,
365            min_prob,
366            predictor,
367        })
368    }
369
370    #[inline(always)]
371    fn bit_to_byte(sym: bool) -> u8 {
372        if sym { 1u8 } else { 0u8 }
373    }
374
375    fn clone_state(&self) -> Self {
376        Self {
377            backend: self.backend.clone(),
378            max_order: self.max_order,
379            min_prob: self.min_prob,
380            predictor: self.predictor.clone(),
381        }
382    }
383}
384
385fn rate_backend_contains_zpaq(backend: &RateBackend) -> bool {
386    match backend {
387        RateBackend::Zpaq { .. } => true,
388        RateBackend::Mixture { spec } => spec
389            .experts
390            .iter()
391            .any(|expert| rate_backend_contains_zpaq(&expert.backend)),
392        RateBackend::Calibrated { spec } => rate_backend_contains_zpaq(&spec.base),
393        _ => false,
394    }
395}
396
397impl Predictor for RateBackendBitPredictor {
398    fn update(&mut self, sym: bool) {
399        self.predictor.update(Self::bit_to_byte(sym));
400    }
401
402    fn update_history(&mut self, sym: bool) {
403        self.predictor.update_frozen(Self::bit_to_byte(sym));
404    }
405
406    fn revert(&mut self) {
407        panic!(
408            "RateBackendBitPredictor does not support generic rollback; callers must use cloned temporary predictors"
409        );
410    }
411
412    fn pop_history(&mut self) {
413        panic!(
414            "RateBackendBitPredictor does not support generic rollback; callers must use cloned temporary predictors"
415        );
416    }
417
418    fn predict_prob(&mut self, sym: bool) -> f64 {
419        let p = self.predictor.log_prob(Self::bit_to_byte(sym)).exp();
420        if p.is_finite() {
421            p.clamp(self.min_prob, 1.0 - self.min_prob)
422        } else {
423            0.5
424        }
425    }
426
427    fn model_name(&self) -> String {
428        format!(
429            "RateBackendBits({})",
430            RateBackendPredictor::default_name(&self.backend, self.max_order)
431        )
432    }
433
434    fn boxed_clone(&self) -> Box<dyn Predictor> {
435        Box::new(self.clone_state())
436    }
437}
438
439#[cfg(feature = "backend-rwkv")]
440use crate::coders::softmax_pdf_floor_inplace;
441
442/// A predictor using the RWKV neural network architecture.
443///
444/// This provides a deep learning based world model for AIXI, allowing
445/// the agent to leverage large pre-trained models for sequence prediction.
446#[cfg(feature = "backend-rwkv")]
447pub struct RwkvPredictor {
448    compressor: RwkvCompressor,
449    history: Vec<(RwkvState, Vec<f64>)>,
450}
451
452#[cfg(feature = "backend-rwkv")]
453impl RwkvPredictor {
454    /// Creates a new `RwkvPredictor` from an initialized `Model`.
455    pub fn new(model: Arc<RwkvModel>) -> Self {
456        let mut compressor = RwkvCompressor::new_from_model(model);
457        let vocab_size = compressor.vocab_size();
458        let logits = compressor
459            .model
460            .forward(&mut compressor.scratch, 0, &mut compressor.state);
461        softmax_pdf_floor_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
462
463        Self {
464            compressor,
465            history: Vec::new(),
466        }
467    }
468
469    /// Creates a new `RwkvPredictor` from a method string.
470    pub fn from_method(method: &str) -> Result<Self, String> {
471        let mut compressor =
472            RwkvCompressor::new_from_method(method).map_err(|err| err.to_string())?;
473        compressor.forward_to_internal_pdf(0);
474        Ok(Self {
475            compressor,
476            history: Vec::new(),
477        })
478    }
479}
480
481#[cfg(feature = "backend-rwkv")]
482impl Predictor for RwkvPredictor {
483    fn update(&mut self, sym: bool) {
484        // Save current state and pdf
485        self.history.push((
486            self.compressor.state.clone(),
487            self.compressor.pdf_buffer.clone(),
488        ));
489
490        let byte = if sym { 1u32 } else { 0u32 };
491        let vocab_size = self.compressor.vocab_size();
492
493        let logits = self.compressor.model.forward(
494            &mut self.compressor.scratch,
495            byte,
496            &mut self.compressor.state,
497        );
498        softmax_pdf_floor_inplace(logits, vocab_size, &mut self.compressor.pdf_buffer);
499    }
500
501    fn revert(&mut self) {
502        if let Some((state, pdf)) = self.history.pop() {
503            self.compressor.state = state;
504            self.compressor.pdf_buffer = pdf;
505        }
506    }
507
508    fn predict_prob(&mut self, sym: bool) -> f64 {
509        let idx = if sym { 1 } else { 0 };
510        // pdf_buffer contains probabilities
511        self.compressor.pdf_buffer[idx]
512    }
513
514    fn model_name(&self) -> String {
515        "RWKV".to_string()
516    }
517
518    fn boxed_clone(&self) -> Box<dyn Predictor> {
519        Box::new(Self {
520            compressor: self.compressor.clone(),
521            history: self.history.clone(),
522        })
523    }
524}
525
526/// A predictor using the Mamba neural network architecture.
527#[cfg(feature = "backend-mamba")]
528pub struct MambaPredictor {
529    compressor: MambaCompressor,
530    history: Vec<(MambaState, Vec<f64>)>,
531}
532
533#[cfg(feature = "backend-mamba")]
534impl MambaPredictor {
535    /// Creates a new `MambaPredictor` from an initialized `Model`.
536    pub fn new(model: Arc<MambaModel>) -> Self {
537        let mut compressor = MambaCompressor::new_from_model(model);
538        let logits = compressor
539            .model
540            .forward(&mut compressor.scratch, 0, &mut compressor.state)
541            .to_vec();
542        let bias = compressor.online_bias_snapshot();
543        MambaCompressor::logits_to_pdf(&logits, bias.as_deref(), &mut compressor.pdf_buffer);
544
545        Self {
546            compressor,
547            history: Vec::new(),
548        }
549    }
550
551    /// Creates a new `MambaPredictor` from a method string.
552    pub fn from_method(method: &str) -> Result<Self, String> {
553        let mut compressor =
554            MambaCompressor::new_from_method(method).map_err(|err| err.to_string())?;
555        let mut pdf = vec![0.0f64; compressor.vocab_size()];
556        compressor.forward_to_pdf(0, &mut pdf);
557        compressor.pdf_buffer.clone_from(&pdf);
558        Ok(Self {
559            compressor,
560            history: Vec::new(),
561        })
562    }
563}
564
565#[cfg(feature = "backend-mamba")]
566impl Predictor for MambaPredictor {
567    fn update(&mut self, sym: bool) {
568        self.history.push((
569            self.compressor.state.clone(),
570            self.compressor.pdf_buffer.clone(),
571        ));
572
573        let byte = if sym { 1u32 } else { 0u32 };
574        let logits = self
575            .compressor
576            .model
577            .forward(
578                &mut self.compressor.scratch,
579                byte,
580                &mut self.compressor.state,
581            )
582            .to_vec();
583        let bias = self.compressor.online_bias_snapshot();
584        MambaCompressor::logits_to_pdf(&logits, bias.as_deref(), &mut self.compressor.pdf_buffer);
585    }
586
587    fn revert(&mut self) {
588        if let Some((state, pdf)) = self.history.pop() {
589            self.compressor.state = state;
590            self.compressor.pdf_buffer = pdf;
591        }
592    }
593
594    fn predict_prob(&mut self, sym: bool) -> f64 {
595        let idx = if sym { 1 } else { 0 };
596        self.compressor.pdf_buffer[idx]
597    }
598
599    fn model_name(&self) -> String {
600        "Mamba".to_string()
601    }
602
603    fn boxed_clone(&self) -> Box<dyn Predictor> {
604        Box::new(Self {
605            compressor: self.compressor.clone(),
606            history: self.history.clone(),
607        })
608    }
609}