infotheory/aixi/
mcts.rs

1//! Monte Carlo Tree Search (MCTS) for AIXI.
2//!
3//! This module implements the planning component of MC-AIXI. It use an upper
4//! confidence bounds applied to trees (UCT) approach to select actions
5//! by simulating future interactions with a world model.
6
7use crate::aixi::common::{
8    Action, ObservationKeyMode, PerceptVal, Reward, observation_repr_from_stream,
9};
10use rayon::prelude::*;
11use std::collections::HashMap;
12
13#[derive(Clone, Debug, Eq, PartialEq, Hash)]
14struct PerceptOutcome {
15    observations: Vec<PerceptVal>,
16    reward: Reward,
17}
18
19/// Interface for an agent that can be simulated during MCTS.
20///
21/// This trait allows the MCTS algorithm to interact with an agent
22/// (like `Agent` in `agent.rs`) to perform "imagined" actions and
23/// receive "imagined" percepts during planning.
24pub trait AgentSimulator: Send + Sync {
25    /// Returns the number of possible actions the agent can perform.
26    fn get_num_actions(&self) -> usize;
27
28    /// Returns the bit-width used to encode observations.
29    fn get_num_observation_bits(&self) -> usize;
30
31    /// Returns the number of observation symbols per action.
32    fn observation_stream_len(&self) -> usize {
33        1
34    }
35
36    /// Returns the observation key mode for search-tree branching.
37    fn observation_key_mode(&self) -> ObservationKeyMode {
38        ObservationKeyMode::FullStream
39    }
40
41    /// Returns the observation representation used for tree branching.
42    fn observation_repr_from_stream(&self, observations: &[PerceptVal]) -> Vec<PerceptVal> {
43        observation_repr_from_stream(
44            self.observation_key_mode(),
45            observations,
46            self.get_num_observation_bits(),
47        )
48    }
49
50    /// Returns the bit-width used to encode rewards.
51    fn get_num_reward_bits(&self) -> usize;
52
53    /// Returns the planning horizon (depth of simulations).
54    fn horizon(&self) -> usize;
55
56    /// Returns the maximum possible reward value.
57    fn max_reward(&self) -> Reward;
58
59    /// Returns the minimum possible reward value.
60    fn min_reward(&self) -> Reward;
61
62    /// Returns the reward offset used to ensure encoded rewards are non-negative.
63    ///
64    /// Paper-compatible encoding uses unsigned reward bits and shifts rewards by an offset.
65    fn reward_offset(&self) -> i64 {
66        0
67    }
68
69    /// Returns the exploration-exploitation constant (often denoted as C).
70    fn get_explore_exploit_ratio(&self) -> f64 {
71        1.0
72    }
73
74    /// Returns the discount factor for future rewards.
75    fn discount_gamma(&self) -> f64 {
76        1.0
77    }
78
79    /// Updates the internal model state with a simulated action.
80    fn model_update_action(&mut self, action: Action);
81
82    /// Generates a simulated percept and updates the model state.
83    fn gen_percept_and_update(&mut self, bits: usize) -> u64;
84
85    /// Reverts the model state to a previous point in the simulation.
86    fn model_revert(&mut self, steps: usize);
87
88    /// Generates a random value in `[0, end)`.
89    fn gen_range(&mut self, end: usize) -> usize;
90
91    /// Generates a random `f64` in `[0, 1)`.
92    fn gen_f64(&mut self) -> f64;
93
94    /// Creates a boxed clone of this simulator for parallel search.
95    fn boxed_clone(&self) -> Box<dyn AgentSimulator> {
96        self.boxed_clone_with_seed(0)
97    }
98
99    /// Creates a boxed clone of this simulator, re-seeding any RNG state.
100    fn boxed_clone_with_seed(&self, seed: u64) -> Box<dyn AgentSimulator>;
101
102    /// Normalizes a reward value to [0, 1] based on the agent's range and horizon.
103    ///
104    /// For discounted rewards, the cumulative range is `sum_{t=0}^{h-1} gamma^t * (max - min)`.
105    /// Similarly, the minimum cumulative reward is `sum_{t=0}^{h-1} gamma^t * min`.
106    fn norm_reward(&self, reward: f64) -> f64 {
107        let min = self.min_reward() as f64;
108        let max = self.max_reward() as f64;
109        let h = self.horizon() as f64;
110        let gamma = self.discount_gamma().clamp(0.0, 1.0);
111
112        // Discounted sum factor: sum_{t=0}^{h-1} gamma^t = (1 - gamma^h) / (1 - gamma) for gamma != 1
113        let discount_sum = if (gamma - 1.0).abs() < 1e-9 {
114            h
115        } else {
116            (1.0 - gamma.powi(h as i32)) / (1.0 - gamma)
117        };
118
119        let range = (max - min) * discount_sum;
120        let min_cumulative = min * discount_sum;
121
122        if range.abs() < 1e-9 {
123            0.5
124        } else {
125            (reward - min_cumulative) / range
126        }
127    }
128
129    /// Helper to generate a percept stream, update the model, and return a search key + reward.
130    fn gen_percepts_and_update(&mut self) -> (Vec<PerceptVal>, Reward) {
131        let obs_bits = self.get_num_observation_bits();
132        let obs_len = self.observation_stream_len().max(1);
133        let mut observations = Vec::with_capacity(obs_len);
134        for _ in 0..obs_len {
135            observations.push(self.gen_percept_and_update(obs_bits));
136        }
137
138        let obs_key = self.observation_repr_from_stream(&observations);
139        let rew_bits = self.get_num_reward_bits();
140        let rew_u = self.gen_percept_and_update(rew_bits);
141        let rew = (rew_u as i64) - self.reward_offset();
142        (obs_key, rew)
143    }
144}
145
146/// A node in the MCTS search tree.
147///
148/// Nodes can be either OR-nodes (representing an agent choice) or
149/// chance nodes (representing an environment response).
150#[derive(Clone)]
151pub struct SearchNode {
152    /// Number of times this node has been visited during search.
153    visits: u32,
154    /// The current mean reward estimated for this node.
155    mean: f64,
156    /// Whether this is a chance node (observation/reward) rather than an action node.
157    is_chance_node: bool,
158    /// Children indexed by action (action nodes only).
159    action_children: Vec<Option<SearchNode>>,
160    /// Children indexed by percept outcome (chance nodes only).
161    percept_children: HashMap<PerceptOutcome, SearchNode>,
162}
163
164impl SearchNode {
165    /// Creates a new `SearchNode`.
166    pub fn new(is_chance_node: bool) -> Self {
167        Self {
168            visits: 0,
169            mean: 0.0,
170            is_chance_node,
171            action_children: Vec::new(),
172            percept_children: HashMap::new(),
173        }
174    }
175
176    /// Selects the best action from this node based on accumulated mean rewards.
177    pub fn best_action(&self, agent: &mut dyn AgentSimulator) -> Action {
178        let mut best_actions = Vec::new();
179        let mut best_mean = -f64::INFINITY;
180
181        for (action, child) in self.action_children.iter().enumerate() {
182            let Some(child) = child.as_ref() else {
183                continue;
184            };
185            let mean = child.mean;
186            if mean > best_mean {
187                best_mean = mean;
188                best_actions.clear();
189                best_actions.push(action as u64);
190            } else if (mean - best_mean).abs() < 1e-9 {
191                best_actions.push(action as u64);
192            }
193        }
194
195        if best_actions.is_empty() {
196            return 0;
197        }
198
199        let idx = agent.gen_range(best_actions.len());
200        best_actions[idx] as Action
201    }
202
203    fn expectation(&self) -> f64 {
204        self.mean
205    }
206
207    fn apply_delta(&mut self, base: &SearchNode, updated: &SearchNode) {
208        if self.is_chance_node != base.is_chance_node
209            || self.is_chance_node != updated.is_chance_node
210        {
211            return;
212        }
213
214        let base_visits = base.visits as f64;
215        let updated_visits = updated.visits as f64;
216        if updated_visits < base_visits {
217            return;
218        }
219
220        let delta_visits = updated.visits - base.visits;
221        if delta_visits > 0 {
222            let base_sum = base.mean * base_visits;
223            let updated_sum = updated.mean * updated_visits;
224            let delta_sum = updated_sum - base_sum;
225            let total_visits = self.visits + delta_visits;
226            let total_sum = self.mean * (self.visits as f64) + delta_sum;
227            self.visits = total_visits;
228            self.mean = if total_visits > 0 {
229                total_sum / (total_visits as f64)
230            } else {
231                0.0
232            };
233        }
234
235        if self.is_chance_node {
236            for (key, updated_child) in &updated.percept_children {
237                if let Some(base_child) = base.percept_children.get(key) {
238                    if let Some(self_child) = self.percept_children.get_mut(key) {
239                        self_child.apply_delta(base_child, updated_child);
240                    } else {
241                        let mut child = SearchNode::new(updated_child.is_chance_node);
242                        child.apply_delta(
243                            &SearchNode::new(updated_child.is_chance_node),
244                            updated_child,
245                        );
246                        self.percept_children.insert(key.clone(), child);
247                    }
248                } else if let Some(self_child) = self.percept_children.get_mut(key) {
249                    let empty = SearchNode::new(updated_child.is_chance_node);
250                    self_child.apply_delta(&empty, updated_child);
251                } else {
252                    let mut child = SearchNode::new(updated_child.is_chance_node);
253                    child.apply_delta(
254                        &SearchNode::new(updated_child.is_chance_node),
255                        updated_child,
256                    );
257                    self.percept_children.insert(key.clone(), child);
258                }
259            }
260        } else {
261            let max_len = base
262                .action_children
263                .len()
264                .max(updated.action_children.len());
265            if self.action_children.len() < max_len {
266                self.action_children.resize_with(max_len, || None);
267            }
268            for idx in 0..max_len {
269                let base_child = base.action_children.get(idx).and_then(|c| c.as_ref());
270                let updated_child = updated.action_children.get(idx).and_then(|c| c.as_ref());
271                let Some(updated_child) = updated_child else {
272                    continue;
273                };
274                match (base_child, self.action_children.get_mut(idx)) {
275                    (Some(base_child), Some(Some(self_child))) => {
276                        self_child.apply_delta(base_child, updated_child);
277                    }
278                    (Some(base_child), Some(slot @ None)) => {
279                        let mut child = SearchNode::new(updated_child.is_chance_node);
280                        child.apply_delta(base_child, updated_child);
281                        *slot = Some(child);
282                    }
283                    (None, Some(Some(self_child))) => {
284                        let empty = SearchNode::new(updated_child.is_chance_node);
285                        self_child.apply_delta(&empty, updated_child);
286                    }
287                    (None, Some(slot @ None)) => {
288                        let mut child = SearchNode::new(updated_child.is_chance_node);
289                        child.apply_delta(
290                            &SearchNode::new(updated_child.is_chance_node),
291                            updated_child,
292                        );
293                        *slot = Some(child);
294                    }
295                    _ => {}
296                }
297            }
298        }
299    }
300
301    /// Selects an action to explore, potentially creating a new child node.
302    fn select_action(
303        &mut self,
304        agent: &mut dyn AgentSimulator,
305        _horizon: usize,
306    ) -> (&mut SearchNode, Action) {
307        let num_actions = agent.get_num_actions();
308
309        if self.action_children.len() < num_actions {
310            self.action_children.resize_with(num_actions, || None);
311        }
312
313        let mut unvisited = Vec::new();
314        for a in 0..num_actions {
315            if self.action_children[a].is_none() {
316                unvisited.push(a as u64);
317            }
318        }
319
320        let action;
321        if !unvisited.is_empty() {
322            let idx = agent.gen_range(unvisited.len());
323            action = unvisited[idx];
324            self.action_children[action as usize] = Some(SearchNode::new(true));
325        } else {
326            // UCT Formula: exploit + explore
327            let c = agent.get_explore_exploit_ratio();
328            let mut best_val = -f64::INFINITY;
329            let mut best_action = 0;
330            let log_visits = (self.visits as f64).ln().max(0.0);
331            for (a, child) in self.action_children.iter().enumerate() {
332                let Some(child) = child.as_ref() else {
333                    continue;
334                };
335                let exploit = agent.norm_reward(child.expectation());
336                let explore = if child.visits > 0 {
337                    c * (log_visits / child.visits as f64).sqrt()
338                } else {
339                    f64::INFINITY
340                };
341                let val = exploit + explore;
342                if val > best_val {
343                    best_val = val;
344                    best_action = a as u64;
345                }
346            }
347            action = best_action;
348        }
349
350        agent.model_update_action(action as Action);
351        (
352            self.action_children[action as usize]
353                .as_mut()
354                .expect("missing action child"),
355            action as Action,
356        )
357    }
358
359    /// Performs a single simulation (sample) from this node.
360    pub fn sample(
361        &mut self,
362        agent: &mut dyn AgentSimulator,
363        horizon: usize,
364        total_horizon: usize,
365    ) -> f64 {
366        if horizon == 0 {
367            agent.model_revert(total_horizon);
368            return 0.0;
369        }
370
371        let reward;
372        if self.is_chance_node {
373            let (obs, rew) = agent.gen_percepts_and_update();
374            let key = PerceptOutcome {
375                observations: obs,
376                reward: rew,
377            };
378            let child = self
379                .percept_children
380                .entry(key)
381                .or_insert_with(|| SearchNode::new(false));
382            reward = (rew as f64)
383                + agent.discount_gamma() * child.sample(agent, horizon - 1, total_horizon);
384        } else if self.visits == 0 {
385            reward = Self::playout(agent, horizon, total_horizon);
386        } else {
387            let (child, _act) = self.select_action(agent, horizon);
388            reward = child.sample(agent, horizon, total_horizon);
389        }
390
391        // Update mean logic:
392        self.mean = (reward + (self.visits as f64) * self.mean) / ((self.visits + 1) as f64);
393        self.visits += 1;
394
395        reward
396    }
397
398    /// Performs a randomized simulation until the horizon is reached.
399    fn playout(agent: &mut dyn AgentSimulator, horizon: usize, total_horizon: usize) -> f64 {
400        let mut total_rew = 0.0;
401        let num_actions = agent.get_num_actions();
402        let gamma = agent.discount_gamma().clamp(0.0, 1.0);
403        let mut discount = 1.0;
404
405        for _ in 0..horizon {
406            let act = agent.gen_range(num_actions);
407            agent.model_update_action(act as Action);
408            let (_key, rew) = agent.gen_percepts_and_update();
409            total_rew += discount * (rew as f64);
410            discount *= gamma;
411        }
412
413        agent.model_revert(total_horizon);
414        total_rew
415    }
416}
417
418/// Manages the MCTS tree and provides the `search` entry point.
419pub struct SearchTree {
420    root: Option<SearchNode>,
421}
422
423impl SearchTree {
424    /// Creates a new `SearchTree`.
425    pub fn new() -> Self {
426        Self {
427            root: Some(SearchNode::new(false)),
428        }
429    }
430
431    /// Performs several MCTS simulations to find the best next action.
432    pub fn search(
433        &mut self,
434        agent: &mut dyn AgentSimulator,
435        prev_obs_stream: &[PerceptVal],
436        prev_rew: Reward,
437        prev_act: u64,
438        samples: usize,
439    ) -> Action {
440        self.prune_tree(agent, prev_obs_stream, prev_rew, prev_act);
441
442        let root = self.root.as_mut().unwrap();
443        let h = agent.horizon();
444        let threads = rayon::current_num_threads().max(1);
445        if samples < 2 || threads < 2 {
446            for _ in 0..samples {
447                root.sample(agent, h, h);
448            }
449            return root.best_action(agent);
450        }
451
452        let workers = threads.min(samples);
453        let base = samples / workers;
454        let extra = samples % workers;
455        let snapshot = root.clone();
456
457        let mut agents = Vec::with_capacity(workers);
458        for i in 0..workers {
459            let seed = agent.gen_f64().to_bits() ^ (i as u64);
460            agents.push(agent.boxed_clone_with_seed(seed));
461        }
462
463        let results: Vec<SearchNode> = agents
464            .into_par_iter()
465            .enumerate()
466            .map(|(i, mut local_agent)| {
467                let mut local_root = snapshot.clone();
468                let iterations = base + usize::from(i < extra);
469                for _ in 0..iterations {
470                    local_root.sample(local_agent.as_mut(), h, h);
471                }
472                local_root
473            })
474            .collect();
475
476        for local in &results {
477            root.apply_delta(&snapshot, local);
478        }
479
480        root.best_action(agent)
481    }
482
483    /// Prunes the tree, keeping only relevant subtrees based on the previous interaction.
484    fn prune_tree(
485        &mut self,
486        agent: &mut dyn AgentSimulator,
487        prev_obs_stream: &[PerceptVal],
488        prev_rew: Reward,
489        prev_act: u64,
490    ) {
491        if self.root.is_none() {
492            self.root = Some(SearchNode::new(false));
493            return;
494        }
495
496        let mut old_root = self.root.take().unwrap();
497
498        // Find chance child (prev_act)
499        let action_child_opt = if old_root.action_children.len() > prev_act as usize {
500            old_root.action_children[prev_act as usize].take()
501        } else {
502            None
503        };
504
505        if let Some(mut chance_child) = action_child_opt {
506            let obs_repr = agent.observation_repr_from_stream(prev_obs_stream);
507            let key = PerceptOutcome {
508                observations: obs_repr,
509                reward: prev_rew,
510            };
511
512            if let Some(action_child) = chance_child.percept_children.remove(&key) {
513                self.root = Some(action_child);
514            } else {
515                self.root = Some(SearchNode::new(false));
516            }
517        } else {
518            self.root = Some(SearchNode::new(false));
519        }
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526    use crate::aixi::common::ObservationKeyMode;
527
528    #[derive(Clone)]
529    struct DummyAgent {
530        obs_bits: usize,
531        rew_bits: usize,
532        horizon: usize,
533        min_reward: Reward,
534        max_reward: Reward,
535        key_mode: ObservationKeyMode,
536    }
537
538    impl DummyAgent {
539        fn new(obs_bits: usize, key_mode: ObservationKeyMode) -> Self {
540            Self {
541                obs_bits,
542                rew_bits: 8,
543                horizon: 5,
544                min_reward: -1,
545                max_reward: 1,
546                key_mode,
547            }
548        }
549    }
550
551    impl AgentSimulator for DummyAgent {
552        fn get_num_actions(&self) -> usize {
553            4
554        }
555
556        fn get_num_observation_bits(&self) -> usize {
557            self.obs_bits
558        }
559
560        fn observation_key_mode(&self) -> ObservationKeyMode {
561            self.key_mode
562        }
563
564        fn get_num_reward_bits(&self) -> usize {
565            self.rew_bits
566        }
567
568        fn horizon(&self) -> usize {
569            self.horizon
570        }
571
572        fn max_reward(&self) -> Reward {
573            self.max_reward
574        }
575
576        fn min_reward(&self) -> Reward {
577            self.min_reward
578        }
579
580        fn model_update_action(&mut self, _action: Action) {}
581
582        fn gen_percept_and_update(&mut self, _bits: usize) -> u64 {
583            0
584        }
585
586        fn model_revert(&mut self, _steps: usize) {}
587
588        fn gen_range(&mut self, _end: usize) -> usize {
589            0
590        }
591
592        fn gen_f64(&mut self) -> f64 {
593            0.0
594        }
595
596        fn boxed_clone_with_seed(&self, _seed: u64) -> Box<dyn AgentSimulator> {
597            Box::new(self.clone())
598        }
599    }
600
601    fn build_tree_with_key(
602        agent: &DummyAgent,
603        prev_act: u64,
604        prev_obs_stream: &[PerceptVal],
605        prev_rew: Reward,
606        kept_mean: f64,
607        kept_visits: u32,
608    ) -> SearchTree {
609        let mut old_root = SearchNode::new(false);
610        old_root.action_children.resize(prev_act as usize + 1, None);
611
612        let mut chance_child = SearchNode::new(true);
613        let mut kept = SearchNode::new(false);
614        kept.mean = kept_mean;
615        kept.visits = kept_visits;
616
617        let obs_repr = agent.observation_repr_from_stream(prev_obs_stream);
618        let key = PerceptOutcome {
619            observations: obs_repr,
620            reward: prev_rew,
621        };
622        chance_child.percept_children.insert(key, kept);
623
624        old_root.action_children[prev_act as usize] = Some(chance_child);
625        SearchTree {
626            root: Some(old_root),
627        }
628    }
629
630    #[test]
631    fn prune_tree_keeps_matching_subtree() {
632        let prev_act = 2u64;
633        let prev_obs_stream = vec![9u64, 2u64, 7u64];
634        let prev_rew: Reward = 3;
635
636        let mut agent = DummyAgent::new(3, ObservationKeyMode::FullStream);
637        let mut tree = build_tree_with_key(&agent, prev_act, &prev_obs_stream, prev_rew, 123.0, 7);
638
639        tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
640
641        let root = tree.root.as_ref().expect("root should exist");
642        assert!(!root.is_chance_node);
643        assert_eq!(root.mean, 123.0);
644        assert_eq!(root.visits, 7);
645    }
646
647    #[test]
648    fn prune_tree_resets_when_action_missing() {
649        let prev_act = 10u64;
650        let prev_obs_stream = vec![1u64];
651        let prev_rew: Reward = 0;
652
653        let mut agent = DummyAgent::new(1, ObservationKeyMode::FullStream);
654        let mut tree = SearchTree::new();
655
656        tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
657
658        let root = tree.root.as_ref().unwrap();
659        assert!(!root.is_chance_node);
660        assert_eq!(root.visits, 0);
661        assert_eq!(root.mean, 0.0);
662    }
663
664    #[test]
665    fn prune_tree_resets_when_percept_key_missing() {
666        let prev_act = 0u64;
667        let prev_obs_stream = vec![1u64, 2u64];
668        let prev_rew: Reward = 1;
669
670        let mut agent = DummyAgent::new(4, ObservationKeyMode::Last);
671
672        // Build tree keyed on a different reward so the percept key won't match.
673        let mut tree = build_tree_with_key(&agent, prev_act, &prev_obs_stream, prev_rew + 1, 9.0, 2);
674
675        tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
676
677        let root = tree.root.as_ref().unwrap();
678        assert!(!root.is_chance_node);
679        assert_eq!(root.visits, 0);
680        assert_eq!(root.mean, 0.0);
681    }
682}