infotheory/
lib.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2
3//! # InfoTheory: Information Theoretic Estimators & Metrics
4//!
5//! This crate provides a comprehensive suite of information-theoretic primitives for
6//! quantifying complexity, dependence, and similarity between data sequences.
7//!
8//! It implements two primary classes of estimators:
9//! 1.  **Compression-based (Kolmogorov Complexity)**: Using the ZPAQ compression algorithm to estimate
10//!     Normalized Compression Distance (NCD).
11//! 2.  **Entropy-based (Shannon Information)**: Using both exact marginal histograms (for i.i.d. data)
12//!     and the ROSA (Rapid Online Suffix Automaton) predictive language model (for sequential data)
13//!     to estimate Entropy, Mutual Information, and related distances.
14//!
15//! ## Mathematical Primitives
16//!
17//! The library implements the following core measures. For sequential data, "Rate" variants
18//! use the ROSA model to estimate `Ĥ(X)` (entropy rate), while "Marginal" variants
19//! treat data as a bag-of-bytes (i.i.d.) and compute `H(X)` from histograms.
20//!
21//! ### 1. Normalized Compression Distance (NCD)
22//! Approximates the Normalized Information Distance (NID) using a compressor `C`.
23//!
24//! `NCD(x,y) = (C(xy) - min(C(x), C(y))) / max(C(x), C(y))`
25//!
26//! ### 2. Normalized Entropy Distance (NED)
27//! An entropic analogue to NCD, defined using Shannon entropy `H`.
28//!
29//! `NED(X,Y) = (H(X,Y) - min(H(X), H(Y))) / max(H(X), H(Y))`
30//!
31//! ### 3. Normalized Transform Effort (NTE)
32//! Based on the Variation of Information (VI), normalized by the maximum entropy.
33//!
34//! `NTE(X,Y) = (H(X|Y) + H(Y|X)) / max(H(X), H(Y)) = (2H(X,Y) - H(X) - H(Y)) / max(H(X), H(Y))`
35//!
36//! ### 4. Mutual Information (MI)
37//! Measures the amount of information obtained about one random variable by observing another.
38//!
39//! `I(X;Y) = H(X) + H(Y) - H(X,Y)`
40//!
41//! ### 5. Divergences & Distances
42//! *   **Total Variation Distance (TVD)**: `δ(P,Q) = 0.5 * Σ |P(x) - Q(x)|`
43//! *   **Normalized Hellinger Distance (NHD)**: `sqrt(1 - Σ sqrt(P(x)Q(x)))`
44//! *   **Kullback-Leibler Divergence (KL)**: `D_KL(P||Q) = Σ P(x) log(P(x)/Q(x))`
45//! *   **Jensen-Shannon Divergence (JSD)**: Symmetrized and smoothed KL divergence.
46//!
47//! ### 6. Intrinsic Dependence (ID)
48//! Measures the redundancy within a sequence, comparing marginal entropy to entropy rate.
49//!
50//! `ID(X) = (H_marginal(X) - H_rate(X)) / H_marginal(X)`
51//!
52//! ### 7. Resistance to Transformation
53//! Quantifies how much information is preserved after a transformation `T` is applied.
54//!
55//! `R(X, T) = I(X; T(X)) / H(X)`
56//!
57//! ## Usage
58//!
59//! ```rust,no_run
60//! use infotheory::{ncd_vitanyi, mutual_information_bytes, NcdVariant};
61//!
62//! let x = b"some data sequence";
63//! let y = b"another data sequence";
64//!
65//! // Compression-based distance
66//! let ncd = ncd_vitanyi("file1.txt", "file2.txt", "5");
67//!
68//! // Entropy-based mutual information (Marginal / i.i.d.)
69//! let mi_marg = mutual_information_bytes(x, y, 0);
70//!
71//! // Entropy-based mutual information (Rate / Sequential, max_order=8)
72//! let mi_rate = mutual_information_bytes(x, y, 8);
73//! ```
74
75/// AIXI planning components, environments, and model abstractions.
76pub mod aixi;
77/// Core information-theoretic axioms and validation helpers.
78pub mod axioms;
79/// Entropy/compression backend implementations and backend discovery.
80pub mod backends;
81/// Entropy coder implementations (AC and rANS).
82pub mod coders;
83/// Rate-coded compression helpers built on generic rate backends.
84pub mod compression;
85/// Synthetic data generators for information-theory experiments.
86pub mod datagen;
87/// Online Bayesian/switching/MDL mixture predictors.
88pub mod mixture;
89pub(crate) mod neural_mix;
90/// Information-theoretic code search pipeline (3-stage: prefilter, filter, KMI rerank).
91pub mod search;
92pub(crate) mod simd_math;
93/// CTW and FAC-CTW backend types.
94pub use backends::ctw;
95#[cfg(feature = "backend-mamba")]
96/// Mamba backend types and compressor.
97pub use backends::mambazip;
98/// Match-based repeat predictor.
99pub use backends::match_model;
100/// Particle-latent filter ensemble rate backend.
101pub use backends::particle;
102/// PPMD-style byte model.
103pub use backends::ppmd;
104/// ROSA+ backend types.
105pub use backends::rosaplus;
106#[cfg(feature = "backend-rwkv")]
107/// RWKV backend types and compressor.
108pub use backends::rwkvzip;
109/// Sparse/gapped match predictor.
110pub use backends::sparse_match;
111/// ZPAQ rate-model adapter.
112pub use backends::zpaq_rate;
113
114use rayon::prelude::*;
115
116use crate::coders::CoderType;
117use std::cell::RefCell;
118#[cfg(any(feature = "backend-rwkv", feature = "backend-mamba"))]
119use std::collections::HashMap;
120use std::sync::Arc;
121use std::sync::OnceLock;
122
123/// How generated symbols should update the model state.
124#[derive(Clone, Copy, Debug, PartialEq, Eq)]
125pub enum GenerationUpdateMode {
126    /// Keep adapting/fitting on generated bytes.
127    Adaptive,
128    /// Freeze fitted parameters/statistics and only advance conditioning state.
129    Frozen,
130}
131
132/// How to pick the next byte from the model distribution.
133#[derive(Clone, Copy, Debug, PartialEq, Eq)]
134pub enum GenerationStrategy {
135    /// Deterministic argmax over the next-byte distribution.
136    Greedy,
137    /// Seeded sampling from the next-byte distribution.
138    Sample,
139}
140
141/// Generation options shared by the library API and CLI.
142#[derive(Clone, Copy, Debug)]
143pub struct GenerationConfig {
144    /// Byte-selection strategy.
145    pub strategy: GenerationStrategy,
146    /// Whether generated bytes should keep adapting the model.
147    pub update_mode: GenerationUpdateMode,
148    /// RNG seed used by [`GenerationStrategy::Sample`].
149    pub seed: u64,
150    /// Softmax temperature for sampling. `<= 0` behaves like greedy.
151    pub temperature: f64,
152    /// Optional top-k truncation. `0` disables it.
153    pub top_k: usize,
154    /// Optional nucleus truncation. Values `>= 1.0` disable it.
155    pub top_p: f64,
156}
157
158impl Default for GenerationConfig {
159    fn default() -> Self {
160        Self::sampled_frozen(42)
161    }
162}
163
164impl GenerationConfig {
165    /// Deterministic frozen continuation.
166    pub const fn greedy_frozen() -> Self {
167        Self {
168            strategy: GenerationStrategy::Greedy,
169            update_mode: GenerationUpdateMode::Frozen,
170            seed: 0xD00D_F00D_CAFE_BABEu64,
171            temperature: 1.0,
172            top_k: 0,
173            top_p: 1.0,
174        }
175    }
176
177    /// Seeded frozen sampling from the model distribution.
178    pub const fn sampled_frozen(seed: u64) -> Self {
179        Self {
180            strategy: GenerationStrategy::Sample,
181            update_mode: GenerationUpdateMode::Frozen,
182            seed,
183            temperature: 1.0,
184            top_k: 0,
185            top_p: 1.0,
186        }
187    }
188}
189
190struct GenerationRng {
191    state: u64,
192}
193
194impl GenerationRng {
195    fn new(seed: u64) -> Self {
196        Self {
197            state: if seed == 0 {
198                0xD00D_F00D_CAFE_BABEu64
199            } else {
200                seed
201            },
202        }
203    }
204
205    fn next_u64(&mut self) -> u64 {
206        let mut x = self.state;
207        x ^= x << 13;
208        x ^= x >> 7;
209        x ^= x << 17;
210        self.state = x;
211        x
212    }
213
214    fn next_f64(&mut self) -> f64 {
215        (self.next_u64() as f64) / (u64::MAX as f64)
216    }
217}
218
219static NUM_THREADS: OnceLock<usize> = OnceLock::new();
220
221thread_local! {
222    #[cfg(feature = "backend-mamba")]
223    static MAMBA_TLS: RefCell<HashMap<usize, mambazip::Compressor>> = RefCell::new(HashMap::new());
224    #[cfg(feature = "backend-mamba")]
225    static MAMBA_RATE_TLS: RefCell<HashMap<usize, mambazip::Compressor>> = RefCell::new(HashMap::new());
226    #[cfg(feature = "backend-mamba")]
227    static MAMBA_METHOD_TLS: RefCell<HashMap<String, mambazip::Compressor>> = RefCell::new(HashMap::new());
228    #[cfg(feature = "backend-rwkv")]
229    static RWKV_TLS: RefCell<HashMap<usize, rwkvzip::Compressor>> = RefCell::new(HashMap::new());
230    #[cfg(feature = "backend-rwkv")]
231    static RWKV_RATE_TLS: RefCell<HashMap<usize, rwkvzip::Compressor>> = RefCell::new(HashMap::new());
232    #[cfg(feature = "backend-rwkv")]
233    static RWKV_METHOD_TLS: RefCell<HashMap<String, rwkvzip::Compressor>> = RefCell::new(HashMap::new());
234}
235
236#[cfg(feature = "backend-zpaq")]
237impl Default for CompressionBackend {
238    fn default() -> Self {
239        CompressionBackend::Zpaq {
240            method: "5".to_string(),
241        }
242    }
243}
244
245#[cfg(not(feature = "backend-zpaq"))]
246impl Default for CompressionBackend {
247    fn default() -> Self {
248        CompressionBackend::Rate {
249            rate_backend: RateBackend::default(),
250            coder: CoderType::AC,
251            framing: compression::FramingMode::Raw,
252        }
253    }
254}
255
256thread_local! {
257    static DEFAULT_CTX: RefCell<InfotheoryCtx> = RefCell::new(InfotheoryCtx::default());
258}
259
260/// Returns the current default information theory context for the thread.
261pub fn get_default_ctx() -> InfotheoryCtx {
262    DEFAULT_CTX.with(|ctx| ctx.borrow().clone())
263}
264
265/// Sets the default information theory context for the thread.
266pub fn set_default_ctx(ctx: InfotheoryCtx) {
267    DEFAULT_CTX.with(|c| *c.borrow_mut() = ctx);
268}
269
270#[inline(always)]
271fn with_default_ctx<R>(f: impl FnOnce(&InfotheoryCtx) -> R) -> R {
272    DEFAULT_CTX.with(|ctx| f(&ctx.borrow()))
273}
274
275/// Mutual information rate estimate under an explicit `backend`.
276///
277/// Inputs are aligned to the shared prefix length.
278pub fn mutual_information_rate_backend(
279    x: &[u8],
280    y: &[u8],
281    max_order: i64,
282    backend: &RateBackend,
283) -> f64 {
284    let (x, y) = aligned_prefix(x, y);
285    if x.is_empty() {
286        return 0.0;
287    }
288    // For CTW, we might want a special aligned implementation?
289    // Using standard formula for now.
290    let h_x = entropy_rate_backend(x, max_order, backend);
291    let h_y = entropy_rate_backend(y, max_order, backend);
292    let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
293    (h_x + h_y - h_xy).max(0.0)
294}
295
296/// Normalized entropy distance under an explicit `backend`.
297///
298/// Returns a value in `[0, 1]` after clamping.
299pub fn ned_rate_backend(x: &[u8], y: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
300    let (x, y) = aligned_prefix(x, y);
301    if x.is_empty() {
302        return 0.0;
303    }
304    let h_x = entropy_rate_backend(x, max_order, backend);
305    let h_y = entropy_rate_backend(y, max_order, backend);
306    let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
307    let min_h = h_x.min(h_y);
308    let max_h = h_x.max(h_y);
309    if max_h == 0.0 {
310        0.0
311    } else {
312        ((h_xy - min_h) / max_h).clamp(0.0, 1.0)
313    }
314}
315
316/// Normalized transform effort (variation-of-information form) under an explicit `backend`.
317///
318/// Returns a value in `[0, 2]` after clamping.
319pub fn nte_rate_backend(x: &[u8], y: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
320    let (x, y) = aligned_prefix(x, y);
321    if x.is_empty() {
322        return 0.0;
323    }
324    let h_x = entropy_rate_backend(x, max_order, backend);
325    let h_y = entropy_rate_backend(y, max_order, backend);
326    let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
327    let max_h = h_x.max(h_y);
328    if max_h == 0.0 {
329        0.0
330    } else {
331        // VI = H(X|Y) + H(Y|X) can be as large as H(X) + H(Y) ≈ 2*max(H)
332        // for independent sequences, so NTE ∈ [0, 2]
333        let vi = (h_xy - h_x).max(0.0) + (h_xy - h_y).max(0.0);
334        (vi / max_h).clamp(0.0, 2.0)
335    }
336}
337
338/// Sequential entropy/rate backend used by context-aware metrics.
339#[derive(Clone)]
340pub enum RateBackend {
341    /// ROSA+ suffix-automaton estimator.
342    RosaPlus,
343    /// Local contiguous match predictor.
344    Match {
345        /// Number of retained hash bits for suffix lookup.
346        hash_bits: usize,
347        /// Minimum repeat length required before predicting.
348        min_len: usize,
349        /// Maximum repeat length used for confidence scaling.
350        max_len: usize,
351        /// Residual probability mass left for non-match symbols.
352        base_mix: f64,
353        /// Confidence multiplier applied to short-match tapering.
354        confidence_scale: f64,
355    },
356    /// Sparse/gapped local match predictor.
357    SparseMatch {
358        /// Number of retained hash bits for spaced-suffix lookup.
359        hash_bits: usize,
360        /// Minimum spaced repeat length required before predicting.
361        min_len: usize,
362        /// Maximum spaced repeat length used for confidence scaling.
363        max_len: usize,
364        /// Minimum gap between matched bytes.
365        gap_min: usize,
366        /// Maximum gap between matched bytes.
367        gap_max: usize,
368        /// Residual probability mass left for non-match symbols.
369        base_mix: f64,
370        /// Confidence multiplier applied to short-match tapering.
371        confidence_scale: f64,
372    },
373    /// Pure-Rust bounded-memory PPMD-style model.
374    Ppmd {
375        /// Maximum context order.
376        order: usize,
377        /// Approximate memory budget in MiB.
378        memory_mb: usize,
379    },
380    #[cfg(feature = "backend-mamba")]
381    /// Mamba model loaded from explicit weights.
382    Mamba {
383        /// Loaded Mamba model.
384        model: Arc<mambazip::Model>,
385    },
386    #[cfg(feature = "backend-mamba")]
387    /// Mamba method string (e.g. `file:...` or `cfg:...[;policy:...]`) resolved lazily.
388    MambaMethod {
389        /// Mamba method string.
390        method: String,
391    },
392    #[cfg(feature = "backend-rwkv")]
393    /// RWKV7 model loaded from explicit weights.
394    Rwkv7 {
395        /// Loaded RWKV7 model.
396        model: Arc<rwkvzip::Model>,
397    },
398    #[cfg(feature = "backend-rwkv")]
399    /// RWKV7 method string (e.g. `file:...` or `cfg:...[;policy:...]`) resolved lazily.
400    Rwkv7Method {
401        /// RWKV7 method string.
402        method: String,
403    },
404    /// ZPAQ compression-based rate model (streamable methods only).
405    Zpaq {
406        /// ZPAQ method string (streamable modes only for rate estimation).
407        method: String,
408    },
409    /// Online mixture over rate-model experts (Bayes, fading Bayes, switching, MDL).
410    Mixture {
411        /// Mixture expert/runtime specification.
412        spec: Arc<MixtureSpec>,
413    },
414    /// Particle-latent filter ensemble.
415    Particle {
416        /// Particle filter specification.
417        spec: Arc<ParticleSpec>,
418    },
419    /// Calibrated wrapper over another bytewise backend.
420    Calibrated {
421        /// Calibration specification.
422        spec: Arc<CalibratedSpec>,
423    },
424    /// Action-Conditional CTW (single context tree).
425    Ctw {
426        /// Context tree depth.
427        depth: usize,
428    },
429    /// Factorized Action-Conditional CTW (k trees for k-bit percepts).
430    FacCtw {
431        /// Base context depth.
432        base_depth: usize,
433        /// Number of percept bits.
434        num_percept_bits: usize,
435        /// Encoding width in bits.
436        encoding_bits: usize,
437    },
438}
439
440#[allow(clippy::derivable_impls)]
441impl Default for RateBackend {
442    fn default() -> Self {
443        #[cfg(feature = "backend-rosa")]
444        {
445            RateBackend::RosaPlus
446        }
447        #[cfg(all(not(feature = "backend-rosa"), feature = "backend-zpaq"))]
448        {
449            RateBackend::Zpaq {
450                method: "1".to_string(),
451            }
452        }
453        #[cfg(all(not(feature = "backend-rosa"), not(feature = "backend-zpaq")))]
454        {
455            RateBackend::Ctw { depth: 16 }
456        }
457    }
458}
459
460/// Compression backend used by NCD/compression-size operations.
461#[derive(Clone)]
462pub enum CompressionBackend {
463    /// ZPAQ compressor with explicit method string.
464    Zpaq {
465        /// ZPAQ method (for example `"1"` or `"5"`).
466        method: String,
467    },
468    #[cfg(feature = "backend-rwkv")]
469    /// RWKV7 model as an entropy-coded compressor.
470    Rwkv7 {
471        /// Loaded RWKV7 model.
472        model: Arc<rwkvzip::Model>,
473        /// Entropy coder used for coding model PDFs.
474        coder: CoderType,
475    },
476    /// Generic rate-coded compressor wrapping an arbitrary rate backend.
477    Rate {
478        /// Predictive rate backend.
479        rate_backend: RateBackend,
480        /// Entropy coder used for coding model PDFs.
481        coder: CoderType,
482        /// Framing mode for output payloads.
483        framing: compression::FramingMode,
484    },
485}
486
487/// Mixture policy kind for rate-backend mixtures.
488#[derive(Clone, Copy, Debug, Eq, PartialEq)]
489pub enum MixtureKind {
490    /// Standard Bayesian mixture with fixed expert weights.
491    Bayes,
492    /// Bayesian mixture with exponential weight decay.
493    FadingBayes,
494    /// Switching mixture with hazard `alpha`.
495    Switching,
496    /// MDL-style best-expert selector.
497    Mdl,
498    /// Bytewise neural logistic mixer (fx2-cmix style adaptation).
499    Neural,
500}
501
502/// Fixed context families for calibrated PDF wrappers.
503#[derive(Clone, Copy, Debug, Eq, PartialEq)]
504pub enum CalibrationContextKind {
505    /// Single global calibration row.
506    Global,
507    /// Previous-byte class only.
508    ByteClass,
509    /// Text-structure-aware context hash.
510    Text,
511    /// Repeat-aware context hash.
512    Repeat,
513    /// Joint text/repeat-aware context hash.
514    TextRepeat,
515}
516
517/// Configuration for a calibrated wrapper rate backend.
518#[derive(Clone)]
519pub struct CalibratedSpec {
520    /// Base backend whose PDF is calibrated.
521    pub base: RateBackend,
522    /// Context family controlling table row selection.
523    pub context: CalibrationContextKind,
524    /// Number of probability bins per row.
525    pub bins: usize,
526    /// Online learning rate for observed-symbol updates.
527    pub learning_rate: f64,
528    /// Symmetric clip applied to calibration weights.
529    pub bias_clip: f64,
530}
531
532/// Expert specification for mixture backends.
533#[derive(Clone)]
534pub struct MixtureExpertSpec {
535    /// Optional expert display name.
536    pub name: Option<String>,
537    /// Log prior weight (natural log). Uniform priors can be `0.0`.
538    pub log_prior: f64,
539    /// Max order for ROSA experts (ignored for other backends).
540    pub max_order: i64,
541    /// Underlying backend for this expert.
542    pub backend: RateBackend,
543}
544
545/// Mixture specification for rate-backend mixtures.
546#[derive(Clone)]
547pub struct MixtureSpec {
548    /// Mixture policy.
549    pub kind: MixtureKind,
550    /// Switching probability (per step) for switching mixtures.
551    pub alpha: f64,
552    /// Decay factor for fading Bayes mixtures.
553    pub decay: Option<f64>,
554    /// Expert list.
555    pub experts: Vec<MixtureExpertSpec>,
556}
557
558impl MixtureSpec {
559    /// Build a mixture specification from kind and expert list.
560    pub fn new(kind: MixtureKind, experts: Vec<MixtureExpertSpec>) -> Self {
561        Self {
562            kind,
563            alpha: 0.01,
564            decay: None,
565            experts,
566        }
567    }
568
569    /// Set switching hazard / adaptation parameter.
570    pub fn with_alpha(mut self, alpha: f64) -> Self {
571        self.alpha = alpha;
572        self
573    }
574
575    /// Set fading decay factor.
576    pub fn with_decay(mut self, decay: f64) -> Self {
577        self.decay = Some(decay);
578        self
579    }
580
581    /// Convert to executable expert configs for runtime mixture evaluation.
582    pub fn build_experts(&self) -> Vec<crate::mixture::ExpertConfig> {
583        self.experts
584            .iter()
585            .map(|spec| {
586                crate::mixture::ExpertConfig::from_rate_backend(
587                    spec.name.clone(),
588                    spec.log_prior,
589                    spec.backend.clone(),
590                    spec.max_order,
591                )
592            })
593            .collect()
594    }
595}
596
597/// Configuration for a particle-latent filter ensemble rate backend.
598#[derive(Clone, Debug)]
599pub struct ParticleSpec {
600    /// Number of particles in the ensemble.
601    pub num_particles: usize,
602    /// Context window length for rolling byte context.
603    pub context_window: usize,
604    /// Number of latent update unroll steps per byte.
605    pub unroll_steps: usize,
606    /// Number of latent cells per particle.
607    pub num_cells: usize,
608    /// Dimensionality of each latent cell.
609    pub cell_dim: usize,
610    /// Number of discrete rules for soft routing.
611    pub num_rules: usize,
612    /// Hidden dimension for the selector MLP.
613    pub selector_hidden: usize,
614    /// Hidden dimension for each rule MLP.
615    pub rule_hidden: usize,
616    /// Dimension of per-rule noise input (ignored when deterministic).
617    pub noise_dim: usize,
618    /// Whether to use fully deterministic execution (no RNG).
619    pub deterministic: bool,
620    /// Whether to inject noise into rule inputs (ignored when deterministic).
621    pub enable_noise: bool,
622    /// Base scale for deterministic hash-noise injected into rule inputs.
623    pub noise_scale: f64,
624    /// Number of steps over which injected noise linearly anneals to zero.
625    pub noise_anneal_steps: usize,
626    /// Learning rate for readout layer SGD.
627    pub learning_rate_readout: f64,
628    /// Learning rate for selector MLP SGD.
629    pub learning_rate_selector: f64,
630    /// Learning rate for rule MLP SGD.
631    pub learning_rate_rule: f64,
632    /// Truncated backpropagation-through-time depth (number of recent steps).
633    pub bptt_depth: usize,
634    /// Momentum coefficient for selector/rule online updates (in [0, 1)).
635    pub optimizer_momentum: f64,
636    /// Gradient clipping threshold (max abs value per element).
637    pub grad_clip: f64,
638    /// Latent cell state clipping threshold (max abs value per element).
639    pub state_clip: f64,
640    /// Forgetting factor for particle log-weights (0 = no forgetting).
641    pub forget_lambda: f64,
642    /// Effective sample size ratio threshold for resampling (in (0, 1]).
643    pub resample_threshold: f64,
644    /// Fraction of particles to mutate after resampling (in [0, 1]).
645    pub mutate_fraction: f64,
646    /// Scale of hash-noise perturbation applied during mutation.
647    pub mutate_scale: f64,
648    /// Whether mutation also perturbs model parameters (state is always mutated).
649    pub mutate_model_params: bool,
650    /// Diagnostics print interval in steps (0 disables particle diagnostics logs).
651    pub diagnostics_interval: usize,
652    /// Minimum probability floor for numerical stability.
653    pub min_prob: f64,
654    /// Master seed for deterministic initialization and mutation.
655    pub seed: u64,
656}
657
658impl Default for ParticleSpec {
659    fn default() -> Self {
660        Self {
661            num_particles: 16,
662            context_window: 32,
663            unroll_steps: 2,
664            num_cells: 8,
665            cell_dim: 32,
666            num_rules: 4,
667            selector_hidden: 64,
668            rule_hidden: 64,
669            noise_dim: 8,
670            deterministic: true,
671            enable_noise: false,
672            noise_scale: 0.10,
673            noise_anneal_steps: 8192,
674            learning_rate_readout: 0.01,
675            learning_rate_selector: 1e-4,
676            learning_rate_rule: 3e-4,
677            bptt_depth: 3,
678            optimizer_momentum: 0.05,
679            grad_clip: 1.0,
680            state_clip: 8.0,
681            forget_lambda: 0.0,
682            resample_threshold: 0.5,
683            mutate_fraction: 0.1,
684            mutate_scale: 0.01,
685            mutate_model_params: false,
686            diagnostics_interval: 0,
687            min_prob: 2f64.powi(-24),
688            seed: 42,
689        }
690    }
691}
692
693impl ParticleSpec {
694    /// Validate all fields, returning an error message on failure.
695    pub fn validate(&self) -> Result<(), String> {
696        if self.num_particles == 0 {
697            return Err("num_particles must be > 0".into());
698        }
699        if self.context_window == 0 {
700            return Err("context_window must be > 0".into());
701        }
702        if self.unroll_steps == 0 {
703            return Err("unroll_steps must be > 0".into());
704        }
705        if self.num_cells == 0 {
706            return Err("num_cells must be > 0".into());
707        }
708        if self.cell_dim == 0 {
709            return Err("cell_dim must be > 0".into());
710        }
711        if self.num_rules == 0 {
712            return Err("num_rules must be > 0".into());
713        }
714        if self.selector_hidden == 0 {
715            return Err("selector_hidden must be > 0".into());
716        }
717        if self.rule_hidden == 0 {
718            return Err("rule_hidden must be > 0".into());
719        }
720        if !self.learning_rate_readout.is_finite() || self.learning_rate_readout < 0.0 {
721            return Err("learning_rate_readout must be finite and non-negative".into());
722        }
723        if !self.learning_rate_selector.is_finite() || self.learning_rate_selector < 0.0 {
724            return Err("learning_rate_selector must be finite and non-negative".into());
725        }
726        if !self.learning_rate_rule.is_finite() || self.learning_rate_rule < 0.0 {
727            return Err("learning_rate_rule must be finite and non-negative".into());
728        }
729        if !self.noise_scale.is_finite() || self.noise_scale < 0.0 {
730            return Err("noise_scale must be finite and non-negative".into());
731        }
732        if !self.optimizer_momentum.is_finite()
733            || self.optimizer_momentum < 0.0
734            || self.optimizer_momentum >= 1.0
735        {
736            return Err("optimizer_momentum must be finite and in [0, 1)".into());
737        }
738        if self.bptt_depth == 0 {
739            return Err("bptt_depth must be > 0".into());
740        }
741        if !(self.resample_threshold > 0.0 && self.resample_threshold <= 1.0) {
742            return Err("resample_threshold must be in (0, 1]".into());
743        }
744        if !(self.mutate_fraction >= 0.0 && self.mutate_fraction <= 1.0) {
745            return Err("mutate_fraction must be in [0, 1]".into());
746        }
747        if !(self.min_prob > 0.0 && self.min_prob < 0.5) {
748            return Err("min_prob must be in (0, 0.5)".into());
749        }
750        Ok(())
751    }
752}
753
754/// Reusable execution context holding default rate and compression backends.
755#[derive(Clone, Default)]
756pub struct InfotheoryCtx {
757    /// Default rate backend for entropy/rate metrics.
758    pub rate_backend: RateBackend,
759    /// Default compression backend for NCD/compression primitives.
760    pub compression_backend: CompressionBackend,
761}
762
763/// Stateful rate-backend session for fitting, conditioning, and continuation.
764pub struct RateBackendSession {
765    predictor: crate::mixture::RateBackendPredictor,
766}
767
768impl RateBackendSession {
769    /// Create a session from an explicit backend.
770    pub fn from_backend(
771        backend: RateBackend,
772        max_order: i64,
773        total_symbols: Option<u64>,
774    ) -> Result<Self, String> {
775        use crate::mixture::OnlineBytePredictor;
776
777        let mut predictor = crate::mixture::RateBackendPredictor::from_backend(
778            backend,
779            max_order,
780            crate::mixture::DEFAULT_MIN_PROB,
781        );
782        predictor.begin_stream(total_symbols)?;
783        Ok(Self { predictor })
784    }
785
786    /// Observe bytes while adapting/fitting the model.
787    pub fn observe(&mut self, data: &[u8]) {
788        use crate::mixture::OnlineBytePredictor;
789
790        for &byte in data {
791            self.predictor.update(byte);
792        }
793    }
794
795    /// Advance conditioning state without changing fitted parameters/statistics.
796    pub fn condition(&mut self, data: &[u8]) {
797        use crate::mixture::OnlineBytePredictor;
798
799        for &byte in data {
800            self.predictor.update_frozen(byte);
801        }
802    }
803
804    /// Reset dynamic conditioning state while preserving fitted parameters/statistics.
805    pub fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
806        use crate::mixture::OnlineBytePredictor;
807
808        self.predictor.reset_frozen(total_symbols)
809    }
810
811    /// Fill the 256-way next-byte log-probabilities.
812    pub fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
813        use crate::mixture::OnlineBytePredictor;
814
815        self.predictor.fill_log_probs(out);
816    }
817
818    /// Generate continuation bytes from the current state.
819    pub fn generate_bytes(&mut self, bytes: usize, config: GenerationConfig) -> Vec<u8> {
820        use crate::mixture::OnlineBytePredictor;
821
822        if bytes == 0 {
823            return Vec::new();
824        }
825
826        let mut out = Vec::with_capacity(bytes);
827        let mut logps = [0.0f64; 256];
828        let mut rng = GenerationRng::new(config.seed);
829
830        for _ in 0..bytes {
831            match &mut self.predictor {
832                // ROSA's scalar path remains the reference for continuation generation.
833                crate::mixture::RateBackendPredictor::Rosa { .. } => {
834                    for (sym, slot) in logps.iter_mut().enumerate() {
835                        *slot = self.predictor.log_prob(sym as u8);
836                    }
837                }
838                _ => self.predictor.fill_log_probs(&mut logps),
839            }
840            let byte = pick_generated_byte(&logps, config, &mut rng);
841            match config.update_mode {
842                GenerationUpdateMode::Adaptive => self.predictor.update(byte),
843                GenerationUpdateMode::Frozen => self.predictor.update_frozen(byte),
844            }
845            out.push(byte);
846        }
847
848        out
849    }
850
851    /// Finalize the underlying stream if the backend needs it.
852    pub fn finish(&mut self) -> Result<(), String> {
853        use crate::mixture::OnlineBytePredictor;
854
855        self.predictor.finish_stream()
856    }
857}
858
859impl InfotheoryCtx {
860    /// Create a context from explicit rate and compression backends.
861    pub fn new(rate_backend: RateBackend, compression_backend: CompressionBackend) -> Self {
862        Self {
863            rate_backend,
864            compression_backend,
865        }
866    }
867
868    /// Create a context with ROSA+ rate backend and ZPAQ compression backend.
869    pub fn with_zpaq(method: impl Into<String>) -> Self {
870        Self {
871            rate_backend: RateBackend::RosaPlus,
872            compression_backend: CompressionBackend::Zpaq {
873                method: method.into(),
874            },
875        }
876    }
877
878    /// Compressed length of one byte slice under this context's compressor.
879    pub fn compress_size(&self, data: &[u8]) -> u64 {
880        compress_size_backend(data, &self.compression_backend)
881    }
882
883    /// Compressed length of chained slices under one stream.
884    pub fn compress_size_chain(&self, parts: &[&[u8]]) -> u64 {
885        compress_size_chain_backend(parts, &self.compression_backend)
886    }
887
888    /// Create a stateful session for the active rate backend.
889    pub fn rate_backend_session(
890        &self,
891        max_order: i64,
892        total_symbols: Option<u64>,
893    ) -> Result<RateBackendSession, String> {
894        RateBackendSession::from_backend(self.rate_backend.clone(), max_order, total_symbols)
895    }
896
897    /// Entropy-rate estimate for `data` under this context's rate backend.
898    pub fn entropy_rate_bytes(&self, data: &[u8], max_order: i64) -> f64 {
899        entropy_rate_backend(data, max_order, &self.rate_backend)
900    }
901
902    /// Biased entropy-rate estimate (plugin variant) for `data`.
903    pub fn biased_entropy_rate_bytes(&self, data: &[u8], max_order: i64) -> f64 {
904        biased_entropy_rate_backend(data, max_order, &self.rate_backend)
905    }
906
907    /// Cross entropy of `test_data` under model trained on `train_data`.
908    pub fn cross_entropy_rate_bytes(
909        &self,
910        test_data: &[u8],
911        train_data: &[u8],
912        max_order: i64,
913    ) -> f64 {
914        cross_entropy_rate_backend(test_data, train_data, max_order, &self.rate_backend)
915    }
916
917    /// Cross entropy with order-0 fast-path fallback when `max_order == 0`.
918    pub fn cross_entropy_bytes(&self, test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
919        if max_order == 0 {
920            if test_data.is_empty() {
921                return 0.0;
922            }
923            let p_x = byte_histogram(test_data);
924            let p_y = byte_histogram(train_data);
925            let mut h = 0.0f64;
926            for i in 0..256 {
927                if p_x[i] > 0.0 {
928                    let q_y = p_y[i].max(1e-12);
929                    h -= p_x[i] * q_y.log2();
930                }
931            }
932            h
933        } else {
934            self.cross_entropy_rate_bytes(test_data, train_data, max_order)
935        }
936    }
937
938    /// Joint entropy-rate estimate `H(X,Y)` under aligned-prefix semantics.
939    pub fn joint_entropy_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
940        let (x, y) = aligned_prefix(x, y);
941        if x.is_empty() {
942            return 0.0;
943        }
944        joint_entropy_rate_backend(x, y, max_order, &self.rate_backend)
945    }
946
947    /// Conditional entropy-rate estimate `H(X|Y)`.
948    pub fn conditional_entropy_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
949        let (x, y) = aligned_prefix(x, y);
950        if x.is_empty() {
951            return 0.0;
952        }
953        let h_xy = self.joint_entropy_rate_bytes(x, y, max_order);
954        let h_y = self.entropy_rate_bytes(y, max_order);
955        (h_xy - h_y).max(0.0)
956    }
957
958    /// Compute `H(data | prefix_parts)` by conditioning the active rate backend
959    /// on an explicit prefix chain.
960    pub fn cross_entropy_conditional_chain(&self, prefix_parts: &[&[u8]], data: &[u8]) -> f64 {
961        match &self.rate_backend {
962            RateBackend::RosaPlus => {
963                let mut prefix = Vec::new();
964                let total: usize = prefix_parts.iter().map(|p| p.len()).sum();
965                prefix.reserve(total);
966                for p in prefix_parts {
967                    prefix.extend_from_slice(p);
968                }
969                cross_entropy_rate_backend(data, &prefix, -1, &RateBackend::RosaPlus)
970            }
971            RateBackend::Match { .. }
972            | RateBackend::SparseMatch { .. }
973            | RateBackend::Ppmd { .. }
974            | RateBackend::Calibrated { .. } => {
975                prequential_rate_backend(data, prefix_parts, -1, &self.rate_backend)
976            }
977            #[cfg(feature = "backend-rwkv")]
978            RateBackend::Rwkv7 { model } => with_rwkv_tls(model, |c| {
979                c.cross_entropy_conditional_chain(prefix_parts, data)
980                    .unwrap_or_else(|e| panic!("rwkv conditional-chain scoring failed: {e:#}"))
981            }),
982            #[cfg(feature = "backend-rwkv")]
983            RateBackend::Rwkv7Method { method } => with_rwkv_method_tls(method, |c| {
984                c.cross_entropy_conditional_chain(prefix_parts, data)
985                    .unwrap_or_else(|e| {
986                        panic!("rwkv method conditional-chain scoring failed: {e:#}")
987                    })
988            }),
989            #[cfg(feature = "backend-mamba")]
990            RateBackend::Mamba { model } => with_mamba_tls(model, |c| {
991                c.cross_entropy_conditional_chain(prefix_parts, data)
992                    .unwrap_or_else(|e| panic!("mamba conditional-chain scoring failed: {e:#}"))
993            }),
994            #[cfg(feature = "backend-mamba")]
995            RateBackend::MambaMethod { method } => with_mamba_method_tls(method, |c| {
996                c.cross_entropy_conditional_chain(prefix_parts, data)
997                    .unwrap_or_else(|e| {
998                        panic!("mamba method conditional-chain scoring failed: {e:#}")
999                    })
1000            }),
1001            RateBackend::Ctw { depth } => {
1002                if data.is_empty() {
1003                    return 0.0;
1004                }
1005                let mut tree = crate::ctw::ContextTree::new(*depth);
1006                for &part in prefix_parts {
1007                    for &b in part {
1008                        for i in (0..8).rev() {
1009                            tree.update(((b >> i) & 1) == 1);
1010                        }
1011                    }
1012                }
1013                let log_p_prefix = tree.get_log_block_probability();
1014                for &b in data {
1015                    for i in (0..8).rev() {
1016                        tree.update(((b >> i) & 1) == 1);
1017                    }
1018                }
1019                let log_p_joint = tree.get_log_block_probability();
1020                let log_p_cond = log_p_joint - log_p_prefix;
1021                let bits = -log_p_cond / std::f64::consts::LN_2;
1022                bits / (data.len() as f64)
1023            }
1024            RateBackend::Zpaq { method } => {
1025                if data.is_empty() {
1026                    return 0.0;
1027                }
1028                let mut model =
1029                    crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
1030                for &part in prefix_parts {
1031                    model.update_and_score(part);
1032                }
1033                let bits = model.update_and_score(data);
1034                bits / (data.len() as f64)
1035            }
1036            RateBackend::Mixture { spec } => {
1037                if data.is_empty() {
1038                    return 0.0;
1039                }
1040                let experts = spec.build_experts();
1041                let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
1042                    .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
1043                let total = prefix_parts
1044                    .iter()
1045                    .map(|p| p.len() as u64)
1046                    .sum::<u64>()
1047                    .saturating_add(data.len() as u64);
1048                mix.begin_stream(Some(total))
1049                    .unwrap_or_else(|e| panic!("Mixture stream init failed: {e}"));
1050                for &part in prefix_parts {
1051                    for &b in part {
1052                        mix.step(b);
1053                    }
1054                }
1055                let mut bits = 0.0;
1056                for &b in data {
1057                    bits -= mix.step(b) / std::f64::consts::LN_2;
1058                }
1059                bits / (data.len() as f64)
1060            }
1061            RateBackend::Particle { spec } => {
1062                if data.is_empty() {
1063                    return 0.0;
1064                }
1065                let mut runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
1066                for &part in prefix_parts {
1067                    for &b in part {
1068                        runtime.step(b);
1069                    }
1070                }
1071                let mut bits = 0.0;
1072                for &b in data {
1073                    bits -= runtime.step(b) / std::f64::consts::LN_2;
1074                }
1075                bits / (data.len() as f64)
1076            }
1077            RateBackend::FacCtw {
1078                base_depth,
1079                num_percept_bits: _,
1080                encoding_bits,
1081            } => {
1082                if data.is_empty() {
1083                    return 0.0;
1084                }
1085                let bits_per_byte = (*encoding_bits).clamp(1, 8);
1086                let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte);
1087                for &part in prefix_parts {
1088                    for &b in part {
1089                        // Fix Issue 1: LSB-first
1090                        for i in 0..bits_per_byte {
1091                            let bit_idx = i;
1092                            // b >> i gets the i-th bit (0 is LSB)
1093                            fac.update(((b >> i) & 1) == 1, bit_idx);
1094                        }
1095                    }
1096                }
1097                let log_p_prefix = fac.get_log_block_probability();
1098                for &b in data {
1099                    for i in 0..bits_per_byte {
1100                        let bit_idx = i;
1101                        fac.update(((b >> i) & 1) == 1, bit_idx);
1102                    }
1103                }
1104                let log_p_joint = fac.get_log_block_probability();
1105                let log_p_cond = log_p_joint - log_p_prefix;
1106                let bits = -log_p_cond / std::f64::consts::LN_2;
1107                bits / (data.len() as f64)
1108            }
1109        }
1110    }
1111
1112    /// Generate a continuation from `prompt` with [`GenerationConfig::default()`].
1113    ///
1114    /// The default is deterministic frozen sampling with seed `42`.
1115    pub fn generate_bytes(&self, prompt: &[u8], bytes: usize, max_order: i64) -> Vec<u8> {
1116        self.generate_bytes_with_config(prompt, bytes, max_order, GenerationConfig::default())
1117    }
1118
1119    /// Generate a continuation from `prompt` using an explicit generation config.
1120    pub fn generate_bytes_with_config(
1121        &self,
1122        prompt: &[u8],
1123        bytes: usize,
1124        max_order: i64,
1125        config: GenerationConfig,
1126    ) -> Vec<u8> {
1127        generate_rate_backend_chain(&[prompt], bytes, max_order, &self.rate_backend, config)
1128    }
1129
1130    /// Generate a continuation after conditioning on an explicit chain of prefix parts.
1131    pub fn generate_bytes_conditional_chain(
1132        &self,
1133        prefix_parts: &[&[u8]],
1134        bytes: usize,
1135        max_order: i64,
1136    ) -> Vec<u8> {
1137        self.generate_bytes_conditional_chain_with_config(
1138            prefix_parts,
1139            bytes,
1140            max_order,
1141            GenerationConfig::default(),
1142        )
1143    }
1144
1145    /// Generate a continuation after conditioning on an explicit chain of prefix parts.
1146    pub fn generate_bytes_conditional_chain_with_config(
1147        &self,
1148        prefix_parts: &[&[u8]],
1149        bytes: usize,
1150        max_order: i64,
1151        config: GenerationConfig,
1152    ) -> Vec<u8> {
1153        generate_rate_backend_chain(prefix_parts, bytes, max_order, &self.rate_backend, config)
1154    }
1155
1156    /// NCD between byte slices using this context's compression backend.
1157    pub fn ncd_bytes(&self, x: &[u8], y: &[u8], variant: NcdVariant) -> f64 {
1158        ncd_bytes_backend(x, y, &self.compression_backend, variant)
1159    }
1160
1161    /// Rate-backend mutual information estimate.
1162    pub fn mutual_information_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1163        mutual_information_rate_backend(x, y, max_order, &self.rate_backend)
1164    }
1165
1166    /// Mutual information with `max_order == 0` marginal fast-path.
1167    pub fn mutual_information_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1168        if max_order == 0 {
1169            mutual_information_marg_bytes(x, y)
1170        } else {
1171            self.mutual_information_rate_bytes(x, y, max_order)
1172        }
1173    }
1174
1175    /// Conditional entropy with aligned-prefix semantics.
1176    pub fn conditional_entropy_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1177        let (x, y) = aligned_prefix(x, y);
1178        if max_order == 0 {
1179            let h_xy = joint_marginal_entropy_bytes(x, y);
1180            let h_y = marginal_entropy_bytes(y);
1181            (h_xy - h_y).max(0.0)
1182        } else {
1183            let h_xy = self.joint_entropy_rate_bytes(x, y, max_order);
1184            let h_y = self.entropy_rate_bytes(y, max_order);
1185            (h_xy - h_y).max(0.0)
1186        }
1187    }
1188
1189    /// Normalized entropy distance (NED) under this context.
1190    pub fn ned_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1191        if max_order == 0 {
1192            ned_marg_bytes(x, y)
1193        } else {
1194            ned_rate_backend(x, y, max_order, &self.rate_backend)
1195        }
1196    }
1197
1198    /// Conservative NED normalization variant.
1199    pub fn ned_cons_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1200        let (x, y) = aligned_prefix(x, y);
1201        let (h_x, h_y, h_xy) = if max_order == 0 {
1202            (
1203                marginal_entropy_bytes(x),
1204                marginal_entropy_bytes(y),
1205                joint_marginal_entropy_bytes(x, y),
1206            )
1207        } else {
1208            (
1209                self.entropy_rate_bytes(x, max_order),
1210                self.entropy_rate_bytes(y, max_order),
1211                self.joint_entropy_rate_bytes(x, y, max_order),
1212            )
1213        };
1214        let min_h = h_x.min(h_y);
1215        if h_xy == 0.0 {
1216            0.0
1217        } else {
1218            ((h_xy - min_h) / h_xy).clamp(0.0, 1.0)
1219        }
1220    }
1221
1222    /// Normalized transform effort (NTE) under this context.
1223    pub fn nte_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1224        if max_order == 0 {
1225            nte_marg_bytes(x, y)
1226        } else {
1227            nte_rate_backend(x, y, max_order, &self.rate_backend)
1228        }
1229    }
1230
1231    /// Intrinsic dependence score in `[0,1]`.
1232    pub fn intrinsic_dependence_bytes(&self, data: &[u8], max_order: i64) -> f64 {
1233        let h_marginal = marginal_entropy_bytes(data);
1234        if h_marginal < 1e-9 {
1235            return 0.0;
1236        }
1237        let h_rate = self.entropy_rate_bytes(data, max_order);
1238        ((h_marginal - h_rate) / h_marginal).clamp(0.0, 1.0)
1239    }
1240
1241    /// Resistance-to-transformation ratio `I(X;T(X))/H(X)` in `[0,1]`.
1242    pub fn resistance_to_transformation_bytes(&self, x: &[u8], tx: &[u8], max_order: i64) -> f64 {
1243        let (x, tx) = aligned_prefix(x, tx);
1244        let h_x = if max_order == 0 {
1245            marginal_entropy_bytes(x)
1246        } else {
1247            self.entropy_rate_bytes(x, max_order)
1248        };
1249        if h_x < 1e-9 {
1250            return 0.0;
1251        }
1252        let mi = self.mutual_information_bytes(x, tx, max_order);
1253        (mi / h_x).clamp(0.0, 1.0)
1254    }
1255}
1256
1257#[cfg(feature = "backend-rwkv")]
1258/// Load an RWKV7 model from `.safetensors` path.
1259pub fn load_rwkv7_model_from_path(path: &str) -> Arc<rwkvzip::Model> {
1260    rwkvzip::Compressor::load_model(path).expect("failed to load RWKV7 model")
1261}
1262
1263#[cfg(feature = "backend-mamba")]
1264/// Load a Mamba-1 model from `.safetensors` path.
1265pub fn load_mamba_model_from_path(path: &str) -> Arc<mambazip::Model> {
1266    mambazip::Compressor::load_model(path).expect("failed to load Mamba model")
1267}
1268
1269#[inline(always)]
1270fn aligned_prefix<'a>(x: &'a [u8], y: &'a [u8]) -> (&'a [u8], &'a [u8]) {
1271    let n = x.len().min(y.len());
1272    (&x[..n], &y[..n])
1273}
1274
1275#[cfg(feature = "backend-zpaq")]
1276#[inline(always)]
1277fn zpaq_compress_size_bytes(data: &[u8], method: &str) -> u64 {
1278    zpaq_rs::compress_size(data, method).unwrap_or(0)
1279}
1280
1281#[cfg(not(feature = "backend-zpaq"))]
1282#[inline(always)]
1283fn zpaq_compress_size_bytes(_data: &[u8], _method: &str) -> u64 {
1284    panic!("CompressionBackend::Zpaq is unavailable: build with feature 'backend-zpaq'")
1285}
1286
1287#[cfg(feature = "backend-zpaq")]
1288#[inline(always)]
1289fn zpaq_compress_size_parallel_bytes(data: &[u8], method: &str, threads: usize) -> u64 {
1290    zpaq_rs::compress_size_parallel(data, method, threads).unwrap_or(0)
1291}
1292
1293#[cfg(not(feature = "backend-zpaq"))]
1294#[inline(always)]
1295fn zpaq_compress_size_parallel_bytes(_data: &[u8], _method: &str, _threads: usize) -> u64 {
1296    panic!("CompressionBackend::Zpaq is unavailable: build with feature 'backend-zpaq'")
1297}
1298
1299#[cfg(feature = "backend-zpaq")]
1300#[inline(always)]
1301fn zpaq_compress_size_stream<R: std::io::Read + Send>(reader: R, method: &str) -> u64 {
1302    zpaq_rs::compress_size_stream(reader, method, None, None).unwrap_or(0)
1303}
1304
1305#[cfg(not(feature = "backend-zpaq"))]
1306#[inline(always)]
1307fn zpaq_compress_size_stream<R: std::io::Read + Send>(_reader: R, _method: &str) -> u64 {
1308    panic!("CompressionBackend::Zpaq is unavailable: build with feature 'backend-zpaq'")
1309}
1310
1311#[cfg(feature = "backend-zpaq")]
1312#[inline(always)]
1313fn zpaq_compress_to_vec(data: &[u8], method: &str) -> anyhow::Result<Vec<u8>> {
1314    Ok(zpaq_rs::compress_to_vec(data, method)?)
1315}
1316
1317#[cfg(not(feature = "backend-zpaq"))]
1318#[inline(always)]
1319fn zpaq_compress_to_vec(_data: &[u8], _method: &str) -> anyhow::Result<Vec<u8>> {
1320    anyhow::bail!("zpaq backend disabled at compile time (enable feature 'backend-zpaq')")
1321}
1322
1323#[cfg(feature = "backend-zpaq")]
1324#[inline(always)]
1325fn zpaq_decompress_to_vec(data: &[u8]) -> anyhow::Result<Vec<u8>> {
1326    Ok(zpaq_rs::decompress_to_vec(data)?)
1327}
1328
1329#[cfg(not(feature = "backend-zpaq"))]
1330#[inline(always)]
1331fn zpaq_decompress_to_vec(_data: &[u8]) -> anyhow::Result<Vec<u8>> {
1332    anyhow::bail!("zpaq backend disabled at compile time (enable feature 'backend-zpaq')")
1333}
1334
1335/// ------- Base Compression Functions -------
1336#[inline(always)]
1337pub fn get_compressed_size(path: &str, method: &str) -> u64 {
1338    // Convert Input file to Vec<u8>, and reference that (compress_size only takes &[u8] input), and pass method.
1339    // Will panic if file does not exist, so it must be prevalidated.
1340    zpaq_compress_size_bytes(&std::fs::read(path).unwrap(), method)
1341}
1342
1343/// Validate that a ZPAQ method string is supported for rate estimation.
1344pub fn validate_zpaq_rate_method(method: &str) -> Result<(), String> {
1345    #[cfg(feature = "backend-zpaq")]
1346    {
1347        zpaq_rate::validate_zpaq_rate_method(method)
1348    }
1349    #[cfg(not(feature = "backend-zpaq"))]
1350    {
1351        let _ = method;
1352        Err("zpaq backend disabled at compile time".to_string())
1353    }
1354}
1355
1356#[cfg(feature = "backend-rwkv")]
1357fn with_rwkv_tls<R>(
1358    model: &Arc<rwkvzip::Model>,
1359    f: impl FnOnce(&mut rwkvzip::Compressor) -> R,
1360) -> R {
1361    let key = Arc::as_ptr(model) as usize;
1362    RWKV_TLS.with(|cell| {
1363        let mut map = cell.borrow_mut();
1364        let comp = map
1365            .entry(key)
1366            .or_insert_with(|| rwkvzip::Compressor::new_from_model(model.clone()));
1367        f(comp)
1368    })
1369}
1370
1371#[cfg(feature = "backend-rwkv")]
1372fn with_rwkv_method_tls<R>(method: &str, f: impl FnOnce(&mut rwkvzip::Compressor) -> R) -> R {
1373    RWKV_METHOD_TLS.with(|cell| {
1374        let mut map = cell.borrow_mut();
1375        // Keep a per-method template compressor for fast cloning while ensuring
1376        // each call gets isolated mutable runtime state (no cross-call leakage).
1377        let mut comp = if let Some(template) = map.get(method) {
1378            template.clone()
1379        } else {
1380            let template = rwkvzip::Compressor::new_from_method(method).unwrap_or_else(|e| {
1381                panic!("invalid rwkv method '{method}': {e:#}");
1382            });
1383            map.insert(method.to_string(), template.clone());
1384            template
1385        };
1386        drop(map);
1387        f(&mut comp)
1388    })
1389}
1390
1391#[cfg(feature = "backend-rwkv")]
1392fn with_rwkv_rate_tls<R>(
1393    model: &Arc<rwkvzip::Model>,
1394    f: impl FnOnce(&mut rwkvzip::Compressor) -> R,
1395) -> R {
1396    let key = Arc::as_ptr(model) as usize;
1397    RWKV_RATE_TLS.with(|cell| {
1398        let mut map = cell.borrow_mut();
1399        let mut comp = if let Some(template) = map.get(&key) {
1400            template.clone()
1401        } else {
1402            let template = rwkvzip::Compressor::new_from_model(model.clone());
1403            map.insert(key, template.clone());
1404            template
1405        };
1406        drop(map);
1407        f(&mut comp)
1408    })
1409}
1410
1411#[cfg(feature = "backend-mamba")]
1412fn with_mamba_tls<R>(
1413    model: &Arc<mambazip::Model>,
1414    f: impl FnOnce(&mut mambazip::Compressor) -> R,
1415) -> R {
1416    let key = Arc::as_ptr(model) as usize;
1417    MAMBA_TLS.with(|cell| {
1418        let mut map = cell.borrow_mut();
1419        let comp = map
1420            .entry(key)
1421            .or_insert_with(|| mambazip::Compressor::new_from_model(model.clone()));
1422        f(comp)
1423    })
1424}
1425
1426#[cfg(feature = "backend-mamba")]
1427fn with_mamba_rate_tls<R>(
1428    model: &Arc<mambazip::Model>,
1429    f: impl FnOnce(&mut mambazip::Compressor) -> R,
1430) -> R {
1431    let key = Arc::as_ptr(model) as usize;
1432    MAMBA_RATE_TLS.with(|cell| {
1433        let mut map = cell.borrow_mut();
1434        let mut comp = if let Some(template) = map.get(&key) {
1435            template.clone()
1436        } else {
1437            let template = mambazip::Compressor::new_from_model(model.clone());
1438            map.insert(key, template.clone());
1439            template
1440        };
1441        drop(map);
1442        f(&mut comp)
1443    })
1444}
1445
1446#[cfg(feature = "backend-mamba")]
1447fn with_mamba_method_tls<R>(method: &str, f: impl FnOnce(&mut mambazip::Compressor) -> R) -> R {
1448    MAMBA_METHOD_TLS.with(|cell| {
1449        let mut map = cell.borrow_mut();
1450        let mut comp = if let Some(template) = map.get(method) {
1451            template.clone()
1452        } else {
1453            let template = mambazip::Compressor::new_from_method(method).unwrap_or_else(|e| {
1454                panic!("invalid mamba method '{method}': {e:#}");
1455            });
1456            map.insert(method.to_string(), template.clone());
1457            template
1458        };
1459        drop(map);
1460        f(&mut comp)
1461    })
1462}
1463
1464struct SliceChainReader<'a> {
1465    parts: &'a [&'a [u8]],
1466    i: usize,
1467    off: usize,
1468}
1469
1470impl<'a> SliceChainReader<'a> {
1471    fn new(parts: &'a [&'a [u8]]) -> Self {
1472        Self {
1473            parts,
1474            i: 0,
1475            off: 0,
1476        }
1477    }
1478}
1479
1480impl<'a> std::io::Read for SliceChainReader<'a> {
1481    fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
1482        let mut total = 0;
1483        if buf.is_empty() {
1484            return Ok(0);
1485        }
1486        while self.i < self.parts.len() {
1487            let p = self.parts[self.i];
1488            if self.off >= p.len() {
1489                self.i += 1;
1490                self.off = 0;
1491                continue;
1492            }
1493            let n = (p.len() - self.off).min(buf.len());
1494            // Safe copy slice
1495            buf[..n].copy_from_slice(&p[self.off..self.off + n]);
1496
1497            // Advance state
1498            self.off += n;
1499            total += n;
1500
1501            // Re-slice buf to fill remainder
1502            let tmp = buf;
1503            buf = &mut tmp[n..];
1504
1505            if buf.is_empty() {
1506                break;
1507            }
1508        }
1509        Ok(total)
1510    }
1511}
1512
1513/// Compute compressed size of a chain of byte slices with a selected compression backend.
1514pub fn compress_size_chain_backend(parts: &[&[u8]], backend: &CompressionBackend) -> u64 {
1515    match backend {
1516        CompressionBackend::Zpaq { method } => {
1517            let r = SliceChainReader::new(parts);
1518            zpaq_compress_size_stream(r, method.as_str())
1519        }
1520        #[cfg(feature = "backend-rwkv")]
1521        CompressionBackend::Rwkv7 { model, coder } => {
1522            with_rwkv_tls(model, |c| c.compress_size_chain(parts, *coder).unwrap_or(0))
1523        }
1524        CompressionBackend::Rate {
1525            rate_backend,
1526            coder,
1527            framing,
1528        } => {
1529            crate::compression::compress_rate_size_chain(parts, rate_backend, -1, *coder, *framing)
1530                .unwrap_or(0)
1531        }
1532    }
1533}
1534
1535/// Compute compressed size of a single byte slice with a selected compression backend.
1536pub fn compress_size_backend(data: &[u8], backend: &CompressionBackend) -> u64 {
1537    match backend {
1538        CompressionBackend::Zpaq { method } => zpaq_compress_size_bytes(data, method.as_str()),
1539        #[cfg(feature = "backend-rwkv")]
1540        CompressionBackend::Rwkv7 { model, coder } => {
1541            with_rwkv_tls(model, |c| c.compress_size(data, *coder).unwrap_or(0))
1542        }
1543        CompressionBackend::Rate {
1544            rate_backend,
1545            coder,
1546            framing,
1547        } => crate::compression::compress_rate_size(data, rate_backend, -1, *coder, *framing)
1548            .unwrap_or(0),
1549    }
1550}
1551
1552/// Compress bytes with a selected compression backend.
1553pub fn compress_bytes_backend(
1554    data: &[u8],
1555    backend: &CompressionBackend,
1556) -> anyhow::Result<Vec<u8>> {
1557    match backend {
1558        CompressionBackend::Zpaq { method } => zpaq_compress_to_vec(data, method),
1559        #[cfg(feature = "backend-rwkv")]
1560        CompressionBackend::Rwkv7 { model, coder } => {
1561            with_rwkv_tls(model, |c| c.compress(data, *coder))
1562        }
1563        CompressionBackend::Rate {
1564            rate_backend,
1565            coder,
1566            framing,
1567        } => crate::compression::compress_rate_bytes(data, rate_backend, -1, *coder, *framing),
1568    }
1569}
1570
1571/// Decompress bytes with a selected compression backend.
1572pub fn decompress_bytes_backend(
1573    input: &[u8],
1574    backend: &CompressionBackend,
1575) -> anyhow::Result<Vec<u8>> {
1576    match backend {
1577        CompressionBackend::Zpaq { .. } => zpaq_decompress_to_vec(input),
1578        #[cfg(feature = "backend-rwkv")]
1579        CompressionBackend::Rwkv7 { model, .. } => with_rwkv_tls(model, |c| c.decompress(input)),
1580        CompressionBackend::Rate {
1581            rate_backend,
1582            coder,
1583            framing,
1584        } => crate::compression::decompress_rate_bytes(input, rate_backend, -1, *coder, *framing),
1585    }
1586}
1587
1588fn prequential_rate_backend(
1589    data: &[u8],
1590    prefix_parts: &[&[u8]],
1591    max_order: i64,
1592    backend: &RateBackend,
1593) -> f64 {
1594    use crate::mixture::OnlineBytePredictor;
1595
1596    if data.is_empty() {
1597        return 0.0;
1598    }
1599    let total = prefix_parts
1600        .iter()
1601        .map(|p| p.len() as u64)
1602        .sum::<u64>()
1603        .saturating_add(data.len() as u64);
1604    let mut predictor = crate::mixture::RateBackendPredictor::from_backend(
1605        backend.clone(),
1606        max_order,
1607        crate::mixture::DEFAULT_MIN_PROB,
1608    );
1609    predictor
1610        .begin_stream(Some(total))
1611        .unwrap_or_else(|e| panic!("rate backend stream init failed: {e}"));
1612    for prefix in prefix_parts {
1613        for &b in *prefix {
1614            predictor.update(b);
1615        }
1616    }
1617    let mut bits = 0.0;
1618    for &b in data {
1619        bits -= predictor.log_prob(b) / std::f64::consts::LN_2;
1620        predictor.update(b);
1621    }
1622    predictor
1623        .finish_stream()
1624        .unwrap_or_else(|e| panic!("rate backend stream finalize failed: {e}"));
1625    bits / (data.len() as f64)
1626}
1627
1628fn frozen_plugin_rate_backend(
1629    score_data: &[u8],
1630    fit_parts: &[&[u8]],
1631    max_order: i64,
1632    backend: &RateBackend,
1633) -> f64 {
1634    if score_data.is_empty() {
1635        return 0.0;
1636    }
1637    if matches!(backend, RateBackend::RosaPlus) {
1638        let mut model = rosaplus::RosaPlus::new(max_order, false, 0, 42);
1639        for part in fit_parts {
1640            model.train_example(part);
1641        }
1642        model.build_lm();
1643        return model.cross_entropy(score_data);
1644    }
1645    #[cfg(feature = "backend-rwkv")]
1646    match backend {
1647        RateBackend::Rwkv7 { model } => {
1648            return with_rwkv_rate_tls(model, |c| {
1649                c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
1650                    .unwrap_or_else(|e| panic!("rwkv frozen-plugin scoring failed: {e:#}"))
1651            });
1652        }
1653        RateBackend::Rwkv7Method { method } => {
1654            return with_rwkv_method_tls(method, |c| {
1655                c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
1656                    .unwrap_or_else(|e| panic!("rwkv method frozen-plugin scoring failed: {e:#}"))
1657            });
1658        }
1659        _ => {}
1660    }
1661    #[cfg(feature = "backend-mamba")]
1662    match backend {
1663        RateBackend::Mamba { model } => {
1664            return with_mamba_rate_tls(model, |c| {
1665                c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
1666                    .unwrap_or_else(|e| panic!("mamba frozen-plugin scoring failed: {e:#}"))
1667            });
1668        }
1669        RateBackend::MambaMethod { method } => {
1670            return with_mamba_method_tls(method, |c| {
1671                c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
1672                    .unwrap_or_else(|e| panic!("mamba method frozen-plugin scoring failed: {e:#}"))
1673            });
1674        }
1675        _ => {}
1676    }
1677
1678    use crate::mixture::OnlineBytePredictor;
1679
1680    let fit_total = fit_parts.iter().map(|part| part.len() as u64).sum::<u64>();
1681    let mut predictor = crate::mixture::RateBackendPredictor::from_backend(
1682        backend.clone(),
1683        max_order,
1684        crate::mixture::DEFAULT_MIN_PROB,
1685    );
1686    predictor
1687        .begin_stream(Some(fit_total))
1688        .unwrap_or_else(|e| panic!("rate backend fit-pass init failed: {e}"));
1689    for part in fit_parts {
1690        for &byte in *part {
1691            predictor.update(byte);
1692        }
1693    }
1694    predictor
1695        .finish_stream()
1696        .unwrap_or_else(|e| panic!("rate backend fit-pass finalize failed: {e}"));
1697    predictor
1698        .reset_frozen(Some(score_data.len() as u64))
1699        .unwrap_or_else(|e| panic!("rate backend frozen-score reset failed: {e}"));
1700    let mut bits = 0.0;
1701    for &byte in score_data {
1702        bits -= predictor.log_prob(byte) / std::f64::consts::LN_2;
1703        predictor.update_frozen(byte);
1704    }
1705    predictor
1706        .finish_stream()
1707        .unwrap_or_else(|e| panic!("rate backend frozen-score finalize failed: {e}"));
1708    bits / (score_data.len() as f64)
1709}
1710
1711#[inline(always)]
1712fn argmax_log_prob_byte(logps: &[f64; 256]) -> u8 {
1713    let mut best_idx = 0usize;
1714    let mut best = f64::NEG_INFINITY;
1715    for (idx, &logp) in logps.iter().enumerate() {
1716        let score = if logp.is_finite() {
1717            logp
1718        } else {
1719            f64::NEG_INFINITY
1720        };
1721        if score > best {
1722            best = score;
1723            best_idx = idx;
1724        }
1725    }
1726    best_idx as u8
1727}
1728
1729fn pick_generated_byte(
1730    logps: &[f64; 256],
1731    config: GenerationConfig,
1732    rng: &mut GenerationRng,
1733) -> u8 {
1734    if matches!(config.strategy, GenerationStrategy::Greedy)
1735        || !config.temperature.is_finite()
1736        || config.temperature <= 0.0
1737    {
1738        return argmax_log_prob_byte(logps);
1739    }
1740
1741    let mut entries = [(0u8, f64::NEG_INFINITY); 256];
1742    for (idx, &logp) in logps.iter().enumerate() {
1743        let scaled = if logp.is_finite() {
1744            logp / config.temperature
1745        } else {
1746            f64::NEG_INFINITY
1747        };
1748        entries[idx] = (idx as u8, scaled);
1749    }
1750    entries.sort_by(|a, b| b.1.total_cmp(&a.1));
1751
1752    let keep_k = if config.top_k == 0 {
1753        entries.len()
1754    } else {
1755        config.top_k.min(entries.len())
1756    };
1757
1758    let top_p = if config.top_p.is_finite() {
1759        config.top_p.clamp(0.0, 1.0)
1760    } else {
1761        1.0
1762    };
1763
1764    let mut max_logp = f64::NEG_INFINITY;
1765    for &(_, logp) in entries.iter().take(keep_k) {
1766        if logp.is_finite() {
1767            max_logp = max_logp.max(logp);
1768        }
1769    }
1770    if !max_logp.is_finite() {
1771        return argmax_log_prob_byte(logps);
1772    }
1773
1774    let mut weights = [(0u8, 0.0f64); 256];
1775    let mut total = 0.0;
1776    for (idx, &(byte, logp)) in entries.iter().take(keep_k).enumerate() {
1777        let w = if logp.is_finite() {
1778            (logp - max_logp).exp()
1779        } else {
1780            0.0
1781        };
1782        weights[idx] = (byte, w);
1783        total += w;
1784    }
1785    if !(total.is_finite()) || total <= 0.0 {
1786        return argmax_log_prob_byte(logps);
1787    }
1788
1789    let cutoff_count = if top_p >= 1.0 {
1790        keep_k
1791    } else {
1792        let mut cumulative = 0.0;
1793        let mut keep = 0usize;
1794        for &(_, w) in weights.iter().take(keep_k) {
1795            cumulative += w / total;
1796            keep += 1;
1797            if cumulative >= top_p {
1798                break;
1799            }
1800        }
1801        keep.max(1)
1802    };
1803
1804    let mut truncated_total = 0.0;
1805    for &(_, w) in weights.iter().take(cutoff_count) {
1806        truncated_total += w;
1807    }
1808    if !(truncated_total.is_finite()) || truncated_total <= 0.0 {
1809        return argmax_log_prob_byte(logps);
1810    }
1811
1812    let target = rng.next_f64() * truncated_total;
1813    let mut cumulative = 0.0;
1814    let mut picked = weights[0].0;
1815    for &(byte, weight) in weights.iter().take(cutoff_count) {
1816        cumulative += weight;
1817        if cumulative >= target {
1818            picked = byte;
1819            break;
1820        }
1821    }
1822    picked
1823}
1824
1825fn generate_rate_backend_chain(
1826    prefix_parts: &[&[u8]],
1827    bytes: usize,
1828    max_order: i64,
1829    backend: &RateBackend,
1830    config: GenerationConfig,
1831) -> Vec<u8> {
1832    if bytes == 0 {
1833        return Vec::new();
1834    }
1835
1836    let total = prefix_parts
1837        .iter()
1838        .map(|p| p.len() as u64)
1839        .sum::<u64>()
1840        .saturating_add(bytes as u64);
1841    let mut session = RateBackendSession::from_backend(backend.clone(), max_order, Some(total))
1842        .unwrap_or_else(|e| panic!("rate backend generation init failed: {e}"));
1843    for &part in prefix_parts {
1844        session.observe(part);
1845    }
1846    let out = session.generate_bytes(bytes, config);
1847    session
1848        .finish()
1849        .unwrap_or_else(|e| panic!("rate backend generation finalize failed: {e}"));
1850    out
1851}
1852
1853/// Estimate entropy rate of `data` using the explicit rate `backend`.
1854pub fn entropy_rate_backend(data: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
1855    match backend {
1856        RateBackend::RosaPlus => {
1857            let mut m = rosaplus::RosaPlus::new(max_order, false, 0, 42);
1858            m.predictive_entropy_rate(data)
1859        }
1860        RateBackend::Match { .. }
1861        | RateBackend::SparseMatch { .. }
1862        | RateBackend::Ppmd { .. }
1863        | RateBackend::Calibrated { .. } => prequential_rate_backend(data, &[], max_order, backend),
1864        #[cfg(feature = "backend-rwkv")]
1865        RateBackend::Rwkv7 { model } => with_rwkv_tls(model, |c| {
1866            c.cross_entropy(data)
1867                .unwrap_or_else(|e| panic!("rwkv entropy scoring failed: {e:#}"))
1868        }),
1869        #[cfg(feature = "backend-rwkv")]
1870        RateBackend::Rwkv7Method { method } => with_rwkv_method_tls(method, |c| {
1871            c.cross_entropy(data)
1872                .unwrap_or_else(|e| panic!("rwkv method entropy scoring failed: {e:#}"))
1873        }),
1874        #[cfg(feature = "backend-mamba")]
1875        RateBackend::Mamba { model } => with_mamba_tls(model, |c| {
1876            c.cross_entropy(data)
1877                .unwrap_or_else(|e| panic!("mamba entropy scoring failed: {e:#}"))
1878        }),
1879        #[cfg(feature = "backend-mamba")]
1880        RateBackend::MambaMethod { method } => with_mamba_method_tls(method, |c| {
1881            c.cross_entropy(data)
1882                .unwrap_or_else(|e| panic!("mamba method entropy scoring failed: {e:#}"))
1883        }),
1884        RateBackend::Zpaq { method } => {
1885            if data.is_empty() {
1886                return 0.0;
1887            }
1888            let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
1889            let bits = model.update_and_score(data);
1890            bits / (data.len() as f64)
1891        }
1892        RateBackend::Mixture { spec } => {
1893            if data.is_empty() {
1894                return 0.0;
1895            }
1896            let experts = spec.build_experts();
1897            let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
1898                .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
1899            mix.begin_stream(Some(data.len() as u64))
1900                .unwrap_or_else(|e| panic!("Mixture stream init failed: {e}"));
1901            let mut bits = 0.0;
1902            for &b in data {
1903                bits -= mix.step(b) / std::f64::consts::LN_2;
1904            }
1905            mix.finish_stream()
1906                .unwrap_or_else(|e| panic!("Mixture stream finalize failed: {e}"));
1907            bits / (data.len() as f64)
1908        }
1909        RateBackend::Particle { spec } => {
1910            if data.is_empty() {
1911                return 0.0;
1912            }
1913            let mut runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
1914            let mut bits = 0.0;
1915            for &b in data {
1916                bits -= runtime.step(b) / std::f64::consts::LN_2;
1917            }
1918            bits / (data.len() as f64)
1919        }
1920        RateBackend::Ctw { depth } => {
1921            if data.is_empty() {
1922                return 0.0;
1923            }
1924            // Byte-wise CTW: factorize by bit position so deterministic bits don't leak entropy.
1925            let mut fac = crate::ctw::FacContextTree::new(*depth, 8);
1926            fac.reserve_for_symbols(data.len());
1927            for &b in data {
1928                fac.update_byte_msb(b);
1929            }
1930            let ln_p = fac.get_log_block_probability();
1931            let bits = -ln_p / std::f64::consts::LN_2;
1932            bits / (data.len() as f64)
1933        }
1934        RateBackend::FacCtw {
1935            base_depth,
1936            num_percept_bits: _,
1937            encoding_bits,
1938        } => {
1939            if data.is_empty() {
1940                return 0.0;
1941            }
1942            let bits_per_byte = (*encoding_bits).clamp(1, 8);
1943            let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte);
1944            fac.reserve_for_symbols(data.len());
1945            for &b in data {
1946                fac.update_byte_lsb(b);
1947            }
1948            let ln_p = fac.get_log_block_probability();
1949            let bits = -ln_p / std::f64::consts::LN_2;
1950            bits / (data.len() as f64)
1951        }
1952    }
1953}
1954
1955/// Estimate biased/plugin entropy rate of `data` using the explicit rate `backend`.
1956pub fn biased_entropy_rate_backend(data: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
1957    match backend {
1958        RateBackend::Zpaq { .. } => {
1959            panic!("biased/plugin entropy is not supported for zpaq rate backends in 1.1.0")
1960        }
1961        _ => frozen_plugin_rate_backend(data, &[data], max_order, backend),
1962    }
1963}
1964
1965/// Cross-entropy H_{train}(test) - score test_data under model trained on train_data.
1966pub fn cross_entropy_rate_backend(
1967    test_data: &[u8],
1968    train_data: &[u8],
1969    max_order: i64,
1970    backend: &RateBackend,
1971) -> f64 {
1972    match backend {
1973        RateBackend::Zpaq { method } => {
1974            if test_data.is_empty() {
1975                return 0.0;
1976            }
1977            let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
1978            model.update_and_score(train_data);
1979            let bits = model.update_and_score(test_data);
1980            bits / (test_data.len() as f64)
1981        }
1982        _ => frozen_plugin_rate_backend(test_data, &[train_data], max_order, backend),
1983    }
1984}
1985
1986/// Estimate joint entropy rate `H(X,Y)` using an explicit `backend`.
1987pub fn joint_entropy_rate_backend(
1988    x: &[u8],
1989    y: &[u8],
1990    max_order: i64,
1991    backend: &RateBackend,
1992) -> f64 {
1993    let (x, y) = aligned_prefix(x, y);
1994    if x.is_empty() {
1995        return 0.0;
1996    }
1997    match backend {
1998        RateBackend::RosaPlus => {
1999            let joint_symbols: Vec<u32> = (0..x.len())
2000                .map(|i| (x[i] as u32) * 256 + (y[i] as u32))
2001                .collect();
2002            let mut m = rosaplus::RosaPlus::new(max_order, false, 0, 42);
2003            m.entropy_rate_cps(&joint_symbols)
2004        }
2005        RateBackend::Match { .. }
2006        | RateBackend::SparseMatch { .. }
2007        | RateBackend::Ppmd { .. }
2008        | RateBackend::Calibrated { .. } => {
2009            let mut joint = Vec::with_capacity(x.len() * 2);
2010            for (&xb, &yb) in x.iter().zip(y.iter()) {
2011                joint.push(xb);
2012                joint.push(yb);
2013            }
2014            entropy_rate_backend(&joint, max_order, backend) * 2.0
2015        }
2016        #[cfg(feature = "backend-rwkv")]
2017        RateBackend::Rwkv7 { model } => with_rwkv_tls(model, |c| {
2018            c.joint_cross_entropy_aligned_min(x, y)
2019                .unwrap_or_else(|e| panic!("rwkv joint-entropy scoring failed: {e:#}"))
2020        }),
2021        #[cfg(feature = "backend-rwkv")]
2022        RateBackend::Rwkv7Method { method } => with_rwkv_method_tls(method, |c| {
2023            c.joint_cross_entropy_aligned_min(x, y)
2024                .unwrap_or_else(|e| panic!("rwkv method joint-entropy scoring failed: {e:#}"))
2025        }),
2026        #[cfg(feature = "backend-mamba")]
2027        RateBackend::Mamba { model } => with_mamba_tls(model, |c| {
2028            c.joint_cross_entropy_aligned_min(x, y)
2029                .unwrap_or_else(|e| panic!("mamba joint-entropy scoring failed: {e:#}"))
2030        }),
2031        #[cfg(feature = "backend-mamba")]
2032        RateBackend::MambaMethod { method } => with_mamba_method_tls(method, |c| {
2033            c.joint_cross_entropy_aligned_min(x, y)
2034                .unwrap_or_else(|e| panic!("mamba method joint-entropy scoring failed: {e:#}"))
2035        }),
2036        RateBackend::Zpaq { method } => {
2037            let mut joint = Vec::with_capacity(x.len() * 2);
2038            for (&xb, &yb) in x.iter().zip(y.iter()) {
2039                joint.push(xb);
2040                joint.push(yb);
2041            }
2042            let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
2043            let bits = model.update_and_score(&joint);
2044            bits / (x.len() as f64)
2045        }
2046        RateBackend::Mixture { spec } => {
2047            let mut joint = Vec::with_capacity(x.len() * 2);
2048            for (&xb, &yb) in x.iter().zip(y.iter()) {
2049                joint.push(xb);
2050                joint.push(yb);
2051            }
2052            let experts = spec.build_experts();
2053            let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
2054                .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
2055            mix.begin_stream(Some(joint.len() as u64))
2056                .unwrap_or_else(|e| panic!("Mixture stream init failed: {e}"));
2057            let mut bits = 0.0;
2058            for &b in &joint {
2059                bits -= mix.step(b) / std::f64::consts::LN_2;
2060            }
2061            mix.finish_stream()
2062                .unwrap_or_else(|e| panic!("Mixture stream finalize failed: {e}"));
2063            bits / (x.len() as f64)
2064        }
2065        RateBackend::Particle { spec } => {
2066            let mut joint = Vec::with_capacity(x.len() * 2);
2067            for (&xb, &yb) in x.iter().zip(y.iter()) {
2068                joint.push(xb);
2069                joint.push(yb);
2070            }
2071            let mut runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
2072            let mut bits = 0.0;
2073            for &b in &joint {
2074                bits -= runtime.step(b) / std::f64::consts::LN_2;
2075            }
2076            bits / (x.len() as f64)
2077        }
2078        RateBackend::Ctw { depth } => {
2079            // NOTE: CTW interleaves bits: x_0, y_0, x_1, y_1...
2080            // This estimates the joint entropy H(X,Y) by modeling the sequence
2081            // of alternating bits. This is a fine-grained joint model but
2082            // theoretically consistent for estimating joint entropy rate.
2083            // ROSA uses 16-bit joint symbols (x << 8 | y). Both are valid.
2084            let mut fac = crate::ctw::FacContextTree::new(*depth, 16);
2085            for k in 0..x.len() {
2086                let bx = x[k];
2087                let by = y[k];
2088                for bit_idx in 0..8 {
2089                    let bit_x = ((bx >> (7 - bit_idx)) & 1) == 1;
2090                    let bit_y = ((by >> (7 - bit_idx)) & 1) == 1;
2091                    fac.update(bit_x, bit_idx);
2092                    fac.update(bit_y, bit_idx + 8);
2093                }
2094            }
2095            let ln_p = fac.get_log_block_probability();
2096            let bits = -ln_p / std::f64::consts::LN_2;
2097            bits / (x.len() as f64)
2098        }
2099        RateBackend::FacCtw {
2100            base_depth,
2101            num_percept_bits: _,
2102            encoding_bits,
2103        } => {
2104            // Joint: interleave x and y bits, use 2*encoding_bits trees
2105            let bits_per_byte = (*encoding_bits).clamp(1, 8);
2106            let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte * 2);
2107            for k in 0..x.len() {
2108                let bx = x[k];
2109                let by = y[k];
2110                for i in 0..bits_per_byte {
2111                    // Tree structure:
2112                    // bits_per_byte trees for X, bits_per_byte trees for Y.
2113                    // But we interleave them in the "joint" sense.
2114                    // Here we map bit i of X to tree 2*i, bit i of Y to tree 2*i + 1
2115                    let bit_idx_x = i * 2;
2116                    let bit_idx_y = bit_idx_x + 1;
2117                    fac.update(((bx >> i) & 1) == 1, bit_idx_x);
2118                    fac.update(((by >> i) & 1) == 1, bit_idx_y);
2119                }
2120            }
2121            let ln_p = fac.get_log_block_probability();
2122            let bits = -ln_p / std::f64::consts::LN_2;
2123            bits / (x.len() as f64)
2124        }
2125    }
2126}
2127#[inline(always)]
2128/// Compute compressed size for a file path with an explicit ZPAQ thread count.
2129pub fn get_compressed_size_parallel(path: &str, method: &str, threads: usize) -> u64 {
2130    // Convert Input file to Vec<u8>, and reference that (compress_size only takes &[u8] input), and pass method.
2131    // Will panic if file does not exist, so it must be prevalidated.
2132    zpaq_compress_size_parallel_bytes(&std::fs::read(path).unwrap(), method, threads)
2133}
2134
2135#[inline(always)]
2136/// Read all files in `paths` in parallel and return their byte contents.
2137pub fn get_bytes_from_paths(paths: &[&str]) -> Vec<Vec<u8>> {
2138    paths
2139        .par_iter()
2140        .map(|path| std::fs::read(*path).expect("failed to read file"))
2141        .collect()
2142}
2143
2144/// ------- Bulk File Compression Functions -------
2145#[inline(always)]
2146pub fn get_sequential_compressed_sizes_from_sequential_paths(
2147    paths: &[&str],
2148    method: &str,
2149) -> Vec<u64> {
2150    // This will, in parallel load all files into memory, THEN in parallel compress each one, each with one thread.
2151    // Use when File IO is the bottleneck
2152    // Only uses ONE ZPAQ THREAD.
2153    // For VERY large n (relative to threads) with small files (relative to memory) this may be useful.
2154    get_bytes_from_paths(paths)
2155        .par_iter()
2156        .map(|data| zpaq_compress_size_bytes(data, method))
2157        .collect()
2158}
2159
2160#[inline(always)]
2161/// Compress all paths after preloading bytes, using per-file parallel ZPAQ compression.
2162pub fn get_parallel_compressed_sizes_from_sequential_paths(
2163    paths: &[&str],
2164    method: &str,
2165    threads: usize,
2166) -> Vec<u64> {
2167    // This will, in parallel load all files into memory, THEN in parallel compress each one, with THREADS. (for each file, the thread count is THREADS)
2168    // Use when File IO is the bottleneck.
2169    // Balanced parallelization between RAYON_NUM_THREADS and ZPAQ `THREADS` const. For when total dataset will fit in memory.
2170    get_bytes_from_paths(paths)
2171        .par_iter()
2172        .map(|data| zpaq_compress_size_parallel_bytes(data, method, threads))
2173        .collect()
2174}
2175
2176#[inline(always)]
2177/// Compress all paths directly from disk using single-thread ZPAQ per file.
2178pub fn get_sequential_compressed_sizes_from_parallel_paths(
2179    paths: &[&str],
2180    method: &str,
2181) -> Vec<u64> {
2182    // This will, in parallel, for each file, read it from disk and compress it with one thread. (one file, one thread)
2183    // Use when File IO is not the bottleneck. Lower memory usage. (does not preload dataset)
2184    // Only uses ONE ZPAQ THREAD. For VERY large n(relative to threads) with large files(relative to memory) this may be useful.
2185    paths
2186        .par_iter()
2187        .map(|path| get_compressed_size(path, method))
2188        .collect()
2189}
2190
2191#[inline(always)]
2192/// Compress all paths directly from disk using per-file multi-thread ZPAQ.
2193pub fn get_parallel_compressed_sizes_from_parallel_paths(
2194    paths: &[&str],
2195    method: &str,
2196    threads: usize,
2197) -> Vec<u64> {
2198    // This will, in parallel, for each file, read it from disk and compress it with THREADS. (for each file, the thread count is THREADS)
2199    // Use when File IO is not the bottleneck. Lower memory usage. (does not preload dataset)
2200    // For large n(relative to threads) with VERY large files(relative to memory) this may be useful.
2201    // This will reflect RAYON_NUM_THREADS and THREAD const values.
2202    paths
2203        .par_iter()
2204        .map(|path| get_compressed_size_parallel(path, method, threads))
2205        .collect()
2206}
2207
2208/// Optimizes parallelization
2209#[inline(always)]
2210pub fn get_compressed_sizes_from_paths(paths: &[&str], method: &str) -> Vec<u64> {
2211    let n: usize = paths.len();
2212    let num_threads: usize = *NUM_THREADS.get_or_init(num_cpus::get);
2213    if n < num_threads {
2214        get_parallel_compressed_sizes_from_parallel_paths(paths, method, num_threads.div_ceil(n))
2215    } else {
2216        get_sequential_compressed_sizes_from_parallel_paths(paths, method)
2217    }
2218}
2219
2220/// ------- NCD (Normalized Compression Distance) ------
2221///
2222/// NCD is a parameter-free similarity metric based on Kolmogorov complexity.
2223/// Since Kolmogorov complexity `K(x)` is uncomputable, we approximate it using
2224/// the compressed size `C(x)` provided by a real-world compressor (here, ZPAQ).
2225///
2226/// The general form is:
2227/// `NCD(x,y) = (C(xy) - min(C(x), C(y))) / max(C(x), C(y))`
2228///
2229/// Different variants handle normalization and symmetry differently.
2230#[derive(Clone, Copy, Debug, Eq, PartialEq)]
2231pub enum NcdVariant {
2232    /// Standard Vitanyi NCD:
2233    /// `NCD(x,y) = (C(xy) - min(C(x), C(y))) / max(C(x), C(y))`
2234    /// Note: `C(xy)` denotes compressing the concatenation of x and y.
2235    Vitanyi,
2236    /// Symmetric Vitanyi NCD:
2237    /// `NCD_sym(x,y) = (min(C(xy), C(yx)) - min(C(x), C(y))) / max(C(x), C(y))`
2238    /// Takes the best compression of `xy` or `yx` to ensure symmetry even if the compressor is not symmetric.
2239    SymVitanyi,
2240    /// Conservative NCD:
2241    /// `NCD_cons(x,y) = (C(xy) - min(C(x), C(y))) / C(xy)`
2242    /// Normalizes by the joint compressed size instead of the max marginal.
2243    Cons,
2244    /// Symmetric Conservative NCD:
2245    /// `NCD_sym_cons(x,y) = (min(C(xy), C(yx)) - min(C(x), C(y))) / min(C(xy), C(yx))`
2246    SymCons,
2247}
2248
2249#[inline(always)]
2250fn compress_size_bytes(data: &[u8], method: &str) -> u64 {
2251    zpaq_compress_size_bytes(data, method)
2252}
2253
2254#[inline(always)]
2255fn ncd_from_sizes(cx: u64, cy: u64, cxy: u64, cyx: Option<u64>, variant: NcdVariant) -> f64 {
2256    let min_c = cx.min(cy) as f64;
2257    let max_c = cx.max(cy) as f64;
2258
2259    match variant {
2260        NcdVariant::Vitanyi => {
2261            if max_c == 0.0 {
2262                0.0
2263            } else {
2264                (cxy as f64 - min_c) / max_c
2265            }
2266        }
2267        NcdVariant::SymVitanyi => {
2268            let m = cxy.min(cyx.expect("cyx required for SymVitanyi")) as f64;
2269            if max_c == 0.0 {
2270                0.0
2271            } else {
2272                (m - min_c) / max_c
2273            }
2274        }
2275        NcdVariant::Cons => {
2276            let denom = cxy as f64;
2277            if denom == 0.0 {
2278                0.0
2279            } else {
2280                (cxy as f64 - min_c) / denom
2281            }
2282        }
2283        NcdVariant::SymCons => {
2284            let m = cxy.min(cyx.expect("cyx required for SymCons")) as f64;
2285            if m == 0.0 { 0.0 } else { (m - min_c) / m }
2286        }
2287    }
2288}
2289
2290#[inline(always)]
2291/// Compute NCD for in-memory byte slices using the given ZPAQ `method` and `variant`.
2292pub fn ncd_bytes(x: &[u8], y: &[u8], method: &str, variant: NcdVariant) -> f64 {
2293    let backend = CompressionBackend::Zpaq {
2294        method: method.to_string(),
2295    };
2296    ncd_bytes_backend(x, y, &backend, variant)
2297}
2298
2299/// NCD with bytes using the default context.
2300#[inline(always)]
2301pub fn ncd_bytes_default(x: &[u8], y: &[u8], variant: NcdVariant) -> f64 {
2302    with_default_ctx(|ctx| ctx.ncd_bytes(x, y, variant))
2303}
2304
2305/// Compute NCD for in-memory byte slices using an explicit compression `backend`.
2306pub fn ncd_bytes_backend(
2307    x: &[u8],
2308    y: &[u8],
2309    backend: &CompressionBackend,
2310    variant: NcdVariant,
2311) -> f64 {
2312    let (cx, cy) = rayon::join(
2313        || compress_size_backend(x, backend),
2314        || compress_size_backend(y, backend),
2315    );
2316
2317    let cxy = compress_size_chain_backend(&[x, y], backend);
2318
2319    let cyx = match variant {
2320        NcdVariant::SymVitanyi | NcdVariant::SymCons => {
2321            Some(compress_size_chain_backend(&[y, x], backend))
2322        }
2323        _ => None,
2324    };
2325
2326    ncd_from_sizes(cx, cy, cxy, cyx, variant)
2327}
2328
2329#[inline(always)]
2330/// Compute NCD for two file paths using a ZPAQ `method` and `variant`.
2331pub fn ncd_paths(x: &str, y: &str, method: &str, variant: NcdVariant) -> f64 {
2332    let (bx, by) = rayon::join(
2333        || std::fs::read(x).expect("failed to read x"),
2334        || std::fs::read(y).expect("failed to read y"),
2335    );
2336    ncd_bytes(&bx, &by, method, variant)
2337}
2338
2339/// Compute NCD for two file paths using an explicit compression `backend`.
2340pub fn ncd_paths_backend(
2341    x: &str,
2342    y: &str,
2343    backend: &CompressionBackend,
2344    variant: NcdVariant,
2345) -> f64 {
2346    let (bx, by) = rayon::join(
2347        || std::fs::read(x).expect("failed to read x"),
2348        || std::fs::read(y).expect("failed to read y"),
2349    );
2350    ncd_bytes_backend(&bx, &by, backend, variant)
2351}
2352
2353/// Back-compat convenience wrappers (operate on file paths).
2354#[inline(always)]
2355pub fn ncd_vitanyi(x: &str, y: &str, method: &str) -> f64 {
2356    ncd_paths(x, y, method, NcdVariant::Vitanyi)
2357}
2358#[inline(always)]
2359/// Convenience wrapper for symmetric-Vitanyi NCD on file paths.
2360pub fn ncd_sym_vitanyi(x: &str, y: &str, method: &str) -> f64 {
2361    ncd_paths(x, y, method, NcdVariant::SymVitanyi)
2362}
2363#[inline(always)]
2364/// Convenience wrapper for conservative NCD on file paths.
2365pub fn ncd_cons(x: &str, y: &str, method: &str) -> f64 {
2366    ncd_paths(x, y, method, NcdVariant::Cons)
2367}
2368#[inline(always)]
2369/// Convenience wrapper for symmetric-conservative NCD on file paths.
2370pub fn ncd_sym_cons(x: &str, y: &str, method: &str) -> f64 {
2371    ncd_paths(x, y, method, NcdVariant::SymCons)
2372}
2373
2374/// Computes an NCD matrix (row-major, len = n*n) for in-memory byte blobs.
2375///
2376/// Note: For symmetric variants, this computes each unordered pair once and writes both (i,j) and (j,i).
2377pub fn ncd_matrix_bytes(datas: &[Vec<u8>], method: &str, variant: NcdVariant) -> Vec<f64> {
2378    let n = datas.len();
2379    let cx: Vec<u64> = datas
2380        .par_iter()
2381        .map(|d| compress_size_bytes(d, method))
2382        .collect();
2383
2384    let mut out = vec![0.0f64; n * n];
2385    let out_ptr = std::sync::atomic::AtomicPtr::new(out.as_mut_ptr());
2386
2387    match variant {
2388        NcdVariant::SymVitanyi | NcdVariant::SymCons => {
2389            (0..n)
2390                .into_par_iter()
2391                .flat_map_iter(|i| (i + 1..n).map(move |j| (i, j)))
2392                .for_each_init(Vec::<u8>::new, |buf, (i, j)| {
2393                    let x = &datas[i];
2394                    let y = &datas[j];
2395
2396                    buf.clear();
2397                    buf.reserve(x.len() + y.len());
2398                    buf.extend_from_slice(x);
2399                    buf.extend_from_slice(y);
2400                    let cxy = compress_size_bytes(buf, method);
2401
2402                    buf.clear();
2403                    buf.reserve(x.len() + y.len());
2404                    buf.extend_from_slice(y);
2405                    buf.extend_from_slice(x);
2406                    let cyx = compress_size_bytes(buf, method);
2407
2408                    let d = ncd_from_sizes(cx[i], cx[j], cxy, Some(cyx), variant);
2409
2410                    // Safety: each (i,j) cell is written exactly once across all iterations.
2411                    let p = out_ptr.load(std::sync::atomic::Ordering::Relaxed);
2412                    unsafe {
2413                        *p.add(i * n + j) = d;
2414                        *p.add(j * n + i) = d;
2415                    }
2416                });
2417        }
2418        NcdVariant::Vitanyi | NcdVariant::Cons => {
2419            (0..n)
2420                .into_par_iter()
2421                .for_each_init(Vec::<u8>::new, |buf, i| {
2422                    let x = &datas[i];
2423                    for j in 0..n {
2424                        let d = if i == j {
2425                            0.0
2426                        } else {
2427                            let y = &datas[j];
2428                            buf.clear();
2429                            buf.reserve(x.len() + y.len());
2430                            buf.extend_from_slice(x);
2431                            buf.extend_from_slice(y);
2432                            let cxy = compress_size_bytes(buf, method);
2433                            ncd_from_sizes(cx[i], cx[j], cxy, None, variant)
2434                        };
2435
2436                        let p = out_ptr.load(std::sync::atomic::Ordering::Relaxed);
2437                        unsafe {
2438                            *p.add(i * n + j) = d;
2439                        }
2440                    }
2441                });
2442        }
2443    }
2444
2445    out
2446}
2447
2448/// Computes an NCD matrix (row-major, len = n*n) for files (preloads all files into memory once).
2449pub fn ncd_matrix_paths(paths: &[&str], method: &str, variant: NcdVariant) -> Vec<f64> {
2450    let datas = get_bytes_from_paths(paths);
2451    ncd_matrix_bytes(&datas, method, variant)
2452}
2453
2454// ============================================================
2455// Entropy-Based Distance Primitives (via ROSA)
2456// ============================================================
2457//
2458// These use ROSA's Witten-Bell language model to estimate entropy
2459// and compute information-theoretic distances.
2460
2461/// Compute marginal (Shannon) entropy H(X) = −Σ p(x) log₂ p(x) in bits/symbol.
2462///
2463/// This is the simple first-order entropy from the byte histogram,
2464/// NOT the context-conditional entropy rate from a language model.
2465#[inline(always)]
2466pub fn marginal_entropy_bytes(data: &[u8]) -> f64 {
2467    if data.is_empty() {
2468        return 0.0;
2469    }
2470
2471    let mut counts = [0u64; 256];
2472    for &b in data {
2473        counts[b as usize] += 1;
2474    }
2475
2476    let n = data.len() as f64;
2477    let mut h = 0.0f64;
2478    for &count in &counts {
2479        if count > 0 {
2480            let p = count as f64 / n;
2481            h -= p * p.log2();
2482        }
2483    }
2484    h
2485}
2486
2487/// Compute entropy rate `Ĥ(X)` in bits/symbol using ROSA LM.
2488///
2489/// This uses ROSA's context-conditional Witten-Bell model to estimate
2490/// the entropy rate, which accounts for sequential dependencies.
2491///
2492/// The estimator is **prequential** (predictive sequential): it sums the negative log-probability
2493/// of each symbol `x_t` given its past context `x_{<t}`, estimated from the model trained on `x_{<t}`.
2494///
2495/// `Ĥ(X) = -1/N * Σ log2 P(x_t | x_{t-k}^{t-1})`
2496///
2497/// For i.i.d. data, this should approximately equal `marginal_entropy_bytes`.
2498///
2499/// * `max_order`: Maximum context order for the suffix automaton LM.
2500///   A value of -1 means unlimited context (bounded only by memory/sequence length).
2501#[inline(always)]
2502pub fn entropy_rate_bytes(data: &[u8], max_order: i64) -> f64 {
2503    with_default_ctx(|ctx| ctx.entropy_rate_bytes(data, max_order))
2504}
2505
2506/// Compute biased entropy rate Ĥ_biased(X) bits per symbol.
2507///
2508/// This uses the full plugin estimator (training on the whole text, then scoring the same text).
2509/// While biased as a source entropy estimate, it is mathematically consistent for
2510/// similarity metrics like Mutual Information and NED.
2511#[inline(always)]
2512pub fn biased_entropy_rate_bytes(data: &[u8], max_order: i64) -> f64 {
2513    with_default_ctx(|ctx| ctx.biased_entropy_rate_bytes(data, max_order))
2514}
2515
2516/// Compute joint marginal entropy H(X,Y) = −Σ p(x,y) log₂ p(x,y) in bits/symbol-pair.
2517///
2518/// Uses a direct histogram of (x_i, y_i) pairs. This is the exact first-order
2519/// joint entropy, matching the spec.md definition.
2520#[inline(always)]
2521pub fn joint_marginal_entropy_bytes(x: &[u8], y: &[u8]) -> f64 {
2522    let (x, y) = aligned_prefix(x, y);
2523    let n = x.len();
2524    if n == 0 {
2525        return 0.0;
2526    }
2527
2528    // Count pair occurrences using a HashMap for (x, y) pairs
2529    // There are up to 65536 possible pairs, so we can use a flat array
2530    let mut counts = vec![0u64; 256 * 256];
2531    for i in 0..n {
2532        let pair_idx = (x[i] as usize) * 256 + (y[i] as usize);
2533        counts[pair_idx] += 1;
2534    }
2535
2536    let n_f64 = n as f64;
2537    let mut h = 0.0f64;
2538    for &c in &counts {
2539        if c > 0 {
2540            let p = c as f64 / n_f64;
2541            h -= p * p.log2();
2542        }
2543    }
2544    h
2545}
2546
2547/// Compute joint entropy rate `Ĥ(X,Y)`.
2548///
2549/// Dispatches based on `max_order`:
2550/// - `max_order == 0`: Strictly aligned pair-symbol mapping (Marginal Joint Entropy).
2551///   Treats `(x_i, y_i)` as a single symbol in a product alphabet `Σ_X × Σ_Y`.
2552/// - `max_order != 0`: Shift-invariant algorithmic joint entropy approximated via ROSA.
2553///   Constructs a sequence of pair-symbols and estimates the entropy rate of that sequence.
2554///
2555/// **Note**: This is an *aligned* joint entropy-rate estimate over time-indexed pairs
2556/// `(x_i, y_i)`. All joint-based quantities (`H(X)`, `H(Y)`, `H(X,Y)`, `I`, NED, NTE, etc.)
2557/// should be computed over the same aligned sample.
2558#[inline(always)]
2559pub fn joint_entropy_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2560    with_default_ctx(|ctx| ctx.joint_entropy_rate_bytes(x, y, max_order))
2561}
2562
2563/// Compute conditional entropy rate `Ĥ(X|Y)`.
2564///
2565/// Dispatches based on `max_order`:
2566/// - `max_order == 0`: Strictly aligned `H(X,Y) - H(Y)` using marginals.
2567/// - `max_order != 0`: Chain rule definition `Ĥ(X|Y) = Ĥ(X,Y) - Ĥ(Y)`.
2568///
2569/// Note: This relies on the identity `H(X|Y) = H(X,Y) - H(Y)`.
2570#[inline(always)]
2571pub fn conditional_entropy_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2572    with_default_ctx(|ctx| ctx.conditional_entropy_rate_bytes(x, y, max_order))
2573}
2574
2575/// Compute conditional entropy H(X|Y) = H(X,Y) − H(Y)
2576///
2577/// Dispatches based on `max_order`.
2578#[inline(always)]
2579pub fn conditional_entropy_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2580    with_default_ctx(|ctx| ctx.conditional_entropy_bytes(x, y, max_order))
2581}
2582
2583/// Compute mutual information `I(X;Y) = H(X) + H(Y) - H(X,Y)`.
2584///
2585/// Dispatches based on `max_order`. If 0, uses marginals; else uses rates.
2586///
2587/// `I(X;Y) = Σ p(x,y) log(p(x,y) / (p(x)p(y)))`
2588#[inline(always)]
2589pub fn mutual_information_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2590    with_default_ctx(|ctx| ctx.mutual_information_bytes(x, y, max_order))
2591}
2592
2593/// Marginal Mutual Information (exact/histogram)
2594pub fn mutual_information_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
2595    let (x, y) = aligned_prefix(x, y);
2596    let h_x = marginal_entropy_bytes(x);
2597    let h_y = marginal_entropy_bytes(y);
2598    let h_xy = joint_marginal_entropy_bytes(x, y);
2599    (h_x + h_y - h_xy).max(0.0)
2600}
2601
2602/// Entropy Rate Mutual Information (ROSA predictive)
2603#[inline(always)]
2604pub fn mutual_information_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2605    with_default_ctx(|ctx| ctx.mutual_information_rate_bytes(x, y, max_order))
2606}
2607
2608// ====== NED: Normalized Entropy Distance ======
2609//
2610// A metric distance based on the overlap of information between two variables.
2611
2612/// NED(X,Y) = (H(X,Y) - min(H(X), H(Y))) / max(H(X), H(Y))
2613///
2614/// Dispatches based on `max_order`. If 0, uses marginals; else uses rates.
2615///
2616/// Range: [0, 1].
2617/// * 0: Identity (X determines Y and Y determines X).
2618/// * 1: Independence (X and Y share no information).
2619#[inline(always)]
2620pub fn ned_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2621    with_default_ctx(|ctx| ctx.ned_bytes(x, y, max_order))
2622}
2623
2624/// Marginal NED (exact/histogram)
2625pub fn ned_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
2626    let (x, y) = aligned_prefix(x, y);
2627    let h_x = marginal_entropy_bytes(x);
2628    let h_y = marginal_entropy_bytes(y);
2629    let h_xy = joint_marginal_entropy_bytes(x, y);
2630    let min_h = h_x.min(h_y);
2631    let max_h = h_x.max(h_y);
2632    if max_h == 0.0 {
2633        0.0
2634    } else {
2635        ((h_xy - min_h) / max_h).clamp(0.0, 1.0)
2636    }
2637}
2638
2639/// Normalized Entropy Distance (Rate-based)
2640#[inline(always)]
2641pub fn ned_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2642    with_default_ctx(|ctx| ctx.ned_bytes(x, y, max_order))
2643}
2644
2645/// NED_cons(X,Y) = (H(X,Y) - min(H(X), H(Y))) / H(X,Y)
2646///
2647/// Conservative variant. Dispatches based on `max_order`.
2648#[inline(always)]
2649pub fn ned_cons_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2650    with_default_ctx(|ctx| ctx.ned_cons_bytes(x, y, max_order))
2651}
2652
2653/// Conservative marginal NED using histogram entropy estimates.
2654pub fn ned_cons_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
2655    let h_x = marginal_entropy_bytes(x);
2656    let h_y = marginal_entropy_bytes(y);
2657    let h_xy = joint_marginal_entropy_bytes(x, y);
2658    let min_h = h_x.min(h_y);
2659    if h_xy == 0.0 {
2660        0.0
2661    } else {
2662        ((h_xy - min_h) / h_xy).clamp(0.0, 1.0)
2663    }
2664}
2665
2666#[inline(always)]
2667/// Conservative rate NED using the current default context backend.
2668pub fn ned_cons_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2669    with_default_ctx(|ctx| ctx.ned_cons_bytes(x, y, max_order))
2670}
2671
2672// ====== NTE: Normalized Transform Effort (Variation of Information) ======
2673
2674/// NTE(X,Y) = VI(X,Y) / max(H(X), H(Y))
2675/// where `VI(X,Y) = H(X|Y) + H(Y|X) = 2H(X,Y) - H(X) - H(Y)`.
2676///
2677/// Represents the "effort" required to transform X into Y (and vice versa) relative
2678/// to their complexity.
2679///
2680/// Note: VI can be as large as `H(X) + H(Y)`. If `H(X) ≈ H(Y)`, then VI can be `≈ 2 max(H(X), H(Y))`.
2681/// Thus, NTE is in [0, 2].
2682/// * Values near 0 indicate near-identity.
2683/// * Values near 1+ indicate substantial effort/transform cost (e.g. independence).
2684///
2685/// Dispatches based on `max_order`.
2686#[inline(always)]
2687pub fn nte_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2688    with_default_ctx(|ctx| ctx.nte_bytes(x, y, max_order))
2689}
2690
2691/// Marginal NTE using histogram entropy estimates.
2692pub fn nte_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
2693    let (x, y) = aligned_prefix(x, y);
2694    let h_x = marginal_entropy_bytes(x);
2695    let h_y = marginal_entropy_bytes(y);
2696    let h_xy = joint_marginal_entropy_bytes(x, y);
2697    let vi = 2.0 * h_xy - h_x - h_y;
2698    let max_h = h_x.max(h_y);
2699    if max_h == 0.0 {
2700        0.0
2701    } else {
2702        (vi / max_h).clamp(0.0, 2.0)
2703    }
2704}
2705
2706#[inline(always)]
2707/// Rate NTE using the current default context backend.
2708pub fn nte_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2709    with_default_ctx(|ctx| ctx.nte_bytes(x, y, max_order))
2710}
2711
2712// ====== TVD: Total Variation Distance ======
2713
2714/// Compute marginal byte histogram p(i) = count(i) / N for i ∈ [0, 255]
2715#[inline(always)]
2716fn byte_histogram(data: &[u8]) -> [f64; 256] {
2717    let mut counts = [0u64; 256];
2718    for &b in data {
2719        counts[b as usize] += 1;
2720    }
2721    let n = data.len() as f64;
2722    let mut probs = [0.0f64; 256];
2723    if n == 0.0 {
2724        return probs;
2725    }
2726    for i in 0..256 {
2727        probs[i] = counts[i] as f64 / n;
2728    }
2729    probs
2730}
2731
2732/// TVD_marg(X,Y) = (1/2) Σᵢ |p_X(i) - p_Y(i)|
2733///
2734/// Total Variation Distance over marginal byte distributions.
2735/// True metric on probability space. Range: [0, 1].
2736/// 0 = identical distributions, 1 = completely disjoint support.
2737#[inline(always)]
2738pub fn tvd_bytes(x: &[u8], y: &[u8], _max_order: i64) -> f64 {
2739    if x.is_empty() || y.is_empty() {
2740        return 0.0;
2741    }
2742    let p_x = byte_histogram(x);
2743    let p_y = byte_histogram(y);
2744
2745    let mut sum = 0.0f64;
2746    for i in 0..256 {
2747        sum += (p_x[i] - p_y[i]).abs();
2748    }
2749
2750    (sum / 2.0).clamp(0.0, 1.0)
2751}
2752
2753// ====== NHD: Normalized Hellinger Distance ======
2754
2755/// NHD(X,Y) = sqrt(1 - BC(X,Y)) where BC = Σᵢ sqrt(p_X(i) · p_Y(i))
2756///
2757/// Normalized Hellinger Distance over marginal byte distributions.
2758/// True metric. Range: [0, 1]. 0 = identical, 1 = disjoint support.
2759#[inline(always)]
2760pub fn nhd_bytes(x: &[u8], y: &[u8], _max_order: i64) -> f64 {
2761    if x.is_empty() || y.is_empty() {
2762        return 0.0;
2763    }
2764    let p_x = byte_histogram(x);
2765    let p_y = byte_histogram(y);
2766
2767    // Bhattacharyya coefficient: BC = Σᵢ sqrt(p_X(i) · p_Y(i))
2768    let mut bc = 0.0f64;
2769    for i in 0..256 {
2770        bc += (p_x[i] * p_y[i]).sqrt();
2771    }
2772
2773    // NHD = sqrt(1 - BC)
2774    (1.0 - bc).max(0.0).sqrt()
2775}
2776
2777// ====== Other Information-Theoretic Measures ======
2778
2779/// Compute cross-entropy H_{train}(test) - score test_data under model trained on train_data.
2780///
2781/// Dispatches based on `max_order`.
2782#[inline(always)]
2783pub fn cross_entropy_bytes(test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
2784    with_default_ctx(|ctx| ctx.cross_entropy_bytes(test_data, train_data, max_order))
2785}
2786
2787/// Compute cross-entropy rate using ROSA/CTW/RWKV.
2788/// Training model on `train_data` and evaluating probability of `test_data`.
2789#[inline(always)]
2790pub fn cross_entropy_rate_bytes(test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
2791    with_default_ctx(|ctx| ctx.cross_entropy_rate_bytes(test_data, train_data, max_order))
2792}
2793
2794/// Generate a continuation from `prompt`
2795/// using the current default context and [`GenerationConfig::default()`].
2796///
2797/// The default is deterministic frozen sampling with seed `42`.
2798#[inline(always)]
2799pub fn generate_bytes(prompt: &[u8], bytes: usize, max_order: i64) -> Vec<u8> {
2800    with_default_ctx(|ctx| ctx.generate_bytes(prompt, bytes, max_order))
2801}
2802
2803/// Generate a continuation from `prompt` using the current default context.
2804#[inline(always)]
2805pub fn generate_bytes_with_config(
2806    prompt: &[u8],
2807    bytes: usize,
2808    max_order: i64,
2809    config: GenerationConfig,
2810) -> Vec<u8> {
2811    with_default_ctx(|ctx| ctx.generate_bytes_with_config(prompt, bytes, max_order, config))
2812}
2813
2814/// Generate a continuation after conditioning on an explicit chain of prefix parts
2815/// using the current default context and [`GenerationConfig::default()`].
2816#[inline(always)]
2817pub fn generate_bytes_conditional_chain(
2818    prefix_parts: &[&[u8]],
2819    bytes: usize,
2820    max_order: i64,
2821) -> Vec<u8> {
2822    with_default_ctx(|ctx| ctx.generate_bytes_conditional_chain(prefix_parts, bytes, max_order))
2823}
2824
2825/// Generate a continuation after conditioning on an explicit chain of prefix parts
2826/// using the current default context.
2827#[inline(always)]
2828pub fn generate_bytes_conditional_chain_with_config(
2829    prefix_parts: &[&[u8]],
2830    bytes: usize,
2831    max_order: i64,
2832    config: GenerationConfig,
2833) -> Vec<u8> {
2834    with_default_ctx(|ctx| {
2835        ctx.generate_bytes_conditional_chain_with_config(prefix_parts, bytes, max_order, config)
2836    })
2837}
2838
2839/// Kullback-Leibler Divergence D_KL(P || Q) = Σ p(x) log(p(x) / q(x))
2840///
2841/// Marginal only. Measure of how one probability distribution is different from a second.
2842pub fn d_kl_bytes(x: &[u8], y: &[u8]) -> f64 {
2843    if x.is_empty() || y.is_empty() {
2844        return 0.0;
2845    }
2846    let p_x = byte_histogram(x);
2847    let p_y = byte_histogram(y);
2848    let mut d_kl = 0.0f64;
2849    for i in 0..256 {
2850        if p_x[i] > 0.0 {
2851            let q_y = p_y[i].max(1e-12);
2852            d_kl += p_x[i] * (p_x[i] / q_y).log2();
2853        }
2854    }
2855    d_kl.max(0.0)
2856}
2857
2858/// Jensen-Shannon Divergence JSD(P || Q) = 1/2 D_KL(P || M) + 1/2 D_KL(Q || M)
2859/// where M = 1/2 (P + Q)
2860///
2861/// Marginal only. Symmetrized and smoothed version of KL divergence. Range `[0,1]`.
2862pub fn js_div_bytes(x: &[u8], y: &[u8]) -> f64 {
2863    if x.is_empty() || y.is_empty() {
2864        return 0.0;
2865    }
2866    let p_x = byte_histogram(x);
2867    let p_y = byte_histogram(y);
2868    let mut m = [0.0f64; 256];
2869    for i in 0..256 {
2870        m[i] = 0.5 * (p_x[i] + p_y[i]);
2871    }
2872
2873    let mut kl_pm = 0.0f64;
2874    let mut kl_qm = 0.0f64;
2875    for i in 0..256 {
2876        if p_x[i] > 0.0 {
2877            kl_pm += p_x[i] * (p_x[i] / m[i]).log2();
2878        }
2879        if p_y[i] > 0.0 {
2880            kl_qm += p_y[i] * (p_y[i] / m[i]).log2();
2881        }
2882    }
2883    (0.5 * kl_pm + 0.5 * kl_qm).max(0.0)
2884}
2885
2886// ====== Path-based convenience wrappers ======
2887
2888/// NED for files.
2889pub fn ned_paths(x: &str, y: &str, max_order: i64) -> f64 {
2890    let (bx, by) = rayon::join(
2891        || std::fs::read(x).expect("failed to read x"),
2892        || std::fs::read(y).expect("failed to read y"),
2893    );
2894    ned_bytes(&bx, &by, max_order)
2895}
2896
2897/// NTE for files.
2898pub fn nte_paths(x: &str, y: &str, max_order: i64) -> f64 {
2899    let (bx, by) = rayon::join(
2900        || std::fs::read(x).expect("failed to read x"),
2901        || std::fs::read(y).expect("failed to read y"),
2902    );
2903    nte_bytes(&bx, &by, max_order)
2904}
2905
2906/// TVD for files.
2907pub fn tvd_paths(x: &str, y: &str, max_order: i64) -> f64 {
2908    let (bx, by) = rayon::join(
2909        || std::fs::read(x).expect("failed to read x"),
2910        || std::fs::read(y).expect("failed to read y"),
2911    );
2912    tvd_bytes(&bx, &by, max_order)
2913}
2914
2915/// NHD for files.
2916pub fn nhd_paths(x: &str, y: &str, max_order: i64) -> f64 {
2917    let (bx, by) = rayon::join(
2918        || std::fs::read(x).expect("failed to read x"),
2919        || std::fs::read(y).expect("failed to read y"),
2920    );
2921    nhd_bytes(&bx, &by, max_order)
2922}
2923
2924/// Mutual Information for files.
2925pub fn mutual_information_paths(x: &str, y: &str, max_order: i64) -> f64 {
2926    let (bx, by) = rayon::join(
2927        || std::fs::read(x).expect("failed to read x"),
2928        || std::fs::read(y).expect("failed to read y"),
2929    );
2930    mutual_information_bytes(&bx, &by, max_order)
2931}
2932
2933/// Conditional Entropy for files.
2934pub fn conditional_entropy_paths(x: &str, y: &str, max_order: i64) -> f64 {
2935    let (bx, by) = rayon::join(
2936        || std::fs::read(x).expect("failed to read x"),
2937        || std::fs::read(y).expect("failed to read y"),
2938    );
2939    conditional_entropy_bytes(&bx, &by, max_order)
2940}
2941
2942/// Cross-Entropy for files.
2943pub fn cross_entropy_paths(x: &str, y: &str, max_order: i64) -> f64 {
2944    let (bx, by) = rayon::join(
2945        || std::fs::read(x).expect("failed to read x"),
2946        || std::fs::read(y).expect("failed to read y"),
2947    );
2948    cross_entropy_bytes(&bx, &by, max_order)
2949}
2950
2951/// KL Divergence for files.
2952pub fn kl_divergence_paths(x: &str, y: &str) -> f64 {
2953    let (bx, by) = rayon::join(
2954        || std::fs::read(x).expect("failed to read x"),
2955        || std::fs::read(y).expect("failed to read y"),
2956    );
2957    d_kl_bytes(&bx, &by)
2958}
2959
2960/// Jensen-Shannon Divergence for files.
2961pub fn js_divergence_paths(x: &str, y: &str) -> f64 {
2962    let (bx, by) = rayon::join(
2963        || std::fs::read(x).expect("failed to read x"),
2964        || std::fs::read(y).expect("failed to read y"),
2965    );
2966    js_div_bytes(&bx, &by)
2967}
2968
2969// ====== Primitives 6 & 7 ======
2970
2971/// Primitive 6: Intrinsic Dependence (Redundancy Ratio).
2972///
2973/// Measures how much structure is intrinsic to the sample, relative to its
2974/// own marginal entropy baseline.
2975///
2976/// `R = (H_marginal - H_rate) / H_marginal`
2977///
2978/// Clamped to `[0,1]`.
2979///
2980/// Interpretation:
2981///   - `R → 0`: Data is close to i.i.d./max-entropy (little intrinsic structure; highly extrinsically explainable by priors).
2982///   - `R → 1`: Data is highly predictable from its own past (strong intrinsic dependence; e.g., periodic strings like 010101...).
2983#[inline(always)]
2984pub fn intrinsic_dependence_bytes(data: &[u8], max_order: i64) -> f64 {
2985    with_default_ctx(|ctx| ctx.intrinsic_dependence_bytes(data, max_order))
2986}
2987
2988/// Primitive 7: Resistance under Allowed Transformations.
2989///
2990/// Measures how much information is preserved after a transformation `T` is applied to `X`.
2991///
2992/// `Resistance(X, T) = I(X; T(X)) / H(X)`
2993///
2994/// Range `[0,1]` (with guard for `H(X)=0`).
2995/// * 1 means perfectly resistant (identity transformation).
2996/// * 0 means the transformation destroyed all information (e.g. mapping everything to a constant).
2997///
2998/// Assumes X and T(X) are aligned.
2999#[inline(always)]
3000pub fn resistance_to_transformation_bytes(x: &[u8], tx: &[u8], max_order: i64) -> f64 {
3001    with_default_ctx(|ctx| ctx.resistance_to_transformation_bytes(x, tx, max_order))
3002}
3003
3004#[cfg(test)]
3005mod tests {
3006    use super::*;
3007
3008    fn test_match_backend() -> RateBackend {
3009        RateBackend::Match {
3010            hash_bits: 12,
3011            min_len: 2,
3012            max_len: 16,
3013            base_mix: 0.01,
3014            confidence_scale: 1.0,
3015        }
3016    }
3017
3018    fn test_ppmd_backend() -> RateBackend {
3019        RateBackend::Ppmd {
3020            order: 4,
3021            memory_mb: 1,
3022        }
3023    }
3024
3025    fn test_calibrated_backend() -> RateBackend {
3026        RateBackend::Calibrated {
3027            spec: Arc::new(CalibratedSpec {
3028                base: test_match_backend(),
3029                context: CalibrationContextKind::Text,
3030                bins: 16,
3031                learning_rate: 0.05,
3032                bias_clip: 4.0,
3033            }),
3034        }
3035    }
3036
3037    fn test_mixture_backend() -> RateBackend {
3038        RateBackend::Mixture {
3039            spec: Arc::new(MixtureSpec::new(
3040                MixtureKind::Bayes,
3041                vec![
3042                    MixtureExpertSpec {
3043                        name: Some("match".to_string()),
3044                        log_prior: 0.0,
3045                        max_order: -1,
3046                        backend: test_match_backend(),
3047                    },
3048                    MixtureExpertSpec {
3049                        name: Some("ppmd".to_string()),
3050                        log_prior: 0.0,
3051                        max_order: -1,
3052                        backend: test_ppmd_backend(),
3053                    },
3054                ],
3055            )),
3056        }
3057    }
3058
3059    fn test_particle_backend() -> RateBackend {
3060        RateBackend::Particle {
3061            spec: Arc::new(ParticleSpec {
3062                num_particles: 4,
3063                num_cells: 4,
3064                cell_dim: 8,
3065                num_rules: 2,
3066                selector_hidden: 16,
3067                rule_hidden: 16,
3068                context_window: 8,
3069                unroll_steps: 1,
3070                ..ParticleSpec::default()
3071            }),
3072        }
3073    }
3074
3075    fn continuation_prompt() -> &'static [u8] {
3076        b"If a frog is green, dogs are red.\nIf a toad is green, cats are red.\nIf a dog is green, frogs are red.\nIf a cat is green, toads are red.\nIf a frog is red, dogs are green.\nIf a toad is red, cats are green.\nIf a dog is red, frogs are green.\nIf a cat is red, toads are \n"
3077    }
3078
3079    fn assert_deterministic_generate_for_backend(
3080        backend: RateBackend,
3081        max_order: i64,
3082        bytes: usize,
3083        label: &str,
3084    ) {
3085        let prompt = continuation_prompt();
3086        let a = generate_rate_backend_chain(
3087            &[prompt],
3088            bytes,
3089            max_order,
3090            &backend,
3091            GenerationConfig::default(),
3092        );
3093        let b = generate_rate_backend_chain(
3094            &[prompt],
3095            bytes,
3096            max_order,
3097            &backend,
3098            GenerationConfig::default(),
3099        );
3100        assert_eq!(
3101            a, b,
3102            "{label} generation should be deterministic for identical input"
3103        );
3104        assert_eq!(
3105            a.len(),
3106            bytes,
3107            "{label} generation should emit requested byte count"
3108        );
3109    }
3110
3111    fn assert_sampled_generate_for_backend(
3112        backend: RateBackend,
3113        max_order: i64,
3114        bytes: usize,
3115        label: &str,
3116    ) {
3117        let prompt = continuation_prompt();
3118        let config = GenerationConfig::sampled_frozen(42);
3119        let a = generate_rate_backend_chain(&[prompt], bytes, max_order, &backend, config);
3120        let b = generate_rate_backend_chain(&[prompt], bytes, max_order, &backend, config);
3121        assert_eq!(
3122            a, b,
3123            "{label} sampled generation should be deterministic for a fixed seed"
3124        );
3125        assert_eq!(
3126            a.len(),
3127            bytes,
3128            "{label} sampled generation should emit requested byte count"
3129        );
3130    }
3131
3132    #[cfg(feature = "backend-zpaq")]
3133    #[test]
3134    fn ncd_basic_identity_nonnegative() {
3135        let x = b"abcdabcdabcd";
3136        let d = ncd_bytes(x, x, "5", NcdVariant::Vitanyi);
3137        assert!(d >= -1e-9);
3138    }
3139
3140    #[test]
3141    fn shannon_identities_marginal_aligned() {
3142        let x = b"abracadabra";
3143        let y = b"abracadabra";
3144
3145        let h = marginal_entropy_bytes(x);
3146        let mi = mutual_information_bytes(x, y, 0);
3147        let h_xy = joint_marginal_entropy_bytes(x, y);
3148        let h_x_given_y = conditional_entropy_bytes(x, y, 0);
3149        let ned = ned_bytes(x, y, 0);
3150        let nte = nte_bytes(x, y, 0);
3151
3152        assert!((h_xy - h).abs() < 1e-12);
3153        assert!(h_x_given_y.abs() < 1e-12);
3154        assert!((mi - h).abs() < 1e-12);
3155        assert!(ned.abs() < 1e-12);
3156        assert!(nte.abs() < 1e-12);
3157    }
3158
3159    #[test]
3160    fn shannon_identities_rate_aligned_reasonable() {
3161        let x = b"the quick brown fox jumps over the lazy dog";
3162        let y = b"the quick brown fox jumps over the lazy dog";
3163        let max_order = 8;
3164        let prev = get_default_ctx();
3165        set_default_ctx(InfotheoryCtx::new(
3166            RateBackend::RosaPlus,
3167            CompressionBackend::default(),
3168        ));
3169
3170        let h_x = entropy_rate_bytes(x, max_order);
3171        let h_xy = joint_entropy_rate_bytes(x, y, max_order);
3172        let h_x_given_y = conditional_entropy_rate_bytes(x, y, max_order);
3173        let mi = mutual_information_bytes(x, y, max_order);
3174        let ned = ned_bytes(x, y, max_order);
3175
3176        // Finite-sample estimators won't be exact; allow reasonable tolerance.
3177        let tol = 0.2;
3178        assert!((h_xy - h_x).abs() < tol);
3179        assert!(h_x_given_y < tol);
3180        assert!((mi - h_x).abs() < tol);
3181        assert!(ned < tol);
3182        set_default_ctx(prev);
3183    }
3184
3185    #[test]
3186    fn resistance_identity_is_one() {
3187        let x = b"some repeated repeated repeated text";
3188        let prev = get_default_ctx();
3189        set_default_ctx(InfotheoryCtx::new(
3190            RateBackend::RosaPlus,
3191            CompressionBackend::default(),
3192        ));
3193        let r0 = resistance_to_transformation_bytes(x, x, 0);
3194        let r8 = resistance_to_transformation_bytes(x, x, 8);
3195        assert!((r0 - 1.0).abs() < 1e-12);
3196        assert!((r8 - 1.0).abs() < 1e-6);
3197        set_default_ctx(prev);
3198    }
3199
3200    #[test]
3201    fn marginal_metrics_empty_inputs_are_zero() {
3202        let empty: &[u8] = &[];
3203        let x = b"abc";
3204
3205        assert_eq!(tvd_bytes(empty, x, 0), 0.0);
3206        assert_eq!(tvd_bytes(x, empty, 0), 0.0);
3207        assert_eq!(nhd_bytes(empty, x, 0), 0.0);
3208        assert_eq!(nhd_bytes(x, empty, 0), 0.0);
3209        assert_eq!(d_kl_bytes(empty, x), 0.0);
3210        assert_eq!(d_kl_bytes(x, empty), 0.0);
3211        assert_eq!(js_div_bytes(empty, x), 0.0);
3212        assert_eq!(js_div_bytes(x, empty), 0.0);
3213    }
3214
3215    #[test]
3216    fn marginal_cross_entropy_empty_test_is_zero() {
3217        let empty: &[u8] = &[];
3218        let y = b"abc";
3219        let ctx = InfotheoryCtx::with_zpaq("5");
3220        assert_eq!(ctx.cross_entropy_bytes(empty, y, 0), 0.0);
3221    }
3222
3223    #[cfg(not(feature = "backend-zpaq"))]
3224    #[test]
3225    #[should_panic(expected = "CompressionBackend::Zpaq is unavailable")]
3226    fn explicit_zpaq_backend_fails_loudly() {
3227        let backend = CompressionBackend::Zpaq {
3228            method: "5".to_string(),
3229        };
3230        let _ = compress_size_backend(b"abc", &backend);
3231    }
3232
3233    #[cfg(not(feature = "backend-zpaq"))]
3234    #[test]
3235    fn default_compression_backend_falls_back_to_rate_coding() {
3236        let backend = CompressionBackend::default();
3237        assert!(matches!(
3238            &backend,
3239            CompressionBackend::Rate {
3240                coder: crate::coders::CoderType::AC,
3241                framing: crate::compression::FramingMode::Raw,
3242                ..
3243            }
3244        ));
3245        assert!(compress_size_backend(b"abc", &backend) > 0);
3246    }
3247
3248    #[test]
3249    fn backend_switching_test() {
3250        let x = b"hello world context";
3251
3252        // Default is RosaPlus
3253        let h_rosa = entropy_rate_bytes(x, 8);
3254
3255        // Switch to CTW
3256        set_default_ctx(InfotheoryCtx::new(
3257            RateBackend::Ctw { depth: 16 },
3258            CompressionBackend::default(),
3259        ));
3260
3261        let h_ctw = entropy_rate_bytes(x, 8);
3262
3263        // They should generally be different, but most importantly, CTW worked
3264        assert!(h_ctw > 0.0);
3265
3266        // Reset to default
3267        set_default_ctx(InfotheoryCtx::default());
3268        let h_rosa_back = entropy_rate_bytes(x, 8);
3269        assert!((h_rosa - h_rosa_back).abs() < 1e-12);
3270    }
3271
3272    #[test]
3273    fn ctw_early_updates_work() {
3274        // Test that CTW produces valid predictions from the very start,
3275        // not just after `depth` symbols have been processed.
3276        use crate::ctw::ContextTree;
3277
3278        let mut tree = ContextTree::new(16);
3279
3280        // Even the first prediction should be valid (not NaN, not 0)
3281        let p0 = tree.predict(false);
3282        let p1 = tree.predict(true);
3283
3284        // Initial KT estimator gives 0.5 / 1 = 0.5 for each symbol
3285        assert!((p0 - 0.5).abs() < 1e-10, "p0 should be ~0.5, got {}", p0);
3286        assert!((p1 - 0.5).abs() < 1e-10, "p1 should be ~0.5, got {}", p1);
3287        assert!((p0 + p1 - 1.0).abs() < 1e-10, "p0 + p1 should = 1.0");
3288
3289        // Update with a few symbols and verify log_prob becomes negative (valid)
3290        for _ in 0..5 {
3291            tree.update(true);
3292            tree.update(false);
3293        }
3294
3295        let log_prob = tree.get_log_block_probability();
3296        assert!(
3297            log_prob < 0.0,
3298            "log_prob should be negative (< log 1), got {}",
3299            log_prob
3300        );
3301        assert!(log_prob.is_finite(), "log_prob should be finite");
3302    }
3303
3304    #[test]
3305    fn nte_can_exceed_one() {
3306        // Test that NTE is properly clamped to [0, 2] instead of [0, 1]
3307        // For independent sequences with similar entropy, NTE can approach 2.0
3308        //
3309        // Note: For *marginal* NTE, due to how joint entropy works for aligned pairs,
3310        // it's mathematically bounded differently. The fix for NTE clamping primarily
3311        // affects *rate*-based NTE where VI can truly be 2*max(H).
3312        //
3313        // We test that the clamp upper bound is at least > 1.0 for cases where VI > max(H)
3314
3315        // Use CTW backend for rate-based test
3316        set_default_ctx(InfotheoryCtx::new(
3317            RateBackend::Ctw { depth: 8 },
3318            CompressionBackend::default(),
3319        ));
3320
3321        // Generate two completely different patterns - should have high VI
3322        let x: Vec<u8> = (0..200).map(|i| (i % 2) as u8).collect(); // 010101...
3323        let y: Vec<u8> = (0..200).map(|i| ((i + 1) % 2) as u8).collect(); // 101010...
3324
3325        let nte_rate = nte_rate_backend(&x, &y, -1, &RateBackend::Ctw { depth: 8 });
3326
3327        // With the fix, NTE should not be clamped to 1.0
3328        // It may or may not exceed 1.0 depending on the specifics, but it should be allowed to
3329        assert!(
3330            (0.0..=2.0 + 1e-9).contains(&nte_rate),
3331            "NTE should be in [0, 2], got {}",
3332            nte_rate
3333        );
3334
3335        // Reset context
3336        set_default_ctx(InfotheoryCtx::default());
3337    }
3338
3339    #[test]
3340    fn ctw_empty_data_returns_zero() {
3341        // Verify empty data doesn't cause division-by-zero or NaN
3342        set_default_ctx(InfotheoryCtx::new(
3343            RateBackend::Ctw { depth: 16 },
3344            CompressionBackend::default(),
3345        ));
3346
3347        let empty: &[u8] = &[];
3348        let h = entropy_rate_bytes(empty, -1);
3349        assert_eq!(h, 0.0, "empty data should return 0.0 entropy");
3350
3351        // Reset
3352        set_default_ctx(InfotheoryCtx::default());
3353    }
3354
3355    #[test]
3356    fn joint_entropy_rate_aligns_inputs_and_handles_empty_cases() {
3357        let cases = vec![
3358            ("ctw", RateBackend::Ctw { depth: 8 }),
3359            (
3360                "fac-ctw",
3361                RateBackend::FacCtw {
3362                    base_depth: 8,
3363                    num_percept_bits: 8,
3364                    encoding_bits: 8,
3365                },
3366            ),
3367            ("match", test_match_backend()),
3368        ];
3369
3370        for (name, backend) in cases {
3371            assert_eq!(
3372                joint_entropy_rate_backend(b"", b"nonempty", -1, &backend),
3373                0.0,
3374                "{name} should return 0.0 for empty aligned pairs"
3375            );
3376            assert_eq!(
3377                joint_entropy_rate_backend(b"nonempty", b"", -1, &backend),
3378                0.0,
3379                "{name} should return 0.0 when alignment truncates to empty"
3380            );
3381
3382            let aligned = joint_entropy_rate_backend(b"abcd", b"wxyz", -1, &backend);
3383            let truncated = joint_entropy_rate_backend(b"abcdextra", b"wxyz", -1, &backend);
3384            assert!(
3385                (aligned - truncated).abs() < 1e-12,
3386                "{name} should score only the aligned prefix: aligned={aligned} truncated={truncated}"
3387            );
3388        }
3389    }
3390
3391    #[test]
3392    fn biased_entropy_is_repeatable_across_backend_families() {
3393        let data = b"ABABABAABBABABABAABB";
3394        let cases = vec![
3395            ("match", test_match_backend()),
3396            ("ppmd", test_ppmd_backend()),
3397            ("calibrated", test_calibrated_backend()),
3398            ("ctw", RateBackend::Ctw { depth: 8 }),
3399            ("mixture", test_mixture_backend()),
3400            ("particle", test_particle_backend()),
3401        ];
3402
3403        for (name, backend) in cases {
3404            let h1 = biased_entropy_rate_backend(data, -1, &backend);
3405            let h2 = biased_entropy_rate_backend(data, -1, &backend);
3406            assert!(h1.is_finite(), "{name} biased entropy should be finite");
3407            assert!(
3408                (h1 - h2).abs() < 1e-12,
3409                "{name} biased entropy leaked mutable state across calls: h1={h1} h2={h2}"
3410            );
3411        }
3412    }
3413
3414    #[test]
3415    fn generate_bytes_chain_matches_flat_prompt() {
3416        let prompt = continuation_prompt();
3417        let split_at = prompt.len() / 2;
3418        let front = &prompt[..split_at];
3419        let back = &prompt[split_at..];
3420        let backend = RateBackend::Ctw { depth: 32 };
3421        let bytes = 8usize;
3422        let max_order = -1;
3423
3424        let flat = generate_rate_backend_chain(
3425            &[prompt],
3426            bytes,
3427            max_order,
3428            &backend,
3429            GenerationConfig::default(),
3430        );
3431        let chained = generate_rate_backend_chain(
3432            &[front, back],
3433            bytes,
3434            max_order,
3435            &backend,
3436            GenerationConfig::default(),
3437        );
3438        assert_eq!(
3439            flat, chained,
3440            "chain conditioning should match flat prompt conditioning"
3441        );
3442    }
3443
3444    #[test]
3445    fn generate_bytes_api_is_deterministic_for_ctw_rosa_match_ppmd() {
3446        assert_deterministic_generate_for_backend(RateBackend::Ctw { depth: 32 }, -1, 8, "ctw");
3447        assert_deterministic_generate_for_backend(RateBackend::RosaPlus, -1, 8, "rosaplus");
3448        assert_deterministic_generate_for_backend(test_match_backend(), -1, 8, "match");
3449        assert_deterministic_generate_for_backend(test_ppmd_backend(), -1, 8, "ppmd");
3450    }
3451
3452    #[cfg(feature = "backend-rwkv")]
3453    #[test]
3454    fn generate_bytes_api_is_deterministic_for_rwkv_method() {
3455        let backend = RateBackend::Rwkv7Method {
3456            method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3457        };
3458        assert_deterministic_generate_for_backend(backend, -1, 8, "rwkv7");
3459    }
3460
3461    #[test]
3462    fn sampled_generation_is_deterministic_for_ctw_rosa_match_ppmd() {
3463        assert_sampled_generate_for_backend(RateBackend::Ctw { depth: 32 }, -1, 8, "ctw");
3464        assert_sampled_generate_for_backend(RateBackend::RosaPlus, -1, 8, "rosaplus");
3465        assert_sampled_generate_for_backend(test_match_backend(), -1, 8, "match");
3466        assert_sampled_generate_for_backend(test_ppmd_backend(), -1, 8, "ppmd");
3467    }
3468
3469    #[cfg(feature = "backend-rwkv")]
3470    #[test]
3471    fn sampled_generation_is_deterministic_for_rwkv_method() {
3472        let backend = RateBackend::Rwkv7Method {
3473            method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3474        };
3475        assert_sampled_generate_for_backend(backend, -1, 8, "rwkv7");
3476    }
3477
3478    #[test]
3479    fn rosaplus_sampled_generation_predicts_green_continuation() {
3480        let out = generate_rate_backend_chain(
3481            &[continuation_prompt()],
3482            8,
3483            -1,
3484            &RateBackend::RosaPlus,
3485            GenerationConfig::sampled_frozen(42),
3486        );
3487        assert_eq!(out, b" green.\n");
3488    }
3489
3490    #[test]
3491    fn rate_backend_session_matches_ctx_generation() {
3492        let prompt = continuation_prompt();
3493        let backend = RateBackend::Ppmd {
3494            order: 12,
3495            memory_mb: 8,
3496        };
3497        let mut session =
3498            RateBackendSession::from_backend(backend.clone(), -1, Some((prompt.len() + 8) as u64))
3499                .expect("session init");
3500        session.observe(prompt);
3501        let from_session = session.generate_bytes(8, GenerationConfig::sampled_frozen(42));
3502        session.finish().expect("session finish");
3503
3504        let ctx = InfotheoryCtx::new(backend, CompressionBackend::default());
3505        let from_ctx =
3506            ctx.generate_bytes_with_config(prompt, 8, -1, GenerationConfig::sampled_frozen(42));
3507        assert_eq!(from_session, from_ctx);
3508    }
3509
3510    #[test]
3511    fn biased_entropy_ctw_uses_frozen_plugin_scoring() {
3512        let backend = RateBackend::Ctw { depth: 8 };
3513        let data = b"AAAAAAAA";
3514        let plugin = biased_entropy_rate_backend(data, -1, &backend);
3515        let prequential = entropy_rate_backend(data, -1, &backend);
3516        assert!(
3517            plugin + 1e-9 < prequential,
3518            "expected plugin scoring to beat prequential scoring: plugin={plugin} prequential={prequential}"
3519        );
3520    }
3521
3522    #[test]
3523    fn rosa_plugin_entropy_matches_direct_model_api() {
3524        let data = b"abracadabra";
3525        let backend = RateBackend::RosaPlus;
3526
3527        let plugin = biased_entropy_rate_backend(data, 3, &backend);
3528
3529        let mut direct = rosaplus::RosaPlus::new(3, false, 0, 42);
3530        direct.train_example(data);
3531        direct.build_lm();
3532        let expected = direct.cross_entropy(data);
3533
3534        assert!(
3535            (plugin - expected).abs() < 1e-12,
3536            "rosa plugin entropy must match direct model API: plugin={plugin} expected={expected}"
3537        );
3538    }
3539
3540    #[test]
3541    fn rosa_plugin_cross_entropy_matches_direct_model_api() {
3542        let train = b"alakazam";
3543        let test = b"abracadabra";
3544        let backend = RateBackend::RosaPlus;
3545
3546        let plugin = cross_entropy_rate_backend(test, train, 3, &backend);
3547
3548        let mut direct = rosaplus::RosaPlus::new(3, false, 0, 42);
3549        direct.train_example(train);
3550        direct.build_lm();
3551        let expected = direct.cross_entropy(test);
3552
3553        assert!(
3554            (plugin - expected).abs() < 1e-12,
3555            "rosa plugin cross entropy must match direct model API: plugin={plugin} expected={expected}"
3556        );
3557    }
3558
3559    #[test]
3560    fn datagen_bernoulli_entropy_estimate() {
3561        // Test that estimated entropy is close to theoretical for Bernoulli(0.5)
3562        let p = 0.5;
3563        let theoretical_h = crate::datagen::bernoulli_entropy(p);
3564        assert!((theoretical_h - 1.0).abs() < 1e-10);
3565
3566        // Generate data and check marginal entropy is close to theoretical
3567        let data = crate::datagen::bernoulli(10000, p, 42);
3568        let estimated_h = marginal_entropy_bytes(&data);
3569
3570        // Should be close to 1.0 bit (since values are 0 or 1)
3571        assert!(
3572            (estimated_h - theoretical_h).abs() < 0.1,
3573            "estimated H={} should be close to theoretical H={}",
3574            estimated_h,
3575            theoretical_h
3576        );
3577    }
3578
3579    #[cfg(feature = "backend-rwkv")]
3580    #[test]
3581    fn rwkv_method_entropy_is_stable_across_calls() {
3582        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=21,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:infer";
3583        let backend = RateBackend::Rwkv7Method {
3584            method: method.to_string(),
3585        };
3586        let data = b"rwkv method entropy stability regression sample";
3587
3588        let h1 = entropy_rate_backend(data, -1, &backend);
3589        let h2 = entropy_rate_backend(data, -1, &backend);
3590        assert!(
3591            (h1 - h2).abs() < 1e-12,
3592            "rwkv method entropy leaked mutable state across calls: h1={h1}, h2={h2}"
3593        );
3594    }
3595
3596    #[cfg(feature = "backend-rwkv")]
3597    #[test]
3598    fn rwkv_method_without_policy_is_accepted_by_public_api() {
3599        let backend = RateBackend::Rwkv7Method {
3600            method: "cfg:hidden=64,layers=1,intermediate=64".to_string(),
3601        };
3602        let data = b"rwkv method without policy";
3603        let h1 = entropy_rate_backend(data, -1, &backend);
3604        let h2 = biased_entropy_rate_backend(data, -1, &backend);
3605        assert!(h1.is_finite());
3606        assert!(h2.is_finite());
3607    }
3608
3609    #[cfg(feature = "backend-rwkv")]
3610    #[test]
3611    fn rwkv_infer_only_plugin_collapses_to_single_pass_entropy() {
3612        let backend = RateBackend::Rwkv7Method {
3613            method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=25,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3614        };
3615        let data = b"rwkv infer-only plugin equality sample";
3616        let h = entropy_rate_backend(data, -1, &backend);
3617        let plugin = biased_entropy_rate_backend(data, -1, &backend);
3618        assert!(
3619            (h - plugin).abs() < 1e-12,
3620            "infer-only rwkv plugin should equal single-pass entropy: h={h}, plugin={plugin}"
3621        );
3622    }
3623
3624    #[cfg(feature = "backend-rwkv")]
3625    #[test]
3626    fn rwkv_method_biased_entropy_is_stable_across_calls_with_training_policy() {
3627        let backend = RateBackend::Rwkv7Method {
3628            method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=23,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.0)".to_string(),
3629        };
3630        let data = b"rwkv plugin stability sample";
3631        let h1 = biased_entropy_rate_backend(data, -1, &backend);
3632        let h2 = biased_entropy_rate_backend(data, -1, &backend);
3633        assert!(
3634            (h1 - h2).abs() < 1e-12,
3635            "rwkv method biased entropy leaked mutable state across calls: h1={h1}, h2={h2}"
3636        );
3637    }
3638
3639    #[cfg(feature = "backend-rwkv")]
3640    #[test]
3641    fn rwkv_method_conditional_chain_is_stable_across_calls() {
3642        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=22,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:infer";
3643        let ctx = InfotheoryCtx::new(
3644            RateBackend::Rwkv7Method {
3645                method: method.to_string(),
3646            },
3647            CompressionBackend::default(),
3648        );
3649
3650        let prefix = b"universal prior slice";
3651        let data = b"query payload";
3652        let h1 = ctx.cross_entropy_conditional_chain(&[prefix.as_slice()], data);
3653        let h2 = ctx.cross_entropy_conditional_chain(&[prefix.as_slice()], data);
3654        assert!(
3655            (h1 - h2).abs() < 1e-12,
3656            "rwkv method conditional chain leaked mutable state across calls: h1={h1}, h2={h2}"
3657        );
3658    }
3659
3660    #[cfg(feature = "backend-mamba")]
3661    #[test]
3662    fn mamba_method_without_policy_is_accepted_by_public_api() {
3663        let backend = RateBackend::MambaMethod {
3664            method: "cfg:hidden=64,layers=1,intermediate=96".to_string(),
3665        };
3666        let data = b"mamba method without policy";
3667        let h1 = entropy_rate_backend(data, -1, &backend);
3668        let h2 = biased_entropy_rate_backend(data, -1, &backend);
3669        assert!(h1.is_finite());
3670        assert!(h2.is_finite());
3671    }
3672
3673    #[cfg(feature = "backend-mamba")]
3674    #[test]
3675    fn mamba_infer_only_plugin_collapses_to_single_pass_entropy() {
3676        let backend = RateBackend::MambaMethod {
3677            method: "cfg:hidden=64,layers=1,intermediate=96,state=16,conv=4,dt_rank=16,seed=26,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3678        };
3679        let data = b"mamba infer-only plugin equality sample";
3680        let h = entropy_rate_backend(data, -1, &backend);
3681        let plugin = biased_entropy_rate_backend(data, -1, &backend);
3682        assert!(
3683            (h - plugin).abs() < 1e-12,
3684            "infer-only mamba plugin should equal single-pass entropy: h={h}, plugin={plugin}"
3685        );
3686    }
3687
3688    #[cfg(feature = "backend-mamba")]
3689    #[test]
3690    fn mamba_method_biased_entropy_is_stable_across_calls_with_training_policy() {
3691        let backend = RateBackend::MambaMethod {
3692            method: "cfg:hidden=64,layers=1,intermediate=96,state=16,conv=4,dt_rank=16,seed=24,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.0)".to_string(),
3693        };
3694        let data = b"mamba plugin stability sample";
3695        let h1 = biased_entropy_rate_backend(data, -1, &backend);
3696        let h2 = biased_entropy_rate_backend(data, -1, &backend);
3697        assert!(
3698            (h1 - h2).abs() < 1e-12,
3699            "mamba method biased entropy leaked mutable state across calls: h1={h1}, h2={h2}"
3700        );
3701    }
3702
3703    #[test]
3704    fn particle_entropy_rate_in_valid_range() {
3705        let rb = test_particle_backend();
3706        let data = b"hello world particle backend test";
3707        let rate = entropy_rate_backend(data, -1, &rb);
3708        assert!(
3709            rate > 0.0 && rate < 8.0,
3710            "particle entropy rate out of (0, 8) range: {rate}"
3711        );
3712    }
3713
3714    #[test]
3715    fn particle_cross_entropy_stability() {
3716        let rb = test_particle_backend();
3717        let train = b"ABCABC";
3718        let test = b"ABC";
3719        let h1 = cross_entropy_rate_backend(test, train, -1, &rb);
3720        let h2 = cross_entropy_rate_backend(test, train, -1, &rb);
3721        assert!(
3722            (h1 - h2).abs() < 1e-12,
3723            "particle cross entropy not deterministic: h1={h1}, h2={h2}"
3724        );
3725    }
3726
3727    #[test]
3728    fn particle_empty_input() {
3729        let rb = RateBackend::Particle {
3730            spec: Arc::new(ParticleSpec::default()),
3731        };
3732        let rate = entropy_rate_backend(b"", -1, &rb);
3733        assert!(
3734            rate == 0.0,
3735            "particle entropy rate for empty input should be 0.0, got {rate}"
3736        );
3737    }
3738
3739    #[test]
3740    fn particle_joint_entropy_rate() {
3741        let rb = test_particle_backend();
3742        let x = b"AAAA";
3743        let y = b"BBBB";
3744        let joint = joint_entropy_rate_backend(x, y, -1, &rb);
3745        assert!(
3746            joint > 0.0 && joint < 16.0,
3747            "particle joint entropy rate out of range: {joint}"
3748        );
3749    }
3750}