infotheory/
lib.rs

1//! # InfoTheory: Information Theoretic Estimators & Metrics
2//!
3//! This crate provides a comprehensive suite of information-theoretic primitives for
4//! quantifying complexity, dependence, and similarity between data sequences.
5//!
6//! It implements two primary classes of estimators:
7//! 1.  **Compression-based (Kolmogorov Complexity)**: Using the ZPAQ compression algorithm to estimate
8//!     Normalized Compression Distance (NCD).
9//! 2.  **Entropy-based (Shannon Information)**: Using both exact marginal histograms (for i.i.d. data)
10//!     and the ROSA (Rapid Online Suffix Automaton) predictive language model (for sequential data)
11//!     to estimate Entropy, Mutual Information, and related distances.
12//!
13//! ## Mathematical Primitives
14//!
15//! The library implements the following core measures. For sequential data, "Rate" variants
16//! use the ROSA model to estimate `Ĥ(X)` (entropy rate), while "Marginal" variants
17//! treat data as a bag-of-bytes (i.i.d.) and compute `H(X)` from histograms.
18//!
19//! ### 1. Normalized Compression Distance (NCD)
20//! Approximates the Normalized Information Distance (NID) using a compressor `C`.
21//!
22//! `NCD(x,y) = (C(xy) - min(C(x), C(y))) / max(C(x), C(y))`
23//!
24//! ### 2. Normalized Entropy Distance (NED)
25//! An entropic analogue to NCD, defined using Shannon entropy `H`.
26//!
27//! `NED(X,Y) = (H(X,Y) - min(H(X), H(Y))) / max(H(X), H(Y))`
28//!
29//! ### 3. Normalized Transform Effort (NTE)
30//! Based on the Variation of Information (VI), normalized by the maximum entropy.
31//!
32//! `NTE(X,Y) = (H(X|Y) + H(Y|X)) / max(H(X), H(Y)) = (2H(X,Y) - H(X) - H(Y)) / max(H(X), H(Y))`
33//!
34//! ### 4. Mutual Information (MI)
35//! Measures the amount of information obtained about one random variable by observing another.
36//!
37//! `I(X;Y) = H(X) + H(Y) - H(X,Y)`
38//!
39//! ### 5. Divergences & Distances
40//! *   **Total Variation Distance (TVD)**: `δ(P,Q) = 0.5 * Σ |P(x) - Q(x)|`
41//! *   **Normalized Hellinger Distance (NHD)**: `sqrt(1 - Σ sqrt(P(x)Q(x)))`
42//! *   **Kullback-Leibler Divergence (KL)**: `D_KL(P||Q) = Σ P(x) log(P(x)/Q(x))`
43//! *   **Jensen-Shannon Divergence (JSD)**: Symmetrized and smoothed KL divergence.
44//!
45//! ### 6. Intrinsic Dependence (ID)
46//! Measures the redundancy within a sequence, comparing marginal entropy to entropy rate.
47//!
48//! `ID(X) = (H_marginal(X) - H_rate(X)) / H_marginal(X)`
49//!
50//! ### 7. Resistance to Transformation
51//! Quantifies how much information is preserved after a transformation `T` is applied.
52//!
53//! `R(X, T) = I(X; T(X)) / H(X)`
54//!
55//! ## Usage
56//!
57//! ```rust,no_run
58//! use infotheory::{ncd_vitanyi, mutual_information_bytes, NcdVariant};
59//!
60//! let x = b"some data sequence";
61//! let y = b"another data sequence";
62//!
63//! // Compression-based distance
64//! let ncd = ncd_vitanyi("file1.txt", "file2.txt", "5");
65//!
66//! // Entropy-based mutual information (Marginal / i.i.d.)
67//! let mi_marg = mutual_information_bytes(x, y, 0);
68//!
69//! // Entropy-based mutual information (Rate / Sequential, max_order=8)
70//! let mi_rate = mutual_information_bytes(x, y, 8);
71//! ```
72
73pub mod aixi;
74pub mod axioms;
75pub mod ctw;
76pub mod datagen;
77pub mod mixture;
78mod zpaq_rate;
79
80use rayon::prelude::*;
81
82use std::cell::RefCell;
83use std::collections::HashMap;
84use std::sync::Arc;
85use std::sync::OnceLock;
86
87static NUM_THREADS: OnceLock<usize> = OnceLock::new();
88
89thread_local! {
90    static RWKV_TLS: RefCell<HashMap<usize, rwkvzip::Compressor>> = RefCell::new(HashMap::new());
91}
92
93impl Default for RateBackend {
94    fn default() -> Self {
95        RateBackend::RosaPlus
96    }
97}
98
99impl Default for NcdBackend {
100    fn default() -> Self {
101        NcdBackend::Zpaq {
102            method: "5".to_string(),
103        }
104    }
105}
106
107impl Default for InfotheoryCtx {
108    fn default() -> Self {
109        Self {
110            rate_backend: RateBackend::default(),
111            ncd_backend: NcdBackend::default(),
112        }
113    }
114}
115
116thread_local! {
117    static DEFAULT_CTX: RefCell<InfotheoryCtx> = RefCell::new(InfotheoryCtx::default());
118}
119
120/// Returns the current default information theory context for the thread.
121pub fn get_default_ctx() -> InfotheoryCtx {
122    DEFAULT_CTX.with(|ctx| ctx.borrow().clone())
123}
124
125/// Sets the default information theory context for the thread.
126pub fn set_default_ctx(ctx: InfotheoryCtx) {
127    DEFAULT_CTX.with(|c| *c.borrow_mut() = ctx);
128}
129
130#[inline(always)]
131fn with_default_ctx<R>(f: impl FnOnce(&InfotheoryCtx) -> R) -> R {
132    DEFAULT_CTX.with(|ctx| f(&*ctx.borrow()))
133}
134
135pub fn mutual_information_rate_backend(
136    x: &[u8],
137    y: &[u8],
138    max_order: i64,
139    backend: &RateBackend,
140) -> f64 {
141    let (x, y) = aligned_prefix(x, y);
142    if x.is_empty() {
143        return 0.0;
144    }
145    // For CTW, we might want a special aligned implementation?
146    // Using standard formula for now.
147    let h_x = entropy_rate_backend(x, max_order, backend);
148    let h_y = entropy_rate_backend(y, max_order, backend);
149    let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
150    (h_x + h_y - h_xy).max(0.0)
151}
152
153pub fn ned_rate_backend(x: &[u8], y: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
154    let (x, y) = aligned_prefix(x, y);
155    if x.is_empty() {
156        return 0.0;
157    }
158    let h_x = entropy_rate_backend(x, max_order, backend);
159    let h_y = entropy_rate_backend(y, max_order, backend);
160    let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
161    let min_h = h_x.min(h_y);
162    let max_h = h_x.max(h_y);
163    if max_h == 0.0 {
164        0.0
165    } else {
166        ((h_xy - min_h) / max_h).clamp(0.0, 1.0)
167    }
168}
169
170pub fn nte_rate_backend(x: &[u8], y: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
171    let (x, y) = aligned_prefix(x, y);
172    if x.is_empty() {
173        return 0.0;
174    }
175    let h_x = entropy_rate_backend(x, max_order, backend);
176    let h_y = entropy_rate_backend(y, max_order, backend);
177    let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
178    let max_h = h_x.max(h_y);
179    if max_h == 0.0 {
180        0.0
181    } else {
182        // VI = H(X|Y) + H(Y|X) can be as large as H(X) + H(Y) ≈ 2*max(H)
183        // for independent sequences, so NTE ∈ [0, 2]
184        let vi = (h_xy - h_x).max(0.0) + (h_xy - h_y).max(0.0);
185        (vi / max_h).clamp(0.0, 2.0)
186    }
187}
188
189#[derive(Clone)]
190pub enum RateBackend {
191    RosaPlus,
192    Rwkv7 {
193        model: Arc<rwkvzip::Model>,
194    },
195    /// ZPAQ compression-based rate model (streamable methods only).
196    Zpaq {
197        method: String,
198    },
199    /// Online mixture over rate-model experts (Bayes, fading Bayes, switching, MDL).
200    Mixture {
201        spec: Arc<MixtureSpec>,
202    },
203    /// Action-Conditional CTW (single context tree).
204    Ctw {
205        depth: usize,
206    },
207    /// Factorized Action-Conditional CTW (k trees for k-bit percepts).
208    FacCtw {
209        base_depth: usize,
210        num_percept_bits: usize,
211        encoding_bits: usize,
212    },
213}
214
215#[derive(Clone)]
216pub enum NcdBackend {
217    Zpaq {
218        method: String,
219    },
220    Rwkv7 {
221        model: Arc<rwkvzip::Model>,
222        coder: rwkvzip::CoderType,
223    },
224}
225
226/// Mixture policy kind for rate-backend mixtures.
227#[derive(Clone, Copy, Debug, Eq, PartialEq)]
228pub enum MixtureKind {
229    Bayes,
230    FadingBayes,
231    Switching,
232    Mdl,
233}
234
235/// Expert specification for mixture backends.
236#[derive(Clone)]
237pub struct MixtureExpertSpec {
238    pub name: Option<String>,
239    /// Log prior weight (natural log). Uniform priors can be `0.0`.
240    pub log_prior: f64,
241    /// Max order for ROSA experts (ignored for other backends).
242    pub max_order: i64,
243    pub backend: RateBackend,
244}
245
246/// Mixture specification for rate-backend mixtures.
247#[derive(Clone)]
248pub struct MixtureSpec {
249    pub kind: MixtureKind,
250    /// Switching probability (per step) for switching mixtures.
251    pub alpha: f64,
252    /// Decay factor for fading Bayes mixtures.
253    pub decay: Option<f64>,
254    pub experts: Vec<MixtureExpertSpec>,
255}
256
257impl MixtureSpec {
258    pub fn new(kind: MixtureKind, experts: Vec<MixtureExpertSpec>) -> Self {
259        Self {
260            kind,
261            alpha: 0.01,
262            decay: None,
263            experts,
264        }
265    }
266
267    pub fn with_alpha(mut self, alpha: f64) -> Self {
268        self.alpha = alpha;
269        self
270    }
271
272    pub fn with_decay(mut self, decay: f64) -> Self {
273        self.decay = Some(decay);
274        self
275    }
276
277    pub fn build_experts(&self) -> Vec<crate::mixture::ExpertConfig> {
278        self.experts
279            .iter()
280            .map(|spec| {
281                crate::mixture::ExpertConfig::from_rate_backend(
282                    spec.name.clone(),
283                    spec.log_prior,
284                    spec.backend.clone(),
285                    spec.max_order,
286                )
287            })
288            .collect()
289    }
290}
291
292#[derive(Clone)]
293pub struct InfotheoryCtx {
294    pub rate_backend: RateBackend,
295    pub ncd_backend: NcdBackend,
296}
297
298impl InfotheoryCtx {
299    pub fn new(rate_backend: RateBackend, ncd_backend: NcdBackend) -> Self {
300        Self {
301            rate_backend,
302            ncd_backend,
303        }
304    }
305
306    pub fn with_zpaq(method: impl Into<String>) -> Self {
307        Self {
308            rate_backend: RateBackend::RosaPlus,
309            ncd_backend: NcdBackend::Zpaq {
310                method: method.into(),
311            },
312        }
313    }
314
315    pub fn compress_size(&self, data: &[u8]) -> u64 {
316        compress_size_backend(data, &self.ncd_backend)
317    }
318
319    pub fn compress_size_chain(&self, parts: &[&[u8]]) -> u64 {
320        compress_size_chain_backend(parts, &self.ncd_backend)
321    }
322
323    pub fn entropy_rate_bytes(&self, data: &[u8], max_order: i64) -> f64 {
324        entropy_rate_backend(data, max_order, &self.rate_backend)
325    }
326
327    pub fn biased_entropy_rate_bytes(&self, data: &[u8], max_order: i64) -> f64 {
328        biased_entropy_rate_backend(data, max_order, &self.rate_backend)
329    }
330
331    pub fn cross_entropy_rate_bytes(
332        &self,
333        test_data: &[u8],
334        train_data: &[u8],
335        max_order: i64,
336    ) -> f64 {
337        cross_entropy_rate_backend(test_data, train_data, max_order, &self.rate_backend)
338    }
339
340    pub fn cross_entropy_bytes(&self, test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
341        if max_order == 0 {
342            if test_data.is_empty() {
343                return 0.0;
344            }
345            let p_x = byte_histogram(test_data);
346            let p_y = byte_histogram(train_data);
347            let mut h = 0.0f64;
348            for i in 0..256 {
349                if p_x[i] > 0.0 {
350                    let q_y = p_y[i].max(1e-12);
351                    h -= p_x[i] * q_y.log2();
352                }
353            }
354            h
355        } else {
356            self.cross_entropy_rate_bytes(test_data, train_data, max_order)
357        }
358    }
359
360    pub fn joint_entropy_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
361        let (x, y) = aligned_prefix(x, y);
362        if x.is_empty() {
363            return 0.0;
364        }
365        joint_entropy_rate_backend(x, y, max_order, &self.rate_backend)
366    }
367
368    pub fn conditional_entropy_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
369        let (x, y) = aligned_prefix(x, y);
370        if x.is_empty() {
371            return 0.0;
372        }
373        let h_xy = self.joint_entropy_rate_bytes(x, y, max_order);
374        let h_y = self.entropy_rate_bytes(y, max_order);
375        (h_xy - h_y).max(0.0)
376    }
377
378    pub fn cross_entropy_conditional_chain(&self, prefix_parts: &[&[u8]], data: &[u8]) -> f64 {
379        match &self.rate_backend {
380            RateBackend::RosaPlus => {
381                let mut prefix = Vec::new();
382                let total: usize = prefix_parts.iter().map(|p| p.len()).sum();
383                prefix.reserve(total);
384                for p in prefix_parts {
385                    prefix.extend_from_slice(p);
386                }
387                cross_entropy_rate_backend(data, &prefix, -1, &RateBackend::RosaPlus)
388            }
389            RateBackend::Rwkv7 { model } => with_rwkv_tls(model, |c| {
390                c.cross_entropy_conditional_chain(prefix_parts, data)
391                    .unwrap_or(0.0)
392            }),
393            RateBackend::Ctw { depth } => {
394                if data.is_empty() {
395                    return 0.0;
396                }
397                let mut tree = crate::ctw::ContextTree::new(*depth);
398                for &part in prefix_parts {
399                    for &b in part {
400                        for i in (0..8).rev() {
401                            tree.update(((b >> i) & 1) == 1);
402                        }
403                    }
404                }
405                let log_p_prefix = tree.get_log_block_probability();
406                for &b in data {
407                    for i in (0..8).rev() {
408                        tree.update(((b >> i) & 1) == 1);
409                    }
410                }
411                let log_p_joint = tree.get_log_block_probability();
412                let log_p_cond = log_p_joint - log_p_prefix;
413                let bits = -log_p_cond / std::f64::consts::LN_2;
414                bits / (data.len() as f64)
415            }
416            RateBackend::Zpaq { method } => {
417                if data.is_empty() {
418                    return 0.0;
419                }
420                let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
421                for &part in prefix_parts {
422                    model.update_and_score(part);
423                }
424                let bits = model.update_and_score(data);
425                bits / (data.len() as f64)
426            }
427            RateBackend::Mixture { spec } => {
428                if data.is_empty() {
429                    return 0.0;
430                }
431                let experts = spec.build_experts();
432                let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
433                    .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
434                for &part in prefix_parts {
435                    for &b in part {
436                        mix.step(b);
437                    }
438                }
439                let mut bits = 0.0;
440                for &b in data {
441                    bits -= mix.step(b) / std::f64::consts::LN_2;
442                }
443                bits / (data.len() as f64)
444            }
445            RateBackend::FacCtw {
446                base_depth,
447                num_percept_bits: _,
448                encoding_bits,
449            } => {
450                if data.is_empty() {
451                    return 0.0;
452                }
453                let bits_per_byte = (*encoding_bits).min(8).max(1);
454                let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte);
455                for &part in prefix_parts {
456                    for &b in part {
457                        // Fix Issue 1: LSB-first
458                        for i in 0..bits_per_byte {
459                            let bit_idx = i;
460                            // b >> i gets the i-th bit (0 is LSB)
461                            fac.update(((b >> i) & 1) == 1, bit_idx);
462                        }
463                    }
464                }
465                let log_p_prefix = fac.get_log_block_probability();
466                for &b in data {
467                    for i in 0..bits_per_byte {
468                        let bit_idx = i;
469                        fac.update(((b >> i) & 1) == 1, bit_idx);
470                    }
471                }
472                let log_p_joint = fac.get_log_block_probability();
473                let log_p_cond = log_p_joint - log_p_prefix;
474                let bits = -log_p_cond / std::f64::consts::LN_2;
475                bits / (data.len() as f64)
476            }
477        }
478    }
479
480    pub fn ncd_bytes(&self, x: &[u8], y: &[u8], variant: NcdVariant) -> f64 {
481        ncd_bytes_backend(x, y, &self.ncd_backend, variant)
482    }
483
484    pub fn mutual_information_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
485        mutual_information_rate_backend(x, y, max_order, &self.rate_backend)
486    }
487
488    pub fn mutual_information_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
489        if max_order == 0 {
490            mutual_information_marg_bytes(x, y)
491        } else {
492            self.mutual_information_rate_bytes(x, y, max_order)
493        }
494    }
495
496    pub fn conditional_entropy_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
497        let (x, y) = aligned_prefix(x, y);
498        if max_order == 0 {
499            let h_xy = joint_marginal_entropy_bytes(x, y);
500            let h_y = marginal_entropy_bytes(y);
501            (h_xy - h_y).max(0.0)
502        } else {
503            let h_xy = self.joint_entropy_rate_bytes(x, y, max_order);
504            let h_y = self.entropy_rate_bytes(y, max_order);
505            (h_xy - h_y).max(0.0)
506        }
507    }
508
509    pub fn ned_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
510        if max_order == 0 {
511            ned_marg_bytes(x, y)
512        } else {
513            ned_rate_backend(x, y, max_order, &self.rate_backend)
514        }
515    }
516
517    pub fn ned_cons_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
518        let (x, y) = aligned_prefix(x, y);
519        let (h_x, h_y, h_xy) = if max_order == 0 {
520            (
521                marginal_entropy_bytes(x),
522                marginal_entropy_bytes(y),
523                joint_marginal_entropy_bytes(x, y),
524            )
525        } else {
526            (
527                self.entropy_rate_bytes(x, max_order),
528                self.entropy_rate_bytes(y, max_order),
529                self.joint_entropy_rate_bytes(x, y, max_order),
530            )
531        };
532        let min_h = h_x.min(h_y);
533        if h_xy == 0.0 {
534            0.0
535        } else {
536            ((h_xy - min_h) / h_xy).clamp(0.0, 1.0)
537        }
538    }
539
540    pub fn nte_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
541        if max_order == 0 {
542            nte_marg_bytes(x, y)
543        } else {
544            nte_rate_backend(x, y, max_order, &self.rate_backend)
545        }
546    }
547
548    pub fn intrinsic_dependence_bytes(&self, data: &[u8], max_order: i64) -> f64 {
549        let h_marginal = marginal_entropy_bytes(data);
550        if h_marginal < 1e-9 {
551            return 0.0;
552        }
553        let h_rate = self.entropy_rate_bytes(data, max_order);
554        ((h_marginal - h_rate) / h_marginal).clamp(0.0, 1.0)
555    }
556
557    pub fn resistance_to_transformation_bytes(&self, x: &[u8], tx: &[u8], max_order: i64) -> f64 {
558        let (x, tx) = aligned_prefix(x, tx);
559        let h_x = if max_order == 0 {
560            marginal_entropy_bytes(x)
561        } else {
562            self.entropy_rate_bytes(x, max_order)
563        };
564        if h_x < 1e-9 {
565            return 0.0;
566        }
567        let mi = self.mutual_information_bytes(x, tx, max_order);
568        (mi / h_x).clamp(0.0, 1.0)
569    }
570}
571
572pub fn load_rwkv7_model_from_path(path: &str) -> Arc<rwkvzip::Model> {
573    rwkvzip::Compressor::load_model(path).expect("failed to load RWKV7 model")
574}
575
576#[inline(always)]
577fn aligned_prefix<'a>(x: &'a [u8], y: &'a [u8]) -> (&'a [u8], &'a [u8]) {
578    let n = x.len().min(y.len());
579    (&x[..n], &y[..n])
580}
581
582/// ------- Base Compression Functions -------
583#[inline(always)]
584pub fn get_compressed_size(path: &str, method: &str) -> u64 {
585    // Convert Input file to Vec<u8>, and reference that (compress_size only takes &[u8] input), and pass method.
586    // Will panic if file does not exist, so it must be prevalidated.
587    zpaq_rs::compress_size(&std::fs::read(path).unwrap(), method).unwrap()
588}
589
590pub fn validate_zpaq_rate_method(method: &str) -> Result<(), String> {
591    zpaq_rate::validate_zpaq_rate_method(method)
592}
593
594fn with_rwkv_tls<R>(
595    model: &Arc<rwkvzip::Model>,
596    f: impl FnOnce(&mut rwkvzip::Compressor) -> R,
597) -> R {
598    let key = Arc::as_ptr(model) as usize;
599    RWKV_TLS.with(|cell| {
600        let mut map = cell.borrow_mut();
601        let comp = map
602            .entry(key)
603            .or_insert_with(|| rwkvzip::Compressor::new_from_model(model.clone()));
604        f(comp)
605    })
606}
607
608struct SliceChainReader<'a> {
609    parts: &'a [&'a [u8]],
610    i: usize,
611    off: usize,
612}
613
614impl<'a> SliceChainReader<'a> {
615    fn new(parts: &'a [&'a [u8]]) -> Self {
616        Self {
617            parts,
618            i: 0,
619            off: 0,
620        }
621    }
622}
623
624impl<'a> std::io::Read for SliceChainReader<'a> {
625    fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
626        let mut total = 0;
627        if buf.is_empty() {
628            return Ok(0);
629        }
630        while self.i < self.parts.len() {
631            let p = self.parts[self.i];
632            if self.off >= p.len() {
633                self.i += 1;
634                self.off = 0;
635                continue;
636            }
637            let n = (p.len() - self.off).min(buf.len());
638            // Safe copy slice
639            buf[..n].copy_from_slice(&p[self.off..self.off + n]);
640
641            // Advance state
642            self.off += n;
643            total += n;
644
645            // Re-slice buf to fill remainder
646            let tmp = buf;
647            buf = &mut tmp[n..];
648
649            if buf.is_empty() {
650                break;
651            }
652        }
653        Ok(total)
654    }
655}
656
657pub fn compress_size_chain_backend(parts: &[&[u8]], backend: &NcdBackend) -> u64 {
658    match backend {
659        NcdBackend::Zpaq { method } => {
660            let r = SliceChainReader::new(parts);
661            zpaq_rs::compress_size_stream(r, method.as_str(), None, None).unwrap_or(0)
662        }
663        NcdBackend::Rwkv7 { model, coder } => {
664            with_rwkv_tls(model, |c| c.compress_size_chain(parts, *coder).unwrap_or(0))
665        }
666    }
667}
668
669pub fn compress_size_backend(data: &[u8], backend: &NcdBackend) -> u64 {
670    match backend {
671        NcdBackend::Zpaq { method } => zpaq_rs::compress_size(data, method.as_str()).unwrap_or(0),
672        NcdBackend::Rwkv7 { model, coder } => {
673            with_rwkv_tls(model, |c| c.compress_size(data, *coder).unwrap_or(0))
674        }
675    }
676}
677
678pub fn entropy_rate_backend(data: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
679    match backend {
680        RateBackend::RosaPlus => {
681            let mut m = rosaplus::RosaPlus::new(max_order, false, 0, 42);
682            m.predictive_entropy_rate(data)
683        }
684        RateBackend::Rwkv7 { model } => {
685            with_rwkv_tls(model, |c| c.cross_entropy(data).unwrap_or(0.0))
686        }
687        RateBackend::Zpaq { method } => {
688            if data.is_empty() {
689                return 0.0;
690            }
691            let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
692            let bits = model.update_and_score(data);
693            bits / (data.len() as f64)
694        }
695        RateBackend::Mixture { spec } => {
696            if data.is_empty() {
697                return 0.0;
698            }
699            let experts = spec.build_experts();
700            let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
701                .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
702            let mut bits = 0.0;
703            for &b in data {
704                bits -= mix.step(b) / std::f64::consts::LN_2;
705            }
706            bits / (data.len() as f64)
707        }
708        RateBackend::Ctw { depth } => {
709            if data.is_empty() {
710                return 0.0;
711            }
712            // Byte-wise CTW: factorize by bit position so deterministic bits don't leak entropy.
713            let mut fac = crate::ctw::FacContextTree::new(*depth, 8);
714            for &b in data {
715                for bit_idx in 0..8 {
716                    let bit = ((b >> (7 - bit_idx)) & 1) == 1;
717                    fac.update(bit, bit_idx);
718                }
719            }
720            let ln_p = fac.get_log_block_probability();
721            let bits = -ln_p / std::f64::consts::LN_2;
722            bits / (data.len() as f64)
723        }
724        RateBackend::FacCtw {
725            base_depth,
726            num_percept_bits: _,
727            encoding_bits,
728        } => {
729            if data.is_empty() {
730                return 0.0;
731            }
732            let bits_per_byte = (*encoding_bits).min(8).max(1);
733            let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte);
734            for &b in data {
735                for i in 0..bits_per_byte {
736                    let bit_idx = i;
737                    fac.update(((b >> i) & 1) == 1, bit_idx);
738                }
739            }
740            let ln_p = fac.get_log_block_probability();
741            let bits = -ln_p / std::f64::consts::LN_2;
742            bits / (data.len() as f64)
743        }
744    }
745}
746
747pub fn biased_entropy_rate_backend(data: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
748    match backend {
749        RateBackend::RosaPlus => {
750            let mut m = rosaplus::RosaPlus::new(max_order, false, 0, 42);
751            m.train_example(data);
752            m.build_lm();
753            m.cross_entropy(data)
754        }
755        RateBackend::Rwkv7 { model } => {
756            with_rwkv_tls(model, |c| c.cross_entropy(data).unwrap_or(0.0))
757        }
758        RateBackend::Zpaq { .. } => entropy_rate_backend(data, max_order, backend),
759        RateBackend::Mixture { .. } => entropy_rate_backend(data, max_order, backend),
760        RateBackend::Ctw { .. } | RateBackend::FacCtw { .. } => {
761            // CTW/FAC-CTW are online, so biased=prequential
762            entropy_rate_backend(data, max_order, backend)
763        }
764    }
765}
766
767/// Cross-entropy H_{train}(test) - score test_data under model trained on train_data.
768pub fn cross_entropy_rate_backend(
769    test_data: &[u8],
770    train_data: &[u8],
771    max_order: i64,
772    backend: &RateBackend,
773) -> f64 {
774    match backend {
775        RateBackend::RosaPlus => {
776            let mut m = rosaplus::RosaPlus::new(max_order, false, 0, 42);
777            m.train_example(train_data);
778            m.build_lm();
779            m.cross_entropy(test_data)
780        }
781        RateBackend::Rwkv7 { model } => {
782            with_rwkv_tls(model, |c| {
783                // Inverted args fix: (prefix, target) -> (train, test)
784                // This estimates H_{train}(test)
785                c.cross_entropy_conditional(train_data, test_data)
786                    .unwrap_or(0.0)
787            })
788        }
789        RateBackend::Zpaq { method } => {
790            if test_data.is_empty() {
791                return 0.0;
792            }
793            let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
794            model.update_and_score(train_data);
795            let bits = model.update_and_score(test_data);
796            bits / (test_data.len() as f64)
797        }
798        RateBackend::Mixture { spec } => {
799            if test_data.is_empty() {
800                return 0.0;
801            }
802            let experts = spec.build_experts();
803            let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
804                .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
805            for &b in train_data {
806                mix.step(b);
807            }
808            let mut bits = 0.0;
809            for &b in test_data {
810                bits -= mix.step(b) / std::f64::consts::LN_2;
811            }
812            bits / (test_data.len() as f64)
813        }
814        RateBackend::Ctw { depth } => {
815            if test_data.is_empty() {
816                return 0.0;
817            }
818            let mut fac = crate::ctw::FacContextTree::new(*depth, 8);
819            for &b in train_data {
820                for bit_idx in 0..8 {
821                    let bit = ((b >> (7 - bit_idx)) & 1) == 1;
822                    fac.update(bit, bit_idx);
823                }
824            }
825            let log_p_y = fac.get_log_block_probability();
826            for &b in test_data {
827                for bit_idx in 0..8 {
828                    let bit = ((b >> (7 - bit_idx)) & 1) == 1;
829                    fac.update(bit, bit_idx);
830                }
831            }
832            let log_p_yx = fac.get_log_block_probability();
833            let log_p_x_given_y = log_p_yx - log_p_y;
834            let bits = -log_p_x_given_y / std::f64::consts::LN_2;
835            bits / (test_data.len() as f64)
836        }
837        RateBackend::FacCtw {
838            base_depth,
839            num_percept_bits: _,
840            encoding_bits,
841        } => {
842            if test_data.is_empty() {
843                return 0.0;
844            }
845            let bits_per_byte = (*encoding_bits).min(8).max(1);
846            let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte);
847            for &b in train_data {
848                for i in 0..bits_per_byte {
849                    let bit_idx = i;
850                    fac.update(((b >> i) & 1) == 1, bit_idx);
851                }
852            }
853
854            let log_p_y = fac.get_log_block_probability();
855            for &b in test_data {
856                for i in 0..bits_per_byte {
857                    let bit_idx = i;
858                    fac.update(((b >> i) & 1) == 1, bit_idx);
859                }
860            }
861            let log_p_yx = fac.get_log_block_probability();
862            let log_p_x_given_y = log_p_yx - log_p_y;
863            let bits = -log_p_x_given_y / std::f64::consts::LN_2;
864            bits / (test_data.len() as f64)
865        }
866    }
867}
868
869pub fn joint_entropy_rate_backend(
870    x: &[u8],
871    y: &[u8],
872    max_order: i64,
873    backend: &RateBackend,
874) -> f64 {
875    match backend {
876        RateBackend::RosaPlus => {
877            let joint_symbols: Vec<u32> = (0..x.len())
878                .map(|i| (x[i] as u32) * 256 + (y[i] as u32))
879                .collect();
880            let mut m = rosaplus::RosaPlus::new(max_order, false, 0, 42);
881            m.entropy_rate_cps(&joint_symbols)
882        }
883        RateBackend::Rwkv7 { model } => with_rwkv_tls(model, |c| {
884            c.joint_cross_entropy_aligned_min(x, y).unwrap_or(0.0)
885        }),
886        RateBackend::Zpaq { method } => {
887            if x.is_empty() {
888                return 0.0;
889            }
890            let mut joint = Vec::with_capacity(x.len() * 2);
891            for i in 0..x.len() {
892                joint.push(x[i]);
893                joint.push(y[i]);
894            }
895            let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
896            let bits = model.update_and_score(&joint);
897            bits / (x.len() as f64)
898        }
899        RateBackend::Mixture { spec } => {
900            if x.is_empty() {
901                return 0.0;
902            }
903            let mut joint = Vec::with_capacity(x.len() * 2);
904            for i in 0..x.len() {
905                joint.push(x[i]);
906                joint.push(y[i]);
907            }
908            let experts = spec.build_experts();
909            let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
910                .unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
911            let mut bits = 0.0;
912            for &b in &joint {
913                bits -= mix.step(b) / std::f64::consts::LN_2;
914            }
915            bits / (x.len() as f64)
916        }
917        RateBackend::Ctw { depth } => {
918            // NOTE: CTW interleaves bits: x_0, y_0, x_1, y_1...
919            // This estimates the joint entropy H(X,Y) by modeling the sequence
920            // of alternating bits. This is a fine-grained joint model but
921            // theoretically consistent for estimating joint entropy rate.
922            // ROSA uses 16-bit joint symbols (x << 8 | y). Both are valid.
923            let mut fac = crate::ctw::FacContextTree::new(*depth, 16);
924            for k in 0..x.len() {
925                let bx = x[k];
926                let by = y[k];
927                for bit_idx in 0..8 {
928                    let bit_x = ((bx >> (7 - bit_idx)) & 1) == 1;
929                    let bit_y = ((by >> (7 - bit_idx)) & 1) == 1;
930                    fac.update(bit_x, bit_idx);
931                    fac.update(bit_y, bit_idx + 8);
932                }
933            }
934            let ln_p = fac.get_log_block_probability();
935            let bits = -ln_p / std::f64::consts::LN_2;
936            bits / (x.len() as f64)
937        }
938        RateBackend::FacCtw {
939            base_depth,
940            num_percept_bits: _,
941            encoding_bits,
942        } => {
943            // Joint: interleave x and y bits, use 2*encoding_bits trees
944            let bits_per_byte = (*encoding_bits).min(8).max(1);
945            let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte * 2);
946            for k in 0..x.len() {
947                let bx = x[k];
948                let by = y[k];
949                for i in 0..bits_per_byte {
950                    // Tree structure:
951                    // bits_per_byte trees for X, bits_per_byte trees for Y.
952                    // But we interleave them in the "joint" sense.
953                    // Here we map bit i of X to tree 2*i, bit i of Y to tree 2*i + 1
954                    let bit_idx_x = i * 2;
955                    let bit_idx_y = bit_idx_x + 1;
956                    fac.update(((bx >> i) & 1) == 1, bit_idx_x);
957                    fac.update(((by >> i) & 1) == 1, bit_idx_y);
958                }
959            }
960            let ln_p = fac.get_log_block_probability();
961            let bits = -ln_p / std::f64::consts::LN_2;
962            bits / (x.len() as f64)
963        }
964    }
965}
966#[inline(always)]
967pub fn get_compressed_size_parallel(path: &str, method: &str, threads: usize) -> u64 {
968    // Convert Input file to Vec<u8>, and reference that (compress_size only takes &[u8] input), and pass method.
969    // Will panic if file does not exist, so it must be prevalidated.
970    zpaq_rs::compress_size_parallel(&std::fs::read(path).unwrap(), method, threads).unwrap()
971}
972
973#[inline(always)]
974pub fn get_bytes_from_paths(paths: &[&str]) -> Vec<Vec<u8>> {
975    paths
976        .par_iter()
977        .map(|path| std::fs::read(*path).expect("failed to read file"))
978        .collect()
979}
980
981/// ------- Bulk File Compression Functions -------
982#[inline(always)]
983pub fn get_sequential_compressed_sizes_from_sequential_paths(
984    paths: &[&str],
985    method: &str,
986) -> Vec<u64> {
987    // This will, in parallel load all files into memory, THEN in parallel compress each one, each with one thread.
988    // Use when File IO is the bottleneck
989    // Only uses ONE ZPAQ THREAD.
990    // For VERY large n (relative to threads) with small files (relative to memory) this may be useful.
991    get_bytes_from_paths(paths)
992        .par_iter()
993        .map(|data| zpaq_rs::compress_size(data, method).unwrap())
994        .collect()
995}
996
997#[inline(always)]
998pub fn get_parallel_compressed_sizes_from_sequential_paths(
999    paths: &[&str],
1000    method: &str,
1001    threads: usize,
1002) -> Vec<u64> {
1003    // This will, in parallel load all files into memory, THEN in parallel compress each one, with THREADS. (for each file, the thread count is THREADS)
1004    // Use when File IO is the bottleneck.
1005    // Balanced parallelization between RAYON_NUM_THREADS and ZPAQ `THREADS` const. For when total dataset will fit in memory.
1006    get_bytes_from_paths(paths)
1007        .par_iter()
1008        .map(|data| zpaq_rs::compress_size_parallel(data, method, threads).unwrap())
1009        .collect()
1010}
1011
1012#[inline(always)]
1013pub fn get_sequential_compressed_sizes_from_parallel_paths(
1014    paths: &[&str],
1015    method: &str,
1016) -> Vec<u64> {
1017    // This will, in parallel, for each file, read it from disk and compress it with one thread. (one file, one thread)
1018    // Use when File IO is not the bottleneck. Lower memory usage. (does not preload dataset)
1019    // Only uses ONE ZPAQ THREAD. For VERY large n(relative to threads) with large files(relative to memory) this may be useful.
1020    paths
1021        .par_iter()
1022        .map(|path| get_compressed_size(path, method))
1023        .collect()
1024}
1025
1026#[inline(always)]
1027pub fn get_parallel_compressed_sizes_from_parallel_paths(
1028    paths: &[&str],
1029    method: &str,
1030    threads: usize,
1031) -> Vec<u64> {
1032    // This will, in parallel, for each file, read it from disk and compress it with THREADS. (for each file, the thread count is THREADS)
1033    // Use when File IO is not the bottleneck. Lower memory usage. (does not preload dataset)
1034    // For large n(relative to threads) with VERY large files(relative to memory) this may be useful.
1035    // This will reflect RAYON_NUM_THREADS and THREAD const values.
1036    paths
1037        .par_iter()
1038        .map(|path| get_compressed_size_parallel(path, method, threads))
1039        .collect()
1040}
1041
1042/// Optimizes parallelization
1043#[inline(always)]
1044pub fn get_compressed_sizes_from_paths(paths: &[&str], method: &str) -> Vec<u64> {
1045    let n: usize = paths.len();
1046    let num_threads: usize = *NUM_THREADS.get_or_init(|| num_cpus::get());
1047    if n < num_threads {
1048        get_parallel_compressed_sizes_from_parallel_paths(paths, method, (num_threads + n - 1) / n)
1049    } else {
1050        get_sequential_compressed_sizes_from_parallel_paths(paths, method)
1051    }
1052}
1053
1054/// ------- NCD (Normalized Compression Distance) ------
1055///
1056/// NCD is a parameter-free similarity metric based on Kolmogorov complexity.
1057/// Since Kolmogorov complexity `K(x)` is uncomputable, we approximate it using
1058/// the compressed size `C(x)` provided by a real-world compressor (here, ZPAQ).
1059///
1060/// The general form is:
1061/// `NCD(x,y) = (C(xy) - min(C(x), C(y))) / max(C(x), C(y))`
1062///
1063/// Different variants handle normalization and symmetry differently.
1064#[derive(Clone, Copy, Debug, Eq, PartialEq)]
1065pub enum NcdVariant {
1066    /// Standard Vitanyi NCD:
1067    /// `NCD(x,y) = (C(xy) - min(C(x), C(y))) / max(C(x), C(y))`
1068    /// Note: `C(xy)` denotes compressing the concatenation of x and y.
1069    Vitanyi,
1070    /// Symmetric Vitanyi NCD:
1071    /// `NCD_sym(x,y) = (min(C(xy), C(yx)) - min(C(x), C(y))) / max(C(x), C(y))`
1072    /// Takes the best compression of `xy` or `yx` to ensure symmetry even if the compressor is not symmetric.
1073    SymVitanyi,
1074    /// Conservative NCD:
1075    /// `NCD_cons(x,y) = (C(xy) - min(C(x), C(y))) / C(xy)`
1076    /// Normalizes by the joint compressed size instead of the max marginal.
1077    Cons,
1078    /// Symmetric Conservative NCD:
1079    /// `NCD_sym_cons(x,y) = (min(C(xy), C(yx)) - min(C(x), C(y))) / min(C(xy), C(yx))`
1080    SymCons,
1081}
1082
1083#[inline(always)]
1084fn compress_size_bytes(data: &[u8], method: &str) -> u64 {
1085    zpaq_rs::compress_size(data, method).unwrap_or(0)
1086}
1087
1088#[inline(always)]
1089fn ncd_from_sizes(cx: u64, cy: u64, cxy: u64, cyx: Option<u64>, variant: NcdVariant) -> f64 {
1090    let min_c = cx.min(cy) as f64;
1091    let max_c = cx.max(cy) as f64;
1092
1093    match variant {
1094        NcdVariant::Vitanyi => {
1095            if max_c == 0.0 {
1096                0.0
1097            } else {
1098                (cxy as f64 - min_c) / max_c
1099            }
1100        }
1101        NcdVariant::SymVitanyi => {
1102            let m = cxy.min(cyx.expect("cyx required for SymVitanyi")) as f64;
1103            if max_c == 0.0 {
1104                0.0
1105            } else {
1106                (m - min_c) / max_c
1107            }
1108        }
1109        NcdVariant::Cons => {
1110            let denom = cxy as f64;
1111            if denom == 0.0 {
1112                0.0
1113            } else {
1114                (cxy as f64 - min_c) / denom
1115            }
1116        }
1117        NcdVariant::SymCons => {
1118            let m = cxy.min(cyx.expect("cyx required for SymCons")) as f64;
1119            if m == 0.0 { 0.0 } else { (m - min_c) / m }
1120        }
1121    }
1122}
1123
1124#[inline(always)]
1125pub fn ncd_bytes(x: &[u8], y: &[u8], method: &str, variant: NcdVariant) -> f64 {
1126    let backend = NcdBackend::Zpaq {
1127        method: method.to_string(),
1128    };
1129    ncd_bytes_backend(x, y, &backend, variant)
1130}
1131
1132/// NCD with bytes using the default context.
1133#[inline(always)]
1134pub fn ncd_bytes_default(x: &[u8], y: &[u8], variant: NcdVariant) -> f64 {
1135    with_default_ctx(|ctx| ctx.ncd_bytes(x, y, variant))
1136}
1137
1138pub fn ncd_bytes_backend(x: &[u8], y: &[u8], backend: &NcdBackend, variant: NcdVariant) -> f64 {
1139    let (cx, cy) = rayon::join(
1140        || compress_size_backend(x, backend),
1141        || compress_size_backend(y, backend),
1142    );
1143
1144    let cxy = compress_size_chain_backend(&[x, y], backend);
1145
1146    let cyx = match variant {
1147        NcdVariant::SymVitanyi | NcdVariant::SymCons => {
1148            Some(compress_size_chain_backend(&[y, x], backend))
1149        }
1150        _ => None,
1151    };
1152
1153    ncd_from_sizes(cx, cy, cxy, cyx, variant)
1154}
1155
1156#[inline(always)]
1157pub fn ncd_paths(x: &str, y: &str, method: &str, variant: NcdVariant) -> f64 {
1158    let (bx, by) = rayon::join(
1159        || std::fs::read(x).expect("failed to read x"),
1160        || std::fs::read(y).expect("failed to read y"),
1161    );
1162    ncd_bytes(&bx, &by, method, variant)
1163}
1164
1165pub fn ncd_paths_backend(x: &str, y: &str, backend: &NcdBackend, variant: NcdVariant) -> f64 {
1166    let (bx, by) = rayon::join(
1167        || std::fs::read(x).expect("failed to read x"),
1168        || std::fs::read(y).expect("failed to read y"),
1169    );
1170    ncd_bytes_backend(&bx, &by, backend, variant)
1171}
1172
1173/// Back-compat convenience wrappers (operate on file paths).
1174#[inline(always)]
1175pub fn ncd_vitanyi(x: &str, y: &str, method: &str) -> f64 {
1176    ncd_paths(x, y, method, NcdVariant::Vitanyi)
1177}
1178#[inline(always)]
1179pub fn ncd_sym_vitanyi(x: &str, y: &str, method: &str) -> f64 {
1180    ncd_paths(x, y, method, NcdVariant::SymVitanyi)
1181}
1182#[inline(always)]
1183pub fn ncd_cons(x: &str, y: &str, method: &str) -> f64 {
1184    ncd_paths(x, y, method, NcdVariant::Cons)
1185}
1186#[inline(always)]
1187pub fn ncd_sym_cons(x: &str, y: &str, method: &str) -> f64 {
1188    ncd_paths(x, y, method, NcdVariant::SymCons)
1189}
1190
1191/// Computes an NCD matrix (row-major, len = n*n) for in-memory byte blobs.
1192///
1193/// Note: For symmetric variants, this computes each unordered pair once and writes both (i,j) and (j,i).
1194pub fn ncd_matrix_bytes(datas: &[Vec<u8>], method: &str, variant: NcdVariant) -> Vec<f64> {
1195    let n = datas.len();
1196    let cx: Vec<u64> = datas
1197        .par_iter()
1198        .map(|d| compress_size_bytes(d, method))
1199        .collect();
1200
1201    let mut out = vec![0.0f64; n * n];
1202    let out_ptr = std::sync::atomic::AtomicPtr::new(out.as_mut_ptr());
1203
1204    match variant {
1205        NcdVariant::SymVitanyi | NcdVariant::SymCons => {
1206            (0..n)
1207                .into_par_iter()
1208                .flat_map_iter(|i| (i + 1..n).map(move |j| (i, j)))
1209                .for_each_init(Vec::<u8>::new, |buf, (i, j)| {
1210                    let x = &datas[i];
1211                    let y = &datas[j];
1212
1213                    buf.clear();
1214                    buf.reserve(x.len() + y.len());
1215                    buf.extend_from_slice(x);
1216                    buf.extend_from_slice(y);
1217                    let cxy = compress_size_bytes(buf, method);
1218
1219                    buf.clear();
1220                    buf.reserve(x.len() + y.len());
1221                    buf.extend_from_slice(y);
1222                    buf.extend_from_slice(x);
1223                    let cyx = compress_size_bytes(buf, method);
1224
1225                    let d = ncd_from_sizes(cx[i], cx[j], cxy, Some(cyx), variant);
1226
1227                    // Safety: each (i,j) cell is written exactly once across all iterations.
1228                    let p = out_ptr.load(std::sync::atomic::Ordering::Relaxed);
1229                    unsafe {
1230                        *p.add(i * n + j) = d;
1231                        *p.add(j * n + i) = d;
1232                    }
1233                });
1234        }
1235        NcdVariant::Vitanyi | NcdVariant::Cons => {
1236            (0..n)
1237                .into_par_iter()
1238                .for_each_init(Vec::<u8>::new, |buf, i| {
1239                    let x = &datas[i];
1240                    for j in 0..n {
1241                        let d = if i == j {
1242                            0.0
1243                        } else {
1244                            let y = &datas[j];
1245                            buf.clear();
1246                            buf.reserve(x.len() + y.len());
1247                            buf.extend_from_slice(x);
1248                            buf.extend_from_slice(y);
1249                            let cxy = compress_size_bytes(buf, method);
1250                            ncd_from_sizes(cx[i], cx[j], cxy, None, variant)
1251                        };
1252
1253                        let p = out_ptr.load(std::sync::atomic::Ordering::Relaxed);
1254                        unsafe {
1255                            *p.add(i * n + j) = d;
1256                        }
1257                    }
1258                });
1259        }
1260    }
1261
1262    out
1263}
1264
1265/// Computes an NCD matrix (row-major, len = n*n) for files (preloads all files into memory once).
1266pub fn ncd_matrix_paths(paths: &[&str], method: &str, variant: NcdVariant) -> Vec<f64> {
1267    let datas = get_bytes_from_paths(paths);
1268    ncd_matrix_bytes(&datas, method, variant)
1269}
1270
1271// ============================================================
1272// Entropy-Based Distance Primitives (via ROSA)
1273// ============================================================
1274//
1275// These use ROSA's Witten-Bell language model to estimate entropy
1276// and compute information-theoretic distances.
1277
1278/// Compute marginal (Shannon) entropy H(X) = −Σ p(x) log₂ p(x) in bits/symbol.
1279///
1280/// This is the simple first-order entropy from the byte histogram,
1281/// NOT the context-conditional entropy rate from a language model.
1282#[inline(always)]
1283pub fn marginal_entropy_bytes(data: &[u8]) -> f64 {
1284    if data.is_empty() {
1285        return 0.0;
1286    }
1287
1288    let mut counts = [0u64; 256];
1289    for &b in data {
1290        counts[b as usize] += 1;
1291    }
1292
1293    let n = data.len() as f64;
1294    let mut h = 0.0f64;
1295    for i in 0..256 {
1296        if counts[i] > 0 {
1297            let p = counts[i] as f64 / n;
1298            h -= p * p.log2();
1299        }
1300    }
1301    h
1302}
1303
1304/// Compute entropy rate `Ĥ(X)` in bits/symbol using ROSA LM.
1305///
1306/// This uses ROSA's context-conditional Witten-Bell model to estimate
1307/// the entropy rate, which accounts for sequential dependencies.
1308///
1309/// The estimator is **prequential** (predictive sequential): it sums the negative log-probability
1310/// of each symbol `x_t` given its past context `x_{<t}`, estimated from the model trained on `x_{<t}`.
1311///
1312/// `Ĥ(X) = -1/N * Σ log2 P(x_t | x_{t-k}^{t-1})`
1313///
1314/// For i.i.d. data, this should approximately equal `marginal_entropy_bytes`.
1315///
1316/// * `max_order`: Maximum context order for the suffix automaton LM.
1317///   A value of -1 means unlimited context (bounded only by memory/sequence length).
1318#[inline(always)]
1319pub fn entropy_rate_bytes(data: &[u8], max_order: i64) -> f64 {
1320    with_default_ctx(|ctx| ctx.entropy_rate_bytes(data, max_order))
1321}
1322
1323/// Compute biased entropy rate Ĥ_biased(X) bits per symbol.
1324///
1325/// This uses the full plugin estimator (training on the whole text, then scoring the same text).
1326/// While biased as a source entropy estimate, it is mathematically consistent for
1327/// similarity metrics like Mutual Information and NED.
1328#[inline(always)]
1329pub fn biased_entropy_rate_bytes(data: &[u8], max_order: i64) -> f64 {
1330    with_default_ctx(|ctx| ctx.biased_entropy_rate_bytes(data, max_order))
1331}
1332
1333/// Compute joint marginal entropy H(X,Y) = −Σ p(x,y) log₂ p(x,y) in bits/symbol-pair.
1334///
1335/// Uses a direct histogram of (x_i, y_i) pairs. This is the exact first-order
1336/// joint entropy, matching the spec.md definition.
1337#[inline(always)]
1338pub fn joint_marginal_entropy_bytes(x: &[u8], y: &[u8]) -> f64 {
1339    let (x, y) = aligned_prefix(x, y);
1340    let n = x.len();
1341    if n == 0 {
1342        return 0.0;
1343    }
1344
1345    // Count pair occurrences using a HashMap for (x, y) pairs
1346    // There are up to 65536 possible pairs, so we can use a flat array
1347    let mut counts = vec![0u64; 256 * 256];
1348    for i in 0..n {
1349        let pair_idx = (x[i] as usize) * 256 + (y[i] as usize);
1350        counts[pair_idx] += 1;
1351    }
1352
1353    let n_f64 = n as f64;
1354    let mut h = 0.0f64;
1355    for &c in &counts {
1356        if c > 0 {
1357            let p = c as f64 / n_f64;
1358            h -= p * p.log2();
1359        }
1360    }
1361    h
1362}
1363
1364/// Compute joint entropy rate `Ĥ(X,Y)`.
1365///
1366/// Dispatches based on `max_order`:
1367/// - `max_order == 0`: Strictly aligned pair-symbol mapping (Marginal Joint Entropy).
1368///   Treats `(x_i, y_i)` as a single symbol in a product alphabet `Σ_X × Σ_Y`.
1369/// - `max_order != 0`: Shift-invariant algorithmic joint entropy approximated via ROSA.
1370///   Constructs a sequence of pair-symbols and estimates the entropy rate of that sequence.
1371///
1372/// **Note**: This is an *aligned* joint entropy-rate estimate over time-indexed pairs
1373/// `(x_i, y_i)`. All joint-based quantities (`H(X)`, `H(Y)`, `H(X,Y)`, `I`, NED, NTE, etc.)
1374/// should be computed over the same aligned sample.
1375#[inline(always)]
1376pub fn joint_entropy_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
1377    with_default_ctx(|ctx| ctx.joint_entropy_rate_bytes(x, y, max_order))
1378}
1379
1380/// Compute conditional entropy rate `Ĥ(X|Y)`.
1381///
1382/// Dispatches based on `max_order`:
1383/// - `max_order == 0`: Strictly aligned `H(X,Y) - H(Y)` using marginals.
1384/// - `max_order != 0`: Chain rule definition `Ĥ(X|Y) = Ĥ(X,Y) - Ĥ(Y)`.
1385///
1386/// Note: This relies on the identity `H(X|Y) = H(X,Y) - H(Y)`.
1387#[inline(always)]
1388pub fn conditional_entropy_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
1389    with_default_ctx(|ctx| ctx.conditional_entropy_rate_bytes(x, y, max_order))
1390}
1391
1392/// Compute conditional entropy H(X|Y) = H(X,Y) − H(Y)
1393///
1394/// Dispatches based on `max_order`.
1395#[inline(always)]
1396pub fn conditional_entropy_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
1397    with_default_ctx(|ctx| ctx.conditional_entropy_bytes(x, y, max_order))
1398}
1399
1400/// Compute mutual information `I(X;Y) = H(X) + H(Y) - H(X,Y)`.
1401///
1402/// Dispatches based on `max_order`. If 0, uses marginals; else uses rates.
1403///
1404/// `I(X;Y) = Σ p(x,y) log(p(x,y) / (p(x)p(y)))`
1405#[inline(always)]
1406pub fn mutual_information_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
1407    with_default_ctx(|ctx| ctx.mutual_information_bytes(x, y, max_order))
1408}
1409
1410/// Marginal Mutual Information (exact/histogram)
1411pub fn mutual_information_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
1412    let (x, y) = aligned_prefix(x, y);
1413    let h_x = marginal_entropy_bytes(x);
1414    let h_y = marginal_entropy_bytes(y);
1415    let h_xy = joint_marginal_entropy_bytes(x, y);
1416    (h_x + h_y - h_xy).max(0.0)
1417}
1418
1419/// Entropy Rate Mutual Information (ROSA predictive)
1420#[inline(always)]
1421pub fn mutual_information_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
1422    with_default_ctx(|ctx| ctx.mutual_information_rate_bytes(x, y, max_order))
1423}
1424
1425// ====== NED: Normalized Entropy Distance ======
1426//
1427// A metric distance based on the overlap of information between two variables.
1428
1429/// NED(X,Y) = (H(X,Y) - min(H(X), H(Y))) / max(H(X), H(Y))
1430///
1431/// Dispatches based on `max_order`. If 0, uses marginals; else uses rates.
1432///
1433/// Range: [0, 1].
1434/// * 0: Identity (X determines Y and Y determines X).
1435/// * 1: Independence (X and Y share no information).
1436#[inline(always)]
1437pub fn ned_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
1438    with_default_ctx(|ctx| ctx.ned_bytes(x, y, max_order))
1439}
1440
1441/// Marginal NED (exact/histogram)
1442pub fn ned_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
1443    let (x, y) = aligned_prefix(x, y);
1444    let h_x = marginal_entropy_bytes(x);
1445    let h_y = marginal_entropy_bytes(y);
1446    let h_xy = joint_marginal_entropy_bytes(x, y);
1447    let min_h = h_x.min(h_y);
1448    let max_h = h_x.max(h_y);
1449    if max_h == 0.0 {
1450        0.0
1451    } else {
1452        ((h_xy - min_h) / max_h).clamp(0.0, 1.0)
1453    }
1454}
1455
1456/// Normalized Entropy Distance (Rate-based)
1457#[inline(always)]
1458pub fn ned_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
1459    with_default_ctx(|ctx| ctx.ned_bytes(x, y, max_order))
1460}
1461
1462/// NED_cons(X,Y) = (H(X,Y) - min(H(X), H(Y))) / H(X,Y)
1463///
1464/// Conservative variant. Dispatches based on `max_order`.
1465#[inline(always)]
1466pub fn ned_cons_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
1467    with_default_ctx(|ctx| ctx.ned_cons_bytes(x, y, max_order))
1468}
1469
1470pub fn ned_cons_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
1471    let h_x = marginal_entropy_bytes(x);
1472    let h_y = marginal_entropy_bytes(y);
1473    let h_xy = joint_marginal_entropy_bytes(x, y);
1474    let min_h = h_x.min(h_y);
1475    if h_xy == 0.0 {
1476        0.0
1477    } else {
1478        ((h_xy - min_h) / h_xy).clamp(0.0, 1.0)
1479    }
1480}
1481
1482#[inline(always)]
1483pub fn ned_cons_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
1484    with_default_ctx(|ctx| ctx.ned_cons_bytes(x, y, max_order))
1485}
1486
1487// ====== NTE: Normalized Transform Effort (Variation of Information) ======
1488
1489/// NTE(X,Y) = VI(X,Y) / max(H(X), H(Y))
1490/// where `VI(X,Y) = H(X|Y) + H(Y|X) = 2H(X,Y) - H(X) - H(Y)`.
1491///
1492/// Represents the "effort" required to transform X into Y (and vice versa) relative
1493/// to their complexity.
1494///
1495/// Note: VI can be as large as `H(X) + H(Y)`. If `H(X) ≈ H(Y)`, then VI can be `≈ 2 max(H(X), H(Y))`.
1496/// Thus, NTE is in [0, 2].
1497/// * Values near 0 indicate near-identity.
1498/// * Values near 1+ indicate substantial effort/transform cost (e.g. independence).
1499///
1500/// Dispatches based on `max_order`.
1501#[inline(always)]
1502pub fn nte_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
1503    with_default_ctx(|ctx| ctx.nte_bytes(x, y, max_order))
1504}
1505
1506pub fn nte_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
1507    let (x, y) = aligned_prefix(x, y);
1508    let h_x = marginal_entropy_bytes(x);
1509    let h_y = marginal_entropy_bytes(y);
1510    let h_xy = joint_marginal_entropy_bytes(x, y);
1511    let vi = 2.0 * h_xy - h_x - h_y;
1512    let max_h = h_x.max(h_y);
1513    if max_h == 0.0 {
1514        0.0
1515    } else {
1516        (vi / max_h).clamp(0.0, 2.0)
1517    }
1518}
1519
1520#[inline(always)]
1521pub fn nte_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
1522    with_default_ctx(|ctx| ctx.nte_bytes(x, y, max_order))
1523}
1524
1525// ====== TVD: Total Variation Distance ======
1526
1527/// Compute marginal byte histogram p(i) = count(i) / N for i ∈ [0, 255]
1528#[inline(always)]
1529fn byte_histogram(data: &[u8]) -> [f64; 256] {
1530    let mut counts = [0u64; 256];
1531    for &b in data {
1532        counts[b as usize] += 1;
1533    }
1534    let n = data.len() as f64;
1535    let mut probs = [0.0f64; 256];
1536    if n == 0.0 {
1537        return probs;
1538    }
1539    for i in 0..256 {
1540        probs[i] = counts[i] as f64 / n;
1541    }
1542    probs
1543}
1544
1545/// TVD_marg(X,Y) = (1/2) Σᵢ |p_X(i) - p_Y(i)|
1546///
1547/// Total Variation Distance over marginal byte distributions.
1548/// True metric on probability space. Range: [0, 1].
1549/// 0 = identical distributions, 1 = completely disjoint support.
1550#[inline(always)]
1551pub fn tvd_bytes(x: &[u8], y: &[u8], _max_order: i64) -> f64 {
1552    if x.is_empty() || y.is_empty() {
1553        return 0.0;
1554    }
1555    let p_x = byte_histogram(x);
1556    let p_y = byte_histogram(y);
1557
1558    let mut sum = 0.0f64;
1559    for i in 0..256 {
1560        sum += (p_x[i] - p_y[i]).abs();
1561    }
1562
1563    (sum / 2.0).clamp(0.0, 1.0)
1564}
1565
1566// ====== NHD: Normalized Hellinger Distance ======
1567
1568/// NHD(X,Y) = sqrt(1 - BC(X,Y)) where BC = Σᵢ sqrt(p_X(i) · p_Y(i))
1569///
1570/// Normalized Hellinger Distance over marginal byte distributions.
1571/// True metric. Range: [0, 1]. 0 = identical, 1 = disjoint support.
1572#[inline(always)]
1573pub fn nhd_bytes(x: &[u8], y: &[u8], _max_order: i64) -> f64 {
1574    if x.is_empty() || y.is_empty() {
1575        return 0.0;
1576    }
1577    let p_x = byte_histogram(x);
1578    let p_y = byte_histogram(y);
1579
1580    // Bhattacharyya coefficient: BC = Σᵢ sqrt(p_X(i) · p_Y(i))
1581    let mut bc = 0.0f64;
1582    for i in 0..256 {
1583        bc += (p_x[i] * p_y[i]).sqrt();
1584    }
1585
1586    // NHD = sqrt(1 - BC)
1587    (1.0 - bc).max(0.0).sqrt()
1588}
1589
1590// ====== Other Information-Theoretic Measures ======
1591
1592/// Compute cross-entropy H_{train}(test) - score test_data under model trained on train_data.
1593///
1594/// Dispatches based on `max_order`.
1595#[inline(always)]
1596pub fn cross_entropy_bytes(test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
1597    with_default_ctx(|ctx| ctx.cross_entropy_bytes(test_data, train_data, max_order))
1598}
1599
1600/// Compute cross-entropy rate using ROSA/CTW/RWKV.
1601/// Training model on `train_data` and evaluating probability of `test_data`.
1602#[inline(always)]
1603pub fn cross_entropy_rate_bytes(test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
1604    with_default_ctx(|ctx| ctx.cross_entropy_rate_bytes(test_data, train_data, max_order))
1605}
1606
1607/// Kullback-Leibler Divergence D_KL(P || Q) = Σ p(x) log(p(x) / q(x))
1608///
1609/// Marginal only. Measure of how one probability distribution is different from a second.
1610pub fn d_kl_bytes(x: &[u8], y: &[u8]) -> f64 {
1611    if x.is_empty() || y.is_empty() {
1612        return 0.0;
1613    }
1614    let p_x = byte_histogram(x);
1615    let p_y = byte_histogram(y);
1616    let mut d_kl = 0.0f64;
1617    for i in 0..256 {
1618        if p_x[i] > 0.0 {
1619            let q_y = p_y[i].max(1e-12);
1620            d_kl += p_x[i] * (p_x[i] / q_y).log2();
1621        }
1622    }
1623    d_kl.max(0.0)
1624}
1625
1626/// Jensen-Shannon Divergence JSD(P || Q) = 1/2 D_KL(P || M) + 1/2 D_KL(Q || M)
1627/// where M = 1/2 (P + Q)
1628///
1629/// Marginal only. Symmetrized and smoothed version of KL divergence. Range `[0,1]`.
1630pub fn js_div_bytes(x: &[u8], y: &[u8]) -> f64 {
1631    if x.is_empty() || y.is_empty() {
1632        return 0.0;
1633    }
1634    let p_x = byte_histogram(x);
1635    let p_y = byte_histogram(y);
1636    let mut m = [0.0f64; 256];
1637    for i in 0..256 {
1638        m[i] = 0.5 * (p_x[i] + p_y[i]);
1639    }
1640
1641    let mut kl_pm = 0.0f64;
1642    let mut kl_qm = 0.0f64;
1643    for i in 0..256 {
1644        if p_x[i] > 0.0 {
1645            kl_pm += p_x[i] * (p_x[i] / m[i]).log2();
1646        }
1647        if p_y[i] > 0.0 {
1648            kl_qm += p_y[i] * (p_y[i] / m[i]).log2();
1649        }
1650    }
1651    (0.5 * kl_pm + 0.5 * kl_qm).max(0.0)
1652}
1653
1654// ====== Path-based convenience wrappers ======
1655
1656/// NED for files.
1657pub fn ned_paths(x: &str, y: &str, max_order: i64) -> f64 {
1658    let (bx, by) = rayon::join(
1659        || std::fs::read(x).expect("failed to read x"),
1660        || std::fs::read(y).expect("failed to read y"),
1661    );
1662    ned_bytes(&bx, &by, max_order)
1663}
1664
1665/// NTE for files.
1666pub fn nte_paths(x: &str, y: &str, max_order: i64) -> f64 {
1667    let (bx, by) = rayon::join(
1668        || std::fs::read(x).expect("failed to read x"),
1669        || std::fs::read(y).expect("failed to read y"),
1670    );
1671    nte_bytes(&bx, &by, max_order)
1672}
1673
1674/// TVD for files.
1675pub fn tvd_paths(x: &str, y: &str, max_order: i64) -> f64 {
1676    let (bx, by) = rayon::join(
1677        || std::fs::read(x).expect("failed to read x"),
1678        || std::fs::read(y).expect("failed to read y"),
1679    );
1680    tvd_bytes(&bx, &by, max_order)
1681}
1682
1683/// NHD for files.
1684pub fn nhd_paths(x: &str, y: &str, max_order: i64) -> f64 {
1685    let (bx, by) = rayon::join(
1686        || std::fs::read(x).expect("failed to read x"),
1687        || std::fs::read(y).expect("failed to read y"),
1688    );
1689    nhd_bytes(&bx, &by, max_order)
1690}
1691
1692/// Mutual Information for files.
1693pub fn mutual_information_paths(x: &str, y: &str, max_order: i64) -> f64 {
1694    let (bx, by) = rayon::join(
1695        || std::fs::read(x).expect("failed to read x"),
1696        || std::fs::read(y).expect("failed to read y"),
1697    );
1698    mutual_information_bytes(&bx, &by, max_order)
1699}
1700
1701/// Conditional Entropy for files.
1702pub fn conditional_entropy_paths(x: &str, y: &str, max_order: i64) -> f64 {
1703    let (bx, by) = rayon::join(
1704        || std::fs::read(x).expect("failed to read x"),
1705        || std::fs::read(y).expect("failed to read y"),
1706    );
1707    conditional_entropy_bytes(&bx, &by, max_order)
1708}
1709
1710/// Cross-Entropy for files.
1711pub fn cross_entropy_paths(x: &str, y: &str, max_order: i64) -> f64 {
1712    let (bx, by) = rayon::join(
1713        || std::fs::read(x).expect("failed to read x"),
1714        || std::fs::read(y).expect("failed to read y"),
1715    );
1716    cross_entropy_bytes(&bx, &by, max_order)
1717}
1718
1719/// KL Divergence for files.
1720pub fn kl_divergence_paths(x: &str, y: &str) -> f64 {
1721    let (bx, by) = rayon::join(
1722        || std::fs::read(x).expect("failed to read x"),
1723        || std::fs::read(y).expect("failed to read y"),
1724    );
1725    d_kl_bytes(&bx, &by)
1726}
1727
1728/// Jensen-Shannon Divergence for files.
1729pub fn js_divergence_paths(x: &str, y: &str) -> f64 {
1730    let (bx, by) = rayon::join(
1731        || std::fs::read(x).expect("failed to read x"),
1732        || std::fs::read(y).expect("failed to read y"),
1733    );
1734    js_div_bytes(&bx, &by)
1735}
1736
1737// ====== Primitives 6 & 7 ======
1738
1739/// Primitive 6: Intrinsic Dependence (Redundancy Ratio).
1740///
1741/// Measures how much structure is intrinsic to the sample, relative to its
1742/// own marginal entropy baseline.
1743///
1744/// `R = (H_marginal - H_rate) / H_marginal`
1745///
1746/// Clamped to `[0,1]`.
1747///
1748/// Interpretation:
1749///   - `R → 0`: Data is close to i.i.d./max-entropy (little intrinsic structure; highly extrinsically explainable by priors).
1750///   - `R → 1`: Data is highly predictable from its own past (strong intrinsic dependence; e.g., periodic strings like 010101...).
1751#[inline(always)]
1752pub fn intrinsic_dependence_bytes(data: &[u8], max_order: i64) -> f64 {
1753    with_default_ctx(|ctx| ctx.intrinsic_dependence_bytes(data, max_order))
1754}
1755
1756/// Primitive 7: Resistance under Allowed Transformations.
1757///
1758/// Measures how much information is preserved after a transformation `T` is applied to `X`.
1759///
1760/// `Resistance(X, T) = I(X; T(X)) / H(X)`
1761///
1762/// Range `[0,1]` (with guard for `H(X)=0`).
1763/// * 1 means perfectly resistant (identity transformation).
1764/// * 0 means the transformation destroyed all information (e.g. mapping everything to a constant).
1765///
1766/// Assumes X and T(X) are aligned.
1767#[inline(always)]
1768pub fn resistance_to_transformation_bytes(x: &[u8], tx: &[u8], max_order: i64) -> f64 {
1769    with_default_ctx(|ctx| ctx.resistance_to_transformation_bytes(x, tx, max_order))
1770}
1771
1772#[cfg(test)]
1773mod tests {
1774    use super::*;
1775
1776    #[test]
1777    fn ncd_basic_identity_nonnegative() {
1778        let x = b"abcdabcdabcd";
1779        let d = ncd_bytes(x, x, "5", NcdVariant::Vitanyi);
1780        assert!(d >= -1e-9);
1781    }
1782
1783    #[test]
1784    fn shannon_identities_marginal_aligned() {
1785        let x = b"abracadabra";
1786        let y = b"abracadabra";
1787
1788        let h = marginal_entropy_bytes(x);
1789        let mi = mutual_information_bytes(x, y, 0);
1790        let h_xy = joint_marginal_entropy_bytes(x, y);
1791        let h_x_given_y = conditional_entropy_bytes(x, y, 0);
1792        let ned = ned_bytes(x, y, 0);
1793        let nte = nte_bytes(x, y, 0);
1794
1795        assert!((h_xy - h).abs() < 1e-12);
1796        assert!(h_x_given_y.abs() < 1e-12);
1797        assert!((mi - h).abs() < 1e-12);
1798        assert!(ned.abs() < 1e-12);
1799        assert!(nte.abs() < 1e-12);
1800    }
1801
1802    #[test]
1803    fn shannon_identities_rate_aligned_reasonable() {
1804        let x = b"the quick brown fox jumps over the lazy dog";
1805        let y = b"the quick brown fox jumps over the lazy dog";
1806        let max_order = 8;
1807
1808        let h_x = entropy_rate_bytes(x, max_order);
1809        let h_xy = joint_entropy_rate_bytes(x, y, max_order);
1810        let h_x_given_y = conditional_entropy_rate_bytes(x, y, max_order);
1811        let mi = mutual_information_bytes(x, y, max_order);
1812        let ned = ned_bytes(x, y, max_order);
1813
1814        // Finite-sample estimators won't be exact; allow reasonable tolerance.
1815        let tol = 0.2;
1816        assert!((h_xy - h_x).abs() < tol);
1817        assert!(h_x_given_y < tol);
1818        assert!((mi - h_x).abs() < tol);
1819        assert!(ned < tol);
1820    }
1821
1822    #[test]
1823    fn resistance_identity_is_one() {
1824        let x = b"some repeated repeated repeated text";
1825        let r0 = resistance_to_transformation_bytes(x, x, 0);
1826        let r8 = resistance_to_transformation_bytes(x, x, 8);
1827        assert!((r0 - 1.0).abs() < 1e-12);
1828        assert!((r8 - 1.0).abs() < 1e-6);
1829    }
1830
1831    #[test]
1832    fn marginal_metrics_empty_inputs_are_zero() {
1833        let empty: &[u8] = &[];
1834        let x = b"abc";
1835
1836        assert_eq!(tvd_bytes(empty, x, 0), 0.0);
1837        assert_eq!(tvd_bytes(x, empty, 0), 0.0);
1838        assert_eq!(nhd_bytes(empty, x, 0), 0.0);
1839        assert_eq!(nhd_bytes(x, empty, 0), 0.0);
1840        assert_eq!(d_kl_bytes(empty, x), 0.0);
1841        assert_eq!(d_kl_bytes(x, empty), 0.0);
1842        assert_eq!(js_div_bytes(empty, x), 0.0);
1843        assert_eq!(js_div_bytes(x, empty), 0.0);
1844    }
1845
1846    #[test]
1847    fn marginal_cross_entropy_empty_test_is_zero() {
1848        let empty: &[u8] = &[];
1849        let y = b"abc";
1850        let ctx = InfotheoryCtx::with_zpaq("5");
1851        assert_eq!(ctx.cross_entropy_bytes(empty, y, 0), 0.0);
1852    }
1853
1854    #[test]
1855    fn backend_switching_test() {
1856        let x = b"hello world context";
1857
1858        // Default is RosaPlus
1859        let h_rosa = entropy_rate_bytes(x, 8);
1860
1861        // Switch to CTW
1862        set_default_ctx(InfotheoryCtx::new(
1863            RateBackend::Ctw { depth: 16 },
1864            NcdBackend::default(),
1865        ));
1866
1867        let h_ctw = entropy_rate_bytes(x, 8);
1868
1869        // They should generally be different, but most importantly, CTW worked
1870        assert!(h_ctw > 0.0);
1871
1872        // Reset to default
1873        set_default_ctx(InfotheoryCtx::default());
1874        let h_rosa_back = entropy_rate_bytes(x, 8);
1875        assert!((h_rosa - h_rosa_back).abs() < 1e-12);
1876    }
1877
1878    #[test]
1879    fn ctw_early_updates_work() {
1880        // Test that CTW produces valid predictions from the very start,
1881        // not just after `depth` symbols have been processed.
1882        use crate::ctw::ContextTree;
1883
1884        let mut tree = ContextTree::new(16);
1885
1886        // Even the first prediction should be valid (not NaN, not 0)
1887        let p0 = tree.predict(false);
1888        let p1 = tree.predict(true);
1889
1890        // Initial KT estimator gives 0.5 / 1 = 0.5 for each symbol
1891        assert!((p0 - 0.5).abs() < 1e-10, "p0 should be ~0.5, got {}", p0);
1892        assert!((p1 - 0.5).abs() < 1e-10, "p1 should be ~0.5, got {}", p1);
1893        assert!((p0 + p1 - 1.0).abs() < 1e-10, "p0 + p1 should = 1.0");
1894
1895        // Update with a few symbols and verify log_prob becomes negative (valid)
1896        for _ in 0..5 {
1897            tree.update(true);
1898            tree.update(false);
1899        }
1900
1901        let log_prob = tree.get_log_block_probability();
1902        assert!(
1903            log_prob < 0.0,
1904            "log_prob should be negative (< log 1), got {}",
1905            log_prob
1906        );
1907        assert!(log_prob.is_finite(), "log_prob should be finite");
1908    }
1909
1910    #[test]
1911    fn nte_can_exceed_one() {
1912        // Test that NTE is properly clamped to [0, 2] instead of [0, 1]
1913        // For independent sequences with similar entropy, NTE can approach 2.0
1914        //
1915        // Note: For *marginal* NTE, due to how joint entropy works for aligned pairs,
1916        // it's mathematically bounded differently. The fix for NTE clamping primarily
1917        // affects *rate*-based NTE where VI can truly be 2*max(H).
1918        //
1919        // We test that the clamp upper bound is at least > 1.0 for cases where VI > max(H)
1920
1921        // Use CTW backend for rate-based test
1922        set_default_ctx(InfotheoryCtx::new(
1923            RateBackend::Ctw { depth: 8 },
1924            NcdBackend::default(),
1925        ));
1926
1927        // Generate two completely different patterns - should have high VI
1928        let x: Vec<u8> = (0..200).map(|i| (i % 2) as u8).collect(); // 010101...
1929        let y: Vec<u8> = (0..200).map(|i| ((i + 1) % 2) as u8).collect(); // 101010...
1930
1931        let nte_rate = nte_rate_backend(&x, &y, -1, &RateBackend::Ctw { depth: 8 });
1932
1933        // With the fix, NTE should not be clamped to 1.0
1934        // It may or may not exceed 1.0 depending on the specifics, but it should be allowed to
1935        assert!(
1936            nte_rate >= 0.0 && nte_rate <= 2.0 + 1e-9,
1937            "NTE should be in [0, 2], got {}",
1938            nte_rate
1939        );
1940
1941        // Reset context
1942        set_default_ctx(InfotheoryCtx::default());
1943    }
1944
1945    #[test]
1946    fn ctw_empty_data_returns_zero() {
1947        // Verify empty data doesn't cause division-by-zero or NaN
1948        set_default_ctx(InfotheoryCtx::new(
1949            RateBackend::Ctw { depth: 16 },
1950            NcdBackend::default(),
1951        ));
1952
1953        let empty: &[u8] = &[];
1954        let h = entropy_rate_bytes(empty, -1);
1955        assert_eq!(h, 0.0, "empty data should return 0.0 entropy");
1956
1957        // Reset
1958        set_default_ctx(InfotheoryCtx::default());
1959    }
1960
1961    #[test]
1962    fn datagen_bernoulli_entropy_estimate() {
1963        // Test that estimated entropy is close to theoretical for Bernoulli(0.5)
1964        let p = 0.5;
1965        let theoretical_h = crate::datagen::bernoulli_entropy(p);
1966        assert!((theoretical_h - 1.0).abs() < 1e-10);
1967
1968        // Generate data and check marginal entropy is close to theoretical
1969        let data = crate::datagen::bernoulli(10000, p, 42);
1970        let estimated_h = marginal_entropy_bytes(&data);
1971
1972        // Should be close to 1.0 bit (since values are 0 or 1)
1973        assert!(
1974            (estimated_h - theoretical_h).abs() < 0.1,
1975            "estimated H={} should be close to theoretical H={}",
1976            estimated_h,
1977            theoretical_h
1978        );
1979    }
1980}