1use crate::aixi::common::{
8 Action, ObservationKeyMode, PerceptVal, Reward, observation_repr_from_stream,
9};
10use rayon::prelude::*;
11use std::collections::HashMap;
12
13#[derive(Clone, Debug, Eq, PartialEq, Hash)]
21struct PerceptOutcome {
22 observations: Box<[PerceptVal]>,
24 reward: Reward,
26}
27
28impl PerceptOutcome {
29 fn new(observations: Vec<PerceptVal>, reward: Reward) -> Self {
31 Self {
32 observations: observations.into_boxed_slice(),
33 reward,
34 }
35 }
36}
37
38pub trait AgentSimulator: Send {
44 fn get_num_actions(&self) -> usize;
46
47 fn get_num_observation_bits(&self) -> usize;
49
50 fn observation_stream_len(&self) -> usize {
52 1
53 }
54
55 fn observation_key_mode(&self) -> ObservationKeyMode {
57 ObservationKeyMode::FullStream
58 }
59
60 fn observation_repr_from_stream(&self, observations: &[PerceptVal]) -> Vec<PerceptVal> {
62 observation_repr_from_stream(
63 self.observation_key_mode(),
64 observations,
65 self.get_num_observation_bits(),
66 )
67 }
68
69 fn get_num_reward_bits(&self) -> usize;
71
72 fn horizon(&self) -> usize;
74
75 fn max_reward(&self) -> Reward;
77
78 fn min_reward(&self) -> Reward;
80
81 fn reward_offset(&self) -> i64 {
85 0
86 }
87
88 fn get_explore_exploit_ratio(&self) -> f64 {
90 1.0
91 }
92
93 fn discount_gamma(&self) -> f64 {
95 1.0
96 }
97
98 fn model_update_action(&mut self, action: Action);
100
101 fn gen_percept_and_update(&mut self, bits: usize) -> u64;
103
104 fn model_revert(&mut self, steps: usize);
106
107 fn gen_range(&mut self, end: usize) -> usize;
109
110 fn gen_f64(&mut self) -> f64;
112
113 fn boxed_clone(&self) -> Box<dyn AgentSimulator> {
115 self.boxed_clone_with_seed(0)
116 }
117
118 fn boxed_clone_with_seed(&self, seed: u64) -> Box<dyn AgentSimulator>;
120
121 fn norm_reward(&self, reward: f64) -> f64 {
126 let min = self.min_reward() as f64;
127 let max = self.max_reward() as f64;
128 let h = self.horizon() as f64;
129 let gamma = self.discount_gamma().clamp(0.0, 1.0);
130
131 let discount_sum = if (gamma - 1.0).abs() < 1e-9 {
133 h
134 } else {
135 (1.0 - gamma.powi(h as i32)) / (1.0 - gamma)
136 };
137
138 let range = (max - min) * discount_sum;
139 let min_cumulative = min * discount_sum;
140
141 if range.abs() < 1e-9 {
142 0.5
143 } else {
144 (reward - min_cumulative) / range
145 }
146 }
147
148 fn gen_percepts_and_update(&mut self) -> (Vec<PerceptVal>, Reward) {
150 let obs_bits = self.get_num_observation_bits();
151 let obs_len = self.observation_stream_len().max(1);
152 let mut observations = Vec::with_capacity(obs_len);
153 for _ in 0..obs_len {
154 observations.push(self.gen_percept_and_update(obs_bits));
155 }
156
157 let obs_key = self.observation_repr_from_stream(&observations);
158 let rew_bits = self.get_num_reward_bits();
159 let rew_u = self.gen_percept_and_update(rew_bits);
160 let rew = (rew_u as i64) - self.reward_offset();
161 (obs_key, rew)
162 }
163}
164
165#[derive(Clone)]
170pub struct SearchNode {
171 visits: u32,
173 mean: f64,
175 is_chance_node: bool,
177 action_children: Vec<Option<SearchNode>>,
179 percept_children: HashMap<PerceptOutcome, SearchNode>,
181}
182
183impl SearchNode {
184 pub fn new(is_chance_node: bool) -> Self {
186 Self {
187 visits: 0,
188 mean: 0.0,
189 is_chance_node,
190 action_children: Vec::new(),
191 percept_children: HashMap::new(),
192 }
193 }
194
195 pub fn best_action(&self, agent: &mut dyn AgentSimulator) -> Action {
197 let mut best_actions = Vec::new();
198 let mut best_mean = -f64::INFINITY;
199
200 for (action, child) in self.action_children.iter().enumerate() {
201 let Some(child) = child.as_ref() else {
202 continue;
203 };
204 let mean = child.mean;
205 if mean > best_mean {
206 best_mean = mean;
207 best_actions.clear();
208 best_actions.push(action as u64);
209 } else if (mean - best_mean).abs() < 1e-9 {
210 best_actions.push(action as u64);
211 }
212 }
213
214 if best_actions.is_empty() {
215 return 0;
216 }
217
218 let idx = agent.gen_range(best_actions.len());
219 best_actions[idx] as Action
220 }
221
222 fn expectation(&self) -> f64 {
223 self.mean
224 }
225
226 fn apply_delta(&mut self, base: &SearchNode, updated: &SearchNode) {
227 if self.is_chance_node != base.is_chance_node
228 || self.is_chance_node != updated.is_chance_node
229 {
230 return;
231 }
232
233 let base_visits = base.visits as f64;
234 let updated_visits = updated.visits as f64;
235 if updated_visits < base_visits {
236 return;
237 }
238
239 let delta_visits = updated.visits - base.visits;
240 if delta_visits > 0 {
241 let base_sum = base.mean * base_visits;
242 let updated_sum = updated.mean * updated_visits;
243 let delta_sum = updated_sum - base_sum;
244 let total_visits = self.visits + delta_visits;
245 let total_sum = self.mean * (self.visits as f64) + delta_sum;
246 self.visits = total_visits;
247 self.mean = if total_visits > 0 {
248 total_sum / (total_visits as f64)
249 } else {
250 0.0
251 };
252 }
253
254 if self.is_chance_node {
255 for (key, updated_child) in &updated.percept_children {
256 if let Some(base_child) = base.percept_children.get(key) {
257 if let Some(self_child) = self.percept_children.get_mut(key) {
258 self_child.apply_delta(base_child, updated_child);
259 } else {
260 let mut child = SearchNode::new(updated_child.is_chance_node);
261 child.apply_delta(
262 &SearchNode::new(updated_child.is_chance_node),
263 updated_child,
264 );
265 self.percept_children.insert(key.clone(), child);
266 }
267 } else if let Some(self_child) = self.percept_children.get_mut(key) {
268 let empty = SearchNode::new(updated_child.is_chance_node);
269 self_child.apply_delta(&empty, updated_child);
270 } else {
271 let mut child = SearchNode::new(updated_child.is_chance_node);
272 child.apply_delta(
273 &SearchNode::new(updated_child.is_chance_node),
274 updated_child,
275 );
276 self.percept_children.insert(key.clone(), child);
277 }
278 }
279 } else {
280 let max_len = base
281 .action_children
282 .len()
283 .max(updated.action_children.len());
284 if self.action_children.len() < max_len {
285 self.action_children.resize_with(max_len, || None);
286 }
287 for idx in 0..max_len {
288 let base_child = base.action_children.get(idx).and_then(|c| c.as_ref());
289 let updated_child = updated.action_children.get(idx).and_then(|c| c.as_ref());
290 let Some(updated_child) = updated_child else {
291 continue;
292 };
293 match (base_child, self.action_children.get_mut(idx)) {
294 (Some(base_child), Some(Some(self_child))) => {
295 self_child.apply_delta(base_child, updated_child);
296 }
297 (Some(base_child), Some(slot @ None)) => {
298 let mut child = SearchNode::new(updated_child.is_chance_node);
299 child.apply_delta(base_child, updated_child);
300 *slot = Some(child);
301 }
302 (None, Some(Some(self_child))) => {
303 let empty = SearchNode::new(updated_child.is_chance_node);
304 self_child.apply_delta(&empty, updated_child);
305 }
306 (None, Some(slot @ None)) => {
307 let mut child = SearchNode::new(updated_child.is_chance_node);
308 child.apply_delta(
309 &SearchNode::new(updated_child.is_chance_node),
310 updated_child,
311 );
312 *slot = Some(child);
313 }
314 _ => {}
315 }
316 }
317 }
318 }
319
320 fn select_action(&mut self, agent: &mut dyn AgentSimulator) -> (&mut SearchNode, Action) {
322 let num_actions = agent.get_num_actions();
323
324 if self.action_children.len() < num_actions {
325 self.action_children.resize_with(num_actions, || None);
326 }
327
328 let mut unvisited = Vec::new();
329 for a in 0..num_actions {
330 if self.action_children[a].is_none() {
331 unvisited.push(a as u64);
332 }
333 }
334
335 let action;
336 if !unvisited.is_empty() {
337 let idx = agent.gen_range(unvisited.len());
338 action = unvisited[idx];
339 self.action_children[action as usize] = Some(SearchNode::new(true));
340 } else {
341 let c = agent.get_explore_exploit_ratio().max(0.0);
344 let explore_bias = (agent.horizon() as f64) * (agent.max_reward() as f64).max(0.0);
345 let mut best_val = -f64::INFINITY;
346 let mut best_action = 0;
347 let log_visits = (self.visits as f64).ln().max(0.0);
348 for (a, child) in self.action_children.iter().enumerate() {
349 let Some(child) = child.as_ref() else {
350 continue;
351 };
352 let nvisits = child.visits as f64;
353 let val = child.expectation() + explore_bias * ((c * log_visits) / nvisits).sqrt();
354 if val > best_val + agent.gen_f64() * 0.001 {
356 best_val = val;
357 best_action = a as u64;
358 }
359 }
360 action = best_action;
361 }
362
363 agent.model_update_action(action as Action);
364 (
365 self.action_children[action as usize]
366 .as_mut()
367 .expect("missing action child"),
368 action as Action,
369 )
370 }
371
372 pub fn sample(
374 &mut self,
375 agent: &mut dyn AgentSimulator,
376 horizon: usize,
377 total_horizon: usize,
378 ) -> f64 {
379 if horizon == 0 {
380 agent.model_revert(total_horizon);
381 return 0.0;
382 }
383
384 let reward;
385 if self.is_chance_node {
386 let (obs, rew) = agent.gen_percepts_and_update();
387 let key = PerceptOutcome::new(obs, rew);
388 let child = self
389 .percept_children
390 .entry(key)
391 .or_insert_with(|| SearchNode::new(false));
392 reward = (rew as f64)
393 + agent.discount_gamma() * child.sample(agent, horizon - 1, total_horizon);
394 } else if self.visits == 0 {
395 reward = Self::playout(agent, horizon, total_horizon);
396 } else {
397 let (child, _act) = self.select_action(agent);
398 reward = child.sample(agent, horizon, total_horizon);
399 }
400
401 self.mean = (reward + (self.visits as f64) * self.mean) / ((self.visits + 1) as f64);
403 self.visits += 1;
404
405 reward
406 }
407
408 fn playout(agent: &mut dyn AgentSimulator, horizon: usize, total_horizon: usize) -> f64 {
410 let mut total_rew = 0.0;
411 let num_actions = agent.get_num_actions();
412 let gamma = agent.discount_gamma().clamp(0.0, 1.0);
413 let mut discount = 1.0;
414
415 for _ in 0..horizon {
416 let act = agent.gen_range(num_actions);
417 agent.model_update_action(act as Action);
418 let (_key, rew) = agent.gen_percepts_and_update();
419 total_rew += discount * (rew as f64);
420 discount *= gamma;
421 }
422
423 agent.model_revert(total_horizon);
424 total_rew
425 }
426}
427
428pub struct SearchTree {
430 root: Option<SearchNode>,
431}
432
433impl SearchTree {
434 pub fn new() -> Self {
436 Self {
437 root: Some(SearchNode::new(false)),
438 }
439 }
440
441 pub fn search(
443 &mut self,
444 agent: &mut dyn AgentSimulator,
445 prev_obs_stream: &[PerceptVal],
446 prev_rew: Reward,
447 prev_act: u64,
448 samples: usize,
449 ) -> Action {
450 self.prune_tree(agent, prev_obs_stream, prev_rew, prev_act);
451
452 let root = self.root.as_mut().unwrap();
453 let h = agent.horizon();
454 let threads = rayon::current_num_threads().max(1);
455 if samples < 2 || threads < 2 {
456 for _ in 0..samples {
457 root.sample(agent, h, h);
458 }
459 return root.best_action(agent);
460 }
461
462 let workers = threads.min(samples);
463 let base = samples / workers;
464 let extra = samples % workers;
465 let snapshot = root.clone();
466
467 let mut agents = Vec::with_capacity(workers);
468 for i in 0..workers {
469 let seed = agent.gen_f64().to_bits() ^ (i as u64);
470 agents.push(agent.boxed_clone_with_seed(seed));
471 }
472
473 let results: Vec<SearchNode> = agents
474 .into_par_iter()
475 .enumerate()
476 .map(|(i, mut local_agent)| {
477 let mut local_root = snapshot.clone();
478 let iterations = base + usize::from(i < extra);
479 for _ in 0..iterations {
480 local_root.sample(local_agent.as_mut(), h, h);
481 }
482 local_root
483 })
484 .collect();
485
486 for local in &results {
487 root.apply_delta(&snapshot, local);
488 }
489
490 root.best_action(agent)
491 }
492
493 fn prune_tree(
495 &mut self,
496 agent: &mut dyn AgentSimulator,
497 prev_obs_stream: &[PerceptVal],
498 prev_rew: Reward,
499 prev_act: u64,
500 ) {
501 if self.root.is_none() {
502 self.root = Some(SearchNode::new(false));
503 return;
504 }
505
506 let mut old_root = self.root.take().unwrap();
507
508 let action_child_opt = if old_root.action_children.len() > prev_act as usize {
510 old_root.action_children[prev_act as usize].take()
511 } else {
512 None
513 };
514
515 if let Some(mut chance_child) = action_child_opt {
516 let obs_repr = agent.observation_repr_from_stream(prev_obs_stream);
517 let key = PerceptOutcome::new(obs_repr, prev_rew);
518
519 if let Some(action_child) = chance_child.percept_children.remove(&key) {
520 self.root = Some(action_child);
521 } else {
522 self.root = Some(SearchNode::new(false));
523 }
524 } else {
525 self.root = Some(SearchNode::new(false));
526 }
527 }
528}
529
530impl Default for SearchTree {
531 fn default() -> Self {
532 Self::new()
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539 use crate::aixi::common::ObservationKeyMode;
540
541 #[derive(Clone)]
542 struct DummyAgent {
543 obs_bits: usize,
544 rew_bits: usize,
545 horizon: usize,
546 min_reward: Reward,
547 max_reward: Reward,
548 key_mode: ObservationKeyMode,
549 }
550
551 impl DummyAgent {
552 fn new(obs_bits: usize, key_mode: ObservationKeyMode) -> Self {
553 Self {
554 obs_bits,
555 rew_bits: 8,
556 horizon: 5,
557 min_reward: -1,
558 max_reward: 1,
559 key_mode,
560 }
561 }
562 }
563
564 impl AgentSimulator for DummyAgent {
565 fn get_num_actions(&self) -> usize {
566 4
567 }
568
569 fn get_num_observation_bits(&self) -> usize {
570 self.obs_bits
571 }
572
573 fn observation_key_mode(&self) -> ObservationKeyMode {
574 self.key_mode
575 }
576
577 fn get_num_reward_bits(&self) -> usize {
578 self.rew_bits
579 }
580
581 fn horizon(&self) -> usize {
582 self.horizon
583 }
584
585 fn max_reward(&self) -> Reward {
586 self.max_reward
587 }
588
589 fn min_reward(&self) -> Reward {
590 self.min_reward
591 }
592
593 fn model_update_action(&mut self, _action: Action) {}
594
595 fn gen_percept_and_update(&mut self, _bits: usize) -> u64 {
596 0
597 }
598
599 fn model_revert(&mut self, _steps: usize) {}
600
601 fn gen_range(&mut self, _end: usize) -> usize {
602 0
603 }
604
605 fn gen_f64(&mut self) -> f64 {
606 0.0
607 }
608
609 fn boxed_clone_with_seed(&self, _seed: u64) -> Box<dyn AgentSimulator> {
610 Box::new(self.clone())
611 }
612 }
613
614 fn build_tree_with_key(
615 agent: &DummyAgent,
616 prev_act: u64,
617 prev_obs_stream: &[PerceptVal],
618 prev_rew: Reward,
619 kept_mean: f64,
620 kept_visits: u32,
621 ) -> SearchTree {
622 let mut old_root = SearchNode::new(false);
623 old_root.action_children.resize(prev_act as usize + 1, None);
624
625 let mut chance_child = SearchNode::new(true);
626 let mut kept = SearchNode::new(false);
627 kept.mean = kept_mean;
628 kept.visits = kept_visits;
629
630 let obs_repr = agent.observation_repr_from_stream(prev_obs_stream);
631 let key = PerceptOutcome::new(obs_repr, prev_rew);
632 chance_child.percept_children.insert(key, kept);
633
634 old_root.action_children[prev_act as usize] = Some(chance_child);
635 SearchTree {
636 root: Some(old_root),
637 }
638 }
639
640 #[test]
641 fn prune_tree_keeps_matching_subtree() {
642 let prev_act = 2u64;
643 let prev_obs_stream = vec![9u64, 2u64, 7u64];
644 let prev_rew: Reward = 3;
645
646 let mut agent = DummyAgent::new(3, ObservationKeyMode::FullStream);
647 let mut tree = build_tree_with_key(&agent, prev_act, &prev_obs_stream, prev_rew, 123.0, 7);
648
649 tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
650
651 let root = tree.root.as_ref().expect("root should exist");
652 assert!(!root.is_chance_node);
653 assert_eq!(root.mean, 123.0);
654 assert_eq!(root.visits, 7);
655 }
656
657 #[test]
658 fn prune_tree_resets_when_action_missing() {
659 let prev_act = 10u64;
660 let prev_obs_stream = vec![1u64];
661 let prev_rew: Reward = 0;
662
663 let mut agent = DummyAgent::new(1, ObservationKeyMode::FullStream);
664 let mut tree = SearchTree::new();
665
666 tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
667
668 let root = tree.root.as_ref().unwrap();
669 assert!(!root.is_chance_node);
670 assert_eq!(root.visits, 0);
671 assert_eq!(root.mean, 0.0);
672 }
673
674 #[test]
675 fn prune_tree_resets_when_percept_key_missing() {
676 let prev_act = 0u64;
677 let prev_obs_stream = vec![1u64, 2u64];
678 let prev_rew: Reward = 1;
679
680 let mut agent = DummyAgent::new(4, ObservationKeyMode::Last);
681
682 let mut tree = build_tree_with_key(&agent, prev_act, &[9u64], prev_rew, 9.0, 2);
684
685 tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
686
687 let root = tree.root.as_ref().unwrap();
688 assert!(!root.is_chance_node);
689 assert_eq!(root.visits, 0);
690 assert_eq!(root.mean, 0.0);
691 }
692
693 #[test]
694 fn prune_tree_resets_when_reward_mismatch_shares_observation_key() {
695 let prev_act = 1u64;
696 let prev_obs_stream = vec![4u64, 5u64];
697 let kept_rew: Reward = -2;
698 let requested_rew: Reward = 2;
699
700 let mut agent = DummyAgent::new(6, ObservationKeyMode::FullStream);
701 let mut tree = build_tree_with_key(&agent, prev_act, &prev_obs_stream, kept_rew, 77.0, 11);
702
703 tree.prune_tree(&mut agent, &prev_obs_stream, requested_rew, prev_act);
704
705 let root = tree.root.as_ref().unwrap();
706 assert!(!root.is_chance_node);
707 assert_eq!(root.visits, 0);
708 assert_eq!(root.mean, 0.0);
709 }
710}