1use crate::RateBackend;
8use crate::ctw::{ContextTree, FacContextTree};
9#[cfg(feature = "backend-mamba")]
10use crate::mambazip::{Compressor as MambaCompressor, Model as MambaModel, State as MambaState};
11use crate::mixture::{DEFAULT_MIN_PROB, OnlineBytePredictor, RateBackendPredictor};
12use crate::rosaplus::{RosaPlus, RosaTx};
13#[cfg(feature = "backend-rwkv")]
14use crate::rwkvzip::{Compressor as RwkvCompressor, Model as RwkvModel, State as RwkvState};
15use crate::zpaq_rate::ZpaqRateModel;
16#[cfg(any(feature = "backend-mamba", feature = "backend-rwkv"))]
17use std::sync::Arc;
18
19pub trait Predictor: Send {
25 fn update(&mut self, sym: bool);
27
28 fn update_history(&mut self, sym: bool) {
31 self.update(sym);
32 }
33
34 fn revert(&mut self);
36
37 fn pop_history(&mut self) {
39 self.revert();
40 }
41
42 fn predict_prob(&mut self, sym: bool) -> f64;
44
45 fn predict_one(&mut self) -> f64 {
47 self.predict_prob(true)
48 }
49
50 fn model_name(&self) -> String;
52
53 fn boxed_clone(&self) -> Box<dyn Predictor>;
55}
56
57pub struct CtwPredictor {
62 tree: ContextTree,
63}
64
65impl CtwPredictor {
66 pub fn new(depth: usize) -> Self {
68 Self {
69 tree: ContextTree::new(depth),
70 }
71 }
72}
73
74impl Predictor for CtwPredictor {
75 fn update(&mut self, sym: bool) {
76 self.tree.update(sym);
77 }
78 fn update_history(&mut self, sym: bool) {
79 self.tree.update_history(&[sym]);
80 }
81
82 fn revert(&mut self) {
83 self.tree.revert();
84 }
85 fn pop_history(&mut self) {
86 self.tree.revert_history();
87 }
88
89 fn predict_prob(&mut self, sym: bool) -> f64 {
90 self.tree.predict(sym)
91 }
92
93 fn model_name(&self) -> String {
94 format!("AC-CTW(d={})", self.tree.depth())
95 }
96
97 fn boxed_clone(&self) -> Box<dyn Predictor> {
98 Box::new(Self {
99 tree: self.tree.clone(),
100 })
101 }
102}
103
104pub struct FacCtwPredictor {
112 tree: FacContextTree,
113 current_bit: usize,
115 num_bits: usize,
117}
118
119impl FacCtwPredictor {
120 pub fn new(base_depth: usize, num_percept_bits: usize) -> Self {
125 Self {
126 tree: FacContextTree::new(base_depth, num_percept_bits),
127 current_bit: 0,
128 num_bits: num_percept_bits,
129 }
130 }
131}
132
133impl Predictor for FacCtwPredictor {
134 fn update(&mut self, sym: bool) {
135 self.tree.update(sym, self.current_bit);
136 self.current_bit = (self.current_bit + 1) % self.num_bits;
137 }
138
139 fn update_history(&mut self, sym: bool) {
140 self.tree.update_history(&[sym]);
141 }
142
143 fn revert(&mut self) {
144 self.current_bit = if self.current_bit == 0 {
146 self.num_bits - 1
147 } else {
148 self.current_bit - 1
149 };
150 self.tree.revert(self.current_bit);
151 }
152
153 fn pop_history(&mut self) {
154 self.tree.revert_history(1);
155 }
156
157 fn predict_prob(&mut self, sym: bool) -> f64 {
158 self.tree.predict(sym, self.current_bit)
159 }
160
161 fn model_name(&self) -> String {
162 format!("FAC-CTW(D={}, k={})", self.tree.base_depth(), self.num_bits)
163 }
164
165 fn boxed_clone(&self) -> Box<dyn Predictor> {
166 Box::new(Self {
167 tree: self.tree.clone(),
168 current_bit: self.current_bit,
169 num_bits: self.num_bits,
170 })
171 }
172}
173
174pub struct RosaPredictor {
179 model: RosaPlus,
180 history: Vec<RosaTx>,
181}
182
183impl RosaPredictor {
184 pub fn new(max_order: i64) -> Self {
187 let mut model = RosaPlus::new(max_order, false, 0, 42);
189 model.build_lm_full_bytes_no_finalize_endpos();
191 Self {
192 model,
193 history: Vec::new(),
194 }
195 }
196}
197
198impl Predictor for RosaPredictor {
199 fn update(&mut self, sym: bool) {
200 let mut tx = self.model.begin_tx();
201 let byte = if sym { 1u8 } else { 0u8 };
203
204 self.model.train_sequence_tx(&mut tx, &[byte]);
206 self.history.push(tx);
207 }
208
209 fn revert(&mut self) {
210 if let Some(tx) = self.history.pop() {
211 self.model.rollback_tx(tx);
212 }
213 }
214
215 fn predict_prob(&mut self, sym: bool) -> f64 {
216 let p0 = self.model.prob_for_last(0);
217 let p1 = self.model.prob_for_last(1);
218 let denom = (p0 + p1).max(1e-12);
219 if sym { p1 / denom } else { p0 / denom }
220 }
221
222 fn model_name(&self) -> String {
223 "ROSA".to_string()
224 }
225
226 fn boxed_clone(&self) -> Box<dyn Predictor> {
227 Box::new(Self {
228 model: self.model.clone(),
229 history: self.history.clone(),
230 })
231 }
232}
233
234pub struct ZpaqPredictor {
239 method: String,
240 min_prob: f64,
241 model: ZpaqRateModel,
242 history: Vec<u8>,
243 pending: Option<(u8, f64)>,
244}
245
246impl ZpaqPredictor {
247 pub fn new(method: String, min_prob: f64) -> Self {
249 let model = ZpaqRateModel::new(method.clone(), min_prob);
250 Self {
251 method,
252 min_prob,
253 model,
254 history: Vec::new(),
255 pending: None,
256 }
257 }
258
259 fn rebuild_from_history(&mut self) {
260 self.model.reset();
261 if !self.history.is_empty() {
262 self.model.update_and_score(&self.history);
263 }
264 }
265
266 fn log_prob_from_history(&self, symbol: u8) -> f64 {
267 let mut tmp = ZpaqRateModel::new(self.method.clone(), self.min_prob);
268 if !self.history.is_empty() {
269 tmp.update_and_score(&self.history);
270 }
271 tmp.log_prob(symbol)
272 }
273}
274
275impl Predictor for ZpaqPredictor {
276 fn update(&mut self, sym: bool) {
277 let byte = if sym { 1u8 } else { 0u8 };
278 if let Some((pending, _)) = self.pending {
279 if pending == byte {
280 self.model.update(byte);
281 self.pending = None;
282 self.history.push(byte);
283 return;
284 }
285 self.pending = None;
286 self.rebuild_from_history();
287 }
288 self.model.update(byte);
289 self.history.push(byte);
290 }
291
292 fn revert(&mut self) {
293 if self.history.pop().is_some() {
294 self.pending = None;
295 self.rebuild_from_history();
296 }
297 }
298
299 fn predict_prob(&mut self, sym: bool) -> f64 {
300 let byte = if sym { 1u8 } else { 0u8 };
301 if let Some((pending, logp)) = self.pending {
302 if pending == byte {
303 return logp.exp();
304 }
305 return self.log_prob_from_history(byte).exp();
306 }
307 let logp = self.model.log_prob(byte);
308 self.pending = Some((byte, logp));
309 logp.exp()
310 }
311
312 fn model_name(&self) -> String {
313 format!("ZPAQ({})", self.method)
314 }
315
316 fn boxed_clone(&self) -> Box<dyn Predictor> {
317 Box::new(Self {
318 method: self.method.clone(),
319 min_prob: self.min_prob,
320 model: self.model.clone(),
321 history: self.history.clone(),
322 pending: self.pending,
323 })
324 }
325}
326
327pub struct RateBackendBitPredictor {
333 backend: RateBackend,
334 max_order: i64,
335 min_prob: f64,
336 predictor: RateBackendPredictor,
337}
338
339impl RateBackendBitPredictor {
340 pub fn new(backend: RateBackend, max_order: i64) -> Result<Self, String> {
342 Self::new_with_min_prob(backend, max_order, DEFAULT_MIN_PROB)
343 }
344
345 pub fn new_with_min_prob(
347 backend: RateBackend,
348 max_order: i64,
349 min_prob: f64,
350 ) -> Result<Self, String> {
351 if rate_backend_contains_zpaq(&backend) {
352 return Err(
353 "RateBackendBitPredictor does not support zpaq backends; use a non-zpaq rate_backend"
354 .to_string(),
355 );
356 }
357 let mut predictor =
358 RateBackendPredictor::from_backend(backend.clone(), max_order, min_prob);
359 predictor
360 .begin_stream(None)
361 .map_err(|err| format!("failed to start RateBackend predictor stream: {err}"))?;
362 Ok(Self {
363 backend,
364 max_order,
365 min_prob,
366 predictor,
367 })
368 }
369
370 #[inline(always)]
371 fn bit_to_byte(sym: bool) -> u8 {
372 if sym { 1u8 } else { 0u8 }
373 }
374
375 fn clone_state(&self) -> Self {
376 Self {
377 backend: self.backend.clone(),
378 max_order: self.max_order,
379 min_prob: self.min_prob,
380 predictor: self.predictor.clone(),
381 }
382 }
383}
384
385fn rate_backend_contains_zpaq(backend: &RateBackend) -> bool {
386 match backend {
387 RateBackend::Zpaq { .. } => true,
388 RateBackend::Mixture { spec } => spec
389 .experts
390 .iter()
391 .any(|expert| rate_backend_contains_zpaq(&expert.backend)),
392 RateBackend::Calibrated { spec } => rate_backend_contains_zpaq(&spec.base),
393 _ => false,
394 }
395}
396
397impl Predictor for RateBackendBitPredictor {
398 fn update(&mut self, sym: bool) {
399 self.predictor.update(Self::bit_to_byte(sym));
400 }
401
402 fn update_history(&mut self, sym: bool) {
403 self.predictor.update_frozen(Self::bit_to_byte(sym));
404 }
405
406 fn revert(&mut self) {
407 panic!(
408 "RateBackendBitPredictor does not support generic rollback; callers must use cloned temporary predictors"
409 );
410 }
411
412 fn pop_history(&mut self) {
413 panic!(
414 "RateBackendBitPredictor does not support generic rollback; callers must use cloned temporary predictors"
415 );
416 }
417
418 fn predict_prob(&mut self, sym: bool) -> f64 {
419 let p = self.predictor.log_prob(Self::bit_to_byte(sym)).exp();
420 if p.is_finite() {
421 p.clamp(self.min_prob, 1.0 - self.min_prob)
422 } else {
423 0.5
424 }
425 }
426
427 fn model_name(&self) -> String {
428 format!(
429 "RateBackendBits({})",
430 RateBackendPredictor::default_name(&self.backend, self.max_order)
431 )
432 }
433
434 fn boxed_clone(&self) -> Box<dyn Predictor> {
435 Box::new(self.clone_state())
436 }
437}
438
439#[cfg(feature = "backend-rwkv")]
440use crate::coders::softmax_pdf_floor_inplace;
441
442#[cfg(feature = "backend-rwkv")]
447pub struct RwkvPredictor {
448 compressor: RwkvCompressor,
449 history: Vec<(RwkvState, Vec<f64>)>,
450}
451
452#[cfg(feature = "backend-rwkv")]
453impl RwkvPredictor {
454 pub fn new(model: Arc<RwkvModel>) -> Self {
456 let mut compressor = RwkvCompressor::new_from_model(model);
457 let vocab_size = compressor.vocab_size();
458 let logits = compressor
459 .model
460 .forward(&mut compressor.scratch, 0, &mut compressor.state);
461 softmax_pdf_floor_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
462
463 Self {
464 compressor,
465 history: Vec::new(),
466 }
467 }
468
469 pub fn from_method(method: &str) -> Result<Self, String> {
471 let mut compressor =
472 RwkvCompressor::new_from_method(method).map_err(|err| err.to_string())?;
473 compressor.forward_to_internal_pdf(0);
474 Ok(Self {
475 compressor,
476 history: Vec::new(),
477 })
478 }
479}
480
481#[cfg(feature = "backend-rwkv")]
482impl Predictor for RwkvPredictor {
483 fn update(&mut self, sym: bool) {
484 self.history.push((
486 self.compressor.state.clone(),
487 self.compressor.pdf_buffer.clone(),
488 ));
489
490 let byte = if sym { 1u32 } else { 0u32 };
491 let vocab_size = self.compressor.vocab_size();
492
493 let logits = self.compressor.model.forward(
494 &mut self.compressor.scratch,
495 byte,
496 &mut self.compressor.state,
497 );
498 softmax_pdf_floor_inplace(logits, vocab_size, &mut self.compressor.pdf_buffer);
499 }
500
501 fn revert(&mut self) {
502 if let Some((state, pdf)) = self.history.pop() {
503 self.compressor.state = state;
504 self.compressor.pdf_buffer = pdf;
505 }
506 }
507
508 fn predict_prob(&mut self, sym: bool) -> f64 {
509 let idx = if sym { 1 } else { 0 };
510 self.compressor.pdf_buffer[idx]
512 }
513
514 fn model_name(&self) -> String {
515 "RWKV".to_string()
516 }
517
518 fn boxed_clone(&self) -> Box<dyn Predictor> {
519 Box::new(Self {
520 compressor: self.compressor.clone(),
521 history: self.history.clone(),
522 })
523 }
524}
525
526#[cfg(feature = "backend-mamba")]
528pub struct MambaPredictor {
529 compressor: MambaCompressor,
530 history: Vec<(MambaState, Vec<f64>)>,
531}
532
533#[cfg(feature = "backend-mamba")]
534impl MambaPredictor {
535 pub fn new(model: Arc<MambaModel>) -> Self {
537 let mut compressor = MambaCompressor::new_from_model(model);
538 let logits = compressor
539 .model
540 .forward(&mut compressor.scratch, 0, &mut compressor.state)
541 .to_vec();
542 let bias = compressor.online_bias_snapshot();
543 MambaCompressor::logits_to_pdf(&logits, bias.as_deref(), &mut compressor.pdf_buffer);
544
545 Self {
546 compressor,
547 history: Vec::new(),
548 }
549 }
550
551 pub fn from_method(method: &str) -> Result<Self, String> {
553 let mut compressor =
554 MambaCompressor::new_from_method(method).map_err(|err| err.to_string())?;
555 let mut pdf = vec![0.0f64; compressor.vocab_size()];
556 compressor.forward_to_pdf(0, &mut pdf);
557 compressor.pdf_buffer.clone_from(&pdf);
558 Ok(Self {
559 compressor,
560 history: Vec::new(),
561 })
562 }
563}
564
565#[cfg(feature = "backend-mamba")]
566impl Predictor for MambaPredictor {
567 fn update(&mut self, sym: bool) {
568 self.history.push((
569 self.compressor.state.clone(),
570 self.compressor.pdf_buffer.clone(),
571 ));
572
573 let byte = if sym { 1u32 } else { 0u32 };
574 let logits = self
575 .compressor
576 .model
577 .forward(
578 &mut self.compressor.scratch,
579 byte,
580 &mut self.compressor.state,
581 )
582 .to_vec();
583 let bias = self.compressor.online_bias_snapshot();
584 MambaCompressor::logits_to_pdf(&logits, bias.as_deref(), &mut self.compressor.pdf_buffer);
585 }
586
587 fn revert(&mut self) {
588 if let Some((state, pdf)) = self.history.pop() {
589 self.compressor.state = state;
590 self.compressor.pdf_buffer = pdf;
591 }
592 }
593
594 fn predict_prob(&mut self, sym: bool) -> f64 {
595 let idx = if sym { 1 } else { 0 };
596 self.compressor.pdf_buffer[idx]
597 }
598
599 fn model_name(&self) -> String {
600 "Mamba".to_string()
601 }
602
603 fn boxed_clone(&self) -> Box<dyn Predictor> {
604 Box::new(Self {
605 compressor: self.compressor.clone(),
606 history: self.history.clone(),
607 })
608 }
609}