1use anyhow::{bail, Result};
15use std::io::{Cursor, Read, Write};
16use std::path::Path;
17use std::sync::Arc;
18
19pub mod coders;
20pub mod rwkv7;
21
22use coders::{
23 quantize_pdf_to_cdf_inplace, quantize_pdf_to_rans_cdf_with_buffer, softmax_pdf_floor_inplace,
24 softmax_pdf_inplace, ArithmeticDecoder, ArithmeticEncoder, BlockedRansDecoder,
25 BlockedRansEncoder, Cdf, ANS_TOTAL, CDF_TOTAL,
26};
27
28pub use rwkv7::{Config, Model, ScratchBuffers, State};
29
30pub const MAGIC: u32 = 0x5a505447;
37
38pub const VERSION: u8 = 2;
40
41pub const VOCAB_SIZE: usize = 256;
44
45#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
54pub enum CoderType {
55 #[default]
58 AC,
59 RANS,
62}
63
64struct CountingWriter {
65 n: u64,
66}
67
68impl CountingWriter {
69 #[inline]
70 fn new() -> Self {
71 Self { n: 0 }
72 }
73
74 #[inline]
75 fn bytes_written(&self) -> u64 {
76 self.n
77 }
78}
79
80impl Write for CountingWriter {
81 #[inline]
82 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
83 let n = buf.len();
84 self.n = self.n.saturating_add(n as u64);
85 Ok(n)
86 }
87
88 #[inline]
89 fn flush(&mut self) -> std::io::Result<()> {
90 Ok(())
91 }
92}
93
94impl std::fmt::Display for CoderType {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 match self {
97 CoderType::AC => write!(f, "AC"),
98 CoderType::RANS => write!(f, "rANS"),
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
116pub struct Header {
117 pub magic: u32,
119 pub version: u8,
121 pub coder: u8,
123 pub original_len: u64,
125 pub crc32: u32,
127}
128
129impl Header {
130 pub const SIZE: usize = 4 + 1 + 1 + 8 + 4; pub fn new(coder: CoderType, original_len: u64, crc32: u32) -> Self {
135 Self {
136 magic: MAGIC,
137 version: VERSION,
138 coder: match coder {
139 CoderType::AC => 0,
140 CoderType::RANS => 1,
141 },
142 original_len,
143 crc32,
144 }
145 }
146
147 pub fn write<W: Write>(&self, w: &mut W) -> Result<()> {
149 w.write_all(&self.magic.to_le_bytes())?;
150 w.write_all(&[self.version])?;
151 w.write_all(&[self.coder])?;
152 w.write_all(&self.original_len.to_le_bytes())?;
153 w.write_all(&self.crc32.to_le_bytes())?;
154 Ok(())
155 }
156
157 pub fn read<R: Read>(r: &mut R) -> Result<Self> {
159 let mut buf4 = [0u8; 4];
160 let mut buf8 = [0u8; 8];
161 let mut buf1 = [0u8; 1];
162
163 r.read_exact(&mut buf4)?;
164 let magic = u32::from_le_bytes(buf4);
165 if magic != MAGIC {
166 bail!(
167 "Invalid magic number: expected 0x{:08X}, got 0x{:08X}",
168 MAGIC,
169 magic
170 );
171 }
172
173 r.read_exact(&mut buf1)?;
174 let version = buf1[0];
175 if version > VERSION {
176 bail!(
177 "Unsupported version: {} (max supported: {})",
178 version,
179 VERSION
180 );
181 }
182
183 r.read_exact(&mut buf1)?;
184 let coder = buf1[0];
185
186 r.read_exact(&mut buf8)?;
187 let original_len = u64::from_le_bytes(buf8);
188
189 r.read_exact(&mut buf4)?;
190 let crc32 = u32::from_le_bytes(buf4);
191
192 Ok(Self {
193 magic,
194 version,
195 coder,
196 original_len,
197 crc32,
198 })
199 }
200
201 pub fn coder_type(&self) -> CoderType {
203 match self.coder {
204 0 => CoderType::AC,
205 _ => CoderType::RANS,
206 }
207 }
208}
209
210pub fn crc32(data: &[u8]) -> u32 {
218 let mut hasher = crc32fast::Hasher::new();
219 hasher.update(data);
220 hasher.finalize()
221}
222
223pub struct Compressor {
232 pub model: Arc<Model>,
234 pub state: State,
236 pub scratch: ScratchBuffers,
237 pub pdf_buffer: Vec<f64>,
239 pub cdf_buffer_ac: Vec<u32>,
241 pub cdf_buffer_rans: Vec<u32>,
243 pub rans_freq_buffer: Vec<i64>,
245}
246
247impl Clone for Compressor {
248 fn clone(&self) -> Self {
249 let mut cloned = Self::new_from_model(self.model.clone());
250 cloned.state = self.state.clone();
251 cloned.pdf_buffer.clone_from(&self.pdf_buffer);
252 cloned.cdf_buffer_ac.clone_from(&self.cdf_buffer_ac);
253 cloned.cdf_buffer_rans.clone_from(&self.cdf_buffer_rans);
254 cloned.rans_freq_buffer.clone_from(&self.rans_freq_buffer);
255 cloned
256 }
257}
258
259impl Compressor {
260 pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
268 let model = Arc::new(Model::load(model_path)?);
269 Ok(Self::new_from_model(model))
270 }
271
272 pub fn load_model<P: AsRef<Path>>(model_path: P) -> Result<Arc<Model>> {
273 Ok(Arc::new(Model::load(model_path)?))
274 }
275
276 pub fn new_from_model(model: Arc<Model>) -> Self {
277 let state = model.new_state();
278 let vocab_size = model.config().vocab_size;
279 let scratch = ScratchBuffers::new(model.config());
280 Self {
281 model,
282 state,
283 scratch,
284 pdf_buffer: vec![0.0f64; vocab_size],
285 cdf_buffer_ac: vec![0u32; vocab_size + 1],
286 cdf_buffer_rans: vec![0u32; vocab_size + 1],
287 rans_freq_buffer: vec![0i64; vocab_size],
288 }
289 }
290
291 pub fn reset(&mut self) {
296 self.state.reset();
297 }
298
299 pub fn vocab_size(&self) -> usize {
301 self.model.config().vocab_size
302 }
303
304 pub fn compress(&mut self, data: &[u8], coder: CoderType) -> Result<Vec<u8>> {
313 let mut output = Vec::new();
314 self.compress_into(data, coder, &mut output)?;
315 Ok(output)
316 }
317
318 pub fn compress_into<W: Write>(
319 &mut self,
320 data: &[u8],
321 coder: CoderType,
322 w: &mut W,
323 ) -> Result<()> {
324 self.state.reset();
325
326 let checksum = crc32(data);
327 let header = Header::new(coder, data.len() as u64, checksum);
328 header.write(w)?;
329
330 match coder {
331 CoderType::AC => self.compress_ac(data, w)?,
332 CoderType::RANS => self.compress_rans(data, w)?,
333 }
334
335 Ok(())
336 }
337
338 pub fn compress_chain_into<W: Write>(
339 &mut self,
340 parts: &[&[u8]],
341 coder: CoderType,
342 w: &mut W,
343 ) -> Result<()> {
344 self.state.reset();
345
346 let mut total_len: u64 = 0;
347 let mut hasher = crc32fast::Hasher::new();
348 for p in parts {
349 total_len = total_len.saturating_add(p.len() as u64);
350 hasher.update(p);
351 }
352 let checksum = hasher.finalize();
353
354 let header = Header::new(coder, total_len, checksum);
355 header.write(w)?;
356
357 let it = parts.iter().flat_map(|p| p.iter().copied());
358 match coder {
359 CoderType::AC => self.compress_ac_iter(it, w)?,
360 CoderType::RANS => self.compress_rans_iter(it, w)?,
361 }
362
363 Ok(())
364 }
365
366 pub fn compress_size(&mut self, data: &[u8], coder: CoderType) -> Result<u64> {
367 let mut w = CountingWriter::new();
368 self.compress_into(data, coder, &mut w)?;
369 Ok(w.bytes_written())
370 }
371
372 pub fn compress_size_chain(&mut self, parts: &[&[u8]], coder: CoderType) -> Result<u64> {
373 let mut w = CountingWriter::new();
374 self.compress_chain_into(parts, coder, &mut w)?;
375 Ok(w.bytes_written())
376 }
377
378 fn compress_ac<W: Write>(&mut self, data: &[u8], output: &mut W) -> Result<()> {
380 self.compress_ac_iter(data.iter().copied(), output)
381 }
382
383 fn compress_ac_iter<I, W: Write>(&mut self, data: I, output: &mut W) -> Result<()>
384 where
385 I: IntoIterator<Item = u8>,
386 {
387 let mut encoder = ArithmeticEncoder::new(output);
388 let vocab_size = self.vocab_size();
389
390 let logits = self.model.forward(&mut self.scratch, 0, &mut self.state);
392 softmax_pdf_floor_inplace(logits, vocab_size, &mut self.pdf_buffer);
393
394 for byte in data {
395 quantize_pdf_to_cdf_inplace(&self.pdf_buffer, &mut self.cdf_buffer_ac);
396 let sym = byte as usize;
397 let c_lo = self.cdf_buffer_ac[sym] as u64;
398 let c_hi = self.cdf_buffer_ac[sym + 1] as u64;
399 encoder.encode_counts(c_lo, c_hi, CDF_TOTAL as u64)?;
400
401 let logits = self
403 .model
404 .forward(&mut self.scratch, byte as u32, &mut self.state);
405 softmax_pdf_floor_inplace(logits, vocab_size, &mut self.pdf_buffer);
406 }
407
408 let _ = encoder.finish()?;
409 Ok(())
410 }
411
412 fn compress_rans<W: Write>(&mut self, data: &[u8], output: &mut W) -> Result<()> {
414 self.compress_rans_iter(data.iter().copied(), output)
415 }
416
417 fn compress_rans_iter<I, W: Write>(&mut self, data: I, output: &mut W) -> Result<()>
418 where
419 I: IntoIterator<Item = u8>,
420 {
421 let vocab_size = self.vocab_size();
422
423 let mut encoder = BlockedRansEncoder::new();
425
426 let logits = self.model.forward(&mut self.scratch, 0, &mut self.state);
428 softmax_pdf_floor_inplace(logits, vocab_size, &mut self.pdf_buffer);
429
430 for byte in data {
431 quantize_pdf_to_rans_cdf_with_buffer(
432 &self.pdf_buffer,
433 &mut self.cdf_buffer_rans,
434 &mut self.rans_freq_buffer,
435 );
436 let sym = byte as usize;
437 let cdf = Cdf::new(
438 self.cdf_buffer_rans[sym],
439 self.cdf_buffer_rans[sym + 1],
440 ANS_TOTAL,
441 );
442 encoder.encode(cdf);
443
444 let logits = self
446 .model
447 .forward(&mut self.scratch, byte as u32, &mut self.state);
448 softmax_pdf_floor_inplace(logits, vocab_size, &mut self.pdf_buffer);
449 }
450
451 let blocks = encoder.finish();
453
454 output.write_all(&(blocks.len() as u32).to_le_bytes())?;
456
457 for block in &blocks {
459 output.write_all(&(block.len() as u32).to_le_bytes())?;
460 output.write_all(block)?;
461 }
462
463 Ok(())
464 }
465
466 pub fn decompress(&mut self, data: &[u8]) -> Result<Vec<u8>> {
474 let mut cursor = Cursor::new(data);
475 let header = Header::read(&mut cursor)?;
476
477 self.state.reset();
478
479 let compressed = &data[Header::SIZE..];
480 let result = match header.coder_type() {
481 CoderType::AC => self.decompress_ac(compressed, header.original_len as usize)?,
482 CoderType::RANS => self.decompress_rans(compressed, header.original_len as usize)?,
483 };
484
485 let actual_crc = crc32(&result);
487 if actual_crc != header.crc32 {
488 bail!(
489 "CRC32 mismatch: expected 0x{:08X}, got 0x{:08X}",
490 header.crc32,
491 actual_crc
492 );
493 }
494
495 Ok(result)
496 }
497
498 fn decompress_ac(&mut self, compressed: &[u8], original_len: usize) -> Result<Vec<u8>> {
500 let mut decoder = ArithmeticDecoder::new(compressed)?;
501 let vocab_size = self.vocab_size();
502
503 let mut result = Vec::with_capacity(original_len);
504
505 let logits = self.model.forward(&mut self.scratch, 0, &mut self.state);
507 softmax_pdf_floor_inplace(logits, vocab_size, &mut self.pdf_buffer);
508
509 for _ in 0..original_len {
510 quantize_pdf_to_cdf_inplace(&self.pdf_buffer, &mut self.cdf_buffer_ac);
511 let sym = decoder.decode_symbol_counts(&self.cdf_buffer_ac, CDF_TOTAL)?;
512 result.push(sym as u8);
513
514 let logits = self
516 .model
517 .forward(&mut self.scratch, sym as u32, &mut self.state);
518 softmax_pdf_floor_inplace(logits, vocab_size, &mut self.pdf_buffer);
519 }
520
521 Ok(result)
522 }
523
524 fn decompress_rans(&mut self, compressed: &[u8], original_len: usize) -> Result<Vec<u8>> {
526 if compressed.len() < 4 {
528 bail!("rANS data too short");
529 }
530 let block_count =
531 u32::from_le_bytes([compressed[0], compressed[1], compressed[2], compressed[3]])
532 as usize;
533
534 let mut blocks = Vec::with_capacity(block_count);
536 let mut pos = 4;
537
538 for _ in 0..block_count {
539 if pos + 4 > compressed.len() {
540 bail!("Truncated block header");
541 }
542 let block_len = u32::from_le_bytes([
543 compressed[pos],
544 compressed[pos + 1],
545 compressed[pos + 2],
546 compressed[pos + 3],
547 ]) as usize;
548 pos += 4;
549
550 if pos + block_len > compressed.len() {
551 bail!("Truncated block data");
552 }
553 blocks.push(&compressed[pos..pos + block_len]);
554 pos += block_len;
555 }
556
557 let mut decoder = BlockedRansDecoder::new(blocks);
559 let vocab_size = self.vocab_size();
560 let mut result = Vec::with_capacity(original_len);
561
562 let logits = self.model.forward(&mut self.scratch, 0, &mut self.state);
564 softmax_pdf_floor_inplace(logits, vocab_size, &mut self.pdf_buffer);
565
566 for _ in 0..original_len {
567 quantize_pdf_to_rans_cdf_with_buffer(
568 &self.pdf_buffer,
569 &mut self.cdf_buffer_rans,
570 &mut self.rans_freq_buffer,
571 );
572 let sym = decoder.decode(&self.cdf_buffer_rans)?;
573 result.push(sym as u8);
574
575 let logits = self
577 .model
578 .forward(&mut self.scratch, sym as u32, &mut self.state);
579 softmax_pdf_floor_inplace(logits, vocab_size, &mut self.pdf_buffer);
580 }
581
582 Ok(result)
583 }
584
585 pub fn cross_entropy(&mut self, data: &[u8]) -> Result<f64> {
596 if data.is_empty() {
597 return Ok(0.0);
598 }
599
600 self.state.reset();
601 let vocab_size = self.vocab_size();
602
603 let mut total_bits = 0.0f64;
604
605 let logits = self.model.forward(&mut self.scratch, 0, &mut self.state);
607 softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
608
609 for &byte in data {
610 let p = self.pdf_buffer[byte as usize];
611 total_bits -= p.log2();
612 let logits = self
613 .model
614 .forward(&mut self.scratch, byte as u32, &mut self.state);
615 softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
616 }
617
618 Ok(total_bits / (data.len() as f64))
619 }
620
621 pub fn cross_entropy_conditional_chain(
622 &mut self,
623 prefix_parts: &[&[u8]],
624 data: &[u8],
625 ) -> Result<f64> {
626 if data.is_empty() {
627 return Ok(0.0);
628 }
629
630 self.state.reset();
631 let vocab_size = self.vocab_size();
632
633 let logits = self.model.forward(&mut self.scratch, 0, &mut self.state);
634 softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
635
636 for p in prefix_parts {
637 for &byte in *p {
638 let logits = self
639 .model
640 .forward(&mut self.scratch, byte as u32, &mut self.state);
641 softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
642 }
643 }
644
645 let mut total_bits = 0.0f64;
646 for &byte in data {
647 let p = self.pdf_buffer[byte as usize];
648 total_bits -= p.log2();
649 let logits = self
650 .model
651 .forward(&mut self.scratch, byte as u32, &mut self.state);
652 softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
653 }
654
655 Ok(total_bits / (data.len() as f64))
656 }
657
658 pub fn cross_entropy_conditional(&mut self, prefix: &[u8], data: &[u8]) -> Result<f64> {
659 if data.is_empty() {
660 return Ok(0.0);
661 }
662
663 self.state.reset();
664 let vocab_size = self.vocab_size();
665
666 let logits = self.model.forward(&mut self.scratch, 0, &mut self.state);
668 softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
669
670 for &byte in prefix {
672 let logits = self
673 .model
674 .forward(&mut self.scratch, byte as u32, &mut self.state);
675 softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
676 }
677
678 let mut total_bits = 0.0f64;
679 for &byte in data {
680 let p = self.pdf_buffer[byte as usize];
681 total_bits -= p.log2();
682 let logits = self
683 .model
684 .forward(&mut self.scratch, byte as u32, &mut self.state);
685 softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
686 }
687
688 Ok(total_bits / (data.len() as f64))
689 }
690
691 pub fn joint_cross_entropy_aligned_min(&mut self, x: &[u8], y: &[u8]) -> Result<f64> {
692 let n = x.len().min(y.len());
693 if n == 0 {
694 return Ok(0.0);
695 }
696
697 let h_xy = self.joint_cross_entropy_aligned_order(x, y, false)?;
698 let h_yx = self.joint_cross_entropy_aligned_order(x, y, true)?;
699 Ok(h_xy.min(h_yx))
700 }
701
702 fn joint_cross_entropy_aligned_order(&mut self, x: &[u8], y: &[u8], swap: bool) -> Result<f64> {
703 let n = x.len().min(y.len());
704 if n == 0 {
705 return Ok(0.0);
706 }
707
708 self.state.reset();
709 let vocab_size = self.vocab_size();
710
711 let logits = self.model.forward(&mut self.scratch, 0, &mut self.state);
712 softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
713
714 let mut total_bits = 0.0f64;
715 for i in 0..n {
716 let a = if swap { y[i] } else { x[i] };
717 let b = if swap { x[i] } else { y[i] };
718
719 let pa = self.pdf_buffer[a as usize];
720 total_bits -= pa.log2();
721 let logits = self
722 .model
723 .forward(&mut self.scratch, a as u32, &mut self.state);
724 softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
725
726 let pb = self.pdf_buffer[b as usize];
727 total_bits -= pb.log2();
728 let logits = self
729 .model
730 .forward(&mut self.scratch, b as u32, &mut self.state);
731 softmax_pdf_inplace(logits, vocab_size, &mut self.pdf_buffer);
732 }
733
734 Ok(total_bits / (n as f64))
735 }
736}
737
738#[derive(Debug, Clone)]
744pub struct CompressionStats {
745 pub original_size: usize,
747 pub compressed_size: usize,
749 pub ratio: f64,
751 pub bits_per_byte: f64,
753 pub time_seconds: f64,
755 pub throughput: f64,
757}
758
759impl std::fmt::Display for CompressionStats {
760 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
761 write!(
762 f,
763 "{} bytes -> {} bytes | ratio={:.3} | bits/byte={:.3} | time={:.2}s | {:.0} B/s",
764 self.original_size,
765 self.compressed_size,
766 self.ratio,
767 self.bits_per_byte,
768 self.time_seconds,
769 self.throughput,
770 )
771 }
772}
773
774pub fn compress_with_stats(
778 compressor: &mut Compressor,
779 data: &[u8],
780 coder: CoderType,
781) -> Result<(Vec<u8>, CompressionStats)> {
782 let start = std::time::Instant::now();
783 let compressed = compressor.compress(data, coder)?;
784 let elapsed = start.elapsed().as_secs_f64();
785
786 let stats = CompressionStats {
787 original_size: data.len(),
788 compressed_size: compressed.len(),
789 ratio: data.len() as f64 / compressed.len() as f64,
790 bits_per_byte: (compressed.len() as f64 * 8.0) / data.len() as f64,
791 time_seconds: elapsed,
792 throughput: data.len() as f64 / elapsed,
793 };
794
795 Ok((compressed, stats))
796}
797
798#[cfg(test)]
803mod tests {
804 use super::*;
805
806 #[test]
807 fn test_header_roundtrip() {
808 let header = Header::new(CoderType::AC, 12345, 0xDEADBEEF);
809
810 let mut buf = Vec::new();
811 header.write(&mut buf).unwrap();
812
813 assert_eq!(buf.len(), Header::SIZE);
814
815 let mut cursor = Cursor::new(&buf);
816 let read_header = Header::read(&mut cursor).unwrap();
817
818 assert_eq!(read_header.magic, MAGIC);
819 assert_eq!(read_header.version, VERSION);
820 assert_eq!(read_header.coder, 0);
821 assert_eq!(read_header.original_len, 12345);
822 assert_eq!(read_header.crc32, 0xDEADBEEF);
823 }
824
825 #[test]
826 fn test_header_rans() {
827 let header = Header::new(CoderType::RANS, 67890, 0xCAFEBABE);
828 assert_eq!(header.coder, 1);
829 assert_eq!(header.coder_type(), CoderType::RANS);
830 }
831
832 #[test]
833 fn test_coder_type_display() {
834 assert_eq!(format!("{}", CoderType::AC), "AC");
835 assert_eq!(format!("{}", CoderType::RANS), "rANS");
836 }
837
838 #[test]
839 fn test_crc32() {
840 let data = b"Hello, World!";
841 let c = crc32(data);
842 assert_ne!(c, 0);
843 assert_eq!(c, crc32(data));
845 }
846
847 #[test]
848 fn test_crc32_different_data() {
849 let c1 = crc32(b"Hello");
850 let c2 = crc32(b"World");
851 assert_ne!(c1, c2);
852 }
853
854 #[test]
855 fn test_crc32_known_vector() {
856 assert_eq!(crc32(b"123456789"), 0xCBF4_3926);
858 }
859
860 #[test]
861 fn test_header_rejects_invalid_magic() {
862 let mut buf = Vec::new();
863 let header = Header::new(CoderType::AC, 1, 2);
864 header.write(&mut buf).unwrap();
865 buf[0] ^= 0xFF;
867
868 let mut cursor = Cursor::new(&buf);
869 let err = Header::read(&mut cursor).unwrap_err();
870 let msg = format!("{err:#}");
871 assert!(msg.contains("Invalid magic number"));
872 }
873}