infotheory/aixi/
common.rs

1//! Common types and utilities for the AIXI implementation.
2
3/// Represents a single bit (0 or 1) in the agent's interaction history.
4pub type Symbol = bool;
5
6/// A list of symbols, used to represent encoded observations, rewards, or actions.
7pub type SymbolList = Vec<Symbol>;
8
9/// Represents an action that the agent can perform.
10pub type Action = u64;
11
12/// Represents a reward received by the agent from the environment.
13pub type Reward = i64;
14
15/// A generic value for a percept component (either an observation or a reward).
16pub type PerceptVal = u64;
17
18/// Strategy for mapping an observation stream into a single percept key for tree search.
19#[derive(Clone, Copy, Debug, Eq, PartialEq)]
20pub enum ObservationKeyMode {
21    /// Use the full observation stream as the key (paper-accurate expectimax).
22    FullStream,
23    /// Use the first observation symbol as the key.
24    First,
25    /// Use the last observation symbol as the key.
26    Last,
27    /// Hash the entire observation stream into a single key.
28    StreamHash,
29}
30
31/// Compute a percept key from an observation stream.
32pub fn observation_key_from_stream(
33    mode: ObservationKeyMode,
34    observations: &[PerceptVal],
35    observation_bits: usize,
36) -> PerceptVal {
37    match mode {
38        ObservationKeyMode::FullStream => {
39            debug_assert!(
40                false,
41                "observation_key_from_stream called with FullStream; use observation_repr_from_stream"
42            );
43            // Fallback to hash in release builds to avoid panics.
44            observation_key_from_stream(
45                ObservationKeyMode::StreamHash,
46                observations,
47                observation_bits,
48            )
49        }
50        ObservationKeyMode::First => observations.first().copied().unwrap_or(0),
51        ObservationKeyMode::Last => observations.last().copied().unwrap_or(0),
52        ObservationKeyMode::StreamHash => {
53            let mask = if observation_bits >= 64 {
54                u64::MAX
55            } else if observation_bits == 0 {
56                0
57            } else {
58                (1u64 << observation_bits) - 1
59            };
60            let mut h = 0u64;
61            for &obs in observations {
62                let v = obs & mask;
63                h = h.rotate_left(7) ^ v;
64            }
65            h
66        }
67    }
68}
69
70/// Compute the observation representation used for tree branching.
71///
72/// - `FullStream` returns the full stream (paper-accurate expectimax).
73/// - Other modes collapse to a single-key vector.
74pub fn observation_repr_from_stream(
75    mode: ObservationKeyMode,
76    observations: &[PerceptVal],
77    observation_bits: usize,
78) -> Vec<PerceptVal> {
79    match mode {
80        ObservationKeyMode::FullStream => observations.to_vec(),
81        _ => vec![observation_key_from_stream(
82            mode,
83            observations,
84            observation_bits,
85        )],
86    }
87}
88
89/// A high-performance random number generator using the XorShift64* algorithm.
90#[derive(Clone, Copy)]
91pub struct RandomGenerator {
92    state: u64,
93}
94
95impl RandomGenerator {
96    #[inline]
97    fn initial_seed() -> u64 {
98        #[cfg(feature = "backend-zpaq")]
99        {
100            if let Ok(bytes) = zpaq_rs::random_bytes(8) {
101                let mut seed_arr = [0u8; 8];
102                seed_arr.copy_from_slice(&bytes);
103                return u64::from_le_bytes(seed_arr);
104            }
105        }
106
107        #[cfg(target_arch = "wasm32")]
108        {
109            // `SystemTime::now()` is unavailable on `wasm32-unknown-unknown` without WASI.
110            return 0xCAFEBABEDEADBEEF ^ 0x9E3779B97F4A7C15;
111        }
112
113        #[cfg(not(target_arch = "wasm32"))]
114        #[allow(clippy::cast_possible_truncation)]
115        {
116            let nanos = std::time::SystemTime::now()
117                .duration_since(std::time::UNIX_EPOCH)
118                .map(|d| d.as_nanos() as u64)
119                .unwrap_or(0xCAFEBABEDEADBEEF);
120            return nanos ^ 0x9E3779B97F4A7C15;
121        }
122
123        #[allow(unreachable_code)]
124        0xCAFEBABEDEADBEEF
125    }
126
127    /// Creates a new `RandomGenerator` with a fresh seed.
128    pub fn new() -> Self {
129        let seed = Self::initial_seed();
130        let state = if seed == 0 { 0xCAFEBABEDEADBEEF } else { seed };
131        Self { state }
132    }
133
134    /// Creates a new `RandomGenerator` from an explicit seed.
135    ///
136    /// A zero seed is remapped to a fixed non-zero constant to avoid the
137    /// xorshift zero-state trap.
138    pub fn from_seed(seed: u64) -> Self {
139        let state = if seed == 0 { 0xCAFEBABEDEADBEEF } else { seed };
140        Self { state }
141    }
142
143    /// Generates the next pseudo-random `u64`.
144    pub fn next_u64(&mut self) -> u64 {
145        // xorshift64*
146        let mut x = self.state;
147        x ^= x >> 12;
148        x ^= x << 25;
149        x ^= x >> 27;
150        self.state = x;
151        x.wrapping_mul(0x2545F4914F6CDD1D)
152    }
153
154    /// Generates a pseudo-random `usize` in the range `[0, end)`.
155    pub fn gen_range(&mut self, end: usize) -> usize {
156        if end == 0 {
157            return 0;
158        }
159        (self.next_u64() % (end as u64)) as usize
160    }
161
162    /// Generates a boolean value with probability `p` of being `true`.
163    pub fn gen_bool(&mut self, p: f64) -> bool {
164        self.gen_f64() < p
165    }
166
167    /// Generates a pseudo-random `f64` in the range `[0, 1)`.
168    pub fn gen_f64(&mut self) -> f64 {
169        // 53 bits
170        let v = self.next_u64() >> 11;
171        (v as f64) * (1.0 / 9007199254740992.0)
172    }
173
174    /// Forks the RNG state with a salt, returning an independent generator.
175    pub fn fork_with(&self, salt: u64) -> Self {
176        let mixed = Self::splitmix64(self.state ^ salt ^ 0x9E3779B97F4A7C15);
177        let state = if mixed == 0 {
178            0xCAFEBABEDEADBEEF
179        } else {
180            mixed
181        };
182        Self { state }
183    }
184
185    fn splitmix64(mut x: u64) -> u64 {
186        x = x.wrapping_add(0x9E3779B97F4A7C15);
187        let mut z = x;
188        z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
189        z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
190        z ^ (z >> 31)
191    }
192}
193
194impl Default for RandomGenerator {
195    fn default() -> Self {
196        Self::new()
197    }
198}
199
200/// Encodes a numeric value into its bit representation and appends it to a `SymbolList`.
201///
202/// Bits are appended in least-significant-bit first order.
203pub fn encode(symlist: &mut SymbolList, value: u64, bits: usize) {
204    let mut v = value;
205    for _ in 0..bits {
206        symlist.push((v & 1) == 1);
207        v >>= 1;
208    }
209}
210
211/// Encodes a signed reward value into its bit representation.
212pub fn encode_reward(symlist: &mut SymbolList, value: i64, bits: usize) {
213    let mut v = value as u64;
214    for _ in 0..bits {
215        symlist.push((v & 1) == 1);
216        v >>= 1;
217    }
218}
219
220/// Encodes a reward after applying an additive `offset`.
221pub fn encode_reward_offset(symlist: &mut SymbolList, value: i64, bits: usize, offset: i64) {
222    let shifted = (value + offset) as u64;
223    encode(symlist, shifted, bits);
224}
225
226/// Decodes a numeric value from its bit representation.
227pub fn decode(symlist: &[Symbol], bits: usize) -> u64 {
228    if bits == 0 {
229        return 0;
230    }
231    assert!(bits <= symlist.len());
232    let mut value = 0u64;
233    for i in 0..bits {
234        let sym = symlist[symlist.len() - 1 - i];
235        value = (value << 1) + (if sym { 1 } else { 0 });
236    }
237    value
238}
239
240/// Decodes a signed reward value from its bit representation.
241pub fn decode_reward(symlist: &[Symbol], bits: usize) -> i64 {
242    if bits == 0 {
243        return 0;
244    }
245    let v = decode(symlist, bits);
246    if bits < 64 && (v & (1 << (bits - 1))) != 0 {
247        // Sign bit set, perform two's complement sign extension
248        (v | (!0u64 << bits)) as i64
249    } else {
250        v as i64
251    }
252}
253
254/// Decodes a reward encoded with [`encode_reward_offset`].
255pub fn decode_reward_offset(symlist: &[Symbol], bits: usize, offset: i64) -> i64 {
256    if bits == 0 {
257        return 0;
258    }
259    let v = decode(symlist, bits) as i64;
260    v - offset
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[test]
268    fn observation_repr_full_stream_is_identity() {
269        let obs = vec![1u64, 2u64, 3u64];
270        let repr = observation_repr_from_stream(ObservationKeyMode::FullStream, &obs, 8);
271        assert_eq!(repr, obs);
272    }
273
274    #[test]
275    fn observation_key_first_last() {
276        let obs = vec![10u64, 20u64, 30u64];
277        assert_eq!(
278            observation_key_from_stream(ObservationKeyMode::First, &obs, 8),
279            10
280        );
281        assert_eq!(
282            observation_key_from_stream(ObservationKeyMode::Last, &obs, 8),
283            30
284        );
285
286        let empty: Vec<PerceptVal> = vec![];
287        assert_eq!(
288            observation_key_from_stream(ObservationKeyMode::First, &empty, 8),
289            0
290        );
291        assert_eq!(
292            observation_key_from_stream(ObservationKeyMode::Last, &empty, 8),
293            0
294        );
295    }
296
297    #[test]
298    fn observation_key_stream_hash_masks_and_mix() {
299        // observation_bits=3 => mask=0b111
300        // obs[0]=9 -> 1; h=0.rotate_left(7)^1 = 1
301        // obs[1]=2 -> 2; h=1.rotate_left(7)^2 = 128^2 = 130
302        let obs = vec![9u64, 2u64];
303        let h = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 3);
304        assert_eq!(h, 130);
305    }
306
307    #[test]
308    fn observation_key_stream_hash_observation_bits_zero_is_zero() {
309        let obs = vec![123u64, 456u64, 789u64];
310        let h = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 0);
311        assert_eq!(h, 0);
312    }
313
314    #[test]
315    fn observation_key_stream_hash_observation_bits_ge_64_uses_full_u64() {
316        let obs = vec![u64::MAX, 0x0123_4567_89ab_cdef];
317        let h1 = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 64);
318        let h2 = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 128);
319        assert_eq!(h1, h2);
320    }
321}