1#![allow(unsafe_op_in_unsafe_fn)]
2
3pub mod aixi;
77pub mod axioms;
79pub mod backends;
81pub mod coders;
83pub mod compression;
85pub mod datagen;
87pub mod diagnostics;
89pub mod mixture;
91pub(crate) mod neural_mix;
92pub mod search;
94pub(crate) mod simd_math;
95pub use backends::ctw;
97#[cfg(feature = "backend-mamba")]
98pub use backends::mambazip;
100pub use backends::match_model;
102pub use backends::particle;
104pub use backends::ppmd;
106pub use backends::rosaplus;
108#[cfg(feature = "backend-rwkv")]
109pub use backends::rwkvzip;
111pub use backends::sequitur;
113pub use backends::sparse_match;
115pub use backends::zpaq_rate;
117
118use rayon::prelude::*;
119
120use crate::coders::CoderType;
121use std::cell::RefCell;
122#[cfg(any(feature = "backend-rwkv", feature = "backend-mamba"))]
123use std::collections::HashMap;
124use std::sync::Arc;
125use std::sync::OnceLock;
126
127#[derive(Clone, Copy, Debug, PartialEq, Eq)]
129pub enum GenerationUpdateMode {
130 Adaptive,
132 Frozen,
134}
135
136#[derive(Clone, Copy, Debug, PartialEq, Eq)]
138pub enum GenerationStrategy {
139 Greedy,
141 Sample,
143}
144
145#[derive(Clone, Copy, Debug)]
147pub struct GenerationConfig {
148 pub strategy: GenerationStrategy,
150 pub update_mode: GenerationUpdateMode,
152 pub seed: u64,
154 pub temperature: f64,
156 pub top_k: usize,
158 pub top_p: f64,
160}
161
162impl Default for GenerationConfig {
163 fn default() -> Self {
164 Self::sampled_frozen(42)
165 }
166}
167
168impl GenerationConfig {
169 pub const fn greedy_frozen() -> Self {
171 Self {
172 strategy: GenerationStrategy::Greedy,
173 update_mode: GenerationUpdateMode::Frozen,
174 seed: 0xD00D_F00D_CAFE_BABEu64,
175 temperature: 1.0,
176 top_k: 0,
177 top_p: 1.0,
178 }
179 }
180
181 pub const fn sampled_frozen(seed: u64) -> Self {
183 Self {
184 strategy: GenerationStrategy::Sample,
185 update_mode: GenerationUpdateMode::Frozen,
186 seed,
187 temperature: 1.0,
188 top_k: 0,
189 top_p: 1.0,
190 }
191 }
192}
193
194struct GenerationRng {
195 state: u64,
196}
197
198impl GenerationRng {
199 fn new(seed: u64) -> Self {
200 Self {
201 state: if seed == 0 {
202 0xD00D_F00D_CAFE_BABEu64
203 } else {
204 seed
205 },
206 }
207 }
208
209 fn next_u64(&mut self) -> u64 {
210 let mut x = self.state;
211 x ^= x << 13;
212 x ^= x >> 7;
213 x ^= x << 17;
214 self.state = x;
215 x
216 }
217
218 fn next_f64(&mut self) -> f64 {
219 (self.next_u64() as f64) / (u64::MAX as f64)
220 }
221}
222
223static NUM_THREADS: OnceLock<usize> = OnceLock::new();
224
225thread_local! {
226 #[cfg(feature = "backend-mamba")]
227 static MAMBA_TLS: RefCell<HashMap<usize, mambazip::Compressor>> = RefCell::new(HashMap::new());
228 #[cfg(feature = "backend-mamba")]
229 static MAMBA_RATE_TLS: RefCell<HashMap<usize, mambazip::Compressor>> = RefCell::new(HashMap::new());
230 #[cfg(feature = "backend-mamba")]
231 static MAMBA_METHOD_TLS: RefCell<HashMap<String, mambazip::Compressor>> = RefCell::new(HashMap::new());
232 #[cfg(feature = "backend-rwkv")]
233 static RWKV_TLS: RefCell<HashMap<usize, rwkvzip::Compressor>> = RefCell::new(HashMap::new());
234 #[cfg(feature = "backend-rwkv")]
235 static RWKV_RATE_TLS: RefCell<HashMap<usize, rwkvzip::Compressor>> = RefCell::new(HashMap::new());
236 #[cfg(feature = "backend-rwkv")]
237 static RWKV_METHOD_TLS: RefCell<HashMap<String, rwkvzip::Compressor>> = RefCell::new(HashMap::new());
238}
239
240#[cfg(feature = "backend-zpaq")]
241impl Default for CompressionBackend {
242 fn default() -> Self {
243 CompressionBackend::Zpaq {
244 method: "5".to_string(),
245 }
246 }
247}
248
249#[cfg(not(feature = "backend-zpaq"))]
250impl Default for CompressionBackend {
251 fn default() -> Self {
252 CompressionBackend::Rate {
253 rate_backend: RateBackend::default(),
254 coder: CoderType::AC,
255 framing: compression::FramingMode::Raw,
256 }
257 }
258}
259
260thread_local! {
261 static DEFAULT_CTX: RefCell<InfotheoryCtx> = RefCell::new(InfotheoryCtx::default());
262}
263
264pub fn get_default_ctx() -> InfotheoryCtx {
266 DEFAULT_CTX.with(|ctx| ctx.borrow().clone())
267}
268
269pub fn set_default_ctx(ctx: InfotheoryCtx) {
271 DEFAULT_CTX.with(|c| *c.borrow_mut() = ctx);
272}
273
274#[inline(always)]
275fn with_default_ctx<R>(f: impl FnOnce(&InfotheoryCtx) -> R) -> R {
276 DEFAULT_CTX.with(|ctx| f(&ctx.borrow()))
277}
278
279pub fn mutual_information_rate_backend(
283 x: &[u8],
284 y: &[u8],
285 max_order: i64,
286 backend: &RateBackend,
287) -> f64 {
288 let (x, y) = aligned_prefix(x, y);
289 if x.is_empty() {
290 return 0.0;
291 }
292 let h_x = entropy_rate_backend(x, max_order, backend);
295 let h_y = entropy_rate_backend(y, max_order, backend);
296 let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
297 (h_x + h_y - h_xy).max(0.0)
298}
299
300pub fn ned_rate_backend(x: &[u8], y: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
304 let (x, y) = aligned_prefix(x, y);
305 if x.is_empty() {
306 return 0.0;
307 }
308 let h_x = entropy_rate_backend(x, max_order, backend);
309 let h_y = entropy_rate_backend(y, max_order, backend);
310 let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
311 let min_h = h_x.min(h_y);
312 let max_h = h_x.max(h_y);
313 if max_h == 0.0 {
314 0.0
315 } else {
316 ((h_xy - min_h) / max_h).clamp(0.0, 1.0)
317 }
318}
319
320pub fn nte_rate_backend(x: &[u8], y: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
324 let (x, y) = aligned_prefix(x, y);
325 if x.is_empty() {
326 return 0.0;
327 }
328 let h_x = entropy_rate_backend(x, max_order, backend);
329 let h_y = entropy_rate_backend(y, max_order, backend);
330 let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
331 let max_h = h_x.max(h_y);
332 if max_h == 0.0 {
333 0.0
334 } else {
335 let vi = (h_xy - h_x).max(0.0) + (h_xy - h_y).max(0.0);
338 (vi / max_h).clamp(0.0, 2.0)
339 }
340}
341
342#[derive(Clone)]
348pub enum RateBackend {
349 RosaPlus,
351 Match {
353 hash_bits: usize,
355 min_len: usize,
357 max_len: usize,
359 base_mix: f64,
361 confidence_scale: f64,
363 },
364 SparseMatch {
366 hash_bits: usize,
368 min_len: usize,
370 max_len: usize,
372 gap_min: usize,
374 gap_max: usize,
376 base_mix: f64,
378 confidence_scale: f64,
380 },
381 Ppmd {
383 order: usize,
385 memory_mb: usize,
387 },
388 Sequitur {
390 context_bytes: usize,
392 },
393 #[cfg(feature = "backend-mamba")]
394 Mamba {
396 model: Arc<mambazip::Model>,
398 },
399 #[cfg(feature = "backend-mamba")]
400 MambaMethod {
402 method: String,
404 },
405 #[cfg(feature = "backend-rwkv")]
406 Rwkv7 {
408 model: Arc<rwkvzip::Model>,
410 },
411 #[cfg(feature = "backend-rwkv")]
412 Rwkv7Method {
414 method: String,
416 },
417 Zpaq {
419 method: String,
421 },
422 Mixture {
428 spec: Arc<MixtureSpec>,
430 },
431 Particle {
433 spec: Arc<ParticleSpec>,
435 },
436 Calibrated {
438 spec: Arc<CalibratedSpec>,
440 },
441 Ctw {
443 depth: usize,
445 },
446 FacCtw {
448 base_depth: usize,
450 num_percept_bits: usize,
452 encoding_bits: usize,
454 },
455}
456
457#[allow(clippy::derivable_impls)]
458impl Default for RateBackend {
459 fn default() -> Self {
460 #[cfg(feature = "backend-rosa")]
461 {
462 RateBackend::RosaPlus
463 }
464 #[cfg(all(not(feature = "backend-rosa"), feature = "backend-zpaq"))]
465 {
466 RateBackend::Zpaq {
467 method: "1".to_string(),
468 }
469 }
470 #[cfg(all(not(feature = "backend-rosa"), not(feature = "backend-zpaq")))]
471 {
472 RateBackend::Ctw { depth: 16 }
473 }
474 }
475}
476
477#[derive(Clone)]
479pub enum CompressionBackend {
480 Zpaq {
482 method: String,
484 },
485 #[cfg(feature = "backend-rwkv")]
486 Rwkv7 {
488 model: Arc<rwkvzip::Model>,
490 coder: CoderType,
492 },
493 Rate {
495 rate_backend: RateBackend,
497 coder: CoderType,
499 framing: compression::FramingMode,
501 },
502}
503
504pub const MAX_MIXTURE_NESTING: usize = 8;
506
507#[derive(Clone, Copy, Debug, Eq, PartialEq)]
509pub enum MixtureKind {
510 Bayes,
512 FadingBayes,
514 Switching,
517 Convex,
519 Mdl,
521 Neural,
523}
524
525#[derive(Clone, Copy, Debug, Eq, PartialEq)]
527pub enum MixtureScheduleMode {
528 Default,
533 Theorem,
545}
546
547impl Default for MixtureScheduleMode {
548 fn default() -> Self {
549 Self::Default
550 }
551}
552
553pub fn parse_mixture_kind_name(kind: &str) -> Result<MixtureKind, String> {
555 match kind.trim().to_ascii_lowercase().as_str() {
556 "bayes" | "bayes-mix" | "bayes_mix" => Ok(MixtureKind::Bayes),
557 "fading" | "fading-bayes" | "fading_bayes" => Ok(MixtureKind::FadingBayes),
558 "switch" | "switching" | "switch-mix" | "switch_mix" => Ok(MixtureKind::Switching),
559 "convex" | "convex-mix" | "convex_mix" => Ok(MixtureKind::Convex),
560 "mdl" | "selector" | "mdr" => Ok(MixtureKind::Mdl),
561 "neural" | "neural-mix" | "neural_mix" | "mix" | "mixture" | "fx2" | "fx2-cmix"
562 | "fx2_cmix" => Ok(MixtureKind::Neural),
563 other => Err(format!("unknown mixture kind '{other}'")),
564 }
565}
566
567pub fn parse_mixture_schedule_name(schedule: &str) -> Result<MixtureScheduleMode, String> {
569 match schedule.trim().to_ascii_lowercase().as_str() {
570 "" | "default" | "constant" | "const" => Ok(MixtureScheduleMode::Default),
571 "theorem" | "paper" | "paper-theorem" | "paper_theorem" => Ok(MixtureScheduleMode::Theorem),
572 other => Err(format!("unknown mixture schedule '{other}'")),
573 }
574}
575
576#[derive(Clone, Copy, Debug, Eq, PartialEq)]
578pub enum CalibrationContextKind {
579 Global,
581 ByteClass,
583 Text,
585 Repeat,
587 TextRepeat,
589}
590
591#[derive(Clone)]
593pub struct CalibratedSpec {
594 pub base: RateBackend,
596 pub context: CalibrationContextKind,
598 pub bins: usize,
600 pub learning_rate: f64,
602 pub bias_clip: f64,
604}
605
606#[derive(Clone)]
608pub struct MixtureExpertSpec {
609 pub name: Option<String>,
611 pub log_prior: f64,
613 pub max_order: i64,
615 pub backend: RateBackend,
617}
618
619#[derive(Clone)]
621pub struct MixtureSpec {
622 pub kind: MixtureKind,
624 pub schedule: MixtureScheduleMode,
626 pub alpha: f64,
632 pub decay: Option<f64>,
634 pub experts: Vec<MixtureExpertSpec>,
636}
637
638impl MixtureSpec {
639 pub fn new(kind: MixtureKind, experts: Vec<MixtureExpertSpec>) -> Self {
641 Self {
642 kind,
643 schedule: MixtureScheduleMode::Default,
644 alpha: 0.01,
645 decay: None,
646 experts,
647 }
648 }
649
650 pub fn with_schedule(mut self, schedule: MixtureScheduleMode) -> Self {
652 self.schedule = schedule;
653 self
654 }
655
656 pub fn with_alpha(mut self, alpha: f64) -> Self {
658 self.alpha = alpha;
659 self
660 }
661
662 pub fn with_decay(mut self, decay: f64) -> Self {
664 self.decay = Some(decay);
665 self
666 }
667
668 pub fn validate(&self) -> Result<(), String> {
670 validate_mixture_spec_with_depth(self, MAX_MIXTURE_NESTING)
671 }
672
673 pub fn build_experts(&self) -> Vec<crate::mixture::ExpertConfig> {
675 self.experts
676 .iter()
677 .map(|spec| {
678 crate::mixture::ExpertConfig::from_rate_backend(
679 spec.name.clone(),
680 spec.log_prior,
681 spec.backend.clone(),
682 spec.max_order,
683 )
684 })
685 .collect()
686 }
687}
688
689fn validate_mixture_spec_with_depth(spec: &MixtureSpec, depth: usize) -> Result<(), String> {
690 if depth == 0 {
691 return Err("mixture spec nesting too deep".to_string());
692 }
693 validate_mixture_spec_shallow(spec)?;
694 for (index, expert) in spec.experts.iter().enumerate() {
695 validate_rate_backend_with_depth(&expert.backend, depth - 1).map_err(|err| {
696 if let Some(name) = expert.name.as_deref() {
697 format!("mixture expert '{name}' invalid: {err}")
698 } else {
699 format!("mixture expert #{} invalid: {err}", index + 1)
700 }
701 })?;
702 }
703 Ok(())
704}
705
706fn validate_mixture_spec_shallow(spec: &MixtureSpec) -> Result<(), String> {
707 if spec.experts.is_empty() {
708 return Err("mixture spec must include at least one expert".to_string());
709 }
710 if !spec.alpha.is_finite() {
711 return Err("mixture alpha must be finite".to_string());
712 }
713 if spec
714 .experts
715 .iter()
716 .any(|expert| !expert.log_prior.is_finite())
717 {
718 return Err("mixture expert log_prior must be finite".to_string());
719 }
720 if let Some(decay) = spec.decay {
721 if !decay.is_finite() || !(0.0..1.0).contains(&decay) {
722 return Err("mixture decay must be in (0, 1)".to_string());
723 }
724 }
725 if matches!(spec.kind, MixtureKind::FadingBayes) && spec.decay.is_none() {
726 return Err("fading Bayes mixture requires decay".to_string());
727 }
728 if spec.schedule != MixtureScheduleMode::Default
729 && !matches!(spec.kind, MixtureKind::Switching | MixtureKind::Convex)
730 {
731 return Err(
732 "mixture schedule is only supported for switching and convex mixtures".to_string(),
733 );
734 }
735 match (spec.kind, spec.schedule) {
736 (MixtureKind::Switching, MixtureScheduleMode::Default) => {
737 if !(0.0..=1.0).contains(&spec.alpha) {
738 return Err("switching mixture alpha must be in [0, 1]".to_string());
739 }
740 }
741 (MixtureKind::Convex, MixtureScheduleMode::Default)
742 | (MixtureKind::Neural, MixtureScheduleMode::Default) => {
743 if spec.alpha <= 0.0 {
744 return Err("mixture alpha must be > 0".to_string());
745 }
746 }
747 (MixtureKind::Neural, MixtureScheduleMode::Theorem) => unreachable!(),
748 _ => {}
749 }
750 Ok(())
751}
752
753fn validate_rate_backend_with_depth(backend: &RateBackend, depth: usize) -> Result<(), String> {
754 match backend {
755 RateBackend::Sequitur { context_bytes } => {
756 if *context_bytes < 2 {
757 Err("sequitur context_bytes must be >= 2".to_string())
758 } else {
759 Ok(())
760 }
761 }
762 RateBackend::Mixture { spec } => validate_mixture_spec_with_depth(spec.as_ref(), depth),
763 RateBackend::Particle { spec } => spec.validate(),
764 RateBackend::Calibrated { spec } => {
765 if depth == 0 {
766 return Err("calibrated spec nesting too deep".to_string());
767 }
768 validate_rate_backend_with_depth(&spec.base, depth - 1)
769 .map_err(|err| format!("calibrated base invalid: {err}"))
770 }
771 _ => Ok(()),
772 }
773}
774
775pub fn validate_rate_backend(backend: &RateBackend) -> Result<(), String> {
777 validate_rate_backend_with_depth(backend, MAX_MIXTURE_NESTING)
778}
779
780#[derive(Clone, Debug)]
782pub struct ParticleSpec {
783 pub num_particles: usize,
785 pub context_window: usize,
787 pub unroll_steps: usize,
789 pub num_cells: usize,
791 pub cell_dim: usize,
793 pub num_rules: usize,
795 pub selector_hidden: usize,
797 pub rule_hidden: usize,
799 pub noise_dim: usize,
801 pub deterministic: bool,
803 pub enable_noise: bool,
805 pub noise_scale: f64,
807 pub noise_anneal_steps: usize,
809 pub learning_rate_readout: f64,
811 pub learning_rate_selector: f64,
813 pub learning_rate_rule: f64,
815 pub bptt_depth: usize,
817 pub optimizer_momentum: f64,
819 pub grad_clip: f64,
821 pub state_clip: f64,
823 pub forget_lambda: f64,
825 pub resample_threshold: f64,
827 pub mutate_fraction: f64,
829 pub mutate_scale: f64,
831 pub mutate_model_params: bool,
833 pub diagnostics_interval: usize,
835 pub min_prob: f64,
837 pub seed: u64,
839}
840
841impl Default for ParticleSpec {
842 fn default() -> Self {
843 Self {
844 num_particles: 16,
845 context_window: 32,
846 unroll_steps: 2,
847 num_cells: 8,
848 cell_dim: 32,
849 num_rules: 4,
850 selector_hidden: 64,
851 rule_hidden: 64,
852 noise_dim: 8,
853 deterministic: true,
854 enable_noise: false,
855 noise_scale: 0.10,
856 noise_anneal_steps: 8192,
857 learning_rate_readout: 0.01,
858 learning_rate_selector: 1e-4,
859 learning_rate_rule: 3e-4,
860 bptt_depth: 3,
861 optimizer_momentum: 0.05,
862 grad_clip: 1.0,
863 state_clip: 8.0,
864 forget_lambda: 0.0,
865 resample_threshold: 0.5,
866 mutate_fraction: 0.1,
867 mutate_scale: 0.01,
868 mutate_model_params: false,
869 diagnostics_interval: 0,
870 min_prob: 2f64.powi(-24),
871 seed: 42,
872 }
873 }
874}
875
876impl ParticleSpec {
877 pub fn validate(&self) -> Result<(), String> {
879 if self.num_particles == 0 {
880 return Err("num_particles must be > 0".into());
881 }
882 if self.context_window == 0 {
883 return Err("context_window must be > 0".into());
884 }
885 if self.unroll_steps == 0 {
886 return Err("unroll_steps must be > 0".into());
887 }
888 if self.num_cells == 0 {
889 return Err("num_cells must be > 0".into());
890 }
891 if self.cell_dim == 0 {
892 return Err("cell_dim must be > 0".into());
893 }
894 if self.num_rules == 0 {
895 return Err("num_rules must be > 0".into());
896 }
897 if self.selector_hidden == 0 {
898 return Err("selector_hidden must be > 0".into());
899 }
900 if self.rule_hidden == 0 {
901 return Err("rule_hidden must be > 0".into());
902 }
903 if !self.learning_rate_readout.is_finite() || self.learning_rate_readout < 0.0 {
904 return Err("learning_rate_readout must be finite and non-negative".into());
905 }
906 if !self.learning_rate_selector.is_finite() || self.learning_rate_selector < 0.0 {
907 return Err("learning_rate_selector must be finite and non-negative".into());
908 }
909 if !self.learning_rate_rule.is_finite() || self.learning_rate_rule < 0.0 {
910 return Err("learning_rate_rule must be finite and non-negative".into());
911 }
912 if !self.noise_scale.is_finite() || self.noise_scale < 0.0 {
913 return Err("noise_scale must be finite and non-negative".into());
914 }
915 if !self.optimizer_momentum.is_finite()
916 || self.optimizer_momentum < 0.0
917 || self.optimizer_momentum >= 1.0
918 {
919 return Err("optimizer_momentum must be finite and in [0, 1)".into());
920 }
921 if self.bptt_depth == 0 {
922 return Err("bptt_depth must be > 0".into());
923 }
924 if !(self.resample_threshold > 0.0 && self.resample_threshold <= 1.0) {
925 return Err("resample_threshold must be in (0, 1]".into());
926 }
927 if !(self.mutate_fraction >= 0.0 && self.mutate_fraction <= 1.0) {
928 return Err("mutate_fraction must be in [0, 1]".into());
929 }
930 if !(self.min_prob > 0.0 && self.min_prob < 0.5) {
931 return Err("min_prob must be in (0, 0.5)".into());
932 }
933 Ok(())
934 }
935}
936
937#[derive(Clone, Default)]
939pub struct InfotheoryCtx {
940 pub rate_backend: RateBackend,
942 pub compression_backend: CompressionBackend,
944}
945
946pub struct RateBackendSession {
948 predictor: crate::mixture::RateBackendPredictor,
949}
950
951impl RateBackendSession {
952 pub fn from_backend(
954 backend: RateBackend,
955 max_order: i64,
956 total_symbols: Option<u64>,
957 ) -> Result<Self, String> {
958 use crate::mixture::OnlineBytePredictor;
959
960 validate_rate_backend(&backend)?;
961 let mut predictor = crate::mixture::RateBackendPredictor::from_backend(
962 backend,
963 max_order,
964 crate::mixture::DEFAULT_MIN_PROB,
965 );
966 predictor.begin_stream(total_symbols)?;
967 Ok(Self { predictor })
968 }
969
970 pub fn observe(&mut self, data: &[u8]) {
972 use crate::mixture::OnlineBytePredictor;
973
974 for &byte in data {
975 self.predictor.update(byte);
976 }
977 }
978
979 pub fn condition(&mut self, data: &[u8]) {
981 use crate::mixture::OnlineBytePredictor;
982
983 for &byte in data {
984 self.predictor.update_frozen(byte);
985 }
986 }
987
988 pub fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
990 use crate::mixture::OnlineBytePredictor;
991
992 self.predictor.reset_frozen(total_symbols)
993 }
994
995 pub fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
997 use crate::mixture::OnlineBytePredictor;
998
999 self.predictor.fill_log_probs(out);
1000 }
1001
1002 pub fn generate_bytes(&mut self, bytes: usize, config: GenerationConfig) -> Vec<u8> {
1004 use crate::mixture::OnlineBytePredictor;
1005
1006 if bytes == 0 {
1007 return Vec::new();
1008 }
1009
1010 let mut out = Vec::with_capacity(bytes);
1011 let mut logps = [0.0f64; 256];
1012 let mut rng = GenerationRng::new(config.seed);
1013
1014 for _ in 0..bytes {
1015 match &mut self.predictor {
1016 crate::mixture::RateBackendPredictor::Rosa { .. } => {
1018 for (sym, slot) in logps.iter_mut().enumerate() {
1019 *slot = self.predictor.log_prob(sym as u8);
1020 }
1021 }
1022 _ => self.predictor.fill_log_probs(&mut logps),
1023 }
1024 let byte = pick_generated_byte(&logps, config, &mut rng);
1025 match config.update_mode {
1026 GenerationUpdateMode::Adaptive => self.predictor.update(byte),
1027 GenerationUpdateMode::Frozen => self.predictor.update_frozen(byte),
1028 }
1029 out.push(byte);
1030 }
1031
1032 out
1033 }
1034
1035 pub fn finish(&mut self) -> Result<(), String> {
1037 use crate::mixture::OnlineBytePredictor;
1038
1039 self.predictor.finish_stream()
1040 }
1041}
1042
1043impl InfotheoryCtx {
1044 pub fn new(rate_backend: RateBackend, compression_backend: CompressionBackend) -> Self {
1046 Self {
1047 rate_backend,
1048 compression_backend,
1049 }
1050 }
1051
1052 pub fn with_zpaq(method: impl Into<String>) -> Self {
1054 Self {
1055 rate_backend: RateBackend::RosaPlus,
1056 compression_backend: CompressionBackend::Zpaq {
1057 method: method.into(),
1058 },
1059 }
1060 }
1061
1062 pub fn compress_size(&self, data: &[u8]) -> u64 {
1064 compress_size_backend(data, &self.compression_backend)
1065 }
1066
1067 pub fn compress_size_chain(&self, parts: &[&[u8]]) -> u64 {
1069 compress_size_chain_backend(parts, &self.compression_backend)
1070 }
1071
1072 pub fn rate_backend_session(
1074 &self,
1075 max_order: i64,
1076 total_symbols: Option<u64>,
1077 ) -> Result<RateBackendSession, String> {
1078 RateBackendSession::from_backend(self.rate_backend.clone(), max_order, total_symbols)
1079 }
1080
1081 pub fn entropy_rate_bytes(&self, data: &[u8], max_order: i64) -> f64 {
1083 entropy_rate_backend(data, max_order, &self.rate_backend)
1084 }
1085
1086 pub fn biased_entropy_rate_bytes(&self, data: &[u8], max_order: i64) -> f64 {
1088 biased_entropy_rate_backend(data, max_order, &self.rate_backend)
1089 }
1090
1091 pub fn cross_entropy_rate_bytes(
1093 &self,
1094 test_data: &[u8],
1095 train_data: &[u8],
1096 max_order: i64,
1097 ) -> f64 {
1098 cross_entropy_rate_backend(test_data, train_data, max_order, &self.rate_backend)
1099 }
1100
1101 pub fn cross_entropy_bytes(&self, test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
1103 if max_order == 0 {
1104 if test_data.is_empty() {
1105 return 0.0;
1106 }
1107 let p_x = byte_histogram(test_data);
1108 let p_y = byte_histogram(train_data);
1109 let mut h = 0.0f64;
1110 for i in 0..256 {
1111 if p_x[i] > 0.0 {
1112 let q_y = p_y[i].max(1e-12);
1113 h -= p_x[i] * q_y.log2();
1114 }
1115 }
1116 h
1117 } else {
1118 self.cross_entropy_rate_bytes(test_data, train_data, max_order)
1119 }
1120 }
1121
1122 pub fn joint_entropy_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1124 let (x, y) = aligned_prefix(x, y);
1125 if x.is_empty() {
1126 return 0.0;
1127 }
1128 joint_entropy_rate_backend(x, y, max_order, &self.rate_backend)
1129 }
1130
1131 pub fn conditional_entropy_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1133 let (x, y) = aligned_prefix(x, y);
1134 if x.is_empty() {
1135 return 0.0;
1136 }
1137 let h_xy = self.joint_entropy_rate_bytes(x, y, max_order);
1138 let h_y = self.entropy_rate_bytes(y, max_order);
1139 (h_xy - h_y).max(0.0)
1140 }
1141
1142 pub fn cross_entropy_conditional_chain(&self, prefix_parts: &[&[u8]], data: &[u8]) -> f64 {
1145 match &self.rate_backend {
1146 RateBackend::RosaPlus => {
1147 let mut prefix = Vec::new();
1148 let total: usize = prefix_parts.iter().map(|p| p.len()).sum();
1149 prefix.reserve(total);
1150 for p in prefix_parts {
1151 prefix.extend_from_slice(p);
1152 }
1153 cross_entropy_rate_backend(data, &prefix, -1, &RateBackend::RosaPlus)
1154 }
1155 RateBackend::Match { .. }
1156 | RateBackend::SparseMatch { .. }
1157 | RateBackend::Ppmd { .. }
1158 | RateBackend::Sequitur { .. }
1159 | RateBackend::Calibrated { .. } => {
1160 prequential_rate_backend(data, prefix_parts, -1, &self.rate_backend)
1161 }
1162 #[cfg(feature = "backend-rwkv")]
1163 RateBackend::Rwkv7 { model } => with_rwkv_tls(model, |c| {
1164 c.cross_entropy_conditional_chain(prefix_parts, data)
1165 .unwrap_or_else(|e| panic!("rwkv conditional-chain scoring failed: {e:#}"))
1166 }),
1167 #[cfg(feature = "backend-rwkv")]
1168 RateBackend::Rwkv7Method { method } => with_rwkv_method_tls(method, |c| {
1169 c.cross_entropy_conditional_chain(prefix_parts, data)
1170 .unwrap_or_else(|e| {
1171 panic!("rwkv method conditional-chain scoring failed: {e:#}")
1172 })
1173 }),
1174 #[cfg(feature = "backend-mamba")]
1175 RateBackend::Mamba { model } => with_mamba_tls(model, |c| {
1176 c.cross_entropy_conditional_chain(prefix_parts, data)
1177 .unwrap_or_else(|e| panic!("mamba conditional-chain scoring failed: {e:#}"))
1178 }),
1179 #[cfg(feature = "backend-mamba")]
1180 RateBackend::MambaMethod { method } => with_mamba_method_tls(method, |c| {
1181 c.cross_entropy_conditional_chain(prefix_parts, data)
1182 .unwrap_or_else(|e| {
1183 panic!("mamba method conditional-chain scoring failed: {e:#}")
1184 })
1185 }),
1186 RateBackend::Ctw { depth } => {
1187 if data.is_empty() {
1188 return 0.0;
1189 }
1190 let mut tree = crate::ctw::ContextTree::new(*depth);
1191 for &part in prefix_parts {
1192 for &b in part {
1193 for i in (0..8).rev() {
1194 tree.update(((b >> i) & 1) == 1);
1195 }
1196 }
1197 }
1198 let log_p_prefix = tree.get_log_block_probability();
1199 for &b in data {
1200 for i in (0..8).rev() {
1201 tree.update(((b >> i) & 1) == 1);
1202 }
1203 }
1204 let log_p_joint = tree.get_log_block_probability();
1205 let log_p_cond = log_p_joint - log_p_prefix;
1206 let bits = -log_p_cond / std::f64::consts::LN_2;
1207 bits / (data.len() as f64)
1208 }
1209 RateBackend::Zpaq { method } => {
1210 if data.is_empty() {
1211 return 0.0;
1212 }
1213 let mut model =
1214 crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
1215 for &part in prefix_parts {
1216 model.update_and_score(part);
1217 }
1218 let bits = model.update_and_score(data);
1219 bits / (data.len() as f64)
1220 }
1221 RateBackend::Mixture { spec } => {
1222 if data.is_empty() {
1223 return 0.0;
1224 }
1225 let experts = spec.build_experts();
1226 let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
1227 .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
1228 let total = prefix_parts
1229 .iter()
1230 .map(|p| p.len() as u64)
1231 .sum::<u64>()
1232 .saturating_add(data.len() as u64);
1233 mix.begin_stream(Some(total))
1234 .unwrap_or_else(|e| panic!("Mixture stream init failed: {e}"));
1235 for &part in prefix_parts {
1236 for &b in part {
1237 mix.step(b);
1238 }
1239 }
1240 let mut bits = 0.0;
1241 for &b in data {
1242 bits -= mix.step(b) / std::f64::consts::LN_2;
1243 }
1244 bits / (data.len() as f64)
1245 }
1246 RateBackend::Particle { spec } => {
1247 if data.is_empty() {
1248 return 0.0;
1249 }
1250 let mut runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
1251 for &part in prefix_parts {
1252 for &b in part {
1253 runtime.step(b);
1254 }
1255 }
1256 let mut bits = 0.0;
1257 for &b in data {
1258 bits -= runtime.step(b) / std::f64::consts::LN_2;
1259 }
1260 bits / (data.len() as f64)
1261 }
1262 RateBackend::FacCtw {
1263 base_depth,
1264 num_percept_bits: _,
1265 encoding_bits,
1266 } => {
1267 if data.is_empty() {
1268 return 0.0;
1269 }
1270 let bits_per_byte = (*encoding_bits).clamp(1, 8);
1271 let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte);
1272 for &part in prefix_parts {
1273 for &b in part {
1274 for i in 0..bits_per_byte {
1276 let bit_idx = i;
1277 fac.update(((b >> i) & 1) == 1, bit_idx);
1279 }
1280 }
1281 }
1282 let log_p_prefix = fac.get_log_block_probability();
1283 for &b in data {
1284 for i in 0..bits_per_byte {
1285 let bit_idx = i;
1286 fac.update(((b >> i) & 1) == 1, bit_idx);
1287 }
1288 }
1289 let log_p_joint = fac.get_log_block_probability();
1290 let log_p_cond = log_p_joint - log_p_prefix;
1291 let bits = -log_p_cond / std::f64::consts::LN_2;
1292 bits / (data.len() as f64)
1293 }
1294 }
1295 }
1296
1297 pub fn generate_bytes(&self, prompt: &[u8], bytes: usize, max_order: i64) -> Vec<u8> {
1301 self.generate_bytes_with_config(prompt, bytes, max_order, GenerationConfig::default())
1302 }
1303
1304 pub fn generate_bytes_with_config(
1306 &self,
1307 prompt: &[u8],
1308 bytes: usize,
1309 max_order: i64,
1310 config: GenerationConfig,
1311 ) -> Vec<u8> {
1312 generate_rate_backend_chain(&[prompt], bytes, max_order, &self.rate_backend, config)
1313 }
1314
1315 pub fn generate_bytes_conditional_chain(
1317 &self,
1318 prefix_parts: &[&[u8]],
1319 bytes: usize,
1320 max_order: i64,
1321 ) -> Vec<u8> {
1322 self.generate_bytes_conditional_chain_with_config(
1323 prefix_parts,
1324 bytes,
1325 max_order,
1326 GenerationConfig::default(),
1327 )
1328 }
1329
1330 pub fn generate_bytes_conditional_chain_with_config(
1332 &self,
1333 prefix_parts: &[&[u8]],
1334 bytes: usize,
1335 max_order: i64,
1336 config: GenerationConfig,
1337 ) -> Vec<u8> {
1338 generate_rate_backend_chain(prefix_parts, bytes, max_order, &self.rate_backend, config)
1339 }
1340
1341 pub fn ncd_bytes(&self, x: &[u8], y: &[u8], variant: NcdVariant) -> f64 {
1343 ncd_bytes_backend(x, y, &self.compression_backend, variant)
1344 }
1345
1346 pub fn mutual_information_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1348 mutual_information_rate_backend(x, y, max_order, &self.rate_backend)
1349 }
1350
1351 pub fn mutual_information_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1353 if max_order == 0 {
1354 mutual_information_marg_bytes(x, y)
1355 } else {
1356 self.mutual_information_rate_bytes(x, y, max_order)
1357 }
1358 }
1359
1360 pub fn conditional_entropy_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1362 let (x, y) = aligned_prefix(x, y);
1363 if max_order == 0 {
1364 let h_xy = joint_marginal_entropy_bytes(x, y);
1365 let h_y = marginal_entropy_bytes(y);
1366 (h_xy - h_y).max(0.0)
1367 } else {
1368 let h_xy = self.joint_entropy_rate_bytes(x, y, max_order);
1369 let h_y = self.entropy_rate_bytes(y, max_order);
1370 (h_xy - h_y).max(0.0)
1371 }
1372 }
1373
1374 pub fn ned_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1376 if max_order == 0 {
1377 ned_marg_bytes(x, y)
1378 } else {
1379 ned_rate_backend(x, y, max_order, &self.rate_backend)
1380 }
1381 }
1382
1383 pub fn ned_cons_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1385 let (x, y) = aligned_prefix(x, y);
1386 let (h_x, h_y, h_xy) = if max_order == 0 {
1387 (
1388 marginal_entropy_bytes(x),
1389 marginal_entropy_bytes(y),
1390 joint_marginal_entropy_bytes(x, y),
1391 )
1392 } else {
1393 (
1394 self.entropy_rate_bytes(x, max_order),
1395 self.entropy_rate_bytes(y, max_order),
1396 self.joint_entropy_rate_bytes(x, y, max_order),
1397 )
1398 };
1399 let min_h = h_x.min(h_y);
1400 if h_xy == 0.0 {
1401 0.0
1402 } else {
1403 ((h_xy - min_h) / h_xy).clamp(0.0, 1.0)
1404 }
1405 }
1406
1407 pub fn nte_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1409 if max_order == 0 {
1410 nte_marg_bytes(x, y)
1411 } else {
1412 nte_rate_backend(x, y, max_order, &self.rate_backend)
1413 }
1414 }
1415
1416 pub fn intrinsic_dependence_bytes(&self, data: &[u8], max_order: i64) -> f64 {
1418 let h_marginal = marginal_entropy_bytes(data);
1419 if h_marginal < 1e-9 {
1420 return 0.0;
1421 }
1422 let h_rate = self.entropy_rate_bytes(data, max_order);
1423 ((h_marginal - h_rate) / h_marginal).clamp(0.0, 1.0)
1424 }
1425
1426 pub fn resistance_to_transformation_bytes(&self, x: &[u8], tx: &[u8], max_order: i64) -> f64 {
1428 let (x, tx) = aligned_prefix(x, tx);
1429 let h_x = if max_order == 0 {
1430 marginal_entropy_bytes(x)
1431 } else {
1432 self.entropy_rate_bytes(x, max_order)
1433 };
1434 if h_x < 1e-9 {
1435 return 0.0;
1436 }
1437 let mi = self.mutual_information_bytes(x, tx, max_order);
1438 (mi / h_x).clamp(0.0, 1.0)
1439 }
1440}
1441
1442#[cfg(feature = "backend-rwkv")]
1443pub fn load_rwkv7_model_from_path(path: &str) -> Arc<rwkvzip::Model> {
1445 rwkvzip::Compressor::load_model(path).expect("failed to load RWKV7 model")
1446}
1447
1448#[cfg(feature = "backend-mamba")]
1449pub fn load_mamba_model_from_path(path: &str) -> Arc<mambazip::Model> {
1451 mambazip::Compressor::load_model(path).expect("failed to load Mamba model")
1452}
1453
1454#[inline(always)]
1455fn aligned_prefix<'a>(x: &'a [u8], y: &'a [u8]) -> (&'a [u8], &'a [u8]) {
1456 let n = x.len().min(y.len());
1457 (&x[..n], &y[..n])
1458}
1459
1460#[cfg(feature = "backend-zpaq")]
1461#[inline(always)]
1462fn zpaq_compress_size_bytes(data: &[u8], method: &str) -> u64 {
1463 zpaq_rs::compress_size(data, method).unwrap_or(0)
1464}
1465
1466#[cfg(not(feature = "backend-zpaq"))]
1467#[inline(always)]
1468fn zpaq_compress_size_bytes(_data: &[u8], _method: &str) -> u64 {
1469 panic!("CompressionBackend::Zpaq is unavailable: build with feature 'backend-zpaq'")
1470}
1471
1472#[cfg(feature = "backend-zpaq")]
1473#[inline(always)]
1474fn zpaq_compress_size_parallel_bytes(data: &[u8], method: &str, threads: usize) -> u64 {
1475 zpaq_rs::compress_size_parallel(data, method, threads).unwrap_or(0)
1476}
1477
1478#[cfg(not(feature = "backend-zpaq"))]
1479#[inline(always)]
1480fn zpaq_compress_size_parallel_bytes(_data: &[u8], _method: &str, _threads: usize) -> u64 {
1481 panic!("CompressionBackend::Zpaq is unavailable: build with feature 'backend-zpaq'")
1482}
1483
1484#[cfg(feature = "backend-zpaq")]
1485#[inline(always)]
1486fn zpaq_compress_size_stream<R: std::io::Read + Send>(reader: R, method: &str) -> u64 {
1487 zpaq_rs::compress_size_stream(reader, method, None, None).unwrap_or(0)
1488}
1489
1490#[cfg(not(feature = "backend-zpaq"))]
1491#[inline(always)]
1492fn zpaq_compress_size_stream<R: std::io::Read + Send>(_reader: R, _method: &str) -> u64 {
1493 panic!("CompressionBackend::Zpaq is unavailable: build with feature 'backend-zpaq'")
1494}
1495
1496#[cfg(feature = "backend-zpaq")]
1497#[inline(always)]
1498fn zpaq_compress_to_vec(data: &[u8], method: &str) -> anyhow::Result<Vec<u8>> {
1499 Ok(zpaq_rs::compress_to_vec(data, method)?)
1500}
1501
1502#[cfg(not(feature = "backend-zpaq"))]
1503#[inline(always)]
1504fn zpaq_compress_to_vec(_data: &[u8], _method: &str) -> anyhow::Result<Vec<u8>> {
1505 anyhow::bail!("zpaq backend disabled at compile time (enable feature 'backend-zpaq')")
1506}
1507
1508#[cfg(feature = "backend-zpaq")]
1509#[inline(always)]
1510fn zpaq_decompress_to_vec(data: &[u8]) -> anyhow::Result<Vec<u8>> {
1511 Ok(zpaq_rs::decompress_to_vec(data)?)
1512}
1513
1514#[cfg(not(feature = "backend-zpaq"))]
1515#[inline(always)]
1516fn zpaq_decompress_to_vec(_data: &[u8]) -> anyhow::Result<Vec<u8>> {
1517 anyhow::bail!("zpaq backend disabled at compile time (enable feature 'backend-zpaq')")
1518}
1519
1520#[inline(always)]
1522pub fn get_compressed_size(path: &str, method: &str) -> u64 {
1523 zpaq_compress_size_bytes(&std::fs::read(path).unwrap(), method)
1526}
1527
1528pub fn validate_zpaq_rate_method(method: &str) -> Result<(), String> {
1530 #[cfg(feature = "backend-zpaq")]
1531 {
1532 zpaq_rate::validate_zpaq_rate_method(method)
1533 }
1534 #[cfg(not(feature = "backend-zpaq"))]
1535 {
1536 let _ = method;
1537 Err("zpaq backend disabled at compile time".to_string())
1538 }
1539}
1540
1541#[cfg(feature = "backend-rwkv")]
1542fn with_rwkv_tls<R>(
1543 model: &Arc<rwkvzip::Model>,
1544 f: impl FnOnce(&mut rwkvzip::Compressor) -> R,
1545) -> R {
1546 let key = Arc::as_ptr(model) as usize;
1547 RWKV_TLS.with(|cell| {
1548 let mut map = cell.borrow_mut();
1549 let comp = map
1550 .entry(key)
1551 .or_insert_with(|| rwkvzip::Compressor::new_from_model(model.clone()));
1552 f(comp)
1553 })
1554}
1555
1556#[cfg(feature = "backend-rwkv")]
1557fn with_rwkv_method_tls<R>(method: &str, f: impl FnOnce(&mut rwkvzip::Compressor) -> R) -> R {
1558 RWKV_METHOD_TLS.with(|cell| {
1559 let mut map = cell.borrow_mut();
1560 let mut comp = if let Some(template) = map.get(method) {
1563 template.clone()
1564 } else {
1565 let template = rwkvzip::Compressor::new_from_method(method).unwrap_or_else(|e| {
1566 panic!("invalid rwkv method '{method}': {e:#}");
1567 });
1568 map.insert(method.to_string(), template.clone());
1569 template
1570 };
1571 drop(map);
1572 f(&mut comp)
1573 })
1574}
1575
1576#[cfg(feature = "backend-rwkv")]
1577fn with_rwkv_rate_tls<R>(
1578 model: &Arc<rwkvzip::Model>,
1579 f: impl FnOnce(&mut rwkvzip::Compressor) -> R,
1580) -> R {
1581 let key = Arc::as_ptr(model) as usize;
1582 RWKV_RATE_TLS.with(|cell| {
1583 let mut map = cell.borrow_mut();
1584 let mut comp = if let Some(template) = map.get(&key) {
1585 template.clone()
1586 } else {
1587 let template = rwkvzip::Compressor::new_from_model(model.clone());
1588 map.insert(key, template.clone());
1589 template
1590 };
1591 drop(map);
1592 f(&mut comp)
1593 })
1594}
1595
1596#[cfg(feature = "backend-mamba")]
1597fn with_mamba_tls<R>(
1598 model: &Arc<mambazip::Model>,
1599 f: impl FnOnce(&mut mambazip::Compressor) -> R,
1600) -> R {
1601 let key = Arc::as_ptr(model) as usize;
1602 MAMBA_TLS.with(|cell| {
1603 let mut map = cell.borrow_mut();
1604 let comp = map
1605 .entry(key)
1606 .or_insert_with(|| mambazip::Compressor::new_from_model(model.clone()));
1607 f(comp)
1608 })
1609}
1610
1611#[cfg(feature = "backend-mamba")]
1612fn with_mamba_rate_tls<R>(
1613 model: &Arc<mambazip::Model>,
1614 f: impl FnOnce(&mut mambazip::Compressor) -> R,
1615) -> R {
1616 let key = Arc::as_ptr(model) as usize;
1617 MAMBA_RATE_TLS.with(|cell| {
1618 let mut map = cell.borrow_mut();
1619 let mut comp = if let Some(template) = map.get(&key) {
1620 template.clone()
1621 } else {
1622 let template = mambazip::Compressor::new_from_model(model.clone());
1623 map.insert(key, template.clone());
1624 template
1625 };
1626 drop(map);
1627 f(&mut comp)
1628 })
1629}
1630
1631#[cfg(feature = "backend-mamba")]
1632fn with_mamba_method_tls<R>(method: &str, f: impl FnOnce(&mut mambazip::Compressor) -> R) -> R {
1633 MAMBA_METHOD_TLS.with(|cell| {
1634 let mut map = cell.borrow_mut();
1635 let mut comp = if let Some(template) = map.get(method) {
1636 template.clone()
1637 } else {
1638 let template = mambazip::Compressor::new_from_method(method).unwrap_or_else(|e| {
1639 panic!("invalid mamba method '{method}': {e:#}");
1640 });
1641 map.insert(method.to_string(), template.clone());
1642 template
1643 };
1644 drop(map);
1645 f(&mut comp)
1646 })
1647}
1648
1649struct SliceChainReader<'a> {
1650 parts: &'a [&'a [u8]],
1651 i: usize,
1652 off: usize,
1653}
1654
1655impl<'a> SliceChainReader<'a> {
1656 fn new(parts: &'a [&'a [u8]]) -> Self {
1657 Self {
1658 parts,
1659 i: 0,
1660 off: 0,
1661 }
1662 }
1663}
1664
1665impl<'a> std::io::Read for SliceChainReader<'a> {
1666 fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
1667 let mut total = 0;
1668 if buf.is_empty() {
1669 return Ok(0);
1670 }
1671 while self.i < self.parts.len() {
1672 let p = self.parts[self.i];
1673 if self.off >= p.len() {
1674 self.i += 1;
1675 self.off = 0;
1676 continue;
1677 }
1678 let n = (p.len() - self.off).min(buf.len());
1679 buf[..n].copy_from_slice(&p[self.off..self.off + n]);
1681
1682 self.off += n;
1684 total += n;
1685
1686 let tmp = buf;
1688 buf = &mut tmp[n..];
1689
1690 if buf.is_empty() {
1691 break;
1692 }
1693 }
1694 Ok(total)
1695 }
1696}
1697
1698pub fn compress_size_chain_backend(parts: &[&[u8]], backend: &CompressionBackend) -> u64 {
1700 match backend {
1701 CompressionBackend::Zpaq { method } => {
1702 let r = SliceChainReader::new(parts);
1703 zpaq_compress_size_stream(r, method.as_str())
1704 }
1705 #[cfg(feature = "backend-rwkv")]
1706 CompressionBackend::Rwkv7 { model, coder } => {
1707 with_rwkv_tls(model, |c| c.compress_size_chain(parts, *coder).unwrap_or(0))
1708 }
1709 CompressionBackend::Rate {
1710 rate_backend,
1711 coder,
1712 framing,
1713 } => {
1714 crate::compression::compress_rate_size_chain(parts, rate_backend, -1, *coder, *framing)
1715 .unwrap_or(0)
1716 }
1717 }
1718}
1719
1720pub fn compress_size_backend(data: &[u8], backend: &CompressionBackend) -> u64 {
1722 match backend {
1723 CompressionBackend::Zpaq { method } => zpaq_compress_size_bytes(data, method.as_str()),
1724 #[cfg(feature = "backend-rwkv")]
1725 CompressionBackend::Rwkv7 { model, coder } => {
1726 with_rwkv_tls(model, |c| c.compress_size(data, *coder).unwrap_or(0))
1727 }
1728 CompressionBackend::Rate {
1729 rate_backend,
1730 coder,
1731 framing,
1732 } => crate::compression::compress_rate_size(data, rate_backend, -1, *coder, *framing)
1733 .unwrap_or(0),
1734 }
1735}
1736
1737pub fn compress_bytes_backend(
1739 data: &[u8],
1740 backend: &CompressionBackend,
1741) -> anyhow::Result<Vec<u8>> {
1742 match backend {
1743 CompressionBackend::Zpaq { method } => zpaq_compress_to_vec(data, method),
1744 #[cfg(feature = "backend-rwkv")]
1745 CompressionBackend::Rwkv7 { model, coder } => {
1746 with_rwkv_tls(model, |c| c.compress(data, *coder))
1747 }
1748 CompressionBackend::Rate {
1749 rate_backend,
1750 coder,
1751 framing,
1752 } => crate::compression::compress_rate_bytes(data, rate_backend, -1, *coder, *framing),
1753 }
1754}
1755
1756pub fn decompress_bytes_backend(
1758 input: &[u8],
1759 backend: &CompressionBackend,
1760) -> anyhow::Result<Vec<u8>> {
1761 match backend {
1762 CompressionBackend::Zpaq { .. } => zpaq_decompress_to_vec(input),
1763 #[cfg(feature = "backend-rwkv")]
1764 CompressionBackend::Rwkv7 { model, .. } => with_rwkv_tls(model, |c| c.decompress(input)),
1765 CompressionBackend::Rate {
1766 rate_backend,
1767 coder,
1768 framing,
1769 } => crate::compression::decompress_rate_bytes(input, rate_backend, -1, *coder, *framing),
1770 }
1771}
1772
1773fn prequential_rate_backend(
1774 data: &[u8],
1775 prefix_parts: &[&[u8]],
1776 max_order: i64,
1777 backend: &RateBackend,
1778) -> f64 {
1779 use crate::mixture::OnlineBytePredictor;
1780
1781 if data.is_empty() {
1782 return 0.0;
1783 }
1784 let total = prefix_parts
1785 .iter()
1786 .map(|p| p.len() as u64)
1787 .sum::<u64>()
1788 .saturating_add(data.len() as u64);
1789 let mut predictor = crate::mixture::RateBackendPredictor::from_backend(
1790 backend.clone(),
1791 max_order,
1792 crate::mixture::DEFAULT_MIN_PROB,
1793 );
1794 predictor
1795 .begin_stream(Some(total))
1796 .unwrap_or_else(|e| panic!("rate backend stream init failed: {e}"));
1797 for prefix in prefix_parts {
1798 for &b in *prefix {
1799 predictor.update(b);
1800 }
1801 }
1802 let mut bits = 0.0;
1803 for &b in data {
1804 bits -= predictor.log_prob(b) / std::f64::consts::LN_2;
1805 predictor.update(b);
1806 }
1807 predictor
1808 .finish_stream()
1809 .unwrap_or_else(|e| panic!("rate backend stream finalize failed: {e}"));
1810 bits / (data.len() as f64)
1811}
1812
1813fn frozen_plugin_rate_backend(
1814 score_data: &[u8],
1815 fit_parts: &[&[u8]],
1816 max_order: i64,
1817 backend: &RateBackend,
1818) -> f64 {
1819 if score_data.is_empty() {
1820 return 0.0;
1821 }
1822 if matches!(backend, RateBackend::RosaPlus) {
1823 let mut model = rosaplus::RosaPlus::new(max_order, false, 0, 42);
1824 for part in fit_parts {
1825 model.train_example(part);
1826 }
1827 model.build_lm();
1828 return model.cross_entropy(score_data);
1829 }
1830 #[cfg(feature = "backend-rwkv")]
1831 match backend {
1832 RateBackend::Rwkv7 { model } => {
1833 return with_rwkv_rate_tls(model, |c| {
1834 c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
1835 .unwrap_or_else(|e| panic!("rwkv frozen-plugin scoring failed: {e:#}"))
1836 });
1837 }
1838 RateBackend::Rwkv7Method { method } => {
1839 return with_rwkv_method_tls(method, |c| {
1840 c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
1841 .unwrap_or_else(|e| panic!("rwkv method frozen-plugin scoring failed: {e:#}"))
1842 });
1843 }
1844 _ => {}
1845 }
1846 #[cfg(feature = "backend-mamba")]
1847 match backend {
1848 RateBackend::Mamba { model } => {
1849 return with_mamba_rate_tls(model, |c| {
1850 c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
1851 .unwrap_or_else(|e| panic!("mamba frozen-plugin scoring failed: {e:#}"))
1852 });
1853 }
1854 RateBackend::MambaMethod { method } => {
1855 return with_mamba_method_tls(method, |c| {
1856 c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
1857 .unwrap_or_else(|e| panic!("mamba method frozen-plugin scoring failed: {e:#}"))
1858 });
1859 }
1860 _ => {}
1861 }
1862
1863 use crate::mixture::OnlineBytePredictor;
1864
1865 let fit_total = fit_parts.iter().map(|part| part.len() as u64).sum::<u64>();
1866 let mut predictor = crate::mixture::RateBackendPredictor::from_backend(
1867 backend.clone(),
1868 max_order,
1869 crate::mixture::DEFAULT_MIN_PROB,
1870 );
1871 predictor
1872 .begin_stream(Some(fit_total))
1873 .unwrap_or_else(|e| panic!("rate backend fit-pass init failed: {e}"));
1874 for part in fit_parts {
1875 for &byte in *part {
1876 predictor.update(byte);
1877 }
1878 }
1879 predictor
1880 .finish_stream()
1881 .unwrap_or_else(|e| panic!("rate backend fit-pass finalize failed: {e}"));
1882 predictor
1883 .reset_frozen(Some(score_data.len() as u64))
1884 .unwrap_or_else(|e| panic!("rate backend frozen-score reset failed: {e}"));
1885 let mut bits = 0.0;
1886 for &byte in score_data {
1887 bits -= predictor.log_prob(byte) / std::f64::consts::LN_2;
1888 predictor.update_frozen(byte);
1889 }
1890 predictor
1891 .finish_stream()
1892 .unwrap_or_else(|e| panic!("rate backend frozen-score finalize failed: {e}"));
1893 bits / (score_data.len() as f64)
1894}
1895
1896#[inline(always)]
1897fn argmax_log_prob_byte(logps: &[f64; 256]) -> u8 {
1898 let mut best_idx = 0usize;
1899 let mut best = f64::NEG_INFINITY;
1900 for (idx, &logp) in logps.iter().enumerate() {
1901 let score = if logp.is_finite() {
1902 logp
1903 } else {
1904 f64::NEG_INFINITY
1905 };
1906 if score > best {
1907 best = score;
1908 best_idx = idx;
1909 }
1910 }
1911 best_idx as u8
1912}
1913
1914fn pick_generated_byte(
1915 logps: &[f64; 256],
1916 config: GenerationConfig,
1917 rng: &mut GenerationRng,
1918) -> u8 {
1919 if matches!(config.strategy, GenerationStrategy::Greedy)
1920 || !config.temperature.is_finite()
1921 || config.temperature <= 0.0
1922 {
1923 return argmax_log_prob_byte(logps);
1924 }
1925
1926 let mut entries = [(0u8, f64::NEG_INFINITY); 256];
1927 for (idx, &logp) in logps.iter().enumerate() {
1928 let scaled = if logp.is_finite() {
1929 logp / config.temperature
1930 } else {
1931 f64::NEG_INFINITY
1932 };
1933 entries[idx] = (idx as u8, scaled);
1934 }
1935 entries.sort_by(|a, b| b.1.total_cmp(&a.1));
1936
1937 let keep_k = if config.top_k == 0 {
1938 entries.len()
1939 } else {
1940 config.top_k.min(entries.len())
1941 };
1942
1943 let top_p = if config.top_p.is_finite() {
1944 config.top_p.clamp(0.0, 1.0)
1945 } else {
1946 1.0
1947 };
1948
1949 let mut max_logp = f64::NEG_INFINITY;
1950 for &(_, logp) in entries.iter().take(keep_k) {
1951 if logp.is_finite() {
1952 max_logp = max_logp.max(logp);
1953 }
1954 }
1955 if !max_logp.is_finite() {
1956 return argmax_log_prob_byte(logps);
1957 }
1958
1959 let mut weights = [(0u8, 0.0f64); 256];
1960 let mut total = 0.0;
1961 for (idx, &(byte, logp)) in entries.iter().take(keep_k).enumerate() {
1962 let w = if logp.is_finite() {
1963 (logp - max_logp).exp()
1964 } else {
1965 0.0
1966 };
1967 weights[idx] = (byte, w);
1968 total += w;
1969 }
1970 if !(total.is_finite()) || total <= 0.0 {
1971 return argmax_log_prob_byte(logps);
1972 }
1973
1974 let cutoff_count = if top_p >= 1.0 {
1975 keep_k
1976 } else {
1977 let mut cumulative = 0.0;
1978 let mut keep = 0usize;
1979 for &(_, w) in weights.iter().take(keep_k) {
1980 cumulative += w / total;
1981 keep += 1;
1982 if cumulative >= top_p {
1983 break;
1984 }
1985 }
1986 keep.max(1)
1987 };
1988
1989 let mut truncated_total = 0.0;
1990 for &(_, w) in weights.iter().take(cutoff_count) {
1991 truncated_total += w;
1992 }
1993 if !(truncated_total.is_finite()) || truncated_total <= 0.0 {
1994 return argmax_log_prob_byte(logps);
1995 }
1996
1997 let target = rng.next_f64() * truncated_total;
1998 let mut cumulative = 0.0;
1999 let mut picked = weights[0].0;
2000 for &(byte, weight) in weights.iter().take(cutoff_count) {
2001 cumulative += weight;
2002 if cumulative >= target {
2003 picked = byte;
2004 break;
2005 }
2006 }
2007 picked
2008}
2009
2010fn generate_rate_backend_chain(
2011 prefix_parts: &[&[u8]],
2012 bytes: usize,
2013 max_order: i64,
2014 backend: &RateBackend,
2015 config: GenerationConfig,
2016) -> Vec<u8> {
2017 if bytes == 0 {
2018 return Vec::new();
2019 }
2020
2021 let total = prefix_parts
2022 .iter()
2023 .map(|p| p.len() as u64)
2024 .sum::<u64>()
2025 .saturating_add(bytes as u64);
2026 let mut session = RateBackendSession::from_backend(backend.clone(), max_order, Some(total))
2027 .unwrap_or_else(|e| panic!("rate backend generation init failed: {e}"));
2028 for &part in prefix_parts {
2029 session.observe(part);
2030 }
2031 let out = session.generate_bytes(bytes, config);
2032 session
2033 .finish()
2034 .unwrap_or_else(|e| panic!("rate backend generation finalize failed: {e}"));
2035 out
2036}
2037
2038pub fn entropy_rate_backend(data: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
2040 match backend {
2041 RateBackend::RosaPlus => {
2042 let mut m = rosaplus::RosaPlus::new(max_order, false, 0, 42);
2043 m.predictive_entropy_rate(data)
2044 }
2045 RateBackend::Match { .. }
2046 | RateBackend::SparseMatch { .. }
2047 | RateBackend::Ppmd { .. }
2048 | RateBackend::Sequitur { .. }
2049 | RateBackend::Calibrated { .. } => prequential_rate_backend(data, &[], max_order, backend),
2050 #[cfg(feature = "backend-rwkv")]
2051 RateBackend::Rwkv7 { model } => with_rwkv_tls(model, |c| {
2052 c.cross_entropy(data)
2053 .unwrap_or_else(|e| panic!("rwkv entropy scoring failed: {e:#}"))
2054 }),
2055 #[cfg(feature = "backend-rwkv")]
2056 RateBackend::Rwkv7Method { method } => with_rwkv_method_tls(method, |c| {
2057 c.cross_entropy(data)
2058 .unwrap_or_else(|e| panic!("rwkv method entropy scoring failed: {e:#}"))
2059 }),
2060 #[cfg(feature = "backend-mamba")]
2061 RateBackend::Mamba { model } => with_mamba_tls(model, |c| {
2062 c.cross_entropy(data)
2063 .unwrap_or_else(|e| panic!("mamba entropy scoring failed: {e:#}"))
2064 }),
2065 #[cfg(feature = "backend-mamba")]
2066 RateBackend::MambaMethod { method } => with_mamba_method_tls(method, |c| {
2067 c.cross_entropy(data)
2068 .unwrap_or_else(|e| panic!("mamba method entropy scoring failed: {e:#}"))
2069 }),
2070 RateBackend::Zpaq { method } => {
2071 if data.is_empty() {
2072 return 0.0;
2073 }
2074 let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
2075 let bits = model.update_and_score(data);
2076 bits / (data.len() as f64)
2077 }
2078 RateBackend::Mixture { spec } => {
2079 if data.is_empty() {
2080 return 0.0;
2081 }
2082 let experts = spec.build_experts();
2083 let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
2084 .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
2085 mix.begin_stream(Some(data.len() as u64))
2086 .unwrap_or_else(|e| panic!("Mixture stream init failed: {e}"));
2087 let mut bits = 0.0;
2088 for &b in data {
2089 bits -= mix.step(b) / std::f64::consts::LN_2;
2090 }
2091 mix.finish_stream()
2092 .unwrap_or_else(|e| panic!("Mixture stream finalize failed: {e}"));
2093 bits / (data.len() as f64)
2094 }
2095 RateBackend::Particle { spec } => {
2096 if data.is_empty() {
2097 return 0.0;
2098 }
2099 let mut runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
2100 let mut bits = 0.0;
2101 for &b in data {
2102 bits -= runtime.step(b) / std::f64::consts::LN_2;
2103 }
2104 bits / (data.len() as f64)
2105 }
2106 RateBackend::Ctw { depth } => {
2107 if data.is_empty() {
2108 return 0.0;
2109 }
2110 let mut fac = crate::ctw::FacContextTree::new(*depth, 8);
2112 fac.reserve_for_symbols(data.len());
2113 for &b in data {
2114 fac.update_byte_msb(b);
2115 }
2116 let ln_p = fac.get_log_block_probability();
2117 let bits = -ln_p / std::f64::consts::LN_2;
2118 bits / (data.len() as f64)
2119 }
2120 RateBackend::FacCtw {
2121 base_depth,
2122 num_percept_bits: _,
2123 encoding_bits,
2124 } => {
2125 if data.is_empty() {
2126 return 0.0;
2127 }
2128 let bits_per_byte = (*encoding_bits).clamp(1, 8);
2129 let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte);
2130 fac.reserve_for_symbols(data.len());
2131 for &b in data {
2132 fac.update_byte_lsb(b);
2133 }
2134 let ln_p = fac.get_log_block_probability();
2135 let bits = -ln_p / std::f64::consts::LN_2;
2136 bits / (data.len() as f64)
2137 }
2138 }
2139}
2140
2141pub fn biased_entropy_rate_backend(data: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
2143 match backend {
2144 RateBackend::Zpaq { .. } => {
2145 panic!("biased/plugin entropy is not supported for zpaq rate backends in 1.1.1")
2146 }
2147 _ => frozen_plugin_rate_backend(data, &[data], max_order, backend),
2148 }
2149}
2150
2151pub fn cross_entropy_rate_backend(
2153 test_data: &[u8],
2154 train_data: &[u8],
2155 max_order: i64,
2156 backend: &RateBackend,
2157) -> f64 {
2158 match backend {
2159 RateBackend::Zpaq { method } => {
2160 if test_data.is_empty() {
2161 return 0.0;
2162 }
2163 let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
2164 model.update_and_score(train_data);
2165 let bits = model.update_and_score(test_data);
2166 bits / (test_data.len() as f64)
2167 }
2168 _ => frozen_plugin_rate_backend(test_data, &[train_data], max_order, backend),
2169 }
2170}
2171
2172pub fn joint_entropy_rate_backend(
2174 x: &[u8],
2175 y: &[u8],
2176 max_order: i64,
2177 backend: &RateBackend,
2178) -> f64 {
2179 let (x, y) = aligned_prefix(x, y);
2180 if x.is_empty() {
2181 return 0.0;
2182 }
2183 match backend {
2184 RateBackend::RosaPlus => {
2185 let joint_symbols: Vec<u32> = (0..x.len())
2186 .map(|i| (x[i] as u32) * 256 + (y[i] as u32))
2187 .collect();
2188 let mut m = rosaplus::RosaPlus::new(max_order, false, 0, 42);
2189 m.entropy_rate_cps(&joint_symbols)
2190 }
2191 RateBackend::Match { .. }
2192 | RateBackend::SparseMatch { .. }
2193 | RateBackend::Ppmd { .. }
2194 | RateBackend::Sequitur { .. }
2195 | RateBackend::Calibrated { .. } => {
2196 let mut joint = Vec::with_capacity(x.len() * 2);
2197 for (&xb, &yb) in x.iter().zip(y.iter()) {
2198 joint.push(xb);
2199 joint.push(yb);
2200 }
2201 entropy_rate_backend(&joint, max_order, backend) * 2.0
2202 }
2203 #[cfg(feature = "backend-rwkv")]
2204 RateBackend::Rwkv7 { model } => with_rwkv_tls(model, |c| {
2205 c.joint_cross_entropy_aligned_min(x, y)
2206 .unwrap_or_else(|e| panic!("rwkv joint-entropy scoring failed: {e:#}"))
2207 }),
2208 #[cfg(feature = "backend-rwkv")]
2209 RateBackend::Rwkv7Method { method } => with_rwkv_method_tls(method, |c| {
2210 c.joint_cross_entropy_aligned_min(x, y)
2211 .unwrap_or_else(|e| panic!("rwkv method joint-entropy scoring failed: {e:#}"))
2212 }),
2213 #[cfg(feature = "backend-mamba")]
2214 RateBackend::Mamba { model } => with_mamba_tls(model, |c| {
2215 c.joint_cross_entropy_aligned_min(x, y)
2216 .unwrap_or_else(|e| panic!("mamba joint-entropy scoring failed: {e:#}"))
2217 }),
2218 #[cfg(feature = "backend-mamba")]
2219 RateBackend::MambaMethod { method } => with_mamba_method_tls(method, |c| {
2220 c.joint_cross_entropy_aligned_min(x, y)
2221 .unwrap_or_else(|e| panic!("mamba method joint-entropy scoring failed: {e:#}"))
2222 }),
2223 RateBackend::Zpaq { method } => {
2224 let mut joint = Vec::with_capacity(x.len() * 2);
2225 for (&xb, &yb) in x.iter().zip(y.iter()) {
2226 joint.push(xb);
2227 joint.push(yb);
2228 }
2229 let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
2230 let bits = model.update_and_score(&joint);
2231 bits / (x.len() as f64)
2232 }
2233 RateBackend::Mixture { spec } => {
2234 let mut joint = Vec::with_capacity(x.len() * 2);
2235 for (&xb, &yb) in x.iter().zip(y.iter()) {
2236 joint.push(xb);
2237 joint.push(yb);
2238 }
2239 let experts = spec.build_experts();
2240 let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
2241 .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
2242 mix.begin_stream(Some(joint.len() as u64))
2243 .unwrap_or_else(|e| panic!("Mixture stream init failed: {e}"));
2244 let mut bits = 0.0;
2245 for &b in &joint {
2246 bits -= mix.step(b) / std::f64::consts::LN_2;
2247 }
2248 mix.finish_stream()
2249 .unwrap_or_else(|e| panic!("Mixture stream finalize failed: {e}"));
2250 bits / (x.len() as f64)
2251 }
2252 RateBackend::Particle { spec } => {
2253 let mut joint = Vec::with_capacity(x.len() * 2);
2254 for (&xb, &yb) in x.iter().zip(y.iter()) {
2255 joint.push(xb);
2256 joint.push(yb);
2257 }
2258 let mut runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
2259 let mut bits = 0.0;
2260 for &b in &joint {
2261 bits -= runtime.step(b) / std::f64::consts::LN_2;
2262 }
2263 bits / (x.len() as f64)
2264 }
2265 RateBackend::Ctw { depth } => {
2266 let mut fac = crate::ctw::FacContextTree::new(*depth, 16);
2272 for k in 0..x.len() {
2273 let bx = x[k];
2274 let by = y[k];
2275 for bit_idx in 0..8 {
2276 let bit_x = ((bx >> (7 - bit_idx)) & 1) == 1;
2277 let bit_y = ((by >> (7 - bit_idx)) & 1) == 1;
2278 fac.update(bit_x, bit_idx);
2279 fac.update(bit_y, bit_idx + 8);
2280 }
2281 }
2282 let ln_p = fac.get_log_block_probability();
2283 let bits = -ln_p / std::f64::consts::LN_2;
2284 bits / (x.len() as f64)
2285 }
2286 RateBackend::FacCtw {
2287 base_depth,
2288 num_percept_bits: _,
2289 encoding_bits,
2290 } => {
2291 let bits_per_byte = (*encoding_bits).clamp(1, 8);
2293 let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte * 2);
2294 for k in 0..x.len() {
2295 let bx = x[k];
2296 let by = y[k];
2297 for i in 0..bits_per_byte {
2298 let bit_idx_x = i * 2;
2303 let bit_idx_y = bit_idx_x + 1;
2304 fac.update(((bx >> i) & 1) == 1, bit_idx_x);
2305 fac.update(((by >> i) & 1) == 1, bit_idx_y);
2306 }
2307 }
2308 let ln_p = fac.get_log_block_probability();
2309 let bits = -ln_p / std::f64::consts::LN_2;
2310 bits / (x.len() as f64)
2311 }
2312 }
2313}
2314#[inline(always)]
2315pub fn get_compressed_size_parallel(path: &str, method: &str, threads: usize) -> u64 {
2317 zpaq_compress_size_parallel_bytes(&std::fs::read(path).unwrap(), method, threads)
2320}
2321
2322#[inline(always)]
2323pub fn get_bytes_from_paths(paths: &[&str]) -> Vec<Vec<u8>> {
2325 paths
2326 .par_iter()
2327 .map(|path| std::fs::read(*path).expect("failed to read file"))
2328 .collect()
2329}
2330
2331#[inline(always)]
2333pub fn get_sequential_compressed_sizes_from_sequential_paths(
2334 paths: &[&str],
2335 method: &str,
2336) -> Vec<u64> {
2337 get_bytes_from_paths(paths)
2342 .par_iter()
2343 .map(|data| zpaq_compress_size_bytes(data, method))
2344 .collect()
2345}
2346
2347#[inline(always)]
2348pub fn get_parallel_compressed_sizes_from_sequential_paths(
2350 paths: &[&str],
2351 method: &str,
2352 threads: usize,
2353) -> Vec<u64> {
2354 get_bytes_from_paths(paths)
2358 .par_iter()
2359 .map(|data| zpaq_compress_size_parallel_bytes(data, method, threads))
2360 .collect()
2361}
2362
2363#[inline(always)]
2364pub fn get_sequential_compressed_sizes_from_parallel_paths(
2366 paths: &[&str],
2367 method: &str,
2368) -> Vec<u64> {
2369 paths
2373 .par_iter()
2374 .map(|path| get_compressed_size(path, method))
2375 .collect()
2376}
2377
2378#[inline(always)]
2379pub fn get_parallel_compressed_sizes_from_parallel_paths(
2381 paths: &[&str],
2382 method: &str,
2383 threads: usize,
2384) -> Vec<u64> {
2385 paths
2390 .par_iter()
2391 .map(|path| get_compressed_size_parallel(path, method, threads))
2392 .collect()
2393}
2394
2395#[inline(always)]
2397pub fn get_compressed_sizes_from_paths(paths: &[&str], method: &str) -> Vec<u64> {
2398 let n: usize = paths.len();
2399 let num_threads: usize = *NUM_THREADS.get_or_init(num_cpus::get);
2400 if n < num_threads {
2401 get_parallel_compressed_sizes_from_parallel_paths(paths, method, num_threads.div_ceil(n))
2402 } else {
2403 get_sequential_compressed_sizes_from_parallel_paths(paths, method)
2404 }
2405}
2406
2407#[derive(Clone, Copy, Debug, Eq, PartialEq)]
2418pub enum NcdVariant {
2419 Vitanyi,
2423 SymVitanyi,
2427 Cons,
2431 SymCons,
2434}
2435
2436#[inline(always)]
2437fn compress_size_bytes(data: &[u8], method: &str) -> u64 {
2438 zpaq_compress_size_bytes(data, method)
2439}
2440
2441#[inline(always)]
2442fn ncd_from_sizes(cx: u64, cy: u64, cxy: u64, cyx: Option<u64>, variant: NcdVariant) -> f64 {
2443 let min_c = cx.min(cy) as f64;
2444 let max_c = cx.max(cy) as f64;
2445
2446 match variant {
2447 NcdVariant::Vitanyi => {
2448 if max_c == 0.0 {
2449 0.0
2450 } else {
2451 (cxy as f64 - min_c) / max_c
2452 }
2453 }
2454 NcdVariant::SymVitanyi => {
2455 let m = cxy.min(cyx.expect("cyx required for SymVitanyi")) as f64;
2456 if max_c == 0.0 {
2457 0.0
2458 } else {
2459 (m - min_c) / max_c
2460 }
2461 }
2462 NcdVariant::Cons => {
2463 let denom = cxy as f64;
2464 if denom == 0.0 {
2465 0.0
2466 } else {
2467 (cxy as f64 - min_c) / denom
2468 }
2469 }
2470 NcdVariant::SymCons => {
2471 let m = cxy.min(cyx.expect("cyx required for SymCons")) as f64;
2472 if m == 0.0 { 0.0 } else { (m - min_c) / m }
2473 }
2474 }
2475}
2476
2477#[inline(always)]
2478pub fn ncd_bytes(x: &[u8], y: &[u8], method: &str, variant: NcdVariant) -> f64 {
2480 let backend = CompressionBackend::Zpaq {
2481 method: method.to_string(),
2482 };
2483 ncd_bytes_backend(x, y, &backend, variant)
2484}
2485
2486#[inline(always)]
2488pub fn ncd_bytes_default(x: &[u8], y: &[u8], variant: NcdVariant) -> f64 {
2489 with_default_ctx(|ctx| ctx.ncd_bytes(x, y, variant))
2490}
2491
2492pub fn ncd_bytes_backend(
2494 x: &[u8],
2495 y: &[u8],
2496 backend: &CompressionBackend,
2497 variant: NcdVariant,
2498) -> f64 {
2499 let (cx, cy) = rayon::join(
2500 || compress_size_backend(x, backend),
2501 || compress_size_backend(y, backend),
2502 );
2503
2504 let cxy = compress_size_chain_backend(&[x, y], backend);
2505
2506 let cyx = match variant {
2507 NcdVariant::SymVitanyi | NcdVariant::SymCons => {
2508 Some(compress_size_chain_backend(&[y, x], backend))
2509 }
2510 _ => None,
2511 };
2512
2513 ncd_from_sizes(cx, cy, cxy, cyx, variant)
2514}
2515
2516#[inline(always)]
2517pub fn ncd_paths(x: &str, y: &str, method: &str, variant: NcdVariant) -> f64 {
2519 let (bx, by) = rayon::join(
2520 || std::fs::read(x).expect("failed to read x"),
2521 || std::fs::read(y).expect("failed to read y"),
2522 );
2523 ncd_bytes(&bx, &by, method, variant)
2524}
2525
2526pub fn ncd_paths_backend(
2528 x: &str,
2529 y: &str,
2530 backend: &CompressionBackend,
2531 variant: NcdVariant,
2532) -> f64 {
2533 let (bx, by) = rayon::join(
2534 || std::fs::read(x).expect("failed to read x"),
2535 || std::fs::read(y).expect("failed to read y"),
2536 );
2537 ncd_bytes_backend(&bx, &by, backend, variant)
2538}
2539
2540#[inline(always)]
2542pub fn ncd_vitanyi(x: &str, y: &str, method: &str) -> f64 {
2543 ncd_paths(x, y, method, NcdVariant::Vitanyi)
2544}
2545#[inline(always)]
2546pub fn ncd_sym_vitanyi(x: &str, y: &str, method: &str) -> f64 {
2548 ncd_paths(x, y, method, NcdVariant::SymVitanyi)
2549}
2550#[inline(always)]
2551pub fn ncd_cons(x: &str, y: &str, method: &str) -> f64 {
2553 ncd_paths(x, y, method, NcdVariant::Cons)
2554}
2555#[inline(always)]
2556pub fn ncd_sym_cons(x: &str, y: &str, method: &str) -> f64 {
2558 ncd_paths(x, y, method, NcdVariant::SymCons)
2559}
2560
2561pub fn ncd_matrix_bytes(datas: &[Vec<u8>], method: &str, variant: NcdVariant) -> Vec<f64> {
2565 let n = datas.len();
2566 let cx: Vec<u64> = datas
2567 .par_iter()
2568 .map(|d| compress_size_bytes(d, method))
2569 .collect();
2570
2571 let mut out = vec![0.0f64; n * n];
2572 let out_ptr = std::sync::atomic::AtomicPtr::new(out.as_mut_ptr());
2573
2574 match variant {
2575 NcdVariant::SymVitanyi | NcdVariant::SymCons => {
2576 (0..n)
2577 .into_par_iter()
2578 .flat_map_iter(|i| (i + 1..n).map(move |j| (i, j)))
2579 .for_each_init(Vec::<u8>::new, |buf, (i, j)| {
2580 let x = &datas[i];
2581 let y = &datas[j];
2582
2583 buf.clear();
2584 buf.reserve(x.len() + y.len());
2585 buf.extend_from_slice(x);
2586 buf.extend_from_slice(y);
2587 let cxy = compress_size_bytes(buf, method);
2588
2589 buf.clear();
2590 buf.reserve(x.len() + y.len());
2591 buf.extend_from_slice(y);
2592 buf.extend_from_slice(x);
2593 let cyx = compress_size_bytes(buf, method);
2594
2595 let d = ncd_from_sizes(cx[i], cx[j], cxy, Some(cyx), variant);
2596
2597 let p = out_ptr.load(std::sync::atomic::Ordering::Relaxed);
2599 unsafe {
2600 *p.add(i * n + j) = d;
2601 *p.add(j * n + i) = d;
2602 }
2603 });
2604 }
2605 NcdVariant::Vitanyi | NcdVariant::Cons => {
2606 (0..n)
2607 .into_par_iter()
2608 .for_each_init(Vec::<u8>::new, |buf, i| {
2609 let x = &datas[i];
2610 for j in 0..n {
2611 let d = if i == j {
2612 0.0
2613 } else {
2614 let y = &datas[j];
2615 buf.clear();
2616 buf.reserve(x.len() + y.len());
2617 buf.extend_from_slice(x);
2618 buf.extend_from_slice(y);
2619 let cxy = compress_size_bytes(buf, method);
2620 ncd_from_sizes(cx[i], cx[j], cxy, None, variant)
2621 };
2622
2623 let p = out_ptr.load(std::sync::atomic::Ordering::Relaxed);
2624 unsafe {
2625 *p.add(i * n + j) = d;
2626 }
2627 }
2628 });
2629 }
2630 }
2631
2632 out
2633}
2634
2635pub fn ncd_matrix_paths(paths: &[&str], method: &str, variant: NcdVariant) -> Vec<f64> {
2637 let datas = get_bytes_from_paths(paths);
2638 ncd_matrix_bytes(&datas, method, variant)
2639}
2640
2641#[inline(always)]
2653pub fn marginal_entropy_bytes(data: &[u8]) -> f64 {
2654 if data.is_empty() {
2655 return 0.0;
2656 }
2657
2658 let mut counts = [0u64; 256];
2659 for &b in data {
2660 counts[b as usize] += 1;
2661 }
2662
2663 let n = data.len() as f64;
2664 let mut h = 0.0f64;
2665 for &count in &counts {
2666 if count > 0 {
2667 let p = count as f64 / n;
2668 h -= p * p.log2();
2669 }
2670 }
2671 h
2672}
2673
2674#[inline(always)]
2689pub fn entropy_rate_bytes(data: &[u8], max_order: i64) -> f64 {
2690 with_default_ctx(|ctx| ctx.entropy_rate_bytes(data, max_order))
2691}
2692
2693#[inline(always)]
2699pub fn biased_entropy_rate_bytes(data: &[u8], max_order: i64) -> f64 {
2700 with_default_ctx(|ctx| ctx.biased_entropy_rate_bytes(data, max_order))
2701}
2702
2703#[inline(always)]
2708pub fn joint_marginal_entropy_bytes(x: &[u8], y: &[u8]) -> f64 {
2709 let (x, y) = aligned_prefix(x, y);
2710 let n = x.len();
2711 if n == 0 {
2712 return 0.0;
2713 }
2714
2715 let mut counts = vec![0u64; 256 * 256];
2718 for i in 0..n {
2719 let pair_idx = (x[i] as usize) * 256 + (y[i] as usize);
2720 counts[pair_idx] += 1;
2721 }
2722
2723 let n_f64 = n as f64;
2724 let mut h = 0.0f64;
2725 for &c in &counts {
2726 if c > 0 {
2727 let p = c as f64 / n_f64;
2728 h -= p * p.log2();
2729 }
2730 }
2731 h
2732}
2733
2734#[inline(always)]
2746pub fn joint_entropy_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2747 with_default_ctx(|ctx| ctx.joint_entropy_rate_bytes(x, y, max_order))
2748}
2749
2750#[inline(always)]
2758pub fn conditional_entropy_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2759 with_default_ctx(|ctx| ctx.conditional_entropy_rate_bytes(x, y, max_order))
2760}
2761
2762#[inline(always)]
2766pub fn conditional_entropy_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2767 with_default_ctx(|ctx| ctx.conditional_entropy_bytes(x, y, max_order))
2768}
2769
2770#[inline(always)]
2776pub fn mutual_information_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2777 with_default_ctx(|ctx| ctx.mutual_information_bytes(x, y, max_order))
2778}
2779
2780pub fn mutual_information_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
2782 let (x, y) = aligned_prefix(x, y);
2783 let h_x = marginal_entropy_bytes(x);
2784 let h_y = marginal_entropy_bytes(y);
2785 let h_xy = joint_marginal_entropy_bytes(x, y);
2786 (h_x + h_y - h_xy).max(0.0)
2787}
2788
2789#[inline(always)]
2791pub fn mutual_information_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2792 with_default_ctx(|ctx| ctx.mutual_information_rate_bytes(x, y, max_order))
2793}
2794
2795#[inline(always)]
2807pub fn ned_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2808 with_default_ctx(|ctx| ctx.ned_bytes(x, y, max_order))
2809}
2810
2811pub fn ned_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
2813 let (x, y) = aligned_prefix(x, y);
2814 let h_x = marginal_entropy_bytes(x);
2815 let h_y = marginal_entropy_bytes(y);
2816 let h_xy = joint_marginal_entropy_bytes(x, y);
2817 let min_h = h_x.min(h_y);
2818 let max_h = h_x.max(h_y);
2819 if max_h == 0.0 {
2820 0.0
2821 } else {
2822 ((h_xy - min_h) / max_h).clamp(0.0, 1.0)
2823 }
2824}
2825
2826#[inline(always)]
2828pub fn ned_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2829 with_default_ctx(|ctx| ctx.ned_bytes(x, y, max_order))
2830}
2831
2832#[inline(always)]
2836pub fn ned_cons_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2837 with_default_ctx(|ctx| ctx.ned_cons_bytes(x, y, max_order))
2838}
2839
2840pub fn ned_cons_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
2842 let h_x = marginal_entropy_bytes(x);
2843 let h_y = marginal_entropy_bytes(y);
2844 let h_xy = joint_marginal_entropy_bytes(x, y);
2845 let min_h = h_x.min(h_y);
2846 if h_xy == 0.0 {
2847 0.0
2848 } else {
2849 ((h_xy - min_h) / h_xy).clamp(0.0, 1.0)
2850 }
2851}
2852
2853#[inline(always)]
2854pub fn ned_cons_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2856 with_default_ctx(|ctx| ctx.ned_cons_bytes(x, y, max_order))
2857}
2858
2859#[inline(always)]
2874pub fn nte_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2875 with_default_ctx(|ctx| ctx.nte_bytes(x, y, max_order))
2876}
2877
2878pub fn nte_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
2880 let (x, y) = aligned_prefix(x, y);
2881 let h_x = marginal_entropy_bytes(x);
2882 let h_y = marginal_entropy_bytes(y);
2883 let h_xy = joint_marginal_entropy_bytes(x, y);
2884 let vi = 2.0 * h_xy - h_x - h_y;
2885 let max_h = h_x.max(h_y);
2886 if max_h == 0.0 {
2887 0.0
2888 } else {
2889 (vi / max_h).clamp(0.0, 2.0)
2890 }
2891}
2892
2893#[inline(always)]
2894pub fn nte_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2896 with_default_ctx(|ctx| ctx.nte_bytes(x, y, max_order))
2897}
2898
2899#[inline(always)]
2903fn byte_histogram(data: &[u8]) -> [f64; 256] {
2904 let mut counts = [0u64; 256];
2905 for &b in data {
2906 counts[b as usize] += 1;
2907 }
2908 let n = data.len() as f64;
2909 let mut probs = [0.0f64; 256];
2910 if n == 0.0 {
2911 return probs;
2912 }
2913 for i in 0..256 {
2914 probs[i] = counts[i] as f64 / n;
2915 }
2916 probs
2917}
2918
2919#[inline(always)]
2925pub fn tvd_bytes(x: &[u8], y: &[u8], _max_order: i64) -> f64 {
2926 if x.is_empty() || y.is_empty() {
2927 return 0.0;
2928 }
2929 let p_x = byte_histogram(x);
2930 let p_y = byte_histogram(y);
2931
2932 let mut sum = 0.0f64;
2933 for i in 0..256 {
2934 sum += (p_x[i] - p_y[i]).abs();
2935 }
2936
2937 (sum / 2.0).clamp(0.0, 1.0)
2938}
2939
2940#[inline(always)]
2947pub fn nhd_bytes(x: &[u8], y: &[u8], _max_order: i64) -> f64 {
2948 if x.is_empty() || y.is_empty() {
2949 return 0.0;
2950 }
2951 let p_x = byte_histogram(x);
2952 let p_y = byte_histogram(y);
2953
2954 let mut bc = 0.0f64;
2956 for i in 0..256 {
2957 bc += (p_x[i] * p_y[i]).sqrt();
2958 }
2959
2960 (1.0 - bc).max(0.0).sqrt()
2962}
2963
2964#[inline(always)]
2970pub fn cross_entropy_bytes(test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
2971 with_default_ctx(|ctx| ctx.cross_entropy_bytes(test_data, train_data, max_order))
2972}
2973
2974#[inline(always)]
2977pub fn cross_entropy_rate_bytes(test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
2978 with_default_ctx(|ctx| ctx.cross_entropy_rate_bytes(test_data, train_data, max_order))
2979}
2980
2981#[inline(always)]
2986pub fn generate_bytes(prompt: &[u8], bytes: usize, max_order: i64) -> Vec<u8> {
2987 with_default_ctx(|ctx| ctx.generate_bytes(prompt, bytes, max_order))
2988}
2989
2990#[inline(always)]
2992pub fn generate_bytes_with_config(
2993 prompt: &[u8],
2994 bytes: usize,
2995 max_order: i64,
2996 config: GenerationConfig,
2997) -> Vec<u8> {
2998 with_default_ctx(|ctx| ctx.generate_bytes_with_config(prompt, bytes, max_order, config))
2999}
3000
3001#[inline(always)]
3004pub fn generate_bytes_conditional_chain(
3005 prefix_parts: &[&[u8]],
3006 bytes: usize,
3007 max_order: i64,
3008) -> Vec<u8> {
3009 with_default_ctx(|ctx| ctx.generate_bytes_conditional_chain(prefix_parts, bytes, max_order))
3010}
3011
3012#[inline(always)]
3015pub fn generate_bytes_conditional_chain_with_config(
3016 prefix_parts: &[&[u8]],
3017 bytes: usize,
3018 max_order: i64,
3019 config: GenerationConfig,
3020) -> Vec<u8> {
3021 with_default_ctx(|ctx| {
3022 ctx.generate_bytes_conditional_chain_with_config(prefix_parts, bytes, max_order, config)
3023 })
3024}
3025
3026pub fn d_kl_bytes(x: &[u8], y: &[u8]) -> f64 {
3030 if x.is_empty() || y.is_empty() {
3031 return 0.0;
3032 }
3033 let p_x = byte_histogram(x);
3034 let p_y = byte_histogram(y);
3035 let mut d_kl = 0.0f64;
3036 for i in 0..256 {
3037 if p_x[i] > 0.0 {
3038 let q_y = p_y[i].max(1e-12);
3039 d_kl += p_x[i] * (p_x[i] / q_y).log2();
3040 }
3041 }
3042 d_kl.max(0.0)
3043}
3044
3045pub fn js_div_bytes(x: &[u8], y: &[u8]) -> f64 {
3050 if x.is_empty() || y.is_empty() {
3051 return 0.0;
3052 }
3053 let p_x = byte_histogram(x);
3054 let p_y = byte_histogram(y);
3055 let mut m = [0.0f64; 256];
3056 for i in 0..256 {
3057 m[i] = 0.5 * (p_x[i] + p_y[i]);
3058 }
3059
3060 let mut kl_pm = 0.0f64;
3061 let mut kl_qm = 0.0f64;
3062 for i in 0..256 {
3063 if p_x[i] > 0.0 {
3064 kl_pm += p_x[i] * (p_x[i] / m[i]).log2();
3065 }
3066 if p_y[i] > 0.0 {
3067 kl_qm += p_y[i] * (p_y[i] / m[i]).log2();
3068 }
3069 }
3070 (0.5 * kl_pm + 0.5 * kl_qm).max(0.0)
3071}
3072
3073pub fn ned_paths(x: &str, y: &str, max_order: i64) -> f64 {
3077 let (bx, by) = rayon::join(
3078 || std::fs::read(x).expect("failed to read x"),
3079 || std::fs::read(y).expect("failed to read y"),
3080 );
3081 ned_bytes(&bx, &by, max_order)
3082}
3083
3084pub fn nte_paths(x: &str, y: &str, max_order: i64) -> f64 {
3086 let (bx, by) = rayon::join(
3087 || std::fs::read(x).expect("failed to read x"),
3088 || std::fs::read(y).expect("failed to read y"),
3089 );
3090 nte_bytes(&bx, &by, max_order)
3091}
3092
3093pub fn tvd_paths(x: &str, y: &str, max_order: i64) -> f64 {
3095 let (bx, by) = rayon::join(
3096 || std::fs::read(x).expect("failed to read x"),
3097 || std::fs::read(y).expect("failed to read y"),
3098 );
3099 tvd_bytes(&bx, &by, max_order)
3100}
3101
3102pub fn nhd_paths(x: &str, y: &str, max_order: i64) -> f64 {
3104 let (bx, by) = rayon::join(
3105 || std::fs::read(x).expect("failed to read x"),
3106 || std::fs::read(y).expect("failed to read y"),
3107 );
3108 nhd_bytes(&bx, &by, max_order)
3109}
3110
3111pub fn mutual_information_paths(x: &str, y: &str, max_order: i64) -> f64 {
3113 let (bx, by) = rayon::join(
3114 || std::fs::read(x).expect("failed to read x"),
3115 || std::fs::read(y).expect("failed to read y"),
3116 );
3117 mutual_information_bytes(&bx, &by, max_order)
3118}
3119
3120pub fn conditional_entropy_paths(x: &str, y: &str, max_order: i64) -> f64 {
3122 let (bx, by) = rayon::join(
3123 || std::fs::read(x).expect("failed to read x"),
3124 || std::fs::read(y).expect("failed to read y"),
3125 );
3126 conditional_entropy_bytes(&bx, &by, max_order)
3127}
3128
3129pub fn cross_entropy_paths(x: &str, y: &str, max_order: i64) -> f64 {
3131 let (bx, by) = rayon::join(
3132 || std::fs::read(x).expect("failed to read x"),
3133 || std::fs::read(y).expect("failed to read y"),
3134 );
3135 cross_entropy_bytes(&bx, &by, max_order)
3136}
3137
3138pub fn kl_divergence_paths(x: &str, y: &str) -> f64 {
3140 let (bx, by) = rayon::join(
3141 || std::fs::read(x).expect("failed to read x"),
3142 || std::fs::read(y).expect("failed to read y"),
3143 );
3144 d_kl_bytes(&bx, &by)
3145}
3146
3147pub fn js_divergence_paths(x: &str, y: &str) -> f64 {
3149 let (bx, by) = rayon::join(
3150 || std::fs::read(x).expect("failed to read x"),
3151 || std::fs::read(y).expect("failed to read y"),
3152 );
3153 js_div_bytes(&bx, &by)
3154}
3155
3156#[inline(always)]
3171pub fn intrinsic_dependence_bytes(data: &[u8], max_order: i64) -> f64 {
3172 with_default_ctx(|ctx| ctx.intrinsic_dependence_bytes(data, max_order))
3173}
3174
3175#[inline(always)]
3187pub fn resistance_to_transformation_bytes(x: &[u8], tx: &[u8], max_order: i64) -> f64 {
3188 with_default_ctx(|ctx| ctx.resistance_to_transformation_bytes(x, tx, max_order))
3189}
3190
3191#[cfg(test)]
3192mod tests {
3193 use super::*;
3194
3195 fn test_match_backend() -> RateBackend {
3196 RateBackend::Match {
3197 hash_bits: 12,
3198 min_len: 2,
3199 max_len: 16,
3200 base_mix: 0.01,
3201 confidence_scale: 1.0,
3202 }
3203 }
3204
3205 fn test_ppmd_backend() -> RateBackend {
3206 RateBackend::Ppmd {
3207 order: 4,
3208 memory_mb: 1,
3209 }
3210 }
3211
3212 fn test_calibrated_backend() -> RateBackend {
3213 RateBackend::Calibrated {
3214 spec: Arc::new(CalibratedSpec {
3215 base: test_match_backend(),
3216 context: CalibrationContextKind::Text,
3217 bins: 16,
3218 learning_rate: 0.05,
3219 bias_clip: 4.0,
3220 }),
3221 }
3222 }
3223
3224 fn test_mixture_backend() -> RateBackend {
3225 RateBackend::Mixture {
3226 spec: Arc::new(MixtureSpec::new(
3227 MixtureKind::Bayes,
3228 vec![
3229 MixtureExpertSpec {
3230 name: Some("match".to_string()),
3231 log_prior: 0.0,
3232 max_order: -1,
3233 backend: test_match_backend(),
3234 },
3235 MixtureExpertSpec {
3236 name: Some("ppmd".to_string()),
3237 log_prior: 0.0,
3238 max_order: -1,
3239 backend: test_ppmd_backend(),
3240 },
3241 ],
3242 )),
3243 }
3244 }
3245
3246 fn test_particle_backend() -> RateBackend {
3247 RateBackend::Particle {
3248 spec: Arc::new(ParticleSpec {
3249 num_particles: 4,
3250 num_cells: 4,
3251 cell_dim: 8,
3252 num_rules: 2,
3253 selector_hidden: 16,
3254 rule_hidden: 16,
3255 context_window: 8,
3256 unroll_steps: 1,
3257 ..ParticleSpec::default()
3258 }),
3259 }
3260 }
3261
3262 fn continuation_prompt() -> &'static [u8] {
3263 b"If a frog is green, dogs are red.\nIf a toad is green, cats are red.\nIf a dog is green, frogs are red.\nIf a cat is green, toads are red.\nIf a frog is red, dogs are green.\nIf a toad is red, cats are green.\nIf a dog is red, frogs are green.\nIf a cat is red, toads are \n"
3264 }
3265
3266 fn assert_deterministic_generate_for_backend(
3267 backend: RateBackend,
3268 max_order: i64,
3269 bytes: usize,
3270 label: &str,
3271 ) {
3272 let prompt = continuation_prompt();
3273 let a = generate_rate_backend_chain(
3274 &[prompt],
3275 bytes,
3276 max_order,
3277 &backend,
3278 GenerationConfig::default(),
3279 );
3280 let b = generate_rate_backend_chain(
3281 &[prompt],
3282 bytes,
3283 max_order,
3284 &backend,
3285 GenerationConfig::default(),
3286 );
3287 assert_eq!(
3288 a, b,
3289 "{label} generation should be deterministic for identical input"
3290 );
3291 assert_eq!(
3292 a.len(),
3293 bytes,
3294 "{label} generation should emit requested byte count"
3295 );
3296 }
3297
3298 fn assert_sampled_generate_for_backend(
3299 backend: RateBackend,
3300 max_order: i64,
3301 bytes: usize,
3302 label: &str,
3303 ) {
3304 let prompt = continuation_prompt();
3305 let config = GenerationConfig::sampled_frozen(42);
3306 let a = generate_rate_backend_chain(&[prompt], bytes, max_order, &backend, config);
3307 let b = generate_rate_backend_chain(&[prompt], bytes, max_order, &backend, config);
3308 assert_eq!(
3309 a, b,
3310 "{label} sampled generation should be deterministic for a fixed seed"
3311 );
3312 assert_eq!(
3313 a.len(),
3314 bytes,
3315 "{label} sampled generation should emit requested byte count"
3316 );
3317 }
3318
3319 #[cfg(feature = "backend-zpaq")]
3320 #[test]
3321 fn ncd_basic_identity_nonnegative() {
3322 let x = b"abcdabcdabcd";
3323 let d = ncd_bytes(x, x, "5", NcdVariant::Vitanyi);
3324 assert!(d >= -1e-9);
3325 }
3326
3327 #[test]
3328 fn shannon_identities_marginal_aligned() {
3329 let x = b"abracadabra";
3330 let y = b"abracadabra";
3331
3332 let h = marginal_entropy_bytes(x);
3333 let mi = mutual_information_bytes(x, y, 0);
3334 let h_xy = joint_marginal_entropy_bytes(x, y);
3335 let h_x_given_y = conditional_entropy_bytes(x, y, 0);
3336 let ned = ned_bytes(x, y, 0);
3337 let nte = nte_bytes(x, y, 0);
3338
3339 assert!((h_xy - h).abs() < 1e-12);
3340 assert!(h_x_given_y.abs() < 1e-12);
3341 assert!((mi - h).abs() < 1e-12);
3342 assert!(ned.abs() < 1e-12);
3343 assert!(nte.abs() < 1e-12);
3344 }
3345
3346 #[test]
3347 fn shannon_identities_rate_aligned_reasonable() {
3348 let x = b"the quick brown fox jumps over the lazy dog";
3349 let y = b"the quick brown fox jumps over the lazy dog";
3350 let max_order = 8;
3351 let prev = get_default_ctx();
3352 set_default_ctx(InfotheoryCtx::new(
3353 RateBackend::RosaPlus,
3354 CompressionBackend::default(),
3355 ));
3356
3357 let h_x = entropy_rate_bytes(x, max_order);
3358 let h_xy = joint_entropy_rate_bytes(x, y, max_order);
3359 let h_x_given_y = conditional_entropy_rate_bytes(x, y, max_order);
3360 let mi = mutual_information_bytes(x, y, max_order);
3361 let ned = ned_bytes(x, y, max_order);
3362
3363 let tol = 0.2;
3365 assert!((h_xy - h_x).abs() < tol);
3366 assert!(h_x_given_y < tol);
3367 assert!((mi - h_x).abs() < tol);
3368 assert!(ned < tol);
3369 set_default_ctx(prev);
3370 }
3371
3372 #[test]
3373 fn resistance_identity_is_one() {
3374 let x = b"some repeated repeated repeated text";
3375 let prev = get_default_ctx();
3376 set_default_ctx(InfotheoryCtx::new(
3377 RateBackend::RosaPlus,
3378 CompressionBackend::default(),
3379 ));
3380 let r0 = resistance_to_transformation_bytes(x, x, 0);
3381 let r8 = resistance_to_transformation_bytes(x, x, 8);
3382 assert!((r0 - 1.0).abs() < 1e-12);
3383 assert!((r8 - 1.0).abs() < 1e-6);
3384 set_default_ctx(prev);
3385 }
3386
3387 #[test]
3388 fn marginal_metrics_empty_inputs_are_zero() {
3389 let empty: &[u8] = &[];
3390 let x = b"abc";
3391
3392 assert_eq!(tvd_bytes(empty, x, 0), 0.0);
3393 assert_eq!(tvd_bytes(x, empty, 0), 0.0);
3394 assert_eq!(nhd_bytes(empty, x, 0), 0.0);
3395 assert_eq!(nhd_bytes(x, empty, 0), 0.0);
3396 assert_eq!(d_kl_bytes(empty, x), 0.0);
3397 assert_eq!(d_kl_bytes(x, empty), 0.0);
3398 assert_eq!(js_div_bytes(empty, x), 0.0);
3399 assert_eq!(js_div_bytes(x, empty), 0.0);
3400 }
3401
3402 #[test]
3403 fn marginal_cross_entropy_empty_test_is_zero() {
3404 let empty: &[u8] = &[];
3405 let y = b"abc";
3406 let ctx = InfotheoryCtx::with_zpaq("5");
3407 assert_eq!(ctx.cross_entropy_bytes(empty, y, 0), 0.0);
3408 }
3409
3410 #[cfg(not(feature = "backend-zpaq"))]
3411 #[test]
3412 #[should_panic(expected = "CompressionBackend::Zpaq is unavailable")]
3413 fn explicit_zpaq_backend_fails_loudly() {
3414 let backend = CompressionBackend::Zpaq {
3415 method: "5".to_string(),
3416 };
3417 let _ = compress_size_backend(b"abc", &backend);
3418 }
3419
3420 #[cfg(not(feature = "backend-zpaq"))]
3421 #[test]
3422 fn default_compression_backend_falls_back_to_rate_coding() {
3423 let backend = CompressionBackend::default();
3424 assert!(matches!(
3425 &backend,
3426 CompressionBackend::Rate {
3427 coder: crate::coders::CoderType::AC,
3428 framing: crate::compression::FramingMode::Raw,
3429 ..
3430 }
3431 ));
3432 assert!(compress_size_backend(b"abc", &backend) > 0);
3433 }
3434
3435 #[test]
3436 fn backend_switching_test() {
3437 let x = b"hello world context";
3438
3439 let h_rosa = entropy_rate_bytes(x, 8);
3441
3442 set_default_ctx(InfotheoryCtx::new(
3444 RateBackend::Ctw { depth: 16 },
3445 CompressionBackend::default(),
3446 ));
3447
3448 let h_ctw = entropy_rate_bytes(x, 8);
3449
3450 assert!(h_ctw > 0.0);
3452
3453 set_default_ctx(InfotheoryCtx::default());
3455 let h_rosa_back = entropy_rate_bytes(x, 8);
3456 assert!((h_rosa - h_rosa_back).abs() < 1e-12);
3457 }
3458
3459 #[test]
3460 fn ctw_early_updates_work() {
3461 use crate::ctw::ContextTree;
3464
3465 let mut tree = ContextTree::new(16);
3466
3467 let p0 = tree.predict(false);
3469 let p1 = tree.predict(true);
3470
3471 assert!((p0 - 0.5).abs() < 1e-10, "p0 should be ~0.5, got {}", p0);
3473 assert!((p1 - 0.5).abs() < 1e-10, "p1 should be ~0.5, got {}", p1);
3474 assert!((p0 + p1 - 1.0).abs() < 1e-10, "p0 + p1 should = 1.0");
3475
3476 for _ in 0..5 {
3478 tree.update(true);
3479 tree.update(false);
3480 }
3481
3482 let log_prob = tree.get_log_block_probability();
3483 assert!(
3484 log_prob < 0.0,
3485 "log_prob should be negative (< log 1), got {}",
3486 log_prob
3487 );
3488 assert!(log_prob.is_finite(), "log_prob should be finite");
3489 }
3490
3491 #[test]
3492 fn nte_can_exceed_one() {
3493 set_default_ctx(InfotheoryCtx::new(
3504 RateBackend::Ctw { depth: 8 },
3505 CompressionBackend::default(),
3506 ));
3507
3508 let x: Vec<u8> = (0..200).map(|i| (i % 2) as u8).collect(); let y: Vec<u8> = (0..200).map(|i| ((i + 1) % 2) as u8).collect(); let nte_rate = nte_rate_backend(&x, &y, -1, &RateBackend::Ctw { depth: 8 });
3513
3514 assert!(
3517 (0.0..=2.0 + 1e-9).contains(&nte_rate),
3518 "NTE should be in [0, 2], got {}",
3519 nte_rate
3520 );
3521
3522 set_default_ctx(InfotheoryCtx::default());
3524 }
3525
3526 #[test]
3527 fn ctw_empty_data_returns_zero() {
3528 set_default_ctx(InfotheoryCtx::new(
3530 RateBackend::Ctw { depth: 16 },
3531 CompressionBackend::default(),
3532 ));
3533
3534 let empty: &[u8] = &[];
3535 let h = entropy_rate_bytes(empty, -1);
3536 assert_eq!(h, 0.0, "empty data should return 0.0 entropy");
3537
3538 set_default_ctx(InfotheoryCtx::default());
3540 }
3541
3542 #[test]
3543 fn joint_entropy_rate_aligns_inputs_and_handles_empty_cases() {
3544 let cases = vec![
3545 ("ctw", RateBackend::Ctw { depth: 8 }),
3546 (
3547 "fac-ctw",
3548 RateBackend::FacCtw {
3549 base_depth: 8,
3550 num_percept_bits: 8,
3551 encoding_bits: 8,
3552 },
3553 ),
3554 ("match", test_match_backend()),
3555 ];
3556
3557 for (name, backend) in cases {
3558 assert_eq!(
3559 joint_entropy_rate_backend(b"", b"nonempty", -1, &backend),
3560 0.0,
3561 "{name} should return 0.0 for empty aligned pairs"
3562 );
3563 assert_eq!(
3564 joint_entropy_rate_backend(b"nonempty", b"", -1, &backend),
3565 0.0,
3566 "{name} should return 0.0 when alignment truncates to empty"
3567 );
3568
3569 let aligned = joint_entropy_rate_backend(b"abcd", b"wxyz", -1, &backend);
3570 let truncated = joint_entropy_rate_backend(b"abcdextra", b"wxyz", -1, &backend);
3571 assert!(
3572 (aligned - truncated).abs() < 1e-12,
3573 "{name} should score only the aligned prefix: aligned={aligned} truncated={truncated}"
3574 );
3575 }
3576 }
3577
3578 #[test]
3579 fn biased_entropy_is_repeatable_across_backend_families() {
3580 let data = b"ABABABAABBABABABAABB";
3581 let cases = vec![
3582 ("match", test_match_backend()),
3583 ("ppmd", test_ppmd_backend()),
3584 ("calibrated", test_calibrated_backend()),
3585 ("ctw", RateBackend::Ctw { depth: 8 }),
3586 ("mixture", test_mixture_backend()),
3587 ("particle", test_particle_backend()),
3588 ];
3589
3590 for (name, backend) in cases {
3591 let h1 = biased_entropy_rate_backend(data, -1, &backend);
3592 let h2 = biased_entropy_rate_backend(data, -1, &backend);
3593 assert!(h1.is_finite(), "{name} biased entropy should be finite");
3594 assert!(
3595 (h1 - h2).abs() < 1e-12,
3596 "{name} biased entropy leaked mutable state across calls: h1={h1} h2={h2}"
3597 );
3598 }
3599 }
3600
3601 #[test]
3602 fn generate_bytes_chain_matches_flat_prompt() {
3603 let prompt = continuation_prompt();
3604 let split_at = prompt.len() / 2;
3605 let front = &prompt[..split_at];
3606 let back = &prompt[split_at..];
3607 let backend = RateBackend::Ctw { depth: 32 };
3608 let bytes = 8usize;
3609 let max_order = -1;
3610
3611 let flat = generate_rate_backend_chain(
3612 &[prompt],
3613 bytes,
3614 max_order,
3615 &backend,
3616 GenerationConfig::default(),
3617 );
3618 let chained = generate_rate_backend_chain(
3619 &[front, back],
3620 bytes,
3621 max_order,
3622 &backend,
3623 GenerationConfig::default(),
3624 );
3625 assert_eq!(
3626 flat, chained,
3627 "chain conditioning should match flat prompt conditioning"
3628 );
3629 }
3630
3631 #[test]
3632 fn generate_bytes_api_is_deterministic_for_ctw_rosa_match_ppmd() {
3633 assert_deterministic_generate_for_backend(RateBackend::Ctw { depth: 32 }, -1, 8, "ctw");
3634 assert_deterministic_generate_for_backend(RateBackend::RosaPlus, -1, 8, "rosaplus");
3635 assert_deterministic_generate_for_backend(test_match_backend(), -1, 8, "match");
3636 assert_deterministic_generate_for_backend(test_ppmd_backend(), -1, 8, "ppmd");
3637 }
3638
3639 #[cfg(feature = "backend-rwkv")]
3640 #[test]
3641 fn generate_bytes_api_is_deterministic_for_rwkv_method() {
3642 let backend = RateBackend::Rwkv7Method {
3643 method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3644 };
3645 assert_deterministic_generate_for_backend(backend, -1, 8, "rwkv7");
3646 }
3647
3648 #[test]
3649 fn sampled_generation_is_deterministic_for_ctw_rosa_match_ppmd() {
3650 assert_sampled_generate_for_backend(RateBackend::Ctw { depth: 32 }, -1, 8, "ctw");
3651 assert_sampled_generate_for_backend(RateBackend::RosaPlus, -1, 8, "rosaplus");
3652 assert_sampled_generate_for_backend(test_match_backend(), -1, 8, "match");
3653 assert_sampled_generate_for_backend(test_ppmd_backend(), -1, 8, "ppmd");
3654 }
3655
3656 #[cfg(feature = "backend-rwkv")]
3657 #[test]
3658 fn sampled_generation_is_deterministic_for_rwkv_method() {
3659 let backend = RateBackend::Rwkv7Method {
3660 method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3661 };
3662 assert_sampled_generate_for_backend(backend, -1, 8, "rwkv7");
3663 }
3664
3665 #[test]
3666 fn rosaplus_sampled_generation_predicts_green_continuation() {
3667 let out = generate_rate_backend_chain(
3668 &[continuation_prompt()],
3669 8,
3670 -1,
3671 &RateBackend::RosaPlus,
3672 GenerationConfig::sampled_frozen(42),
3673 );
3674 assert_eq!(out, b" green.\n");
3675 }
3676
3677 #[test]
3678 fn rate_backend_session_matches_ctx_generation() {
3679 let prompt = continuation_prompt();
3680 let backend = RateBackend::Ppmd {
3681 order: 12,
3682 memory_mb: 8,
3683 };
3684 let mut session =
3685 RateBackendSession::from_backend(backend.clone(), -1, Some((prompt.len() + 8) as u64))
3686 .expect("session init");
3687 session.observe(prompt);
3688 let from_session = session.generate_bytes(8, GenerationConfig::sampled_frozen(42));
3689 session.finish().expect("session finish");
3690
3691 let ctx = InfotheoryCtx::new(backend, CompressionBackend::default());
3692 let from_ctx =
3693 ctx.generate_bytes_with_config(prompt, 8, -1, GenerationConfig::sampled_frozen(42));
3694 assert_eq!(from_session, from_ctx);
3695 }
3696
3697 #[test]
3698 fn biased_entropy_ctw_uses_frozen_plugin_scoring() {
3699 let backend = RateBackend::Ctw { depth: 8 };
3700 let data = b"AAAAAAAA";
3701 let plugin = biased_entropy_rate_backend(data, -1, &backend);
3702 let prequential = entropy_rate_backend(data, -1, &backend);
3703 assert!(
3704 plugin + 1e-9 < prequential,
3705 "expected plugin scoring to beat prequential scoring: plugin={plugin} prequential={prequential}"
3706 );
3707 }
3708
3709 #[test]
3710 fn rosa_plugin_entropy_matches_direct_model_api() {
3711 let data = b"abracadabra";
3712 let backend = RateBackend::RosaPlus;
3713
3714 let plugin = biased_entropy_rate_backend(data, 3, &backend);
3715
3716 let mut direct = rosaplus::RosaPlus::new(3, false, 0, 42);
3717 direct.train_example(data);
3718 direct.build_lm();
3719 let expected = direct.cross_entropy(data);
3720
3721 assert!(
3722 (plugin - expected).abs() < 1e-12,
3723 "rosa plugin entropy must match direct model API: plugin={plugin} expected={expected}"
3724 );
3725 }
3726
3727 #[test]
3728 fn rosa_plugin_cross_entropy_matches_direct_model_api() {
3729 let train = b"alakazam";
3730 let test = b"abracadabra";
3731 let backend = RateBackend::RosaPlus;
3732
3733 let plugin = cross_entropy_rate_backend(test, train, 3, &backend);
3734
3735 let mut direct = rosaplus::RosaPlus::new(3, false, 0, 42);
3736 direct.train_example(train);
3737 direct.build_lm();
3738 let expected = direct.cross_entropy(test);
3739
3740 assert!(
3741 (plugin - expected).abs() < 1e-12,
3742 "rosa plugin cross entropy must match direct model API: plugin={plugin} expected={expected}"
3743 );
3744 }
3745
3746 #[test]
3747 fn datagen_bernoulli_entropy_estimate() {
3748 let p = 0.5;
3750 let theoretical_h = crate::datagen::bernoulli_entropy(p);
3751 assert!((theoretical_h - 1.0).abs() < 1e-10);
3752
3753 let data = crate::datagen::bernoulli(10000, p, 42);
3755 let estimated_h = marginal_entropy_bytes(&data);
3756
3757 assert!(
3759 (estimated_h - theoretical_h).abs() < 0.1,
3760 "estimated H={} should be close to theoretical H={}",
3761 estimated_h,
3762 theoretical_h
3763 );
3764 }
3765
3766 #[cfg(feature = "backend-rwkv")]
3767 #[test]
3768 fn rwkv_method_entropy_is_stable_across_calls() {
3769 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=21,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:infer";
3770 let backend = RateBackend::Rwkv7Method {
3771 method: method.to_string(),
3772 };
3773 let data = b"rwkv method entropy stability regression sample";
3774
3775 let h1 = entropy_rate_backend(data, -1, &backend);
3776 let h2 = entropy_rate_backend(data, -1, &backend);
3777 assert!(
3778 (h1 - h2).abs() < 1e-12,
3779 "rwkv method entropy leaked mutable state across calls: h1={h1}, h2={h2}"
3780 );
3781 }
3782
3783 #[cfg(feature = "backend-rwkv")]
3784 #[test]
3785 fn rwkv_method_without_policy_is_accepted_by_public_api() {
3786 let backend = RateBackend::Rwkv7Method {
3787 method: "cfg:hidden=64,layers=1,intermediate=64".to_string(),
3788 };
3789 let data = b"rwkv method without policy";
3790 let h1 = entropy_rate_backend(data, -1, &backend);
3791 let h2 = biased_entropy_rate_backend(data, -1, &backend);
3792 assert!(h1.is_finite());
3793 assert!(h2.is_finite());
3794 }
3795
3796 #[cfg(feature = "backend-rwkv")]
3797 #[test]
3798 fn rwkv_infer_only_plugin_collapses_to_single_pass_entropy() {
3799 let backend = RateBackend::Rwkv7Method {
3800 method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=25,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3801 };
3802 let data = b"rwkv infer-only plugin equality sample";
3803 let h = entropy_rate_backend(data, -1, &backend);
3804 let plugin = biased_entropy_rate_backend(data, -1, &backend);
3805 assert!(
3806 (h - plugin).abs() < 1e-12,
3807 "infer-only rwkv plugin should equal single-pass entropy: h={h}, plugin={plugin}"
3808 );
3809 }
3810
3811 #[cfg(feature = "backend-rwkv")]
3812 #[test]
3813 fn rwkv_method_biased_entropy_is_stable_across_calls_with_training_policy() {
3814 let backend = RateBackend::Rwkv7Method {
3815 method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=23,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.0)".to_string(),
3816 };
3817 let data = b"rwkv plugin stability sample";
3818 let h1 = biased_entropy_rate_backend(data, -1, &backend);
3819 let h2 = biased_entropy_rate_backend(data, -1, &backend);
3820 assert!(
3821 (h1 - h2).abs() < 1e-12,
3822 "rwkv method biased entropy leaked mutable state across calls: h1={h1}, h2={h2}"
3823 );
3824 }
3825
3826 #[cfg(feature = "backend-rwkv")]
3827 #[test]
3828 fn rwkv_method_conditional_chain_is_stable_across_calls() {
3829 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=22,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:infer";
3830 let ctx = InfotheoryCtx::new(
3831 RateBackend::Rwkv7Method {
3832 method: method.to_string(),
3833 },
3834 CompressionBackend::default(),
3835 );
3836
3837 let prefix = b"universal prior slice";
3838 let data = b"query payload";
3839 let h1 = ctx.cross_entropy_conditional_chain(&[prefix.as_slice()], data);
3840 let h2 = ctx.cross_entropy_conditional_chain(&[prefix.as_slice()], data);
3841 assert!(
3842 (h1 - h2).abs() < 1e-12,
3843 "rwkv method conditional chain leaked mutable state across calls: h1={h1}, h2={h2}"
3844 );
3845 }
3846
3847 #[cfg(feature = "backend-mamba")]
3848 #[test]
3849 fn mamba_method_without_policy_is_accepted_by_public_api() {
3850 let backend = RateBackend::MambaMethod {
3851 method: "cfg:hidden=64,layers=1,intermediate=96".to_string(),
3852 };
3853 let data = b"mamba method without policy";
3854 let h1 = entropy_rate_backend(data, -1, &backend);
3855 let h2 = biased_entropy_rate_backend(data, -1, &backend);
3856 assert!(h1.is_finite());
3857 assert!(h2.is_finite());
3858 }
3859
3860 #[cfg(feature = "backend-mamba")]
3861 #[test]
3862 fn mamba_infer_only_plugin_collapses_to_single_pass_entropy() {
3863 let backend = RateBackend::MambaMethod {
3864 method: "cfg:hidden=64,layers=1,intermediate=96,state=16,conv=4,dt_rank=16,seed=26,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3865 };
3866 let data = b"mamba infer-only plugin equality sample";
3867 let h = entropy_rate_backend(data, -1, &backend);
3868 let plugin = biased_entropy_rate_backend(data, -1, &backend);
3869 assert!(
3870 (h - plugin).abs() < 1e-12,
3871 "infer-only mamba plugin should equal single-pass entropy: h={h}, plugin={plugin}"
3872 );
3873 }
3874
3875 #[cfg(feature = "backend-mamba")]
3876 #[test]
3877 fn mamba_method_biased_entropy_is_stable_across_calls_with_training_policy() {
3878 let backend = RateBackend::MambaMethod {
3879 method: "cfg:hidden=64,layers=1,intermediate=96,state=16,conv=4,dt_rank=16,seed=24,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.0)".to_string(),
3880 };
3881 let data = b"mamba plugin stability sample";
3882 let h1 = biased_entropy_rate_backend(data, -1, &backend);
3883 let h2 = biased_entropy_rate_backend(data, -1, &backend);
3884 assert!(
3885 (h1 - h2).abs() < 1e-12,
3886 "mamba method biased entropy leaked mutable state across calls: h1={h1}, h2={h2}"
3887 );
3888 }
3889
3890 #[test]
3891 fn particle_entropy_rate_in_valid_range() {
3892 let rb = test_particle_backend();
3893 let data = b"hello world particle backend test";
3894 let rate = entropy_rate_backend(data, -1, &rb);
3895 assert!(
3896 rate > 0.0 && rate < 8.0,
3897 "particle entropy rate out of (0, 8) range: {rate}"
3898 );
3899 }
3900
3901 #[test]
3902 fn particle_cross_entropy_stability() {
3903 let rb = test_particle_backend();
3904 let train = b"ABCABC";
3905 let test = b"ABC";
3906 let h1 = cross_entropy_rate_backend(test, train, -1, &rb);
3907 let h2 = cross_entropy_rate_backend(test, train, -1, &rb);
3908 assert!(
3909 (h1 - h2).abs() < 1e-12,
3910 "particle cross entropy not deterministic: h1={h1}, h2={h2}"
3911 );
3912 }
3913
3914 #[test]
3915 fn particle_empty_input() {
3916 let rb = RateBackend::Particle {
3917 spec: Arc::new(ParticleSpec::default()),
3918 };
3919 let rate = entropy_rate_backend(b"", -1, &rb);
3920 assert!(
3921 rate == 0.0,
3922 "particle entropy rate for empty input should be 0.0, got {rate}"
3923 );
3924 }
3925
3926 #[test]
3927 fn particle_joint_entropy_rate() {
3928 let rb = test_particle_backend();
3929 let x = b"AAAA";
3930 let y = b"BBBB";
3931 let joint = joint_entropy_rate_backend(x, y, -1, &rb);
3932 assert!(
3933 joint > 0.0 && joint < 16.0,
3934 "particle joint entropy rate out of range: {joint}"
3935 );
3936 }
3937}