1use crate::aixi::common::{Action, PerceptVal, RandomGenerator, Reward};
10use crate::aixi::model::{CtwPredictor, FacCtwPredictor, Predictor, RateBackendBitPredictor};
11#[cfg(feature = "backend-rwkv")]
12use crate::load_rwkv7_model_from_path;
13use crate::{CalibratedSpec, MixtureSpec, RateBackend};
14use std::sync::Arc;
15
16#[derive(Clone)]
18pub struct AiqiConfig {
19 pub algorithm: String,
26 pub ct_depth: usize,
28 pub observation_bits: usize,
30 pub observation_stream_len: usize,
32 pub reward_bits: usize,
34 pub agent_actions: usize,
36 pub min_reward: Reward,
38 pub max_reward: Reward,
40 pub reward_offset: Reward,
42 pub discount_gamma: f64,
44 pub return_horizon: usize,
46 pub return_bins: usize,
48 pub augmentation_period: usize,
50 pub history_prune_keep_steps: Option<usize>,
57 pub baseline_exploration: f64,
59 pub random_seed: Option<u64>,
63 pub rate_backend: Option<RateBackend>,
68 pub rate_backend_max_order: i64,
70 pub rwkv_model_path: Option<String>,
75 pub rosa_max_order: Option<i64>,
77 pub zpaq_method: Option<String>,
79}
80
81impl AiqiConfig {
82 pub fn validate(&self) -> Result<(), String> {
84 if self.agent_actions == 0 {
85 return Err("agent_actions must be >= 1".to_string());
86 }
87 if self.return_horizon == 0 {
88 return Err("return_horizon must be >= 1".to_string());
89 }
90 if self.return_bins == 0 {
91 return Err("return_bins must be >= 1".to_string());
92 }
93 if self.augmentation_period < self.return_horizon {
94 return Err(format!(
95 "augmentation_period must be >= return_horizon (got N={}, H={})",
96 self.augmentation_period, self.return_horizon
97 ));
98 }
99 if !(0.0 < self.discount_gamma && self.discount_gamma < 1.0) {
100 return Err(format!(
101 "discount_gamma must be in (0, 1) for paper-correct AIQI, got {}",
102 self.discount_gamma
103 ));
104 }
105 if !(0.0 < self.baseline_exploration && self.baseline_exploration <= 1.0) {
106 return Err(format!(
107 "baseline_exploration (tau) must be in (0, 1] for paper-correct AIQI, got {}",
108 self.baseline_exploration
109 ));
110 }
111 if self.max_reward < self.min_reward {
112 return Err(format!(
113 "max_reward must be >= min_reward (got {} < {})",
114 self.max_reward, self.min_reward
115 ));
116 }
117
118 if self.rate_backend.is_none() {
121 match self.algorithm.as_str() {
122 "ctw" | "fac-ctw" | "ac-ctw" | "ctw-context-tree" | "rosa" => {}
123 "zpaq" => {
124 return Err(
125 "AIQI strict mode does not support algorithm=zpaq: zpaq backends do not provide strict frozen conditioning"
126 .to_string(),
127 )
128 }
129 #[cfg(feature = "backend-rwkv")]
130 "rwkv" => {}
131 #[cfg(not(feature = "backend-rwkv"))]
132 "rwkv" => {
133 return Err("algorithm=rwkv requires backend-rwkv feature".to_string())
134 }
135 other => return Err(format!("Unknown AIQI algorithm: {other}")),
136 }
137 }
138
139 if let Some(rate_backend) = &self.rate_backend {
140 if !rate_backend_supports_aiqi_frozen_conditioning(rate_backend) {
141 return Err(
142 "AIQI strict mode requires frozen context updates; configured rate_backend contains zpaq which does not provide strict frozen conditioning"
143 .to_string(),
144 );
145 }
146 }
147
148 #[cfg(feature = "backend-rwkv")]
149 if self.rate_backend.is_none() && self.algorithm == "rwkv" {
150 match self.rwkv_model_path.as_deref() {
151 Some(path) if !path.trim().is_empty() => {}
152 _ => {
153 return Err(
154 "algorithm=rwkv requires rwkv_model_path when no rate_backend override is configured; for method-string RWKV configure rate_backend rwkv/rwkv7"
155 .to_string(),
156 )
157 }
158 }
159 }
160
161 let min_shifted = (self.min_reward as i128) + (self.reward_offset as i128);
162 let max_shifted = (self.max_reward as i128) + (self.reward_offset as i128);
163 if min_shifted < 0 {
164 return Err(format!(
165 "reward_offset too small: min_reward + reward_offset must be >= 0 (got {})",
166 min_shifted
167 ));
168 }
169 if self.reward_bits < 64 {
170 let max_enc = (1u128 << self.reward_bits) - 1;
171 if (max_shifted as u128) > max_enc {
172 return Err(format!(
173 "reward_bits too small for configured reward range: max shifted reward {} exceeds {}",
174 max_shifted, max_enc
175 ));
176 }
177 }
178
179 Ok(())
180 }
181}
182
183#[derive(Clone, Debug)]
184struct StepRecord {
185 action: Action,
186 observations: Vec<PerceptVal>,
187 reward: Reward,
188}
189
190struct PhaseModel {
191 predictor: Box<dyn Predictor>,
192 last_augmented_step: usize,
195}
196
197pub struct AiqiAgent {
199 config: AiqiConfig,
200 phases: Vec<PhaseModel>,
201 steps: Vec<StepRecord>,
202 return_bins_by_step: Vec<Option<u64>>,
203 history_base_step: usize,
205 total_steps_observed: usize,
207 action_bits: usize,
208 return_bits: usize,
209 use_generic_planner: bool,
210 distribution_uses_training_updates: bool,
211 rng: RandomGenerator,
212}
213
214impl AiqiAgent {
215 pub fn new(config: AiqiConfig) -> Result<Self, String> {
217 config.validate()?;
218
219 let action_bits = bits_for_cardinality(config.agent_actions);
220 let return_bits = bits_for_cardinality(config.return_bins);
221 let use_generic_planner = aiqi_requires_generic_planner(&config);
222 let distribution_uses_training_updates = config.rate_backend.is_none()
223 && matches!(
224 config.algorithm.as_str(),
225 "ctw" | "fac-ctw" | "ac-ctw" | "ctw-context-tree"
226 );
227
228 let mut phases = Vec::with_capacity(config.augmentation_period);
229 for _ in 0..config.augmentation_period {
230 phases.push(PhaseModel {
231 predictor: build_predictor(&config, return_bits)?,
232 last_augmented_step: 0,
233 });
234 }
235
236 let rng = if let Some(seed) = config.random_seed {
237 RandomGenerator::from_seed(seed)
238 } else {
239 RandomGenerator::new()
240 };
241
242 Ok(Self {
243 action_bits,
244 return_bits,
245 use_generic_planner,
246 distribution_uses_training_updates,
247 config,
248 phases,
249 steps: Vec::new(),
250 return_bins_by_step: Vec::new(),
251 history_base_step: 1,
252 total_steps_observed: 0,
253 rng,
254 })
255 }
256
257 pub fn steps_observed(&self) -> usize {
259 self.total_steps_observed
260 }
261
262 pub fn num_actions(&self) -> usize {
264 self.config.agent_actions
265 }
266
267 pub fn get_planned_action(&mut self) -> Action {
269 let q_values = self.estimate_q_values();
270 let greedy_action = argmax_with_fixed_tie_break(&q_values) as u64;
271 if self.config.baseline_exploration > 0.0
272 && self
273 .rng
274 .gen_bool(self.config.baseline_exploration.clamp(0.0, 1.0))
275 {
276 self.rng.gen_range(self.config.agent_actions) as u64
277 } else {
278 greedy_action
279 }
280 }
281
282 pub fn get_planned_action_with_extra_exploration(&mut self, extra_exploration: f64) -> Action {
288 let extra = extra_exploration.clamp(0.0, 1.0);
289 let tau = self.config.baseline_exploration.clamp(0.0, 1.0);
290 let effective = 1.0 - (1.0 - tau) * (1.0 - extra);
291 let q_values = self.estimate_q_values();
292 let greedy_action = argmax_with_fixed_tie_break(&q_values) as u64;
293 if effective > 0.0 && self.rng.gen_bool(effective) {
294 self.rng.gen_range(self.config.agent_actions) as u64
295 } else {
296 greedy_action
297 }
298 }
299
300 pub fn observe_transition(
305 &mut self,
306 action: Action,
307 observations: &[PerceptVal],
308 reward: Reward,
309 ) -> Result<(), String> {
310 if action as usize >= self.config.agent_actions {
311 return Err(format!(
312 "action out of range: action={} but agent_actions={}",
313 action, self.config.agent_actions
314 ));
315 }
316
317 let expected_obs = self.config.observation_stream_len.max(1);
318 if observations.len() != expected_obs {
319 return Err(format!(
320 "observation stream length mismatch: expected {}, got {}",
321 expected_obs,
322 observations.len()
323 ));
324 }
325
326 if reward < self.config.min_reward || reward > self.config.max_reward {
327 return Err(format!(
328 "reward out of configured range: reward={} not in [{}, {}]",
329 reward, self.config.min_reward, self.config.max_reward
330 ));
331 }
332
333 let obs_max = max_value_for_bits(self.config.observation_bits);
334 for &obs in observations {
335 if obs > obs_max {
336 return Err(format!(
337 "observation value {} does not fit observation_bits={} (max={})",
338 obs, self.config.observation_bits, obs_max
339 ));
340 }
341 }
342
343 let rew_shifted = (reward as i128) + (self.config.reward_offset as i128);
344 if rew_shifted < 0 {
345 return Err(format!(
346 "encoded reward became negative after offset: reward={} offset={}",
347 reward, self.config.reward_offset
348 ));
349 }
350 if self.config.reward_bits < 64 {
351 let max_enc = (1u128 << self.config.reward_bits) - 1;
352 if (rew_shifted as u128) > max_enc {
353 return Err(format!(
354 "encoded reward {} exceeds reward_bits={} capacity {}",
355 rew_shifted, self.config.reward_bits, max_enc
356 ));
357 }
358 }
359
360 self.steps.push(StepRecord {
361 action,
362 observations: observations.to_vec(),
363 reward,
364 });
365 self.total_steps_observed += 1;
366 self.return_bins_by_step.push(None);
367
368 self.maybe_learn_new_return()?;
369 self.maybe_prune_history();
370 Ok(())
371 }
372
373 fn maybe_learn_new_return(&mut self) -> Result<(), String> {
374 let t = self.total_steps_observed;
375 let h = self.config.return_horizon;
376 if t < h {
377 return Ok(());
378 }
379
380 let i = t + 1 - h;
382 let bin = self.compute_return_bin(i);
383 let local_idx = self.local_index(i)?;
384 self.return_bins_by_step[local_idx] = Some(bin);
385
386 let phase = i % self.config.augmentation_period;
387 self.advance_phase_model_to_step(phase, i)
388 }
389
390 fn estimate_q_values(&mut self) -> Vec<f64> {
391 if self.use_generic_planner {
392 return self.estimate_q_values_generic();
393 }
394
395 let step = self.total_steps_observed + 1;
396 let phase = step % self.config.augmentation_period;
397 let config = &self.config;
398 let steps = &self.steps;
399 let return_bins_by_step = &self.return_bins_by_step;
400 let history_base_step = self.history_base_step;
401 let action_bits = self.action_bits;
402 let return_bits = self.return_bits;
403
404 let mut q_values = vec![0.0; self.config.agent_actions];
405 let mut pushed_fast_forward = 0usize;
406
407 {
408 let model = &mut self.phases[phase];
409 let start = (model.last_augmented_step + 1).max(history_base_step);
410 let end = step.saturating_sub(1);
411 if start <= end {
412 for idx in start..=end {
413 pushed_fast_forward += push_step_tokens_history(
414 config,
415 history_base_step,
416 steps,
417 return_bins_by_step,
418 action_bits,
419 return_bits,
420 model.predictor.as_mut(),
421 phase,
422 idx,
423 );
424 }
425 }
426
427 for action in 0..self.config.agent_actions {
428 let pushed_action = push_encoded_bits_history(
429 model.predictor.as_mut(),
430 action as u64,
431 self.action_bits,
432 );
433 let dist = Self::predict_return_distribution(
434 self.config.return_bins,
435 self.return_bits,
436 model.predictor.as_mut(),
437 self.distribution_uses_training_updates,
438 );
439 q_values[action] = expectation_from_distribution(&dist);
440 pop_history_bits(model.predictor.as_mut(), pushed_action);
441 }
442
443 pop_history_bits(model.predictor.as_mut(), pushed_fast_forward);
444 }
445
446 q_values
447 }
448
449 fn estimate_q_values_generic(&mut self) -> Vec<f64> {
450 let step = self.total_steps_observed + 1;
451 let phase = step % self.config.augmentation_period;
452
453 let model = &self.phases[phase];
454 let mut context_predictor = model.predictor.boxed_clone();
455
456 let start = (model.last_augmented_step + 1).max(self.history_base_step);
457 let end = step.saturating_sub(1);
458 if start <= end {
459 for idx in start..=end {
460 push_step_tokens_history(
461 &self.config,
462 self.history_base_step,
463 &self.steps,
464 &self.return_bins_by_step,
465 self.action_bits,
466 self.return_bits,
467 context_predictor.as_mut(),
468 phase,
469 idx,
470 );
471 }
472 }
473
474 let mut q_values = vec![0.0; self.config.agent_actions];
475 for action in 0..self.config.agent_actions {
476 let mut action_predictor = context_predictor.boxed_clone();
477 let _ = push_encoded_bits_history(
478 action_predictor.as_mut(),
479 action as u64,
480 self.action_bits,
481 );
482 let dist = Self::predict_return_distribution_from_base_predictor(
483 self.config.return_bins,
484 self.return_bits,
485 action_predictor.as_ref(),
486 );
487 q_values[action] = expectation_from_distribution(&dist);
488 }
489
490 q_values
491 }
492
493 fn predict_return_distribution(
494 return_bins: usize,
495 return_bits: usize,
496 predictor: &mut dyn Predictor,
497 use_training_updates: bool,
498 ) -> Vec<f64> {
499 if return_bins == 1 {
500 return vec![1.0];
501 }
502
503 let mut probs = vec![0.0; return_bins];
504 for (bin, slot) in probs.iter_mut().enumerate() {
505 let mut p = 1.0f64;
506 let mut v = bin as u64;
507 for _ in 0..return_bits {
508 let bit = (v & 1) == 1;
509 v >>= 1;
510 let q = predictor.predict_prob(bit).clamp(1e-12, 1.0 - 1e-12);
511 p *= q;
512 if use_training_updates {
513 predictor.update(bit);
514 } else {
515 predictor.update_history(bit);
516 }
517 }
518 if use_training_updates {
519 revert_bits(predictor, return_bits);
520 } else {
521 pop_history_bits(predictor, return_bits);
522 }
523 *slot = p;
524 }
525
526 let sum: f64 = probs.iter().sum();
527 if !sum.is_finite() || sum <= 0.0 {
528 let u = 1.0 / (return_bins as f64);
529 probs.fill(u);
530 return probs;
531 }
532
533 for p in &mut probs {
534 *p /= sum;
535 }
536 probs
537 }
538
539 fn predict_return_distribution_from_base_predictor(
540 return_bins: usize,
541 return_bits: usize,
542 base_predictor: &dyn Predictor,
543 ) -> Vec<f64> {
544 if return_bins == 1 {
545 return vec![1.0];
546 }
547
548 let mut probs = vec![0.0; return_bins];
549 for (bin, slot) in probs.iter_mut().enumerate() {
550 let mut predictor = base_predictor.boxed_clone();
551 let mut p = 1.0f64;
552 let mut v = bin as u64;
553 for _ in 0..return_bits {
554 let bit = (v & 1) == 1;
555 v >>= 1;
556 let q = predictor.predict_prob(bit).clamp(1e-12, 1.0 - 1e-12);
557 p *= q;
558 predictor.update_history(bit);
559 }
560 *slot = p;
561 }
562
563 let sum: f64 = probs.iter().sum();
564 if !sum.is_finite() || sum <= 0.0 {
565 let u = 1.0 / (return_bins as f64);
566 probs.fill(u);
567 return probs;
568 }
569
570 for p in &mut probs {
571 *p /= sum;
572 }
573 probs
574 }
575
576 fn advance_phase_model_to_step(
577 &mut self,
578 phase: usize,
579 target_step: usize,
580 ) -> Result<(), String> {
581 let config = &self.config;
582 let steps = &self.steps;
583 let return_bins_by_step = &self.return_bins_by_step;
584 let history_base_step = self.history_base_step;
585 let action_bits = self.action_bits;
586 let return_bits = self.return_bits;
587 let model = &mut self.phases[phase];
588 if target_step <= model.last_augmented_step {
589 return Ok(());
590 }
591
592 let start = (model.last_augmented_step + 1).max(history_base_step);
593 for idx in start..=target_step {
594 push_action_tokens_history(
595 history_base_step,
596 steps,
597 action_bits,
598 model.predictor.as_mut(),
599 idx,
600 );
601
602 if idx % config.augmentation_period == phase {
603 let local_idx = idx - history_base_step;
604 let bin = return_bins_by_step[local_idx].ok_or_else(|| {
605 format!(
606 "missing return bin for step {} in phase {} while advancing model",
607 idx, phase
608 )
609 })?;
610 push_encoded_bits_train(model.predictor.as_mut(), bin, return_bits);
611 }
612
613 push_percept_tokens_history(
614 config,
615 history_base_step,
616 steps,
617 model.predictor.as_mut(),
618 idx,
619 );
620 }
621
622 model.last_augmented_step = target_step;
623 Ok(())
624 }
625
626 fn compute_return_bin(&self, start_step: usize) -> u64 {
627 let h = self.config.return_horizon;
628 let gamma = self.config.discount_gamma;
629
630 debug_assert!(gamma > 0.0 && gamma < 1.0);
631 let reward_range = (self.config.max_reward - self.config.min_reward) as f64;
632
633 let mut total = 0.0f64;
635 let mut gk = 1.0f64;
636 for k in 0..h {
637 let idx = start_step + k;
638 let local_idx = self
639 .local_index(idx)
640 .expect("return computation requires in-range history");
641 let r = self.steps[local_idx].reward;
642 let rn = if reward_range <= 0.0 {
643 0.0
644 } else {
645 ((r - self.config.min_reward) as f64 / reward_range).clamp(0.0, 1.0)
646 };
647 total += gk * rn;
648 gk *= gamma;
649 }
650 let ret = ((1.0 - gamma) * total).clamp(0.0, 1.0);
651
652 let mut bin = (ret * (self.config.return_bins as f64)).floor() as u64;
653 let max_bin = (self.config.return_bins as u64).saturating_sub(1);
654 if bin > max_bin {
655 bin = max_bin;
656 }
657 bin
658 }
659
660 fn local_index(&self, global_step: usize) -> Result<usize, String> {
661 if global_step < self.history_base_step || global_step > self.total_steps_observed {
662 return Err(format!(
663 "global step {} out of retained history range [{}, {}]",
664 global_step, self.history_base_step, self.total_steps_observed
665 ));
666 }
667 Ok(global_step - self.history_base_step)
668 }
669
670 fn maybe_prune_history(&mut self) {
671 let Some(keep_steps) = self.config.history_prune_keep_steps else {
672 return;
673 };
674 if self.steps.is_empty() {
675 return;
676 }
677
678 let min_phase_committed = self
679 .phases
680 .iter()
681 .map(|phase| phase.last_augmented_step)
682 .min()
683 .unwrap_or(0);
684
685 let next_start_needed = self
689 .total_steps_observed
690 .saturating_add(2)
691 .saturating_sub(self.config.return_horizon);
692 let returns_safe_drop_upto = next_start_needed.saturating_sub(1);
693
694 let mut safe_drop_upto = min_phase_committed.min(returns_safe_drop_upto);
695
696 let keep_floor_drop_upto = self.total_steps_observed.saturating_sub(keep_steps);
699 safe_drop_upto = safe_drop_upto.min(keep_floor_drop_upto);
700
701 if safe_drop_upto < self.history_base_step {
702 return;
703 }
704
705 let drain_count = safe_drop_upto - self.history_base_step + 1;
706 if drain_count == 0 || drain_count > self.steps.len() {
707 return;
708 }
709
710 self.steps.drain(0..drain_count);
711 self.return_bins_by_step.drain(0..drain_count);
712 self.history_base_step += drain_count;
713 }
714}
715
716fn push_step_tokens_history(
717 config: &AiqiConfig,
718 history_base_step: usize,
719 steps: &[StepRecord],
720 return_bins_by_step: &[Option<u64>],
721 action_bits: usize,
722 return_bits: usize,
723 predictor: &mut dyn Predictor,
724 phase: usize,
725 idx: usize,
726) -> usize {
727 let mut pushed = 0usize;
728 pushed += push_action_tokens_history(history_base_step, steps, action_bits, predictor, idx);
729
730 if idx % config.augmentation_period == phase {
731 let local_idx = idx - history_base_step;
732 if let Some(bin) = return_bins_by_step[local_idx] {
733 pushed += push_encoded_bits_history(predictor, bin, return_bits);
734 }
735 }
736
737 pushed + push_percept_tokens_history(config, history_base_step, steps, predictor, idx)
738}
739
740fn push_action_tokens_history(
741 history_base_step: usize,
742 steps: &[StepRecord],
743 action_bits: usize,
744 predictor: &mut dyn Predictor,
745 idx: usize,
746) -> usize {
747 let action = steps[idx - history_base_step].action;
748 push_encoded_bits_history(predictor, action, action_bits)
749}
750
751fn push_percept_tokens_history(
752 config: &AiqiConfig,
753 history_base_step: usize,
754 steps: &[StepRecord],
755 predictor: &mut dyn Predictor,
756 idx: usize,
757) -> usize {
758 let step = &steps[idx - history_base_step];
759 let mut pushed = 0usize;
760 for &obs in &step.observations {
761 pushed += push_encoded_bits_history(predictor, obs, config.observation_bits);
762 }
763 pushed
764 + push_encoded_reward_history(
765 predictor,
766 step.reward,
767 config.reward_bits,
768 config.reward_offset,
769 )
770}
771
772fn build_predictor(config: &AiqiConfig, return_bits: usize) -> Result<Box<dyn Predictor>, String> {
773 if let Some(rate_backend) = config.rate_backend.clone() {
774 let bit_backend = adapt_rate_backend_for_bit_tokens(rate_backend);
775 let predictor = RateBackendBitPredictor::new(bit_backend, config.rate_backend_max_order)?;
776 return Ok(Box::new(predictor));
777 }
778
779 match config.algorithm.as_str() {
780 "ctw" | "ac-ctw" | "ctw-context-tree" => Ok(Box::new(CtwPredictor::new(config.ct_depth))),
781 "fac-ctw" => {
782 Ok(Box::new(FacCtwPredictor::new(config.ct_depth, return_bits)))
784 }
785 "rosa" => {
786 let max_order = config
787 .rosa_max_order
788 .unwrap_or(config.rate_backend_max_order);
789 let bit_backend = adapt_rate_backend_for_bit_tokens(RateBackend::RosaPlus);
790 let predictor = RateBackendBitPredictor::new(bit_backend, max_order)?;
791 Ok(Box::new(predictor))
792 }
793 #[cfg(feature = "backend-rwkv")]
794 "rwkv" => {
795 let path = config.rwkv_model_path.as_ref().ok_or_else(|| {
796 "algorithm=rwkv requires rwkv_model_path when no rate_backend override is configured; for method-string RWKV configure rate_backend rwkv/rwkv7"
797 .to_string()
798 })?;
799 let model_arc = load_rwkv7_model_from_path(path);
800 let bit_backend =
801 adapt_rate_backend_for_bit_tokens(RateBackend::Rwkv7 { model: model_arc });
802 let predictor = RateBackendBitPredictor::new(bit_backend, config.rate_backend_max_order)?;
803 Ok(Box::new(predictor))
804 }
805 #[cfg(not(feature = "backend-rwkv"))]
806 "rwkv" => Err("algorithm=rwkv requires backend-rwkv feature".to_string()),
807 "zpaq" => Err(
808 "AIQI strict mode does not support algorithm=zpaq; configure a backend with strict frozen conditioning"
809 .to_string(),
810 ),
811 _ => Err(format!("Unknown AIQI algorithm: {}", config.algorithm)),
812 }
813}
814
815fn adapt_rate_backend_for_bit_tokens(backend: RateBackend) -> RateBackend {
816 match backend {
817 RateBackend::Ctw { depth } => RateBackend::FacCtw {
818 base_depth: depth,
819 num_percept_bits: 1,
820 encoding_bits: 1,
821 },
822 RateBackend::FacCtw { base_depth, .. } => RateBackend::FacCtw {
823 base_depth,
824 num_percept_bits: 1,
825 encoding_bits: 1,
826 },
827 RateBackend::Mixture { spec } => {
828 let experts = spec
829 .experts
830 .iter()
831 .map(|expert| crate::MixtureExpertSpec {
832 name: expert.name.clone(),
833 log_prior: expert.log_prior,
834 max_order: expert.max_order,
835 backend: adapt_rate_backend_for_bit_tokens(expert.backend.clone()),
836 })
837 .collect();
838
839 let mut adapted = MixtureSpec::new(spec.kind, experts).with_alpha(spec.alpha);
840 if let Some(decay) = spec.decay {
841 adapted = adapted.with_decay(decay);
842 }
843 RateBackend::Mixture {
844 spec: Arc::new(adapted),
845 }
846 }
847 RateBackend::Calibrated { spec } => RateBackend::Calibrated {
848 spec: Arc::new(CalibratedSpec {
849 base: adapt_rate_backend_for_bit_tokens(spec.base.clone()),
850 context: spec.context,
851 bins: spec.bins,
852 learning_rate: spec.learning_rate,
853 bias_clip: spec.bias_clip,
854 }),
855 },
856 other => other,
857 }
858}
859
860fn rate_backend_supports_aiqi_frozen_conditioning(backend: &RateBackend) -> bool {
861 match backend {
862 RateBackend::Zpaq { .. } => false,
863 RateBackend::Mixture { spec } => spec
864 .experts
865 .iter()
866 .all(|expert| rate_backend_supports_aiqi_frozen_conditioning(&expert.backend)),
867 RateBackend::Calibrated { spec } => {
868 rate_backend_supports_aiqi_frozen_conditioning(&spec.base)
869 }
870 _ => true,
871 }
872}
873
874fn aiqi_requires_generic_planner(config: &AiqiConfig) -> bool {
875 config.rate_backend.is_some()
876 || !matches!(
877 config.algorithm.as_str(),
878 "ctw" | "fac-ctw" | "ac-ctw" | "ctw-context-tree"
879 )
880}
881
882fn bits_for_cardinality(cardinality: usize) -> usize {
883 let n = cardinality.max(1);
884 let mut bits = 0usize;
885 while (1usize << bits) < n {
886 bits += 1;
887 }
888 bits.max(1)
889}
890
891fn max_value_for_bits(bits: usize) -> u64 {
892 if bits >= 64 {
893 u64::MAX
894 } else if bits == 0 {
895 0
896 } else {
897 (1u64 << bits) - 1
898 }
899}
900
901fn push_encoded_bits_train(predictor: &mut dyn Predictor, value: u64, bits: usize) -> usize {
902 let mut v = value;
903 for _ in 0..bits {
904 predictor.update((v & 1) == 1);
905 v >>= 1;
906 }
907 bits
908}
909
910fn push_encoded_bits_history(predictor: &mut dyn Predictor, value: u64, bits: usize) -> usize {
911 let mut v = value;
912 for _ in 0..bits {
913 predictor.update_history((v & 1) == 1);
914 v >>= 1;
915 }
916 bits
917}
918
919fn push_encoded_reward_history(
920 predictor: &mut dyn Predictor,
921 reward: Reward,
922 bits: usize,
923 offset: Reward,
924) -> usize {
925 let shifted = (reward as i128) + (offset as i128);
926 let as_u64 = if shifted <= 0 {
927 0
928 } else if shifted > (u64::MAX as i128) {
929 u64::MAX
930 } else {
931 shifted as u64
932 };
933 push_encoded_bits_history(predictor, as_u64, bits)
934}
935
936fn pop_history_bits(predictor: &mut dyn Predictor, bits: usize) {
937 for _ in 0..bits {
938 predictor.pop_history();
939 }
940}
941
942fn revert_bits(predictor: &mut dyn Predictor, bits: usize) {
943 for _ in 0..bits {
944 predictor.revert();
945 }
946}
947
948fn expectation_from_distribution(probs: &[f64]) -> f64 {
949 if probs.is_empty() {
950 return 0.0;
951 }
952 let m = probs.len() as f64;
953 probs
954 .iter()
955 .enumerate()
956 .map(|(i, p)| (i as f64 / m) * p)
957 .sum::<f64>()
958}
959
960fn argmax_with_fixed_tie_break(values: &[f64]) -> usize {
961 let mut best_value = f64::NEG_INFINITY;
962 let mut best_idx = 0usize;
963 for (i, &v) in values.iter().enumerate() {
964 if v > best_value {
965 best_value = v;
966 best_idx = i;
967 }
968 }
969 best_idx
970}
971
972#[cfg(test)]
973mod tests {
974 use super::*;
975
976 fn basic_config() -> AiqiConfig {
977 AiqiConfig {
978 algorithm: "ac-ctw".to_string(),
979 ct_depth: 8,
980 observation_bits: 1,
981 observation_stream_len: 1,
982 reward_bits: 1,
983 agent_actions: 2,
984 min_reward: 0,
985 max_reward: 1,
986 reward_offset: 0,
987 discount_gamma: 0.99,
988 return_horizon: 2,
989 return_bins: 8,
990 augmentation_period: 2,
991 history_prune_keep_steps: None,
992 baseline_exploration: 0.01,
993 random_seed: Some(7),
994 rate_backend: None,
995 rate_backend_max_order: 20,
996 rwkv_model_path: None,
997 rosa_max_order: None,
998 zpaq_method: None,
999 }
1000 }
1001
1002 #[derive(Clone, Default)]
1003 struct CountingPredictor {
1004 update_calls: usize,
1005 update_history_calls: usize,
1006 revert_calls: usize,
1007 pop_history_calls: usize,
1008 }
1009
1010 impl Predictor for CountingPredictor {
1011 fn update(&mut self, _sym: bool) {
1012 self.update_calls += 1;
1013 }
1014
1015 fn update_history(&mut self, _sym: bool) {
1016 self.update_history_calls += 1;
1017 }
1018
1019 fn revert(&mut self) {
1020 self.revert_calls += 1;
1021 }
1022
1023 fn pop_history(&mut self) {
1024 self.pop_history_calls += 1;
1025 }
1026
1027 fn predict_prob(&mut self, sym: bool) -> f64 {
1028 if sym { 0.75 } else { 0.25 }
1029 }
1030
1031 fn model_name(&self) -> String {
1032 "CountingPredictor".to_string()
1033 }
1034
1035 fn boxed_clone(&self) -> Box<dyn Predictor> {
1036 Box::new(self.clone())
1037 }
1038 }
1039
1040 #[test]
1041 fn config_rejects_invalid_period() {
1042 let mut cfg = basic_config();
1043 cfg.augmentation_period = 1;
1044 cfg.return_horizon = 2;
1045 let err = cfg
1046 .validate()
1047 .expect_err("N < H must be rejected for paper-correct augmentation");
1048 assert!(err.contains("augmentation_period"));
1049 }
1050
1051 #[test]
1052 fn config_rejects_zpaq_algorithm_in_strict_mode() {
1053 let mut cfg = basic_config();
1054 cfg.algorithm = "zpaq".to_string();
1055 let err = cfg
1056 .validate()
1057 .expect_err("strict AIQI must reject zpaq algorithm mode");
1058 assert!(err.contains("strict mode"));
1059 }
1060
1061 #[test]
1062 fn config_rejects_zpaq_rate_backend_in_strict_mode() {
1063 let mut cfg = basic_config();
1064 cfg.rate_backend = Some(RateBackend::Zpaq {
1065 method: "1".to_string(),
1066 });
1067 let err = cfg
1068 .validate()
1069 .expect_err("strict AIQI must reject zpaq rate backend");
1070 assert!(err.contains("strict frozen conditioning"));
1071 }
1072
1073 #[test]
1074 fn config_rejects_nonpaper_gamma_or_tau() {
1075 let mut cfg = basic_config();
1076 cfg.discount_gamma = 1.0;
1077 let err = cfg
1078 .validate()
1079 .expect_err("gamma=1 must be rejected for strict paper AIQI");
1080 assert!(err.contains("discount_gamma"));
1081
1082 cfg = basic_config();
1083 cfg.baseline_exploration = 0.0;
1084 let err = cfg
1085 .validate()
1086 .expect_err("tau=0 must be rejected for strict paper AIQI");
1087 assert!(err.contains("baseline_exploration"));
1088 }
1089
1090 #[test]
1091 fn aiqi_estimates_action_values_after_observations() {
1092 let mut agent = AiqiAgent::new(basic_config()).expect("valid aiqi config");
1093 for _ in 0..8 {
1094 agent
1095 .observe_transition(1, &[1], 1)
1096 .expect("transition should be accepted");
1097 }
1098
1099 let action = agent.get_planned_action();
1100 assert!(action <= 1);
1101 }
1102
1103 #[test]
1104 fn fac_ctw_predictor_uses_return_bit_width() {
1105 let mut cfg = basic_config();
1106 cfg.algorithm = "fac-ctw".to_string();
1107 cfg.return_bins = 8; let agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1110 let name = agent.phases[0].predictor.model_name();
1111 assert!(
1112 name.contains("k=3"),
1113 "FAC-CTW should factorize over return bits only, model_name={name}"
1114 );
1115 }
1116
1117 #[test]
1118 fn ac_ctw_path_uses_single_tree_predictor() {
1119 let mut cfg = basic_config();
1120 cfg.algorithm = "ac-ctw".to_string();
1121
1122 let agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1123 let name = agent.phases[0].predictor.model_name();
1124 assert!(
1125 name.starts_with("AC-CTW"),
1126 "ac-ctw should map to the single-tree CTW predictor, model_name={name}"
1127 );
1128 }
1129
1130 #[test]
1131 fn ctw_alias_matches_ac_ctw_predictor() {
1132 let mut cfg = basic_config();
1133 cfg.algorithm = "ctw".to_string();
1134
1135 let agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1136 let name = agent.phases[0].predictor.model_name();
1137 assert!(
1138 name.starts_with("AC-CTW"),
1139 "ctw alias should map to paper AIQI-CTW predictor, model_name={name}"
1140 );
1141 }
1142
1143 #[test]
1144 fn distribution_rollout_uses_update_and_revert_when_requested() {
1145 let mut predictor = CountingPredictor::default();
1146 let probs = AiqiAgent::predict_return_distribution(4, 2, &mut predictor, true);
1147
1148 assert_eq!(probs.len(), 4);
1149 assert_eq!(predictor.update_calls, 8);
1150 assert_eq!(predictor.revert_calls, 8);
1151 assert_eq!(predictor.update_history_calls, 0);
1152 assert_eq!(predictor.pop_history_calls, 0);
1153 }
1154
1155 #[test]
1156 fn distribution_rollout_uses_history_path_when_not_requested() {
1157 let mut predictor = CountingPredictor::default();
1158 let probs = AiqiAgent::predict_return_distribution(4, 2, &mut predictor, false);
1159
1160 assert_eq!(probs.len(), 4);
1161 assert_eq!(predictor.update_calls, 0);
1162 assert_eq!(predictor.revert_calls, 0);
1163 assert_eq!(predictor.update_history_calls, 8);
1164 assert_eq!(predictor.pop_history_calls, 8);
1165 }
1166
1167 #[test]
1168 fn ac_ctw_rollout_uses_training_updates() {
1169 let mut cfg = basic_config();
1170 cfg.algorithm = "ac-ctw".to_string();
1171
1172 let agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1173 assert!(
1174 agent.distribution_uses_training_updates,
1175 "ac-ctw should use update/revert during return distribution rollout"
1176 );
1177 }
1178
1179 #[test]
1180 fn return_bin_for_gamma_less_than_one_matches_paper_h_step_return() {
1181 let mut cfg = basic_config();
1182 cfg.discount_gamma = 0.5;
1183 cfg.return_bins = 8;
1184
1185 let mut agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1186 agent
1187 .observe_transition(0, &[0], 1)
1188 .expect("first transition stored");
1189 agent
1190 .observe_transition(0, &[0], 0)
1191 .expect("second transition should produce first return");
1192
1193 let bin = agent.return_bins_by_step[0].expect("first return should be available");
1194 assert_eq!(bin, 4);
1198 }
1199
1200 #[test]
1201 fn optional_history_pruning_bounds_retained_state_without_losing_progress() {
1202 let mut cfg = basic_config();
1203 cfg.return_horizon = 3;
1204 cfg.augmentation_period = 4;
1205 cfg.history_prune_keep_steps = Some(8);
1206
1207 let mut agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1208 for i in 0..256usize {
1209 let action = (i % 2) as u64;
1210 let obs = [(i % 2) as u64];
1211 let rew = (i % 2) as i64;
1212 agent
1213 .observe_transition(action, &obs, rew)
1214 .expect("transition should be accepted");
1215 }
1216
1217 assert_eq!(agent.steps_observed(), 256);
1219 assert!(
1220 agent.history_base_step > 1,
1221 "history should have been pruned"
1222 );
1223 assert!(
1224 agent.steps.len() < agent.steps_observed(),
1225 "retained history should be smaller than total observed"
1226 );
1227
1228 let action = agent.get_planned_action();
1229 assert!(action <= 1);
1230 }
1231}