infotheory/aixi/
common.rs1pub type Symbol = bool;
5
6pub type SymbolList = Vec<Symbol>;
8
9pub type Action = u64;
11
12pub type Reward = i64;
14
15pub type PerceptVal = u64;
17
18#[derive(Clone, Copy, Debug, Eq, PartialEq)]
20pub enum ObservationKeyMode {
21 FullStream,
23 First,
25 Last,
27 StreamHash,
29}
30
31pub fn observation_key_from_stream(
33 mode: ObservationKeyMode,
34 observations: &[PerceptVal],
35 observation_bits: usize,
36) -> PerceptVal {
37 match mode {
38 ObservationKeyMode::FullStream => {
39 debug_assert!(
40 false,
41 "observation_key_from_stream called with FullStream; use observation_repr_from_stream"
42 );
43 observation_key_from_stream(
45 ObservationKeyMode::StreamHash,
46 observations,
47 observation_bits,
48 )
49 }
50 ObservationKeyMode::First => observations.first().copied().unwrap_or(0),
51 ObservationKeyMode::Last => observations.last().copied().unwrap_or(0),
52 ObservationKeyMode::StreamHash => {
53 let mask = if observation_bits >= 64 {
54 u64::MAX
55 } else if observation_bits == 0 {
56 0
57 } else {
58 (1u64 << observation_bits) - 1
59 };
60 let mut h = 0u64;
61 for &obs in observations {
62 let v = obs & mask;
63 h = h.rotate_left(7) ^ v;
64 }
65 h
66 }
67 }
68}
69
70pub fn observation_repr_from_stream(
75 mode: ObservationKeyMode,
76 observations: &[PerceptVal],
77 observation_bits: usize,
78) -> Vec<PerceptVal> {
79 match mode {
80 ObservationKeyMode::FullStream => observations.to_vec(),
81 _ => vec![observation_key_from_stream(
82 mode,
83 observations,
84 observation_bits,
85 )],
86 }
87}
88
89#[derive(Clone, Copy)]
94pub struct RandomGenerator {
95 state: u64,
96}
97
98impl RandomGenerator {
99 pub fn new() -> Self {
101 let bytes = zpaq_rs::random_bytes(8).expect("Failed to get random seed");
103 let mut seed_arr = [0u8; 8];
104 seed_arr.copy_from_slice(&bytes);
105 let seed = u64::from_le_bytes(seed_arr);
106 let state = if seed == 0 { 0xCAFEBABEDEADBEEF } else { seed };
107 Self { state }
108 }
109
110 pub fn next_u64(&mut self) -> u64 {
112 let mut x = self.state;
114 x ^= x >> 12;
115 x ^= x << 25;
116 x ^= x >> 27;
117 self.state = x;
118 x.wrapping_mul(0x2545F4914F6CDD1D)
119 }
120
121 pub fn gen_range(&mut self, end: usize) -> usize {
123 if end == 0 {
124 return 0;
125 }
126 (self.next_u64() % (end as u64)) as usize
127 }
128
129 pub fn gen_bool(&mut self, p: f64) -> bool {
131 self.gen_f64() < p
132 }
133
134 pub fn gen_f64(&mut self) -> f64 {
136 let v = self.next_u64() >> 11;
138 (v as f64) * (1.0 / 9007199254740992.0)
139 }
140
141 pub fn fork_with(&self, salt: u64) -> Self {
143 let mixed = Self::splitmix64(self.state ^ salt ^ 0x9E3779B97F4A7C15);
144 let state = if mixed == 0 {
145 0xCAFEBABEDEADBEEF
146 } else {
147 mixed
148 };
149 Self { state }
150 }
151
152 fn splitmix64(mut x: u64) -> u64 {
153 x = x.wrapping_add(0x9E3779B97F4A7C15);
154 let mut z = x;
155 z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
156 z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
157 z ^ (z >> 31)
158 }
159}
160
161pub fn encode(symlist: &mut SymbolList, value: u64, bits: usize) {
165 let mut v = value;
166 for _ in 0..bits {
167 symlist.push((v & 1) == 1);
168 v >>= 1;
169 }
170}
171
172pub fn encode_reward(symlist: &mut SymbolList, value: i64, bits: usize) {
174 let mut v = value as u64;
175 for _ in 0..bits {
176 symlist.push((v & 1) == 1);
177 v >>= 1;
178 }
179}
180
181pub fn encode_reward_offset(symlist: &mut SymbolList, value: i64, bits: usize, offset: i64) {
182 let shifted = (value + offset) as u64;
183 encode(symlist, shifted, bits);
184}
185
186pub fn decode(symlist: &[Symbol], bits: usize) -> u64 {
188 if bits == 0 {
189 return 0;
190 }
191 assert!(bits <= symlist.len());
192 let mut value = 0u64;
193 for i in 0..bits {
194 let sym = symlist[symlist.len() - 1 - i];
195 value = (value << 1) + (if sym { 1 } else { 0 });
196 }
197 value
198}
199
200pub fn decode_reward(symlist: &[Symbol], bits: usize) -> i64 {
202 if bits == 0 {
203 return 0;
204 }
205 let v = decode(symlist, bits);
206 if bits < 64 && (v & (1 << (bits - 1))) != 0 {
207 (v | (!0u64 << bits)) as i64
209 } else {
210 v as i64
211 }
212}
213
214pub fn decode_reward_offset(symlist: &[Symbol], bits: usize, offset: i64) -> i64 {
215 if bits == 0 {
216 return 0;
217 }
218 let v = decode(symlist, bits) as i64;
219 v - offset
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn observation_repr_full_stream_is_identity() {
228 let obs = vec![1u64, 2u64, 3u64];
229 let repr = observation_repr_from_stream(ObservationKeyMode::FullStream, &obs, 8);
230 assert_eq!(repr, obs);
231 }
232
233 #[test]
234 fn observation_key_first_last() {
235 let obs = vec![10u64, 20u64, 30u64];
236 assert_eq!(
237 observation_key_from_stream(ObservationKeyMode::First, &obs, 8),
238 10
239 );
240 assert_eq!(
241 observation_key_from_stream(ObservationKeyMode::Last, &obs, 8),
242 30
243 );
244
245 let empty: Vec<PerceptVal> = vec![];
246 assert_eq!(
247 observation_key_from_stream(ObservationKeyMode::First, &empty, 8),
248 0
249 );
250 assert_eq!(
251 observation_key_from_stream(ObservationKeyMode::Last, &empty, 8),
252 0
253 );
254 }
255
256 #[test]
257 fn observation_key_stream_hash_masks_and_mix() {
258 let obs = vec![9u64, 2u64];
262 let h = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 3);
263 assert_eq!(h, 130);
264 }
265
266 #[test]
267 fn observation_key_stream_hash_observation_bits_zero_is_zero() {
268 let obs = vec![123u64, 456u64, 789u64];
269 let h = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 0);
270 assert_eq!(h, 0);
271 }
272
273 #[test]
274 fn observation_key_stream_hash_observation_bits_ge_64_uses_full_u64() {
275 let obs = vec![u64::MAX, 0x0123_4567_89ab_cdef];
276 let h1 = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 64);
277 let h2 = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 128);
278 assert_eq!(h1, h2);
279 }
280}