infotheory/aixi/
common.rs1pub type Symbol = bool;
5
6pub type SymbolList = Vec<Symbol>;
8
9pub type Action = u64;
11
12pub type Reward = i64;
14
15pub type PerceptVal = u64;
17
18#[derive(Clone, Copy, Debug, Eq, PartialEq)]
20pub enum ObservationKeyMode {
21 FullStream,
23 First,
25 Last,
27 StreamHash,
29}
30
31pub fn observation_key_from_stream(
33 mode: ObservationKeyMode,
34 observations: &[PerceptVal],
35 observation_bits: usize,
36) -> PerceptVal {
37 match mode {
38 ObservationKeyMode::FullStream => {
39 debug_assert!(
40 false,
41 "observation_key_from_stream called with FullStream; use observation_repr_from_stream"
42 );
43 observation_key_from_stream(
45 ObservationKeyMode::StreamHash,
46 observations,
47 observation_bits,
48 )
49 }
50 ObservationKeyMode::First => observations.first().copied().unwrap_or(0),
51 ObservationKeyMode::Last => observations.last().copied().unwrap_or(0),
52 ObservationKeyMode::StreamHash => {
53 let mask = if observation_bits >= 64 {
54 u64::MAX
55 } else if observation_bits == 0 {
56 0
57 } else {
58 (1u64 << observation_bits) - 1
59 };
60 let mut h = 0u64;
61 for &obs in observations {
62 let v = obs & mask;
63 h = h.rotate_left(7) ^ v;
64 }
65 h
66 }
67 }
68}
69
70pub fn observation_repr_from_stream(
75 mode: ObservationKeyMode,
76 observations: &[PerceptVal],
77 observation_bits: usize,
78) -> Vec<PerceptVal> {
79 match mode {
80 ObservationKeyMode::FullStream => observations.to_vec(),
81 _ => vec![observation_key_from_stream(
82 mode,
83 observations,
84 observation_bits,
85 )],
86 }
87}
88
89#[derive(Clone, Copy)]
91pub struct RandomGenerator {
92 state: u64,
93}
94
95impl RandomGenerator {
96 #[inline]
97 fn initial_seed() -> u64 {
98 #[cfg(feature = "backend-zpaq")]
99 {
100 if let Ok(bytes) = zpaq_rs::random_bytes(8) {
101 let mut seed_arr = [0u8; 8];
102 seed_arr.copy_from_slice(&bytes);
103 return u64::from_le_bytes(seed_arr);
104 }
105 }
106
107 #[cfg(target_arch = "wasm32")]
108 {
109 return 0xCAFEBABEDEADBEEF ^ 0x9E3779B97F4A7C15;
111 }
112
113 #[cfg(not(target_arch = "wasm32"))]
114 #[allow(clippy::cast_possible_truncation)]
115 {
116 let nanos = std::time::SystemTime::now()
117 .duration_since(std::time::UNIX_EPOCH)
118 .map(|d| d.as_nanos() as u64)
119 .unwrap_or(0xCAFEBABEDEADBEEF);
120 return nanos ^ 0x9E3779B97F4A7C15;
121 }
122
123 #[allow(unreachable_code)]
124 0xCAFEBABEDEADBEEF
125 }
126
127 pub fn new() -> Self {
129 let seed = Self::initial_seed();
130 let state = if seed == 0 { 0xCAFEBABEDEADBEEF } else { seed };
131 Self { state }
132 }
133
134 pub fn from_seed(seed: u64) -> Self {
139 let state = if seed == 0 { 0xCAFEBABEDEADBEEF } else { seed };
140 Self { state }
141 }
142
143 pub fn next_u64(&mut self) -> u64 {
145 let mut x = self.state;
147 x ^= x >> 12;
148 x ^= x << 25;
149 x ^= x >> 27;
150 self.state = x;
151 x.wrapping_mul(0x2545F4914F6CDD1D)
152 }
153
154 pub fn gen_range(&mut self, end: usize) -> usize {
156 if end == 0 {
157 return 0;
158 }
159 (self.next_u64() % (end as u64)) as usize
160 }
161
162 pub fn gen_bool(&mut self, p: f64) -> bool {
164 self.gen_f64() < p
165 }
166
167 pub fn gen_f64(&mut self) -> f64 {
169 let v = self.next_u64() >> 11;
171 (v as f64) * (1.0 / 9007199254740992.0)
172 }
173
174 pub fn fork_with(&self, salt: u64) -> Self {
176 let mixed = Self::splitmix64(self.state ^ salt ^ 0x9E3779B97F4A7C15);
177 let state = if mixed == 0 {
178 0xCAFEBABEDEADBEEF
179 } else {
180 mixed
181 };
182 Self { state }
183 }
184
185 fn splitmix64(mut x: u64) -> u64 {
186 x = x.wrapping_add(0x9E3779B97F4A7C15);
187 let mut z = x;
188 z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
189 z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
190 z ^ (z >> 31)
191 }
192}
193
194impl Default for RandomGenerator {
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200pub fn encode(symlist: &mut SymbolList, value: u64, bits: usize) {
204 let mut v = value;
205 for _ in 0..bits {
206 symlist.push((v & 1) == 1);
207 v >>= 1;
208 }
209}
210
211pub fn encode_reward(symlist: &mut SymbolList, value: i64, bits: usize) {
213 let mut v = value as u64;
214 for _ in 0..bits {
215 symlist.push((v & 1) == 1);
216 v >>= 1;
217 }
218}
219
220pub fn encode_reward_offset(symlist: &mut SymbolList, value: i64, bits: usize, offset: i64) {
222 let shifted = (value + offset) as u64;
223 encode(symlist, shifted, bits);
224}
225
226pub fn decode(symlist: &[Symbol], bits: usize) -> u64 {
228 if bits == 0 {
229 return 0;
230 }
231 assert!(bits <= symlist.len());
232 let mut value = 0u64;
233 for i in 0..bits {
234 let sym = symlist[symlist.len() - 1 - i];
235 value = (value << 1) + (if sym { 1 } else { 0 });
236 }
237 value
238}
239
240pub fn decode_reward(symlist: &[Symbol], bits: usize) -> i64 {
242 if bits == 0 {
243 return 0;
244 }
245 let v = decode(symlist, bits);
246 if bits < 64 && (v & (1 << (bits - 1))) != 0 {
247 (v | (!0u64 << bits)) as i64
249 } else {
250 v as i64
251 }
252}
253
254pub fn decode_reward_offset(symlist: &[Symbol], bits: usize, offset: i64) -> i64 {
256 if bits == 0 {
257 return 0;
258 }
259 let v = decode(symlist, bits) as i64;
260 v - offset
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn observation_repr_full_stream_is_identity() {
269 let obs = vec![1u64, 2u64, 3u64];
270 let repr = observation_repr_from_stream(ObservationKeyMode::FullStream, &obs, 8);
271 assert_eq!(repr, obs);
272 }
273
274 #[test]
275 fn observation_key_first_last() {
276 let obs = vec![10u64, 20u64, 30u64];
277 assert_eq!(
278 observation_key_from_stream(ObservationKeyMode::First, &obs, 8),
279 10
280 );
281 assert_eq!(
282 observation_key_from_stream(ObservationKeyMode::Last, &obs, 8),
283 30
284 );
285
286 let empty: Vec<PerceptVal> = vec![];
287 assert_eq!(
288 observation_key_from_stream(ObservationKeyMode::First, &empty, 8),
289 0
290 );
291 assert_eq!(
292 observation_key_from_stream(ObservationKeyMode::Last, &empty, 8),
293 0
294 );
295 }
296
297 #[test]
298 fn observation_key_stream_hash_masks_and_mix() {
299 let obs = vec![9u64, 2u64];
303 let h = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 3);
304 assert_eq!(h, 130);
305 }
306
307 #[test]
308 fn observation_key_stream_hash_observation_bits_zero_is_zero() {
309 let obs = vec![123u64, 456u64, 789u64];
310 let h = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 0);
311 assert_eq!(h, 0);
312 }
313
314 #[test]
315 fn observation_key_stream_hash_observation_bits_ge_64_uses_full_u64() {
316 let obs = vec![u64::MAX, 0x0123_4567_89ab_cdef];
317 let h1 = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 64);
318 let h2 = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 128);
319 assert_eq!(h1, h2);
320 }
321}