infotheory/coders/
ac.rs

1//! Arithmetic Coder implementation for rwkvzip.
2//!
3//! This implements a binary arithmetic coder with 32-bit precision, optimized for
4//! neural network probability distributions. The implementation is mathematically
5//! rigorous to ensure lossless compression.
6//!
7//! # Information-Theoretic Properties
8//!
9//! - Uses base-2 arithmetic for bitstream output
10//! - 32-bit precision prevents underflow for typical neural network distributions
11//! - Integer CDF quantization uses 30-bit total (2^30) to minimize quantization error
12//! - Probability floor ensures no symbol has zero probability (critical for lossless)
13
14use std::io::Write;
15use wide::f32x8;
16use wide::f64x4;
17
18/// Total count for CDF quantization (2^30 for high precision)
19pub const CDF_TOTAL: u32 = 1 << 30;
20
21/// Arithmetic coder precision in bits
22const PRECISION: u32 = 32;
23
24/// Base for arithmetic coding (binary)
25const BASE: u64 = 2;
26
27/// Returns the minimum probability floor for symbols.
28/// P_MIN = 2 * 2^(-(PRECISION-2)) = 2^(-(PRECISION-3)) = 2^(-29)
29#[inline]
30pub fn p_min() -> f64 {
31    // 2.0 * 2.0^(-(32-2)) = 2^(-29) ≈ 1.86e-9
32    2.0f64.powi(-(PRECISION as i32 - 3))
33}
34
35/// Compute softmax PDF with probability floor.
36///
37/// # Arguments
38/// * `logits` - Raw logits from the model
39/// * `vocab_size` - Size of the vocabulary (256 for byte-level)
40///
41/// # Returns
42/// Probability distribution with floor applied, normalized to sum to 1.
43pub fn softmax_pdf_floor(logits: &[f32], vocab_size: usize) -> Vec<f64> {
44    let mut result = vec![0f64; vocab_size];
45    softmax_pdf_floor_inplace(logits, vocab_size, &mut result);
46    result
47}
48
49/// In-place softmax from logits into a caller-provided `pdf_out` buffer.
50///
51/// `pdf_out.len()` must be at least `vocab_size`.
52pub fn softmax_pdf_inplace(logits: &[f32], vocab_size: usize, pdf_out: &mut [f64]) {
53    debug_assert!(pdf_out.len() >= vocab_size);
54
55    let max = logits
56        .iter()
57        .take(vocab_size)
58        .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
59
60    let mut sum = 0.0f64;
61    for i in 0..vocab_size {
62        let e = ((logits[i] - max) as f64).exp();
63        pdf_out[i] = e;
64        sum += e;
65    }
66
67    if sum > 0.0 {
68        let inv = 1.0 / sum;
69        for value in pdf_out.iter_mut().take(vocab_size) {
70            *value *= inv;
71        }
72    } else {
73        let inv = 1.0 / (vocab_size.max(1) as f64);
74        for value in pdf_out.iter_mut().take(vocab_size) {
75            *value = inv;
76        }
77    }
78}
79
80/// In-place version of softmax_pdf_floor to avoid allocations.
81///
82/// # Arguments
83/// * `logits` - Raw logits from the model
84/// * `vocab_size` - Size of the vocabulary
85/// * `pdf_out` - Pre-allocated buffer for output PDF (length >= vocab_size)
86pub fn softmax_pdf_floor_inplace(logits: &[f32], vocab_size: usize, pdf_out: &mut [f64]) {
87    // Fast path for byte-level vocab using portable SIMD (`wide`).
88    if vocab_size == 256 && logits.len() >= 256 && pdf_out.len() >= 256 {
89        softmax_pdf_floor_wide_256(logits, pdf_out);
90        return;
91    }
92
93    let p_min_val = p_min();
94
95    // Find max for numerical stability
96    let max = logits
97        .iter()
98        .take(vocab_size)
99        .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
100
101    // Compute exp(x - max) and sum (reuse pdf_out as temp buffer)
102    let mut sum = 0.0f64;
103    for i in 0..vocab_size {
104        let e = ((logits[i] - max) as f64).exp();
105        pdf_out[i] = e;
106        sum += e;
107    }
108
109    // Normalize and apply floor
110    for value in pdf_out.iter_mut().take(vocab_size) {
111        *value = (*value / sum).max(p_min_val);
112    }
113
114    // Re-normalize after floor application
115    let norm: f64 = pdf_out[..vocab_size].iter().sum();
116    for value in pdf_out.iter_mut().take(vocab_size) {
117        *value /= norm;
118    }
119}
120
121/// Portable SIMD softmax with probability floor for vocab_size=256.
122#[inline]
123fn softmax_pdf_floor_wide_256(logits: &[f32], pdf_out: &mut [f64]) {
124    const N: usize = 256;
125    debug_assert!(logits.len() >= N);
126    debug_assert!(pdf_out.len() >= N);
127    let p_min_val = p_min();
128
129    #[inline(always)]
130    unsafe fn load8(ptr: *const f32) -> f32x8 {
131        ptr.cast::<f32x8>().read_unaligned()
132    }
133
134    let mut max_v = f32x8::splat(f32::NEG_INFINITY);
135    for i in (0..N).step_by(8) {
136        let v = unsafe { load8(logits.as_ptr().add(i)) };
137        max_v = max_v.fast_max(v);
138    }
139    let mut max = f32::NEG_INFINITY;
140    for x in max_v.to_array() {
141        max = max.max(x);
142    }
143    let max_v = f32x8::splat(max);
144
145    let mut sum4 = f64x4::ZERO;
146    for (chunk_idx, out_chunk) in pdf_out[..N].chunks_exact_mut(8).enumerate() {
147        let i = chunk_idx * 8;
148        let centered = unsafe { load8(logits.as_ptr().add(i)) } - max_v;
149        let exp_vals = centered.exp().to_array();
150        let v0 = f64x4::new([
151            exp_vals[0] as f64,
152            exp_vals[1] as f64,
153            exp_vals[2] as f64,
154            exp_vals[3] as f64,
155        ]);
156        let v1 = f64x4::new([
157            exp_vals[4] as f64,
158            exp_vals[5] as f64,
159            exp_vals[6] as f64,
160            exp_vals[7] as f64,
161        ]);
162        sum4 += v0 + v1;
163
164        let lanes0 = v0.to_array();
165        let lanes1 = v1.to_array();
166        out_chunk[..4].copy_from_slice(&lanes0);
167        out_chunk[4..].copy_from_slice(&lanes1);
168    }
169
170    let sum_lanes = sum4.to_array();
171    let sum = sum_lanes[0] + sum_lanes[1] + sum_lanes[2] + sum_lanes[3];
172
173    let inv_sum = 1.0 / sum;
174    let mut norm4 = f64x4::ZERO;
175    let inv_sum4 = f64x4::splat(inv_sum);
176    let min4 = f64x4::splat(p_min_val);
177
178    for chunk in pdf_out[..N].chunks_exact_mut(4) {
179        let vals = f64x4::new([chunk[0], chunk[1], chunk[2], chunk[3]]);
180        let mut v = vals * inv_sum4;
181        v = v.max(min4);
182
183        let lanes = v.to_array();
184        chunk.copy_from_slice(&lanes);
185        norm4 += v;
186    }
187
188    let norm_lanes = norm4.to_array();
189    let norm = norm_lanes[0] + norm_lanes[1] + norm_lanes[2] + norm_lanes[3];
190
191    let inv_norm = 1.0 / norm;
192    let inv_norm4 = f64x4::splat(inv_norm);
193    for chunk in pdf_out[..N].chunks_exact_mut(4) {
194        let vals = f64x4::new([chunk[0], chunk[1], chunk[2], chunk[3]]);
195        let out = vals * inv_norm4;
196        let lanes = out.to_array();
197        chunk.copy_from_slice(&lanes);
198    }
199}
200
201/// Compute softmax PDF without floor (for entropy calculation).
202pub fn softmax_pdf(logits: &[f32], vocab_size: usize) -> Vec<f64> {
203    let max = logits
204        .iter()
205        .take(vocab_size)
206        .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
207
208    let mut exps = vec![0f64; vocab_size];
209    let mut sum = 0.0f64;
210    for i in 0..vocab_size {
211        let e = ((logits[i] - max) as f64).exp();
212        exps[i] = e;
213        sum += e;
214    }
215
216    if sum <= 0.0 {
217        // Uniform distribution fallback
218        let uniform = 1.0 / (vocab_size as f64);
219        return vec![uniform; vocab_size];
220    }
221
222    let mut pdf = vec![0f64; vocab_size];
223    for i in 0..vocab_size {
224        pdf[i] = exps[i] / sum;
225    }
226    pdf
227}
228
229/// Quantize probability distribution to integer CDF.
230///
231/// The CDF is constructed to be monotonically non-decreasing with:
232/// - `cdf[0] = 0`
233/// - `cdf[vocab_size] = CDF_TOTAL`
234///
235/// # Arguments
236/// * `pdf` - Probability distribution (must sum to ~1.0)
237///
238/// # Returns
239/// Integer CDF with length vocab_size + 1
240pub fn quantize_pdf_to_cdf(pdf: &[f64]) -> Vec<u32> {
241    let mut cdf = vec![0u32; pdf.len() + 1];
242    quantize_pdf_to_cdf_inplace(pdf, &mut cdf);
243    cdf
244}
245
246/// Quantize PDF to integer CDF using a reusable output buffer.
247///
248/// `cdf_out` must have length at least `pdf.len() + 1`.
249#[inline]
250pub fn quantize_pdf_to_cdf_inplace(pdf: &[f64], cdf_out: &mut [u32]) {
251    let mut unused_freq = [];
252    super::quantize_pdf_to_integer_cdf_with_buffer(pdf, CDF_TOTAL, cdf_out, &mut unused_freq);
253}
254
255/// Quantize PDF to integer CDF using reusable output and frequency buffers.
256#[inline]
257pub fn quantize_pdf_to_cdf_with_buffer(pdf: &[f64], cdf_out: &mut [u32], freq_buf: &mut [i64]) {
258    super::quantize_pdf_to_integer_cdf_with_buffer(pdf, CDF_TOTAL, cdf_out, freq_buf);
259}
260
261/// Binary arithmetic encoder.
262pub struct ArithmeticEncoder<W: Write> {
263    b_to_pm1: u64,
264    b_to_pm2: u64,
265    mask: u64,
266    low: u64,
267    high: u64,
268    carry_run: u64,
269    out: W,
270    bit_buffer: u8,
271    bit_count: u8,
272    bytes_out: u64,
273}
274
275impl<W: Write> ArithmeticEncoder<W> {
276    /// Create a new arithmetic encoder.
277    pub fn new(out: W) -> Self {
278        let b_to_pm1 = BASE.pow(PRECISION - 1);
279        let b_to_pm2 = BASE.pow(PRECISION - 2);
280        let mask = BASE.pow(PRECISION) - 1;
281        Self {
282            b_to_pm1,
283            b_to_pm2,
284            mask,
285            low: 0,
286            high: mask,
287            carry_run: 0,
288            out,
289            bit_buffer: 0,
290            bit_count: 0,
291            bytes_out: 0,
292        }
293    }
294
295    #[inline]
296    fn write_byte(&mut self, byte: u8) -> anyhow::Result<()> {
297        self.out.write_all(&[byte])?;
298        self.bytes_out += 1;
299        Ok(())
300    }
301
302    #[inline]
303    fn put_bit_internal(&mut self, bit: u8) -> anyhow::Result<()> {
304        self.bit_buffer = (self.bit_buffer << 1) | (bit & 1);
305        self.bit_count += 1;
306        if self.bit_count == 8 {
307            let b = self.bit_buffer;
308            self.write_byte(b)?;
309            self.bit_buffer = 0;
310            self.bit_count = 0;
311        }
312        Ok(())
313    }
314
315    #[inline]
316    fn put_bit(&mut self, bit: u8) -> anyhow::Result<()> {
317        self.put_bit_internal(bit)?;
318        while self.carry_run > 0 {
319            self.put_bit_internal((!bit) & 1)?;
320            self.carry_run -= 1;
321        }
322        Ok(())
323    }
324
325    /// Encode a symbol using integer CDF bounds.
326    ///
327    /// # Arguments
328    /// * `c_lo` - Lower CDF bound (cumulative probability before symbol)
329    /// * `c_hi` - Upper CDF bound (cumulative probability including symbol)
330    /// * `total` - Total CDF range (should be CDF_TOTAL)
331    pub fn encode_counts(&mut self, c_lo: u64, c_hi: u64, total: u64) -> anyhow::Result<()> {
332        let range = (self.high - self.low + 1) as u128;
333        let total_u = total as u128;
334        let c_lo_u = c_lo as u128;
335        let c_hi_u = c_hi as u128;
336        let low_u = self.low as u128;
337
338        let new_low = low_u + (range * c_lo_u) / total_u;
339        let new_high = low_u + (range * c_hi_u) / total_u - 1;
340
341        self.low = (new_low & (self.mask as u128)) as u64;
342        self.high = (new_high & (self.mask as u128)) as u64;
343
344        loop {
345            if self.high < self.b_to_pm1 {
346                self.put_bit(0)?;
347            } else if self.low >= self.b_to_pm1 {
348                self.put_bit(1)?;
349                self.low -= self.b_to_pm1;
350                self.high -= self.b_to_pm1;
351            } else if self.low >= self.b_to_pm2 && self.high < self.b_to_pm2 * 3 {
352                self.carry_run += 1;
353                self.low -= self.b_to_pm2;
354                self.high -= self.b_to_pm2;
355            } else {
356                break;
357            }
358            self.low = (self.low << 1) & self.mask;
359            self.high = ((self.high << 1) & self.mask) | 1;
360        }
361        Ok(())
362    }
363
364    /// Encode a symbol given its PDF and symbol index.
365    ///
366    /// This is a convenience method that quantizes the PDF to CDF internally.
367    pub fn encode_symbol(&mut self, pdf: &[f64], sym: usize) -> anyhow::Result<()> {
368        let cdf = quantize_pdf_to_cdf(pdf);
369        let c_lo = cdf[sym] as u64;
370        let c_hi = cdf[sym + 1] as u64;
371        self.encode_counts(c_lo, c_hi, CDF_TOTAL as u64)
372    }
373
374    /// Finish encoding and flush remaining bits.
375    ///
376    /// Returns the underlying writer.
377    pub fn finish(mut self) -> anyhow::Result<W> {
378        self.carry_run += 1;
379        if self.low < self.b_to_pm2 {
380            self.put_bit(0)?;
381        } else {
382            self.put_bit(1)?;
383        }
384        // Pad remaining bits
385        if self.bit_count > 0 {
386            let remaining = 8 - self.bit_count;
387            for _ in 0..remaining {
388                self.put_bit_internal(0)?;
389            }
390        }
391        Ok(self.out)
392    }
393
394    /// Get the number of bytes written so far.
395    #[inline]
396    pub fn bytes_written(&self) -> u64 {
397        self.bytes_out
398    }
399}
400
401/// Binary arithmetic decoder.
402pub struct ArithmeticDecoder<'a> {
403    b_to_pm1: u64,
404    b_to_pm2: u64,
405    mask: u64,
406    low: u64,
407    high: u64,
408    code: u64,
409    input: &'a [u8],
410    byte_pos: usize,
411    bit_pos: u8,
412}
413
414impl<'a> ArithmeticDecoder<'a> {
415    /// Create a new arithmetic decoder from input bytes.
416    pub fn new(input: &'a [u8]) -> anyhow::Result<Self> {
417        let b_to_pm1 = BASE.pow(PRECISION - 1);
418        let b_to_pm2 = BASE.pow(PRECISION - 2);
419        let mask = BASE.pow(PRECISION) - 1;
420
421        let mut s = Self {
422            b_to_pm1,
423            b_to_pm2,
424            mask,
425            low: 0,
426            high: mask,
427            code: 0,
428            input,
429            byte_pos: 0,
430            bit_pos: 0,
431        };
432
433        // Initialize code register with first PRECISION bits
434        for _ in 0..PRECISION {
435            s.code = (s.code << 1) | (s.get_bit().unwrap_or(1) as u64);
436        }
437
438        Ok(s)
439    }
440
441    #[inline]
442    fn get_bit(&mut self) -> Option<u8> {
443        if self.byte_pos >= self.input.len() {
444            return None;
445        }
446        let byte = self.input[self.byte_pos];
447        let bit = (byte >> (7 - self.bit_pos)) & 1;
448        self.bit_pos += 1;
449        if self.bit_pos >= 8 {
450            self.bit_pos = 0;
451            self.byte_pos += 1;
452        }
453        Some(bit)
454    }
455
456    /// Decode a symbol using integer CDF.
457    ///
458    /// # Arguments
459    /// * `cdf` - Cumulative distribution function (length = vocab_size + 1)
460    /// * `total` - Total CDF range (should be CDF_TOTAL)
461    ///
462    /// # Returns
463    /// The decoded symbol index.
464    pub fn decode_symbol_counts(&mut self, cdf: &[u32], total: u32) -> anyhow::Result<usize> {
465        let total_u = total as u64;
466        let range = self.high - self.low + 1;
467        let value =
468            (((self.code - self.low + 1) as u128 * (total_u as u128)) - 1) / (range as u128);
469        let value_u = value as u32;
470
471        // Binary search for symbol `s` with `cdf[s] <= value < cdf[s+1]`
472        let mut lo = 0usize;
473        let mut hi = cdf.len() - 1;
474        while lo + 1 < hi {
475            let mid = (lo + hi) / 2;
476            if cdf[mid] <= value_u {
477                lo = mid;
478            } else {
479                hi = mid;
480            }
481        }
482        let s = lo;
483        let c_lo = cdf[s] as u64;
484        let c_hi = cdf[s + 1] as u64;
485
486        // Update range
487        let range = (self.high - self.low + 1) as u128;
488        let low_u = self.low as u128;
489        let total_u128 = total as u128;
490        let new_low = low_u + (range * (c_lo as u128)) / total_u128;
491        let new_high = low_u + (range * (c_hi as u128)) / total_u128 - 1;
492
493        self.low = new_low as u64;
494        self.high = new_high as u64;
495
496        // Renormalize
497        loop {
498            if self.high < self.b_to_pm1 {
499                // nothing
500            } else if self.low >= self.b_to_pm1 {
501                self.low -= self.b_to_pm1;
502                self.high -= self.b_to_pm1;
503                self.code -= self.b_to_pm1;
504            } else if self.low >= self.b_to_pm2 && self.high < self.b_to_pm2 * 3 {
505                self.low -= self.b_to_pm2;
506                self.high -= self.b_to_pm2;
507                self.code -= self.b_to_pm2;
508            } else {
509                break;
510            }
511            self.low = (self.low << 1) & self.mask;
512            self.high = ((self.high << 1) & self.mask) | 1;
513            self.code = ((self.code << 1) & self.mask) | (self.get_bit().unwrap_or(1) as u64);
514        }
515
516        Ok(s)
517    }
518
519    /// Decode a symbol given a PDF.
520    ///
521    /// This is a convenience method that quantizes the PDF to CDF internally.
522    pub fn decode_symbol(&mut self, pdf: &[f64]) -> anyhow::Result<usize> {
523        let cdf = quantize_pdf_to_cdf(pdf);
524        self.decode_symbol_counts(&cdf, CDF_TOTAL)
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531
532    #[test]
533    fn test_roundtrip_uniform() {
534        // Test with uniform distribution
535        let pdf = vec![0.25, 0.25, 0.25, 0.25];
536        let symbols = vec![0, 1, 2, 3, 0, 1, 2, 3];
537
538        // Encode
539        let mut buf = Vec::new();
540        let mut enc = ArithmeticEncoder::new(&mut buf);
541        for &s in &symbols {
542            enc.encode_symbol(&pdf, s).unwrap();
543        }
544        let buf = enc.finish().unwrap().to_vec();
545
546        // Decode
547        let mut dec = ArithmeticDecoder::new(&buf).unwrap();
548        for &expected in &symbols {
549            let got = dec.decode_symbol(&pdf).unwrap();
550            assert_eq!(got, expected);
551        }
552    }
553
554    #[test]
555    fn test_roundtrip_skewed() {
556        // Test with skewed distribution
557        let pdf = vec![0.7, 0.2, 0.05, 0.05];
558        let symbols = vec![0, 0, 0, 1, 0, 2, 0, 3, 0, 0];
559
560        // Encode
561        let mut buf = Vec::new();
562        let mut enc = ArithmeticEncoder::new(&mut buf);
563        for &s in &symbols {
564            enc.encode_symbol(&pdf, s).unwrap();
565        }
566        let buf = enc.finish().unwrap().to_vec();
567
568        // Decode
569        let mut dec = ArithmeticDecoder::new(&buf).unwrap();
570        for &expected in &symbols {
571            let got = dec.decode_symbol(&pdf).unwrap();
572            assert_eq!(got, expected);
573        }
574    }
575
576    #[test]
577    fn test_softmax_pdf_floor() {
578        let logits = vec![1.0f32, 2.0, 3.0, 4.0];
579        let pdf = softmax_pdf_floor(&logits, 4);
580
581        // Check sum is ~1
582        let sum: f64 = pdf.iter().sum();
583        assert!((sum - 1.0).abs() < 1e-10);
584
585        // Check all probabilities are >= floor
586        let p_min_val = p_min();
587        for &p in &pdf {
588            assert!(p >= p_min_val);
589        }
590    }
591
592    #[test]
593    fn test_cdf_monotonic() {
594        let pdf = vec![0.1, 0.2, 0.3, 0.4];
595        let cdf = quantize_pdf_to_cdf(&pdf);
596
597        assert_eq!(cdf[0], 0);
598        assert_eq!(cdf[4], CDF_TOTAL);
599
600        // Check monotonicity
601        for i in 1..cdf.len() {
602            assert!(cdf[i] >= cdf[i - 1]);
603        }
604    }
605
606    #[test]
607    fn test_cdf_positive_width_for_tiny_positive_tail() {
608        let tail = 1e-18;
609        let head = 1.0 - (255.0 * tail);
610        let mut pdf = vec![tail; 256];
611        pdf[0] = head;
612        let mut cdf = vec![0u32; 257];
613        let mut freq = vec![0i64; 256];
614        quantize_pdf_to_cdf_with_buffer(&pdf, &mut cdf, &mut freq);
615
616        assert_eq!(cdf[0], 0);
617        assert_eq!(cdf[256], CDF_TOTAL);
618        for i in 0..256 {
619            assert!(cdf[i + 1] > cdf[i], "symbol {i} has zero-width interval");
620        }
621    }
622
623    #[test]
624    fn test_cdf_positive_width_when_mass_is_last_symbol() {
625        let mut pdf = vec![0.0; 256];
626        pdf[255] = 1.0;
627        let mut cdf = vec![0u32; 257];
628        let mut freq = vec![0i64; 256];
629        quantize_pdf_to_cdf_with_buffer(&pdf, &mut cdf, &mut freq);
630
631        assert_eq!(cdf[0], 0);
632        assert_eq!(cdf[255], 255);
633        assert_eq!(cdf[256], CDF_TOTAL);
634        for i in 0..256 {
635            assert!(cdf[i + 1] > cdf[i], "symbol {i} has zero-width interval");
636        }
637    }
638
639    #[test]
640    #[should_panic]
641    fn test_softmax_floor_256_short_logits_panics_safely() {
642        // For vocab_size=256, short logits must not hit the SIMD fast path.
643        // Safe fallback behavior is a normal Rust bounds panic in scalar code.
644        let logits = vec![0.0f32; 255];
645        let mut pdf_out = vec![0.0f64; 256];
646        softmax_pdf_floor_inplace(&logits, 256, &mut pdf_out);
647    }
648}