Skip to main content

infotheory/aixi/
mcts.rs

1//! Monte Carlo Tree Search (MCTS) for AIXI.
2//!
3//! This module implements the planning component of MC-AIXI. It use an upper
4//! confidence bounds applied to trees (UCT) approach to select actions
5//! by simulating future interactions with a world model.
6
7use crate::aixi::common::{
8    Action, ObservationKeyMode, PerceptVal, Reward, observation_repr_from_stream,
9};
10use rayon::prelude::*;
11use std::collections::HashMap;
12
13/// Hash key for a sampled percept outcome at a chance node.
14///
15/// Both the observation representation and the immediate reward are required
16/// to identify the correct continuation subtree for generic environments.
17/// Some environments can emit the same observation alongside different
18/// rewards, so observation-only keys would incorrectly merge distinct
19/// successor states during search-tree reuse.
20#[derive(Clone, Debug, Eq, PartialEq, Hash)]
21struct PerceptOutcome {
22    /// Observation symbols used for chance-node branching.
23    observations: Box<[PerceptVal]>,
24    /// Immediate reward observed on the sampled edge.
25    reward: Reward,
26}
27
28impl PerceptOutcome {
29    /// Creates a compact percept key from an observation stream and reward.
30    fn new(observations: Vec<PerceptVal>, reward: Reward) -> Self {
31        Self {
32            observations: observations.into_boxed_slice(),
33            reward,
34        }
35    }
36}
37
38/// Interface for an agent that can be simulated during MCTS.
39///
40/// This trait allows the MCTS algorithm to interact with an agent
41/// (like `Agent` in `agent.rs`) to perform "imagined" actions and
42/// receive "imagined" percepts during planning.
43pub trait AgentSimulator: Send {
44    /// Returns the number of possible actions the agent can perform.
45    fn get_num_actions(&self) -> usize;
46
47    /// Returns the bit-width used to encode observations.
48    fn get_num_observation_bits(&self) -> usize;
49
50    /// Returns the number of observation symbols per action.
51    fn observation_stream_len(&self) -> usize {
52        1
53    }
54
55    /// Returns the observation key mode for search-tree branching.
56    fn observation_key_mode(&self) -> ObservationKeyMode {
57        ObservationKeyMode::FullStream
58    }
59
60    /// Returns the observation representation used for tree branching.
61    fn observation_repr_from_stream(&self, observations: &[PerceptVal]) -> Vec<PerceptVal> {
62        observation_repr_from_stream(
63            self.observation_key_mode(),
64            observations,
65            self.get_num_observation_bits(),
66        )
67    }
68
69    /// Returns the bit-width used to encode rewards.
70    fn get_num_reward_bits(&self) -> usize;
71
72    /// Returns the planning horizon (depth of simulations).
73    fn horizon(&self) -> usize;
74
75    /// Returns the maximum possible reward value.
76    fn max_reward(&self) -> Reward;
77
78    /// Returns the minimum possible reward value.
79    fn min_reward(&self) -> Reward;
80
81    /// Returns the reward offset used to ensure encoded rewards are non-negative.
82    ///
83    /// Paper-compatible encoding uses unsigned reward bits and shifts rewards by an offset.
84    fn reward_offset(&self) -> i64 {
85        0
86    }
87
88    /// Returns the exploration-exploitation constant (often denoted as C).
89    fn get_explore_exploit_ratio(&self) -> f64 {
90        1.0
91    }
92
93    /// Returns the discount factor for future rewards.
94    fn discount_gamma(&self) -> f64 {
95        1.0
96    }
97
98    /// Updates the internal model state with a simulated action.
99    fn model_update_action(&mut self, action: Action);
100
101    /// Generates a simulated percept and updates the model state.
102    fn gen_percept_and_update(&mut self, bits: usize) -> u64;
103
104    /// Marks the start of a new simulation rollout.
105    fn begin_simulation(&mut self) {}
106
107    /// Reverts the model state to a previous point in the simulation.
108    fn model_revert(&mut self, steps: usize);
109
110    /// Generates a random value in `[0, end)`.
111    fn gen_range(&mut self, end: usize) -> usize;
112
113    /// Generates a random `f64` in `[0, 1)`.
114    fn gen_f64(&mut self) -> f64;
115
116    /// Creates a boxed clone of this simulator for parallel search.
117    fn boxed_clone(&self) -> Box<dyn AgentSimulator> {
118        self.boxed_clone_with_seed(0)
119    }
120
121    /// Creates a boxed clone of this simulator, re-seeding any RNG state.
122    fn boxed_clone_with_seed(&self, seed: u64) -> Box<dyn AgentSimulator>;
123
124    /// Normalizes a reward value to [0, 1] based on the agent's range and horizon.
125    ///
126    /// For discounted rewards, the cumulative range is `sum_{t=0}^{h-1} gamma^t * (max - min)`.
127    /// Similarly, the minimum cumulative reward is `sum_{t=0}^{h-1} gamma^t * min`.
128    fn norm_reward(&self, reward: f64) -> f64 {
129        let min = self.min_reward() as f64;
130        let max = self.max_reward() as f64;
131        let h = self.horizon() as f64;
132        let gamma = self.discount_gamma().clamp(0.0, 1.0);
133
134        // Discounted sum factor: sum_{t=0}^{h-1} gamma^t = (1 - gamma^h) / (1 - gamma) for gamma != 1
135        let discount_sum = if (gamma - 1.0).abs() < 1e-9 {
136            h
137        } else {
138            (1.0 - gamma.powi(h as i32)) / (1.0 - gamma)
139        };
140
141        let range = (max - min) * discount_sum;
142        let min_cumulative = min * discount_sum;
143
144        if range.abs() < 1e-9 {
145            0.5
146        } else {
147            (reward - min_cumulative) / range
148        }
149    }
150
151    /// Helper to generate a percept stream, update the model, and return a search key + reward.
152    fn gen_percepts_and_update(&mut self) -> (Vec<PerceptVal>, Reward) {
153        let obs_bits = self.get_num_observation_bits();
154        let obs_len = self.observation_stream_len().max(1);
155        let mut observations = Vec::with_capacity(obs_len);
156        for _ in 0..obs_len {
157            observations.push(self.gen_percept_and_update(obs_bits));
158        }
159
160        let obs_key = self.observation_repr_from_stream(&observations);
161        let rew_bits = self.get_num_reward_bits();
162        let rew_u = self.gen_percept_and_update(rew_bits);
163        let rew = (rew_u as i64) - self.reward_offset();
164        (obs_key, rew)
165    }
166}
167
168/// A node in the MCTS search tree.
169///
170/// Nodes can be either OR-nodes (representing an agent choice) or
171/// chance nodes (representing an environment response).
172#[derive(Clone)]
173pub struct SearchNode {
174    /// Number of times this node has been visited during search.
175    visits: u32,
176    /// The current mean reward estimated for this node.
177    mean: f64,
178    /// Whether this is a chance node (observation/reward) rather than an action node.
179    is_chance_node: bool,
180    /// Children indexed by action (action nodes only).
181    action_children: Vec<Option<SearchNode>>,
182    /// Children indexed by percept outcome (chance nodes only).
183    percept_children: HashMap<PerceptOutcome, SearchNode>,
184}
185
186impl SearchNode {
187    /// Creates a new `SearchNode`.
188    pub fn new(is_chance_node: bool) -> Self {
189        Self {
190            visits: 0,
191            mean: 0.0,
192            is_chance_node,
193            action_children: Vec::new(),
194            percept_children: HashMap::new(),
195        }
196    }
197
198    /// Selects the best action from this node based on accumulated mean rewards.
199    pub fn best_action(&self, agent: &mut dyn AgentSimulator) -> Action {
200        let mut best_actions = Vec::new();
201        let mut best_mean = -f64::INFINITY;
202
203        for (action, child) in self.action_children.iter().enumerate() {
204            let Some(child) = child.as_ref() else {
205                continue;
206            };
207            let mean = child.mean;
208            if mean > best_mean {
209                best_mean = mean;
210                best_actions.clear();
211                best_actions.push(action as u64);
212            } else if (mean - best_mean).abs() < 1e-9 {
213                best_actions.push(action as u64);
214            }
215        }
216
217        if best_actions.is_empty() {
218            return 0;
219        }
220
221        let idx = agent.gen_range(best_actions.len());
222        best_actions[idx] as Action
223    }
224
225    fn expectation(&self) -> f64 {
226        self.mean
227    }
228
229    fn apply_delta(&mut self, base: &SearchNode, updated: &SearchNode) {
230        if self.is_chance_node != base.is_chance_node
231            || self.is_chance_node != updated.is_chance_node
232        {
233            return;
234        }
235
236        let base_visits = base.visits as f64;
237        let updated_visits = updated.visits as f64;
238        if updated_visits < base_visits {
239            return;
240        }
241
242        let delta_visits = updated.visits - base.visits;
243        if delta_visits > 0 {
244            let base_sum = base.mean * base_visits;
245            let updated_sum = updated.mean * updated_visits;
246            let delta_sum = updated_sum - base_sum;
247            let total_visits = self.visits + delta_visits;
248            let total_sum = self.mean * (self.visits as f64) + delta_sum;
249            self.visits = total_visits;
250            self.mean = if total_visits > 0 {
251                total_sum / (total_visits as f64)
252            } else {
253                0.0
254            };
255        }
256
257        if self.is_chance_node {
258            for (key, updated_child) in &updated.percept_children {
259                if let Some(base_child) = base.percept_children.get(key) {
260                    if let Some(self_child) = self.percept_children.get_mut(key) {
261                        self_child.apply_delta(base_child, updated_child);
262                    } else {
263                        let mut child = SearchNode::new(updated_child.is_chance_node);
264                        child.apply_delta(
265                            &SearchNode::new(updated_child.is_chance_node),
266                            updated_child,
267                        );
268                        self.percept_children.insert(key.clone(), child);
269                    }
270                } else if let Some(self_child) = self.percept_children.get_mut(key) {
271                    let empty = SearchNode::new(updated_child.is_chance_node);
272                    self_child.apply_delta(&empty, updated_child);
273                } else {
274                    let mut child = SearchNode::new(updated_child.is_chance_node);
275                    child.apply_delta(
276                        &SearchNode::new(updated_child.is_chance_node),
277                        updated_child,
278                    );
279                    self.percept_children.insert(key.clone(), child);
280                }
281            }
282        } else {
283            let max_len = base
284                .action_children
285                .len()
286                .max(updated.action_children.len());
287            if self.action_children.len() < max_len {
288                self.action_children.resize_with(max_len, || None);
289            }
290            for idx in 0..max_len {
291                let base_child = base.action_children.get(idx).and_then(|c| c.as_ref());
292                let updated_child = updated.action_children.get(idx).and_then(|c| c.as_ref());
293                let Some(updated_child) = updated_child else {
294                    continue;
295                };
296                match (base_child, self.action_children.get_mut(idx)) {
297                    (Some(base_child), Some(Some(self_child))) => {
298                        self_child.apply_delta(base_child, updated_child);
299                    }
300                    (Some(base_child), Some(slot @ None)) => {
301                        let mut child = SearchNode::new(updated_child.is_chance_node);
302                        child.apply_delta(base_child, updated_child);
303                        *slot = Some(child);
304                    }
305                    (None, Some(Some(self_child))) => {
306                        let empty = SearchNode::new(updated_child.is_chance_node);
307                        self_child.apply_delta(&empty, updated_child);
308                    }
309                    (None, Some(slot @ None)) => {
310                        let mut child = SearchNode::new(updated_child.is_chance_node);
311                        child.apply_delta(
312                            &SearchNode::new(updated_child.is_chance_node),
313                            updated_child,
314                        );
315                        *slot = Some(child);
316                    }
317                    _ => {}
318                }
319            }
320        }
321    }
322
323    /// Selects an action to explore, potentially creating a new child node.
324    fn select_action(&mut self, agent: &mut dyn AgentSimulator) -> (&mut SearchNode, Action) {
325        let num_actions = agent.get_num_actions();
326
327        if self.action_children.len() < num_actions {
328            self.action_children.resize_with(num_actions, || None);
329        }
330
331        let mut unvisited = Vec::new();
332        for a in 0..num_actions {
333            if self.action_children[a].is_none() {
334                unvisited.push(a as u64);
335            }
336        }
337
338        let action;
339        if !unvisited.is_empty() {
340            let idx = agent.gen_range(unvisited.len());
341            action = unvisited[idx];
342            self.action_children[action as usize] = Some(SearchNode::new(true));
343        } else {
344            // Match reference MC-AIXI UCB scaling:
345            // priority = E[return] + horizon*max_reward*sqrt(C*log(N)/n)
346            let c = agent.get_explore_exploit_ratio().max(0.0);
347            let explore_bias = (agent.horizon() as f64) * (agent.max_reward() as f64).max(0.0);
348            let mut best_val = -f64::INFINITY;
349            let mut best_action = None;
350            let mut num_maximal_actions = 0usize;
351            let log_visits = (self.visits as f64).ln().max(0.0);
352            for (a, child) in self.action_children.iter().enumerate() {
353                let Some(child) = child.as_ref() else {
354                    continue;
355                };
356                let nvisits = child.visits as f64;
357                let val = child.expectation() + explore_bias * ((c * log_visits) / nvisits).sqrt();
358                debug_assert!(
359                    val.is_finite(),
360                    "UCB score must be finite for visited MC-AIXI action children"
361                );
362                match val.total_cmp(&best_val) {
363                    std::cmp::Ordering::Greater => {
364                        best_val = val;
365                        best_action = Some(a as u64);
366                        num_maximal_actions = 1;
367                    }
368                    std::cmp::Ordering::Equal => {
369                        num_maximal_actions += 1;
370                        // Tie-break from "A Monte-Carlo AIXI Approximation":
371                        // choose uniformly among maximal actions.
372                        // Reservoir sampling keeps this O(1) in memory without a tie list.
373                        if agent.gen_range(num_maximal_actions) == 0 {
374                            best_action = Some(a as u64);
375                        }
376                    }
377                    std::cmp::Ordering::Less => {}
378                }
379            }
380            action = best_action.expect("visited MC-AIXI node must have a maximal action");
381        }
382
383        agent.model_update_action(action as Action);
384        (
385            self.action_children[action as usize]
386                .as_mut()
387                .expect("missing action child"),
388            action as Action,
389        )
390    }
391
392    /// Performs a single simulation (sample) from this node.
393    pub fn sample(
394        &mut self,
395        agent: &mut dyn AgentSimulator,
396        horizon: usize,
397        total_horizon: usize,
398    ) -> f64 {
399        if horizon == 0 {
400            agent.model_revert(total_horizon);
401            return 0.0;
402        }
403
404        let reward;
405        if self.is_chance_node {
406            let (obs, rew) = agent.gen_percepts_and_update();
407            let key = PerceptOutcome::new(obs, rew);
408            let child = self
409                .percept_children
410                .entry(key)
411                .or_insert_with(|| SearchNode::new(false));
412            reward = (rew as f64)
413                + agent.discount_gamma() * child.sample(agent, horizon - 1, total_horizon);
414        } else if self.visits == 0 {
415            reward = Self::playout(agent, horizon, total_horizon);
416        } else {
417            let (child, _act) = self.select_action(agent);
418            reward = child.sample(agent, horizon, total_horizon);
419        }
420
421        // Update mean logic:
422        self.mean = (reward + (self.visits as f64) * self.mean) / ((self.visits + 1) as f64);
423        self.visits += 1;
424
425        reward
426    }
427
428    /// Performs a randomized simulation until the horizon is reached.
429    fn playout(agent: &mut dyn AgentSimulator, horizon: usize, total_horizon: usize) -> f64 {
430        let mut total_rew = 0.0;
431        let num_actions = agent.get_num_actions();
432        let gamma = agent.discount_gamma().clamp(0.0, 1.0);
433        let mut discount = 1.0;
434
435        for _ in 0..horizon {
436            let act = agent.gen_range(num_actions);
437            agent.model_update_action(act as Action);
438            let (_key, rew) = agent.gen_percepts_and_update();
439            total_rew += discount * (rew as f64);
440            discount *= gamma;
441        }
442
443        agent.model_revert(total_horizon);
444        total_rew
445    }
446}
447
448/// Manages the MCTS tree and provides the `search` entry point.
449pub struct SearchTree {
450    root: Option<SearchNode>,
451}
452
453impl SearchTree {
454    /// Creates a new `SearchTree`.
455    pub fn new() -> Self {
456        Self {
457            root: Some(SearchNode::new(false)),
458        }
459    }
460
461    /// Performs several MCTS simulations to find the best next action.
462    pub fn search(
463        &mut self,
464        agent: &mut dyn AgentSimulator,
465        prev_obs_stream: &[PerceptVal],
466        prev_rew: Reward,
467        prev_act: u64,
468        samples: usize,
469    ) -> Action {
470        self.prune_tree(agent, prev_obs_stream, prev_rew, prev_act);
471
472        let root = self.root.as_mut().unwrap();
473        let h = agent.horizon();
474        let threads = rayon::current_num_threads().max(1);
475        if samples < 2 || threads < 2 {
476            for _ in 0..samples {
477                agent.begin_simulation();
478                root.sample(agent, h, h);
479            }
480            return root.best_action(agent);
481        }
482
483        let workers = threads.min(samples);
484        let base = samples / workers;
485        let extra = samples % workers;
486        let snapshot = root.clone();
487
488        let mut agents = Vec::with_capacity(workers);
489        for i in 0..workers {
490            let seed = agent.gen_f64().to_bits() ^ (i as u64);
491            agents.push(agent.boxed_clone_with_seed(seed));
492        }
493
494        let results: Vec<SearchNode> = agents
495            .into_par_iter()
496            .enumerate()
497            .map(|(i, mut local_agent)| {
498                let mut local_root = snapshot.clone();
499                let iterations = base + usize::from(i < extra);
500                for _ in 0..iterations {
501                    local_agent.begin_simulation();
502                    local_root.sample(local_agent.as_mut(), h, h);
503                }
504                local_root
505            })
506            .collect();
507
508        for local in &results {
509            root.apply_delta(&snapshot, local);
510        }
511
512        root.best_action(agent)
513    }
514
515    /// Prunes the tree, keeping only relevant subtrees based on the previous interaction.
516    fn prune_tree(
517        &mut self,
518        agent: &mut dyn AgentSimulator,
519        prev_obs_stream: &[PerceptVal],
520        prev_rew: Reward,
521        prev_act: u64,
522    ) {
523        if self.root.is_none() {
524            self.root = Some(SearchNode::new(false));
525            return;
526        }
527
528        let mut old_root = self.root.take().unwrap();
529
530        // Find chance child (prev_act)
531        let action_child_opt = if old_root.action_children.len() > prev_act as usize {
532            old_root.action_children[prev_act as usize].take()
533        } else {
534            None
535        };
536
537        if let Some(mut chance_child) = action_child_opt {
538            let obs_repr = agent.observation_repr_from_stream(prev_obs_stream);
539            let key = PerceptOutcome::new(obs_repr, prev_rew);
540
541            if let Some(action_child) = chance_child.percept_children.remove(&key) {
542                self.root = Some(action_child);
543            } else {
544                self.root = Some(SearchNode::new(false));
545            }
546        } else {
547            self.root = Some(SearchNode::new(false));
548        }
549    }
550}
551
552impl Default for SearchTree {
553    fn default() -> Self {
554        Self::new()
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561    use crate::aixi::common::ObservationKeyMode;
562    use std::sync::{
563        Arc,
564        atomic::{AtomicUsize, Ordering},
565    };
566
567    #[derive(Clone)]
568    struct DummyAgent {
569        obs_bits: usize,
570        rew_bits: usize,
571        horizon: usize,
572        min_reward: Reward,
573        max_reward: Reward,
574        key_mode: ObservationKeyMode,
575    }
576
577    impl DummyAgent {
578        fn new(obs_bits: usize, key_mode: ObservationKeyMode) -> Self {
579            Self {
580                obs_bits,
581                rew_bits: 8,
582                horizon: 5,
583                min_reward: -1,
584                max_reward: 1,
585                key_mode,
586            }
587        }
588    }
589
590    impl AgentSimulator for DummyAgent {
591        fn get_num_actions(&self) -> usize {
592            4
593        }
594
595        fn get_num_observation_bits(&self) -> usize {
596            self.obs_bits
597        }
598
599        fn observation_key_mode(&self) -> ObservationKeyMode {
600            self.key_mode
601        }
602
603        fn get_num_reward_bits(&self) -> usize {
604            self.rew_bits
605        }
606
607        fn horizon(&self) -> usize {
608            self.horizon
609        }
610
611        fn max_reward(&self) -> Reward {
612            self.max_reward
613        }
614
615        fn min_reward(&self) -> Reward {
616            self.min_reward
617        }
618
619        fn model_update_action(&mut self, _action: Action) {}
620
621        fn gen_percept_and_update(&mut self, _bits: usize) -> u64 {
622            0
623        }
624
625        fn model_revert(&mut self, _steps: usize) {}
626
627        fn gen_range(&mut self, _end: usize) -> usize {
628            0
629        }
630
631        fn gen_f64(&mut self) -> f64 {
632            0.0
633        }
634
635        fn boxed_clone_with_seed(&self, _seed: u64) -> Box<dyn AgentSimulator> {
636            Box::new(self.clone())
637        }
638    }
639
640    fn build_tree_with_key(
641        agent: &DummyAgent,
642        prev_act: u64,
643        prev_obs_stream: &[PerceptVal],
644        prev_rew: Reward,
645        kept_mean: f64,
646        kept_visits: u32,
647    ) -> SearchTree {
648        let mut old_root = SearchNode::new(false);
649        old_root.action_children.resize(prev_act as usize + 1, None);
650
651        let mut chance_child = SearchNode::new(true);
652        let mut kept = SearchNode::new(false);
653        kept.mean = kept_mean;
654        kept.visits = kept_visits;
655
656        let obs_repr = agent.observation_repr_from_stream(prev_obs_stream);
657        let key = PerceptOutcome::new(obs_repr, prev_rew);
658        chance_child.percept_children.insert(key, kept);
659
660        old_root.action_children[prev_act as usize] = Some(chance_child);
661        SearchTree {
662            root: Some(old_root),
663        }
664    }
665
666    #[test]
667    fn prune_tree_keeps_matching_subtree() {
668        let prev_act = 2u64;
669        let prev_obs_stream = vec![9u64, 2u64, 7u64];
670        let prev_rew: Reward = 3;
671
672        let mut agent = DummyAgent::new(3, ObservationKeyMode::FullStream);
673        let mut tree = build_tree_with_key(&agent, prev_act, &prev_obs_stream, prev_rew, 123.0, 7);
674
675        tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
676
677        let root = tree.root.as_ref().expect("root should exist");
678        assert!(!root.is_chance_node);
679        assert_eq!(root.mean, 123.0);
680        assert_eq!(root.visits, 7);
681    }
682
683    #[test]
684    fn prune_tree_resets_when_action_missing() {
685        let prev_act = 10u64;
686        let prev_obs_stream = vec![1u64];
687        let prev_rew: Reward = 0;
688
689        let mut agent = DummyAgent::new(1, ObservationKeyMode::FullStream);
690        let mut tree = SearchTree::new();
691
692        tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
693
694        let root = tree.root.as_ref().unwrap();
695        assert!(!root.is_chance_node);
696        assert_eq!(root.visits, 0);
697        assert_eq!(root.mean, 0.0);
698    }
699
700    #[test]
701    fn prune_tree_resets_when_percept_key_missing() {
702        let prev_act = 0u64;
703        let prev_obs_stream = vec![1u64, 2u64];
704        let prev_rew: Reward = 1;
705
706        let mut agent = DummyAgent::new(4, ObservationKeyMode::Last);
707
708        // Build tree keyed on a different observation key so pruning misses it.
709        let mut tree = build_tree_with_key(&agent, prev_act, &[9u64], prev_rew, 9.0, 2);
710
711        tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
712
713        let root = tree.root.as_ref().unwrap();
714        assert!(!root.is_chance_node);
715        assert_eq!(root.visits, 0);
716        assert_eq!(root.mean, 0.0);
717    }
718
719    #[test]
720    fn prune_tree_resets_when_reward_mismatch_shares_observation_key() {
721        let prev_act = 1u64;
722        let prev_obs_stream = vec![4u64, 5u64];
723        let kept_rew: Reward = -2;
724        let requested_rew: Reward = 2;
725
726        let mut agent = DummyAgent::new(6, ObservationKeyMode::FullStream);
727        let mut tree = build_tree_with_key(&agent, prev_act, &prev_obs_stream, kept_rew, 77.0, 11);
728
729        tree.prune_tree(&mut agent, &prev_obs_stream, requested_rew, prev_act);
730
731        let root = tree.root.as_ref().unwrap();
732        assert!(!root.is_chance_node);
733        assert_eq!(root.visits, 0);
734        assert_eq!(root.mean, 0.0);
735    }
736
737    #[derive(Clone)]
738    struct BeginCountingAgent {
739        begins: Arc<AtomicUsize>,
740    }
741
742    impl AgentSimulator for BeginCountingAgent {
743        fn get_num_actions(&self) -> usize {
744            2
745        }
746
747        fn get_num_observation_bits(&self) -> usize {
748            1
749        }
750
751        fn get_num_reward_bits(&self) -> usize {
752            1
753        }
754
755        fn horizon(&self) -> usize {
756            1
757        }
758
759        fn max_reward(&self) -> Reward {
760            1
761        }
762
763        fn min_reward(&self) -> Reward {
764            0
765        }
766
767        fn model_update_action(&mut self, _action: Action) {}
768
769        fn gen_percept_and_update(&mut self, _bits: usize) -> u64 {
770            0
771        }
772
773        fn begin_simulation(&mut self) {
774            self.begins.fetch_add(1, Ordering::Relaxed);
775        }
776
777        fn model_revert(&mut self, _steps: usize) {}
778
779        fn gen_range(&mut self, _end: usize) -> usize {
780            0
781        }
782
783        fn gen_f64(&mut self) -> f64 {
784            0.0
785        }
786
787        fn boxed_clone_with_seed(&self, _seed: u64) -> Box<dyn AgentSimulator> {
788            Box::new(self.clone())
789        }
790    }
791
792    #[test]
793    fn search_calls_begin_simulation_for_each_rollout() {
794        let begins = Arc::new(AtomicUsize::new(0));
795        let mut agent = BeginCountingAgent {
796            begins: begins.clone(),
797        };
798        let mut tree = SearchTree::new();
799
800        let _ = tree.search(&mut agent, &[0], 0, 0, 5);
801        assert_eq!(begins.load(Ordering::Relaxed), 5);
802    }
803
804    #[derive(Clone)]
805    struct TieBreakAgent {
806        next_range: Arc<AtomicUsize>,
807    }
808
809    impl AgentSimulator for TieBreakAgent {
810        fn get_num_actions(&self) -> usize {
811            4
812        }
813
814        fn get_num_observation_bits(&self) -> usize {
815            1
816        }
817
818        fn get_num_reward_bits(&self) -> usize {
819            1
820        }
821
822        fn horizon(&self) -> usize {
823            1
824        }
825
826        fn max_reward(&self) -> Reward {
827            1
828        }
829
830        fn min_reward(&self) -> Reward {
831            0
832        }
833
834        fn get_explore_exploit_ratio(&self) -> f64 {
835            0.0
836        }
837
838        fn model_update_action(&mut self, _action: Action) {}
839
840        fn gen_percept_and_update(&mut self, _bits: usize) -> u64 {
841            0
842        }
843
844        fn model_revert(&mut self, _steps: usize) {}
845
846        fn gen_range(&mut self, end: usize) -> usize {
847            self.next_range
848                .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |value| {
849                    Some(value.saturating_sub(1))
850                })
851                .expect("range source should be initialized")
852                % end
853        }
854
855        fn gen_f64(&mut self) -> f64 {
856            0.0
857        }
858
859        fn boxed_clone_with_seed(&self, _seed: u64) -> Box<dyn AgentSimulator> {
860            Box::new(self.clone())
861        }
862    }
863
864    #[test]
865    fn select_action_uses_uniform_tie_break_for_maximal_ucb_actions() {
866        let mut node = SearchNode::new(false);
867        node.visits = 16;
868        node.action_children = vec![
869            Some(SearchNode {
870                visits: 5,
871                mean: 0.1,
872                is_chance_node: true,
873                action_children: Vec::new(),
874                percept_children: HashMap::new(),
875            }),
876            Some(SearchNode {
877                visits: 5,
878                mean: 0.9,
879                is_chance_node: true,
880                action_children: Vec::new(),
881                percept_children: HashMap::new(),
882            }),
883            Some(SearchNode {
884                visits: 5,
885                mean: 0.2,
886                is_chance_node: true,
887                action_children: Vec::new(),
888                percept_children: HashMap::new(),
889            }),
890            Some(SearchNode {
891                visits: 5,
892                mean: 0.9,
893                is_chance_node: true,
894                action_children: Vec::new(),
895                percept_children: HashMap::new(),
896            }),
897        ];
898
899        let mut agent = TieBreakAgent {
900            next_range: Arc::new(AtomicUsize::new(0)),
901        };
902
903        let (_child, action) = node.select_action(&mut agent);
904        assert_eq!(
905            action, 3,
906            "exactly tied maximal UCB actions should be chosen uniformly; scripted RNG selected the later maximal action"
907        );
908    }
909}