Skip to main content

infotheory/
lib.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2
3//! # InfoTheory: Information Theoretic Estimators & Metrics
4//!
5//! This crate provides a comprehensive suite of information-theoretic primitives for
6//! quantifying complexity, dependence, and similarity between data sequences.
7//!
8//! It implements two primary classes of estimators:
9//! 1.  **Compression-based (Kolmogorov Complexity)**: Using the ZPAQ compression algorithm to estimate
10//!     Normalized Compression Distance (NCD).
11//! 2.  **Entropy-based (Shannon Information)**: Using both exact marginal histograms (for i.i.d. data)
12//!     and the ROSA (Rapid Online Suffix Automaton) predictive language model (for sequential data)
13//!     to estimate Entropy, Mutual Information, and related distances.
14//!
15//! ## Mathematical Primitives
16//!
17//! The library implements the following core measures. For sequential data, "Rate" variants
18//! use the ROSA model to estimate `Ĥ(X)` (entropy rate), while "Marginal" variants
19//! treat data as a bag-of-bytes (i.i.d.) and compute `H(X)` from histograms.
20//!
21//! ### 1. Normalized Compression Distance (NCD)
22//! Approximates the Normalized Information Distance (NID) using a compressor `C`.
23//!
24//! `NCD(x,y) = (C(xy) - min(C(x), C(y))) / max(C(x), C(y))`
25//!
26//! ### 2. Normalized Entropy Distance (NED)
27//! An entropic analogue to NCD, defined using Shannon entropy `H`.
28//!
29//! `NED(X,Y) = (H(X,Y) - min(H(X), H(Y))) / max(H(X), H(Y))`
30//!
31//! ### 3. Normalized Transform Effort (NTE)
32//! Based on the Variation of Information (VI), normalized by the maximum entropy.
33//!
34//! `NTE(X,Y) = (H(X|Y) + H(Y|X)) / max(H(X), H(Y)) = (2H(X,Y) - H(X) - H(Y)) / max(H(X), H(Y))`
35//!
36//! ### 4. Mutual Information (MI)
37//! Measures the amount of information obtained about one random variable by observing another.
38//!
39//! `I(X;Y) = H(X) + H(Y) - H(X,Y)`
40//!
41//! ### 5. Divergences & Distances
42//! *   **Total Variation Distance (TVD)**: `δ(P,Q) = 0.5 * Σ |P(x) - Q(x)|`
43//! *   **Normalized Hellinger Distance (NHD)**: `sqrt(1 - Σ sqrt(P(x)Q(x)))`
44//! *   **Kullback-Leibler Divergence (KL)**: `D_KL(P||Q) = Σ P(x) log(P(x)/Q(x))`
45//! *   **Jensen-Shannon Divergence (JSD)**: Symmetrized and smoothed KL divergence.
46//!
47//! ### 6. Intrinsic Dependence (ID)
48//! Measures the redundancy within a sequence, comparing marginal entropy to entropy rate.
49//!
50//! `ID(X) = (H_marginal(X) - H_rate(X)) / H_marginal(X)`
51//!
52//! ### 7. Resistance to Transformation
53//! Quantifies how much information is preserved after a transformation `T` is applied.
54//!
55//! `R(X, T) = I(X; T(X)) / H(X)`
56//!
57//! ## Usage
58//!
59//! ```rust,no_run
60//! use infotheory::{ncd_vitanyi, mutual_information_bytes, NcdVariant};
61//!
62//! let x = b"some data sequence";
63//! let y = b"another data sequence";
64//!
65//! // Compression-based distance
66//! let ncd = ncd_vitanyi("file1.txt", "file2.txt", "5");
67//!
68//! // Entropy-based mutual information (Marginal / i.i.d.)
69//! let mi_marg = mutual_information_bytes(x, y, 0);
70//!
71//! // Entropy-based mutual information (Rate / Sequential, max_order=8)
72//! let mi_rate = mutual_information_bytes(x, y, 8);
73//! ```
74
75/// AIXI planning components, environments, and model abstractions.
76pub mod aixi;
77/// Core information-theoretic axioms and validation helpers.
78pub mod axioms;
79/// Entropy/compression backend implementations and backend discovery.
80pub mod backends;
81/// Entropy coder implementations (AC and rANS).
82pub mod coders;
83/// Rate-coded compression helpers built on generic rate backends.
84pub mod compression;
85/// Synthetic data generators for information-theory experiments.
86pub mod datagen;
87/// Diagnostic tooling for exact AC/log-loss mixture tracing.
88pub mod diagnostics;
89/// Online Bayesian/switching/MDL mixture predictors.
90pub mod mixture;
91pub(crate) mod neural_mix;
92/// Information-theoretic code search pipeline (3-stage: prefilter, filter, KMI rerank).
93pub mod search;
94pub(crate) mod simd_math;
95/// CTW and FAC-CTW backend types.
96pub use backends::ctw;
97#[cfg(feature = "backend-mamba")]
98/// Mamba backend types and compressor.
99pub use backends::mambazip;
100/// Match-based repeat predictor.
101pub use backends::match_model;
102/// Particle-latent filter ensemble rate backend.
103pub use backends::particle;
104/// PPMD-style byte model.
105pub use backends::ppmd;
106/// ROSA+ backend types.
107pub use backends::rosaplus;
108#[cfg(feature = "backend-rwkv")]
109/// RWKV backend types and compressor.
110pub use backends::rwkvzip;
111/// Exact online Sequitur backend types.
112pub use backends::sequitur;
113/// Sparse/gapped match predictor.
114pub use backends::sparse_match;
115/// ZPAQ rate-model adapter.
116pub use backends::zpaq_rate;
117
118use rayon::prelude::*;
119
120use crate::coders::CoderType;
121use std::cell::RefCell;
122#[cfg(any(feature = "backend-rwkv", feature = "backend-mamba"))]
123use std::collections::HashMap;
124use std::sync::Arc;
125use std::sync::OnceLock;
126
127/// How generated symbols should update the model state.
128#[derive(Clone, Copy, Debug, PartialEq, Eq)]
129pub enum GenerationUpdateMode {
130    /// Keep adapting/fitting on generated bytes.
131    Adaptive,
132    /// Freeze fitted parameters/statistics and only advance conditioning state.
133    Frozen,
134}
135
136/// How to pick the next byte from the model distribution.
137#[derive(Clone, Copy, Debug, PartialEq, Eq)]
138pub enum GenerationStrategy {
139    /// Deterministic argmax over the next-byte distribution.
140    Greedy,
141    /// Seeded sampling from the next-byte distribution.
142    Sample,
143}
144
145/// Generation options shared by the library API and CLI.
146#[derive(Clone, Copy, Debug)]
147pub struct GenerationConfig {
148    /// Byte-selection strategy.
149    pub strategy: GenerationStrategy,
150    /// Whether generated bytes should keep adapting the model.
151    pub update_mode: GenerationUpdateMode,
152    /// RNG seed used by [`GenerationStrategy::Sample`].
153    pub seed: u64,
154    /// Softmax temperature for sampling. `<= 0` behaves like greedy.
155    pub temperature: f64,
156    /// Optional top-k truncation. `0` disables it.
157    pub top_k: usize,
158    /// Optional nucleus truncation. Values `>= 1.0` disable it.
159    pub top_p: f64,
160}
161
162impl Default for GenerationConfig {
163    fn default() -> Self {
164        Self::sampled_frozen(42)
165    }
166}
167
168impl GenerationConfig {
169    /// Deterministic frozen continuation.
170    pub const fn greedy_frozen() -> Self {
171        Self {
172            strategy: GenerationStrategy::Greedy,
173            update_mode: GenerationUpdateMode::Frozen,
174            seed: 0xD00D_F00D_CAFE_BABEu64,
175            temperature: 1.0,
176            top_k: 0,
177            top_p: 1.0,
178        }
179    }
180
181    /// Seeded frozen sampling from the model distribution.
182    pub const fn sampled_frozen(seed: u64) -> Self {
183        Self {
184            strategy: GenerationStrategy::Sample,
185            update_mode: GenerationUpdateMode::Frozen,
186            seed,
187            temperature: 1.0,
188            top_k: 0,
189            top_p: 1.0,
190        }
191    }
192}
193
194struct GenerationRng {
195    state: u64,
196}
197
198impl GenerationRng {
199    fn new(seed: u64) -> Self {
200        Self {
201            state: if seed == 0 {
202                0xD00D_F00D_CAFE_BABEu64
203            } else {
204                seed
205            },
206        }
207    }
208
209    fn next_u64(&mut self) -> u64 {
210        let mut x = self.state;
211        x ^= x << 13;
212        x ^= x >> 7;
213        x ^= x << 17;
214        self.state = x;
215        x
216    }
217
218    fn next_f64(&mut self) -> f64 {
219        (self.next_u64() as f64) / (u64::MAX as f64)
220    }
221}
222
223static NUM_THREADS: OnceLock<usize> = OnceLock::new();
224
225thread_local! {
226    #[cfg(feature = "backend-mamba")]
227    static MAMBA_TLS: RefCell<HashMap<usize, mambazip::Compressor>> = RefCell::new(HashMap::new());
228    #[cfg(feature = "backend-mamba")]
229    static MAMBA_RATE_TLS: RefCell<HashMap<usize, mambazip::Compressor>> = RefCell::new(HashMap::new());
230    #[cfg(feature = "backend-mamba")]
231    static MAMBA_METHOD_TLS: RefCell<HashMap<String, mambazip::Compressor>> = RefCell::new(HashMap::new());
232    #[cfg(feature = "backend-rwkv")]
233    static RWKV_TLS: RefCell<HashMap<usize, rwkvzip::Compressor>> = RefCell::new(HashMap::new());
234    #[cfg(feature = "backend-rwkv")]
235    static RWKV_RATE_TLS: RefCell<HashMap<usize, rwkvzip::Compressor>> = RefCell::new(HashMap::new());
236    #[cfg(feature = "backend-rwkv")]
237    static RWKV_METHOD_TLS: RefCell<HashMap<String, rwkvzip::Compressor>> = RefCell::new(HashMap::new());
238}
239
240#[cfg(feature = "backend-zpaq")]
241impl Default for CompressionBackend {
242    fn default() -> Self {
243        CompressionBackend::Zpaq {
244            method: "5".to_string(),
245        }
246    }
247}
248
249#[cfg(not(feature = "backend-zpaq"))]
250impl Default for CompressionBackend {
251    fn default() -> Self {
252        CompressionBackend::Rate {
253            rate_backend: RateBackend::default(),
254            coder: CoderType::AC,
255            framing: compression::FramingMode::Raw,
256        }
257    }
258}
259
260thread_local! {
261    static DEFAULT_CTX: RefCell<InfotheoryCtx> = RefCell::new(InfotheoryCtx::default());
262}
263
264/// Returns the current default information theory context for the thread.
265pub fn get_default_ctx() -> InfotheoryCtx {
266    DEFAULT_CTX.with(|ctx| ctx.borrow().clone())
267}
268
269/// Sets the default information theory context for the thread.
270pub fn set_default_ctx(ctx: InfotheoryCtx) {
271    DEFAULT_CTX.with(|c| *c.borrow_mut() = ctx);
272}
273
274#[inline(always)]
275fn with_default_ctx<R>(f: impl FnOnce(&InfotheoryCtx) -> R) -> R {
276    DEFAULT_CTX.with(|ctx| f(&ctx.borrow()))
277}
278
279/// Mutual information rate estimate under an explicit `backend`.
280///
281/// Inputs are aligned to the shared prefix length.
282pub fn mutual_information_rate_backend(
283    x: &[u8],
284    y: &[u8],
285    max_order: i64,
286    backend: &RateBackend,
287) -> f64 {
288    let (x, y) = aligned_prefix(x, y);
289    if x.is_empty() {
290        return 0.0;
291    }
292    // For CTW, we might want a special aligned implementation?
293    // Using standard formula for now.
294    let h_x = entropy_rate_backend(x, max_order, backend);
295    let h_y = entropy_rate_backend(y, max_order, backend);
296    let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
297    (h_x + h_y - h_xy).max(0.0)
298}
299
300/// Normalized entropy distance under an explicit `backend`.
301///
302/// Returns a value in `[0, 1]` after clamping.
303pub fn ned_rate_backend(x: &[u8], y: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
304    let (x, y) = aligned_prefix(x, y);
305    if x.is_empty() {
306        return 0.0;
307    }
308    let h_x = entropy_rate_backend(x, max_order, backend);
309    let h_y = entropy_rate_backend(y, max_order, backend);
310    let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
311    let min_h = h_x.min(h_y);
312    let max_h = h_x.max(h_y);
313    if max_h == 0.0 {
314        0.0
315    } else {
316        ((h_xy - min_h) / max_h).clamp(0.0, 1.0)
317    }
318}
319
320/// Normalized transform effort (variation-of-information form) under an explicit `backend`.
321///
322/// Returns a value in `[0, 2]` after clamping.
323pub fn nte_rate_backend(x: &[u8], y: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
324    let (x, y) = aligned_prefix(x, y);
325    if x.is_empty() {
326        return 0.0;
327    }
328    let h_x = entropy_rate_backend(x, max_order, backend);
329    let h_y = entropy_rate_backend(y, max_order, backend);
330    let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
331    let max_h = h_x.max(h_y);
332    if max_h == 0.0 {
333        0.0
334    } else {
335        // VI = H(X|Y) + H(Y|X) can be as large as H(X) + H(Y) ≈ 2*max(H)
336        // for independent sequences, so NTE ∈ [0, 2]
337        let vi = (h_xy - h_x).max(0.0) + (h_xy - h_y).max(0.0);
338        (vi / max_h).clamp(0.0, 2.0)
339    }
340}
341
342/// Core predictive model class used by the library.
343///
344/// `RateBackend` is the shared model class behind entropy-rate estimation,
345/// rate-coded compression, generation, and the world-model interface used by
346/// MC-AIXI/AIQI planners.
347#[derive(Clone)]
348pub enum RateBackend {
349    /// ROSA+ suffix-automaton estimator.
350    RosaPlus,
351    /// Local contiguous match predictor.
352    Match {
353        /// Number of retained hash bits for suffix lookup.
354        hash_bits: usize,
355        /// Minimum repeat length required before predicting.
356        min_len: usize,
357        /// Maximum repeat length used for confidence scaling.
358        max_len: usize,
359        /// Residual probability mass left for non-match symbols.
360        base_mix: f64,
361        /// Confidence multiplier applied to short-match tapering.
362        confidence_scale: f64,
363    },
364    /// Sparse/gapped local match predictor.
365    SparseMatch {
366        /// Number of retained hash bits for spaced-suffix lookup.
367        hash_bits: usize,
368        /// Minimum spaced repeat length required before predicting.
369        min_len: usize,
370        /// Maximum spaced repeat length used for confidence scaling.
371        max_len: usize,
372        /// Minimum gap between matched bytes.
373        gap_min: usize,
374        /// Maximum gap between matched bytes.
375        gap_max: usize,
376        /// Residual probability mass left for non-match symbols.
377        base_mix: f64,
378        /// Confidence multiplier applied to short-match tapering.
379        confidence_scale: f64,
380    },
381    /// Pure-Rust bounded-memory PPMD-style model.
382    Ppmd {
383        /// Maximum context order.
384        order: usize,
385        /// Approximate memory budget in MiB.
386        memory_mb: usize,
387    },
388    /// Exact online Sequitur grammar backend with byte-level predictive readout.
389    Sequitur {
390        /// Maximum number of terminal bytes retained per grammar-derived context.
391        context_bytes: usize,
392    },
393    #[cfg(feature = "backend-mamba")]
394    /// Mamba model loaded from explicit weights.
395    Mamba {
396        /// Loaded Mamba model.
397        model: Arc<mambazip::Model>,
398    },
399    #[cfg(feature = "backend-mamba")]
400    /// Mamba method string (e.g. `file:...` or `cfg:...[;policy:...]`) resolved lazily.
401    MambaMethod {
402        /// Mamba method string.
403        method: String,
404    },
405    #[cfg(feature = "backend-rwkv")]
406    /// RWKV7 model loaded from explicit weights.
407    Rwkv7 {
408        /// Loaded RWKV7 model.
409        model: Arc<rwkvzip::Model>,
410    },
411    #[cfg(feature = "backend-rwkv")]
412    /// RWKV7 method string (e.g. `file:...` or `cfg:...[;policy:...]`) resolved lazily.
413    Rwkv7Method {
414        /// RWKV7 method string.
415        method: String,
416    },
417    /// ZPAQ compression-based rate model (streamable methods only).
418    Zpaq {
419        /// ZPAQ method string (streamable modes only for rate estimation).
420        method: String,
421    },
422    /// Online mixture over `RateBackend` experts.
423    ///
424    /// `Bayes`, `Switching`, and `Convex` follow
425    /// "On Ensemble Techniques for AIXI Approximation"; `FadingBayes`,
426    /// `Mdl`, and `Neural` are repository extensions.
427    Mixture {
428        /// Mixture expert/runtime specification.
429        spec: Arc<MixtureSpec>,
430    },
431    /// Particle-latent filter ensemble.
432    Particle {
433        /// Particle filter specification.
434        spec: Arc<ParticleSpec>,
435    },
436    /// Calibrated wrapper over another bytewise backend.
437    Calibrated {
438        /// Calibration specification.
439        spec: Arc<CalibratedSpec>,
440    },
441    /// Action-Conditional CTW (single context tree).
442    Ctw {
443        /// Context tree depth.
444        depth: usize,
445    },
446    /// Factorized Action-Conditional CTW (k trees for k-bit percepts).
447    FacCtw {
448        /// Base context depth.
449        base_depth: usize,
450        /// Number of percept bits.
451        num_percept_bits: usize,
452        /// Encoding width in bits.
453        encoding_bits: usize,
454    },
455}
456
457#[allow(clippy::derivable_impls)]
458impl Default for RateBackend {
459    fn default() -> Self {
460        #[cfg(feature = "backend-rosa")]
461        {
462            RateBackend::RosaPlus
463        }
464        #[cfg(all(not(feature = "backend-rosa"), feature = "backend-zpaq"))]
465        {
466            RateBackend::Zpaq {
467                method: "1".to_string(),
468            }
469        }
470        #[cfg(all(not(feature = "backend-rosa"), not(feature = "backend-zpaq")))]
471        {
472            RateBackend::Ctw { depth: 16 }
473        }
474    }
475}
476
477/// Compression backend used by NCD/compression-size operations.
478#[derive(Clone)]
479pub enum CompressionBackend {
480    /// ZPAQ compressor with explicit method string.
481    Zpaq {
482        /// ZPAQ method (for example `"1"` or `"5"`).
483        method: String,
484    },
485    #[cfg(feature = "backend-rwkv")]
486    /// RWKV7 model as an entropy-coded compressor.
487    Rwkv7 {
488        /// Loaded RWKV7 model.
489        model: Arc<rwkvzip::Model>,
490        /// Entropy coder used for coding model PDFs.
491        coder: CoderType,
492    },
493    /// Generic rate-coded compressor wrapping an arbitrary rate backend.
494    Rate {
495        /// Predictive rate backend.
496        rate_backend: RateBackend,
497        /// Entropy coder used for coding model PDFs.
498        coder: CoderType,
499        /// Framing mode for output payloads.
500        framing: compression::FramingMode,
501    },
502}
503
504/// Shared maximum nesting depth for recursive mixture specifications.
505pub const MAX_MIXTURE_NESTING: usize = 8;
506
507/// Mixture policy kind for rate-backend mixtures.
508#[derive(Clone, Copy, Debug, Eq, PartialEq)]
509pub enum MixtureKind {
510    /// Standard Bayesian mixture with fixed expert weights.
511    Bayes,
512    /// Bayesian mixture with exponential weight decay.
513    FadingBayes,
514    /// Switching mixture using the fixed-share update from
515    /// "On Ensemble Techniques for AIXI Approximation".
516    Switching,
517    /// Online convex mixture with projected-simplex weight updates.
518    Convex,
519    /// MDL-style best-expert selector.
520    Mdl,
521    /// Bytewise neural logistic mixer (fx2-cmix style adaptation).
522    Neural,
523}
524
525/// Adaptive schedule family for switching and convex mixtures.
526#[derive(Clone, Copy, Debug, Eq, PartialEq)]
527pub enum MixtureScheduleMode {
528    /// Use the implementation's default exposed parameterization.
529    ///
530    /// - `Switching`: constant switch rate `alpha`
531    /// - `Convex`: step size `alpha / sqrt(t)`
532    Default,
533    /// Use the theorem schedule from
534    /// "On Ensemble Techniques for AIXI Approximation".
535    ///
536    /// - `Switching`: `alpha_t = 1 / t`
537    /// - `Convex`: `eta_t = epsilon / sqrt(t)` under this implementation's
538    ///   natural-log gradient, matching the bit-loss schedule analyzed in
539    ///   "On Ensemble Techniques for AIXI Approximation" after
540    ///   accounting for the `1 / ln(2)` factor in the base-2 gradient
541    ///
542    /// This preserves configured expert priors; exact theorem hypotheses for
543    /// switching still additionally require uniform priors.
544    Theorem,
545}
546
547impl Default for MixtureScheduleMode {
548    fn default() -> Self {
549        Self::Default
550    }
551}
552
553/// Parse a mixture kind name with the shared alias table used across CLI, Python, and WASM.
554pub fn parse_mixture_kind_name(kind: &str) -> Result<MixtureKind, String> {
555    match kind.trim().to_ascii_lowercase().as_str() {
556        "bayes" | "bayes-mix" | "bayes_mix" => Ok(MixtureKind::Bayes),
557        "fading" | "fading-bayes" | "fading_bayes" => Ok(MixtureKind::FadingBayes),
558        "switch" | "switching" | "switch-mix" | "switch_mix" => Ok(MixtureKind::Switching),
559        "convex" | "convex-mix" | "convex_mix" => Ok(MixtureKind::Convex),
560        "mdl" | "selector" | "mdr" => Ok(MixtureKind::Mdl),
561        "neural" | "neural-mix" | "neural_mix" | "mix" | "mixture" | "fx2" | "fx2-cmix"
562        | "fx2_cmix" => Ok(MixtureKind::Neural),
563        other => Err(format!("unknown mixture kind '{other}'")),
564    }
565}
566
567/// Parse a mixture schedule mode with the shared alias table used across CLI, Python, and WASM.
568pub fn parse_mixture_schedule_name(schedule: &str) -> Result<MixtureScheduleMode, String> {
569    match schedule.trim().to_ascii_lowercase().as_str() {
570        "" | "default" | "constant" | "const" => Ok(MixtureScheduleMode::Default),
571        "theorem" | "paper" | "paper-theorem" | "paper_theorem" => Ok(MixtureScheduleMode::Theorem),
572        other => Err(format!("unknown mixture schedule '{other}'")),
573    }
574}
575
576/// Fixed context families for calibrated PDF wrappers.
577#[derive(Clone, Copy, Debug, Eq, PartialEq)]
578pub enum CalibrationContextKind {
579    /// Single global calibration row.
580    Global,
581    /// Previous-byte class only.
582    ByteClass,
583    /// Text-structure-aware context hash.
584    Text,
585    /// Repeat-aware context hash.
586    Repeat,
587    /// Joint text/repeat-aware context hash.
588    TextRepeat,
589}
590
591/// Configuration for a calibrated wrapper rate backend.
592#[derive(Clone)]
593pub struct CalibratedSpec {
594    /// Base backend whose PDF is calibrated.
595    pub base: RateBackend,
596    /// Context family controlling table row selection.
597    pub context: CalibrationContextKind,
598    /// Number of probability bins per row.
599    pub bins: usize,
600    /// Online learning rate for observed-symbol updates.
601    pub learning_rate: f64,
602    /// Symmetric clip applied to calibration weights.
603    pub bias_clip: f64,
604}
605
606/// Expert specification for mixture backends.
607#[derive(Clone)]
608pub struct MixtureExpertSpec {
609    /// Optional expert display name.
610    pub name: Option<String>,
611    /// Log prior weight (natural log). Uniform priors can be `0.0`.
612    pub log_prior: f64,
613    /// Max order for ROSA experts (ignored for other backends).
614    pub max_order: i64,
615    /// Underlying backend for this expert.
616    pub backend: RateBackend,
617}
618
619/// Mixture specification for rate-backend mixtures.
620#[derive(Clone)]
621pub struct MixtureSpec {
622    /// Mixture policy.
623    pub kind: MixtureKind,
624    /// Adaptive schedule family for supported mixture kinds.
625    pub schedule: MixtureScheduleMode,
626    /// Shared scalar parameter: switch rate for `Switching`, step-size scale for `Convex`,
627    /// learning rate for `Neural`, and generic alpha for the remaining families.
628    ///
629    /// In theorem mode for `Switching` and `Convex`, this field is retained for API
630    /// compatibility but is not used by the update schedule.
631    pub alpha: f64,
632    /// Decay factor for fading Bayes mixtures.
633    pub decay: Option<f64>,
634    /// Expert list.
635    pub experts: Vec<MixtureExpertSpec>,
636}
637
638impl MixtureSpec {
639    /// Build a mixture specification from kind and expert list.
640    pub fn new(kind: MixtureKind, experts: Vec<MixtureExpertSpec>) -> Self {
641        Self {
642            kind,
643            schedule: MixtureScheduleMode::Default,
644            alpha: 0.01,
645            decay: None,
646            experts,
647        }
648    }
649
650    /// Set the schedule family.
651    pub fn with_schedule(mut self, schedule: MixtureScheduleMode) -> Self {
652        self.schedule = schedule;
653        self
654    }
655
656    /// Set the family-specific alpha parameter.
657    pub fn with_alpha(mut self, alpha: f64) -> Self {
658        self.alpha = alpha;
659        self
660    }
661
662    /// Set fading decay factor.
663    pub fn with_decay(mut self, decay: f64) -> Self {
664        self.decay = Some(decay);
665        self
666    }
667
668    /// Validate the mixture configuration before building runtime state.
669    pub fn validate(&self) -> Result<(), String> {
670        validate_mixture_spec_with_depth(self, MAX_MIXTURE_NESTING)
671    }
672
673    /// Convert to executable expert configs for runtime mixture evaluation.
674    pub fn build_experts(&self) -> Vec<crate::mixture::ExpertConfig> {
675        self.experts
676            .iter()
677            .map(|spec| {
678                crate::mixture::ExpertConfig::from_rate_backend(
679                    spec.name.clone(),
680                    spec.log_prior,
681                    spec.backend.clone(),
682                    spec.max_order,
683                )
684            })
685            .collect()
686    }
687}
688
689fn validate_mixture_spec_with_depth(spec: &MixtureSpec, depth: usize) -> Result<(), String> {
690    if depth == 0 {
691        return Err("mixture spec nesting too deep".to_string());
692    }
693    validate_mixture_spec_shallow(spec)?;
694    for (index, expert) in spec.experts.iter().enumerate() {
695        validate_rate_backend_with_depth(&expert.backend, depth - 1).map_err(|err| {
696            if let Some(name) = expert.name.as_deref() {
697                format!("mixture expert '{name}' invalid: {err}")
698            } else {
699                format!("mixture expert #{} invalid: {err}", index + 1)
700            }
701        })?;
702    }
703    Ok(())
704}
705
706fn validate_mixture_spec_shallow(spec: &MixtureSpec) -> Result<(), String> {
707    if spec.experts.is_empty() {
708        return Err("mixture spec must include at least one expert".to_string());
709    }
710    if !spec.alpha.is_finite() {
711        return Err("mixture alpha must be finite".to_string());
712    }
713    if spec
714        .experts
715        .iter()
716        .any(|expert| !expert.log_prior.is_finite())
717    {
718        return Err("mixture expert log_prior must be finite".to_string());
719    }
720    if let Some(decay) = spec.decay {
721        if !decay.is_finite() || !(0.0..1.0).contains(&decay) {
722            return Err("mixture decay must be in (0, 1)".to_string());
723        }
724    }
725    if matches!(spec.kind, MixtureKind::FadingBayes) && spec.decay.is_none() {
726        return Err("fading Bayes mixture requires decay".to_string());
727    }
728    if spec.schedule != MixtureScheduleMode::Default
729        && !matches!(spec.kind, MixtureKind::Switching | MixtureKind::Convex)
730    {
731        return Err(
732            "mixture schedule is only supported for switching and convex mixtures".to_string(),
733        );
734    }
735    match (spec.kind, spec.schedule) {
736        (MixtureKind::Switching, MixtureScheduleMode::Default) => {
737            if !(0.0..=1.0).contains(&spec.alpha) {
738                return Err("switching mixture alpha must be in [0, 1]".to_string());
739            }
740        }
741        (MixtureKind::Convex, MixtureScheduleMode::Default)
742        | (MixtureKind::Neural, MixtureScheduleMode::Default) => {
743            if spec.alpha <= 0.0 {
744                return Err("mixture alpha must be > 0".to_string());
745            }
746        }
747        (MixtureKind::Neural, MixtureScheduleMode::Theorem) => unreachable!(),
748        _ => {}
749    }
750    Ok(())
751}
752
753fn validate_rate_backend_with_depth(backend: &RateBackend, depth: usize) -> Result<(), String> {
754    match backend {
755        RateBackend::Sequitur { context_bytes } => {
756            if *context_bytes < 2 {
757                Err("sequitur context_bytes must be >= 2".to_string())
758            } else {
759                Ok(())
760            }
761        }
762        RateBackend::Mixture { spec } => validate_mixture_spec_with_depth(spec.as_ref(), depth),
763        RateBackend::Particle { spec } => spec.validate(),
764        RateBackend::Calibrated { spec } => {
765            if depth == 0 {
766                return Err("calibrated spec nesting too deep".to_string());
767            }
768            validate_rate_backend_with_depth(&spec.base, depth - 1)
769                .map_err(|err| format!("calibrated base invalid: {err}"))
770        }
771        _ => Ok(()),
772    }
773}
774
775/// Validate a rate backend, including nested mixture/calibrated subgraphs.
776pub fn validate_rate_backend(backend: &RateBackend) -> Result<(), String> {
777    validate_rate_backend_with_depth(backend, MAX_MIXTURE_NESTING)
778}
779
780/// Configuration for a particle-latent filter ensemble rate backend.
781#[derive(Clone, Debug)]
782pub struct ParticleSpec {
783    /// Number of particles in the ensemble.
784    pub num_particles: usize,
785    /// Context window length for rolling byte context.
786    pub context_window: usize,
787    /// Number of latent update unroll steps per byte.
788    pub unroll_steps: usize,
789    /// Number of latent cells per particle.
790    pub num_cells: usize,
791    /// Dimensionality of each latent cell.
792    pub cell_dim: usize,
793    /// Number of discrete rules for soft routing.
794    pub num_rules: usize,
795    /// Hidden dimension for the selector MLP.
796    pub selector_hidden: usize,
797    /// Hidden dimension for each rule MLP.
798    pub rule_hidden: usize,
799    /// Dimension of per-rule noise input (ignored when deterministic).
800    pub noise_dim: usize,
801    /// Whether to use fully deterministic execution (no RNG).
802    pub deterministic: bool,
803    /// Whether to inject noise into rule inputs (ignored when deterministic).
804    pub enable_noise: bool,
805    /// Base scale for deterministic hash-noise injected into rule inputs.
806    pub noise_scale: f64,
807    /// Number of steps over which injected noise linearly anneals to zero.
808    pub noise_anneal_steps: usize,
809    /// Learning rate for readout layer SGD.
810    pub learning_rate_readout: f64,
811    /// Learning rate for selector MLP SGD.
812    pub learning_rate_selector: f64,
813    /// Learning rate for rule MLP SGD.
814    pub learning_rate_rule: f64,
815    /// Truncated backpropagation-through-time depth (number of recent steps).
816    pub bptt_depth: usize,
817    /// Momentum coefficient for selector/rule online updates (in [0, 1)).
818    pub optimizer_momentum: f64,
819    /// Gradient clipping threshold (max abs value per element).
820    pub grad_clip: f64,
821    /// Latent cell state clipping threshold (max abs value per element).
822    pub state_clip: f64,
823    /// Forgetting factor for particle log-weights (0 = no forgetting).
824    pub forget_lambda: f64,
825    /// Effective sample size ratio threshold for resampling (in (0, 1]).
826    pub resample_threshold: f64,
827    /// Fraction of particles to mutate after resampling (in [0, 1]).
828    pub mutate_fraction: f64,
829    /// Scale of hash-noise perturbation applied during mutation.
830    pub mutate_scale: f64,
831    /// Whether mutation also perturbs model parameters (state is always mutated).
832    pub mutate_model_params: bool,
833    /// Diagnostics print interval in steps (0 disables particle diagnostics logs).
834    pub diagnostics_interval: usize,
835    /// Minimum probability floor for numerical stability.
836    pub min_prob: f64,
837    /// Master seed for deterministic initialization and mutation.
838    pub seed: u64,
839}
840
841impl Default for ParticleSpec {
842    fn default() -> Self {
843        Self {
844            num_particles: 16,
845            context_window: 32,
846            unroll_steps: 2,
847            num_cells: 8,
848            cell_dim: 32,
849            num_rules: 4,
850            selector_hidden: 64,
851            rule_hidden: 64,
852            noise_dim: 8,
853            deterministic: true,
854            enable_noise: false,
855            noise_scale: 0.10,
856            noise_anneal_steps: 8192,
857            learning_rate_readout: 0.01,
858            learning_rate_selector: 1e-4,
859            learning_rate_rule: 3e-4,
860            bptt_depth: 3,
861            optimizer_momentum: 0.05,
862            grad_clip: 1.0,
863            state_clip: 8.0,
864            forget_lambda: 0.0,
865            resample_threshold: 0.5,
866            mutate_fraction: 0.1,
867            mutate_scale: 0.01,
868            mutate_model_params: false,
869            diagnostics_interval: 0,
870            min_prob: 2f64.powi(-24),
871            seed: 42,
872        }
873    }
874}
875
876impl ParticleSpec {
877    /// Validate all fields, returning an error message on failure.
878    pub fn validate(&self) -> Result<(), String> {
879        if self.num_particles == 0 {
880            return Err("num_particles must be > 0".into());
881        }
882        if self.context_window == 0 {
883            return Err("context_window must be > 0".into());
884        }
885        if self.unroll_steps == 0 {
886            return Err("unroll_steps must be > 0".into());
887        }
888        if self.num_cells == 0 {
889            return Err("num_cells must be > 0".into());
890        }
891        if self.cell_dim == 0 {
892            return Err("cell_dim must be > 0".into());
893        }
894        if self.num_rules == 0 {
895            return Err("num_rules must be > 0".into());
896        }
897        if self.selector_hidden == 0 {
898            return Err("selector_hidden must be > 0".into());
899        }
900        if self.rule_hidden == 0 {
901            return Err("rule_hidden must be > 0".into());
902        }
903        if !self.learning_rate_readout.is_finite() || self.learning_rate_readout < 0.0 {
904            return Err("learning_rate_readout must be finite and non-negative".into());
905        }
906        if !self.learning_rate_selector.is_finite() || self.learning_rate_selector < 0.0 {
907            return Err("learning_rate_selector must be finite and non-negative".into());
908        }
909        if !self.learning_rate_rule.is_finite() || self.learning_rate_rule < 0.0 {
910            return Err("learning_rate_rule must be finite and non-negative".into());
911        }
912        if !self.noise_scale.is_finite() || self.noise_scale < 0.0 {
913            return Err("noise_scale must be finite and non-negative".into());
914        }
915        if !self.optimizer_momentum.is_finite()
916            || self.optimizer_momentum < 0.0
917            || self.optimizer_momentum >= 1.0
918        {
919            return Err("optimizer_momentum must be finite and in [0, 1)".into());
920        }
921        if self.bptt_depth == 0 {
922            return Err("bptt_depth must be > 0".into());
923        }
924        if !(self.resample_threshold > 0.0 && self.resample_threshold <= 1.0) {
925            return Err("resample_threshold must be in (0, 1]".into());
926        }
927        if !(self.mutate_fraction >= 0.0 && self.mutate_fraction <= 1.0) {
928            return Err("mutate_fraction must be in [0, 1]".into());
929        }
930        if !(self.min_prob > 0.0 && self.min_prob < 0.5) {
931            return Err("min_prob must be in (0, 0.5)".into());
932        }
933        Ok(())
934    }
935}
936
937/// Reusable execution context holding default rate and compression backends.
938#[derive(Clone, Default)]
939pub struct InfotheoryCtx {
940    /// Default rate backend for entropy/rate metrics.
941    pub rate_backend: RateBackend,
942    /// Default compression backend for NCD/compression primitives.
943    pub compression_backend: CompressionBackend,
944}
945
946/// Stateful rate-backend session for fitting, conditioning, and continuation.
947pub struct RateBackendSession {
948    predictor: crate::mixture::RateBackendPredictor,
949}
950
951impl RateBackendSession {
952    /// Create a session from an explicit backend.
953    pub fn from_backend(
954        backend: RateBackend,
955        max_order: i64,
956        total_symbols: Option<u64>,
957    ) -> Result<Self, String> {
958        use crate::mixture::OnlineBytePredictor;
959
960        validate_rate_backend(&backend)?;
961        let mut predictor = crate::mixture::RateBackendPredictor::from_backend(
962            backend,
963            max_order,
964            crate::mixture::DEFAULT_MIN_PROB,
965        );
966        predictor.begin_stream(total_symbols)?;
967        Ok(Self { predictor })
968    }
969
970    /// Observe bytes while adapting/fitting the model.
971    pub fn observe(&mut self, data: &[u8]) {
972        use crate::mixture::OnlineBytePredictor;
973
974        for &byte in data {
975            self.predictor.update(byte);
976        }
977    }
978
979    /// Advance conditioning state without changing fitted parameters/statistics.
980    pub fn condition(&mut self, data: &[u8]) {
981        use crate::mixture::OnlineBytePredictor;
982
983        for &byte in data {
984            self.predictor.update_frozen(byte);
985        }
986    }
987
988    /// Reset dynamic conditioning state while preserving fitted parameters/statistics.
989    pub fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
990        use crate::mixture::OnlineBytePredictor;
991
992        self.predictor.reset_frozen(total_symbols)
993    }
994
995    /// Fill the 256-way next-byte log-probabilities.
996    pub fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
997        use crate::mixture::OnlineBytePredictor;
998
999        self.predictor.fill_log_probs(out);
1000    }
1001
1002    /// Generate continuation bytes from the current state.
1003    pub fn generate_bytes(&mut self, bytes: usize, config: GenerationConfig) -> Vec<u8> {
1004        use crate::mixture::OnlineBytePredictor;
1005
1006        if bytes == 0 {
1007            return Vec::new();
1008        }
1009
1010        let mut out = Vec::with_capacity(bytes);
1011        let mut logps = [0.0f64; 256];
1012        let mut rng = GenerationRng::new(config.seed);
1013
1014        for _ in 0..bytes {
1015            match &mut self.predictor {
1016                // ROSA's scalar path remains the reference for continuation generation.
1017                crate::mixture::RateBackendPredictor::Rosa { .. } => {
1018                    for (sym, slot) in logps.iter_mut().enumerate() {
1019                        *slot = self.predictor.log_prob(sym as u8);
1020                    }
1021                }
1022                _ => self.predictor.fill_log_probs(&mut logps),
1023            }
1024            let byte = pick_generated_byte(&logps, config, &mut rng);
1025            match config.update_mode {
1026                GenerationUpdateMode::Adaptive => self.predictor.update(byte),
1027                GenerationUpdateMode::Frozen => self.predictor.update_frozen(byte),
1028            }
1029            out.push(byte);
1030        }
1031
1032        out
1033    }
1034
1035    /// Finalize the underlying stream if the backend needs it.
1036    pub fn finish(&mut self) -> Result<(), String> {
1037        use crate::mixture::OnlineBytePredictor;
1038
1039        self.predictor.finish_stream()
1040    }
1041}
1042
1043impl InfotheoryCtx {
1044    /// Create a context from explicit rate and compression backends.
1045    pub fn new(rate_backend: RateBackend, compression_backend: CompressionBackend) -> Self {
1046        Self {
1047            rate_backend,
1048            compression_backend,
1049        }
1050    }
1051
1052    /// Create a context with ROSA+ rate backend and ZPAQ compression backend.
1053    pub fn with_zpaq(method: impl Into<String>) -> Self {
1054        Self {
1055            rate_backend: RateBackend::RosaPlus,
1056            compression_backend: CompressionBackend::Zpaq {
1057                method: method.into(),
1058            },
1059        }
1060    }
1061
1062    /// Compressed length of one byte slice under this context's compressor.
1063    pub fn compress_size(&self, data: &[u8]) -> u64 {
1064        compress_size_backend(data, &self.compression_backend)
1065    }
1066
1067    /// Compressed length of chained slices under one stream.
1068    pub fn compress_size_chain(&self, parts: &[&[u8]]) -> u64 {
1069        compress_size_chain_backend(parts, &self.compression_backend)
1070    }
1071
1072    /// Create a stateful session for the active rate backend.
1073    pub fn rate_backend_session(
1074        &self,
1075        max_order: i64,
1076        total_symbols: Option<u64>,
1077    ) -> Result<RateBackendSession, String> {
1078        RateBackendSession::from_backend(self.rate_backend.clone(), max_order, total_symbols)
1079    }
1080
1081    /// Entropy-rate estimate for `data` under this context's rate backend.
1082    pub fn entropy_rate_bytes(&self, data: &[u8], max_order: i64) -> f64 {
1083        entropy_rate_backend(data, max_order, &self.rate_backend)
1084    }
1085
1086    /// Biased entropy-rate estimate (plugin variant) for `data`.
1087    pub fn biased_entropy_rate_bytes(&self, data: &[u8], max_order: i64) -> f64 {
1088        biased_entropy_rate_backend(data, max_order, &self.rate_backend)
1089    }
1090
1091    /// Cross entropy of `test_data` under model trained on `train_data`.
1092    pub fn cross_entropy_rate_bytes(
1093        &self,
1094        test_data: &[u8],
1095        train_data: &[u8],
1096        max_order: i64,
1097    ) -> f64 {
1098        cross_entropy_rate_backend(test_data, train_data, max_order, &self.rate_backend)
1099    }
1100
1101    /// Cross entropy with order-0 fast-path fallback when `max_order == 0`.
1102    pub fn cross_entropy_bytes(&self, test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
1103        if max_order == 0 {
1104            if test_data.is_empty() {
1105                return 0.0;
1106            }
1107            let p_x = byte_histogram(test_data);
1108            let p_y = byte_histogram(train_data);
1109            let mut h = 0.0f64;
1110            for i in 0..256 {
1111                if p_x[i] > 0.0 {
1112                    let q_y = p_y[i].max(1e-12);
1113                    h -= p_x[i] * q_y.log2();
1114                }
1115            }
1116            h
1117        } else {
1118            self.cross_entropy_rate_bytes(test_data, train_data, max_order)
1119        }
1120    }
1121
1122    /// Joint entropy-rate estimate `H(X,Y)` under aligned-prefix semantics.
1123    pub fn joint_entropy_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1124        let (x, y) = aligned_prefix(x, y);
1125        if x.is_empty() {
1126            return 0.0;
1127        }
1128        joint_entropy_rate_backend(x, y, max_order, &self.rate_backend)
1129    }
1130
1131    /// Conditional entropy-rate estimate `H(X|Y)`.
1132    pub fn conditional_entropy_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1133        let (x, y) = aligned_prefix(x, y);
1134        if x.is_empty() {
1135            return 0.0;
1136        }
1137        let h_xy = self.joint_entropy_rate_bytes(x, y, max_order);
1138        let h_y = self.entropy_rate_bytes(y, max_order);
1139        (h_xy - h_y).max(0.0)
1140    }
1141
1142    /// Compute `H(data | prefix_parts)` by conditioning the active rate backend
1143    /// on an explicit prefix chain.
1144    pub fn cross_entropy_conditional_chain(&self, prefix_parts: &[&[u8]], data: &[u8]) -> f64 {
1145        match &self.rate_backend {
1146            RateBackend::RosaPlus => {
1147                let mut prefix = Vec::new();
1148                let total: usize = prefix_parts.iter().map(|p| p.len()).sum();
1149                prefix.reserve(total);
1150                for p in prefix_parts {
1151                    prefix.extend_from_slice(p);
1152                }
1153                cross_entropy_rate_backend(data, &prefix, -1, &RateBackend::RosaPlus)
1154            }
1155            RateBackend::Match { .. }
1156            | RateBackend::SparseMatch { .. }
1157            | RateBackend::Ppmd { .. }
1158            | RateBackend::Sequitur { .. }
1159            | RateBackend::Calibrated { .. } => {
1160                prequential_rate_backend(data, prefix_parts, -1, &self.rate_backend)
1161            }
1162            #[cfg(feature = "backend-rwkv")]
1163            RateBackend::Rwkv7 { model } => with_rwkv_tls(model, |c| {
1164                c.cross_entropy_conditional_chain(prefix_parts, data)
1165                    .unwrap_or_else(|e| panic!("rwkv conditional-chain scoring failed: {e:#}"))
1166            }),
1167            #[cfg(feature = "backend-rwkv")]
1168            RateBackend::Rwkv7Method { method } => with_rwkv_method_tls(method, |c| {
1169                c.cross_entropy_conditional_chain(prefix_parts, data)
1170                    .unwrap_or_else(|e| {
1171                        panic!("rwkv method conditional-chain scoring failed: {e:#}")
1172                    })
1173            }),
1174            #[cfg(feature = "backend-mamba")]
1175            RateBackend::Mamba { model } => with_mamba_tls(model, |c| {
1176                c.cross_entropy_conditional_chain(prefix_parts, data)
1177                    .unwrap_or_else(|e| panic!("mamba conditional-chain scoring failed: {e:#}"))
1178            }),
1179            #[cfg(feature = "backend-mamba")]
1180            RateBackend::MambaMethod { method } => with_mamba_method_tls(method, |c| {
1181                c.cross_entropy_conditional_chain(prefix_parts, data)
1182                    .unwrap_or_else(|e| {
1183                        panic!("mamba method conditional-chain scoring failed: {e:#}")
1184                    })
1185            }),
1186            RateBackend::Ctw { depth } => {
1187                if data.is_empty() {
1188                    return 0.0;
1189                }
1190                let mut tree = crate::ctw::ContextTree::new(*depth);
1191                for &part in prefix_parts {
1192                    for &b in part {
1193                        for i in (0..8).rev() {
1194                            tree.update(((b >> i) & 1) == 1);
1195                        }
1196                    }
1197                }
1198                let log_p_prefix = tree.get_log_block_probability();
1199                for &b in data {
1200                    for i in (0..8).rev() {
1201                        tree.update(((b >> i) & 1) == 1);
1202                    }
1203                }
1204                let log_p_joint = tree.get_log_block_probability();
1205                let log_p_cond = log_p_joint - log_p_prefix;
1206                let bits = -log_p_cond / std::f64::consts::LN_2;
1207                bits / (data.len() as f64)
1208            }
1209            RateBackend::Zpaq { method } => {
1210                if data.is_empty() {
1211                    return 0.0;
1212                }
1213                let mut model =
1214                    crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
1215                for &part in prefix_parts {
1216                    model.update_and_score(part);
1217                }
1218                let bits = model.update_and_score(data);
1219                bits / (data.len() as f64)
1220            }
1221            RateBackend::Mixture { spec } => {
1222                if data.is_empty() {
1223                    return 0.0;
1224                }
1225                let experts = spec.build_experts();
1226                let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
1227                    .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
1228                let total = prefix_parts
1229                    .iter()
1230                    .map(|p| p.len() as u64)
1231                    .sum::<u64>()
1232                    .saturating_add(data.len() as u64);
1233                mix.begin_stream(Some(total))
1234                    .unwrap_or_else(|e| panic!("Mixture stream init failed: {e}"));
1235                for &part in prefix_parts {
1236                    for &b in part {
1237                        mix.step(b);
1238                    }
1239                }
1240                let mut bits = 0.0;
1241                for &b in data {
1242                    bits -= mix.step(b) / std::f64::consts::LN_2;
1243                }
1244                bits / (data.len() as f64)
1245            }
1246            RateBackend::Particle { spec } => {
1247                if data.is_empty() {
1248                    return 0.0;
1249                }
1250                let mut runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
1251                for &part in prefix_parts {
1252                    for &b in part {
1253                        runtime.step(b);
1254                    }
1255                }
1256                let mut bits = 0.0;
1257                for &b in data {
1258                    bits -= runtime.step(b) / std::f64::consts::LN_2;
1259                }
1260                bits / (data.len() as f64)
1261            }
1262            RateBackend::FacCtw {
1263                base_depth,
1264                num_percept_bits: _,
1265                encoding_bits,
1266            } => {
1267                if data.is_empty() {
1268                    return 0.0;
1269                }
1270                let bits_per_byte = (*encoding_bits).clamp(1, 8);
1271                let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte);
1272                for &part in prefix_parts {
1273                    for &b in part {
1274                        // Fix Issue 1: LSB-first
1275                        for i in 0..bits_per_byte {
1276                            let bit_idx = i;
1277                            // b >> i gets the i-th bit (0 is LSB)
1278                            fac.update(((b >> i) & 1) == 1, bit_idx);
1279                        }
1280                    }
1281                }
1282                let log_p_prefix = fac.get_log_block_probability();
1283                for &b in data {
1284                    for i in 0..bits_per_byte {
1285                        let bit_idx = i;
1286                        fac.update(((b >> i) & 1) == 1, bit_idx);
1287                    }
1288                }
1289                let log_p_joint = fac.get_log_block_probability();
1290                let log_p_cond = log_p_joint - log_p_prefix;
1291                let bits = -log_p_cond / std::f64::consts::LN_2;
1292                bits / (data.len() as f64)
1293            }
1294        }
1295    }
1296
1297    /// Generate a continuation from `prompt` with [`GenerationConfig::default()`].
1298    ///
1299    /// The default is deterministic frozen sampling with seed `42`.
1300    pub fn generate_bytes(&self, prompt: &[u8], bytes: usize, max_order: i64) -> Vec<u8> {
1301        self.generate_bytes_with_config(prompt, bytes, max_order, GenerationConfig::default())
1302    }
1303
1304    /// Generate a continuation from `prompt` using an explicit generation config.
1305    pub fn generate_bytes_with_config(
1306        &self,
1307        prompt: &[u8],
1308        bytes: usize,
1309        max_order: i64,
1310        config: GenerationConfig,
1311    ) -> Vec<u8> {
1312        generate_rate_backend_chain(&[prompt], bytes, max_order, &self.rate_backend, config)
1313    }
1314
1315    /// Generate a continuation after conditioning on an explicit chain of prefix parts.
1316    pub fn generate_bytes_conditional_chain(
1317        &self,
1318        prefix_parts: &[&[u8]],
1319        bytes: usize,
1320        max_order: i64,
1321    ) -> Vec<u8> {
1322        self.generate_bytes_conditional_chain_with_config(
1323            prefix_parts,
1324            bytes,
1325            max_order,
1326            GenerationConfig::default(),
1327        )
1328    }
1329
1330    /// Generate a continuation after conditioning on an explicit chain of prefix parts.
1331    pub fn generate_bytes_conditional_chain_with_config(
1332        &self,
1333        prefix_parts: &[&[u8]],
1334        bytes: usize,
1335        max_order: i64,
1336        config: GenerationConfig,
1337    ) -> Vec<u8> {
1338        generate_rate_backend_chain(prefix_parts, bytes, max_order, &self.rate_backend, config)
1339    }
1340
1341    /// NCD between byte slices using this context's compression backend.
1342    pub fn ncd_bytes(&self, x: &[u8], y: &[u8], variant: NcdVariant) -> f64 {
1343        ncd_bytes_backend(x, y, &self.compression_backend, variant)
1344    }
1345
1346    /// Rate-backend mutual information estimate.
1347    pub fn mutual_information_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1348        mutual_information_rate_backend(x, y, max_order, &self.rate_backend)
1349    }
1350
1351    /// Mutual information with `max_order == 0` marginal fast-path.
1352    pub fn mutual_information_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1353        if max_order == 0 {
1354            mutual_information_marg_bytes(x, y)
1355        } else {
1356            self.mutual_information_rate_bytes(x, y, max_order)
1357        }
1358    }
1359
1360    /// Conditional entropy with aligned-prefix semantics.
1361    pub fn conditional_entropy_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1362        let (x, y) = aligned_prefix(x, y);
1363        if max_order == 0 {
1364            let h_xy = joint_marginal_entropy_bytes(x, y);
1365            let h_y = marginal_entropy_bytes(y);
1366            (h_xy - h_y).max(0.0)
1367        } else {
1368            let h_xy = self.joint_entropy_rate_bytes(x, y, max_order);
1369            let h_y = self.entropy_rate_bytes(y, max_order);
1370            (h_xy - h_y).max(0.0)
1371        }
1372    }
1373
1374    /// Normalized entropy distance (NED) under this context.
1375    pub fn ned_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1376        if max_order == 0 {
1377            ned_marg_bytes(x, y)
1378        } else {
1379            ned_rate_backend(x, y, max_order, &self.rate_backend)
1380        }
1381    }
1382
1383    /// Conservative NED normalization variant.
1384    pub fn ned_cons_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1385        let (x, y) = aligned_prefix(x, y);
1386        let (h_x, h_y, h_xy) = if max_order == 0 {
1387            (
1388                marginal_entropy_bytes(x),
1389                marginal_entropy_bytes(y),
1390                joint_marginal_entropy_bytes(x, y),
1391            )
1392        } else {
1393            (
1394                self.entropy_rate_bytes(x, max_order),
1395                self.entropy_rate_bytes(y, max_order),
1396                self.joint_entropy_rate_bytes(x, y, max_order),
1397            )
1398        };
1399        let min_h = h_x.min(h_y);
1400        if h_xy == 0.0 {
1401            0.0
1402        } else {
1403            ((h_xy - min_h) / h_xy).clamp(0.0, 1.0)
1404        }
1405    }
1406
1407    /// Normalized transform effort (NTE) under this context.
1408    pub fn nte_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
1409        if max_order == 0 {
1410            nte_marg_bytes(x, y)
1411        } else {
1412            nte_rate_backend(x, y, max_order, &self.rate_backend)
1413        }
1414    }
1415
1416    /// Intrinsic dependence score in `[0,1]`.
1417    pub fn intrinsic_dependence_bytes(&self, data: &[u8], max_order: i64) -> f64 {
1418        let h_marginal = marginal_entropy_bytes(data);
1419        if h_marginal < 1e-9 {
1420            return 0.0;
1421        }
1422        let h_rate = self.entropy_rate_bytes(data, max_order);
1423        ((h_marginal - h_rate) / h_marginal).clamp(0.0, 1.0)
1424    }
1425
1426    /// Resistance-to-transformation ratio `I(X;T(X))/H(X)` in `[0,1]`.
1427    pub fn resistance_to_transformation_bytes(&self, x: &[u8], tx: &[u8], max_order: i64) -> f64 {
1428        let (x, tx) = aligned_prefix(x, tx);
1429        let h_x = if max_order == 0 {
1430            marginal_entropy_bytes(x)
1431        } else {
1432            self.entropy_rate_bytes(x, max_order)
1433        };
1434        if h_x < 1e-9 {
1435            return 0.0;
1436        }
1437        let mi = self.mutual_information_bytes(x, tx, max_order);
1438        (mi / h_x).clamp(0.0, 1.0)
1439    }
1440}
1441
1442#[cfg(feature = "backend-rwkv")]
1443/// Load an RWKV7 model from `.safetensors` path.
1444pub fn load_rwkv7_model_from_path(path: &str) -> Arc<rwkvzip::Model> {
1445    rwkvzip::Compressor::load_model(path).expect("failed to load RWKV7 model")
1446}
1447
1448#[cfg(feature = "backend-mamba")]
1449/// Load a Mamba-1 model from `.safetensors` path.
1450pub fn load_mamba_model_from_path(path: &str) -> Arc<mambazip::Model> {
1451    mambazip::Compressor::load_model(path).expect("failed to load Mamba model")
1452}
1453
1454#[inline(always)]
1455fn aligned_prefix<'a>(x: &'a [u8], y: &'a [u8]) -> (&'a [u8], &'a [u8]) {
1456    let n = x.len().min(y.len());
1457    (&x[..n], &y[..n])
1458}
1459
1460#[cfg(feature = "backend-zpaq")]
1461#[inline(always)]
1462fn zpaq_compress_size_bytes(data: &[u8], method: &str) -> u64 {
1463    zpaq_rs::compress_size(data, method).unwrap_or(0)
1464}
1465
1466#[cfg(not(feature = "backend-zpaq"))]
1467#[inline(always)]
1468fn zpaq_compress_size_bytes(_data: &[u8], _method: &str) -> u64 {
1469    panic!("CompressionBackend::Zpaq is unavailable: build with feature 'backend-zpaq'")
1470}
1471
1472#[cfg(feature = "backend-zpaq")]
1473#[inline(always)]
1474fn zpaq_compress_size_parallel_bytes(data: &[u8], method: &str, threads: usize) -> u64 {
1475    zpaq_rs::compress_size_parallel(data, method, threads).unwrap_or(0)
1476}
1477
1478#[cfg(not(feature = "backend-zpaq"))]
1479#[inline(always)]
1480fn zpaq_compress_size_parallel_bytes(_data: &[u8], _method: &str, _threads: usize) -> u64 {
1481    panic!("CompressionBackend::Zpaq is unavailable: build with feature 'backend-zpaq'")
1482}
1483
1484#[cfg(feature = "backend-zpaq")]
1485#[inline(always)]
1486fn zpaq_compress_size_stream<R: std::io::Read + Send>(reader: R, method: &str) -> u64 {
1487    zpaq_rs::compress_size_stream(reader, method, None, None).unwrap_or(0)
1488}
1489
1490#[cfg(not(feature = "backend-zpaq"))]
1491#[inline(always)]
1492fn zpaq_compress_size_stream<R: std::io::Read + Send>(_reader: R, _method: &str) -> u64 {
1493    panic!("CompressionBackend::Zpaq is unavailable: build with feature 'backend-zpaq'")
1494}
1495
1496#[cfg(feature = "backend-zpaq")]
1497#[inline(always)]
1498fn zpaq_compress_to_vec(data: &[u8], method: &str) -> anyhow::Result<Vec<u8>> {
1499    Ok(zpaq_rs::compress_to_vec(data, method)?)
1500}
1501
1502#[cfg(not(feature = "backend-zpaq"))]
1503#[inline(always)]
1504fn zpaq_compress_to_vec(_data: &[u8], _method: &str) -> anyhow::Result<Vec<u8>> {
1505    anyhow::bail!("zpaq backend disabled at compile time (enable feature 'backend-zpaq')")
1506}
1507
1508#[cfg(feature = "backend-zpaq")]
1509#[inline(always)]
1510fn zpaq_decompress_to_vec(data: &[u8]) -> anyhow::Result<Vec<u8>> {
1511    Ok(zpaq_rs::decompress_to_vec(data)?)
1512}
1513
1514#[cfg(not(feature = "backend-zpaq"))]
1515#[inline(always)]
1516fn zpaq_decompress_to_vec(_data: &[u8]) -> anyhow::Result<Vec<u8>> {
1517    anyhow::bail!("zpaq backend disabled at compile time (enable feature 'backend-zpaq')")
1518}
1519
1520/// ------- Base Compression Functions -------
1521#[inline(always)]
1522pub fn get_compressed_size(path: &str, method: &str) -> u64 {
1523    // Convert Input file to Vec<u8>, and reference that (compress_size only takes &[u8] input), and pass method.
1524    // Will panic if file does not exist, so it must be prevalidated.
1525    zpaq_compress_size_bytes(&std::fs::read(path).unwrap(), method)
1526}
1527
1528/// Validate that a ZPAQ method string is supported for rate estimation.
1529pub fn validate_zpaq_rate_method(method: &str) -> Result<(), String> {
1530    #[cfg(feature = "backend-zpaq")]
1531    {
1532        zpaq_rate::validate_zpaq_rate_method(method)
1533    }
1534    #[cfg(not(feature = "backend-zpaq"))]
1535    {
1536        let _ = method;
1537        Err("zpaq backend disabled at compile time".to_string())
1538    }
1539}
1540
1541#[cfg(feature = "backend-rwkv")]
1542fn with_rwkv_tls<R>(
1543    model: &Arc<rwkvzip::Model>,
1544    f: impl FnOnce(&mut rwkvzip::Compressor) -> R,
1545) -> R {
1546    let key = Arc::as_ptr(model) as usize;
1547    RWKV_TLS.with(|cell| {
1548        let mut map = cell.borrow_mut();
1549        let comp = map
1550            .entry(key)
1551            .or_insert_with(|| rwkvzip::Compressor::new_from_model(model.clone()));
1552        f(comp)
1553    })
1554}
1555
1556#[cfg(feature = "backend-rwkv")]
1557fn with_rwkv_method_tls<R>(method: &str, f: impl FnOnce(&mut rwkvzip::Compressor) -> R) -> R {
1558    RWKV_METHOD_TLS.with(|cell| {
1559        let mut map = cell.borrow_mut();
1560        // Keep a per-method template compressor for fast cloning while ensuring
1561        // each call gets isolated mutable runtime state (no cross-call leakage).
1562        let mut comp = if let Some(template) = map.get(method) {
1563            template.clone()
1564        } else {
1565            let template = rwkvzip::Compressor::new_from_method(method).unwrap_or_else(|e| {
1566                panic!("invalid rwkv method '{method}': {e:#}");
1567            });
1568            map.insert(method.to_string(), template.clone());
1569            template
1570        };
1571        drop(map);
1572        f(&mut comp)
1573    })
1574}
1575
1576#[cfg(feature = "backend-rwkv")]
1577fn with_rwkv_rate_tls<R>(
1578    model: &Arc<rwkvzip::Model>,
1579    f: impl FnOnce(&mut rwkvzip::Compressor) -> R,
1580) -> R {
1581    let key = Arc::as_ptr(model) as usize;
1582    RWKV_RATE_TLS.with(|cell| {
1583        let mut map = cell.borrow_mut();
1584        let mut comp = if let Some(template) = map.get(&key) {
1585            template.clone()
1586        } else {
1587            let template = rwkvzip::Compressor::new_from_model(model.clone());
1588            map.insert(key, template.clone());
1589            template
1590        };
1591        drop(map);
1592        f(&mut comp)
1593    })
1594}
1595
1596#[cfg(feature = "backend-mamba")]
1597fn with_mamba_tls<R>(
1598    model: &Arc<mambazip::Model>,
1599    f: impl FnOnce(&mut mambazip::Compressor) -> R,
1600) -> R {
1601    let key = Arc::as_ptr(model) as usize;
1602    MAMBA_TLS.with(|cell| {
1603        let mut map = cell.borrow_mut();
1604        let comp = map
1605            .entry(key)
1606            .or_insert_with(|| mambazip::Compressor::new_from_model(model.clone()));
1607        f(comp)
1608    })
1609}
1610
1611#[cfg(feature = "backend-mamba")]
1612fn with_mamba_rate_tls<R>(
1613    model: &Arc<mambazip::Model>,
1614    f: impl FnOnce(&mut mambazip::Compressor) -> R,
1615) -> R {
1616    let key = Arc::as_ptr(model) as usize;
1617    MAMBA_RATE_TLS.with(|cell| {
1618        let mut map = cell.borrow_mut();
1619        let mut comp = if let Some(template) = map.get(&key) {
1620            template.clone()
1621        } else {
1622            let template = mambazip::Compressor::new_from_model(model.clone());
1623            map.insert(key, template.clone());
1624            template
1625        };
1626        drop(map);
1627        f(&mut comp)
1628    })
1629}
1630
1631#[cfg(feature = "backend-mamba")]
1632fn with_mamba_method_tls<R>(method: &str, f: impl FnOnce(&mut mambazip::Compressor) -> R) -> R {
1633    MAMBA_METHOD_TLS.with(|cell| {
1634        let mut map = cell.borrow_mut();
1635        let mut comp = if let Some(template) = map.get(method) {
1636            template.clone()
1637        } else {
1638            let template = mambazip::Compressor::new_from_method(method).unwrap_or_else(|e| {
1639                panic!("invalid mamba method '{method}': {e:#}");
1640            });
1641            map.insert(method.to_string(), template.clone());
1642            template
1643        };
1644        drop(map);
1645        f(&mut comp)
1646    })
1647}
1648
1649struct SliceChainReader<'a> {
1650    parts: &'a [&'a [u8]],
1651    i: usize,
1652    off: usize,
1653}
1654
1655impl<'a> SliceChainReader<'a> {
1656    fn new(parts: &'a [&'a [u8]]) -> Self {
1657        Self {
1658            parts,
1659            i: 0,
1660            off: 0,
1661        }
1662    }
1663}
1664
1665impl<'a> std::io::Read for SliceChainReader<'a> {
1666    fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
1667        let mut total = 0;
1668        if buf.is_empty() {
1669            return Ok(0);
1670        }
1671        while self.i < self.parts.len() {
1672            let p = self.parts[self.i];
1673            if self.off >= p.len() {
1674                self.i += 1;
1675                self.off = 0;
1676                continue;
1677            }
1678            let n = (p.len() - self.off).min(buf.len());
1679            // Safe copy slice
1680            buf[..n].copy_from_slice(&p[self.off..self.off + n]);
1681
1682            // Advance state
1683            self.off += n;
1684            total += n;
1685
1686            // Re-slice buf to fill remainder
1687            let tmp = buf;
1688            buf = &mut tmp[n..];
1689
1690            if buf.is_empty() {
1691                break;
1692            }
1693        }
1694        Ok(total)
1695    }
1696}
1697
1698/// Compute compressed size of a chain of byte slices with a selected compression backend.
1699pub fn compress_size_chain_backend(parts: &[&[u8]], backend: &CompressionBackend) -> u64 {
1700    match backend {
1701        CompressionBackend::Zpaq { method } => {
1702            let r = SliceChainReader::new(parts);
1703            zpaq_compress_size_stream(r, method.as_str())
1704        }
1705        #[cfg(feature = "backend-rwkv")]
1706        CompressionBackend::Rwkv7 { model, coder } => {
1707            with_rwkv_tls(model, |c| c.compress_size_chain(parts, *coder).unwrap_or(0))
1708        }
1709        CompressionBackend::Rate {
1710            rate_backend,
1711            coder,
1712            framing,
1713        } => {
1714            crate::compression::compress_rate_size_chain(parts, rate_backend, -1, *coder, *framing)
1715                .unwrap_or(0)
1716        }
1717    }
1718}
1719
1720/// Compute compressed size of a single byte slice with a selected compression backend.
1721pub fn compress_size_backend(data: &[u8], backend: &CompressionBackend) -> u64 {
1722    match backend {
1723        CompressionBackend::Zpaq { method } => zpaq_compress_size_bytes(data, method.as_str()),
1724        #[cfg(feature = "backend-rwkv")]
1725        CompressionBackend::Rwkv7 { model, coder } => {
1726            with_rwkv_tls(model, |c| c.compress_size(data, *coder).unwrap_or(0))
1727        }
1728        CompressionBackend::Rate {
1729            rate_backend,
1730            coder,
1731            framing,
1732        } => crate::compression::compress_rate_size(data, rate_backend, -1, *coder, *framing)
1733            .unwrap_or(0),
1734    }
1735}
1736
1737/// Compress bytes with a selected compression backend.
1738pub fn compress_bytes_backend(
1739    data: &[u8],
1740    backend: &CompressionBackend,
1741) -> anyhow::Result<Vec<u8>> {
1742    match backend {
1743        CompressionBackend::Zpaq { method } => zpaq_compress_to_vec(data, method),
1744        #[cfg(feature = "backend-rwkv")]
1745        CompressionBackend::Rwkv7 { model, coder } => {
1746            with_rwkv_tls(model, |c| c.compress(data, *coder))
1747        }
1748        CompressionBackend::Rate {
1749            rate_backend,
1750            coder,
1751            framing,
1752        } => crate::compression::compress_rate_bytes(data, rate_backend, -1, *coder, *framing),
1753    }
1754}
1755
1756/// Decompress bytes with a selected compression backend.
1757pub fn decompress_bytes_backend(
1758    input: &[u8],
1759    backend: &CompressionBackend,
1760) -> anyhow::Result<Vec<u8>> {
1761    match backend {
1762        CompressionBackend::Zpaq { .. } => zpaq_decompress_to_vec(input),
1763        #[cfg(feature = "backend-rwkv")]
1764        CompressionBackend::Rwkv7 { model, .. } => with_rwkv_tls(model, |c| c.decompress(input)),
1765        CompressionBackend::Rate {
1766            rate_backend,
1767            coder,
1768            framing,
1769        } => crate::compression::decompress_rate_bytes(input, rate_backend, -1, *coder, *framing),
1770    }
1771}
1772
1773fn prequential_rate_backend(
1774    data: &[u8],
1775    prefix_parts: &[&[u8]],
1776    max_order: i64,
1777    backend: &RateBackend,
1778) -> f64 {
1779    use crate::mixture::OnlineBytePredictor;
1780
1781    if data.is_empty() {
1782        return 0.0;
1783    }
1784    let total = prefix_parts
1785        .iter()
1786        .map(|p| p.len() as u64)
1787        .sum::<u64>()
1788        .saturating_add(data.len() as u64);
1789    let mut predictor = crate::mixture::RateBackendPredictor::from_backend(
1790        backend.clone(),
1791        max_order,
1792        crate::mixture::DEFAULT_MIN_PROB,
1793    );
1794    predictor
1795        .begin_stream(Some(total))
1796        .unwrap_or_else(|e| panic!("rate backend stream init failed: {e}"));
1797    for prefix in prefix_parts {
1798        for &b in *prefix {
1799            predictor.update(b);
1800        }
1801    }
1802    let mut bits = 0.0;
1803    for &b in data {
1804        bits -= predictor.log_prob(b) / std::f64::consts::LN_2;
1805        predictor.update(b);
1806    }
1807    predictor
1808        .finish_stream()
1809        .unwrap_or_else(|e| panic!("rate backend stream finalize failed: {e}"));
1810    bits / (data.len() as f64)
1811}
1812
1813fn frozen_plugin_rate_backend(
1814    score_data: &[u8],
1815    fit_parts: &[&[u8]],
1816    max_order: i64,
1817    backend: &RateBackend,
1818) -> f64 {
1819    if score_data.is_empty() {
1820        return 0.0;
1821    }
1822    if matches!(backend, RateBackend::RosaPlus) {
1823        let mut model = rosaplus::RosaPlus::new(max_order, false, 0, 42);
1824        for part in fit_parts {
1825            model.train_example(part);
1826        }
1827        model.build_lm();
1828        return model.cross_entropy(score_data);
1829    }
1830    #[cfg(feature = "backend-rwkv")]
1831    match backend {
1832        RateBackend::Rwkv7 { model } => {
1833            return with_rwkv_rate_tls(model, |c| {
1834                c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
1835                    .unwrap_or_else(|e| panic!("rwkv frozen-plugin scoring failed: {e:#}"))
1836            });
1837        }
1838        RateBackend::Rwkv7Method { method } => {
1839            return with_rwkv_method_tls(method, |c| {
1840                c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
1841                    .unwrap_or_else(|e| panic!("rwkv method frozen-plugin scoring failed: {e:#}"))
1842            });
1843        }
1844        _ => {}
1845    }
1846    #[cfg(feature = "backend-mamba")]
1847    match backend {
1848        RateBackend::Mamba { model } => {
1849            return with_mamba_rate_tls(model, |c| {
1850                c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
1851                    .unwrap_or_else(|e| panic!("mamba frozen-plugin scoring failed: {e:#}"))
1852            });
1853        }
1854        RateBackend::MambaMethod { method } => {
1855            return with_mamba_method_tls(method, |c| {
1856                c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
1857                    .unwrap_or_else(|e| panic!("mamba method frozen-plugin scoring failed: {e:#}"))
1858            });
1859        }
1860        _ => {}
1861    }
1862
1863    use crate::mixture::OnlineBytePredictor;
1864
1865    let fit_total = fit_parts.iter().map(|part| part.len() as u64).sum::<u64>();
1866    let mut predictor = crate::mixture::RateBackendPredictor::from_backend(
1867        backend.clone(),
1868        max_order,
1869        crate::mixture::DEFAULT_MIN_PROB,
1870    );
1871    predictor
1872        .begin_stream(Some(fit_total))
1873        .unwrap_or_else(|e| panic!("rate backend fit-pass init failed: {e}"));
1874    for part in fit_parts {
1875        for &byte in *part {
1876            predictor.update(byte);
1877        }
1878    }
1879    predictor
1880        .finish_stream()
1881        .unwrap_or_else(|e| panic!("rate backend fit-pass finalize failed: {e}"));
1882    predictor
1883        .reset_frozen(Some(score_data.len() as u64))
1884        .unwrap_or_else(|e| panic!("rate backend frozen-score reset failed: {e}"));
1885    let mut bits = 0.0;
1886    for &byte in score_data {
1887        bits -= predictor.log_prob(byte) / std::f64::consts::LN_2;
1888        predictor.update_frozen(byte);
1889    }
1890    predictor
1891        .finish_stream()
1892        .unwrap_or_else(|e| panic!("rate backend frozen-score finalize failed: {e}"));
1893    bits / (score_data.len() as f64)
1894}
1895
1896#[inline(always)]
1897fn argmax_log_prob_byte(logps: &[f64; 256]) -> u8 {
1898    let mut best_idx = 0usize;
1899    let mut best = f64::NEG_INFINITY;
1900    for (idx, &logp) in logps.iter().enumerate() {
1901        let score = if logp.is_finite() {
1902            logp
1903        } else {
1904            f64::NEG_INFINITY
1905        };
1906        if score > best {
1907            best = score;
1908            best_idx = idx;
1909        }
1910    }
1911    best_idx as u8
1912}
1913
1914fn pick_generated_byte(
1915    logps: &[f64; 256],
1916    config: GenerationConfig,
1917    rng: &mut GenerationRng,
1918) -> u8 {
1919    if matches!(config.strategy, GenerationStrategy::Greedy)
1920        || !config.temperature.is_finite()
1921        || config.temperature <= 0.0
1922    {
1923        return argmax_log_prob_byte(logps);
1924    }
1925
1926    let mut entries = [(0u8, f64::NEG_INFINITY); 256];
1927    for (idx, &logp) in logps.iter().enumerate() {
1928        let scaled = if logp.is_finite() {
1929            logp / config.temperature
1930        } else {
1931            f64::NEG_INFINITY
1932        };
1933        entries[idx] = (idx as u8, scaled);
1934    }
1935    entries.sort_by(|a, b| b.1.total_cmp(&a.1));
1936
1937    let keep_k = if config.top_k == 0 {
1938        entries.len()
1939    } else {
1940        config.top_k.min(entries.len())
1941    };
1942
1943    let top_p = if config.top_p.is_finite() {
1944        config.top_p.clamp(0.0, 1.0)
1945    } else {
1946        1.0
1947    };
1948
1949    let mut max_logp = f64::NEG_INFINITY;
1950    for &(_, logp) in entries.iter().take(keep_k) {
1951        if logp.is_finite() {
1952            max_logp = max_logp.max(logp);
1953        }
1954    }
1955    if !max_logp.is_finite() {
1956        return argmax_log_prob_byte(logps);
1957    }
1958
1959    let mut weights = [(0u8, 0.0f64); 256];
1960    let mut total = 0.0;
1961    for (idx, &(byte, logp)) in entries.iter().take(keep_k).enumerate() {
1962        let w = if logp.is_finite() {
1963            (logp - max_logp).exp()
1964        } else {
1965            0.0
1966        };
1967        weights[idx] = (byte, w);
1968        total += w;
1969    }
1970    if !(total.is_finite()) || total <= 0.0 {
1971        return argmax_log_prob_byte(logps);
1972    }
1973
1974    let cutoff_count = if top_p >= 1.0 {
1975        keep_k
1976    } else {
1977        let mut cumulative = 0.0;
1978        let mut keep = 0usize;
1979        for &(_, w) in weights.iter().take(keep_k) {
1980            cumulative += w / total;
1981            keep += 1;
1982            if cumulative >= top_p {
1983                break;
1984            }
1985        }
1986        keep.max(1)
1987    };
1988
1989    let mut truncated_total = 0.0;
1990    for &(_, w) in weights.iter().take(cutoff_count) {
1991        truncated_total += w;
1992    }
1993    if !(truncated_total.is_finite()) || truncated_total <= 0.0 {
1994        return argmax_log_prob_byte(logps);
1995    }
1996
1997    let target = rng.next_f64() * truncated_total;
1998    let mut cumulative = 0.0;
1999    let mut picked = weights[0].0;
2000    for &(byte, weight) in weights.iter().take(cutoff_count) {
2001        cumulative += weight;
2002        if cumulative >= target {
2003            picked = byte;
2004            break;
2005        }
2006    }
2007    picked
2008}
2009
2010fn generate_rate_backend_chain(
2011    prefix_parts: &[&[u8]],
2012    bytes: usize,
2013    max_order: i64,
2014    backend: &RateBackend,
2015    config: GenerationConfig,
2016) -> Vec<u8> {
2017    if bytes == 0 {
2018        return Vec::new();
2019    }
2020
2021    let total = prefix_parts
2022        .iter()
2023        .map(|p| p.len() as u64)
2024        .sum::<u64>()
2025        .saturating_add(bytes as u64);
2026    let mut session = RateBackendSession::from_backend(backend.clone(), max_order, Some(total))
2027        .unwrap_or_else(|e| panic!("rate backend generation init failed: {e}"));
2028    for &part in prefix_parts {
2029        session.observe(part);
2030    }
2031    let out = session.generate_bytes(bytes, config);
2032    session
2033        .finish()
2034        .unwrap_or_else(|e| panic!("rate backend generation finalize failed: {e}"));
2035    out
2036}
2037
2038/// Estimate entropy rate of `data` using the explicit rate `backend`.
2039pub fn entropy_rate_backend(data: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
2040    match backend {
2041        RateBackend::RosaPlus => {
2042            let mut m = rosaplus::RosaPlus::new(max_order, false, 0, 42);
2043            m.predictive_entropy_rate(data)
2044        }
2045        RateBackend::Match { .. }
2046        | RateBackend::SparseMatch { .. }
2047        | RateBackend::Ppmd { .. }
2048        | RateBackend::Sequitur { .. }
2049        | RateBackend::Calibrated { .. } => prequential_rate_backend(data, &[], max_order, backend),
2050        #[cfg(feature = "backend-rwkv")]
2051        RateBackend::Rwkv7 { model } => with_rwkv_tls(model, |c| {
2052            c.cross_entropy(data)
2053                .unwrap_or_else(|e| panic!("rwkv entropy scoring failed: {e:#}"))
2054        }),
2055        #[cfg(feature = "backend-rwkv")]
2056        RateBackend::Rwkv7Method { method } => with_rwkv_method_tls(method, |c| {
2057            c.cross_entropy(data)
2058                .unwrap_or_else(|e| panic!("rwkv method entropy scoring failed: {e:#}"))
2059        }),
2060        #[cfg(feature = "backend-mamba")]
2061        RateBackend::Mamba { model } => with_mamba_tls(model, |c| {
2062            c.cross_entropy(data)
2063                .unwrap_or_else(|e| panic!("mamba entropy scoring failed: {e:#}"))
2064        }),
2065        #[cfg(feature = "backend-mamba")]
2066        RateBackend::MambaMethod { method } => with_mamba_method_tls(method, |c| {
2067            c.cross_entropy(data)
2068                .unwrap_or_else(|e| panic!("mamba method entropy scoring failed: {e:#}"))
2069        }),
2070        RateBackend::Zpaq { method } => {
2071            if data.is_empty() {
2072                return 0.0;
2073            }
2074            let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
2075            let bits = model.update_and_score(data);
2076            bits / (data.len() as f64)
2077        }
2078        RateBackend::Mixture { spec } => {
2079            if data.is_empty() {
2080                return 0.0;
2081            }
2082            let experts = spec.build_experts();
2083            let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
2084                .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
2085            mix.begin_stream(Some(data.len() as u64))
2086                .unwrap_or_else(|e| panic!("Mixture stream init failed: {e}"));
2087            let mut bits = 0.0;
2088            for &b in data {
2089                bits -= mix.step(b) / std::f64::consts::LN_2;
2090            }
2091            mix.finish_stream()
2092                .unwrap_or_else(|e| panic!("Mixture stream finalize failed: {e}"));
2093            bits / (data.len() as f64)
2094        }
2095        RateBackend::Particle { spec } => {
2096            if data.is_empty() {
2097                return 0.0;
2098            }
2099            let mut runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
2100            let mut bits = 0.0;
2101            for &b in data {
2102                bits -= runtime.step(b) / std::f64::consts::LN_2;
2103            }
2104            bits / (data.len() as f64)
2105        }
2106        RateBackend::Ctw { depth } => {
2107            if data.is_empty() {
2108                return 0.0;
2109            }
2110            // Byte-wise CTW: factorize by bit position so deterministic bits don't leak entropy.
2111            let mut fac = crate::ctw::FacContextTree::new(*depth, 8);
2112            fac.reserve_for_symbols(data.len());
2113            for &b in data {
2114                fac.update_byte_msb(b);
2115            }
2116            let ln_p = fac.get_log_block_probability();
2117            let bits = -ln_p / std::f64::consts::LN_2;
2118            bits / (data.len() as f64)
2119        }
2120        RateBackend::FacCtw {
2121            base_depth,
2122            num_percept_bits: _,
2123            encoding_bits,
2124        } => {
2125            if data.is_empty() {
2126                return 0.0;
2127            }
2128            let bits_per_byte = (*encoding_bits).clamp(1, 8);
2129            let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte);
2130            fac.reserve_for_symbols(data.len());
2131            for &b in data {
2132                fac.update_byte_lsb(b);
2133            }
2134            let ln_p = fac.get_log_block_probability();
2135            let bits = -ln_p / std::f64::consts::LN_2;
2136            bits / (data.len() as f64)
2137        }
2138    }
2139}
2140
2141/// Estimate biased/plugin entropy rate of `data` using the explicit rate `backend`.
2142pub fn biased_entropy_rate_backend(data: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
2143    match backend {
2144        RateBackend::Zpaq { .. } => {
2145            panic!("biased/plugin entropy is not supported for zpaq rate backends in 1.1.1")
2146        }
2147        _ => frozen_plugin_rate_backend(data, &[data], max_order, backend),
2148    }
2149}
2150
2151/// Cross-entropy H_{train}(test) - score test_data under model trained on train_data.
2152pub fn cross_entropy_rate_backend(
2153    test_data: &[u8],
2154    train_data: &[u8],
2155    max_order: i64,
2156    backend: &RateBackend,
2157) -> f64 {
2158    match backend {
2159        RateBackend::Zpaq { method } => {
2160            if test_data.is_empty() {
2161                return 0.0;
2162            }
2163            let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
2164            model.update_and_score(train_data);
2165            let bits = model.update_and_score(test_data);
2166            bits / (test_data.len() as f64)
2167        }
2168        _ => frozen_plugin_rate_backend(test_data, &[train_data], max_order, backend),
2169    }
2170}
2171
2172/// Estimate joint entropy rate `H(X,Y)` using an explicit `backend`.
2173pub fn joint_entropy_rate_backend(
2174    x: &[u8],
2175    y: &[u8],
2176    max_order: i64,
2177    backend: &RateBackend,
2178) -> f64 {
2179    let (x, y) = aligned_prefix(x, y);
2180    if x.is_empty() {
2181        return 0.0;
2182    }
2183    match backend {
2184        RateBackend::RosaPlus => {
2185            let joint_symbols: Vec<u32> = (0..x.len())
2186                .map(|i| (x[i] as u32) * 256 + (y[i] as u32))
2187                .collect();
2188            let mut m = rosaplus::RosaPlus::new(max_order, false, 0, 42);
2189            m.entropy_rate_cps(&joint_symbols)
2190        }
2191        RateBackend::Match { .. }
2192        | RateBackend::SparseMatch { .. }
2193        | RateBackend::Ppmd { .. }
2194        | RateBackend::Sequitur { .. }
2195        | RateBackend::Calibrated { .. } => {
2196            let mut joint = Vec::with_capacity(x.len() * 2);
2197            for (&xb, &yb) in x.iter().zip(y.iter()) {
2198                joint.push(xb);
2199                joint.push(yb);
2200            }
2201            entropy_rate_backend(&joint, max_order, backend) * 2.0
2202        }
2203        #[cfg(feature = "backend-rwkv")]
2204        RateBackend::Rwkv7 { model } => with_rwkv_tls(model, |c| {
2205            c.joint_cross_entropy_aligned_min(x, y)
2206                .unwrap_or_else(|e| panic!("rwkv joint-entropy scoring failed: {e:#}"))
2207        }),
2208        #[cfg(feature = "backend-rwkv")]
2209        RateBackend::Rwkv7Method { method } => with_rwkv_method_tls(method, |c| {
2210            c.joint_cross_entropy_aligned_min(x, y)
2211                .unwrap_or_else(|e| panic!("rwkv method joint-entropy scoring failed: {e:#}"))
2212        }),
2213        #[cfg(feature = "backend-mamba")]
2214        RateBackend::Mamba { model } => with_mamba_tls(model, |c| {
2215            c.joint_cross_entropy_aligned_min(x, y)
2216                .unwrap_or_else(|e| panic!("mamba joint-entropy scoring failed: {e:#}"))
2217        }),
2218        #[cfg(feature = "backend-mamba")]
2219        RateBackend::MambaMethod { method } => with_mamba_method_tls(method, |c| {
2220            c.joint_cross_entropy_aligned_min(x, y)
2221                .unwrap_or_else(|e| panic!("mamba method joint-entropy scoring failed: {e:#}"))
2222        }),
2223        RateBackend::Zpaq { method } => {
2224            let mut joint = Vec::with_capacity(x.len() * 2);
2225            for (&xb, &yb) in x.iter().zip(y.iter()) {
2226                joint.push(xb);
2227                joint.push(yb);
2228            }
2229            let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
2230            let bits = model.update_and_score(&joint);
2231            bits / (x.len() as f64)
2232        }
2233        RateBackend::Mixture { spec } => {
2234            let mut joint = Vec::with_capacity(x.len() * 2);
2235            for (&xb, &yb) in x.iter().zip(y.iter()) {
2236                joint.push(xb);
2237                joint.push(yb);
2238            }
2239            let experts = spec.build_experts();
2240            let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
2241                .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
2242            mix.begin_stream(Some(joint.len() as u64))
2243                .unwrap_or_else(|e| panic!("Mixture stream init failed: {e}"));
2244            let mut bits = 0.0;
2245            for &b in &joint {
2246                bits -= mix.step(b) / std::f64::consts::LN_2;
2247            }
2248            mix.finish_stream()
2249                .unwrap_or_else(|e| panic!("Mixture stream finalize failed: {e}"));
2250            bits / (x.len() as f64)
2251        }
2252        RateBackend::Particle { spec } => {
2253            let mut joint = Vec::with_capacity(x.len() * 2);
2254            for (&xb, &yb) in x.iter().zip(y.iter()) {
2255                joint.push(xb);
2256                joint.push(yb);
2257            }
2258            let mut runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
2259            let mut bits = 0.0;
2260            for &b in &joint {
2261                bits -= runtime.step(b) / std::f64::consts::LN_2;
2262            }
2263            bits / (x.len() as f64)
2264        }
2265        RateBackend::Ctw { depth } => {
2266            // NOTE: CTW interleaves bits: x_0, y_0, x_1, y_1...
2267            // This estimates the joint entropy H(X,Y) by modeling the sequence
2268            // of alternating bits. This is a fine-grained joint model but
2269            // theoretically consistent for estimating joint entropy rate.
2270            // ROSA uses 16-bit joint symbols (x << 8 | y). Both are valid.
2271            let mut fac = crate::ctw::FacContextTree::new(*depth, 16);
2272            for k in 0..x.len() {
2273                let bx = x[k];
2274                let by = y[k];
2275                for bit_idx in 0..8 {
2276                    let bit_x = ((bx >> (7 - bit_idx)) & 1) == 1;
2277                    let bit_y = ((by >> (7 - bit_idx)) & 1) == 1;
2278                    fac.update(bit_x, bit_idx);
2279                    fac.update(bit_y, bit_idx + 8);
2280                }
2281            }
2282            let ln_p = fac.get_log_block_probability();
2283            let bits = -ln_p / std::f64::consts::LN_2;
2284            bits / (x.len() as f64)
2285        }
2286        RateBackend::FacCtw {
2287            base_depth,
2288            num_percept_bits: _,
2289            encoding_bits,
2290        } => {
2291            // Joint: interleave x and y bits, use 2*encoding_bits trees
2292            let bits_per_byte = (*encoding_bits).clamp(1, 8);
2293            let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte * 2);
2294            for k in 0..x.len() {
2295                let bx = x[k];
2296                let by = y[k];
2297                for i in 0..bits_per_byte {
2298                    // Tree structure:
2299                    // bits_per_byte trees for X, bits_per_byte trees for Y.
2300                    // But we interleave them in the "joint" sense.
2301                    // Here we map bit i of X to tree 2*i, bit i of Y to tree 2*i + 1
2302                    let bit_idx_x = i * 2;
2303                    let bit_idx_y = bit_idx_x + 1;
2304                    fac.update(((bx >> i) & 1) == 1, bit_idx_x);
2305                    fac.update(((by >> i) & 1) == 1, bit_idx_y);
2306                }
2307            }
2308            let ln_p = fac.get_log_block_probability();
2309            let bits = -ln_p / std::f64::consts::LN_2;
2310            bits / (x.len() as f64)
2311        }
2312    }
2313}
2314#[inline(always)]
2315/// Compute compressed size for a file path with an explicit ZPAQ thread count.
2316pub fn get_compressed_size_parallel(path: &str, method: &str, threads: usize) -> u64 {
2317    // Convert Input file to Vec<u8>, and reference that (compress_size only takes &[u8] input), and pass method.
2318    // Will panic if file does not exist, so it must be prevalidated.
2319    zpaq_compress_size_parallel_bytes(&std::fs::read(path).unwrap(), method, threads)
2320}
2321
2322#[inline(always)]
2323/// Read all files in `paths` in parallel and return their byte contents.
2324pub fn get_bytes_from_paths(paths: &[&str]) -> Vec<Vec<u8>> {
2325    paths
2326        .par_iter()
2327        .map(|path| std::fs::read(*path).expect("failed to read file"))
2328        .collect()
2329}
2330
2331/// ------- Bulk File Compression Functions -------
2332#[inline(always)]
2333pub fn get_sequential_compressed_sizes_from_sequential_paths(
2334    paths: &[&str],
2335    method: &str,
2336) -> Vec<u64> {
2337    // This will, in parallel load all files into memory, THEN in parallel compress each one, each with one thread.
2338    // Use when File IO is the bottleneck
2339    // Only uses ONE ZPAQ THREAD.
2340    // For VERY large n (relative to threads) with small files (relative to memory) this may be useful.
2341    get_bytes_from_paths(paths)
2342        .par_iter()
2343        .map(|data| zpaq_compress_size_bytes(data, method))
2344        .collect()
2345}
2346
2347#[inline(always)]
2348/// Compress all paths after preloading bytes, using per-file parallel ZPAQ compression.
2349pub fn get_parallel_compressed_sizes_from_sequential_paths(
2350    paths: &[&str],
2351    method: &str,
2352    threads: usize,
2353) -> Vec<u64> {
2354    // This will, in parallel load all files into memory, THEN in parallel compress each one, with THREADS. (for each file, the thread count is THREADS)
2355    // Use when File IO is the bottleneck.
2356    // Balanced parallelization between RAYON_NUM_THREADS and ZPAQ `THREADS` const. For when total dataset will fit in memory.
2357    get_bytes_from_paths(paths)
2358        .par_iter()
2359        .map(|data| zpaq_compress_size_parallel_bytes(data, method, threads))
2360        .collect()
2361}
2362
2363#[inline(always)]
2364/// Compress all paths directly from disk using single-thread ZPAQ per file.
2365pub fn get_sequential_compressed_sizes_from_parallel_paths(
2366    paths: &[&str],
2367    method: &str,
2368) -> Vec<u64> {
2369    // This will, in parallel, for each file, read it from disk and compress it with one thread. (one file, one thread)
2370    // Use when File IO is not the bottleneck. Lower memory usage. (does not preload dataset)
2371    // Only uses ONE ZPAQ THREAD. For VERY large n(relative to threads) with large files(relative to memory) this may be useful.
2372    paths
2373        .par_iter()
2374        .map(|path| get_compressed_size(path, method))
2375        .collect()
2376}
2377
2378#[inline(always)]
2379/// Compress all paths directly from disk using per-file multi-thread ZPAQ.
2380pub fn get_parallel_compressed_sizes_from_parallel_paths(
2381    paths: &[&str],
2382    method: &str,
2383    threads: usize,
2384) -> Vec<u64> {
2385    // This will, in parallel, for each file, read it from disk and compress it with THREADS. (for each file, the thread count is THREADS)
2386    // Use when File IO is not the bottleneck. Lower memory usage. (does not preload dataset)
2387    // For large n(relative to threads) with VERY large files(relative to memory) this may be useful.
2388    // This will reflect RAYON_NUM_THREADS and THREAD const values.
2389    paths
2390        .par_iter()
2391        .map(|path| get_compressed_size_parallel(path, method, threads))
2392        .collect()
2393}
2394
2395/// Optimizes parallelization
2396#[inline(always)]
2397pub fn get_compressed_sizes_from_paths(paths: &[&str], method: &str) -> Vec<u64> {
2398    let n: usize = paths.len();
2399    let num_threads: usize = *NUM_THREADS.get_or_init(num_cpus::get);
2400    if n < num_threads {
2401        get_parallel_compressed_sizes_from_parallel_paths(paths, method, num_threads.div_ceil(n))
2402    } else {
2403        get_sequential_compressed_sizes_from_parallel_paths(paths, method)
2404    }
2405}
2406
2407/// ------- NCD (Normalized Compression Distance) ------
2408///
2409/// NCD is a parameter-free similarity metric based on Kolmogorov complexity.
2410/// Since Kolmogorov complexity `K(x)` is uncomputable, we approximate it using
2411/// the compressed size `C(x)` provided by a real-world compressor (here, ZPAQ).
2412///
2413/// The general form is:
2414/// `NCD(x,y) = (C(xy) - min(C(x), C(y))) / max(C(x), C(y))`
2415///
2416/// Different variants handle normalization and symmetry differently.
2417#[derive(Clone, Copy, Debug, Eq, PartialEq)]
2418pub enum NcdVariant {
2419    /// Standard Vitanyi NCD:
2420    /// `NCD(x,y) = (C(xy) - min(C(x), C(y))) / max(C(x), C(y))`
2421    /// Note: `C(xy)` denotes compressing the concatenation of x and y.
2422    Vitanyi,
2423    /// Symmetric Vitanyi NCD:
2424    /// `NCD_sym(x,y) = (min(C(xy), C(yx)) - min(C(x), C(y))) / max(C(x), C(y))`
2425    /// Takes the best compression of `xy` or `yx` to ensure symmetry even if the compressor is not symmetric.
2426    SymVitanyi,
2427    /// Conservative NCD:
2428    /// `NCD_cons(x,y) = (C(xy) - min(C(x), C(y))) / C(xy)`
2429    /// Normalizes by the joint compressed size instead of the max marginal.
2430    Cons,
2431    /// Symmetric Conservative NCD:
2432    /// `NCD_sym_cons(x,y) = (min(C(xy), C(yx)) - min(C(x), C(y))) / min(C(xy), C(yx))`
2433    SymCons,
2434}
2435
2436#[inline(always)]
2437fn compress_size_bytes(data: &[u8], method: &str) -> u64 {
2438    zpaq_compress_size_bytes(data, method)
2439}
2440
2441#[inline(always)]
2442fn ncd_from_sizes(cx: u64, cy: u64, cxy: u64, cyx: Option<u64>, variant: NcdVariant) -> f64 {
2443    let min_c = cx.min(cy) as f64;
2444    let max_c = cx.max(cy) as f64;
2445
2446    match variant {
2447        NcdVariant::Vitanyi => {
2448            if max_c == 0.0 {
2449                0.0
2450            } else {
2451                (cxy as f64 - min_c) / max_c
2452            }
2453        }
2454        NcdVariant::SymVitanyi => {
2455            let m = cxy.min(cyx.expect("cyx required for SymVitanyi")) as f64;
2456            if max_c == 0.0 {
2457                0.0
2458            } else {
2459                (m - min_c) / max_c
2460            }
2461        }
2462        NcdVariant::Cons => {
2463            let denom = cxy as f64;
2464            if denom == 0.0 {
2465                0.0
2466            } else {
2467                (cxy as f64 - min_c) / denom
2468            }
2469        }
2470        NcdVariant::SymCons => {
2471            let m = cxy.min(cyx.expect("cyx required for SymCons")) as f64;
2472            if m == 0.0 { 0.0 } else { (m - min_c) / m }
2473        }
2474    }
2475}
2476
2477#[inline(always)]
2478/// Compute NCD for in-memory byte slices using the given ZPAQ `method` and `variant`.
2479pub fn ncd_bytes(x: &[u8], y: &[u8], method: &str, variant: NcdVariant) -> f64 {
2480    let backend = CompressionBackend::Zpaq {
2481        method: method.to_string(),
2482    };
2483    ncd_bytes_backend(x, y, &backend, variant)
2484}
2485
2486/// NCD with bytes using the default context.
2487#[inline(always)]
2488pub fn ncd_bytes_default(x: &[u8], y: &[u8], variant: NcdVariant) -> f64 {
2489    with_default_ctx(|ctx| ctx.ncd_bytes(x, y, variant))
2490}
2491
2492/// Compute NCD for in-memory byte slices using an explicit compression `backend`.
2493pub fn ncd_bytes_backend(
2494    x: &[u8],
2495    y: &[u8],
2496    backend: &CompressionBackend,
2497    variant: NcdVariant,
2498) -> f64 {
2499    let (cx, cy) = rayon::join(
2500        || compress_size_backend(x, backend),
2501        || compress_size_backend(y, backend),
2502    );
2503
2504    let cxy = compress_size_chain_backend(&[x, y], backend);
2505
2506    let cyx = match variant {
2507        NcdVariant::SymVitanyi | NcdVariant::SymCons => {
2508            Some(compress_size_chain_backend(&[y, x], backend))
2509        }
2510        _ => None,
2511    };
2512
2513    ncd_from_sizes(cx, cy, cxy, cyx, variant)
2514}
2515
2516#[inline(always)]
2517/// Compute NCD for two file paths using a ZPAQ `method` and `variant`.
2518pub fn ncd_paths(x: &str, y: &str, method: &str, variant: NcdVariant) -> f64 {
2519    let (bx, by) = rayon::join(
2520        || std::fs::read(x).expect("failed to read x"),
2521        || std::fs::read(y).expect("failed to read y"),
2522    );
2523    ncd_bytes(&bx, &by, method, variant)
2524}
2525
2526/// Compute NCD for two file paths using an explicit compression `backend`.
2527pub fn ncd_paths_backend(
2528    x: &str,
2529    y: &str,
2530    backend: &CompressionBackend,
2531    variant: NcdVariant,
2532) -> f64 {
2533    let (bx, by) = rayon::join(
2534        || std::fs::read(x).expect("failed to read x"),
2535        || std::fs::read(y).expect("failed to read y"),
2536    );
2537    ncd_bytes_backend(&bx, &by, backend, variant)
2538}
2539
2540/// Back-compat convenience wrappers (operate on file paths).
2541#[inline(always)]
2542pub fn ncd_vitanyi(x: &str, y: &str, method: &str) -> f64 {
2543    ncd_paths(x, y, method, NcdVariant::Vitanyi)
2544}
2545#[inline(always)]
2546/// Convenience wrapper for symmetric-Vitanyi NCD on file paths.
2547pub fn ncd_sym_vitanyi(x: &str, y: &str, method: &str) -> f64 {
2548    ncd_paths(x, y, method, NcdVariant::SymVitanyi)
2549}
2550#[inline(always)]
2551/// Convenience wrapper for conservative NCD on file paths.
2552pub fn ncd_cons(x: &str, y: &str, method: &str) -> f64 {
2553    ncd_paths(x, y, method, NcdVariant::Cons)
2554}
2555#[inline(always)]
2556/// Convenience wrapper for symmetric-conservative NCD on file paths.
2557pub fn ncd_sym_cons(x: &str, y: &str, method: &str) -> f64 {
2558    ncd_paths(x, y, method, NcdVariant::SymCons)
2559}
2560
2561/// Computes an NCD matrix (row-major, len = n*n) for in-memory byte blobs.
2562///
2563/// Note: For symmetric variants, this computes each unordered pair once and writes both (i,j) and (j,i).
2564pub fn ncd_matrix_bytes(datas: &[Vec<u8>], method: &str, variant: NcdVariant) -> Vec<f64> {
2565    let n = datas.len();
2566    let cx: Vec<u64> = datas
2567        .par_iter()
2568        .map(|d| compress_size_bytes(d, method))
2569        .collect();
2570
2571    let mut out = vec![0.0f64; n * n];
2572    let out_ptr = std::sync::atomic::AtomicPtr::new(out.as_mut_ptr());
2573
2574    match variant {
2575        NcdVariant::SymVitanyi | NcdVariant::SymCons => {
2576            (0..n)
2577                .into_par_iter()
2578                .flat_map_iter(|i| (i + 1..n).map(move |j| (i, j)))
2579                .for_each_init(Vec::<u8>::new, |buf, (i, j)| {
2580                    let x = &datas[i];
2581                    let y = &datas[j];
2582
2583                    buf.clear();
2584                    buf.reserve(x.len() + y.len());
2585                    buf.extend_from_slice(x);
2586                    buf.extend_from_slice(y);
2587                    let cxy = compress_size_bytes(buf, method);
2588
2589                    buf.clear();
2590                    buf.reserve(x.len() + y.len());
2591                    buf.extend_from_slice(y);
2592                    buf.extend_from_slice(x);
2593                    let cyx = compress_size_bytes(buf, method);
2594
2595                    let d = ncd_from_sizes(cx[i], cx[j], cxy, Some(cyx), variant);
2596
2597                    // Safety: each (i,j) cell is written exactly once across all iterations.
2598                    let p = out_ptr.load(std::sync::atomic::Ordering::Relaxed);
2599                    unsafe {
2600                        *p.add(i * n + j) = d;
2601                        *p.add(j * n + i) = d;
2602                    }
2603                });
2604        }
2605        NcdVariant::Vitanyi | NcdVariant::Cons => {
2606            (0..n)
2607                .into_par_iter()
2608                .for_each_init(Vec::<u8>::new, |buf, i| {
2609                    let x = &datas[i];
2610                    for j in 0..n {
2611                        let d = if i == j {
2612                            0.0
2613                        } else {
2614                            let y = &datas[j];
2615                            buf.clear();
2616                            buf.reserve(x.len() + y.len());
2617                            buf.extend_from_slice(x);
2618                            buf.extend_from_slice(y);
2619                            let cxy = compress_size_bytes(buf, method);
2620                            ncd_from_sizes(cx[i], cx[j], cxy, None, variant)
2621                        };
2622
2623                        let p = out_ptr.load(std::sync::atomic::Ordering::Relaxed);
2624                        unsafe {
2625                            *p.add(i * n + j) = d;
2626                        }
2627                    }
2628                });
2629        }
2630    }
2631
2632    out
2633}
2634
2635/// Computes an NCD matrix (row-major, len = n*n) for files (preloads all files into memory once).
2636pub fn ncd_matrix_paths(paths: &[&str], method: &str, variant: NcdVariant) -> Vec<f64> {
2637    let datas = get_bytes_from_paths(paths);
2638    ncd_matrix_bytes(&datas, method, variant)
2639}
2640
2641// ============================================================
2642// Entropy-Based Distance Primitives (via ROSA)
2643// ============================================================
2644//
2645// These use ROSA's Witten-Bell language model to estimate entropy
2646// and compute information-theoretic distances.
2647
2648/// Compute marginal (Shannon) entropy H(X) = −Σ p(x) log₂ p(x) in bits/symbol.
2649///
2650/// This is the simple first-order entropy from the byte histogram,
2651/// NOT the context-conditional entropy rate from a language model.
2652#[inline(always)]
2653pub fn marginal_entropy_bytes(data: &[u8]) -> f64 {
2654    if data.is_empty() {
2655        return 0.0;
2656    }
2657
2658    let mut counts = [0u64; 256];
2659    for &b in data {
2660        counts[b as usize] += 1;
2661    }
2662
2663    let n = data.len() as f64;
2664    let mut h = 0.0f64;
2665    for &count in &counts {
2666        if count > 0 {
2667            let p = count as f64 / n;
2668            h -= p * p.log2();
2669        }
2670    }
2671    h
2672}
2673
2674/// Compute entropy rate `Ĥ(X)` in bits/symbol using ROSA LM.
2675///
2676/// This uses ROSA's context-conditional Witten-Bell model to estimate
2677/// the entropy rate, which accounts for sequential dependencies.
2678///
2679/// The estimator is **prequential** (predictive sequential): it sums the negative log-probability
2680/// of each symbol `x_t` given its past context `x_{<t}`, estimated from the model trained on `x_{<t}`.
2681///
2682/// `Ĥ(X) = -1/N * Σ log2 P(x_t | x_{t-k}^{t-1})`
2683///
2684/// For i.i.d. data, this should approximately equal `marginal_entropy_bytes`.
2685///
2686/// * `max_order`: Maximum context order for the suffix automaton LM.
2687///   A value of -1 means unlimited context (bounded only by memory/sequence length).
2688#[inline(always)]
2689pub fn entropy_rate_bytes(data: &[u8], max_order: i64) -> f64 {
2690    with_default_ctx(|ctx| ctx.entropy_rate_bytes(data, max_order))
2691}
2692
2693/// Compute biased entropy rate Ĥ_biased(X) bits per symbol.
2694///
2695/// This uses the full plugin estimator (training on the whole text, then scoring the same text).
2696/// While biased as a source entropy estimate, it is mathematically consistent for
2697/// similarity metrics like Mutual Information and NED.
2698#[inline(always)]
2699pub fn biased_entropy_rate_bytes(data: &[u8], max_order: i64) -> f64 {
2700    with_default_ctx(|ctx| ctx.biased_entropy_rate_bytes(data, max_order))
2701}
2702
2703/// Compute joint marginal entropy H(X,Y) = −Σ p(x,y) log₂ p(x,y) in bits/symbol-pair.
2704///
2705/// Uses a direct histogram of (x_i, y_i) pairs. This is the exact first-order
2706/// joint entropy, matching the spec.md definition.
2707#[inline(always)]
2708pub fn joint_marginal_entropy_bytes(x: &[u8], y: &[u8]) -> f64 {
2709    let (x, y) = aligned_prefix(x, y);
2710    let n = x.len();
2711    if n == 0 {
2712        return 0.0;
2713    }
2714
2715    // Count pair occurrences using a HashMap for (x, y) pairs
2716    // There are up to 65536 possible pairs, so we can use a flat array
2717    let mut counts = vec![0u64; 256 * 256];
2718    for i in 0..n {
2719        let pair_idx = (x[i] as usize) * 256 + (y[i] as usize);
2720        counts[pair_idx] += 1;
2721    }
2722
2723    let n_f64 = n as f64;
2724    let mut h = 0.0f64;
2725    for &c in &counts {
2726        if c > 0 {
2727            let p = c as f64 / n_f64;
2728            h -= p * p.log2();
2729        }
2730    }
2731    h
2732}
2733
2734/// Compute joint entropy rate `Ĥ(X,Y)`.
2735///
2736/// Dispatches based on `max_order`:
2737/// - `max_order == 0`: Strictly aligned pair-symbol mapping (Marginal Joint Entropy).
2738///   Treats `(x_i, y_i)` as a single symbol in a product alphabet `Σ_X × Σ_Y`.
2739/// - `max_order != 0`: Shift-invariant algorithmic joint entropy approximated via ROSA.
2740///   Constructs a sequence of pair-symbols and estimates the entropy rate of that sequence.
2741///
2742/// **Note**: This is an *aligned* joint entropy-rate estimate over time-indexed pairs
2743/// `(x_i, y_i)`. All joint-based quantities (`H(X)`, `H(Y)`, `H(X,Y)`, `I`, NED, NTE, etc.)
2744/// should be computed over the same aligned sample.
2745#[inline(always)]
2746pub fn joint_entropy_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2747    with_default_ctx(|ctx| ctx.joint_entropy_rate_bytes(x, y, max_order))
2748}
2749
2750/// Compute conditional entropy rate `Ĥ(X|Y)`.
2751///
2752/// Dispatches based on `max_order`:
2753/// - `max_order == 0`: Strictly aligned `H(X,Y) - H(Y)` using marginals.
2754/// - `max_order != 0`: Chain rule definition `Ĥ(X|Y) = Ĥ(X,Y) - Ĥ(Y)`.
2755///
2756/// Note: This relies on the identity `H(X|Y) = H(X,Y) - H(Y)`.
2757#[inline(always)]
2758pub fn conditional_entropy_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2759    with_default_ctx(|ctx| ctx.conditional_entropy_rate_bytes(x, y, max_order))
2760}
2761
2762/// Compute conditional entropy H(X|Y) = H(X,Y) − H(Y)
2763///
2764/// Dispatches based on `max_order`.
2765#[inline(always)]
2766pub fn conditional_entropy_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2767    with_default_ctx(|ctx| ctx.conditional_entropy_bytes(x, y, max_order))
2768}
2769
2770/// Compute mutual information `I(X;Y) = H(X) + H(Y) - H(X,Y)`.
2771///
2772/// Dispatches based on `max_order`. If 0, uses marginals; else uses rates.
2773///
2774/// `I(X;Y) = Σ p(x,y) log(p(x,y) / (p(x)p(y)))`
2775#[inline(always)]
2776pub fn mutual_information_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2777    with_default_ctx(|ctx| ctx.mutual_information_bytes(x, y, max_order))
2778}
2779
2780/// Marginal Mutual Information (exact/histogram)
2781pub fn mutual_information_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
2782    let (x, y) = aligned_prefix(x, y);
2783    let h_x = marginal_entropy_bytes(x);
2784    let h_y = marginal_entropy_bytes(y);
2785    let h_xy = joint_marginal_entropy_bytes(x, y);
2786    (h_x + h_y - h_xy).max(0.0)
2787}
2788
2789/// Entropy Rate Mutual Information (ROSA predictive)
2790#[inline(always)]
2791pub fn mutual_information_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2792    with_default_ctx(|ctx| ctx.mutual_information_rate_bytes(x, y, max_order))
2793}
2794
2795// ====== NED: Normalized Entropy Distance ======
2796//
2797// A metric distance based on the overlap of information between two variables.
2798
2799/// NED(X,Y) = (H(X,Y) - min(H(X), H(Y))) / max(H(X), H(Y))
2800///
2801/// Dispatches based on `max_order`. If 0, uses marginals; else uses rates.
2802///
2803/// Range: [0, 1].
2804/// * 0: Identity (X determines Y and Y determines X).
2805/// * 1: Independence (X and Y share no information).
2806#[inline(always)]
2807pub fn ned_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2808    with_default_ctx(|ctx| ctx.ned_bytes(x, y, max_order))
2809}
2810
2811/// Marginal NED (exact/histogram)
2812pub fn ned_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
2813    let (x, y) = aligned_prefix(x, y);
2814    let h_x = marginal_entropy_bytes(x);
2815    let h_y = marginal_entropy_bytes(y);
2816    let h_xy = joint_marginal_entropy_bytes(x, y);
2817    let min_h = h_x.min(h_y);
2818    let max_h = h_x.max(h_y);
2819    if max_h == 0.0 {
2820        0.0
2821    } else {
2822        ((h_xy - min_h) / max_h).clamp(0.0, 1.0)
2823    }
2824}
2825
2826/// Normalized Entropy Distance (Rate-based)
2827#[inline(always)]
2828pub fn ned_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2829    with_default_ctx(|ctx| ctx.ned_bytes(x, y, max_order))
2830}
2831
2832/// NED_cons(X,Y) = (H(X,Y) - min(H(X), H(Y))) / H(X,Y)
2833///
2834/// Conservative variant. Dispatches based on `max_order`.
2835#[inline(always)]
2836pub fn ned_cons_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2837    with_default_ctx(|ctx| ctx.ned_cons_bytes(x, y, max_order))
2838}
2839
2840/// Conservative marginal NED using histogram entropy estimates.
2841pub fn ned_cons_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
2842    let h_x = marginal_entropy_bytes(x);
2843    let h_y = marginal_entropy_bytes(y);
2844    let h_xy = joint_marginal_entropy_bytes(x, y);
2845    let min_h = h_x.min(h_y);
2846    if h_xy == 0.0 {
2847        0.0
2848    } else {
2849        ((h_xy - min_h) / h_xy).clamp(0.0, 1.0)
2850    }
2851}
2852
2853#[inline(always)]
2854/// Conservative rate NED using the current default context backend.
2855pub fn ned_cons_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2856    with_default_ctx(|ctx| ctx.ned_cons_bytes(x, y, max_order))
2857}
2858
2859// ====== NTE: Normalized Transform Effort (Variation of Information) ======
2860
2861/// NTE(X,Y) = VI(X,Y) / max(H(X), H(Y))
2862/// where `VI(X,Y) = H(X|Y) + H(Y|X) = 2H(X,Y) - H(X) - H(Y)`.
2863///
2864/// Represents the "effort" required to transform X into Y (and vice versa) relative
2865/// to their complexity.
2866///
2867/// Note: VI can be as large as `H(X) + H(Y)`. If `H(X) ≈ H(Y)`, then VI can be `≈ 2 max(H(X), H(Y))`.
2868/// Thus, NTE is in [0, 2].
2869/// * Values near 0 indicate near-identity.
2870/// * Values near 1+ indicate substantial effort/transform cost (e.g. independence).
2871///
2872/// Dispatches based on `max_order`.
2873#[inline(always)]
2874pub fn nte_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2875    with_default_ctx(|ctx| ctx.nte_bytes(x, y, max_order))
2876}
2877
2878/// Marginal NTE using histogram entropy estimates.
2879pub fn nte_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
2880    let (x, y) = aligned_prefix(x, y);
2881    let h_x = marginal_entropy_bytes(x);
2882    let h_y = marginal_entropy_bytes(y);
2883    let h_xy = joint_marginal_entropy_bytes(x, y);
2884    let vi = 2.0 * h_xy - h_x - h_y;
2885    let max_h = h_x.max(h_y);
2886    if max_h == 0.0 {
2887        0.0
2888    } else {
2889        (vi / max_h).clamp(0.0, 2.0)
2890    }
2891}
2892
2893#[inline(always)]
2894/// Rate NTE using the current default context backend.
2895pub fn nte_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
2896    with_default_ctx(|ctx| ctx.nte_bytes(x, y, max_order))
2897}
2898
2899// ====== TVD: Total Variation Distance ======
2900
2901/// Compute marginal byte histogram p(i) = count(i) / N for i ∈ [0, 255]
2902#[inline(always)]
2903fn byte_histogram(data: &[u8]) -> [f64; 256] {
2904    let mut counts = [0u64; 256];
2905    for &b in data {
2906        counts[b as usize] += 1;
2907    }
2908    let n = data.len() as f64;
2909    let mut probs = [0.0f64; 256];
2910    if n == 0.0 {
2911        return probs;
2912    }
2913    for i in 0..256 {
2914        probs[i] = counts[i] as f64 / n;
2915    }
2916    probs
2917}
2918
2919/// TVD_marg(X,Y) = (1/2) Σᵢ |p_X(i) - p_Y(i)|
2920///
2921/// Total Variation Distance over marginal byte distributions.
2922/// True metric on probability space. Range: [0, 1].
2923/// 0 = identical distributions, 1 = completely disjoint support.
2924#[inline(always)]
2925pub fn tvd_bytes(x: &[u8], y: &[u8], _max_order: i64) -> f64 {
2926    if x.is_empty() || y.is_empty() {
2927        return 0.0;
2928    }
2929    let p_x = byte_histogram(x);
2930    let p_y = byte_histogram(y);
2931
2932    let mut sum = 0.0f64;
2933    for i in 0..256 {
2934        sum += (p_x[i] - p_y[i]).abs();
2935    }
2936
2937    (sum / 2.0).clamp(0.0, 1.0)
2938}
2939
2940// ====== NHD: Normalized Hellinger Distance ======
2941
2942/// NHD(X,Y) = sqrt(1 - BC(X,Y)) where BC = Σᵢ sqrt(p_X(i) · p_Y(i))
2943///
2944/// Normalized Hellinger Distance over marginal byte distributions.
2945/// True metric. Range: [0, 1]. 0 = identical, 1 = disjoint support.
2946#[inline(always)]
2947pub fn nhd_bytes(x: &[u8], y: &[u8], _max_order: i64) -> f64 {
2948    if x.is_empty() || y.is_empty() {
2949        return 0.0;
2950    }
2951    let p_x = byte_histogram(x);
2952    let p_y = byte_histogram(y);
2953
2954    // Bhattacharyya coefficient: BC = Σᵢ sqrt(p_X(i) · p_Y(i))
2955    let mut bc = 0.0f64;
2956    for i in 0..256 {
2957        bc += (p_x[i] * p_y[i]).sqrt();
2958    }
2959
2960    // NHD = sqrt(1 - BC)
2961    (1.0 - bc).max(0.0).sqrt()
2962}
2963
2964// ====== Other Information-Theoretic Measures ======
2965
2966/// Compute cross-entropy H_{train}(test) - score test_data under model trained on train_data.
2967///
2968/// Dispatches based on `max_order`.
2969#[inline(always)]
2970pub fn cross_entropy_bytes(test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
2971    with_default_ctx(|ctx| ctx.cross_entropy_bytes(test_data, train_data, max_order))
2972}
2973
2974/// Compute cross-entropy rate using ROSA/CTW/RWKV.
2975/// Training model on `train_data` and evaluating probability of `test_data`.
2976#[inline(always)]
2977pub fn cross_entropy_rate_bytes(test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
2978    with_default_ctx(|ctx| ctx.cross_entropy_rate_bytes(test_data, train_data, max_order))
2979}
2980
2981/// Generate a continuation from `prompt`
2982/// using the current default context and [`GenerationConfig::default()`].
2983///
2984/// The default is deterministic frozen sampling with seed `42`.
2985#[inline(always)]
2986pub fn generate_bytes(prompt: &[u8], bytes: usize, max_order: i64) -> Vec<u8> {
2987    with_default_ctx(|ctx| ctx.generate_bytes(prompt, bytes, max_order))
2988}
2989
2990/// Generate a continuation from `prompt` using the current default context.
2991#[inline(always)]
2992pub fn generate_bytes_with_config(
2993    prompt: &[u8],
2994    bytes: usize,
2995    max_order: i64,
2996    config: GenerationConfig,
2997) -> Vec<u8> {
2998    with_default_ctx(|ctx| ctx.generate_bytes_with_config(prompt, bytes, max_order, config))
2999}
3000
3001/// Generate a continuation after conditioning on an explicit chain of prefix parts
3002/// using the current default context and [`GenerationConfig::default()`].
3003#[inline(always)]
3004pub fn generate_bytes_conditional_chain(
3005    prefix_parts: &[&[u8]],
3006    bytes: usize,
3007    max_order: i64,
3008) -> Vec<u8> {
3009    with_default_ctx(|ctx| ctx.generate_bytes_conditional_chain(prefix_parts, bytes, max_order))
3010}
3011
3012/// Generate a continuation after conditioning on an explicit chain of prefix parts
3013/// using the current default context.
3014#[inline(always)]
3015pub fn generate_bytes_conditional_chain_with_config(
3016    prefix_parts: &[&[u8]],
3017    bytes: usize,
3018    max_order: i64,
3019    config: GenerationConfig,
3020) -> Vec<u8> {
3021    with_default_ctx(|ctx| {
3022        ctx.generate_bytes_conditional_chain_with_config(prefix_parts, bytes, max_order, config)
3023    })
3024}
3025
3026/// Kullback-Leibler Divergence D_KL(P || Q) = Σ p(x) log(p(x) / q(x))
3027///
3028/// Marginal only. Measure of how one probability distribution is different from a second.
3029pub fn d_kl_bytes(x: &[u8], y: &[u8]) -> f64 {
3030    if x.is_empty() || y.is_empty() {
3031        return 0.0;
3032    }
3033    let p_x = byte_histogram(x);
3034    let p_y = byte_histogram(y);
3035    let mut d_kl = 0.0f64;
3036    for i in 0..256 {
3037        if p_x[i] > 0.0 {
3038            let q_y = p_y[i].max(1e-12);
3039            d_kl += p_x[i] * (p_x[i] / q_y).log2();
3040        }
3041    }
3042    d_kl.max(0.0)
3043}
3044
3045/// Jensen-Shannon Divergence JSD(P || Q) = 1/2 D_KL(P || M) + 1/2 D_KL(Q || M)
3046/// where M = 1/2 (P + Q)
3047///
3048/// Marginal only. Symmetrized and smoothed version of KL divergence. Range `[0,1]`.
3049pub fn js_div_bytes(x: &[u8], y: &[u8]) -> f64 {
3050    if x.is_empty() || y.is_empty() {
3051        return 0.0;
3052    }
3053    let p_x = byte_histogram(x);
3054    let p_y = byte_histogram(y);
3055    let mut m = [0.0f64; 256];
3056    for i in 0..256 {
3057        m[i] = 0.5 * (p_x[i] + p_y[i]);
3058    }
3059
3060    let mut kl_pm = 0.0f64;
3061    let mut kl_qm = 0.0f64;
3062    for i in 0..256 {
3063        if p_x[i] > 0.0 {
3064            kl_pm += p_x[i] * (p_x[i] / m[i]).log2();
3065        }
3066        if p_y[i] > 0.0 {
3067            kl_qm += p_y[i] * (p_y[i] / m[i]).log2();
3068        }
3069    }
3070    (0.5 * kl_pm + 0.5 * kl_qm).max(0.0)
3071}
3072
3073// ====== Path-based convenience wrappers ======
3074
3075/// NED for files.
3076pub fn ned_paths(x: &str, y: &str, max_order: i64) -> f64 {
3077    let (bx, by) = rayon::join(
3078        || std::fs::read(x).expect("failed to read x"),
3079        || std::fs::read(y).expect("failed to read y"),
3080    );
3081    ned_bytes(&bx, &by, max_order)
3082}
3083
3084/// NTE for files.
3085pub fn nte_paths(x: &str, y: &str, max_order: i64) -> f64 {
3086    let (bx, by) = rayon::join(
3087        || std::fs::read(x).expect("failed to read x"),
3088        || std::fs::read(y).expect("failed to read y"),
3089    );
3090    nte_bytes(&bx, &by, max_order)
3091}
3092
3093/// TVD for files.
3094pub fn tvd_paths(x: &str, y: &str, max_order: i64) -> f64 {
3095    let (bx, by) = rayon::join(
3096        || std::fs::read(x).expect("failed to read x"),
3097        || std::fs::read(y).expect("failed to read y"),
3098    );
3099    tvd_bytes(&bx, &by, max_order)
3100}
3101
3102/// NHD for files.
3103pub fn nhd_paths(x: &str, y: &str, max_order: i64) -> f64 {
3104    let (bx, by) = rayon::join(
3105        || std::fs::read(x).expect("failed to read x"),
3106        || std::fs::read(y).expect("failed to read y"),
3107    );
3108    nhd_bytes(&bx, &by, max_order)
3109}
3110
3111/// Mutual Information for files.
3112pub fn mutual_information_paths(x: &str, y: &str, max_order: i64) -> f64 {
3113    let (bx, by) = rayon::join(
3114        || std::fs::read(x).expect("failed to read x"),
3115        || std::fs::read(y).expect("failed to read y"),
3116    );
3117    mutual_information_bytes(&bx, &by, max_order)
3118}
3119
3120/// Conditional Entropy for files.
3121pub fn conditional_entropy_paths(x: &str, y: &str, max_order: i64) -> f64 {
3122    let (bx, by) = rayon::join(
3123        || std::fs::read(x).expect("failed to read x"),
3124        || std::fs::read(y).expect("failed to read y"),
3125    );
3126    conditional_entropy_bytes(&bx, &by, max_order)
3127}
3128
3129/// Cross-Entropy for files.
3130pub fn cross_entropy_paths(x: &str, y: &str, max_order: i64) -> f64 {
3131    let (bx, by) = rayon::join(
3132        || std::fs::read(x).expect("failed to read x"),
3133        || std::fs::read(y).expect("failed to read y"),
3134    );
3135    cross_entropy_bytes(&bx, &by, max_order)
3136}
3137
3138/// KL Divergence for files.
3139pub fn kl_divergence_paths(x: &str, y: &str) -> f64 {
3140    let (bx, by) = rayon::join(
3141        || std::fs::read(x).expect("failed to read x"),
3142        || std::fs::read(y).expect("failed to read y"),
3143    );
3144    d_kl_bytes(&bx, &by)
3145}
3146
3147/// Jensen-Shannon Divergence for files.
3148pub fn js_divergence_paths(x: &str, y: &str) -> f64 {
3149    let (bx, by) = rayon::join(
3150        || std::fs::read(x).expect("failed to read x"),
3151        || std::fs::read(y).expect("failed to read y"),
3152    );
3153    js_div_bytes(&bx, &by)
3154}
3155
3156// ====== Primitives 6 & 7 ======
3157
3158/// Primitive 6: Intrinsic Dependence (Redundancy Ratio).
3159///
3160/// Measures how much structure is intrinsic to the sample, relative to its
3161/// own marginal entropy baseline.
3162///
3163/// `R = (H_marginal - H_rate) / H_marginal`
3164///
3165/// Clamped to `[0,1]`.
3166///
3167/// Interpretation:
3168///   - `R → 0`: Data is close to i.i.d./max-entropy (little intrinsic structure; highly extrinsically explainable by priors).
3169///   - `R → 1`: Data is highly predictable from its own past (strong intrinsic dependence; e.g., periodic strings like 010101...).
3170#[inline(always)]
3171pub fn intrinsic_dependence_bytes(data: &[u8], max_order: i64) -> f64 {
3172    with_default_ctx(|ctx| ctx.intrinsic_dependence_bytes(data, max_order))
3173}
3174
3175/// Primitive 7: Resistance under Allowed Transformations.
3176///
3177/// Measures how much information is preserved after a transformation `T` is applied to `X`.
3178///
3179/// `Resistance(X, T) = I(X; T(X)) / H(X)`
3180///
3181/// Range `[0,1]` (with guard for `H(X)=0`).
3182/// * 1 means perfectly resistant (identity transformation).
3183/// * 0 means the transformation destroyed all information (e.g. mapping everything to a constant).
3184///
3185/// Assumes X and T(X) are aligned.
3186#[inline(always)]
3187pub fn resistance_to_transformation_bytes(x: &[u8], tx: &[u8], max_order: i64) -> f64 {
3188    with_default_ctx(|ctx| ctx.resistance_to_transformation_bytes(x, tx, max_order))
3189}
3190
3191#[cfg(test)]
3192mod tests {
3193    use super::*;
3194
3195    fn test_match_backend() -> RateBackend {
3196        RateBackend::Match {
3197            hash_bits: 12,
3198            min_len: 2,
3199            max_len: 16,
3200            base_mix: 0.01,
3201            confidence_scale: 1.0,
3202        }
3203    }
3204
3205    fn test_ppmd_backend() -> RateBackend {
3206        RateBackend::Ppmd {
3207            order: 4,
3208            memory_mb: 1,
3209        }
3210    }
3211
3212    fn test_calibrated_backend() -> RateBackend {
3213        RateBackend::Calibrated {
3214            spec: Arc::new(CalibratedSpec {
3215                base: test_match_backend(),
3216                context: CalibrationContextKind::Text,
3217                bins: 16,
3218                learning_rate: 0.05,
3219                bias_clip: 4.0,
3220            }),
3221        }
3222    }
3223
3224    fn test_mixture_backend() -> RateBackend {
3225        RateBackend::Mixture {
3226            spec: Arc::new(MixtureSpec::new(
3227                MixtureKind::Bayes,
3228                vec![
3229                    MixtureExpertSpec {
3230                        name: Some("match".to_string()),
3231                        log_prior: 0.0,
3232                        max_order: -1,
3233                        backend: test_match_backend(),
3234                    },
3235                    MixtureExpertSpec {
3236                        name: Some("ppmd".to_string()),
3237                        log_prior: 0.0,
3238                        max_order: -1,
3239                        backend: test_ppmd_backend(),
3240                    },
3241                ],
3242            )),
3243        }
3244    }
3245
3246    fn test_particle_backend() -> RateBackend {
3247        RateBackend::Particle {
3248            spec: Arc::new(ParticleSpec {
3249                num_particles: 4,
3250                num_cells: 4,
3251                cell_dim: 8,
3252                num_rules: 2,
3253                selector_hidden: 16,
3254                rule_hidden: 16,
3255                context_window: 8,
3256                unroll_steps: 1,
3257                ..ParticleSpec::default()
3258            }),
3259        }
3260    }
3261
3262    fn continuation_prompt() -> &'static [u8] {
3263        b"If a frog is green, dogs are red.\nIf a toad is green, cats are red.\nIf a dog is green, frogs are red.\nIf a cat is green, toads are red.\nIf a frog is red, dogs are green.\nIf a toad is red, cats are green.\nIf a dog is red, frogs are green.\nIf a cat is red, toads are \n"
3264    }
3265
3266    fn assert_deterministic_generate_for_backend(
3267        backend: RateBackend,
3268        max_order: i64,
3269        bytes: usize,
3270        label: &str,
3271    ) {
3272        let prompt = continuation_prompt();
3273        let a = generate_rate_backend_chain(
3274            &[prompt],
3275            bytes,
3276            max_order,
3277            &backend,
3278            GenerationConfig::default(),
3279        );
3280        let b = generate_rate_backend_chain(
3281            &[prompt],
3282            bytes,
3283            max_order,
3284            &backend,
3285            GenerationConfig::default(),
3286        );
3287        assert_eq!(
3288            a, b,
3289            "{label} generation should be deterministic for identical input"
3290        );
3291        assert_eq!(
3292            a.len(),
3293            bytes,
3294            "{label} generation should emit requested byte count"
3295        );
3296    }
3297
3298    fn assert_sampled_generate_for_backend(
3299        backend: RateBackend,
3300        max_order: i64,
3301        bytes: usize,
3302        label: &str,
3303    ) {
3304        let prompt = continuation_prompt();
3305        let config = GenerationConfig::sampled_frozen(42);
3306        let a = generate_rate_backend_chain(&[prompt], bytes, max_order, &backend, config);
3307        let b = generate_rate_backend_chain(&[prompt], bytes, max_order, &backend, config);
3308        assert_eq!(
3309            a, b,
3310            "{label} sampled generation should be deterministic for a fixed seed"
3311        );
3312        assert_eq!(
3313            a.len(),
3314            bytes,
3315            "{label} sampled generation should emit requested byte count"
3316        );
3317    }
3318
3319    #[cfg(feature = "backend-zpaq")]
3320    #[test]
3321    fn ncd_basic_identity_nonnegative() {
3322        let x = b"abcdabcdabcd";
3323        let d = ncd_bytes(x, x, "5", NcdVariant::Vitanyi);
3324        assert!(d >= -1e-9);
3325    }
3326
3327    #[test]
3328    fn shannon_identities_marginal_aligned() {
3329        let x = b"abracadabra";
3330        let y = b"abracadabra";
3331
3332        let h = marginal_entropy_bytes(x);
3333        let mi = mutual_information_bytes(x, y, 0);
3334        let h_xy = joint_marginal_entropy_bytes(x, y);
3335        let h_x_given_y = conditional_entropy_bytes(x, y, 0);
3336        let ned = ned_bytes(x, y, 0);
3337        let nte = nte_bytes(x, y, 0);
3338
3339        assert!((h_xy - h).abs() < 1e-12);
3340        assert!(h_x_given_y.abs() < 1e-12);
3341        assert!((mi - h).abs() < 1e-12);
3342        assert!(ned.abs() < 1e-12);
3343        assert!(nte.abs() < 1e-12);
3344    }
3345
3346    #[test]
3347    fn shannon_identities_rate_aligned_reasonable() {
3348        let x = b"the quick brown fox jumps over the lazy dog";
3349        let y = b"the quick brown fox jumps over the lazy dog";
3350        let max_order = 8;
3351        let prev = get_default_ctx();
3352        set_default_ctx(InfotheoryCtx::new(
3353            RateBackend::RosaPlus,
3354            CompressionBackend::default(),
3355        ));
3356
3357        let h_x = entropy_rate_bytes(x, max_order);
3358        let h_xy = joint_entropy_rate_bytes(x, y, max_order);
3359        let h_x_given_y = conditional_entropy_rate_bytes(x, y, max_order);
3360        let mi = mutual_information_bytes(x, y, max_order);
3361        let ned = ned_bytes(x, y, max_order);
3362
3363        // Finite-sample estimators won't be exact; allow reasonable tolerance.
3364        let tol = 0.2;
3365        assert!((h_xy - h_x).abs() < tol);
3366        assert!(h_x_given_y < tol);
3367        assert!((mi - h_x).abs() < tol);
3368        assert!(ned < tol);
3369        set_default_ctx(prev);
3370    }
3371
3372    #[test]
3373    fn resistance_identity_is_one() {
3374        let x = b"some repeated repeated repeated text";
3375        let prev = get_default_ctx();
3376        set_default_ctx(InfotheoryCtx::new(
3377            RateBackend::RosaPlus,
3378            CompressionBackend::default(),
3379        ));
3380        let r0 = resistance_to_transformation_bytes(x, x, 0);
3381        let r8 = resistance_to_transformation_bytes(x, x, 8);
3382        assert!((r0 - 1.0).abs() < 1e-12);
3383        assert!((r8 - 1.0).abs() < 1e-6);
3384        set_default_ctx(prev);
3385    }
3386
3387    #[test]
3388    fn marginal_metrics_empty_inputs_are_zero() {
3389        let empty: &[u8] = &[];
3390        let x = b"abc";
3391
3392        assert_eq!(tvd_bytes(empty, x, 0), 0.0);
3393        assert_eq!(tvd_bytes(x, empty, 0), 0.0);
3394        assert_eq!(nhd_bytes(empty, x, 0), 0.0);
3395        assert_eq!(nhd_bytes(x, empty, 0), 0.0);
3396        assert_eq!(d_kl_bytes(empty, x), 0.0);
3397        assert_eq!(d_kl_bytes(x, empty), 0.0);
3398        assert_eq!(js_div_bytes(empty, x), 0.0);
3399        assert_eq!(js_div_bytes(x, empty), 0.0);
3400    }
3401
3402    #[test]
3403    fn marginal_cross_entropy_empty_test_is_zero() {
3404        let empty: &[u8] = &[];
3405        let y = b"abc";
3406        let ctx = InfotheoryCtx::with_zpaq("5");
3407        assert_eq!(ctx.cross_entropy_bytes(empty, y, 0), 0.0);
3408    }
3409
3410    #[cfg(not(feature = "backend-zpaq"))]
3411    #[test]
3412    #[should_panic(expected = "CompressionBackend::Zpaq is unavailable")]
3413    fn explicit_zpaq_backend_fails_loudly() {
3414        let backend = CompressionBackend::Zpaq {
3415            method: "5".to_string(),
3416        };
3417        let _ = compress_size_backend(b"abc", &backend);
3418    }
3419
3420    #[cfg(not(feature = "backend-zpaq"))]
3421    #[test]
3422    fn default_compression_backend_falls_back_to_rate_coding() {
3423        let backend = CompressionBackend::default();
3424        assert!(matches!(
3425            &backend,
3426            CompressionBackend::Rate {
3427                coder: crate::coders::CoderType::AC,
3428                framing: crate::compression::FramingMode::Raw,
3429                ..
3430            }
3431        ));
3432        assert!(compress_size_backend(b"abc", &backend) > 0);
3433    }
3434
3435    #[test]
3436    fn backend_switching_test() {
3437        let x = b"hello world context";
3438
3439        // Default is RosaPlus
3440        let h_rosa = entropy_rate_bytes(x, 8);
3441
3442        // Switch to CTW
3443        set_default_ctx(InfotheoryCtx::new(
3444            RateBackend::Ctw { depth: 16 },
3445            CompressionBackend::default(),
3446        ));
3447
3448        let h_ctw = entropy_rate_bytes(x, 8);
3449
3450        // They should generally be different, but most importantly, CTW worked
3451        assert!(h_ctw > 0.0);
3452
3453        // Reset to default
3454        set_default_ctx(InfotheoryCtx::default());
3455        let h_rosa_back = entropy_rate_bytes(x, 8);
3456        assert!((h_rosa - h_rosa_back).abs() < 1e-12);
3457    }
3458
3459    #[test]
3460    fn ctw_early_updates_work() {
3461        // Test that CTW produces valid predictions from the very start,
3462        // not just after `depth` symbols have been processed.
3463        use crate::ctw::ContextTree;
3464
3465        let mut tree = ContextTree::new(16);
3466
3467        // Even the first prediction should be valid (not NaN, not 0)
3468        let p0 = tree.predict(false);
3469        let p1 = tree.predict(true);
3470
3471        // Initial KT estimator gives 0.5 / 1 = 0.5 for each symbol
3472        assert!((p0 - 0.5).abs() < 1e-10, "p0 should be ~0.5, got {}", p0);
3473        assert!((p1 - 0.5).abs() < 1e-10, "p1 should be ~0.5, got {}", p1);
3474        assert!((p0 + p1 - 1.0).abs() < 1e-10, "p0 + p1 should = 1.0");
3475
3476        // Update with a few symbols and verify log_prob becomes negative (valid)
3477        for _ in 0..5 {
3478            tree.update(true);
3479            tree.update(false);
3480        }
3481
3482        let log_prob = tree.get_log_block_probability();
3483        assert!(
3484            log_prob < 0.0,
3485            "log_prob should be negative (< log 1), got {}",
3486            log_prob
3487        );
3488        assert!(log_prob.is_finite(), "log_prob should be finite");
3489    }
3490
3491    #[test]
3492    fn nte_can_exceed_one() {
3493        // Test that NTE is properly clamped to [0, 2] instead of [0, 1]
3494        // For independent sequences with similar entropy, NTE can approach 2.0
3495        //
3496        // Note: For *marginal* NTE, due to how joint entropy works for aligned pairs,
3497        // it's mathematically bounded differently. The fix for NTE clamping primarily
3498        // affects *rate*-based NTE where VI can truly be 2*max(H).
3499        //
3500        // We test that the clamp upper bound is at least > 1.0 for cases where VI > max(H)
3501
3502        // Use CTW backend for rate-based test
3503        set_default_ctx(InfotheoryCtx::new(
3504            RateBackend::Ctw { depth: 8 },
3505            CompressionBackend::default(),
3506        ));
3507
3508        // Generate two completely different patterns - should have high VI
3509        let x: Vec<u8> = (0..200).map(|i| (i % 2) as u8).collect(); // 010101...
3510        let y: Vec<u8> = (0..200).map(|i| ((i + 1) % 2) as u8).collect(); // 101010...
3511
3512        let nte_rate = nte_rate_backend(&x, &y, -1, &RateBackend::Ctw { depth: 8 });
3513
3514        // With the fix, NTE should not be clamped to 1.0
3515        // It may or may not exceed 1.0 depending on the specifics, but it should be allowed to
3516        assert!(
3517            (0.0..=2.0 + 1e-9).contains(&nte_rate),
3518            "NTE should be in [0, 2], got {}",
3519            nte_rate
3520        );
3521
3522        // Reset context
3523        set_default_ctx(InfotheoryCtx::default());
3524    }
3525
3526    #[test]
3527    fn ctw_empty_data_returns_zero() {
3528        // Verify empty data doesn't cause division-by-zero or NaN
3529        set_default_ctx(InfotheoryCtx::new(
3530            RateBackend::Ctw { depth: 16 },
3531            CompressionBackend::default(),
3532        ));
3533
3534        let empty: &[u8] = &[];
3535        let h = entropy_rate_bytes(empty, -1);
3536        assert_eq!(h, 0.0, "empty data should return 0.0 entropy");
3537
3538        // Reset
3539        set_default_ctx(InfotheoryCtx::default());
3540    }
3541
3542    #[test]
3543    fn joint_entropy_rate_aligns_inputs_and_handles_empty_cases() {
3544        let cases = vec![
3545            ("ctw", RateBackend::Ctw { depth: 8 }),
3546            (
3547                "fac-ctw",
3548                RateBackend::FacCtw {
3549                    base_depth: 8,
3550                    num_percept_bits: 8,
3551                    encoding_bits: 8,
3552                },
3553            ),
3554            ("match", test_match_backend()),
3555        ];
3556
3557        for (name, backend) in cases {
3558            assert_eq!(
3559                joint_entropy_rate_backend(b"", b"nonempty", -1, &backend),
3560                0.0,
3561                "{name} should return 0.0 for empty aligned pairs"
3562            );
3563            assert_eq!(
3564                joint_entropy_rate_backend(b"nonempty", b"", -1, &backend),
3565                0.0,
3566                "{name} should return 0.0 when alignment truncates to empty"
3567            );
3568
3569            let aligned = joint_entropy_rate_backend(b"abcd", b"wxyz", -1, &backend);
3570            let truncated = joint_entropy_rate_backend(b"abcdextra", b"wxyz", -1, &backend);
3571            assert!(
3572                (aligned - truncated).abs() < 1e-12,
3573                "{name} should score only the aligned prefix: aligned={aligned} truncated={truncated}"
3574            );
3575        }
3576    }
3577
3578    #[test]
3579    fn biased_entropy_is_repeatable_across_backend_families() {
3580        let data = b"ABABABAABBABABABAABB";
3581        let cases = vec![
3582            ("match", test_match_backend()),
3583            ("ppmd", test_ppmd_backend()),
3584            ("calibrated", test_calibrated_backend()),
3585            ("ctw", RateBackend::Ctw { depth: 8 }),
3586            ("mixture", test_mixture_backend()),
3587            ("particle", test_particle_backend()),
3588        ];
3589
3590        for (name, backend) in cases {
3591            let h1 = biased_entropy_rate_backend(data, -1, &backend);
3592            let h2 = biased_entropy_rate_backend(data, -1, &backend);
3593            assert!(h1.is_finite(), "{name} biased entropy should be finite");
3594            assert!(
3595                (h1 - h2).abs() < 1e-12,
3596                "{name} biased entropy leaked mutable state across calls: h1={h1} h2={h2}"
3597            );
3598        }
3599    }
3600
3601    #[test]
3602    fn generate_bytes_chain_matches_flat_prompt() {
3603        let prompt = continuation_prompt();
3604        let split_at = prompt.len() / 2;
3605        let front = &prompt[..split_at];
3606        let back = &prompt[split_at..];
3607        let backend = RateBackend::Ctw { depth: 32 };
3608        let bytes = 8usize;
3609        let max_order = -1;
3610
3611        let flat = generate_rate_backend_chain(
3612            &[prompt],
3613            bytes,
3614            max_order,
3615            &backend,
3616            GenerationConfig::default(),
3617        );
3618        let chained = generate_rate_backend_chain(
3619            &[front, back],
3620            bytes,
3621            max_order,
3622            &backend,
3623            GenerationConfig::default(),
3624        );
3625        assert_eq!(
3626            flat, chained,
3627            "chain conditioning should match flat prompt conditioning"
3628        );
3629    }
3630
3631    #[test]
3632    fn generate_bytes_api_is_deterministic_for_ctw_rosa_match_ppmd() {
3633        assert_deterministic_generate_for_backend(RateBackend::Ctw { depth: 32 }, -1, 8, "ctw");
3634        assert_deterministic_generate_for_backend(RateBackend::RosaPlus, -1, 8, "rosaplus");
3635        assert_deterministic_generate_for_backend(test_match_backend(), -1, 8, "match");
3636        assert_deterministic_generate_for_backend(test_ppmd_backend(), -1, 8, "ppmd");
3637    }
3638
3639    #[cfg(feature = "backend-rwkv")]
3640    #[test]
3641    fn generate_bytes_api_is_deterministic_for_rwkv_method() {
3642        let backend = RateBackend::Rwkv7Method {
3643            method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3644        };
3645        assert_deterministic_generate_for_backend(backend, -1, 8, "rwkv7");
3646    }
3647
3648    #[test]
3649    fn sampled_generation_is_deterministic_for_ctw_rosa_match_ppmd() {
3650        assert_sampled_generate_for_backend(RateBackend::Ctw { depth: 32 }, -1, 8, "ctw");
3651        assert_sampled_generate_for_backend(RateBackend::RosaPlus, -1, 8, "rosaplus");
3652        assert_sampled_generate_for_backend(test_match_backend(), -1, 8, "match");
3653        assert_sampled_generate_for_backend(test_ppmd_backend(), -1, 8, "ppmd");
3654    }
3655
3656    #[cfg(feature = "backend-rwkv")]
3657    #[test]
3658    fn sampled_generation_is_deterministic_for_rwkv_method() {
3659        let backend = RateBackend::Rwkv7Method {
3660            method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3661        };
3662        assert_sampled_generate_for_backend(backend, -1, 8, "rwkv7");
3663    }
3664
3665    #[test]
3666    fn rosaplus_sampled_generation_predicts_green_continuation() {
3667        let out = generate_rate_backend_chain(
3668            &[continuation_prompt()],
3669            8,
3670            -1,
3671            &RateBackend::RosaPlus,
3672            GenerationConfig::sampled_frozen(42),
3673        );
3674        assert_eq!(out, b" green.\n");
3675    }
3676
3677    #[test]
3678    fn rate_backend_session_matches_ctx_generation() {
3679        let prompt = continuation_prompt();
3680        let backend = RateBackend::Ppmd {
3681            order: 12,
3682            memory_mb: 8,
3683        };
3684        let mut session =
3685            RateBackendSession::from_backend(backend.clone(), -1, Some((prompt.len() + 8) as u64))
3686                .expect("session init");
3687        session.observe(prompt);
3688        let from_session = session.generate_bytes(8, GenerationConfig::sampled_frozen(42));
3689        session.finish().expect("session finish");
3690
3691        let ctx = InfotheoryCtx::new(backend, CompressionBackend::default());
3692        let from_ctx =
3693            ctx.generate_bytes_with_config(prompt, 8, -1, GenerationConfig::sampled_frozen(42));
3694        assert_eq!(from_session, from_ctx);
3695    }
3696
3697    #[test]
3698    fn biased_entropy_ctw_uses_frozen_plugin_scoring() {
3699        let backend = RateBackend::Ctw { depth: 8 };
3700        let data = b"AAAAAAAA";
3701        let plugin = biased_entropy_rate_backend(data, -1, &backend);
3702        let prequential = entropy_rate_backend(data, -1, &backend);
3703        assert!(
3704            plugin + 1e-9 < prequential,
3705            "expected plugin scoring to beat prequential scoring: plugin={plugin} prequential={prequential}"
3706        );
3707    }
3708
3709    #[test]
3710    fn rosa_plugin_entropy_matches_direct_model_api() {
3711        let data = b"abracadabra";
3712        let backend = RateBackend::RosaPlus;
3713
3714        let plugin = biased_entropy_rate_backend(data, 3, &backend);
3715
3716        let mut direct = rosaplus::RosaPlus::new(3, false, 0, 42);
3717        direct.train_example(data);
3718        direct.build_lm();
3719        let expected = direct.cross_entropy(data);
3720
3721        assert!(
3722            (plugin - expected).abs() < 1e-12,
3723            "rosa plugin entropy must match direct model API: plugin={plugin} expected={expected}"
3724        );
3725    }
3726
3727    #[test]
3728    fn rosa_plugin_cross_entropy_matches_direct_model_api() {
3729        let train = b"alakazam";
3730        let test = b"abracadabra";
3731        let backend = RateBackend::RosaPlus;
3732
3733        let plugin = cross_entropy_rate_backend(test, train, 3, &backend);
3734
3735        let mut direct = rosaplus::RosaPlus::new(3, false, 0, 42);
3736        direct.train_example(train);
3737        direct.build_lm();
3738        let expected = direct.cross_entropy(test);
3739
3740        assert!(
3741            (plugin - expected).abs() < 1e-12,
3742            "rosa plugin cross entropy must match direct model API: plugin={plugin} expected={expected}"
3743        );
3744    }
3745
3746    #[test]
3747    fn datagen_bernoulli_entropy_estimate() {
3748        // Test that estimated entropy is close to theoretical for Bernoulli(0.5)
3749        let p = 0.5;
3750        let theoretical_h = crate::datagen::bernoulli_entropy(p);
3751        assert!((theoretical_h - 1.0).abs() < 1e-10);
3752
3753        // Generate data and check marginal entropy is close to theoretical
3754        let data = crate::datagen::bernoulli(10000, p, 42);
3755        let estimated_h = marginal_entropy_bytes(&data);
3756
3757        // Should be close to 1.0 bit (since values are 0 or 1)
3758        assert!(
3759            (estimated_h - theoretical_h).abs() < 0.1,
3760            "estimated H={} should be close to theoretical H={}",
3761            estimated_h,
3762            theoretical_h
3763        );
3764    }
3765
3766    #[cfg(feature = "backend-rwkv")]
3767    #[test]
3768    fn rwkv_method_entropy_is_stable_across_calls() {
3769        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=21,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:infer";
3770        let backend = RateBackend::Rwkv7Method {
3771            method: method.to_string(),
3772        };
3773        let data = b"rwkv method entropy stability regression sample";
3774
3775        let h1 = entropy_rate_backend(data, -1, &backend);
3776        let h2 = entropy_rate_backend(data, -1, &backend);
3777        assert!(
3778            (h1 - h2).abs() < 1e-12,
3779            "rwkv method entropy leaked mutable state across calls: h1={h1}, h2={h2}"
3780        );
3781    }
3782
3783    #[cfg(feature = "backend-rwkv")]
3784    #[test]
3785    fn rwkv_method_without_policy_is_accepted_by_public_api() {
3786        let backend = RateBackend::Rwkv7Method {
3787            method: "cfg:hidden=64,layers=1,intermediate=64".to_string(),
3788        };
3789        let data = b"rwkv method without policy";
3790        let h1 = entropy_rate_backend(data, -1, &backend);
3791        let h2 = biased_entropy_rate_backend(data, -1, &backend);
3792        assert!(h1.is_finite());
3793        assert!(h2.is_finite());
3794    }
3795
3796    #[cfg(feature = "backend-rwkv")]
3797    #[test]
3798    fn rwkv_infer_only_plugin_collapses_to_single_pass_entropy() {
3799        let backend = RateBackend::Rwkv7Method {
3800            method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=25,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3801        };
3802        let data = b"rwkv infer-only plugin equality sample";
3803        let h = entropy_rate_backend(data, -1, &backend);
3804        let plugin = biased_entropy_rate_backend(data, -1, &backend);
3805        assert!(
3806            (h - plugin).abs() < 1e-12,
3807            "infer-only rwkv plugin should equal single-pass entropy: h={h}, plugin={plugin}"
3808        );
3809    }
3810
3811    #[cfg(feature = "backend-rwkv")]
3812    #[test]
3813    fn rwkv_method_biased_entropy_is_stable_across_calls_with_training_policy() {
3814        let backend = RateBackend::Rwkv7Method {
3815            method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=23,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.0)".to_string(),
3816        };
3817        let data = b"rwkv plugin stability sample";
3818        let h1 = biased_entropy_rate_backend(data, -1, &backend);
3819        let h2 = biased_entropy_rate_backend(data, -1, &backend);
3820        assert!(
3821            (h1 - h2).abs() < 1e-12,
3822            "rwkv method biased entropy leaked mutable state across calls: h1={h1}, h2={h2}"
3823        );
3824    }
3825
3826    #[cfg(feature = "backend-rwkv")]
3827    #[test]
3828    fn rwkv_method_conditional_chain_is_stable_across_calls() {
3829        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=22,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:infer";
3830        let ctx = InfotheoryCtx::new(
3831            RateBackend::Rwkv7Method {
3832                method: method.to_string(),
3833            },
3834            CompressionBackend::default(),
3835        );
3836
3837        let prefix = b"universal prior slice";
3838        let data = b"query payload";
3839        let h1 = ctx.cross_entropy_conditional_chain(&[prefix.as_slice()], data);
3840        let h2 = ctx.cross_entropy_conditional_chain(&[prefix.as_slice()], data);
3841        assert!(
3842            (h1 - h2).abs() < 1e-12,
3843            "rwkv method conditional chain leaked mutable state across calls: h1={h1}, h2={h2}"
3844        );
3845    }
3846
3847    #[cfg(feature = "backend-mamba")]
3848    #[test]
3849    fn mamba_method_without_policy_is_accepted_by_public_api() {
3850        let backend = RateBackend::MambaMethod {
3851            method: "cfg:hidden=64,layers=1,intermediate=96".to_string(),
3852        };
3853        let data = b"mamba method without policy";
3854        let h1 = entropy_rate_backend(data, -1, &backend);
3855        let h2 = biased_entropy_rate_backend(data, -1, &backend);
3856        assert!(h1.is_finite());
3857        assert!(h2.is_finite());
3858    }
3859
3860    #[cfg(feature = "backend-mamba")]
3861    #[test]
3862    fn mamba_infer_only_plugin_collapses_to_single_pass_entropy() {
3863        let backend = RateBackend::MambaMethod {
3864            method: "cfg:hidden=64,layers=1,intermediate=96,state=16,conv=4,dt_rank=16,seed=26,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
3865        };
3866        let data = b"mamba infer-only plugin equality sample";
3867        let h = entropy_rate_backend(data, -1, &backend);
3868        let plugin = biased_entropy_rate_backend(data, -1, &backend);
3869        assert!(
3870            (h - plugin).abs() < 1e-12,
3871            "infer-only mamba plugin should equal single-pass entropy: h={h}, plugin={plugin}"
3872        );
3873    }
3874
3875    #[cfg(feature = "backend-mamba")]
3876    #[test]
3877    fn mamba_method_biased_entropy_is_stable_across_calls_with_training_policy() {
3878        let backend = RateBackend::MambaMethod {
3879            method: "cfg:hidden=64,layers=1,intermediate=96,state=16,conv=4,dt_rank=16,seed=24,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.0)".to_string(),
3880        };
3881        let data = b"mamba plugin stability sample";
3882        let h1 = biased_entropy_rate_backend(data, -1, &backend);
3883        let h2 = biased_entropy_rate_backend(data, -1, &backend);
3884        assert!(
3885            (h1 - h2).abs() < 1e-12,
3886            "mamba method biased entropy leaked mutable state across calls: h1={h1}, h2={h2}"
3887        );
3888    }
3889
3890    #[test]
3891    fn particle_entropy_rate_in_valid_range() {
3892        let rb = test_particle_backend();
3893        let data = b"hello world particle backend test";
3894        let rate = entropy_rate_backend(data, -1, &rb);
3895        assert!(
3896            rate > 0.0 && rate < 8.0,
3897            "particle entropy rate out of (0, 8) range: {rate}"
3898        );
3899    }
3900
3901    #[test]
3902    fn particle_cross_entropy_stability() {
3903        let rb = test_particle_backend();
3904        let train = b"ABCABC";
3905        let test = b"ABC";
3906        let h1 = cross_entropy_rate_backend(test, train, -1, &rb);
3907        let h2 = cross_entropy_rate_backend(test, train, -1, &rb);
3908        assert!(
3909            (h1 - h2).abs() < 1e-12,
3910            "particle cross entropy not deterministic: h1={h1}, h2={h2}"
3911        );
3912    }
3913
3914    #[test]
3915    fn particle_empty_input() {
3916        let rb = RateBackend::Particle {
3917            spec: Arc::new(ParticleSpec::default()),
3918        };
3919        let rate = entropy_rate_backend(b"", -1, &rb);
3920        assert!(
3921            rate == 0.0,
3922            "particle entropy rate for empty input should be 0.0, got {rate}"
3923        );
3924    }
3925
3926    #[test]
3927    fn particle_joint_entropy_rate() {
3928        let rb = test_particle_backend();
3929        let x = b"AAAA";
3930        let y = b"BBBB";
3931        let joint = joint_entropy_rate_backend(x, y, -1, &rb);
3932        assert!(
3933            joint > 0.0 && joint < 16.0,
3934            "particle joint entropy rate out of range: {joint}"
3935        );
3936    }
3937}