1use anyhow::{Result, bail};
9
10use crate::backends::calibration::CalibratorCore;
11use crate::backends::match_model::MatchModel;
12use crate::backends::ppmd::PpmdModel;
13use crate::backends::sparse_match::SparseMatchModel;
14use crate::backends::text_context::TextContextAnalyzer;
15use crate::coders::{
16 ANS_TOTAL, ArithmeticDecoder, ArithmeticEncoder, BlockedRansDecoder, BlockedRansEncoder,
17 CDF_TOTAL, Cdf, CoderType, crc32, quantize_pdf_to_rans_cdf_with_buffer,
18};
19use crate::ctw::FacContextTree;
20#[cfg(feature = "backend-mamba")]
21use crate::mambazip;
22use crate::mixture::DEFAULT_MIN_PROB;
23use crate::neural_mix::NeuralMixCore;
24use crate::rosaplus::RosaPlus;
25#[cfg(feature = "backend-rwkv")]
26use crate::rwkvzip;
27use crate::zpaq_rate::ZpaqRateModel;
28use crate::{CalibratedSpec, MixtureKind, MixtureSpec, RateBackend};
29
30const FRAMED_MAGIC: u32 = 0x4354_4946; const FRAMED_VERSION: u8 = 1;
32const PDF_MIN: f64 = DEFAULT_MIN_PROB;
33
34#[inline]
35fn build_calibrator(spec: &CalibratedSpec) -> CalibratorCore {
36 CalibratorCore::new(spec.context, spec.bins, spec.learning_rate, spec.bias_clip)
37}
38
39#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
40pub enum FramingMode {
42 Raw,
44 #[default]
46 Framed,
47}
48
49#[derive(Clone, Copy, Debug)]
50struct FramedHeader {
51 magic: u32,
52 version: u8,
53 coder: u8,
54 original_len: u64,
55 crc32: u32,
56}
57
58impl FramedHeader {
59 const SIZE: usize = 4 + 1 + 1 + 8 + 4;
60
61 fn new(coder: CoderType, original_len: u64, crc32: u32) -> Self {
62 Self {
63 magic: FRAMED_MAGIC,
64 version: FRAMED_VERSION,
65 coder: match coder {
66 CoderType::AC => 0,
67 CoderType::RANS => 1,
68 },
69 original_len,
70 crc32,
71 }
72 }
73
74 fn write(&self, out: &mut Vec<u8>) {
75 out.extend_from_slice(&self.magic.to_le_bytes());
76 out.push(self.version);
77 out.push(self.coder);
78 out.extend_from_slice(&self.original_len.to_le_bytes());
79 out.extend_from_slice(&self.crc32.to_le_bytes());
80 }
81
82 fn read(input: &[u8]) -> Result<Self> {
83 if input.len() < Self::SIZE {
84 bail!("framed payload too short");
85 }
86 let magic = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
87 if magic != FRAMED_MAGIC {
88 bail!("invalid framed magic: expected 0x{FRAMED_MAGIC:08X}, got 0x{magic:08X}");
89 }
90 let version = input[4];
91 if version != FRAMED_VERSION {
92 bail!("unsupported framed version: {version}");
93 }
94 let coder = input[5];
95 let original_len = u64::from_le_bytes([
96 input[6], input[7], input[8], input[9], input[10], input[11], input[12], input[13],
97 ]);
98 let crc32 = u32::from_le_bytes([input[14], input[15], input[16], input[17]]);
99 Ok(Self {
100 magic,
101 version,
102 coder,
103 original_len,
104 crc32,
105 })
106 }
107
108 fn coder_type(&self) -> CoderType {
109 match self.coder {
110 0 => CoderType::AC,
111 _ => CoderType::RANS,
112 }
113 }
114}
115
116#[derive(Clone)]
117struct CtwPredictor {
118 tree: FacContextTree,
119 bits_per_symbol: usize,
120 msb_first: bool,
121 pdf: Vec<f64>,
122 pattern_logps: Vec<f64>,
123 valid: bool,
124}
125
126impl CtwPredictor {
127 fn new_ctw(depth: usize) -> Self {
128 Self {
129 tree: FacContextTree::new(depth, 8),
130 bits_per_symbol: 8,
131 msb_first: true,
132 pdf: vec![0.0; 256],
133 pattern_logps: vec![f64::NEG_INFINITY; 256],
134 valid: false,
135 }
136 }
137
138 fn new_fac(base_depth: usize, bits_per_symbol: usize) -> Self {
139 Self {
140 tree: FacContextTree::new(base_depth, bits_per_symbol),
141 bits_per_symbol,
142 msb_first: false,
143 pdf: vec![0.0; 256],
144 pattern_logps: vec![f64::NEG_INFINITY; 256],
145 valid: false,
146 }
147 }
148
149 fn fill_pattern_log_probs(&mut self) -> usize {
150 fn rec(
151 tree: &mut FacContextTree,
152 bits: usize,
153 msb_first: bool,
154 depth: usize,
155 pattern: usize,
156 log_before: f64,
157 out: &mut [f64],
158 ) {
159 if depth == bits {
160 out[pattern] = tree.get_log_block_probability() - log_before;
161 return;
162 }
163 for bit in [false, true] {
164 tree.update(bit, depth);
165 let next_pattern = if msb_first {
166 (pattern << 1) | (bit as usize)
167 } else {
168 pattern | ((bit as usize) << depth)
169 };
170 rec(
171 tree,
172 bits,
173 msb_first,
174 depth + 1,
175 next_pattern,
176 log_before,
177 out,
178 );
179 tree.revert(depth);
180 }
181 }
182
183 let bits = self.bits_per_symbol.clamp(1, 8);
184 let patterns = 1usize << bits;
185 let log_before = self.tree.get_log_block_probability();
186 self.pattern_logps[..patterns].fill(f64::NEG_INFINITY);
187 rec(
188 &mut self.tree,
189 bits,
190 self.msb_first,
191 0,
192 0,
193 log_before,
194 &mut self.pattern_logps[..patterns],
195 );
196 patterns
197 }
198
199 #[cfg(test)]
200 fn log_prob_symbol_bruteforce(&mut self, symbol: u8) -> f64 {
201 let bits = self.bits_per_symbol.clamp(1, 8);
202 let before = self.tree.get_log_block_probability();
203 if self.msb_first {
204 for bit_idx in 0..bits {
205 let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
206 self.tree.update(bit, bit_idx);
207 }
208 let after = self.tree.get_log_block_probability();
209 for bit_idx in (0..bits).rev() {
210 self.tree.revert(bit_idx);
211 }
212 after - before
213 } else {
214 for bit_idx in 0..bits {
215 let bit = ((symbol >> bit_idx) & 1) == 1;
216 self.tree.update(bit, bit_idx);
217 }
218 let after = self.tree.get_log_block_probability();
219 for bit_idx in (0..bits).rev() {
220 self.tree.revert(bit_idx);
221 }
222 after - before
223 }
224 }
225
226 fn normalize_pdf(pdf: &mut [f64]) {
227 let mut sum = 0.0f64;
228 for p in pdf.iter_mut() {
229 let v = if p.is_finite() { *p } else { 0.0 };
230 *p = v.max(PDF_MIN);
231 sum += *p;
232 }
233 if sum <= 0.0 || !sum.is_finite() {
234 let u = 1.0 / (pdf.len() as f64);
235 for p in pdf.iter_mut() {
236 *p = u;
237 }
238 return;
239 }
240 let inv = 1.0 / sum;
241 for p in pdf.iter_mut() {
242 *p *= inv;
243 }
244 }
245
246 fn pdf_next(&mut self) -> &[f64] {
247 if !self.valid {
248 let bits = self.bits_per_symbol.clamp(1, 8);
249 let patterns = self.fill_pattern_log_probs();
250 if bits == 8 {
251 for sym in 0..256usize {
252 self.pdf[sym] = self.pattern_logps[sym].exp();
253 }
254 } else {
255 let aliases = 1usize << (8 - bits);
256 for byte in 0..256usize {
257 let pat = if self.msb_first {
258 byte >> (8 - bits)
259 } else {
260 byte & (patterns - 1)
261 };
262 self.pdf[byte] = self.pattern_logps[pat].exp() / (aliases as f64);
263 }
264 }
265 Self::normalize_pdf(&mut self.pdf);
266 self.valid = true;
267 }
268 &self.pdf
269 }
270
271 fn update(&mut self, symbol: u8) {
272 if self.msb_first {
273 for bit_idx in 0..self.bits_per_symbol {
274 let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
275 self.tree.update(bit, bit_idx);
276 }
277 } else {
278 for bit_idx in 0..self.bits_per_symbol {
279 let bit = ((symbol >> bit_idx) & 1) == 1;
280 self.tree.update(bit, bit_idx);
281 }
282 }
283 self.valid = false;
284 }
285
286 #[inline]
287 fn can_fast_ac_bitwise(&self) -> bool {
288 self.bits_per_symbol == 8 && self.msb_first
289 }
290
291 #[inline]
292 fn bit_prob_one_msb(&mut self, bit_idx: usize) -> f64 {
293 debug_assert!(self.can_fast_ac_bitwise());
294 self.tree.predict_one(bit_idx).clamp(PDF_MIN, 1.0 - PDF_MIN)
295 }
296
297 #[inline]
298 fn update_bit_msb(&mut self, bit_idx: usize, bit: bool) {
299 debug_assert!(self.can_fast_ac_bitwise());
300 self.tree.update_predicted(bit, bit_idx);
301 self.valid = false;
302 }
303}
304
305#[derive(Clone)]
306struct RosaPredictor {
307 model: RosaPlus,
308 pdf: Vec<f64>,
309 cdf: [f64; 257],
310 valid: bool,
311 cdf_valid: bool,
312}
313
314impl RosaPredictor {
315 fn new(max_order: i64) -> Self {
316 let mut model = RosaPlus::new(max_order, false, 0, 42);
317 model.build_lm_full_bytes_no_finalize_endpos();
318 Self {
319 model,
320 pdf: vec![0.0; 256],
321 cdf: uniform_cdf_row(),
322 valid: false,
323 cdf_valid: false,
324 }
325 }
326
327 fn pdf_next(&mut self) -> &[f64] {
328 self.ensure_pdf(false);
329 &self.pdf
330 }
331
332 fn cdf_next(&mut self) -> &[f64; 257] {
333 self.ensure_pdf(true);
334 &self.cdf
335 }
336
337 fn ensure_pdf(&mut self, want_cdf: bool) {
338 if self.valid {
339 if want_cdf && !self.cdf_valid {
340 build_cdf_row_from_pdf_slice(&self.pdf, &mut self.cdf);
341 self.cdf_valid = true;
342 }
343 return;
344 }
345 self.model.fill_probs_for_last_bytes(&mut self.pdf);
346 normalize_pdf_vec_and_maybe_build_cdf(
347 &mut self.pdf,
348 if want_cdf { Some(&mut self.cdf) } else { None },
349 );
350 self.valid = true;
351 self.cdf_valid = want_cdf;
352 }
353
354 fn update(&mut self, symbol: u8) {
355 self.model.train_byte(symbol);
356 self.valid = false;
357 self.cdf_valid = false;
358 }
359
360 fn begin_stream(&mut self, total_len: usize) {
361 self.model.reserve_for_stream(total_len);
362 }
363}
364
365#[derive(Clone)]
366#[cfg(feature = "backend-mamba")]
367struct MambaPredictor {
368 compressor: mambazip::Compressor,
369 primed: bool,
370 pdf: Vec<f64>,
371 cdf: [f64; 257],
372 valid: bool,
373 cdf_valid: bool,
374}
375
376#[derive(Clone)]
377#[cfg(feature = "backend-rwkv")]
378struct RwkvPredictor {
379 compressor: rwkvzip::Compressor,
380 primed: bool,
381 cdf: [f64; 257],
382 cdf_valid: bool,
383}
384
385#[derive(Clone)]
386struct ZpaqPredictor {
387 method: String,
388 history: Vec<u8>,
389 pdf: Vec<f64>,
390 valid: bool,
391}
392
393impl ZpaqPredictor {
394 fn new(method: String) -> Self {
395 Self {
396 method,
397 history: Vec::new(),
398 pdf: vec![0.0; 256],
399 valid: false,
400 }
401 }
402
403 fn pdf_next(&mut self) -> &[f64] {
404 if !self.valid {
405 for sym in 0..256usize {
406 let mut model = ZpaqRateModel::new(self.method.clone(), PDF_MIN);
407 if !self.history.is_empty() {
408 let _ = model.update_and_score(&self.history);
409 }
410 let logp = model.log_prob(sym as u8);
411 self.pdf[sym] = logp.exp().max(PDF_MIN);
412 }
413 normalize_pdf(&mut self.pdf);
414 self.valid = true;
415 }
416 &self.pdf
417 }
418
419 fn update(&mut self, symbol: u8) {
420 self.history.push(symbol);
421 self.valid = false;
422 }
423}
424
425#[cfg(feature = "backend-mamba")]
426impl MambaPredictor {
427 fn from_model(model: std::sync::Arc<mambazip::Model>) -> Self {
428 let compressor = mambazip::Compressor::new_from_model(model);
429 let vocab = compressor.vocab_size();
430 Self {
431 compressor,
432 primed: false,
433 pdf: vec![0.0; vocab],
434 cdf: uniform_cdf_row(),
435 valid: false,
436 cdf_valid: false,
437 }
438 }
439
440 fn from_method(method: &str) -> Result<Self> {
441 let compressor = mambazip::Compressor::new_from_method(method)?;
442 let vocab = compressor.vocab_size();
443 Ok(Self {
444 compressor,
445 primed: false,
446 pdf: vec![0.0; vocab],
447 cdf: uniform_cdf_row(),
448 valid: false,
449 cdf_valid: false,
450 })
451 }
452
453 fn ensure_predicted(&mut self, want_cdf: bool) {
454 if self.valid {
455 if want_cdf && !self.cdf_valid {
456 debug_assert!(self.pdf.len() >= 256);
457 build_cdf_row_from_pdf_slice(&self.pdf[..256], &mut self.cdf);
458 self.cdf_valid = true;
459 }
460 return;
461 }
462 if !self.primed {
463 self.compressor.forward_to_pdf(0, &mut self.pdf);
464 self.primed = true;
465 self.valid = true;
466 self.cdf_valid = false;
467 if want_cdf {
468 debug_assert!(self.pdf.len() >= 256);
469 build_cdf_row_from_pdf_slice(&self.pdf[..256], &mut self.cdf);
470 self.cdf_valid = true;
471 }
472 return;
473 }
474 self.valid = true;
475 self.cdf_valid = false;
476 if want_cdf {
477 debug_assert!(self.pdf.len() >= 256);
478 build_cdf_row_from_pdf_slice(&self.pdf[..256], &mut self.cdf);
479 self.cdf_valid = true;
480 }
481 }
482
483 fn pdf_next(&mut self) -> &[f64] {
484 self.ensure_predicted(false);
485 &self.pdf
486 }
487
488 fn cdf_next(&mut self) -> &[f64; 257] {
489 self.ensure_predicted(true);
490 &self.cdf
491 }
492
493 fn update(&mut self, symbol: u8) -> Result<()> {
494 self.ensure_predicted(false);
495 self.compressor.online_update_from_pdf(symbol, &self.pdf)?;
496 self.compressor.forward_to_pdf(symbol as u32, &mut self.pdf);
497 self.valid = true;
498 self.cdf_valid = false;
499 Ok(())
500 }
501
502 fn begin_stream(&mut self, total_len: usize) -> Result<()> {
503 self.compressor
504 .begin_online_policy_stream(Some(total_len as u64))
505 }
506}
507
508#[cfg(feature = "backend-rwkv")]
509impl RwkvPredictor {
510 fn from_model(model: std::sync::Arc<rwkvzip::Model>) -> Self {
511 let compressor = rwkvzip::Compressor::new_from_model(model);
512 Self {
513 compressor,
514 primed: false,
515 cdf: uniform_cdf_row(),
516 cdf_valid: false,
517 }
518 }
519
520 fn from_method(method: &str) -> Result<Self> {
521 let compressor = rwkvzip::Compressor::new_from_method(method)?;
522 Ok(Self {
523 compressor,
524 primed: false,
525 cdf: uniform_cdf_row(),
526 cdf_valid: false,
527 })
528 }
529
530 fn ensure_predicted(&mut self, want_cdf: bool) {
531 if !self.primed {
532 self.compressor.reset_and_prime();
533 self.primed = true;
534 self.cdf_valid = false;
535 }
536 if want_cdf && !self.cdf_valid {
537 debug_assert!(self.compressor.pdf_buffer.len() >= 256);
538 build_cdf_row_from_pdf_slice(&self.compressor.pdf_buffer[..256], &mut self.cdf);
539 self.cdf_valid = true;
540 }
541 }
542
543 fn pdf_next(&mut self) -> &[f64] {
544 self.ensure_predicted(false);
545 &self.compressor.pdf_buffer
546 }
547
548 fn cdf_next(&mut self) -> &[f64; 257] {
549 self.ensure_predicted(true);
550 &self.cdf
551 }
552
553 fn update(&mut self, symbol: u8) -> Result<()> {
554 self.ensure_predicted(false);
555 self.compressor.observe_symbol_from_current_pdf(symbol)?;
556 self.cdf_valid = false;
557 Ok(())
558 }
559
560 fn begin_stream(&mut self, total_len: usize) -> Result<()> {
561 self.compressor
562 .begin_online_policy_stream(Some(total_len as u64))
563 }
564
565 fn finish_stream(&mut self) -> Result<()> {
566 self.compressor.finish_online_policy_stream()
567 }
568}
569
570#[derive(Clone)]
571struct MixExpert {
572 predictor: Box<RatePdfPredictor>,
573 log_weight: f64,
574 log_prior: f64,
575 cum_log_loss: f64,
576}
577
578#[derive(Clone)]
579struct MixturePredictor {
580 kind: MixtureKind,
581 alpha: f64,
582 decay: f64,
583 experts: Vec<MixExpert>,
584 neural: NeuralMixCore,
585 analyzer: TextContextAnalyzer,
586 neural_logps: Vec<f64>,
587 neural_bit_modes: Vec<u8>,
588 neural_lo: Vec<usize>,
589 neural_hi: Vec<usize>,
590 neural_pdf_cdf_rows: Vec<Vec<f64>>,
591 scratch: Vec<f64>,
592 scratch2: Vec<f64>,
593 pdf: Vec<f64>,
594 valid: bool,
595}
596
597impl MixturePredictor {
598 fn new(spec: &MixtureSpec) -> Result<Self> {
599 if spec.experts.is_empty() {
600 bail!("mixture spec must include at least one expert");
601 }
602 let mut experts = Vec::with_capacity(spec.experts.len());
603 for e in &spec.experts {
604 experts.push(MixExpert {
605 predictor: Box::new(RatePdfPredictor::from_rate_backend(
606 e.backend.clone(),
607 e.max_order,
608 )?),
609 log_weight: e.log_prior,
610 log_prior: e.log_prior,
611 cum_log_loss: 0.0,
612 });
613 }
614 let m = logsumexp_expert_weights(&experts);
615 for e in &mut experts {
616 e.log_weight -= m;
617 }
618
619 let mut prior_weights = vec![0.0; experts.len()];
620 for (i, e) in experts.iter().enumerate() {
621 let p = (e.log_weight).exp().clamp(PDF_MIN, 1.0 - PDF_MIN);
622 prior_weights[i] = p;
623 }
624
625 let base_lr = spec.alpha.abs().clamp(1e-6, 1.0);
626 let effective_lr = (base_lr * 25.0).clamp(1e-6, 1.0);
627 let analyzer = TextContextAnalyzer::new();
628 let mut neural = NeuralMixCore::new(
629 experts.len(),
630 &prior_weights,
631 effective_lr * 0.5,
632 effective_lr,
633 1e-5,
634 );
635 neural.set_context_state(analyzer.state());
636 Ok(Self {
637 kind: spec.kind,
638 alpha: spec.alpha.clamp(1e-12, 1.0 - 1e-12),
639 decay: spec.decay.unwrap_or(1.0).clamp(0.0, 1.0),
640 experts,
641 neural,
642 analyzer,
643 neural_logps: vec![0.0; spec.experts.len()],
644 neural_bit_modes: vec![0; spec.experts.len()],
645 neural_lo: vec![0; spec.experts.len()],
646 neural_hi: vec![256; spec.experts.len()],
647 neural_pdf_cdf_rows: vec![vec![0.0; 257]; spec.experts.len()],
648 scratch: Vec::new(),
649 scratch2: Vec::new(),
650 pdf: vec![0.0; 256],
651 valid: false,
652 })
653 }
654
655 fn ensure_pdf(&mut self) -> Result<&[f64]> {
656 if self.valid {
657 return Ok(&self.pdf);
658 }
659 match self.kind {
660 MixtureKind::Neural => {
661 if self.experts.len() == 1 {
662 self.pdf.fill(0.0);
663 let epdf = self.experts[0].predictor.pdf_next()?;
664 self.pdf.copy_from_slice(epdf);
665 normalize_pdf(&mut self.pdf);
666 self.valid = true;
667 return Ok(&self.pdf);
668 }
669 self.neural.set_context_state(self.analyzer.state());
670 self.neural.evaluate_expert_weights();
671 let n = self.experts.len();
672 self.scratch.resize(n, 0.0);
673 self.scratch.copy_from_slice(self.neural.expert_weights());
674 self.pdf.fill(0.0);
675 for i in 0..n {
676 let epdf = self.experts[i].predictor.pdf_next()?;
677 let w = self.scratch[i];
678 for (pdf_slot, &p) in self.pdf.iter_mut().zip(epdf.iter()) {
679 *pdf_slot += w * p;
680 }
681 }
682 normalize_pdf(&mut self.pdf);
683 self.valid = true;
684 return Ok(&self.pdf);
685 }
686 _ => {
687 self.pdf.fill(0.0);
688
689 let lw_norm = logsumexp_expert_weights(&self.experts);
690 for e in &mut self.experts {
691 let w = (e.log_weight - lw_norm).exp();
692 let epdf = e.predictor.pdf_next()?;
693 for (i, p) in epdf.iter().enumerate().take(256) {
694 self.pdf[i] += w * *p;
695 }
696 }
697 }
698 }
699
700 normalize_pdf(&mut self.pdf);
701 self.valid = true;
702 Ok(&self.pdf)
703 }
704
705 fn begin_stream(&mut self, total_len: usize) -> Result<()> {
706 for expert in &mut self.experts {
707 match &mut *expert.predictor {
708 RatePdfPredictor::Ctw(_) | RatePdfPredictor::FacCtw(_) => {}
711 _ => expert.predictor.begin_stream(total_len)?,
712 }
713 }
714 Ok(())
715 }
716
717 fn update(&mut self, symbol: u8) -> Result<()> {
718 let _ = self.ensure_pdf()?;
719
720 match self.kind {
721 MixtureKind::Bayes => {
722 let n = self.experts.len();
723 self.scratch.resize(n, 0.0);
724 self.scratch2.resize(n, 0.0);
725 for (i, e) in self.experts.iter_mut().enumerate() {
726 let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
727 let lp = p.ln();
728 self.scratch[i] = lp;
729 self.scratch2[i] = e.log_weight + lp;
730 }
731 let log_mix = logsumexp_slice(&self.scratch2[..n]);
732 for (i, e) in self.experts.iter_mut().enumerate() {
733 e.log_weight = e.log_weight + self.scratch[i] - log_mix;
734 e.cum_log_loss -= self.scratch[i];
735 e.predictor.update(symbol)?;
736 }
737 }
738 MixtureKind::FadingBayes => {
739 let n = self.experts.len();
740 self.scratch.resize(n, 0.0);
741 self.scratch2.resize(n, 0.0);
742 for (i, e) in self.experts.iter_mut().enumerate() {
743 let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
744 let lp = p.ln();
745 self.scratch[i] = lp;
746 self.scratch2[i] = e.log_weight + lp;
747 }
748 for (i, e) in self.experts.iter_mut().enumerate() {
749 self.scratch2[i] = self.decay * e.log_weight + self.scratch[i];
750 }
751 let log_mix = logsumexp_slice(&self.scratch2[..n]);
752 for (i, e) in self.experts.iter_mut().enumerate() {
753 e.log_weight = self.decay * e.log_weight + self.scratch[i] - log_mix;
754 e.cum_log_loss -= self.scratch[i];
755 e.predictor.update(symbol)?;
756 }
757 }
758 MixtureKind::Switching => {
759 let n = self.experts.len();
760 self.scratch.resize(n, 0.0);
761 self.scratch2.resize(n, 0.0);
762 for (i, e) in self.experts.iter_mut().enumerate() {
763 let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
764 let lp = p.ln();
765 self.scratch[i] = lp;
766 self.scratch2[i] = e.log_weight + lp;
767 }
768 let log_alpha = self.alpha.ln();
769 let log_1m_alpha = (1.0 - self.alpha).ln();
770 for (i, e) in self.experts.iter_mut().enumerate() {
771 let switched = logsumexp2(log_1m_alpha + e.log_weight, log_alpha + e.log_prior);
772 self.scratch2[i] = switched + self.scratch[i];
773 }
774 let log_mix = logsumexp_slice(&self.scratch2[..n]);
775 for (i, e) in self.experts.iter_mut().enumerate() {
776 e.log_weight = self.scratch2[i] - log_mix;
777 e.cum_log_loss -= self.scratch[i];
778 e.predictor.update(symbol)?;
779 }
780 }
781 MixtureKind::Mdl => {
782 let n = self.experts.len();
783 self.scratch.resize(n, 0.0);
784 for (i, e) in self.experts.iter_mut().enumerate() {
785 let p = e.predictor.pdf_next()?[symbol as usize].max(PDF_MIN);
786 let lp = p.ln();
787 self.scratch[i] = lp;
788 }
789 for (i, e) in self.experts.iter_mut().enumerate() {
790 e.cum_log_loss -= self.scratch[i];
791 e.predictor.update(symbol)?;
792 }
793 }
794 MixtureKind::Neural => {
795 let y = symbol as usize;
796 if self.experts.len() == 1 {
797 let lp = self.experts[0].predictor.pdf_next()?[y].max(PDF_MIN).ln();
798 self.experts[0].cum_log_loss -= lp;
799 self.experts[0].predictor.update(symbol)?;
800 self.analyzer.update(symbol);
801 self.neural.set_context_state(self.analyzer.state());
802 self.valid = false;
803 return Ok(());
804 }
805 let n = self.experts.len();
806 self.neural.set_context_state(self.analyzer.state());
807 self.neural_logps.resize(n, 0.0);
808 for i in 0..n {
809 let p = self.experts[i].predictor.pdf_next()?[y].max(PDF_MIN);
810 let lp = p.ln();
811 self.neural_logps[i] = lp;
812 self.experts[i].cum_log_loss -= lp;
813 }
814 self.neural.evaluate_symbol(&self.neural_logps, PDF_MIN);
815 self.neural
816 .update_weights_symbol(&self.neural_logps, PDF_MIN);
817 for e in &mut self.experts {
818 e.predictor.update(symbol)?;
819 }
820 self.analyzer.update(symbol);
821 self.neural.set_context_state(self.analyzer.state());
822 }
823 }
824
825 self.valid = false;
826 Ok(())
827 }
828
829 fn finish_stream(&mut self) -> Result<()> {
830 for expert in &mut self.experts {
831 expert.predictor.finish_stream()?;
832 }
833 Ok(())
834 }
835
836 #[inline]
837 fn can_fast_ac_bitwise(&self) -> bool {
838 self.experts.iter().any(|e| {
839 if let RatePdfPredictor::Ctw(ctw) = &*e.predictor {
840 ctw.can_fast_ac_bitwise()
841 } else {
842 false
843 }
844 })
845 }
846
847 fn ac_step_bitwise<F>(&mut self, mut choose_bit: F) -> Result<u8>
848 where
849 F: FnMut(usize, f64) -> Result<u8>,
850 {
851 let n = self.experts.len();
852 self.scratch.resize(n, 0.0);
853 match self.kind {
854 MixtureKind::Neural if n > 1 => {
855 self.neural.set_context_state(self.analyzer.state());
856 self.neural.evaluate_expert_weights();
857 self.scratch.copy_from_slice(self.neural.expert_weights());
858 }
859 _ => {
860 let lw_norm = logsumexp_expert_weights(&self.experts);
861 for (i, expert) in self.experts.iter().enumerate() {
862 self.scratch[i] = (expert.log_weight - lw_norm).exp();
863 }
864 }
865 }
866 self.scratch2.resize(n, 1.0);
867 self.scratch2.fill(1.0);
868 self.neural_logps.resize(n, 0.0);
869 self.neural_bit_modes.resize(n, 0);
870 self.neural_lo.resize(n, 0);
871 self.neural_hi.resize(n, 256);
872 if self.neural_pdf_cdf_rows.len() < n {
873 self.neural_pdf_cdf_rows.resize_with(n, || vec![0.0; 257]);
874 }
875
876 for i in 0..n {
877 self.neural_bit_modes[i] = 1;
878 self.neural_lo[i] = 0;
879 self.neural_hi[i] = 256;
880
881 let mut handled_ctw = false;
882 if let RatePdfPredictor::Ctw(ctw) = &mut *self.experts[i].predictor
883 && ctw.can_fast_ac_bitwise()
884 {
885 self.neural_bit_modes[i] = 0;
886 handled_ctw = true;
887 }
888 if handled_ctw {
889 continue;
890 }
891
892 if self.experts[i]
893 .predictor
894 .prepare_cached_cdf_fast_bitwise()?
895 {
896 self.neural_bit_modes[i] = 2;
897 continue;
898 }
899
900 let pdf = self.experts[i].predictor.pdf_next()?;
901 let row = &mut self.neural_pdf_cdf_rows[i];
902 if row.len() != 257 {
903 row.resize(257, 0.0);
904 }
905 row[0] = 0.0;
906 for b in 0..256usize {
907 row[b + 1] = row[b] + pdf[b].max(PDF_MIN);
908 }
909 if !row[256].is_finite() || row[256] <= 0.0 {
910 for (j, v) in row.iter_mut().enumerate() {
911 *v = (j as f64) / 256.0;
912 }
913 }
914 }
915
916 let mut symbol = 0u8;
917 for bit_idx in 0..8usize {
918 let mut denom = 0.0;
919 let mut numer1 = 0.0;
920
921 for i in 0..n {
922 let p1 = if self.neural_bit_modes[i] == 0 {
923 match &mut *self.experts[i].predictor {
924 RatePdfPredictor::Ctw(ctw) => ctw.bit_prob_one_msb(bit_idx),
925 _ => 0.5,
926 }
927 } else if self.neural_bit_modes[i] == 2 {
928 self.experts[i]
929 .predictor
930 .cached_cdf_bit_prob_one_msb(self.neural_lo[i], self.neural_hi[i])
931 .unwrap_or(0.5)
932 } else {
933 let lo = self.neural_lo[i];
934 let hi = self.neural_hi[i];
935 let mid = (lo + hi) >> 1;
936 let row = &self.neural_pdf_cdf_rows[i];
937 let total = (row[hi] - row[lo]).max(PDF_MIN);
938 let one = (row[hi] - row[mid]).max(0.0);
939 (one / total).clamp(PDF_MIN, 1.0 - PDF_MIN)
940 };
941 self.neural_logps[i] = p1;
942 let wp = self.scratch[i] * self.scratch2[i];
943 denom += wp;
944 numer1 += wp * p1;
945 }
946
947 let p1_mix = if denom.is_finite() && denom > 0.0 {
948 (numer1 / denom).clamp(PDF_MIN, 1.0 - PDF_MIN)
949 } else {
950 0.5
951 };
952 let bit = choose_bit(bit_idx, p1_mix)? & 1;
953 symbol |= bit << (7 - bit_idx);
954
955 for i in 0..n {
956 let p1 = self.neural_logps[i];
957 let pb = if bit == 1 { p1 } else { 1.0 - p1 };
958 self.scratch2[i] = (self.scratch2[i] * pb).max(PDF_MIN);
959
960 if self.neural_bit_modes[i] == 0 {
961 if let RatePdfPredictor::Ctw(ctw) = &mut *self.experts[i].predictor {
962 ctw.update_bit_msb(bit_idx, bit == 1);
963 }
964 } else {
965 let lo = self.neural_lo[i];
966 let hi = self.neural_hi[i];
967 let mid = (lo + hi) >> 1;
968 if bit == 1 {
969 self.neural_lo[i] = mid;
970 self.neural_hi[i] = hi;
971 } else {
972 self.neural_lo[i] = lo;
973 self.neural_hi[i] = mid;
974 }
975 }
976 }
977 }
978
979 for i in 0..n {
980 let lp = self.scratch2[i].max(PDF_MIN).ln();
981 self.neural_logps[i] = lp;
982 self.experts[i].cum_log_loss -= lp;
983 if self.neural_bit_modes[i] != 0 {
984 self.experts[i].predictor.update(symbol)?;
985 }
986 }
987
988 match self.kind {
989 MixtureKind::Bayes => {
990 for i in 0..n {
991 self.scratch[i] = self.experts[i].log_weight + self.neural_logps[i];
992 }
993 let log_mix = logsumexp_slice(&self.scratch[..n]);
994 for i in 0..n {
995 self.experts[i].log_weight += self.neural_logps[i] - log_mix;
996 }
997 }
998 MixtureKind::FadingBayes => {
999 for i in 0..n {
1000 self.scratch[i] =
1001 self.decay * self.experts[i].log_weight + self.neural_logps[i];
1002 }
1003 let log_mix = logsumexp_slice(&self.scratch[..n]);
1004 for i in 0..n {
1005 self.experts[i].log_weight = self.scratch[i] - log_mix;
1006 }
1007 }
1008 MixtureKind::Switching => {
1009 let log_alpha = self.alpha.ln();
1010 let log_1m_alpha = (1.0 - self.alpha).ln();
1011 for i in 0..n {
1012 let expert = &self.experts[i];
1013 let switched = logsumexp2(
1014 log_1m_alpha + expert.log_weight,
1015 log_alpha + expert.log_prior,
1016 );
1017 self.scratch[i] = switched + self.neural_logps[i];
1018 }
1019 let log_mix = logsumexp_slice(&self.scratch[..n]);
1020 for i in 0..n {
1021 self.experts[i].log_weight = self.scratch[i] - log_mix;
1022 }
1023 }
1024 MixtureKind::Mdl => {}
1025 MixtureKind::Neural => {
1026 if n > 1 {
1027 self.neural.set_context_state(self.analyzer.state());
1028 self.neural.evaluate_symbol(&self.neural_logps, PDF_MIN);
1029 self.neural
1030 .update_weights_symbol(&self.neural_logps, PDF_MIN);
1031 }
1032 self.analyzer.update(symbol);
1033 self.neural.set_context_state(self.analyzer.state());
1034 }
1035 }
1036 self.valid = false;
1037 Ok(symbol)
1038 }
1039}
1040
1041#[derive(Clone)]
1042#[allow(clippy::large_enum_variant)]
1043enum RatePdfPredictor {
1044 Rosa(RosaPredictor),
1045 Match {
1046 model: MatchModel,
1047 },
1048 SparseMatch {
1049 model: SparseMatchModel,
1050 },
1051 Ppmd {
1052 model: PpmdModel,
1053 },
1054 Ctw(CtwPredictor),
1055 FacCtw(CtwPredictor),
1056 #[cfg(feature = "backend-mamba")]
1057 Mamba(MambaPredictor),
1058 #[cfg(feature = "backend-rwkv")]
1059 Rwkv(RwkvPredictor),
1060 Zpaq(ZpaqPredictor),
1061 Mixture(MixturePredictor),
1062 Particle(crate::particle::ParticleRuntime),
1063 Calibrated {
1064 base: Box<RatePdfPredictor>,
1065 core: CalibratorCore,
1066 pdf: Vec<f64>,
1067 valid: bool,
1068 },
1069}
1070
1071impl RatePdfPredictor {
1072 fn from_rate_backend(backend: RateBackend, max_order: i64) -> Result<Self> {
1073 match backend {
1074 RateBackend::RosaPlus => Ok(Self::Rosa(RosaPredictor::new(max_order))),
1075 RateBackend::Match {
1076 hash_bits,
1077 min_len,
1078 max_len,
1079 base_mix,
1080 confidence_scale,
1081 } => Ok(Self::Match {
1082 model: MatchModel::new_contiguous(
1083 hash_bits,
1084 min_len,
1085 max_len,
1086 base_mix,
1087 confidence_scale,
1088 ),
1089 }),
1090 RateBackend::SparseMatch {
1091 hash_bits,
1092 min_len,
1093 max_len,
1094 gap_min,
1095 gap_max,
1096 base_mix,
1097 confidence_scale,
1098 } => Ok(Self::SparseMatch {
1099 model: SparseMatchModel::new(
1100 hash_bits,
1101 min_len,
1102 max_len,
1103 gap_min,
1104 gap_max,
1105 base_mix,
1106 confidence_scale,
1107 ),
1108 }),
1109 RateBackend::Ppmd { order, memory_mb } => Ok(Self::Ppmd {
1110 model: PpmdModel::new(order, memory_mb),
1111 }),
1112 RateBackend::Ctw { depth } => Ok(Self::Ctw(CtwPredictor::new_ctw(depth))),
1113 RateBackend::FacCtw {
1114 base_depth,
1115 num_percept_bits: _,
1116 encoding_bits,
1117 } => {
1118 let bits = encoding_bits.clamp(1, 8);
1119 Ok(Self::FacCtw(CtwPredictor::new_fac(base_depth, bits)))
1120 }
1121 #[cfg(feature = "backend-mamba")]
1122 RateBackend::Mamba { model } => Ok(Self::Mamba(MambaPredictor::from_model(model))),
1123 #[cfg(feature = "backend-mamba")]
1124 RateBackend::MambaMethod { method } => {
1125 Ok(Self::Mamba(MambaPredictor::from_method(&method)?))
1126 }
1127 #[cfg(feature = "backend-rwkv")]
1128 RateBackend::Rwkv7 { model } => Ok(Self::Rwkv(RwkvPredictor::from_model(model))),
1129 #[cfg(feature = "backend-rwkv")]
1130 RateBackend::Rwkv7Method { method } => {
1131 Ok(Self::Rwkv(RwkvPredictor::from_method(&method)?))
1132 }
1133 RateBackend::Zpaq { method } => Ok(Self::Zpaq(ZpaqPredictor::new(method))),
1134 RateBackend::Mixture { spec } => {
1135 Ok(Self::Mixture(MixturePredictor::new(spec.as_ref())?))
1136 }
1137 RateBackend::Particle { spec } => Ok(Self::Particle(
1138 crate::particle::ParticleRuntime::new(spec.as_ref()),
1139 )),
1140 RateBackend::Calibrated { spec } => Ok(Self::Calibrated {
1141 base: Box::new(Self::from_rate_backend(spec.base.clone(), max_order)?),
1142 core: build_calibrator(spec.as_ref()),
1143 pdf: vec![1.0 / 256.0; 256],
1144 valid: false,
1145 }),
1146 }
1147 }
1148
1149 fn begin_stream(&mut self, total_len: usize) -> Result<()> {
1150 self.finish_stream()?;
1151 match self {
1152 Self::Rosa(m) => {
1153 m.begin_stream(total_len);
1154 Ok(())
1155 }
1156 Self::Match { .. }
1157 | Self::SparseMatch { .. }
1158 | Self::Ppmd { .. }
1159 | Self::Zpaq(_)
1160 | Self::Particle(_) => Ok(()),
1161 Self::Ctw(m) | Self::FacCtw(m) => {
1162 m.tree.reserve_for_symbols(total_len);
1163 Ok(())
1164 }
1165 #[cfg(feature = "backend-mamba")]
1166 Self::Mamba(m) => m.begin_stream(total_len),
1167 #[cfg(feature = "backend-rwkv")]
1168 Self::Rwkv(m) => m.begin_stream(total_len),
1169 Self::Mixture(m) => m.begin_stream(total_len),
1170 Self::Calibrated { base, .. } => base.begin_stream(total_len),
1171 }
1172 }
1173
1174 fn finish_stream(&mut self) -> Result<()> {
1175 match self {
1176 Self::Rosa(_)
1177 | Self::Match { .. }
1178 | Self::SparseMatch { .. }
1179 | Self::Ppmd { .. }
1180 | Self::Ctw(_)
1181 | Self::FacCtw(_)
1182 | Self::Zpaq(_)
1183 | Self::Particle(_) => Ok(()),
1184 #[cfg(feature = "backend-mamba")]
1185 Self::Mamba(_) => Ok(()),
1186 #[cfg(feature = "backend-rwkv")]
1187 Self::Rwkv(m) => m.finish_stream(),
1188 Self::Mixture(m) => m.finish_stream(),
1189 Self::Calibrated { base, .. } => base.finish_stream(),
1190 }
1191 }
1192
1193 fn pdf_next(&mut self) -> Result<&[f64]> {
1194 match self {
1195 Self::Rosa(m) => Ok(m.pdf_next()),
1196 Self::Match { model } => Ok(model.pdf()),
1197 Self::Ctw(m) => Ok(m.pdf_next()),
1198 Self::FacCtw(m) => Ok(m.pdf_next()),
1199 #[cfg(feature = "backend-mamba")]
1200 Self::Mamba(m) => Ok(m.pdf_next()),
1201 #[cfg(feature = "backend-rwkv")]
1202 Self::Rwkv(m) => Ok(m.pdf_next()),
1203 Self::Zpaq(m) => Ok(m.pdf_next()),
1204 Self::Mixture(m) => m.ensure_pdf(),
1205 Self::Particle(m) => Ok(m.pdf_next()),
1206 Self::SparseMatch { model } => Ok(model.pdf()),
1207 Self::Ppmd { model } => Ok(model.pdf()),
1208 Self::Calibrated {
1209 base,
1210 core,
1211 pdf,
1212 valid,
1213 } => {
1214 if !*valid {
1215 let base_pdf = base.pdf_next()?;
1216 core.apply_pdf(base_pdf, pdf);
1217 normalize_pdf(pdf);
1218 *valid = true;
1219 }
1220 Ok(pdf)
1221 }
1222 }
1223 }
1224
1225 fn update(&mut self, symbol: u8) -> Result<()> {
1226 match self {
1227 Self::Rosa(m) => {
1228 m.update(symbol);
1229 Ok(())
1230 }
1231 Self::Match { model } => {
1232 model.update(symbol);
1233 Ok(())
1234 }
1235 Self::SparseMatch { model } => {
1236 model.update(symbol);
1237 Ok(())
1238 }
1239 Self::Ppmd { model } => {
1240 model.update(symbol);
1241 Ok(())
1242 }
1243 Self::Ctw(m) => {
1244 m.update(symbol);
1245 Ok(())
1246 }
1247 Self::FacCtw(m) => {
1248 m.update(symbol);
1249 Ok(())
1250 }
1251 #[cfg(feature = "backend-mamba")]
1252 Self::Mamba(m) => m.update(symbol),
1253 #[cfg(feature = "backend-rwkv")]
1254 Self::Rwkv(m) => m.update(symbol),
1255 Self::Zpaq(m) => {
1256 m.update(symbol);
1257 Ok(())
1258 }
1259 Self::Mixture(m) => m.update(symbol),
1260 Self::Particle(m) => {
1261 m.step(symbol);
1262 Ok(())
1263 }
1264 Self::Calibrated {
1265 base,
1266 core,
1267 pdf,
1268 valid,
1269 } => {
1270 if !*valid {
1271 let base_pdf = base.pdf_next()?;
1272 core.apply_pdf(base_pdf, pdf);
1273 normalize_pdf(pdf);
1274 }
1275 core.update(symbol, pdf);
1276 base.update(symbol)?;
1277 *valid = false;
1278 Ok(())
1279 }
1280 }
1281 }
1282
1283 fn prepare_cached_cdf_fast_bitwise(&mut self) -> Result<bool> {
1284 match self {
1285 Self::Rosa(m) => {
1286 let _ = m.cdf_next();
1287 Ok(true)
1288 }
1289 Self::Match { model } => {
1290 let _ = model.cdf();
1291 Ok(true)
1292 }
1293 Self::SparseMatch { model } => {
1294 let _ = model.cdf();
1295 Ok(true)
1296 }
1297 Self::Ppmd { model } => {
1298 let _ = model.cdf();
1299 Ok(true)
1300 }
1301 #[cfg(feature = "backend-mamba")]
1302 Self::Mamba(m) => {
1303 let _ = m.cdf_next();
1304 Ok(true)
1305 }
1306 #[cfg(feature = "backend-rwkv")]
1307 Self::Rwkv(m) => {
1308 let _ = m.cdf_next();
1309 Ok(true)
1310 }
1311 _ => Ok(false),
1312 }
1313 }
1314
1315 fn cached_cdf_bit_prob_one_msb(&mut self, lo: usize, hi: usize) -> Option<f64> {
1316 match self {
1317 Self::Rosa(m) => Some(cdf_bit_prob_one_msb(&m.cdf, lo, hi)),
1318 Self::Match { model } => Some(cdf_bit_prob_one_msb(model.cdf(), lo, hi)),
1319 Self::SparseMatch { model } => Some(cdf_bit_prob_one_msb(model.cdf(), lo, hi)),
1320 Self::Ppmd { model } => Some(cdf_bit_prob_one_msb(model.cdf(), lo, hi)),
1321 #[cfg(feature = "backend-mamba")]
1322 Self::Mamba(m) => Some(cdf_bit_prob_one_msb(m.cdf_next(), lo, hi)),
1323 #[cfg(feature = "backend-rwkv")]
1324 Self::Rwkv(m) => Some(cdf_bit_prob_one_msb(m.cdf_next(), lo, hi)),
1325 _ => None,
1326 }
1327 }
1328
1329 #[inline]
1330 fn can_fast_ac_bitwise(&self) -> bool {
1331 match self {
1332 Self::Ctw(m) => m.can_fast_ac_bitwise(),
1333 Self::Mixture(m) => m.can_fast_ac_bitwise(),
1334 _ => false,
1335 }
1336 }
1337
1338 fn ac_step_fast_bitwise<F>(&mut self, choose_bit: F) -> Result<u8>
1339 where
1340 F: FnMut(usize, f64) -> Result<u8>,
1341 {
1342 match self {
1343 Self::Ctw(m) => ctw_ac_step_bitwise(m, choose_bit),
1344 Self::Mixture(m) => m.ac_step_bitwise(choose_bit),
1345 _ => unreachable!("fast bitwise path requested for unsupported predictor"),
1346 }
1347 }
1348}
1349
1350fn ctw_ac_step_bitwise<F>(ctw: &mut CtwPredictor, mut choose_bit: F) -> Result<u8>
1351where
1352 F: FnMut(usize, f64) -> Result<u8>,
1353{
1354 debug_assert!(ctw.can_fast_ac_bitwise());
1355 let mut symbol = 0u8;
1356 for bit_idx in 0..8usize {
1357 let p1 = ctw.bit_prob_one_msb(bit_idx);
1358 let bit = choose_bit(bit_idx, p1)? & 1;
1359 symbol |= bit << (7 - bit_idx);
1360 ctw.update_bit_msb(bit_idx, bit == 1);
1361 }
1362 Ok(symbol)
1363}
1364
1365#[inline]
1366fn binary_split_from_prob_one(p1: f64) -> u32 {
1367 let p1 = p1.clamp(PDF_MIN, 1.0 - PDF_MIN);
1368 let p0 = 1.0 - p1;
1369 let mut split = (p0 * (CDF_TOTAL as f64)) as u32;
1370 if split == 0 {
1371 split = 1;
1372 } else if split >= CDF_TOTAL {
1373 split = CDF_TOTAL - 1;
1374 }
1375 split
1376}
1377
1378fn encode_payload_ac(data: &[u8], predictor: &mut RatePdfPredictor) -> Result<Vec<u8>> {
1379 predictor.begin_stream(data.len())?;
1380 if predictor.can_fast_ac_bitwise() {
1381 let mut out = Vec::new();
1382 {
1383 let mut enc = ArithmeticEncoder::new(&mut out);
1384 for &symbol in data {
1385 predictor.ac_step_fast_bitwise(|bit_idx, p1_mix| {
1386 let bit = (symbol >> (7 - bit_idx)) & 1;
1387 let split = binary_split_from_prob_one(p1_mix);
1388 if bit == 0 {
1389 enc.encode_counts(0, split as u64, CDF_TOTAL as u64)?;
1390 } else {
1391 enc.encode_counts(split as u64, CDF_TOTAL as u64, CDF_TOTAL as u64)?;
1392 }
1393 Ok(bit)
1394 })?;
1395 }
1396 let _ = enc.finish()?;
1397 }
1398 predictor.finish_stream()?;
1399 return Ok(out);
1400 }
1401
1402 let mut out = Vec::new();
1403 {
1404 let mut enc = ArithmeticEncoder::new(&mut out);
1405 let mut cdf = vec![0u32; 257];
1406 for &b in data {
1407 let pdf = predictor.pdf_next()?;
1408 crate::coders::quantize_pdf_to_integer_cdf_dense_positive_with_buffer(
1409 pdf, CDF_TOTAL, &mut cdf,
1410 );
1411 let sym = b as usize;
1412 enc.encode_counts(cdf[sym] as u64, cdf[sym + 1] as u64, CDF_TOTAL as u64)?;
1413 predictor.update(b)?;
1414 }
1415 let _ = enc.finish()?;
1416 }
1417 predictor.finish_stream()?;
1418 Ok(out)
1419}
1420
1421fn decode_payload_ac(
1422 payload: &[u8],
1423 out_len: usize,
1424 predictor: &mut RatePdfPredictor,
1425) -> Result<Vec<u8>> {
1426 predictor.begin_stream(out_len)?;
1427 if predictor.can_fast_ac_bitwise() {
1428 let mut dec = ArithmeticDecoder::new(payload)?;
1429 let mut out = Vec::with_capacity(out_len);
1430 for _ in 0..out_len {
1431 let symbol = predictor.ac_step_fast_bitwise(|_, p1_mix| {
1432 let split = binary_split_from_prob_one(p1_mix);
1433 let cdf = [0u32, split, CDF_TOTAL];
1434 Ok(dec.decode_symbol_counts(&cdf, CDF_TOTAL)? as u8)
1435 })?;
1436 out.push(symbol);
1437 }
1438 predictor.finish_stream()?;
1439 return Ok(out);
1440 }
1441
1442 let mut dec = ArithmeticDecoder::new(payload)?;
1443 let mut out = Vec::with_capacity(out_len);
1444 let mut cdf = vec![0u32; 257];
1445 for _ in 0..out_len {
1446 let pdf = predictor.pdf_next()?;
1447 crate::coders::quantize_pdf_to_integer_cdf_dense_positive_with_buffer(
1448 pdf, CDF_TOTAL, &mut cdf,
1449 );
1450 let sym = dec.decode_symbol_counts(&cdf, CDF_TOTAL)? as u8;
1451 out.push(sym);
1452 predictor.update(sym)?;
1453 }
1454 predictor.finish_stream()?;
1455 Ok(out)
1456}
1457
1458fn encode_payload_rans(data: &[u8], predictor: &mut RatePdfPredictor) -> Result<Vec<u8>> {
1459 predictor.begin_stream(data.len())?;
1460 let mut encoder = BlockedRansEncoder::new();
1461 let mut cdf = vec![0u32; 257];
1462 let mut freq = vec![0i64; 256];
1463
1464 for &b in data {
1465 let pdf = predictor.pdf_next()?;
1466 quantize_pdf_to_rans_cdf_with_buffer(pdf, &mut cdf, &mut freq);
1467 let s = b as usize;
1468 encoder.encode(Cdf::new(cdf[s], cdf[s + 1], ANS_TOTAL));
1469 predictor.update(b)?;
1470 }
1471
1472 let blocks = encoder.finish();
1473 let mut out = Vec::new();
1474 out.extend_from_slice(&(blocks.len() as u32).to_le_bytes());
1475 for block in blocks {
1476 out.extend_from_slice(&(block.len() as u32).to_le_bytes());
1477 out.extend_from_slice(&block);
1478 }
1479 predictor.finish_stream()?;
1480 Ok(out)
1481}
1482
1483fn decode_payload_rans(
1484 payload: &[u8],
1485 out_len: usize,
1486 predictor: &mut RatePdfPredictor,
1487) -> Result<Vec<u8>> {
1488 predictor.begin_stream(out_len)?;
1489 if payload.len() < 4 {
1490 bail!("rANS payload too short");
1491 }
1492 let block_count = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]) as usize;
1493 let mut pos = 4usize;
1494 let mut blocks = Vec::with_capacity(block_count);
1495 for _ in 0..block_count {
1496 if pos + 4 > payload.len() {
1497 bail!("truncated rANS block header");
1498 }
1499 let len = u32::from_le_bytes([
1500 payload[pos],
1501 payload[pos + 1],
1502 payload[pos + 2],
1503 payload[pos + 3],
1504 ]) as usize;
1505 pos += 4;
1506 if pos + len > payload.len() {
1507 bail!("truncated rANS block data");
1508 }
1509 blocks.push(&payload[pos..pos + len]);
1510 pos += len;
1511 }
1512
1513 let mut dec = BlockedRansDecoder::new(blocks, out_len)?;
1514 let mut out = Vec::with_capacity(out_len);
1515 let mut cdf = vec![0u32; 257];
1516 let mut freq = vec![0i64; 256];
1517
1518 for _ in 0..out_len {
1519 let pdf = predictor.pdf_next()?;
1520 quantize_pdf_to_rans_cdf_with_buffer(pdf, &mut cdf, &mut freq);
1521 let sym = dec.decode(&cdf)? as u8;
1522 out.push(sym);
1523 predictor.update(sym)?;
1524 }
1525 predictor.finish_stream()?;
1526 Ok(out)
1527}
1528
1529pub fn compress_rate_bytes(
1534 data: &[u8],
1535 rate_backend: &RateBackend,
1536 max_order: i64,
1537 coder: CoderType,
1538 framing: FramingMode,
1539) -> Result<Vec<u8>> {
1540 let mut predictor = RatePdfPredictor::from_rate_backend(rate_backend.clone(), max_order)?;
1541 let payload = match coder {
1542 CoderType::AC => encode_payload_ac(data, &mut predictor)?,
1543 CoderType::RANS => encode_payload_rans(data, &mut predictor)?,
1544 };
1545
1546 if framing == FramingMode::Raw {
1547 return Ok(payload);
1548 }
1549
1550 let mut out = Vec::with_capacity(FramedHeader::SIZE + payload.len());
1551 let hdr = FramedHeader::new(coder, data.len() as u64, crc32(data));
1552 hdr.write(&mut out);
1553 out.extend_from_slice(&payload);
1554 Ok(out)
1555}
1556
1557pub fn compress_rate_size(
1559 data: &[u8],
1560 rate_backend: &RateBackend,
1561 max_order: i64,
1562 coder: CoderType,
1563 framing: FramingMode,
1564) -> Result<u64> {
1565 let encoded = compress_rate_bytes(data, rate_backend, max_order, coder, framing)?;
1566 Ok(encoded.len() as u64)
1567}
1568
1569pub fn compress_rate_size_chain(
1571 parts: &[&[u8]],
1572 rate_backend: &RateBackend,
1573 max_order: i64,
1574 coder: CoderType,
1575 framing: FramingMode,
1576) -> Result<u64> {
1577 let total = parts.iter().map(|p| p.len()).sum();
1578 let mut data = Vec::with_capacity(total);
1579 for p in parts {
1580 data.extend_from_slice(p);
1581 }
1582 compress_rate_size(&data, rate_backend, max_order, coder, framing)
1583}
1584
1585pub fn decompress_rate_bytes(
1587 input: &[u8],
1588 rate_backend: &RateBackend,
1589 max_order: i64,
1590 _coder: CoderType,
1591 framing: FramingMode,
1592) -> Result<Vec<u8>> {
1593 let (payload, coder, out_len, expected_crc) = if framing == FramingMode::Framed {
1594 let hdr = FramedHeader::read(input)?;
1595 (
1596 &input[FramedHeader::SIZE..],
1597 hdr.coder_type(),
1598 hdr.original_len as usize,
1599 Some(hdr.crc32),
1600 )
1601 } else {
1602 bail!("raw payload decompression requires explicit output length and is not supported");
1603 };
1604
1605 let _ = coder;
1606 let mut predictor = RatePdfPredictor::from_rate_backend(rate_backend.clone(), max_order)?;
1607 let decoded = match coder {
1608 CoderType::AC => decode_payload_ac(payload, out_len, &mut predictor)?,
1609 CoderType::RANS => decode_payload_rans(payload, out_len, &mut predictor)?,
1610 };
1611
1612 if let Some(crc) = expected_crc {
1613 let got = crc32(&decoded);
1614 if got != crc {
1615 bail!("CRC32 mismatch: expected 0x{crc:08X}, got 0x{got:08X}");
1616 }
1617 }
1618
1619 Ok(decoded)
1620}
1621
1622fn normalize_pdf(pdf: &mut [f64]) {
1623 let mut sum = 0.0;
1624 for p in pdf.iter_mut() {
1625 *p = if p.is_finite() {
1626 (*p).max(PDF_MIN)
1627 } else {
1628 PDF_MIN
1629 };
1630 sum += *p;
1631 }
1632 if !(sum.is_finite()) || sum <= 0.0 {
1633 let u = 1.0 / (pdf.len() as f64);
1634 for p in pdf.iter_mut() {
1635 *p = u;
1636 }
1637 return;
1638 }
1639 let inv = 1.0 / sum;
1640 for p in pdf.iter_mut() {
1641 *p *= inv;
1642 }
1643}
1644
1645#[inline]
1646fn uniform_cdf_row() -> [f64; 257] {
1647 let mut cdf = [0.0; 257];
1648 let inv = 1.0 / 256.0;
1649 for (i, slot) in cdf.iter_mut().enumerate() {
1650 *slot = (i as f64) * inv;
1651 }
1652 cdf
1653}
1654
1655#[inline]
1656fn build_cdf_row_from_pdf_slice(pdf: &[f64], cdf: &mut [f64; 257]) {
1657 cdf[0] = 0.0;
1658 let mut acc = 0.0;
1659 for i in 0..256 {
1660 acc += pdf[i];
1661 cdf[i + 1] = acc;
1662 }
1663}
1664
1665fn normalize_pdf_vec_and_maybe_build_cdf(pdf: &mut [f64], mut cdf: Option<&mut [f64; 257]>) {
1666 let mut sum = 0.0;
1667 for p in pdf.iter_mut() {
1668 *p = if p.is_finite() {
1669 (*p).max(PDF_MIN)
1670 } else {
1671 PDF_MIN
1672 };
1673 sum += *p;
1674 }
1675 if !(sum.is_finite()) || sum <= 0.0 {
1676 let u = 1.0 / (pdf.len() as f64);
1677 pdf.fill(u);
1678 if let Some(cdf) = cdf.as_deref_mut() {
1679 *cdf = uniform_cdf_row();
1680 }
1681 return;
1682 }
1683 let inv = 1.0 / sum;
1684 if let Some(cdf) = cdf.as_deref_mut() {
1685 cdf[0] = 0.0;
1686 let mut acc = 0.0;
1687 for i in 0..256 {
1688 pdf[i] *= inv;
1689 acc += pdf[i];
1690 cdf[i + 1] = acc;
1691 }
1692 } else {
1693 for p in pdf.iter_mut() {
1694 *p *= inv;
1695 }
1696 }
1697}
1698
1699#[inline]
1700fn cdf_bit_prob_one_msb(cdf: &[f64; 257], lo: usize, hi: usize) -> f64 {
1701 let mid = (lo + hi) >> 1;
1702 let total = (cdf[hi] - cdf[lo]).max(PDF_MIN);
1703 let one = (cdf[hi] - cdf[mid]).max(0.0);
1704 (one / total).clamp(PDF_MIN, 1.0 - PDF_MIN)
1705}
1706
1707#[inline]
1708fn logsumexp_slice(vals: &[f64]) -> f64 {
1709 let mut m = f64::NEG_INFINITY;
1710 for &v in vals {
1711 if v > m {
1712 m = v;
1713 }
1714 }
1715 if !m.is_finite() {
1716 return m;
1717 }
1718 let mut s = 0.0;
1719 for &v in vals {
1720 s += (v - m).exp();
1721 }
1722 m + s.ln()
1723}
1724
1725#[inline]
1726fn logsumexp_expert_weights(experts: &[MixExpert]) -> f64 {
1727 let mut m = f64::NEG_INFINITY;
1728 for e in experts {
1729 if e.log_weight > m {
1730 m = e.log_weight;
1731 }
1732 }
1733 if !m.is_finite() {
1734 return m;
1735 }
1736 let mut s = 0.0;
1737 for e in experts {
1738 s += (e.log_weight - m).exp();
1739 }
1740 m + s.ln()
1741}
1742
1743fn logsumexp2(a: f64, b: f64) -> f64 {
1744 let m = if a > b { a } else { b };
1745 if !m.is_finite() {
1746 return m;
1747 }
1748 m + ((a - m).exp() + (b - m).exp()).ln()
1749}
1750
1751#[allow(dead_code)]
1752fn _zpaq_marker(_: &ZpaqRateModel) {}
1753
1754#[cfg(test)]
1755mod tests {
1756 use super::*;
1757 use std::sync::Arc;
1758
1759 fn assert_pdf_close(lhs: &[f64], rhs: &[f64], tol: f64) {
1760 assert_eq!(lhs.len(), rhs.len());
1761 for (idx, (&a, &b)) in lhs.iter().zip(rhs.iter()).enumerate() {
1762 let delta = (a - b).abs();
1763 assert!(
1764 delta <= tol,
1765 "pdf mismatch at symbol {idx}: lhs={a} rhs={b} delta={delta}"
1766 );
1767 }
1768 }
1769
1770 fn brute_force_pdf(predictor: &mut CtwPredictor) -> Vec<f64> {
1771 let bits = predictor.bits_per_symbol.clamp(1, 8);
1772 let mut out = vec![0.0; 256];
1773
1774 if bits == 8 {
1775 for (sym, slot) in out.iter_mut().enumerate().take(256usize) {
1776 *slot = predictor.log_prob_symbol_bruteforce(sym as u8).exp();
1777 }
1778 } else {
1779 let patterns = 1usize << bits;
1780 let aliases = 1usize << (8 - bits);
1781 let mut pat_prob = vec![0.0; patterns];
1782 for (pat, value) in pat_prob.iter_mut().enumerate() {
1783 let symbol = if predictor.msb_first {
1784 (pat as u8) << (8 - bits)
1785 } else {
1786 pat as u8
1787 };
1788 *value = predictor.log_prob_symbol_bruteforce(symbol).exp();
1789 }
1790 for (byte, slot) in out.iter_mut().enumerate().take(256usize) {
1791 let pat = if predictor.msb_first {
1792 byte >> (8 - bits)
1793 } else {
1794 byte & (patterns - 1)
1795 };
1796 *slot = pat_prob[pat] / (aliases as f64);
1797 }
1798 }
1799
1800 CtwPredictor::normalize_pdf(&mut out);
1801 out
1802 }
1803
1804 #[test]
1805 fn ctw_pdf_fast_matches_bruteforce() {
1806 let mut predictor = CtwPredictor::new_ctw(6);
1807 for &b in b"ctw fast-path regression corpus 1234567890" {
1808 predictor.update(b);
1809 }
1810
1811 let fast = predictor.pdf_next().to_vec();
1812 predictor.valid = false;
1813 let brute = brute_force_pdf(&mut predictor);
1814
1815 for i in 0..256usize {
1816 let delta = (fast[i] - brute[i]).abs();
1817 assert!(
1818 delta < 1e-12,
1819 "symbol={i} fast={} brute={} delta={delta}",
1820 fast[i],
1821 brute[i]
1822 );
1823 }
1824 }
1825
1826 #[test]
1827 fn fac_pdf_fast_matches_bruteforce_subbyte() {
1828 let mut predictor = CtwPredictor::new_fac(5, 5);
1829 for &b in b"fac ctw subbyte regression corpus abcdefghijklmnopqrstuvwxyz" {
1830 predictor.update(b);
1831 }
1832
1833 let fast = predictor.pdf_next().to_vec();
1834 predictor.valid = false;
1835 let brute = brute_force_pdf(&mut predictor);
1836
1837 for i in 0..256usize {
1838 let delta = (fast[i] - brute[i]).abs();
1839 assert!(
1840 delta < 1e-12,
1841 "symbol={i} fast={} brute={} delta={delta}",
1842 fast[i],
1843 brute[i]
1844 );
1845 }
1846 }
1847
1848 fn assert_ctw_pdf_next_preserves_state(mut predictor: CtwPredictor) {
1849 for &b in b"ctw predictor state preservation payload" {
1850 predictor.update(b);
1851 }
1852 let mut before_p0 = [0.0f64; 8];
1853 let mut before_p1 = [0.0f64; 8];
1854 for bit_idx in 0..8usize {
1855 before_p0[bit_idx] = predictor.tree.predict(false, bit_idx);
1856 before_p1[bit_idx] = predictor.tree.predict(true, bit_idx);
1857 }
1858 let log_before = predictor.tree.get_log_block_probability();
1859 let _ = predictor.pdf_next();
1860 let log_after = predictor.tree.get_log_block_probability();
1861 assert!(
1862 (log_before - log_after).abs() < 1e-12,
1863 "log drift: before={log_before} after={log_after}"
1864 );
1865 for bit_idx in 0..8usize {
1866 let after_p0 = predictor.tree.predict(false, bit_idx);
1867 let after_p1 = predictor.tree.predict(true, bit_idx);
1868 assert!(
1869 (before_p0[bit_idx] - after_p0).abs() < 1e-12,
1870 "bit {bit_idx} p0 drift: {} vs {}",
1871 before_p0[bit_idx],
1872 after_p0
1873 );
1874 assert!(
1875 (before_p1[bit_idx] - after_p1).abs() < 1e-12,
1876 "bit {bit_idx} p1 drift: {} vs {}",
1877 before_p1[bit_idx],
1878 after_p1
1879 );
1880 }
1881 }
1882
1883 #[test]
1884 fn ctw_pdf_next_preserves_state() {
1885 assert_ctw_pdf_next_preserves_state(CtwPredictor::new_ctw(7));
1886 }
1887
1888 #[test]
1889 fn fac_pdf_next_preserves_state() {
1890 assert_ctw_pdf_next_preserves_state(CtwPredictor::new_fac(7, 8));
1891 }
1892
1893 fn assert_fill_pattern_preserves_symbol_log_probs(mut predictor: CtwPredictor) {
1894 for &b in b"fill-pattern preservation regression payload" {
1895 predictor.update(b);
1896 }
1897 let mut baseline = [0.0f64; 256];
1898 for (sym, slot) in baseline.iter_mut().enumerate() {
1899 *slot = predictor.log_prob_symbol_bruteforce(sym as u8);
1900 }
1901 let _ = predictor.fill_pattern_log_probs();
1902 for (sym, &expected) in baseline.iter().enumerate() {
1903 let got = predictor.log_prob_symbol_bruteforce(sym as u8);
1904 let diff = (expected - got).abs();
1905 assert!(
1906 diff < 1e-12,
1907 "symbol={sym} expected={expected} got={got} diff={diff}"
1908 );
1909 }
1910 }
1911
1912 #[test]
1913 fn ctw_fill_pattern_preserves_symbol_log_probs() {
1914 assert_fill_pattern_preserves_symbol_log_probs(CtwPredictor::new_ctw(7));
1915 }
1916
1917 #[test]
1918 fn fac_fill_pattern_preserves_symbol_log_probs() {
1919 assert_fill_pattern_preserves_symbol_log_probs(CtwPredictor::new_fac(7, 8));
1920 }
1921
1922 fn assert_pdf_then_update_matches_plain_update(mut base: CtwPredictor) {
1923 for &b in b"pdf then update parity payload" {
1924 base.update(b);
1925 }
1926 let observed = b'n';
1927 let mut with_pdf = base.clone();
1928 let mut plain = base;
1929
1930 let _ = with_pdf.pdf_next();
1931 with_pdf.update(observed);
1932 plain.update(observed);
1933
1934 for sym in 0u8..=255u8 {
1935 let lp_with_pdf = with_pdf.log_prob_symbol_bruteforce(sym);
1936 let lp_plain = plain.log_prob_symbol_bruteforce(sym);
1937 let diff = (lp_with_pdf - lp_plain).abs();
1938 assert!(
1939 diff < 1e-12,
1940 "symbol={sym} with_pdf={lp_with_pdf} plain={lp_plain} diff={diff}"
1941 );
1942 }
1943 }
1944
1945 #[test]
1946 fn ctw_pdf_then_update_matches_plain_update() {
1947 assert_pdf_then_update_matches_plain_update(CtwPredictor::new_ctw(7));
1948 }
1949
1950 #[test]
1951 fn fac_pdf_then_update_matches_plain_update() {
1952 assert_pdf_then_update_matches_plain_update(CtwPredictor::new_fac(7, 8));
1953 }
1954
1955 #[test]
1956 fn roundtrip_rate_ac_ctw() {
1957 let data = b"ctw backend roundtrip payload";
1958 let backend = RateBackend::Ctw { depth: 8 };
1959 let enc =
1960 compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
1961 let dec =
1962 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
1963 assert_eq!(dec, data);
1964 }
1965
1966 #[test]
1967 fn roundtrip_rate_ac_match_family_and_ppmd() {
1968 let data = b"repeat repeat repeat sparse sparse repeat payload";
1969 for backend in [
1970 RateBackend::Match {
1971 hash_bits: 20,
1972 min_len: 4,
1973 max_len: 255,
1974 base_mix: 0.02,
1975 confidence_scale: 1.0,
1976 },
1977 RateBackend::SparseMatch {
1978 hash_bits: 19,
1979 min_len: 3,
1980 max_len: 64,
1981 gap_min: 1,
1982 gap_max: 2,
1983 base_mix: 0.05,
1984 confidence_scale: 1.0,
1985 },
1986 RateBackend::Ppmd {
1987 order: 8,
1988 memory_mb: 8,
1989 },
1990 ] {
1991 let enc = compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed)
1992 .unwrap();
1993 let dec = decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed)
1994 .unwrap();
1995 assert_eq!(dec, data);
1996 }
1997 }
1998
1999 #[test]
2000 fn roundtrip_rate_ac_ppmd_high_order_text_payload() {
2001 let seed = include_bytes!("../../README.md");
2002 let mut data = Vec::with_capacity(4096);
2003 while data.len() < 4096 {
2004 data.extend_from_slice(seed);
2005 }
2006 data.truncate(4096);
2007
2008 let backend = RateBackend::Ppmd {
2009 order: 12,
2010 memory_mb: 256,
2011 };
2012 let enc = compress_rate_bytes(&data, &backend, -1, CoderType::AC, FramingMode::Framed)
2013 .expect("ppmd high-order compression");
2014 let dec = decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed)
2015 .expect("ppmd high-order decompression");
2016 assert_eq!(dec, data);
2017 }
2018
2019 #[test]
2020 fn roundtrip_rate_ac_calibrated_backend() {
2021 let data = b"calibration wrapper payload calibration wrapper payload";
2022 let backend = RateBackend::Calibrated {
2023 spec: Arc::new(crate::CalibratedSpec {
2024 base: RateBackend::Ctw { depth: 8 },
2025 context: crate::CalibrationContextKind::Text,
2026 bins: 33,
2027 learning_rate: 0.02,
2028 bias_clip: 4.0,
2029 }),
2030 };
2031 let enc =
2032 compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2033 let dec =
2034 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2035 assert_eq!(dec, data);
2036 }
2037
2038 #[test]
2039 fn roundtrip_rate_ac_single_expert_ctw_neural_mixture() {
2040 let data = b"single expert neural ctw fast path payload";
2041 let spec = MixtureSpec::new(
2042 MixtureKind::Neural,
2043 vec![crate::MixtureExpertSpec {
2044 name: Some("ctw".to_string()),
2045 log_prior: 0.0,
2046 max_order: -1,
2047 backend: RateBackend::Ctw { depth: 8 },
2048 }],
2049 )
2050 .with_alpha(0.03);
2051 let backend = RateBackend::Mixture {
2052 spec: Arc::new(spec),
2053 };
2054 let enc =
2055 compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2056 let dec =
2057 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2058 assert_eq!(dec, data);
2059 }
2060
2061 #[test]
2062 fn roundtrip_rate_ac_single_expert_ctw_bayes_mixture() {
2063 let data = b"single expert bayes ctw fast path payload";
2064 let spec = MixtureSpec::new(
2065 MixtureKind::Bayes,
2066 vec![crate::MixtureExpertSpec {
2067 name: Some("ctw".to_string()),
2068 log_prior: 0.0,
2069 max_order: -1,
2070 backend: RateBackend::Ctw { depth: 8 },
2071 }],
2072 )
2073 .with_alpha(0.03);
2074 let backend = RateBackend::Mixture {
2075 spec: Arc::new(spec),
2076 };
2077 let enc =
2078 compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2079 let dec =
2080 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2081 assert_eq!(dec, data);
2082 }
2083
2084 #[test]
2085 fn roundtrip_rate_rans_recursive_mixture() {
2086 let data = b"recursive mixture payload";
2087 let nested = MixtureSpec::new(
2088 MixtureKind::Bayes,
2089 vec![
2090 crate::MixtureExpertSpec {
2091 name: Some("ctw".to_string()),
2092 log_prior: 0.0,
2093 max_order: -1,
2094 backend: RateBackend::Ctw { depth: 6 },
2095 },
2096 crate::MixtureExpertSpec {
2097 name: Some("fac".to_string()),
2098 log_prior: 0.0,
2099 max_order: -1,
2100 backend: RateBackend::FacCtw {
2101 base_depth: 6,
2102 num_percept_bits: 8,
2103 encoding_bits: 8,
2104 },
2105 },
2106 ],
2107 );
2108 let root = MixtureSpec::new(
2109 MixtureKind::Switching,
2110 vec![
2111 crate::MixtureExpertSpec {
2112 name: Some("nested".to_string()),
2113 log_prior: 0.0,
2114 max_order: -1,
2115 backend: RateBackend::Mixture {
2116 spec: Arc::new(nested),
2117 },
2118 },
2119 crate::MixtureExpertSpec {
2120 name: Some("zpaq".to_string()),
2121 log_prior: 0.0,
2122 max_order: -1,
2123 backend: RateBackend::Zpaq {
2124 method: "1".to_string(),
2125 },
2126 },
2127 ],
2128 )
2129 .with_alpha(0.05);
2130
2131 let backend = RateBackend::Mixture {
2132 spec: Arc::new(root),
2133 };
2134 let enc =
2135 compress_rate_bytes(data, &backend, -1, CoderType::RANS, FramingMode::Framed).unwrap();
2136 let dec = decompress_rate_bytes(&enc, &backend, -1, CoderType::RANS, FramingMode::Framed)
2137 .unwrap();
2138 assert_eq!(dec, data);
2139 }
2140
2141 #[test]
2142 fn roundtrip_rate_ac_recursive_neural_mixture() {
2143 let data = b"neural recursive mixture payload for ac coder";
2144 let inner = MixtureSpec::new(
2145 MixtureKind::Bayes,
2146 vec![
2147 crate::MixtureExpertSpec {
2148 name: Some("ctw".to_string()),
2149 log_prior: 0.0,
2150 max_order: -1,
2151 backend: RateBackend::Ctw { depth: 6 },
2152 },
2153 crate::MixtureExpertSpec {
2154 name: Some("fac".to_string()),
2155 log_prior: 0.0,
2156 max_order: -1,
2157 backend: RateBackend::FacCtw {
2158 base_depth: 6,
2159 num_percept_bits: 8,
2160 encoding_bits: 8,
2161 },
2162 },
2163 ],
2164 );
2165 let root = MixtureSpec::new(
2166 MixtureKind::Neural,
2167 vec![
2168 crate::MixtureExpertSpec {
2169 name: Some("nested".to_string()),
2170 log_prior: 0.0,
2171 max_order: -1,
2172 backend: RateBackend::Mixture {
2173 spec: Arc::new(inner),
2174 },
2175 },
2176 crate::MixtureExpertSpec {
2177 name: Some("zpaq".to_string()),
2178 log_prior: 0.0,
2179 max_order: -1,
2180 backend: RateBackend::Zpaq {
2181 method: "1".to_string(),
2182 },
2183 },
2184 ],
2185 )
2186 .with_alpha(0.03);
2187
2188 let backend = RateBackend::Mixture {
2189 spec: Arc::new(root),
2190 };
2191 let enc =
2192 compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2193 let dec =
2194 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2195 assert_eq!(dec, data);
2196 }
2197
2198 #[test]
2199 fn neural_runtime_and_compression_predictor_align() {
2200 let spec = MixtureSpec::new(
2201 MixtureKind::Neural,
2202 vec![
2203 crate::MixtureExpertSpec {
2204 name: Some("ctw".to_string()),
2205 log_prior: 0.0,
2206 max_order: -1,
2207 backend: RateBackend::Ctw { depth: 7 },
2208 },
2209 crate::MixtureExpertSpec {
2210 name: Some("fac".to_string()),
2211 log_prior: 0.0,
2212 max_order: -1,
2213 backend: RateBackend::FacCtw {
2214 base_depth: 7,
2215 num_percept_bits: 8,
2216 encoding_bits: 8,
2217 },
2218 },
2219 ],
2220 )
2221 .with_alpha(0.03);
2222
2223 let backend = RateBackend::Mixture {
2224 spec: Arc::new(spec.clone()),
2225 };
2226 let mut predictor = RatePdfPredictor::from_rate_backend(backend, -1).unwrap();
2227 let experts = spec.build_experts();
2228 let mut runtime = crate::mixture::build_mixture_runtime(&spec, &experts).unwrap();
2229
2230 let data = b"neural alignment check sequence";
2231 for &b in data {
2232 let pdf = predictor.pdf_next().unwrap();
2233 let p_comp = pdf[b as usize];
2234 let p_runtime = runtime.peek_log_prob(b).exp();
2235 assert!(
2236 (p_comp - p_runtime).abs() < 1e-8,
2237 "p_comp={p_comp} p_runtime={p_runtime} symbol={b}"
2238 );
2239 predictor.update(b).unwrap();
2240 runtime.step(b);
2241 }
2242 }
2243
2244 fn assert_cached_cdf_fast_bitwise_matches_pdf_rows(mut predictor: RatePdfPredictor) {
2245 let data = b"cached cdf parity check payload";
2246 for &symbol in data {
2247 let pdf = predictor.pdf_next().unwrap().to_vec();
2248 assert!(predictor.prepare_cached_cdf_fast_bitwise().unwrap());
2249
2250 let mut row = [0.0; 257];
2251 row[0] = 0.0;
2252 for i in 0..256 {
2253 row[i + 1] = row[i] + pdf[i].max(PDF_MIN);
2254 }
2255
2256 let mut stack = vec![(0usize, 256usize)];
2257 while let Some((lo, hi)) = stack.pop() {
2258 if hi - lo <= 1 {
2259 continue;
2260 }
2261 let expected = cdf_bit_prob_one_msb(&row, lo, hi);
2262 let got = predictor
2263 .cached_cdf_bit_prob_one_msb(lo, hi)
2264 .expect("cached cdf branch probability");
2265 let diff = (expected - got).abs();
2266 assert!(
2267 diff <= 1e-12,
2268 "lo={lo} hi={hi} expected={expected} got={got} diff={diff}"
2269 );
2270 let mid = (lo + hi) >> 1;
2271 stack.push((lo, mid));
2272 stack.push((mid, hi));
2273 }
2274
2275 predictor.update(symbol).unwrap();
2276 }
2277 }
2278
2279 #[test]
2280 fn cached_cdf_fast_bitwise_matches_pdf_rows_for_specialized_predictors() {
2281 assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2282 RatePdfPredictor::from_rate_backend(RateBackend::RosaPlus, -1).unwrap(),
2283 );
2284 assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2285 RatePdfPredictor::from_rate_backend(
2286 RateBackend::Ppmd {
2287 order: 6,
2288 memory_mb: 8,
2289 },
2290 -1,
2291 )
2292 .unwrap(),
2293 );
2294 assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2295 RatePdfPredictor::from_rate_backend(
2296 RateBackend::Match {
2297 hash_bits: 20,
2298 min_len: 4,
2299 max_len: 255,
2300 base_mix: 0.02,
2301 confidence_scale: 1.0,
2302 },
2303 -1,
2304 )
2305 .unwrap(),
2306 );
2307 #[cfg(feature = "backend-rwkv")]
2308 assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2309 RatePdfPredictor::from_rate_backend(
2310 RateBackend::Rwkv7Method {
2311 method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=11,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
2312 },
2313 -1,
2314 )
2315 .unwrap(),
2316 );
2317 #[cfg(feature = "backend-mamba")]
2318 assert_cached_cdf_fast_bitwise_matches_pdf_rows(
2319 RatePdfPredictor::from_rate_backend(
2320 RateBackend::MambaMethod {
2321 method: "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=7,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
2322 },
2323 -1,
2324 )
2325 .unwrap(),
2326 );
2327 }
2328
2329 #[test]
2330 fn raw_size_not_larger_than_framed_size() {
2331 let data = b"raw/framed size check payload";
2332 let backend = RateBackend::RosaPlus;
2333 let raw = compress_rate_size(data, &backend, 8, CoderType::AC, FramingMode::Raw).unwrap();
2334 let framed =
2335 compress_rate_size(data, &backend, 8, CoderType::AC, FramingMode::Framed).unwrap();
2336 assert!(framed >= raw);
2337 }
2338
2339 #[cfg(feature = "backend-rwkv")]
2340 #[test]
2341 fn roundtrip_rate_rwkv_method_cfg() {
2342 let data = b"rwkv cfg method backend";
2343 let backend = RateBackend::Rwkv7Method {
2344 method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=11,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
2345 };
2346 let enc =
2347 compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2348 let dec =
2349 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2350 assert_eq!(dec, data);
2351 }
2352
2353 #[cfg(feature = "backend-rwkv")]
2354 #[test]
2355 fn rwkv_rate_predictor_preserves_backend_pdf_exactly() {
2356 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=11,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer";
2357 let mut predictor = RwkvPredictor::from_method(method).expect("rwkv predictor");
2358 let mut backend = rwkvzip::Compressor::new_from_method(method).expect("rwkv backend");
2359 let mut direct = vec![0.0; backend.vocab_size()];
2360
2361 let predicted = predictor.pdf_next().to_vec();
2362 backend.forward_to_pdf(0, &mut direct);
2363 assert_pdf_close(&predicted, &direct, 1e-18);
2364
2365 predictor.update(b'x').expect("predictor update");
2366 backend
2367 .online_update_from_pdf(b'x', &direct)
2368 .expect("backend update");
2369 backend.forward_to_pdf(u32::from(b'x'), &mut direct);
2370 assert_pdf_close(predictor.pdf_next(), &direct, 1e-18);
2371 }
2372
2373 #[cfg(feature = "backend-rwkv")]
2374 #[test]
2375 fn rwkv_rate_predictor_matches_backend_after_partial_tbptt_stream() {
2376 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=29,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=8,clip=0,momentum=0.9)";
2377 let data = b"abcdefghij";
2378 let mut predictor = RwkvPredictor::from_method(method).expect("rwkv predictor");
2379 let mut backend = rwkvzip::Compressor::new_from_method(method).expect("rwkv backend");
2380 let mut direct = vec![0.0; backend.vocab_size()];
2381
2382 predictor
2383 .begin_stream(data.len())
2384 .expect("begin predictor stream");
2385 backend
2386 .begin_online_policy_stream(Some(data.len() as u64))
2387 .expect("begin backend stream");
2388 backend.reset_and_prime();
2389
2390 for &byte in data {
2391 let predicted = predictor.pdf_next().to_vec();
2392 backend.copy_current_pdf_to(&mut direct);
2393 assert_pdf_close(&predicted, &direct, 1e-18);
2394
2395 predictor.update(byte).expect("predictor update");
2396 backend
2397 .observe_symbol_from_current_pdf(byte)
2398 .expect("backend update");
2399 }
2400
2401 predictor.finish_stream().expect("finish predictor stream");
2402 backend
2403 .finish_online_policy_stream()
2404 .expect("finish backend stream");
2405 backend.copy_current_pdf_to(&mut direct);
2406 assert_pdf_close(predictor.pdf_next(), &direct, 1e-18);
2407 }
2408
2409 #[cfg(feature = "backend-rwkv")]
2410 #[test]
2411 fn roundtrip_rate_rwkv_two_json_method_2m() {
2412 let two_json: serde_json::Value =
2413 serde_json::from_str(include_str!("../../examples/two.json")).unwrap();
2414 let method = two_json["experts"]
2415 .as_array()
2416 .unwrap()
2417 .iter()
2418 .find(|expert| expert["name"].as_str() == Some("rwkv"))
2419 .and_then(|expert| expert["method"].as_str())
2420 .unwrap()
2421 .to_string();
2422
2423 let backend = RateBackend::Rwkv7Method { method };
2424 let seed = include_bytes!("../../README.md");
2425 let target_len = 2_097_152usize;
2426 let mut data = Vec::with_capacity(target_len);
2427 while data.len() < target_len {
2428 let remaining = target_len - data.len();
2429 data.extend_from_slice(&seed[..seed.len().min(remaining)]);
2430 }
2431
2432 let enc =
2433 compress_rate_bytes(&data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2434 let dec =
2435 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2436 assert_eq!(dec, data);
2437 }
2438
2439 #[cfg(feature = "backend-mamba")]
2440 #[test]
2441 fn mamba_rate_predictor_preserves_backend_pdf_exactly() {
2442 let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=7,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer";
2443 let mut predictor = MambaPredictor::from_method(method).expect("mamba predictor");
2444 let mut backend = mambazip::Compressor::new_from_method(method).expect("mamba backend");
2445 let mut direct = vec![0.0; backend.vocab_size()];
2446
2447 let predicted = predictor.pdf_next().to_vec();
2448 backend.forward_to_pdf(0, &mut direct);
2449 assert_pdf_close(&predicted, &direct, 1e-18);
2450
2451 predictor.update(b'x').expect("predictor update");
2452 backend
2453 .online_update_from_pdf(b'x', &direct)
2454 .expect("backend update");
2455 backend.forward_to_pdf(u32::from(b'x'), &mut direct);
2456 assert_pdf_close(predictor.pdf_next(), &direct, 1e-18);
2457 }
2458
2459 #[test]
2460 fn roundtrip_rate_ac_particle() {
2461 let spec = crate::ParticleSpec {
2462 num_particles: 4,
2463 num_cells: 4,
2464 cell_dim: 8,
2465 num_rules: 2,
2466 selector_hidden: 16,
2467 rule_hidden: 16,
2468 context_window: 8,
2469 unroll_steps: 1,
2470 ..crate::ParticleSpec::default()
2471 };
2472 let data = b"particle ac roundtrip payload";
2473 let backend = RateBackend::Particle {
2474 spec: Arc::new(spec),
2475 };
2476 let enc =
2477 compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2478 let dec =
2479 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2480 assert_eq!(dec, data);
2481 }
2482
2483 #[test]
2484 fn roundtrip_rate_rans_particle() {
2485 let spec = crate::ParticleSpec {
2486 num_particles: 4,
2487 num_cells: 4,
2488 cell_dim: 8,
2489 num_rules: 2,
2490 selector_hidden: 16,
2491 rule_hidden: 16,
2492 context_window: 8,
2493 unroll_steps: 1,
2494 ..crate::ParticleSpec::default()
2495 };
2496 let data = b"particle rans roundtrip payload";
2497 let backend = RateBackend::Particle {
2498 spec: Arc::new(spec),
2499 };
2500 let enc =
2501 compress_rate_bytes(data, &backend, -1, CoderType::RANS, FramingMode::Framed).unwrap();
2502 let dec = decompress_rate_bytes(&enc, &backend, -1, CoderType::RANS, FramingMode::Framed)
2503 .unwrap();
2504 assert_eq!(dec, data);
2505 }
2506
2507 #[test]
2508 fn mixture_with_particle_expert_roundtrip() {
2509 let particle_spec = crate::ParticleSpec {
2510 num_particles: 4,
2511 num_cells: 4,
2512 cell_dim: 8,
2513 num_rules: 2,
2514 selector_hidden: 16,
2515 rule_hidden: 16,
2516 context_window: 8,
2517 unroll_steps: 1,
2518 ..crate::ParticleSpec::default()
2519 };
2520 let spec = MixtureSpec::new(
2521 MixtureKind::Bayes,
2522 vec![
2523 crate::MixtureExpertSpec {
2524 name: Some("particle".to_string()),
2525 log_prior: 0.0,
2526 max_order: -1,
2527 backend: RateBackend::Particle {
2528 spec: Arc::new(particle_spec),
2529 },
2530 },
2531 crate::MixtureExpertSpec {
2532 name: Some("ctw".to_string()),
2533 log_prior: 0.0,
2534 max_order: -1,
2535 backend: RateBackend::Ctw { depth: 6 },
2536 },
2537 ],
2538 );
2539 let backend = RateBackend::Mixture {
2540 spec: Arc::new(spec),
2541 };
2542 let data = b"mixture with particle expert roundtrip";
2543 let enc =
2544 compress_rate_bytes(data, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2545 let dec =
2546 decompress_rate_bytes(&enc, &backend, -1, CoderType::AC, FramingMode::Framed).unwrap();
2547 assert_eq!(dec, data);
2548 }
2549}