rwkvzip/
lib.rs

1// rwkvzip - High-performance neural network compressor using RWKV7.
2//
3// This library provides lossless compression by leveraging the RWKV7 language model's
4// predictive capabilities to generate probability distributions, which are then
5// compressed via entropy coding (arithmetic coding or rANS).
6//
7// # Architecture
8//
9// - **Byte-level compression**: Operates directly on raw bytes (vocab_size=256)
10// - **Infinite context**: RWKV7's recurrent architecture maintains state indefinitely
11// - **x86_64 optimized**: AVX2/FMA SIMD throughout, no external BLAS dependencies
12// - **Correct-by-construction**: Information-theoretically sound implementation
13
14use anyhow::{bail, Result};
15use std::io::{Cursor, Read, Write};
16use std::path::Path;
17use std::sync::Arc;
18
19pub mod coders;
20pub mod rwkv7;
21
22use coders::{
23    quantize_pdf_to_cdf_inplace, quantize_pdf_to_rans_cdf_with_buffer, softmax_pdf_floor_inplace,
24    softmax_pdf_inplace, ArithmeticDecoder, ArithmeticEncoder, BlockedRansDecoder,
25    BlockedRansEncoder, Cdf, ANS_TOTAL, CDF_TOTAL,
26};
27
28pub use rwkv7::{Config, Model, ScratchBuffers, State};
29
30// =============================================================================
31// File Format Constants
32// =============================================================================
33
34/// File format magic number: "GPTZ" in little-endian (0x47505A54 as ASCII).
35/// Used to identify valid rwkvzip compressed files.
36pub const MAGIC: u32 = 0x5a505447;
37
38/// File format version. Increment on breaking changes to ensure compatibility.
39pub const VERSION: u8 = 2;
40
41/// Vocabulary size for byte-level compression.
42/// Each byte (0-255) is treated as a separate symbol.
43pub const VOCAB_SIZE: usize = 256;
44
45// =============================================================================
46// Entropy Coder Selection
47// =============================================================================
48
49/// Entropy coder type for compression.
50///
51/// Both coders are lossless and produce equivalent results; the choice
52/// affects compression speed and ratio.
53#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
54pub enum CoderType {
55    /// Arithmetic coding: optimal compression ratio, slightly slower.
56    /// Recommended for small files or when compression ratio is critical.
57    #[default]
58    AC,
59    /// rANS coding: near-optimal compression with better throughput.
60    /// Recommended for larger files where speed matters more.
61    RANS,
62}
63
64struct CountingWriter {
65    n: u64,
66}
67
68impl CountingWriter {
69    #[inline]
70    fn new() -> Self {
71        Self { n: 0 }
72    }
73
74    #[inline]
75    fn bytes_written(&self) -> u64 {
76        self.n
77    }
78}
79
80impl Write for CountingWriter {
81    #[inline]
82    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
83        let n = buf.len();
84        self.n = self.n.saturating_add(n as u64);
85        Ok(n)
86    }
87
88    #[inline]
89    fn flush(&mut self) -> std::io::Result<()> {
90        Ok(())
91    }
92}
93
94impl std::fmt::Display for CoderType {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        match self {
97            CoderType::AC => write!(f, "AC"),
98            CoderType::RANS => write!(f, "rANS"),
99        }
100    }
101}
102
103// =============================================================================
104// File Header
105// =============================================================================
106
107/// Header structure for compressed data files.
108///
109/// Layout (18 bytes total):
110/// - magic: 4 bytes (little-endian u32)
111/// - version: 1 byte
112/// - coder: 1 byte (0=AC, 1=rANS)
113/// - original_len: 8 bytes (little-endian u64)
114/// - crc32: 4 bytes (little-endian u32)
115#[derive(Debug, Clone)]
116pub struct Header {
117    /// Magic number for format identification (must be MAGIC).
118    pub magic: u32,
119    /// Format version for compatibility checking.
120    pub version: u8,
121    /// Coder type used (0=AC, 1=rANS).
122    pub coder: u8,
123    /// Original uncompressed data length in bytes.
124    pub original_len: u64,
125    /// CRC32 checksum of original data for integrity verification.
126    pub crc32: u32,
127}
128
129impl Header {
130    /// Total header size in bytes.
131    pub const SIZE: usize = 4 + 1 + 1 + 8 + 4; // 18 bytes
132
133    /// Create a new header for compressed data.
134    pub fn new(coder: CoderType, original_len: u64, crc32: u32) -> Self {
135        Self {
136            magic: MAGIC,
137            version: VERSION,
138            coder: match coder {
139                CoderType::AC => 0,
140                CoderType::RANS => 1,
141            },
142            original_len,
143            crc32,
144        }
145    }
146
147    /// Serialize header to a writer (little-endian format).
148    pub fn write<W: Write>(&self, w: &mut W) -> Result<()> {
149        w.write_all(&self.magic.to_le_bytes())?;
150        w.write_all(&[self.version])?;
151        w.write_all(&[self.coder])?;
152        w.write_all(&self.original_len.to_le_bytes())?;
153        w.write_all(&self.crc32.to_le_bytes())?;
154        Ok(())
155    }
156
157    /// Deserialize header from a reader (little-endian format).
158    pub fn read<R: Read>(r: &mut R) -> Result<Self> {
159        let mut buf4 = [0u8; 4];
160        let mut buf8 = [0u8; 8];
161        let mut buf1 = [0u8; 1];
162
163        r.read_exact(&mut buf4)?;
164        let magic = u32::from_le_bytes(buf4);
165        if magic != MAGIC {
166            bail!(
167                "Invalid magic number: expected 0x{:08X}, got 0x{:08X}",
168                MAGIC,
169                magic
170            );
171        }
172
173        r.read_exact(&mut buf1)?;
174        let version = buf1[0];
175        if version > VERSION {
176            bail!(
177                "Unsupported version: {} (max supported: {})",
178                version,
179                VERSION
180            );
181        }
182
183        r.read_exact(&mut buf1)?;
184        let coder = buf1[0];
185
186        r.read_exact(&mut buf8)?;
187        let original_len = u64::from_le_bytes(buf8);
188
189        r.read_exact(&mut buf4)?;
190        let crc32 = u32::from_le_bytes(buf4);
191
192        Ok(Self {
193            magic,
194            version,
195            coder,
196            original_len,
197            crc32,
198        })
199    }
200
201    /// Get the coder type from the header byte.
202    pub fn coder_type(&self) -> CoderType {
203        match self.coder {
204            0 => CoderType::AC,
205            _ => CoderType::RANS,
206        }
207    }
208}
209
210// =============================================================================
211// CRC32 Checksum
212// =============================================================================
213
214/// Compute CRC32 checksum for data integrity verification.
215///
216/// Uses the crc32fast crate for hardware-accelerated computation.
217pub fn crc32(data: &[u8]) -> u32 {
218    let mut hasher = crc32fast::Hasher::new();
219    hasher.update(data);
220    hasher.finalize()
221}
222
223// =============================================================================
224// Compressor
225// =============================================================================
226
227/// Main compressor/decompressor that combines RWKV7 inference with entropy coding.
228///
229/// The compressor maintains internal state and pre-allocated buffers to minimize
230/// allocations during the compression/decompression hot path.
231pub struct Compressor {
232    /// RWKV7 model for generating probability distributions.
233    pub model: Arc<Model>,
234    /// Model state (recurrent hidden states).
235    pub state: State,
236    pub scratch: ScratchBuffers,
237    /// Pre-allocated PDF buffer (eliminates allocations in compression loop).
238    pub pdf_buffer: Vec<f64>,
239    /// Reusable AC CDF buffer (vocab_size + 1 entries).
240    pub cdf_buffer_ac: Vec<u32>,
241    /// Reusable rANS CDF buffer (vocab_size + 1 entries).
242    pub cdf_buffer_rans: Vec<u32>,
243    /// Scratch frequencies for rANS quantization.
244    pub rans_freq_buffer: Vec<i64>,
245}
246
247impl Clone for Compressor {
248    fn clone(&self) -> Self {
249        let mut cloned = Self::new_from_model(self.model.clone());
250        cloned.state = self.state.clone();
251        cloned.pdf_buffer.clone_from(&self.pdf_buffer);
252        cloned.cdf_buffer_ac.clone_from(&self.cdf_buffer_ac);
253        cloned.cdf_buffer_rans.clone_from(&self.cdf_buffer_rans);
254        cloned.rans_freq_buffer.clone_from(&self.rans_freq_buffer);
255        cloned
256    }
257}
258
259impl Compressor {
260    /// Create a new compressor with the given model.
261    ///
262    /// # Arguments
263    /// * `model_path` - Path to RWKV7 model weights (.safetensors format)
264    ///
265    /// # Returns
266    /// A new Compressor ready for compression/decompression operations.
267    pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
268        let model = Arc::new(Model::load(model_path)?);
269        Ok(Self::new_from_model(model))
270    }
271
272    pub fn load_model<P: AsRef<Path>>(model_path: P) -> Result<Arc<Model>> {
273        Ok(Arc::new(Model::load(model_path)?))
274    }
275
276    pub fn new_from_model(model: Arc<Model>) -> Self {
277        let state = model.new_state();
278        let vocab_size = model.config().vocab_size;
279        let scratch = ScratchBuffers::new(model.config());
280        Self {
281            model,
282            state,
283            scratch,
284            pdf_buffer: vec![0.0f64; vocab_size],
285            cdf_buffer_ac: vec![0u32; vocab_size + 1],
286            cdf_buffer_rans: vec![0u32; vocab_size + 1],
287            rans_freq_buffer: vec![0i64; vocab_size],
288        }
289    }
290
291    /// Reset the model state to initial values.
292    ///
293    /// Call this between independent compression/decompression operations
294    /// to ensure a clean state.
295    pub fn reset(&mut self) {
296        self.state.reset();
297    }
298
299    /// Get the vocabulary size (should always be 256 for byte-level).
300    pub fn vocab_size(&self) -> usize {
301        self.model.config().vocab_size
302    }
303
304    /// Compress data using the specified entropy coder.
305    ///
306    /// # Arguments
307    /// * `data` - Raw bytes to compress
308    /// * `coder` - Entropy coder to use (AC or rANS)
309    ///
310    /// # Returns
311    /// Compressed data including header with checksum.
312    pub fn compress(&mut self, data: &[u8], coder: CoderType) -> Result<Vec<u8>> {
313        let mut output = Vec::new();
314        self.compress_into(data, coder, &mut output)?;
315        Ok(output)
316    }
317
318    pub fn compress_into<W: Write>(
319        &mut self,
320        data: &[u8],
321        coder: CoderType,
322        w: &mut W,
323    ) -> Result<()> {
324        self.state.reset();
325
326        let checksum = crc32(data);
327        let header = Header::new(coder, data.len() as u64, checksum);
328        header.write(w)?;
329
330        match coder {
331            CoderType::AC => self.compress_ac(data, w)?,
332            CoderType::RANS => self.compress_rans(data, w)?,
333        }
334
335        Ok(())
336    }
337
338    pub fn compress_chain_into<W: Write>(
339        &mut self,
340        parts: &[&[u8]],
341        coder: CoderType,
342        w: &mut W,
343    ) -> Result<()> {
344        self.state.reset();
345
346        let mut total_len: u64 = 0;
347        let mut hasher = crc32fast::Hasher::new();
348        for p in parts {
349            total_len = total_len.saturating_add(p.len() as u64);
350            hasher.update(p);
351        }
352        let checksum = hasher.finalize();
353
354        let header = Header::new(coder, total_len, checksum);
355        header.write(w)?;
356
357        let it = parts.iter().flat_map(|p| p.iter().copied());
358        match coder {
359            CoderType::AC => self.compress_ac_iter(it, w)?,
360            CoderType::RANS => self.compress_rans_iter(it, w)?,
361        }
362
363        Ok(())
364    }
365
366    pub fn compress_size(&mut self, data: &[u8], coder: CoderType) -> Result<u64> {
367        let mut w = CountingWriter::new();
368        self.compress_into(data, coder, &mut w)?;
369        Ok(w.bytes_written())
370    }
371
372    pub fn compress_size_chain(&mut self, parts: &[&[u8]], coder: CoderType) -> Result<u64> {
373        let mut w = CountingWriter::new();
374        self.compress_chain_into(parts, coder, &mut w)?;
375        Ok(w.bytes_written())
376    }
377
378    /// Compress using arithmetic coding.
379    fn compress_ac<W: Write>(&mut self, data: &[u8], output: &mut W) -> Result<()> {
380        self.compress_ac_iter(data.iter().copied(), output)
381    }
382
383    fn compress_ac_iter<I, W: Write>(&mut self, data: I, output: &mut W) -> Result<()>
384    where
385        I: IntoIterator<Item = u8>,
386    {
387        let mut encoder = ArithmeticEncoder::new(output);
388        let vocab_size = self.vocab_size();
389
390        // Prime the model with a null byte to establish initial state
391        let logits = self.model.forward(&mut self.scratch, 0, &mut self.state);
392        softmax_pdf_floor_inplace(logits, vocab_size, &mut self.pdf_buffer);
393
394        for byte in data {
395            quantize_pdf_to_cdf_inplace(&self.pdf_buffer, &mut self.cdf_buffer_ac);
396            let sym = byte as usize;
397            let c_lo = self.cdf_buffer_ac[sym] as u64;
398            let c_hi = self.cdf_buffer_ac[sym + 1] as u64;
399            encoder.encode_counts(c_lo, c_hi, CDF_TOTAL as u64)?;
400
401            // Update model state with actual byte for next prediction
402            let logits = self
403                .model
404                .forward(&mut self.scratch, byte as u32, &mut self.state);
405            softmax_pdf_floor_inplace(logits, vocab_size, &mut self.pdf_buffer);
406        }
407
408        let _ = encoder.finish()?;
409        Ok(())
410    }
411
412    /// Compress using rANS coding with block-based encoding.
413    fn compress_rans<W: Write>(&mut self, data: &[u8], output: &mut W) -> Result<()> {
414        self.compress_rans_iter(data.iter().copied(), output)
415    }
416
417    fn compress_rans_iter<I, W: Write>(&mut self, data: I, output: &mut W) -> Result<()>
418    where
419        I: IntoIterator<Item = u8>,
420    {
421        let vocab_size = self.vocab_size();
422
423        // Use blocked encoder (128KB blocks) for streaming large files
424        let mut encoder = BlockedRansEncoder::new();
425
426        // Prime the model with a null byte
427        let logits = self.model.forward(&mut self.scratch, 0, &mut self.state);
428        softmax_pdf_floor_inplace(logits, vocab_size, &mut self.pdf_buffer);
429
430        for byte in data {
431            quantize_pdf_to_rans_cdf_with_buffer(
432                &self.pdf_buffer,
433                &mut self.cdf_buffer_rans,
434                &mut self.rans_freq_buffer,
435            );
436            let sym = byte as usize;
437            let cdf = Cdf::new(
438                self.cdf_buffer_rans[sym],
439                self.cdf_buffer_rans[sym + 1],
440                ANS_TOTAL,
441            );
442            encoder.encode(cdf);
443
444            // Update model state
445            let logits = self
446                .model
447                .forward(&mut self.scratch, byte as u32, &mut self.state);
448            softmax_pdf_floor_inplace(logits, vocab_size, &mut self.pdf_buffer);
449        }
450
451        // Finish encoding and write blocks
452        let blocks = encoder.finish();
453
454        // Write block count
455        output.write_all(&(blocks.len() as u32).to_le_bytes())?;
456
457        // Write each block with length prefix
458        for block in &blocks {
459            output.write_all(&(block.len() as u32).to_le_bytes())?;
460            output.write_all(block)?;
461        }
462
463        Ok(())
464    }
465
466    /// Decompress data.
467    ///
468    /// # Arguments
469    /// * `data` - Compressed data (must include header)
470    ///
471    /// # Returns
472    /// Original decompressed data. Returns error if checksum doesn't match.
473    pub fn decompress(&mut self, data: &[u8]) -> Result<Vec<u8>> {
474        let mut cursor = Cursor::new(data);
475        let header = Header::read(&mut cursor)?;
476
477        self.state.reset();
478
479        let compressed = &data[Header::SIZE..];
480        let result = match header.coder_type() {
481            CoderType::AC => self.decompress_ac(compressed, header.original_len as usize)?,
482            CoderType::RANS => self.decompress_rans(compressed, header.original_len as usize)?,
483        };
484
485        // Verify checksum for data integrity
486        let actual_crc = crc32(&result);
487        if actual_crc != header.crc32 {
488            bail!(
489                "CRC32 mismatch: expected 0x{:08X}, got 0x{:08X}",
490                header.crc32,
491                actual_crc
492            );
493        }
494
495        Ok(result)
496    }
497
498    /// Decompress using arithmetic coding.
499    fn decompress_ac(&mut self, compressed: &[u8], original_len: usize) -> Result<Vec<u8>> {
500        let mut decoder = ArithmeticDecoder::new(compressed)?;
501        let vocab_size = self.vocab_size();
502
503        let mut result = Vec::with_capacity(original_len);
504
505        // Prime with null byte (must match compression)
506        let logits = self.model.forward(&mut self.scratch, 0, &mut self.state);
507        softmax_pdf_floor_inplace(logits, vocab_size, &mut self.pdf_buffer);
508
509        for _ in 0..original_len {
510            quantize_pdf_to_cdf_inplace(&self.pdf_buffer, &mut self.cdf_buffer_ac);
511            let sym = decoder.decode_symbol_counts(&self.cdf_buffer_ac, CDF_TOTAL)?;
512            result.push(sym as u8);
513
514            // Update model state with decoded byte
515            let logits = self
516                .model
517                .forward(&mut self.scratch, sym as u32, &mut self.state);
518            softmax_pdf_floor_inplace(logits, vocab_size, &mut self.pdf_buffer);
519        }
520
521        Ok(result)
522    }
523
524    /// Decompress using rANS coding.
525    fn decompress_rans(&mut self, compressed: &[u8], original_len: usize) -> Result<Vec<u8>> {
526        // Read block count
527        if compressed.len() < 4 {
528            bail!("rANS data too short");
529        }
530        let block_count =
531            u32::from_le_bytes([compressed[0], compressed[1], compressed[2], compressed[3]])
532                as usize;
533
534        // Read blocks
535        let mut blocks = Vec::with_capacity(block_count);
536        let mut pos = 4;
537
538        for _ in 0..block_count {
539            if pos + 4 > compressed.len() {
540                bail!("Truncated block header");
541            }
542            let block_len = u32::from_le_bytes([
543                compressed[pos],
544                compressed[pos + 1],
545                compressed[pos + 2],
546                compressed[pos + 3],
547            ]) as usize;
548            pos += 4;
549
550            if pos + block_len > compressed.len() {
551                bail!("Truncated block data");
552            }
553            blocks.push(&compressed[pos..pos + block_len]);
554            pos += block_len;
555        }
556
557        // Decode using blocked decoder
558        let mut decoder = BlockedRansDecoder::new(blocks);
559        let vocab_size = self.vocab_size();
560        let mut result = Vec::with_capacity(original_len);
561
562        // Prime with null byte
563        let logits = self.model.forward(&mut self.scratch, 0, &mut self.state);
564        softmax_pdf_floor_inplace(logits, vocab_size, &mut self.pdf_buffer);
565
566        for _ in 0..original_len {
567            quantize_pdf_to_rans_cdf_with_buffer(
568                &self.pdf_buffer,
569                &mut self.cdf_buffer_rans,
570                &mut self.rans_freq_buffer,
571            );
572            let sym = decoder.decode(&self.cdf_buffer_rans)?;
573            result.push(sym as u8);
574
575            // Update model state
576            let logits = self
577                .model
578                .forward(&mut self.scratch, sym as u32, &mut self.state);
579            softmax_pdf_floor_inplace(logits, vocab_size, &mut self.pdf_buffer);
580        }
581
582        Ok(result)
583    }
584
585    /// Calculate cross-entropy (bits per byte) for data without compression.
586    ///
587    /// This measures how well the model predicts the data, giving a theoretical
588    /// lower bound on achievable compression. Useful for evaluating model quality.
589    ///
590    /// # Arguments
591    /// * `data` - Data to analyze
592    ///
593    /// # Returns
594    /// Average bits per byte (lower is better, 8.0 means no compression possible).
595    pub fn cross_entropy(&mut self, data: &[u8]) -> Result<f64> {
596        if data.is_empty() {
597            return Ok(0.0);
598        }
599
600        self.state.reset();
601        let vocab_size = self.vocab_size();
602
603        let mut total_bits = 0.0f64;
604
605        // Prime with null byte
606        let logits = self.model.forward(&mut self.scratch, 0, &mut self.state);
607        softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
608
609        for &byte in data {
610            let p = self.pdf_buffer[byte as usize];
611            total_bits -= p.log2();
612            let logits = self
613                .model
614                .forward(&mut self.scratch, byte as u32, &mut self.state);
615            softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
616        }
617
618        Ok(total_bits / (data.len() as f64))
619    }
620
621    pub fn cross_entropy_conditional_chain(
622        &mut self,
623        prefix_parts: &[&[u8]],
624        data: &[u8],
625    ) -> Result<f64> {
626        if data.is_empty() {
627            return Ok(0.0);
628        }
629
630        self.state.reset();
631        let vocab_size = self.vocab_size();
632
633        let logits = self.model.forward(&mut self.scratch, 0, &mut self.state);
634        softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
635
636        for p in prefix_parts {
637            for &byte in *p {
638                let logits = self
639                    .model
640                    .forward(&mut self.scratch, byte as u32, &mut self.state);
641                softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
642            }
643        }
644
645        let mut total_bits = 0.0f64;
646        for &byte in data {
647            let p = self.pdf_buffer[byte as usize];
648            total_bits -= p.log2();
649            let logits = self
650                .model
651                .forward(&mut self.scratch, byte as u32, &mut self.state);
652            softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
653        }
654
655        Ok(total_bits / (data.len() as f64))
656    }
657
658    pub fn cross_entropy_conditional(&mut self, prefix: &[u8], data: &[u8]) -> Result<f64> {
659        if data.is_empty() {
660            return Ok(0.0);
661        }
662
663        self.state.reset();
664        let vocab_size = self.vocab_size();
665
666        // Prime with null byte
667        let logits = self.model.forward(&mut self.scratch, 0, &mut self.state);
668        softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
669
670        // Condition on prefix (update state, no scoring)
671        for &byte in prefix {
672            let logits = self
673                .model
674                .forward(&mut self.scratch, byte as u32, &mut self.state);
675            softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
676        }
677
678        let mut total_bits = 0.0f64;
679        for &byte in data {
680            let p = self.pdf_buffer[byte as usize];
681            total_bits -= p.log2();
682            let logits = self
683                .model
684                .forward(&mut self.scratch, byte as u32, &mut self.state);
685            softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
686        }
687
688        Ok(total_bits / (data.len() as f64))
689    }
690
691    pub fn joint_cross_entropy_aligned_min(&mut self, x: &[u8], y: &[u8]) -> Result<f64> {
692        let n = x.len().min(y.len());
693        if n == 0 {
694            return Ok(0.0);
695        }
696
697        let h_xy = self.joint_cross_entropy_aligned_order(x, y, false)?;
698        let h_yx = self.joint_cross_entropy_aligned_order(x, y, true)?;
699        Ok(h_xy.min(h_yx))
700    }
701
702    fn joint_cross_entropy_aligned_order(&mut self, x: &[u8], y: &[u8], swap: bool) -> Result<f64> {
703        let n = x.len().min(y.len());
704        if n == 0 {
705            return Ok(0.0);
706        }
707
708        self.state.reset();
709        let vocab_size = self.vocab_size();
710
711        let logits = self.model.forward(&mut self.scratch, 0, &mut self.state);
712        softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
713
714        let mut total_bits = 0.0f64;
715        for i in 0..n {
716            let a = if swap { y[i] } else { x[i] };
717            let b = if swap { x[i] } else { y[i] };
718
719            let pa = self.pdf_buffer[a as usize];
720            total_bits -= pa.log2();
721            let logits = self
722                .model
723                .forward(&mut self.scratch, a as u32, &mut self.state);
724            softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
725
726            let pb = self.pdf_buffer[b as usize];
727            total_bits -= pb.log2();
728            let logits = self
729                .model
730                .forward(&mut self.scratch, b as u32, &mut self.state);
731            softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
732        }
733
734        Ok(total_bits / (n as f64))
735    }
736}
737
738// =============================================================================
739// Compression Statistics
740// =============================================================================
741
742/// Statistics from a compression operation.
743#[derive(Debug, Clone)]
744pub struct CompressionStats {
745    /// Original size in bytes.
746    pub original_size: usize,
747    /// Compressed size in bytes (including header).
748    pub compressed_size: usize,
749    /// Compression ratio (original/compressed). Higher is better.
750    pub ratio: f64,
751    /// Bits per byte. Lower is better (theoretical minimum: ~0, maximum: 8).
752    pub bits_per_byte: f64,
753    /// Time taken in seconds.
754    pub time_seconds: f64,
755    /// Throughput in bytes per second.
756    pub throughput: f64,
757}
758
759impl std::fmt::Display for CompressionStats {
760    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
761        write!(
762            f,
763            "{} bytes -> {} bytes | ratio={:.3} | bits/byte={:.3} | time={:.2}s | {:.0} B/s",
764            self.original_size,
765            self.compressed_size,
766            self.ratio,
767            self.bits_per_byte,
768            self.time_seconds,
769            self.throughput,
770        )
771    }
772}
773
774/// Compress data and return both the compressed output and statistics.
775///
776/// This is a convenience function that wraps `Compressor::compress` with timing.
777pub fn compress_with_stats(
778    compressor: &mut Compressor,
779    data: &[u8],
780    coder: CoderType,
781) -> Result<(Vec<u8>, CompressionStats)> {
782    let start = std::time::Instant::now();
783    let compressed = compressor.compress(data, coder)?;
784    let elapsed = start.elapsed().as_secs_f64();
785
786    let stats = CompressionStats {
787        original_size: data.len(),
788        compressed_size: compressed.len(),
789        ratio: data.len() as f64 / compressed.len() as f64,
790        bits_per_byte: (compressed.len() as f64 * 8.0) / data.len() as f64,
791        time_seconds: elapsed,
792        throughput: data.len() as f64 / elapsed,
793    };
794
795    Ok((compressed, stats))
796}
797
798// =============================================================================
799// Tests
800// =============================================================================
801
802#[cfg(test)]
803mod tests {
804    use super::*;
805
806    #[test]
807    fn test_header_roundtrip() {
808        let header = Header::new(CoderType::AC, 12345, 0xDEADBEEF);
809
810        let mut buf = Vec::new();
811        header.write(&mut buf).unwrap();
812
813        assert_eq!(buf.len(), Header::SIZE);
814
815        let mut cursor = Cursor::new(&buf);
816        let read_header = Header::read(&mut cursor).unwrap();
817
818        assert_eq!(read_header.magic, MAGIC);
819        assert_eq!(read_header.version, VERSION);
820        assert_eq!(read_header.coder, 0);
821        assert_eq!(read_header.original_len, 12345);
822        assert_eq!(read_header.crc32, 0xDEADBEEF);
823    }
824
825    #[test]
826    fn test_header_rans() {
827        let header = Header::new(CoderType::RANS, 67890, 0xCAFEBABE);
828        assert_eq!(header.coder, 1);
829        assert_eq!(header.coder_type(), CoderType::RANS);
830    }
831
832    #[test]
833    fn test_coder_type_display() {
834        assert_eq!(format!("{}", CoderType::AC), "AC");
835        assert_eq!(format!("{}", CoderType::RANS), "rANS");
836    }
837
838    #[test]
839    fn test_crc32() {
840        let data = b"Hello, World!";
841        let c = crc32(data);
842        assert_ne!(c, 0);
843        // CRC32 should be deterministic
844        assert_eq!(c, crc32(data));
845    }
846
847    #[test]
848    fn test_crc32_different_data() {
849        let c1 = crc32(b"Hello");
850        let c2 = crc32(b"World");
851        assert_ne!(c1, c2);
852    }
853
854    #[test]
855    fn test_crc32_known_vector() {
856        // Standard CRC-32 (ISO-HDLC) test vector.
857        assert_eq!(crc32(b"123456789"), 0xCBF4_3926);
858    }
859
860    #[test]
861    fn test_header_rejects_invalid_magic() {
862        let mut buf = Vec::new();
863        let header = Header::new(CoderType::AC, 1, 2);
864        header.write(&mut buf).unwrap();
865        // Corrupt magic.
866        buf[0] ^= 0xFF;
867
868        let mut cursor = Cursor::new(&buf);
869        let err = Header::read(&mut cursor).unwrap_err();
870        let msg = format!("{err:#}");
871        assert!(msg.contains("Invalid magic number"));
872    }
873}