1use crate::RateBackend;
7use crate::aixi::common::{
8 Action, ObservationKeyMode, PerceptVal, RandomGenerator, Reward, decode, encode,
9 observation_repr_from_stream,
10};
11use crate::aixi::mcts::{AgentSimulator, SearchTree};
12#[cfg(feature = "backend-mamba")]
13use crate::aixi::model::MambaPredictor;
14#[cfg(feature = "backend-rwkv")]
15use crate::aixi::model::RwkvPredictor;
16use crate::aixi::model::{
17 CtwPredictor, FacCtwPredictor, Predictor, RateBackendBitPredictor, RosaPredictor, ZpaqPredictor,
18};
19use crate::aixi::rate_backend::{adapt_rate_backend_for_bit_tokens, rate_backend_contains_zpaq};
20#[cfg(feature = "backend-mamba")]
21use crate::load_mamba_model_from_path;
22#[cfg(feature = "backend-rwkv")]
23use crate::load_rwkv7_model_from_path;
24use crate::{validate_rate_backend, validate_zpaq_rate_method};
25
26#[derive(Clone)]
28pub struct AgentConfig {
29 pub algorithm: String,
31 pub ct_depth: usize,
33 pub agent_horizon: usize,
35 pub observation_bits: usize,
37 pub observation_stream_len: usize,
39 pub observation_key_mode: ObservationKeyMode,
41 pub reward_bits: usize,
43 pub agent_actions: usize,
45 pub num_simulations: usize,
47 pub exploration_exploitation_ratio: f64,
49 pub discount_gamma: f64,
51 pub min_reward: Reward,
53 pub max_reward: Reward,
55 pub reward_offset: Reward,
59 pub random_seed: Option<u64>,
63 pub rate_backend: Option<RateBackend>,
68 pub rate_backend_max_order: i64,
70 pub rwkv_model_path: Option<String>,
72 pub rwkv_method: Option<String>,
74 pub mamba_model_path: Option<String>,
76 pub mamba_method: Option<String>,
78 pub rosa_max_order: Option<i64>,
80 pub zpaq_method: Option<String>,
82}
83
84impl AgentConfig {
85 pub fn validate(&self) -> Result<(), String> {
87 if self.agent_actions == 0 {
88 return Err("agent_actions must be >= 1".to_string());
89 }
90 if self.agent_horizon == 0 {
91 return Err("agent_horizon must be >= 1".to_string());
92 }
93 if self.num_simulations == 0 {
94 return Err("num_simulations must be >= 1".to_string());
95 }
96 if self.exploration_exploitation_ratio <= 0.0 {
97 return Err("exploration_exploitation_ratio must be > 0".to_string());
98 }
99 if !(0.0..=1.0).contains(&self.discount_gamma) {
100 return Err(format!(
101 "discount_gamma must be in [0, 1] for MC-AIXI, got {}",
102 self.discount_gamma
103 ));
104 }
105 if self.max_reward < self.min_reward {
106 return Err(format!(
107 "max_reward must be >= min_reward (got {} < {})",
108 self.max_reward, self.min_reward
109 ));
110 }
111
112 let min_shifted = (self.min_reward as i128) + (self.reward_offset as i128);
113 let max_shifted = (self.max_reward as i128) + (self.reward_offset as i128);
114 if min_shifted < 0 {
115 return Err(format!(
116 "reward_offset too small: min_reward + reward_offset must be >= 0 (got {})",
117 min_shifted
118 ));
119 }
120 if self.reward_bits < 64 {
121 let max_enc = (1u128 << self.reward_bits) - 1;
122 if (max_shifted as u128) > max_enc {
123 return Err(format!(
124 "reward_bits too small for configured reward range: max shifted reward {} exceeds {}",
125 max_shifted, max_enc
126 ));
127 }
128 }
129
130 if let Some(rate_backend) = &self.rate_backend {
131 validate_rate_backend(rate_backend)
132 .map_err(|err| format!("invalid rate_backend: {err}"))?;
133 if rate_backend_contains_zpaq(rate_backend) {
134 return Err(
135 "MC-AIXI strict generic rate_backend support requires reversible action conditioning; configured rate_backend contains zpaq which does not provide the reversible action conditioning required by \"A Monte-Carlo AIXI Approximation\""
136 .to_string(),
137 );
138 }
139 return Ok(());
140 }
141
142 match self.algorithm.as_str() {
143 "ctw" | "fac-ctw" | "ac-ctw" | "ctw-context-tree" | "rosa" => {}
144 #[cfg(feature = "backend-rwkv")]
145 "rwkv" => {
146 let has_method = self
147 .rwkv_method
148 .as_deref()
149 .map(str::trim)
150 .is_some_and(|v| !v.is_empty());
151 let has_path = self
152 .rwkv_model_path
153 .as_deref()
154 .map(str::trim)
155 .is_some_and(|v| !v.is_empty());
156 if !(has_method || has_path) {
157 return Err(
158 "algorithm=rwkv requires rwkv_model_path or rwkv_method when no rate_backend override is configured"
159 .to_string(),
160 );
161 }
162 }
163 #[cfg(not(feature = "backend-rwkv"))]
164 "rwkv" => return Err("algorithm=rwkv requires backend-rwkv feature".to_string()),
165 #[cfg(feature = "backend-mamba")]
166 "mamba" => {
167 let has_method = self
168 .mamba_method
169 .as_deref()
170 .map(str::trim)
171 .is_some_and(|v| !v.is_empty());
172 let has_path = self
173 .mamba_model_path
174 .as_deref()
175 .map(str::trim)
176 .is_some_and(|v| !v.is_empty());
177 if !(has_method || has_path) {
178 return Err(
179 "algorithm=mamba requires mamba_model_path or mamba_method when no rate_backend override is configured"
180 .to_string(),
181 );
182 }
183 }
184 #[cfg(not(feature = "backend-mamba"))]
185 "mamba" => return Err("algorithm=mamba requires backend-mamba feature".to_string()),
186 "zpaq" => {
187 let method = self.zpaq_method.as_deref().unwrap_or("1");
188 if let Err(err) = validate_zpaq_rate_method(method) {
189 return Err(format!("Invalid zpaq method for AIXI: {err}"));
190 }
191 }
192 other => return Err(format!("Unknown algorithm: {other}")),
193 }
194
195 Ok(())
196 }
197}
198
199pub struct Agent {
205 model: Box<dyn Predictor>,
207 planner: Option<SearchTree>,
209 config: AgentConfig,
211
212 age: u64,
214 total_reward: f64,
216
217 action_bits: usize,
219
220 rng: RandomGenerator,
222
223 obs_buffer: Vec<u64>,
225 sym_buffer: Vec<bool>,
227}
228
229impl Agent {
230 pub fn new(config: AgentConfig) -> Self {
232 Self::try_new(config).unwrap_or_else(|err| panic!("Invalid MC-AIXI config: {err}"))
233 }
234
235 pub fn try_new(config: AgentConfig) -> Result<Self, String> {
237 config.validate()?;
238
239 let mut action_bits = 0;
240 let mut c = 1;
241 let mut i = 1;
242 while i < config.agent_actions {
243 i *= 2;
244 action_bits = c;
245 c += 1;
246 }
247 if config.agent_actions == 1 {
248 action_bits = 1;
249 }
250
251 let model = build_model(&config)?;
252
253 let rng = if let Some(seed) = config.random_seed {
254 RandomGenerator::from_seed(seed)
255 } else {
256 RandomGenerator::new()
257 };
258
259 Ok(Self {
260 model,
261 planner: Some(SearchTree::new()),
262 config,
263 age: 0,
264 total_reward: 0.0,
265 action_bits,
266 rng,
267 obs_buffer: Vec::with_capacity(128),
268 sym_buffer: Vec::with_capacity(64),
269 })
270 }
271
272 fn clone_for_simulation(&self, seed: u64) -> Self {
273 Self {
274 model: self.model.boxed_clone(),
275 planner: None,
276 config: self.config.clone(),
277 age: self.age,
278 total_reward: self.total_reward,
279 action_bits: self.action_bits,
280 rng: self.rng.fork_with(seed),
281 obs_buffer: Vec::with_capacity(128),
282 sym_buffer: Vec::with_capacity(64),
283 }
284 }
285
286 pub fn reset(&mut self) {
288 self.age = 0;
289 self.total_reward = 0.0;
290 }
291
292 pub fn get_planned_action(
296 &mut self,
297 prev_obs_stream: &[PerceptVal],
298 prev_rew: Reward,
299 prev_act: Action,
300 ) -> Action {
301 let mut planner = self.planner.take().expect("Planner missing");
302 let num_sim = self.config.num_simulations;
303 let action = planner.search(self, prev_obs_stream, prev_rew, prev_act, num_sim);
304 self.planner = Some(planner);
305 action
306 }
307
308 pub fn model_update_percept(&mut self, observation: PerceptVal, reward: Reward) {
310 self.model_update_percept_stream(&[observation], reward);
311 }
312
313 pub fn model_update_percept_stream(&mut self, observations: &[PerceptVal], reward: Reward) {
315 debug_assert!(
316 !observations.is_empty() || self.config.observation_bits == 0,
317 "percept update missing observation stream"
318 );
319 let mut percept_syms = Vec::new();
320 for &obs in observations {
321 encode(&mut percept_syms, obs, self.config.observation_bits);
322 }
323 crate::aixi::common::encode_reward_offset(
324 &mut percept_syms,
325 reward,
326 self.config.reward_bits,
327 self.config.reward_offset,
328 );
329
330 for &sym in &percept_syms {
331 self.model.commit_update(sym);
332 }
333
334 self.total_reward += reward as f64;
335 }
336
337 pub fn observation_repr_from_stream(&self, observations: &[PerceptVal]) -> Vec<PerceptVal> {
339 observation_repr_from_stream(
340 self.config.observation_key_mode,
341 observations,
342 self.config.observation_bits,
343 )
344 }
345
346 pub fn model_update_action_external(&mut self, action: Action) {
348 self.sym_buffer.clear();
349 encode(&mut self.sym_buffer, action, self.action_bits);
350
351 for &sym in &self.sym_buffer {
352 self.model.commit_update_history(sym);
353 }
354 }
355}
356
357fn build_model(config: &AgentConfig) -> Result<Box<dyn Predictor>, String> {
358 if let Some(rate_backend) = config.rate_backend.clone() {
359 let bit_backend = adapt_rate_backend_for_bit_tokens(rate_backend);
360 let predictor = RateBackendBitPredictor::new(bit_backend, config.rate_backend_max_order)?;
361 return Ok(Box::new(predictor));
362 }
363
364 match config.algorithm.as_str() {
365 "ctw" | "fac-ctw" => {
368 let obs_len = config.observation_stream_len.max(1);
369 let percept_bits = (config.observation_bits * obs_len) + config.reward_bits;
370 Ok(Box::new(FacCtwPredictor::new(
371 config.ct_depth,
372 percept_bits,
373 )))
374 }
375 "ac-ctw" | "ctw-context-tree" => Ok(Box::new(CtwPredictor::new(config.ct_depth))),
377 "rosa" => {
378 let max_order = config.rosa_max_order.unwrap_or(20);
379 Ok(Box::new(RosaPredictor::new(max_order)))
380 }
381 #[cfg(feature = "backend-rwkv")]
382 "rwkv" => {
383 if let Some(method) = config
384 .rwkv_method
385 .as_deref()
386 .map(str::trim)
387 .filter(|v| !v.is_empty())
388 {
389 let predictor = RwkvPredictor::from_method(method)
390 .map_err(|err| format!("Invalid RWKV method for AIXI: {err}"))?;
391 Ok(Box::new(predictor))
392 } else {
393 let path = config.rwkv_model_path.as_ref().ok_or_else(|| {
394 "RWKV model path required when rwkv_method is not configured".to_string()
395 })?;
396 let model_arc = load_rwkv7_model_from_path(path);
397 Ok(Box::new(RwkvPredictor::new(model_arc)))
398 }
399 }
400 #[cfg(not(feature = "backend-rwkv"))]
401 "rwkv" => Err("RWKV backend disabled at compile time".to_string()),
402 #[cfg(feature = "backend-mamba")]
403 "mamba" => {
404 if let Some(method) = config
405 .mamba_method
406 .as_deref()
407 .map(str::trim)
408 .filter(|v| !v.is_empty())
409 {
410 let predictor = MambaPredictor::from_method(method)
411 .map_err(|err| format!("Invalid Mamba method for AIXI: {err}"))?;
412 Ok(Box::new(predictor))
413 } else {
414 let path = config.mamba_model_path.as_ref().ok_or_else(|| {
415 "Mamba model path required when mamba_method is not configured".to_string()
416 })?;
417 let model_arc = load_mamba_model_from_path(path);
418 Ok(Box::new(MambaPredictor::new(model_arc)))
419 }
420 }
421 #[cfg(not(feature = "backend-mamba"))]
422 "mamba" => Err("Mamba backend disabled at compile time".to_string()),
423 "zpaq" => {
424 let method = config
425 .zpaq_method
426 .clone()
427 .unwrap_or_else(|| "1".to_string());
428 if let Err(err) = validate_zpaq_rate_method(&method) {
429 return Err(format!("Invalid zpaq method for AIXI: {err}"));
430 }
431 Ok(Box::new(ZpaqPredictor::new(method, 2f64.powi(-24))))
432 }
433 _ => Err(format!("Unknown algorithm: {}", config.algorithm)),
434 }
435}
436
437impl AgentSimulator for Agent {
438 fn get_num_actions(&self) -> usize {
439 self.config.agent_actions
440 }
441
442 fn get_num_observation_bits(&self) -> usize {
443 self.config.observation_bits
444 }
445
446 fn observation_stream_len(&self) -> usize {
447 self.config.observation_stream_len.max(1)
448 }
449
450 fn observation_key_mode(&self) -> ObservationKeyMode {
451 self.config.observation_key_mode
452 }
453
454 fn get_num_reward_bits(&self) -> usize {
455 self.config.reward_bits
456 }
457
458 fn horizon(&self) -> usize {
459 self.config.agent_horizon
460 }
461
462 fn max_reward(&self) -> Reward {
463 self.config.max_reward
464 }
465
466 fn min_reward(&self) -> Reward {
467 self.config.min_reward
468 }
469
470 fn reward_offset(&self) -> i64 {
471 self.config.reward_offset
472 }
473
474 fn get_explore_exploit_ratio(&self) -> f64 {
475 self.config.exploration_exploitation_ratio
476 }
477
478 fn discount_gamma(&self) -> f64 {
479 self.config.discount_gamma
480 }
481
482 fn model_update_action(&mut self, action: Action) {
483 self.sym_buffer.clear();
484 encode(&mut self.sym_buffer, action, self.action_bits);
485
486 for &sym in &self.sym_buffer {
487 self.model.update_history(sym);
488 }
489 }
490
491 fn gen_percept_and_update(&mut self, bits: usize) -> u64 {
492 self.sym_buffer.clear();
493 for _ in 0..bits {
494 let prob_1 = self.model.predict_one();
495 let sym = self.rng.gen_bool(prob_1);
496 self.model.update(sym);
497 self.sym_buffer.push(sym);
498 }
499 decode(&self.sym_buffer, bits)
500 }
501
502 fn begin_simulation(&mut self) {
503 self.model.begin_rollback_scope();
504 }
505
506 fn gen_percepts_and_update(&mut self) -> (Vec<PerceptVal>, Reward) {
507 let obs_bits = self.config.observation_bits;
508 let obs_len = self.config.observation_stream_len.max(1);
509
510 self.obs_buffer.clear();
511 for _ in 0..obs_len {
512 let p = self.gen_percept_and_update(obs_bits);
513 self.obs_buffer.push(p);
514 }
515
516 let obs_repr = observation_repr_from_stream(
517 self.config.observation_key_mode,
518 &self.obs_buffer,
519 obs_bits,
520 );
521 let rew_bits = self.config.reward_bits;
522 let rew_u = self.gen_percept_and_update(rew_bits);
523 let rew = (rew_u as i64) - self.config.reward_offset;
524
525 (obs_repr, rew)
528 }
529
530 fn gen_range(&mut self, end: usize) -> usize {
531 self.rng.gen_range(end)
532 }
533
534 fn gen_f64(&mut self) -> f64 {
535 self.rng.gen_f64()
536 }
537
538 fn model_revert(&mut self, steps: usize) {
539 if self.model.rollback_scope() {
540 return;
541 }
542 let obs_bits = self.config.observation_bits * self.config.observation_stream_len.max(1);
543 let percept_bits = obs_bits + self.config.reward_bits;
544
545 for _ in 0..steps {
546 for _ in 0..percept_bits {
547 self.model.revert();
548 }
549 for _ in 0..self.action_bits {
550 self.model.pop_history();
551 }
552 }
553 }
554
555 fn boxed_clone_with_seed(&self, seed: u64) -> Box<dyn AgentSimulator> {
556 Box::new(self.clone_for_simulation(seed))
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563 use std::sync::{Arc, Mutex};
564
565 #[derive(Clone, Default)]
566 struct CallCounts {
567 update: usize,
568 commit_update: usize,
569 update_history: usize,
570 commit_update_history: usize,
571 begin_scope: usize,
572 rollback_scope: usize,
573 revert: usize,
574 pop_history: usize,
575 }
576
577 #[derive(Clone)]
578 struct InstrumentedPredictor {
579 counts: Arc<Mutex<CallCounts>>,
580 }
581
582 impl InstrumentedPredictor {
583 fn new(counts: Arc<Mutex<CallCounts>>) -> Self {
584 Self { counts }
585 }
586 }
587
588 impl Predictor for InstrumentedPredictor {
589 fn update(&mut self, _sym: bool) {
590 self.counts.lock().unwrap().update += 1;
591 }
592
593 fn commit_update(&mut self, _sym: bool) {
594 self.counts.lock().unwrap().commit_update += 1;
595 }
596
597 fn update_history(&mut self, _sym: bool) {
598 self.counts.lock().unwrap().update_history += 1;
599 }
600
601 fn commit_update_history(&mut self, _sym: bool) {
602 self.counts.lock().unwrap().commit_update_history += 1;
603 }
604
605 fn revert(&mut self) {
606 self.counts.lock().unwrap().revert += 1;
607 }
608
609 fn pop_history(&mut self) {
610 self.counts.lock().unwrap().pop_history += 1;
611 }
612
613 fn begin_rollback_scope(&mut self) {
614 self.counts.lock().unwrap().begin_scope += 1;
615 }
616
617 fn rollback_scope(&mut self) -> bool {
618 self.counts.lock().unwrap().rollback_scope += 1;
619 true
620 }
621
622 fn predict_prob(&mut self, sym: bool) -> f64 {
623 if sym { 0.75 } else { 0.25 }
624 }
625
626 fn model_name(&self) -> String {
627 "InstrumentedPredictor".to_string()
628 }
629
630 fn boxed_clone(&self) -> Box<dyn Predictor> {
631 Box::new(self.clone())
632 }
633 }
634
635 fn basic_config() -> AgentConfig {
636 AgentConfig {
637 algorithm: "ac-ctw".to_string(),
638 ct_depth: 8,
639 agent_horizon: 2,
640 observation_bits: 2,
641 observation_stream_len: 2,
642 observation_key_mode: ObservationKeyMode::FullStream,
643 reward_bits: 3,
644 agent_actions: 4,
645 num_simulations: 2,
646 exploration_exploitation_ratio: 1.0,
647 discount_gamma: 0.95,
648 min_reward: -2,
649 max_reward: 3,
650 reward_offset: 2,
651 random_seed: Some(7),
652 rate_backend: None,
653 rate_backend_max_order: 8,
654 rwkv_model_path: None,
655 rwkv_method: None,
656 mamba_model_path: None,
657 mamba_method: None,
658 rosa_max_order: None,
659 zpaq_method: None,
660 }
661 }
662
663 #[test]
664 fn external_history_updates_use_committed_predictor_paths() {
665 let mut agent = Agent::try_new(basic_config()).expect("valid agent config");
666 let counts = Arc::new(Mutex::new(CallCounts::default()));
667 agent.model = Box::new(InstrumentedPredictor::new(counts.clone()));
668
669 agent.model_update_percept_stream(&[1, 2], 1);
670 agent.model_update_action_external(3);
671
672 let snapshot = counts.lock().unwrap().clone();
673 assert_eq!(snapshot.commit_update, 7);
674 assert_eq!(snapshot.commit_update_history, 2);
675 assert_eq!(snapshot.update, 0);
676 assert_eq!(snapshot.update_history, 0);
677 }
678
679 #[test]
680 fn simulation_revert_prefers_predictor_scope_when_available() {
681 let mut agent = Agent::try_new(basic_config()).expect("valid agent config");
682 let counts = Arc::new(Mutex::new(CallCounts::default()));
683 agent.model = Box::new(InstrumentedPredictor::new(counts.clone()));
684
685 AgentSimulator::begin_simulation(&mut agent);
686 agent.model_revert(3);
687
688 let snapshot = counts.lock().unwrap().clone();
689 assert_eq!(snapshot.begin_scope, 1);
690 assert_eq!(snapshot.rollback_scope, 1);
691 assert_eq!(snapshot.revert, 0);
692 assert_eq!(snapshot.pop_history, 0);
693 }
694}