infotheory/aixi/
aiqi.rs

1//! AIQI (Universal AI with Q-Induction) implementation.
2//!
3//! This module implements a model-free universal agent that predicts
4//! discretized H-step returns directly from augmented interaction history.
5//! The implementation follows the paper's phase-indexed periodic augmentation:
6//! for return horizon `H` and period `N >= H`, each phase model only inserts
7//! returns at indices `i % N == phase`.
8
9use crate::aixi::common::{Action, PerceptVal, RandomGenerator, Reward};
10use crate::aixi::model::{CtwPredictor, FacCtwPredictor, Predictor, RateBackendBitPredictor};
11#[cfg(feature = "backend-rwkv")]
12use crate::load_rwkv7_model_from_path;
13use crate::{CalibratedSpec, MixtureSpec, RateBackend};
14use std::sync::Arc;
15
16/// Configuration parameters for an AIQI agent.
17#[derive(Clone)]
18pub struct AiqiConfig {
19    /// Predictive backend.
20    ///
21    /// - `ac-ctw` / `ctw` / `ctw-context-tree`: paper AIQI-CTW path.
22    /// - `fac-ctw`: factorized CTW extension.
23    /// - `rosa` / `rwkv`: pluggable predictor extensions.
24    /// - `zpaq`: intentionally unsupported for AIQI strict conditioning.
25    pub algorithm: String,
26    /// Context depth for CTW/FAC-CTW backends.
27    pub ct_depth: usize,
28    /// Number of bits used to encode observations.
29    pub observation_bits: usize,
30    /// Number of observation symbols per environment step.
31    pub observation_stream_len: usize,
32    /// Number of bits used to encode rewards.
33    pub reward_bits: usize,
34    /// Number of valid actions.
35    pub agent_actions: usize,
36    /// Minimum possible environment reward.
37    pub min_reward: Reward,
38    /// Maximum possible environment reward.
39    pub max_reward: Reward,
40    /// Offset applied before encoding reward bits.
41    pub reward_offset: Reward,
42    /// Discount factor used when constructing H-step returns.
43    pub discount_gamma: f64,
44    /// Return horizon `H`.
45    pub return_horizon: usize,
46    /// Number of discretization bins `M` for returns.
47    pub return_bins: usize,
48    /// Augmentation period `N` (must satisfy `N >= H`).
49    pub augmentation_period: usize,
50    /// Optional history retention knob for bounded memory growth.
51    ///
52    /// - `None`: keep full history (default behavior, no pruning).
53    /// - `Some(k)`: keep at least the most recent `k` steps, while also
54    ///   preserving all steps still required for exact return construction and
55    ///   deferred phase-model advancement.
56    pub history_prune_keep_steps: Option<usize>,
57    /// Baseline epsilon-greedy exploration probability `tau`.
58    pub baseline_exploration: f64,
59    /// Optional deterministic RNG seed for action selection/exploration.
60    ///
61    /// When `None`, a fresh runtime-derived seed is used.
62    pub random_seed: Option<u64>,
63    /// Optional generic rate backend.
64    ///
65    /// When set, this takes precedence over `algorithm` and routes AIQI
66    /// prediction through the shared `RateBackend` abstraction.
67    pub rate_backend: Option<RateBackend>,
68    /// Max-order hint for `rate_backend` constructors that use it (for example ROSA).
69    pub rate_backend_max_order: i64,
70    /// Optional RWKV model path.
71    ///
72    /// Required only when selecting `algorithm="rwkv"` and no `rate_backend`
73    /// override is configured.
74    pub rwkv_model_path: Option<String>,
75    /// Optional ROSA max order.
76    pub rosa_max_order: Option<i64>,
77    /// Optional ZPAQ method string.
78    pub zpaq_method: Option<String>,
79}
80
81impl AiqiConfig {
82    /// Validate configuration constraints.
83    pub fn validate(&self) -> Result<(), String> {
84        if self.agent_actions == 0 {
85            return Err("agent_actions must be >= 1".to_string());
86        }
87        if self.return_horizon == 0 {
88            return Err("return_horizon must be >= 1".to_string());
89        }
90        if self.return_bins == 0 {
91            return Err("return_bins must be >= 1".to_string());
92        }
93        if self.augmentation_period < self.return_horizon {
94            return Err(format!(
95                "augmentation_period must be >= return_horizon (got N={}, H={})",
96                self.augmentation_period, self.return_horizon
97            ));
98        }
99        if !(0.0 < self.discount_gamma && self.discount_gamma < 1.0) {
100            return Err(format!(
101                "discount_gamma must be in (0, 1) for paper-correct AIQI, got {}",
102                self.discount_gamma
103            ));
104        }
105        if !(0.0 < self.baseline_exploration && self.baseline_exploration <= 1.0) {
106            return Err(format!(
107                "baseline_exploration (tau) must be in (0, 1] for paper-correct AIQI, got {}",
108                self.baseline_exploration
109            ));
110        }
111        if self.max_reward < self.min_reward {
112            return Err(format!(
113                "max_reward must be >= min_reward (got {} < {})",
114                self.max_reward, self.min_reward
115            ));
116        }
117
118        // `rate_backend` takes precedence over `algorithm`; only validate
119        // algorithm choices when no backend override is configured.
120        if self.rate_backend.is_none() {
121            match self.algorithm.as_str() {
122                "ctw" | "fac-ctw" | "ac-ctw" | "ctw-context-tree" | "rosa" => {}
123                "zpaq" => {
124                    return Err(
125                        "AIQI strict mode does not support algorithm=zpaq: zpaq backends do not provide strict frozen conditioning"
126                            .to_string(),
127                    )
128                }
129                #[cfg(feature = "backend-rwkv")]
130                "rwkv" => {}
131                #[cfg(not(feature = "backend-rwkv"))]
132                "rwkv" => {
133                    return Err("algorithm=rwkv requires backend-rwkv feature".to_string())
134                }
135                other => return Err(format!("Unknown AIQI algorithm: {other}")),
136            }
137        }
138
139        if let Some(rate_backend) = &self.rate_backend {
140            if !rate_backend_supports_aiqi_frozen_conditioning(rate_backend) {
141                return Err(
142                    "AIQI strict mode requires frozen context updates; configured rate_backend contains zpaq which does not provide strict frozen conditioning"
143                        .to_string(),
144                );
145            }
146        }
147
148        #[cfg(feature = "backend-rwkv")]
149        if self.rate_backend.is_none() && self.algorithm == "rwkv" {
150            match self.rwkv_model_path.as_deref() {
151                Some(path) if !path.trim().is_empty() => {}
152                _ => {
153                    return Err(
154                        "algorithm=rwkv requires rwkv_model_path when no rate_backend override is configured; for method-string RWKV configure rate_backend rwkv/rwkv7"
155                            .to_string(),
156                    )
157                }
158            }
159        }
160
161        let min_shifted = (self.min_reward as i128) + (self.reward_offset as i128);
162        let max_shifted = (self.max_reward as i128) + (self.reward_offset as i128);
163        if min_shifted < 0 {
164            return Err(format!(
165                "reward_offset too small: min_reward + reward_offset must be >= 0 (got {})",
166                min_shifted
167            ));
168        }
169        if self.reward_bits < 64 {
170            let max_enc = (1u128 << self.reward_bits) - 1;
171            if (max_shifted as u128) > max_enc {
172                return Err(format!(
173                    "reward_bits too small for configured reward range: max shifted reward {} exceeds {}",
174                    max_shifted, max_enc
175                ));
176            }
177        }
178
179        Ok(())
180    }
181}
182
183#[derive(Clone, Debug)]
184struct StepRecord {
185    action: Action,
186    observations: Vec<PerceptVal>,
187    reward: Reward,
188}
189
190struct PhaseModel {
191    predictor: Box<dyn Predictor>,
192    // Largest step index for which this phase model has consumed
193    // the augmented stream up to and including that step's percept.
194    last_augmented_step: usize,
195}
196
197/// AIQI agent with phase-indexed augmented return predictors.
198pub struct AiqiAgent {
199    config: AiqiConfig,
200    phases: Vec<PhaseModel>,
201    steps: Vec<StepRecord>,
202    return_bins_by_step: Vec<Option<u64>>,
203    // Global 1-based index of steps[0] / return_bins_by_step[0].
204    history_base_step: usize,
205    // Total number of transitions observed so far (global 1-based max step index).
206    total_steps_observed: usize,
207    action_bits: usize,
208    return_bits: usize,
209    use_generic_planner: bool,
210    distribution_uses_training_updates: bool,
211    rng: RandomGenerator,
212}
213
214impl AiqiAgent {
215    /// Construct a new AIQI agent.
216    pub fn new(config: AiqiConfig) -> Result<Self, String> {
217        config.validate()?;
218
219        let action_bits = bits_for_cardinality(config.agent_actions);
220        let return_bits = bits_for_cardinality(config.return_bins);
221        let use_generic_planner = aiqi_requires_generic_planner(&config);
222        let distribution_uses_training_updates = config.rate_backend.is_none()
223            && matches!(
224                config.algorithm.as_str(),
225                "ctw" | "fac-ctw" | "ac-ctw" | "ctw-context-tree"
226            );
227
228        let mut phases = Vec::with_capacity(config.augmentation_period);
229        for _ in 0..config.augmentation_period {
230            phases.push(PhaseModel {
231                predictor: build_predictor(&config, return_bits)?,
232                last_augmented_step: 0,
233            });
234        }
235
236        let rng = if let Some(seed) = config.random_seed {
237            RandomGenerator::from_seed(seed)
238        } else {
239            RandomGenerator::new()
240        };
241
242        Ok(Self {
243            action_bits,
244            return_bits,
245            use_generic_planner,
246            distribution_uses_training_updates,
247            config,
248            phases,
249            steps: Vec::new(),
250            return_bins_by_step: Vec::new(),
251            history_base_step: 1,
252            total_steps_observed: 0,
253            rng,
254        })
255    }
256
257    /// Number of transitions incorporated so far.
258    pub fn steps_observed(&self) -> usize {
259        self.total_steps_observed
260    }
261
262    /// Returns the configured number of actions.
263    pub fn num_actions(&self) -> usize {
264        self.config.agent_actions
265    }
266
267    /// Select the next action from the current history.
268    pub fn get_planned_action(&mut self) -> Action {
269        let q_values = self.estimate_q_values();
270        let greedy_action = argmax_with_fixed_tie_break(&q_values) as u64;
271        if self.config.baseline_exploration > 0.0
272            && self
273                .rng
274                .gen_bool(self.config.baseline_exploration.clamp(0.0, 1.0))
275        {
276            self.rng.gen_range(self.config.agent_actions) as u64
277        } else {
278            greedy_action
279        }
280    }
281
282    /// Select the next action, adding optional extra exploration.
283    ///
284    /// The extra exploration probability is combined as
285    /// `p = 1 - (1 - tau) * (1 - extra)`, where `tau` is the baseline
286    /// exploration in [`AiqiConfig`].
287    pub fn get_planned_action_with_extra_exploration(&mut self, extra_exploration: f64) -> Action {
288        let extra = extra_exploration.clamp(0.0, 1.0);
289        let tau = self.config.baseline_exploration.clamp(0.0, 1.0);
290        let effective = 1.0 - (1.0 - tau) * (1.0 - extra);
291        let q_values = self.estimate_q_values();
292        let greedy_action = argmax_with_fixed_tie_break(&q_values) as u64;
293        if effective > 0.0 && self.rng.gen_bool(effective) {
294            self.rng.gen_range(self.config.agent_actions) as u64
295        } else {
296            greedy_action
297        }
298    }
299
300    /// Record one environment transition `(action, observations, reward)`.
301    ///
302    /// This appends to history and, when enough future rewards are known,
303    /// computes and learns one newly available discretized return.
304    pub fn observe_transition(
305        &mut self,
306        action: Action,
307        observations: &[PerceptVal],
308        reward: Reward,
309    ) -> Result<(), String> {
310        if action as usize >= self.config.agent_actions {
311            return Err(format!(
312                "action out of range: action={} but agent_actions={}",
313                action, self.config.agent_actions
314            ));
315        }
316
317        let expected_obs = self.config.observation_stream_len.max(1);
318        if observations.len() != expected_obs {
319            return Err(format!(
320                "observation stream length mismatch: expected {}, got {}",
321                expected_obs,
322                observations.len()
323            ));
324        }
325
326        if reward < self.config.min_reward || reward > self.config.max_reward {
327            return Err(format!(
328                "reward out of configured range: reward={} not in [{}, {}]",
329                reward, self.config.min_reward, self.config.max_reward
330            ));
331        }
332
333        let obs_max = max_value_for_bits(self.config.observation_bits);
334        for &obs in observations {
335            if obs > obs_max {
336                return Err(format!(
337                    "observation value {} does not fit observation_bits={} (max={})",
338                    obs, self.config.observation_bits, obs_max
339                ));
340            }
341        }
342
343        let rew_shifted = (reward as i128) + (self.config.reward_offset as i128);
344        if rew_shifted < 0 {
345            return Err(format!(
346                "encoded reward became negative after offset: reward={} offset={}",
347                reward, self.config.reward_offset
348            ));
349        }
350        if self.config.reward_bits < 64 {
351            let max_enc = (1u128 << self.config.reward_bits) - 1;
352            if (rew_shifted as u128) > max_enc {
353                return Err(format!(
354                    "encoded reward {} exceeds reward_bits={} capacity {}",
355                    rew_shifted, self.config.reward_bits, max_enc
356                ));
357            }
358        }
359
360        self.steps.push(StepRecord {
361            action,
362            observations: observations.to_vec(),
363            reward,
364        });
365        self.total_steps_observed += 1;
366        self.return_bins_by_step.push(None);
367
368        self.maybe_learn_new_return()?;
369        self.maybe_prune_history();
370        Ok(())
371    }
372
373    fn maybe_learn_new_return(&mut self) -> Result<(), String> {
374        let t = self.total_steps_observed;
375        let h = self.config.return_horizon;
376        if t < h {
377            return Ok(());
378        }
379
380        // Newly available return index (1-based): i = t - H + 1.
381        let i = t + 1 - h;
382        let bin = self.compute_return_bin(i);
383        let local_idx = self.local_index(i)?;
384        self.return_bins_by_step[local_idx] = Some(bin);
385
386        let phase = i % self.config.augmentation_period;
387        self.advance_phase_model_to_step(phase, i)
388    }
389
390    fn estimate_q_values(&mut self) -> Vec<f64> {
391        if self.use_generic_planner {
392            return self.estimate_q_values_generic();
393        }
394
395        let step = self.total_steps_observed + 1;
396        let phase = step % self.config.augmentation_period;
397        let config = &self.config;
398        let steps = &self.steps;
399        let return_bins_by_step = &self.return_bins_by_step;
400        let history_base_step = self.history_base_step;
401        let action_bits = self.action_bits;
402        let return_bits = self.return_bits;
403
404        let mut q_values = vec![0.0; self.config.agent_actions];
405        let mut pushed_fast_forward = 0usize;
406
407        {
408            let model = &mut self.phases[phase];
409            let start = (model.last_augmented_step + 1).max(history_base_step);
410            let end = step.saturating_sub(1);
411            if start <= end {
412                for idx in start..=end {
413                    pushed_fast_forward += push_step_tokens_history(
414                        config,
415                        history_base_step,
416                        steps,
417                        return_bins_by_step,
418                        action_bits,
419                        return_bits,
420                        model.predictor.as_mut(),
421                        phase,
422                        idx,
423                    );
424                }
425            }
426
427            for action in 0..self.config.agent_actions {
428                let pushed_action = push_encoded_bits_history(
429                    model.predictor.as_mut(),
430                    action as u64,
431                    self.action_bits,
432                );
433                let dist = Self::predict_return_distribution(
434                    self.config.return_bins,
435                    self.return_bits,
436                    model.predictor.as_mut(),
437                    self.distribution_uses_training_updates,
438                );
439                q_values[action] = expectation_from_distribution(&dist);
440                pop_history_bits(model.predictor.as_mut(), pushed_action);
441            }
442
443            pop_history_bits(model.predictor.as_mut(), pushed_fast_forward);
444        }
445
446        q_values
447    }
448
449    fn estimate_q_values_generic(&mut self) -> Vec<f64> {
450        let step = self.total_steps_observed + 1;
451        let phase = step % self.config.augmentation_period;
452
453        let model = &self.phases[phase];
454        let mut context_predictor = model.predictor.boxed_clone();
455
456        let start = (model.last_augmented_step + 1).max(self.history_base_step);
457        let end = step.saturating_sub(1);
458        if start <= end {
459            for idx in start..=end {
460                push_step_tokens_history(
461                    &self.config,
462                    self.history_base_step,
463                    &self.steps,
464                    &self.return_bins_by_step,
465                    self.action_bits,
466                    self.return_bits,
467                    context_predictor.as_mut(),
468                    phase,
469                    idx,
470                );
471            }
472        }
473
474        let mut q_values = vec![0.0; self.config.agent_actions];
475        for action in 0..self.config.agent_actions {
476            let mut action_predictor = context_predictor.boxed_clone();
477            let _ = push_encoded_bits_history(
478                action_predictor.as_mut(),
479                action as u64,
480                self.action_bits,
481            );
482            let dist = Self::predict_return_distribution_from_base_predictor(
483                self.config.return_bins,
484                self.return_bits,
485                action_predictor.as_ref(),
486            );
487            q_values[action] = expectation_from_distribution(&dist);
488        }
489
490        q_values
491    }
492
493    fn predict_return_distribution(
494        return_bins: usize,
495        return_bits: usize,
496        predictor: &mut dyn Predictor,
497        use_training_updates: bool,
498    ) -> Vec<f64> {
499        if return_bins == 1 {
500            return vec![1.0];
501        }
502
503        let mut probs = vec![0.0; return_bins];
504        for (bin, slot) in probs.iter_mut().enumerate() {
505            let mut p = 1.0f64;
506            let mut v = bin as u64;
507            for _ in 0..return_bits {
508                let bit = (v & 1) == 1;
509                v >>= 1;
510                let q = predictor.predict_prob(bit).clamp(1e-12, 1.0 - 1e-12);
511                p *= q;
512                if use_training_updates {
513                    predictor.update(bit);
514                } else {
515                    predictor.update_history(bit);
516                }
517            }
518            if use_training_updates {
519                revert_bits(predictor, return_bits);
520            } else {
521                pop_history_bits(predictor, return_bits);
522            }
523            *slot = p;
524        }
525
526        let sum: f64 = probs.iter().sum();
527        if !sum.is_finite() || sum <= 0.0 {
528            let u = 1.0 / (return_bins as f64);
529            probs.fill(u);
530            return probs;
531        }
532
533        for p in &mut probs {
534            *p /= sum;
535        }
536        probs
537    }
538
539    fn predict_return_distribution_from_base_predictor(
540        return_bins: usize,
541        return_bits: usize,
542        base_predictor: &dyn Predictor,
543    ) -> Vec<f64> {
544        if return_bins == 1 {
545            return vec![1.0];
546        }
547
548        let mut probs = vec![0.0; return_bins];
549        for (bin, slot) in probs.iter_mut().enumerate() {
550            let mut predictor = base_predictor.boxed_clone();
551            let mut p = 1.0f64;
552            let mut v = bin as u64;
553            for _ in 0..return_bits {
554                let bit = (v & 1) == 1;
555                v >>= 1;
556                let q = predictor.predict_prob(bit).clamp(1e-12, 1.0 - 1e-12);
557                p *= q;
558                predictor.update_history(bit);
559            }
560            *slot = p;
561        }
562
563        let sum: f64 = probs.iter().sum();
564        if !sum.is_finite() || sum <= 0.0 {
565            let u = 1.0 / (return_bins as f64);
566            probs.fill(u);
567            return probs;
568        }
569
570        for p in &mut probs {
571            *p /= sum;
572        }
573        probs
574    }
575
576    fn advance_phase_model_to_step(
577        &mut self,
578        phase: usize,
579        target_step: usize,
580    ) -> Result<(), String> {
581        let config = &self.config;
582        let steps = &self.steps;
583        let return_bins_by_step = &self.return_bins_by_step;
584        let history_base_step = self.history_base_step;
585        let action_bits = self.action_bits;
586        let return_bits = self.return_bits;
587        let model = &mut self.phases[phase];
588        if target_step <= model.last_augmented_step {
589            return Ok(());
590        }
591
592        let start = (model.last_augmented_step + 1).max(history_base_step);
593        for idx in start..=target_step {
594            push_action_tokens_history(
595                history_base_step,
596                steps,
597                action_bits,
598                model.predictor.as_mut(),
599                idx,
600            );
601
602            if idx % config.augmentation_period == phase {
603                let local_idx = idx - history_base_step;
604                let bin = return_bins_by_step[local_idx].ok_or_else(|| {
605                    format!(
606                        "missing return bin for step {} in phase {} while advancing model",
607                        idx, phase
608                    )
609                })?;
610                push_encoded_bits_train(model.predictor.as_mut(), bin, return_bits);
611            }
612
613            push_percept_tokens_history(
614                config,
615                history_base_step,
616                steps,
617                model.predictor.as_mut(),
618                idx,
619            );
620        }
621
622        model.last_augmented_step = target_step;
623        Ok(())
624    }
625
626    fn compute_return_bin(&self, start_step: usize) -> u64 {
627        let h = self.config.return_horizon;
628        let gamma = self.config.discount_gamma;
629
630        debug_assert!(gamma > 0.0 && gamma < 1.0);
631        let reward_range = (self.config.max_reward - self.config.min_reward) as f64;
632
633        // Paper definition: R_{t,H} = (1-gamma) * sum_{k=0}^{H-1} gamma^k r_{t+k}.
634        let mut total = 0.0f64;
635        let mut gk = 1.0f64;
636        for k in 0..h {
637            let idx = start_step + k;
638            let local_idx = self
639                .local_index(idx)
640                .expect("return computation requires in-range history");
641            let r = self.steps[local_idx].reward;
642            let rn = if reward_range <= 0.0 {
643                0.0
644            } else {
645                ((r - self.config.min_reward) as f64 / reward_range).clamp(0.0, 1.0)
646            };
647            total += gk * rn;
648            gk *= gamma;
649        }
650        let ret = ((1.0 - gamma) * total).clamp(0.0, 1.0);
651
652        let mut bin = (ret * (self.config.return_bins as f64)).floor() as u64;
653        let max_bin = (self.config.return_bins as u64).saturating_sub(1);
654        if bin > max_bin {
655            bin = max_bin;
656        }
657        bin
658    }
659
660    fn local_index(&self, global_step: usize) -> Result<usize, String> {
661        if global_step < self.history_base_step || global_step > self.total_steps_observed {
662            return Err(format!(
663                "global step {} out of retained history range [{}, {}]",
664                global_step, self.history_base_step, self.total_steps_observed
665            ));
666        }
667        Ok(global_step - self.history_base_step)
668    }
669
670    fn maybe_prune_history(&mut self) {
671        let Some(keep_steps) = self.config.history_prune_keep_steps else {
672            return;
673        };
674        if self.steps.is_empty() {
675            return;
676        }
677
678        let min_phase_committed = self
679            .phases
680            .iter()
681            .map(|phase| phase.last_augmented_step)
682            .min()
683            .unwrap_or(0);
684
685        // For the next return update, we must retain steps from
686        // (t+2-H) onward (1-based indexing). Everything before that is no
687        // longer needed for exact H-step return construction.
688        let next_start_needed = self
689            .total_steps_observed
690            .saturating_add(2)
691            .saturating_sub(self.config.return_horizon);
692        let returns_safe_drop_upto = next_start_needed.saturating_sub(1);
693
694        let mut safe_drop_upto = min_phase_committed.min(returns_safe_drop_upto);
695
696        // Optional retention floor: keep at least `keep_steps` most recent
697        // transitions in memory for diagnostics/debugging.
698        let keep_floor_drop_upto = self.total_steps_observed.saturating_sub(keep_steps);
699        safe_drop_upto = safe_drop_upto.min(keep_floor_drop_upto);
700
701        if safe_drop_upto < self.history_base_step {
702            return;
703        }
704
705        let drain_count = safe_drop_upto - self.history_base_step + 1;
706        if drain_count == 0 || drain_count > self.steps.len() {
707            return;
708        }
709
710        self.steps.drain(0..drain_count);
711        self.return_bins_by_step.drain(0..drain_count);
712        self.history_base_step += drain_count;
713    }
714}
715
716fn push_step_tokens_history(
717    config: &AiqiConfig,
718    history_base_step: usize,
719    steps: &[StepRecord],
720    return_bins_by_step: &[Option<u64>],
721    action_bits: usize,
722    return_bits: usize,
723    predictor: &mut dyn Predictor,
724    phase: usize,
725    idx: usize,
726) -> usize {
727    let mut pushed = 0usize;
728    pushed += push_action_tokens_history(history_base_step, steps, action_bits, predictor, idx);
729
730    if idx % config.augmentation_period == phase {
731        let local_idx = idx - history_base_step;
732        if let Some(bin) = return_bins_by_step[local_idx] {
733            pushed += push_encoded_bits_history(predictor, bin, return_bits);
734        }
735    }
736
737    pushed + push_percept_tokens_history(config, history_base_step, steps, predictor, idx)
738}
739
740fn push_action_tokens_history(
741    history_base_step: usize,
742    steps: &[StepRecord],
743    action_bits: usize,
744    predictor: &mut dyn Predictor,
745    idx: usize,
746) -> usize {
747    let action = steps[idx - history_base_step].action;
748    push_encoded_bits_history(predictor, action, action_bits)
749}
750
751fn push_percept_tokens_history(
752    config: &AiqiConfig,
753    history_base_step: usize,
754    steps: &[StepRecord],
755    predictor: &mut dyn Predictor,
756    idx: usize,
757) -> usize {
758    let step = &steps[idx - history_base_step];
759    let mut pushed = 0usize;
760    for &obs in &step.observations {
761        pushed += push_encoded_bits_history(predictor, obs, config.observation_bits);
762    }
763    pushed
764        + push_encoded_reward_history(
765            predictor,
766            step.reward,
767            config.reward_bits,
768            config.reward_offset,
769        )
770}
771
772fn build_predictor(config: &AiqiConfig, return_bits: usize) -> Result<Box<dyn Predictor>, String> {
773    if let Some(rate_backend) = config.rate_backend.clone() {
774        let bit_backend = adapt_rate_backend_for_bit_tokens(rate_backend);
775        let predictor = RateBackendBitPredictor::new(bit_backend, config.rate_backend_max_order)?;
776        return Ok(Box::new(predictor));
777    }
778
779    match config.algorithm.as_str() {
780        "ctw" | "ac-ctw" | "ctw-context-tree" => Ok(Box::new(CtwPredictor::new(config.ct_depth))),
781        "fac-ctw" => {
782            // AIQI-FAC-CTW extension: factorized return-bit modeling.
783            Ok(Box::new(FacCtwPredictor::new(config.ct_depth, return_bits)))
784        }
785        "rosa" => {
786            let max_order = config
787                .rosa_max_order
788                .unwrap_or(config.rate_backend_max_order);
789            let bit_backend = adapt_rate_backend_for_bit_tokens(RateBackend::RosaPlus);
790            let predictor = RateBackendBitPredictor::new(bit_backend, max_order)?;
791            Ok(Box::new(predictor))
792        }
793        #[cfg(feature = "backend-rwkv")]
794        "rwkv" => {
795            let path = config.rwkv_model_path.as_ref().ok_or_else(|| {
796                "algorithm=rwkv requires rwkv_model_path when no rate_backend override is configured; for method-string RWKV configure rate_backend rwkv/rwkv7"
797                    .to_string()
798            })?;
799            let model_arc = load_rwkv7_model_from_path(path);
800            let bit_backend =
801                adapt_rate_backend_for_bit_tokens(RateBackend::Rwkv7 { model: model_arc });
802            let predictor = RateBackendBitPredictor::new(bit_backend, config.rate_backend_max_order)?;
803            Ok(Box::new(predictor))
804        }
805        #[cfg(not(feature = "backend-rwkv"))]
806        "rwkv" => Err("algorithm=rwkv requires backend-rwkv feature".to_string()),
807        "zpaq" => Err(
808            "AIQI strict mode does not support algorithm=zpaq; configure a backend with strict frozen conditioning"
809                .to_string(),
810        ),
811        _ => Err(format!("Unknown AIQI algorithm: {}", config.algorithm)),
812    }
813}
814
815fn adapt_rate_backend_for_bit_tokens(backend: RateBackend) -> RateBackend {
816    match backend {
817        RateBackend::Ctw { depth } => RateBackend::FacCtw {
818            base_depth: depth,
819            num_percept_bits: 1,
820            encoding_bits: 1,
821        },
822        RateBackend::FacCtw { base_depth, .. } => RateBackend::FacCtw {
823            base_depth,
824            num_percept_bits: 1,
825            encoding_bits: 1,
826        },
827        RateBackend::Mixture { spec } => {
828            let experts = spec
829                .experts
830                .iter()
831                .map(|expert| crate::MixtureExpertSpec {
832                    name: expert.name.clone(),
833                    log_prior: expert.log_prior,
834                    max_order: expert.max_order,
835                    backend: adapt_rate_backend_for_bit_tokens(expert.backend.clone()),
836                })
837                .collect();
838
839            let mut adapted = MixtureSpec::new(spec.kind, experts).with_alpha(spec.alpha);
840            if let Some(decay) = spec.decay {
841                adapted = adapted.with_decay(decay);
842            }
843            RateBackend::Mixture {
844                spec: Arc::new(adapted),
845            }
846        }
847        RateBackend::Calibrated { spec } => RateBackend::Calibrated {
848            spec: Arc::new(CalibratedSpec {
849                base: adapt_rate_backend_for_bit_tokens(spec.base.clone()),
850                context: spec.context,
851                bins: spec.bins,
852                learning_rate: spec.learning_rate,
853                bias_clip: spec.bias_clip,
854            }),
855        },
856        other => other,
857    }
858}
859
860fn rate_backend_supports_aiqi_frozen_conditioning(backend: &RateBackend) -> bool {
861    match backend {
862        RateBackend::Zpaq { .. } => false,
863        RateBackend::Mixture { spec } => spec
864            .experts
865            .iter()
866            .all(|expert| rate_backend_supports_aiqi_frozen_conditioning(&expert.backend)),
867        RateBackend::Calibrated { spec } => {
868            rate_backend_supports_aiqi_frozen_conditioning(&spec.base)
869        }
870        _ => true,
871    }
872}
873
874fn aiqi_requires_generic_planner(config: &AiqiConfig) -> bool {
875    config.rate_backend.is_some()
876        || !matches!(
877            config.algorithm.as_str(),
878            "ctw" | "fac-ctw" | "ac-ctw" | "ctw-context-tree"
879        )
880}
881
882fn bits_for_cardinality(cardinality: usize) -> usize {
883    let n = cardinality.max(1);
884    let mut bits = 0usize;
885    while (1usize << bits) < n {
886        bits += 1;
887    }
888    bits.max(1)
889}
890
891fn max_value_for_bits(bits: usize) -> u64 {
892    if bits >= 64 {
893        u64::MAX
894    } else if bits == 0 {
895        0
896    } else {
897        (1u64 << bits) - 1
898    }
899}
900
901fn push_encoded_bits_train(predictor: &mut dyn Predictor, value: u64, bits: usize) -> usize {
902    let mut v = value;
903    for _ in 0..bits {
904        predictor.update((v & 1) == 1);
905        v >>= 1;
906    }
907    bits
908}
909
910fn push_encoded_bits_history(predictor: &mut dyn Predictor, value: u64, bits: usize) -> usize {
911    let mut v = value;
912    for _ in 0..bits {
913        predictor.update_history((v & 1) == 1);
914        v >>= 1;
915    }
916    bits
917}
918
919fn push_encoded_reward_history(
920    predictor: &mut dyn Predictor,
921    reward: Reward,
922    bits: usize,
923    offset: Reward,
924) -> usize {
925    let shifted = (reward as i128) + (offset as i128);
926    let as_u64 = if shifted <= 0 {
927        0
928    } else if shifted > (u64::MAX as i128) {
929        u64::MAX
930    } else {
931        shifted as u64
932    };
933    push_encoded_bits_history(predictor, as_u64, bits)
934}
935
936fn pop_history_bits(predictor: &mut dyn Predictor, bits: usize) {
937    for _ in 0..bits {
938        predictor.pop_history();
939    }
940}
941
942fn revert_bits(predictor: &mut dyn Predictor, bits: usize) {
943    for _ in 0..bits {
944        predictor.revert();
945    }
946}
947
948fn expectation_from_distribution(probs: &[f64]) -> f64 {
949    if probs.is_empty() {
950        return 0.0;
951    }
952    let m = probs.len() as f64;
953    probs
954        .iter()
955        .enumerate()
956        .map(|(i, p)| (i as f64 / m) * p)
957        .sum::<f64>()
958}
959
960fn argmax_with_fixed_tie_break(values: &[f64]) -> usize {
961    let mut best_value = f64::NEG_INFINITY;
962    let mut best_idx = 0usize;
963    for (i, &v) in values.iter().enumerate() {
964        if v > best_value {
965            best_value = v;
966            best_idx = i;
967        }
968    }
969    best_idx
970}
971
972#[cfg(test)]
973mod tests {
974    use super::*;
975
976    fn basic_config() -> AiqiConfig {
977        AiqiConfig {
978            algorithm: "ac-ctw".to_string(),
979            ct_depth: 8,
980            observation_bits: 1,
981            observation_stream_len: 1,
982            reward_bits: 1,
983            agent_actions: 2,
984            min_reward: 0,
985            max_reward: 1,
986            reward_offset: 0,
987            discount_gamma: 0.99,
988            return_horizon: 2,
989            return_bins: 8,
990            augmentation_period: 2,
991            history_prune_keep_steps: None,
992            baseline_exploration: 0.01,
993            random_seed: Some(7),
994            rate_backend: None,
995            rate_backend_max_order: 20,
996            rwkv_model_path: None,
997            rosa_max_order: None,
998            zpaq_method: None,
999        }
1000    }
1001
1002    #[derive(Clone, Default)]
1003    struct CountingPredictor {
1004        update_calls: usize,
1005        update_history_calls: usize,
1006        revert_calls: usize,
1007        pop_history_calls: usize,
1008    }
1009
1010    impl Predictor for CountingPredictor {
1011        fn update(&mut self, _sym: bool) {
1012            self.update_calls += 1;
1013        }
1014
1015        fn update_history(&mut self, _sym: bool) {
1016            self.update_history_calls += 1;
1017        }
1018
1019        fn revert(&mut self) {
1020            self.revert_calls += 1;
1021        }
1022
1023        fn pop_history(&mut self) {
1024            self.pop_history_calls += 1;
1025        }
1026
1027        fn predict_prob(&mut self, sym: bool) -> f64 {
1028            if sym { 0.75 } else { 0.25 }
1029        }
1030
1031        fn model_name(&self) -> String {
1032            "CountingPredictor".to_string()
1033        }
1034
1035        fn boxed_clone(&self) -> Box<dyn Predictor> {
1036            Box::new(self.clone())
1037        }
1038    }
1039
1040    #[test]
1041    fn config_rejects_invalid_period() {
1042        let mut cfg = basic_config();
1043        cfg.augmentation_period = 1;
1044        cfg.return_horizon = 2;
1045        let err = cfg
1046            .validate()
1047            .expect_err("N < H must be rejected for paper-correct augmentation");
1048        assert!(err.contains("augmentation_period"));
1049    }
1050
1051    #[test]
1052    fn config_rejects_zpaq_algorithm_in_strict_mode() {
1053        let mut cfg = basic_config();
1054        cfg.algorithm = "zpaq".to_string();
1055        let err = cfg
1056            .validate()
1057            .expect_err("strict AIQI must reject zpaq algorithm mode");
1058        assert!(err.contains("strict mode"));
1059    }
1060
1061    #[test]
1062    fn config_rejects_zpaq_rate_backend_in_strict_mode() {
1063        let mut cfg = basic_config();
1064        cfg.rate_backend = Some(RateBackend::Zpaq {
1065            method: "1".to_string(),
1066        });
1067        let err = cfg
1068            .validate()
1069            .expect_err("strict AIQI must reject zpaq rate backend");
1070        assert!(err.contains("strict frozen conditioning"));
1071    }
1072
1073    #[test]
1074    fn config_rejects_nonpaper_gamma_or_tau() {
1075        let mut cfg = basic_config();
1076        cfg.discount_gamma = 1.0;
1077        let err = cfg
1078            .validate()
1079            .expect_err("gamma=1 must be rejected for strict paper AIQI");
1080        assert!(err.contains("discount_gamma"));
1081
1082        cfg = basic_config();
1083        cfg.baseline_exploration = 0.0;
1084        let err = cfg
1085            .validate()
1086            .expect_err("tau=0 must be rejected for strict paper AIQI");
1087        assert!(err.contains("baseline_exploration"));
1088    }
1089
1090    #[test]
1091    fn aiqi_estimates_action_values_after_observations() {
1092        let mut agent = AiqiAgent::new(basic_config()).expect("valid aiqi config");
1093        for _ in 0..8 {
1094            agent
1095                .observe_transition(1, &[1], 1)
1096                .expect("transition should be accepted");
1097        }
1098
1099        let action = agent.get_planned_action();
1100        assert!(action <= 1);
1101    }
1102
1103    #[test]
1104    fn fac_ctw_predictor_uses_return_bit_width() {
1105        let mut cfg = basic_config();
1106        cfg.algorithm = "fac-ctw".to_string();
1107        cfg.return_bins = 8; // return_bits=3
1108
1109        let agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1110        let name = agent.phases[0].predictor.model_name();
1111        assert!(
1112            name.contains("k=3"),
1113            "FAC-CTW should factorize over return bits only, model_name={name}"
1114        );
1115    }
1116
1117    #[test]
1118    fn ac_ctw_path_uses_single_tree_predictor() {
1119        let mut cfg = basic_config();
1120        cfg.algorithm = "ac-ctw".to_string();
1121
1122        let agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1123        let name = agent.phases[0].predictor.model_name();
1124        assert!(
1125            name.starts_with("AC-CTW"),
1126            "ac-ctw should map to the single-tree CTW predictor, model_name={name}"
1127        );
1128    }
1129
1130    #[test]
1131    fn ctw_alias_matches_ac_ctw_predictor() {
1132        let mut cfg = basic_config();
1133        cfg.algorithm = "ctw".to_string();
1134
1135        let agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1136        let name = agent.phases[0].predictor.model_name();
1137        assert!(
1138            name.starts_with("AC-CTW"),
1139            "ctw alias should map to paper AIQI-CTW predictor, model_name={name}"
1140        );
1141    }
1142
1143    #[test]
1144    fn distribution_rollout_uses_update_and_revert_when_requested() {
1145        let mut predictor = CountingPredictor::default();
1146        let probs = AiqiAgent::predict_return_distribution(4, 2, &mut predictor, true);
1147
1148        assert_eq!(probs.len(), 4);
1149        assert_eq!(predictor.update_calls, 8);
1150        assert_eq!(predictor.revert_calls, 8);
1151        assert_eq!(predictor.update_history_calls, 0);
1152        assert_eq!(predictor.pop_history_calls, 0);
1153    }
1154
1155    #[test]
1156    fn distribution_rollout_uses_history_path_when_not_requested() {
1157        let mut predictor = CountingPredictor::default();
1158        let probs = AiqiAgent::predict_return_distribution(4, 2, &mut predictor, false);
1159
1160        assert_eq!(probs.len(), 4);
1161        assert_eq!(predictor.update_calls, 0);
1162        assert_eq!(predictor.revert_calls, 0);
1163        assert_eq!(predictor.update_history_calls, 8);
1164        assert_eq!(predictor.pop_history_calls, 8);
1165    }
1166
1167    #[test]
1168    fn ac_ctw_rollout_uses_training_updates() {
1169        let mut cfg = basic_config();
1170        cfg.algorithm = "ac-ctw".to_string();
1171
1172        let agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1173        assert!(
1174            agent.distribution_uses_training_updates,
1175            "ac-ctw should use update/revert during return distribution rollout"
1176        );
1177    }
1178
1179    #[test]
1180    fn return_bin_for_gamma_less_than_one_matches_paper_h_step_return() {
1181        let mut cfg = basic_config();
1182        cfg.discount_gamma = 0.5;
1183        cfg.return_bins = 8;
1184
1185        let mut agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1186        agent
1187            .observe_transition(0, &[0], 1)
1188            .expect("first transition stored");
1189        agent
1190            .observe_transition(0, &[0], 0)
1191            .expect("second transition should produce first return");
1192
1193        let bin = agent.return_bins_by_step[0].expect("first return should be available");
1194        // Paper target: R_{t,H} = (1-gamma) * sum_{k=0}^{H-1} gamma^k r_{t+k}.
1195        // For rewards [1, 0], gamma=0.5, H=2 this equals 0.5.
1196        // With M=8 bins this maps to floor(8 * 0.5) = 4.
1197        assert_eq!(bin, 4);
1198    }
1199
1200    #[test]
1201    fn optional_history_pruning_bounds_retained_state_without_losing_progress() {
1202        let mut cfg = basic_config();
1203        cfg.return_horizon = 3;
1204        cfg.augmentation_period = 4;
1205        cfg.history_prune_keep_steps = Some(8);
1206
1207        let mut agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1208        for i in 0..256usize {
1209            let action = (i % 2) as u64;
1210            let obs = [(i % 2) as u64];
1211            let rew = (i % 2) as i64;
1212            agent
1213                .observe_transition(action, &obs, rew)
1214                .expect("transition should be accepted");
1215        }
1216
1217        // Global progress should be preserved even when retained history is bounded.
1218        assert_eq!(agent.steps_observed(), 256);
1219        assert!(
1220            agent.history_base_step > 1,
1221            "history should have been pruned"
1222        );
1223        assert!(
1224            agent.steps.len() < agent.steps_observed(),
1225            "retained history should be smaller than total observed"
1226        );
1227
1228        let action = agent.get_planned_action();
1229        assert!(action <= 1);
1230    }
1231}