1use crate::aixi::common::{Action, PerceptVal, RandomGenerator, Reward};
8
9pub trait Environment {
14 fn perform_action(&mut self, action: Action);
16
17 fn get_observation(&self) -> PerceptVal;
19
20 fn drain_observations(&mut self) -> Vec<PerceptVal> {
24 vec![self.get_observation()]
25 }
26
27 fn get_reward(&self) -> Reward;
29
30 fn is_finished(&self) -> bool;
32
33 fn get_observation_bits(&self) -> usize;
35
36 fn get_reward_bits(&self) -> usize;
38
39 fn get_action_bits(&self) -> usize;
41
42 fn get_num_actions(&self) -> usize {
44 1 << self.get_action_bits()
45 }
46
47 fn max_reward(&self) -> Reward {
49 let bits = self.get_reward_bits();
50 if bits == 0 {
51 return 0;
52 }
53 if bits >= 64 {
55 i64::MAX
56 } else {
57 (1i64 << (bits - 1)) - 1
58 }
59 }
60
61 fn min_reward(&self) -> Reward {
63 let bits = self.get_reward_bits();
64 if bits == 0 {
65 return 0;
66 }
67 if bits >= 64 {
69 i64::MIN
70 } else {
71 -(1i64 << (bits - 1))
72 }
73 }
74}
75
76pub struct CoinFlip {
81 p: f64,
83 obs: PerceptVal,
85 rew: Reward,
87 rng: RandomGenerator,
89}
90
91impl CoinFlip {
92 pub fn new(p: f64) -> Self {
94 let mut env = Self {
95 p,
96 obs: 0,
97 rew: 0,
98 rng: RandomGenerator::new(),
99 };
100 env.gen_next();
102 env
103 }
104
105 fn gen_next(&mut self) {
106 self.obs = if self.rng.gen_bool(self.p) { 1 } else { 0 };
107 }
108}
109
110impl Environment for CoinFlip {
111 fn perform_action(&mut self, action: Action) {
112 self.gen_next();
113 self.rew = if action == self.obs { 1 } else { 0 };
114 }
115
116 fn get_observation(&self) -> PerceptVal {
117 self.obs
118 }
119 fn get_reward(&self) -> Reward {
120 self.rew
121 }
122 fn is_finished(&self) -> bool {
123 false
124 }
125
126 fn get_observation_bits(&self) -> usize {
127 1
128 }
129 fn get_reward_bits(&self) -> usize {
130 1
131 }
132
133 fn min_reward(&self) -> Reward {
134 0
135 }
136
137 fn max_reward(&self) -> Reward {
138 1
139 }
140 fn get_action_bits(&self) -> usize {
141 1
142 }
143}
144
145pub struct CtwTest {
150 cycle: usize,
151 last_action: Action,
152 obs: PerceptVal,
153 rew: Reward,
154}
155
156impl CtwTest {
157 pub fn new() -> Self {
159 Self {
160 cycle: 0,
161 last_action: 0,
162 obs: 0,
163 rew: 0,
164 }
165 }
166}
167
168impl Environment for CtwTest {
169 fn perform_action(&mut self, action: Action) {
170 if self.cycle == 0 {
171 self.obs = 0;
172 self.rew = if self.obs == action { 1 } else { 0 };
173 } else {
174 self.obs = (self.last_action + 1) % 2;
175 self.rew = if self.obs == action { 1 } else { 0 };
176 }
177 self.last_action = action;
178 self.cycle += 1;
179 }
180
181 fn get_observation(&self) -> PerceptVal {
182 self.obs
183 }
184 fn get_reward(&self) -> Reward {
185 self.rew
186 }
187 fn is_finished(&self) -> bool {
188 false
189 }
190
191 fn get_observation_bits(&self) -> usize {
192 1
193 }
194 fn get_reward_bits(&self) -> usize {
195 1
196 }
197
198 fn min_reward(&self) -> Reward {
199 0
200 }
201
202 fn max_reward(&self) -> Reward {
203 1
204 }
205 fn get_action_bits(&self) -> usize {
206 1
207 }
208}
209
210pub struct BiasedRockPaperScissor {
215 obs: PerceptVal,
216 rew: Reward,
217 rng: RandomGenerator,
218 opponent_won_last_round: bool,
219 opponent_last_round_action: Action,
220}
221
222impl BiasedRockPaperScissor {
223 pub fn new() -> Self {
225 Self {
226 obs: 0,
227 rew: 0,
228 rng: RandomGenerator::new(),
229 opponent_won_last_round: false,
230 opponent_last_round_action: 0,
231 }
232 }
233}
234
235impl Environment for BiasedRockPaperScissor {
236 fn perform_action(&mut self, action: Action) {
237 let opponent_action = if self.opponent_won_last_round {
240 self.opponent_last_round_action
241 } else {
242 let r = self.rng.gen_f64();
243 if r < 1.0 / 3.0 {
244 0
245 } else if r < 2.0 / 3.0 {
246 1
247 } else {
248 2
249 }
250 };
251
252 if opponent_action == action {
254 self.rew = 0; self.opponent_won_last_round = false;
256 } else if (opponent_action == 0 && action == 1)
257 || (opponent_action == 1 && action == 2)
258 || (opponent_action == 2 && action == 0)
259 {
260 self.rew = 1; self.opponent_won_last_round = false;
262 } else {
263 self.rew = -1; self.opponent_won_last_round = true;
265 self.opponent_last_round_action = opponent_action;
266 }
267 self.obs = opponent_action as PerceptVal;
268 }
269
270 fn get_observation(&self) -> PerceptVal {
271 self.obs
272 }
273 fn get_reward(&self) -> Reward {
274 self.rew
275 }
276 fn is_finished(&self) -> bool {
277 false
278 }
279
280 fn get_observation_bits(&self) -> usize {
281 2
282 }
283 fn get_reward_bits(&self) -> usize {
284 2
285 }
286
287 fn min_reward(&self) -> Reward {
288 -1
289 }
290
291 fn max_reward(&self) -> Reward {
292 1
293 }
294 fn get_action_bits(&self) -> usize {
295 2
296 }
297 fn get_num_actions(&self) -> usize {
298 3
299 }
300}
301
302pub struct ExtendedTiger {
307 state: usize, tiger_door: usize,
309 gold_door: usize,
310 obs: PerceptVal,
311 rew: Reward,
312 rng: RandomGenerator,
313}
314
315impl ExtendedTiger {
316 pub fn new() -> Self {
318 let mut rng = RandomGenerator::new();
319 let gold_door = if rng.gen_bool(0.5) { 1 } else { 2 };
320 let tiger_door = if gold_door == 1 { 2 } else { 3 };
321
322 Self {
323 state: 0,
324 gold_door,
325 tiger_door,
326 obs: 0,
327 rew: 0,
328 rng,
329 }
330 }
331
332 fn reset_doors(&mut self) {
333 self.gold_door = if self.rng.gen_bool(0.5) { 1 } else { 2 };
334 self.tiger_door = if self.gold_door == 1 { 2 } else { 3 };
335 }
336}
337
338impl Environment for ExtendedTiger {
339 fn perform_action(&mut self, action: Action) {
340 match action {
342 0 => {
343 if self.state == 1 {
345 self.rew = -1;
346 } else {
347 self.state = 1;
348 self.rew = -1;
349 if self.obs < 4 {
350 self.obs += 4;
351 }
352 }
353 }
354 1 => {
355 if self.state == 1 || self.obs != 0 {
357 self.rew = -1;
358 self.obs = 0;
359 } else {
360 self.obs = if self.rng.gen_bool(0.85) {
361 self.tiger_door as PerceptVal
362 } else {
363 self.gold_door as PerceptVal
364 };
365 self.rew = -1;
366 }
367 }
368 2 => {
369 if self.state == 0 {
371 self.rew = -100;
372 } else {
373 self.rew = if self.gold_door == 1 { 30 } else { -100 };
374 self.obs = 0;
375 self.state = 0;
376 self.reset_doors();
377 }
378 }
379 3 => {
380 if self.state == 0 {
382 self.rew = -100;
383 } else {
384 self.rew = if self.gold_door == 2 { 30 } else { -100 };
385 self.obs = 0;
386 self.state = 0;
387 self.reset_doors();
388 }
389 }
390 _ => {
391 self.rew = -100;
392 }
393 }
394 }
395
396 fn get_observation(&self) -> PerceptVal {
397 self.obs
398 }
399 fn get_reward(&self) -> Reward {
400 self.rew
401 }
402 fn is_finished(&self) -> bool {
403 false
404 }
405
406 fn get_observation_bits(&self) -> usize {
407 3
408 }
409 fn get_reward_bits(&self) -> usize {
410 8
411 }
412
413 fn min_reward(&self) -> Reward {
414 -100
415 }
416
417 fn max_reward(&self) -> Reward {
418 30
419 }
420 fn get_action_bits(&self) -> usize {
421 2
422 }
423 fn get_num_actions(&self) -> usize {
424 4
425 }
426}
427
428pub struct TicTacToe {
430 board: [i8; 9], open_squares: Vec<usize>,
432 state: u64,
433 obs: PerceptVal,
434 rew: Reward,
435 rng: RandomGenerator,
436}
437
438impl TicTacToe {
439 pub fn new() -> Self {
441 Self {
442 board: [0; 9],
443 open_squares: (0..9).collect(),
444 state: 0,
445 obs: 0,
446 rew: 0,
447 rng: RandomGenerator::new(),
448 }
449 }
450
451 fn reset_game(&mut self) {
452 self.board = [0; 9];
453 self.open_squares = (0..9).collect();
454 self.state = 0;
455 }
456
457 fn check_win(&self, player: i8) -> bool {
458 let b = self.board;
459 let wins = [
460 (0, 1, 2),
461 (3, 4, 5),
462 (6, 7, 8), (0, 3, 6),
464 (1, 4, 7),
465 (2, 5, 8), (0, 4, 8),
467 (2, 4, 6), ];
469 for &(x, y, z) in &wins {
470 if b[x] == player && b[y] == player && b[z] == player {
471 return true;
472 }
473 }
474 false
475 }
476}
477
478impl Environment for TicTacToe {
479 fn perform_action(&mut self, action: Action) {
480 if action >= 9 {
481 self.rew = -3;
482 self.obs = self.state as PerceptVal;
483 return;
484 }
485
486 if self.board[action as usize] != 0 {
487 self.rew = -3;
489 } else {
490 self.state += 1 << (2 * action);
492 self.board[action as usize] = 1;
493
494 if let Some(pos) = self.open_squares.iter().position(|&x| x == action as usize) {
496 self.open_squares.remove(pos);
497 }
498
499 self.rew = 0;
500
501 if self.check_win(1) {
502 self.reset_game();
504 self.rew = 2;
505 } else if self.open_squares.is_empty() {
506 self.reset_game();
508 self.rew = 1;
509 } else {
510 let n = self.open_squares.len();
514 if n > 0 {
515 let idx = self.rng.gen_range(n);
516 let opponent_move = self.open_squares[idx];
517
518 self.state += 2 << (2 * opponent_move);
519 self.board[opponent_move] = -1;
520
521 self.open_squares.remove(idx);
522
523 if self.check_win(-1) {
524 self.reset_game();
526 self.rew = -2;
527 } else if self.open_squares.is_empty() {
528 self.reset_game();
529 self.rew = 1;
530 }
531 }
532 }
533 }
534 self.obs = self.state as PerceptVal;
535 }
536
537 fn get_observation(&self) -> PerceptVal {
538 self.obs
539 }
540 fn get_reward(&self) -> Reward {
541 self.rew
542 }
543 fn is_finished(&self) -> bool {
544 false
545 }
546
547 fn get_observation_bits(&self) -> usize {
548 18
549 } fn get_reward_bits(&self) -> usize {
551 3
552 }
553 fn min_reward(&self) -> Reward {
554 -3
555 }
556 fn max_reward(&self) -> Reward {
557 2
558 }
559 fn get_action_bits(&self) -> usize {
560 4
561 }
562 fn get_num_actions(&self) -> usize {
563 9
564 }
565}
566
567pub struct KuhnPoker {
572 opponent_card: usize, agent_card: usize,
574 agent_chips: usize,
575 chips_in_play: usize,
576 alpha: f64,
577 opponent_action: usize, obs: PerceptVal,
579 rew: Reward,
580 rng: RandomGenerator,
581}
582
583impl KuhnPoker {
584 pub fn new() -> Self {
586 let mut env = Self {
587 opponent_card: 0,
588 agent_card: 0,
589 agent_chips: 0,
590 chips_in_play: 0,
591 alpha: 0.0,
592 opponent_action: 0,
593 obs: 0,
594 rew: 0,
595 rng: RandomGenerator::new(),
596 };
597 env.reset_game();
598 env
599 }
600
601 fn reset_game(&mut self) {
602 let r = self.rng.gen_f64();
603 self.opponent_card = if r < 1.0 / 3.0 {
604 2
605 } else if self.rng.gen_bool(0.5) {
606 1
607 } else {
608 0
609 };
610
611 let k = if self.rng.gen_bool(0.5) { 1 } else { 2 };
613 self.agent_card = (self.opponent_card + k) % 3;
614
615 self.agent_chips = 1;
616 self.chips_in_play = 2; self.alpha = self.rng.gen_f64() / 3.0; self.opponent_action = if self.opponent_card == 0 {
623 if self.rng.gen_bool(self.alpha) { 1 } else { 0 }
625 } else if self.opponent_card == 1 {
626 0 } else {
629 if self.rng.gen_bool(3.0 * self.alpha) {
631 1
632 } else {
633 0
634 }
635 };
636
637 if self.opponent_action == 1 {
638 self.chips_in_play += 1;
639 }
640
641 let card_code = 1 << self.agent_card;
642 let action_code = if self.opponent_action == 1 { 8 } else { 0 };
643 self.obs = (action_code + card_code) as PerceptVal;
644 }
645}
646
647impl Environment for KuhnPoker {
648 fn perform_action(&mut self, action: Action) {
649 self.agent_chips += action as usize;
651 self.chips_in_play += action as usize;
652
653 let opponent_bets = self.opponent_action == 1;
654 let agent_bets = action == 1;
655
656 if opponent_bets == agent_bets {
657 if self.agent_card > self.opponent_card {
659 self.rew = self.chips_in_play as i64;
660 } else {
661 self.rew = -(self.agent_chips as i64);
662 }
663 self.reset_game();
664 } else if opponent_bets && !agent_bets {
665 self.rew = -(self.agent_chips as i64);
667 self.reset_game();
668 } else {
669 let call = self.rng.gen_bool(self.alpha + 1.0 / 3.0);
671 if call {
672 self.chips_in_play += 1;
673 if self.agent_card > self.opponent_card {
674 self.rew = self.chips_in_play as i64;
675 } else {
676 self.rew = -(self.agent_chips as i64);
677 }
678 } else {
679 self.rew = self.chips_in_play as i64;
681 }
682 self.reset_game();
683 }
684 }
685
686 fn get_observation(&self) -> PerceptVal {
687 self.obs
688 }
689 fn get_reward(&self) -> Reward {
690 self.rew
691 }
692 fn is_finished(&self) -> bool {
693 false
694 }
695
696 fn get_observation_bits(&self) -> usize {
697 4
698 }
699 fn get_reward_bits(&self) -> usize {
700 3
701 }
702
703 fn min_reward(&self) -> Reward {
704 -2
705 }
706
707 fn max_reward(&self) -> Reward {
708 4
709 }
710 fn get_action_bits(&self) -> usize {
711 1
712 } fn get_num_actions(&self) -> usize {
714 2
715 }
716}
717
718