1use std::io::Write;
15use wide::f32x8;
16use wide::f64x4;
17
18pub const CDF_TOTAL: u32 = 1 << 30;
20
21const PRECISION: u32 = 32;
23
24const BASE: u64 = 2;
26
27#[inline]
30pub fn p_min() -> f64 {
31 2.0f64.powi(-(PRECISION as i32 - 3))
33}
34
35pub fn softmax_pdf_floor(logits: &[f32], vocab_size: usize) -> Vec<f64> {
44 let mut result = vec![0f64; vocab_size];
45 softmax_pdf_floor_inplace(logits, vocab_size, &mut result);
46 result
47}
48
49pub fn softmax_pdf_inplace(logits: &[f32], vocab_size: usize, pdf_out: &mut [f64]) {
53 debug_assert!(pdf_out.len() >= vocab_size);
54
55 let max = logits
56 .iter()
57 .take(vocab_size)
58 .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
59
60 let mut sum = 0.0f64;
61 for i in 0..vocab_size {
62 let e = ((logits[i] - max) as f64).exp();
63 pdf_out[i] = e;
64 sum += e;
65 }
66
67 if sum > 0.0 {
68 let inv = 1.0 / sum;
69 for value in pdf_out.iter_mut().take(vocab_size) {
70 *value *= inv;
71 }
72 } else {
73 let inv = 1.0 / (vocab_size.max(1) as f64);
74 for value in pdf_out.iter_mut().take(vocab_size) {
75 *value = inv;
76 }
77 }
78}
79
80pub fn softmax_pdf_floor_inplace(logits: &[f32], vocab_size: usize, pdf_out: &mut [f64]) {
87 if vocab_size == 256 && logits.len() >= 256 && pdf_out.len() >= 256 {
89 softmax_pdf_floor_wide_256(logits, pdf_out);
90 return;
91 }
92
93 let p_min_val = p_min();
94
95 let max = logits
97 .iter()
98 .take(vocab_size)
99 .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
100
101 let mut sum = 0.0f64;
103 for i in 0..vocab_size {
104 let e = ((logits[i] - max) as f64).exp();
105 pdf_out[i] = e;
106 sum += e;
107 }
108
109 for value in pdf_out.iter_mut().take(vocab_size) {
111 *value = (*value / sum).max(p_min_val);
112 }
113
114 let norm: f64 = pdf_out[..vocab_size].iter().sum();
116 for value in pdf_out.iter_mut().take(vocab_size) {
117 *value /= norm;
118 }
119}
120
121#[inline]
123fn softmax_pdf_floor_wide_256(logits: &[f32], pdf_out: &mut [f64]) {
124 const N: usize = 256;
125 debug_assert!(logits.len() >= N);
126 debug_assert!(pdf_out.len() >= N);
127 let p_min_val = p_min();
128
129 #[inline(always)]
130 unsafe fn load8(ptr: *const f32) -> f32x8 {
131 ptr.cast::<f32x8>().read_unaligned()
132 }
133
134 let mut max_v = f32x8::splat(f32::NEG_INFINITY);
135 for i in (0..N).step_by(8) {
136 let v = unsafe { load8(logits.as_ptr().add(i)) };
137 max_v = max_v.fast_max(v);
138 }
139 let mut max = f32::NEG_INFINITY;
140 for x in max_v.to_array() {
141 max = max.max(x);
142 }
143 let max_v = f32x8::splat(max);
144
145 let mut sum4 = f64x4::ZERO;
146 for (chunk_idx, out_chunk) in pdf_out[..N].chunks_exact_mut(8).enumerate() {
147 let i = chunk_idx * 8;
148 let centered = unsafe { load8(logits.as_ptr().add(i)) } - max_v;
149 let exp_vals = centered.exp().to_array();
150 let v0 = f64x4::new([
151 exp_vals[0] as f64,
152 exp_vals[1] as f64,
153 exp_vals[2] as f64,
154 exp_vals[3] as f64,
155 ]);
156 let v1 = f64x4::new([
157 exp_vals[4] as f64,
158 exp_vals[5] as f64,
159 exp_vals[6] as f64,
160 exp_vals[7] as f64,
161 ]);
162 sum4 += v0 + v1;
163
164 let lanes0 = v0.to_array();
165 let lanes1 = v1.to_array();
166 out_chunk[..4].copy_from_slice(&lanes0);
167 out_chunk[4..].copy_from_slice(&lanes1);
168 }
169
170 let sum_lanes = sum4.to_array();
171 let sum = sum_lanes[0] + sum_lanes[1] + sum_lanes[2] + sum_lanes[3];
172
173 let inv_sum = 1.0 / sum;
174 let mut norm4 = f64x4::ZERO;
175 let inv_sum4 = f64x4::splat(inv_sum);
176 let min4 = f64x4::splat(p_min_val);
177
178 for chunk in pdf_out[..N].chunks_exact_mut(4) {
179 let vals = f64x4::new([chunk[0], chunk[1], chunk[2], chunk[3]]);
180 let mut v = vals * inv_sum4;
181 v = v.max(min4);
182
183 let lanes = v.to_array();
184 chunk.copy_from_slice(&lanes);
185 norm4 += v;
186 }
187
188 let norm_lanes = norm4.to_array();
189 let norm = norm_lanes[0] + norm_lanes[1] + norm_lanes[2] + norm_lanes[3];
190
191 let inv_norm = 1.0 / norm;
192 let inv_norm4 = f64x4::splat(inv_norm);
193 for chunk in pdf_out[..N].chunks_exact_mut(4) {
194 let vals = f64x4::new([chunk[0], chunk[1], chunk[2], chunk[3]]);
195 let out = vals * inv_norm4;
196 let lanes = out.to_array();
197 chunk.copy_from_slice(&lanes);
198 }
199}
200
201pub fn softmax_pdf(logits: &[f32], vocab_size: usize) -> Vec<f64> {
203 let max = logits
204 .iter()
205 .take(vocab_size)
206 .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
207
208 let mut exps = vec![0f64; vocab_size];
209 let mut sum = 0.0f64;
210 for i in 0..vocab_size {
211 let e = ((logits[i] - max) as f64).exp();
212 exps[i] = e;
213 sum += e;
214 }
215
216 if sum <= 0.0 {
217 let uniform = 1.0 / (vocab_size as f64);
219 return vec![uniform; vocab_size];
220 }
221
222 let mut pdf = vec![0f64; vocab_size];
223 for i in 0..vocab_size {
224 pdf[i] = exps[i] / sum;
225 }
226 pdf
227}
228
229pub fn quantize_pdf_to_cdf(pdf: &[f64]) -> Vec<u32> {
241 let mut cdf = vec![0u32; pdf.len() + 1];
242 quantize_pdf_to_cdf_inplace(pdf, &mut cdf);
243 cdf
244}
245
246#[inline]
250pub fn quantize_pdf_to_cdf_inplace(pdf: &[f64], cdf_out: &mut [u32]) {
251 let mut unused_freq = [];
252 super::quantize_pdf_to_integer_cdf_with_buffer(pdf, CDF_TOTAL, cdf_out, &mut unused_freq);
253}
254
255#[inline]
257pub fn quantize_pdf_to_cdf_with_buffer(pdf: &[f64], cdf_out: &mut [u32], freq_buf: &mut [i64]) {
258 super::quantize_pdf_to_integer_cdf_with_buffer(pdf, CDF_TOTAL, cdf_out, freq_buf);
259}
260
261pub struct ArithmeticEncoder<W: Write> {
263 b_to_pm1: u64,
264 b_to_pm2: u64,
265 mask: u64,
266 low: u64,
267 high: u64,
268 carry_run: u64,
269 out: W,
270 bit_buffer: u8,
271 bit_count: u8,
272 bytes_out: u64,
273}
274
275impl<W: Write> ArithmeticEncoder<W> {
276 pub fn new(out: W) -> Self {
278 let b_to_pm1 = BASE.pow(PRECISION - 1);
279 let b_to_pm2 = BASE.pow(PRECISION - 2);
280 let mask = BASE.pow(PRECISION) - 1;
281 Self {
282 b_to_pm1,
283 b_to_pm2,
284 mask,
285 low: 0,
286 high: mask,
287 carry_run: 0,
288 out,
289 bit_buffer: 0,
290 bit_count: 0,
291 bytes_out: 0,
292 }
293 }
294
295 #[inline]
296 fn write_byte(&mut self, byte: u8) -> anyhow::Result<()> {
297 self.out.write_all(&[byte])?;
298 self.bytes_out += 1;
299 Ok(())
300 }
301
302 #[inline]
303 fn put_bit_internal(&mut self, bit: u8) -> anyhow::Result<()> {
304 self.bit_buffer = (self.bit_buffer << 1) | (bit & 1);
305 self.bit_count += 1;
306 if self.bit_count == 8 {
307 let b = self.bit_buffer;
308 self.write_byte(b)?;
309 self.bit_buffer = 0;
310 self.bit_count = 0;
311 }
312 Ok(())
313 }
314
315 #[inline]
316 fn put_bit(&mut self, bit: u8) -> anyhow::Result<()> {
317 self.put_bit_internal(bit)?;
318 while self.carry_run > 0 {
319 self.put_bit_internal((!bit) & 1)?;
320 self.carry_run -= 1;
321 }
322 Ok(())
323 }
324
325 pub fn encode_counts(&mut self, c_lo: u64, c_hi: u64, total: u64) -> anyhow::Result<()> {
332 let range = (self.high - self.low + 1) as u128;
333 let total_u = total as u128;
334 let c_lo_u = c_lo as u128;
335 let c_hi_u = c_hi as u128;
336 let low_u = self.low as u128;
337
338 let new_low = low_u + (range * c_lo_u) / total_u;
339 let new_high = low_u + (range * c_hi_u) / total_u - 1;
340
341 self.low = (new_low & (self.mask as u128)) as u64;
342 self.high = (new_high & (self.mask as u128)) as u64;
343
344 loop {
345 if self.high < self.b_to_pm1 {
346 self.put_bit(0)?;
347 } else if self.low >= self.b_to_pm1 {
348 self.put_bit(1)?;
349 self.low -= self.b_to_pm1;
350 self.high -= self.b_to_pm1;
351 } else if self.low >= self.b_to_pm2 && self.high < self.b_to_pm2 * 3 {
352 self.carry_run += 1;
353 self.low -= self.b_to_pm2;
354 self.high -= self.b_to_pm2;
355 } else {
356 break;
357 }
358 self.low = (self.low << 1) & self.mask;
359 self.high = ((self.high << 1) & self.mask) | 1;
360 }
361 Ok(())
362 }
363
364 pub fn encode_symbol(&mut self, pdf: &[f64], sym: usize) -> anyhow::Result<()> {
368 let cdf = quantize_pdf_to_cdf(pdf);
369 let c_lo = cdf[sym] as u64;
370 let c_hi = cdf[sym + 1] as u64;
371 self.encode_counts(c_lo, c_hi, CDF_TOTAL as u64)
372 }
373
374 pub fn finish(mut self) -> anyhow::Result<W> {
378 self.carry_run += 1;
379 if self.low < self.b_to_pm2 {
380 self.put_bit(0)?;
381 } else {
382 self.put_bit(1)?;
383 }
384 if self.bit_count > 0 {
386 let remaining = 8 - self.bit_count;
387 for _ in 0..remaining {
388 self.put_bit_internal(0)?;
389 }
390 }
391 Ok(self.out)
392 }
393
394 #[inline]
396 pub fn bytes_written(&self) -> u64 {
397 self.bytes_out
398 }
399}
400
401pub struct ArithmeticDecoder<'a> {
403 b_to_pm1: u64,
404 b_to_pm2: u64,
405 mask: u64,
406 low: u64,
407 high: u64,
408 code: u64,
409 input: &'a [u8],
410 byte_pos: usize,
411 bit_pos: u8,
412}
413
414impl<'a> ArithmeticDecoder<'a> {
415 pub fn new(input: &'a [u8]) -> anyhow::Result<Self> {
417 let b_to_pm1 = BASE.pow(PRECISION - 1);
418 let b_to_pm2 = BASE.pow(PRECISION - 2);
419 let mask = BASE.pow(PRECISION) - 1;
420
421 let mut s = Self {
422 b_to_pm1,
423 b_to_pm2,
424 mask,
425 low: 0,
426 high: mask,
427 code: 0,
428 input,
429 byte_pos: 0,
430 bit_pos: 0,
431 };
432
433 for _ in 0..PRECISION {
435 s.code = (s.code << 1) | (s.get_bit().unwrap_or(1) as u64);
436 }
437
438 Ok(s)
439 }
440
441 #[inline]
442 fn get_bit(&mut self) -> Option<u8> {
443 if self.byte_pos >= self.input.len() {
444 return None;
445 }
446 let byte = self.input[self.byte_pos];
447 let bit = (byte >> (7 - self.bit_pos)) & 1;
448 self.bit_pos += 1;
449 if self.bit_pos >= 8 {
450 self.bit_pos = 0;
451 self.byte_pos += 1;
452 }
453 Some(bit)
454 }
455
456 pub fn decode_symbol_counts(&mut self, cdf: &[u32], total: u32) -> anyhow::Result<usize> {
465 let total_u = total as u64;
466 let range = self.high - self.low + 1;
467 let value =
468 (((self.code - self.low + 1) as u128 * (total_u as u128)) - 1) / (range as u128);
469 let value_u = value as u32;
470
471 let mut lo = 0usize;
473 let mut hi = cdf.len() - 1;
474 while lo + 1 < hi {
475 let mid = (lo + hi) / 2;
476 if cdf[mid] <= value_u {
477 lo = mid;
478 } else {
479 hi = mid;
480 }
481 }
482 let s = lo;
483 let c_lo = cdf[s] as u64;
484 let c_hi = cdf[s + 1] as u64;
485
486 let range = (self.high - self.low + 1) as u128;
488 let low_u = self.low as u128;
489 let total_u128 = total as u128;
490 let new_low = low_u + (range * (c_lo as u128)) / total_u128;
491 let new_high = low_u + (range * (c_hi as u128)) / total_u128 - 1;
492
493 self.low = new_low as u64;
494 self.high = new_high as u64;
495
496 loop {
498 if self.high < self.b_to_pm1 {
499 } else if self.low >= self.b_to_pm1 {
501 self.low -= self.b_to_pm1;
502 self.high -= self.b_to_pm1;
503 self.code -= self.b_to_pm1;
504 } else if self.low >= self.b_to_pm2 && self.high < self.b_to_pm2 * 3 {
505 self.low -= self.b_to_pm2;
506 self.high -= self.b_to_pm2;
507 self.code -= self.b_to_pm2;
508 } else {
509 break;
510 }
511 self.low = (self.low << 1) & self.mask;
512 self.high = ((self.high << 1) & self.mask) | 1;
513 self.code = ((self.code << 1) & self.mask) | (self.get_bit().unwrap_or(1) as u64);
514 }
515
516 Ok(s)
517 }
518
519 pub fn decode_symbol(&mut self, pdf: &[f64]) -> anyhow::Result<usize> {
523 let cdf = quantize_pdf_to_cdf(pdf);
524 self.decode_symbol_counts(&cdf, CDF_TOTAL)
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531
532 #[test]
533 fn test_roundtrip_uniform() {
534 let pdf = vec![0.25, 0.25, 0.25, 0.25];
536 let symbols = vec![0, 1, 2, 3, 0, 1, 2, 3];
537
538 let mut buf = Vec::new();
540 let mut enc = ArithmeticEncoder::new(&mut buf);
541 for &s in &symbols {
542 enc.encode_symbol(&pdf, s).unwrap();
543 }
544 let buf = enc.finish().unwrap().to_vec();
545
546 let mut dec = ArithmeticDecoder::new(&buf).unwrap();
548 for &expected in &symbols {
549 let got = dec.decode_symbol(&pdf).unwrap();
550 assert_eq!(got, expected);
551 }
552 }
553
554 #[test]
555 fn test_roundtrip_skewed() {
556 let pdf = vec![0.7, 0.2, 0.05, 0.05];
558 let symbols = vec![0, 0, 0, 1, 0, 2, 0, 3, 0, 0];
559
560 let mut buf = Vec::new();
562 let mut enc = ArithmeticEncoder::new(&mut buf);
563 for &s in &symbols {
564 enc.encode_symbol(&pdf, s).unwrap();
565 }
566 let buf = enc.finish().unwrap().to_vec();
567
568 let mut dec = ArithmeticDecoder::new(&buf).unwrap();
570 for &expected in &symbols {
571 let got = dec.decode_symbol(&pdf).unwrap();
572 assert_eq!(got, expected);
573 }
574 }
575
576 #[test]
577 fn test_softmax_pdf_floor() {
578 let logits = vec![1.0f32, 2.0, 3.0, 4.0];
579 let pdf = softmax_pdf_floor(&logits, 4);
580
581 let sum: f64 = pdf.iter().sum();
583 assert!((sum - 1.0).abs() < 1e-10);
584
585 let p_min_val = p_min();
587 for &p in &pdf {
588 assert!(p >= p_min_val);
589 }
590 }
591
592 #[test]
593 fn test_cdf_monotonic() {
594 let pdf = vec![0.1, 0.2, 0.3, 0.4];
595 let cdf = quantize_pdf_to_cdf(&pdf);
596
597 assert_eq!(cdf[0], 0);
598 assert_eq!(cdf[4], CDF_TOTAL);
599
600 for i in 1..cdf.len() {
602 assert!(cdf[i] >= cdf[i - 1]);
603 }
604 }
605
606 #[test]
607 fn test_cdf_positive_width_for_tiny_positive_tail() {
608 let tail = 1e-18;
609 let head = 1.0 - (255.0 * tail);
610 let mut pdf = vec![tail; 256];
611 pdf[0] = head;
612 let mut cdf = vec![0u32; 257];
613 let mut freq = vec![0i64; 256];
614 quantize_pdf_to_cdf_with_buffer(&pdf, &mut cdf, &mut freq);
615
616 assert_eq!(cdf[0], 0);
617 assert_eq!(cdf[256], CDF_TOTAL);
618 for i in 0..256 {
619 assert!(cdf[i + 1] > cdf[i], "symbol {i} has zero-width interval");
620 }
621 }
622
623 #[test]
624 fn test_cdf_positive_width_when_mass_is_last_symbol() {
625 let mut pdf = vec![0.0; 256];
626 pdf[255] = 1.0;
627 let mut cdf = vec![0u32; 257];
628 let mut freq = vec![0i64; 256];
629 quantize_pdf_to_cdf_with_buffer(&pdf, &mut cdf, &mut freq);
630
631 assert_eq!(cdf[0], 0);
632 assert_eq!(cdf[255], 255);
633 assert_eq!(cdf[256], CDF_TOTAL);
634 for i in 0..256 {
635 assert!(cdf[i + 1] > cdf[i], "symbol {i} has zero-width interval");
636 }
637 }
638
639 #[test]
640 #[should_panic]
641 fn test_softmax_floor_256_short_logits_panics_safely() {
642 let logits = vec![0.0f32; 255];
645 let mut pdf_out = vec![0.0f64; 256];
646 softmax_pdf_floor_inplace(&logits, 256, &mut pdf_out);
647 }
648}