1use crate::aixi::common::{Action, PerceptVal, RandomGenerator, Reward};
8
9pub trait Environment {
14 fn perform_action(&mut self, action: Action);
16
17 fn get_observation(&self) -> PerceptVal;
19
20 fn drain_observations(&mut self) -> Vec<PerceptVal> {
24 vec![self.get_observation()]
25 }
26
27 fn get_reward(&self) -> Reward;
29
30 fn is_finished(&self) -> bool;
32
33 fn get_observation_bits(&self) -> usize;
35
36 fn get_reward_bits(&self) -> usize;
38
39 fn get_action_bits(&self) -> usize;
41
42 fn set_random_seed(&mut self, _seed: u64) {}
48
49 fn get_num_actions(&self) -> usize {
51 1 << self.get_action_bits()
52 }
53
54 fn max_reward(&self) -> Reward {
56 let bits = self.get_reward_bits();
57 if bits == 0 {
58 return 0;
59 }
60 if bits >= 64 {
62 i64::MAX
63 } else {
64 (1i64 << (bits - 1)) - 1
65 }
66 }
67
68 fn min_reward(&self) -> Reward {
70 let bits = self.get_reward_bits();
71 if bits == 0 {
72 return 0;
73 }
74 if bits >= 64 {
76 i64::MIN
77 } else {
78 -(1i64 << (bits - 1))
79 }
80 }
81}
82
83pub struct CoinFlip {
88 p: f64,
90 obs: PerceptVal,
92 rew: Reward,
94 rng: RandomGenerator,
96}
97
98impl CoinFlip {
99 pub fn new(p: f64) -> Self {
101 Self::new_with_seed(p, None)
102 }
103
104 pub fn new_with_seed(p: f64, seed: Option<u64>) -> Self {
106 let mut env = Self {
107 p,
108 obs: 0,
109 rew: 0,
110 rng: seed.map(RandomGenerator::from_seed).unwrap_or_default(),
111 };
112 env.gen_next();
114 env
115 }
116
117 fn gen_next(&mut self) {
118 self.obs = if self.rng.gen_bool(self.p) { 1 } else { 0 };
119 }
120}
121
122impl Environment for CoinFlip {
123 fn perform_action(&mut self, action: Action) {
124 self.gen_next();
125 self.rew = if action == self.obs { 1 } else { 0 };
126 }
127
128 fn get_observation(&self) -> PerceptVal {
129 self.obs
130 }
131 fn get_reward(&self) -> Reward {
132 self.rew
133 }
134 fn is_finished(&self) -> bool {
135 false
136 }
137
138 fn get_observation_bits(&self) -> usize {
139 1
140 }
141 fn get_reward_bits(&self) -> usize {
142 1
143 }
144
145 fn min_reward(&self) -> Reward {
146 0
147 }
148
149 fn max_reward(&self) -> Reward {
150 1
151 }
152 fn get_action_bits(&self) -> usize {
153 1
154 }
155
156 fn set_random_seed(&mut self, seed: u64) {
157 self.rng = RandomGenerator::from_seed(seed);
158 self.rew = 0;
159 self.gen_next();
160 }
161}
162
163pub struct CtwTest {
168 cycle: usize,
169 last_action: Action,
170 obs: PerceptVal,
171 rew: Reward,
172}
173
174impl CtwTest {
175 pub fn new() -> Self {
177 Self {
178 cycle: 0,
179 last_action: 0,
180 obs: 0,
181 rew: 0,
182 }
183 }
184}
185
186impl Default for CtwTest {
187 fn default() -> Self {
188 Self::new()
189 }
190}
191
192impl Environment for CtwTest {
193 fn perform_action(&mut self, action: Action) {
194 if self.cycle == 0 {
195 self.obs = 0;
196 self.rew = if self.obs == action { 1 } else { 0 };
197 } else {
198 self.obs = (self.last_action + 1) % 2;
199 self.rew = if self.obs == action { 1 } else { 0 };
200 }
201 self.last_action = action;
202 self.cycle += 1;
203 }
204
205 fn get_observation(&self) -> PerceptVal {
206 self.obs
207 }
208 fn get_reward(&self) -> Reward {
209 self.rew
210 }
211 fn is_finished(&self) -> bool {
212 false
213 }
214
215 fn get_observation_bits(&self) -> usize {
216 1
217 }
218 fn get_reward_bits(&self) -> usize {
219 1
220 }
221
222 fn min_reward(&self) -> Reward {
223 0
224 }
225
226 fn max_reward(&self) -> Reward {
227 1
228 }
229 fn get_action_bits(&self) -> usize {
230 1
231 }
232}
233
234pub struct BiasedRockPaperScissor {
239 obs: PerceptVal,
240 rew: Reward,
241 rng: RandomGenerator,
242}
243
244impl BiasedRockPaperScissor {
245 pub fn new() -> Self {
247 Self::new_with_seed(None)
248 }
249
250 pub fn new_with_seed(seed: Option<u64>) -> Self {
252 Self {
253 obs: 1,
255 rew: 0,
256 rng: seed.map(RandomGenerator::from_seed).unwrap_or_default(),
257 }
258 }
259}
260
261impl Default for BiasedRockPaperScissor {
262 fn default() -> Self {
263 Self::new()
264 }
265}
266
267impl Environment for BiasedRockPaperScissor {
268 fn perform_action(&mut self, action: Action) {
269 let opponent_action = if self.obs == 0 && self.rew == -1 {
273 0
274 } else {
275 let r = self.rng.gen_f64();
276 if r < 1.0 / 3.0 {
277 0
278 } else if r < 2.0 / 3.0 {
279 1
280 } else {
281 2
282 }
283 };
284
285 if opponent_action == action {
287 self.rew = 0; } else if (opponent_action == 0 && action == 1)
289 || (opponent_action == 1 && action == 2)
290 || (opponent_action == 2 && action == 0)
291 {
292 self.rew = 1; } else {
294 self.rew = -1; }
296 self.obs = opponent_action as PerceptVal;
297 }
298
299 fn get_observation(&self) -> PerceptVal {
300 self.obs
301 }
302 fn get_reward(&self) -> Reward {
303 self.rew
304 }
305 fn is_finished(&self) -> bool {
306 false
307 }
308
309 fn get_observation_bits(&self) -> usize {
310 2
311 }
312 fn get_reward_bits(&self) -> usize {
313 2
314 }
315
316 fn min_reward(&self) -> Reward {
317 -1
318 }
319
320 fn max_reward(&self) -> Reward {
321 1
322 }
323 fn get_action_bits(&self) -> usize {
324 2
325 }
326 fn get_num_actions(&self) -> usize {
327 3
328 }
329
330 fn set_random_seed(&mut self, seed: u64) {
331 self.rng = RandomGenerator::from_seed(seed);
332 self.obs = 1;
334 self.rew = 0;
335 }
336}
337
338pub struct ExtendedTiger {
343 state: usize, tiger_door: usize,
345 gold_door: usize,
346 obs: PerceptVal,
347 rew: Reward,
348 rng: RandomGenerator,
349}
350
351impl ExtendedTiger {
352 pub fn new() -> Self {
354 let mut rng = RandomGenerator::new();
355 let gold_door = if rng.gen_bool(0.5) { 1 } else { 2 };
356 let tiger_door = if gold_door == 1 { 2 } else { 3 };
357
358 Self {
359 state: 0,
360 gold_door,
361 tiger_door,
362 obs: 0,
363 rew: 0,
364 rng,
365 }
366 }
367
368 fn reset_doors(&mut self) {
369 self.gold_door = if self.rng.gen_bool(0.5) { 1 } else { 2 };
370 self.tiger_door = if self.gold_door == 1 { 2 } else { 3 };
371 }
372}
373
374impl Default for ExtendedTiger {
375 fn default() -> Self {
376 Self::new()
377 }
378}
379
380impl Environment for ExtendedTiger {
381 fn perform_action(&mut self, action: Action) {
382 match action {
384 0 => {
385 if self.state == 1 {
387 self.rew = -1;
388 } else {
389 self.state = 1;
390 self.rew = -1;
391 if self.obs < 4 {
392 self.obs += 4;
393 }
394 }
395 }
396 1 => {
397 if self.state == 1 || self.obs != 0 {
399 self.rew = -1;
400 self.obs = 0;
401 } else {
402 self.obs = if self.rng.gen_bool(0.85) {
403 self.tiger_door as PerceptVal
404 } else {
405 self.gold_door as PerceptVal
406 };
407 self.rew = -1;
408 }
409 }
410 2 => {
411 if self.state == 0 {
413 self.rew = -100;
414 } else {
415 self.rew = if self.gold_door == 1 { 30 } else { -100 };
416 self.obs = 0;
417 self.state = 0;
418 self.reset_doors();
419 }
420 }
421 3 => {
422 if self.state == 0 {
424 self.rew = -100;
425 } else {
426 self.rew = if self.gold_door == 2 { 30 } else { -100 };
427 self.obs = 0;
428 self.state = 0;
429 self.reset_doors();
430 }
431 }
432 _ => {
433 self.rew = -100;
434 }
435 }
436 }
437
438 fn get_observation(&self) -> PerceptVal {
439 self.obs
440 }
441 fn get_reward(&self) -> Reward {
442 self.rew
443 }
444 fn is_finished(&self) -> bool {
445 false
446 }
447
448 fn get_observation_bits(&self) -> usize {
449 3
450 }
451 fn get_reward_bits(&self) -> usize {
452 8
453 }
454
455 fn min_reward(&self) -> Reward {
456 -100
457 }
458
459 fn max_reward(&self) -> Reward {
460 30
461 }
462 fn get_action_bits(&self) -> usize {
463 2
464 }
465 fn get_num_actions(&self) -> usize {
466 4
467 }
468
469 fn set_random_seed(&mut self, seed: u64) {
470 self.rng = RandomGenerator::from_seed(seed);
471 self.state = 0;
472 self.obs = 0;
473 self.rew = 0;
474 self.reset_doors();
475 }
476}
477
478pub struct TicTacToe {
480 board: [i8; 9], open_squares: Vec<usize>,
482 state: u64,
483 obs: PerceptVal,
484 rew: Reward,
485 rng: RandomGenerator,
486}
487
488impl TicTacToe {
489 pub fn new() -> Self {
491 Self {
492 board: [0; 9],
493 open_squares: (0..9).collect(),
494 state: 0,
495 obs: 0,
496 rew: 0,
497 rng: RandomGenerator::new(),
498 }
499 }
500
501 fn reset_game(&mut self) {
502 self.board = [0; 9];
503 self.open_squares = (0..9).collect();
504 self.state = 0;
505 }
506
507 fn check_win(&self, player: i8) -> bool {
508 let b = self.board;
509 let wins = [
510 (0, 1, 2),
511 (3, 4, 5),
512 (6, 7, 8), (0, 3, 6),
514 (1, 4, 7),
515 (2, 5, 8), (0, 4, 8),
517 (2, 4, 6), ];
519 for &(x, y, z) in &wins {
520 if b[x] == player && b[y] == player && b[z] == player {
521 return true;
522 }
523 }
524 false
525 }
526}
527
528impl Default for TicTacToe {
529 fn default() -> Self {
530 Self::new()
531 }
532}
533
534impl Environment for TicTacToe {
535 fn perform_action(&mut self, action: Action) {
536 if action >= 9 {
537 self.rew = -3;
538 self.obs = self.state as PerceptVal;
539 return;
540 }
541
542 if self.board[action as usize] != 0 {
543 self.rew = -3;
545 } else {
546 self.state += 1 << (2 * action);
548 self.board[action as usize] = 1;
549
550 if let Some(pos) = self.open_squares.iter().position(|&x| x == action as usize) {
552 self.open_squares.remove(pos);
553 }
554
555 self.rew = 0;
556
557 if self.check_win(1) {
558 self.reset_game();
560 self.rew = 2;
561 } else if self.open_squares.is_empty() {
562 self.reset_game();
564 self.rew = 1;
565 } else {
566 let n = self.open_squares.len();
570 if n > 0 {
571 let idx = self.rng.gen_range(n);
572 let opponent_move = self.open_squares[idx];
573
574 self.state += 2 << (2 * opponent_move);
575 self.board[opponent_move] = -1;
576
577 self.open_squares.remove(idx);
578
579 if self.check_win(-1) {
580 self.reset_game();
582 self.rew = -2;
583 } else if self.open_squares.is_empty() {
584 self.reset_game();
585 self.rew = 1;
586 }
587 }
588 }
589 }
590 self.obs = self.state as PerceptVal;
591 }
592
593 fn get_observation(&self) -> PerceptVal {
594 self.obs
595 }
596 fn get_reward(&self) -> Reward {
597 self.rew
598 }
599 fn is_finished(&self) -> bool {
600 false
601 }
602
603 fn get_observation_bits(&self) -> usize {
604 18
605 } fn get_reward_bits(&self) -> usize {
607 3
608 }
609 fn min_reward(&self) -> Reward {
610 -3
611 }
612 fn max_reward(&self) -> Reward {
613 2
614 }
615 fn get_action_bits(&self) -> usize {
616 4
617 }
618 fn get_num_actions(&self) -> usize {
619 9
620 }
621
622 fn set_random_seed(&mut self, seed: u64) {
623 self.rng = RandomGenerator::from_seed(seed);
624 self.reset_game();
625 self.obs = 0;
626 self.rew = 0;
627 }
628}
629
630pub struct KuhnPoker {
635 opponent_card: usize, agent_card: usize,
637 opponent_action: usize, obs: PerceptVal,
639 rew: Reward,
640 rng: RandomGenerator,
641}
642
643impl KuhnPoker {
644 pub fn new() -> Self {
646 Self::new_with_seed(None)
647 }
648
649 pub fn new_with_seed(seed: Option<u64>) -> Self {
651 let mut env = Self {
652 opponent_card: 0,
653 agent_card: 0,
654 opponent_action: 0,
655 obs: 0,
656 rew: 0,
657 rng: seed.map(RandomGenerator::from_seed).unwrap_or_default(),
658 };
659 env.reset_game();
660 env
661 }
662
663 #[inline]
664 fn random_card(&mut self) -> usize {
665 self.rng.gen_range(3)
666 }
667
668 fn reset_game(&mut self) {
669 self.agent_card = self.random_card();
672 self.opponent_card = self.agent_card;
673 while self.opponent_card == self.agent_card {
674 self.opponent_card = self.random_card();
675 }
676
677 const ACTION_BET: usize = 0;
678 const ACTION_PASS: usize = 1;
679 const BET_PROB_KING: f64 = 0.7;
680 const BET_PROB_JACK: f64 = BET_PROB_KING / 3.0;
681
682 self.opponent_action = if self.opponent_card == 0 {
684 if self.rng.gen_bool(BET_PROB_JACK) {
685 ACTION_BET
686 } else {
687 ACTION_PASS
688 }
689 } else if self.opponent_card == 1 {
690 ACTION_PASS
691 } else if self.rng.gen_bool(BET_PROB_KING) {
692 ACTION_BET
693 } else {
694 ACTION_PASS
695 };
696
697 let action_code = if self.opponent_action == ACTION_PASS {
700 4
701 } else {
702 0
703 };
704 let card_code = self.agent_card;
705 self.obs = (action_code + card_code) as PerceptVal;
706 }
707}
708
709impl Default for KuhnPoker {
710 fn default() -> Self {
711 Self::new()
712 }
713}
714
715impl Environment for KuhnPoker {
716 fn perform_action(&mut self, action: Action) {
717 const ACTION_BET: usize = 0;
718 const ACTION_PASS: usize = 1;
719
720 const R_BET_LOSS: Reward = -2;
723 const R_PASS_LOSS: Reward = -1;
724 const R_PASS_WIN: Reward = 1;
725 const R_BET_WIN: Reward = 2;
726
727 const BET_PROB_KING: f64 = 0.7;
728 const BET_PROB_QUEEN: f64 = (1.0 + BET_PROB_KING) / 3.0;
729
730 if action > 1 {
731 self.rew = R_BET_LOSS;
732 self.reset_game();
733 return;
734 }
735
736 if action as usize == ACTION_PASS && self.opponent_action == ACTION_BET {
738 self.rew = R_PASS_LOSS;
739 self.reset_game();
740 return;
741 }
742
743 if action as usize == ACTION_BET && self.opponent_action == ACTION_PASS {
745 if self.opponent_card == 1 && self.rng.gen_bool(BET_PROB_QUEEN) {
746 self.opponent_action = ACTION_BET;
747 } else if self.opponent_card == 2 {
748 self.opponent_action = ACTION_BET;
749 } else {
750 self.rew = R_PASS_WIN;
751 self.reset_game();
752 return;
753 }
754 }
755
756 let agent_wins =
758 self.opponent_card == 0 || (self.opponent_card == 1 && self.agent_card == 2);
759 if agent_wins {
760 self.rew = if self.opponent_action == ACTION_BET {
761 R_BET_WIN
762 } else {
763 R_PASS_WIN
764 };
765 } else {
766 self.rew = if action as usize == ACTION_BET {
767 R_BET_LOSS
768 } else {
769 R_PASS_LOSS
770 };
771 }
772 self.reset_game();
773 }
774
775 fn get_observation(&self) -> PerceptVal {
776 self.obs
777 }
778 fn get_reward(&self) -> Reward {
779 self.rew
780 }
781 fn is_finished(&self) -> bool {
782 false
783 }
784
785 fn get_observation_bits(&self) -> usize {
786 3
787 }
788 fn get_reward_bits(&self) -> usize {
789 3
790 }
791
792 fn min_reward(&self) -> Reward {
793 -2
794 }
795
796 fn max_reward(&self) -> Reward {
797 2
798 }
799 fn get_action_bits(&self) -> usize {
800 1
801 } fn get_num_actions(&self) -> usize {
803 2
804 }
805
806 fn set_random_seed(&mut self, seed: u64) {
807 self.rng = RandomGenerator::from_seed(seed);
808 self.rew = 0;
809 self.reset_game();
810 }
811}