infotheory/aixi/
agent.rs

1//! The core AIXI agent implementation.
2//!
3//! This module defines the `Agent` struct, which ties together a world model
4//! (Predictor) and a planner (SearchTree) to form a complete autonomous entity.
5
6use crate::aixi::common::{
7    Action, ObservationKeyMode, PerceptVal, RandomGenerator, Reward, decode, encode,
8    observation_repr_from_stream,
9};
10use crate::aixi::mcts::{AgentSimulator, SearchTree};
11use crate::aixi::model::{
12    CtwPredictor, FacCtwPredictor, Predictor, RosaPredictor, RwkvPredictor, ZpaqPredictor,
13};
14use crate::{load_rwkv7_model_from_path, validate_zpaq_rate_method};
15
16/// Configuration parameters for an AIXI agent.
17#[derive(Clone, Debug)]
18pub struct AgentConfig {
19    /// The predictive algorithm to use ("ctw", "rosa", "rwkv", "zpaq").
20    pub algorithm: String,
21    /// Context depth for the CTW model.
22    pub ct_depth: usize,
23    /// Planning horizon for MCTS.
24    pub agent_horizon: usize,
25    /// Number of bits used to encode observations.
26    pub observation_bits: usize,
27    /// Number of observation symbols per action (stream length).
28    pub observation_stream_len: usize,
29    /// Strategy for mapping observation streams into search keys.
30    pub observation_key_mode: ObservationKeyMode,
31    /// Number of bits used to encode rewards.
32    pub reward_bits: usize,
33    /// Number of possible actions.
34    pub agent_actions: usize,
35    /// Number of MCTS simulations per planning step.
36    pub num_simulations: usize,
37    /// Constant governing exploration vs exploitation in UCT.
38    pub exploration_exploitation_ratio: f64,
39    /// Discount factor for future rewards (1.0 = undiscounted).
40    pub discount_gamma: f64,
41    /// Minimum possible instantaneous reward in the environment.
42    pub min_reward: Reward,
43    /// Maximum possible instantaneous reward in the environment.
44    pub max_reward: Reward,
45    /// Reward offset applied before encoding rewards as unsigned bits.
46    ///
47    /// Paper-compatible encoding shifts rewards by an offset so all encoded values are non-negative.
48    pub reward_offset: Reward,
49    /// Path to the RWKV model weights (if using "rwkv").
50    pub rwkv_model_path: Option<String>,
51    /// Maximum Markov order for the ROSA model (if using "rosa").
52    pub rosa_max_order: Option<i64>,
53    /// ZPAQ method string for the rate model (if using "zpaq").
54    pub zpaq_method: Option<String>,
55}
56
57/// A complete MC-AIXI agent.
58///
59/// The agent maintains an internal world model and a planning tree. It can
60/// be used for both live interaction with an environment and for
61/// "imaginary" simulations during planning.
62pub struct Agent {
63    /// The world model used for prediction.
64    model: Box<dyn Predictor>,
65    /// The MCTS planner, temporarily taken during search.
66    planner: Option<SearchTree>,
67    /// Configuration settings.
68    config: AgentConfig,
69
70    /// Total number of interaction cycles.
71    age: u64,
72    /// Accumulated reward.
73    total_reward: f64,
74
75    /// Pre-calculated bit depth for actions based on `agent_actions`.
76    action_bits: usize,
77
78    /// Internal PRNG for simulations.
79    rng: RandomGenerator,
80
81    /// Recycled buffer for observation generation during planning.
82    obs_buffer: Vec<u64>,
83    /// Recycled buffer for symbol processing.
84    sym_buffer: Vec<bool>,
85}
86
87impl Agent {
88    /// Creates a new `Agent` with the given configuration.
89    pub fn new(config: AgentConfig) -> Self {
90        let mut action_bits = 0;
91        let mut c = 1;
92        let mut i = 1;
93        while i < config.agent_actions {
94            i *= 2;
95            action_bits = c;
96            c += 1;
97        }
98        if config.agent_actions == 1 {
99            action_bits = 1;
100        }
101
102        let model: Box<dyn Predictor> = match config.algorithm.as_str() {
103            // FAC-CTW is the default and recommended CTW variant per the paper
104            "ctw" | "fac-ctw" => {
105                let obs_len = config.observation_stream_len.max(1);
106                let percept_bits = (config.observation_bits * obs_len) + config.reward_bits;
107                Box::new(FacCtwPredictor::new(config.ct_depth, percept_bits))
108            }
109            // AC-CTW is the legacy single-tree variant
110            "ac-ctw" | "ctw-context-tree" => Box::new(CtwPredictor::new(config.ct_depth)),
111            "rosa" => {
112                let max_order = config.rosa_max_order.unwrap_or(20);
113                Box::new(RosaPredictor::new(max_order))
114            }
115            "rwkv" => {
116                let path = config
117                    .rwkv_model_path
118                    .as_ref()
119                    .expect("RWKV model path required");
120                let model_arc = load_rwkv7_model_from_path(path);
121                Box::new(RwkvPredictor::new(model_arc))
122            }
123            "zpaq" => {
124                let method = config
125                    .zpaq_method
126                    .clone()
127                    .unwrap_or_else(|| "1".to_string());
128                if let Err(err) = validate_zpaq_rate_method(&method) {
129                    panic!("Invalid zpaq method for AIXI: {err}");
130                }
131                Box::new(ZpaqPredictor::new(method, 2f64.powi(-24)))
132            }
133            _ => panic!("Unknown algorithm: {}", config.algorithm),
134        };
135
136        Self {
137            model,
138            planner: Some(SearchTree::new()),
139            config,
140            age: 0,
141            total_reward: 0.0,
142            action_bits,
143            rng: RandomGenerator::new(),
144            obs_buffer: Vec::with_capacity(128),
145            sym_buffer: Vec::with_capacity(64),
146        }
147    }
148
149    fn clone_for_simulation(&self, seed: u64) -> Self {
150        Self {
151            model: self.model.boxed_clone(),
152            planner: None,
153            config: self.config.clone(),
154            age: self.age,
155            total_reward: self.total_reward,
156            action_bits: self.action_bits,
157            rng: self.rng.fork_with(seed),
158            obs_buffer: Vec::with_capacity(128),
159            sym_buffer: Vec::with_capacity(64),
160        }
161    }
162
163    /// Resets the agent's interaction statistics.
164    pub fn reset(&mut self) {
165        self.age = 0;
166        self.total_reward = 0.0;
167    }
168
169    /// Primary interface for decision making.
170    ///
171    /// Uses MCTS to find the action that maximizes expected future reward.
172    pub fn get_planned_action(
173        &mut self,
174        prev_obs_stream: &[PerceptVal],
175        prev_rew: Reward,
176        prev_act: Action,
177    ) -> Action {
178        let mut planner = self.planner.take().expect("Planner missing");
179        let num_sim = self.config.num_simulations;
180        let action = planner.search(self, prev_obs_stream, prev_rew, prev_act, num_sim);
181        self.planner = Some(planner);
182        action
183    }
184
185    /// Updates the world model with real-world percepts.
186    pub fn model_update_percept(&mut self, observation: PerceptVal, reward: Reward) {
187        self.model_update_percept_stream(&[observation], reward);
188    }
189
190    /// Updates the world model with an observation stream and a terminal reward.
191    pub fn model_update_percept_stream(&mut self, observations: &[PerceptVal], reward: Reward) {
192        debug_assert!(
193            !observations.is_empty() || self.config.observation_bits == 0,
194            "percept update missing observation stream"
195        );
196        let mut percept_syms = Vec::new();
197        for &obs in observations {
198            encode(&mut percept_syms, obs, self.config.observation_bits);
199        }
200        crate::aixi::common::encode_reward_offset(
201            &mut percept_syms,
202            reward,
203            self.config.reward_bits,
204            self.config.reward_offset,
205        );
206
207        for &sym in &percept_syms {
208            self.model.update(sym);
209        }
210
211        self.total_reward += reward as f64;
212    }
213
214    /// Computes the observation key used for search-tree branching.
215    pub fn observation_repr_from_stream(&self, observations: &[PerceptVal]) -> Vec<PerceptVal> {
216        observation_repr_from_stream(
217            self.config.observation_key_mode,
218            observations,
219            self.config.observation_bits,
220        )
221    }
222
223    /// Explicitly updates the world model with an action.
224    pub fn model_update_action_external(&mut self, action: Action) {
225        self.model_update_action(action);
226    }
227}
228
229impl AgentSimulator for Agent {
230    fn get_num_actions(&self) -> usize {
231        self.config.agent_actions
232    }
233
234    fn get_num_observation_bits(&self) -> usize {
235        self.config.observation_bits
236    }
237
238    fn observation_stream_len(&self) -> usize {
239        self.config.observation_stream_len.max(1)
240    }
241
242    fn observation_key_mode(&self) -> ObservationKeyMode {
243        self.config.observation_key_mode
244    }
245
246    fn get_num_reward_bits(&self) -> usize {
247        self.config.reward_bits
248    }
249
250    fn horizon(&self) -> usize {
251        self.config.agent_horizon
252    }
253
254    fn max_reward(&self) -> Reward {
255        self.config.max_reward
256    }
257
258    fn min_reward(&self) -> Reward {
259        self.config.min_reward
260    }
261
262    fn reward_offset(&self) -> i64 {
263        self.config.reward_offset
264    }
265
266    fn get_explore_exploit_ratio(&self) -> f64 {
267        self.config.exploration_exploitation_ratio
268    }
269
270    fn discount_gamma(&self) -> f64 {
271        self.config.discount_gamma
272    }
273
274    fn model_update_action(&mut self, action: Action) {
275        self.sym_buffer.clear();
276        encode(&mut self.sym_buffer, action, self.action_bits);
277
278        for &sym in &self.sym_buffer {
279            self.model.update_history(sym);
280        }
281    }
282
283    fn gen_percept_and_update(&mut self, bits: usize) -> u64 {
284        self.sym_buffer.clear();
285        for _ in 0..bits {
286            let prob_1 = self.model.predict_one();
287            let sym = self.rng.gen_bool(prob_1);
288            self.model.update(sym);
289            self.sym_buffer.push(sym);
290        }
291        decode(&self.sym_buffer, bits)
292    }
293
294    fn gen_percepts_and_update(&mut self) -> (Vec<PerceptVal>, Reward) {
295        let obs_bits = self.config.observation_bits;
296        let obs_len = self.config.observation_stream_len.max(1);
297
298        self.obs_buffer.clear();
299        for _ in 0..obs_len {
300            let p = self.gen_percept_and_update(obs_bits);
301            self.obs_buffer.push(p);
302        }
303
304        let obs_repr = observation_repr_from_stream(
305            self.config.observation_key_mode,
306            &self.obs_buffer,
307            obs_bits,
308        );
309        let rew_bits = self.config.reward_bits;
310        let rew_u = self.gen_percept_and_update(rew_bits);
311        let rew = (rew_u as i64) - self.config.reward_offset;
312
313        // Mark that we've completed a percept cycle (ready for next action)
314
315        (obs_repr, rew)
316    }
317
318    fn gen_range(&mut self, end: usize) -> usize {
319        self.rng.gen_range(end)
320    }
321
322    fn gen_f64(&mut self) -> f64 {
323        self.rng.gen_f64()
324    }
325
326    fn model_revert(&mut self, steps: usize) {
327        let obs_bits = self.config.observation_bits * self.config.observation_stream_len.max(1);
328        let percept_bits = obs_bits + self.config.reward_bits;
329
330        for _ in 0..steps {
331            for _ in 0..percept_bits {
332                self.model.revert();
333            }
334            for _ in 0..self.action_bits {
335                self.model.pop_history();
336            }
337        }
338
339    }
340
341    fn boxed_clone_with_seed(&self, seed: u64) -> Box<dyn AgentSimulator> {
342        Box::new(self.clone_for_simulation(seed))
343    }
344}