infotheory/coders/
rans.rs

1//! rANS (range Asymmetric Numeral System) coder with an optional multi-lane path.
2//!
3//! The primary implementation is scalar and portable. On x86_64 builds, this
4//! module also exposes an 8-lane interleaved encoder/decoder API.
5//!
6//! # Design
7//!
8//! - Uses 8-way parallel rANS states for throughput
9//! - 15-bit precision for probability quantization
10//! - Interleaved bitstream for decoder efficiency
11//! - Supports both streaming and batch modes
12
13/// Number of bits for rANS probability precision
14pub const ANS_BITS: u32 = 15;
15
16/// Total probability range (2^15 = 32768)
17pub const ANS_TOTAL: u32 = 1 << ANS_BITS;
18
19/// Lower bound for rANS state (2^15)
20pub const ANS_LOW: u32 = 1 << ANS_BITS;
21
22/// Upper bound for rANS state (2^31)
23pub const ANS_HIGH: u32 = 1 << 31;
24
25/// rANS CDF representation for a symbol.
26#[derive(Clone, Debug)]
27pub struct Cdf {
28    /// Lower cumulative probability bound
29    pub lo: u32,
30    /// Upper cumulative probability bound  
31    pub hi: u32,
32    /// Total probability (should be ANS_TOTAL)
33    pub total: u32,
34}
35
36impl Cdf {
37    /// Create a new CDF entry.
38    #[inline]
39    pub fn new(lo: u32, hi: u32, total: u32) -> Self {
40        Self { lo, hi, total }
41    }
42
43    /// Get the frequency (hi - lo).
44    #[inline]
45    pub fn freq(&self) -> u32 {
46        self.hi - self.lo
47    }
48}
49
50/// Quantize PDF to rANS CDF table with guaranteed minimum frequencies.
51///
52/// This implements a robust quantization algorithm that ensures:
53/// 1. All symbols with p > 0 get freq >= 1
54/// 2. The total equals ANS_TOTAL exactly
55/// 3. Monotonicity is preserved (`cdf[i+1] >= cdf[i]`)
56///
57/// Uses error diffusion to distribute rounding errors across symbols.
58///
59/// # Arguments
60/// * `pdf` - Probability distribution (must sum to ~1.0)
61///
62/// # Returns
63/// CDF table where `cdf[i]` = cumulative probability up to symbol i
64pub fn quantize_pdf_to_rans_cdf(pdf: &[f64]) -> Vec<u32> {
65    let mut cdf = vec![0u32; pdf.len() + 1];
66    let mut freqs = vec![0i64; pdf.len()];
67    quantize_pdf_to_rans_cdf_with_buffer(pdf, &mut cdf, &mut freqs);
68    cdf
69}
70
71/// Quantize PDF to rANS CDF using reusable scratch buffers.
72///
73/// * `cdf_out` must have length at least `pdf.len() + 1`
74/// * `freq_buf` must have length at least `pdf.len()`
75/// * `index_buf` must have length at least `pdf.len()`
76pub fn quantize_pdf_to_rans_cdf_with_buffer(
77    pdf: &[f64],
78    cdf_out: &mut [u32],
79    freq_buf: &mut [i64],
80) {
81    let n = pdf.len();
82    super::quantize_pdf_to_integer_cdf_with_buffer(pdf, ANS_TOTAL, cdf_out, freq_buf);
83
84    debug_assert_eq!(cdf_out[n], ANS_TOTAL, "CDF total must equal ANS_TOTAL");
85    for i in 0..n {
86        if pdf[i] > 0.0 {
87            debug_assert!(
88                cdf_out[i + 1] > cdf_out[i],
89                "Symbol {} with p={} has zero frequency",
90                i,
91                pdf[i]
92            );
93        }
94    }
95}
96
97/// Get Cdf for a symbol from a CDF table.
98#[inline]
99pub fn cdf_for_symbol(cdf: &[u32], sym: usize) -> Cdf {
100    Cdf::new(cdf[sym], cdf[sym + 1], ANS_TOTAL)
101}
102
103/// Scalar rANS encoder.
104pub struct RansEncoder {
105    state: u32,
106    output: Vec<u16>, // 16-bit words for output
107}
108
109impl RansEncoder {
110    /// Create a new rANS encoder.
111    pub fn new() -> Self {
112        Self {
113            state: ANS_LOW,
114            output: Vec::new(),
115        }
116    }
117
118    /// Encode a symbol using its CDF bounds.
119    ///
120    /// rANS encoding formula:
121    /// x' = (x / freq) << ANS_BITS + (x % freq) + c_lo
122    #[inline]
123    pub fn encode(&mut self, cdf: &Cdf) {
124        let freq = cdf.freq();
125        debug_assert!(freq > 0, "Symbol frequency must be > 0");
126
127        // Renormalize: output 16-bit words while state >= max allowed
128        // Max state after encode: freq * (2^16) - 1, we need this < ANS_HIGH
129        // So we renorm when state >= (freq << (32 - 1 - ANS_BITS)) = freq << 16
130        while self.state >= (freq << 16) {
131            self.output.push(self.state as u16);
132            self.state >>= 16;
133        }
134
135        // Encode: x' = (x / freq) << ANS_BITS + (x % freq) + c_lo
136        let q = self.state / freq;
137        let r = self.state % freq;
138        self.state = (q << ANS_BITS) + r + cdf.lo;
139    }
140
141    /// Encode a symbol given a PDF.
142    pub fn encode_pdf(&mut self, pdf: &[f64], sym: usize) {
143        let cdf_table = quantize_pdf_to_rans_cdf(pdf);
144        let cdf = cdf_for_symbol(&cdf_table, sym);
145        self.encode(&cdf);
146    }
147
148    /// Finish encoding and return the output bytes.
149    pub fn finish(self) -> Vec<u8> {
150        // Output final state (4 bytes)
151        let mut result = Vec::with_capacity(self.output.len() * 2 + 4);
152
153        // Push final state first (will be read first during decode)
154        result.extend_from_slice(&self.state.to_le_bytes());
155
156        // Push output words in reverse order (LIFO)
157        for &word in self.output.iter().rev() {
158            result.extend_from_slice(&word.to_le_bytes());
159        }
160
161        result
162    }
163
164    /// Get current output size estimate.
165    pub fn size_estimate(&self) -> usize {
166        self.output.len() * 2 + 4 // *2 for u16->bytes, +4 for final state
167    }
168}
169
170impl Default for RansEncoder {
171    fn default() -> Self {
172        Self::new()
173    }
174}
175
176/// Scalar rANS decoder.
177pub struct RansDecoder<'a> {
178    state: u32,
179    input: &'a [u8],
180    pos: usize,
181}
182
183impl<'a> RansDecoder<'a> {
184    /// Create a new rANS decoder from input bytes.
185    pub fn new(input: &'a [u8]) -> anyhow::Result<Self> {
186        if input.len() < 4 {
187            anyhow::bail!("rANS input too short");
188        }
189
190        // Read initial state (little-endian, first 4 bytes)
191        let state = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
192
193        Ok(Self {
194            state,
195            input,
196            pos: 4,
197        })
198    }
199
200    /// Decode a symbol using a CDF table.
201    ///
202    /// rANS decoding:
203    /// 1. Extract slot = state % total (= state & (ANS_TOTAL - 1))
204    /// 2. Find symbol `s` where `cdf[s] <= slot < cdf[s+1]`
205    /// 3. Update state: x' = freq * (x >> ANS_BITS) + (x & (ANS_TOTAL-1)) - c_lo
206    #[inline]
207    pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
208        // Extract slot from state (low ANS_BITS bits)
209        let slot = self.state & (ANS_TOTAL - 1);
210
211        // Binary search for symbol `s` where `cdf[s] <= slot < cdf[s+1]`
212        let mut lo = 0usize;
213        let mut hi = cdf.len() - 1;
214        while lo + 1 < hi {
215            let mid = (lo + hi) / 2;
216            if cdf[mid] <= slot {
217                lo = mid;
218            } else {
219                hi = mid;
220            }
221        }
222        let sym = lo;
223
224        let c_lo = cdf[sym];
225        let c_hi = cdf[sym + 1];
226        let freq = c_hi - c_lo;
227
228        // Decode: x' = freq * (x >> ANS_BITS) + (x & (ANS_TOTAL-1)) - c_lo
229        self.state = freq * (self.state >> ANS_BITS) + slot - c_lo;
230
231        // Renormalize: read 16-bit words while state < ANS_LOW
232        while self.state < ANS_LOW && self.pos + 1 < self.input.len() {
233            let word = u16::from_le_bytes([self.input[self.pos], self.input[self.pos + 1]]);
234            self.state = (self.state << 16) | (word as u32);
235            self.pos += 2;
236        }
237
238        Ok(sym)
239    }
240
241    /// Decode a symbol given a PDF.
242    pub fn decode_pdf(&mut self, pdf: &[f64]) -> anyhow::Result<usize> {
243        let cdf = quantize_pdf_to_rans_cdf(pdf);
244        self.decode(&cdf)
245    }
246}
247
248// =============================================================================
249// 8-way interleaved rANS API (x86_64 build target)
250// =============================================================================
251
252#[cfg(target_arch = "x86_64")]
253mod simd {
254    use super::*;
255
256    /// Number of parallel rANS streams
257    pub const RANS_LANES: usize = 8;
258
259    /// 8-way interleaved rANS encoder.
260    pub struct SimdRansEncoder {
261        states: [u32; RANS_LANES],
262        outputs: [Vec<u8>; RANS_LANES],
263        lane: usize,
264    }
265
266    impl SimdRansEncoder {
267        /// Create a new SIMD rANS encoder.
268        pub fn new() -> Self {
269            Self {
270                states: [ANS_LOW; RANS_LANES],
271                outputs: Default::default(),
272                lane: 0,
273            }
274        }
275
276        /// Encode a symbol, cycling through lanes.
277        pub fn encode(&mut self, cdf: &Cdf) {
278            let freq = cdf.freq();
279            let lane = self.lane;
280            self.lane = (self.lane + 1) % RANS_LANES;
281
282            let state = &mut self.states[lane];
283            let output = &mut self.outputs[lane];
284
285            // Renormalize
286            while *state >= (ANS_HIGH / cdf.total) * freq {
287                output.push(*state as u8);
288                *state >>= 8;
289            }
290
291            // Encode
292            *state = ((*state / freq) * cdf.total) + (*state % freq) + cdf.lo;
293        }
294
295        /// Finish encoding and return interleaved output.
296        pub fn finish(self) -> Vec<u8> {
297            let mut result = Vec::new();
298
299            // Output final states (interleaved)
300            for &s in self.states.iter().take(RANS_LANES) {
301                result.extend_from_slice(&s.to_le_bytes());
302            }
303
304            // Find max output length
305            let max_len = self.outputs.iter().map(|v| v.len()).max().unwrap_or(0);
306
307            // Interleave output bytes
308            for pos in 0..max_len {
309                for lane in 0..RANS_LANES {
310                    let out = &self.outputs[lane];
311                    if pos < out.len() {
312                        result.push(out[out.len() - 1 - pos]);
313                    } else {
314                        result.push(0);
315                    }
316                }
317            }
318
319            result
320        }
321    }
322
323    impl Default for SimdRansEncoder {
324        fn default() -> Self {
325            Self::new()
326        }
327    }
328
329    /// 8-way interleaved rANS decoder.
330    pub struct SimdRansDecoder<'a> {
331        states: [u32; RANS_LANES],
332        input: &'a [u8],
333        pos: usize,
334        lane: usize,
335    }
336
337    impl<'a> SimdRansDecoder<'a> {
338        /// Create a new SIMD rANS decoder.
339        pub fn new(input: &'a [u8]) -> anyhow::Result<Self> {
340            if input.len() < RANS_LANES * 4 {
341                anyhow::bail!("SIMD rANS input too short");
342            }
343
344            let mut states = [0u32; RANS_LANES];
345            for (i, state) in states.iter_mut().enumerate() {
346                let offset = i * 4;
347                *state = u32::from_le_bytes([
348                    input[offset],
349                    input[offset + 1],
350                    input[offset + 2],
351                    input[offset + 3],
352                ]);
353            }
354
355            Ok(Self {
356                states,
357                input,
358                pos: RANS_LANES * 4,
359                lane: 0,
360            })
361        }
362
363        /// Decode a symbol from the current lane.
364        pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
365            let lane = self.lane;
366            self.lane = (self.lane + 1) % RANS_LANES;
367
368            let state = &mut self.states[lane];
369            let total = ANS_TOTAL;
370            let value = *state & (total - 1);
371
372            // Binary search
373            let mut lo = 0usize;
374            let mut hi = cdf.len() - 1;
375            while lo + 1 < hi {
376                let mid = (lo + hi) / 2;
377                if cdf[mid] <= value {
378                    lo = mid;
379                } else {
380                    hi = mid;
381                }
382            }
383            let sym = lo;
384
385            let c_lo = cdf[sym];
386            let c_hi = cdf[sym + 1];
387            let freq = c_hi - c_lo;
388
389            // Decode
390            *state = freq * (*state >> ANS_BITS) + (*state & (total - 1)) - c_lo;
391
392            // Renormalize (read from interleaved stream)
393            while *state < ANS_LOW {
394                // Read byte for this lane
395                let byte_idx = self.pos + lane;
396                if byte_idx < self.input.len() {
397                    *state = (*state << 8) | (self.input[byte_idx] as u32);
398                }
399                self.pos += RANS_LANES;
400            }
401
402            Ok(sym)
403        }
404    }
405}
406#[cfg(target_arch = "x86_64")]
407/// SIMD lane-parallel rANS types on x86_64.
408pub use simd::*;
409
410#[cfg(not(target_arch = "x86_64"))]
411/// Number of SIMD lanes for the portable fallback (single-lane).
412pub const RANS_LANES: usize = 1;
413
414#[cfg(not(target_arch = "x86_64"))]
415/// Portable wrapper that maps SIMD encoder API to scalar rANS.
416pub struct SimdRansEncoder {
417    inner: RansEncoder,
418}
419
420#[cfg(not(target_arch = "x86_64"))]
421impl SimdRansEncoder {
422    /// Create a fallback single-lane encoder.
423    pub fn new() -> Self {
424        Self {
425            inner: RansEncoder::new(),
426        }
427    }
428
429    /// Encode one symbol using scalar rANS.
430    pub fn encode(&mut self, cdf: &Cdf) {
431        self.inner.encode(cdf);
432    }
433
434    /// Finalize encoding and return encoded bytes.
435    pub fn finish(self) -> Vec<u8> {
436        self.inner.finish()
437    }
438}
439
440#[cfg(not(target_arch = "x86_64"))]
441impl Default for SimdRansEncoder {
442    fn default() -> Self {
443        Self::new()
444    }
445}
446
447#[cfg(not(target_arch = "x86_64"))]
448/// Portable wrapper that maps SIMD decoder API to scalar rANS.
449pub struct SimdRansDecoder<'a> {
450    inner: RansDecoder<'a>,
451}
452
453#[cfg(not(target_arch = "x86_64"))]
454impl<'a> SimdRansDecoder<'a> {
455    /// Create a fallback single-lane decoder.
456    pub fn new(input: &'a [u8]) -> anyhow::Result<Self> {
457        Ok(Self {
458            inner: RansDecoder::new(input)?,
459        })
460    }
461
462    /// Decode one symbol using scalar rANS.
463    pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
464        self.inner.decode(cdf)
465    }
466}
467
468// =============================================================================
469// Blocked rANS for streaming large files
470// =============================================================================
471
472/// Block size for blocked rANS (128KB)
473pub const BLOCK_SIZE: usize = 128 * 1024;
474
475/// Blocked rANS encoder for streaming large files.
476///
477/// Splits input into 128KB blocks and encodes each independently.
478/// This allows O(1) memory for encoding arbitrary-sized inputs.
479pub struct BlockedRansEncoder {
480    /// Symbols buffered for current block (stores low/high bounds only)
481    symbols: Vec<Cdf>,
482    /// Encoded blocks
483    blocks: Vec<Vec<u8>>,
484}
485
486impl BlockedRansEncoder {
487    /// Create an empty blocked encoder.
488    pub fn new() -> Self {
489        Self {
490            symbols: Vec::with_capacity(BLOCK_SIZE),
491            blocks: Vec::new(),
492        }
493    }
494
495    /// Encode a symbol with its CDF.
496    pub fn encode(&mut self, cdf: Cdf) {
497        self.symbols.push(cdf);
498
499        // Flush block if full
500        if self.symbols.len() >= BLOCK_SIZE {
501            self.flush_block();
502        }
503    }
504
505    /// Flush the current block.
506    fn flush_block(&mut self) {
507        if self.symbols.is_empty() {
508            return;
509        }
510
511        // Encode in reverse order (rANS is LIFO)
512        let mut encoder = RansEncoder::new();
513        for cdf in self.symbols.iter().rev() {
514            encoder.encode(cdf);
515        }
516
517        let encoded = encoder.finish();
518        self.blocks.push(encoded);
519        self.symbols.clear();
520    }
521
522    /// Finish encoding and return all blocks.
523    pub fn finish(mut self) -> Vec<Vec<u8>> {
524        // Flush any remaining symbols
525        self.flush_block();
526        self.blocks
527    }
528}
529
530impl Default for BlockedRansEncoder {
531    fn default() -> Self {
532        Self::new()
533    }
534}
535
536/// Blocked rANS decoder for streaming large files.
537pub struct BlockedRansDecoder<'a> {
538    blocks: Vec<&'a [u8]>,
539    current_block: usize,
540    symbols_remaining_in_block: usize,
541    total_symbols: usize,
542    decoder: Option<RansDecoder<'a>>,
543}
544
545impl<'a> BlockedRansDecoder<'a> {
546    /// Create a new blocked decoder from encoded blocks.
547    pub fn new(blocks: Vec<&'a [u8]>, total_symbols: usize) -> anyhow::Result<Self> {
548        let expected_blocks = if total_symbols == 0 {
549            0
550        } else {
551            total_symbols.div_ceil(BLOCK_SIZE)
552        };
553        if blocks.len() != expected_blocks {
554            anyhow::bail!(
555                "blocked rANS expected {expected_blocks} blocks for {total_symbols} symbols, got {}",
556                blocks.len()
557            );
558        }
559        Ok(Self {
560            blocks,
561            current_block: 0,
562            symbols_remaining_in_block: 0,
563            total_symbols,
564            decoder: None,
565        })
566    }
567
568    #[inline]
569    fn open_block(&mut self, block_index: usize) -> anyhow::Result<()> {
570        if block_index >= self.blocks.len() {
571            anyhow::bail!("No more blocks to decode");
572        }
573        let consumed = block_index.saturating_mul(BLOCK_SIZE);
574        let remaining = self.total_symbols.saturating_sub(consumed);
575        self.current_block = block_index;
576        self.symbols_remaining_in_block = remaining.min(BLOCK_SIZE);
577        self.decoder = Some(RansDecoder::new(self.blocks[block_index])?);
578        Ok(())
579    }
580
581    /// Decode next symbol with provided CDF.
582    pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
583        if self.symbols_remaining_in_block == 0 {
584            if self.decoder.is_some() {
585                self.open_block(self.current_block + 1)?;
586            } else {
587                self.open_block(0)?;
588            }
589        }
590
591        let sym = self
592            .decoder
593            .as_mut()
594            .expect("decoder initialized for current block")
595            .decode(cdf)?;
596        self.symbols_remaining_in_block = self.symbols_remaining_in_block.saturating_sub(1);
597        Ok(sym)
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604
605    #[test]
606    fn test_roundtrip_scalar() {
607        let pdf = vec![0.5, 0.3, 0.15, 0.05];
608        let symbols = vec![0, 0, 1, 0, 2, 1, 0, 3, 0, 0, 1, 2];
609
610        // Encode in REVERSE order (rANS is LIFO)
611        let mut enc = RansEncoder::new();
612        let cdf_table = quantize_pdf_to_rans_cdf(&pdf);
613        for &s in symbols.iter().rev() {
614            let cdf = cdf_for_symbol(&cdf_table, s);
615            enc.encode(&cdf);
616        }
617        let encoded = enc.finish();
618
619        // Decode in FORWARD order
620        let mut dec = RansDecoder::new(&encoded).unwrap();
621        for &expected in &symbols {
622            let got = dec.decode(&cdf_table).unwrap();
623            assert_eq!(got, expected, "Symbol mismatch");
624        }
625    }
626
627    #[test]
628    fn test_cdf_quantization() {
629        let pdf = vec![0.25, 0.25, 0.25, 0.25];
630        let cdf = quantize_pdf_to_rans_cdf(&pdf);
631
632        assert_eq!(cdf[0], 0);
633        assert_eq!(cdf[4], ANS_TOTAL);
634
635        // Check roughly equal spacing
636        for i in 1..4 {
637            let delta = cdf[i] - cdf[i - 1];
638            assert!(delta > 0);
639        }
640    }
641
642    #[test]
643    fn test_extreme_probabilities() {
644        // Very skewed distribution
645        let pdf = vec![0.99, 0.005, 0.003, 0.002];
646        let symbols = vec![0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 3];
647
648        // Encode in REVERSE order (rANS is LIFO)
649        let mut enc = RansEncoder::new();
650        let cdf_table = quantize_pdf_to_rans_cdf(&pdf);
651        for &s in symbols.iter().rev() {
652            let cdf = cdf_for_symbol(&cdf_table, s);
653            enc.encode(&cdf);
654        }
655        let encoded = enc.finish();
656
657        // Decode in FORWARD order
658        let mut dec = RansDecoder::new(&encoded).unwrap();
659        for &expected in &symbols {
660            let got = dec.decode(&cdf_table).unwrap();
661            assert_eq!(got, expected);
662        }
663    }
664
665    #[test]
666    fn test_blocked_rans_roundtrip_across_block_boundary() {
667        let pdf = vec![0.5, 0.25, 0.125, 0.125];
668        let cdf = quantize_pdf_to_rans_cdf(&pdf);
669        let symbols: Vec<usize> = (0..(BLOCK_SIZE + 17)).map(|i| i % pdf.len()).collect();
670
671        let mut enc = BlockedRansEncoder::new();
672        for &sym in &symbols {
673            enc.encode(cdf_for_symbol(&cdf, sym));
674        }
675        let blocks = enc.finish();
676        let block_refs: Vec<&[u8]> = blocks.iter().map(Vec::as_slice).collect();
677
678        let mut dec = BlockedRansDecoder::new(block_refs, symbols.len()).unwrap();
679        for &expected in &symbols {
680            let got = dec.decode(&cdf).unwrap();
681            assert_eq!(got, expected, "blocked rANS mismatch at symbol {expected}");
682        }
683    }
684}