1use crate::ctw::{ContextTree, FacContextTree};
8use crate::zpaq_rate::ZpaqRateModel;
9use rosaplus::{RosaPlus, RosaTx};
10use rwkvzip::{Compressor, Model, State};
11use std::sync::Arc;
12
13pub trait Predictor: Send + Sync {
18 fn update(&mut self, sym: bool);
20
21 fn update_history(&mut self, sym: bool) {
24 self.update(sym);
25 }
26
27 fn revert(&mut self);
29
30 fn pop_history(&mut self) {
32 self.revert();
33 }
34
35 fn predict_prob(&mut self, sym: bool) -> f64;
37
38 fn predict_one(&mut self) -> f64 {
40 self.predict_prob(true)
41 }
42
43 fn model_name(&self) -> String;
45
46 fn boxed_clone(&self) -> Box<dyn Predictor>;
48}
49
50pub struct CtwPredictor {
55 tree: ContextTree,
56}
57
58impl CtwPredictor {
59 pub fn new(depth: usize) -> Self {
61 Self {
62 tree: ContextTree::new(depth),
63 }
64 }
65}
66
67impl Predictor for CtwPredictor {
68 fn update(&mut self, sym: bool) {
69 self.tree.update(sym);
70 }
71 fn update_history(&mut self, sym: bool) {
72 self.tree.update_history(&[sym]);
73 }
74
75 fn revert(&mut self) {
76 self.tree.revert();
77 }
78 fn pop_history(&mut self) {
79 self.tree.revert_history();
80 }
81
82 fn predict_prob(&mut self, sym: bool) -> f64 {
83 self.tree.predict(sym)
84 }
85
86 fn model_name(&self) -> String {
87 format!("AC-CTW(d={})", self.tree.depth())
88 }
89
90 fn boxed_clone(&self) -> Box<dyn Predictor> {
91 Box::new(Self {
92 tree: self.tree.clone(),
93 })
94 }
95}
96
97pub struct FacCtwPredictor {
105 tree: FacContextTree,
106 current_bit: usize,
108 num_bits: usize,
110}
111
112impl FacCtwPredictor {
113 pub fn new(base_depth: usize, num_percept_bits: usize) -> Self {
118 Self {
119 tree: FacContextTree::new(base_depth, num_percept_bits),
120 current_bit: 0,
121 num_bits: num_percept_bits,
122 }
123 }
124}
125
126impl Predictor for FacCtwPredictor {
127 fn update(&mut self, sym: bool) {
128 self.tree.update(sym, self.current_bit);
129 self.current_bit = (self.current_bit + 1) % self.num_bits;
130 }
131
132 fn update_history(&mut self, sym: bool) {
133 self.tree.update_history(&[sym]);
134 }
135
136 fn revert(&mut self) {
137 self.current_bit = if self.current_bit == 0 {
139 self.num_bits - 1
140 } else {
141 self.current_bit - 1
142 };
143 self.tree.revert(self.current_bit);
144 }
145
146 fn pop_history(&mut self) {
147 self.tree.revert_history(1);
148 }
149
150 fn predict_prob(&mut self, sym: bool) -> f64 {
151 self.tree.predict(sym, self.current_bit)
152 }
153
154 fn model_name(&self) -> String {
155 format!("FAC-CTW(D={}, k={})", self.tree.base_depth(), self.num_bits)
156 }
157
158 fn boxed_clone(&self) -> Box<dyn Predictor> {
159 Box::new(Self {
160 tree: self.tree.clone(),
161 current_bit: self.current_bit,
162 num_bits: self.num_bits,
163 })
164 }
165}
166
167pub struct RosaPredictor {
172 model: RosaPlus,
173 history: Vec<RosaTx>,
174}
175
176impl RosaPredictor {
177 pub fn new(max_order: i64) -> Self {
180 let mut model = RosaPlus::new(max_order, false, 0, 42);
182 model.build_lm_full_bytes_no_finalize_endpos();
184 Self {
185 model,
186 history: Vec::new(),
187 }
188 }
189}
190
191impl Predictor for RosaPredictor {
192 fn update(&mut self, sym: bool) {
193 let mut tx = self.model.begin_tx();
194 let byte = if sym { 1u8 } else { 0u8 };
196
197 self.model.train_sequence_tx(&mut tx, &[byte]);
199 self.history.push(tx);
200 }
201
202 fn revert(&mut self) {
203 if let Some(tx) = self.history.pop() {
204 self.model.rollback_tx(tx);
205 }
206 }
207
208 fn predict_prob(&mut self, sym: bool) -> f64 {
209 let p0 = self.model.prob_for_last(0);
210 let p1 = self.model.prob_for_last(1);
211 let denom = (p0 + p1).max(1e-12);
212 if sym { p1 / denom } else { p0 / denom }
213 }
214
215 fn model_name(&self) -> String {
216 "ROSA".to_string()
217 }
218
219 fn boxed_clone(&self) -> Box<dyn Predictor> {
220 Box::new(Self {
221 model: self.model.clone(),
222 history: self.history.clone(),
223 })
224 }
225}
226
227pub struct ZpaqPredictor {
232 method: String,
233 min_prob: f64,
234 model: ZpaqRateModel,
235 history: Vec<u8>,
236 pending: Option<(u8, f64)>,
237}
238
239unsafe impl Sync for ZpaqPredictor {}
240
241impl ZpaqPredictor {
242 pub fn new(method: String, min_prob: f64) -> Self {
243 let model = ZpaqRateModel::new(method.clone(), min_prob);
244 Self {
245 method,
246 min_prob,
247 model,
248 history: Vec::new(),
249 pending: None,
250 }
251 }
252
253 fn rebuild_from_history(&mut self) {
254 self.model.reset();
255 if !self.history.is_empty() {
256 self.model.update_and_score(&self.history);
257 }
258 }
259
260 fn log_prob_from_history(&self, symbol: u8) -> f64 {
261 let mut tmp = ZpaqRateModel::new(self.method.clone(), self.min_prob);
262 if !self.history.is_empty() {
263 tmp.update_and_score(&self.history);
264 }
265 tmp.log_prob(symbol)
266 }
267}
268
269impl Predictor for ZpaqPredictor {
270 fn update(&mut self, sym: bool) {
271 let byte = if sym { 1u8 } else { 0u8 };
272 if let Some((pending, _)) = self.pending {
273 if pending == byte {
274 self.model.update(byte);
275 self.pending = None;
276 self.history.push(byte);
277 return;
278 }
279 self.pending = None;
280 self.rebuild_from_history();
281 }
282 self.model.update(byte);
283 self.history.push(byte);
284 }
285
286 fn revert(&mut self) {
287 if self.history.pop().is_some() {
288 self.pending = None;
289 self.rebuild_from_history();
290 }
291 }
292
293 fn predict_prob(&mut self, sym: bool) -> f64 {
294 let byte = if sym { 1u8 } else { 0u8 };
295 if let Some((pending, logp)) = self.pending {
296 if pending == byte {
297 return logp.exp();
298 }
299 return self.log_prob_from_history(byte).exp();
300 }
301 let logp = self.model.log_prob(byte);
302 self.pending = Some((byte, logp));
303 logp.exp()
304 }
305
306 fn model_name(&self) -> String {
307 format!("ZPAQ({})", self.method)
308 }
309
310 fn boxed_clone(&self) -> Box<dyn Predictor> {
311 let mut model = ZpaqRateModel::new(self.method.clone(), self.min_prob);
312 if !self.history.is_empty() {
313 model.update_and_score(&self.history);
314 }
315 Box::new(Self {
316 method: self.method.clone(),
317 min_prob: self.min_prob,
318 model,
319 history: self.history.clone(),
320 pending: None,
321 })
322 }
323}
324
325use rwkvzip::coders::softmax_pdf_floor_inplace;
326
327pub struct RwkvPredictor {
332 compressor: Compressor,
333 history: Vec<(State, Vec<f64>)>,
334}
335
336impl RwkvPredictor {
337 pub fn new(model: Arc<Model>) -> Self {
339 let mut compressor = Compressor::new_from_model(model);
340 let vocab_size = compressor.vocab_size();
341 let logits = compressor
342 .model
343 .forward(&mut compressor.scratch, 0, &mut compressor.state);
344 softmax_pdf_floor_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
345
346 Self {
347 compressor,
348 history: Vec::new(),
349 }
350 }
351}
352
353impl Predictor for RwkvPredictor {
354 fn update(&mut self, sym: bool) {
355 self.history.push((
357 self.compressor.state.clone(),
358 self.compressor.pdf_buffer.clone(),
359 ));
360
361 let byte = if sym { 1u32 } else { 0u32 };
362 let vocab_size = self.compressor.vocab_size();
363
364 let logits = self.compressor.model.forward(
365 &mut self.compressor.scratch,
366 byte,
367 &mut self.compressor.state,
368 );
369 softmax_pdf_floor_inplace(logits, vocab_size, &mut self.compressor.pdf_buffer);
370 }
371
372 fn revert(&mut self) {
373 if let Some((state, pdf)) = self.history.pop() {
374 self.compressor.state = state;
375 self.compressor.pdf_buffer = pdf;
376 }
377 }
378
379 fn predict_prob(&mut self, sym: bool) -> f64 {
380 let idx = if sym { 1 } else { 0 };
381 self.compressor.pdf_buffer[idx]
383 }
384
385 fn model_name(&self) -> String {
386 "RWKV".to_string()
387 }
388
389 fn boxed_clone(&self) -> Box<dyn Predictor> {
390 Box::new(Self {
391 compressor: self.compressor.clone(),
392 history: self.history.clone(),
393 })
394 }
395}