1#![allow(unsafe_op_in_unsafe_fn)]
2
3pub mod aixi;
77pub mod axioms;
79pub mod backends;
81pub mod coders;
83pub mod compression;
85pub mod datagen;
87pub mod mixture;
89pub(crate) mod neural_mix;
90pub mod search;
92pub(crate) mod simd_math;
93pub use backends::ctw;
95#[cfg(feature = "backend-mamba")]
96pub use backends::mambazip;
98pub use backends::match_model;
100pub use backends::particle;
102pub use backends::ppmd;
104pub use backends::rosaplus;
106#[cfg(feature = "backend-rwkv")]
107pub use backends::rwkvzip;
109pub use backends::sparse_match;
111pub use backends::zpaq_rate;
113
114use rayon::prelude::*;
115
116use crate::coders::CoderType;
117use std::cell::RefCell;
118#[cfg(any(feature = "backend-rwkv", feature = "backend-mamba"))]
119use std::collections::HashMap;
120use std::sync::Arc;
121use std::sync::OnceLock;
122
123#[derive(Clone, Copy, Debug, PartialEq, Eq)]
125pub enum GenerationUpdateMode {
126 Adaptive,
128 Frozen,
130}
131
132#[derive(Clone, Copy, Debug, PartialEq, Eq)]
134pub enum GenerationStrategy {
135 Greedy,
137 Sample,
139}
140
141#[derive(Clone, Copy, Debug)]
143pub struct GenerationConfig {
144 pub strategy: GenerationStrategy,
146 pub update_mode: GenerationUpdateMode,
148 pub seed: u64,
150 pub temperature: f64,
152 pub top_k: usize,
154 pub top_p: f64,
156}
157
158impl Default for GenerationConfig {
159 fn default() -> Self {
160 Self::sampled_frozen(42)
161 }
162}
163
164impl GenerationConfig {
165 pub const fn greedy_frozen() -> Self {
167 Self {
168 strategy: GenerationStrategy::Greedy,
169 update_mode: GenerationUpdateMode::Frozen,
170 seed: 0xD00D_F00D_CAFE_BABEu64,
171 temperature: 1.0,
172 top_k: 0,
173 top_p: 1.0,
174 }
175 }
176
177 pub const fn sampled_frozen(seed: u64) -> Self {
179 Self {
180 strategy: GenerationStrategy::Sample,
181 update_mode: GenerationUpdateMode::Frozen,
182 seed,
183 temperature: 1.0,
184 top_k: 0,
185 top_p: 1.0,
186 }
187 }
188}
189
190struct GenerationRng {
191 state: u64,
192}
193
194impl GenerationRng {
195 fn new(seed: u64) -> Self {
196 Self {
197 state: if seed == 0 {
198 0xD00D_F00D_CAFE_BABEu64
199 } else {
200 seed
201 },
202 }
203 }
204
205 fn next_u64(&mut self) -> u64 {
206 let mut x = self.state;
207 x ^= x << 13;
208 x ^= x >> 7;
209 x ^= x << 17;
210 self.state = x;
211 x
212 }
213
214 fn next_f64(&mut self) -> f64 {
215 (self.next_u64() as f64) / (u64::MAX as f64)
216 }
217}
218
219static NUM_THREADS: OnceLock<usize> = OnceLock::new();
220
221thread_local! {
222 #[cfg(feature = "backend-mamba")]
223 static MAMBA_TLS: RefCell<HashMap<usize, mambazip::Compressor>> = RefCell::new(HashMap::new());
224 #[cfg(feature = "backend-mamba")]
225 static MAMBA_RATE_TLS: RefCell<HashMap<usize, mambazip::Compressor>> = RefCell::new(HashMap::new());
226 #[cfg(feature = "backend-mamba")]
227 static MAMBA_METHOD_TLS: RefCell<HashMap<String, mambazip::Compressor>> = RefCell::new(HashMap::new());
228 #[cfg(feature = "backend-rwkv")]
229 static RWKV_TLS: RefCell<HashMap<usize, rwkvzip::Compressor>> = RefCell::new(HashMap::new());
230 #[cfg(feature = "backend-rwkv")]
231 static RWKV_RATE_TLS: RefCell<HashMap<usize, rwkvzip::Compressor>> = RefCell::new(HashMap::new());
232 #[cfg(feature = "backend-rwkv")]
233 static RWKV_METHOD_TLS: RefCell<HashMap<String, rwkvzip::Compressor>> = RefCell::new(HashMap::new());
234}
235
236#[cfg(feature = "backend-zpaq")]
237impl Default for CompressionBackend {
238 fn default() -> Self {
239 CompressionBackend::Zpaq {
240 method: "5".to_string(),
241 }
242 }
243}
244
245#[cfg(not(feature = "backend-zpaq"))]
246impl Default for CompressionBackend {
247 fn default() -> Self {
248 CompressionBackend::Rate {
249 rate_backend: RateBackend::default(),
250 coder: CoderType::AC,
251 framing: compression::FramingMode::Raw,
252 }
253 }
254}
255
256thread_local! {
257 static DEFAULT_CTX: RefCell<InfotheoryCtx> = RefCell::new(InfotheoryCtx::default());
258}
259
260pub fn get_default_ctx() -> InfotheoryCtx {
262 DEFAULT_CTX.with(|ctx| ctx.borrow().clone())
263}
264
265pub fn set_default_ctx(ctx: InfotheoryCtx) {
267 DEFAULT_CTX.with(|c| *c.borrow_mut() = ctx);
268}
269
270#[inline(always)]
271fn with_default_ctx<R>(f: impl FnOnce(&InfotheoryCtx) -> R) -> R {
272 DEFAULT_CTX.with(|ctx| f(&ctx.borrow()))
273}
274
275pub fn mutual_information_rate_backend(
279 x: &[u8],
280 y: &[u8],
281 max_order: i64,
282 backend: &RateBackend,
283) -> f64 {
284 let (x, y) = aligned_prefix(x, y);
285 if x.is_empty() {
286 return 0.0;
287 }
288 let h_x = entropy_rate_backend(x, max_order, backend);
291 let h_y = entropy_rate_backend(y, max_order, backend);
292 let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
293 (h_x + h_y - h_xy).max(0.0)
294}
295
296pub fn ned_rate_backend(x: &[u8], y: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
300 let (x, y) = aligned_prefix(x, y);
301 if x.is_empty() {
302 return 0.0;
303 }
304 let h_x = entropy_rate_backend(x, max_order, backend);
305 let h_y = entropy_rate_backend(y, max_order, backend);
306 let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
307 let min_h = h_x.min(h_y);
308 let max_h = h_x.max(h_y);
309 if max_h == 0.0 {
310 0.0
311 } else {
312 ((h_xy - min_h) / max_h).clamp(0.0, 1.0)
313 }
314}
315
316pub fn nte_rate_backend(x: &[u8], y: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
320 let (x, y) = aligned_prefix(x, y);
321 if x.is_empty() {
322 return 0.0;
323 }
324 let h_x = entropy_rate_backend(x, max_order, backend);
325 let h_y = entropy_rate_backend(y, max_order, backend);
326 let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
327 let max_h = h_x.max(h_y);
328 if max_h == 0.0 {
329 0.0
330 } else {
331 let vi = (h_xy - h_x).max(0.0) + (h_xy - h_y).max(0.0);
334 (vi / max_h).clamp(0.0, 2.0)
335 }
336}
337
338#[derive(Clone)]
340pub enum RateBackend {
341 RosaPlus,
343 Match {
345 hash_bits: usize,
347 min_len: usize,
349 max_len: usize,
351 base_mix: f64,
353 confidence_scale: f64,
355 },
356 SparseMatch {
358 hash_bits: usize,
360 min_len: usize,
362 max_len: usize,
364 gap_min: usize,
366 gap_max: usize,
368 base_mix: f64,
370 confidence_scale: f64,
372 },
373 Ppmd {
375 order: usize,
377 memory_mb: usize,
379 },
380 #[cfg(feature = "backend-mamba")]
381 Mamba {
383 model: Arc<mambazip::Model>,
385 },
386 #[cfg(feature = "backend-mamba")]
387 MambaMethod {
389 method: String,
391 },
392 #[cfg(feature = "backend-rwkv")]
393 Rwkv7 {
395 model: Arc<rwkvzip::Model>,
397 },
398 #[cfg(feature = "backend-rwkv")]
399 Rwkv7Method {
401 method: String,
403 },
404 Zpaq {
406 method: String,
408 },
409 Mixture {
411 spec: Arc<MixtureSpec>,
413 },
414 Particle {
416 spec: Arc<ParticleSpec>,
418 },
419 Calibrated {
421 spec: Arc<CalibratedSpec>,
423 },
424 Ctw {
426 depth: usize,
428 },
429 FacCtw {
431 base_depth: usize,
433 num_percept_bits: usize,
435 encoding_bits: usize,
437 },
438}
439
440#[allow(clippy::derivable_impls)]
441impl Default for RateBackend {
442 fn default() -> Self {
443 #[cfg(feature = "backend-rosa")]
444 {
445 RateBackend::RosaPlus
446 }
447 #[cfg(all(not(feature = "backend-rosa"), feature = "backend-zpaq"))]
448 {
449 RateBackend::Zpaq {
450 method: "1".to_string(),
451 }
452 }
453 #[cfg(all(not(feature = "backend-rosa"), not(feature = "backend-zpaq")))]
454 {
455 RateBackend::Ctw { depth: 16 }
456 }
457 }
458}
459
460#[derive(Clone)]
462pub enum CompressionBackend {
463 Zpaq {
465 method: String,
467 },
468 #[cfg(feature = "backend-rwkv")]
469 Rwkv7 {
471 model: Arc<rwkvzip::Model>,
473 coder: CoderType,
475 },
476 Rate {
478 rate_backend: RateBackend,
480 coder: CoderType,
482 framing: compression::FramingMode,
484 },
485}
486
487#[derive(Clone, Copy, Debug, Eq, PartialEq)]
489pub enum MixtureKind {
490 Bayes,
492 FadingBayes,
494 Switching,
496 Mdl,
498 Neural,
500}
501
502#[derive(Clone, Copy, Debug, Eq, PartialEq)]
504pub enum CalibrationContextKind {
505 Global,
507 ByteClass,
509 Text,
511 Repeat,
513 TextRepeat,
515}
516
517#[derive(Clone)]
519pub struct CalibratedSpec {
520 pub base: RateBackend,
522 pub context: CalibrationContextKind,
524 pub bins: usize,
526 pub learning_rate: f64,
528 pub bias_clip: f64,
530}
531
532#[derive(Clone)]
534pub struct MixtureExpertSpec {
535 pub name: Option<String>,
537 pub log_prior: f64,
539 pub max_order: i64,
541 pub backend: RateBackend,
543}
544
545#[derive(Clone)]
547pub struct MixtureSpec {
548 pub kind: MixtureKind,
550 pub alpha: f64,
552 pub decay: Option<f64>,
554 pub experts: Vec<MixtureExpertSpec>,
556}
557
558impl MixtureSpec {
559 pub fn new(kind: MixtureKind, experts: Vec<MixtureExpertSpec>) -> Self {
561 Self {
562 kind,
563 alpha: 0.01,
564 decay: None,
565 experts,
566 }
567 }
568
569 pub fn with_alpha(mut self, alpha: f64) -> Self {
571 self.alpha = alpha;
572 self
573 }
574
575 pub fn with_decay(mut self, decay: f64) -> Self {
577 self.decay = Some(decay);
578 self
579 }
580
581 pub fn build_experts(&self) -> Vec<crate::mixture::ExpertConfig> {
583 self.experts
584 .iter()
585 .map(|spec| {
586 crate::mixture::ExpertConfig::from_rate_backend(
587 spec.name.clone(),
588 spec.log_prior,
589 spec.backend.clone(),
590 spec.max_order,
591 )
592 })
593 .collect()
594 }
595}
596
597#[derive(Clone, Debug)]
599pub struct ParticleSpec {
600 pub num_particles: usize,
602 pub context_window: usize,
604 pub unroll_steps: usize,
606 pub num_cells: usize,
608 pub cell_dim: usize,
610 pub num_rules: usize,
612 pub selector_hidden: usize,
614 pub rule_hidden: usize,
616 pub noise_dim: usize,
618 pub deterministic: bool,
620 pub enable_noise: bool,
622 pub noise_scale: f64,
624 pub noise_anneal_steps: usize,
626 pub learning_rate_readout: f64,
628 pub learning_rate_selector: f64,
630 pub learning_rate_rule: f64,
632 pub bptt_depth: usize,
634 pub optimizer_momentum: f64,
636 pub grad_clip: f64,
638 pub state_clip: f64,
640 pub forget_lambda: f64,
642 pub resample_threshold: f64,
644 pub mutate_fraction: f64,
646 pub mutate_scale: f64,
648 pub mutate_model_params: bool,
650 pub diagnostics_interval: usize,
652 pub min_prob: f64,
654 pub seed: u64,
656}
657
658impl Default for ParticleSpec {
659 fn default() -> Self {
660 Self {
661 num_particles: 16,
662 context_window: 32,
663 unroll_steps: 2,
664 num_cells: 8,
665 cell_dim: 32,
666 num_rules: 4,
667 selector_hidden: 64,
668 rule_hidden: 64,
669 noise_dim: 8,
670 deterministic: true,
671 enable_noise: false,
672 noise_scale: 0.10,
673 noise_anneal_steps: 8192,
674 learning_rate_readout: 0.01,
675 learning_rate_selector: 1e-4,
676 learning_rate_rule: 3e-4,
677 bptt_depth: 3,
678 optimizer_momentum: 0.05,
679 grad_clip: 1.0,
680 state_clip: 8.0,
681 forget_lambda: 0.0,
682 resample_threshold: 0.5,
683 mutate_fraction: 0.1,
684 mutate_scale: 0.01,
685 mutate_model_params: false,
686 diagnostics_interval: 0,
687 min_prob: 2f64.powi(-24),
688 seed: 42,
689 }
690 }
691}
692
693impl ParticleSpec {
694 pub fn validate(&self) -> Result<(), String> {
696 if self.num_particles == 0 {
697 return Err("num_particles must be > 0".into());
698 }
699 if self.context_window == 0 {
700 return Err("context_window must be > 0".into());
701 }
702 if self.unroll_steps == 0 {
703 return Err("unroll_steps must be > 0".into());
704 }
705 if self.num_cells == 0 {
706 return Err("num_cells must be > 0".into());
707 }
708 if self.cell_dim == 0 {
709 return Err("cell_dim must be > 0".into());
710 }
711 if self.num_rules == 0 {
712 return Err("num_rules must be > 0".into());
713 }
714 if self.selector_hidden == 0 {
715 return Err("selector_hidden must be > 0".into());
716 }
717 if self.rule_hidden == 0 {
718 return Err("rule_hidden must be > 0".into());
719 }
720 if !self.learning_rate_readout.is_finite() || self.learning_rate_readout < 0.0 {
721 return Err("learning_rate_readout must be finite and non-negative".into());
722 }
723 if !self.learning_rate_selector.is_finite() || self.learning_rate_selector < 0.0 {
724 return Err("learning_rate_selector must be finite and non-negative".into());
725 }
726 if !self.learning_rate_rule.is_finite() || self.learning_rate_rule < 0.0 {
727 return Err("learning_rate_rule must be finite and non-negative".into());
728 }
729 if !self.noise_scale.is_finite() || self.noise_scale < 0.0 {
730 return Err("noise_scale must be finite and non-negative".into());
731 }
732 if !self.optimizer_momentum.is_finite()
733 || self.optimizer_momentum < 0.0
734 || self.optimizer_momentum >= 1.0
735 {
736 return Err("optimizer_momentum must be finite and in [0, 1)".into());
737 }
738 if self.bptt_depth == 0 {
739 return Err("bptt_depth must be > 0".into());
740 }
741 if !(self.resample_threshold > 0.0 && self.resample_threshold <= 1.0) {
742 return Err("resample_threshold must be in (0, 1]".into());
743 }
744 if !(self.mutate_fraction >= 0.0 && self.mutate_fraction <= 1.0) {
745 return Err("mutate_fraction must be in [0, 1]".into());
746 }
747 if !(self.min_prob > 0.0 && self.min_prob < 0.5) {
748 return Err("min_prob must be in (0, 0.5)".into());
749 }
750 Ok(())
751 }
752}
753
754#[derive(Clone, Default)]
756pub struct InfotheoryCtx {
757 pub rate_backend: RateBackend,
759 pub compression_backend: CompressionBackend,
761}
762
763pub struct RateBackendSession {
765 predictor: crate::mixture::RateBackendPredictor,
766}
767
768impl RateBackendSession {
769 pub fn from_backend(
771 backend: RateBackend,
772 max_order: i64,
773 total_symbols: Option<u64>,
774 ) -> Result<Self, String> {
775 use crate::mixture::OnlineBytePredictor;
776
777 let mut predictor = crate::mixture::RateBackendPredictor::from_backend(
778 backend,
779 max_order,
780 crate::mixture::DEFAULT_MIN_PROB,
781 );
782 predictor.begin_stream(total_symbols)?;
783 Ok(Self { predictor })
784 }
785
786 pub fn observe(&mut self, data: &[u8]) {
788 use crate::mixture::OnlineBytePredictor;
789
790 for &byte in data {
791 self.predictor.update(byte);
792 }
793 }
794
795 pub fn condition(&mut self, data: &[u8]) {
797 use crate::mixture::OnlineBytePredictor;
798
799 for &byte in data {
800 self.predictor.update_frozen(byte);
801 }
802 }
803
804 pub fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
806 use crate::mixture::OnlineBytePredictor;
807
808 self.predictor.reset_frozen(total_symbols)
809 }
810
811 pub fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
813 use crate::mixture::OnlineBytePredictor;
814
815 self.predictor.fill_log_probs(out);
816 }
817
818 pub fn generate_bytes(&mut self, bytes: usize, config: GenerationConfig) -> Vec<u8> {
820 use crate::mixture::OnlineBytePredictor;
821
822 if bytes == 0 {
823 return Vec::new();
824 }
825
826 let mut out = Vec::with_capacity(bytes);
827 let mut logps = [0.0f64; 256];
828 let mut rng = GenerationRng::new(config.seed);
829
830 for _ in 0..bytes {
831 match &mut self.predictor {
832 crate::mixture::RateBackendPredictor::Rosa { .. } => {
834 for (sym, slot) in logps.iter_mut().enumerate() {
835 *slot = self.predictor.log_prob(sym as u8);
836 }
837 }
838 _ => self.predictor.fill_log_probs(&mut logps),
839 }
840 let byte = pick_generated_byte(&logps, config, &mut rng);
841 match config.update_mode {
842 GenerationUpdateMode::Adaptive => self.predictor.update(byte),
843 GenerationUpdateMode::Frozen => self.predictor.update_frozen(byte),
844 }
845 out.push(byte);
846 }
847
848 out
849 }
850
851 pub fn finish(&mut self) -> Result<(), String> {
853 use crate::mixture::OnlineBytePredictor;
854
855 self.predictor.finish_stream()
856 }
857}
858
859impl InfotheoryCtx {
860 pub fn new(rate_backend: RateBackend, compression_backend: CompressionBackend) -> Self {
862 Self {
863 rate_backend,
864 compression_backend,
865 }
866 }
867
868 pub fn with_zpaq(method: impl Into<String>) -> Self {
870 Self {
871 rate_backend: RateBackend::RosaPlus,
872 compression_backend: CompressionBackend::Zpaq {
873 method: method.into(),
874 },
875 }
876 }
877
878 pub fn compress_size(&self, data: &[u8]) -> u64 {
880 compress_size_backend(data, &self.compression_backend)
881 }
882
883 pub fn compress_size_chain(&self, parts: &[&[u8]]) -> u64 {
885 compress_size_chain_backend(parts, &self.compression_backend)
886 }
887
888 pub fn rate_backend_session(
890 &self,
891 max_order: i64,
892 total_symbols: Option<u64>,
893 ) -> Result<RateBackendSession, String> {
894 RateBackendSession::from_backend(self.rate_backend.clone(), max_order, total_symbols)
895 }
896
897 pub fn entropy_rate_bytes(&self, data: &[u8], max_order: i64) -> f64 {
899 entropy_rate_backend(data, max_order, &self.rate_backend)
900 }
901
902 pub fn biased_entropy_rate_bytes(&self, data: &[u8], max_order: i64) -> f64 {
904 biased_entropy_rate_backend(data, max_order, &self.rate_backend)
905 }
906
907 pub fn cross_entropy_rate_bytes(
909 &self,
910 test_data: &[u8],
911 train_data: &[u8],
912 max_order: i64,
913 ) -> f64 {
914 cross_entropy_rate_backend(test_data, train_data, max_order, &self.rate_backend)
915 }
916
917 pub fn cross_entropy_bytes(&self, test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
919 if max_order == 0 {
920 if test_data.is_empty() {
921 return 0.0;
922 }
923 let p_x = byte_histogram(test_data);
924 let p_y = byte_histogram(train_data);
925 let mut h = 0.0f64;
926 for i in 0..256 {
927 if p_x[i] > 0.0 {
928 let q_y = p_y[i].max(1e-12);
929 h -= p_x[i] * q_y.log2();
930 }
931 }
932 h
933 } else {
934 self.cross_entropy_rate_bytes(test_data, train_data, max_order)
935 }
936 }
937
938 pub fn joint_entropy_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
940 let (x, y) = aligned_prefix(x, y);
941 if x.is_empty() {
942 return 0.0;
943 }
944 joint_entropy_rate_backend(x, y, max_order, &self.rate_backend)
945 }
946
947 pub fn conditional_entropy_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
949 let (x, y) = aligned_prefix(x, y);
950 if x.is_empty() {
951 return 0.0;
952 }
953 let h_xy = self.joint_entropy_rate_bytes(x, y, max_order);
954 let h_y = self.entropy_rate_bytes(y, max_order);
955 (h_xy - h_y).max(0.0)
956 }
957
958 pub fn cross_entropy_conditional_chain(&self, prefix_parts: &[&[u8]], data: &[u8]) -> f64 {
961 match &self.rate_backend {
962 RateBackend::RosaPlus => {
963 let mut prefix = Vec::new();
964 let total: usize = prefix_parts.iter().map(|p| p.len()).sum();
965 prefix.reserve(total);
966 for p in prefix_parts {
967 prefix.extend_from_slice(p);
968 }
969 cross_entropy_rate_backend(data, &prefix, -1, &RateBackend::RosaPlus)
970 }
971 RateBackend::Match { .. }
972 | RateBackend::SparseMatch { .. }
973 | RateBackend::Ppmd { .. }
974 | RateBackend::Calibrated { .. } => {
975 prequential_rate_backend(data, prefix_parts, -1, &self.rate_backend)
976 }
977 #[cfg(feature = "backend-rwkv")]
978 RateBackend::Rwkv7 { model } => with_rwkv_tls(model, |c| {
979 c.cross_entropy_conditional_chain(prefix_parts, data)
980 .unwrap_or_else(|e| panic!("rwkv conditional-chain scoring failed: {e:#}"))
981 }),
982 #[cfg(feature = "backend-rwkv")]
983 RateBackend::Rwkv7Method { method } => with_rwkv_method_tls(method, |c| {
984 c.cross_entropy_conditional_chain(prefix_parts, data)
985 .unwrap_or_else(|e| {
986 panic!("rwkv method conditional-chain scoring failed: {e:#}")
987 })
988 }),
989 #[cfg(feature = "backend-mamba")]
990 RateBackend::Mamba { model } => with_mamba_tls(model, |c| {
991 c.cross_entropy_conditional_chain(prefix_parts, data)
992 .unwrap_or_else(|e| panic!("mamba conditional-chain scoring failed: {e:#}"))
993 }),
994 #[cfg(feature = "backend-mamba")]
995 RateBackend::MambaMethod { method } => with_mamba_method_tls(method, |c| {
996 c.cross_entropy_conditional_chain(prefix_parts, data)
997 .unwrap_or_else(|e| {
998 panic!("mamba method conditional-chain scoring failed: {e:#}")
999 })
1000 }),
1001 RateBackend::Ctw { depth } => {
1002 if data.is_empty() {
1003 return 0.0;
1004 }
1005 let mut tree = crate::ctw::ContextTree::new(*depth);
1006 for &part in prefix_parts {
1007 for &b in part {
1008 for i in (0..8).rev() {
1009 tree.update(((b >> i) & 1) == 1);
1010 }
1011 }
1012 }
1013 let log_p_prefix = tree.get_log_block_probability();
1014 for &b in data {
1015 for i in (0..8).rev() {
1016 tree.update(((b >> i) & 1) == 1);
1017 }
1018 }
1019 let log_p_joint = tree.get_log_block_probability();
1020 let log_p_cond = log_p_joint - log_p_prefix;
1021 let bits = -log_p_cond / std::f64::consts::LN_2;
1022 bits / (data.len() as f64)
1023 }
1024 RateBackend::Zpaq { method } => {
1025 if data.is_empty() {
1026 return 0.0;
1027 }
1028 let mut model =
1029 crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
1030 for &part in prefix_parts {
1031 model.update_and_score(part);
1032 }
1033 let bits = model.update_and_score(data);
1034 bits / (data.len() as f64)
1035 }
1036 RateBackend::Mixture { spec } => {
1037 if data.is_empty() {
1038 return 0.0;
1039 }
1040 let experts = spec.build_experts();
1041 let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
1042 .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
1043 let total = prefix_parts
1044 .iter()
1045 .map(|p| p.len() as u64)
1046 .sum::<u64>()
1047 .saturating_add(data.len() as u64);
1048 mix.begin_stream(Some(total))
1049 .unwrap_or_else(|e| panic!("Mixture stream init failed: {e}"));
1050 for &part in prefix_parts {
1051 for &b in part {
1052 mix.step(b);
1053 }
1054 }
1055 let mut bits = 0.0;
1056 for &b in data {
1057 bits -= mix.step(b) / std::f64::consts::LN_2;
1058 }
1059 bits / (data.len() as f64)
1060 }
1061 RateBackend::Particle { spec } => {
1062 if data.is_empty() {
1063 return 0.0;
1064 }
1065 let mut runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
1066 for &part in prefix_parts {
1067 for &b in part {
1068 runtime.step(b);
1069 }
1070 }
1071 let mut bits = 0.0;
1072 for &b in data {
1073 bits -= runtime.step(b) / std::f64::consts::LN_2;
1074 }
1075 bits / (data.len() as f64)
1076 }
1077 RateBackend::FacCtw {
1078 base_depth,
1079 num_percept_bits: _,
1080 encoding_bits,
1081 } => {
1082 if data.is_empty() {
1083 return 0.0;
1084 }
1085 let bits_per_byte = (*encoding_bits).clamp(1, 8);
1086 let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte);
1087 for &part in prefix_parts {
1088 for &b in part {
1089 for i in 0..bits_per_byte {
1091 let bit_idx = i;
1092 fac.update(((b >> i) & 1) == 1, bit_idx);
1094 }
1095 }
1096 }
1097 let log_p_prefix = fac.get_log_block_probability();
1098 for &b in data {
1099 for i in 0..bits_per_byte {
1100 let bit_idx = i;
1101 fac.update(((b >> i) & 1) == 1, bit_idx);
1102 }
1103 }
1104 let log_p_joint = fac.get_log_block_probability();
1105 let log_p_cond = log_p_joint - log_p_prefix;
1106 let bits = -log_p_cond / std::f64::consts::LN_2;
1107 bits / (data.len() as f64)
1108 }
1109 }
1110 }
1111
1112 pub fn generate_bytes(&self, prompt: &[u8], bytes: usize, max_order: i64) -> Vec<u8> {
1116 self.generate_bytes_with_config(prompt, bytes, max_order, GenerationConfig::default())
1117 }
1118
1119 pub fn generate_bytes_with_config(
1121 &self,
1122 prompt: &[u8],
1123 bytes: usize,
1124 max_order: i64,
1125 config: GenerationConfig,
1126 ) -> Vec<u8> {
1127 generate_rate_backend_chain(&[prompt], bytes, max_order, &self.rate_backend, config)
1128 }
1129
1130 pub fn generate_bytes_conditional_chain(
1132 &self,
1133 prefix_parts: &[&[u8]],
1134 bytes: usize,
1135 max_order: i64,
1136 ) -> Vec<u8> {
1137 self.generate_bytes_conditional_chain_with_config(
1138 prefix_parts,
1139 bytes,
1140 max_order,
1141 GenerationConfig::default(),
1142 )
1143 }
1144
1145 pub fn generate_bytes_conditional_chain_with_config(
1147 &self,
1148 prefix_parts: &[&[u8]],
1149 bytes: usize,
1150 max_order: i64,
1151 config: GenerationConfig,
1152 ) -> Vec<u8> {
1153 generate_rate_backend_chain(prefix_parts, bytes, max_order, &self.rate_backend, config)
1154 }
1155
1156 pub fn ncd_bytes(&self, x: &[u8], y: &[u8], variant: NcdVariant) -> f64 {
1158 ncd_bytes_backend(x, y, &self.compression_backend, variant)
1159 }
1160
1161 pub fn mutual_information_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1163 mutual_information_rate_backend(x, y, max_order, &self.rate_backend)
1164 }
1165
1166 pub fn mutual_information_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1168 if max_order == 0 {
1169 mutual_information_marg_bytes(x, y)
1170 } else {
1171 self.mutual_information_rate_bytes(x, y, max_order)
1172 }
1173 }
1174
1175 pub fn conditional_entropy_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1177 let (x, y) = aligned_prefix(x, y);
1178 if max_order == 0 {
1179 let h_xy = joint_marginal_entropy_bytes(x, y);
1180 let h_y = marginal_entropy_bytes(y);
1181 (h_xy - h_y).max(0.0)
1182 } else {
1183 let h_xy = self.joint_entropy_rate_bytes(x, y, max_order);
1184 let h_y = self.entropy_rate_bytes(y, max_order);
1185 (h_xy - h_y).max(0.0)
1186 }
1187 }
1188
1189 pub fn ned_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1191 if max_order == 0 {
1192 ned_marg_bytes(x, y)
1193 } else {
1194 ned_rate_backend(x, y, max_order, &self.rate_backend)
1195 }
1196 }
1197
1198 pub fn ned_cons_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1200 let (x, y) = aligned_prefix(x, y);
1201 let (h_x, h_y, h_xy) = if max_order == 0 {
1202 (
1203 marginal_entropy_bytes(x),
1204 marginal_entropy_bytes(y),
1205 joint_marginal_entropy_bytes(x, y),
1206 )
1207 } else {
1208 (
1209 self.entropy_rate_bytes(x, max_order),
1210 self.entropy_rate_bytes(y, max_order),
1211 self.joint_entropy_rate_bytes(x, y, max_order),
1212 )
1213 };
1214 let min_h = h_x.min(h_y);
1215 if h_xy == 0.0 {
1216 0.0
1217 } else {
1218 ((h_xy - min_h) / h_xy).clamp(0.0, 1.0)
1219 }
1220 }
1221
1222 pub fn nte_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1224 if max_order == 0 {
1225 nte_marg_bytes(x, y)
1226 } else {
1227 nte_rate_backend(x, y, max_order, &self.rate_backend)
1228 }
1229 }
1230
1231 pub fn intrinsic_dependence_bytes(&self, data: &[u8], max_order: i64) -> f64 {
1233 let h_marginal = marginal_entropy_bytes(data);
1234 if h_marginal < 1e-9 {
1235 return 0.0;
1236 }
1237 let h_rate = self.entropy_rate_bytes(data, max_order);
1238 ((h_marginal - h_rate) / h_marginal).clamp(0.0, 1.0)
1239 }
1240
1241 pub fn resistance_to_transformation_bytes(&self, x: &[u8], tx: &[u8], max_order: i64) -> f64 {
1243 let (x, tx) = aligned_prefix(x, tx);
1244 let h_x = if max_order == 0 {
1245 marginal_entropy_bytes(x)
1246 } else {
1247 self.entropy_rate_bytes(x, max_order)
1248 };
1249 if h_x < 1e-9 {
1250 return 0.0;
1251 }
1252 let mi = self.mutual_information_bytes(x, tx, max_order);
1253 (mi / h_x).clamp(0.0, 1.0)
1254 }
1255}
1256
1257#[cfg(feature = "backend-rwkv")]
1258pub fn load_rwkv7_model_from_path(path: &str) -> Arc<rwkvzip::Model> {
1260 rwkvzip::Compressor::load_model(path).expect("failed to load RWKV7 model")
1261}
1262
1263#[cfg(feature = "backend-mamba")]
1264pub fn load_mamba_model_from_path(path: &str) -> Arc<mambazip::Model> {
1266 mambazip::Compressor::load_model(path).expect("failed to load Mamba model")
1267}
1268
1269#[inline(always)]
1270fn aligned_prefix<'a>(x: &'a [u8], y: &'a [u8]) -> (&'a [u8], &'a [u8]) {
1271 let n = x.len().min(y.len());
1272 (&x[..n], &y[..n])
1273}
1274
1275#[cfg(feature = "backend-zpaq")]
1276#[inline(always)]
1277fn zpaq_compress_size_bytes(data: &[u8], method: &str) -> u64 {
1278 zpaq_rs::compress_size(data, method).unwrap_or(0)
1279}
1280
1281#[cfg(not(feature = "backend-zpaq"))]
1282#[inline(always)]
1283fn zpaq_compress_size_bytes(_data: &[u8], _method: &str) -> u64 {
1284 panic!("CompressionBackend::Zpaq is unavailable: build with feature 'backend-zpaq'")
1285}
1286
1287#[cfg(feature = "backend-zpaq")]
1288#[inline(always)]
1289fn zpaq_compress_size_parallel_bytes(data: &[u8], method: &str, threads: usize) -> u64 {
1290 zpaq_rs::compress_size_parallel(data, method, threads).unwrap_or(0)
1291}
1292
1293#[cfg(not(feature = "backend-zpaq"))]
1294#[inline(always)]
1295fn zpaq_compress_size_parallel_bytes(_data: &[u8], _method: &str, _threads: usize) -> u64 {
1296 panic!("CompressionBackend::Zpaq is unavailable: build with feature 'backend-zpaq'")
1297}
1298
1299#[cfg(feature = "backend-zpaq")]
1300#[inline(always)]
1301fn zpaq_compress_size_stream<R: std::io::Read + Send>(reader: R, method: &str) -> u64 {
1302 zpaq_rs::compress_size_stream(reader, method, None, None).unwrap_or(0)
1303}
1304
1305#[cfg(not(feature = "backend-zpaq"))]
1306#[inline(always)]
1307fn zpaq_compress_size_stream<R: std::io::Read + Send>(_reader: R, _method: &str) -> u64 {
1308 panic!("CompressionBackend::Zpaq is unavailable: build with feature 'backend-zpaq'")
1309}
1310
1311#[cfg(feature = "backend-zpaq")]
1312#[inline(always)]
1313fn zpaq_compress_to_vec(data: &[u8], method: &str) -> anyhow::Result<Vec<u8>> {
1314 Ok(zpaq_rs::compress_to_vec(data, method)?)
1315}
1316
1317#[cfg(not(feature = "backend-zpaq"))]
1318#[inline(always)]
1319fn zpaq_compress_to_vec(_data: &[u8], _method: &str) -> anyhow::Result<Vec<u8>> {
1320 anyhow::bail!("zpaq backend disabled at compile time (enable feature 'backend-zpaq')")
1321}
1322
1323#[cfg(feature = "backend-zpaq")]
1324#[inline(always)]
1325fn zpaq_decompress_to_vec(data: &[u8]) -> anyhow::Result<Vec<u8>> {
1326 Ok(zpaq_rs::decompress_to_vec(data)?)
1327}
1328
1329#[cfg(not(feature = "backend-zpaq"))]
1330#[inline(always)]
1331fn zpaq_decompress_to_vec(_data: &[u8]) -> anyhow::Result<Vec<u8>> {
1332 anyhow::bail!("zpaq backend disabled at compile time (enable feature 'backend-zpaq')")
1333}
1334
1335#[inline(always)]
1337pub fn get_compressed_size(path: &str, method: &str) -> u64 {
1338 zpaq_compress_size_bytes(&std::fs::read(path).unwrap(), method)
1341}
1342
1343pub fn validate_zpaq_rate_method(method: &str) -> Result<(), String> {
1345 #[cfg(feature = "backend-zpaq")]
1346 {
1347 zpaq_rate::validate_zpaq_rate_method(method)
1348 }
1349 #[cfg(not(feature = "backend-zpaq"))]
1350 {
1351 let _ = method;
1352 Err("zpaq backend disabled at compile time".to_string())
1353 }
1354}
1355
1356#[cfg(feature = "backend-rwkv")]
1357fn with_rwkv_tls<R>(
1358 model: &Arc<rwkvzip::Model>,
1359 f: impl FnOnce(&mut rwkvzip::Compressor) -> R,
1360) -> R {
1361 let key = Arc::as_ptr(model) as usize;
1362 RWKV_TLS.with(|cell| {
1363 let mut map = cell.borrow_mut();
1364 let comp = map
1365 .entry(key)
1366 .or_insert_with(|| rwkvzip::Compressor::new_from_model(model.clone()));
1367 f(comp)
1368 })
1369}
1370
1371#[cfg(feature = "backend-rwkv")]
1372fn with_rwkv_method_tls<R>(method: &str, f: impl FnOnce(&mut rwkvzip::Compressor) -> R) -> R {
1373 RWKV_METHOD_TLS.with(|cell| {
1374 let mut map = cell.borrow_mut();
1375 let mut comp = if let Some(template) = map.get(method) {
1378 template.clone()
1379 } else {
1380 let template = rwkvzip::Compressor::new_from_method(method).unwrap_or_else(|e| {
1381 panic!("invalid rwkv method '{method}': {e:#}");
1382 });
1383 map.insert(method.to_string(), template.clone());
1384 template
1385 };
1386 drop(map);
1387 f(&mut comp)
1388 })
1389}
1390
1391#[cfg(feature = "backend-rwkv")]
1392fn with_rwkv_rate_tls<R>(
1393 model: &Arc<rwkvzip::Model>,
1394 f: impl FnOnce(&mut rwkvzip::Compressor) -> R,
1395) -> R {
1396 let key = Arc::as_ptr(model) as usize;
1397 RWKV_RATE_TLS.with(|cell| {
1398 let mut map = cell.borrow_mut();
1399 let mut comp = if let Some(template) = map.get(&key) {
1400 template.clone()
1401 } else {
1402 let template = rwkvzip::Compressor::new_from_model(model.clone());
1403 map.insert(key, template.clone());
1404 template
1405 };
1406 drop(map);
1407 f(&mut comp)
1408 })
1409}
1410
1411#[cfg(feature = "backend-mamba")]
1412fn with_mamba_tls<R>(
1413 model: &Arc<mambazip::Model>,
1414 f: impl FnOnce(&mut mambazip::Compressor) -> R,
1415) -> R {
1416 let key = Arc::as_ptr(model) as usize;
1417 MAMBA_TLS.with(|cell| {
1418 let mut map = cell.borrow_mut();
1419 let comp = map
1420 .entry(key)
1421 .or_insert_with(|| mambazip::Compressor::new_from_model(model.clone()));
1422 f(comp)
1423 })
1424}
1425
1426#[cfg(feature = "backend-mamba")]
1427fn with_mamba_rate_tls<R>(
1428 model: &Arc<mambazip::Model>,
1429 f: impl FnOnce(&mut mambazip::Compressor) -> R,
1430) -> R {
1431 let key = Arc::as_ptr(model) as usize;
1432 MAMBA_RATE_TLS.with(|cell| {
1433 let mut map = cell.borrow_mut();
1434 let mut comp = if let Some(template) = map.get(&key) {
1435 template.clone()
1436 } else {
1437 let template = mambazip::Compressor::new_from_model(model.clone());
1438 map.insert(key, template.clone());
1439 template
1440 };
1441 drop(map);
1442 f(&mut comp)
1443 })
1444}
1445
1446#[cfg(feature = "backend-mamba")]
1447fn with_mamba_method_tls<R>(method: &str, f: impl FnOnce(&mut mambazip::Compressor) -> R) -> R {
1448 MAMBA_METHOD_TLS.with(|cell| {
1449 let mut map = cell.borrow_mut();
1450 let mut comp = if let Some(template) = map.get(method) {
1451 template.clone()
1452 } else {
1453 let template = mambazip::Compressor::new_from_method(method).unwrap_or_else(|e| {
1454 panic!("invalid mamba method '{method}': {e:#}");
1455 });
1456 map.insert(method.to_string(), template.clone());
1457 template
1458 };
1459 drop(map);
1460 f(&mut comp)
1461 })
1462}
1463
1464struct SliceChainReader<'a> {
1465 parts: &'a [&'a [u8]],
1466 i: usize,
1467 off: usize,
1468}
1469
1470impl<'a> SliceChainReader<'a> {
1471 fn new(parts: &'a [&'a [u8]]) -> Self {
1472 Self {
1473 parts,
1474 i: 0,
1475 off: 0,
1476 }
1477 }
1478}
1479
1480impl<'a> std::io::Read for SliceChainReader<'a> {
1481 fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
1482 let mut total = 0;
1483 if buf.is_empty() {
1484 return Ok(0);
1485 }
1486 while self.i < self.parts.len() {
1487 let p = self.parts[self.i];
1488 if self.off >= p.len() {
1489 self.i += 1;
1490 self.off = 0;
1491 continue;
1492 }
1493 let n = (p.len() - self.off).min(buf.len());
1494 buf[..n].copy_from_slice(&p[self.off..self.off + n]);
1496
1497 self.off += n;
1499 total += n;
1500
1501 let tmp = buf;
1503 buf = &mut tmp[n..];
1504
1505 if buf.is_empty() {
1506 break;
1507 }
1508 }
1509 Ok(total)
1510 }
1511}
1512
1513pub fn compress_size_chain_backend(parts: &[&[u8]], backend: &CompressionBackend) -> u64 {
1515 match backend {
1516 CompressionBackend::Zpaq { method } => {
1517 let r = SliceChainReader::new(parts);
1518 zpaq_compress_size_stream(r, method.as_str())
1519 }
1520 #[cfg(feature = "backend-rwkv")]
1521 CompressionBackend::Rwkv7 { model, coder } => {
1522 with_rwkv_tls(model, |c| c.compress_size_chain(parts, *coder).unwrap_or(0))
1523 }
1524 CompressionBackend::Rate {
1525 rate_backend,
1526 coder,
1527 framing,
1528 } => {
1529 crate::compression::compress_rate_size_chain(parts, rate_backend, -1, *coder, *framing)
1530 .unwrap_or(0)
1531 }
1532 }
1533}
1534
1535pub fn compress_size_backend(data: &[u8], backend: &CompressionBackend) -> u64 {
1537 match backend {
1538 CompressionBackend::Zpaq { method } => zpaq_compress_size_bytes(data, method.as_str()),
1539 #[cfg(feature = "backend-rwkv")]
1540 CompressionBackend::Rwkv7 { model, coder } => {
1541 with_rwkv_tls(model, |c| c.compress_size(data, *coder).unwrap_or(0))
1542 }
1543 CompressionBackend::Rate {
1544 rate_backend,
1545 coder,
1546 framing,
1547 } => crate::compression::compress_rate_size(data, rate_backend, -1, *coder, *framing)
1548 .unwrap_or(0),
1549 }
1550}
1551
1552pub fn compress_bytes_backend(
1554 data: &[u8],
1555 backend: &CompressionBackend,
1556) -> anyhow::Result<Vec<u8>> {
1557 match backend {
1558 CompressionBackend::Zpaq { method } => zpaq_compress_to_vec(data, method),
1559 #[cfg(feature = "backend-rwkv")]
1560 CompressionBackend::Rwkv7 { model, coder } => {
1561 with_rwkv_tls(model, |c| c.compress(data, *coder))
1562 }
1563 CompressionBackend::Rate {
1564 rate_backend,
1565 coder,
1566 framing,
1567 } => crate::compression::compress_rate_bytes(data, rate_backend, -1, *coder, *framing),
1568 }
1569}
1570
1571pub fn decompress_bytes_backend(
1573 input: &[u8],
1574 backend: &CompressionBackend,
1575) -> anyhow::Result<Vec<u8>> {
1576 match backend {
1577 CompressionBackend::Zpaq { .. } => zpaq_decompress_to_vec(input),
1578 #[cfg(feature = "backend-rwkv")]
1579 CompressionBackend::Rwkv7 { model, .. } => with_rwkv_tls(model, |c| c.decompress(input)),
1580 CompressionBackend::Rate {
1581 rate_backend,
1582 coder,
1583 framing,
1584 } => crate::compression::decompress_rate_bytes(input, rate_backend, -1, *coder, *framing),
1585 }
1586}
1587
1588fn prequential_rate_backend(
1589 data: &[u8],
1590 prefix_parts: &[&[u8]],
1591 max_order: i64,
1592 backend: &RateBackend,
1593) -> f64 {
1594 use crate::mixture::OnlineBytePredictor;
1595
1596 if data.is_empty() {
1597 return 0.0;
1598 }
1599 let total = prefix_parts
1600 .iter()
1601 .map(|p| p.len() as u64)
1602 .sum::<u64>()
1603 .saturating_add(data.len() as u64);
1604 let mut predictor = crate::mixture::RateBackendPredictor::from_backend(
1605 backend.clone(),
1606 max_order,
1607 crate::mixture::DEFAULT_MIN_PROB,
1608 );
1609 predictor
1610 .begin_stream(Some(total))
1611 .unwrap_or_else(|e| panic!("rate backend stream init failed: {e}"));
1612 for prefix in prefix_parts {
1613 for &b in *prefix {
1614 predictor.update(b);
1615 }
1616 }
1617 let mut bits = 0.0;
1618 for &b in data {
1619 bits -= predictor.log_prob(b) / std::f64::consts::LN_2;
1620 predictor.update(b);
1621 }
1622 predictor
1623 .finish_stream()
1624 .unwrap_or_else(|e| panic!("rate backend stream finalize failed: {e}"));
1625 bits / (data.len() as f64)
1626}
1627
1628fn frozen_plugin_rate_backend(
1629 score_data: &[u8],
1630 fit_parts: &[&[u8]],
1631 max_order: i64,
1632 backend: &RateBackend,
1633) -> f64 {
1634 if score_data.is_empty() {
1635 return 0.0;
1636 }
1637 if matches!(backend, RateBackend::RosaPlus) {
1638 let mut model = rosaplus::RosaPlus::new(max_order, false, 0, 42);
1639 for part in fit_parts {
1640 model.train_example(part);
1641 }
1642 model.build_lm();
1643 return model.cross_entropy(score_data);
1644 }
1645 #[cfg(feature = "backend-rwkv")]
1646 match backend {
1647 RateBackend::Rwkv7 { model } => {
1648 return with_rwkv_rate_tls(model, |c| {
1649 c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
1650 .unwrap_or_else(|e| panic!("rwkv frozen-plugin scoring failed: {e:#}"))
1651 });
1652 }
1653 RateBackend::Rwkv7Method { method } => {
1654 return with_rwkv_method_tls(method, |c| {
1655 c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
1656 .unwrap_or_else(|e| panic!("rwkv method frozen-plugin scoring failed: {e:#}"))
1657 });
1658 }
1659 _ => {}
1660 }
1661 #[cfg(feature = "backend-mamba")]
1662 match backend {
1663 RateBackend::Mamba { model } => {
1664 return with_mamba_rate_tls(model, |c| {
1665 c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
1666 .unwrap_or_else(|e| panic!("mamba frozen-plugin scoring failed: {e:#}"))
1667 });
1668 }
1669 RateBackend::MambaMethod { method } => {
1670 return with_mamba_method_tls(method, |c| {
1671 c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
1672 .unwrap_or_else(|e| panic!("mamba method frozen-plugin scoring failed: {e:#}"))
1673 });
1674 }
1675 _ => {}
1676 }
1677
1678 use crate::mixture::OnlineBytePredictor;
1679
1680 let fit_total = fit_parts.iter().map(|part| part.len() as u64).sum::<u64>();
1681 let mut predictor = crate::mixture::RateBackendPredictor::from_backend(
1682 backend.clone(),
1683 max_order,
1684 crate::mixture::DEFAULT_MIN_PROB,
1685 );
1686 predictor
1687 .begin_stream(Some(fit_total))
1688 .unwrap_or_else(|e| panic!("rate backend fit-pass init failed: {e}"));
1689 for part in fit_parts {
1690 for &byte in *part {
1691 predictor.update(byte);
1692 }
1693 }
1694 predictor
1695 .finish_stream()
1696 .unwrap_or_else(|e| panic!("rate backend fit-pass finalize failed: {e}"));
1697 predictor
1698 .reset_frozen(Some(score_data.len() as u64))
1699 .unwrap_or_else(|e| panic!("rate backend frozen-score reset failed: {e}"));
1700 let mut bits = 0.0;
1701 for &byte in score_data {
1702 bits -= predictor.log_prob(byte) / std::f64::consts::LN_2;
1703 predictor.update_frozen(byte);
1704 }
1705 predictor
1706 .finish_stream()
1707 .unwrap_or_else(|e| panic!("rate backend frozen-score finalize failed: {e}"));
1708 bits / (score_data.len() as f64)
1709}
1710
1711#[inline(always)]
1712fn argmax_log_prob_byte(logps: &[f64; 256]) -> u8 {
1713 let mut best_idx = 0usize;
1714 let mut best = f64::NEG_INFINITY;
1715 for (idx, &logp) in logps.iter().enumerate() {
1716 let score = if logp.is_finite() {
1717 logp
1718 } else {
1719 f64::NEG_INFINITY
1720 };
1721 if score > best {
1722 best = score;
1723 best_idx = idx;
1724 }
1725 }
1726 best_idx as u8
1727}
1728
1729fn pick_generated_byte(
1730 logps: &[f64; 256],
1731 config: GenerationConfig,
1732 rng: &mut GenerationRng,
1733) -> u8 {
1734 if matches!(config.strategy, GenerationStrategy::Greedy)
1735 || !config.temperature.is_finite()
1736 || config.temperature <= 0.0
1737 {
1738 return argmax_log_prob_byte(logps);
1739 }
1740
1741 let mut entries = [(0u8, f64::NEG_INFINITY); 256];
1742 for (idx, &logp) in logps.iter().enumerate() {
1743 let scaled = if logp.is_finite() {
1744 logp / config.temperature
1745 } else {
1746 f64::NEG_INFINITY
1747 };
1748 entries[idx] = (idx as u8, scaled);
1749 }
1750 entries.sort_by(|a, b| b.1.total_cmp(&a.1));
1751
1752 let keep_k = if config.top_k == 0 {
1753 entries.len()
1754 } else {
1755 config.top_k.min(entries.len())
1756 };
1757
1758 let top_p = if config.top_p.is_finite() {
1759 config.top_p.clamp(0.0, 1.0)
1760 } else {
1761 1.0
1762 };
1763
1764 let mut max_logp = f64::NEG_INFINITY;
1765 for &(_, logp) in entries.iter().take(keep_k) {
1766 if logp.is_finite() {
1767 max_logp = max_logp.max(logp);
1768 }
1769 }
1770 if !max_logp.is_finite() {
1771 return argmax_log_prob_byte(logps);
1772 }
1773
1774 let mut weights = [(0u8, 0.0f64); 256];
1775 let mut total = 0.0;
1776 for (idx, &(byte, logp)) in entries.iter().take(keep_k).enumerate() {
1777 let w = if logp.is_finite() {
1778 (logp - max_logp).exp()
1779 } else {
1780 0.0
1781 };
1782 weights[idx] = (byte, w);
1783 total += w;
1784 }
1785 if !(total.is_finite()) || total <= 0.0 {
1786 return argmax_log_prob_byte(logps);
1787 }
1788
1789 let cutoff_count = if top_p >= 1.0 {
1790 keep_k
1791 } else {
1792 let mut cumulative = 0.0;
1793 let mut keep = 0usize;
1794 for &(_, w) in weights.iter().take(keep_k) {
1795 cumulative += w / total;
1796 keep += 1;
1797 if cumulative >= top_p {
1798 break;
1799 }
1800 }
1801 keep.max(1)
1802 };
1803
1804 let mut truncated_total = 0.0;
1805 for &(_, w) in weights.iter().take(cutoff_count) {
1806 truncated_total += w;
1807 }
1808 if !(truncated_total.is_finite()) || truncated_total <= 0.0 {
1809 return argmax_log_prob_byte(logps);
1810 }
1811
1812 let target = rng.next_f64() * truncated_total;
1813 let mut cumulative = 0.0;
1814 let mut picked = weights[0].0;
1815 for &(byte, weight) in weights.iter().take(cutoff_count) {
1816 cumulative += weight;
1817 if cumulative >= target {
1818 picked = byte;
1819 break;
1820 }
1821 }
1822 picked
1823}
1824
1825fn generate_rate_backend_chain(
1826 prefix_parts: &[&[u8]],
1827 bytes: usize,
1828 max_order: i64,
1829 backend: &RateBackend,
1830 config: GenerationConfig,
1831) -> Vec<u8> {
1832 if bytes == 0 {
1833 return Vec::new();
1834 }
1835
1836 let total = prefix_parts
1837 .iter()
1838 .map(|p| p.len() as u64)
1839 .sum::<u64>()
1840 .saturating_add(bytes as u64);
1841 let mut session = RateBackendSession::from_backend(backend.clone(), max_order, Some(total))
1842 .unwrap_or_else(|e| panic!("rate backend generation init failed: {e}"));
1843 for &part in prefix_parts {
1844 session.observe(part);
1845 }
1846 let out = session.generate_bytes(bytes, config);
1847 session
1848 .finish()
1849 .unwrap_or_else(|e| panic!("rate backend generation finalize failed: {e}"));
1850 out
1851}
1852
1853pub fn entropy_rate_backend(data: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
1855 match backend {
1856 RateBackend::RosaPlus => {
1857 let mut m = rosaplus::RosaPlus::new(max_order, false, 0, 42);
1858 m.predictive_entropy_rate(data)
1859 }
1860 RateBackend::Match { .. }
1861 | RateBackend::SparseMatch { .. }
1862 | RateBackend::Ppmd { .. }
1863 | RateBackend::Calibrated { .. } => prequential_rate_backend(data, &[], max_order, backend),
1864 #[cfg(feature = "backend-rwkv")]
1865 RateBackend::Rwkv7 { model } => with_rwkv_tls(model, |c| {
1866 c.cross_entropy(data)
1867 .unwrap_or_else(|e| panic!("rwkv entropy scoring failed: {e:#}"))
1868 }),
1869 #[cfg(feature = "backend-rwkv")]
1870 RateBackend::Rwkv7Method { method } => with_rwkv_method_tls(method, |c| {
1871 c.cross_entropy(data)
1872 .unwrap_or_else(|e| panic!("rwkv method entropy scoring failed: {e:#}"))
1873 }),
1874 #[cfg(feature = "backend-mamba")]
1875 RateBackend::Mamba { model } => with_mamba_tls(model, |c| {
1876 c.cross_entropy(data)
1877 .unwrap_or_else(|e| panic!("mamba entropy scoring failed: {e:#}"))
1878 }),
1879 #[cfg(feature = "backend-mamba")]
1880 RateBackend::MambaMethod { method } => with_mamba_method_tls(method, |c| {
1881 c.cross_entropy(data)
1882 .unwrap_or_else(|e| panic!("mamba method entropy scoring failed: {e:#}"))
1883 }),
1884 RateBackend::Zpaq { method } => {
1885 if data.is_empty() {
1886 return 0.0;
1887 }
1888 let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
1889 let bits = model.update_and_score(data);
1890 bits / (data.len() as f64)
1891 }
1892 RateBackend::Mixture { spec } => {
1893 if data.is_empty() {
1894 return 0.0;
1895 }
1896 let experts = spec.build_experts();
1897 let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
1898 .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
1899 mix.begin_stream(Some(data.len() as u64))
1900 .unwrap_or_else(|e| panic!("Mixture stream init failed: {e}"));
1901 let mut bits = 0.0;
1902 for &b in data {
1903 bits -= mix.step(b) / std::f64::consts::LN_2;
1904 }
1905 mix.finish_stream()
1906 .unwrap_or_else(|e| panic!("Mixture stream finalize failed: {e}"));
1907 bits / (data.len() as f64)
1908 }
1909 RateBackend::Particle { spec } => {
1910 if data.is_empty() {
1911 return 0.0;
1912 }
1913 let mut runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
1914 let mut bits = 0.0;
1915 for &b in data {
1916 bits -= runtime.step(b) / std::f64::consts::LN_2;
1917 }
1918 bits / (data.len() as f64)
1919 }
1920 RateBackend::Ctw { depth } => {
1921 if data.is_empty() {
1922 return 0.0;
1923 }
1924 let mut fac = crate::ctw::FacContextTree::new(*depth, 8);
1926 fac.reserve_for_symbols(data.len());
1927 for &b in data {
1928 fac.update_byte_msb(b);
1929 }
1930 let ln_p = fac.get_log_block_probability();
1931 let bits = -ln_p / std::f64::consts::LN_2;
1932 bits / (data.len() as f64)
1933 }
1934 RateBackend::FacCtw {
1935 base_depth,
1936 num_percept_bits: _,
1937 encoding_bits,
1938 } => {
1939 if data.is_empty() {
1940 return 0.0;
1941 }
1942 let bits_per_byte = (*encoding_bits).clamp(1, 8);
1943 let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte);
1944 fac.reserve_for_symbols(data.len());
1945 for &b in data {
1946 fac.update_byte_lsb(b);
1947 }
1948 let ln_p = fac.get_log_block_probability();
1949 let bits = -ln_p / std::f64::consts::LN_2;
1950 bits / (data.len() as f64)
1951 }
1952 }
1953}
1954
1955pub fn biased_entropy_rate_backend(data: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
1957 match backend {
1958 RateBackend::Zpaq { .. } => {
1959 panic!("biased/plugin entropy is not supported for zpaq rate backends in 1.1.0")
1960 }
1961 _ => frozen_plugin_rate_backend(data, &[data], max_order, backend),
1962 }
1963}
1964
1965pub fn cross_entropy_rate_backend(
1967 test_data: &[u8],
1968 train_data: &[u8],
1969 max_order: i64,
1970 backend: &RateBackend,
1971) -> f64 {
1972 match backend {
1973 RateBackend::Zpaq { method } => {
1974 if test_data.is_empty() {
1975 return 0.0;
1976 }
1977 let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
1978 model.update_and_score(train_data);
1979 let bits = model.update_and_score(test_data);
1980 bits / (test_data.len() as f64)
1981 }
1982 _ => frozen_plugin_rate_backend(test_data, &[train_data], max_order, backend),
1983 }
1984}
1985
1986pub fn joint_entropy_rate_backend(
1988 x: &[u8],
1989 y: &[u8],
1990 max_order: i64,
1991 backend: &RateBackend,
1992) -> f64 {
1993 let (x, y) = aligned_prefix(x, y);
1994 if x.is_empty() {
1995 return 0.0;
1996 }
1997 match backend {
1998 RateBackend::RosaPlus => {
1999 let joint_symbols: Vec<u32> = (0..x.len())
2000 .map(|i| (x[i] as u32) * 256 + (y[i] as u32))
2001 .collect();
2002 let mut m = rosaplus::RosaPlus::new(max_order, false, 0, 42);
2003 m.entropy_rate_cps(&joint_symbols)
2004 }
2005 RateBackend::Match { .. }
2006 | RateBackend::SparseMatch { .. }
2007 | RateBackend::Ppmd { .. }
2008 | RateBackend::Calibrated { .. } => {
2009 let mut joint = Vec::with_capacity(x.len() * 2);
2010 for (&xb, &yb) in x.iter().zip(y.iter()) {
2011 joint.push(xb);
2012 joint.push(yb);
2013 }
2014 entropy_rate_backend(&joint, max_order, backend) * 2.0
2015 }
2016 #[cfg(feature = "backend-rwkv")]
2017 RateBackend::Rwkv7 { model } => with_rwkv_tls(model, |c| {
2018 c.joint_cross_entropy_aligned_min(x, y)
2019 .unwrap_or_else(|e| panic!("rwkv joint-entropy scoring failed: {e:#}"))
2020 }),
2021 #[cfg(feature = "backend-rwkv")]
2022 RateBackend::Rwkv7Method { method } => with_rwkv_method_tls(method, |c| {
2023 c.joint_cross_entropy_aligned_min(x, y)
2024 .unwrap_or_else(|e| panic!("rwkv method joint-entropy scoring failed: {e:#}"))
2025 }),
2026 #[cfg(feature = "backend-mamba")]
2027 RateBackend::Mamba { model } => with_mamba_tls(model, |c| {
2028 c.joint_cross_entropy_aligned_min(x, y)
2029 .unwrap_or_else(|e| panic!("mamba joint-entropy scoring failed: {e:#}"))
2030 }),
2031 #[cfg(feature = "backend-mamba")]
2032 RateBackend::MambaMethod { method } => with_mamba_method_tls(method, |c| {
2033 c.joint_cross_entropy_aligned_min(x, y)
2034 .unwrap_or_else(|e| panic!("mamba method joint-entropy scoring failed: {e:#}"))
2035 }),
2036 RateBackend::Zpaq { method } => {
2037 let mut joint = Vec::with_capacity(x.len() * 2);
2038 for (&xb, &yb) in x.iter().zip(y.iter()) {
2039 joint.push(xb);
2040 joint.push(yb);
2041 }
2042 let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
2043 let bits = model.update_and_score(&joint);
2044 bits / (x.len() as f64)
2045 }
2046 RateBackend::Mixture { spec } => {
2047 let mut joint = Vec::with_capacity(x.len() * 2);
2048 for (&xb, &yb) in x.iter().zip(y.iter()) {
2049 joint.push(xb);
2050 joint.push(yb);
2051 }
2052 let experts = spec.build_experts();
2053 let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
2054 .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
2055 mix.begin_stream(Some(joint.len() as u64))
2056 .unwrap_or_else(|e| panic!("Mixture stream init failed: {e}"));
2057 let mut bits = 0.0;
2058 for &b in &joint {
2059 bits -= mix.step(b) / std::f64::consts::LN_2;
2060 }
2061 mix.finish_stream()
2062 .unwrap_or_else(|e| panic!("Mixture stream finalize failed: {e}"));
2063 bits / (x.len() as f64)
2064 }
2065 RateBackend::Particle { spec } => {
2066 let mut joint = Vec::with_capacity(x.len() * 2);
2067 for (&xb, &yb) in x.iter().zip(y.iter()) {
2068 joint.push(xb);
2069 joint.push(yb);
2070 }
2071 let mut runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
2072 let mut bits = 0.0;
2073 for &b in &joint {
2074 bits -= runtime.step(b) / std::f64::consts::LN_2;
2075 }
2076 bits / (x.len() as f64)
2077 }
2078 RateBackend::Ctw { depth } => {
2079 let mut fac = crate::ctw::FacContextTree::new(*depth, 16);
2085 for k in 0..x.len() {
2086 let bx = x[k];
2087 let by = y[k];
2088 for bit_idx in 0..8 {
2089 let bit_x = ((bx >> (7 - bit_idx)) & 1) == 1;
2090 let bit_y = ((by >> (7 - bit_idx)) & 1) == 1;
2091 fac.update(bit_x, bit_idx);
2092 fac.update(bit_y, bit_idx + 8);
2093 }
2094 }
2095 let ln_p = fac.get_log_block_probability();
2096 let bits = -ln_p / std::f64::consts::LN_2;
2097 bits / (x.len() as f64)
2098 }
2099 RateBackend::FacCtw {
2100 base_depth,
2101 num_percept_bits: _,
2102 encoding_bits,
2103 } => {
2104 let bits_per_byte = (*encoding_bits).clamp(1, 8);
2106 let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte * 2);
2107 for k in 0..x.len() {
2108 let bx = x[k];
2109 let by = y[k];
2110 for i in 0..bits_per_byte {
2111 let bit_idx_x = i * 2;
2116 let bit_idx_y = bit_idx_x + 1;
2117 fac.update(((bx >> i) & 1) == 1, bit_idx_x);
2118 fac.update(((by >> i) & 1) == 1, bit_idx_y);
2119 }
2120 }
2121 let ln_p = fac.get_log_block_probability();
2122 let bits = -ln_p / std::f64::consts::LN_2;
2123 bits / (x.len() as f64)
2124 }
2125 }
2126}
2127#[inline(always)]
2128pub fn get_compressed_size_parallel(path: &str, method: &str, threads: usize) -> u64 {
2130 zpaq_compress_size_parallel_bytes(&std::fs::read(path).unwrap(), method, threads)
2133}
2134
2135#[inline(always)]
2136pub fn get_bytes_from_paths(paths: &[&str]) -> Vec<Vec<u8>> {
2138 paths
2139 .par_iter()
2140 .map(|path| std::fs::read(*path).expect("failed to read file"))
2141 .collect()
2142}
2143
2144#[inline(always)]
2146pub fn get_sequential_compressed_sizes_from_sequential_paths(
2147 paths: &[&str],
2148 method: &str,
2149) -> Vec<u64> {
2150 get_bytes_from_paths(paths)
2155 .par_iter()
2156 .map(|data| zpaq_compress_size_bytes(data, method))
2157 .collect()
2158}
2159
2160#[inline(always)]
2161pub fn get_parallel_compressed_sizes_from_sequential_paths(
2163 paths: &[&str],
2164 method: &str,
2165 threads: usize,
2166) -> Vec<u64> {
2167 get_bytes_from_paths(paths)
2171 .par_iter()
2172 .map(|data| zpaq_compress_size_parallel_bytes(data, method, threads))
2173 .collect()
2174}
2175
2176#[inline(always)]
2177pub fn get_sequential_compressed_sizes_from_parallel_paths(
2179 paths: &[&str],
2180 method: &str,
2181) -> Vec<u64> {
2182 paths
2186 .par_iter()
2187 .map(|path| get_compressed_size(path, method))
2188 .collect()
2189}
2190
2191#[inline(always)]
2192pub fn get_parallel_compressed_sizes_from_parallel_paths(
2194 paths: &[&str],
2195 method: &str,
2196 threads: usize,
2197) -> Vec<u64> {
2198 paths
2203 .par_iter()
2204 .map(|path| get_compressed_size_parallel(path, method, threads))
2205 .collect()
2206}
2207
2208#[inline(always)]
2210pub fn get_compressed_sizes_from_paths(paths: &[&str], method: &str) -> Vec<u64> {
2211 let n: usize = paths.len();
2212 let num_threads: usize = *NUM_THREADS.get_or_init(num_cpus::get);
2213 if n < num_threads {
2214 get_parallel_compressed_sizes_from_parallel_paths(paths, method, num_threads.div_ceil(n))
2215 } else {
2216 get_sequential_compressed_sizes_from_parallel_paths(paths, method)
2217 }
2218}
2219
2220#[derive(Clone, Copy, Debug, Eq, PartialEq)]
2231pub enum NcdVariant {
2232 Vitanyi,
2236 SymVitanyi,
2240 Cons,
2244 SymCons,
2247}
2248
2249#[inline(always)]
2250fn compress_size_bytes(data: &[u8], method: &str) -> u64 {
2251 zpaq_compress_size_bytes(data, method)
2252}
2253
2254#[inline(always)]
2255fn ncd_from_sizes(cx: u64, cy: u64, cxy: u64, cyx: Option<u64>, variant: NcdVariant) -> f64 {
2256 let min_c = cx.min(cy) as f64;
2257 let max_c = cx.max(cy) as f64;
2258
2259 match variant {
2260 NcdVariant::Vitanyi => {
2261 if max_c == 0.0 {
2262 0.0
2263 } else {
2264 (cxy as f64 - min_c) / max_c
2265 }
2266 }
2267 NcdVariant::SymVitanyi => {
2268 let m = cxy.min(cyx.expect("cyx required for SymVitanyi")) as f64;
2269 if max_c == 0.0 {
2270 0.0
2271 } else {
2272 (m - min_c) / max_c
2273 }
2274 }
2275 NcdVariant::Cons => {
2276 let denom = cxy as f64;
2277 if denom == 0.0 {
2278 0.0
2279 } else {
2280 (cxy as f64 - min_c) / denom
2281 }
2282 }
2283 NcdVariant::SymCons => {
2284 let m = cxy.min(cyx.expect("cyx required for SymCons")) as f64;
2285 if m == 0.0 { 0.0 } else { (m - min_c) / m }
2286 }
2287 }
2288}
2289
2290#[inline(always)]
2291pub fn ncd_bytes(x: &[u8], y: &[u8], method: &str, variant: NcdVariant) -> f64 {
2293 let backend = CompressionBackend::Zpaq {
2294 method: method.to_string(),
2295 };
2296 ncd_bytes_backend(x, y, &backend, variant)
2297}
2298
2299#[inline(always)]
2301pub fn ncd_bytes_default(x: &[u8], y: &[u8], variant: NcdVariant) -> f64 {
2302 with_default_ctx(|ctx| ctx.ncd_bytes(x, y, variant))
2303}
2304
2305pub fn ncd_bytes_backend(
2307 x: &[u8],
2308 y: &[u8],
2309 backend: &CompressionBackend,
2310 variant: NcdVariant,
2311) -> f64 {
2312 let (cx, cy) = rayon::join(
2313 || compress_size_backend(x, backend),
2314 || compress_size_backend(y, backend),
2315 );
2316
2317 let cxy = compress_size_chain_backend(&[x, y], backend);
2318
2319 let cyx = match variant {
2320 NcdVariant::SymVitanyi | NcdVariant::SymCons => {
2321 Some(compress_size_chain_backend(&[y, x], backend))
2322 }
2323 _ => None,
2324 };
2325
2326 ncd_from_sizes(cx, cy, cxy, cyx, variant)
2327}
2328
2329#[inline(always)]
2330pub fn ncd_paths(x: &str, y: &str, method: &str, variant: NcdVariant) -> f64 {
2332 let (bx, by) = rayon::join(
2333 || std::fs::read(x).expect("failed to read x"),
2334 || std::fs::read(y).expect("failed to read y"),
2335 );
2336 ncd_bytes(&bx, &by, method, variant)
2337}
2338
2339pub fn ncd_paths_backend(
2341 x: &str,
2342 y: &str,
2343 backend: &CompressionBackend,
2344 variant: NcdVariant,
2345) -> f64 {
2346 let (bx, by) = rayon::join(
2347 || std::fs::read(x).expect("failed to read x"),
2348 || std::fs::read(y).expect("failed to read y"),
2349 );
2350 ncd_bytes_backend(&bx, &by, backend, variant)
2351}
2352
2353#[inline(always)]
2355pub fn ncd_vitanyi(x: &str, y: &str, method: &str) -> f64 {
2356 ncd_paths(x, y, method, NcdVariant::Vitanyi)
2357}
2358#[inline(always)]
2359pub fn ncd_sym_vitanyi(x: &str, y: &str, method: &str) -> f64 {
2361 ncd_paths(x, y, method, NcdVariant::SymVitanyi)
2362}
2363#[inline(always)]
2364pub fn ncd_cons(x: &str, y: &str, method: &str) -> f64 {
2366 ncd_paths(x, y, method, NcdVariant::Cons)
2367}
2368#[inline(always)]
2369pub fn ncd_sym_cons(x: &str, y: &str, method: &str) -> f64 {
2371 ncd_paths(x, y, method, NcdVariant::SymCons)
2372}
2373
2374pub fn ncd_matrix_bytes(datas: &[Vec<u8>], method: &str, variant: NcdVariant) -> Vec<f64> {
2378 let n = datas.len();
2379 let cx: Vec<u64> = datas
2380 .par_iter()
2381 .map(|d| compress_size_bytes(d, method))
2382 .collect();
2383
2384 let mut out = vec![0.0f64; n * n];
2385 let out_ptr = std::sync::atomic::AtomicPtr::new(out.as_mut_ptr());
2386
2387 match variant {
2388 NcdVariant::SymVitanyi | NcdVariant::SymCons => {
2389 (0..n)
2390 .into_par_iter()
2391 .flat_map_iter(|i| (i + 1..n).map(move |j| (i, j)))
2392 .for_each_init(Vec::<u8>::new, |buf, (i, j)| {
2393 let x = &datas[i];
2394 let y = &datas[j];
2395
2396 buf.clear();
2397 buf.reserve(x.len() + y.len());
2398 buf.extend_from_slice(x);
2399 buf.extend_from_slice(y);
2400 let cxy = compress_size_bytes(buf, method);
2401
2402 buf.clear();
2403 buf.reserve(x.len() + y.len());
2404 buf.extend_from_slice(y);
2405 buf.extend_from_slice(x);
2406 let cyx = compress_size_bytes(buf, method);
2407
2408 let d = ncd_from_sizes(cx[i], cx[j], cxy, Some(cyx), variant);
2409
2410 let p = out_ptr.load(std::sync::atomic::Ordering::Relaxed);
2412 unsafe {
2413 *p.add(i * n + j) = d;
2414 *p.add(j * n + i) = d;
2415 }
2416 });
2417 }
2418 NcdVariant::Vitanyi | NcdVariant::Cons => {
2419 (0..n)
2420 .into_par_iter()
2421 .for_each_init(Vec::<u8>::new, |buf, i| {
2422 let x = &datas[i];
2423 for j in 0..n {
2424 let d = if i == j {
2425 0.0
2426 } else {
2427 let y = &datas[j];
2428 buf.clear();
2429 buf.reserve(x.len() + y.len());
2430 buf.extend_from_slice(x);
2431 buf.extend_from_slice(y);
2432 let cxy = compress_size_bytes(buf, method);
2433 ncd_from_sizes(cx[i], cx[j], cxy, None, variant)
2434 };
2435
2436 let p = out_ptr.load(std::sync::atomic::Ordering::Relaxed);
2437 unsafe {
2438 *p.add(i * n + j) = d;
2439 }
2440 }
2441 });
2442 }
2443 }
2444
2445 out
2446}
2447
2448pub fn ncd_matrix_paths(paths: &[&str], method: &str, variant: NcdVariant) -> Vec<f64> {
2450 let datas = get_bytes_from_paths(paths);
2451 ncd_matrix_bytes(&datas, method, variant)
2452}
2453
2454#[inline(always)]
2466pub fn marginal_entropy_bytes(data: &[u8]) -> f64 {
2467 if data.is_empty() {
2468 return 0.0;
2469 }
2470
2471 let mut counts = [0u64; 256];
2472 for &b in data {
2473 counts[b as usize] += 1;
2474 }
2475
2476 let n = data.len() as f64;
2477 let mut h = 0.0f64;
2478 for &count in &counts {
2479 if count > 0 {
2480 let p = count as f64 / n;
2481 h -= p * p.log2();
2482 }
2483 }
2484 h
2485}
2486
2487#[inline(always)]
2502pub fn entropy_rate_bytes(data: &[u8], max_order: i64) -> f64 {
2503 with_default_ctx(|ctx| ctx.entropy_rate_bytes(data, max_order))
2504}
2505
2506#[inline(always)]
2512pub fn biased_entropy_rate_bytes(data: &[u8], max_order: i64) -> f64 {
2513 with_default_ctx(|ctx| ctx.biased_entropy_rate_bytes(data, max_order))
2514}
2515
2516#[inline(always)]
2521pub fn joint_marginal_entropy_bytes(x: &[u8], y: &[u8]) -> f64 {
2522 let (x, y) = aligned_prefix(x, y);
2523 let n = x.len();
2524 if n == 0 {
2525 return 0.0;
2526 }
2527
2528 let mut counts = vec![0u64; 256 * 256];
2531 for i in 0..n {
2532 let pair_idx = (x[i] as usize) * 256 + (y[i] as usize);
2533 counts[pair_idx] += 1;
2534 }
2535
2536 let n_f64 = n as f64;
2537 let mut h = 0.0f64;
2538 for &c in &counts {
2539 if c > 0 {
2540 let p = c as f64 / n_f64;
2541 h -= p * p.log2();
2542 }
2543 }
2544 h
2545}
2546
2547#[inline(always)]
2559pub fn joint_entropy_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2560 with_default_ctx(|ctx| ctx.joint_entropy_rate_bytes(x, y, max_order))
2561}
2562
2563#[inline(always)]
2571pub fn conditional_entropy_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2572 with_default_ctx(|ctx| ctx.conditional_entropy_rate_bytes(x, y, max_order))
2573}
2574
2575#[inline(always)]
2579pub fn conditional_entropy_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2580 with_default_ctx(|ctx| ctx.conditional_entropy_bytes(x, y, max_order))
2581}
2582
2583#[inline(always)]
2589pub fn mutual_information_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2590 with_default_ctx(|ctx| ctx.mutual_information_bytes(x, y, max_order))
2591}
2592
2593pub fn mutual_information_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
2595 let (x, y) = aligned_prefix(x, y);
2596 let h_x = marginal_entropy_bytes(x);
2597 let h_y = marginal_entropy_bytes(y);
2598 let h_xy = joint_marginal_entropy_bytes(x, y);
2599 (h_x + h_y - h_xy).max(0.0)
2600}
2601
2602#[inline(always)]
2604pub fn mutual_information_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2605 with_default_ctx(|ctx| ctx.mutual_information_rate_bytes(x, y, max_order))
2606}
2607
2608#[inline(always)]
2620pub fn ned_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2621 with_default_ctx(|ctx| ctx.ned_bytes(x, y, max_order))
2622}
2623
2624pub fn ned_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
2626 let (x, y) = aligned_prefix(x, y);
2627 let h_x = marginal_entropy_bytes(x);
2628 let h_y = marginal_entropy_bytes(y);
2629 let h_xy = joint_marginal_entropy_bytes(x, y);
2630 let min_h = h_x.min(h_y);
2631 let max_h = h_x.max(h_y);
2632 if max_h == 0.0 {
2633 0.0
2634 } else {
2635 ((h_xy - min_h) / max_h).clamp(0.0, 1.0)
2636 }
2637}
2638
2639#[inline(always)]
2641pub fn ned_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2642 with_default_ctx(|ctx| ctx.ned_bytes(x, y, max_order))
2643}
2644
2645#[inline(always)]
2649pub fn ned_cons_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2650 with_default_ctx(|ctx| ctx.ned_cons_bytes(x, y, max_order))
2651}
2652
2653pub fn ned_cons_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
2655 let h_x = marginal_entropy_bytes(x);
2656 let h_y = marginal_entropy_bytes(y);
2657 let h_xy = joint_marginal_entropy_bytes(x, y);
2658 let min_h = h_x.min(h_y);
2659 if h_xy == 0.0 {
2660 0.0
2661 } else {
2662 ((h_xy - min_h) / h_xy).clamp(0.0, 1.0)
2663 }
2664}
2665
2666#[inline(always)]
2667pub fn ned_cons_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2669 with_default_ctx(|ctx| ctx.ned_cons_bytes(x, y, max_order))
2670}
2671
2672#[inline(always)]
2687pub fn nte_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2688 with_default_ctx(|ctx| ctx.nte_bytes(x, y, max_order))
2689}
2690
2691pub fn nte_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
2693 let (x, y) = aligned_prefix(x, y);
2694 let h_x = marginal_entropy_bytes(x);
2695 let h_y = marginal_entropy_bytes(y);
2696 let h_xy = joint_marginal_entropy_bytes(x, y);
2697 let vi = 2.0 * h_xy - h_x - h_y;
2698 let max_h = h_x.max(h_y);
2699 if max_h == 0.0 {
2700 0.0
2701 } else {
2702 (vi / max_h).clamp(0.0, 2.0)
2703 }
2704}
2705
2706#[inline(always)]
2707pub fn nte_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2709 with_default_ctx(|ctx| ctx.nte_bytes(x, y, max_order))
2710}
2711
2712#[inline(always)]
2716fn byte_histogram(data: &[u8]) -> [f64; 256] {
2717 let mut counts = [0u64; 256];
2718 for &b in data {
2719 counts[b as usize] += 1;
2720 }
2721 let n = data.len() as f64;
2722 let mut probs = [0.0f64; 256];
2723 if n == 0.0 {
2724 return probs;
2725 }
2726 for i in 0..256 {
2727 probs[i] = counts[i] as f64 / n;
2728 }
2729 probs
2730}
2731
2732#[inline(always)]
2738pub fn tvd_bytes(x: &[u8], y: &[u8], _max_order: i64) -> f64 {
2739 if x.is_empty() || y.is_empty() {
2740 return 0.0;
2741 }
2742 let p_x = byte_histogram(x);
2743 let p_y = byte_histogram(y);
2744
2745 let mut sum = 0.0f64;
2746 for i in 0..256 {
2747 sum += (p_x[i] - p_y[i]).abs();
2748 }
2749
2750 (sum / 2.0).clamp(0.0, 1.0)
2751}
2752
2753#[inline(always)]
2760pub fn nhd_bytes(x: &[u8], y: &[u8], _max_order: i64) -> f64 {
2761 if x.is_empty() || y.is_empty() {
2762 return 0.0;
2763 }
2764 let p_x = byte_histogram(x);
2765 let p_y = byte_histogram(y);
2766
2767 let mut bc = 0.0f64;
2769 for i in 0..256 {
2770 bc += (p_x[i] * p_y[i]).sqrt();
2771 }
2772
2773 (1.0 - bc).max(0.0).sqrt()
2775}
2776
2777#[inline(always)]
2783pub fn cross_entropy_bytes(test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
2784 with_default_ctx(|ctx| ctx.cross_entropy_bytes(test_data, train_data, max_order))
2785}
2786
2787#[inline(always)]
2790pub fn cross_entropy_rate_bytes(test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
2791 with_default_ctx(|ctx| ctx.cross_entropy_rate_bytes(test_data, train_data, max_order))
2792}
2793
2794#[inline(always)]
2799pub fn generate_bytes(prompt: &[u8], bytes: usize, max_order: i64) -> Vec<u8> {
2800 with_default_ctx(|ctx| ctx.generate_bytes(prompt, bytes, max_order))
2801}
2802
2803#[inline(always)]
2805pub fn generate_bytes_with_config(
2806 prompt: &[u8],
2807 bytes: usize,
2808 max_order: i64,
2809 config: GenerationConfig,
2810) -> Vec<u8> {
2811 with_default_ctx(|ctx| ctx.generate_bytes_with_config(prompt, bytes, max_order, config))
2812}
2813
2814#[inline(always)]
2817pub fn generate_bytes_conditional_chain(
2818 prefix_parts: &[&[u8]],
2819 bytes: usize,
2820 max_order: i64,
2821) -> Vec<u8> {
2822 with_default_ctx(|ctx| ctx.generate_bytes_conditional_chain(prefix_parts, bytes, max_order))
2823}
2824
2825#[inline(always)]
2828pub fn generate_bytes_conditional_chain_with_config(
2829 prefix_parts: &[&[u8]],
2830 bytes: usize,
2831 max_order: i64,
2832 config: GenerationConfig,
2833) -> Vec<u8> {
2834 with_default_ctx(|ctx| {
2835 ctx.generate_bytes_conditional_chain_with_config(prefix_parts, bytes, max_order, config)
2836 })
2837}
2838
2839pub fn d_kl_bytes(x: &[u8], y: &[u8]) -> f64 {
2843 if x.is_empty() || y.is_empty() {
2844 return 0.0;
2845 }
2846 let p_x = byte_histogram(x);
2847 let p_y = byte_histogram(y);
2848 let mut d_kl = 0.0f64;
2849 for i in 0..256 {
2850 if p_x[i] > 0.0 {
2851 let q_y = p_y[i].max(1e-12);
2852 d_kl += p_x[i] * (p_x[i] / q_y).log2();
2853 }
2854 }
2855 d_kl.max(0.0)
2856}
2857
2858pub fn js_div_bytes(x: &[u8], y: &[u8]) -> f64 {
2863 if x.is_empty() || y.is_empty() {
2864 return 0.0;
2865 }
2866 let p_x = byte_histogram(x);
2867 let p_y = byte_histogram(y);
2868 let mut m = [0.0f64; 256];
2869 for i in 0..256 {
2870 m[i] = 0.5 * (p_x[i] + p_y[i]);
2871 }
2872
2873 let mut kl_pm = 0.0f64;
2874 let mut kl_qm = 0.0f64;
2875 for i in 0..256 {
2876 if p_x[i] > 0.0 {
2877 kl_pm += p_x[i] * (p_x[i] / m[i]).log2();
2878 }
2879 if p_y[i] > 0.0 {
2880 kl_qm += p_y[i] * (p_y[i] / m[i]).log2();
2881 }
2882 }
2883 (0.5 * kl_pm + 0.5 * kl_qm).max(0.0)
2884}
2885
2886pub fn ned_paths(x: &str, y: &str, max_order: i64) -> f64 {
2890 let (bx, by) = rayon::join(
2891 || std::fs::read(x).expect("failed to read x"),
2892 || std::fs::read(y).expect("failed to read y"),
2893 );
2894 ned_bytes(&bx, &by, max_order)
2895}
2896
2897pub fn nte_paths(x: &str, y: &str, max_order: i64) -> f64 {
2899 let (bx, by) = rayon::join(
2900 || std::fs::read(x).expect("failed to read x"),
2901 || std::fs::read(y).expect("failed to read y"),
2902 );
2903 nte_bytes(&bx, &by, max_order)
2904}
2905
2906pub fn tvd_paths(x: &str, y: &str, max_order: i64) -> f64 {
2908 let (bx, by) = rayon::join(
2909 || std::fs::read(x).expect("failed to read x"),
2910 || std::fs::read(y).expect("failed to read y"),
2911 );
2912 tvd_bytes(&bx, &by, max_order)
2913}
2914
2915pub fn nhd_paths(x: &str, y: &str, max_order: i64) -> f64 {
2917 let (bx, by) = rayon::join(
2918 || std::fs::read(x).expect("failed to read x"),
2919 || std::fs::read(y).expect("failed to read y"),
2920 );
2921 nhd_bytes(&bx, &by, max_order)
2922}
2923
2924pub fn mutual_information_paths(x: &str, y: &str, max_order: i64) -> f64 {
2926 let (bx, by) = rayon::join(
2927 || std::fs::read(x).expect("failed to read x"),
2928 || std::fs::read(y).expect("failed to read y"),
2929 );
2930 mutual_information_bytes(&bx, &by, max_order)
2931}
2932
2933pub fn conditional_entropy_paths(x: &str, y: &str, max_order: i64) -> f64 {
2935 let (bx, by) = rayon::join(
2936 || std::fs::read(x).expect("failed to read x"),
2937 || std::fs::read(y).expect("failed to read y"),
2938 );
2939 conditional_entropy_bytes(&bx, &by, max_order)
2940}
2941
2942pub fn cross_entropy_paths(x: &str, y: &str, max_order: i64) -> f64 {
2944 let (bx, by) = rayon::join(
2945 || std::fs::read(x).expect("failed to read x"),
2946 || std::fs::read(y).expect("failed to read y"),
2947 );
2948 cross_entropy_bytes(&bx, &by, max_order)
2949}
2950
2951pub fn kl_divergence_paths(x: &str, y: &str) -> f64 {
2953 let (bx, by) = rayon::join(
2954 || std::fs::read(x).expect("failed to read x"),
2955 || std::fs::read(y).expect("failed to read y"),
2956 );
2957 d_kl_bytes(&bx, &by)
2958}
2959
2960pub fn js_divergence_paths(x: &str, y: &str) -> f64 {
2962 let (bx, by) = rayon::join(
2963 || std::fs::read(x).expect("failed to read x"),
2964 || std::fs::read(y).expect("failed to read y"),
2965 );
2966 js_div_bytes(&bx, &by)
2967}
2968
2969#[inline(always)]
2984pub fn intrinsic_dependence_bytes(data: &[u8], max_order: i64) -> f64 {
2985 with_default_ctx(|ctx| ctx.intrinsic_dependence_bytes(data, max_order))
2986}
2987
2988#[inline(always)]
3000pub fn resistance_to_transformation_bytes(x: &[u8], tx: &[u8], max_order: i64) -> f64 {
3001 with_default_ctx(|ctx| ctx.resistance_to_transformation_bytes(x, tx, max_order))
3002}
3003
3004#[cfg(test)]
3005mod tests {
3006 use super::*;
3007
3008 fn test_match_backend() -> RateBackend {
3009 RateBackend::Match {
3010 hash_bits: 12,
3011 min_len: 2,
3012 max_len: 16,
3013 base_mix: 0.01,
3014 confidence_scale: 1.0,
3015 }
3016 }
3017
3018 fn test_ppmd_backend() -> RateBackend {
3019 RateBackend::Ppmd {
3020 order: 4,
3021 memory_mb: 1,
3022 }
3023 }
3024
3025 fn test_calibrated_backend() -> RateBackend {
3026 RateBackend::Calibrated {
3027 spec: Arc::new(CalibratedSpec {
3028 base: test_match_backend(),
3029 context: CalibrationContextKind::Text,
3030 bins: 16,
3031 learning_rate: 0.05,
3032 bias_clip: 4.0,
3033 }),
3034 }
3035 }
3036
3037 fn test_mixture_backend() -> RateBackend {
3038 RateBackend::Mixture {
3039 spec: Arc::new(MixtureSpec::new(
3040 MixtureKind::Bayes,
3041 vec![
3042 MixtureExpertSpec {
3043 name: Some("match".to_string()),
3044 log_prior: 0.0,
3045 max_order: -1,
3046 backend: test_match_backend(),
3047 },
3048 MixtureExpertSpec {
3049 name: Some("ppmd".to_string()),
3050 log_prior: 0.0,
3051 max_order: -1,
3052 backend: test_ppmd_backend(),
3053 },
3054 ],
3055 )),
3056 }
3057 }
3058
3059 fn test_particle_backend() -> RateBackend {
3060 RateBackend::Particle {
3061 spec: Arc::new(ParticleSpec {
3062 num_particles: 4,
3063 num_cells: 4,
3064 cell_dim: 8,
3065 num_rules: 2,
3066 selector_hidden: 16,
3067 rule_hidden: 16,
3068 context_window: 8,
3069 unroll_steps: 1,
3070 ..ParticleSpec::default()
3071 }),
3072 }
3073 }
3074
3075 fn continuation_prompt() -> &'static [u8] {
3076 b"If a frog is green, dogs are red.\nIf a toad is green, cats are red.\nIf a dog is green, frogs are red.\nIf a cat is green, toads are red.\nIf a frog is red, dogs are green.\nIf a toad is red, cats are green.\nIf a dog is red, frogs are green.\nIf a cat is red, toads are \n"
3077 }
3078
3079 fn assert_deterministic_generate_for_backend(
3080 backend: RateBackend,
3081 max_order: i64,
3082 bytes: usize,
3083 label: &str,
3084 ) {
3085 let prompt = continuation_prompt();
3086 let a = generate_rate_backend_chain(
3087 &[prompt],
3088 bytes,
3089 max_order,
3090 &backend,
3091 GenerationConfig::default(),
3092 );
3093 let b = generate_rate_backend_chain(
3094 &[prompt],
3095 bytes,
3096 max_order,
3097 &backend,
3098 GenerationConfig::default(),
3099 );
3100 assert_eq!(
3101 a, b,
3102 "{label} generation should be deterministic for identical input"
3103 );
3104 assert_eq!(
3105 a.len(),
3106 bytes,
3107 "{label} generation should emit requested byte count"
3108 );
3109 }
3110
3111 fn assert_sampled_generate_for_backend(
3112 backend: RateBackend,
3113 max_order: i64,
3114 bytes: usize,
3115 label: &str,
3116 ) {
3117 let prompt = continuation_prompt();
3118 let config = GenerationConfig::sampled_frozen(42);
3119 let a = generate_rate_backend_chain(&[prompt], bytes, max_order, &backend, config);
3120 let b = generate_rate_backend_chain(&[prompt], bytes, max_order, &backend, config);
3121 assert_eq!(
3122 a, b,
3123 "{label} sampled generation should be deterministic for a fixed seed"
3124 );
3125 assert_eq!(
3126 a.len(),
3127 bytes,
3128 "{label} sampled generation should emit requested byte count"
3129 );
3130 }
3131
3132 #[cfg(feature = "backend-zpaq")]
3133 #[test]
3134 fn ncd_basic_identity_nonnegative() {
3135 let x = b"abcdabcdabcd";
3136 let d = ncd_bytes(x, x, "5", NcdVariant::Vitanyi);
3137 assert!(d >= -1e-9);
3138 }
3139
3140 #[test]
3141 fn shannon_identities_marginal_aligned() {
3142 let x = b"abracadabra";
3143 let y = b"abracadabra";
3144
3145 let h = marginal_entropy_bytes(x);
3146 let mi = mutual_information_bytes(x, y, 0);
3147 let h_xy = joint_marginal_entropy_bytes(x, y);
3148 let h_x_given_y = conditional_entropy_bytes(x, y, 0);
3149 let ned = ned_bytes(x, y, 0);
3150 let nte = nte_bytes(x, y, 0);
3151
3152 assert!((h_xy - h).abs() < 1e-12);
3153 assert!(h_x_given_y.abs() < 1e-12);
3154 assert!((mi - h).abs() < 1e-12);
3155 assert!(ned.abs() < 1e-12);
3156 assert!(nte.abs() < 1e-12);
3157 }
3158
3159 #[test]
3160 fn shannon_identities_rate_aligned_reasonable() {
3161 let x = b"the quick brown fox jumps over the lazy dog";
3162 let y = b"the quick brown fox jumps over the lazy dog";
3163 let max_order = 8;
3164 let prev = get_default_ctx();
3165 set_default_ctx(InfotheoryCtx::new(
3166 RateBackend::RosaPlus,
3167 CompressionBackend::default(),
3168 ));
3169
3170 let h_x = entropy_rate_bytes(x, max_order);
3171 let h_xy = joint_entropy_rate_bytes(x, y, max_order);
3172 let h_x_given_y = conditional_entropy_rate_bytes(x, y, max_order);
3173 let mi = mutual_information_bytes(x, y, max_order);
3174 let ned = ned_bytes(x, y, max_order);
3175
3176 let tol = 0.2;
3178 assert!((h_xy - h_x).abs() < tol);
3179 assert!(h_x_given_y < tol);
3180 assert!((mi - h_x).abs() < tol);
3181 assert!(ned < tol);
3182 set_default_ctx(prev);
3183 }
3184
3185 #[test]
3186 fn resistance_identity_is_one() {
3187 let x = b"some repeated repeated repeated text";
3188 let prev = get_default_ctx();
3189 set_default_ctx(InfotheoryCtx::new(
3190 RateBackend::RosaPlus,
3191 CompressionBackend::default(),
3192 ));
3193 let r0 = resistance_to_transformation_bytes(x, x, 0);
3194 let r8 = resistance_to_transformation_bytes(x, x, 8);
3195 assert!((r0 - 1.0).abs() < 1e-12);
3196 assert!((r8 - 1.0).abs() < 1e-6);
3197 set_default_ctx(prev);
3198 }
3199
3200 #[test]
3201 fn marginal_metrics_empty_inputs_are_zero() {
3202 let empty: &[u8] = &[];
3203 let x = b"abc";
3204
3205 assert_eq!(tvd_bytes(empty, x, 0), 0.0);
3206 assert_eq!(tvd_bytes(x, empty, 0), 0.0);
3207 assert_eq!(nhd_bytes(empty, x, 0), 0.0);
3208 assert_eq!(nhd_bytes(x, empty, 0), 0.0);
3209 assert_eq!(d_kl_bytes(empty, x), 0.0);
3210 assert_eq!(d_kl_bytes(x, empty), 0.0);
3211 assert_eq!(js_div_bytes(empty, x), 0.0);
3212 assert_eq!(js_div_bytes(x, empty), 0.0);
3213 }
3214
3215 #[test]
3216 fn marginal_cross_entropy_empty_test_is_zero() {
3217 let empty: &[u8] = &[];
3218 let y = b"abc";
3219 let ctx = InfotheoryCtx::with_zpaq("5");
3220 assert_eq!(ctx.cross_entropy_bytes(empty, y, 0), 0.0);
3221 }
3222
3223 #[cfg(not(feature = "backend-zpaq"))]
3224 #[test]
3225 #[should_panic(expected = "CompressionBackend::Zpaq is unavailable")]
3226 fn explicit_zpaq_backend_fails_loudly() {
3227 let backend = CompressionBackend::Zpaq {
3228 method: "5".to_string(),
3229 };
3230 let _ = compress_size_backend(b"abc", &backend);
3231 }
3232
3233 #[cfg(not(feature = "backend-zpaq"))]
3234 #[test]
3235 fn default_compression_backend_falls_back_to_rate_coding() {
3236 let backend = CompressionBackend::default();
3237 assert!(matches!(
3238 &backend,
3239 CompressionBackend::Rate {
3240 coder: crate::coders::CoderType::AC,
3241 framing: crate::compression::FramingMode::Raw,
3242 ..
3243 }
3244 ));
3245 assert!(compress_size_backend(b"abc", &backend) > 0);
3246 }
3247
3248 #[test]
3249 fn backend_switching_test() {
3250 let x = b"hello world context";
3251
3252 let h_rosa = entropy_rate_bytes(x, 8);
3254
3255 set_default_ctx(InfotheoryCtx::new(
3257 RateBackend::Ctw { depth: 16 },
3258 CompressionBackend::default(),
3259 ));
3260
3261 let h_ctw = entropy_rate_bytes(x, 8);
3262
3263 assert!(h_ctw > 0.0);
3265
3266 set_default_ctx(InfotheoryCtx::default());
3268 let h_rosa_back = entropy_rate_bytes(x, 8);
3269 assert!((h_rosa - h_rosa_back).abs() < 1e-12);
3270 }
3271
3272 #[test]
3273 fn ctw_early_updates_work() {
3274 use crate::ctw::ContextTree;
3277
3278 let mut tree = ContextTree::new(16);
3279
3280 let p0 = tree.predict(false);
3282 let p1 = tree.predict(true);
3283
3284 assert!((p0 - 0.5).abs() < 1e-10, "p0 should be ~0.5, got {}", p0);
3286 assert!((p1 - 0.5).abs() < 1e-10, "p1 should be ~0.5, got {}", p1);
3287 assert!((p0 + p1 - 1.0).abs() < 1e-10, "p0 + p1 should = 1.0");
3288
3289 for _ in 0..5 {
3291 tree.update(true);
3292 tree.update(false);
3293 }
3294
3295 let log_prob = tree.get_log_block_probability();
3296 assert!(
3297 log_prob < 0.0,
3298 "log_prob should be negative (< log 1), got {}",
3299 log_prob
3300 );
3301 assert!(log_prob.is_finite(), "log_prob should be finite");
3302 }
3303
3304 #[test]
3305 fn nte_can_exceed_one() {
3306 set_default_ctx(InfotheoryCtx::new(
3317 RateBackend::Ctw { depth: 8 },
3318 CompressionBackend::default(),
3319 ));
3320
3321 let x: Vec<u8> = (0..200).map(|i| (i % 2) as u8).collect(); let y: Vec<u8> = (0..200).map(|i| ((i + 1) % 2) as u8).collect(); let nte_rate = nte_rate_backend(&x, &y, -1, &RateBackend::Ctw { depth: 8 });
3326
3327 assert!(
3330 (0.0..=2.0 + 1e-9).contains(&nte_rate),
3331 "NTE should be in [0, 2], got {}",
3332 nte_rate
3333 );
3334
3335 set_default_ctx(InfotheoryCtx::default());
3337 }
3338
3339 #[test]
3340 fn ctw_empty_data_returns_zero() {
3341 set_default_ctx(InfotheoryCtx::new(
3343 RateBackend::Ctw { depth: 16 },
3344 CompressionBackend::default(),
3345 ));
3346
3347 let empty: &[u8] = &[];
3348 let h = entropy_rate_bytes(empty, -1);
3349 assert_eq!(h, 0.0, "empty data should return 0.0 entropy");
3350
3351 set_default_ctx(InfotheoryCtx::default());
3353 }
3354
3355 #[test]
3356 fn joint_entropy_rate_aligns_inputs_and_handles_empty_cases() {
3357 let cases = vec![
3358 ("ctw", RateBackend::Ctw { depth: 8 }),
3359 (
3360 "fac-ctw",
3361 RateBackend::FacCtw {
3362 base_depth: 8,
3363 num_percept_bits: 8,
3364 encoding_bits: 8,
3365 },
3366 ),
3367 ("match", test_match_backend()),
3368 ];
3369
3370 for (name, backend) in cases {
3371 assert_eq!(
3372 joint_entropy_rate_backend(b"", b"nonempty", -1, &backend),
3373 0.0,
3374 "{name} should return 0.0 for empty aligned pairs"
3375 );
3376 assert_eq!(
3377 joint_entropy_rate_backend(b"nonempty", b"", -1, &backend),
3378 0.0,
3379 "{name} should return 0.0 when alignment truncates to empty"
3380 );
3381
3382 let aligned = joint_entropy_rate_backend(b"abcd", b"wxyz", -1, &backend);
3383 let truncated = joint_entropy_rate_backend(b"abcdextra", b"wxyz", -1, &backend);
3384 assert!(
3385 (aligned - truncated).abs() < 1e-12,
3386 "{name} should score only the aligned prefix: aligned={aligned} truncated={truncated}"
3387 );
3388 }
3389 }
3390
3391 #[test]
3392 fn biased_entropy_is_repeatable_across_backend_families() {
3393 let data = b"ABABABAABBABABABAABB";
3394 let cases = vec![
3395 ("match", test_match_backend()),
3396 ("ppmd", test_ppmd_backend()),
3397 ("calibrated", test_calibrated_backend()),
3398 ("ctw", RateBackend::Ctw { depth: 8 }),
3399 ("mixture", test_mixture_backend()),
3400 ("particle", test_particle_backend()),
3401 ];
3402
3403 for (name, backend) in cases {
3404 let h1 = biased_entropy_rate_backend(data, -1, &backend);
3405 let h2 = biased_entropy_rate_backend(data, -1, &backend);
3406 assert!(h1.is_finite(), "{name} biased entropy should be finite");
3407 assert!(
3408 (h1 - h2).abs() < 1e-12,
3409 "{name} biased entropy leaked mutable state across calls: h1={h1} h2={h2}"
3410 );
3411 }
3412 }
3413
3414 #[test]
3415 fn generate_bytes_chain_matches_flat_prompt() {
3416 let prompt = continuation_prompt();
3417 let split_at = prompt.len() / 2;
3418 let front = &prompt[..split_at];
3419 let back = &prompt[split_at..];
3420 let backend = RateBackend::Ctw { depth: 32 };
3421 let bytes = 8usize;
3422 let max_order = -1;
3423
3424 let flat = generate_rate_backend_chain(
3425 &[prompt],
3426 bytes,
3427 max_order,
3428 &backend,
3429 GenerationConfig::default(),
3430 );
3431 let chained = generate_rate_backend_chain(
3432 &[front, back],
3433 bytes,
3434 max_order,
3435 &backend,
3436 GenerationConfig::default(),
3437 );
3438 assert_eq!(
3439 flat, chained,
3440 "chain conditioning should match flat prompt conditioning"
3441 );
3442 }
3443
3444 #[test]
3445 fn generate_bytes_api_is_deterministic_for_ctw_rosa_match_ppmd() {
3446 assert_deterministic_generate_for_backend(RateBackend::Ctw { depth: 32 }, -1, 8, "ctw");
3447 assert_deterministic_generate_for_backend(RateBackend::RosaPlus, -1, 8, "rosaplus");
3448 assert_deterministic_generate_for_backend(test_match_backend(), -1, 8, "match");
3449 assert_deterministic_generate_for_backend(test_ppmd_backend(), -1, 8, "ppmd");
3450 }
3451
3452 #[cfg(feature = "backend-rwkv")]
3453 #[test]
3454 fn generate_bytes_api_is_deterministic_for_rwkv_method() {
3455 let backend = RateBackend::Rwkv7Method {
3456 method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3457 };
3458 assert_deterministic_generate_for_backend(backend, -1, 8, "rwkv7");
3459 }
3460
3461 #[test]
3462 fn sampled_generation_is_deterministic_for_ctw_rosa_match_ppmd() {
3463 assert_sampled_generate_for_backend(RateBackend::Ctw { depth: 32 }, -1, 8, "ctw");
3464 assert_sampled_generate_for_backend(RateBackend::RosaPlus, -1, 8, "rosaplus");
3465 assert_sampled_generate_for_backend(test_match_backend(), -1, 8, "match");
3466 assert_sampled_generate_for_backend(test_ppmd_backend(), -1, 8, "ppmd");
3467 }
3468
3469 #[cfg(feature = "backend-rwkv")]
3470 #[test]
3471 fn sampled_generation_is_deterministic_for_rwkv_method() {
3472 let backend = RateBackend::Rwkv7Method {
3473 method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3474 };
3475 assert_sampled_generate_for_backend(backend, -1, 8, "rwkv7");
3476 }
3477
3478 #[test]
3479 fn rosaplus_sampled_generation_predicts_green_continuation() {
3480 let out = generate_rate_backend_chain(
3481 &[continuation_prompt()],
3482 8,
3483 -1,
3484 &RateBackend::RosaPlus,
3485 GenerationConfig::sampled_frozen(42),
3486 );
3487 assert_eq!(out, b" green.\n");
3488 }
3489
3490 #[test]
3491 fn rate_backend_session_matches_ctx_generation() {
3492 let prompt = continuation_prompt();
3493 let backend = RateBackend::Ppmd {
3494 order: 12,
3495 memory_mb: 8,
3496 };
3497 let mut session =
3498 RateBackendSession::from_backend(backend.clone(), -1, Some((prompt.len() + 8) as u64))
3499 .expect("session init");
3500 session.observe(prompt);
3501 let from_session = session.generate_bytes(8, GenerationConfig::sampled_frozen(42));
3502 session.finish().expect("session finish");
3503
3504 let ctx = InfotheoryCtx::new(backend, CompressionBackend::default());
3505 let from_ctx =
3506 ctx.generate_bytes_with_config(prompt, 8, -1, GenerationConfig::sampled_frozen(42));
3507 assert_eq!(from_session, from_ctx);
3508 }
3509
3510 #[test]
3511 fn biased_entropy_ctw_uses_frozen_plugin_scoring() {
3512 let backend = RateBackend::Ctw { depth: 8 };
3513 let data = b"AAAAAAAA";
3514 let plugin = biased_entropy_rate_backend(data, -1, &backend);
3515 let prequential = entropy_rate_backend(data, -1, &backend);
3516 assert!(
3517 plugin + 1e-9 < prequential,
3518 "expected plugin scoring to beat prequential scoring: plugin={plugin} prequential={prequential}"
3519 );
3520 }
3521
3522 #[test]
3523 fn rosa_plugin_entropy_matches_direct_model_api() {
3524 let data = b"abracadabra";
3525 let backend = RateBackend::RosaPlus;
3526
3527 let plugin = biased_entropy_rate_backend(data, 3, &backend);
3528
3529 let mut direct = rosaplus::RosaPlus::new(3, false, 0, 42);
3530 direct.train_example(data);
3531 direct.build_lm();
3532 let expected = direct.cross_entropy(data);
3533
3534 assert!(
3535 (plugin - expected).abs() < 1e-12,
3536 "rosa plugin entropy must match direct model API: plugin={plugin} expected={expected}"
3537 );
3538 }
3539
3540 #[test]
3541 fn rosa_plugin_cross_entropy_matches_direct_model_api() {
3542 let train = b"alakazam";
3543 let test = b"abracadabra";
3544 let backend = RateBackend::RosaPlus;
3545
3546 let plugin = cross_entropy_rate_backend(test, train, 3, &backend);
3547
3548 let mut direct = rosaplus::RosaPlus::new(3, false, 0, 42);
3549 direct.train_example(train);
3550 direct.build_lm();
3551 let expected = direct.cross_entropy(test);
3552
3553 assert!(
3554 (plugin - expected).abs() < 1e-12,
3555 "rosa plugin cross entropy must match direct model API: plugin={plugin} expected={expected}"
3556 );
3557 }
3558
3559 #[test]
3560 fn datagen_bernoulli_entropy_estimate() {
3561 let p = 0.5;
3563 let theoretical_h = crate::datagen::bernoulli_entropy(p);
3564 assert!((theoretical_h - 1.0).abs() < 1e-10);
3565
3566 let data = crate::datagen::bernoulli(10000, p, 42);
3568 let estimated_h = marginal_entropy_bytes(&data);
3569
3570 assert!(
3572 (estimated_h - theoretical_h).abs() < 0.1,
3573 "estimated H={} should be close to theoretical H={}",
3574 estimated_h,
3575 theoretical_h
3576 );
3577 }
3578
3579 #[cfg(feature = "backend-rwkv")]
3580 #[test]
3581 fn rwkv_method_entropy_is_stable_across_calls() {
3582 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=21,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:infer";
3583 let backend = RateBackend::Rwkv7Method {
3584 method: method.to_string(),
3585 };
3586 let data = b"rwkv method entropy stability regression sample";
3587
3588 let h1 = entropy_rate_backend(data, -1, &backend);
3589 let h2 = entropy_rate_backend(data, -1, &backend);
3590 assert!(
3591 (h1 - h2).abs() < 1e-12,
3592 "rwkv method entropy leaked mutable state across calls: h1={h1}, h2={h2}"
3593 );
3594 }
3595
3596 #[cfg(feature = "backend-rwkv")]
3597 #[test]
3598 fn rwkv_method_without_policy_is_accepted_by_public_api() {
3599 let backend = RateBackend::Rwkv7Method {
3600 method: "cfg:hidden=64,layers=1,intermediate=64".to_string(),
3601 };
3602 let data = b"rwkv method without policy";
3603 let h1 = entropy_rate_backend(data, -1, &backend);
3604 let h2 = biased_entropy_rate_backend(data, -1, &backend);
3605 assert!(h1.is_finite());
3606 assert!(h2.is_finite());
3607 }
3608
3609 #[cfg(feature = "backend-rwkv")]
3610 #[test]
3611 fn rwkv_infer_only_plugin_collapses_to_single_pass_entropy() {
3612 let backend = RateBackend::Rwkv7Method {
3613 method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=25,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3614 };
3615 let data = b"rwkv infer-only plugin equality sample";
3616 let h = entropy_rate_backend(data, -1, &backend);
3617 let plugin = biased_entropy_rate_backend(data, -1, &backend);
3618 assert!(
3619 (h - plugin).abs() < 1e-12,
3620 "infer-only rwkv plugin should equal single-pass entropy: h={h}, plugin={plugin}"
3621 );
3622 }
3623
3624 #[cfg(feature = "backend-rwkv")]
3625 #[test]
3626 fn rwkv_method_biased_entropy_is_stable_across_calls_with_training_policy() {
3627 let backend = RateBackend::Rwkv7Method {
3628 method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=23,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.0)".to_string(),
3629 };
3630 let data = b"rwkv plugin stability sample";
3631 let h1 = biased_entropy_rate_backend(data, -1, &backend);
3632 let h2 = biased_entropy_rate_backend(data, -1, &backend);
3633 assert!(
3634 (h1 - h2).abs() < 1e-12,
3635 "rwkv method biased entropy leaked mutable state across calls: h1={h1}, h2={h2}"
3636 );
3637 }
3638
3639 #[cfg(feature = "backend-rwkv")]
3640 #[test]
3641 fn rwkv_method_conditional_chain_is_stable_across_calls() {
3642 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=22,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:infer";
3643 let ctx = InfotheoryCtx::new(
3644 RateBackend::Rwkv7Method {
3645 method: method.to_string(),
3646 },
3647 CompressionBackend::default(),
3648 );
3649
3650 let prefix = b"universal prior slice";
3651 let data = b"query payload";
3652 let h1 = ctx.cross_entropy_conditional_chain(&[prefix.as_slice()], data);
3653 let h2 = ctx.cross_entropy_conditional_chain(&[prefix.as_slice()], data);
3654 assert!(
3655 (h1 - h2).abs() < 1e-12,
3656 "rwkv method conditional chain leaked mutable state across calls: h1={h1}, h2={h2}"
3657 );
3658 }
3659
3660 #[cfg(feature = "backend-mamba")]
3661 #[test]
3662 fn mamba_method_without_policy_is_accepted_by_public_api() {
3663 let backend = RateBackend::MambaMethod {
3664 method: "cfg:hidden=64,layers=1,intermediate=96".to_string(),
3665 };
3666 let data = b"mamba method without policy";
3667 let h1 = entropy_rate_backend(data, -1, &backend);
3668 let h2 = biased_entropy_rate_backend(data, -1, &backend);
3669 assert!(h1.is_finite());
3670 assert!(h2.is_finite());
3671 }
3672
3673 #[cfg(feature = "backend-mamba")]
3674 #[test]
3675 fn mamba_infer_only_plugin_collapses_to_single_pass_entropy() {
3676 let backend = RateBackend::MambaMethod {
3677 method: "cfg:hidden=64,layers=1,intermediate=96,state=16,conv=4,dt_rank=16,seed=26,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3678 };
3679 let data = b"mamba infer-only plugin equality sample";
3680 let h = entropy_rate_backend(data, -1, &backend);
3681 let plugin = biased_entropy_rate_backend(data, -1, &backend);
3682 assert!(
3683 (h - plugin).abs() < 1e-12,
3684 "infer-only mamba plugin should equal single-pass entropy: h={h}, plugin={plugin}"
3685 );
3686 }
3687
3688 #[cfg(feature = "backend-mamba")]
3689 #[test]
3690 fn mamba_method_biased_entropy_is_stable_across_calls_with_training_policy() {
3691 let backend = RateBackend::MambaMethod {
3692 method: "cfg:hidden=64,layers=1,intermediate=96,state=16,conv=4,dt_rank=16,seed=24,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.0)".to_string(),
3693 };
3694 let data = b"mamba plugin stability sample";
3695 let h1 = biased_entropy_rate_backend(data, -1, &backend);
3696 let h2 = biased_entropy_rate_backend(data, -1, &backend);
3697 assert!(
3698 (h1 - h2).abs() < 1e-12,
3699 "mamba method biased entropy leaked mutable state across calls: h1={h1}, h2={h2}"
3700 );
3701 }
3702
3703 #[test]
3704 fn particle_entropy_rate_in_valid_range() {
3705 let rb = test_particle_backend();
3706 let data = b"hello world particle backend test";
3707 let rate = entropy_rate_backend(data, -1, &rb);
3708 assert!(
3709 rate > 0.0 && rate < 8.0,
3710 "particle entropy rate out of (0, 8) range: {rate}"
3711 );
3712 }
3713
3714 #[test]
3715 fn particle_cross_entropy_stability() {
3716 let rb = test_particle_backend();
3717 let train = b"ABCABC";
3718 let test = b"ABC";
3719 let h1 = cross_entropy_rate_backend(test, train, -1, &rb);
3720 let h2 = cross_entropy_rate_backend(test, train, -1, &rb);
3721 assert!(
3722 (h1 - h2).abs() < 1e-12,
3723 "particle cross entropy not deterministic: h1={h1}, h2={h2}"
3724 );
3725 }
3726
3727 #[test]
3728 fn particle_empty_input() {
3729 let rb = RateBackend::Particle {
3730 spec: Arc::new(ParticleSpec::default()),
3731 };
3732 let rate = entropy_rate_backend(b"", -1, &rb);
3733 assert!(
3734 rate == 0.0,
3735 "particle entropy rate for empty input should be 0.0, got {rate}"
3736 );
3737 }
3738
3739 #[test]
3740 fn particle_joint_entropy_rate() {
3741 let rb = test_particle_backend();
3742 let x = b"AAAA";
3743 let y = b"BBBB";
3744 let joint = joint_entropy_rate_backend(x, y, -1, &rb);
3745 assert!(
3746 joint > 0.0 && joint < 16.0,
3747 "particle joint entropy rate out of range: {joint}"
3748 );
3749 }
3750}