infotheory/aixi/
common.rs

1//! Common types and utilities for the AIXI implementation.
2
3/// Represents a single bit (0 or 1) in the agent's interaction history.
4pub type Symbol = bool;
5
6/// A list of symbols, used to represent encoded observations, rewards, or actions.
7pub type SymbolList = Vec<Symbol>;
8
9/// Represents an action that the agent can perform.
10pub type Action = u64;
11
12/// Represents a reward received by the agent from the environment.
13pub type Reward = i64;
14
15/// A generic value for a percept component (either an observation or a reward).
16pub type PerceptVal = u64;
17
18/// Strategy for mapping an observation stream into a single percept key for tree search.
19#[derive(Clone, Copy, Debug, Eq, PartialEq)]
20pub enum ObservationKeyMode {
21    /// Use the full observation stream as the key (paper-accurate expectimax).
22    FullStream,
23    /// Use the first observation symbol as the key.
24    First,
25    /// Use the last observation symbol as the key.
26    Last,
27    /// Hash the entire observation stream into a single key.
28    StreamHash,
29}
30
31/// Compute a percept key from an observation stream.
32pub fn observation_key_from_stream(
33    mode: ObservationKeyMode,
34    observations: &[PerceptVal],
35    observation_bits: usize,
36) -> PerceptVal {
37    match mode {
38        ObservationKeyMode::FullStream => {
39            debug_assert!(
40                false,
41                "observation_key_from_stream called with FullStream; use observation_repr_from_stream"
42            );
43            // Fallback to hash in release builds to avoid panics.
44            observation_key_from_stream(
45                ObservationKeyMode::StreamHash,
46                observations,
47                observation_bits,
48            )
49        }
50        ObservationKeyMode::First => observations.first().copied().unwrap_or(0),
51        ObservationKeyMode::Last => observations.last().copied().unwrap_or(0),
52        ObservationKeyMode::StreamHash => {
53            let mask = if observation_bits >= 64 {
54                u64::MAX
55            } else if observation_bits == 0 {
56                0
57            } else {
58                (1u64 << observation_bits) - 1
59            };
60            let mut h = 0u64;
61            for &obs in observations {
62                let v = obs & mask;
63                h = h.rotate_left(7) ^ v;
64            }
65            h
66        }
67    }
68}
69
70/// Compute the observation representation used for tree branching.
71///
72/// - `FullStream` returns the full stream (paper-accurate expectimax).
73/// - Other modes collapse to a single-key vector.
74pub fn observation_repr_from_stream(
75    mode: ObservationKeyMode,
76    observations: &[PerceptVal],
77    observation_bits: usize,
78) -> Vec<PerceptVal> {
79    match mode {
80        ObservationKeyMode::FullStream => observations.to_vec(),
81        _ => vec![observation_key_from_stream(
82            mode,
83            observations,
84            observation_bits,
85        )],
86    }
87}
88
89/// A high-performance random number generator using the XorShift64* algorithm.
90///
91/// This generator is seeded using `zpaq_rs::random_bytes` to avoid external dependencies
92/// like the `rand` crate while using OS-provided entropy for the initial seed.
93#[derive(Clone, Copy)]
94pub struct RandomGenerator {
95    state: u64,
96}
97
98impl RandomGenerator {
99    /// Creates a new `RandomGenerator` with a fresh seed.
100    pub fn new() -> Self {
101        // Seeding from zpaq_rs
102        let bytes = zpaq_rs::random_bytes(8).expect("Failed to get random seed");
103        let mut seed_arr = [0u8; 8];
104        seed_arr.copy_from_slice(&bytes);
105        let seed = u64::from_le_bytes(seed_arr);
106        let state = if seed == 0 { 0xCAFEBABEDEADBEEF } else { seed };
107        Self { state }
108    }
109
110    /// Generates the next pseudo-random `u64`.
111    pub fn next_u64(&mut self) -> u64 {
112        // xorshift64*
113        let mut x = self.state;
114        x ^= x >> 12;
115        x ^= x << 25;
116        x ^= x >> 27;
117        self.state = x;
118        x.wrapping_mul(0x2545F4914F6CDD1D)
119    }
120
121    /// Generates a pseudo-random `usize` in the range `[0, end)`.
122    pub fn gen_range(&mut self, end: usize) -> usize {
123        if end == 0 {
124            return 0;
125        }
126        (self.next_u64() % (end as u64)) as usize
127    }
128
129    /// Generates a boolean value with probability `p` of being `true`.
130    pub fn gen_bool(&mut self, p: f64) -> bool {
131        self.gen_f64() < p
132    }
133
134    /// Generates a pseudo-random `f64` in the range `[0, 1)`.
135    pub fn gen_f64(&mut self) -> f64 {
136        // 53 bits
137        let v = self.next_u64() >> 11;
138        (v as f64) * (1.0 / 9007199254740992.0)
139    }
140
141    /// Forks the RNG state with a salt, returning an independent generator.
142    pub fn fork_with(&self, salt: u64) -> Self {
143        let mixed = Self::splitmix64(self.state ^ salt ^ 0x9E3779B97F4A7C15);
144        let state = if mixed == 0 {
145            0xCAFEBABEDEADBEEF
146        } else {
147            mixed
148        };
149        Self { state }
150    }
151
152    fn splitmix64(mut x: u64) -> u64 {
153        x = x.wrapping_add(0x9E3779B97F4A7C15);
154        let mut z = x;
155        z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
156        z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
157        z ^ (z >> 31)
158    }
159}
160
161/// Encodes a numeric value into its bit representation and appends it to a `SymbolList`.
162///
163/// Bits are appended in least-significant-bit first order.
164pub fn encode(symlist: &mut SymbolList, value: u64, bits: usize) {
165    let mut v = value;
166    for _ in 0..bits {
167        symlist.push((v & 1) == 1);
168        v >>= 1;
169    }
170}
171
172/// Encodes a signed reward value into its bit representation.
173pub fn encode_reward(symlist: &mut SymbolList, value: i64, bits: usize) {
174    let mut v = value as u64;
175    for _ in 0..bits {
176        symlist.push((v & 1) == 1);
177        v >>= 1;
178    }
179}
180
181pub fn encode_reward_offset(symlist: &mut SymbolList, value: i64, bits: usize, offset: i64) {
182    let shifted = (value + offset) as u64;
183    encode(symlist, shifted, bits);
184}
185
186/// Decodes a numeric value from its bit representation.
187pub fn decode(symlist: &[Symbol], bits: usize) -> u64 {
188    if bits == 0 {
189        return 0;
190    }
191    assert!(bits <= symlist.len());
192    let mut value = 0u64;
193    for i in 0..bits {
194        let sym = symlist[symlist.len() - 1 - i];
195        value = (value << 1) + (if sym { 1 } else { 0 });
196    }
197    value
198}
199
200/// Decodes a signed reward value from its bit representation.
201pub fn decode_reward(symlist: &[Symbol], bits: usize) -> i64 {
202    if bits == 0 {
203        return 0;
204    }
205    let v = decode(symlist, bits);
206    if bits < 64 && (v & (1 << (bits - 1))) != 0 {
207        // Sign bit set, perform two's complement sign extension
208        (v | (!0u64 << bits)) as i64
209    } else {
210        v as i64
211    }
212}
213
214pub fn decode_reward_offset(symlist: &[Symbol], bits: usize, offset: i64) -> i64 {
215    if bits == 0 {
216        return 0;
217    }
218    let v = decode(symlist, bits) as i64;
219    v - offset
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn observation_repr_full_stream_is_identity() {
228        let obs = vec![1u64, 2u64, 3u64];
229        let repr = observation_repr_from_stream(ObservationKeyMode::FullStream, &obs, 8);
230        assert_eq!(repr, obs);
231    }
232
233    #[test]
234    fn observation_key_first_last() {
235        let obs = vec![10u64, 20u64, 30u64];
236        assert_eq!(
237            observation_key_from_stream(ObservationKeyMode::First, &obs, 8),
238            10
239        );
240        assert_eq!(
241            observation_key_from_stream(ObservationKeyMode::Last, &obs, 8),
242            30
243        );
244
245        let empty: Vec<PerceptVal> = vec![];
246        assert_eq!(
247            observation_key_from_stream(ObservationKeyMode::First, &empty, 8),
248            0
249        );
250        assert_eq!(
251            observation_key_from_stream(ObservationKeyMode::Last, &empty, 8),
252            0
253        );
254    }
255
256    #[test]
257    fn observation_key_stream_hash_masks_and_mix() {
258        // observation_bits=3 => mask=0b111
259        // obs[0]=9 -> 1; h=0.rotate_left(7)^1 = 1
260        // obs[1]=2 -> 2; h=1.rotate_left(7)^2 = 128^2 = 130
261        let obs = vec![9u64, 2u64];
262        let h = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 3);
263        assert_eq!(h, 130);
264    }
265
266    #[test]
267    fn observation_key_stream_hash_observation_bits_zero_is_zero() {
268        let obs = vec![123u64, 456u64, 789u64];
269        let h = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 0);
270        assert_eq!(h, 0);
271    }
272
273    #[test]
274    fn observation_key_stream_hash_observation_bits_ge_64_uses_full_u64() {
275        let obs = vec![u64::MAX, 0x0123_4567_89ab_cdef];
276        let h1 = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 64);
277        let h2 = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 128);
278        assert_eq!(h1, h2);
279    }
280}