Skip to main content

infotheory/aixi/
aiqi.rs

1//! AIQI implementation from "A Model-Free Universal AI".
2//!
3//! This module implements a model-free universal agent that predicts
4//! discretized H-step returns directly from augmented interaction history.
5//! The implementation follows the phase-indexed periodic augmentation in
6//! "A Model-Free Universal AI":
7//! for return horizon `H` and period `N >= H`, each phase model only inserts
8//! returns at indices `i % N == phase`.
9
10use crate::aixi::common::{Action, PerceptVal, RandomGenerator, Reward};
11use crate::aixi::model::{CtwPredictor, FacCtwPredictor, Predictor, RateBackendBitPredictor};
12use crate::aixi::rate_backend::rate_backend_contains_zpaq;
13#[cfg(feature = "backend-rwkv")]
14use crate::load_rwkv7_model_from_path;
15use crate::{RateBackend, validate_rate_backend};
16
17/// Configuration parameters for an AIQI agent.
18#[derive(Clone)]
19pub struct AiqiConfig {
20    /// Predictive backend.
21    ///
22    /// - `ac-ctw` / `ctw` / `ctw-context-tree`: AIQI-CTW path from
23    ///   "A Model-Free Universal AI".
24    /// - `fac-ctw`: factorized CTW extension.
25    /// - `rosa` / `rwkv`: pluggable predictor extensions.
26    /// - `zpaq`: intentionally unsupported for AIQI strict conditioning.
27    pub algorithm: String,
28    /// Context depth for CTW/FAC-CTW backends.
29    pub ct_depth: usize,
30    /// Number of bits used to encode observations.
31    pub observation_bits: usize,
32    /// Number of observation symbols per environment step.
33    pub observation_stream_len: usize,
34    /// Number of bits used to encode rewards.
35    pub reward_bits: usize,
36    /// Number of valid actions.
37    pub agent_actions: usize,
38    /// Minimum possible environment reward.
39    pub min_reward: Reward,
40    /// Maximum possible environment reward.
41    pub max_reward: Reward,
42    /// Offset applied before encoding reward bits.
43    pub reward_offset: Reward,
44    /// Discount factor used when constructing H-step returns.
45    pub discount_gamma: f64,
46    /// Return horizon `H`.
47    pub return_horizon: usize,
48    /// Number of discretization bins `M` for returns.
49    ///
50    /// This implementation uses exact fixed-width binary encoding of return bins,
51    /// so `return_bins` must be a power of two.
52    pub return_bins: usize,
53    /// Augmentation period `N` (must satisfy `N >= H`).
54    pub augmentation_period: usize,
55    /// Optional history retention knob for bounded memory growth.
56    ///
57    /// - `None`: keep full history (default behavior, no pruning).
58    /// - `Some(k)`: keep at least the most recent `k` steps, while also
59    ///   preserving all steps still required for exact return construction and
60    ///   deferred phase-model advancement.
61    pub history_prune_keep_steps: Option<usize>,
62    /// Baseline epsilon-greedy exploration probability `tau`.
63    pub baseline_exploration: f64,
64    /// Optional deterministic RNG seed for action selection/exploration.
65    ///
66    /// When `None`, a fresh runtime-derived seed is used.
67    pub random_seed: Option<u64>,
68    /// Optional generic rate backend.
69    ///
70    /// When set, this takes precedence over `algorithm` and routes AIQI
71    /// prediction through the shared `RateBackend` abstraction.
72    pub rate_backend: Option<RateBackend>,
73    /// Max-order hint for `rate_backend` constructors that use it (for example ROSA).
74    pub rate_backend_max_order: i64,
75    /// Optional RWKV model path.
76    ///
77    /// Required only when selecting `algorithm="rwkv"` and no `rate_backend`
78    /// override is configured.
79    pub rwkv_model_path: Option<String>,
80    /// Optional ROSA max order.
81    pub rosa_max_order: Option<i64>,
82    /// Optional ZPAQ method string.
83    pub zpaq_method: Option<String>,
84}
85
86impl AiqiConfig {
87    /// Validate configuration constraints.
88    pub fn validate(&self) -> Result<(), String> {
89        if self.agent_actions == 0 {
90            return Err("agent_actions must be >= 1".to_string());
91        }
92        if self.return_horizon == 0 {
93            return Err("return_horizon must be >= 1".to_string());
94        }
95        if self.return_bins == 0 {
96            return Err("return_bins must be >= 1".to_string());
97        }
98        if !self.return_bins.is_power_of_two() {
99            return Err(format!(
100                "return_bins must be a power of two for exact binary return encoding, got {}",
101                self.return_bins
102            ));
103        }
104        if self.augmentation_period < self.return_horizon {
105            return Err(format!(
106                "augmentation_period must be >= return_horizon (got N={}, H={})",
107                self.augmentation_period, self.return_horizon
108            ));
109        }
110        if !(0.0 < self.discount_gamma && self.discount_gamma < 1.0) {
111            return Err(format!(
112                "discount_gamma must be in (0, 1) for AIQI as defined in \"A Model-Free Universal AI\", got {}",
113                self.discount_gamma
114            ));
115        }
116        if !(0.0 < self.baseline_exploration && self.baseline_exploration <= 1.0) {
117            return Err(format!(
118                "baseline_exploration (tau) must be in (0, 1] for AIQI as defined in \"A Model-Free Universal AI\", got {}",
119                self.baseline_exploration
120            ));
121        }
122        if self.max_reward < self.min_reward {
123            return Err(format!(
124                "max_reward must be >= min_reward (got {} < {})",
125                self.max_reward, self.min_reward
126            ));
127        }
128
129        // `rate_backend` takes precedence over `algorithm`; only validate
130        // algorithm choices when no backend override is configured.
131        if self.rate_backend.is_none() {
132            match self.algorithm.as_str() {
133                "ctw" | "fac-ctw" | "ac-ctw" | "ctw-context-tree" | "rosa" => {}
134                "zpaq" => {
135                    return Err(
136                        "AIQI strict mode does not support algorithm=zpaq: zpaq backends do not provide strict frozen conditioning"
137                            .to_string(),
138                    )
139                }
140                #[cfg(feature = "backend-rwkv")]
141                "rwkv" => {}
142                #[cfg(not(feature = "backend-rwkv"))]
143                "rwkv" => {
144                    return Err("algorithm=rwkv requires backend-rwkv feature".to_string())
145                }
146                other => return Err(format!("Unknown AIQI algorithm: {other}")),
147            }
148        }
149
150        if let Some(rate_backend) = &self.rate_backend {
151            validate_rate_backend(rate_backend)
152                .map_err(|err| format!("invalid rate_backend: {err}"))?;
153            if !rate_backend_supports_aiqi_frozen_conditioning(rate_backend) {
154                return Err(
155                    "AIQI strict mode requires frozen context updates; configured rate_backend contains zpaq which does not provide strict frozen conditioning"
156                        .to_string(),
157                );
158            }
159        }
160
161        #[cfg(feature = "backend-rwkv")]
162        if self.rate_backend.is_none() && self.algorithm == "rwkv" {
163            match self.rwkv_model_path.as_deref() {
164                Some(path) if !path.trim().is_empty() => {}
165                _ => {
166                    return Err(
167                        "algorithm=rwkv requires rwkv_model_path when no rate_backend override is configured; for method-string RWKV configure rate_backend rwkv/rwkv7"
168                            .to_string(),
169                    )
170                }
171            }
172        }
173
174        let min_shifted = (self.min_reward as i128) + (self.reward_offset as i128);
175        let max_shifted = (self.max_reward as i128) + (self.reward_offset as i128);
176        if min_shifted < 0 {
177            return Err(format!(
178                "reward_offset too small: min_reward + reward_offset must be >= 0 (got {})",
179                min_shifted
180            ));
181        }
182        if self.reward_bits < 64 {
183            let max_enc = (1u128 << self.reward_bits) - 1;
184            if (max_shifted as u128) > max_enc {
185                return Err(format!(
186                    "reward_bits too small for configured reward range: max shifted reward {} exceeds {}",
187                    max_shifted, max_enc
188                ));
189            }
190        }
191
192        Ok(())
193    }
194}
195
196#[derive(Clone, Debug)]
197struct StepRecord {
198    action: Action,
199    observations: Vec<PerceptVal>,
200    reward: Reward,
201}
202
203struct PhaseModel {
204    predictor: Box<dyn Predictor>,
205    // Largest step index for which this phase model has consumed
206    // the augmented stream up to and including that step's percept.
207    last_augmented_step: usize,
208}
209
210/// AIQI agent with phase-indexed augmented return predictors.
211pub struct AiqiAgent {
212    config: AiqiConfig,
213    phases: Vec<PhaseModel>,
214    steps: Vec<StepRecord>,
215    return_bins_by_step: Vec<Option<u64>>,
216    // Global 1-based index of steps[0] / return_bins_by_step[0].
217    history_base_step: usize,
218    // Total number of transitions observed so far (global 1-based max step index).
219    total_steps_observed: usize,
220    action_bits: usize,
221    return_bits: usize,
222    use_generic_planner: bool,
223    distribution_uses_training_updates: bool,
224    rng: RandomGenerator,
225}
226
227impl AiqiAgent {
228    /// Construct a new AIQI agent.
229    pub fn new(config: AiqiConfig) -> Result<Self, String> {
230        config.validate()?;
231
232        let action_bits = bits_for_cardinality(config.agent_actions);
233        let return_bits = bits_for_cardinality(config.return_bins);
234        let use_generic_planner = aiqi_requires_generic_planner(&config);
235        let distribution_uses_training_updates = config.rate_backend.is_none()
236            && matches!(
237                config.algorithm.as_str(),
238                "ctw" | "fac-ctw" | "ac-ctw" | "ctw-context-tree"
239            );
240
241        let mut phases = Vec::with_capacity(config.augmentation_period);
242        for _ in 0..config.augmentation_period {
243            phases.push(PhaseModel {
244                predictor: build_predictor(&config, return_bits)?,
245                last_augmented_step: 0,
246            });
247        }
248
249        let rng = if let Some(seed) = config.random_seed {
250            RandomGenerator::from_seed(seed)
251        } else {
252            RandomGenerator::new()
253        };
254
255        Ok(Self {
256            action_bits,
257            return_bits,
258            use_generic_planner,
259            distribution_uses_training_updates,
260            config,
261            phases,
262            steps: Vec::new(),
263            return_bins_by_step: Vec::new(),
264            history_base_step: 1,
265            total_steps_observed: 0,
266            rng,
267        })
268    }
269
270    /// Number of transitions incorporated so far.
271    pub fn steps_observed(&self) -> usize {
272        self.total_steps_observed
273    }
274
275    /// Returns the configured number of actions.
276    pub fn num_actions(&self) -> usize {
277        self.config.agent_actions
278    }
279
280    /// Select the next action from the current history.
281    pub fn get_planned_action(&mut self) -> Action {
282        let q_values = self.estimate_q_values();
283        let greedy_action = argmax_with_fixed_tie_break(&q_values) as u64;
284        if self.config.baseline_exploration > 0.0
285            && self
286                .rng
287                .gen_bool(self.config.baseline_exploration.clamp(0.0, 1.0))
288        {
289            self.rng.gen_range(self.config.agent_actions) as u64
290        } else {
291            greedy_action
292        }
293    }
294
295    /// Select the next action, adding optional extra exploration.
296    ///
297    /// The extra exploration probability is combined as
298    /// `p = 1 - (1 - tau) * (1 - extra)`, where `tau` is the baseline
299    /// exploration in [`AiqiConfig`].
300    pub fn get_planned_action_with_extra_exploration(&mut self, extra_exploration: f64) -> Action {
301        let extra = extra_exploration.clamp(0.0, 1.0);
302        let tau = self.config.baseline_exploration.clamp(0.0, 1.0);
303        let effective = 1.0 - (1.0 - tau) * (1.0 - extra);
304        let q_values = self.estimate_q_values();
305        let greedy_action = argmax_with_fixed_tie_break(&q_values) as u64;
306        if effective > 0.0 && self.rng.gen_bool(effective) {
307            self.rng.gen_range(self.config.agent_actions) as u64
308        } else {
309            greedy_action
310        }
311    }
312
313    /// Record one environment transition `(action, observations, reward)`.
314    ///
315    /// This appends to history and, when enough future rewards are known,
316    /// computes and learns one newly available discretized return.
317    pub fn observe_transition(
318        &mut self,
319        action: Action,
320        observations: &[PerceptVal],
321        reward: Reward,
322    ) -> Result<(), String> {
323        if action as usize >= self.config.agent_actions {
324            return Err(format!(
325                "action out of range: action={} but agent_actions={}",
326                action, self.config.agent_actions
327            ));
328        }
329
330        let expected_obs = self.config.observation_stream_len.max(1);
331        if observations.len() != expected_obs {
332            return Err(format!(
333                "observation stream length mismatch: expected {}, got {}",
334                expected_obs,
335                observations.len()
336            ));
337        }
338
339        if reward < self.config.min_reward || reward > self.config.max_reward {
340            return Err(format!(
341                "reward out of configured range: reward={} not in [{}, {}]",
342                reward, self.config.min_reward, self.config.max_reward
343            ));
344        }
345
346        let obs_max = max_value_for_bits(self.config.observation_bits);
347        for &obs in observations {
348            if obs > obs_max {
349                return Err(format!(
350                    "observation value {} does not fit observation_bits={} (max={})",
351                    obs, self.config.observation_bits, obs_max
352                ));
353            }
354        }
355
356        let rew_shifted = (reward as i128) + (self.config.reward_offset as i128);
357        if rew_shifted < 0 {
358            return Err(format!(
359                "encoded reward became negative after offset: reward={} offset={}",
360                reward, self.config.reward_offset
361            ));
362        }
363        if self.config.reward_bits < 64 {
364            let max_enc = (1u128 << self.config.reward_bits) - 1;
365            if (rew_shifted as u128) > max_enc {
366                return Err(format!(
367                    "encoded reward {} exceeds reward_bits={} capacity {}",
368                    rew_shifted, self.config.reward_bits, max_enc
369                ));
370            }
371        }
372
373        self.steps.push(StepRecord {
374            action,
375            observations: observations.to_vec(),
376            reward,
377        });
378        self.total_steps_observed += 1;
379        self.return_bins_by_step.push(None);
380
381        self.maybe_learn_new_return()?;
382        self.maybe_prune_history();
383        Ok(())
384    }
385
386    fn maybe_learn_new_return(&mut self) -> Result<(), String> {
387        let t = self.total_steps_observed;
388        let h = self.config.return_horizon;
389        if t < h {
390            return Ok(());
391        }
392
393        // Newly available return index (1-based): i = t - H + 1.
394        let i = t + 1 - h;
395        let bin = self.compute_return_bin(i);
396        let local_idx = self.local_index(i)?;
397        self.return_bins_by_step[local_idx] = Some(bin);
398
399        let phase = i % self.config.augmentation_period;
400        self.advance_phase_model_to_step(phase, i)
401    }
402
403    fn estimate_q_values(&mut self) -> Vec<f64> {
404        if self.use_generic_planner {
405            return self.estimate_q_values_generic();
406        }
407
408        let step = self.total_steps_observed + 1;
409        let phase = step % self.config.augmentation_period;
410        let config = &self.config;
411        let steps = &self.steps;
412        let return_bins_by_step = &self.return_bins_by_step;
413        let history_base_step = self.history_base_step;
414        let action_bits = self.action_bits;
415        let return_bits = self.return_bits;
416
417        let mut q_values = vec![0.0; self.config.agent_actions];
418        let mut pushed_fast_forward = 0usize;
419
420        {
421            let model = &mut self.phases[phase];
422            let start = (model.last_augmented_step + 1).max(history_base_step);
423            let end = step.saturating_sub(1);
424            if start <= end {
425                for idx in start..=end {
426                    pushed_fast_forward += push_step_tokens_history(
427                        config,
428                        history_base_step,
429                        steps,
430                        return_bins_by_step,
431                        action_bits,
432                        return_bits,
433                        model.predictor.as_mut(),
434                        phase,
435                        idx,
436                    );
437                }
438            }
439
440            for action in 0..self.config.agent_actions {
441                let pushed_action = push_encoded_bits_history(
442                    model.predictor.as_mut(),
443                    action as u64,
444                    self.action_bits,
445                );
446                let dist = Self::predict_return_distribution(
447                    self.config.return_bins,
448                    self.return_bits,
449                    model.predictor.as_mut(),
450                    self.distribution_uses_training_updates,
451                );
452                q_values[action] = expectation_from_distribution(&dist);
453                pop_history_bits(model.predictor.as_mut(), pushed_action);
454            }
455
456            pop_history_bits(model.predictor.as_mut(), pushed_fast_forward);
457        }
458
459        q_values
460    }
461
462    fn estimate_q_values_generic(&mut self) -> Vec<f64> {
463        let step = self.total_steps_observed + 1;
464        let phase = step % self.config.augmentation_period;
465
466        let model = &self.phases[phase];
467        let mut context_predictor = model.predictor.boxed_clone();
468
469        let start = (model.last_augmented_step + 1).max(self.history_base_step);
470        let end = step.saturating_sub(1);
471        if start <= end {
472            for idx in start..=end {
473                push_augmented_step_tokens_commit(
474                    &self.config,
475                    self.history_base_step,
476                    &self.steps,
477                    &self.return_bins_by_step,
478                    self.action_bits,
479                    self.return_bits,
480                    context_predictor.as_mut(),
481                    phase,
482                    idx,
483                )
484                .expect("generic planner retained history must contain required augmented return");
485            }
486        }
487
488        let mut q_values = vec![0.0; self.config.agent_actions];
489        for action in 0..self.config.agent_actions {
490            let mut action_predictor = context_predictor.boxed_clone();
491            let _ = push_encoded_bits_commit_history(
492                action_predictor.as_mut(),
493                action as u64,
494                self.action_bits,
495            );
496            let dist = Self::predict_return_distribution_from_base_predictor(
497                self.config.return_bins,
498                self.return_bits,
499                action_predictor.as_ref(),
500            );
501            q_values[action] = expectation_from_distribution(&dist);
502        }
503
504        q_values
505    }
506
507    fn predict_return_distribution(
508        return_bins: usize,
509        return_bits: usize,
510        predictor: &mut dyn Predictor,
511        use_training_updates: bool,
512    ) -> Vec<f64> {
513        debug_assert!(return_bins.is_power_of_two());
514        if return_bins == 1 {
515            return vec![1.0];
516        }
517
518        let mut probs = vec![0.0; return_bins];
519        for (bin, slot) in probs.iter_mut().enumerate() {
520            let mut p = 1.0f64;
521            let mut v = bin as u64;
522            for _ in 0..return_bits {
523                let bit = (v & 1) == 1;
524                v >>= 1;
525                let q = predictor.predict_prob(bit).clamp(1e-12, 1.0 - 1e-12);
526                p *= q;
527                if use_training_updates {
528                    predictor.update(bit);
529                } else {
530                    predictor.update_history(bit);
531                }
532            }
533            if use_training_updates {
534                revert_bits(predictor, return_bits);
535            } else {
536                pop_history_bits(predictor, return_bits);
537            }
538            *slot = p;
539        }
540
541        let sum: f64 = probs.iter().sum();
542        if !sum.is_finite() || sum <= 0.0 {
543            let u = 1.0 / (return_bins as f64);
544            probs.fill(u);
545            return probs;
546        }
547
548        for p in &mut probs {
549            *p /= sum;
550        }
551        probs
552    }
553
554    fn predict_return_distribution_from_base_predictor(
555        return_bins: usize,
556        return_bits: usize,
557        base_predictor: &dyn Predictor,
558    ) -> Vec<f64> {
559        debug_assert!(return_bins.is_power_of_two());
560        if return_bins == 1 {
561            return vec![1.0];
562        }
563
564        let mut probs = vec![0.0; return_bins];
565        for (bin, slot) in probs.iter_mut().enumerate() {
566            let mut predictor = base_predictor.boxed_clone();
567            let mut p = 1.0f64;
568            let mut v = bin as u64;
569            for _ in 0..return_bits {
570                let bit = (v & 1) == 1;
571                v >>= 1;
572                let q = predictor.predict_prob(bit).clamp(1e-12, 1.0 - 1e-12);
573                p *= q;
574                predictor.commit_update(bit);
575            }
576            *slot = p;
577        }
578
579        let sum: f64 = probs.iter().sum();
580        if !sum.is_finite() || sum <= 0.0 {
581            let u = 1.0 / (return_bins as f64);
582            probs.fill(u);
583            return probs;
584        }
585
586        for p in &mut probs {
587            *p /= sum;
588        }
589        probs
590    }
591
592    fn advance_phase_model_to_step(
593        &mut self,
594        phase: usize,
595        target_step: usize,
596    ) -> Result<(), String> {
597        let config = &self.config;
598        let steps = &self.steps;
599        let return_bins_by_step = &self.return_bins_by_step;
600        let history_base_step = self.history_base_step;
601        let action_bits = self.action_bits;
602        let return_bits = self.return_bits;
603        let model = &mut self.phases[phase];
604        if target_step <= model.last_augmented_step {
605            return Ok(());
606        }
607
608        let start = (model.last_augmented_step + 1).max(history_base_step);
609        for idx in start..=target_step {
610            push_augmented_step_tokens_commit(
611                config,
612                history_base_step,
613                steps,
614                return_bins_by_step,
615                action_bits,
616                return_bits,
617                model.predictor.as_mut(),
618                phase,
619                idx,
620            )?;
621        }
622
623        model.last_augmented_step = target_step;
624        Ok(())
625    }
626
627    fn compute_return_bin(&self, start_step: usize) -> u64 {
628        let h = self.config.return_horizon;
629        let gamma = self.config.discount_gamma;
630
631        debug_assert!(gamma > 0.0 && gamma < 1.0);
632        let reward_range = (self.config.max_reward - self.config.min_reward) as f64;
633
634        // Paper definition: R_{t,H} = (1-gamma) * sum_{k=0}^{H-1} gamma^k r_{t+k}.
635        let mut total = 0.0f64;
636        let mut gk = 1.0f64;
637        for k in 0..h {
638            let idx = start_step + k;
639            let local_idx = self
640                .local_index(idx)
641                .expect("return computation requires in-range history");
642            let r = self.steps[local_idx].reward;
643            let rn = if reward_range <= 0.0 {
644                0.0
645            } else {
646                ((r - self.config.min_reward) as f64 / reward_range).clamp(0.0, 1.0)
647            };
648            total += gk * rn;
649            gk *= gamma;
650        }
651        let ret = ((1.0 - gamma) * total).clamp(0.0, 1.0);
652
653        let mut bin = (ret * (self.config.return_bins as f64)).floor() as u64;
654        let max_bin = (self.config.return_bins as u64).saturating_sub(1);
655        if bin > max_bin {
656            bin = max_bin;
657        }
658        bin
659    }
660
661    fn local_index(&self, global_step: usize) -> Result<usize, String> {
662        if global_step < self.history_base_step || global_step > self.total_steps_observed {
663            return Err(format!(
664                "global step {} out of retained history range [{}, {}]",
665                global_step, self.history_base_step, self.total_steps_observed
666            ));
667        }
668        Ok(global_step - self.history_base_step)
669    }
670
671    fn maybe_prune_history(&mut self) {
672        let Some(keep_steps) = self.config.history_prune_keep_steps else {
673            return;
674        };
675        if self.steps.is_empty() {
676            return;
677        }
678
679        let min_phase_committed = self
680            .phases
681            .iter()
682            .map(|phase| phase.last_augmented_step)
683            .min()
684            .unwrap_or(0);
685
686        // For the next return update, we must retain steps from
687        // (t+2-H) onward (1-based indexing). Everything before that is no
688        // longer needed for exact H-step return construction.
689        let next_start_needed = self
690            .total_steps_observed
691            .saturating_add(2)
692            .saturating_sub(self.config.return_horizon);
693        let returns_safe_drop_upto = next_start_needed.saturating_sub(1);
694
695        let mut safe_drop_upto = min_phase_committed.min(returns_safe_drop_upto);
696
697        // Optional retention floor: keep at least `keep_steps` most recent
698        // transitions in memory for diagnostics/debugging.
699        let keep_floor_drop_upto = self.total_steps_observed.saturating_sub(keep_steps);
700        safe_drop_upto = safe_drop_upto.min(keep_floor_drop_upto);
701
702        if safe_drop_upto < self.history_base_step {
703            return;
704        }
705
706        let drain_count = safe_drop_upto - self.history_base_step + 1;
707        if drain_count == 0 || drain_count > self.steps.len() {
708            return;
709        }
710
711        self.steps.drain(0..drain_count);
712        self.return_bins_by_step.drain(0..drain_count);
713        self.history_base_step += drain_count;
714    }
715}
716
717fn push_step_tokens_history(
718    config: &AiqiConfig,
719    history_base_step: usize,
720    steps: &[StepRecord],
721    return_bins_by_step: &[Option<u64>],
722    action_bits: usize,
723    return_bits: usize,
724    predictor: &mut dyn Predictor,
725    phase: usize,
726    idx: usize,
727) -> usize {
728    let mut pushed = 0usize;
729    pushed += push_action_tokens_history(history_base_step, steps, action_bits, predictor, idx);
730
731    if idx % config.augmentation_period == phase {
732        let local_idx = idx - history_base_step;
733        if let Some(bin) = return_bins_by_step[local_idx] {
734            pushed += push_encoded_bits_history(predictor, bin, return_bits);
735        }
736    }
737
738    pushed + push_percept_tokens_history(config, history_base_step, steps, predictor, idx)
739}
740
741fn push_augmented_step_tokens_commit(
742    config: &AiqiConfig,
743    history_base_step: usize,
744    steps: &[StepRecord],
745    return_bins_by_step: &[Option<u64>],
746    action_bits: usize,
747    return_bits: usize,
748    predictor: &mut dyn Predictor,
749    phase: usize,
750    idx: usize,
751) -> Result<usize, String> {
752    let mut pushed = 0usize;
753    pushed +=
754        push_action_tokens_commit_history(history_base_step, steps, action_bits, predictor, idx);
755
756    if idx % config.augmentation_period == phase {
757        let local_idx = idx - history_base_step;
758        let bin = return_bins_by_step[local_idx].ok_or_else(|| {
759            format!(
760                "missing return bin for step {} in phase {} while pushing augmented history",
761                idx, phase
762            )
763        })?;
764        pushed += push_encoded_bits_commit(predictor, bin, return_bits);
765    }
766
767    Ok(pushed
768        + push_percept_tokens_commit_history(config, history_base_step, steps, predictor, idx))
769}
770
771fn push_action_tokens_history(
772    history_base_step: usize,
773    steps: &[StepRecord],
774    action_bits: usize,
775    predictor: &mut dyn Predictor,
776    idx: usize,
777) -> usize {
778    let action = steps[idx - history_base_step].action;
779    push_encoded_bits_history(predictor, action, action_bits)
780}
781
782fn push_action_tokens_commit_history(
783    history_base_step: usize,
784    steps: &[StepRecord],
785    action_bits: usize,
786    predictor: &mut dyn Predictor,
787    idx: usize,
788) -> usize {
789    let action = steps[idx - history_base_step].action;
790    push_encoded_bits_commit_history(predictor, action, action_bits)
791}
792
793fn push_percept_tokens_history(
794    config: &AiqiConfig,
795    history_base_step: usize,
796    steps: &[StepRecord],
797    predictor: &mut dyn Predictor,
798    idx: usize,
799) -> usize {
800    let step = &steps[idx - history_base_step];
801    let mut pushed = 0usize;
802    for &obs in &step.observations {
803        pushed += push_encoded_bits_history(predictor, obs, config.observation_bits);
804    }
805    pushed
806        + push_encoded_reward_history(
807            predictor,
808            step.reward,
809            config.reward_bits,
810            config.reward_offset,
811        )
812}
813
814fn push_percept_tokens_commit_history(
815    config: &AiqiConfig,
816    history_base_step: usize,
817    steps: &[StepRecord],
818    predictor: &mut dyn Predictor,
819    idx: usize,
820) -> usize {
821    let step = &steps[idx - history_base_step];
822    let mut pushed = 0usize;
823    for &obs in &step.observations {
824        pushed += push_encoded_bits_commit_history(predictor, obs, config.observation_bits);
825    }
826    pushed
827        + push_encoded_reward_commit_history(
828            predictor,
829            step.reward,
830            config.reward_bits,
831            config.reward_offset,
832        )
833}
834
835fn build_predictor(config: &AiqiConfig, return_bits: usize) -> Result<Box<dyn Predictor>, String> {
836    if let Some(rate_backend) = config.rate_backend.clone() {
837        let bit_backend = adapt_rate_backend_for_bit_tokens(rate_backend);
838        let predictor = RateBackendBitPredictor::new(bit_backend, config.rate_backend_max_order)?;
839        return Ok(Box::new(predictor));
840    }
841
842    match config.algorithm.as_str() {
843        "ctw" | "ac-ctw" | "ctw-context-tree" => Ok(Box::new(CtwPredictor::new(config.ct_depth))),
844        "fac-ctw" => {
845            // AIQI-FAC-CTW extension: factorized return-bit modeling.
846            Ok(Box::new(FacCtwPredictor::new(config.ct_depth, return_bits)))
847        }
848        "rosa" => {
849            let max_order = config
850                .rosa_max_order
851                .unwrap_or(config.rate_backend_max_order);
852            let bit_backend = adapt_rate_backend_for_bit_tokens(RateBackend::RosaPlus);
853            let predictor = RateBackendBitPredictor::new(bit_backend, max_order)?;
854            Ok(Box::new(predictor))
855        }
856        #[cfg(feature = "backend-rwkv")]
857        "rwkv" => {
858            let path = config.rwkv_model_path.as_ref().ok_or_else(|| {
859                "algorithm=rwkv requires rwkv_model_path when no rate_backend override is configured; for method-string RWKV configure rate_backend rwkv/rwkv7"
860                    .to_string()
861            })?;
862            let model_arc = load_rwkv7_model_from_path(path);
863            let bit_backend =
864                adapt_rate_backend_for_bit_tokens(RateBackend::Rwkv7 { model: model_arc });
865            let predictor = RateBackendBitPredictor::new(bit_backend, config.rate_backend_max_order)?;
866            Ok(Box::new(predictor))
867        }
868        #[cfg(not(feature = "backend-rwkv"))]
869        "rwkv" => Err("algorithm=rwkv requires backend-rwkv feature".to_string()),
870        "zpaq" => Err(
871            "AIQI strict mode does not support algorithm=zpaq; configure a backend with strict frozen conditioning"
872                .to_string(),
873        ),
874        _ => Err(format!("Unknown AIQI algorithm: {}", config.algorithm)),
875    }
876}
877
878fn adapt_rate_backend_for_bit_tokens(backend: RateBackend) -> RateBackend {
879    crate::aixi::rate_backend::adapt_rate_backend_for_bit_tokens(backend)
880}
881
882fn rate_backend_supports_aiqi_frozen_conditioning(backend: &RateBackend) -> bool {
883    !rate_backend_contains_zpaq(backend)
884}
885
886fn aiqi_requires_generic_planner(config: &AiqiConfig) -> bool {
887    config.rate_backend.is_some()
888        || !matches!(
889            config.algorithm.as_str(),
890            "ctw" | "fac-ctw" | "ac-ctw" | "ctw-context-tree"
891        )
892}
893
894fn bits_for_cardinality(cardinality: usize) -> usize {
895    let n = cardinality.max(1);
896    let mut bits = 0usize;
897    while (1usize << bits) < n {
898        bits += 1;
899    }
900    bits.max(1)
901}
902
903fn max_value_for_bits(bits: usize) -> u64 {
904    if bits >= 64 {
905        u64::MAX
906    } else if bits == 0 {
907        0
908    } else {
909        (1u64 << bits) - 1
910    }
911}
912
913fn push_encoded_bits_commit(predictor: &mut dyn Predictor, value: u64, bits: usize) -> usize {
914    let mut v = value;
915    for _ in 0..bits {
916        predictor.commit_update((v & 1) == 1);
917        v >>= 1;
918    }
919    bits
920}
921
922fn push_encoded_bits_history(predictor: &mut dyn Predictor, value: u64, bits: usize) -> usize {
923    let mut v = value;
924    for _ in 0..bits {
925        predictor.update_history((v & 1) == 1);
926        v >>= 1;
927    }
928    bits
929}
930
931fn push_encoded_bits_commit_history(
932    predictor: &mut dyn Predictor,
933    value: u64,
934    bits: usize,
935) -> usize {
936    let mut v = value;
937    for _ in 0..bits {
938        predictor.commit_update_history((v & 1) == 1);
939        v >>= 1;
940    }
941    bits
942}
943
944fn push_encoded_reward_history(
945    predictor: &mut dyn Predictor,
946    reward: Reward,
947    bits: usize,
948    offset: Reward,
949) -> usize {
950    let shifted = (reward as i128) + (offset as i128);
951    let as_u64 = if shifted <= 0 {
952        0
953    } else if shifted > (u64::MAX as i128) {
954        u64::MAX
955    } else {
956        shifted as u64
957    };
958    push_encoded_bits_history(predictor, as_u64, bits)
959}
960
961fn push_encoded_reward_commit_history(
962    predictor: &mut dyn Predictor,
963    reward: Reward,
964    bits: usize,
965    offset: Reward,
966) -> usize {
967    let shifted = (reward as i128) + (offset as i128);
968    let as_u64 = if shifted <= 0 {
969        0
970    } else if shifted > (u64::MAX as i128) {
971        u64::MAX
972    } else {
973        shifted as u64
974    };
975    push_encoded_bits_commit_history(predictor, as_u64, bits)
976}
977
978fn pop_history_bits(predictor: &mut dyn Predictor, bits: usize) {
979    for _ in 0..bits {
980        predictor.pop_history();
981    }
982}
983
984fn revert_bits(predictor: &mut dyn Predictor, bits: usize) {
985    for _ in 0..bits {
986        predictor.revert();
987    }
988}
989
990fn expectation_from_distribution(probs: &[f64]) -> f64 {
991    if probs.is_empty() {
992        return 0.0;
993    }
994    let m = probs.len() as f64;
995    probs
996        .iter()
997        .enumerate()
998        .map(|(i, p)| (i as f64 / m) * p)
999        .sum::<f64>()
1000}
1001
1002fn argmax_with_fixed_tie_break(values: &[f64]) -> usize {
1003    let mut best_value = f64::NEG_INFINITY;
1004    let mut best_idx = 0usize;
1005    for (i, &v) in values.iter().enumerate() {
1006        if v > best_value {
1007            best_value = v;
1008            best_idx = i;
1009        }
1010    }
1011    best_idx
1012}
1013
1014#[cfg(test)]
1015mod tests {
1016    use super::*;
1017    use std::sync::{Arc, Mutex};
1018
1019    fn basic_config() -> AiqiConfig {
1020        AiqiConfig {
1021            algorithm: "ac-ctw".to_string(),
1022            ct_depth: 8,
1023            observation_bits: 1,
1024            observation_stream_len: 1,
1025            reward_bits: 1,
1026            agent_actions: 2,
1027            min_reward: 0,
1028            max_reward: 1,
1029            reward_offset: 0,
1030            discount_gamma: 0.99,
1031            return_horizon: 2,
1032            return_bins: 8,
1033            augmentation_period: 2,
1034            history_prune_keep_steps: None,
1035            baseline_exploration: 0.01,
1036            random_seed: Some(7),
1037            rate_backend: None,
1038            rate_backend_max_order: 20,
1039            rwkv_model_path: None,
1040            rosa_max_order: None,
1041            zpaq_method: None,
1042        }
1043    }
1044
1045    #[derive(Clone, Default)]
1046    struct CountingPredictor {
1047        update_calls: usize,
1048        commit_update_calls: usize,
1049        update_history_calls: usize,
1050        commit_update_history_calls: usize,
1051        revert_calls: usize,
1052        pop_history_calls: usize,
1053    }
1054
1055    impl Predictor for CountingPredictor {
1056        fn update(&mut self, _sym: bool) {
1057            self.update_calls += 1;
1058        }
1059
1060        fn commit_update(&mut self, _sym: bool) {
1061            self.commit_update_calls += 1;
1062        }
1063
1064        fn update_history(&mut self, _sym: bool) {
1065            self.update_history_calls += 1;
1066        }
1067
1068        fn commit_update_history(&mut self, _sym: bool) {
1069            self.commit_update_history_calls += 1;
1070        }
1071
1072        fn revert(&mut self) {
1073            self.revert_calls += 1;
1074        }
1075
1076        fn pop_history(&mut self) {
1077            self.pop_history_calls += 1;
1078        }
1079
1080        fn predict_prob(&mut self, sym: bool) -> f64 {
1081            if sym { 0.75 } else { 0.25 }
1082        }
1083
1084        fn model_name(&self) -> String {
1085            "CountingPredictor".to_string()
1086        }
1087
1088        fn boxed_clone(&self) -> Box<dyn Predictor> {
1089            Box::new(self.clone())
1090        }
1091    }
1092
1093    #[derive(Clone, Default)]
1094    struct SharedCallCounts {
1095        update: usize,
1096        commit_update: usize,
1097        update_history: usize,
1098        commit_update_history: usize,
1099    }
1100
1101    #[derive(Clone)]
1102    struct SharedCountingPredictor {
1103        counts: Arc<Mutex<SharedCallCounts>>,
1104    }
1105
1106    impl SharedCountingPredictor {
1107        fn new(counts: Arc<Mutex<SharedCallCounts>>) -> Self {
1108            Self { counts }
1109        }
1110    }
1111
1112    impl Predictor for SharedCountingPredictor {
1113        fn update(&mut self, _sym: bool) {
1114            self.counts.lock().unwrap().update += 1;
1115        }
1116
1117        fn commit_update(&mut self, _sym: bool) {
1118            self.counts.lock().unwrap().commit_update += 1;
1119        }
1120
1121        fn update_history(&mut self, _sym: bool) {
1122            self.counts.lock().unwrap().update_history += 1;
1123        }
1124
1125        fn commit_update_history(&mut self, _sym: bool) {
1126            self.counts.lock().unwrap().commit_update_history += 1;
1127        }
1128
1129        fn revert(&mut self) {}
1130
1131        fn pop_history(&mut self) {}
1132
1133        fn predict_prob(&mut self, sym: bool) -> f64 {
1134            if sym { 0.75 } else { 0.25 }
1135        }
1136
1137        fn model_name(&self) -> String {
1138            "SharedCountingPredictor".to_string()
1139        }
1140
1141        fn boxed_clone(&self) -> Box<dyn Predictor> {
1142            Box::new(self.clone())
1143        }
1144    }
1145
1146    #[derive(Clone, Default)]
1147    struct ReturnLearningPredictor {
1148        saw_training_one: bool,
1149    }
1150
1151    impl Predictor for ReturnLearningPredictor {
1152        fn update(&mut self, sym: bool) {
1153            if sym {
1154                self.saw_training_one = true;
1155            }
1156        }
1157
1158        fn commit_update(&mut self, sym: bool) {
1159            if sym {
1160                self.saw_training_one = true;
1161            }
1162        }
1163
1164        fn update_history(&mut self, _sym: bool) {}
1165
1166        fn commit_update_history(&mut self, _sym: bool) {}
1167
1168        fn revert(&mut self) {}
1169
1170        fn pop_history(&mut self) {}
1171
1172        fn predict_prob(&mut self, sym: bool) -> f64 {
1173            let p1 = if self.saw_training_one { 0.75 } else { 0.25 };
1174            if sym { p1 } else { 1.0 - p1 }
1175        }
1176
1177        fn model_name(&self) -> String {
1178            "ReturnLearningPredictor".to_string()
1179        }
1180
1181        fn boxed_clone(&self) -> Box<dyn Predictor> {
1182            Box::new(self.clone())
1183        }
1184    }
1185
1186    #[test]
1187    fn config_rejects_invalid_period() {
1188        let mut cfg = basic_config();
1189        cfg.augmentation_period = 1;
1190        cfg.return_horizon = 2;
1191        let err = cfg
1192            .validate()
1193            .expect_err("N < H must be rejected to match \"A Model-Free Universal AI\"");
1194        assert!(err.contains("augmentation_period"));
1195    }
1196
1197    #[test]
1198    fn config_rejects_non_power_of_two_return_bins() {
1199        let mut cfg = basic_config();
1200        cfg.return_bins = 3;
1201        let err = cfg
1202            .validate()
1203            .expect_err("non-power-of-two return_bins should be rejected");
1204        assert!(err.contains("power of two"));
1205    }
1206
1207    #[test]
1208    fn config_rejects_zpaq_algorithm_in_strict_mode() {
1209        let mut cfg = basic_config();
1210        cfg.algorithm = "zpaq".to_string();
1211        let err = cfg
1212            .validate()
1213            .expect_err("strict AIQI must reject zpaq algorithm mode");
1214        assert!(err.contains("strict mode"));
1215    }
1216
1217    #[test]
1218    fn config_rejects_zpaq_rate_backend_in_strict_mode() {
1219        let mut cfg = basic_config();
1220        cfg.rate_backend = Some(RateBackend::Zpaq {
1221            method: "1".to_string(),
1222        });
1223        let err = cfg
1224            .validate()
1225            .expect_err("strict AIQI must reject zpaq rate backend");
1226        assert!(err.contains("strict frozen conditioning"));
1227    }
1228
1229    #[test]
1230    fn config_rejects_nonpaper_gamma_or_tau() {
1231        let mut cfg = basic_config();
1232        cfg.discount_gamma = 1.0;
1233        let err = cfg
1234            .validate()
1235            .expect_err("gamma=1 must be rejected for strict paper AIQI");
1236        assert!(err.contains("discount_gamma"));
1237
1238        cfg = basic_config();
1239        cfg.baseline_exploration = 0.0;
1240        let err = cfg
1241            .validate()
1242            .expect_err("tau=0 must be rejected for strict paper AIQI");
1243        assert!(err.contains("baseline_exploration"));
1244    }
1245
1246    #[test]
1247    fn aiqi_estimates_action_values_after_observations() {
1248        let mut agent = AiqiAgent::new(basic_config()).expect("valid aiqi config");
1249        for _ in 0..8 {
1250            agent
1251                .observe_transition(1, &[1], 1)
1252                .expect("transition should be accepted");
1253        }
1254
1255        let action = agent.get_planned_action();
1256        assert!(action <= 1);
1257    }
1258
1259    #[test]
1260    fn fac_ctw_predictor_uses_return_bit_width() {
1261        let mut cfg = basic_config();
1262        cfg.algorithm = "fac-ctw".to_string();
1263        cfg.return_bins = 8; // return_bits=3
1264
1265        let agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1266        let name = agent.phases[0].predictor.model_name();
1267        assert!(
1268            name.contains("k=3"),
1269            "FAC-CTW should factorize over return bits only, model_name={name}"
1270        );
1271    }
1272
1273    #[test]
1274    fn ac_ctw_path_uses_single_tree_predictor() {
1275        let mut cfg = basic_config();
1276        cfg.algorithm = "ac-ctw".to_string();
1277
1278        let agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1279        let name = agent.phases[0].predictor.model_name();
1280        assert!(
1281            name.starts_with("AC-CTW"),
1282            "ac-ctw should map to the single-tree CTW predictor, model_name={name}"
1283        );
1284    }
1285
1286    #[test]
1287    fn ctw_alias_matches_ac_ctw_predictor() {
1288        let mut cfg = basic_config();
1289        cfg.algorithm = "ctw".to_string();
1290
1291        let agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1292        let name = agent.phases[0].predictor.model_name();
1293        assert!(
1294            name.starts_with("AC-CTW"),
1295            "ctw alias should map to paper AIQI-CTW predictor, model_name={name}"
1296        );
1297    }
1298
1299    #[test]
1300    fn distribution_rollout_uses_update_and_revert_when_requested() {
1301        let mut predictor = CountingPredictor::default();
1302        let probs = AiqiAgent::predict_return_distribution(4, 2, &mut predictor, true);
1303
1304        assert_eq!(probs.len(), 4);
1305        assert_eq!(predictor.update_calls, 8);
1306        assert_eq!(predictor.revert_calls, 8);
1307        assert_eq!(predictor.update_history_calls, 0);
1308        assert_eq!(predictor.pop_history_calls, 0);
1309    }
1310
1311    #[test]
1312    fn distribution_rollout_uses_history_path_when_not_requested() {
1313        let mut predictor = CountingPredictor::default();
1314        let probs = AiqiAgent::predict_return_distribution(4, 2, &mut predictor, false);
1315
1316        assert_eq!(probs.len(), 4);
1317        assert_eq!(predictor.update_calls, 0);
1318        assert_eq!(predictor.revert_calls, 0);
1319        assert_eq!(predictor.update_history_calls, 8);
1320        assert_eq!(predictor.pop_history_calls, 8);
1321    }
1322
1323    #[test]
1324    fn generic_distribution_rollout_trains_on_return_symbols() {
1325        let predictor = ReturnLearningPredictor::default();
1326        let probs = AiqiAgent::predict_return_distribution_from_base_predictor(4, 2, &predictor);
1327
1328        assert_eq!(probs.len(), 4);
1329        assert!((probs.iter().sum::<f64>() - 1.0).abs() < 1e-12);
1330        assert!(
1331            probs[3] > probs[1],
1332            "training on the first return bit should make bin 11 likelier than 01; got {:?}",
1333            probs
1334        );
1335        assert!(
1336            (probs[0] - 0.5625).abs() < 1e-12,
1337            "expected exact normalized mass for 00, got {:?}",
1338            probs
1339        );
1340    }
1341
1342    #[test]
1343    fn ac_ctw_rollout_uses_training_updates() {
1344        let mut cfg = basic_config();
1345        cfg.algorithm = "ac-ctw".to_string();
1346
1347        let agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1348        assert!(
1349            agent.distribution_uses_training_updates,
1350            "ac-ctw should use update/revert during return distribution rollout"
1351        );
1352    }
1353
1354    #[test]
1355    fn return_bin_for_gamma_less_than_one_matches_paper_h_step_return() {
1356        let mut cfg = basic_config();
1357        cfg.discount_gamma = 0.5;
1358        cfg.return_bins = 8;
1359
1360        let mut agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1361        agent
1362            .observe_transition(0, &[0], 1)
1363            .expect("first transition stored");
1364        agent
1365            .observe_transition(0, &[0], 0)
1366            .expect("second transition should produce first return");
1367
1368        let bin = agent.return_bins_by_step[0].expect("first return should be available");
1369        // Paper target: R_{t,H} = (1-gamma) * sum_{k=0}^{H-1} gamma^k r_{t+k}.
1370        // For rewards [1, 0], gamma=0.5, H=2 this equals 0.5.
1371        // With M=8 bins this maps to floor(8 * 0.5) = 4.
1372        assert_eq!(bin, 4);
1373    }
1374
1375    #[test]
1376    fn optional_history_pruning_bounds_retained_state_without_losing_progress() {
1377        let mut cfg = basic_config();
1378        cfg.return_horizon = 3;
1379        cfg.augmentation_period = 4;
1380        cfg.history_prune_keep_steps = Some(8);
1381
1382        let mut agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1383        for i in 0..256usize {
1384            let action = (i % 2) as u64;
1385            let obs = [(i % 2) as u64];
1386            let rew = (i % 2) as i64;
1387            agent
1388                .observe_transition(action, &obs, rew)
1389                .expect("transition should be accepted");
1390        }
1391
1392        // Global progress should be preserved even when retained history is bounded.
1393        assert_eq!(agent.steps_observed(), 256);
1394        assert!(
1395            agent.history_base_step > 1,
1396            "history should have been pruned"
1397        );
1398        assert!(
1399            agent.steps.len() < agent.steps_observed(),
1400            "retained history should be smaller than total observed"
1401        );
1402
1403        let action = agent.get_planned_action();
1404        assert!(action <= 1);
1405    }
1406
1407    #[test]
1408    fn committed_phase_advancement_uses_commit_predictor_paths() {
1409        let mut agent = AiqiAgent::new(basic_config()).expect("valid aiqi config");
1410        let counts = Arc::new(Mutex::new(SharedCallCounts::default()));
1411        agent.phases[1].predictor = Box::new(SharedCountingPredictor::new(counts.clone()));
1412        agent.phases[1].last_augmented_step = 0;
1413        agent.history_base_step = 1;
1414        agent.total_steps_observed = 1;
1415        agent.steps = vec![StepRecord {
1416            action: 1,
1417            observations: vec![1],
1418            reward: 1,
1419        }];
1420        agent.return_bins_by_step = vec![Some(3)];
1421
1422        agent
1423            .advance_phase_model_to_step(1, 1)
1424            .expect("phase advancement should succeed");
1425
1426        let snapshot = counts.lock().unwrap().clone();
1427        assert_eq!(snapshot.commit_update, 3);
1428        assert_eq!(snapshot.commit_update_history, 3);
1429        assert_eq!(snapshot.update, 0);
1430        assert_eq!(snapshot.update_history, 0);
1431    }
1432
1433    #[test]
1434    fn generic_planner_trains_on_returns_and_freezes_conditioning_tokens() {
1435        let mut cfg = basic_config();
1436        cfg.rate_backend = Some(RateBackend::Match {
1437            hash_bits: 16,
1438            min_len: 2,
1439            max_len: 16,
1440            base_mix: 0.05,
1441            confidence_scale: 1.0,
1442        });
1443
1444        let mut agent = AiqiAgent::new(cfg).expect("valid aiqi config");
1445        let counts = Arc::new(Mutex::new(SharedCallCounts::default()));
1446        agent.phases[1].predictor = Box::new(SharedCountingPredictor::new(counts.clone()));
1447        agent.phases[1].last_augmented_step = 0;
1448        agent.history_base_step = 1;
1449        agent.total_steps_observed = 2;
1450        agent.steps = vec![
1451            StepRecord {
1452                action: 1,
1453                observations: vec![1],
1454                reward: 1,
1455            },
1456            StepRecord {
1457                action: 0,
1458                observations: vec![0],
1459                reward: 0,
1460            },
1461        ];
1462        agent.return_bins_by_step = vec![Some(3), None];
1463
1464        let q_values = agent.estimate_q_values_generic();
1465
1466        assert_eq!(q_values.len(), agent.config.agent_actions);
1467        let snapshot = counts.lock().unwrap().clone();
1468        assert_eq!(snapshot.update, 0);
1469        assert_eq!(snapshot.update_history, 0);
1470        assert!(
1471            snapshot.commit_update > 0,
1472            "generic planner should train on augmented return symbols"
1473        );
1474        assert!(
1475            snapshot.commit_update_history > 0,
1476            "generic planner should keep action/percept conditioning frozen"
1477        );
1478    }
1479}