1use crate::aixi::common::{Action, PerceptVal, RandomGenerator, Reward};
11use crate::aixi::model::{CtwPredictor, FacCtwPredictor, Predictor, RateBackendBitPredictor};
12use crate::aixi::rate_backend::rate_backend_contains_zpaq;
13#[cfg(feature = "backend-rwkv")]
14use crate::load_rwkv7_model_from_path;
15use crate::{RateBackend, validate_rate_backend};
16
17#[derive(Clone)]
19pub struct AiqiConfig {
20 pub algorithm: String,
28 pub ct_depth: usize,
30 pub observation_bits: usize,
32 pub observation_stream_len: usize,
34 pub reward_bits: usize,
36 pub agent_actions: usize,
38 pub min_reward: Reward,
40 pub max_reward: Reward,
42 pub reward_offset: Reward,
44 pub discount_gamma: f64,
46 pub return_horizon: usize,
48 pub return_bins: usize,
53 pub augmentation_period: usize,
55 pub history_prune_keep_steps: Option<usize>,
62 pub baseline_exploration: f64,
64 pub random_seed: Option<u64>,
68 pub rate_backend: Option<RateBackend>,
73 pub rate_backend_max_order: i64,
75 pub rwkv_model_path: Option<String>,
80 pub rosa_max_order: Option<i64>,
82 pub zpaq_method: Option<String>,
84}
85
86impl AiqiConfig {
87 pub fn validate(&self) -> Result<(), String> {
89 if self.agent_actions == 0 {
90 return Err("agent_actions must be >= 1".to_string());
91 }
92 if self.return_horizon == 0 {
93 return Err("return_horizon must be >= 1".to_string());
94 }
95 if self.return_bins == 0 {
96 return Err("return_bins must be >= 1".to_string());
97 }
98 if !self.return_bins.is_power_of_two() {
99 return Err(format!(
100 "return_bins must be a power of two for exact binary return encoding, got {}",
101 self.return_bins
102 ));
103 }
104 if self.augmentation_period < self.return_horizon {
105 return Err(format!(
106 "augmentation_period must be >= return_horizon (got N={}, H={})",
107 self.augmentation_period, self.return_horizon
108 ));
109 }
110 if !(0.0 < self.discount_gamma && self.discount_gamma < 1.0) {
111 return Err(format!(
112 "discount_gamma must be in (0, 1) for AIQI as defined in \"A Model-Free Universal AI\", got {}",
113 self.discount_gamma
114 ));
115 }
116 if !(0.0 < self.baseline_exploration && self.baseline_exploration <= 1.0) {
117 return Err(format!(
118 "baseline_exploration (tau) must be in (0, 1] for AIQI as defined in \"A Model-Free Universal AI\", got {}",
119 self.baseline_exploration
120 ));
121 }
122 if self.max_reward < self.min_reward {
123 return Err(format!(
124 "max_reward must be >= min_reward (got {} < {})",
125 self.max_reward, self.min_reward
126 ));
127 }
128
129 if self.rate_backend.is_none() {
132 match self.algorithm.as_str() {
133 "ctw" | "fac-ctw" | "ac-ctw" | "ctw-context-tree" | "rosa" => {}
134 "zpaq" => {
135 return Err(
136 "AIQI strict mode does not support algorithm=zpaq: zpaq backends do not provide strict frozen conditioning"
137 .to_string(),
138 )
139 }
140 #[cfg(feature = "backend-rwkv")]
141 "rwkv" => {}
142 #[cfg(not(feature = "backend-rwkv"))]
143 "rwkv" => {
144 return Err("algorithm=rwkv requires backend-rwkv feature".to_string())
145 }
146 other => return Err(format!("Unknown AIQI algorithm: {other}")),
147 }
148 }
149
150 if let Some(rate_backend) = &self.rate_backend {
151 validate_rate_backend(rate_backend)
152 .map_err(|err| format!("invalid rate_backend: {err}"))?;
153 if !rate_backend_supports_aiqi_frozen_conditioning(rate_backend) {
154 return Err(
155 "AIQI strict mode requires frozen context updates; configured rate_backend contains zpaq which does not provide strict frozen conditioning"
156 .to_string(),
157 );
158 }
159 }
160
161 #[cfg(feature = "backend-rwkv")]
162 if self.rate_backend.is_none() && self.algorithm == "rwkv" {
163 match self.rwkv_model_path.as_deref() {
164 Some(path) if !path.trim().is_empty() => {}
165 _ => {
166 return Err(
167 "algorithm=rwkv requires rwkv_model_path when no rate_backend override is configured; for method-string RWKV configure rate_backend rwkv/rwkv7"
168 .to_string(),
169 )
170 }
171 }
172 }
173
174 let min_shifted = (self.min_reward as i128) + (self.reward_offset as i128);
175 let max_shifted = (self.max_reward as i128) + (self.reward_offset as i128);
176 if min_shifted < 0 {
177 return Err(format!(
178 "reward_offset too small: min_reward + reward_offset must be >= 0 (got {})",
179 min_shifted
180 ));
181 }
182 if self.reward_bits < 64 {
183 let max_enc = (1u128 << self.reward_bits) - 1;
184 if (max_shifted as u128) > max_enc {
185 return Err(format!(
186 "reward_bits too small for configured reward range: max shifted reward {} exceeds {}",
187 max_shifted, max_enc
188 ));
189 }
190 }
191
192 Ok(())
193 }
194}
195
196#[derive(Clone, Debug)]
197struct StepRecord {
198 action: Action,
199 observations: Vec<PerceptVal>,
200 reward: Reward,
201}
202
203struct PhaseModel {
204 predictor: Box<dyn Predictor>,
205 last_augmented_step: usize,
208}
209
210pub struct AiqiAgent {
212 config: AiqiConfig,
213 phases: Vec<PhaseModel>,
214 steps: Vec<StepRecord>,
215 return_bins_by_step: Vec<Option<u64>>,
216 history_base_step: usize,
218 total_steps_observed: usize,
220 action_bits: usize,
221 return_bits: usize,
222 use_generic_planner: bool,
223 distribution_uses_training_updates: bool,
224 rng: RandomGenerator,
225}
226
227impl AiqiAgent {
228 pub fn new(config: AiqiConfig) -> Result<Self, String> {
230 config.validate()?;
231
232 let action_bits = bits_for_cardinality(config.agent_actions);
233 let return_bits = bits_for_cardinality(config.return_bins);
234 let use_generic_planner = aiqi_requires_generic_planner(&config);
235 let distribution_uses_training_updates = config.rate_backend.is_none()
236 && matches!(
237 config.algorithm.as_str(),
238 "ctw" | "fac-ctw" | "ac-ctw" | "ctw-context-tree"
239 );
240
241 let mut phases = Vec::with_capacity(config.augmentation_period);
242 for _ in 0..config.augmentation_period {
243 phases.push(PhaseModel {
244 predictor: build_predictor(&config, return_bits)?,
245 last_augmented_step: 0,
246 });
247 }
248
249 let rng = if let Some(seed) = config.random_seed {
250 RandomGenerator::from_seed(seed)
251 } else {
252 RandomGenerator::new()
253 };
254
255 Ok(Self {
256 action_bits,
257 return_bits,
258 use_generic_planner,
259 distribution_uses_training_updates,
260 config,
261 phases,
262 steps: Vec::new(),
263 return_bins_by_step: Vec::new(),
264 history_base_step: 1,
265 total_steps_observed: 0,
266 rng,
267 })
268 }
269
270 pub fn steps_observed(&self) -> usize {
272 self.total_steps_observed
273 }
274
275 pub fn num_actions(&self) -> usize {
277 self.config.agent_actions
278 }
279
280 pub fn get_planned_action(&mut self) -> Action {
282 let q_values = self.estimate_q_values();
283 let greedy_action = argmax_with_fixed_tie_break(&q_values) as u64;
284 if self.config.baseline_exploration > 0.0
285 && self
286 .rng
287 .gen_bool(self.config.baseline_exploration.clamp(0.0, 1.0))
288 {
289 self.rng.gen_range(self.config.agent_actions) as u64
290 } else {
291 greedy_action
292 }
293 }
294
295 pub fn get_planned_action_with_extra_exploration(&mut self, extra_exploration: f64) -> Action {
301 let extra = extra_exploration.clamp(0.0, 1.0);
302 let tau = self.config.baseline_exploration.clamp(0.0, 1.0);
303 let effective = 1.0 - (1.0 - tau) * (1.0 - extra);
304 let q_values = self.estimate_q_values();
305 let greedy_action = argmax_with_fixed_tie_break(&q_values) as u64;
306 if effective > 0.0 && self.rng.gen_bool(effective) {
307 self.rng.gen_range(self.config.agent_actions) as u64
308 } else {
309 greedy_action
310 }
311 }
312
313 pub fn observe_transition(
318 &mut self,
319 action: Action,
320 observations: &[PerceptVal],
321 reward: Reward,
322 ) -> Result<(), String> {
323 if action as usize >= self.config.agent_actions {
324 return Err(format!(
325 "action out of range: action={} but agent_actions={}",
326 action, self.config.agent_actions
327 ));
328 }
329
330 let expected_obs = self.config.observation_stream_len.max(1);
331 if observations.len() != expected_obs {
332 return Err(format!(
333 "observation stream length mismatch: expected {}, got {}",
334 expected_obs,
335 observations.len()
336 ));
337 }
338
339 if reward < self.config.min_reward || reward > self.config.max_reward {
340 return Err(format!(
341 "reward out of configured range: reward={} not in [{}, {}]",
342 reward, self.config.min_reward, self.config.max_reward
343 ));
344 }
345
346 let obs_max = max_value_for_bits(self.config.observation_bits);
347 for &obs in observations {
348 if obs > obs_max {
349 return Err(format!(
350 "observation value {} does not fit observation_bits={} (max={})",
351 obs, self.config.observation_bits, obs_max
352 ));
353 }
354 }
355
356 let rew_shifted = (reward as i128) + (self.config.reward_offset as i128);
357 if rew_shifted < 0 {
358 return Err(format!(
359 "encoded reward became negative after offset: reward={} offset={}",
360 reward, self.config.reward_offset
361 ));
362 }
363 if self.config.reward_bits < 64 {
364 let max_enc = (1u128 << self.config.reward_bits) - 1;
365 if (rew_shifted as u128) > max_enc {
366 return Err(format!(
367 "encoded reward {} exceeds reward_bits={} capacity {}",
368 rew_shifted, self.config.reward_bits, max_enc
369 ));
370 }
371 }
372
373 self.steps.push(StepRecord {
374 action,
375 observations: observations.to_vec(),
376 reward,
377 });
378 self.total_steps_observed += 1;
379 self.return_bins_by_step.push(None);
380
381 self.maybe_learn_new_return()?;
382 self.maybe_prune_history();
383 Ok(())
384 }
385
386 fn maybe_learn_new_return(&mut self) -> Result<(), String> {
387 let t = self.total_steps_observed;
388 let h = self.config.return_horizon;
389 if t < h {
390 return Ok(());
391 }
392
393 let i = t + 1 - h;
395 let bin = self.compute_return_bin(i);
396 let local_idx = self.local_index(i)?;
397 self.return_bins_by_step[local_idx] = Some(bin);
398
399 let phase = i % self.config.augmentation_period;
400 self.advance_phase_model_to_step(phase, i)
401 }
402
403 fn estimate_q_values(&mut self) -> Vec<f64> {
404 if self.use_generic_planner {
405 return self.estimate_q_values_generic();
406 }
407
408 let step = self.total_steps_observed + 1;
409 let phase = step % self.config.augmentation_period;
410 let config = &self.config;
411 let steps = &self.steps;
412 let return_bins_by_step = &self.return_bins_by_step;
413 let history_base_step = self.history_base_step;
414 let action_bits = self.action_bits;
415 let return_bits = self.return_bits;
416
417 let mut q_values = vec![0.0; self.config.agent_actions];
418 let mut pushed_fast_forward = 0usize;
419
420 {
421 let model = &mut self.phases[phase];
422 let start = (model.last_augmented_step + 1).max(history_base_step);
423 let end = step.saturating_sub(1);
424 if start <= end {
425 for idx in start..=end {
426 pushed_fast_forward += push_step_tokens_history(
427 config,
428 history_base_step,
429 steps,
430 return_bins_by_step,
431 action_bits,
432 return_bits,
433 model.predictor.as_mut(),
434 phase,
435 idx,
436 );
437 }
438 }
439
440 for action in 0..self.config.agent_actions {
441 let pushed_action = push_encoded_bits_history(
442 model.predictor.as_mut(),
443 action as u64,
444 self.action_bits,
445 );
446 let dist = Self::predict_return_distribution(
447 self.config.return_bins,
448 self.return_bits,
449 model.predictor.as_mut(),
450 self.distribution_uses_training_updates,
451 );
452 q_values[action] = expectation_from_distribution(&dist);
453 pop_history_bits(model.predictor.as_mut(), pushed_action);
454 }
455
456 pop_history_bits(model.predictor.as_mut(), pushed_fast_forward);
457 }
458
459 q_values
460 }
461
462 fn estimate_q_values_generic(&mut self) -> Vec<f64> {
463 let step = self.total_steps_observed + 1;
464 let phase = step % self.config.augmentation_period;
465
466 let model = &self.phases[phase];
467 let mut context_predictor = model.predictor.boxed_clone();
468
469 let start = (model.last_augmented_step + 1).max(self.history_base_step);
470 let end = step.saturating_sub(1);
471 if start <= end {
472 for idx in start..=end {
473 push_augmented_step_tokens_commit(
474 &self.config,
475 self.history_base_step,
476 &self.steps,
477 &self.return_bins_by_step,
478 self.action_bits,
479 self.return_bits,
480 context_predictor.as_mut(),
481 phase,
482 idx,
483 )
484 .expect("generic planner retained history must contain required augmented return");
485 }
486 }
487
488 let mut q_values = vec![0.0; self.config.agent_actions];
489 for action in 0..self.config.agent_actions {
490 let mut action_predictor = context_predictor.boxed_clone();
491 let _ = push_encoded_bits_commit_history(
492 action_predictor.as_mut(),
493 action as u64,
494 self.action_bits,
495 );
496 let dist = Self::predict_return_distribution_from_base_predictor(
497 self.config.return_bins,
498 self.return_bits,
499 action_predictor.as_ref(),
500 );
501 q_values[action] = expectation_from_distribution(&dist);
502 }
503
504 q_values
505 }
506
507 fn predict_return_distribution(
508 return_bins: usize,
509 return_bits: usize,
510 predictor: &mut dyn Predictor,
511 use_training_updates: bool,
512 ) -> Vec<f64> {
513 debug_assert!(return_bins.is_power_of_two());
514 if return_bins == 1 {
515 return vec![1.0];
516 }
517
518 let mut probs = vec![0.0; return_bins];
519 for (bin, slot) in probs.iter_mut().enumerate() {
520 let mut p = 1.0f64;
521 let mut v = bin as u64;
522 for _ in 0..return_bits {
523 let bit = (v & 1) == 1;
524 v >>= 1;
525 let q = predictor.predict_prob(bit).clamp(1e-12, 1.0 - 1e-12);
526 p *= q;
527 if use_training_updates {
528 predictor.update(bit);
529 } else {
530 predictor.update_history(bit);
531 }
532 }
533 if use_training_updates {
534 revert_bits(predictor, return_bits);
535 } else {
536 pop_history_bits(predictor, return_bits);
537 }
538 *slot = p;
539 }
540
541 let sum: f64 = probs.iter().sum();
542 if !sum.is_finite() || sum <= 0.0 {
543 let u = 1.0 / (return_bins as f64);
544 probs.fill(u);
545 return probs;
546 }
547
548 for p in &mut probs {
549 *p /= sum;
550 }
551 probs
552 }
553
554 fn predict_return_distribution_from_base_predictor(
555 return_bins: usize,
556 return_bits: usize,
557 base_predictor: &dyn Predictor,
558 ) -> Vec<f64> {
559 debug_assert!(return_bins.is_power_of_two());
560 if return_bins == 1 {
561 return vec![1.0];
562 }
563
564 let mut probs = vec![0.0; return_bins];
565 for (bin, slot) in probs.iter_mut().enumerate() {
566 let mut predictor = base_predictor.boxed_clone();
567 let mut p = 1.0f64;
568 let mut v = bin as u64;
569 for _ in 0..return_bits {
570 let bit = (v & 1) == 1;
571 v >>= 1;
572 let q = predictor.predict_prob(bit).clamp(1e-12, 1.0 - 1e-12);
573 p *= q;
574 predictor.commit_update(bit);
575 }
576 *slot = p;
577 }
578
579 let sum: f64 = probs.iter().sum();
580 if !sum.is_finite() || sum <= 0.0 {
581 let u = 1.0 / (return_bins as f64);
582 probs.fill(u);
583 return probs;
584 }
585
586 for p in &mut probs {
587 *p /= sum;
588 }
589 probs
590 }
591
592 fn advance_phase_model_to_step(
593 &mut self,
594 phase: usize,
595 target_step: usize,
596 ) -> Result<(), String> {
597 let config = &self.config;
598 let steps = &self.steps;
599 let return_bins_by_step = &self.return_bins_by_step;
600 let history_base_step = self.history_base_step;
601 let action_bits = self.action_bits;
602 let return_bits = self.return_bits;
603 let model = &mut self.phases[phase];
604 if target_step <= model.last_augmented_step {
605 return Ok(());
606 }
607
608 let start = (model.last_augmented_step + 1).max(history_base_step);
609 for idx in start..=target_step {
610 push_augmented_step_tokens_commit(
611 config,
612 history_base_step,
613 steps,
614 return_bins_by_step,
615 action_bits,
616 return_bits,
617 model.predictor.as_mut(),
618 phase,
619 idx,
620 )?;
621 }
622
623 model.last_augmented_step = target_step;
624 Ok(())
625 }
626
627 fn compute_return_bin(&self, start_step: usize) -> u64 {
628 let h = self.config.return_horizon;
629 let gamma = self.config.discount_gamma;
630
631 debug_assert!(gamma > 0.0 && gamma < 1.0);
632 let reward_range = (self.config.max_reward - self.config.min_reward) as f64;
633
634 let mut total = 0.0f64;
636 let mut gk = 1.0f64;
637 for k in 0..h {
638 let idx = start_step + k;
639 let local_idx = self
640 .local_index(idx)
641 .expect("return computation requires in-range history");
642 let r = self.steps[local_idx].reward;
643 let rn = if reward_range <= 0.0 {
644 0.0
645 } else {
646 ((r - self.config.min_reward) as f64 / reward_range).clamp(0.0, 1.0)
647 };
648 total += gk * rn;
649 gk *= gamma;
650 }
651 let ret = ((1.0 - gamma) * total).clamp(0.0, 1.0);
652
653 let mut bin = (ret * (self.config.return_bins as f64)).floor() as u64;
654 let max_bin = (self.config.return_bins as u64).saturating_sub(1);
655 if bin > max_bin {
656 bin = max_bin;
657 }
658 bin
659 }
660
661 fn local_index(&self, global_step: usize) -> Result<usize, String> {
662 if global_step < self.history_base_step || global_step > self.total_steps_observed {
663 return Err(format!(
664 "global step {} out of retained history range [{}, {}]",
665 global_step, self.history_base_step, self.total_steps_observed
666 ));
667 }
668 Ok(global_step - self.history_base_step)
669 }
670
671 fn maybe_prune_history(&mut self) {
672 let Some(keep_steps) = self.config.history_prune_keep_steps else {
673 return;
674 };
675 if self.steps.is_empty() {
676 return;
677 }
678
679 let min_phase_committed = self
680 .phases
681 .iter()
682 .map(|phase| phase.last_augmented_step)
683 .min()
684 .unwrap_or(0);
685
686 let next_start_needed = self
690 .total_steps_observed
691 .saturating_add(2)
692 .saturating_sub(self.config.return_horizon);
693 let returns_safe_drop_upto = next_start_needed.saturating_sub(1);
694
695 let mut safe_drop_upto = min_phase_committed.min(returns_safe_drop_upto);
696
697 let keep_floor_drop_upto = self.total_steps_observed.saturating_sub(keep_steps);
700 safe_drop_upto = safe_drop_upto.min(keep_floor_drop_upto);
701
702 if safe_drop_upto < self.history_base_step {
703 return;
704 }
705
706 let drain_count = safe_drop_upto - self.history_base_step + 1;
707 if drain_count == 0 || drain_count > self.steps.len() {
708 return;
709 }
710
711 self.steps.drain(0..drain_count);
712 self.return_bins_by_step.drain(0..drain_count);
713 self.history_base_step += drain_count;
714 }
715}
716
717fn push_step_tokens_history(
718 config: &AiqiConfig,
719 history_base_step: usize,
720 steps: &[StepRecord],
721 return_bins_by_step: &[Option<u64>],
722 action_bits: usize,
723 return_bits: usize,
724 predictor: &mut dyn Predictor,
725 phase: usize,
726 idx: usize,
727) -> usize {
728 let mut pushed = 0usize;
729 pushed += push_action_tokens_history(history_base_step, steps, action_bits, predictor, idx);
730
731 if idx % config.augmentation_period == phase {
732 let local_idx = idx - history_base_step;
733 if let Some(bin) = return_bins_by_step[local_idx] {
734 pushed += push_encoded_bits_history(predictor, bin, return_bits);
735 }
736 }
737
738 pushed + push_percept_tokens_history(config, history_base_step, steps, predictor, idx)
739}
740
741fn push_augmented_step_tokens_commit(
742 config: &AiqiConfig,
743 history_base_step: usize,
744 steps: &[StepRecord],
745 return_bins_by_step: &[Option<u64>],
746 action_bits: usize,
747 return_bits: usize,
748 predictor: &mut dyn Predictor,
749 phase: usize,
750 idx: usize,
751) -> Result<usize, String> {
752 let mut pushed = 0usize;
753 pushed +=
754 push_action_tokens_commit_history(history_base_step, steps, action_bits, predictor, idx);
755
756 if idx % config.augmentation_period == phase {
757 let local_idx = idx - history_base_step;
758 let bin = return_bins_by_step[local_idx].ok_or_else(|| {
759 format!(
760 "missing return bin for step {} in phase {} while pushing augmented history",
761 idx, phase
762 )
763 })?;
764 pushed += push_encoded_bits_commit(predictor, bin, return_bits);
765 }
766
767 Ok(pushed
768 + push_percept_tokens_commit_history(config, history_base_step, steps, predictor, idx))
769}
770
771fn push_action_tokens_history(
772 history_base_step: usize,
773 steps: &[StepRecord],
774 action_bits: usize,
775 predictor: &mut dyn Predictor,
776 idx: usize,
777) -> usize {
778 let action = steps[idx - history_base_step].action;
779 push_encoded_bits_history(predictor, action, action_bits)
780}
781
782fn push_action_tokens_commit_history(
783 history_base_step: usize,
784 steps: &[StepRecord],
785 action_bits: usize,
786 predictor: &mut dyn Predictor,
787 idx: usize,
788) -> usize {
789 let action = steps[idx - history_base_step].action;
790 push_encoded_bits_commit_history(predictor, action, action_bits)
791}
792
793fn push_percept_tokens_history(
794 config: &AiqiConfig,
795 history_base_step: usize,
796 steps: &[StepRecord],
797 predictor: &mut dyn Predictor,
798 idx: usize,
799) -> usize {
800 let step = &steps[idx - history_base_step];
801 let mut pushed = 0usize;
802 for &obs in &step.observations {
803 pushed += push_encoded_bits_history(predictor, obs, config.observation_bits);
804 }
805 pushed
806 + push_encoded_reward_history(
807 predictor,
808 step.reward,
809 config.reward_bits,
810 config.reward_offset,
811 )
812}
813
814fn push_percept_tokens_commit_history(
815 config: &AiqiConfig,
816 history_base_step: usize,
817 steps: &[StepRecord],
818 predictor: &mut dyn Predictor,
819 idx: usize,
820) -> usize {
821 let step = &steps[idx - history_base_step];
822 let mut pushed = 0usize;
823 for &obs in &step.observations {
824 pushed += push_encoded_bits_commit_history(predictor, obs, config.observation_bits);
825 }
826 pushed
827 + push_encoded_reward_commit_history(
828 predictor,
829 step.reward,
830 config.reward_bits,
831 config.reward_offset,
832 )
833}
834
835fn build_predictor(config: &AiqiConfig, return_bits: usize) -> Result<Box<dyn Predictor>, String> {
836 if let Some(rate_backend) = config.rate_backend.clone() {
837 let bit_backend = adapt_rate_backend_for_bit_tokens(rate_backend);
838 let predictor = RateBackendBitPredictor::new(bit_backend, config.rate_backend_max_order)?;
839 return Ok(Box::new(predictor));
840 }
841
842 match config.algorithm.as_str() {
843 "ctw" | "ac-ctw" | "ctw-context-tree" => Ok(Box::new(CtwPredictor::new(config.ct_depth))),
844 "fac-ctw" => {
845 Ok(Box::new(FacCtwPredictor::new(config.ct_depth, return_bits)))
847 }
848 "rosa" => {
849 let max_order = config
850 .rosa_max_order
851 .unwrap_or(config.rate_backend_max_order);
852 let bit_backend = adapt_rate_backend_for_bit_tokens(RateBackend::RosaPlus);
853 let predictor = RateBackendBitPredictor::new(bit_backend, max_order)?;
854 Ok(Box::new(predictor))
855 }
856 #[cfg(feature = "backend-rwkv")]
857 "rwkv" => {
858 let path = config.rwkv_model_path.as_ref().ok_or_else(|| {
859 "algorithm=rwkv requires rwkv_model_path when no rate_backend override is configured; for method-string RWKV configure rate_backend rwkv/rwkv7"
860 .to_string()
861 })?;
862 let model_arc = load_rwkv7_model_from_path(path);
863 let bit_backend =
864 adapt_rate_backend_for_bit_tokens(RateBackend::Rwkv7 { model: model_arc });
865 let predictor = RateBackendBitPredictor::new(bit_backend, config.rate_backend_max_order)?;
866 Ok(Box::new(predictor))
867 }
868 #[cfg(not(feature = "backend-rwkv"))]
869 "rwkv" => Err("algorithm=rwkv requires backend-rwkv feature".to_string()),
870 "zpaq" => Err(
871 "AIQI strict mode does not support algorithm=zpaq; configure a backend with strict frozen conditioning"
872 .to_string(),
873 ),
874 _ => Err(format!("Unknown AIQI algorithm: {}", config.algorithm)),
875 }
876}
877
878fn adapt_rate_backend_for_bit_tokens(backend: RateBackend) -> RateBackend {
879 crate::aixi::rate_backend::adapt_rate_backend_for_bit_tokens(backend)
880}
881
882fn rate_backend_supports_aiqi_frozen_conditioning(backend: &RateBackend) -> bool {
883 !rate_backend_contains_zpaq(backend)
884}
885
886fn aiqi_requires_generic_planner(config: &AiqiConfig) -> bool {
887 config.rate_backend.is_some()
888 || !matches!(
889 config.algorithm.as_str(),
890 "ctw" | "fac-ctw" | "ac-ctw" | "ctw-context-tree"
891 )
892}
893
894fn bits_for_cardinality(cardinality: usize) -> usize {
895 let n = cardinality.max(1);
896 let mut bits = 0usize;
897 while (1usize << bits) < n {
898 bits += 1;
899 }
900 bits.max(1)
901}
902
903fn max_value_for_bits(bits: usize) -> u64 {
904 if bits >= 64 {
905 u64::MAX
906 } else if bits == 0 {
907 0
908 } else {
909 (1u64 << bits) - 1
910 }
911}
912
913fn push_encoded_bits_commit(predictor: &mut dyn Predictor, value: u64, bits: usize) -> usize {
914 let mut v = value;
915 for _ in 0..bits {
916 predictor.commit_update((v & 1) == 1);
917 v >>= 1;
918 }
919 bits
920}
921
922fn push_encoded_bits_history(predictor: &mut dyn Predictor, value: u64, bits: usize) -> usize {
923 let mut v = value;
924 for _ in 0..bits {
925 predictor.update_history((v & 1) == 1);
926 v >>= 1;
927 }
928 bits
929}
930
931fn push_encoded_bits_commit_history(
932 predictor: &mut dyn Predictor,
933 value: u64,
934 bits: usize,
935) -> usize {
936 let mut v = value;
937 for _ in 0..bits {
938 predictor.commit_update_history((v & 1) == 1);
939 v >>= 1;
940 }
941 bits
942}
943
944fn push_encoded_reward_history(
945 predictor: &mut dyn Predictor,
946 reward: Reward,
947 bits: usize,
948 offset: Reward,
949) -> usize {
950 let shifted = (reward as i128) + (offset as i128);
951 let as_u64 = if shifted <= 0 {
952 0
953 } else if shifted > (u64::MAX as i128) {
954 u64::MAX
955 } else {
956 shifted as u64
957 };
958 push_encoded_bits_history(predictor, as_u64, bits)
959}
960
961fn push_encoded_reward_commit_history(
962 predictor: &mut dyn Predictor,
963 reward: Reward,
964 bits: usize,
965 offset: Reward,
966) -> usize {
967 let shifted = (reward as i128) + (offset as i128);
968 let as_u64 = if shifted <= 0 {
969 0
970 } else if shifted > (u64::MAX as i128) {
971 u64::MAX
972 } else {
973 shifted as u64
974 };
975 push_encoded_bits_commit_history(predictor, as_u64, bits)
976}
977
978fn pop_history_bits(predictor: &mut dyn Predictor, bits: usize) {
979 for _ in 0..bits {
980 predictor.pop_history();
981 }
982}
983
984fn revert_bits(predictor: &mut dyn Predictor, bits: usize) {
985 for _ in 0..bits {
986 predictor.revert();
987 }
988}
989
990fn expectation_from_distribution(probs: &[f64]) -> f64 {
991 if probs.is_empty() {
992 return 0.0;
993 }
994 let m = probs.len() as f64;
995 probs
996 .iter()
997 .enumerate()
998 .map(|(i, p)| (i as f64 / m) * p)
999 .sum::<f64>()
1000}
1001
1002fn argmax_with_fixed_tie_break(values: &[f64]) -> usize {
1003 let mut best_value = f64::NEG_INFINITY;
1004 let mut best_idx = 0usize;
1005 for (i, &v) in values.iter().enumerate() {
1006 if v > best_value {
1007 best_value = v;
1008 best_idx = i;
1009 }
1010 }
1011 best_idx
1012}
1013
1014#[cfg(test)]
1015mod tests {
1016 use super::*;
1017 use std::sync::{Arc, Mutex};
1018
1019 fn basic_config() -> AiqiConfig {
1020 AiqiConfig {
1021 algorithm: "ac-ctw".to_string(),
1022 ct_depth: 8,
1023 observation_bits: 1,
1024 observation_stream_len: 1,
1025 reward_bits: 1,
1026 agent_actions: 2,
1027 min_reward: 0,
1028 max_reward: 1,
1029 reward_offset: 0,
1030 discount_gamma: 0.99,
1031 return_horizon: 2,
1032 return_bins: 8,
1033 augmentation_period: 2,
1034 history_prune_keep_steps: None,
1035 baseline_exploration: 0.01,
1036 random_seed: Some(7),
1037 rate_backend: None,
1038 rate_backend_max_order: 20,
1039 rwkv_model_path: None,
1040 rosa_max_order: None,
1041 zpaq_method: None,
1042 }
1043 }
1044
1045 #[derive(Clone, Default)]
1046 struct CountingPredictor {
1047 update_calls: usize,
1048 commit_update_calls: usize,
1049 update_history_calls: usize,
1050 commit_update_history_calls: usize,
1051 revert_calls: usize,
1052 pop_history_calls: usize,
1053 }
1054
1055 impl Predictor for CountingPredictor {
1056 fn update(&mut self, _sym: bool) {
1057 self.update_calls += 1;
1058 }
1059
1060 fn commit_update(&mut self, _sym: bool) {
1061 self.commit_update_calls += 1;
1062 }
1063
1064 fn update_history(&mut self, _sym: bool) {
1065 self.update_history_calls += 1;
1066 }
1067
1068 fn commit_update_history(&mut self, _sym: bool) {
1069 self.commit_update_history_calls += 1;
1070 }
1071
1072 fn revert(&mut self) {
1073 self.revert_calls += 1;
1074 }
1075
1076 fn pop_history(&mut self) {
1077 self.pop_history_calls += 1;
1078 }
1079
1080 fn predict_prob(&mut self, sym: bool) -> f64 {
1081 if sym { 0.75 } else { 0.25 }
1082 }
1083
1084 fn model_name(&self) -> String {
1085 "CountingPredictor".to_string()
1086 }
1087
1088 fn boxed_clone(&self) -> Box<dyn Predictor> {
1089 Box::new(self.clone())
1090 }
1091 }
1092
1093 #[derive(Clone, Default)]
1094 struct SharedCallCounts {
1095 update: usize,
1096 commit_update: usize,
1097 update_history: usize,
1098 commit_update_history: usize,
1099 }
1100
1101 #[derive(Clone)]
1102 struct SharedCountingPredictor {
1103 counts: Arc<Mutex<SharedCallCounts>>,
1104 }
1105
1106 impl SharedCountingPredictor {
1107 fn new(counts: Arc<Mutex<SharedCallCounts>>) -> Self {
1108 Self { counts }
1109 }
1110 }
1111
1112 impl Predictor for SharedCountingPredictor {
1113 fn update(&mut self, _sym: bool) {
1114 self.counts.lock().unwrap().update += 1;
1115 }
1116
1117 fn commit_update(&mut self, _sym: bool) {
1118 self.counts.lock().unwrap().commit_update += 1;
1119 }
1120
1121 fn update_history(&mut self, _sym: bool) {
1122 self.counts.lock().unwrap().update_history += 1;
1123 }
1124
1125 fn commit_update_history(&mut self, _sym: bool) {
1126 self.counts.lock().unwrap().commit_update_history += 1;
1127 }
1128
1129 fn revert(&mut self) {}
1130
1131 fn pop_history(&mut self) {}
1132
1133 fn predict_prob(&mut self, sym: bool) -> f64 {
1134 if sym { 0.75 } else { 0.25 }
1135 }
1136
1137 fn model_name(&self) -> String {
1138 "SharedCountingPredictor".to_string()
1139 }
1140
1141 fn boxed_clone(&self) -> Box<dyn Predictor> {
1142 Box::new(self.clone())
1143 }
1144 }
1145
1146 #[derive(Clone, Default)]
1147 struct ReturnLearningPredictor {
1148 saw_training_one: bool,
1149 }
1150
1151 impl Predictor for ReturnLearningPredictor {
1152 fn update(&mut self, sym: bool) {
1153 if sym {
1154 self.saw_training_one = true;
1155 }
1156 }
1157
1158 fn commit_update(&mut self, sym: bool) {
1159 if sym {
1160 self.saw_training_one = true;
1161 }
1162 }
1163
1164 fn update_history(&mut self, _sym: bool) {}
1165
1166 fn commit_update_history(&mut self, _sym: bool) {}
1167
1168 fn revert(&mut self) {}
1169
1170 fn pop_history(&mut self) {}
1171
1172 fn predict_prob(&mut self, sym: bool) -> f64 {
1173 let p1 = if self.saw_training_one { 0.75 } else { 0.25 };
1174 if sym { p1 } else { 1.0 - p1 }
1175 }
1176
1177 fn model_name(&self) -> String {
1178 "ReturnLearningPredictor".to_string()
1179 }
1180
1181 fn boxed_clone(&self) -> Box<dyn Predictor> {
1182 Box::new(self.clone())
1183 }
1184 }
1185
1186 #[test]
1187 fn config_rejects_invalid_period() {
1188 let mut cfg = basic_config();
1189 cfg.augmentation_period = 1;
1190 cfg.return_horizon = 2;
1191 let err = cfg
1192 .validate()
1193 .expect_err("N < H must be rejected to match \"A Model-Free Universal AI\"");
1194 assert!(err.contains("augmentation_period"));
1195 }
1196
1197 #[test]
1198 fn config_rejects_non_power_of_two_return_bins() {
1199 let mut cfg = basic_config();
1200 cfg.return_bins = 3;
1201 let err = cfg
1202 .validate()
1203 .expect_err("non-power-of-two return_bins should be rejected");
1204 assert!(err.contains("power of two"));
1205 }
1206
1207 #[test]
1208 fn config_rejects_zpaq_algorithm_in_strict_mode() {
1209 let mut cfg = basic_config();
1210 cfg.algorithm = "zpaq".to_string();
1211 let err = cfg
1212 .validate()
1213 .expect_err("strict AIQI must reject zpaq algorithm mode");
1214 assert!(err.contains("strict mode"));
1215 }
1216
1217 #[test]
1218 fn config_rejects_zpaq_rate_backend_in_strict_mode() {
1219 let mut cfg = basic_config();
1220 cfg.rate_backend = Some(RateBackend::Zpaq {
1221 method: "1".to_string(),
1222 });
1223 let err = cfg
1224 .validate()
1225 .expect_err("strict AIQI must reject zpaq rate backend");
1226 assert!(err.contains("strict frozen conditioning"));
1227 }
1228
1229 #[test]
1230 fn config_rejects_nonpaper_gamma_or_tau() {
1231 let mut cfg = basic_config();
1232 cfg.discount_gamma = 1.0;
1233 let err = cfg
1234 .validate()
1235 .expect_err("gamma=1 must be rejected for strict paper AIQI");
1236 assert!(err.contains("discount_gamma"));
1237
1238 cfg = basic_config();
1239 cfg.baseline_exploration = 0.0;
1240 let err = cfg
1241 .validate()
1242 .expect_err("tau=0 must be rejected for strict paper AIQI");
1243 assert!(err.contains("baseline_exploration"));
1244 }
1245
1246 #[test]
1247 fn aiqi_estimates_action_values_after_observations() {
1248 let mut agent = AiqiAgent::new(basic_config()).expect("valid aiqi config");
1249 for _ in 0..8 {
1250 agent
1251 .observe_transition(1, &[1], 1)
1252 .expect("transition should be accepted");
1253 }
1254
1255 let action = agent.get_planned_action();
1256 assert!(action <= 1);
1257 }
1258
1259 #[test]
1260 fn fac_ctw_predictor_uses_return_bit_width() {
1261 let mut cfg = basic_config();
1262 cfg.algorithm = "fac-ctw".to_string();
1263 cfg.return_bins = 8; let agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1266 let name = agent.phases[0].predictor.model_name();
1267 assert!(
1268 name.contains("k=3"),
1269 "FAC-CTW should factorize over return bits only, model_name={name}"
1270 );
1271 }
1272
1273 #[test]
1274 fn ac_ctw_path_uses_single_tree_predictor() {
1275 let mut cfg = basic_config();
1276 cfg.algorithm = "ac-ctw".to_string();
1277
1278 let agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1279 let name = agent.phases[0].predictor.model_name();
1280 assert!(
1281 name.starts_with("AC-CTW"),
1282 "ac-ctw should map to the single-tree CTW predictor, model_name={name}"
1283 );
1284 }
1285
1286 #[test]
1287 fn ctw_alias_matches_ac_ctw_predictor() {
1288 let mut cfg = basic_config();
1289 cfg.algorithm = "ctw".to_string();
1290
1291 let agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1292 let name = agent.phases[0].predictor.model_name();
1293 assert!(
1294 name.starts_with("AC-CTW"),
1295 "ctw alias should map to paper AIQI-CTW predictor, model_name={name}"
1296 );
1297 }
1298
1299 #[test]
1300 fn distribution_rollout_uses_update_and_revert_when_requested() {
1301 let mut predictor = CountingPredictor::default();
1302 let probs = AiqiAgent::predict_return_distribution(4, 2, &mut predictor, true);
1303
1304 assert_eq!(probs.len(), 4);
1305 assert_eq!(predictor.update_calls, 8);
1306 assert_eq!(predictor.revert_calls, 8);
1307 assert_eq!(predictor.update_history_calls, 0);
1308 assert_eq!(predictor.pop_history_calls, 0);
1309 }
1310
1311 #[test]
1312 fn distribution_rollout_uses_history_path_when_not_requested() {
1313 let mut predictor = CountingPredictor::default();
1314 let probs = AiqiAgent::predict_return_distribution(4, 2, &mut predictor, false);
1315
1316 assert_eq!(probs.len(), 4);
1317 assert_eq!(predictor.update_calls, 0);
1318 assert_eq!(predictor.revert_calls, 0);
1319 assert_eq!(predictor.update_history_calls, 8);
1320 assert_eq!(predictor.pop_history_calls, 8);
1321 }
1322
1323 #[test]
1324 fn generic_distribution_rollout_trains_on_return_symbols() {
1325 let predictor = ReturnLearningPredictor::default();
1326 let probs = AiqiAgent::predict_return_distribution_from_base_predictor(4, 2, &predictor);
1327
1328 assert_eq!(probs.len(), 4);
1329 assert!((probs.iter().sum::<f64>() - 1.0).abs() < 1e-12);
1330 assert!(
1331 probs[3] > probs[1],
1332 "training on the first return bit should make bin 11 likelier than 01; got {:?}",
1333 probs
1334 );
1335 assert!(
1336 (probs[0] - 0.5625).abs() < 1e-12,
1337 "expected exact normalized mass for 00, got {:?}",
1338 probs
1339 );
1340 }
1341
1342 #[test]
1343 fn ac_ctw_rollout_uses_training_updates() {
1344 let mut cfg = basic_config();
1345 cfg.algorithm = "ac-ctw".to_string();
1346
1347 let agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1348 assert!(
1349 agent.distribution_uses_training_updates,
1350 "ac-ctw should use update/revert during return distribution rollout"
1351 );
1352 }
1353
1354 #[test]
1355 fn return_bin_for_gamma_less_than_one_matches_paper_h_step_return() {
1356 let mut cfg = basic_config();
1357 cfg.discount_gamma = 0.5;
1358 cfg.return_bins = 8;
1359
1360 let mut agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1361 agent
1362 .observe_transition(0, &[0], 1)
1363 .expect("first transition stored");
1364 agent
1365 .observe_transition(0, &[0], 0)
1366 .expect("second transition should produce first return");
1367
1368 let bin = agent.return_bins_by_step[0].expect("first return should be available");
1369 assert_eq!(bin, 4);
1373 }
1374
1375 #[test]
1376 fn optional_history_pruning_bounds_retained_state_without_losing_progress() {
1377 let mut cfg = basic_config();
1378 cfg.return_horizon = 3;
1379 cfg.augmentation_period = 4;
1380 cfg.history_prune_keep_steps = Some(8);
1381
1382 let mut agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1383 for i in 0..256usize {
1384 let action = (i % 2) as u64;
1385 let obs = [(i % 2) as u64];
1386 let rew = (i % 2) as i64;
1387 agent
1388 .observe_transition(action, &obs, rew)
1389 .expect("transition should be accepted");
1390 }
1391
1392 assert_eq!(agent.steps_observed(), 256);
1394 assert!(
1395 agent.history_base_step > 1,
1396 "history should have been pruned"
1397 );
1398 assert!(
1399 agent.steps.len() < agent.steps_observed(),
1400 "retained history should be smaller than total observed"
1401 );
1402
1403 let action = agent.get_planned_action();
1404 assert!(action <= 1);
1405 }
1406
1407 #[test]
1408 fn committed_phase_advancement_uses_commit_predictor_paths() {
1409 let mut agent = AiqiAgent::new(basic_config()).expect("valid aiqi config");
1410 let counts = Arc::new(Mutex::new(SharedCallCounts::default()));
1411 agent.phases[1].predictor = Box::new(SharedCountingPredictor::new(counts.clone()));
1412 agent.phases[1].last_augmented_step = 0;
1413 agent.history_base_step = 1;
1414 agent.total_steps_observed = 1;
1415 agent.steps = vec![StepRecord {
1416 action: 1,
1417 observations: vec![1],
1418 reward: 1,
1419 }];
1420 agent.return_bins_by_step = vec![Some(3)];
1421
1422 agent
1423 .advance_phase_model_to_step(1, 1)
1424 .expect("phase advancement should succeed");
1425
1426 let snapshot = counts.lock().unwrap().clone();
1427 assert_eq!(snapshot.commit_update, 3);
1428 assert_eq!(snapshot.commit_update_history, 3);
1429 assert_eq!(snapshot.update, 0);
1430 assert_eq!(snapshot.update_history, 0);
1431 }
1432
1433 #[test]
1434 fn generic_planner_trains_on_returns_and_freezes_conditioning_tokens() {
1435 let mut cfg = basic_config();
1436 cfg.rate_backend = Some(RateBackend::Match {
1437 hash_bits: 16,
1438 min_len: 2,
1439 max_len: 16,
1440 base_mix: 0.05,
1441 confidence_scale: 1.0,
1442 });
1443
1444 let mut agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1445 let counts = Arc::new(Mutex::new(SharedCallCounts::default()));
1446 agent.phases[1].predictor = Box::new(SharedCountingPredictor::new(counts.clone()));
1447 agent.phases[1].last_augmented_step = 0;
1448 agent.history_base_step = 1;
1449 agent.total_steps_observed = 2;
1450 agent.steps = vec![
1451 StepRecord {
1452 action: 1,
1453 observations: vec![1],
1454 reward: 1,
1455 },
1456 StepRecord {
1457 action: 0,
1458 observations: vec![0],
1459 reward: 0,
1460 },
1461 ];
1462 agent.return_bins_by_step = vec![Some(3), None];
1463
1464 let q_values = agent.estimate_q_values_generic();
1465
1466 assert_eq!(q_values.len(), agent.config.agent_actions);
1467 let snapshot = counts.lock().unwrap().clone();
1468 assert_eq!(snapshot.update, 0);
1469 assert_eq!(snapshot.update_history, 0);
1470 assert!(
1471 snapshot.commit_update > 0,
1472 "generic planner should train on augmented return symbols"
1473 );
1474 assert!(
1475 snapshot.commit_update_history > 0,
1476 "generic planner should keep action/percept conditioning frozen"
1477 );
1478 }
1479}