1use crate::aixi::common::{
8 Action, ObservationKeyMode, PerceptVal, Reward, observation_repr_from_stream,
9};
10use rayon::prelude::*;
11use std::collections::HashMap;
12
13#[derive(Clone, Debug, Eq, PartialEq, Hash)]
14struct PerceptOutcome {
15 observations: Vec<PerceptVal>,
16 reward: Reward,
17}
18
19pub trait AgentSimulator: Send + Sync {
25 fn get_num_actions(&self) -> usize;
27
28 fn get_num_observation_bits(&self) -> usize;
30
31 fn observation_stream_len(&self) -> usize {
33 1
34 }
35
36 fn observation_key_mode(&self) -> ObservationKeyMode {
38 ObservationKeyMode::FullStream
39 }
40
41 fn observation_repr_from_stream(&self, observations: &[PerceptVal]) -> Vec<PerceptVal> {
43 observation_repr_from_stream(
44 self.observation_key_mode(),
45 observations,
46 self.get_num_observation_bits(),
47 )
48 }
49
50 fn get_num_reward_bits(&self) -> usize;
52
53 fn horizon(&self) -> usize;
55
56 fn max_reward(&self) -> Reward;
58
59 fn min_reward(&self) -> Reward;
61
62 fn reward_offset(&self) -> i64 {
66 0
67 }
68
69 fn get_explore_exploit_ratio(&self) -> f64 {
71 1.0
72 }
73
74 fn discount_gamma(&self) -> f64 {
76 1.0
77 }
78
79 fn model_update_action(&mut self, action: Action);
81
82 fn gen_percept_and_update(&mut self, bits: usize) -> u64;
84
85 fn model_revert(&mut self, steps: usize);
87
88 fn gen_range(&mut self, end: usize) -> usize;
90
91 fn gen_f64(&mut self) -> f64;
93
94 fn boxed_clone(&self) -> Box<dyn AgentSimulator> {
96 self.boxed_clone_with_seed(0)
97 }
98
99 fn boxed_clone_with_seed(&self, seed: u64) -> Box<dyn AgentSimulator>;
101
102 fn norm_reward(&self, reward: f64) -> f64 {
107 let min = self.min_reward() as f64;
108 let max = self.max_reward() as f64;
109 let h = self.horizon() as f64;
110 let gamma = self.discount_gamma().clamp(0.0, 1.0);
111
112 let discount_sum = if (gamma - 1.0).abs() < 1e-9 {
114 h
115 } else {
116 (1.0 - gamma.powi(h as i32)) / (1.0 - gamma)
117 };
118
119 let range = (max - min) * discount_sum;
120 let min_cumulative = min * discount_sum;
121
122 if range.abs() < 1e-9 {
123 0.5
124 } else {
125 (reward - min_cumulative) / range
126 }
127 }
128
129 fn gen_percepts_and_update(&mut self) -> (Vec<PerceptVal>, Reward) {
131 let obs_bits = self.get_num_observation_bits();
132 let obs_len = self.observation_stream_len().max(1);
133 let mut observations = Vec::with_capacity(obs_len);
134 for _ in 0..obs_len {
135 observations.push(self.gen_percept_and_update(obs_bits));
136 }
137
138 let obs_key = self.observation_repr_from_stream(&observations);
139 let rew_bits = self.get_num_reward_bits();
140 let rew_u = self.gen_percept_and_update(rew_bits);
141 let rew = (rew_u as i64) - self.reward_offset();
142 (obs_key, rew)
143 }
144}
145
146#[derive(Clone)]
151pub struct SearchNode {
152 visits: u32,
154 mean: f64,
156 is_chance_node: bool,
158 action_children: Vec<Option<SearchNode>>,
160 percept_children: HashMap<PerceptOutcome, SearchNode>,
162}
163
164impl SearchNode {
165 pub fn new(is_chance_node: bool) -> Self {
167 Self {
168 visits: 0,
169 mean: 0.0,
170 is_chance_node,
171 action_children: Vec::new(),
172 percept_children: HashMap::new(),
173 }
174 }
175
176 pub fn best_action(&self, agent: &mut dyn AgentSimulator) -> Action {
178 let mut best_actions = Vec::new();
179 let mut best_mean = -f64::INFINITY;
180
181 for (action, child) in self.action_children.iter().enumerate() {
182 let Some(child) = child.as_ref() else {
183 continue;
184 };
185 let mean = child.mean;
186 if mean > best_mean {
187 best_mean = mean;
188 best_actions.clear();
189 best_actions.push(action as u64);
190 } else if (mean - best_mean).abs() < 1e-9 {
191 best_actions.push(action as u64);
192 }
193 }
194
195 if best_actions.is_empty() {
196 return 0;
197 }
198
199 let idx = agent.gen_range(best_actions.len());
200 best_actions[idx] as Action
201 }
202
203 fn expectation(&self) -> f64 {
204 self.mean
205 }
206
207 fn apply_delta(&mut self, base: &SearchNode, updated: &SearchNode) {
208 if self.is_chance_node != base.is_chance_node
209 || self.is_chance_node != updated.is_chance_node
210 {
211 return;
212 }
213
214 let base_visits = base.visits as f64;
215 let updated_visits = updated.visits as f64;
216 if updated_visits < base_visits {
217 return;
218 }
219
220 let delta_visits = updated.visits - base.visits;
221 if delta_visits > 0 {
222 let base_sum = base.mean * base_visits;
223 let updated_sum = updated.mean * updated_visits;
224 let delta_sum = updated_sum - base_sum;
225 let total_visits = self.visits + delta_visits;
226 let total_sum = self.mean * (self.visits as f64) + delta_sum;
227 self.visits = total_visits;
228 self.mean = if total_visits > 0 {
229 total_sum / (total_visits as f64)
230 } else {
231 0.0
232 };
233 }
234
235 if self.is_chance_node {
236 for (key, updated_child) in &updated.percept_children {
237 if let Some(base_child) = base.percept_children.get(key) {
238 if let Some(self_child) = self.percept_children.get_mut(key) {
239 self_child.apply_delta(base_child, updated_child);
240 } else {
241 let mut child = SearchNode::new(updated_child.is_chance_node);
242 child.apply_delta(
243 &SearchNode::new(updated_child.is_chance_node),
244 updated_child,
245 );
246 self.percept_children.insert(key.clone(), child);
247 }
248 } else if let Some(self_child) = self.percept_children.get_mut(key) {
249 let empty = SearchNode::new(updated_child.is_chance_node);
250 self_child.apply_delta(&empty, updated_child);
251 } else {
252 let mut child = SearchNode::new(updated_child.is_chance_node);
253 child.apply_delta(
254 &SearchNode::new(updated_child.is_chance_node),
255 updated_child,
256 );
257 self.percept_children.insert(key.clone(), child);
258 }
259 }
260 } else {
261 let max_len = base
262 .action_children
263 .len()
264 .max(updated.action_children.len());
265 if self.action_children.len() < max_len {
266 self.action_children.resize_with(max_len, || None);
267 }
268 for idx in 0..max_len {
269 let base_child = base.action_children.get(idx).and_then(|c| c.as_ref());
270 let updated_child = updated.action_children.get(idx).and_then(|c| c.as_ref());
271 let Some(updated_child) = updated_child else {
272 continue;
273 };
274 match (base_child, self.action_children.get_mut(idx)) {
275 (Some(base_child), Some(Some(self_child))) => {
276 self_child.apply_delta(base_child, updated_child);
277 }
278 (Some(base_child), Some(slot @ None)) => {
279 let mut child = SearchNode::new(updated_child.is_chance_node);
280 child.apply_delta(base_child, updated_child);
281 *slot = Some(child);
282 }
283 (None, Some(Some(self_child))) => {
284 let empty = SearchNode::new(updated_child.is_chance_node);
285 self_child.apply_delta(&empty, updated_child);
286 }
287 (None, Some(slot @ None)) => {
288 let mut child = SearchNode::new(updated_child.is_chance_node);
289 child.apply_delta(
290 &SearchNode::new(updated_child.is_chance_node),
291 updated_child,
292 );
293 *slot = Some(child);
294 }
295 _ => {}
296 }
297 }
298 }
299 }
300
301 fn select_action(
303 &mut self,
304 agent: &mut dyn AgentSimulator,
305 _horizon: usize,
306 ) -> (&mut SearchNode, Action) {
307 let num_actions = agent.get_num_actions();
308
309 if self.action_children.len() < num_actions {
310 self.action_children.resize_with(num_actions, || None);
311 }
312
313 let mut unvisited = Vec::new();
314 for a in 0..num_actions {
315 if self.action_children[a].is_none() {
316 unvisited.push(a as u64);
317 }
318 }
319
320 let action;
321 if !unvisited.is_empty() {
322 let idx = agent.gen_range(unvisited.len());
323 action = unvisited[idx];
324 self.action_children[action as usize] = Some(SearchNode::new(true));
325 } else {
326 let c = agent.get_explore_exploit_ratio();
328 let mut best_val = -f64::INFINITY;
329 let mut best_action = 0;
330 let log_visits = (self.visits as f64).ln().max(0.0);
331 for (a, child) in self.action_children.iter().enumerate() {
332 let Some(child) = child.as_ref() else {
333 continue;
334 };
335 let exploit = agent.norm_reward(child.expectation());
336 let explore = if child.visits > 0 {
337 c * (log_visits / child.visits as f64).sqrt()
338 } else {
339 f64::INFINITY
340 };
341 let val = exploit + explore;
342 if val > best_val {
343 best_val = val;
344 best_action = a as u64;
345 }
346 }
347 action = best_action;
348 }
349
350 agent.model_update_action(action as Action);
351 (
352 self.action_children[action as usize]
353 .as_mut()
354 .expect("missing action child"),
355 action as Action,
356 )
357 }
358
359 pub fn sample(
361 &mut self,
362 agent: &mut dyn AgentSimulator,
363 horizon: usize,
364 total_horizon: usize,
365 ) -> f64 {
366 if horizon == 0 {
367 agent.model_revert(total_horizon);
368 return 0.0;
369 }
370
371 let reward;
372 if self.is_chance_node {
373 let (obs, rew) = agent.gen_percepts_and_update();
374 let key = PerceptOutcome {
375 observations: obs,
376 reward: rew,
377 };
378 let child = self
379 .percept_children
380 .entry(key)
381 .or_insert_with(|| SearchNode::new(false));
382 reward = (rew as f64)
383 + agent.discount_gamma() * child.sample(agent, horizon - 1, total_horizon);
384 } else if self.visits == 0 {
385 reward = Self::playout(agent, horizon, total_horizon);
386 } else {
387 let (child, _act) = self.select_action(agent, horizon);
388 reward = child.sample(agent, horizon, total_horizon);
389 }
390
391 self.mean = (reward + (self.visits as f64) * self.mean) / ((self.visits + 1) as f64);
393 self.visits += 1;
394
395 reward
396 }
397
398 fn playout(agent: &mut dyn AgentSimulator, horizon: usize, total_horizon: usize) -> f64 {
400 let mut total_rew = 0.0;
401 let num_actions = agent.get_num_actions();
402 let gamma = agent.discount_gamma().clamp(0.0, 1.0);
403 let mut discount = 1.0;
404
405 for _ in 0..horizon {
406 let act = agent.gen_range(num_actions);
407 agent.model_update_action(act as Action);
408 let (_key, rew) = agent.gen_percepts_and_update();
409 total_rew += discount * (rew as f64);
410 discount *= gamma;
411 }
412
413 agent.model_revert(total_horizon);
414 total_rew
415 }
416}
417
418pub struct SearchTree {
420 root: Option<SearchNode>,
421}
422
423impl SearchTree {
424 pub fn new() -> Self {
426 Self {
427 root: Some(SearchNode::new(false)),
428 }
429 }
430
431 pub fn search(
433 &mut self,
434 agent: &mut dyn AgentSimulator,
435 prev_obs_stream: &[PerceptVal],
436 prev_rew: Reward,
437 prev_act: u64,
438 samples: usize,
439 ) -> Action {
440 self.prune_tree(agent, prev_obs_stream, prev_rew, prev_act);
441
442 let root = self.root.as_mut().unwrap();
443 let h = agent.horizon();
444 let threads = rayon::current_num_threads().max(1);
445 if samples < 2 || threads < 2 {
446 for _ in 0..samples {
447 root.sample(agent, h, h);
448 }
449 return root.best_action(agent);
450 }
451
452 let workers = threads.min(samples);
453 let base = samples / workers;
454 let extra = samples % workers;
455 let snapshot = root.clone();
456
457 let mut agents = Vec::with_capacity(workers);
458 for i in 0..workers {
459 let seed = agent.gen_f64().to_bits() ^ (i as u64);
460 agents.push(agent.boxed_clone_with_seed(seed));
461 }
462
463 let results: Vec<SearchNode> = agents
464 .into_par_iter()
465 .enumerate()
466 .map(|(i, mut local_agent)| {
467 let mut local_root = snapshot.clone();
468 let iterations = base + usize::from(i < extra);
469 for _ in 0..iterations {
470 local_root.sample(local_agent.as_mut(), h, h);
471 }
472 local_root
473 })
474 .collect();
475
476 for local in &results {
477 root.apply_delta(&snapshot, local);
478 }
479
480 root.best_action(agent)
481 }
482
483 fn prune_tree(
485 &mut self,
486 agent: &mut dyn AgentSimulator,
487 prev_obs_stream: &[PerceptVal],
488 prev_rew: Reward,
489 prev_act: u64,
490 ) {
491 if self.root.is_none() {
492 self.root = Some(SearchNode::new(false));
493 return;
494 }
495
496 let mut old_root = self.root.take().unwrap();
497
498 let action_child_opt = if old_root.action_children.len() > prev_act as usize {
500 old_root.action_children[prev_act as usize].take()
501 } else {
502 None
503 };
504
505 if let Some(mut chance_child) = action_child_opt {
506 let obs_repr = agent.observation_repr_from_stream(prev_obs_stream);
507 let key = PerceptOutcome {
508 observations: obs_repr,
509 reward: prev_rew,
510 };
511
512 if let Some(action_child) = chance_child.percept_children.remove(&key) {
513 self.root = Some(action_child);
514 } else {
515 self.root = Some(SearchNode::new(false));
516 }
517 } else {
518 self.root = Some(SearchNode::new(false));
519 }
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526 use crate::aixi::common::ObservationKeyMode;
527
528 #[derive(Clone)]
529 struct DummyAgent {
530 obs_bits: usize,
531 rew_bits: usize,
532 horizon: usize,
533 min_reward: Reward,
534 max_reward: Reward,
535 key_mode: ObservationKeyMode,
536 }
537
538 impl DummyAgent {
539 fn new(obs_bits: usize, key_mode: ObservationKeyMode) -> Self {
540 Self {
541 obs_bits,
542 rew_bits: 8,
543 horizon: 5,
544 min_reward: -1,
545 max_reward: 1,
546 key_mode,
547 }
548 }
549 }
550
551 impl AgentSimulator for DummyAgent {
552 fn get_num_actions(&self) -> usize {
553 4
554 }
555
556 fn get_num_observation_bits(&self) -> usize {
557 self.obs_bits
558 }
559
560 fn observation_key_mode(&self) -> ObservationKeyMode {
561 self.key_mode
562 }
563
564 fn get_num_reward_bits(&self) -> usize {
565 self.rew_bits
566 }
567
568 fn horizon(&self) -> usize {
569 self.horizon
570 }
571
572 fn max_reward(&self) -> Reward {
573 self.max_reward
574 }
575
576 fn min_reward(&self) -> Reward {
577 self.min_reward
578 }
579
580 fn model_update_action(&mut self, _action: Action) {}
581
582 fn gen_percept_and_update(&mut self, _bits: usize) -> u64 {
583 0
584 }
585
586 fn model_revert(&mut self, _steps: usize) {}
587
588 fn gen_range(&mut self, _end: usize) -> usize {
589 0
590 }
591
592 fn gen_f64(&mut self) -> f64 {
593 0.0
594 }
595
596 fn boxed_clone_with_seed(&self, _seed: u64) -> Box<dyn AgentSimulator> {
597 Box::new(self.clone())
598 }
599 }
600
601 fn build_tree_with_key(
602 agent: &DummyAgent,
603 prev_act: u64,
604 prev_obs_stream: &[PerceptVal],
605 prev_rew: Reward,
606 kept_mean: f64,
607 kept_visits: u32,
608 ) -> SearchTree {
609 let mut old_root = SearchNode::new(false);
610 old_root.action_children.resize(prev_act as usize + 1, None);
611
612 let mut chance_child = SearchNode::new(true);
613 let mut kept = SearchNode::new(false);
614 kept.mean = kept_mean;
615 kept.visits = kept_visits;
616
617 let obs_repr = agent.observation_repr_from_stream(prev_obs_stream);
618 let key = PerceptOutcome {
619 observations: obs_repr,
620 reward: prev_rew,
621 };
622 chance_child.percept_children.insert(key, kept);
623
624 old_root.action_children[prev_act as usize] = Some(chance_child);
625 SearchTree {
626 root: Some(old_root),
627 }
628 }
629
630 #[test]
631 fn prune_tree_keeps_matching_subtree() {
632 let prev_act = 2u64;
633 let prev_obs_stream = vec![9u64, 2u64, 7u64];
634 let prev_rew: Reward = 3;
635
636 let mut agent = DummyAgent::new(3, ObservationKeyMode::FullStream);
637 let mut tree = build_tree_with_key(&agent, prev_act, &prev_obs_stream, prev_rew, 123.0, 7);
638
639 tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
640
641 let root = tree.root.as_ref().expect("root should exist");
642 assert!(!root.is_chance_node);
643 assert_eq!(root.mean, 123.0);
644 assert_eq!(root.visits, 7);
645 }
646
647 #[test]
648 fn prune_tree_resets_when_action_missing() {
649 let prev_act = 10u64;
650 let prev_obs_stream = vec![1u64];
651 let prev_rew: Reward = 0;
652
653 let mut agent = DummyAgent::new(1, ObservationKeyMode::FullStream);
654 let mut tree = SearchTree::new();
655
656 tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
657
658 let root = tree.root.as_ref().unwrap();
659 assert!(!root.is_chance_node);
660 assert_eq!(root.visits, 0);
661 assert_eq!(root.mean, 0.0);
662 }
663
664 #[test]
665 fn prune_tree_resets_when_percept_key_missing() {
666 let prev_act = 0u64;
667 let prev_obs_stream = vec![1u64, 2u64];
668 let prev_rew: Reward = 1;
669
670 let mut agent = DummyAgent::new(4, ObservationKeyMode::Last);
671
672 let mut tree = build_tree_with_key(&agent, prev_act, &prev_obs_stream, prev_rew + 1, 9.0, 2);
674
675 tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
676
677 let root = tree.root.as_ref().unwrap();
678 assert!(!root.is_chance_node);
679 assert_eq!(root.visits, 0);
680 assert_eq!(root.mean, 0.0);
681 }
682}