1use crate::aixi::common::{
7 Action, ObservationKeyMode, PerceptVal, RandomGenerator, Reward, decode, encode,
8 observation_repr_from_stream,
9};
10use crate::aixi::mcts::{AgentSimulator, SearchTree};
11use crate::aixi::model::{
12 CtwPredictor, FacCtwPredictor, Predictor, RosaPredictor, RwkvPredictor, ZpaqPredictor,
13};
14use crate::{load_rwkv7_model_from_path, validate_zpaq_rate_method};
15
16#[derive(Clone, Debug)]
18pub struct AgentConfig {
19 pub algorithm: String,
21 pub ct_depth: usize,
23 pub agent_horizon: usize,
25 pub observation_bits: usize,
27 pub observation_stream_len: usize,
29 pub observation_key_mode: ObservationKeyMode,
31 pub reward_bits: usize,
33 pub agent_actions: usize,
35 pub num_simulations: usize,
37 pub exploration_exploitation_ratio: f64,
39 pub discount_gamma: f64,
41 pub min_reward: Reward,
43 pub max_reward: Reward,
45 pub reward_offset: Reward,
49 pub rwkv_model_path: Option<String>,
51 pub rosa_max_order: Option<i64>,
53 pub zpaq_method: Option<String>,
55}
56
57pub struct Agent {
63 model: Box<dyn Predictor>,
65 planner: Option<SearchTree>,
67 config: AgentConfig,
69
70 age: u64,
72 total_reward: f64,
74
75 action_bits: usize,
77
78 rng: RandomGenerator,
80
81 obs_buffer: Vec<u64>,
83 sym_buffer: Vec<bool>,
85}
86
87impl Agent {
88 pub fn new(config: AgentConfig) -> Self {
90 let mut action_bits = 0;
91 let mut c = 1;
92 let mut i = 1;
93 while i < config.agent_actions {
94 i *= 2;
95 action_bits = c;
96 c += 1;
97 }
98 if config.agent_actions == 1 {
99 action_bits = 1;
100 }
101
102 let model: Box<dyn Predictor> = match config.algorithm.as_str() {
103 "ctw" | "fac-ctw" => {
105 let obs_len = config.observation_stream_len.max(1);
106 let percept_bits = (config.observation_bits * obs_len) + config.reward_bits;
107 Box::new(FacCtwPredictor::new(config.ct_depth, percept_bits))
108 }
109 "ac-ctw" | "ctw-context-tree" => Box::new(CtwPredictor::new(config.ct_depth)),
111 "rosa" => {
112 let max_order = config.rosa_max_order.unwrap_or(20);
113 Box::new(RosaPredictor::new(max_order))
114 }
115 "rwkv" => {
116 let path = config
117 .rwkv_model_path
118 .as_ref()
119 .expect("RWKV model path required");
120 let model_arc = load_rwkv7_model_from_path(path);
121 Box::new(RwkvPredictor::new(model_arc))
122 }
123 "zpaq" => {
124 let method = config
125 .zpaq_method
126 .clone()
127 .unwrap_or_else(|| "1".to_string());
128 if let Err(err) = validate_zpaq_rate_method(&method) {
129 panic!("Invalid zpaq method for AIXI: {err}");
130 }
131 Box::new(ZpaqPredictor::new(method, 2f64.powi(-24)))
132 }
133 _ => panic!("Unknown algorithm: {}", config.algorithm),
134 };
135
136 Self {
137 model,
138 planner: Some(SearchTree::new()),
139 config,
140 age: 0,
141 total_reward: 0.0,
142 action_bits,
143 rng: RandomGenerator::new(),
144 obs_buffer: Vec::with_capacity(128),
145 sym_buffer: Vec::with_capacity(64),
146 }
147 }
148
149 fn clone_for_simulation(&self, seed: u64) -> Self {
150 Self {
151 model: self.model.boxed_clone(),
152 planner: None,
153 config: self.config.clone(),
154 age: self.age,
155 total_reward: self.total_reward,
156 action_bits: self.action_bits,
157 rng: self.rng.fork_with(seed),
158 obs_buffer: Vec::with_capacity(128),
159 sym_buffer: Vec::with_capacity(64),
160 }
161 }
162
163 pub fn reset(&mut self) {
165 self.age = 0;
166 self.total_reward = 0.0;
167 }
168
169 pub fn get_planned_action(
173 &mut self,
174 prev_obs_stream: &[PerceptVal],
175 prev_rew: Reward,
176 prev_act: Action,
177 ) -> Action {
178 let mut planner = self.planner.take().expect("Planner missing");
179 let num_sim = self.config.num_simulations;
180 let action = planner.search(self, prev_obs_stream, prev_rew, prev_act, num_sim);
181 self.planner = Some(planner);
182 action
183 }
184
185 pub fn model_update_percept(&mut self, observation: PerceptVal, reward: Reward) {
187 self.model_update_percept_stream(&[observation], reward);
188 }
189
190 pub fn model_update_percept_stream(&mut self, observations: &[PerceptVal], reward: Reward) {
192 debug_assert!(
193 !observations.is_empty() || self.config.observation_bits == 0,
194 "percept update missing observation stream"
195 );
196 let mut percept_syms = Vec::new();
197 for &obs in observations {
198 encode(&mut percept_syms, obs, self.config.observation_bits);
199 }
200 crate::aixi::common::encode_reward_offset(
201 &mut percept_syms,
202 reward,
203 self.config.reward_bits,
204 self.config.reward_offset,
205 );
206
207 for &sym in &percept_syms {
208 self.model.update(sym);
209 }
210
211 self.total_reward += reward as f64;
212 }
213
214 pub fn observation_repr_from_stream(&self, observations: &[PerceptVal]) -> Vec<PerceptVal> {
216 observation_repr_from_stream(
217 self.config.observation_key_mode,
218 observations,
219 self.config.observation_bits,
220 )
221 }
222
223 pub fn model_update_action_external(&mut self, action: Action) {
225 self.model_update_action(action);
226 }
227}
228
229impl AgentSimulator for Agent {
230 fn get_num_actions(&self) -> usize {
231 self.config.agent_actions
232 }
233
234 fn get_num_observation_bits(&self) -> usize {
235 self.config.observation_bits
236 }
237
238 fn observation_stream_len(&self) -> usize {
239 self.config.observation_stream_len.max(1)
240 }
241
242 fn observation_key_mode(&self) -> ObservationKeyMode {
243 self.config.observation_key_mode
244 }
245
246 fn get_num_reward_bits(&self) -> usize {
247 self.config.reward_bits
248 }
249
250 fn horizon(&self) -> usize {
251 self.config.agent_horizon
252 }
253
254 fn max_reward(&self) -> Reward {
255 self.config.max_reward
256 }
257
258 fn min_reward(&self) -> Reward {
259 self.config.min_reward
260 }
261
262 fn reward_offset(&self) -> i64 {
263 self.config.reward_offset
264 }
265
266 fn get_explore_exploit_ratio(&self) -> f64 {
267 self.config.exploration_exploitation_ratio
268 }
269
270 fn discount_gamma(&self) -> f64 {
271 self.config.discount_gamma
272 }
273
274 fn model_update_action(&mut self, action: Action) {
275 self.sym_buffer.clear();
276 encode(&mut self.sym_buffer, action, self.action_bits);
277
278 for &sym in &self.sym_buffer {
279 self.model.update_history(sym);
280 }
281 }
282
283 fn gen_percept_and_update(&mut self, bits: usize) -> u64 {
284 self.sym_buffer.clear();
285 for _ in 0..bits {
286 let prob_1 = self.model.predict_one();
287 let sym = self.rng.gen_bool(prob_1);
288 self.model.update(sym);
289 self.sym_buffer.push(sym);
290 }
291 decode(&self.sym_buffer, bits)
292 }
293
294 fn gen_percepts_and_update(&mut self) -> (Vec<PerceptVal>, Reward) {
295 let obs_bits = self.config.observation_bits;
296 let obs_len = self.config.observation_stream_len.max(1);
297
298 self.obs_buffer.clear();
299 for _ in 0..obs_len {
300 let p = self.gen_percept_and_update(obs_bits);
301 self.obs_buffer.push(p);
302 }
303
304 let obs_repr = observation_repr_from_stream(
305 self.config.observation_key_mode,
306 &self.obs_buffer,
307 obs_bits,
308 );
309 let rew_bits = self.config.reward_bits;
310 let rew_u = self.gen_percept_and_update(rew_bits);
311 let rew = (rew_u as i64) - self.config.reward_offset;
312
313 (obs_repr, rew)
316 }
317
318 fn gen_range(&mut self, end: usize) -> usize {
319 self.rng.gen_range(end)
320 }
321
322 fn gen_f64(&mut self) -> f64 {
323 self.rng.gen_f64()
324 }
325
326 fn model_revert(&mut self, steps: usize) {
327 let obs_bits = self.config.observation_bits * self.config.observation_stream_len.max(1);
328 let percept_bits = obs_bits + self.config.reward_bits;
329
330 for _ in 0..steps {
331 for _ in 0..percept_bits {
332 self.model.revert();
333 }
334 for _ in 0..self.action_bits {
335 self.model.pop_history();
336 }
337 }
338
339 }
340
341 fn boxed_clone_with_seed(&self, seed: u64) -> Box<dyn AgentSimulator> {
342 Box::new(self.clone_for_simulation(seed))
343 }
344}