1use crate::aixi::common::{
7 Action, ObservationKeyMode, PerceptVal, RandomGenerator, Reward, decode, encode,
8 observation_repr_from_stream,
9};
10use crate::aixi::mcts::{AgentSimulator, SearchTree};
11#[cfg(feature = "backend-mamba")]
12use crate::aixi::model::MambaPredictor;
13#[cfg(feature = "backend-rwkv")]
14use crate::aixi::model::RwkvPredictor;
15use crate::aixi::model::{CtwPredictor, FacCtwPredictor, Predictor, RosaPredictor, ZpaqPredictor};
16#[cfg(feature = "backend-mamba")]
17use crate::load_mamba_model_from_path;
18#[cfg(feature = "backend-rwkv")]
19use crate::load_rwkv7_model_from_path;
20use crate::validate_zpaq_rate_method;
21
22#[derive(Clone, Debug)]
24pub struct AgentConfig {
25 pub algorithm: String,
27 pub ct_depth: usize,
29 pub agent_horizon: usize,
31 pub observation_bits: usize,
33 pub observation_stream_len: usize,
35 pub observation_key_mode: ObservationKeyMode,
37 pub reward_bits: usize,
39 pub agent_actions: usize,
41 pub num_simulations: usize,
43 pub exploration_exploitation_ratio: f64,
45 pub discount_gamma: f64,
47 pub min_reward: Reward,
49 pub max_reward: Reward,
51 pub reward_offset: Reward,
55 pub random_seed: Option<u64>,
59 pub rwkv_model_path: Option<String>,
61 pub rwkv_method: Option<String>,
63 pub mamba_model_path: Option<String>,
65 pub mamba_method: Option<String>,
67 pub rosa_max_order: Option<i64>,
69 pub zpaq_method: Option<String>,
71}
72
73pub struct Agent {
79 model: Box<dyn Predictor>,
81 planner: Option<SearchTree>,
83 config: AgentConfig,
85
86 age: u64,
88 total_reward: f64,
90
91 action_bits: usize,
93
94 rng: RandomGenerator,
96
97 obs_buffer: Vec<u64>,
99 sym_buffer: Vec<bool>,
101}
102
103impl Agent {
104 pub fn new(config: AgentConfig) -> Self {
106 let mut action_bits = 0;
107 let mut c = 1;
108 let mut i = 1;
109 while i < config.agent_actions {
110 i *= 2;
111 action_bits = c;
112 c += 1;
113 }
114 if config.agent_actions == 1 {
115 action_bits = 1;
116 }
117
118 let model: Box<dyn Predictor> = match config.algorithm.as_str() {
119 "ctw" | "fac-ctw" => {
121 let obs_len = config.observation_stream_len.max(1);
122 let percept_bits = (config.observation_bits * obs_len) + config.reward_bits;
123 Box::new(FacCtwPredictor::new(config.ct_depth, percept_bits))
124 }
125 "ac-ctw" | "ctw-context-tree" => Box::new(CtwPredictor::new(config.ct_depth)),
127 "rosa" => {
128 let max_order = config.rosa_max_order.unwrap_or(20);
129 Box::new(RosaPredictor::new(max_order))
130 }
131 #[cfg(feature = "backend-rwkv")]
132 "rwkv" => {
133 if let Some(method) = config.rwkv_method.as_deref() {
134 let predictor = RwkvPredictor::from_method(method)
135 .unwrap_or_else(|err| panic!("Invalid RWKV method for AIXI: {err}"));
136 Box::new(predictor)
137 } else {
138 let path = config
139 .rwkv_model_path
140 .as_ref()
141 .expect("RWKV model path required");
142 let model_arc = load_rwkv7_model_from_path(path);
143 Box::new(RwkvPredictor::new(model_arc))
144 }
145 }
146 #[cfg(not(feature = "backend-rwkv"))]
147 "rwkv" => panic!("RWKV backend disabled at compile time"),
148 #[cfg(feature = "backend-mamba")]
149 "mamba" => {
150 if let Some(method) = config.mamba_method.as_deref() {
151 let predictor = MambaPredictor::from_method(method)
152 .unwrap_or_else(|err| panic!("Invalid Mamba method for AIXI: {err}"));
153 Box::new(predictor)
154 } else {
155 let path = config
156 .mamba_model_path
157 .as_ref()
158 .expect("Mamba model path required");
159 let model_arc = load_mamba_model_from_path(path);
160 Box::new(MambaPredictor::new(model_arc))
161 }
162 }
163 #[cfg(not(feature = "backend-mamba"))]
164 "mamba" => panic!("Mamba backend disabled at compile time"),
165 "zpaq" => {
166 let method = config
167 .zpaq_method
168 .clone()
169 .unwrap_or_else(|| "1".to_string());
170 if let Err(err) = validate_zpaq_rate_method(&method) {
171 panic!("Invalid zpaq method for AIXI: {err}");
172 }
173 Box::new(ZpaqPredictor::new(method, 2f64.powi(-24)))
174 }
175 _ => panic!("Unknown algorithm: {}", config.algorithm),
176 };
177
178 let rng = if let Some(seed) = config.random_seed {
179 RandomGenerator::from_seed(seed)
180 } else {
181 RandomGenerator::new()
182 };
183
184 Self {
185 model,
186 planner: Some(SearchTree::new()),
187 config,
188 age: 0,
189 total_reward: 0.0,
190 action_bits,
191 rng,
192 obs_buffer: Vec::with_capacity(128),
193 sym_buffer: Vec::with_capacity(64),
194 }
195 }
196
197 fn clone_for_simulation(&self, seed: u64) -> Self {
198 Self {
199 model: self.model.boxed_clone(),
200 planner: None,
201 config: self.config.clone(),
202 age: self.age,
203 total_reward: self.total_reward,
204 action_bits: self.action_bits,
205 rng: self.rng.fork_with(seed),
206 obs_buffer: Vec::with_capacity(128),
207 sym_buffer: Vec::with_capacity(64),
208 }
209 }
210
211 pub fn reset(&mut self) {
213 self.age = 0;
214 self.total_reward = 0.0;
215 }
216
217 pub fn get_planned_action(
221 &mut self,
222 prev_obs_stream: &[PerceptVal],
223 prev_rew: Reward,
224 prev_act: Action,
225 ) -> Action {
226 let mut planner = self.planner.take().expect("Planner missing");
227 let num_sim = self.config.num_simulations;
228 let action = planner.search(self, prev_obs_stream, prev_rew, prev_act, num_sim);
229 self.planner = Some(planner);
230 action
231 }
232
233 pub fn model_update_percept(&mut self, observation: PerceptVal, reward: Reward) {
235 self.model_update_percept_stream(&[observation], reward);
236 }
237
238 pub fn model_update_percept_stream(&mut self, observations: &[PerceptVal], reward: Reward) {
240 debug_assert!(
241 !observations.is_empty() || self.config.observation_bits == 0,
242 "percept update missing observation stream"
243 );
244 let mut percept_syms = Vec::new();
245 for &obs in observations {
246 encode(&mut percept_syms, obs, self.config.observation_bits);
247 }
248 crate::aixi::common::encode_reward_offset(
249 &mut percept_syms,
250 reward,
251 self.config.reward_bits,
252 self.config.reward_offset,
253 );
254
255 for &sym in &percept_syms {
256 self.model.update(sym);
257 }
258
259 self.total_reward += reward as f64;
260 }
261
262 pub fn observation_repr_from_stream(&self, observations: &[PerceptVal]) -> Vec<PerceptVal> {
264 observation_repr_from_stream(
265 self.config.observation_key_mode,
266 observations,
267 self.config.observation_bits,
268 )
269 }
270
271 pub fn model_update_action_external(&mut self, action: Action) {
273 self.model_update_action(action);
274 }
275}
276
277impl AgentSimulator for Agent {
278 fn get_num_actions(&self) -> usize {
279 self.config.agent_actions
280 }
281
282 fn get_num_observation_bits(&self) -> usize {
283 self.config.observation_bits
284 }
285
286 fn observation_stream_len(&self) -> usize {
287 self.config.observation_stream_len.max(1)
288 }
289
290 fn observation_key_mode(&self) -> ObservationKeyMode {
291 self.config.observation_key_mode
292 }
293
294 fn get_num_reward_bits(&self) -> usize {
295 self.config.reward_bits
296 }
297
298 fn horizon(&self) -> usize {
299 self.config.agent_horizon
300 }
301
302 fn max_reward(&self) -> Reward {
303 self.config.max_reward
304 }
305
306 fn min_reward(&self) -> Reward {
307 self.config.min_reward
308 }
309
310 fn reward_offset(&self) -> i64 {
311 self.config.reward_offset
312 }
313
314 fn get_explore_exploit_ratio(&self) -> f64 {
315 self.config.exploration_exploitation_ratio
316 }
317
318 fn discount_gamma(&self) -> f64 {
319 self.config.discount_gamma
320 }
321
322 fn model_update_action(&mut self, action: Action) {
323 self.sym_buffer.clear();
324 encode(&mut self.sym_buffer, action, self.action_bits);
325
326 for &sym in &self.sym_buffer {
327 self.model.update_history(sym);
328 }
329 }
330
331 fn gen_percept_and_update(&mut self, bits: usize) -> u64 {
332 self.sym_buffer.clear();
333 for _ in 0..bits {
334 let prob_1 = self.model.predict_one();
335 let sym = self.rng.gen_bool(prob_1);
336 self.model.update(sym);
337 self.sym_buffer.push(sym);
338 }
339 decode(&self.sym_buffer, bits)
340 }
341
342 fn gen_percepts_and_update(&mut self) -> (Vec<PerceptVal>, Reward) {
343 let obs_bits = self.config.observation_bits;
344 let obs_len = self.config.observation_stream_len.max(1);
345
346 self.obs_buffer.clear();
347 for _ in 0..obs_len {
348 let p = self.gen_percept_and_update(obs_bits);
349 self.obs_buffer.push(p);
350 }
351
352 let obs_repr = observation_repr_from_stream(
353 self.config.observation_key_mode,
354 &self.obs_buffer,
355 obs_bits,
356 );
357 let rew_bits = self.config.reward_bits;
358 let rew_u = self.gen_percept_and_update(rew_bits);
359 let rew = (rew_u as i64) - self.config.reward_offset;
360
361 (obs_repr, rew)
364 }
365
366 fn gen_range(&mut self, end: usize) -> usize {
367 self.rng.gen_range(end)
368 }
369
370 fn gen_f64(&mut self) -> f64 {
371 self.rng.gen_f64()
372 }
373
374 fn model_revert(&mut self, steps: usize) {
375 let obs_bits = self.config.observation_bits * self.config.observation_stream_len.max(1);
376 let percept_bits = obs_bits + self.config.reward_bits;
377
378 for _ in 0..steps {
379 for _ in 0..percept_bits {
380 self.model.revert();
381 }
382 for _ in 0..self.action_bits {
383 self.model.pop_history();
384 }
385 }
386 }
387
388 fn boxed_clone_with_seed(&self, seed: u64) -> Box<dyn AgentSimulator> {
389 Box::new(self.clone_for_simulation(seed))
390 }
391}