rwkvzip/coders/
rans.rs

1//! rANS (range Asymmetric Numeral System) coder with SIMD optimizations.
2//!
3//! This implements a vectorized rANS coder optimized for x86_64 with AVX2/BMI2.
4//! Falls back to scalar implementation on other architectures.
5//!
6//! # Design
7//!
8//! - Uses 8-way parallel rANS states for throughput
9//! - 15-bit precision for probability quantization
10//! - Interleaved bitstream for decoder efficiency
11//! - Supports both streaming and batch modes
12
13/// Number of bits for rANS probability precision
14pub const ANS_BITS: u32 = 15;
15
16/// Total probability range (2^15 = 32768)
17pub const ANS_TOTAL: u32 = 1 << ANS_BITS;
18
19/// Lower bound for rANS state (2^15)
20pub const ANS_LOW: u32 = 1 << ANS_BITS;
21
22/// Upper bound for rANS state (2^31)
23pub const ANS_HIGH: u32 = 1 << 31;
24
25/// rANS CDF representation for a symbol.
26#[derive(Clone, Debug)]
27pub struct Cdf {
28    /// Lower cumulative probability bound
29    pub lo: u32,
30    /// Upper cumulative probability bound  
31    pub hi: u32,
32    /// Total probability (should be ANS_TOTAL)
33    pub total: u32,
34}
35
36impl Cdf {
37    /// Create a new CDF entry.
38    #[inline]
39    pub fn new(lo: u32, hi: u32, total: u32) -> Self {
40        Self { lo, hi, total }
41    }
42
43    /// Get the frequency (hi - lo).
44    #[inline]
45    pub fn freq(&self) -> u32 {
46        self.hi - self.lo
47    }
48}
49
50/// Quantize PDF to rANS CDF table with guaranteed minimum frequencies.
51///
52/// This implements a robust quantization algorithm that ensures:
53/// 1. All symbols with p > 0 get freq >= 1
54/// 2. The total equals ANS_TOTAL exactly
55/// 3. Monotonicity is preserved (`cdf[i+1] >= cdf[i]`)
56///
57/// Uses error diffusion to distribute rounding errors across symbols.
58///
59/// # Arguments
60/// * `pdf` - Probability distribution (must sum to ~1.0)
61///
62/// # Returns
63/// CDF table where `cdf[i]` = cumulative probability up to symbol i
64pub fn quantize_pdf_to_rans_cdf(pdf: &[f64]) -> Vec<u32> {
65    let mut cdf = vec![0u32; pdf.len() + 1];
66    let mut freqs = vec![0i64; pdf.len()];
67    quantize_pdf_to_rans_cdf_with_buffer(pdf, &mut cdf, &mut freqs);
68    cdf
69}
70
71/// Quantize PDF to rANS CDF using reusable scratch buffers.
72///
73/// * `cdf_out` must have length at least `pdf.len() + 1`
74/// * `freq_buf` must have length at least `pdf.len()`
75/// * `index_buf` must have length at least `pdf.len()`
76pub fn quantize_pdf_to_rans_cdf_with_buffer(
77    pdf: &[f64],
78    cdf_out: &mut [u32],
79    freq_buf: &mut [i64],
80) {
81    let n = pdf.len();
82    assert!(cdf_out.len() >= n + 1, "cdf buffer too small");
83    assert!(freq_buf.len() >= n, "frequency buffer too small");
84
85    let total = ANS_TOTAL as i64;
86    for i in 0..n {
87        freq_buf[i] = (pdf[i] * total as f64).round() as i64;
88        if pdf[i] > 0.0 && freq_buf[i] == 0 {
89            freq_buf[i] = 1;
90        } else if pdf[i] <= 0.0 {
91            freq_buf[i] = 0;
92        }
93    }
94
95    let sum: i64 = freq_buf[..n].iter().sum();
96    if sum > total {
97        let mut to_remove = sum - total;
98        while to_remove > 0 {
99            let mut removed = 0;
100            for i in (0..n).rev() {
101                if freq_buf[i] > 1 {
102                    freq_buf[i] -= 1;
103                    to_remove -= 1;
104                    removed += 1;
105                    if to_remove == 0 {
106                        break;
107                    }
108                }
109            }
110            if removed == 0 {
111                break;
112            }
113        }
114    } else if sum < total {
115        let mut to_add = total - sum;
116        while to_add > 0 {
117            let mut added = 0;
118            for i in 0..n {
119                if pdf[i] > 0.0 {
120                    freq_buf[i] += 1;
121                    to_add -= 1;
122                    added += 1;
123                    if to_add == 0 {
124                        break;
125                    }
126                }
127            }
128            if added == 0 {
129                for i in 0..n {
130                    freq_buf[i] += 1;
131                    to_add -= 1;
132                    if to_add == 0 {
133                        break;
134                    }
135                }
136            }
137        }
138    }
139
140    cdf_out[0] = 0;
141    let mut cumsum = 0u32;
142    for i in 0..n {
143        cdf_out[i] = cumsum;
144        cumsum += freq_buf[i] as u32;
145    }
146    cdf_out[n] = cumsum;
147
148    debug_assert_eq!(cdf_out[n], ANS_TOTAL, "CDF total must equal ANS_TOTAL");
149    for i in 0..n {
150        if pdf[i] > 0.0 {
151            debug_assert!(
152                cdf_out[i + 1] > cdf_out[i],
153                "Symbol {} with p={} has zero frequency",
154                i,
155                pdf[i]
156            );
157        }
158    }
159}
160
161/// Get Cdf for a symbol from a CDF table.
162#[inline]
163pub fn cdf_for_symbol(cdf: &[u32], sym: usize) -> Cdf {
164    Cdf::new(cdf[sym], cdf[sym + 1], ANS_TOTAL)
165}
166
167/// Scalar rANS encoder.
168pub struct RansEncoder {
169    state: u32,
170    output: Vec<u16>, // 16-bit words for output
171}
172
173impl RansEncoder {
174    /// Create a new rANS encoder.
175    pub fn new() -> Self {
176        Self {
177            state: ANS_LOW,
178            output: Vec::new(),
179        }
180    }
181
182    /// Encode a symbol using its CDF bounds.
183    ///
184    /// rANS encoding formula:
185    /// x' = (x / freq) << ANS_BITS + (x % freq) + c_lo
186    #[inline]
187    pub fn encode(&mut self, cdf: &Cdf) {
188        let freq = cdf.freq();
189        debug_assert!(freq > 0, "Symbol frequency must be > 0");
190
191        // Renormalize: output 16-bit words while state >= max allowed
192        // Max state after encode: freq * (2^16) - 1, we need this < ANS_HIGH
193        // So we renorm when state >= (freq << (32 - 1 - ANS_BITS)) = freq << 16
194        while self.state >= (freq << 16) {
195            self.output.push(self.state as u16);
196            self.state >>= 16;
197        }
198
199        // Encode: x' = (x / freq) << ANS_BITS + (x % freq) + c_lo
200        let q = self.state / freq;
201        let r = self.state % freq;
202        self.state = (q << ANS_BITS) + r + cdf.lo;
203    }
204
205    /// Encode a symbol given a PDF.
206    pub fn encode_pdf(&mut self, pdf: &[f64], sym: usize) {
207        let cdf_table = quantize_pdf_to_rans_cdf(pdf);
208        let cdf = cdf_for_symbol(&cdf_table, sym);
209        self.encode(&cdf);
210    }
211
212    /// Finish encoding and return the output bytes.
213    pub fn finish(self) -> Vec<u8> {
214        // Output final state (4 bytes)
215        let mut result = Vec::with_capacity(self.output.len() * 2 + 4);
216
217        // Push final state first (will be read first during decode)
218        result.extend_from_slice(&self.state.to_le_bytes());
219
220        // Push output words in reverse order (LIFO)
221        for &word in self.output.iter().rev() {
222            result.extend_from_slice(&word.to_le_bytes());
223        }
224
225        result
226    }
227
228    /// Get current output size estimate.
229    pub fn size_estimate(&self) -> usize {
230        self.output.len() * 2 + 4 // *2 for u16->bytes, +4 for final state
231    }
232}
233
234impl Default for RansEncoder {
235    fn default() -> Self {
236        Self::new()
237    }
238}
239
240/// Scalar rANS decoder.
241pub struct RansDecoder<'a> {
242    state: u32,
243    input: &'a [u8],
244    pos: usize,
245}
246
247impl<'a> RansDecoder<'a> {
248    /// Create a new rANS decoder from input bytes.
249    pub fn new(input: &'a [u8]) -> anyhow::Result<Self> {
250        if input.len() < 4 {
251            anyhow::bail!("rANS input too short");
252        }
253
254        // Read initial state (little-endian, first 4 bytes)
255        let state = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
256
257        Ok(Self {
258            state,
259            input,
260            pos: 4,
261        })
262    }
263
264    /// Decode a symbol using a CDF table.
265    ///
266    /// rANS decoding:
267    /// 1. Extract slot = state % total (= state & (ANS_TOTAL - 1))
268    /// 2. Find symbol `s` where `cdf[s] <= slot < cdf[s+1]`
269    /// 3. Update state: x' = freq * (x >> ANS_BITS) + (x & (ANS_TOTAL-1)) - c_lo
270    #[inline]
271    pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
272        // Extract slot from state (low ANS_BITS bits)
273        let slot = self.state & (ANS_TOTAL - 1);
274
275        // Binary search for symbol `s` where `cdf[s] <= slot < cdf[s+1]`
276        let mut lo = 0usize;
277        let mut hi = cdf.len() - 1;
278        while lo + 1 < hi {
279            let mid = (lo + hi) / 2;
280            if cdf[mid] <= slot {
281                lo = mid;
282            } else {
283                hi = mid;
284            }
285        }
286        let sym = lo;
287
288        let c_lo = cdf[sym];
289        let c_hi = cdf[sym + 1];
290        let freq = c_hi - c_lo;
291
292        // Decode: x' = freq * (x >> ANS_BITS) + (x & (ANS_TOTAL-1)) - c_lo
293        self.state = freq * (self.state >> ANS_BITS) + slot - c_lo;
294
295        // Renormalize: read 16-bit words while state < ANS_LOW
296        while self.state < ANS_LOW && self.pos + 1 < self.input.len() {
297            let word = u16::from_le_bytes([self.input[self.pos], self.input[self.pos + 1]]);
298            self.state = (self.state << 16) | (word as u32);
299            self.pos += 2;
300        }
301
302        Ok(sym)
303    }
304
305    /// Decode a symbol given a PDF.
306    pub fn decode_pdf(&mut self, pdf: &[f64]) -> anyhow::Result<usize> {
307        let cdf = quantize_pdf_to_rans_cdf(pdf);
308        self.decode(&cdf)
309    }
310}
311
312// =============================================================================
313// SIMD-optimized 8-way interleaved rANS (x86_64 AVX2)
314// =============================================================================
315
316mod simd {
317    use super::*;
318    #[allow(unused_imports)]
319    use std::arch::x86_64::*;
320
321    /// Number of parallel rANS streams
322    pub const RANS_LANES: usize = 8;
323
324    /// 8-way parallel rANS encoder using AVX2.
325    pub struct SimdRansEncoder {
326        states: [u32; RANS_LANES],
327        outputs: [Vec<u8>; RANS_LANES],
328        lane: usize,
329    }
330
331    impl SimdRansEncoder {
332        /// Create a new SIMD rANS encoder.
333        pub fn new() -> Self {
334            Self {
335                states: [ANS_LOW; RANS_LANES],
336                outputs: Default::default(),
337                lane: 0,
338            }
339        }
340
341        /// Encode a symbol, cycling through lanes.
342        pub fn encode(&mut self, cdf: &Cdf) {
343            let freq = cdf.freq();
344            let lane = self.lane;
345            self.lane = (self.lane + 1) % RANS_LANES;
346
347            let state = &mut self.states[lane];
348            let output = &mut self.outputs[lane];
349
350            // Renormalize
351            while *state >= (ANS_HIGH / cdf.total) * freq {
352                output.push(*state as u8);
353                *state >>= 8;
354            }
355
356            // Encode
357            *state = ((*state / freq) * cdf.total) + (*state % freq) + cdf.lo;
358        }
359
360        /// Finish encoding and return interleaved output.
361        pub fn finish(self) -> Vec<u8> {
362            let mut result = Vec::new();
363
364            // Output final states (interleaved)
365            for i in 0..RANS_LANES {
366                let s = self.states[i];
367                result.extend_from_slice(&s.to_le_bytes());
368            }
369
370            // Find max output length
371            let max_len = self.outputs.iter().map(|v| v.len()).max().unwrap_or(0);
372
373            // Interleave output bytes
374            for pos in 0..max_len {
375                for lane in 0..RANS_LANES {
376                    let out = &self.outputs[lane];
377                    if pos < out.len() {
378                        result.push(out[out.len() - 1 - pos]);
379                    } else {
380                        result.push(0);
381                    }
382                }
383            }
384
385            result
386        }
387    }
388
389    impl Default for SimdRansEncoder {
390        fn default() -> Self {
391            Self::new()
392        }
393    }
394
395    /// 8-way parallel rANS decoder using AVX2.
396    pub struct SimdRansDecoder<'a> {
397        states: [u32; RANS_LANES],
398        input: &'a [u8],
399        pos: usize,
400        lane: usize,
401    }
402
403    impl<'a> SimdRansDecoder<'a> {
404        /// Create a new SIMD rANS decoder.
405        pub fn new(input: &'a [u8]) -> anyhow::Result<Self> {
406            if input.len() < RANS_LANES * 4 {
407                anyhow::bail!("SIMD rANS input too short");
408            }
409
410            let mut states = [0u32; RANS_LANES];
411            for i in 0..RANS_LANES {
412                let offset = i * 4;
413                states[i] = u32::from_le_bytes([
414                    input[offset],
415                    input[offset + 1],
416                    input[offset + 2],
417                    input[offset + 3],
418                ]);
419            }
420
421            Ok(Self {
422                states,
423                input,
424                pos: RANS_LANES * 4,
425                lane: 0,
426            })
427        }
428
429        /// Decode a symbol from the current lane.
430        pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
431            let lane = self.lane;
432            self.lane = (self.lane + 1) % RANS_LANES;
433
434            let state = &mut self.states[lane];
435            let total = ANS_TOTAL;
436            let value = *state & (total - 1);
437
438            // Binary search
439            let mut lo = 0usize;
440            let mut hi = cdf.len() - 1;
441            while lo + 1 < hi {
442                let mid = (lo + hi) / 2;
443                if cdf[mid] <= value {
444                    lo = mid;
445                } else {
446                    hi = mid;
447                }
448            }
449            let sym = lo;
450
451            let c_lo = cdf[sym];
452            let c_hi = cdf[sym + 1];
453            let freq = c_hi - c_lo;
454
455            // Decode
456            *state = freq * (*state >> ANS_BITS) + (*state & (total - 1)) - c_lo;
457
458            // Renormalize (read from interleaved stream)
459            while *state < ANS_LOW {
460                // Read byte for this lane
461                let byte_idx = self.pos + lane;
462                if byte_idx < self.input.len() {
463                    *state = (*state << 8) | (self.input[byte_idx] as u32);
464                }
465                self.pos += RANS_LANES;
466            }
467
468            Ok(sym)
469        }
470    }
471}
472pub use simd::*;
473
474// Non-SIMD fallback
475// No fallback, SIMD rANS is required on x86_64 targets.
476
477// =============================================================================
478// Blocked rANS for streaming large files
479// =============================================================================
480
481/// Block size for blocked rANS (128KB)
482pub const BLOCK_SIZE: usize = 128 * 1024;
483
484/// Blocked rANS encoder for streaming large files.
485///
486/// Splits input into 128KB blocks and encodes each independently.
487/// This allows O(1) memory for encoding arbitrary-sized inputs.
488pub struct BlockedRansEncoder {
489    /// Symbols buffered for current block (stores low/high bounds only)
490    symbols: Vec<Cdf>,
491    /// Encoded blocks
492    blocks: Vec<Vec<u8>>,
493}
494
495impl BlockedRansEncoder {
496    pub fn new() -> Self {
497        Self {
498            symbols: Vec::with_capacity(BLOCK_SIZE),
499            blocks: Vec::new(),
500        }
501    }
502
503    /// Encode a symbol with its CDF.
504    pub fn encode(&mut self, cdf: Cdf) {
505        self.symbols.push(cdf);
506
507        // Flush block if full
508        if self.symbols.len() >= BLOCK_SIZE {
509            self.flush_block();
510        }
511    }
512
513    /// Flush the current block.
514    fn flush_block(&mut self) {
515        if self.symbols.is_empty() {
516            return;
517        }
518
519        // Encode in reverse order (rANS is LIFO)
520        let mut encoder = RansEncoder::new();
521        for cdf in self.symbols.iter().rev() {
522            encoder.encode(cdf);
523        }
524
525        let encoded = encoder.finish();
526        self.blocks.push(encoded);
527        self.symbols.clear();
528    }
529
530    /// Finish encoding and return all blocks.
531    pub fn finish(mut self) -> Vec<Vec<u8>> {
532        // Flush any remaining symbols
533        self.flush_block();
534        self.blocks
535    }
536}
537
538impl Default for BlockedRansEncoder {
539    fn default() -> Self {
540        Self::new()
541    }
542}
543
544/// Blocked rANS decoder for streaming large files.
545pub struct BlockedRansDecoder<'a> {
546    blocks: Vec<&'a [u8]>,
547    current_block: usize,
548    decoder: Option<RansDecoder<'a>>,
549}
550
551impl<'a> BlockedRansDecoder<'a> {
552    /// Create a new blocked decoder from encoded blocks.
553    pub fn new(blocks: Vec<&'a [u8]>) -> Self {
554        Self {
555            blocks,
556            current_block: 0,
557            decoder: None,
558        }
559    }
560
561    /// Decode next symbol with provided CDF.
562    pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
563        // Initialize decoder for first block if needed
564        if self.decoder.is_none() {
565            if self.current_block >= self.blocks.len() {
566                anyhow::bail!("No more blocks to decode");
567            }
568            self.decoder = Some(RansDecoder::new(self.blocks[self.current_block])?);
569        }
570
571        // Try to decode from current block
572        match self.decoder.as_mut().unwrap().decode(cdf) {
573            Ok(sym) => Ok(sym),
574            Err(_) => {
575                // Current block exhausted, move to next
576                self.current_block += 1;
577                if self.current_block >= self.blocks.len() {
578                    anyhow::bail!("All blocks exhausted");
579                }
580                self.decoder = Some(RansDecoder::new(self.blocks[self.current_block])?);
581                self.decoder.as_mut().unwrap().decode(cdf)
582            }
583        }
584    }
585}
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590
591    #[test]
592    fn test_roundtrip_scalar() {
593        let pdf = vec![0.5, 0.3, 0.15, 0.05];
594        let symbols = vec![0, 0, 1, 0, 2, 1, 0, 3, 0, 0, 1, 2];
595
596        // Encode in REVERSE order (rANS is LIFO)
597        let mut enc = RansEncoder::new();
598        let cdf_table = quantize_pdf_to_rans_cdf(&pdf);
599        for &s in symbols.iter().rev() {
600            let cdf = cdf_for_symbol(&cdf_table, s);
601            enc.encode(&cdf);
602        }
603        let encoded = enc.finish();
604
605        // Decode in FORWARD order
606        let mut dec = RansDecoder::new(&encoded).unwrap();
607        for &expected in &symbols {
608            let got = dec.decode(&cdf_table).unwrap();
609            assert_eq!(got, expected, "Symbol mismatch");
610        }
611    }
612
613    #[test]
614    fn test_cdf_quantization() {
615        let pdf = vec![0.25, 0.25, 0.25, 0.25];
616        let cdf = quantize_pdf_to_rans_cdf(&pdf);
617
618        assert_eq!(cdf[0], 0);
619        assert_eq!(cdf[4], ANS_TOTAL);
620
621        // Check roughly equal spacing
622        for i in 1..4 {
623            let delta = cdf[i] - cdf[i - 1];
624            assert!(delta > 0);
625        }
626    }
627
628    #[test]
629    fn test_extreme_probabilities() {
630        // Very skewed distribution
631        let pdf = vec![0.99, 0.005, 0.003, 0.002];
632        let symbols = vec![0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 3];
633
634        // Encode in REVERSE order (rANS is LIFO)
635        let mut enc = RansEncoder::new();
636        let cdf_table = quantize_pdf_to_rans_cdf(&pdf);
637        for &s in symbols.iter().rev() {
638            let cdf = cdf_for_symbol(&cdf_table, s);
639            enc.encode(&cdf);
640        }
641        let encoded = enc.finish();
642
643        // Decode in FORWARD order
644        let mut dec = RansDecoder::new(&encoded).unwrap();
645        for &expected in &symbols {
646            let got = dec.decode(&cdf_table).unwrap();
647            assert_eq!(got, expected);
648        }
649    }
650}