infotheory/aixi/
environment.rs

1//! Standard benchmark environments for AIXI.
2//!
3//! This module provides a set of environments for testing and evaluating
4//! AIXI agents. Each environment implements the `Environment` trait,
5//! providing a consistent interface for interaction.
6
7use crate::aixi::common::{Action, PerceptVal, RandomGenerator, Reward};
8
9/// Interface for an agent's environment.
10///
11/// An environment consumes actions from the agent and produces percepts
12/// (observations and rewards) in response.
13pub trait Environment {
14    /// Executes an action in the environment and updates its internal state.
15    fn perform_action(&mut self, action: Action);
16
17    /// Returns the current observation produced by the environment.
18    fn get_observation(&self) -> PerceptVal;
19
20    /// Returns a stream of observation symbols produced by the last action.
21    ///
22    /// Default behavior is a single observation.
23    fn drain_observations(&mut self) -> Vec<PerceptVal> {
24        vec![self.get_observation()]
25    }
26
27    /// Returns the current reward produced by the environment.
28    fn get_reward(&self) -> Reward;
29
30    /// Returns true if the environment has reached a terminal state.
31    fn is_finished(&self) -> bool;
32
33    /// Returns the number of bits used to encode observations in this environment.
34    fn get_observation_bits(&self) -> usize;
35
36    /// Returns the number of bits used to encode rewards in this environment.
37    fn get_reward_bits(&self) -> usize;
38
39    /// Returns the number of bits required to represent all possible actions.
40    fn get_action_bits(&self) -> usize;
41
42    /// Returns the total number of valid actions available.
43    fn get_num_actions(&self) -> usize {
44        1 << self.get_action_bits()
45    }
46
47    /// Returns the maximum possible reward value in this environment.
48    fn max_reward(&self) -> Reward {
49        let bits = self.get_reward_bits();
50        if bits == 0 {
51            return 0;
52        }
53        // Prevent overflow for bits >= 64
54        if bits >= 64 {
55            i64::MAX
56        } else {
57            (1i64 << (bits - 1)) - 1
58        }
59    }
60
61    /// Returns the minimum possible reward value in this environment.
62    fn min_reward(&self) -> Reward {
63        let bits = self.get_reward_bits();
64        if bits == 0 {
65            return 0;
66        }
67        // Prevent overflow for bits >= 64
68        if bits >= 64 {
69            i64::MIN
70        } else {
71            -(1i64 << (bits - 1))
72        }
73    }
74}
75
76/// A simple biased coin flip environment.
77///
78/// The agent predicts the outcome of a coin flip. Correct predictions
79/// result in a reward of 1, otherwise 0.
80pub struct CoinFlip {
81    /// Probability of the coin landing heads (1).
82    p: f64,
83    /// Current observation (coin face).
84    obs: PerceptVal,
85    /// Last reward received.
86    rew: Reward,
87    /// Internal RNG.
88    rng: RandomGenerator,
89}
90
91impl CoinFlip {
92    /// Creates a new `CoinFlip` environment with bias `p`.
93    pub fn new(p: f64) -> Self {
94        let mut env = Self {
95            p,
96            obs: 0,
97            rew: 0,
98            rng: RandomGenerator::new(),
99        };
100        // Initial observation
101        env.gen_next();
102        env
103    }
104
105    fn gen_next(&mut self) {
106        self.obs = if self.rng.gen_bool(self.p) { 1 } else { 0 };
107    }
108}
109
110impl Environment for CoinFlip {
111    fn perform_action(&mut self, action: Action) {
112        self.gen_next();
113        self.rew = if action == self.obs { 1 } else { 0 };
114    }
115
116    fn get_observation(&self) -> PerceptVal {
117        self.obs
118    }
119    fn get_reward(&self) -> Reward {
120        self.rew
121    }
122    fn is_finished(&self) -> bool {
123        false
124    }
125
126    fn get_observation_bits(&self) -> usize {
127        1
128    }
129    fn get_reward_bits(&self) -> usize {
130        1
131    }
132
133    fn min_reward(&self) -> Reward {
134        0
135    }
136
137    fn max_reward(&self) -> Reward {
138        1
139    }
140    fn get_action_bits(&self) -> usize {
141        1
142    }
143}
144
145/// A synthetic environment for testing CTW performance.
146///
147/// Generates a deterministic sequence designed to be perfectly
148/// predictable by a sufficiently deep Context Tree.
149pub struct CtwTest {
150    cycle: usize,
151    last_action: Action,
152    obs: PerceptVal,
153    rew: Reward,
154}
155
156impl CtwTest {
157    /// Creates a new `CtwTest` environment.
158    pub fn new() -> Self {
159        Self {
160            cycle: 0,
161            last_action: 0,
162            obs: 0,
163            rew: 0,
164        }
165    }
166}
167
168impl Environment for CtwTest {
169    fn perform_action(&mut self, action: Action) {
170        if self.cycle == 0 {
171            self.obs = 0;
172            self.rew = if self.obs == action { 1 } else { 0 };
173        } else {
174            self.obs = (self.last_action + 1) % 2;
175            self.rew = if self.obs == action { 1 } else { 0 };
176        }
177        self.last_action = action;
178        self.cycle += 1;
179    }
180
181    fn get_observation(&self) -> PerceptVal {
182        self.obs
183    }
184    fn get_reward(&self) -> Reward {
185        self.rew
186    }
187    fn is_finished(&self) -> bool {
188        false
189    }
190
191    fn get_observation_bits(&self) -> usize {
192        1
193    }
194    fn get_reward_bits(&self) -> usize {
195        1
196    }
197
198    fn min_reward(&self) -> Reward {
199        0
200    }
201
202    fn max_reward(&self) -> Reward {
203        1
204    }
205    fn get_action_bits(&self) -> usize {
206        1
207    }
208}
209
210/// A Rock-Paper-Scissors environment with a biased opponent.
211///
212/// The opponent plays randomly unless it wins a round, in which case
213/// it repeats its winning action.
214pub struct BiasedRockPaperScissor {
215    obs: PerceptVal,
216    rew: Reward,
217    rng: RandomGenerator,
218    opponent_won_last_round: bool,
219    opponent_last_round_action: Action,
220}
221
222impl BiasedRockPaperScissor {
223    /// Creates a new `BiasedRockPaperScissor` environment.
224    pub fn new() -> Self {
225        Self {
226            obs: 0,
227            rew: 0,
228            rng: RandomGenerator::new(),
229            opponent_won_last_round: false,
230            opponent_last_round_action: 0,
231        }
232    }
233}
234
235impl Environment for BiasedRockPaperScissor {
236    fn perform_action(&mut self, action: Action) {
237        // action 0: Rock, 1: Paper, 2: Scissors
238        // Opponent Logic
239        let opponent_action = if self.opponent_won_last_round {
240            self.opponent_last_round_action
241        } else {
242            let r = self.rng.gen_f64();
243            if r < 1.0 / 3.0 {
244                0
245            } else if r < 2.0 / 3.0 {
246                1
247            } else {
248                2
249            }
250        };
251
252        // Determine Outcome
253        if opponent_action == action {
254            self.rew = 0; // Draw
255            self.opponent_won_last_round = false;
256        } else if (opponent_action == 0 && action == 1)
257            || (opponent_action == 1 && action == 2)
258            || (opponent_action == 2 && action == 0)
259        {
260            self.rew = 1; // Win
261            self.opponent_won_last_round = false;
262        } else {
263            self.rew = -1; // Loss
264            self.opponent_won_last_round = true;
265            self.opponent_last_round_action = opponent_action;
266        }
267        self.obs = opponent_action as PerceptVal;
268    }
269
270    fn get_observation(&self) -> PerceptVal {
271        self.obs
272    }
273    fn get_reward(&self) -> Reward {
274        self.rew
275    }
276    fn is_finished(&self) -> bool {
277        false
278    }
279
280    fn get_observation_bits(&self) -> usize {
281        2
282    }
283    fn get_reward_bits(&self) -> usize {
284        2
285    }
286
287    fn min_reward(&self) -> Reward {
288        -1
289    }
290
291    fn max_reward(&self) -> Reward {
292        1
293    }
294    fn get_action_bits(&self) -> usize {
295        2
296    }
297    fn get_num_actions(&self) -> usize {
298        3
299    }
300}
301
302/// A more complex version of the classic Tiger problem.
303///
304/// Includes states for sitting and standing, with different rewards
305/// and transition probabilities.
306pub struct ExtendedTiger {
307    state: usize, // 0: sitting, 1: standing
308    tiger_door: usize,
309    gold_door: usize,
310    obs: PerceptVal,
311    rew: Reward,
312    rng: RandomGenerator,
313}
314
315impl ExtendedTiger {
316    /// Creates a new `ExtendedTiger` environment.
317    pub fn new() -> Self {
318        let mut rng = RandomGenerator::new();
319        let gold_door = if rng.gen_bool(0.5) { 1 } else { 2 };
320        let tiger_door = if gold_door == 1 { 2 } else { 3 };
321
322        Self {
323            state: 0,
324            gold_door,
325            tiger_door,
326            obs: 0,
327            rew: 0,
328            rng,
329        }
330    }
331
332    fn reset_doors(&mut self) {
333        self.gold_door = if self.rng.gen_bool(0.5) { 1 } else { 2 };
334        self.tiger_door = if self.gold_door == 1 { 2 } else { 3 };
335    }
336}
337
338impl Environment for ExtendedTiger {
339    fn perform_action(&mut self, action: Action) {
340        // Actions: 0: Stand, 1: Listen, 2: Open 1, 3: Open 2
341        match action {
342            0 => {
343                // Stand
344                if self.state == 1 {
345                    self.rew = -1;
346                } else {
347                    self.state = 1;
348                    self.rew = -1;
349                    if self.obs < 4 {
350                        self.obs += 4;
351                    }
352                }
353            }
354            1 => {
355                // Listen
356                if self.state == 1 || self.obs != 0 {
357                    self.rew = -1;
358                    self.obs = 0;
359                } else {
360                    self.obs = if self.rng.gen_bool(0.85) {
361                        self.tiger_door as PerceptVal
362                    } else {
363                        self.gold_door as PerceptVal
364                    };
365                    self.rew = -1;
366                }
367            }
368            2 => {
369                // Open 1
370                if self.state == 0 {
371                    self.rew = -100;
372                } else {
373                    self.rew = if self.gold_door == 1 { 30 } else { -100 };
374                    self.obs = 0;
375                    self.state = 0;
376                    self.reset_doors();
377                }
378            }
379            3 => {
380                // Open 2
381                if self.state == 0 {
382                    self.rew = -100;
383                } else {
384                    self.rew = if self.gold_door == 2 { 30 } else { -100 };
385                    self.obs = 0;
386                    self.state = 0;
387                    self.reset_doors();
388                }
389            }
390            _ => {
391                self.rew = -100;
392            }
393        }
394    }
395
396    fn get_observation(&self) -> PerceptVal {
397        self.obs
398    }
399    fn get_reward(&self) -> Reward {
400        self.rew
401    }
402    fn is_finished(&self) -> bool {
403        false
404    }
405
406    fn get_observation_bits(&self) -> usize {
407        3
408    }
409    fn get_reward_bits(&self) -> usize {
410        8
411    }
412
413    fn min_reward(&self) -> Reward {
414        -100
415    }
416
417    fn max_reward(&self) -> Reward {
418        30
419    }
420    fn get_action_bits(&self) -> usize {
421        2
422    }
423    fn get_num_actions(&self) -> usize {
424        4
425    }
426}
427
428/// A standard Tic-Tac-Toe environment against a random opponent.
429pub struct TicTacToe {
430    board: [i8; 9], // 0: empty, 1: agent, -1: opponent.
431    open_squares: Vec<usize>,
432    state: u64,
433    obs: PerceptVal,
434    rew: Reward,
435    rng: RandomGenerator,
436}
437
438impl TicTacToe {
439    /// Creates a new `TicTacToe` environment.
440    pub fn new() -> Self {
441        Self {
442            board: [0; 9],
443            open_squares: (0..9).collect(),
444            state: 0,
445            obs: 0,
446            rew: 0,
447            rng: RandomGenerator::new(),
448        }
449    }
450
451    fn reset_game(&mut self) {
452        self.board = [0; 9];
453        self.open_squares = (0..9).collect();
454        self.state = 0;
455    }
456
457    fn check_win(&self, player: i8) -> bool {
458        let b = self.board;
459        let wins = [
460            (0, 1, 2),
461            (3, 4, 5),
462            (6, 7, 8), // Rows
463            (0, 3, 6),
464            (1, 4, 7),
465            (2, 5, 8), // Cols
466            (0, 4, 8),
467            (2, 4, 6), // Diags
468        ];
469        for &(x, y, z) in &wins {
470            if b[x] == player && b[y] == player && b[z] == player {
471                return true;
472            }
473        }
474        false
475    }
476}
477
478impl Environment for TicTacToe {
479    fn perform_action(&mut self, action: Action) {
480        if action >= 9 {
481            self.rew = -3;
482            self.obs = self.state as PerceptVal;
483            return;
484        }
485
486        if self.board[action as usize] != 0 {
487            // Illegal move
488            self.rew = -3;
489        } else {
490            // Agent move (1)
491            self.state += 1 << (2 * action);
492            self.board[action as usize] = 1;
493
494            // Remove from open
495            if let Some(pos) = self.open_squares.iter().position(|&x| x == action as usize) {
496                self.open_squares.remove(pos);
497            }
498
499            self.rew = 0;
500
501            if self.check_win(1) {
502                // Agent won
503                self.reset_game();
504                self.rew = 2;
505            } else if self.open_squares.is_empty() {
506                // Draw
507                self.reset_game();
508                self.rew = 1;
509            } else {
510                // Opponent move (-1, mapped to 2 in base-4)
511
512                // Shuffle open squares
513                let n = self.open_squares.len();
514                if n > 0 {
515                    let idx = self.rng.gen_range(n);
516                    let opponent_move = self.open_squares[idx];
517
518                    self.state += 2 << (2 * opponent_move);
519                    self.board[opponent_move] = -1;
520
521                    self.open_squares.remove(idx);
522
523                    if self.check_win(-1) {
524                        // Opponent won
525                        self.reset_game();
526                        self.rew = -2;
527                    } else if self.open_squares.is_empty() {
528                        self.reset_game();
529                        self.rew = 1;
530                    }
531                }
532            }
533        }
534        self.obs = self.state as PerceptVal;
535    }
536
537    fn get_observation(&self) -> PerceptVal {
538        self.obs
539    }
540    fn get_reward(&self) -> Reward {
541        self.rew
542    }
543    fn is_finished(&self) -> bool {
544        false
545    }
546
547    fn get_observation_bits(&self) -> usize {
548        18
549    } // 9 squares * 2 bits
550    fn get_reward_bits(&self) -> usize {
551        3
552    }
553    fn min_reward(&self) -> Reward {
554        -3
555    }
556    fn max_reward(&self) -> Reward {
557        2
558    }
559    fn get_action_bits(&self) -> usize {
560        4
561    }
562    fn get_num_actions(&self) -> usize {
563        9
564    }
565}
566
567/// A 2-player imperfect information game: Kuhn Poker.
568///
569/// The agent plays against a Nash-optimized opponent in a simplified
570/// 3-card poker game.
571pub struct KuhnPoker {
572    opponent_card: usize, // 0:J, 1:Q, 2:K
573    agent_card: usize,
574    agent_chips: usize,
575    chips_in_play: usize,
576    alpha: f64,
577    opponent_action: usize, // 0: pass, 1: bet
578    obs: PerceptVal,
579    rew: Reward,
580    rng: RandomGenerator,
581}
582
583impl KuhnPoker {
584    /// Creates a new `KuhnPoker` environment.
585    pub fn new() -> Self {
586        let mut env = Self {
587            opponent_card: 0,
588            agent_card: 0,
589            agent_chips: 0,
590            chips_in_play: 0,
591            alpha: 0.0,
592            opponent_action: 0,
593            obs: 0,
594            rew: 0,
595            rng: RandomGenerator::new(),
596        };
597        env.reset_game();
598        env
599    }
600
601    fn reset_game(&mut self) {
602        let r = self.rng.gen_f64();
603        self.opponent_card = if r < 1.0 / 3.0 {
604            2
605        } else if self.rng.gen_bool(0.5) {
606            1
607        } else {
608            0
609        };
610
611        // Agent card: one of the remaining
612        let k = if self.rng.gen_bool(0.5) { 1 } else { 2 };
613        self.agent_card = (self.opponent_card + k) % 3;
614
615        self.agent_chips = 1;
616        self.chips_in_play = 2; // Ante 1 each
617
618        // Opponent action (Nash)
619        self.alpha = self.rng.gen_f64() / 3.0; // alpha in [0, 1/3]
620
621        // Opponent logic matching C++
622        self.opponent_action = if self.opponent_card == 0 {
623            // Jack
624            if self.rng.gen_bool(self.alpha) { 1 } else { 0 }
625        } else if self.opponent_card == 1 {
626            // Queen
627            0 // Always check
628        } else {
629            // King
630            if self.rng.gen_bool(3.0 * self.alpha) {
631                1
632            } else {
633                0
634            }
635        };
636
637        if self.opponent_action == 1 {
638            self.chips_in_play += 1;
639        }
640
641        let card_code = 1 << self.agent_card;
642        let action_code = if self.opponent_action == 1 { 8 } else { 0 };
643        self.obs = (action_code + card_code) as PerceptVal;
644    }
645}
646
647impl Environment for KuhnPoker {
648    fn perform_action(&mut self, action: Action) {
649        // Action: 0: Pass, 1: Bet
650        self.agent_chips += action as usize;
651        self.chips_in_play += action as usize;
652
653        let opponent_bets = self.opponent_action == 1;
654        let agent_bets = action == 1;
655
656        if opponent_bets == agent_bets {
657            // Showdown
658            if self.agent_card > self.opponent_card {
659                self.rew = self.chips_in_play as i64;
660            } else {
661                self.rew = -(self.agent_chips as i64);
662            }
663            self.reset_game();
664        } else if opponent_bets && !agent_bets {
665            // Opponent bet, Agent fold
666            self.rew = -(self.agent_chips as i64);
667            self.reset_game();
668        } else {
669            // Opponent passed, Agent bet. Opponent decision.
670            let call = self.rng.gen_bool(self.alpha + 1.0 / 3.0);
671            if call {
672                self.chips_in_play += 1;
673                if self.agent_card > self.opponent_card {
674                    self.rew = self.chips_in_play as i64;
675                } else {
676                    self.rew = -(self.agent_chips as i64);
677                }
678            } else {
679                // Opponent folds
680                self.rew = self.chips_in_play as i64;
681            }
682            self.reset_game();
683        }
684    }
685
686    fn get_observation(&self) -> PerceptVal {
687        self.obs
688    }
689    fn get_reward(&self) -> Reward {
690        self.rew
691    }
692    fn is_finished(&self) -> bool {
693        false
694    }
695
696    fn get_observation_bits(&self) -> usize {
697        4
698    }
699    fn get_reward_bits(&self) -> usize {
700        3
701    }
702
703    fn min_reward(&self) -> Reward {
704        -2
705    }
706
707    fn max_reward(&self) -> Reward {
708        4
709    }
710    fn get_action_bits(&self) -> usize {
711        1
712    } // 0 or 1
713    fn get_num_actions(&self) -> usize {
714        2
715    }
716}
717
718