Skip to main content

infotheory/aixi/
agent.rs

1//! The core AIXI agent implementation.
2//!
3//! This module defines the `Agent` struct, which ties together a world model
4//! (Predictor) and a planner (SearchTree) to form a complete autonomous entity.
5
6use crate::RateBackend;
7use crate::aixi::common::{
8    Action, ObservationKeyMode, PerceptVal, RandomGenerator, Reward, decode, encode,
9    observation_repr_from_stream,
10};
11use crate::aixi::mcts::{AgentSimulator, SearchTree};
12#[cfg(feature = "backend-mamba")]
13use crate::aixi::model::MambaPredictor;
14#[cfg(feature = "backend-rwkv")]
15use crate::aixi::model::RwkvPredictor;
16use crate::aixi::model::{
17    CtwPredictor, FacCtwPredictor, Predictor, RateBackendBitPredictor, RosaPredictor, ZpaqPredictor,
18};
19use crate::aixi::rate_backend::{adapt_rate_backend_for_bit_tokens, rate_backend_contains_zpaq};
20#[cfg(feature = "backend-mamba")]
21use crate::load_mamba_model_from_path;
22#[cfg(feature = "backend-rwkv")]
23use crate::load_rwkv7_model_from_path;
24use crate::{validate_rate_backend, validate_zpaq_rate_method};
25
26/// Configuration parameters for an AIXI agent.
27#[derive(Clone)]
28pub struct AgentConfig {
29    /// The predictive algorithm to use ("ctw", "rosa", "rwkv", "mamba", "zpaq").
30    pub algorithm: String,
31    /// Context depth for the CTW model.
32    pub ct_depth: usize,
33    /// Planning horizon for MCTS.
34    pub agent_horizon: usize,
35    /// Number of bits used to encode observations.
36    pub observation_bits: usize,
37    /// Number of observation symbols per action (stream length).
38    pub observation_stream_len: usize,
39    /// Strategy for mapping observation streams into search keys.
40    pub observation_key_mode: ObservationKeyMode,
41    /// Number of bits used to encode rewards.
42    pub reward_bits: usize,
43    /// Number of possible actions.
44    pub agent_actions: usize,
45    /// Number of MCTS simulations per planning step.
46    pub num_simulations: usize,
47    /// Constant governing exploration vs exploitation in UCT.
48    pub exploration_exploitation_ratio: f64,
49    /// Discount factor for future rewards (1.0 = undiscounted).
50    pub discount_gamma: f64,
51    /// Minimum possible instantaneous reward in the environment.
52    pub min_reward: Reward,
53    /// Maximum possible instantaneous reward in the environment.
54    pub max_reward: Reward,
55    /// Reward offset applied before encoding rewards as unsigned bits.
56    ///
57    /// Paper-compatible encoding shifts rewards by an offset so all encoded values are non-negative.
58    pub reward_offset: Reward,
59    /// Optional deterministic RNG seed for planning/simulation behavior.
60    ///
61    /// When `None`, a fresh runtime-derived seed is used.
62    pub random_seed: Option<u64>,
63    /// Optional generic rate backend override.
64    ///
65    /// When set, this takes precedence over `algorithm` and routes MC-AIXI
66    /// through the shared `RateBackend` abstraction.
67    pub rate_backend: Option<RateBackend>,
68    /// Max-order hint for `rate_backend` constructors that use it (for example ROSA).
69    pub rate_backend_max_order: i64,
70    /// Path to the RWKV model weights (if using "rwkv").
71    pub rwkv_model_path: Option<String>,
72    /// Optional RWKV method string for hosted/browser-safe construction.
73    pub rwkv_method: Option<String>,
74    /// Path to the Mamba model weights (if using "mamba").
75    pub mamba_model_path: Option<String>,
76    /// Optional Mamba method string for hosted/browser-safe construction.
77    pub mamba_method: Option<String>,
78    /// Maximum Markov order for the ROSA model (if using "rosa").
79    pub rosa_max_order: Option<i64>,
80    /// ZPAQ method string for the rate model (if using "zpaq").
81    pub zpaq_method: Option<String>,
82}
83
84impl AgentConfig {
85    /// Validate configuration constraints for MC-AIXI.
86    pub fn validate(&self) -> Result<(), String> {
87        if self.agent_actions == 0 {
88            return Err("agent_actions must be >= 1".to_string());
89        }
90        if self.agent_horizon == 0 {
91            return Err("agent_horizon must be >= 1".to_string());
92        }
93        if self.num_simulations == 0 {
94            return Err("num_simulations must be >= 1".to_string());
95        }
96        if self.exploration_exploitation_ratio <= 0.0 {
97            return Err("exploration_exploitation_ratio must be > 0".to_string());
98        }
99        if !(0.0..=1.0).contains(&self.discount_gamma) {
100            return Err(format!(
101                "discount_gamma must be in [0, 1] for MC-AIXI, got {}",
102                self.discount_gamma
103            ));
104        }
105        if self.max_reward < self.min_reward {
106            return Err(format!(
107                "max_reward must be >= min_reward (got {} < {})",
108                self.max_reward, self.min_reward
109            ));
110        }
111
112        let min_shifted = (self.min_reward as i128) + (self.reward_offset as i128);
113        let max_shifted = (self.max_reward as i128) + (self.reward_offset as i128);
114        if min_shifted < 0 {
115            return Err(format!(
116                "reward_offset too small: min_reward + reward_offset must be >= 0 (got {})",
117                min_shifted
118            ));
119        }
120        if self.reward_bits < 64 {
121            let max_enc = (1u128 << self.reward_bits) - 1;
122            if (max_shifted as u128) > max_enc {
123                return Err(format!(
124                    "reward_bits too small for configured reward range: max shifted reward {} exceeds {}",
125                    max_shifted, max_enc
126                ));
127            }
128        }
129
130        if let Some(rate_backend) = &self.rate_backend {
131            validate_rate_backend(rate_backend)
132                .map_err(|err| format!("invalid rate_backend: {err}"))?;
133            if rate_backend_contains_zpaq(rate_backend) {
134                return Err(
135                    "MC-AIXI strict generic rate_backend support requires reversible action conditioning; configured rate_backend contains zpaq which does not provide the reversible action conditioning required by \"A Monte-Carlo AIXI Approximation\""
136                        .to_string(),
137                );
138            }
139            return Ok(());
140        }
141
142        match self.algorithm.as_str() {
143            "ctw" | "fac-ctw" | "ac-ctw" | "ctw-context-tree" | "rosa" => {}
144            #[cfg(feature = "backend-rwkv")]
145            "rwkv" => {
146                let has_method = self
147                    .rwkv_method
148                    .as_deref()
149                    .map(str::trim)
150                    .is_some_and(|v| !v.is_empty());
151                let has_path = self
152                    .rwkv_model_path
153                    .as_deref()
154                    .map(str::trim)
155                    .is_some_and(|v| !v.is_empty());
156                if !(has_method || has_path) {
157                    return Err(
158                        "algorithm=rwkv requires rwkv_model_path or rwkv_method when no rate_backend override is configured"
159                            .to_string(),
160                    );
161                }
162            }
163            #[cfg(not(feature = "backend-rwkv"))]
164            "rwkv" => return Err("algorithm=rwkv requires backend-rwkv feature".to_string()),
165            #[cfg(feature = "backend-mamba")]
166            "mamba" => {
167                let has_method = self
168                    .mamba_method
169                    .as_deref()
170                    .map(str::trim)
171                    .is_some_and(|v| !v.is_empty());
172                let has_path = self
173                    .mamba_model_path
174                    .as_deref()
175                    .map(str::trim)
176                    .is_some_and(|v| !v.is_empty());
177                if !(has_method || has_path) {
178                    return Err(
179                        "algorithm=mamba requires mamba_model_path or mamba_method when no rate_backend override is configured"
180                            .to_string(),
181                    );
182                }
183            }
184            #[cfg(not(feature = "backend-mamba"))]
185            "mamba" => return Err("algorithm=mamba requires backend-mamba feature".to_string()),
186            "zpaq" => {
187                let method = self.zpaq_method.as_deref().unwrap_or("1");
188                if let Err(err) = validate_zpaq_rate_method(method) {
189                    return Err(format!("Invalid zpaq method for AIXI: {err}"));
190                }
191            }
192            other => return Err(format!("Unknown algorithm: {other}")),
193        }
194
195        Ok(())
196    }
197}
198
199/// A complete MC-AIXI agent.
200///
201/// The agent maintains an internal world model and a planning tree. It can
202/// be used for both live interaction with an environment and for
203/// "imaginary" simulations during planning.
204pub struct Agent {
205    /// The world model used for prediction.
206    model: Box<dyn Predictor>,
207    /// The MCTS planner, temporarily taken during search.
208    planner: Option<SearchTree>,
209    /// Configuration settings.
210    config: AgentConfig,
211
212    /// Total number of interaction cycles.
213    age: u64,
214    /// Accumulated reward.
215    total_reward: f64,
216
217    /// Pre-calculated bit depth for actions based on `agent_actions`.
218    action_bits: usize,
219
220    /// Internal PRNG for simulations.
221    rng: RandomGenerator,
222
223    /// Recycled buffer for observation generation during planning.
224    obs_buffer: Vec<u64>,
225    /// Recycled buffer for symbol processing.
226    sym_buffer: Vec<bool>,
227}
228
229impl Agent {
230    /// Creates a new `Agent` with the given configuration.
231    pub fn new(config: AgentConfig) -> Self {
232        Self::try_new(config).unwrap_or_else(|err| panic!("Invalid MC-AIXI config: {err}"))
233    }
234
235    /// Creates a new `Agent` with the given configuration, returning a validation error on failure.
236    pub fn try_new(config: AgentConfig) -> Result<Self, String> {
237        config.validate()?;
238
239        let mut action_bits = 0;
240        let mut c = 1;
241        let mut i = 1;
242        while i < config.agent_actions {
243            i *= 2;
244            action_bits = c;
245            c += 1;
246        }
247        if config.agent_actions == 1 {
248            action_bits = 1;
249        }
250
251        let model = build_model(&config)?;
252
253        let rng = if let Some(seed) = config.random_seed {
254            RandomGenerator::from_seed(seed)
255        } else {
256            RandomGenerator::new()
257        };
258
259        Ok(Self {
260            model,
261            planner: Some(SearchTree::new()),
262            config,
263            age: 0,
264            total_reward: 0.0,
265            action_bits,
266            rng,
267            obs_buffer: Vec::with_capacity(128),
268            sym_buffer: Vec::with_capacity(64),
269        })
270    }
271
272    fn clone_for_simulation(&self, seed: u64) -> Self {
273        Self {
274            model: self.model.boxed_clone(),
275            planner: None,
276            config: self.config.clone(),
277            age: self.age,
278            total_reward: self.total_reward,
279            action_bits: self.action_bits,
280            rng: self.rng.fork_with(seed),
281            obs_buffer: Vec::with_capacity(128),
282            sym_buffer: Vec::with_capacity(64),
283        }
284    }
285
286    /// Resets the agent's interaction statistics.
287    pub fn reset(&mut self) {
288        self.age = 0;
289        self.total_reward = 0.0;
290    }
291
292    /// Primary interface for decision making.
293    ///
294    /// Uses MCTS to find the action that maximizes expected future reward.
295    pub fn get_planned_action(
296        &mut self,
297        prev_obs_stream: &[PerceptVal],
298        prev_rew: Reward,
299        prev_act: Action,
300    ) -> Action {
301        let mut planner = self.planner.take().expect("Planner missing");
302        let num_sim = self.config.num_simulations;
303        let action = planner.search(self, prev_obs_stream, prev_rew, prev_act, num_sim);
304        self.planner = Some(planner);
305        action
306    }
307
308    /// Updates the world model with real-world percepts.
309    pub fn model_update_percept(&mut self, observation: PerceptVal, reward: Reward) {
310        self.model_update_percept_stream(&[observation], reward);
311    }
312
313    /// Updates the world model with an observation stream and a terminal reward.
314    pub fn model_update_percept_stream(&mut self, observations: &[PerceptVal], reward: Reward) {
315        debug_assert!(
316            !observations.is_empty() || self.config.observation_bits == 0,
317            "percept update missing observation stream"
318        );
319        let mut percept_syms = Vec::new();
320        for &obs in observations {
321            encode(&mut percept_syms, obs, self.config.observation_bits);
322        }
323        crate::aixi::common::encode_reward_offset(
324            &mut percept_syms,
325            reward,
326            self.config.reward_bits,
327            self.config.reward_offset,
328        );
329
330        for &sym in &percept_syms {
331            self.model.commit_update(sym);
332        }
333
334        self.total_reward += reward as f64;
335    }
336
337    /// Computes the observation key used for search-tree branching.
338    pub fn observation_repr_from_stream(&self, observations: &[PerceptVal]) -> Vec<PerceptVal> {
339        observation_repr_from_stream(
340            self.config.observation_key_mode,
341            observations,
342            self.config.observation_bits,
343        )
344    }
345
346    /// Explicitly updates the world model with an action.
347    pub fn model_update_action_external(&mut self, action: Action) {
348        self.sym_buffer.clear();
349        encode(&mut self.sym_buffer, action, self.action_bits);
350
351        for &sym in &self.sym_buffer {
352            self.model.commit_update_history(sym);
353        }
354    }
355}
356
357fn build_model(config: &AgentConfig) -> Result<Box<dyn Predictor>, String> {
358    if let Some(rate_backend) = config.rate_backend.clone() {
359        let bit_backend = adapt_rate_backend_for_bit_tokens(rate_backend);
360        let predictor = RateBackendBitPredictor::new(bit_backend, config.rate_backend_max_order)?;
361        return Ok(Box::new(predictor));
362    }
363
364    match config.algorithm.as_str() {
365        // FAC-CTW is the default and recommended CTW variant in
366        // "A Monte-Carlo AIXI Approximation".
367        "ctw" | "fac-ctw" => {
368            let obs_len = config.observation_stream_len.max(1);
369            let percept_bits = (config.observation_bits * obs_len) + config.reward_bits;
370            Ok(Box::new(FacCtwPredictor::new(
371                config.ct_depth,
372                percept_bits,
373            )))
374        }
375        // AC-CTW is the legacy single-tree variant
376        "ac-ctw" | "ctw-context-tree" => Ok(Box::new(CtwPredictor::new(config.ct_depth))),
377        "rosa" => {
378            let max_order = config.rosa_max_order.unwrap_or(20);
379            Ok(Box::new(RosaPredictor::new(max_order)))
380        }
381        #[cfg(feature = "backend-rwkv")]
382        "rwkv" => {
383            if let Some(method) = config
384                .rwkv_method
385                .as_deref()
386                .map(str::trim)
387                .filter(|v| !v.is_empty())
388            {
389                let predictor = RwkvPredictor::from_method(method)
390                    .map_err(|err| format!("Invalid RWKV method for AIXI: {err}"))?;
391                Ok(Box::new(predictor))
392            } else {
393                let path = config.rwkv_model_path.as_ref().ok_or_else(|| {
394                    "RWKV model path required when rwkv_method is not configured".to_string()
395                })?;
396                let model_arc = load_rwkv7_model_from_path(path);
397                Ok(Box::new(RwkvPredictor::new(model_arc)))
398            }
399        }
400        #[cfg(not(feature = "backend-rwkv"))]
401        "rwkv" => Err("RWKV backend disabled at compile time".to_string()),
402        #[cfg(feature = "backend-mamba")]
403        "mamba" => {
404            if let Some(method) = config
405                .mamba_method
406                .as_deref()
407                .map(str::trim)
408                .filter(|v| !v.is_empty())
409            {
410                let predictor = MambaPredictor::from_method(method)
411                    .map_err(|err| format!("Invalid Mamba method for AIXI: {err}"))?;
412                Ok(Box::new(predictor))
413            } else {
414                let path = config.mamba_model_path.as_ref().ok_or_else(|| {
415                    "Mamba model path required when mamba_method is not configured".to_string()
416                })?;
417                let model_arc = load_mamba_model_from_path(path);
418                Ok(Box::new(MambaPredictor::new(model_arc)))
419            }
420        }
421        #[cfg(not(feature = "backend-mamba"))]
422        "mamba" => Err("Mamba backend disabled at compile time".to_string()),
423        "zpaq" => {
424            let method = config
425                .zpaq_method
426                .clone()
427                .unwrap_or_else(|| "1".to_string());
428            if let Err(err) = validate_zpaq_rate_method(&method) {
429                return Err(format!("Invalid zpaq method for AIXI: {err}"));
430            }
431            Ok(Box::new(ZpaqPredictor::new(method, 2f64.powi(-24))))
432        }
433        _ => Err(format!("Unknown algorithm: {}", config.algorithm)),
434    }
435}
436
437impl AgentSimulator for Agent {
438    fn get_num_actions(&self) -> usize {
439        self.config.agent_actions
440    }
441
442    fn get_num_observation_bits(&self) -> usize {
443        self.config.observation_bits
444    }
445
446    fn observation_stream_len(&self) -> usize {
447        self.config.observation_stream_len.max(1)
448    }
449
450    fn observation_key_mode(&self) -> ObservationKeyMode {
451        self.config.observation_key_mode
452    }
453
454    fn get_num_reward_bits(&self) -> usize {
455        self.config.reward_bits
456    }
457
458    fn horizon(&self) -> usize {
459        self.config.agent_horizon
460    }
461
462    fn max_reward(&self) -> Reward {
463        self.config.max_reward
464    }
465
466    fn min_reward(&self) -> Reward {
467        self.config.min_reward
468    }
469
470    fn reward_offset(&self) -> i64 {
471        self.config.reward_offset
472    }
473
474    fn get_explore_exploit_ratio(&self) -> f64 {
475        self.config.exploration_exploitation_ratio
476    }
477
478    fn discount_gamma(&self) -> f64 {
479        self.config.discount_gamma
480    }
481
482    fn model_update_action(&mut self, action: Action) {
483        self.sym_buffer.clear();
484        encode(&mut self.sym_buffer, action, self.action_bits);
485
486        for &sym in &self.sym_buffer {
487            self.model.update_history(sym);
488        }
489    }
490
491    fn gen_percept_and_update(&mut self, bits: usize) -> u64 {
492        self.sym_buffer.clear();
493        for _ in 0..bits {
494            let prob_1 = self.model.predict_one();
495            let sym = self.rng.gen_bool(prob_1);
496            self.model.update(sym);
497            self.sym_buffer.push(sym);
498        }
499        decode(&self.sym_buffer, bits)
500    }
501
502    fn begin_simulation(&mut self) {
503        self.model.begin_rollback_scope();
504    }
505
506    fn gen_percepts_and_update(&mut self) -> (Vec<PerceptVal>, Reward) {
507        let obs_bits = self.config.observation_bits;
508        let obs_len = self.config.observation_stream_len.max(1);
509
510        self.obs_buffer.clear();
511        for _ in 0..obs_len {
512            let p = self.gen_percept_and_update(obs_bits);
513            self.obs_buffer.push(p);
514        }
515
516        let obs_repr = observation_repr_from_stream(
517            self.config.observation_key_mode,
518            &self.obs_buffer,
519            obs_bits,
520        );
521        let rew_bits = self.config.reward_bits;
522        let rew_u = self.gen_percept_and_update(rew_bits);
523        let rew = (rew_u as i64) - self.config.reward_offset;
524
525        // Mark that we've completed a percept cycle (ready for next action)
526
527        (obs_repr, rew)
528    }
529
530    fn gen_range(&mut self, end: usize) -> usize {
531        self.rng.gen_range(end)
532    }
533
534    fn gen_f64(&mut self) -> f64 {
535        self.rng.gen_f64()
536    }
537
538    fn model_revert(&mut self, steps: usize) {
539        if self.model.rollback_scope() {
540            return;
541        }
542        let obs_bits = self.config.observation_bits * self.config.observation_stream_len.max(1);
543        let percept_bits = obs_bits + self.config.reward_bits;
544
545        for _ in 0..steps {
546            for _ in 0..percept_bits {
547                self.model.revert();
548            }
549            for _ in 0..self.action_bits {
550                self.model.pop_history();
551            }
552        }
553    }
554
555    fn boxed_clone_with_seed(&self, seed: u64) -> Box<dyn AgentSimulator> {
556        Box::new(self.clone_for_simulation(seed))
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563    use std::sync::{Arc, Mutex};
564
565    #[derive(Clone, Default)]
566    struct CallCounts {
567        update: usize,
568        commit_update: usize,
569        update_history: usize,
570        commit_update_history: usize,
571        begin_scope: usize,
572        rollback_scope: usize,
573        revert: usize,
574        pop_history: usize,
575    }
576
577    #[derive(Clone)]
578    struct InstrumentedPredictor {
579        counts: Arc<Mutex<CallCounts>>,
580    }
581
582    impl InstrumentedPredictor {
583        fn new(counts: Arc<Mutex<CallCounts>>) -> Self {
584            Self { counts }
585        }
586    }
587
588    impl Predictor for InstrumentedPredictor {
589        fn update(&mut self, _sym: bool) {
590            self.counts.lock().unwrap().update += 1;
591        }
592
593        fn commit_update(&mut self, _sym: bool) {
594            self.counts.lock().unwrap().commit_update += 1;
595        }
596
597        fn update_history(&mut self, _sym: bool) {
598            self.counts.lock().unwrap().update_history += 1;
599        }
600
601        fn commit_update_history(&mut self, _sym: bool) {
602            self.counts.lock().unwrap().commit_update_history += 1;
603        }
604
605        fn revert(&mut self) {
606            self.counts.lock().unwrap().revert += 1;
607        }
608
609        fn pop_history(&mut self) {
610            self.counts.lock().unwrap().pop_history += 1;
611        }
612
613        fn begin_rollback_scope(&mut self) {
614            self.counts.lock().unwrap().begin_scope += 1;
615        }
616
617        fn rollback_scope(&mut self) -> bool {
618            self.counts.lock().unwrap().rollback_scope += 1;
619            true
620        }
621
622        fn predict_prob(&mut self, sym: bool) -> f64 {
623            if sym { 0.75 } else { 0.25 }
624        }
625
626        fn model_name(&self) -> String {
627            "InstrumentedPredictor".to_string()
628        }
629
630        fn boxed_clone(&self) -> Box<dyn Predictor> {
631            Box::new(self.clone())
632        }
633    }
634
635    fn basic_config() -> AgentConfig {
636        AgentConfig {
637            algorithm: "ac-ctw".to_string(),
638            ct_depth: 8,
639            agent_horizon: 2,
640            observation_bits: 2,
641            observation_stream_len: 2,
642            observation_key_mode: ObservationKeyMode::FullStream,
643            reward_bits: 3,
644            agent_actions: 4,
645            num_simulations: 2,
646            exploration_exploitation_ratio: 1.0,
647            discount_gamma: 0.95,
648            min_reward: -2,
649            max_reward: 3,
650            reward_offset: 2,
651            random_seed: Some(7),
652            rate_backend: None,
653            rate_backend_max_order: 8,
654            rwkv_model_path: None,
655            rwkv_method: None,
656            mamba_model_path: None,
657            mamba_method: None,
658            rosa_max_order: None,
659            zpaq_method: None,
660        }
661    }
662
663    #[test]
664    fn external_history_updates_use_committed_predictor_paths() {
665        let mut agent = Agent::try_new(basic_config()).expect("valid agent config");
666        let counts = Arc::new(Mutex::new(CallCounts::default()));
667        agent.model = Box::new(InstrumentedPredictor::new(counts.clone()));
668
669        agent.model_update_percept_stream(&[1, 2], 1);
670        agent.model_update_action_external(3);
671
672        let snapshot = counts.lock().unwrap().clone();
673        assert_eq!(snapshot.commit_update, 7);
674        assert_eq!(snapshot.commit_update_history, 2);
675        assert_eq!(snapshot.update, 0);
676        assert_eq!(snapshot.update_history, 0);
677    }
678
679    #[test]
680    fn simulation_revert_prefers_predictor_scope_when_available() {
681        let mut agent = Agent::try_new(basic_config()).expect("valid agent config");
682        let counts = Arc::new(Mutex::new(CallCounts::default()));
683        agent.model = Box::new(InstrumentedPredictor::new(counts.clone()));
684
685        AgentSimulator::begin_simulation(&mut agent);
686        agent.model_revert(3);
687
688        let snapshot = counts.lock().unwrap().clone();
689        assert_eq!(snapshot.begin_scope, 1);
690        assert_eq!(snapshot.rollback_scope, 1);
691        assert_eq!(snapshot.revert, 0);
692        assert_eq!(snapshot.pop_history, 0);
693    }
694}