infotheory/aixi/
environment.rs

1//! Standard benchmark environments for AIXI.
2//!
3//! This module provides a set of environments for testing and evaluating
4//! AIXI agents. Each environment implements the `Environment` trait,
5//! providing a consistent interface for interaction.
6
7use crate::aixi::common::{Action, PerceptVal, RandomGenerator, Reward};
8
9/// Interface for an agent's environment.
10///
11/// An environment consumes actions from the agent and produces percepts
12/// (observations and rewards) in response.
13pub trait Environment {
14    /// Executes an action in the environment and updates its internal state.
15    fn perform_action(&mut self, action: Action);
16
17    /// Returns the current observation produced by the environment.
18    fn get_observation(&self) -> PerceptVal;
19
20    /// Returns a stream of observation symbols produced by the last action.
21    ///
22    /// Default behavior is a single observation.
23    fn drain_observations(&mut self) -> Vec<PerceptVal> {
24        vec![self.get_observation()]
25    }
26
27    /// Returns the current reward produced by the environment.
28    fn get_reward(&self) -> Reward;
29
30    /// Returns true if the environment has reached a terminal state.
31    fn is_finished(&self) -> bool;
32
33    /// Returns the number of bits used to encode observations in this environment.
34    fn get_observation_bits(&self) -> usize;
35
36    /// Returns the number of bits used to encode rewards in this environment.
37    fn get_reward_bits(&self) -> usize;
38
39    /// Returns the number of bits required to represent all possible actions.
40    fn get_action_bits(&self) -> usize;
41
42    /// Reseed the environment RNG for deterministic, reproducible runs.
43    ///
44    /// Deterministic environments can ignore this. Stochastic environments
45    /// should reseed and reset any stochastic state so the initial percept
46    /// sequence is reproducible from `seed`.
47    fn set_random_seed(&mut self, _seed: u64) {}
48
49    /// Returns the total number of valid actions available.
50    fn get_num_actions(&self) -> usize {
51        1 << self.get_action_bits()
52    }
53
54    /// Returns the maximum possible reward value in this environment.
55    fn max_reward(&self) -> Reward {
56        let bits = self.get_reward_bits();
57        if bits == 0 {
58            return 0;
59        }
60        // Prevent overflow for bits >= 64
61        if bits >= 64 {
62            i64::MAX
63        } else {
64            (1i64 << (bits - 1)) - 1
65        }
66    }
67
68    /// Returns the minimum possible reward value in this environment.
69    fn min_reward(&self) -> Reward {
70        let bits = self.get_reward_bits();
71        if bits == 0 {
72            return 0;
73        }
74        // Prevent overflow for bits >= 64
75        if bits >= 64 {
76            i64::MIN
77        } else {
78            -(1i64 << (bits - 1))
79        }
80    }
81}
82
83/// A simple biased coin flip environment.
84///
85/// The agent predicts the outcome of a coin flip. Correct predictions
86/// result in a reward of 1, otherwise 0.
87pub struct CoinFlip {
88    /// Probability of the coin landing heads (1).
89    p: f64,
90    /// Current observation (coin face).
91    obs: PerceptVal,
92    /// Last reward received.
93    rew: Reward,
94    /// Internal RNG.
95    rng: RandomGenerator,
96}
97
98impl CoinFlip {
99    /// Creates a new `CoinFlip` environment with bias `p`.
100    pub fn new(p: f64) -> Self {
101        Self::new_with_seed(p, None)
102    }
103
104    /// Creates a new `CoinFlip` environment with optional deterministic seed.
105    pub fn new_with_seed(p: f64, seed: Option<u64>) -> Self {
106        let mut env = Self {
107            p,
108            obs: 0,
109            rew: 0,
110            rng: seed.map(RandomGenerator::from_seed).unwrap_or_default(),
111        };
112        // Initial observation
113        env.gen_next();
114        env
115    }
116
117    fn gen_next(&mut self) {
118        self.obs = if self.rng.gen_bool(self.p) { 1 } else { 0 };
119    }
120}
121
122impl Environment for CoinFlip {
123    fn perform_action(&mut self, action: Action) {
124        self.gen_next();
125        self.rew = if action == self.obs { 1 } else { 0 };
126    }
127
128    fn get_observation(&self) -> PerceptVal {
129        self.obs
130    }
131    fn get_reward(&self) -> Reward {
132        self.rew
133    }
134    fn is_finished(&self) -> bool {
135        false
136    }
137
138    fn get_observation_bits(&self) -> usize {
139        1
140    }
141    fn get_reward_bits(&self) -> usize {
142        1
143    }
144
145    fn min_reward(&self) -> Reward {
146        0
147    }
148
149    fn max_reward(&self) -> Reward {
150        1
151    }
152    fn get_action_bits(&self) -> usize {
153        1
154    }
155
156    fn set_random_seed(&mut self, seed: u64) {
157        self.rng = RandomGenerator::from_seed(seed);
158        self.rew = 0;
159        self.gen_next();
160    }
161}
162
163/// A synthetic environment for testing CTW performance.
164///
165/// Generates a deterministic sequence designed to be perfectly
166/// predictable by a sufficiently deep Context Tree.
167pub struct CtwTest {
168    cycle: usize,
169    last_action: Action,
170    obs: PerceptVal,
171    rew: Reward,
172}
173
174impl CtwTest {
175    /// Creates a new `CtwTest` environment.
176    pub fn new() -> Self {
177        Self {
178            cycle: 0,
179            last_action: 0,
180            obs: 0,
181            rew: 0,
182        }
183    }
184}
185
186impl Default for CtwTest {
187    fn default() -> Self {
188        Self::new()
189    }
190}
191
192impl Environment for CtwTest {
193    fn perform_action(&mut self, action: Action) {
194        if self.cycle == 0 {
195            self.obs = 0;
196            self.rew = if self.obs == action { 1 } else { 0 };
197        } else {
198            self.obs = (self.last_action + 1) % 2;
199            self.rew = if self.obs == action { 1 } else { 0 };
200        }
201        self.last_action = action;
202        self.cycle += 1;
203    }
204
205    fn get_observation(&self) -> PerceptVal {
206        self.obs
207    }
208    fn get_reward(&self) -> Reward {
209        self.rew
210    }
211    fn is_finished(&self) -> bool {
212        false
213    }
214
215    fn get_observation_bits(&self) -> usize {
216        1
217    }
218    fn get_reward_bits(&self) -> usize {
219        1
220    }
221
222    fn min_reward(&self) -> Reward {
223        0
224    }
225
226    fn max_reward(&self) -> Reward {
227        1
228    }
229    fn get_action_bits(&self) -> usize {
230        1
231    }
232}
233
234/// A Rock-Paper-Scissors environment with a biased opponent.
235///
236/// The opponent plays randomly unless it wins a round, in which case
237/// it repeats its winning action.
238pub struct BiasedRockPaperScissor {
239    obs: PerceptVal,
240    rew: Reward,
241    rng: RandomGenerator,
242}
243
244impl BiasedRockPaperScissor {
245    /// Creates a new `BiasedRockPaperScissor` environment.
246    pub fn new() -> Self {
247        Self::new_with_seed(None)
248    }
249
250    /// Creates a new `BiasedRockPaperScissor` environment with optional seed.
251    pub fn new_with_seed(seed: Option<u64>) -> Self {
252        Self {
253            // Match reference MC-AIXI/PyAIXI initial percept: non-rock.
254            obs: 1,
255            rew: 0,
256            rng: seed.map(RandomGenerator::from_seed).unwrap_or_default(),
257        }
258    }
259}
260
261impl Default for BiasedRockPaperScissor {
262    fn default() -> Self {
263        Self::new()
264    }
265}
266
267impl Environment for BiasedRockPaperScissor {
268    fn perform_action(&mut self, action: Action) {
269        // action 0: Rock, 1: Paper, 2: Scissors
270        // Match reference MC-AIXI/PyAIXI bias: repeat rock iff opponent won
271        // the previous round by playing rock.
272        let opponent_action = if self.obs == 0 && self.rew == -1 {
273            0
274        } else {
275            let r = self.rng.gen_f64();
276            if r < 1.0 / 3.0 {
277                0
278            } else if r < 2.0 / 3.0 {
279                1
280            } else {
281                2
282            }
283        };
284
285        // Determine Outcome
286        if opponent_action == action {
287            self.rew = 0; // Draw
288        } else if (opponent_action == 0 && action == 1)
289            || (opponent_action == 1 && action == 2)
290            || (opponent_action == 2 && action == 0)
291        {
292            self.rew = 1; // Win
293        } else {
294            self.rew = -1; // Loss
295        }
296        self.obs = opponent_action as PerceptVal;
297    }
298
299    fn get_observation(&self) -> PerceptVal {
300        self.obs
301    }
302    fn get_reward(&self) -> Reward {
303        self.rew
304    }
305    fn is_finished(&self) -> bool {
306        false
307    }
308
309    fn get_observation_bits(&self) -> usize {
310        2
311    }
312    fn get_reward_bits(&self) -> usize {
313        2
314    }
315
316    fn min_reward(&self) -> Reward {
317        -1
318    }
319
320    fn max_reward(&self) -> Reward {
321        1
322    }
323    fn get_action_bits(&self) -> usize {
324        2
325    }
326    fn get_num_actions(&self) -> usize {
327        3
328    }
329
330    fn set_random_seed(&mut self, seed: u64) {
331        self.rng = RandomGenerator::from_seed(seed);
332        // Match reference initial condition after reseed.
333        self.obs = 1;
334        self.rew = 0;
335    }
336}
337
338/// A more complex version of the classic Tiger problem.
339///
340/// Includes states for sitting and standing, with different rewards
341/// and transition probabilities.
342pub struct ExtendedTiger {
343    state: usize, // 0: sitting, 1: standing
344    tiger_door: usize,
345    gold_door: usize,
346    obs: PerceptVal,
347    rew: Reward,
348    rng: RandomGenerator,
349}
350
351impl ExtendedTiger {
352    /// Creates a new `ExtendedTiger` environment.
353    pub fn new() -> Self {
354        let mut rng = RandomGenerator::new();
355        let gold_door = if rng.gen_bool(0.5) { 1 } else { 2 };
356        let tiger_door = if gold_door == 1 { 2 } else { 3 };
357
358        Self {
359            state: 0,
360            gold_door,
361            tiger_door,
362            obs: 0,
363            rew: 0,
364            rng,
365        }
366    }
367
368    fn reset_doors(&mut self) {
369        self.gold_door = if self.rng.gen_bool(0.5) { 1 } else { 2 };
370        self.tiger_door = if self.gold_door == 1 { 2 } else { 3 };
371    }
372}
373
374impl Default for ExtendedTiger {
375    fn default() -> Self {
376        Self::new()
377    }
378}
379
380impl Environment for ExtendedTiger {
381    fn perform_action(&mut self, action: Action) {
382        // Actions: 0: Stand, 1: Listen, 2: Open 1, 3: Open 2
383        match action {
384            0 => {
385                // Stand
386                if self.state == 1 {
387                    self.rew = -1;
388                } else {
389                    self.state = 1;
390                    self.rew = -1;
391                    if self.obs < 4 {
392                        self.obs += 4;
393                    }
394                }
395            }
396            1 => {
397                // Listen
398                if self.state == 1 || self.obs != 0 {
399                    self.rew = -1;
400                    self.obs = 0;
401                } else {
402                    self.obs = if self.rng.gen_bool(0.85) {
403                        self.tiger_door as PerceptVal
404                    } else {
405                        self.gold_door as PerceptVal
406                    };
407                    self.rew = -1;
408                }
409            }
410            2 => {
411                // Open 1
412                if self.state == 0 {
413                    self.rew = -100;
414                } else {
415                    self.rew = if self.gold_door == 1 { 30 } else { -100 };
416                    self.obs = 0;
417                    self.state = 0;
418                    self.reset_doors();
419                }
420            }
421            3 => {
422                // Open 2
423                if self.state == 0 {
424                    self.rew = -100;
425                } else {
426                    self.rew = if self.gold_door == 2 { 30 } else { -100 };
427                    self.obs = 0;
428                    self.state = 0;
429                    self.reset_doors();
430                }
431            }
432            _ => {
433                self.rew = -100;
434            }
435        }
436    }
437
438    fn get_observation(&self) -> PerceptVal {
439        self.obs
440    }
441    fn get_reward(&self) -> Reward {
442        self.rew
443    }
444    fn is_finished(&self) -> bool {
445        false
446    }
447
448    fn get_observation_bits(&self) -> usize {
449        3
450    }
451    fn get_reward_bits(&self) -> usize {
452        8
453    }
454
455    fn min_reward(&self) -> Reward {
456        -100
457    }
458
459    fn max_reward(&self) -> Reward {
460        30
461    }
462    fn get_action_bits(&self) -> usize {
463        2
464    }
465    fn get_num_actions(&self) -> usize {
466        4
467    }
468
469    fn set_random_seed(&mut self, seed: u64) {
470        self.rng = RandomGenerator::from_seed(seed);
471        self.state = 0;
472        self.obs = 0;
473        self.rew = 0;
474        self.reset_doors();
475    }
476}
477
478/// A standard Tic-Tac-Toe environment against a random opponent.
479pub struct TicTacToe {
480    board: [i8; 9], // 0: empty, 1: agent, -1: opponent.
481    open_squares: Vec<usize>,
482    state: u64,
483    obs: PerceptVal,
484    rew: Reward,
485    rng: RandomGenerator,
486}
487
488impl TicTacToe {
489    /// Creates a new `TicTacToe` environment.
490    pub fn new() -> Self {
491        Self {
492            board: [0; 9],
493            open_squares: (0..9).collect(),
494            state: 0,
495            obs: 0,
496            rew: 0,
497            rng: RandomGenerator::new(),
498        }
499    }
500
501    fn reset_game(&mut self) {
502        self.board = [0; 9];
503        self.open_squares = (0..9).collect();
504        self.state = 0;
505    }
506
507    fn check_win(&self, player: i8) -> bool {
508        let b = self.board;
509        let wins = [
510            (0, 1, 2),
511            (3, 4, 5),
512            (6, 7, 8), // Rows
513            (0, 3, 6),
514            (1, 4, 7),
515            (2, 5, 8), // Cols
516            (0, 4, 8),
517            (2, 4, 6), // Diags
518        ];
519        for &(x, y, z) in &wins {
520            if b[x] == player && b[y] == player && b[z] == player {
521                return true;
522            }
523        }
524        false
525    }
526}
527
528impl Default for TicTacToe {
529    fn default() -> Self {
530        Self::new()
531    }
532}
533
534impl Environment for TicTacToe {
535    fn perform_action(&mut self, action: Action) {
536        if action >= 9 {
537            self.rew = -3;
538            self.obs = self.state as PerceptVal;
539            return;
540        }
541
542        if self.board[action as usize] != 0 {
543            // Illegal move
544            self.rew = -3;
545        } else {
546            // Agent move (1)
547            self.state += 1 << (2 * action);
548            self.board[action as usize] = 1;
549
550            // Remove from open
551            if let Some(pos) = self.open_squares.iter().position(|&x| x == action as usize) {
552                self.open_squares.remove(pos);
553            }
554
555            self.rew = 0;
556
557            if self.check_win(1) {
558                // Agent won
559                self.reset_game();
560                self.rew = 2;
561            } else if self.open_squares.is_empty() {
562                // Draw
563                self.reset_game();
564                self.rew = 1;
565            } else {
566                // Opponent move (-1, mapped to 2 in base-4)
567
568                // Shuffle open squares
569                let n = self.open_squares.len();
570                if n > 0 {
571                    let idx = self.rng.gen_range(n);
572                    let opponent_move = self.open_squares[idx];
573
574                    self.state += 2 << (2 * opponent_move);
575                    self.board[opponent_move] = -1;
576
577                    self.open_squares.remove(idx);
578
579                    if self.check_win(-1) {
580                        // Opponent won
581                        self.reset_game();
582                        self.rew = -2;
583                    } else if self.open_squares.is_empty() {
584                        self.reset_game();
585                        self.rew = 1;
586                    }
587                }
588            }
589        }
590        self.obs = self.state as PerceptVal;
591    }
592
593    fn get_observation(&self) -> PerceptVal {
594        self.obs
595    }
596    fn get_reward(&self) -> Reward {
597        self.rew
598    }
599    fn is_finished(&self) -> bool {
600        false
601    }
602
603    fn get_observation_bits(&self) -> usize {
604        18
605    } // 9 squares * 2 bits
606    fn get_reward_bits(&self) -> usize {
607        3
608    }
609    fn min_reward(&self) -> Reward {
610        -3
611    }
612    fn max_reward(&self) -> Reward {
613        2
614    }
615    fn get_action_bits(&self) -> usize {
616        4
617    }
618    fn get_num_actions(&self) -> usize {
619        9
620    }
621
622    fn set_random_seed(&mut self, seed: u64) {
623        self.rng = RandomGenerator::from_seed(seed);
624        self.reset_game();
625        self.obs = 0;
626        self.rew = 0;
627    }
628}
629
630/// A 2-player imperfect information game: Kuhn Poker.
631///
632/// The agent plays against a Nash-optimized opponent in a simplified
633/// 3-card poker game.
634pub struct KuhnPoker {
635    opponent_card: usize, // 0:J, 1:Q, 2:K
636    agent_card: usize,
637    opponent_action: usize, // 0: bet, 1: pass
638    obs: PerceptVal,
639    rew: Reward,
640    rng: RandomGenerator,
641}
642
643impl KuhnPoker {
644    /// Creates a new `KuhnPoker` environment.
645    pub fn new() -> Self {
646        Self::new_with_seed(None)
647    }
648
649    /// Creates a new `KuhnPoker` environment with optional deterministic seed.
650    pub fn new_with_seed(seed: Option<u64>) -> Self {
651        let mut env = Self {
652            opponent_card: 0,
653            agent_card: 0,
654            opponent_action: 0,
655            obs: 0,
656            rew: 0,
657            rng: seed.map(RandomGenerator::from_seed).unwrap_or_default(),
658        };
659        env.reset_game();
660        env
661    }
662
663    #[inline]
664    fn random_card(&mut self) -> usize {
665        self.rng.gen_range(3)
666    }
667
668    fn reset_game(&mut self) {
669        // Card encoding matches the reference implementations:
670        // 0=Jack, 1=Queen, 2=King.
671        self.agent_card = self.random_card();
672        self.opponent_card = self.agent_card;
673        while self.opponent_card == self.agent_card {
674            self.opponent_card = self.random_card();
675        }
676
677        const ACTION_BET: usize = 0;
678        const ACTION_PASS: usize = 1;
679        const BET_PROB_KING: f64 = 0.7;
680        const BET_PROB_JACK: f64 = BET_PROB_KING / 3.0;
681
682        // Opponent first action (reference Nash policy).
683        self.opponent_action = if self.opponent_card == 0 {
684            if self.rng.gen_bool(BET_PROB_JACK) {
685                ACTION_BET
686            } else {
687                ACTION_PASS
688            }
689        } else if self.opponent_card == 1 {
690            ACTION_PASS
691        } else if self.rng.gen_bool(BET_PROB_KING) {
692            ACTION_BET
693        } else {
694            ACTION_PASS
695        };
696
697        // Observation encoding matches C++/PyAIXI:
698        // observation = agent_card + (opponent_pass ? 4 : 0)
699        let action_code = if self.opponent_action == ACTION_PASS {
700            4
701        } else {
702            0
703        };
704        let card_code = self.agent_card;
705        self.obs = (action_code + card_code) as PerceptVal;
706    }
707}
708
709impl Default for KuhnPoker {
710    fn default() -> Self {
711        Self::new()
712    }
713}
714
715impl Environment for KuhnPoker {
716    fn perform_action(&mut self, action: Action) {
717        const ACTION_BET: usize = 0;
718        const ACTION_PASS: usize = 1;
719
720        // Reference reward levels are encoded as {0,1,3,4}. We emit the
721        // offset-removed values {-2,-1,1,2} for direct comparability.
722        const R_BET_LOSS: Reward = -2;
723        const R_PASS_LOSS: Reward = -1;
724        const R_PASS_WIN: Reward = 1;
725        const R_BET_WIN: Reward = 2;
726
727        const BET_PROB_KING: f64 = 0.7;
728        const BET_PROB_QUEEN: f64 = (1.0 + BET_PROB_KING) / 3.0;
729
730        if action > 1 {
731            self.rew = R_BET_LOSS;
732            self.reset_game();
733            return;
734        }
735
736        // If the agent did not call an opponent bet, the agent loses.
737        if action as usize == ACTION_PASS && self.opponent_action == ACTION_BET {
738            self.rew = R_PASS_LOSS;
739            self.reset_game();
740            return;
741        }
742
743        // If opponent passed and agent bet, opponent may reconsider.
744        if action as usize == ACTION_BET && self.opponent_action == ACTION_PASS {
745            if self.opponent_card == 1 && self.rng.gen_bool(BET_PROB_QUEEN) {
746                self.opponent_action = ACTION_BET;
747            } else if self.opponent_card == 2 {
748                self.opponent_action = ACTION_BET;
749            } else {
750                self.rew = R_PASS_WIN;
751                self.reset_game();
752                return;
753            }
754        }
755
756        // Showdown.
757        let agent_wins =
758            self.opponent_card == 0 || (self.opponent_card == 1 && self.agent_card == 2);
759        if agent_wins {
760            self.rew = if self.opponent_action == ACTION_BET {
761                R_BET_WIN
762            } else {
763                R_PASS_WIN
764            };
765        } else {
766            self.rew = if action as usize == ACTION_BET {
767                R_BET_LOSS
768            } else {
769                R_PASS_LOSS
770            };
771        }
772        self.reset_game();
773    }
774
775    fn get_observation(&self) -> PerceptVal {
776        self.obs
777    }
778    fn get_reward(&self) -> Reward {
779        self.rew
780    }
781    fn is_finished(&self) -> bool {
782        false
783    }
784
785    fn get_observation_bits(&self) -> usize {
786        3
787    }
788    fn get_reward_bits(&self) -> usize {
789        3
790    }
791
792    fn min_reward(&self) -> Reward {
793        -2
794    }
795
796    fn max_reward(&self) -> Reward {
797        2
798    }
799    fn get_action_bits(&self) -> usize {
800        1
801    } // 0 or 1
802    fn get_num_actions(&self) -> usize {
803        2
804    }
805
806    fn set_random_seed(&mut self, seed: u64) {
807        self.rng = RandomGenerator::from_seed(seed);
808        self.rew = 0;
809        self.reset_game();
810    }
811}