1pub const ANS_BITS: u32 = 15;
15
16pub const ANS_TOTAL: u32 = 1 << ANS_BITS;
18
19pub const ANS_LOW: u32 = 1 << ANS_BITS;
21
22pub const ANS_HIGH: u32 = 1 << 31;
24
25#[derive(Clone, Debug)]
27pub struct Cdf {
28 pub lo: u32,
30 pub hi: u32,
32 pub total: u32,
34}
35
36impl Cdf {
37 #[inline]
39 pub fn new(lo: u32, hi: u32, total: u32) -> Self {
40 Self { lo, hi, total }
41 }
42
43 #[inline]
45 pub fn freq(&self) -> u32 {
46 self.hi - self.lo
47 }
48}
49
50pub fn quantize_pdf_to_rans_cdf(pdf: &[f64]) -> Vec<u32> {
65 let mut cdf = vec![0u32; pdf.len() + 1];
66 let mut freqs = vec![0i64; pdf.len()];
67 quantize_pdf_to_rans_cdf_with_buffer(pdf, &mut cdf, &mut freqs);
68 cdf
69}
70
71pub fn quantize_pdf_to_rans_cdf_with_buffer(
77 pdf: &[f64],
78 cdf_out: &mut [u32],
79 freq_buf: &mut [i64],
80) {
81 let n = pdf.len();
82 super::quantize_pdf_to_integer_cdf_with_buffer(pdf, ANS_TOTAL, cdf_out, freq_buf);
83
84 debug_assert_eq!(cdf_out[n], ANS_TOTAL, "CDF total must equal ANS_TOTAL");
85 for i in 0..n {
86 if pdf[i] > 0.0 {
87 debug_assert!(
88 cdf_out[i + 1] > cdf_out[i],
89 "Symbol {} with p={} has zero frequency",
90 i,
91 pdf[i]
92 );
93 }
94 }
95}
96
97#[inline]
99pub fn cdf_for_symbol(cdf: &[u32], sym: usize) -> Cdf {
100 Cdf::new(cdf[sym], cdf[sym + 1], ANS_TOTAL)
101}
102
103pub struct RansEncoder {
105 state: u32,
106 output: Vec<u16>, }
108
109impl RansEncoder {
110 pub fn new() -> Self {
112 Self {
113 state: ANS_LOW,
114 output: Vec::new(),
115 }
116 }
117
118 #[inline]
123 pub fn encode(&mut self, cdf: &Cdf) {
124 let freq = cdf.freq();
125 debug_assert!(freq > 0, "Symbol frequency must be > 0");
126
127 while self.state >= (freq << 16) {
131 self.output.push(self.state as u16);
132 self.state >>= 16;
133 }
134
135 let q = self.state / freq;
137 let r = self.state % freq;
138 self.state = (q << ANS_BITS) + r + cdf.lo;
139 }
140
141 pub fn encode_pdf(&mut self, pdf: &[f64], sym: usize) {
143 let cdf_table = quantize_pdf_to_rans_cdf(pdf);
144 let cdf = cdf_for_symbol(&cdf_table, sym);
145 self.encode(&cdf);
146 }
147
148 pub fn finish(self) -> Vec<u8> {
150 let mut result = Vec::with_capacity(self.output.len() * 2 + 4);
152
153 result.extend_from_slice(&self.state.to_le_bytes());
155
156 for &word in self.output.iter().rev() {
158 result.extend_from_slice(&word.to_le_bytes());
159 }
160
161 result
162 }
163
164 pub fn size_estimate(&self) -> usize {
166 self.output.len() * 2 + 4 }
168}
169
170impl Default for RansEncoder {
171 fn default() -> Self {
172 Self::new()
173 }
174}
175
176pub struct RansDecoder<'a> {
178 state: u32,
179 input: &'a [u8],
180 pos: usize,
181}
182
183impl<'a> RansDecoder<'a> {
184 pub fn new(input: &'a [u8]) -> anyhow::Result<Self> {
186 if input.len() < 4 {
187 anyhow::bail!("rANS input too short");
188 }
189
190 let state = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
192
193 Ok(Self {
194 state,
195 input,
196 pos: 4,
197 })
198 }
199
200 #[inline]
207 pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
208 let slot = self.state & (ANS_TOTAL - 1);
210
211 let mut lo = 0usize;
213 let mut hi = cdf.len() - 1;
214 while lo + 1 < hi {
215 let mid = (lo + hi) / 2;
216 if cdf[mid] <= slot {
217 lo = mid;
218 } else {
219 hi = mid;
220 }
221 }
222 let sym = lo;
223
224 let c_lo = cdf[sym];
225 let c_hi = cdf[sym + 1];
226 let freq = c_hi - c_lo;
227
228 self.state = freq * (self.state >> ANS_BITS) + slot - c_lo;
230
231 while self.state < ANS_LOW && self.pos + 1 < self.input.len() {
233 let word = u16::from_le_bytes([self.input[self.pos], self.input[self.pos + 1]]);
234 self.state = (self.state << 16) | (word as u32);
235 self.pos += 2;
236 }
237
238 Ok(sym)
239 }
240
241 pub fn decode_pdf(&mut self, pdf: &[f64]) -> anyhow::Result<usize> {
243 let cdf = quantize_pdf_to_rans_cdf(pdf);
244 self.decode(&cdf)
245 }
246}
247
248#[cfg(target_arch = "x86_64")]
253mod simd {
254 use super::*;
255
256 pub const RANS_LANES: usize = 8;
258
259 pub struct SimdRansEncoder {
261 states: [u32; RANS_LANES],
262 outputs: [Vec<u8>; RANS_LANES],
263 lane: usize,
264 }
265
266 impl SimdRansEncoder {
267 pub fn new() -> Self {
269 Self {
270 states: [ANS_LOW; RANS_LANES],
271 outputs: Default::default(),
272 lane: 0,
273 }
274 }
275
276 pub fn encode(&mut self, cdf: &Cdf) {
278 let freq = cdf.freq();
279 let lane = self.lane;
280 self.lane = (self.lane + 1) % RANS_LANES;
281
282 let state = &mut self.states[lane];
283 let output = &mut self.outputs[lane];
284
285 while *state >= (ANS_HIGH / cdf.total) * freq {
287 output.push(*state as u8);
288 *state >>= 8;
289 }
290
291 *state = ((*state / freq) * cdf.total) + (*state % freq) + cdf.lo;
293 }
294
295 pub fn finish(self) -> Vec<u8> {
297 let mut result = Vec::new();
298
299 for &s in self.states.iter().take(RANS_LANES) {
301 result.extend_from_slice(&s.to_le_bytes());
302 }
303
304 let max_len = self.outputs.iter().map(|v| v.len()).max().unwrap_or(0);
306
307 for pos in 0..max_len {
309 for lane in 0..RANS_LANES {
310 let out = &self.outputs[lane];
311 if pos < out.len() {
312 result.push(out[out.len() - 1 - pos]);
313 } else {
314 result.push(0);
315 }
316 }
317 }
318
319 result
320 }
321 }
322
323 impl Default for SimdRansEncoder {
324 fn default() -> Self {
325 Self::new()
326 }
327 }
328
329 pub struct SimdRansDecoder<'a> {
331 states: [u32; RANS_LANES],
332 input: &'a [u8],
333 pos: usize,
334 lane: usize,
335 }
336
337 impl<'a> SimdRansDecoder<'a> {
338 pub fn new(input: &'a [u8]) -> anyhow::Result<Self> {
340 if input.len() < RANS_LANES * 4 {
341 anyhow::bail!("SIMD rANS input too short");
342 }
343
344 let mut states = [0u32; RANS_LANES];
345 for (i, state) in states.iter_mut().enumerate() {
346 let offset = i * 4;
347 *state = u32::from_le_bytes([
348 input[offset],
349 input[offset + 1],
350 input[offset + 2],
351 input[offset + 3],
352 ]);
353 }
354
355 Ok(Self {
356 states,
357 input,
358 pos: RANS_LANES * 4,
359 lane: 0,
360 })
361 }
362
363 pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
365 let lane = self.lane;
366 self.lane = (self.lane + 1) % RANS_LANES;
367
368 let state = &mut self.states[lane];
369 let total = ANS_TOTAL;
370 let value = *state & (total - 1);
371
372 let mut lo = 0usize;
374 let mut hi = cdf.len() - 1;
375 while lo + 1 < hi {
376 let mid = (lo + hi) / 2;
377 if cdf[mid] <= value {
378 lo = mid;
379 } else {
380 hi = mid;
381 }
382 }
383 let sym = lo;
384
385 let c_lo = cdf[sym];
386 let c_hi = cdf[sym + 1];
387 let freq = c_hi - c_lo;
388
389 *state = freq * (*state >> ANS_BITS) + (*state & (total - 1)) - c_lo;
391
392 while *state < ANS_LOW {
394 let byte_idx = self.pos + lane;
396 if byte_idx < self.input.len() {
397 *state = (*state << 8) | (self.input[byte_idx] as u32);
398 }
399 self.pos += RANS_LANES;
400 }
401
402 Ok(sym)
403 }
404 }
405}
406#[cfg(target_arch = "x86_64")]
407pub use simd::*;
409
410#[cfg(not(target_arch = "x86_64"))]
411pub const RANS_LANES: usize = 1;
413
414#[cfg(not(target_arch = "x86_64"))]
415pub struct SimdRansEncoder {
417 inner: RansEncoder,
418}
419
420#[cfg(not(target_arch = "x86_64"))]
421impl SimdRansEncoder {
422 pub fn new() -> Self {
424 Self {
425 inner: RansEncoder::new(),
426 }
427 }
428
429 pub fn encode(&mut self, cdf: &Cdf) {
431 self.inner.encode(cdf);
432 }
433
434 pub fn finish(self) -> Vec<u8> {
436 self.inner.finish()
437 }
438}
439
440#[cfg(not(target_arch = "x86_64"))]
441impl Default for SimdRansEncoder {
442 fn default() -> Self {
443 Self::new()
444 }
445}
446
447#[cfg(not(target_arch = "x86_64"))]
448pub struct SimdRansDecoder<'a> {
450 inner: RansDecoder<'a>,
451}
452
453#[cfg(not(target_arch = "x86_64"))]
454impl<'a> SimdRansDecoder<'a> {
455 pub fn new(input: &'a [u8]) -> anyhow::Result<Self> {
457 Ok(Self {
458 inner: RansDecoder::new(input)?,
459 })
460 }
461
462 pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
464 self.inner.decode(cdf)
465 }
466}
467
468pub const BLOCK_SIZE: usize = 128 * 1024;
474
475pub struct BlockedRansEncoder {
480 symbols: Vec<Cdf>,
482 blocks: Vec<Vec<u8>>,
484}
485
486impl BlockedRansEncoder {
487 pub fn new() -> Self {
489 Self {
490 symbols: Vec::with_capacity(BLOCK_SIZE),
491 blocks: Vec::new(),
492 }
493 }
494
495 pub fn encode(&mut self, cdf: Cdf) {
497 self.symbols.push(cdf);
498
499 if self.symbols.len() >= BLOCK_SIZE {
501 self.flush_block();
502 }
503 }
504
505 fn flush_block(&mut self) {
507 if self.symbols.is_empty() {
508 return;
509 }
510
511 let mut encoder = RansEncoder::new();
513 for cdf in self.symbols.iter().rev() {
514 encoder.encode(cdf);
515 }
516
517 let encoded = encoder.finish();
518 self.blocks.push(encoded);
519 self.symbols.clear();
520 }
521
522 pub fn finish(mut self) -> Vec<Vec<u8>> {
524 self.flush_block();
526 self.blocks
527 }
528}
529
530impl Default for BlockedRansEncoder {
531 fn default() -> Self {
532 Self::new()
533 }
534}
535
536pub struct BlockedRansDecoder<'a> {
538 blocks: Vec<&'a [u8]>,
539 current_block: usize,
540 symbols_remaining_in_block: usize,
541 total_symbols: usize,
542 decoder: Option<RansDecoder<'a>>,
543}
544
545impl<'a> BlockedRansDecoder<'a> {
546 pub fn new(blocks: Vec<&'a [u8]>, total_symbols: usize) -> anyhow::Result<Self> {
548 let expected_blocks = if total_symbols == 0 {
549 0
550 } else {
551 total_symbols.div_ceil(BLOCK_SIZE)
552 };
553 if blocks.len() != expected_blocks {
554 anyhow::bail!(
555 "blocked rANS expected {expected_blocks} blocks for {total_symbols} symbols, got {}",
556 blocks.len()
557 );
558 }
559 Ok(Self {
560 blocks,
561 current_block: 0,
562 symbols_remaining_in_block: 0,
563 total_symbols,
564 decoder: None,
565 })
566 }
567
568 #[inline]
569 fn open_block(&mut self, block_index: usize) -> anyhow::Result<()> {
570 if block_index >= self.blocks.len() {
571 anyhow::bail!("No more blocks to decode");
572 }
573 let consumed = block_index.saturating_mul(BLOCK_SIZE);
574 let remaining = self.total_symbols.saturating_sub(consumed);
575 self.current_block = block_index;
576 self.symbols_remaining_in_block = remaining.min(BLOCK_SIZE);
577 self.decoder = Some(RansDecoder::new(self.blocks[block_index])?);
578 Ok(())
579 }
580
581 pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
583 if self.symbols_remaining_in_block == 0 {
584 if self.decoder.is_some() {
585 self.open_block(self.current_block + 1)?;
586 } else {
587 self.open_block(0)?;
588 }
589 }
590
591 let sym = self
592 .decoder
593 .as_mut()
594 .expect("decoder initialized for current block")
595 .decode(cdf)?;
596 self.symbols_remaining_in_block = self.symbols_remaining_in_block.saturating_sub(1);
597 Ok(sym)
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604
605 #[test]
606 fn test_roundtrip_scalar() {
607 let pdf = vec![0.5, 0.3, 0.15, 0.05];
608 let symbols = vec![0, 0, 1, 0, 2, 1, 0, 3, 0, 0, 1, 2];
609
610 let mut enc = RansEncoder::new();
612 let cdf_table = quantize_pdf_to_rans_cdf(&pdf);
613 for &s in symbols.iter().rev() {
614 let cdf = cdf_for_symbol(&cdf_table, s);
615 enc.encode(&cdf);
616 }
617 let encoded = enc.finish();
618
619 let mut dec = RansDecoder::new(&encoded).unwrap();
621 for &expected in &symbols {
622 let got = dec.decode(&cdf_table).unwrap();
623 assert_eq!(got, expected, "Symbol mismatch");
624 }
625 }
626
627 #[test]
628 fn test_cdf_quantization() {
629 let pdf = vec![0.25, 0.25, 0.25, 0.25];
630 let cdf = quantize_pdf_to_rans_cdf(&pdf);
631
632 assert_eq!(cdf[0], 0);
633 assert_eq!(cdf[4], ANS_TOTAL);
634
635 for i in 1..4 {
637 let delta = cdf[i] - cdf[i - 1];
638 assert!(delta > 0);
639 }
640 }
641
642 #[test]
643 fn test_extreme_probabilities() {
644 let pdf = vec![0.99, 0.005, 0.003, 0.002];
646 let symbols = vec![0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 3];
647
648 let mut enc = RansEncoder::new();
650 let cdf_table = quantize_pdf_to_rans_cdf(&pdf);
651 for &s in symbols.iter().rev() {
652 let cdf = cdf_for_symbol(&cdf_table, s);
653 enc.encode(&cdf);
654 }
655 let encoded = enc.finish();
656
657 let mut dec = RansDecoder::new(&encoded).unwrap();
659 for &expected in &symbols {
660 let got = dec.decode(&cdf_table).unwrap();
661 assert_eq!(got, expected);
662 }
663 }
664
665 #[test]
666 fn test_blocked_rans_roundtrip_across_block_boundary() {
667 let pdf = vec![0.5, 0.25, 0.125, 0.125];
668 let cdf = quantize_pdf_to_rans_cdf(&pdf);
669 let symbols: Vec<usize> = (0..(BLOCK_SIZE + 17)).map(|i| i % pdf.len()).collect();
670
671 let mut enc = BlockedRansEncoder::new();
672 for &sym in &symbols {
673 enc.encode(cdf_for_symbol(&cdf, sym));
674 }
675 let blocks = enc.finish();
676 let block_refs: Vec<&[u8]> = blocks.iter().map(Vec::as_slice).collect();
677
678 let mut dec = BlockedRansDecoder::new(block_refs, symbols.len()).unwrap();
679 for &expected in &symbols {
680 let got = dec.decode(&cdf).unwrap();
681 assert_eq!(got, expected, "blocked rANS mismatch at symbol {expected}");
682 }
683 }
684}