infotheory/aixi/
mcts.rs

1//! Monte Carlo Tree Search (MCTS) for AIXI.
2//!
3//! This module implements the planning component of MC-AIXI. It use an upper
4//! confidence bounds applied to trees (UCT) approach to select actions
5//! by simulating future interactions with a world model.
6
7use crate::aixi::common::{
8    Action, ObservationKeyMode, PerceptVal, Reward, observation_repr_from_stream,
9};
10use rayon::prelude::*;
11use std::collections::HashMap;
12
13/// Hash key for a sampled percept outcome at a chance node.
14///
15/// Both the observation representation and the immediate reward are required
16/// to identify the correct continuation subtree for generic environments.
17/// Some environments can emit the same observation alongside different
18/// rewards, so observation-only keys would incorrectly merge distinct
19/// successor states during search-tree reuse.
20#[derive(Clone, Debug, Eq, PartialEq, Hash)]
21struct PerceptOutcome {
22    /// Observation symbols used for chance-node branching.
23    observations: Box<[PerceptVal]>,
24    /// Immediate reward observed on the sampled edge.
25    reward: Reward,
26}
27
28impl PerceptOutcome {
29    /// Creates a compact percept key from an observation stream and reward.
30    fn new(observations: Vec<PerceptVal>, reward: Reward) -> Self {
31        Self {
32            observations: observations.into_boxed_slice(),
33            reward,
34        }
35    }
36}
37
38/// Interface for an agent that can be simulated during MCTS.
39///
40/// This trait allows the MCTS algorithm to interact with an agent
41/// (like `Agent` in `agent.rs`) to perform "imagined" actions and
42/// receive "imagined" percepts during planning.
43pub trait AgentSimulator: Send {
44    /// Returns the number of possible actions the agent can perform.
45    fn get_num_actions(&self) -> usize;
46
47    /// Returns the bit-width used to encode observations.
48    fn get_num_observation_bits(&self) -> usize;
49
50    /// Returns the number of observation symbols per action.
51    fn observation_stream_len(&self) -> usize {
52        1
53    }
54
55    /// Returns the observation key mode for search-tree branching.
56    fn observation_key_mode(&self) -> ObservationKeyMode {
57        ObservationKeyMode::FullStream
58    }
59
60    /// Returns the observation representation used for tree branching.
61    fn observation_repr_from_stream(&self, observations: &[PerceptVal]) -> Vec<PerceptVal> {
62        observation_repr_from_stream(
63            self.observation_key_mode(),
64            observations,
65            self.get_num_observation_bits(),
66        )
67    }
68
69    /// Returns the bit-width used to encode rewards.
70    fn get_num_reward_bits(&self) -> usize;
71
72    /// Returns the planning horizon (depth of simulations).
73    fn horizon(&self) -> usize;
74
75    /// Returns the maximum possible reward value.
76    fn max_reward(&self) -> Reward;
77
78    /// Returns the minimum possible reward value.
79    fn min_reward(&self) -> Reward;
80
81    /// Returns the reward offset used to ensure encoded rewards are non-negative.
82    ///
83    /// Paper-compatible encoding uses unsigned reward bits and shifts rewards by an offset.
84    fn reward_offset(&self) -> i64 {
85        0
86    }
87
88    /// Returns the exploration-exploitation constant (often denoted as C).
89    fn get_explore_exploit_ratio(&self) -> f64 {
90        1.0
91    }
92
93    /// Returns the discount factor for future rewards.
94    fn discount_gamma(&self) -> f64 {
95        1.0
96    }
97
98    /// Updates the internal model state with a simulated action.
99    fn model_update_action(&mut self, action: Action);
100
101    /// Generates a simulated percept and updates the model state.
102    fn gen_percept_and_update(&mut self, bits: usize) -> u64;
103
104    /// Reverts the model state to a previous point in the simulation.
105    fn model_revert(&mut self, steps: usize);
106
107    /// Generates a random value in `[0, end)`.
108    fn gen_range(&mut self, end: usize) -> usize;
109
110    /// Generates a random `f64` in `[0, 1)`.
111    fn gen_f64(&mut self) -> f64;
112
113    /// Creates a boxed clone of this simulator for parallel search.
114    fn boxed_clone(&self) -> Box<dyn AgentSimulator> {
115        self.boxed_clone_with_seed(0)
116    }
117
118    /// Creates a boxed clone of this simulator, re-seeding any RNG state.
119    fn boxed_clone_with_seed(&self, seed: u64) -> Box<dyn AgentSimulator>;
120
121    /// Normalizes a reward value to [0, 1] based on the agent's range and horizon.
122    ///
123    /// For discounted rewards, the cumulative range is `sum_{t=0}^{h-1} gamma^t * (max - min)`.
124    /// Similarly, the minimum cumulative reward is `sum_{t=0}^{h-1} gamma^t * min`.
125    fn norm_reward(&self, reward: f64) -> f64 {
126        let min = self.min_reward() as f64;
127        let max = self.max_reward() as f64;
128        let h = self.horizon() as f64;
129        let gamma = self.discount_gamma().clamp(0.0, 1.0);
130
131        // Discounted sum factor: sum_{t=0}^{h-1} gamma^t = (1 - gamma^h) / (1 - gamma) for gamma != 1
132        let discount_sum = if (gamma - 1.0).abs() < 1e-9 {
133            h
134        } else {
135            (1.0 - gamma.powi(h as i32)) / (1.0 - gamma)
136        };
137
138        let range = (max - min) * discount_sum;
139        let min_cumulative = min * discount_sum;
140
141        if range.abs() < 1e-9 {
142            0.5
143        } else {
144            (reward - min_cumulative) / range
145        }
146    }
147
148    /// Helper to generate a percept stream, update the model, and return a search key + reward.
149    fn gen_percepts_and_update(&mut self) -> (Vec<PerceptVal>, Reward) {
150        let obs_bits = self.get_num_observation_bits();
151        let obs_len = self.observation_stream_len().max(1);
152        let mut observations = Vec::with_capacity(obs_len);
153        for _ in 0..obs_len {
154            observations.push(self.gen_percept_and_update(obs_bits));
155        }
156
157        let obs_key = self.observation_repr_from_stream(&observations);
158        let rew_bits = self.get_num_reward_bits();
159        let rew_u = self.gen_percept_and_update(rew_bits);
160        let rew = (rew_u as i64) - self.reward_offset();
161        (obs_key, rew)
162    }
163}
164
165/// A node in the MCTS search tree.
166///
167/// Nodes can be either OR-nodes (representing an agent choice) or
168/// chance nodes (representing an environment response).
169#[derive(Clone)]
170pub struct SearchNode {
171    /// Number of times this node has been visited during search.
172    visits: u32,
173    /// The current mean reward estimated for this node.
174    mean: f64,
175    /// Whether this is a chance node (observation/reward) rather than an action node.
176    is_chance_node: bool,
177    /// Children indexed by action (action nodes only).
178    action_children: Vec<Option<SearchNode>>,
179    /// Children indexed by percept outcome (chance nodes only).
180    percept_children: HashMap<PerceptOutcome, SearchNode>,
181}
182
183impl SearchNode {
184    /// Creates a new `SearchNode`.
185    pub fn new(is_chance_node: bool) -> Self {
186        Self {
187            visits: 0,
188            mean: 0.0,
189            is_chance_node,
190            action_children: Vec::new(),
191            percept_children: HashMap::new(),
192        }
193    }
194
195    /// Selects the best action from this node based on accumulated mean rewards.
196    pub fn best_action(&self, agent: &mut dyn AgentSimulator) -> Action {
197        let mut best_actions = Vec::new();
198        let mut best_mean = -f64::INFINITY;
199
200        for (action, child) in self.action_children.iter().enumerate() {
201            let Some(child) = child.as_ref() else {
202                continue;
203            };
204            let mean = child.mean;
205            if mean > best_mean {
206                best_mean = mean;
207                best_actions.clear();
208                best_actions.push(action as u64);
209            } else if (mean - best_mean).abs() < 1e-9 {
210                best_actions.push(action as u64);
211            }
212        }
213
214        if best_actions.is_empty() {
215            return 0;
216        }
217
218        let idx = agent.gen_range(best_actions.len());
219        best_actions[idx] as Action
220    }
221
222    fn expectation(&self) -> f64 {
223        self.mean
224    }
225
226    fn apply_delta(&mut self, base: &SearchNode, updated: &SearchNode) {
227        if self.is_chance_node != base.is_chance_node
228            || self.is_chance_node != updated.is_chance_node
229        {
230            return;
231        }
232
233        let base_visits = base.visits as f64;
234        let updated_visits = updated.visits as f64;
235        if updated_visits < base_visits {
236            return;
237        }
238
239        let delta_visits = updated.visits - base.visits;
240        if delta_visits > 0 {
241            let base_sum = base.mean * base_visits;
242            let updated_sum = updated.mean * updated_visits;
243            let delta_sum = updated_sum - base_sum;
244            let total_visits = self.visits + delta_visits;
245            let total_sum = self.mean * (self.visits as f64) + delta_sum;
246            self.visits = total_visits;
247            self.mean = if total_visits > 0 {
248                total_sum / (total_visits as f64)
249            } else {
250                0.0
251            };
252        }
253
254        if self.is_chance_node {
255            for (key, updated_child) in &updated.percept_children {
256                if let Some(base_child) = base.percept_children.get(key) {
257                    if let Some(self_child) = self.percept_children.get_mut(key) {
258                        self_child.apply_delta(base_child, updated_child);
259                    } else {
260                        let mut child = SearchNode::new(updated_child.is_chance_node);
261                        child.apply_delta(
262                            &SearchNode::new(updated_child.is_chance_node),
263                            updated_child,
264                        );
265                        self.percept_children.insert(key.clone(), child);
266                    }
267                } else if let Some(self_child) = self.percept_children.get_mut(key) {
268                    let empty = SearchNode::new(updated_child.is_chance_node);
269                    self_child.apply_delta(&empty, updated_child);
270                } else {
271                    let mut child = SearchNode::new(updated_child.is_chance_node);
272                    child.apply_delta(
273                        &SearchNode::new(updated_child.is_chance_node),
274                        updated_child,
275                    );
276                    self.percept_children.insert(key.clone(), child);
277                }
278            }
279        } else {
280            let max_len = base
281                .action_children
282                .len()
283                .max(updated.action_children.len());
284            if self.action_children.len() < max_len {
285                self.action_children.resize_with(max_len, || None);
286            }
287            for idx in 0..max_len {
288                let base_child = base.action_children.get(idx).and_then(|c| c.as_ref());
289                let updated_child = updated.action_children.get(idx).and_then(|c| c.as_ref());
290                let Some(updated_child) = updated_child else {
291                    continue;
292                };
293                match (base_child, self.action_children.get_mut(idx)) {
294                    (Some(base_child), Some(Some(self_child))) => {
295                        self_child.apply_delta(base_child, updated_child);
296                    }
297                    (Some(base_child), Some(slot @ None)) => {
298                        let mut child = SearchNode::new(updated_child.is_chance_node);
299                        child.apply_delta(base_child, updated_child);
300                        *slot = Some(child);
301                    }
302                    (None, Some(Some(self_child))) => {
303                        let empty = SearchNode::new(updated_child.is_chance_node);
304                        self_child.apply_delta(&empty, updated_child);
305                    }
306                    (None, Some(slot @ None)) => {
307                        let mut child = SearchNode::new(updated_child.is_chance_node);
308                        child.apply_delta(
309                            &SearchNode::new(updated_child.is_chance_node),
310                            updated_child,
311                        );
312                        *slot = Some(child);
313                    }
314                    _ => {}
315                }
316            }
317        }
318    }
319
320    /// Selects an action to explore, potentially creating a new child node.
321    fn select_action(&mut self, agent: &mut dyn AgentSimulator) -> (&mut SearchNode, Action) {
322        let num_actions = agent.get_num_actions();
323
324        if self.action_children.len() < num_actions {
325            self.action_children.resize_with(num_actions, || None);
326        }
327
328        let mut unvisited = Vec::new();
329        for a in 0..num_actions {
330            if self.action_children[a].is_none() {
331                unvisited.push(a as u64);
332            }
333        }
334
335        let action;
336        if !unvisited.is_empty() {
337            let idx = agent.gen_range(unvisited.len());
338            action = unvisited[idx];
339            self.action_children[action as usize] = Some(SearchNode::new(true));
340        } else {
341            // Match reference MC-AIXI UCB scaling:
342            // priority = E[return] + horizon*max_reward*sqrt(C*log(N)/n)
343            let c = agent.get_explore_exploit_ratio().max(0.0);
344            let explore_bias = (agent.horizon() as f64) * (agent.max_reward() as f64).max(0.0);
345            let mut best_val = -f64::INFINITY;
346            let mut best_action = 0;
347            let log_visits = (self.visits as f64).ln().max(0.0);
348            for (a, child) in self.action_children.iter().enumerate() {
349                let Some(child) = child.as_ref() else {
350                    continue;
351                };
352                let nvisits = child.visits as f64;
353                let val = child.expectation() + explore_bias * ((c * log_visits) / nvisits).sqrt();
354                // Keep random tie-break behavior from reference implementations.
355                if val > best_val + agent.gen_f64() * 0.001 {
356                    best_val = val;
357                    best_action = a as u64;
358                }
359            }
360            action = best_action;
361        }
362
363        agent.model_update_action(action as Action);
364        (
365            self.action_children[action as usize]
366                .as_mut()
367                .expect("missing action child"),
368            action as Action,
369        )
370    }
371
372    /// Performs a single simulation (sample) from this node.
373    pub fn sample(
374        &mut self,
375        agent: &mut dyn AgentSimulator,
376        horizon: usize,
377        total_horizon: usize,
378    ) -> f64 {
379        if horizon == 0 {
380            agent.model_revert(total_horizon);
381            return 0.0;
382        }
383
384        let reward;
385        if self.is_chance_node {
386            let (obs, rew) = agent.gen_percepts_and_update();
387            let key = PerceptOutcome::new(obs, rew);
388            let child = self
389                .percept_children
390                .entry(key)
391                .or_insert_with(|| SearchNode::new(false));
392            reward = (rew as f64)
393                + agent.discount_gamma() * child.sample(agent, horizon - 1, total_horizon);
394        } else if self.visits == 0 {
395            reward = Self::playout(agent, horizon, total_horizon);
396        } else {
397            let (child, _act) = self.select_action(agent);
398            reward = child.sample(agent, horizon, total_horizon);
399        }
400
401        // Update mean logic:
402        self.mean = (reward + (self.visits as f64) * self.mean) / ((self.visits + 1) as f64);
403        self.visits += 1;
404
405        reward
406    }
407
408    /// Performs a randomized simulation until the horizon is reached.
409    fn playout(agent: &mut dyn AgentSimulator, horizon: usize, total_horizon: usize) -> f64 {
410        let mut total_rew = 0.0;
411        let num_actions = agent.get_num_actions();
412        let gamma = agent.discount_gamma().clamp(0.0, 1.0);
413        let mut discount = 1.0;
414
415        for _ in 0..horizon {
416            let act = agent.gen_range(num_actions);
417            agent.model_update_action(act as Action);
418            let (_key, rew) = agent.gen_percepts_and_update();
419            total_rew += discount * (rew as f64);
420            discount *= gamma;
421        }
422
423        agent.model_revert(total_horizon);
424        total_rew
425    }
426}
427
428/// Manages the MCTS tree and provides the `search` entry point.
429pub struct SearchTree {
430    root: Option<SearchNode>,
431}
432
433impl SearchTree {
434    /// Creates a new `SearchTree`.
435    pub fn new() -> Self {
436        Self {
437            root: Some(SearchNode::new(false)),
438        }
439    }
440
441    /// Performs several MCTS simulations to find the best next action.
442    pub fn search(
443        &mut self,
444        agent: &mut dyn AgentSimulator,
445        prev_obs_stream: &[PerceptVal],
446        prev_rew: Reward,
447        prev_act: u64,
448        samples: usize,
449    ) -> Action {
450        self.prune_tree(agent, prev_obs_stream, prev_rew, prev_act);
451
452        let root = self.root.as_mut().unwrap();
453        let h = agent.horizon();
454        let threads = rayon::current_num_threads().max(1);
455        if samples < 2 || threads < 2 {
456            for _ in 0..samples {
457                root.sample(agent, h, h);
458            }
459            return root.best_action(agent);
460        }
461
462        let workers = threads.min(samples);
463        let base = samples / workers;
464        let extra = samples % workers;
465        let snapshot = root.clone();
466
467        let mut agents = Vec::with_capacity(workers);
468        for i in 0..workers {
469            let seed = agent.gen_f64().to_bits() ^ (i as u64);
470            agents.push(agent.boxed_clone_with_seed(seed));
471        }
472
473        let results: Vec<SearchNode> = agents
474            .into_par_iter()
475            .enumerate()
476            .map(|(i, mut local_agent)| {
477                let mut local_root = snapshot.clone();
478                let iterations = base + usize::from(i < extra);
479                for _ in 0..iterations {
480                    local_root.sample(local_agent.as_mut(), h, h);
481                }
482                local_root
483            })
484            .collect();
485
486        for local in &results {
487            root.apply_delta(&snapshot, local);
488        }
489
490        root.best_action(agent)
491    }
492
493    /// Prunes the tree, keeping only relevant subtrees based on the previous interaction.
494    fn prune_tree(
495        &mut self,
496        agent: &mut dyn AgentSimulator,
497        prev_obs_stream: &[PerceptVal],
498        prev_rew: Reward,
499        prev_act: u64,
500    ) {
501        if self.root.is_none() {
502            self.root = Some(SearchNode::new(false));
503            return;
504        }
505
506        let mut old_root = self.root.take().unwrap();
507
508        // Find chance child (prev_act)
509        let action_child_opt = if old_root.action_children.len() > prev_act as usize {
510            old_root.action_children[prev_act as usize].take()
511        } else {
512            None
513        };
514
515        if let Some(mut chance_child) = action_child_opt {
516            let obs_repr = agent.observation_repr_from_stream(prev_obs_stream);
517            let key = PerceptOutcome::new(obs_repr, prev_rew);
518
519            if let Some(action_child) = chance_child.percept_children.remove(&key) {
520                self.root = Some(action_child);
521            } else {
522                self.root = Some(SearchNode::new(false));
523            }
524        } else {
525            self.root = Some(SearchNode::new(false));
526        }
527    }
528}
529
530impl Default for SearchTree {
531    fn default() -> Self {
532        Self::new()
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539    use crate::aixi::common::ObservationKeyMode;
540
541    #[derive(Clone)]
542    struct DummyAgent {
543        obs_bits: usize,
544        rew_bits: usize,
545        horizon: usize,
546        min_reward: Reward,
547        max_reward: Reward,
548        key_mode: ObservationKeyMode,
549    }
550
551    impl DummyAgent {
552        fn new(obs_bits: usize, key_mode: ObservationKeyMode) -> Self {
553            Self {
554                obs_bits,
555                rew_bits: 8,
556                horizon: 5,
557                min_reward: -1,
558                max_reward: 1,
559                key_mode,
560            }
561        }
562    }
563
564    impl AgentSimulator for DummyAgent {
565        fn get_num_actions(&self) -> usize {
566            4
567        }
568
569        fn get_num_observation_bits(&self) -> usize {
570            self.obs_bits
571        }
572
573        fn observation_key_mode(&self) -> ObservationKeyMode {
574            self.key_mode
575        }
576
577        fn get_num_reward_bits(&self) -> usize {
578            self.rew_bits
579        }
580
581        fn horizon(&self) -> usize {
582            self.horizon
583        }
584
585        fn max_reward(&self) -> Reward {
586            self.max_reward
587        }
588
589        fn min_reward(&self) -> Reward {
590            self.min_reward
591        }
592
593        fn model_update_action(&mut self, _action: Action) {}
594
595        fn gen_percept_and_update(&mut self, _bits: usize) -> u64 {
596            0
597        }
598
599        fn model_revert(&mut self, _steps: usize) {}
600
601        fn gen_range(&mut self, _end: usize) -> usize {
602            0
603        }
604
605        fn gen_f64(&mut self) -> f64 {
606            0.0
607        }
608
609        fn boxed_clone_with_seed(&self, _seed: u64) -> Box<dyn AgentSimulator> {
610            Box::new(self.clone())
611        }
612    }
613
614    fn build_tree_with_key(
615        agent: &DummyAgent,
616        prev_act: u64,
617        prev_obs_stream: &[PerceptVal],
618        prev_rew: Reward,
619        kept_mean: f64,
620        kept_visits: u32,
621    ) -> SearchTree {
622        let mut old_root = SearchNode::new(false);
623        old_root.action_children.resize(prev_act as usize + 1, None);
624
625        let mut chance_child = SearchNode::new(true);
626        let mut kept = SearchNode::new(false);
627        kept.mean = kept_mean;
628        kept.visits = kept_visits;
629
630        let obs_repr = agent.observation_repr_from_stream(prev_obs_stream);
631        let key = PerceptOutcome::new(obs_repr, prev_rew);
632        chance_child.percept_children.insert(key, kept);
633
634        old_root.action_children[prev_act as usize] = Some(chance_child);
635        SearchTree {
636            root: Some(old_root),
637        }
638    }
639
640    #[test]
641    fn prune_tree_keeps_matching_subtree() {
642        let prev_act = 2u64;
643        let prev_obs_stream = vec![9u64, 2u64, 7u64];
644        let prev_rew: Reward = 3;
645
646        let mut agent = DummyAgent::new(3, ObservationKeyMode::FullStream);
647        let mut tree = build_tree_with_key(&agent, prev_act, &prev_obs_stream, prev_rew, 123.0, 7);
648
649        tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
650
651        let root = tree.root.as_ref().expect("root should exist");
652        assert!(!root.is_chance_node);
653        assert_eq!(root.mean, 123.0);
654        assert_eq!(root.visits, 7);
655    }
656
657    #[test]
658    fn prune_tree_resets_when_action_missing() {
659        let prev_act = 10u64;
660        let prev_obs_stream = vec![1u64];
661        let prev_rew: Reward = 0;
662
663        let mut agent = DummyAgent::new(1, ObservationKeyMode::FullStream);
664        let mut tree = SearchTree::new();
665
666        tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
667
668        let root = tree.root.as_ref().unwrap();
669        assert!(!root.is_chance_node);
670        assert_eq!(root.visits, 0);
671        assert_eq!(root.mean, 0.0);
672    }
673
674    #[test]
675    fn prune_tree_resets_when_percept_key_missing() {
676        let prev_act = 0u64;
677        let prev_obs_stream = vec![1u64, 2u64];
678        let prev_rew: Reward = 1;
679
680        let mut agent = DummyAgent::new(4, ObservationKeyMode::Last);
681
682        // Build tree keyed on a different observation key so pruning misses it.
683        let mut tree = build_tree_with_key(&agent, prev_act, &[9u64], prev_rew, 9.0, 2);
684
685        tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
686
687        let root = tree.root.as_ref().unwrap();
688        assert!(!root.is_chance_node);
689        assert_eq!(root.visits, 0);
690        assert_eq!(root.mean, 0.0);
691    }
692
693    #[test]
694    fn prune_tree_resets_when_reward_mismatch_shares_observation_key() {
695        let prev_act = 1u64;
696        let prev_obs_stream = vec![4u64, 5u64];
697        let kept_rew: Reward = -2;
698        let requested_rew: Reward = 2;
699
700        let mut agent = DummyAgent::new(6, ObservationKeyMode::FullStream);
701        let mut tree = build_tree_with_key(&agent, prev_act, &prev_obs_stream, kept_rew, 77.0, 11);
702
703        tree.prune_tree(&mut agent, &prev_obs_stream, requested_rew, prev_act);
704
705        let root = tree.root.as_ref().unwrap();
706        assert!(!root.is_chance_node);
707        assert_eq!(root.visits, 0);
708        assert_eq!(root.mean, 0.0);
709    }
710}