1pub const ANS_BITS: u32 = 15;
15
16pub const ANS_TOTAL: u32 = 1 << ANS_BITS;
18
19pub const ANS_LOW: u32 = 1 << ANS_BITS;
21
22pub const ANS_HIGH: u32 = 1 << 31;
24
25#[derive(Clone, Debug)]
27pub struct Cdf {
28 pub lo: u32,
30 pub hi: u32,
32 pub total: u32,
34}
35
36impl Cdf {
37 #[inline]
39 pub fn new(lo: u32, hi: u32, total: u32) -> Self {
40 Self { lo, hi, total }
41 }
42
43 #[inline]
45 pub fn freq(&self) -> u32 {
46 self.hi - self.lo
47 }
48}
49
50pub fn quantize_pdf_to_rans_cdf(pdf: &[f64]) -> Vec<u32> {
65 let mut cdf = vec![0u32; pdf.len() + 1];
66 let mut freqs = vec![0i64; pdf.len()];
67 quantize_pdf_to_rans_cdf_with_buffer(pdf, &mut cdf, &mut freqs);
68 cdf
69}
70
71pub fn quantize_pdf_to_rans_cdf_with_buffer(
77 pdf: &[f64],
78 cdf_out: &mut [u32],
79 freq_buf: &mut [i64],
80) {
81 let n = pdf.len();
82 assert!(cdf_out.len() >= n + 1, "cdf buffer too small");
83 assert!(freq_buf.len() >= n, "frequency buffer too small");
84
85 let total = ANS_TOTAL as i64;
86 for i in 0..n {
87 freq_buf[i] = (pdf[i] * total as f64).round() as i64;
88 if pdf[i] > 0.0 && freq_buf[i] == 0 {
89 freq_buf[i] = 1;
90 } else if pdf[i] <= 0.0 {
91 freq_buf[i] = 0;
92 }
93 }
94
95 let sum: i64 = freq_buf[..n].iter().sum();
96 if sum > total {
97 let mut to_remove = sum - total;
98 while to_remove > 0 {
99 let mut removed = 0;
100 for i in (0..n).rev() {
101 if freq_buf[i] > 1 {
102 freq_buf[i] -= 1;
103 to_remove -= 1;
104 removed += 1;
105 if to_remove == 0 {
106 break;
107 }
108 }
109 }
110 if removed == 0 {
111 break;
112 }
113 }
114 } else if sum < total {
115 let mut to_add = total - sum;
116 while to_add > 0 {
117 let mut added = 0;
118 for i in 0..n {
119 if pdf[i] > 0.0 {
120 freq_buf[i] += 1;
121 to_add -= 1;
122 added += 1;
123 if to_add == 0 {
124 break;
125 }
126 }
127 }
128 if added == 0 {
129 for i in 0..n {
130 freq_buf[i] += 1;
131 to_add -= 1;
132 if to_add == 0 {
133 break;
134 }
135 }
136 }
137 }
138 }
139
140 cdf_out[0] = 0;
141 let mut cumsum = 0u32;
142 for i in 0..n {
143 cdf_out[i] = cumsum;
144 cumsum += freq_buf[i] as u32;
145 }
146 cdf_out[n] = cumsum;
147
148 debug_assert_eq!(cdf_out[n], ANS_TOTAL, "CDF total must equal ANS_TOTAL");
149 for i in 0..n {
150 if pdf[i] > 0.0 {
151 debug_assert!(
152 cdf_out[i + 1] > cdf_out[i],
153 "Symbol {} with p={} has zero frequency",
154 i,
155 pdf[i]
156 );
157 }
158 }
159}
160
161#[inline]
163pub fn cdf_for_symbol(cdf: &[u32], sym: usize) -> Cdf {
164 Cdf::new(cdf[sym], cdf[sym + 1], ANS_TOTAL)
165}
166
167pub struct RansEncoder {
169 state: u32,
170 output: Vec<u16>, }
172
173impl RansEncoder {
174 pub fn new() -> Self {
176 Self {
177 state: ANS_LOW,
178 output: Vec::new(),
179 }
180 }
181
182 #[inline]
187 pub fn encode(&mut self, cdf: &Cdf) {
188 let freq = cdf.freq();
189 debug_assert!(freq > 0, "Symbol frequency must be > 0");
190
191 while self.state >= (freq << 16) {
195 self.output.push(self.state as u16);
196 self.state >>= 16;
197 }
198
199 let q = self.state / freq;
201 let r = self.state % freq;
202 self.state = (q << ANS_BITS) + r + cdf.lo;
203 }
204
205 pub fn encode_pdf(&mut self, pdf: &[f64], sym: usize) {
207 let cdf_table = quantize_pdf_to_rans_cdf(pdf);
208 let cdf = cdf_for_symbol(&cdf_table, sym);
209 self.encode(&cdf);
210 }
211
212 pub fn finish(self) -> Vec<u8> {
214 let mut result = Vec::with_capacity(self.output.len() * 2 + 4);
216
217 result.extend_from_slice(&self.state.to_le_bytes());
219
220 for &word in self.output.iter().rev() {
222 result.extend_from_slice(&word.to_le_bytes());
223 }
224
225 result
226 }
227
228 pub fn size_estimate(&self) -> usize {
230 self.output.len() * 2 + 4 }
232}
233
234impl Default for RansEncoder {
235 fn default() -> Self {
236 Self::new()
237 }
238}
239
240pub struct RansDecoder<'a> {
242 state: u32,
243 input: &'a [u8],
244 pos: usize,
245}
246
247impl<'a> RansDecoder<'a> {
248 pub fn new(input: &'a [u8]) -> anyhow::Result<Self> {
250 if input.len() < 4 {
251 anyhow::bail!("rANS input too short");
252 }
253
254 let state = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
256
257 Ok(Self {
258 state,
259 input,
260 pos: 4,
261 })
262 }
263
264 #[inline]
271 pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
272 let slot = self.state & (ANS_TOTAL - 1);
274
275 let mut lo = 0usize;
277 let mut hi = cdf.len() - 1;
278 while lo + 1 < hi {
279 let mid = (lo + hi) / 2;
280 if cdf[mid] <= slot {
281 lo = mid;
282 } else {
283 hi = mid;
284 }
285 }
286 let sym = lo;
287
288 let c_lo = cdf[sym];
289 let c_hi = cdf[sym + 1];
290 let freq = c_hi - c_lo;
291
292 self.state = freq * (self.state >> ANS_BITS) + slot - c_lo;
294
295 while self.state < ANS_LOW && self.pos + 1 < self.input.len() {
297 let word = u16::from_le_bytes([self.input[self.pos], self.input[self.pos + 1]]);
298 self.state = (self.state << 16) | (word as u32);
299 self.pos += 2;
300 }
301
302 Ok(sym)
303 }
304
305 pub fn decode_pdf(&mut self, pdf: &[f64]) -> anyhow::Result<usize> {
307 let cdf = quantize_pdf_to_rans_cdf(pdf);
308 self.decode(&cdf)
309 }
310}
311
312mod simd {
317 use super::*;
318 #[allow(unused_imports)]
319 use std::arch::x86_64::*;
320
321 pub const RANS_LANES: usize = 8;
323
324 pub struct SimdRansEncoder {
326 states: [u32; RANS_LANES],
327 outputs: [Vec<u8>; RANS_LANES],
328 lane: usize,
329 }
330
331 impl SimdRansEncoder {
332 pub fn new() -> Self {
334 Self {
335 states: [ANS_LOW; RANS_LANES],
336 outputs: Default::default(),
337 lane: 0,
338 }
339 }
340
341 pub fn encode(&mut self, cdf: &Cdf) {
343 let freq = cdf.freq();
344 let lane = self.lane;
345 self.lane = (self.lane + 1) % RANS_LANES;
346
347 let state = &mut self.states[lane];
348 let output = &mut self.outputs[lane];
349
350 while *state >= (ANS_HIGH / cdf.total) * freq {
352 output.push(*state as u8);
353 *state >>= 8;
354 }
355
356 *state = ((*state / freq) * cdf.total) + (*state % freq) + cdf.lo;
358 }
359
360 pub fn finish(self) -> Vec<u8> {
362 let mut result = Vec::new();
363
364 for i in 0..RANS_LANES {
366 let s = self.states[i];
367 result.extend_from_slice(&s.to_le_bytes());
368 }
369
370 let max_len = self.outputs.iter().map(|v| v.len()).max().unwrap_or(0);
372
373 for pos in 0..max_len {
375 for lane in 0..RANS_LANES {
376 let out = &self.outputs[lane];
377 if pos < out.len() {
378 result.push(out[out.len() - 1 - pos]);
379 } else {
380 result.push(0);
381 }
382 }
383 }
384
385 result
386 }
387 }
388
389 impl Default for SimdRansEncoder {
390 fn default() -> Self {
391 Self::new()
392 }
393 }
394
395 pub struct SimdRansDecoder<'a> {
397 states: [u32; RANS_LANES],
398 input: &'a [u8],
399 pos: usize,
400 lane: usize,
401 }
402
403 impl<'a> SimdRansDecoder<'a> {
404 pub fn new(input: &'a [u8]) -> anyhow::Result<Self> {
406 if input.len() < RANS_LANES * 4 {
407 anyhow::bail!("SIMD rANS input too short");
408 }
409
410 let mut states = [0u32; RANS_LANES];
411 for i in 0..RANS_LANES {
412 let offset = i * 4;
413 states[i] = u32::from_le_bytes([
414 input[offset],
415 input[offset + 1],
416 input[offset + 2],
417 input[offset + 3],
418 ]);
419 }
420
421 Ok(Self {
422 states,
423 input,
424 pos: RANS_LANES * 4,
425 lane: 0,
426 })
427 }
428
429 pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
431 let lane = self.lane;
432 self.lane = (self.lane + 1) % RANS_LANES;
433
434 let state = &mut self.states[lane];
435 let total = ANS_TOTAL;
436 let value = *state & (total - 1);
437
438 let mut lo = 0usize;
440 let mut hi = cdf.len() - 1;
441 while lo + 1 < hi {
442 let mid = (lo + hi) / 2;
443 if cdf[mid] <= value {
444 lo = mid;
445 } else {
446 hi = mid;
447 }
448 }
449 let sym = lo;
450
451 let c_lo = cdf[sym];
452 let c_hi = cdf[sym + 1];
453 let freq = c_hi - c_lo;
454
455 *state = freq * (*state >> ANS_BITS) + (*state & (total - 1)) - c_lo;
457
458 while *state < ANS_LOW {
460 let byte_idx = self.pos + lane;
462 if byte_idx < self.input.len() {
463 *state = (*state << 8) | (self.input[byte_idx] as u32);
464 }
465 self.pos += RANS_LANES;
466 }
467
468 Ok(sym)
469 }
470 }
471}
472pub use simd::*;
473
474pub const BLOCK_SIZE: usize = 128 * 1024;
483
484pub struct BlockedRansEncoder {
489 symbols: Vec<Cdf>,
491 blocks: Vec<Vec<u8>>,
493}
494
495impl BlockedRansEncoder {
496 pub fn new() -> Self {
497 Self {
498 symbols: Vec::with_capacity(BLOCK_SIZE),
499 blocks: Vec::new(),
500 }
501 }
502
503 pub fn encode(&mut self, cdf: Cdf) {
505 self.symbols.push(cdf);
506
507 if self.symbols.len() >= BLOCK_SIZE {
509 self.flush_block();
510 }
511 }
512
513 fn flush_block(&mut self) {
515 if self.symbols.is_empty() {
516 return;
517 }
518
519 let mut encoder = RansEncoder::new();
521 for cdf in self.symbols.iter().rev() {
522 encoder.encode(cdf);
523 }
524
525 let encoded = encoder.finish();
526 self.blocks.push(encoded);
527 self.symbols.clear();
528 }
529
530 pub fn finish(mut self) -> Vec<Vec<u8>> {
532 self.flush_block();
534 self.blocks
535 }
536}
537
538impl Default for BlockedRansEncoder {
539 fn default() -> Self {
540 Self::new()
541 }
542}
543
544pub struct BlockedRansDecoder<'a> {
546 blocks: Vec<&'a [u8]>,
547 current_block: usize,
548 decoder: Option<RansDecoder<'a>>,
549}
550
551impl<'a> BlockedRansDecoder<'a> {
552 pub fn new(blocks: Vec<&'a [u8]>) -> Self {
554 Self {
555 blocks,
556 current_block: 0,
557 decoder: None,
558 }
559 }
560
561 pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
563 if self.decoder.is_none() {
565 if self.current_block >= self.blocks.len() {
566 anyhow::bail!("No more blocks to decode");
567 }
568 self.decoder = Some(RansDecoder::new(self.blocks[self.current_block])?);
569 }
570
571 match self.decoder.as_mut().unwrap().decode(cdf) {
573 Ok(sym) => Ok(sym),
574 Err(_) => {
575 self.current_block += 1;
577 if self.current_block >= self.blocks.len() {
578 anyhow::bail!("All blocks exhausted");
579 }
580 self.decoder = Some(RansDecoder::new(self.blocks[self.current_block])?);
581 self.decoder.as_mut().unwrap().decode(cdf)
582 }
583 }
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590
591 #[test]
592 fn test_roundtrip_scalar() {
593 let pdf = vec![0.5, 0.3, 0.15, 0.05];
594 let symbols = vec![0, 0, 1, 0, 2, 1, 0, 3, 0, 0, 1, 2];
595
596 let mut enc = RansEncoder::new();
598 let cdf_table = quantize_pdf_to_rans_cdf(&pdf);
599 for &s in symbols.iter().rev() {
600 let cdf = cdf_for_symbol(&cdf_table, s);
601 enc.encode(&cdf);
602 }
603 let encoded = enc.finish();
604
605 let mut dec = RansDecoder::new(&encoded).unwrap();
607 for &expected in &symbols {
608 let got = dec.decode(&cdf_table).unwrap();
609 assert_eq!(got, expected, "Symbol mismatch");
610 }
611 }
612
613 #[test]
614 fn test_cdf_quantization() {
615 let pdf = vec![0.25, 0.25, 0.25, 0.25];
616 let cdf = quantize_pdf_to_rans_cdf(&pdf);
617
618 assert_eq!(cdf[0], 0);
619 assert_eq!(cdf[4], ANS_TOTAL);
620
621 for i in 1..4 {
623 let delta = cdf[i] - cdf[i - 1];
624 assert!(delta > 0);
625 }
626 }
627
628 #[test]
629 fn test_extreme_probabilities() {
630 let pdf = vec![0.99, 0.005, 0.003, 0.002];
632 let symbols = vec![0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 3];
633
634 let mut enc = RansEncoder::new();
636 let cdf_table = quantize_pdf_to_rans_cdf(&pdf);
637 for &s in symbols.iter().rev() {
638 let cdf = cdf_for_symbol(&cdf_table, s);
639 enc.encode(&cdf);
640 }
641 let encoded = enc.finish();
642
643 let mut dec = RansDecoder::new(&encoded).unwrap();
645 for &expected in &symbols {
646 let got = dec.decode(&cdf_table).unwrap();
647 assert_eq!(got, expected);
648 }
649 }
650}