infotheory/aixi/
model.rs

1//! Predictive models for AIXI.
2//!
3//! This module defines the `Predictor` trait, which encapsulates the logic
4//! for learning from history and predicting future symbols. Different implementations
5//! provide different complexity vs performance trade-offs.
6
7use crate::ctw::{ContextTree, FacContextTree};
8use crate::zpaq_rate::ZpaqRateModel;
9use rosaplus::{RosaPlus, RosaTx};
10use rwkvzip::{Compressor, Model, State};
11use std::sync::Arc;
12
13/// Interface for an AIXI world model.
14///
15/// A predictor must be able to update its internal state based on observed symbols,
16/// revert its state for Monte Carlo simulations, and provide probabilities for
17pub trait Predictor: Send + Sync {
18    /// Incorporates a new symbol into the model's training history.
19    fn update(&mut self, sym: bool);
20
21    /// Appends a symbol to the model's interaction history without necessarily
22    /// updating the training counts immediately (backend dependent).
23    fn update_history(&mut self, sym: bool) {
24        self.update(sym);
25    }
26
27    /// Reverts the model to its state before the last `update`.
28    fn revert(&mut self);
29
30    /// Reverts the model to its state before the last `update_history`.
31    fn pop_history(&mut self) {
32        self.revert();
33    }
34
35    /// Predicts the probability of the next symbol being `sym`.
36    fn predict_prob(&mut self, sym: bool) -> f64;
37
38    /// Shorthand for `predict_prob(true)`.
39    fn predict_one(&mut self) -> f64 {
40        self.predict_prob(true)
41    }
42
43    /// Returns a human-readable name of the predictive model.
44    fn model_name(&self) -> String;
45
46    /// Creates a boxed clone of this predictor.
47    fn boxed_clone(&self) -> Box<dyn Predictor>;
48}
49
50/// A predictor using the Action-Conditional CTW algorithm.
51///
52/// AC-CTW uses a single context tree for all bits in sequence.
53/// For better type information exploitation, use `FacCtwPredictor`.
54pub struct CtwPredictor {
55    tree: ContextTree,
56}
57
58impl CtwPredictor {
59    /// Creates a new `CtwPredictor` with the specified context depth.
60    pub fn new(depth: usize) -> Self {
61        Self {
62            tree: ContextTree::new(depth),
63        }
64    }
65}
66
67impl Predictor for CtwPredictor {
68    fn update(&mut self, sym: bool) {
69        self.tree.update(sym);
70    }
71    fn update_history(&mut self, sym: bool) {
72        self.tree.update_history(&[sym]);
73    }
74
75    fn revert(&mut self) {
76        self.tree.revert();
77    }
78    fn pop_history(&mut self) {
79        self.tree.revert_history();
80    }
81
82    fn predict_prob(&mut self, sym: bool) -> f64 {
83        self.tree.predict(sym)
84    }
85
86    fn model_name(&self) -> String {
87        format!("AC-CTW(d={})", self.tree.depth())
88    }
89
90    fn boxed_clone(&self) -> Box<dyn Predictor> {
91        Box::new(Self {
92            tree: self.tree.clone(),
93        })
94    }
95}
96
97/// A predictor using the Factorized Action-Conditional CTW (FAC-CTW) algorithm.
98///
99/// FAC-CTW uses k separate context trees (one per percept bit) with overlapping
100/// context depths D+i-1 for bit i. This enables better exploitation of type
101/// information within percepts, as described in Veness et al. (2011) Section 5.
102///
103/// This is the recommended CTW variant for MC-AIXI agents.
104pub struct FacCtwPredictor {
105    tree: FacContextTree,
106    /// Current bit index within a percept (cycles 0..num_bits).
107    current_bit: usize,
108    /// Total number of percept bits (k).
109    num_bits: usize,
110}
111
112impl FacCtwPredictor {
113    /// Creates a new `FacCtwPredictor`.
114    ///
115    /// - `base_depth`: Context depth D for the first bit's tree
116    /// - `num_percept_bits`: Total bits per percept (observation_bits + reward_bits)
117    pub fn new(base_depth: usize, num_percept_bits: usize) -> Self {
118        Self {
119            tree: FacContextTree::new(base_depth, num_percept_bits),
120            current_bit: 0,
121            num_bits: num_percept_bits,
122        }
123    }
124}
125
126impl Predictor for FacCtwPredictor {
127    fn update(&mut self, sym: bool) {
128        self.tree.update(sym, self.current_bit);
129        self.current_bit = (self.current_bit + 1) % self.num_bits;
130    }
131
132    fn update_history(&mut self, sym: bool) {
133        self.tree.update_history(&[sym]);
134    }
135
136    fn revert(&mut self) {
137        // Revert to previous bit index
138        self.current_bit = if self.current_bit == 0 {
139            self.num_bits - 1
140        } else {
141            self.current_bit - 1
142        };
143        self.tree.revert(self.current_bit);
144    }
145
146    fn pop_history(&mut self) {
147        self.tree.revert_history(1);
148    }
149
150    fn predict_prob(&mut self, sym: bool) -> f64 {
151        self.tree.predict(sym, self.current_bit)
152    }
153
154    fn model_name(&self) -> String {
155        format!("FAC-CTW(D={}, k={})", self.tree.base_depth(), self.num_bits)
156    }
157
158    fn boxed_clone(&self) -> Box<dyn Predictor> {
159        Box::new(Self {
160            tree: self.tree.clone(),
161            current_bit: self.current_bit,
162            num_bits: self.num_bits,
163        })
164    }
165}
166
167/// A predictor using the ROSA-Plus (Rapid Online Suffix Automaton + Witten-Bell Smoother) algorithm.
168///
169/// ROSA is a (practically) sub-quadratic suffix automaton based language model that
170/// can handle very long contexts efficiently.
171pub struct RosaPredictor {
172    model: RosaPlus,
173    history: Vec<RosaTx>,
174}
175
176impl RosaPredictor {
177    /// Creates a new `RosaPredictor` with a maximum context length for the fallback LM.
178    /// Note: deterministic ROSA uses the full SAM and is not capped by `max_order`.
179    pub fn new(max_order: i64) -> Self {
180        // Seed and EOT
181        let mut model = RosaPlus::new(max_order, false, 0, 42);
182        // Important: Pre-build the LM with full byte alphabet so we can incrementally update it.
183        model.build_lm_full_bytes_no_finalize_endpos();
184        Self {
185            model,
186            history: Vec::new(),
187        }
188    }
189}
190
191impl Predictor for RosaPredictor {
192    fn update(&mut self, sym: bool) {
193        let mut tx = self.model.begin_tx();
194        // Using 0u8 and 1u8 for bits
195        let byte = if sym { 1u8 } else { 0u8 };
196
197        // train_example_tx updates the tx object and the model
198        self.model.train_sequence_tx(&mut tx, &[byte]);
199        self.history.push(tx);
200    }
201
202    fn revert(&mut self) {
203        if let Some(tx) = self.history.pop() {
204            self.model.rollback_tx(tx);
205        }
206    }
207
208    fn predict_prob(&mut self, sym: bool) -> f64 {
209        let p0 = self.model.prob_for_last(0);
210        let p1 = self.model.prob_for_last(1);
211        let denom = (p0 + p1).max(1e-12);
212        if sym { p1 / denom } else { p0 / denom }
213    }
214
215    fn model_name(&self) -> String {
216        "ROSA".to_string()
217    }
218
219    fn boxed_clone(&self) -> Box<dyn Predictor> {
220        Box::new(Self {
221            model: self.model.clone(),
222            history: self.history.clone(),
223        })
224    }
225}
226
227/// A predictor using ZPAQ as a streaming rate model.
228///
229/// This maintains a full history so it can rebuild state on revert and handle
230/// any misuse where `predict_prob` is called without a matching `update`.
231pub struct ZpaqPredictor {
232    method: String,
233    min_prob: f64,
234    model: ZpaqRateModel,
235    history: Vec<u8>,
236    pending: Option<(u8, f64)>,
237}
238
239unsafe impl Sync for ZpaqPredictor {}
240
241impl ZpaqPredictor {
242    pub fn new(method: String, min_prob: f64) -> Self {
243        let model = ZpaqRateModel::new(method.clone(), min_prob);
244        Self {
245            method,
246            min_prob,
247            model,
248            history: Vec::new(),
249            pending: None,
250        }
251    }
252
253    fn rebuild_from_history(&mut self) {
254        self.model.reset();
255        if !self.history.is_empty() {
256            self.model.update_and_score(&self.history);
257        }
258    }
259
260    fn log_prob_from_history(&self, symbol: u8) -> f64 {
261        let mut tmp = ZpaqRateModel::new(self.method.clone(), self.min_prob);
262        if !self.history.is_empty() {
263            tmp.update_and_score(&self.history);
264        }
265        tmp.log_prob(symbol)
266    }
267}
268
269impl Predictor for ZpaqPredictor {
270    fn update(&mut self, sym: bool) {
271        let byte = if sym { 1u8 } else { 0u8 };
272        if let Some((pending, _)) = self.pending {
273            if pending == byte {
274                self.model.update(byte);
275                self.pending = None;
276                self.history.push(byte);
277                return;
278            }
279            self.pending = None;
280            self.rebuild_from_history();
281        }
282        self.model.update(byte);
283        self.history.push(byte);
284    }
285
286    fn revert(&mut self) {
287        if self.history.pop().is_some() {
288            self.pending = None;
289            self.rebuild_from_history();
290        }
291    }
292
293    fn predict_prob(&mut self, sym: bool) -> f64 {
294        let byte = if sym { 1u8 } else { 0u8 };
295        if let Some((pending, logp)) = self.pending {
296            if pending == byte {
297                return logp.exp();
298            }
299            return self.log_prob_from_history(byte).exp();
300        }
301        let logp = self.model.log_prob(byte);
302        self.pending = Some((byte, logp));
303        logp.exp()
304    }
305
306    fn model_name(&self) -> String {
307        format!("ZPAQ({})", self.method)
308    }
309
310    fn boxed_clone(&self) -> Box<dyn Predictor> {
311        let mut model = ZpaqRateModel::new(self.method.clone(), self.min_prob);
312        if !self.history.is_empty() {
313            model.update_and_score(&self.history);
314        }
315        Box::new(Self {
316            method: self.method.clone(),
317            min_prob: self.min_prob,
318            model,
319            history: self.history.clone(),
320            pending: None,
321        })
322    }
323}
324
325use rwkvzip::coders::softmax_pdf_floor_inplace;
326
327/// A predictor using the RWKV neural network architecture.
328///
329/// This provides a deep learning based world model for AIXI, allowing
330/// the agent to leverage large pre-trained models for sequence prediction.
331pub struct RwkvPredictor {
332    compressor: Compressor,
333    history: Vec<(State, Vec<f64>)>,
334}
335
336impl RwkvPredictor {
337    /// Creates a new `RwkvPredictor` from an initialized `Model`.
338    pub fn new(model: Arc<Model>) -> Self {
339        let mut compressor = Compressor::new_from_model(model);
340        let vocab_size = compressor.vocab_size();
341        let logits = compressor
342            .model
343            .forward(&mut compressor.scratch, 0, &mut compressor.state);
344        softmax_pdf_floor_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
345
346        Self {
347            compressor,
348            history: Vec::new(),
349        }
350    }
351}
352
353impl Predictor for RwkvPredictor {
354    fn update(&mut self, sym: bool) {
355        // Save current state and pdf
356        self.history.push((
357            self.compressor.state.clone(),
358            self.compressor.pdf_buffer.clone(),
359        ));
360
361        let byte = if sym { 1u32 } else { 0u32 };
362        let vocab_size = self.compressor.vocab_size();
363
364        let logits = self.compressor.model.forward(
365            &mut self.compressor.scratch,
366            byte,
367            &mut self.compressor.state,
368        );
369        softmax_pdf_floor_inplace(logits, vocab_size, &mut self.compressor.pdf_buffer);
370    }
371
372    fn revert(&mut self) {
373        if let Some((state, pdf)) = self.history.pop() {
374            self.compressor.state = state;
375            self.compressor.pdf_buffer = pdf;
376        }
377    }
378
379    fn predict_prob(&mut self, sym: bool) -> f64 {
380        let idx = if sym { 1 } else { 0 };
381        // pdf_buffer contains probabilities
382        self.compressor.pdf_buffer[idx]
383    }
384
385    fn model_name(&self) -> String {
386        "RWKV".to_string()
387    }
388
389    fn boxed_clone(&self) -> Box<dyn Predictor> {
390        Box::new(Self {
391            compressor: self.compressor.clone(),
392            history: self.history.clone(),
393        })
394    }
395}