infotheory/aixi/
agent.rs

1//! The core AIXI agent implementation.
2//!
3//! This module defines the `Agent` struct, which ties together a world model
4//! (Predictor) and a planner (SearchTree) to form a complete autonomous entity.
5
6use crate::aixi::common::{
7    Action, ObservationKeyMode, PerceptVal, RandomGenerator, Reward, decode, encode,
8    observation_repr_from_stream,
9};
10use crate::aixi::mcts::{AgentSimulator, SearchTree};
11#[cfg(feature = "backend-mamba")]
12use crate::aixi::model::MambaPredictor;
13#[cfg(feature = "backend-rwkv")]
14use crate::aixi::model::RwkvPredictor;
15use crate::aixi::model::{CtwPredictor, FacCtwPredictor, Predictor, RosaPredictor, ZpaqPredictor};
16#[cfg(feature = "backend-mamba")]
17use crate::load_mamba_model_from_path;
18#[cfg(feature = "backend-rwkv")]
19use crate::load_rwkv7_model_from_path;
20use crate::validate_zpaq_rate_method;
21
22/// Configuration parameters for an AIXI agent.
23#[derive(Clone, Debug)]
24pub struct AgentConfig {
25    /// The predictive algorithm to use ("ctw", "rosa", "rwkv", "mamba", "zpaq").
26    pub algorithm: String,
27    /// Context depth for the CTW model.
28    pub ct_depth: usize,
29    /// Planning horizon for MCTS.
30    pub agent_horizon: usize,
31    /// Number of bits used to encode observations.
32    pub observation_bits: usize,
33    /// Number of observation symbols per action (stream length).
34    pub observation_stream_len: usize,
35    /// Strategy for mapping observation streams into search keys.
36    pub observation_key_mode: ObservationKeyMode,
37    /// Number of bits used to encode rewards.
38    pub reward_bits: usize,
39    /// Number of possible actions.
40    pub agent_actions: usize,
41    /// Number of MCTS simulations per planning step.
42    pub num_simulations: usize,
43    /// Constant governing exploration vs exploitation in UCT.
44    pub exploration_exploitation_ratio: f64,
45    /// Discount factor for future rewards (1.0 = undiscounted).
46    pub discount_gamma: f64,
47    /// Minimum possible instantaneous reward in the environment.
48    pub min_reward: Reward,
49    /// Maximum possible instantaneous reward in the environment.
50    pub max_reward: Reward,
51    /// Reward offset applied before encoding rewards as unsigned bits.
52    ///
53    /// Paper-compatible encoding shifts rewards by an offset so all encoded values are non-negative.
54    pub reward_offset: Reward,
55    /// Optional deterministic RNG seed for planning/simulation behavior.
56    ///
57    /// When `None`, a fresh runtime-derived seed is used.
58    pub random_seed: Option<u64>,
59    /// Path to the RWKV model weights (if using "rwkv").
60    pub rwkv_model_path: Option<String>,
61    /// Optional RWKV method string for hosted/browser-safe construction.
62    pub rwkv_method: Option<String>,
63    /// Path to the Mamba model weights (if using "mamba").
64    pub mamba_model_path: Option<String>,
65    /// Optional Mamba method string for hosted/browser-safe construction.
66    pub mamba_method: Option<String>,
67    /// Maximum Markov order for the ROSA model (if using "rosa").
68    pub rosa_max_order: Option<i64>,
69    /// ZPAQ method string for the rate model (if using "zpaq").
70    pub zpaq_method: Option<String>,
71}
72
73/// A complete MC-AIXI agent.
74///
75/// The agent maintains an internal world model and a planning tree. It can
76/// be used for both live interaction with an environment and for
77/// "imaginary" simulations during planning.
78pub struct Agent {
79    /// The world model used for prediction.
80    model: Box<dyn Predictor>,
81    /// The MCTS planner, temporarily taken during search.
82    planner: Option<SearchTree>,
83    /// Configuration settings.
84    config: AgentConfig,
85
86    /// Total number of interaction cycles.
87    age: u64,
88    /// Accumulated reward.
89    total_reward: f64,
90
91    /// Pre-calculated bit depth for actions based on `agent_actions`.
92    action_bits: usize,
93
94    /// Internal PRNG for simulations.
95    rng: RandomGenerator,
96
97    /// Recycled buffer for observation generation during planning.
98    obs_buffer: Vec<u64>,
99    /// Recycled buffer for symbol processing.
100    sym_buffer: Vec<bool>,
101}
102
103impl Agent {
104    /// Creates a new `Agent` with the given configuration.
105    pub fn new(config: AgentConfig) -> Self {
106        let mut action_bits = 0;
107        let mut c = 1;
108        let mut i = 1;
109        while i < config.agent_actions {
110            i *= 2;
111            action_bits = c;
112            c += 1;
113        }
114        if config.agent_actions == 1 {
115            action_bits = 1;
116        }
117
118        let model: Box<dyn Predictor> = match config.algorithm.as_str() {
119            // FAC-CTW is the default and recommended CTW variant per the paper
120            "ctw" | "fac-ctw" => {
121                let obs_len = config.observation_stream_len.max(1);
122                let percept_bits = (config.observation_bits * obs_len) + config.reward_bits;
123                Box::new(FacCtwPredictor::new(config.ct_depth, percept_bits))
124            }
125            // AC-CTW is the legacy single-tree variant
126            "ac-ctw" | "ctw-context-tree" => Box::new(CtwPredictor::new(config.ct_depth)),
127            "rosa" => {
128                let max_order = config.rosa_max_order.unwrap_or(20);
129                Box::new(RosaPredictor::new(max_order))
130            }
131            #[cfg(feature = "backend-rwkv")]
132            "rwkv" => {
133                if let Some(method) = config.rwkv_method.as_deref() {
134                    let predictor = RwkvPredictor::from_method(method)
135                        .unwrap_or_else(|err| panic!("Invalid RWKV method for AIXI: {err}"));
136                    Box::new(predictor)
137                } else {
138                    let path = config
139                        .rwkv_model_path
140                        .as_ref()
141                        .expect("RWKV model path required");
142                    let model_arc = load_rwkv7_model_from_path(path);
143                    Box::new(RwkvPredictor::new(model_arc))
144                }
145            }
146            #[cfg(not(feature = "backend-rwkv"))]
147            "rwkv" => panic!("RWKV backend disabled at compile time"),
148            #[cfg(feature = "backend-mamba")]
149            "mamba" => {
150                if let Some(method) = config.mamba_method.as_deref() {
151                    let predictor = MambaPredictor::from_method(method)
152                        .unwrap_or_else(|err| panic!("Invalid Mamba method for AIXI: {err}"));
153                    Box::new(predictor)
154                } else {
155                    let path = config
156                        .mamba_model_path
157                        .as_ref()
158                        .expect("Mamba model path required");
159                    let model_arc = load_mamba_model_from_path(path);
160                    Box::new(MambaPredictor::new(model_arc))
161                }
162            }
163            #[cfg(not(feature = "backend-mamba"))]
164            "mamba" => panic!("Mamba backend disabled at compile time"),
165            "zpaq" => {
166                let method = config
167                    .zpaq_method
168                    .clone()
169                    .unwrap_or_else(|| "1".to_string());
170                if let Err(err) = validate_zpaq_rate_method(&method) {
171                    panic!("Invalid zpaq method for AIXI: {err}");
172                }
173                Box::new(ZpaqPredictor::new(method, 2f64.powi(-24)))
174            }
175            _ => panic!("Unknown algorithm: {}", config.algorithm),
176        };
177
178        let rng = if let Some(seed) = config.random_seed {
179            RandomGenerator::from_seed(seed)
180        } else {
181            RandomGenerator::new()
182        };
183
184        Self {
185            model,
186            planner: Some(SearchTree::new()),
187            config,
188            age: 0,
189            total_reward: 0.0,
190            action_bits,
191            rng,
192            obs_buffer: Vec::with_capacity(128),
193            sym_buffer: Vec::with_capacity(64),
194        }
195    }
196
197    fn clone_for_simulation(&self, seed: u64) -> Self {
198        Self {
199            model: self.model.boxed_clone(),
200            planner: None,
201            config: self.config.clone(),
202            age: self.age,
203            total_reward: self.total_reward,
204            action_bits: self.action_bits,
205            rng: self.rng.fork_with(seed),
206            obs_buffer: Vec::with_capacity(128),
207            sym_buffer: Vec::with_capacity(64),
208        }
209    }
210
211    /// Resets the agent's interaction statistics.
212    pub fn reset(&mut self) {
213        self.age = 0;
214        self.total_reward = 0.0;
215    }
216
217    /// Primary interface for decision making.
218    ///
219    /// Uses MCTS to find the action that maximizes expected future reward.
220    pub fn get_planned_action(
221        &mut self,
222        prev_obs_stream: &[PerceptVal],
223        prev_rew: Reward,
224        prev_act: Action,
225    ) -> Action {
226        let mut planner = self.planner.take().expect("Planner missing");
227        let num_sim = self.config.num_simulations;
228        let action = planner.search(self, prev_obs_stream, prev_rew, prev_act, num_sim);
229        self.planner = Some(planner);
230        action
231    }
232
233    /// Updates the world model with real-world percepts.
234    pub fn model_update_percept(&mut self, observation: PerceptVal, reward: Reward) {
235        self.model_update_percept_stream(&[observation], reward);
236    }
237
238    /// Updates the world model with an observation stream and a terminal reward.
239    pub fn model_update_percept_stream(&mut self, observations: &[PerceptVal], reward: Reward) {
240        debug_assert!(
241            !observations.is_empty() || self.config.observation_bits == 0,
242            "percept update missing observation stream"
243        );
244        let mut percept_syms = Vec::new();
245        for &obs in observations {
246            encode(&mut percept_syms, obs, self.config.observation_bits);
247        }
248        crate::aixi::common::encode_reward_offset(
249            &mut percept_syms,
250            reward,
251            self.config.reward_bits,
252            self.config.reward_offset,
253        );
254
255        for &sym in &percept_syms {
256            self.model.update(sym);
257        }
258
259        self.total_reward += reward as f64;
260    }
261
262    /// Computes the observation key used for search-tree branching.
263    pub fn observation_repr_from_stream(&self, observations: &[PerceptVal]) -> Vec<PerceptVal> {
264        observation_repr_from_stream(
265            self.config.observation_key_mode,
266            observations,
267            self.config.observation_bits,
268        )
269    }
270
271    /// Explicitly updates the world model with an action.
272    pub fn model_update_action_external(&mut self, action: Action) {
273        self.model_update_action(action);
274    }
275}
276
277impl AgentSimulator for Agent {
278    fn get_num_actions(&self) -> usize {
279        self.config.agent_actions
280    }
281
282    fn get_num_observation_bits(&self) -> usize {
283        self.config.observation_bits
284    }
285
286    fn observation_stream_len(&self) -> usize {
287        self.config.observation_stream_len.max(1)
288    }
289
290    fn observation_key_mode(&self) -> ObservationKeyMode {
291        self.config.observation_key_mode
292    }
293
294    fn get_num_reward_bits(&self) -> usize {
295        self.config.reward_bits
296    }
297
298    fn horizon(&self) -> usize {
299        self.config.agent_horizon
300    }
301
302    fn max_reward(&self) -> Reward {
303        self.config.max_reward
304    }
305
306    fn min_reward(&self) -> Reward {
307        self.config.min_reward
308    }
309
310    fn reward_offset(&self) -> i64 {
311        self.config.reward_offset
312    }
313
314    fn get_explore_exploit_ratio(&self) -> f64 {
315        self.config.exploration_exploitation_ratio
316    }
317
318    fn discount_gamma(&self) -> f64 {
319        self.config.discount_gamma
320    }
321
322    fn model_update_action(&mut self, action: Action) {
323        self.sym_buffer.clear();
324        encode(&mut self.sym_buffer, action, self.action_bits);
325
326        for &sym in &self.sym_buffer {
327            self.model.update_history(sym);
328        }
329    }
330
331    fn gen_percept_and_update(&mut self, bits: usize) -> u64 {
332        self.sym_buffer.clear();
333        for _ in 0..bits {
334            let prob_1 = self.model.predict_one();
335            let sym = self.rng.gen_bool(prob_1);
336            self.model.update(sym);
337            self.sym_buffer.push(sym);
338        }
339        decode(&self.sym_buffer, bits)
340    }
341
342    fn gen_percepts_and_update(&mut self) -> (Vec<PerceptVal>, Reward) {
343        let obs_bits = self.config.observation_bits;
344        let obs_len = self.config.observation_stream_len.max(1);
345
346        self.obs_buffer.clear();
347        for _ in 0..obs_len {
348            let p = self.gen_percept_and_update(obs_bits);
349            self.obs_buffer.push(p);
350        }
351
352        let obs_repr = observation_repr_from_stream(
353            self.config.observation_key_mode,
354            &self.obs_buffer,
355            obs_bits,
356        );
357        let rew_bits = self.config.reward_bits;
358        let rew_u = self.gen_percept_and_update(rew_bits);
359        let rew = (rew_u as i64) - self.config.reward_offset;
360
361        // Mark that we've completed a percept cycle (ready for next action)
362
363        (obs_repr, rew)
364    }
365
366    fn gen_range(&mut self, end: usize) -> usize {
367        self.rng.gen_range(end)
368    }
369
370    fn gen_f64(&mut self) -> f64 {
371        self.rng.gen_f64()
372    }
373
374    fn model_revert(&mut self, steps: usize) {
375        let obs_bits = self.config.observation_bits * self.config.observation_stream_len.max(1);
376        let percept_bits = obs_bits + self.config.reward_bits;
377
378        for _ in 0..steps {
379            for _ in 0..percept_bits {
380                self.model.revert();
381            }
382            for _ in 0..self.action_bits {
383                self.model.pop_history();
384            }
385        }
386    }
387
388    fn boxed_clone_with_seed(&self, seed: u64) -> Box<dyn AgentSimulator> {
389        Box::new(self.clone_for_simulation(seed))
390    }
391}