1use anyhow::{Result, bail};
9
10use crate::backends::calibration::CalibratorCore;
11use crate::backends::match_model::MatchModel;
12use crate::backends::ppmd::PpmdModel;
13use crate::backends::sequitur::SequiturModel;
14use crate::backends::sparse_match::SparseMatchModel;
15use crate::backends::text_context::TextContextAnalyzer;
16use crate::coders::{
17 ANS_TOTAL, ArithmeticDecoder, ArithmeticEncoder, BlockedRansDecoder, BlockedRansEncoder,
18 CDF_TOTAL, Cdf, CoderType, crc32, quantize_pdf_to_rans_cdf_with_buffer,
19};
20use crate::ctw::FacContextTree;
21#[cfg(feature = "backend-mamba")]
22use crate::mambazip;
23use crate::mixture::{
24 DEFAULT_MIN_PROB, convex_step_size_for_update, project_simplex_with_scratch,
25 switching_alpha_for_update,
26};
27use crate::neural_mix::NeuralMixCore;
28use crate::rosaplus::RosaPlus;
29#[cfg(feature = "backend-rwkv")]
30use crate::rwkvzip;
31use crate::zpaq_rate::ZpaqRateModel;
32use crate::{CalibratedSpec, MixtureKind, MixtureScheduleMode, MixtureSpec, RateBackend};
33use rayon::{ThreadPool, prelude::*};
34
35const FRAMED_MAGIC: u32 = 0x4354_4946; const FRAMED_VERSION: u8 = 1;
37const PDF_MIN: f64 = DEFAULT_MIN_PROB;
38const DIAGNOSTIC_PARALLEL_THRESHOLD: usize = 4;
39
40#[inline]
41fn build_calibrator(spec: &CalibratedSpec) -> CalibratorCore {
42 CalibratorCore::new(spec.context, spec.bins, spec.learning_rate, spec.bias_clip)
43}
44
45#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
46pub enum FramingMode {
48 Raw,
50 #[default]
52 Framed,
53}
54
55#[derive(Clone, Copy, Debug)]
56struct FramedHeader {
57 magic: u32,
58 version: u8,
59 coder: u8,
60 original_len: u64,
61 crc32: u32,
62}
63
64impl FramedHeader {
65 const SIZE: usize = 4 + 1 + 1 + 8 + 4;
66
67 fn new(coder: CoderType, original_len: u64, crc32: u32) -> Self {
68 Self {
69 magic: FRAMED_MAGIC,
70 version: FRAMED_VERSION,
71 coder: match coder {
72 CoderType::AC => 0,
73 CoderType::RANS => 1,
74 },
75 original_len,
76 crc32,
77 }
78 }
79
80 fn write(&self, out: &mut Vec<u8>) {
81 out.extend_from_slice(&self.magic.to_le_bytes());
82 out.push(self.version);
83 out.push(self.coder);
84 out.extend_from_slice(&self.original_len.to_le_bytes());
85 out.extend_from_slice(&self.crc32.to_le_bytes());
86 }
87
88 fn read(input: &[u8]) -> Result<Self> {
89 if input.len() < Self::SIZE {
90 bail!("framed payload too short");
91 }
92 let magic = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
93 if magic != FRAMED_MAGIC {
94 bail!("invalid framed magic: expected 0x{FRAMED_MAGIC:08X}, got 0x{magic:08X}");
95 }
96 let version = input[4];
97 if version != FRAMED_VERSION {
98 bail!("unsupported framed version: {version}");
99 }
100 let coder = input[5];
101 let original_len = u64::from_le_bytes([
102 input[6], input[7], input[8], input[9], input[10], input[11], input[12], input[13],
103 ]);
104 let crc32 = u32::from_le_bytes([input[14], input[15], input[16], input[17]]);
105 Ok(Self {
106 magic,
107 version,
108 coder,
109 original_len,
110 crc32,
111 })
112 }
113
114 fn coder_type(&self) -> CoderType {
115 match self.coder {
116 0 => CoderType::AC,
117 _ => CoderType::RANS,
118 }
119 }
120}
121
122#[derive(Clone)]
123struct CtwPredictor {
124 tree: FacContextTree,
125 bits_per_symbol: usize,
126 msb_first: bool,
127 pdf: Vec<f64>,
128 pattern_logps: Vec<f64>,
129 valid: bool,
130}
131
132impl CtwPredictor {
133 fn new_ctw(depth: usize) -> Self {
134 Self {
135 tree: FacContextTree::new(depth, 8),
136 bits_per_symbol: 8,
137 msb_first: true,
138 pdf: vec![0.0; 256],
139 pattern_logps: vec![f64::NEG_INFINITY; 256],
140 valid: false,
141 }
142 }
143
144 fn new_fac(base_depth: usize, bits_per_symbol: usize) -> Self {
145 Self {
146 tree: FacContextTree::new(base_depth, bits_per_symbol),
147 bits_per_symbol,
148 msb_first: false,
149 pdf: vec![0.0; 256],
150 pattern_logps: vec![f64::NEG_INFINITY; 256],
151 valid: false,
152 }
153 }
154
155 fn fill_pattern_log_probs(&mut self) -> usize {
156 fn rec(
157 tree: &mut FacContextTree,
158 bits: usize,
159 msb_first: bool,
160 depth: usize,
161 pattern: usize,
162 log_before: f64,
163 out: &mut [f64],
164 ) {
165 if depth == bits {
166 out[pattern] = tree.get_log_block_probability() - log_before;
167 return;
168 }
169 for bit in [false, true] {
170 tree.update(bit, depth);
171 let next_pattern = if msb_first {
172 (pattern << 1) | (bit as usize)
173 } else {
174 pattern | ((bit as usize) << depth)
175 };
176 rec(
177 tree,
178 bits,
179 msb_first,
180 depth + 1,
181 next_pattern,
182 log_before,
183 out,
184 );
185 tree.revert(depth);
186 }
187 }
188
189 let bits = self.bits_per_symbol.clamp(1, 8);
190 let patterns = 1usize << bits;
191 let log_before = self.tree.get_log_block_probability();
192 self.pattern_logps[..patterns].fill(f64::NEG_INFINITY);
193 rec(
194 &mut self.tree,
195 bits,
196 self.msb_first,
197 0,
198 0,
199 log_before,
200 &mut self.pattern_logps[..patterns],
201 );
202 patterns
203 }
204
205 #[cfg(test)]
206 fn log_prob_symbol_bruteforce(&mut self, symbol: u8) -> f64 {
207 let bits = self.bits_per_symbol.clamp(1, 8);
208 let before = self.tree.get_log_block_probability();
209 if self.msb_first {
210 for bit_idx in 0..bits {
211 let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
212 self.tree.update(bit, bit_idx);
213 }
214 let after = self.tree.get_log_block_probability();
215 for bit_idx in (0..bits).rev() {
216 self.tree.revert(bit_idx);
217 }
218 after - before
219 } else {
220 for bit_idx in 0..bits {
221 let bit = ((symbol >> bit_idx) & 1) == 1;
222 self.tree.update(bit, bit_idx);
223 }
224 let after = self.tree.get_log_block_probability();
225 for bit_idx in (0..bits).rev() {
226 self.tree.revert(bit_idx);
227 }
228 after - before
229 }
230 }
231
232 fn normalize_pdf(pdf: &mut [f64]) {
233 let mut sum = 0.0f64;
234 for p in pdf.iter_mut() {
235 let v = if p.is_finite() { *p } else { 0.0 };
236 *p = v.max(PDF_MIN);
237 sum += *p;
238 }
239 if sum <= 0.0 || !sum.is_finite() {
240 let u = 1.0 / (pdf.len() as f64);
241 for p in pdf.iter_mut() {
242 *p = u;
243 }
244 return;
245 }
246 let inv = 1.0 / sum;
247 for p in pdf.iter_mut() {
248 *p *= inv;
249 }
250 }
251
252 fn pdf_next(&mut self) -> &[f64] {
253 if !self.valid {
254 let bits = self.bits_per_symbol.clamp(1, 8);
255 let patterns = self.fill_pattern_log_probs();
256 if bits == 8 {
257 for sym in 0..256usize {
258 self.pdf[sym] = self.pattern_logps[sym].exp();
259 }
260 } else {
261 let aliases = 1usize << (8 - bits);
262 for byte in 0..256usize {
263 let pat = if self.msb_first {
264 byte >> (8 - bits)
265 } else {
266 byte & (patterns - 1)
267 };
268 self.pdf[byte] = self.pattern_logps[pat].exp() / (aliases as f64);
269 }
270 }
271 Self::normalize_pdf(&mut self.pdf);
272 self.valid = true;
273 }
274 &self.pdf
275 }
276
277 fn update(&mut self, symbol: u8) {
278 if self.msb_first {
279 for bit_idx in 0..self.bits_per_symbol {
280 let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
281 self.tree.update(bit, bit_idx);
282 }
283 } else {
284 for bit_idx in 0..self.bits_per_symbol {
285 let bit = ((symbol >> bit_idx) & 1) == 1;
286 self.tree.update(bit, bit_idx);
287 }
288 }
289 self.valid = false;
290 }
291
292 #[inline]
293 fn can_fast_ac_bitwise(&self) -> bool {
294 self.bits_per_symbol == 8 && self.msb_first
295 }
296
297 #[inline]
298 fn bit_prob_one_msb(&mut self, bit_idx: usize) -> f64 {
299 debug_assert!(self.can_fast_ac_bitwise());
300 self.tree.predict_one(bit_idx).clamp(PDF_MIN, 1.0 - PDF_MIN)
301 }
302
303 #[inline]
304 fn update_bit_msb(&mut self, bit_idx: usize, bit: bool) {
305 debug_assert!(self.can_fast_ac_bitwise());
306 self.tree.update_predicted(bit, bit_idx);
307 self.valid = false;
308 }
309}
310
311#[derive(Clone)]
312struct RosaPredictor {
313 model: RosaPlus,
314 pdf: Vec<f64>,
315 cdf: [f64; 257],
316 valid: bool,
317 cdf_valid: bool,
318}
319
320impl RosaPredictor {
321 fn new(max_order: i64) -> Self {
322 let mut model = RosaPlus::new(max_order, false, 0, 42);
323 model.build_lm_full_bytes_no_finalize_endpos();
324 Self {
325 model,
326 pdf: vec![0.0; 256],
327 cdf: uniform_cdf_row(),
328 valid: false,
329 cdf_valid: false,
330 }
331 }
332
333 fn pdf_next(&mut self) -> &[f64] {
334 self.ensure_pdf(false);
335 &self.pdf
336 }
337
338 fn cdf_next(&mut self) -> &[f64; 257] {
339 self.ensure_pdf(true);
340 &self.cdf
341 }
342
343 fn ensure_pdf(&mut self, want_cdf: bool) {
344 if self.valid {
345 if want_cdf && !self.cdf_valid {
346 build_cdf_row_from_pdf_slice(&self.pdf, &mut self.cdf);
347 self.cdf_valid = true;
348 }
349 return;
350 }
351 self.model.fill_probs_for_last_bytes(&mut self.pdf);
352 normalize_pdf_vec_and_maybe_build_cdf(
353 &mut self.pdf,
354 if want_cdf { Some(&mut self.cdf) } else { None },
355 );
356 self.valid = true;
357 self.cdf_valid = want_cdf;
358 }
359
360 fn update(&mut self, symbol: u8) {
361 self.model.train_byte(symbol);
362 self.valid = false;
363 self.cdf_valid = false;
364 }
365
366 fn begin_stream(&mut self, total_len: usize) {
367 self.model.reserve_for_stream(total_len);
368 }
369}
370
371#[derive(Clone)]
372#[cfg(feature = "backend-mamba")]
373struct MambaPredictor {
374 compressor: mambazip::Compressor,
375 primed: bool,
376 pdf: Vec<f64>,
377 cdf: [f64; 257],
378 valid: bool,
379 cdf_valid: bool,
380}
381
382#[derive(Clone)]
383#[cfg(feature = "backend-rwkv")]
384struct RwkvPredictor {
385 compressor: rwkvzip::Compressor,
386 primed: bool,
387 cdf: [f64; 257],
388 cdf_valid: bool,
389}
390
391#[derive(Clone)]
392struct ZpaqPredictor {
393 method: String,
394 history: Vec<u8>,
395 pdf: Vec<f64>,
396 valid: bool,
397}
398
399impl ZpaqPredictor {
400 fn new(method: String) -> Self {
401 Self {
402 method,
403 history: Vec::new(),
404 pdf: vec![0.0; 256],
405 valid: false,
406 }
407 }
408
409 fn pdf_next(&mut self) -> &[f64] {
410 if !self.valid {
411 for sym in 0..256usize {
412 let mut model = ZpaqRateModel::new(self.method.clone(), PDF_MIN);
413 if !self.history.is_empty() {
414 let _ = model.update_and_score(&self.history);
415 }
416 let logp = model.log_prob(sym as u8);
417 self.pdf[sym] = logp.exp().max(PDF_MIN);
418 }
419 normalize_pdf(&mut self.pdf);
420 self.valid = true;
421 }
422 &self.pdf
423 }
424
425 fn update(&mut self, symbol: u8) {
426 self.history.push(symbol);
427 self.valid = false;
428 }
429}
430
431#[cfg(feature = "backend-mamba")]
432impl MambaPredictor {
433 fn from_model(model: std::sync::Arc<mambazip::Model>) -> Self {
434 let compressor = mambazip::Compressor::new_from_model(model);
435 let vocab = compressor.vocab_size();
436 Self {
437 compressor,
438 primed: false,
439 pdf: vec![0.0; vocab],
440 cdf: uniform_cdf_row(),
441 valid: false,
442 cdf_valid: false,
443 }
444 }
445
446 fn from_method(method: &str) -> Result<Self> {
447 let compressor = mambazip::Compressor::new_from_method(method)?;
448 let vocab = compressor.vocab_size();
449 Ok(Self {
450 compressor,
451 primed: false,
452 pdf: vec![0.0; vocab],
453 cdf: uniform_cdf_row(),
454 valid: false,
455 cdf_valid: false,
456 })
457 }
458
459 fn ensure_predicted(&mut self, want_cdf: bool) {
460 if self.valid {
461 if want_cdf && !self.cdf_valid {
462 debug_assert!(self.pdf.len() >= 256);
463 build_cdf_row_from_pdf_slice(&self.pdf[..256], &mut self.cdf);
464 self.cdf_valid = true;
465 }
466 return;
467 }
468 if !self.primed {
469 self.compressor.forward_to_pdf(0, &mut self.pdf);
470 self.primed = true;
471 self.valid = true;
472 self.cdf_valid = false;
473 if want_cdf {
474 debug_assert!(self.pdf.len() >= 256);
475 build_cdf_row_from_pdf_slice(&self.pdf[..256], &mut self.cdf);
476 self.cdf_valid = true;
477 }
478 return;
479 }
480 self.valid = true;
481 self.cdf_valid = false;
482 if want_cdf {
483 debug_assert!(self.pdf.len() >= 256);
484 build_cdf_row_from_pdf_slice(&self.pdf[..256], &mut self.cdf);
485 self.cdf_valid = true;
486 }
487 }
488
489 fn pdf_next(&mut self) -> &[f64] {
490 self.ensure_predicted(false);
491 &self.pdf
492 }
493
494 fn cdf_next(&mut self) -> &[f64; 257] {
495 self.ensure_predicted(true);
496 &self.cdf
497 }
498
499 fn update(&mut self, symbol: u8) -> Result<()> {
500 self.ensure_predicted(false);
501 self.compressor.online_update_from_pdf(symbol, &self.pdf)?;
502 self.compressor.forward_to_pdf(symbol as u32, &mut self.pdf);
503 self.valid = true;
504 self.cdf_valid = false;
505 Ok(())
506 }
507
508 fn begin_stream(&mut self, total_len: usize) -> Result<()> {
509 self.compressor
510 .begin_online_policy_stream(Some(total_len as u64))
511 }
512}
513
514#[cfg(feature = "backend-rwkv")]
515impl RwkvPredictor {
516 fn from_model(model: std::sync::Arc<rwkvzip::Model>) -> Self {
517 let compressor = rwkvzip::Compressor::new_from_model(model);
518 Self {
519 compressor,
520 primed: false,
521 cdf: uniform_cdf_row(),
522 cdf_valid: false,
523 }
524 }
525
526 fn from_method(method: &str) -> Result<Self> {
527 let compressor = rwkvzip::Compressor::new_from_method(method)?;
528 Ok(Self {
529 compressor,
530 primed: false,
531 cdf: uniform_cdf_row(),
532 cdf_valid: false,
533 })
534 }
535
536 fn ensure_predicted(&mut self, want_cdf: bool) {
537 if !self.primed {
538 self.compressor.reset_and_prime();
539 self.primed = true;
540 self.cdf_valid = false;
541 }
542 if want_cdf && !self.cdf_valid {
543 debug_assert!(self.compressor.pdf_buffer.len() >= 256);
544 build_cdf_row_from_pdf_slice(&self.compressor.pdf_buffer[..256], &mut self.cdf);
545 self.cdf_valid = true;
546 }
547 }
548
549 fn pdf_next(&mut self) -> &[f64] {
550 self.ensure_predicted(false);
551 &self.compressor.pdf_buffer
552 }
553
554 fn cdf_next(&mut self) -> &[f64; 257] {
555 self.ensure_predicted(true);
556 &self.cdf
557 }
558
559 fn update(&mut self, symbol: u8) -> Result<()> {
560 self.ensure_predicted(false);
561 self.compressor.observe_symbol_from_current_pdf(symbol)?;
562 self.cdf_valid = false;
563 Ok(())
564 }
565
566 fn begin_stream(&mut self, total_len: usize) -> Result<()> {
567 self.compressor
568 .begin_online_policy_stream(Some(total_len as u64))
569 }
570
571 fn finish_stream(&mut self) -> Result<()> {
572 self.compressor.finish_online_policy_stream()
573 }
574}
575
576#[derive(Clone)]
577struct MixExpert {
578 predictor: Box<RatePdfPredictor>,
579 log_weight: f64,
580 log_prior: f64,
581 cum_log_loss: f64,
582}
583
584#[derive(Clone, Copy, Debug, Default)]
585pub(crate) struct AcLogLossNodeValue {
586 pub(crate) prob: f64,
587 pub(crate) local_weight: f64,
588 pub(crate) effective_weight: f64,
589}
590
591#[derive(Clone, Debug, Default)]
592pub(crate) struct AcLogLossSubtreeSnapshot {
593 pub(crate) prob: f64,
594 pub(crate) rows: Vec<AcLogLossNodeValue>,
595}
596
597#[derive(Clone, Copy, Debug, Default)]
598pub(crate) struct AcLogLossRootSnapshot {
599 pub(crate) mix_prob: f64,
600 pub(crate) root_weight_entropy_bits: f64,
601 pub(crate) root_top1_child_index: Option<usize>,
602 pub(crate) root_top1_weight: f64,
603 pub(crate) root_top2_child_index: Option<usize>,
604 pub(crate) root_top2_weight: f64,
605}
606
607#[derive(Clone)]
608struct MixturePredictor {
609 kind: MixtureKind,
610 schedule: MixtureScheduleMode,
611 alpha: f64,
612 decay: f64,
613 experts: Vec<MixExpert>,
614 prior_weights: Vec<f64>,
615 neural: NeuralMixCore,
616 analyzer: TextContextAnalyzer,
617 neural_logps: Vec<f64>,
618 neural_bit_modes: Vec<u8>,
619 neural_lo: Vec<usize>,
620 neural_hi: Vec<usize>,
621 neural_pdf_cdf_rows: Vec<Vec<f64>>,
622 scratch: Vec<f64>,
623 scratch2: Vec<f64>,
624 projection_scratch: Vec<f64>,
625 pdf: Vec<f64>,
626 valid: bool,
627 switch_updates: u64,
628 convex_updates: u64,
629}
630
631impl MixturePredictor {
632 fn new(spec: &MixtureSpec) -> Result<Self> {
633 spec.validate().map_err(anyhow::Error::msg)?;
634 let mut experts = Vec::with_capacity(spec.experts.len());
635 for e in &spec.experts {
636 experts.push(MixExpert {
637 predictor: Box::new(RatePdfPredictor::from_rate_backend(
638 e.backend.clone(),
639 e.max_order,
640 )?),
641 log_weight: e.log_prior,
642 log_prior: e.log_prior,
643 cum_log_loss: 0.0,
644 });
645 }
646 let m = logsumexp_expert_weights(&experts);
647 for e in &mut experts {
648 e.log_weight -= m;
649 }
650
651 let mut prior_weights = vec![0.0; experts.len()];
652 normalized_mix_expert_prior_weights(&experts, &mut prior_weights);
653 let mut neural_prior_weights = prior_weights.clone();
654 for weight in &mut neural_prior_weights {
655 *weight = weight.clamp(PDF_MIN, 1.0 - PDF_MIN);
656 }
657
658 let base_lr = spec.alpha.abs().clamp(1e-6, 1.0);
659 let effective_lr = (base_lr * 25.0).clamp(1e-6, 1.0);
660 let analyzer = TextContextAnalyzer::new();
661 let mut neural = NeuralMixCore::new(
662 experts.len(),
663 &neural_prior_weights,
664 effective_lr * 0.5,
665 effective_lr,
666 1e-5,
667 );
668 neural.set_context_state(analyzer.state());
669 Ok(Self {
670 kind: spec.kind,
671 schedule: spec.schedule,
672 alpha: spec.alpha,
673 decay: spec.decay.unwrap_or(1.0).clamp(0.0, 1.0),
674 experts,
675 prior_weights,
676 neural,
677 analyzer,
678 neural_logps: vec![0.0; spec.experts.len()],
679 neural_bit_modes: vec![0; spec.experts.len()],
680 neural_lo: vec![0; spec.experts.len()],
681 neural_hi: vec![256; spec.experts.len()],
682 neural_pdf_cdf_rows: vec![vec![0.0; 257]; spec.experts.len()],
683 scratch: Vec::new(),
684 scratch2: Vec::new(),
685 projection_scratch: Vec::new(),
686 pdf: vec![0.0; 256],
687 valid: false,
688 switch_updates: 0,
689 convex_updates: 0,
690 })
691 }
692
693 fn best_expert_index(&self) -> Option<usize> {
694 let mut best_idx = None;
695 let mut best_loss = f64::INFINITY;
696 for (index, expert) in self.experts.iter().enumerate() {
697 if expert.cum_log_loss < best_loss {
698 best_loss = expert.cum_log_loss;
699 best_idx = Some(index);
700 }
701 }
702 best_idx
703 }
704
705 fn predictive_weights(&mut self) -> Vec<f64> {
706 if self.experts.is_empty() {
707 return Vec::new();
708 }
709
710 match self.kind {
711 MixtureKind::Neural => {
712 if self.experts.len() == 1 {
713 return vec![1.0];
714 }
715 self.neural.set_context_state(self.analyzer.state());
716 self.neural.evaluate_expert_weights();
717 let mut weights = self.neural.expert_weights().to_vec();
718 normalize_simplex_weights(&mut weights);
719 weights
720 }
721 MixtureKind::Mdl => {
722 let mut weights = vec![0.0; self.experts.len()];
723 if let Some(best_idx) = self.best_expert_index() {
724 weights[best_idx] = 1.0;
725 }
726 weights
727 }
728 MixtureKind::FadingBayes => {
729 let max_log = self
730 .experts
731 .iter()
732 .map(|expert| self.decay * expert.log_weight)
733 .fold(f64::NEG_INFINITY, f64::max);
734 let mut weights = self
735 .experts
736 .iter()
737 .map(|expert| {
738 if max_log.is_finite() {
739 (self.decay * expert.log_weight - max_log).exp()
740 } else {
741 0.0
742 }
743 })
744 .collect::<Vec<_>>();
745 normalize_simplex_weights(&mut weights);
746 weights
747 }
748 MixtureKind::Convex => {
749 let mut weights = self
750 .experts
751 .iter()
752 .map(|expert| expert.log_weight.exp())
753 .collect::<Vec<_>>();
754 normalize_simplex_weights(&mut weights);
755 weights
756 }
757 MixtureKind::Bayes | MixtureKind::Switching => {
758 let max_log = self
759 .experts
760 .iter()
761 .map(|expert| expert.log_weight)
762 .fold(f64::NEG_INFINITY, f64::max);
763 let mut weights = self
764 .experts
765 .iter()
766 .map(|expert| {
767 if max_log.is_finite() {
768 (expert.log_weight - max_log).exp()
769 } else {
770 0.0
771 }
772 })
773 .collect::<Vec<_>>();
774 normalize_simplex_weights(&mut weights);
775 weights
776 }
777 }
778 }
779
780 fn ensure_pdf(&mut self) -> Result<&[f64]> {
781 if self.valid {
782 return Ok(&self.pdf);
783 }
784 let weights = self.predictive_weights();
785 if weights.len() == 1 && matches!(self.kind, MixtureKind::Mdl | MixtureKind::Neural) {
786 self.pdf.fill(0.0);
787 } else {
788 self.pdf.fill(0.0);
789 }
790 for (index, expert) in self.experts.iter_mut().enumerate() {
791 let weight = weights.get(index).copied().unwrap_or(0.0);
792 if weight <= 0.0 {
793 continue;
794 }
795 let epdf = expert.predictor.pdf_next()?;
796 for (slot, &p) in self.pdf.iter_mut().zip(epdf.iter()) {
797 *slot += weight * p;
798 }
799 }
800
801 normalize_pdf(&mut self.pdf);
802 self.valid = true;
803 Ok(&self.pdf)
804 }
805
806 fn begin_stream(&mut self, total_len: usize) -> Result<()> {
807 for expert in &mut self.experts {
808 match &mut *expert.predictor {
809 RatePdfPredictor::Ctw(_) | RatePdfPredictor::FacCtw(_) => {}
812 _ => expert.predictor.begin_stream(total_len)?,
813 }
814 }
815 Ok(())
816 }
817
818 fn diagnostic_collect_children(
819 &mut self,
820 symbol: u8,
821 weights: &[f64],
822 effective_prefix: f64,
823 pool: Option<&ThreadPool>,
824 ) -> Result<Vec<AcLogLossSubtreeSnapshot>> {
825 let use_parallel = pool.is_some() && self.experts.len() >= DIAGNOSTIC_PARALLEL_THRESHOLD;
826 if use_parallel {
827 let pool = pool.expect("checked is_some");
828 pool.install(|| {
829 self.experts
830 .par_iter_mut()
831 .enumerate()
832 .map(|(index, expert)| {
833 let local_weight = weights.get(index).copied().unwrap_or(0.0);
834 let effective_weight = effective_prefix * local_weight;
835 expert.predictor.diagnostic_snapshot_subtree(
836 symbol,
837 local_weight,
838 effective_weight,
839 None,
840 )
841 })
842 .collect()
843 })
844 } else {
845 let mut children = Vec::with_capacity(self.experts.len());
846 for (index, expert) in self.experts.iter_mut().enumerate() {
847 let local_weight = weights.get(index).copied().unwrap_or(0.0);
848 let effective_weight = effective_prefix * local_weight;
849 children.push(expert.predictor.diagnostic_snapshot_subtree(
850 symbol,
851 local_weight,
852 effective_weight,
853 pool,
854 )?);
855 }
856 Ok(children)
857 }
858 }
859
860 fn diagnostic_subtree_snapshot(
861 &mut self,
862 symbol: u8,
863 local_weight: f64,
864 effective_weight: f64,
865 pool: Option<&ThreadPool>,
866 ) -> Result<AcLogLossSubtreeSnapshot> {
867 let weights = self.predictive_weights();
868 let children =
869 self.diagnostic_collect_children(symbol, &weights, effective_weight, pool)?;
870 let mix_prob = children
871 .iter()
872 .enumerate()
873 .map(|(index, child)| weights.get(index).copied().unwrap_or(0.0) * child.prob)
874 .sum::<f64>()
875 .max(PDF_MIN);
876 let total_rows = 1 + children.iter().map(|child| child.rows.len()).sum::<usize>();
877 let mut rows = Vec::with_capacity(total_rows);
878 rows.push(AcLogLossNodeValue {
879 prob: mix_prob,
880 local_weight,
881 effective_weight,
882 });
883 for child in children {
884 rows.extend(child.rows);
885 }
886 Ok(AcLogLossSubtreeSnapshot {
887 prob: mix_prob,
888 rows,
889 })
890 }
891
892 fn diagnostic_root_snapshot(
893 &mut self,
894 symbol: u8,
895 pool: Option<&ThreadPool>,
896 out: &mut Vec<AcLogLossNodeValue>,
897 ) -> Result<AcLogLossRootSnapshot> {
898 let weights = self.predictive_weights();
899 let children = self.diagnostic_collect_children(symbol, &weights, 1.0, pool)?;
900 out.clear();
901 out.reserve(children.iter().map(|child| child.rows.len()).sum::<usize>());
902 for child in &children {
903 out.extend_from_slice(&child.rows);
904 }
905
906 let mix_prob = children
907 .iter()
908 .enumerate()
909 .map(|(index, child)| weights.get(index).copied().unwrap_or(0.0) * child.prob)
910 .sum::<f64>()
911 .max(PDF_MIN);
912
913 let mut top1 = None;
914 let mut top2 = None;
915 for (index, &weight) in weights.iter().enumerate() {
916 match top1 {
917 None => top1 = Some((index, weight)),
918 Some((best_idx, best_weight)) if weight > best_weight => {
919 top2 = Some((best_idx, best_weight));
920 top1 = Some((index, weight));
921 }
922 _ => match top2 {
923 None => top2 = Some((index, weight)),
924 Some((_, second_weight)) if weight > second_weight => {
925 top2 = Some((index, weight));
926 }
927 _ => {}
928 },
929 }
930 }
931
932 let root_weight_entropy_bits = weights
933 .iter()
934 .copied()
935 .filter(|weight| *weight > 0.0)
936 .map(|weight| -weight * weight.log2())
937 .sum::<f64>();
938
939 Ok(AcLogLossRootSnapshot {
940 mix_prob,
941 root_weight_entropy_bits,
942 root_top1_child_index: top1.map(|(index, _)| index),
943 root_top1_weight: top1.map(|(_, weight)| weight).unwrap_or(0.0),
944 root_top2_child_index: top2.map(|(index, _)| index),
945 root_top2_weight: top2.map(|(_, weight)| weight).unwrap_or(0.0),
946 })
947 }
948
949 fn update(&mut self, symbol: u8) -> Result<()> {
950 let _ = self.ensure_pdf()?;
951
952 match self.kind {
953 MixtureKind::Bayes => {
954 let n = self.experts.len();
955 self.scratch.resize(n, 0.0);
956 self.scratch2.resize(n, 0.0);
957 for (i, e) in self.experts.iter_mut().enumerate() {
958 let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
959 let lp = p.ln();
960 self.scratch[i] = lp;
961 self.scratch2[i] = e.log_weight + lp;
962 }
963 let log_mix = logsumexp_slice(&self.scratch2[..n]);
964 for (i, e) in self.experts.iter_mut().enumerate() {
965 e.log_weight = e.log_weight + self.scratch[i] - log_mix;
966 e.cum_log_loss -= self.scratch[i];
967 e.predictor.update(symbol)?;
968 }
969 }
970 MixtureKind::FadingBayes => {
971 let n = self.experts.len();
972 self.scratch.resize(n, 0.0);
973 self.scratch2.resize(n, 0.0);
974 for (i, e) in self.experts.iter_mut().enumerate() {
975 let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
976 let lp = p.ln();
977 self.scratch[i] = lp;
978 self.scratch2[i] = e.log_weight + lp;
979 }
980 for (i, e) in self.experts.iter_mut().enumerate() {
981 self.scratch2[i] = self.decay * e.log_weight + self.scratch[i];
982 }
983 let log_mix = logsumexp_slice(&self.scratch2[..n]);
984 for (i, e) in self.experts.iter_mut().enumerate() {
985 e.log_weight = self.decay * e.log_weight + self.scratch[i] - log_mix;
986 e.cum_log_loss -= self.scratch[i];
987 e.predictor.update(symbol)?;
988 }
989 }
990 MixtureKind::Switching => {
991 let n = self.experts.len();
992 self.scratch.resize(n, 0.0);
993 self.scratch2.resize(n, 0.0);
994 for (i, e) in self.experts.iter_mut().enumerate() {
995 let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
996 let lp = p.ln();
997 self.scratch[i] = lp;
998 self.scratch2[i] = e.log_weight + lp;
999 }
1000 let log_mix = logsumexp_slice(&self.scratch2[..n]);
1001 for (i, e) in self.experts.iter_mut().enumerate() {
1002 self.scratch2[i] = (self.scratch2[i] - log_mix).exp();
1003 e.cum_log_loss -= self.scratch[i];
1004 e.predictor.update(symbol)?;
1005 }
1006 let alpha =
1007 switching_alpha_for_update(self.schedule, self.alpha, self.switch_updates);
1008 self.switch_updates = self.switch_updates.saturating_add(1);
1009 apply_switching_weights(
1010 &mut self.experts,
1011 &self.prior_weights[..n],
1012 alpha,
1013 &mut self.scratch2[..n],
1014 &mut self.scratch[..n],
1015 );
1016 }
1017 MixtureKind::Convex => {
1018 let n = self.experts.len();
1019 self.scratch.resize(n, 0.0);
1020 self.scratch2.resize(n, 0.0);
1021 for (i, e) in self.experts.iter_mut().enumerate() {
1022 let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
1023 let lp = p.ln();
1024 self.scratch[i] = lp;
1025 self.scratch2[i] = e.log_weight.exp();
1026 e.cum_log_loss -= lp;
1027 e.predictor.update(symbol)?;
1028 }
1029 let mix_prob = self
1030 .scratch
1031 .iter()
1032 .zip(self.scratch2.iter())
1033 .map(|(&lp, &w)| w * lp.exp())
1034 .sum::<f64>()
1035 .max(PDF_MIN);
1036 let log_mix = mix_prob.ln();
1037 self.convex_updates = self.convex_updates.saturating_add(1);
1038 let eta =
1039 convex_step_size_for_update(self.schedule, self.alpha, self.convex_updates);
1040 for i in 0..n {
1041 let grad = -(self.scratch[i] - log_mix).exp();
1042 self.scratch2[i] -= eta * grad;
1043 }
1044 project_simplex_with_scratch(&mut self.scratch2[..n], &mut self.projection_scratch);
1045 for i in 0..n {
1046 self.experts[i].log_weight = self.scratch2[i].max(PDF_MIN).ln();
1047 }
1048 }
1049 MixtureKind::Mdl => {
1050 let n = self.experts.len();
1051 self.scratch.resize(n, 0.0);
1052 for (i, e) in self.experts.iter_mut().enumerate() {
1053 let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
1054 let lp = p.ln();
1055 self.scratch[i] = lp;
1056 }
1057 for (i, e) in self.experts.iter_mut().enumerate() {
1058 e.cum_log_loss -= self.scratch[i];
1059 e.predictor.update(symbol)?;
1060 }
1061 }
1062 MixtureKind::Neural => {
1063 let y = symbol as usize;
1064 if self.experts.len() == 1 {
1065 let lp = self.experts[0].predictor.pdf_next()?[y].max(PDF_MIN).ln();
1066 self.experts[0].cum_log_loss -= lp;
1067 self.experts[0].predictor.update(symbol)?;
1068 self.analyzer.update(symbol);
1069 self.neural.set_context_state(self.analyzer.state());
1070 self.valid = false;
1071 return Ok(());
1072 }
1073 let n = self.experts.len();
1074 self.neural.set_context_state(self.analyzer.state());
1075 self.neural_logps.resize(n, 0.0);
1076 for i in 0..n {
1077 let p = self.experts[i].predictor.pdf_next()?[y].max(PDF_MIN);
1078 let lp = p.ln();
1079 self.neural_logps[i] = lp;
1080 self.experts[i].cum_log_loss -= lp;
1081 }
1082 self.neural.evaluate_symbol(&self.neural_logps, PDF_MIN);
1083 self.neural
1084 .update_weights_symbol(&self.neural_logps, PDF_MIN);
1085 for e in &mut self.experts {
1086 e.predictor.update(symbol)?;
1087 }
1088 self.analyzer.update(symbol);
1089 self.neural.set_context_state(self.analyzer.state());
1090 }
1091 }
1092
1093 self.valid = false;
1094 Ok(())
1095 }
1096
1097 fn finish_stream(&mut self) -> Result<()> {
1098 for expert in &mut self.experts {
1099 expert.predictor.finish_stream()?;
1100 }
1101 Ok(())
1102 }
1103
1104 #[inline]
1105 fn can_fast_ac_bitwise(&self) -> bool {
1106 self.experts.iter().any(|e| {
1107 if let RatePdfPredictor::Ctw(ctw) = &*e.predictor {
1108 ctw.can_fast_ac_bitwise()
1109 } else {
1110 false
1111 }
1112 })
1113 }
1114
1115 fn ac_step_bitwise<F>(&mut self, mut choose_bit: F) -> Result<u8>
1116 where
1117 F: FnMut(usize, f64) -> Result<u8>,
1118 {
1119 let n = self.experts.len();
1120 self.scratch.resize(n, 0.0);
1121 match self.kind {
1122 MixtureKind::Neural if n > 1 => {
1123 self.neural.set_context_state(self.analyzer.state());
1124 self.neural.evaluate_expert_weights();
1125 self.scratch.copy_from_slice(self.neural.expert_weights());
1126 }
1127 MixtureKind::FadingBayes => {
1128 let weights = self.predictive_weights();
1129 self.scratch.copy_from_slice(&weights);
1130 }
1131 MixtureKind::Mdl => {
1132 let weights = self.predictive_weights();
1133 self.scratch.copy_from_slice(&weights);
1134 }
1135 _ => {
1136 let weights = self.predictive_weights();
1137 self.scratch.copy_from_slice(&weights);
1138 }
1139 }
1140 self.scratch2.resize(n, 1.0);
1141 self.scratch2.fill(1.0);
1142 self.neural_logps.resize(n, 0.0);
1143 self.neural_bit_modes.resize(n, 0);
1144 self.neural_lo.resize(n, 0);
1145 self.neural_hi.resize(n, 256);
1146 if self.neural_pdf_cdf_rows.len() < n {
1147 self.neural_pdf_cdf_rows.resize_with(n, || vec![0.0; 257]);
1148 }
1149
1150 for i in 0..n {
1151 self.neural_bit_modes[i] = 1;
1152 self.neural_lo[i] = 0;
1153 self.neural_hi[i] = 256;
1154
1155 let mut handled_ctw = false;
1156 if let RatePdfPredictor::Ctw(ctw) = &mut *self.experts[i].predictor
1157 && ctw.can_fast_ac_bitwise()
1158 {
1159 self.neural_bit_modes[i] = 0;
1160 handled_ctw = true;
1161 }
1162 if handled_ctw {
1163 continue;
1164 }
1165
1166 if self.experts[i]
1167 .predictor
1168 .prepare_cached_cdf_fast_bitwise()?
1169 {
1170 self.neural_bit_modes[i] = 2;
1171 continue;
1172 }
1173
1174 let pdf = self.experts[i].predictor.pdf_next()?;
1175 let row = &mut self.neural_pdf_cdf_rows[i];
1176 if row.len() != 257 {
1177 row.resize(257, 0.0);
1178 }
1179 row[0] = 0.0;
1180 for b in 0..256usize {
1181 row[b + 1] = row[b] + pdf[b].max(PDF_MIN);
1182 }
1183 if !row[256].is_finite() || row[256] <= 0.0 {
1184 for (j, v) in row.iter_mut().enumerate() {
1185 *v = (j as f64) / 256.0;
1186 }
1187 }
1188 }
1189
1190 let mut symbol = 0u8;
1191 for bit_idx in 0..8usize {
1192 let mut denom = 0.0;
1193 let mut numer1 = 0.0;
1194
1195 for i in 0..n {
1196 let p1 = if self.neural_bit_modes[i] == 0 {
1197 match &mut *self.experts[i].predictor {
1198 RatePdfPredictor::Ctw(ctw) => ctw.bit_prob_one_msb(bit_idx),
1199 _ => 0.5,
1200 }
1201 } else if self.neural_bit_modes[i] == 2 {
1202 self.experts[i]
1203 .predictor
1204 .cached_cdf_bit_prob_one_msb(self.neural_lo[i], self.neural_hi[i])
1205 .unwrap_or(0.5)
1206 } else {
1207 let lo = self.neural_lo[i];
1208 let hi = self.neural_hi[i];
1209 let mid = (lo + hi) >> 1;
1210 let row = &self.neural_pdf_cdf_rows[i];
1211 let total = (row[hi] - row[lo]).max(PDF_MIN);
1212 let one = (row[hi] - row[mid]).max(0.0);
1213 (one / total).clamp(PDF_MIN, 1.0 - PDF_MIN)
1214 };
1215 self.neural_logps[i] = p1;
1216 let wp = self.scratch[i] * self.scratch2[i];
1217 denom += wp;
1218 numer1 += wp * p1;
1219 }
1220
1221 let p1_mix = if denom.is_finite() && denom > 0.0 {
1222 (numer1 / denom).clamp(PDF_MIN, 1.0 - PDF_MIN)
1223 } else {
1224 0.5
1225 };
1226 let bit = choose_bit(bit_idx, p1_mix)? & 1;
1227 symbol |= bit << (7 - bit_idx);
1228
1229 for i in 0..n {
1230 let p1 = self.neural_logps[i];
1231 let pb = if bit == 1 { p1 } else { 1.0 - p1 };
1232 self.scratch2[i] = (self.scratch2[i] * pb).max(PDF_MIN);
1233
1234 if self.neural_bit_modes[i] == 0 {
1235 if let RatePdfPredictor::Ctw(ctw) = &mut *self.experts[i].predictor {
1236 ctw.update_bit_msb(bit_idx, bit == 1);
1237 }
1238 } else {
1239 let lo = self.neural_lo[i];
1240 let hi = self.neural_hi[i];
1241 let mid = (lo + hi) >> 1;
1242 if bit == 1 {
1243 self.neural_lo[i] = mid;
1244 self.neural_hi[i] = hi;
1245 } else {
1246 self.neural_lo[i] = lo;
1247 self.neural_hi[i] = mid;
1248 }
1249 }
1250 }
1251 }
1252
1253 for i in 0..n {
1254 let lp = self.scratch2[i].max(PDF_MIN).ln();
1255 self.neural_logps[i] = lp;
1256 self.experts[i].cum_log_loss -= lp;
1257 if self.neural_bit_modes[i] != 0 {
1258 self.experts[i].predictor.update(symbol)?;
1259 }
1260 }
1261
1262 match self.kind {
1263 MixtureKind::Bayes => {
1264 for i in 0..n {
1265 self.scratch[i] = self.experts[i].log_weight + self.neural_logps[i];
1266 }
1267 let log_mix = logsumexp_slice(&self.scratch[..n]);
1268 for i in 0..n {
1269 self.experts[i].log_weight += self.neural_logps[i] - log_mix;
1270 }
1271 }
1272 MixtureKind::FadingBayes => {
1273 for i in 0..n {
1274 self.scratch[i] =
1275 self.decay * self.experts[i].log_weight + self.neural_logps[i];
1276 }
1277 let log_mix = logsumexp_slice(&self.scratch[..n]);
1278 for i in 0..n {
1279 self.experts[i].log_weight = self.scratch[i] - log_mix;
1280 }
1281 }
1282 MixtureKind::Switching => {
1283 for i in 0..n {
1284 self.scratch[i] = self.experts[i].log_weight + self.neural_logps[i];
1285 }
1286 let log_mix = logsumexp_slice(&self.scratch[..n]);
1287 for weight in &mut self.scratch[..n] {
1288 *weight = (*weight - log_mix).exp();
1289 }
1290 let alpha =
1291 switching_alpha_for_update(self.schedule, self.alpha, self.switch_updates);
1292 self.switch_updates = self.switch_updates.saturating_add(1);
1293 apply_switching_weights(
1294 &mut self.experts,
1295 &self.prior_weights[..n],
1296 alpha,
1297 &mut self.scratch[..n],
1298 &mut self.scratch2[..n],
1299 );
1300 }
1301 MixtureKind::Convex => {
1302 self.scratch.resize(n, 0.0);
1303 self.scratch2.resize(n, 0.0);
1304 for i in 0..n {
1305 self.scratch2[i] = self.experts[i].log_weight.exp();
1306 }
1307 let mix_prob = self
1308 .neural_logps
1309 .iter()
1310 .zip(self.scratch2.iter())
1311 .map(|(&lp, &w)| w * lp.exp())
1312 .sum::<f64>()
1313 .max(PDF_MIN);
1314 let log_mix = mix_prob.ln();
1315 self.convex_updates = self.convex_updates.saturating_add(1);
1316 let eta =
1317 convex_step_size_for_update(self.schedule, self.alpha, self.convex_updates);
1318 for i in 0..n {
1319 let grad = -(self.neural_logps[i] - log_mix).exp();
1320 self.scratch2[i] -= eta * grad;
1321 }
1322 project_simplex_with_scratch(&mut self.scratch2[..n], &mut self.projection_scratch);
1323 for i in 0..n {
1324 self.experts[i].log_weight = self.scratch2[i].max(PDF_MIN).ln();
1325 }
1326 }
1327 MixtureKind::Mdl => {}
1328 MixtureKind::Neural => {
1329 if n > 1 {
1330 self.neural.set_context_state(self.analyzer.state());
1331 self.neural.evaluate_symbol(&self.neural_logps, PDF_MIN);
1332 self.neural
1333 .update_weights_symbol(&self.neural_logps, PDF_MIN);
1334 }
1335 self.analyzer.update(symbol);
1336 self.neural.set_context_state(self.analyzer.state());
1337 }
1338 }
1339 self.valid = false;
1340 Ok(symbol)
1341 }
1342}
1343
1344pub(crate) struct DiagnosticRatePredictor {
1345 inner: RatePdfPredictor,
1346}
1347
1348impl DiagnosticRatePredictor {
1349 pub(crate) fn from_rate_backend(backend: RateBackend, max_order: i64) -> Result<Self> {
1350 Ok(Self {
1351 inner: RatePdfPredictor::from_rate_backend(backend, max_order)?,
1352 })
1353 }
1354
1355 pub(crate) fn begin_stream(&mut self, total_len: usize) -> Result<()> {
1356 self.inner.begin_stream(total_len)
1357 }
1358
1359 pub(crate) fn finish_stream(&mut self) -> Result<()> {
1360 self.inner.finish_stream()
1361 }
1362
1363 #[cfg(test)]
1364 pub(crate) fn pdf_next(&mut self) -> Result<&[f64]> {
1365 self.inner.pdf_next()
1366 }
1367
1368 #[cfg(test)]
1369 pub(crate) fn update(&mut self, symbol: u8) -> Result<()> {
1370 self.inner.update(symbol)
1371 }
1372
1373 pub(crate) fn diagnostic_root_snapshot(
1374 &mut self,
1375 symbol: u8,
1376 pool: Option<&ThreadPool>,
1377 out: &mut Vec<AcLogLossNodeValue>,
1378 ) -> Result<AcLogLossRootSnapshot> {
1379 self.inner.diagnostic_root_snapshot(symbol, pool, out)
1380 }
1381
1382 pub(crate) fn encode_symbol_ac_step<W: std::io::Write>(
1383 &mut self,
1384 symbol: u8,
1385 encoder: &mut ArithmeticEncoder<W>,
1386 ) -> Result<()> {
1387 self.inner.encode_symbol_ac_step(symbol, encoder)
1388 }
1389}
1390
1391#[derive(Clone)]
1392#[allow(clippy::large_enum_variant)]
1393enum RatePdfPredictor {
1394 Rosa(RosaPredictor),
1395 Match {
1396 model: MatchModel,
1397 },
1398 SparseMatch {
1399 model: SparseMatchModel,
1400 },
1401 Ppmd {
1402 model: PpmdModel,
1403 },
1404 Sequitur {
1405 model: SequiturModel,
1406 },
1407 Ctw(CtwPredictor),
1408 FacCtw(CtwPredictor),
1409 #[cfg(feature = "backend-mamba")]
1410 Mamba(MambaPredictor),
1411 #[cfg(feature = "backend-rwkv")]
1412 Rwkv(RwkvPredictor),
1413 Zpaq(ZpaqPredictor),
1414 Mixture(MixturePredictor),
1415 Particle(crate::particle::ParticleRuntime),
1416 Calibrated {
1417 base: Box<RatePdfPredictor>,
1418 core: CalibratorCore,
1419 pdf: Vec<f64>,
1420 valid: bool,
1421 },
1422}
1423
1424impl RatePdfPredictor {
1425 fn from_rate_backend(backend: RateBackend, max_order: i64) -> Result<Self> {
1426 match backend {
1427 RateBackend::RosaPlus => Ok(Self::Rosa(RosaPredictor::new(max_order))),
1428 RateBackend::Match {
1429 hash_bits,
1430 min_len,
1431 max_len,
1432 base_mix,
1433 confidence_scale,
1434 } => Ok(Self::Match {
1435 model: MatchModel::new_contiguous(
1436 hash_bits,
1437 min_len,
1438 max_len,
1439 base_mix,
1440 confidence_scale,
1441 ),
1442 }),
1443 RateBackend::SparseMatch {
1444 hash_bits,
1445 min_len,
1446 max_len,
1447 gap_min,
1448 gap_max,
1449 base_mix,
1450 confidence_scale,
1451 } => Ok(Self::SparseMatch {
1452 model: SparseMatchModel::new(
1453 hash_bits,
1454 min_len,
1455 max_len,
1456 gap_min,
1457 gap_max,
1458 base_mix,
1459 confidence_scale,
1460 ),
1461 }),
1462 RateBackend::Ppmd { order, memory_mb } => Ok(Self::Ppmd {
1463 model: PpmdModel::new(order, memory_mb),
1464 }),
1465 RateBackend::Sequitur { context_bytes } => Ok(Self::Sequitur {
1466 model: SequiturModel::new(context_bytes),
1467 }),
1468 RateBackend::Ctw { depth } => Ok(Self::Ctw(CtwPredictor::new_ctw(depth))),
1469 RateBackend::FacCtw {
1470 base_depth,
1471 num_percept_bits: _,
1472 encoding_bits,
1473 } => {
1474 let bits = encoding_bits.clamp(1, 8);
1475 Ok(Self::FacCtw(CtwPredictor::new_fac(base_depth, bits)))
1476 }
1477 #[cfg(feature = "backend-mamba")]
1478 RateBackend::Mamba { model } => Ok(Self::Mamba(MambaPredictor::from_model(model))),
1479 #[cfg(feature = "backend-mamba")]
1480 RateBackend::MambaMethod { method } => {
1481 Ok(Self::Mamba(MambaPredictor::from_method(&method)?))
1482 }
1483 #[cfg(feature = "backend-rwkv")]
1484 RateBackend::Rwkv7 { model } => Ok(Self::Rwkv(RwkvPredictor::from_model(model))),
1485 #[cfg(feature = "backend-rwkv")]
1486 RateBackend::Rwkv7Method { method } => {
1487 Ok(Self::Rwkv(RwkvPredictor::from_method(&method)?))
1488 }
1489 RateBackend::Zpaq { method } => Ok(Self::Zpaq(ZpaqPredictor::new(method))),
1490 RateBackend::Mixture { spec } => {
1491 Ok(Self::Mixture(MixturePredictor::new(spec.as_ref())?))
1492 }
1493 RateBackend::Particle { spec } => Ok(Self::Particle(
1494 crate::particle::ParticleRuntime::new(spec.as_ref()),
1495 )),
1496 RateBackend::Calibrated { spec } => Ok(Self::Calibrated {
1497 base: Box::new(Self::from_rate_backend(spec.base.clone(), max_order)?),
1498 core: build_calibrator(spec.as_ref()),
1499 pdf: vec![1.0 / 256.0; 256],
1500 valid: false,
1501 }),
1502 }
1503 }
1504
1505 fn begin_stream(&mut self, total_len: usize) -> Result<()> {
1506 self.finish_stream()?;
1507 match self {
1508 Self::Rosa(m) => {
1509 m.begin_stream(total_len);
1510 Ok(())
1511 }
1512 Self::Match { .. }
1513 | Self::SparseMatch { .. }
1514 | Self::Ppmd { .. }
1515 | Self::Zpaq(_)
1516 | Self::Particle(_) => Ok(()),
1517 Self::Sequitur { model } => {
1518 model.begin_stream(Some(total_len as u64));
1519 Ok(())
1520 }
1521 Self::Ctw(m) | Self::FacCtw(m) => {
1522 m.tree.reserve_for_symbols(total_len);
1523 Ok(())
1524 }
1525 #[cfg(feature = "backend-mamba")]
1526 Self::Mamba(m) => m.begin_stream(total_len),
1527 #[cfg(feature = "backend-rwkv")]
1528 Self::Rwkv(m) => m.begin_stream(total_len),
1529 Self::Mixture(m) => m.begin_stream(total_len),
1530 Self::Calibrated { base, .. } => base.begin_stream(total_len),
1531 }
1532 }
1533
1534 fn finish_stream(&mut self) -> Result<()> {
1535 match self {
1536 Self::Rosa(_)
1537 | Self::Match { .. }
1538 | Self::SparseMatch { .. }
1539 | Self::Ppmd { .. }
1540 | Self::Sequitur { .. }
1541 | Self::Ctw(_)
1542 | Self::FacCtw(_)
1543 | Self::Zpaq(_)
1544 | Self::Particle(_) => Ok(()),
1545 #[cfg(feature = "backend-mamba")]
1546 Self::Mamba(_) => Ok(()),
1547 #[cfg(feature = "backend-rwkv")]
1548 Self::Rwkv(m) => m.finish_stream(),
1549 Self::Mixture(m) => m.finish_stream(),
1550 Self::Calibrated { base, .. } => base.finish_stream(),
1551 }
1552 }
1553
1554 fn pdf_next(&mut self) -> Result<&[f64]> {
1555 match self {
1556 Self::Rosa(m) => Ok(m.pdf_next()),
1557 Self::Match { model } => Ok(model.pdf()),
1558 Self::Ctw(m) => Ok(m.pdf_next()),
1559 Self::FacCtw(m) => Ok(m.pdf_next()),
1560 #[cfg(feature = "backend-mamba")]
1561 Self::Mamba(m) => Ok(m.pdf_next()),
1562 #[cfg(feature = "backend-rwkv")]
1563 Self::Rwkv(m) => Ok(m.pdf_next()),
1564 Self::Zpaq(m) => Ok(m.pdf_next()),
1565 Self::Mixture(m) => m.ensure_pdf(),
1566 Self::Particle(m) => Ok(m.pdf_next()),
1567 Self::SparseMatch { model } => Ok(model.pdf()),
1568 Self::Ppmd { model } => Ok(model.pdf()),
1569 Self::Sequitur { model } => Ok(model.pdf()),
1570 Self::Calibrated {
1571 base,
1572 core,
1573 pdf,
1574 valid,
1575 } => {
1576 if !*valid {
1577 let base_pdf = base.pdf_next()?;
1578 core.apply_pdf(base_pdf, pdf);
1579 normalize_pdf(pdf);
1580 *valid = true;
1581 }
1582 Ok(pdf)
1583 }
1584 }
1585 }
1586
1587 fn update(&mut self, symbol: u8) -> Result<()> {
1588 match self {
1589 Self::Rosa(m) => {
1590 m.update(symbol);
1591 Ok(())
1592 }
1593 Self::Match { model } => {
1594 model.update(symbol);
1595 Ok(())
1596 }
1597 Self::SparseMatch { model } => {
1598 model.update(symbol);
1599 Ok(())
1600 }
1601 Self::Ppmd { model } => {
1602 model.update(symbol);
1603 Ok(())
1604 }
1605 Self::Sequitur { model } => {
1606 model.update(symbol);
1607 Ok(())
1608 }
1609 Self::Ctw(m) => {
1610 m.update(symbol);
1611 Ok(())
1612 }
1613 Self::FacCtw(m) => {
1614 m.update(symbol);
1615 Ok(())
1616 }
1617 #[cfg(feature = "backend-mamba")]
1618 Self::Mamba(m) => m.update(symbol),
1619 #[cfg(feature = "backend-rwkv")]
1620 Self::Rwkv(m) => m.update(symbol),
1621 Self::Zpaq(m) => {
1622 m.update(symbol);
1623 Ok(())
1624 }
1625 Self::Mixture(m) => m.update(symbol),
1626 Self::Particle(m) => {
1627 m.step(symbol);
1628 Ok(())
1629 }
1630 Self::Calibrated {
1631 base,
1632 core,
1633 pdf,
1634 valid,
1635 } => {
1636 if !*valid {
1637 let base_pdf = base.pdf_next()?;
1638 core.apply_pdf(base_pdf, pdf);
1639 normalize_pdf(pdf);
1640 }
1641 core.update(symbol, pdf);
1642 base.update(symbol)?;
1643 *valid = false;
1644 Ok(())
1645 }
1646 }
1647 }
1648
1649 fn prepare_cached_cdf_fast_bitwise(&mut self) -> Result<bool> {
1650 match self {
1651 Self::Rosa(m) => {
1652 let _ = m.cdf_next();
1653 Ok(true)
1654 }
1655 Self::Match { model } => {
1656 let _ = model.cdf();
1657 Ok(true)
1658 }
1659 Self::SparseMatch { model } => {
1660 let _ = model.cdf();
1661 Ok(true)
1662 }
1663 Self::Ppmd { model } => {
1664 let _ = model.cdf();
1665 Ok(true)
1666 }
1667 #[cfg(feature = "backend-mamba")]
1668 Self::Mamba(m) => {
1669 let _ = m.cdf_next();
1670 Ok(true)
1671 }
1672 #[cfg(feature = "backend-rwkv")]
1673 Self::Rwkv(m) => {
1674 let _ = m.cdf_next();
1675 Ok(true)
1676 }
1677 _ => Ok(false),
1678 }
1679 }
1680
1681 fn cached_cdf_bit_prob_one_msb(&mut self, lo: usize, hi: usize) -> Option<f64> {
1682 match self {
1683 Self::Rosa(m) => Some(cdf_bit_prob_one_msb(&m.cdf, lo, hi)),
1684 Self::Match { model } => Some(cdf_bit_prob_one_msb(model.cdf(), lo, hi)),
1685 Self::SparseMatch { model } => Some(cdf_bit_prob_one_msb(model.cdf(), lo, hi)),
1686 Self::Ppmd { model } => Some(cdf_bit_prob_one_msb(model.cdf(), lo, hi)),
1687 #[cfg(feature = "backend-mamba")]
1688 Self::Mamba(m) => Some(cdf_bit_prob_one_msb(m.cdf_next(), lo, hi)),
1689 #[cfg(feature = "backend-rwkv")]
1690 Self::Rwkv(m) => Some(cdf_bit_prob_one_msb(m.cdf_next(), lo, hi)),
1691 _ => None,
1692 }
1693 }
1694
1695 #[inline]
1696 fn can_fast_ac_bitwise(&self) -> bool {
1697 match self {
1698 Self::Ctw(m) => m.can_fast_ac_bitwise(),
1699 Self::Mixture(m) => m.can_fast_ac_bitwise(),
1700 _ => false,
1701 }
1702 }
1703
1704 fn ac_step_fast_bitwise<F>(&mut self, choose_bit: F) -> Result<u8>
1705 where
1706 F: FnMut(usize, f64) -> Result<u8>,
1707 {
1708 match self {
1709 Self::Ctw(m) => ctw_ac_step_bitwise(m, choose_bit),
1710 Self::Mixture(m) => m.ac_step_bitwise(choose_bit),
1711 _ => unreachable!("fast bitwise path requested for unsupported predictor"),
1712 }
1713 }
1714
1715 fn diagnostic_snapshot_subtree(
1716 &mut self,
1717 symbol: u8,
1718 local_weight: f64,
1719 effective_weight: f64,
1720 pool: Option<&ThreadPool>,
1721 ) -> Result<AcLogLossSubtreeSnapshot> {
1722 match self {
1723 Self::Mixture(m) => {
1724 m.diagnostic_subtree_snapshot(symbol, local_weight, effective_weight, pool)
1725 }
1726 _ => {
1727 let prob = self.pdf_next()?[symbol as usize].max(PDF_MIN);
1728 Ok(AcLogLossSubtreeSnapshot {
1729 prob,
1730 rows: vec![AcLogLossNodeValue {
1731 prob,
1732 local_weight,
1733 effective_weight,
1734 }],
1735 })
1736 }
1737 }
1738 }
1739
1740 fn diagnostic_root_snapshot(
1741 &mut self,
1742 symbol: u8,
1743 pool: Option<&ThreadPool>,
1744 out: &mut Vec<AcLogLossNodeValue>,
1745 ) -> Result<AcLogLossRootSnapshot> {
1746 match self {
1747 Self::Mixture(m) => m.diagnostic_root_snapshot(symbol, pool, out),
1748 _ => anyhow::bail!("AC log-loss diagnostics require a top-level mixture backend"),
1749 }
1750 }
1751
1752 fn encode_symbol_ac_step<W: std::io::Write>(
1753 &mut self,
1754 symbol: u8,
1755 encoder: &mut ArithmeticEncoder<W>,
1756 ) -> Result<()> {
1757 if self.can_fast_ac_bitwise() {
1758 self.ac_step_fast_bitwise(|bit_idx, p1_mix| {
1759 let bit = (symbol >> (7 - bit_idx)) & 1;
1760 let split = binary_split_from_prob_one(p1_mix);
1761 if bit == 0 {
1762 encoder.encode_counts(0, split as u64, CDF_TOTAL as u64)?;
1763 } else {
1764 encoder.encode_counts(split as u64, CDF_TOTAL as u64, CDF_TOTAL as u64)?;
1765 }
1766 Ok(bit)
1767 })?;
1768 return Ok(());
1769 }
1770
1771 let pdf = self.pdf_next()?;
1772 let mut cdf = vec![0u32; 257];
1773 crate::coders::quantize_pdf_to_integer_cdf_dense_positive_with_buffer(
1774 pdf, CDF_TOTAL, &mut cdf,
1775 );
1776 let sym = symbol as usize;
1777 encoder.encode_counts(cdf[sym] as u64, cdf[sym + 1] as u64, CDF_TOTAL as u64)?;
1778 self.update(symbol)
1779 }
1780}
1781
1782fn ctw_ac_step_bitwise<F>(ctw: &mut CtwPredictor, mut choose_bit: F) -> Result<u8>
1783where
1784 F: FnMut(usize, f64) -> Result<u8>,
1785{
1786 debug_assert!(ctw.can_fast_ac_bitwise());
1787 let mut symbol = 0u8;
1788 for bit_idx in 0..8usize {
1789 let p1 = ctw.bit_prob_one_msb(bit_idx);
1790 let bit = choose_bit(bit_idx, p1)? & 1;
1791 symbol |= bit << (7 - bit_idx);
1792 ctw.update_bit_msb(bit_idx, bit == 1);
1793 }
1794 Ok(symbol)
1795}
1796
1797#[inline]
1798fn binary_split_from_prob_one(p1: f64) -> u32 {
1799 let p1 = p1.clamp(PDF_MIN, 1.0 - PDF_MIN);
1800 let p0 = 1.0 - p1;
1801 let mut split = (p0 * (CDF_TOTAL as f64)) as u32;
1802 if split == 0 {
1803 split = 1;
1804 } else if split >= CDF_TOTAL {
1805 split = CDF_TOTAL - 1;
1806 }
1807 split
1808}
1809
1810fn encode_payload_ac(data: &[u8], predictor: &mut RatePdfPredictor) -> Result<Vec<u8>> {
1811 predictor.begin_stream(data.len())?;
1812 let mut out = Vec::new();
1813 {
1814 let mut enc = ArithmeticEncoder::new(&mut out);
1815 for &symbol in data {
1816 predictor.encode_symbol_ac_step(symbol, &mut enc)?;
1817 }
1818 let _ = enc.finish()?;
1819 }
1820 predictor.finish_stream()?;
1821 Ok(out)
1822}
1823
1824fn decode_payload_ac(
1825 payload: &[u8],
1826 out_len: usize,
1827 predictor: &mut RatePdfPredictor,
1828) -> Result<Vec<u8>> {
1829 predictor.begin_stream(out_len)?;
1830 if predictor.can_fast_ac_bitwise() {
1831 let mut dec = ArithmeticDecoder::new(payload)?;
1832 let mut out = Vec::with_capacity(out_len);
1833 for _ in 0..out_len {
1834 let symbol = predictor.ac_step_fast_bitwise(|_, p1_mix| {
1835 let split = binary_split_from_prob_one(p1_mix);
1836 let cdf = [0u32, split, CDF_TOTAL];
1837 Ok(dec.decode_symbol_counts(&cdf, CDF_TOTAL)? as u8)
1838 })?;
1839 out.push(symbol);
1840 }
1841 predictor.finish_stream()?;
1842 return Ok(out);
1843 }
1844
1845 let mut dec = ArithmeticDecoder::new(payload)?;
1846 let mut out = Vec::with_capacity(out_len);
1847 let mut cdf = vec![0u32; 257];
1848 for _ in 0..out_len {
1849 let pdf = predictor.pdf_next()?;
1850 crate::coders::quantize_pdf_to_integer_cdf_dense_positive_with_buffer(
1851 pdf, CDF_TOTAL, &mut cdf,
1852 );
1853 let sym = dec.decode_symbol_counts(&cdf, CDF_TOTAL)? as u8;
1854 out.push(sym);
1855 predictor.update(sym)?;
1856 }
1857 predictor.finish_stream()?;
1858 Ok(out)
1859}
1860
1861fn encode_payload_rans(data: &[u8], predictor: &mut RatePdfPredictor) -> Result<Vec<u8>> {
1862 predictor.begin_stream(data.len())?;
1863 let mut encoder = BlockedRansEncoder::new();
1864 let mut cdf = vec![0u32; 257];
1865 let mut freq = vec![0i64; 256];
1866
1867 for &b in data {
1868 let pdf = predictor.pdf_next()?;
1869 quantize_pdf_to_rans_cdf_with_buffer(pdf, &mut cdf, &mut freq);
1870 let s = b as usize;
1871 encoder.encode(Cdf::new(cdf[s], cdf[s + 1], ANS_TOTAL));
1872 predictor.update(b)?;
1873 }
1874
1875 let blocks = encoder.finish();
1876 let mut out = Vec::new();
1877 out.extend_from_slice(&(blocks.len() as u32).to_le_bytes());
1878 for block in blocks {
1879 out.extend_from_slice(&(block.len() as u32).to_le_bytes());
1880 out.extend_from_slice(&block);
1881 }
1882 predictor.finish_stream()?;
1883 Ok(out)
1884}
1885
1886fn decode_payload_rans(
1887 payload: &[u8],
1888 out_len: usize,
1889 predictor: &mut RatePdfPredictor,
1890) -> Result<Vec<u8>> {
1891 predictor.begin_stream(out_len)?;
1892 if payload.len() < 4 {
1893 bail!("rANS payload too short");
1894 }
1895 let block_count = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]) as usize;
1896 let mut pos = 4usize;
1897 let mut blocks = Vec::with_capacity(block_count);
1898 for _ in 0..block_count {
1899 if pos + 4 > payload.len() {
1900 bail!("truncated rANS block header");
1901 }
1902 let len = u32::from_le_bytes([
1903 payload[pos],
1904 payload[pos + 1],
1905 payload[pos + 2],
1906 payload[pos + 3],
1907 ]) as usize;
1908 pos += 4;
1909 if pos + len > payload.len() {
1910 bail!("truncated rANS block data");
1911 }
1912 blocks.push(&payload[pos..pos + len]);
1913 pos += len;
1914 }
1915
1916 let mut dec = BlockedRansDecoder::new(blocks, out_len)?;
1917 let mut out = Vec::with_capacity(out_len);
1918 let mut cdf = vec![0u32; 257];
1919 let mut freq = vec![0i64; 256];
1920
1921 for _ in 0..out_len {
1922 let pdf = predictor.pdf_next()?;
1923 quantize_pdf_to_rans_cdf_with_buffer(pdf, &mut cdf, &mut freq);
1924 let sym = dec.decode(&cdf)? as u8;
1925 out.push(sym);
1926 predictor.update(sym)?;
1927 }
1928 predictor.finish_stream()?;
1929 Ok(out)
1930}
1931
1932pub fn compress_rate_bytes(
1937 data: &[u8],
1938 rate_backend: &RateBackend,
1939 max_order: i64,
1940 coder: CoderType,
1941 framing: FramingMode,
1942) -> Result<Vec<u8>> {
1943 let mut predictor = RatePdfPredictor::from_rate_backend(rate_backend.clone(), max_order)?;
1944 let payload = match coder {
1945 CoderType::AC => encode_payload_ac(data, &mut predictor)?,
1946 CoderType::RANS => encode_payload_rans(data, &mut predictor)?,
1947 };
1948
1949 if framing == FramingMode::Raw {
1950 return Ok(payload);
1951 }
1952
1953 let mut out = Vec::with_capacity(FramedHeader::SIZE + payload.len());
1954 let hdr = FramedHeader::new(coder, data.len() as u64, crc32(data));
1955 hdr.write(&mut out);
1956 out.extend_from_slice(&payload);
1957 Ok(out)
1958}
1959
1960pub fn compress_rate_size(
1962 data: &[u8],
1963 rate_backend: &RateBackend,
1964 max_order: i64,
1965 coder: CoderType,
1966 framing: FramingMode,
1967) -> Result<u64> {
1968 let encoded = compress_rate_bytes(data, rate_backend, max_order, coder, framing)?;
1969 Ok(encoded.len() as u64)
1970}
1971
1972pub fn compress_rate_size_chain(
1974 parts: &[&[u8]],
1975 rate_backend: &RateBackend,
1976 max_order: i64,
1977 coder: CoderType,
1978 framing: FramingMode,
1979) -> Result<u64> {
1980 let total = parts.iter().map(|p| p.len()).sum();
1981 let mut data = Vec::with_capacity(total);
1982 for p in parts {
1983 data.extend_from_slice(p);
1984 }
1985 compress_rate_size(&data, rate_backend, max_order, coder, framing)
1986}
1987
1988pub fn decompress_rate_bytes(
1990 input: &[u8],
1991 rate_backend: &RateBackend,
1992 max_order: i64,
1993 _coder: CoderType,
1994 framing: FramingMode,
1995) -> Result<Vec<u8>> {
1996 let (payload, coder, out_len, expected_crc) = if framing == FramingMode::Framed {
1997 let hdr = FramedHeader::read(input)?;
1998 (
1999 &input[FramedHeader::SIZE..],
2000 hdr.coder_type(),
2001 hdr.original_len as usize,
2002 Some(hdr.crc32),
2003 )
2004 } else {
2005 bail!("raw payload decompression requires explicit output length and is not supported");
2006 };
2007
2008 let _ = coder;
2009 let mut predictor = RatePdfPredictor::from_rate_backend(rate_backend.clone(), max_order)?;
2010 let decoded = match coder {
2011 CoderType::AC => decode_payload_ac(payload, out_len, &mut predictor)?,
2012 CoderType::RANS => decode_payload_rans(payload, out_len, &mut predictor)?,
2013 };
2014
2015 if let Some(crc) = expected_crc {
2016 let got = crc32(&decoded);
2017 if got != crc {
2018 bail!("CRC32 mismatch: expected 0x{crc:08X}, got 0x{got:08X}");
2019 }
2020 }
2021
2022 Ok(decoded)
2023}
2024
2025fn normalize_pdf(pdf: &mut [f64]) {
2026 let mut sum = 0.0;
2027 for p in pdf.iter_mut() {
2028 *p = if p.is_finite() {
2029 (*p).max(PDF_MIN)
2030 } else {
2031 PDF_MIN
2032 };
2033 sum += *p;
2034 }
2035 if !(sum.is_finite()) || sum <= 0.0 {
2036 let u = 1.0 / (pdf.len() as f64);
2037 for p in pdf.iter_mut() {
2038 *p = u;
2039 }
2040 return;
2041 }
2042 let inv = 1.0 / sum;
2043 for p in pdf.iter_mut() {
2044 *p *= inv;
2045 }
2046}
2047
2048#[inline]
2049fn uniform_cdf_row() -> [f64; 257] {
2050 let mut cdf = [0.0; 257];
2051 let inv = 1.0 / 256.0;
2052 for (i, slot) in cdf.iter_mut().enumerate() {
2053 *slot = (i as f64) * inv;
2054 }
2055 cdf
2056}
2057
2058#[inline]
2059fn build_cdf_row_from_pdf_slice(pdf: &[f64], cdf: &mut [f64; 257]) {
2060 cdf[0] = 0.0;
2061 let mut acc = 0.0;
2062 for i in 0..256 {
2063 acc += pdf[i];
2064 cdf[i + 1] = acc;
2065 }
2066}
2067
2068fn normalize_pdf_vec_and_maybe_build_cdf(pdf: &mut [f64], mut cdf: Option<&mut [f64; 257]>) {
2069 let mut sum = 0.0;
2070 for p in pdf.iter_mut() {
2071 *p = if p.is_finite() {
2072 (*p).max(PDF_MIN)
2073 } else {
2074 PDF_MIN
2075 };
2076 sum += *p;
2077 }
2078 if !(sum.is_finite()) || sum <= 0.0 {
2079 let u = 1.0 / (pdf.len() as f64);
2080 pdf.fill(u);
2081 if let Some(cdf) = cdf.as_deref_mut() {
2082 *cdf = uniform_cdf_row();
2083 }
2084 return;
2085 }
2086 let inv = 1.0 / sum;
2087 if let Some(cdf) = cdf.as_deref_mut() {
2088 cdf[0] = 0.0;
2089 let mut acc = 0.0;
2090 for i in 0..256 {
2091 pdf[i] *= inv;
2092 acc += pdf[i];
2093 cdf[i + 1] = acc;
2094 }
2095 } else {
2096 for p in pdf.iter_mut() {
2097 *p *= inv;
2098 }
2099 }
2100}
2101
2102#[inline]
2103fn cdf_bit_prob_one_msb(cdf: &[f64; 257], lo: usize, hi: usize) -> f64 {
2104 let mid = (lo + hi) >> 1;
2105 let total = (cdf[hi] - cdf[lo]).max(PDF_MIN);
2106 let one = (cdf[hi] - cdf[mid]).max(0.0);
2107 (one / total).clamp(PDF_MIN, 1.0 - PDF_MIN)
2108}
2109
2110#[inline]
2111fn logsumexp_slice(vals: &[f64]) -> f64 {
2112 let mut m = f64::NEG_INFINITY;
2113 for &v in vals {
2114 if v > m {
2115 m = v;
2116 }
2117 }
2118 if !m.is_finite() {
2119 return m;
2120 }
2121 let mut s = 0.0;
2122 for &v in vals {
2123 s += (v - m).exp();
2124 }
2125 m + s.ln()
2126}
2127
2128#[inline]
2129fn logsumexp_expert_weights(experts: &[MixExpert]) -> f64 {
2130 let mut m = f64::NEG_INFINITY;
2131 for e in experts {
2132 if e.log_weight > m {
2133 m = e.log_weight;
2134 }
2135 }
2136 if !m.is_finite() {
2137 return m;
2138 }
2139 let mut s = 0.0;
2140 for e in experts {
2141 s += (e.log_weight - m).exp();
2142 }
2143 m + s.ln()
2144}
2145
2146fn normalize_simplex_weights(weights: &mut [f64]) {
2147 if weights.is_empty() {
2148 return;
2149 }
2150 let mut sum = 0.0;
2151 for weight in weights.iter_mut() {
2152 if !weight.is_finite() || *weight < 0.0 {
2153 *weight = 0.0;
2154 }
2155 sum += *weight;
2156 }
2157 if !sum.is_finite() || sum <= 0.0 {
2158 let uniform = 1.0 / (weights.len() as f64);
2159 weights.fill(uniform);
2160 return;
2161 }
2162 for weight in weights.iter_mut() {
2163 *weight /= sum;
2164 }
2165}
2166
2167fn normalized_mix_expert_prior_weights(experts: &[MixExpert], out: &mut [f64]) {
2168 debug_assert_eq!(experts.len(), out.len());
2169 let max_log = experts
2170 .iter()
2171 .map(|expert| expert.log_prior)
2172 .fold(f64::NEG_INFINITY, f64::max);
2173 for (slot, expert) in out.iter_mut().zip(experts.iter()) {
2174 *slot = if max_log.is_finite() {
2175 (expert.log_prior - max_log).exp()
2176 } else {
2177 0.0
2178 };
2179 }
2180 normalize_simplex_weights(out);
2181}
2182
2183fn set_mix_expert_log_weights_from_linear(experts: &mut [MixExpert], weights: &[f64]) {
2184 for (expert, &weight) in experts.iter_mut().zip(weights.iter()) {
2185 expert.log_weight = if weight > 0.0 {
2186 weight.ln()
2187 } else {
2188 f64::NEG_INFINITY
2189 };
2190 }
2191}
2192
2193fn apply_switching_weights(
2194 experts: &mut [MixExpert],
2195 prior_weights: &[f64],
2196 alpha: f64,
2197 posterior: &mut [f64],
2198 scratch: &mut [f64],
2199) {
2200 if experts.is_empty() {
2201 return;
2202 }
2203 debug_assert_eq!(experts.len(), prior_weights.len());
2204
2205 normalize_simplex_weights(posterior);
2206 if experts.len() == 1 || alpha <= 0.0 {
2207 set_mix_expert_log_weights_from_linear(experts, posterior);
2208 return;
2209 }
2210
2211 let num_switch_targets = prior_weights.iter().filter(|&&prior| prior < 1.0).count();
2212 if num_switch_targets <= 1 {
2213 set_mix_expert_log_weights_from_linear(experts, posterior);
2214 return;
2215 }
2216
2217 let mut switch_out_sum = 0.0;
2218 for i in 0..experts.len() {
2219 let denom = 1.0 - prior_weights[i];
2220 if denom > 0.0 {
2221 switch_out_sum += posterior[i] / denom;
2222 }
2223 }
2224
2225 for i in 0..experts.len() {
2226 let prior = prior_weights[i];
2227 let stay = (1.0 - alpha) * posterior[i];
2228 let switch_in = if prior > 0.0 {
2229 let denom = 1.0 - prior;
2230 let switchable_mass = if denom > 0.0 {
2231 switch_out_sum - posterior[i] / denom
2232 } else {
2233 0.0
2234 };
2235 alpha * prior * switchable_mass
2236 } else {
2237 0.0
2238 };
2239 scratch[i] = stay + switch_in;
2240 }
2241
2242 normalize_simplex_weights(scratch);
2243 set_mix_expert_log_weights_from_linear(experts, scratch);
2244}
2245
2246#[allow(dead_code)]
2247fn _zpaq_marker(_: &ZpaqRateModel) {}
2248
2249#[cfg(test)]
2250mod tests {
2251 use super::*;
2252 use std::sync::Arc;
2253
2254 fn assert_pdf_close(lhs: &[f64], rhs: &[f64], tol: f64) {
2255 assert_eq!(lhs.len(), rhs.len());
2256 for (idx, (&a, &b)) in lhs.iter().zip(rhs.iter()).enumerate() {
2257 let delta = (a - b).abs();
2258 assert!(
2259 delta <= tol,
2260 "pdf mismatch at symbol {idx}: lhs={a} rhs={b} delta={delta}"
2261 );
2262 }
2263 }
2264
2265 fn brute_force_pdf(predictor: &mut CtwPredictor) -> Vec<f64> {
2266 let bits = predictor.bits_per_symbol.clamp(1, 8);
2267 let mut out = vec![0.0; 256];
2268
2269 if bits == 8 {
2270 for (sym, slot) in out.iter_mut().enumerate().take(256usize) {
2271 *slot = predictor.log_prob_symbol_bruteforce(sym as u8).exp();
2272 }
2273 } else {
2274 let patterns = 1usize << bits;
2275 let aliases = 1usize << (8 - bits);
2276 let mut pat_prob = vec![0.0; patterns];
2277 for (pat, value) in pat_prob.iter_mut().enumerate() {
2278 let symbol = if predictor.msb_first {
2279 (pat as u8) << (8 - bits)
2280 } else {
2281 pat as u8
2282 };
2283 *value = predictor.log_prob_symbol_bruteforce(symbol).exp();
2284 }
2285 for (byte, slot) in out.iter_mut().enumerate().take(256usize) {
2286 let pat = if predictor.msb_first {
2287 byte >> (8 - bits)
2288 } else {
2289 byte & (patterns - 1)
2290 };
2291 *slot = pat_prob[pat] / (aliases as f64);
2292 }
2293 }
2294
2295 CtwPredictor::normalize_pdf(&mut out);
2296 out
2297 }
2298
2299 #[test]
2300 fn ctw_pdf_fast_matches_bruteforce() {
2301 let mut predictor = CtwPredictor::new_ctw(6);
2302 for &b in b"ctw fast-path regression corpus 1234567890" {
2303 predictor.update(b);
2304 }
2305
2306 let fast = predictor.pdf_next().to_vec();
2307 predictor.valid = false;
2308 let brute = brute_force_pdf(&mut predictor);
2309
2310 for i in 0..256usize {
2311 let delta = (fast[i] - brute[i]).abs();
2312 assert!(
2313 delta < 1e-12,
2314 "symbol={i} fast={} brute={} delta={delta}",
2315 fast[i],
2316 brute[i]
2317 );
2318 }
2319 }
2320
2321 #[test]
2322 fn fac_pdf_fast_matches_bruteforce_subbyte() {
2323 let mut predictor = CtwPredictor::new_fac(5, 5);
2324 for &b in b"fac ctw subbyte regression corpus abcdefghijklmnopqrstuvwxyz" {
2325 predictor.update(b);
2326 }
2327
2328 let fast = predictor.pdf_next().to_vec();
2329 predictor.valid = false;
2330 let brute = brute_force_pdf(&mut predictor);
2331
2332 for i in 0..256usize {
2333 let delta = (fast[i] - brute[i]).abs();
2334 assert!(
2335 delta < 1e-12,
2336 "symbol={i} fast={} brute={} delta={delta}",
2337 fast[i],
2338 brute[i]
2339 );
2340 }
2341 }
2342
2343 fn assert_ctw_pdf_next_preserves_state(mut predictor: CtwPredictor) {
2344 for &b in b"ctw predictor state preservation payload" {
2345 predictor.update(b);
2346 }
2347 let mut before_p0 = [0.0f64; 8];
2348 let mut before_p1 = [0.0f64; 8];
2349 for bit_idx in 0..8usize {
2350 before_p0[bit_idx] = predictor.tree.predict(false, bit_idx);
2351 before_p1[bit_idx] = predictor.tree.predict(true, bit_idx);
2352 }
2353 let log_before = predictor.tree.get_log_block_probability();
2354 let _ = predictor.pdf_next();
2355 let log_after = predictor.tree.get_log_block_probability();
2356 assert!(
2357 (log_before - log_after).abs() < 1e-12,
2358 "log drift: before={log_before} after={log_after}"
2359 );
2360 for bit_idx in 0..8usize {
2361 let after_p0 = predictor.tree.predict(false, bit_idx);
2362 let after_p1 = predictor.tree.predict(true, bit_idx);
2363 assert!(
2364 (before_p0[bit_idx] - after_p0).abs() < 1e-12,
2365 "bit {bit_idx} p0 drift: {} vs {}",
2366 before_p0[bit_idx],
2367 after_p0
2368 );
2369 assert!(
2370 (before_p1[bit_idx] - after_p1).abs() < 1e-12,
2371 "bit {bit_idx} p1 drift: {} vs {}",
2372 before_p1[bit_idx],
2373 after_p1
2374 );
2375 }
2376 }
2377
2378 #[test]
2379 fn ctw_pdf_next_preserves_state() {
2380 assert_ctw_pdf_next_preserves_state(CtwPredictor::new_ctw(7));
2381 }
2382
2383 #[test]
2384 fn fac_pdf_next_preserves_state() {
2385 assert_ctw_pdf_next_preserves_state(CtwPredictor::new_fac(7, 8));
2386 }
2387
2388 fn assert_fill_pattern_preserves_symbol_log_probs(mut predictor: CtwPredictor) {
2389 for &b in b"fill-pattern preservation regression payload" {
2390 predictor.update(b);
2391 }
2392 let mut baseline = [0.0f64; 256];
2393 for (sym, slot) in baseline.iter_mut().enumerate() {
2394 *slot = predictor.log_prob_symbol_bruteforce(sym as u8);
2395 }
2396 let _ = predictor.fill_pattern_log_probs();
2397 for (sym, &expected) in baseline.iter().enumerate() {
2398 let got = predictor.log_prob_symbol_bruteforce(sym as u8);
2399 let diff = (expected - got).abs();
2400 assert!(
2401 diff < 1e-12,
2402 "symbol={sym} expected={expected} got={got} diff={diff}"
2403 );
2404 }
2405 }
2406
2407 #[test]
2408 fn ctw_fill_pattern_preserves_symbol_log_probs() {
2409 assert_fill_pattern_preserves_symbol_log_probs(CtwPredictor::new_ctw(7));
2410 }
2411
2412 #[test]
2413 fn fac_fill_pattern_preserves_symbol_log_probs() {
2414 assert_fill_pattern_preserves_symbol_log_probs(CtwPredictor::new_fac(7, 8));
2415 }
2416
2417 fn assert_pdf_then_update_matches_plain_update(mut base: CtwPredictor) {
2418 for &b in b"pdf then update parity payload" {
2419 base.update(b);
2420 }
2421 let observed = b'n';
2422 let mut with_pdf = base.clone();
2423 let mut plain = base;
2424
2425 let _ = with_pdf.pdf_next();
2426 with_pdf.update(observed);
2427 plain.update(observed);
2428
2429 for sym in 0u8..=255u8 {
2430 let lp_with_pdf = with_pdf.log_prob_symbol_bruteforce(sym);
2431 let lp_plain = plain.log_prob_symbol_bruteforce(sym);
2432 let diff = (lp_with_pdf - lp_plain).abs();
2433 assert!(
2434 diff < 1e-12,
2435 "symbol={sym} with_pdf={lp_with_pdf} plain={lp_plain} diff={diff}"
2436 );
2437 }
2438 }
2439
2440 #[test]
2441 fn ctw_pdf_then_update_matches_plain_update() {
2442 assert_pdf_then_update_matches_plain_update(CtwPredictor::new_ctw(7));
2443 }
2444
2445 #[test]
2446 fn fac_pdf_then_update_matches_plain_update() {
2447 assert_pdf_then_update_matches_plain_update(CtwPredictor::new_fac(7, 8));
2448 }
2449
2450 #[test]
2451 fn roundtrip_rate_ac_ctw() {
2452 let data = b"ctw backend roundtrip payload";
2453 let backend = RateBackend::Ctw { depth: 8 };
2454 let enc =
2455 compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2456 let dec =
2457 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2458 assert_eq!(dec, data);
2459 }
2460
2461 #[test]
2462 fn roundtrip_rate_ac_match_family_and_ppmd() {
2463 let data = b"repeat repeat repeat sparse sparse repeat payload";
2464 for backend in [
2465 RateBackend::Match {
2466 hash_bits: 20,
2467 min_len: 4,
2468 max_len: 255,
2469 base_mix: 0.02,
2470 confidence_scale: 1.0,
2471 },
2472 RateBackend::SparseMatch {
2473 hash_bits: 19,
2474 min_len: 3,
2475 max_len: 64,
2476 gap_min: 1,
2477 gap_max: 2,
2478 base_mix: 0.05,
2479 confidence_scale: 1.0,
2480 },
2481 RateBackend::Ppmd {
2482 order: 8,
2483 memory_mb: 8,
2484 },
2485 ] {
2486 let enc = compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed)
2487 .unwrap();
2488 let dec = decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed)
2489 .unwrap();
2490 assert_eq!(dec, data);
2491 }
2492 }
2493
2494 #[test]
2495 fn roundtrip_rate_ac_ppmd_high_order_text_payload() {
2496 let seed = include_bytes!("../../README.md");
2497 let mut data = Vec::with_capacity(4096);
2498 while data.len() < 4096 {
2499 data.extend_from_slice(seed);
2500 }
2501 data.truncate(4096);
2502
2503 let backend = RateBackend::Ppmd {
2504 order: 12,
2505 memory_mb: 256,
2506 };
2507 let enc = compress_rate_bytes(&data, &backend, -1, CoderType::AC, FramingMode::Framed)
2508 .expect("ppmd high-order compression");
2509 let dec = decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed)
2510 .expect("ppmd high-order decompression");
2511 assert_eq!(dec, data);
2512 }
2513
2514 #[test]
2515 fn roundtrip_rate_ac_calibrated_backend() {
2516 let data = b"calibration wrapper payload calibration wrapper payload";
2517 let backend = RateBackend::Calibrated {
2518 spec: Arc::new(crate::CalibratedSpec {
2519 base: RateBackend::Ctw { depth: 8 },
2520 context: crate::CalibrationContextKind::Text,
2521 bins: 33,
2522 learning_rate: 0.02,
2523 bias_clip: 4.0,
2524 }),
2525 };
2526 let enc =
2527 compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2528 let dec =
2529 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2530 assert_eq!(dec, data);
2531 }
2532
2533 #[test]
2534 fn roundtrip_rate_ac_single_expert_ctw_neural_mixture() {
2535 let data = b"single expert neural ctw fast path payload";
2536 let spec = MixtureSpec::new(
2537 MixtureKind::Neural,
2538 vec![crate::MixtureExpertSpec {
2539 name: Some("ctw".to_string()),
2540 log_prior: 0.0,
2541 max_order: -1,
2542 backend: RateBackend::Ctw { depth: 8 },
2543 }],
2544 )
2545 .with_alpha(0.03);
2546 let backend = RateBackend::Mixture {
2547 spec: Arc::new(spec),
2548 };
2549 let enc =
2550 compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2551 let dec =
2552 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2553 assert_eq!(dec, data);
2554 }
2555
2556 #[test]
2557 fn roundtrip_rate_ac_single_expert_ctw_bayes_mixture() {
2558 let data = b"single expert bayes ctw fast path payload";
2559 let spec = MixtureSpec::new(
2560 MixtureKind::Bayes,
2561 vec![crate::MixtureExpertSpec {
2562 name: Some("ctw".to_string()),
2563 log_prior: 0.0,
2564 max_order: -1,
2565 backend: RateBackend::Ctw { depth: 8 },
2566 }],
2567 )
2568 .with_alpha(0.03);
2569 let backend = RateBackend::Mixture {
2570 spec: Arc::new(spec),
2571 };
2572 let enc =
2573 compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2574 let dec =
2575 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2576 assert_eq!(dec, data);
2577 }
2578
2579 #[test]
2580 fn roundtrip_rate_rans_recursive_mixture() {
2581 let data = b"recursive mixture payload";
2582 let nested = MixtureSpec::new(
2583 MixtureKind::Bayes,
2584 vec![
2585 crate::MixtureExpertSpec {
2586 name: Some("ctw".to_string()),
2587 log_prior: 0.0,
2588 max_order: -1,
2589 backend: RateBackend::Ctw { depth: 6 },
2590 },
2591 crate::MixtureExpertSpec {
2592 name: Some("fac".to_string()),
2593 log_prior: 0.0,
2594 max_order: -1,
2595 backend: RateBackend::FacCtw {
2596 base_depth: 6,
2597 num_percept_bits: 8,
2598 encoding_bits: 8,
2599 },
2600 },
2601 ],
2602 );
2603 let root = MixtureSpec::new(
2604 MixtureKind::Switching,
2605 vec![
2606 crate::MixtureExpertSpec {
2607 name: Some("nested".to_string()),
2608 log_prior: 0.0,
2609 max_order: -1,
2610 backend: RateBackend::Mixture {
2611 spec: Arc::new(nested),
2612 },
2613 },
2614 crate::MixtureExpertSpec {
2615 name: Some("zpaq".to_string()),
2616 log_prior: 0.0,
2617 max_order: -1,
2618 backend: RateBackend::Zpaq {
2619 method: "1".to_string(),
2620 },
2621 },
2622 ],
2623 )
2624 .with_alpha(0.05);
2625
2626 let backend = RateBackend::Mixture {
2627 spec: Arc::new(root),
2628 };
2629 let enc =
2630 compress_rate_bytes(data, &backend, -1, CoderType::RANS, FramingMode::Framed).unwrap();
2631 let dec = decompress_rate_bytes(&enc, &backend, -1, CoderType::RANS, FramingMode::Framed)
2632 .unwrap();
2633 assert_eq!(dec, data);
2634 }
2635
2636 #[test]
2637 fn roundtrip_rate_ac_recursive_neural_mixture() {
2638 let data = b"neural recursive mixture payload for ac coder";
2639 let inner = MixtureSpec::new(
2640 MixtureKind::Bayes,
2641 vec![
2642 crate::MixtureExpertSpec {
2643 name: Some("ctw".to_string()),
2644 log_prior: 0.0,
2645 max_order: -1,
2646 backend: RateBackend::Ctw { depth: 6 },
2647 },
2648 crate::MixtureExpertSpec {
2649 name: Some("fac".to_string()),
2650 log_prior: 0.0,
2651 max_order: -1,
2652 backend: RateBackend::FacCtw {
2653 base_depth: 6,
2654 num_percept_bits: 8,
2655 encoding_bits: 8,
2656 },
2657 },
2658 ],
2659 );
2660 let root = MixtureSpec::new(
2661 MixtureKind::Neural,
2662 vec![
2663 crate::MixtureExpertSpec {
2664 name: Some("nested".to_string()),
2665 log_prior: 0.0,
2666 max_order: -1,
2667 backend: RateBackend::Mixture {
2668 spec: Arc::new(inner),
2669 },
2670 },
2671 crate::MixtureExpertSpec {
2672 name: Some("zpaq".to_string()),
2673 log_prior: 0.0,
2674 max_order: -1,
2675 backend: RateBackend::Zpaq {
2676 method: "1".to_string(),
2677 },
2678 },
2679 ],
2680 )
2681 .with_alpha(0.03);
2682
2683 let backend = RateBackend::Mixture {
2684 spec: Arc::new(root),
2685 };
2686 let enc =
2687 compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2688 let dec =
2689 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2690 assert_eq!(dec, data);
2691 }
2692
2693 fn assert_runtime_and_compression_predictor_align(spec: MixtureSpec, data: &[u8], tol: f64) {
2694 let backend = RateBackend::Mixture {
2695 spec: Arc::new(spec.clone()),
2696 };
2697 let mut predictor = RatePdfPredictor::from_rate_backend(backend, -1).unwrap();
2698 let experts = spec.build_experts();
2699 let mut runtime = crate::mixture::build_mixture_runtime(&spec, &experts).unwrap();
2700
2701 for (t, &symbol) in data.iter().enumerate() {
2702 let pdf = predictor.pdf_next().unwrap();
2703 let p_comp = pdf[symbol as usize];
2704 let p_runtime = runtime.peek_log_prob(symbol).exp();
2705 assert!(
2706 (p_comp - p_runtime).abs() < tol,
2707 "t={t} p_comp={p_comp} p_runtime={p_runtime} symbol={symbol}"
2708 );
2709 predictor.update(symbol).unwrap();
2710 runtime.step(symbol);
2711 }
2712 }
2713
2714 fn alignment_experts() -> Vec<crate::MixtureExpertSpec> {
2715 vec![
2716 crate::MixtureExpertSpec {
2717 name: Some("ctw".to_string()),
2718 log_prior: 0.0,
2719 max_order: -1,
2720 backend: RateBackend::Ctw { depth: 7 },
2721 },
2722 crate::MixtureExpertSpec {
2723 name: Some("fac".to_string()),
2724 log_prior: -0.7,
2725 max_order: -1,
2726 backend: RateBackend::FacCtw {
2727 base_depth: 7,
2728 num_percept_bits: 8,
2729 encoding_bits: 8,
2730 },
2731 },
2732 ]
2733 }
2734
2735 #[test]
2736 fn bayes_runtime_and_compression_predictor_align() {
2737 let spec = MixtureSpec::new(MixtureKind::Bayes, alignment_experts());
2738 assert_runtime_and_compression_predictor_align(
2739 spec,
2740 b"bayes predictor alignment check sequence",
2741 1e-8,
2742 );
2743 }
2744
2745 #[test]
2746 fn fading_runtime_and_compression_predictor_align() {
2747 let spec = MixtureSpec::new(MixtureKind::FadingBayes, alignment_experts()).with_decay(0.97);
2748 assert_runtime_and_compression_predictor_align(
2749 spec,
2750 b"fading predictor alignment check sequence",
2751 1e-8,
2752 );
2753 }
2754
2755 #[test]
2756 fn switching_runtime_and_compression_predictor_align() {
2757 let spec = MixtureSpec::new(MixtureKind::Switching, alignment_experts()).with_alpha(0.17);
2758 assert_runtime_and_compression_predictor_align(
2759 spec,
2760 b"switching predictor alignment check sequence",
2761 1e-8,
2762 );
2763 }
2764
2765 #[test]
2766 fn switching_theorem_runtime_and_compression_predictor_align() {
2767 let spec = MixtureSpec::new(MixtureKind::Switching, alignment_experts())
2768 .with_schedule(MixtureScheduleMode::Theorem)
2769 .with_alpha(0.91);
2770 assert_runtime_and_compression_predictor_align(
2771 spec,
2772 b"switching theorem predictor alignment check sequence",
2773 1e-8,
2774 );
2775 }
2776
2777 #[test]
2778 fn convex_runtime_and_compression_predictor_align_for_alpha_above_one() {
2779 let spec = MixtureSpec::new(MixtureKind::Convex, alignment_experts()).with_alpha(1.25);
2780 assert_runtime_and_compression_predictor_align(
2781 spec,
2782 b"convex predictor alignment check sequence",
2783 1e-8,
2784 );
2785 }
2786
2787 #[test]
2788 fn convex_theorem_runtime_and_compression_predictor_align() {
2789 let spec = MixtureSpec::new(MixtureKind::Convex, alignment_experts())
2790 .with_schedule(MixtureScheduleMode::Theorem)
2791 .with_alpha(7.5);
2792 assert_runtime_and_compression_predictor_align(
2793 spec,
2794 b"convex theorem predictor alignment check sequence",
2795 1e-8,
2796 );
2797 }
2798
2799 #[test]
2800 fn neural_runtime_and_compression_predictor_align() {
2801 let spec = MixtureSpec::new(MixtureKind::Neural, alignment_experts()).with_alpha(0.03);
2802 assert_runtime_and_compression_predictor_align(
2803 spec,
2804 b"neural alignment check sequence",
2805 1e-8,
2806 );
2807 }
2808
2809 #[test]
2810 fn mdl_runtime_and_compression_predictor_align() {
2811 let spec = MixtureSpec::new(MixtureKind::Mdl, alignment_experts());
2812 assert_runtime_and_compression_predictor_align(spec, b"mdl alignment check sequence", 1e-8);
2813 }
2814
2815 #[test]
2816 fn nested_runtime_and_compression_predictor_align() {
2817 let nested = MixtureSpec::new(MixtureKind::Bayes, alignment_experts());
2818 let spec = MixtureSpec::new(
2819 MixtureKind::Switching,
2820 vec![
2821 crate::MixtureExpertSpec {
2822 name: Some("nested".to_string()),
2823 log_prior: 0.0,
2824 max_order: -1,
2825 backend: RateBackend::Mixture {
2826 spec: Arc::new(nested),
2827 },
2828 },
2829 crate::MixtureExpertSpec {
2830 name: Some("ppmd".to_string()),
2831 log_prior: -0.2,
2832 max_order: -1,
2833 backend: RateBackend::Ppmd {
2834 order: 5,
2835 memory_mb: 8,
2836 },
2837 },
2838 ],
2839 )
2840 .with_alpha(0.13);
2841 assert_runtime_and_compression_predictor_align(
2842 spec,
2843 b"nested mixture predictor alignment check sequence",
2844 1e-8,
2845 );
2846 }
2847
2848 fn assert_cached_cdf_fast_bitwise_matches_pdf_rows(mut predictor: RatePdfPredictor) {
2849 let data = b"cached cdf parity check payload";
2850 for &symbol in data {
2851 let pdf = predictor.pdf_next().unwrap().to_vec();
2852 assert!(predictor.prepare_cached_cdf_fast_bitwise().unwrap());
2853
2854 let mut row = [0.0; 257];
2855 row[0] = 0.0;
2856 for i in 0..256 {
2857 row[i + 1] = row[i] + pdf[i].max(PDF_MIN);
2858 }
2859
2860 let mut stack = vec![(0usize, 256usize)];
2861 while let Some((lo, hi)) = stack.pop() {
2862 if hi - lo <= 1 {
2863 continue;
2864 }
2865 let expected = cdf_bit_prob_one_msb(&row, lo, hi);
2866 let got = predictor
2867 .cached_cdf_bit_prob_one_msb(lo, hi)
2868 .expect("cached cdf branch probability");
2869 let diff = (expected - got).abs();
2870 assert!(
2871 diff <= 1e-12,
2872 "lo={lo} hi={hi} expected={expected} got={got} diff={diff}"
2873 );
2874 let mid = (lo + hi) >> 1;
2875 stack.push((lo, mid));
2876 stack.push((mid, hi));
2877 }
2878
2879 predictor.update(symbol).unwrap();
2880 }
2881 }
2882
2883 #[test]
2884 fn cached_cdf_fast_bitwise_matches_pdf_rows_for_specialized_predictors() {
2885 assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2886 RatePdfPredictor::from_rate_backend(RateBackend::RosaPlus, -1).unwrap(),
2887 );
2888 assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2889 RatePdfPredictor::from_rate_backend(
2890 RateBackend::Ppmd {
2891 order: 6,
2892 memory_mb: 8,
2893 },
2894 -1,
2895 )
2896 .unwrap(),
2897 );
2898 assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2899 RatePdfPredictor::from_rate_backend(
2900 RateBackend::Match {
2901 hash_bits: 20,
2902 min_len: 4,
2903 max_len: 255,
2904 base_mix: 0.02,
2905 confidence_scale: 1.0,
2906 },
2907 -1,
2908 )
2909 .unwrap(),
2910 );
2911 #[cfg(feature = "backend-rwkv")]
2912 assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2913 RatePdfPredictor::from_rate_backend(
2914 RateBackend::Rwkv7Method {
2915 method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=11,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
2916 },
2917 -1,
2918 )
2919 .unwrap(),
2920 );
2921 #[cfg(feature = "backend-mamba")]
2922 assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2923 RatePdfPredictor::from_rate_backend(
2924 RateBackend::MambaMethod {
2925 method: "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=7,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
2926 },
2927 -1,
2928 )
2929 .unwrap(),
2930 );
2931 }
2932
2933 #[test]
2934 fn raw_size_not_larger_than_framed_size() {
2935 let data = b"raw/framed size check payload";
2936 let backend = RateBackend::RosaPlus;
2937 let raw = compress_rate_size(data, &backend, 8, CoderType::AC, FramingMode::Raw).unwrap();
2938 let framed =
2939 compress_rate_size(data, &backend, 8, CoderType::AC, FramingMode::Framed).unwrap();
2940 assert!(framed >= raw);
2941 }
2942
2943 #[cfg(feature = "backend-rwkv")]
2944 #[test]
2945 fn roundtrip_rate_rwkv_method_cfg() {
2946 let data = b"rwkv cfg method backend";
2947 let backend = RateBackend::Rwkv7Method {
2948 method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=11,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
2949 };
2950 let enc =
2951 compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2952 let dec =
2953 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2954 assert_eq!(dec, data);
2955 }
2956
2957 #[cfg(feature = "backend-rwkv")]
2958 #[test]
2959 fn rwkv_rate_predictor_preserves_backend_pdf_exactly() {
2960 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=11,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer";
2961 let mut predictor = RwkvPredictor::from_method(method).expect("rwkv predictor");
2962 let mut backend = rwkvzip::Compressor::new_from_method(method).expect("rwkv backend");
2963 let mut direct = vec![0.0; backend.vocab_size()];
2964
2965 let predicted = predictor.pdf_next().to_vec();
2966 backend.forward_to_pdf(0, &mut direct);
2967 assert_pdf_close(&predicted, &direct, 1e-18);
2968
2969 predictor.update(b'x').expect("predictor update");
2970 backend
2971 .online_update_from_pdf(b'x', &direct)
2972 .expect("backend update");
2973 backend.forward_to_pdf(u32::from(b'x'), &mut direct);
2974 assert_pdf_close(predictor.pdf_next(), &direct, 1e-18);
2975 }
2976
2977 #[cfg(feature = "backend-rwkv")]
2978 #[test]
2979 fn rwkv_rate_predictor_matches_backend_after_partial_tbptt_stream() {
2980 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=29,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=8,clip=0,momentum=0.9)";
2981 let data = b"abcdefghij";
2982 let mut predictor = RwkvPredictor::from_method(method).expect("rwkv predictor");
2983 let mut backend = rwkvzip::Compressor::new_from_method(method).expect("rwkv backend");
2984 let mut direct = vec![0.0; backend.vocab_size()];
2985
2986 predictor
2987 .begin_stream(data.len())
2988 .expect("begin predictor stream");
2989 backend
2990 .begin_online_policy_stream(Some(data.len() as u64))
2991 .expect("begin backend stream");
2992 backend.reset_and_prime();
2993
2994 for &byte in data {
2995 let predicted = predictor.pdf_next().to_vec();
2996 backend.copy_current_pdf_to(&mut direct);
2997 assert_pdf_close(&predicted, &direct, 1e-18);
2998
2999 predictor.update(byte).expect("predictor update");
3000 backend
3001 .observe_symbol_from_current_pdf(byte)
3002 .expect("backend update");
3003 }
3004
3005 predictor.finish_stream().expect("finish predictor stream");
3006 backend
3007 .finish_online_policy_stream()
3008 .expect("finish backend stream");
3009 backend.copy_current_pdf_to(&mut direct);
3010 assert_pdf_close(predictor.pdf_next(), &direct, 1e-18);
3011 }
3012
3013 #[cfg(feature = "backend-rwkv")]
3014 #[test]
3015 fn roundtrip_rate_rwkv_two_json_method_2m() {
3016 let two_json: serde_json::Value =
3017 serde_json::from_str(include_str!("../../examples/two.json")).unwrap();
3018 let method = two_json["experts"]
3019 .as_array()
3020 .unwrap()
3021 .iter()
3022 .find(|expert| expert["name"].as_str() == Some("rwkv"))
3023 .and_then(|expert| expert["method"].as_str())
3024 .unwrap()
3025 .to_string();
3026
3027 let backend = RateBackend::Rwkv7Method { method };
3028 let seed = include_bytes!("../../README.md");
3029 let target_len = 2_097_152usize;
3030 let mut data = Vec::with_capacity(target_len);
3031 while data.len() < target_len {
3032 let remaining = target_len - data.len();
3033 data.extend_from_slice(&seed[..seed.len().min(remaining)]);
3034 }
3035
3036 let enc =
3037 compress_rate_bytes(&data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
3038 let dec =
3039 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
3040 assert_eq!(dec, data);
3041 }
3042
3043 #[cfg(feature = "backend-mamba")]
3044 #[test]
3045 fn mamba_rate_predictor_preserves_backend_pdf_exactly() {
3046 let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=7,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer";
3047 let mut predictor = MambaPredictor::from_method(method).expect("mamba predictor");
3048 let mut backend = mambazip::Compressor::new_from_method(method).expect("mamba backend");
3049 let mut direct = vec![0.0; backend.vocab_size()];
3050
3051 let predicted = predictor.pdf_next().to_vec();
3052 backend.forward_to_pdf(0, &mut direct);
3053 assert_pdf_close(&predicted, &direct, 1e-18);
3054
3055 predictor.update(b'x').expect("predictor update");
3056 backend
3057 .online_update_from_pdf(b'x', &direct)
3058 .expect("backend update");
3059 backend.forward_to_pdf(u32::from(b'x'), &mut direct);
3060 assert_pdf_close(predictor.pdf_next(), &direct, 1e-18);
3061 }
3062
3063 #[test]
3064 fn roundtrip_rate_ac_particle() {
3065 let spec = crate::ParticleSpec {
3066 num_particles: 4,
3067 num_cells: 4,
3068 cell_dim: 8,
3069 num_rules: 2,
3070 selector_hidden: 16,
3071 rule_hidden: 16,
3072 context_window: 8,
3073 unroll_steps: 1,
3074 ..crate::ParticleSpec::default()
3075 };
3076 let data = b"particle ac roundtrip payload";
3077 let backend = RateBackend::Particle {
3078 spec: Arc::new(spec),
3079 };
3080 let enc =
3081 compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
3082 let dec =
3083 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
3084 assert_eq!(dec, data);
3085 }
3086
3087 #[test]
3088 fn roundtrip_rate_rans_particle() {
3089 let spec = crate::ParticleSpec {
3090 num_particles: 4,
3091 num_cells: 4,
3092 cell_dim: 8,
3093 num_rules: 2,
3094 selector_hidden: 16,
3095 rule_hidden: 16,
3096 context_window: 8,
3097 unroll_steps: 1,
3098 ..crate::ParticleSpec::default()
3099 };
3100 let data = b"particle rans roundtrip payload";
3101 let backend = RateBackend::Particle {
3102 spec: Arc::new(spec),
3103 };
3104 let enc =
3105 compress_rate_bytes(data, &backend, -1, CoderType::RANS, FramingMode::Framed).unwrap();
3106 let dec = decompress_rate_bytes(&enc, &backend, -1, CoderType::RANS, FramingMode::Framed)
3107 .unwrap();
3108 assert_eq!(dec, data);
3109 }
3110
3111 #[test]
3112 fn mixture_with_particle_expert_roundtrip() {
3113 let particle_spec = crate::ParticleSpec {
3114 num_particles: 4,
3115 num_cells: 4,
3116 cell_dim: 8,
3117 num_rules: 2,
3118 selector_hidden: 16,
3119 rule_hidden: 16,
3120 context_window: 8,
3121 unroll_steps: 1,
3122 ..crate::ParticleSpec::default()
3123 };
3124 let spec = MixtureSpec::new(
3125 MixtureKind::Bayes,
3126 vec![
3127 crate::MixtureExpertSpec {
3128 name: Some("particle".to_string()),
3129 log_prior: 0.0,
3130 max_order: -1,
3131 backend: RateBackend::Particle {
3132 spec: Arc::new(particle_spec),
3133 },
3134 },
3135 crate::MixtureExpertSpec {
3136 name: Some("ctw".to_string()),
3137 log_prior: 0.0,
3138 max_order: -1,
3139 backend: RateBackend::Ctw { depth: 6 },
3140 },
3141 ],
3142 );
3143 let backend = RateBackend::Mixture {
3144 spec: Arc::new(spec),
3145 };
3146 let data = b"mixture with particle expert roundtrip";
3147 let enc =
3148 compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
3149 let dec =
3150 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
3151 assert_eq!(dec, data);
3152 }
3153}