1use std::io::Write;
15
16pub const CDF_TOTAL: u32 = 1 << 30;
18
19const PRECISION: u32 = 32;
21
22const BASE: u64 = 2;
24
25#[inline]
28pub fn p_min() -> f64 {
29 2.0f64.powi(-(PRECISION as i32 - 3))
31}
32
33pub fn softmax_pdf_floor(logits: &[f32], vocab_size: usize) -> Vec<f64> {
42 let mut result = vec![0f64; vocab_size];
43 softmax_pdf_floor_inplace(logits, vocab_size, &mut result);
44 result
45}
46
47pub fn softmax_pdf_inplace(logits: &[f32], vocab_size: usize, pdf_out: &mut [f64]) {
48 debug_assert!(pdf_out.len() >= vocab_size);
49
50 let max = logits
51 .iter()
52 .take(vocab_size)
53 .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
54
55 let mut sum = 0.0f64;
56 for i in 0..vocab_size {
57 let e = ((logits[i] - max) as f64).exp();
58 pdf_out[i] = e;
59 sum += e;
60 }
61
62 if sum > 0.0 {
63 let inv = 1.0 / sum;
64 for i in 0..vocab_size {
65 pdf_out[i] *= inv;
66 }
67 } else {
68 let inv = 1.0 / (vocab_size.max(1) as f64);
69 for i in 0..vocab_size {
70 pdf_out[i] = inv;
71 }
72 }
73}
74
75pub fn softmax_pdf_floor_inplace(logits: &[f32], vocab_size: usize, pdf_out: &mut [f64]) {
82 #[cfg(target_arch = "x86_64")]
84 if vocab_size == 256 {
85 unsafe { softmax_pdf_floor_avx2(logits, pdf_out) };
86 return;
87 }
88
89 let p_min_val = p_min();
90
91 let max = logits
93 .iter()
94 .take(vocab_size)
95 .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
96
97 let mut sum = 0.0f64;
99 for i in 0..vocab_size {
100 let e = ((logits[i] - max) as f64).exp();
101 pdf_out[i] = e;
102 sum += e;
103 }
104
105 for i in 0..vocab_size {
107 pdf_out[i] = (pdf_out[i] / sum).max(p_min_val);
108 }
109
110 let norm: f64 = pdf_out[..vocab_size].iter().sum();
112 for i in 0..vocab_size {
113 pdf_out[i] /= norm;
114 }
115}
116
117#[cfg(target_arch = "x86_64")]
119#[target_feature(enable = "avx2,fma")]
120unsafe fn softmax_pdf_floor_avx2(logits: &[f32], pdf_out: &mut [f64]) {
121 use std::arch::x86_64::*;
122
123 const N: usize = 256;
124 let p_min_val = p_min();
125
126 let mut max_v = _mm256_set1_ps(f32::NEG_INFINITY);
128 for i in (0..N).step_by(8) {
129 let v = _mm256_loadu_ps(logits.as_ptr().add(i));
130 max_v = _mm256_max_ps(max_v, v);
131 }
132 let hi = _mm256_extractf128_ps(max_v, 1);
134 let lo = _mm256_castps256_ps128(max_v);
135 let max128 = _mm_max_ps(hi, lo);
136 let max64 = _mm_max_ps(max128, _mm_movehl_ps(max128, max128));
137 let max32 = _mm_max_ss(max64, _mm_shuffle_ps(max64, max64, 0x55));
138 let max = _mm_cvtss_f32(max32);
139 let max_v = _mm256_set1_ps(max);
140
141 let mut sum0 = _mm256_setzero_pd();
144 let mut sum1 = _mm256_setzero_pd();
145 for i in (0..N).step_by(8) {
146 let v = _mm256_loadu_ps(logits.as_ptr().add(i));
147 let centered = _mm256_sub_ps(v, max_v);
148 let exp_vals = exp256_ps_fast(centered);
149
150 let lo4 = _mm256_castps256_ps128(exp_vals);
152 let hi4 = _mm256_extractf128_ps(exp_vals, 1);
153 let d_lo = _mm256_cvtps_pd(lo4);
154 let d_hi = _mm256_cvtps_pd(hi4);
155
156 _mm256_storeu_pd(pdf_out.as_mut_ptr().add(i), d_lo);
158 _mm256_storeu_pd(pdf_out.as_mut_ptr().add(i + 4), d_hi);
159 sum0 = _mm256_add_pd(sum0, d_lo);
160 sum1 = _mm256_add_pd(sum1, d_hi);
161 }
162
163 let sum01 = _mm256_add_pd(sum0, sum1);
165 let hi128 = _mm256_extractf128_pd(sum01, 1);
166 let lo128 = _mm256_castpd256_pd128(sum01);
167 let sum2 = _mm_add_pd(hi128, lo128);
168 let sum1_v = _mm_unpackhi_pd(sum2, sum2);
169 let sum_v = _mm_add_sd(sum2, sum1_v);
170 let sum = _mm_cvtsd_f64(sum_v);
171
172 let inv_sum = 1.0 / sum;
174 let p_min_v = _mm256_set1_pd(p_min_val);
175 let inv_sum_v = _mm256_set1_pd(inv_sum);
176
177 let mut new_sum0 = _mm256_setzero_pd();
178 let mut new_sum1 = _mm256_setzero_pd();
179 for i in (0..N).step_by(8) {
180 let v0 = _mm256_loadu_pd(pdf_out.as_ptr().add(i));
181 let v1 = _mm256_loadu_pd(pdf_out.as_ptr().add(i + 4));
182 let normed0 = _mm256_mul_pd(v0, inv_sum_v);
183 let normed1 = _mm256_mul_pd(v1, inv_sum_v);
184 let floored0 = _mm256_max_pd(normed0, p_min_v);
185 let floored1 = _mm256_max_pd(normed1, p_min_v);
186 _mm256_storeu_pd(pdf_out.as_mut_ptr().add(i), floored0);
187 _mm256_storeu_pd(pdf_out.as_mut_ptr().add(i + 4), floored1);
188 new_sum0 = _mm256_add_pd(new_sum0, floored0);
189 new_sum1 = _mm256_add_pd(new_sum1, floored1);
190 }
191
192 let ns01 = _mm256_add_pd(new_sum0, new_sum1);
194 let ns_hi = _mm256_extractf128_pd(ns01, 1);
195 let ns_lo = _mm256_castpd256_pd128(ns01);
196 let ns2 = _mm_add_pd(ns_hi, ns_lo);
197 let ns1 = _mm_unpackhi_pd(ns2, ns2);
198 let ns_v = _mm_add_sd(ns2, ns1);
199 let new_sum = _mm_cvtsd_f64(ns_v);
200
201 let inv_norm = 1.0 / new_sum;
203 let inv_norm_v = _mm256_set1_pd(inv_norm);
204 for i in (0..N).step_by(8) {
205 let v0 = _mm256_loadu_pd(pdf_out.as_ptr().add(i));
206 let v1 = _mm256_loadu_pd(pdf_out.as_ptr().add(i + 4));
207 let result0 = _mm256_mul_pd(v0, inv_norm_v);
208 let result1 = _mm256_mul_pd(v1, inv_norm_v);
209 _mm256_storeu_pd(pdf_out.as_mut_ptr().add(i), result0);
210 _mm256_storeu_pd(pdf_out.as_mut_ptr().add(i + 4), result1);
211 }
212}
213
214#[cfg(target_arch = "x86_64")]
216#[inline(always)]
217unsafe fn exp256_ps_fast(x: std::arch::x86_64::__m256) -> std::arch::x86_64::__m256 {
218 use std::arch::x86_64::*;
219
220 let x = _mm256_max_ps(
222 _mm256_min_ps(x, _mm256_set1_ps(88.0)),
223 _mm256_set1_ps(-88.0),
224 );
225
226 let log2e = _mm256_set1_ps(1.442695041);
228 let fx = _mm256_mul_ps(x, log2e);
229
230 let fx_floor = _mm256_floor_ps(fx);
232 let f = _mm256_sub_ps(fx, fx_floor);
233
234 let c0 = _mm256_set1_ps(1.0);
237 let c1 = _mm256_set1_ps(0.693147180559945);
238 let c2 = _mm256_set1_ps(0.240226506959101);
239 let c3 = _mm256_set1_ps(0.0558263180532956);
240
241 let poly = _mm256_fmadd_ps(f, c3, c2);
242 let poly = _mm256_fmadd_ps(f, poly, c1);
243 let poly = _mm256_fmadd_ps(f, poly, c0);
244
245 let n = _mm256_cvtps_epi32(fx_floor);
247 let n = _mm256_add_epi32(n, _mm256_set1_epi32(127)); let n = _mm256_slli_epi32(n, 23); let pow2n = _mm256_castsi256_ps(n);
250
251 _mm256_mul_ps(poly, pow2n)
252}
253
254pub fn softmax_pdf(logits: &[f32], vocab_size: usize) -> Vec<f64> {
256 let max = logits
257 .iter()
258 .take(vocab_size)
259 .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
260
261 let mut exps = vec![0f64; vocab_size];
262 let mut sum = 0.0f64;
263 for i in 0..vocab_size {
264 let e = ((logits[i] - max) as f64).exp();
265 exps[i] = e;
266 sum += e;
267 }
268
269 if sum <= 0.0 {
270 let uniform = 1.0 / (vocab_size as f64);
272 return vec![uniform; vocab_size];
273 }
274
275 let mut pdf = vec![0f64; vocab_size];
276 for i in 0..vocab_size {
277 pdf[i] = exps[i] / sum;
278 }
279 pdf
280}
281
282pub fn quantize_pdf_to_cdf(pdf: &[f64]) -> Vec<u32> {
294 let mut cdf = vec![0u32; pdf.len() + 1];
295 quantize_pdf_to_cdf_inplace(pdf, &mut cdf);
296 cdf
297}
298
299#[inline]
303pub fn quantize_pdf_to_cdf_inplace(pdf: &[f64], cdf_out: &mut [u32]) {
304 let n = pdf.len();
305 debug_assert!(cdf_out.len() >= n + 1, "cdf buffer too small");
306
307 unsafe {
308 *cdf_out.get_unchecked_mut(0) = 0;
309 let scale = CDF_TOTAL as f64;
310 let mut acc = 0.0f64;
311 let mut prev = 0u32;
312
313 for i in 0..n {
314 acc += *pdf.get_unchecked(i);
315 let v = (acc * scale) as u32;
316 let v = v.max(prev);
317 *cdf_out.get_unchecked_mut(i + 1) = v;
318 prev = v;
319 }
320 *cdf_out.get_unchecked_mut(n) = CDF_TOTAL;
321 }
322}
323
324pub struct ArithmeticEncoder<W: Write> {
326 b_to_pm1: u64,
327 b_to_pm2: u64,
328 mask: u64,
329 low: u64,
330 high: u64,
331 carry_run: u64,
332 out: W,
333 bit_buffer: u8,
334 bit_count: u8,
335 bytes_out: u64,
336}
337
338impl<W: Write> ArithmeticEncoder<W> {
339 pub fn new(out: W) -> Self {
341 let b_to_pm1 = BASE.pow(PRECISION - 1);
342 let b_to_pm2 = BASE.pow(PRECISION - 2);
343 let mask = BASE.pow(PRECISION) - 1;
344 Self {
345 b_to_pm1,
346 b_to_pm2,
347 mask,
348 low: 0,
349 high: mask,
350 carry_run: 0,
351 out,
352 bit_buffer: 0,
353 bit_count: 0,
354 bytes_out: 0,
355 }
356 }
357
358 #[inline]
359 fn write_byte(&mut self, byte: u8) -> anyhow::Result<()> {
360 self.out.write_all(&[byte])?;
361 self.bytes_out += 1;
362 Ok(())
363 }
364
365 #[inline]
366 fn put_bit_internal(&mut self, bit: u8) -> anyhow::Result<()> {
367 self.bit_buffer = (self.bit_buffer << 1) | (bit & 1);
368 self.bit_count += 1;
369 if self.bit_count == 8 {
370 let b = self.bit_buffer;
371 self.write_byte(b)?;
372 self.bit_buffer = 0;
373 self.bit_count = 0;
374 }
375 Ok(())
376 }
377
378 #[inline]
379 fn put_bit(&mut self, bit: u8) -> anyhow::Result<()> {
380 self.put_bit_internal(bit)?;
381 while self.carry_run > 0 {
382 self.put_bit_internal((!bit) & 1)?;
383 self.carry_run -= 1;
384 }
385 Ok(())
386 }
387
388 pub fn encode_counts(&mut self, c_lo: u64, c_hi: u64, total: u64) -> anyhow::Result<()> {
395 let range = (self.high - self.low + 1) as u128;
396 let total_u = total as u128;
397 let c_lo_u = c_lo as u128;
398 let c_hi_u = c_hi as u128;
399 let low_u = self.low as u128;
400
401 let new_low = low_u + (range * c_lo_u) / total_u;
402 let new_high = low_u + (range * c_hi_u) / total_u - 1;
403
404 self.low = (new_low & (self.mask as u128)) as u64;
405 self.high = (new_high & (self.mask as u128)) as u64;
406
407 loop {
408 if self.high < self.b_to_pm1 {
409 self.put_bit(0)?;
410 } else if self.low >= self.b_to_pm1 {
411 self.put_bit(1)?;
412 self.low -= self.b_to_pm1;
413 self.high -= self.b_to_pm1;
414 } else if self.low >= self.b_to_pm2 && self.high < self.b_to_pm2 * 3 {
415 self.carry_run += 1;
416 self.low -= self.b_to_pm2;
417 self.high -= self.b_to_pm2;
418 } else {
419 break;
420 }
421 self.low = (self.low << 1) & self.mask;
422 self.high = ((self.high << 1) & self.mask) | 1;
423 }
424 Ok(())
425 }
426
427 pub fn encode_symbol(&mut self, pdf: &[f64], sym: usize) -> anyhow::Result<()> {
431 let cdf = quantize_pdf_to_cdf(pdf);
432 let c_lo = cdf[sym] as u64;
433 let c_hi = cdf[sym + 1] as u64;
434 self.encode_counts(c_lo, c_hi, CDF_TOTAL as u64)
435 }
436
437 pub fn finish(mut self) -> anyhow::Result<W> {
441 self.carry_run += 1;
442 if self.low < self.b_to_pm2 {
443 self.put_bit(0)?;
444 } else {
445 self.put_bit(1)?;
446 }
447 if self.bit_count > 0 {
449 let remaining = 8 - self.bit_count;
450 for _ in 0..remaining {
451 self.put_bit_internal(0)?;
452 }
453 }
454 Ok(self.out)
455 }
456
457 #[inline]
459 pub fn bytes_written(&self) -> u64 {
460 self.bytes_out
461 }
462}
463
464pub struct ArithmeticDecoder<'a> {
466 b_to_pm1: u64,
467 b_to_pm2: u64,
468 mask: u64,
469 low: u64,
470 high: u64,
471 code: u64,
472 input: &'a [u8],
473 byte_pos: usize,
474 bit_pos: u8,
475}
476
477impl<'a> ArithmeticDecoder<'a> {
478 pub fn new(input: &'a [u8]) -> anyhow::Result<Self> {
480 let b_to_pm1 = BASE.pow(PRECISION - 1);
481 let b_to_pm2 = BASE.pow(PRECISION - 2);
482 let mask = BASE.pow(PRECISION) - 1;
483
484 let mut s = Self {
485 b_to_pm1,
486 b_to_pm2,
487 mask,
488 low: 0,
489 high: mask,
490 code: 0,
491 input,
492 byte_pos: 0,
493 bit_pos: 0,
494 };
495
496 for _ in 0..PRECISION {
498 s.code = (s.code << 1) | (s.get_bit().unwrap_or(1) as u64);
499 }
500
501 Ok(s)
502 }
503
504 #[inline]
505 fn get_bit(&mut self) -> Option<u8> {
506 if self.byte_pos >= self.input.len() {
507 return None;
508 }
509 let byte = self.input[self.byte_pos];
510 let bit = (byte >> (7 - self.bit_pos)) & 1;
511 self.bit_pos += 1;
512 if self.bit_pos >= 8 {
513 self.bit_pos = 0;
514 self.byte_pos += 1;
515 }
516 Some(bit)
517 }
518
519 pub fn decode_symbol_counts(&mut self, cdf: &[u32], total: u32) -> anyhow::Result<usize> {
528 let total_u = total as u64;
529 let range = self.high - self.low + 1;
530 let value =
531 (((self.code - self.low + 1) as u128 * (total_u as u128)) - 1) / (range as u128);
532 let value_u = value as u32;
533
534 let mut lo = 0usize;
536 let mut hi = cdf.len() - 1;
537 while lo + 1 < hi {
538 let mid = (lo + hi) / 2;
539 if cdf[mid] <= value_u {
540 lo = mid;
541 } else {
542 hi = mid;
543 }
544 }
545 let s = lo;
546 let c_lo = cdf[s] as u64;
547 let c_hi = cdf[s + 1] as u64;
548
549 let range = (self.high - self.low + 1) as u128;
551 let low_u = self.low as u128;
552 let total_u128 = total as u128;
553 let new_low = low_u + (range * (c_lo as u128)) / total_u128;
554 let new_high = low_u + (range * (c_hi as u128)) / total_u128 - 1;
555
556 self.low = new_low as u64;
557 self.high = new_high as u64;
558
559 loop {
561 if self.high < self.b_to_pm1 {
562 } else if self.low >= self.b_to_pm1 {
564 self.low -= self.b_to_pm1;
565 self.high -= self.b_to_pm1;
566 self.code -= self.b_to_pm1;
567 } else if self.low >= self.b_to_pm2 && self.high < self.b_to_pm2 * 3 {
568 self.low -= self.b_to_pm2;
569 self.high -= self.b_to_pm2;
570 self.code -= self.b_to_pm2;
571 } else {
572 break;
573 }
574 self.low = (self.low << 1) & self.mask;
575 self.high = ((self.high << 1) & self.mask) | 1;
576 self.code = ((self.code << 1) & self.mask) | (self.get_bit().unwrap_or(1) as u64);
577 }
578
579 Ok(s)
580 }
581
582 pub fn decode_symbol(&mut self, pdf: &[f64]) -> anyhow::Result<usize> {
586 let cdf = quantize_pdf_to_cdf(pdf);
587 self.decode_symbol_counts(&cdf, CDF_TOTAL)
588 }
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594
595 #[test]
596 fn test_roundtrip_uniform() {
597 let pdf = vec![0.25, 0.25, 0.25, 0.25];
599 let symbols = vec![0, 1, 2, 3, 0, 1, 2, 3];
600
601 let mut buf = Vec::new();
603 let mut enc = ArithmeticEncoder::new(&mut buf);
604 for &s in &symbols {
605 enc.encode_symbol(&pdf, s).unwrap();
606 }
607 let buf = enc.finish().unwrap().to_vec();
608
609 let mut dec = ArithmeticDecoder::new(&buf).unwrap();
611 for &expected in &symbols {
612 let got = dec.decode_symbol(&pdf).unwrap();
613 assert_eq!(got, expected);
614 }
615 }
616
617 #[test]
618 fn test_roundtrip_skewed() {
619 let pdf = vec![0.7, 0.2, 0.05, 0.05];
621 let symbols = vec![0, 0, 0, 1, 0, 2, 0, 3, 0, 0];
622
623 let mut buf = Vec::new();
625 let mut enc = ArithmeticEncoder::new(&mut buf);
626 for &s in &symbols {
627 enc.encode_symbol(&pdf, s).unwrap();
628 }
629 let buf = enc.finish().unwrap().to_vec();
630
631 let mut dec = ArithmeticDecoder::new(&buf).unwrap();
633 for &expected in &symbols {
634 let got = dec.decode_symbol(&pdf).unwrap();
635 assert_eq!(got, expected);
636 }
637 }
638
639 #[test]
640 fn test_softmax_pdf_floor() {
641 let logits = vec![1.0f32, 2.0, 3.0, 4.0];
642 let pdf = softmax_pdf_floor(&logits, 4);
643
644 let sum: f64 = pdf.iter().sum();
646 assert!((sum - 1.0).abs() < 1e-10);
647
648 let p_min_val = p_min();
650 for &p in &pdf {
651 assert!(p >= p_min_val);
652 }
653 }
654
655 #[test]
656 fn test_cdf_monotonic() {
657 let pdf = vec![0.1, 0.2, 0.3, 0.4];
658 let cdf = quantize_pdf_to_cdf(&pdf);
659
660 assert_eq!(cdf[0], 0);
661 assert_eq!(cdf[4], CDF_TOTAL);
662
663 for i in 1..cdf.len() {
665 assert!(cdf[i] >= cdf[i - 1]);
666 }
667 }
668}