rwkvzip/coders/
ac.rs

1//! Arithmetic Coder implementation for rwkvzip.
2//!
3//! This implements a binary arithmetic coder with 32-bit precision, optimized for
4//! neural network probability distributions. The implementation is mathematically
5//! rigorous to ensure lossless compression.
6//!
7//! # Information-Theoretic Properties
8//!
9//! - Uses base-2 arithmetic for bitstream output
10//! - 32-bit precision prevents underflow for typical neural network distributions
11//! - Integer CDF quantization uses 30-bit total (2^30) to minimize quantization error
12//! - Probability floor ensures no symbol has zero probability (critical for lossless)
13
14use std::io::Write;
15
16/// Total count for CDF quantization (2^30 for high precision)
17pub const CDF_TOTAL: u32 = 1 << 30;
18
19/// Arithmetic coder precision in bits
20const PRECISION: u32 = 32;
21
22/// Base for arithmetic coding (binary)
23const BASE: u64 = 2;
24
25/// Returns the minimum probability floor for symbols.
26/// P_MIN = 2 * 2^(-(PRECISION-2)) = 2^(-(PRECISION-3)) = 2^(-29)
27#[inline]
28pub fn p_min() -> f64 {
29    // 2.0 * 2.0^(-(32-2)) = 2^(-29) ≈ 1.86e-9
30    2.0f64.powi(-(PRECISION as i32 - 3))
31}
32
33/// Compute softmax PDF with probability floor.
34///
35/// # Arguments
36/// * `logits` - Raw logits from the model
37/// * `vocab_size` - Size of the vocabulary (256 for byte-level)
38///
39/// # Returns
40/// Probability distribution with floor applied, normalized to sum to 1.
41pub fn softmax_pdf_floor(logits: &[f32], vocab_size: usize) -> Vec<f64> {
42    let mut result = vec![0f64; vocab_size];
43    softmax_pdf_floor_inplace(logits, vocab_size, &mut result);
44    result
45}
46
47pub fn softmax_pdf_inplace(logits: &[f32], vocab_size: usize, pdf_out: &mut [f64]) {
48    debug_assert!(pdf_out.len() >= vocab_size);
49
50    let max = logits
51        .iter()
52        .take(vocab_size)
53        .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
54
55    let mut sum = 0.0f64;
56    for i in 0..vocab_size {
57        let e = ((logits[i] - max) as f64).exp();
58        pdf_out[i] = e;
59        sum += e;
60    }
61
62    if sum > 0.0 {
63        let inv = 1.0 / sum;
64        for i in 0..vocab_size {
65            pdf_out[i] *= inv;
66        }
67    } else {
68        let inv = 1.0 / (vocab_size.max(1) as f64);
69        for i in 0..vocab_size {
70            pdf_out[i] = inv;
71        }
72    }
73}
74
75/// In-place version of softmax_pdf_floor to avoid allocations.
76///
77/// # Arguments
78/// * `logits` - Raw logits from the model
79/// * `vocab_size` - Size of the vocabulary
80/// * `pdf_out` - Pre-allocated buffer for output PDF (length >= vocab_size)
81pub fn softmax_pdf_floor_inplace(logits: &[f32], vocab_size: usize, pdf_out: &mut [f64]) {
82    // Fast path for vocab_size=256 with AVX2
83    #[cfg(target_arch = "x86_64")]
84    if vocab_size == 256 {
85        unsafe { softmax_pdf_floor_avx2(logits, pdf_out) };
86        return;
87    }
88
89    let p_min_val = p_min();
90
91    // Find max for numerical stability
92    let max = logits
93        .iter()
94        .take(vocab_size)
95        .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
96
97    // Compute exp(x - max) and sum (reuse pdf_out as temp buffer)
98    let mut sum = 0.0f64;
99    for i in 0..vocab_size {
100        let e = ((logits[i] - max) as f64).exp();
101        pdf_out[i] = e;
102        sum += e;
103    }
104
105    // Normalize and apply floor
106    for i in 0..vocab_size {
107        pdf_out[i] = (pdf_out[i] / sum).max(p_min_val);
108    }
109
110    // Re-normalize after floor application
111    let norm: f64 = pdf_out[..vocab_size].iter().sum();
112    for i in 0..vocab_size {
113        pdf_out[i] /= norm;
114    }
115}
116
117/// AVX2-optimized softmax with probability floor for vocab_size=256.
118#[cfg(target_arch = "x86_64")]
119#[target_feature(enable = "avx2,fma")]
120unsafe fn softmax_pdf_floor_avx2(logits: &[f32], pdf_out: &mut [f64]) {
121    use std::arch::x86_64::*;
122
123    const N: usize = 256;
124    let p_min_val = p_min();
125
126    // Find max using AVX2
127    let mut max_v = _mm256_set1_ps(f32::NEG_INFINITY);
128    for i in (0..N).step_by(8) {
129        let v = _mm256_loadu_ps(logits.as_ptr().add(i));
130        max_v = _mm256_max_ps(max_v, v);
131    }
132    // Horizontal max reduction
133    let hi = _mm256_extractf128_ps(max_v, 1);
134    let lo = _mm256_castps256_ps128(max_v);
135    let max128 = _mm_max_ps(hi, lo);
136    let max64 = _mm_max_ps(max128, _mm_movehl_ps(max128, max128));
137    let max32 = _mm_max_ss(max64, _mm_shuffle_ps(max64, max64, 0x55));
138    let max = _mm_cvtss_f32(max32);
139    let max_v = _mm256_set1_ps(max);
140
141    // Compute exp(x - max) and sum
142    // Use f32 for exp computation, convert to f64 for accumulation
143    let mut sum0 = _mm256_setzero_pd();
144    let mut sum1 = _mm256_setzero_pd();
145    for i in (0..N).step_by(8) {
146        let v = _mm256_loadu_ps(logits.as_ptr().add(i));
147        let centered = _mm256_sub_ps(v, max_v);
148        let exp_vals = exp256_ps_fast(centered);
149
150        // Convert f32 to f64 using AVX2: 8 floats -> 2x4 doubles
151        let lo4 = _mm256_castps256_ps128(exp_vals);
152        let hi4 = _mm256_extractf128_ps(exp_vals, 1);
153        let d_lo = _mm256_cvtps_pd(lo4);
154        let d_hi = _mm256_cvtps_pd(hi4);
155
156        // Store and accumulate
157        _mm256_storeu_pd(pdf_out.as_mut_ptr().add(i), d_lo);
158        _mm256_storeu_pd(pdf_out.as_mut_ptr().add(i + 4), d_hi);
159        sum0 = _mm256_add_pd(sum0, d_lo);
160        sum1 = _mm256_add_pd(sum1, d_hi);
161    }
162
163    // Horizontal sum of sum0 and sum1
164    let sum01 = _mm256_add_pd(sum0, sum1);
165    let hi128 = _mm256_extractf128_pd(sum01, 1);
166    let lo128 = _mm256_castpd256_pd128(sum01);
167    let sum2 = _mm_add_pd(hi128, lo128);
168    let sum1_v = _mm_unpackhi_pd(sum2, sum2);
169    let sum_v = _mm_add_sd(sum2, sum1_v);
170    let sum = _mm_cvtsd_f64(sum_v);
171
172    // Normalize and apply floor
173    let inv_sum = 1.0 / sum;
174    let p_min_v = _mm256_set1_pd(p_min_val);
175    let inv_sum_v = _mm256_set1_pd(inv_sum);
176
177    let mut new_sum0 = _mm256_setzero_pd();
178    let mut new_sum1 = _mm256_setzero_pd();
179    for i in (0..N).step_by(8) {
180        let v0 = _mm256_loadu_pd(pdf_out.as_ptr().add(i));
181        let v1 = _mm256_loadu_pd(pdf_out.as_ptr().add(i + 4));
182        let normed0 = _mm256_mul_pd(v0, inv_sum_v);
183        let normed1 = _mm256_mul_pd(v1, inv_sum_v);
184        let floored0 = _mm256_max_pd(normed0, p_min_v);
185        let floored1 = _mm256_max_pd(normed1, p_min_v);
186        _mm256_storeu_pd(pdf_out.as_mut_ptr().add(i), floored0);
187        _mm256_storeu_pd(pdf_out.as_mut_ptr().add(i + 4), floored1);
188        new_sum0 = _mm256_add_pd(new_sum0, floored0);
189        new_sum1 = _mm256_add_pd(new_sum1, floored1);
190    }
191
192    // Horizontal sum for new_sum
193    let ns01 = _mm256_add_pd(new_sum0, new_sum1);
194    let ns_hi = _mm256_extractf128_pd(ns01, 1);
195    let ns_lo = _mm256_castpd256_pd128(ns01);
196    let ns2 = _mm_add_pd(ns_hi, ns_lo);
197    let ns1 = _mm_unpackhi_pd(ns2, ns2);
198    let ns_v = _mm_add_sd(ns2, ns1);
199    let new_sum = _mm_cvtsd_f64(ns_v);
200
201    // Re-normalize
202    let inv_norm = 1.0 / new_sum;
203    let inv_norm_v = _mm256_set1_pd(inv_norm);
204    for i in (0..N).step_by(8) {
205        let v0 = _mm256_loadu_pd(pdf_out.as_ptr().add(i));
206        let v1 = _mm256_loadu_pd(pdf_out.as_ptr().add(i + 4));
207        let result0 = _mm256_mul_pd(v0, inv_norm_v);
208        let result1 = _mm256_mul_pd(v1, inv_norm_v);
209        _mm256_storeu_pd(pdf_out.as_mut_ptr().add(i), result0);
210        _mm256_storeu_pd(pdf_out.as_mut_ptr().add(i + 4), result1);
211    }
212}
213
214/// Fast exp approximation for f32 (AVX2).
215#[cfg(target_arch = "x86_64")]
216#[inline(always)]
217unsafe fn exp256_ps_fast(x: std::arch::x86_64::__m256) -> std::arch::x86_64::__m256 {
218    use std::arch::x86_64::*;
219
220    // Clamp to avoid overflow/underflow
221    let x = _mm256_max_ps(
222        _mm256_min_ps(x, _mm256_set1_ps(88.0)),
223        _mm256_set1_ps(-88.0),
224    );
225
226    // exp(x) = 2^(x * log2(e))
227    let log2e = _mm256_set1_ps(1.442695041);
228    let fx = _mm256_mul_ps(x, log2e);
229
230    // Split into integer and fractional parts
231    let fx_floor = _mm256_floor_ps(fx);
232    let f = _mm256_sub_ps(fx, fx_floor);
233
234    // Polynomial approximation for 2^f where f in [0, 1]
235    // 2^f ≈ 1 + f*(0.693147 + f*(0.240226 + f*0.0558))
236    let c0 = _mm256_set1_ps(1.0);
237    let c1 = _mm256_set1_ps(0.693147180559945);
238    let c2 = _mm256_set1_ps(0.240226506959101);
239    let c3 = _mm256_set1_ps(0.0558263180532956);
240
241    let poly = _mm256_fmadd_ps(f, c3, c2);
242    let poly = _mm256_fmadd_ps(f, poly, c1);
243    let poly = _mm256_fmadd_ps(f, poly, c0);
244
245    // Scale by 2^n using float bit manipulation
246    let n = _mm256_cvtps_epi32(fx_floor);
247    let n = _mm256_add_epi32(n, _mm256_set1_epi32(127)); // Add exponent bias
248    let n = _mm256_slli_epi32(n, 23); // Shift to exponent position
249    let pow2n = _mm256_castsi256_ps(n);
250
251    _mm256_mul_ps(poly, pow2n)
252}
253
254/// Compute softmax PDF without floor (for entropy calculation).
255pub fn softmax_pdf(logits: &[f32], vocab_size: usize) -> Vec<f64> {
256    let max = logits
257        .iter()
258        .take(vocab_size)
259        .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
260
261    let mut exps = vec![0f64; vocab_size];
262    let mut sum = 0.0f64;
263    for i in 0..vocab_size {
264        let e = ((logits[i] - max) as f64).exp();
265        exps[i] = e;
266        sum += e;
267    }
268
269    if sum <= 0.0 {
270        // Uniform distribution fallback
271        let uniform = 1.0 / (vocab_size as f64);
272        return vec![uniform; vocab_size];
273    }
274
275    let mut pdf = vec![0f64; vocab_size];
276    for i in 0..vocab_size {
277        pdf[i] = exps[i] / sum;
278    }
279    pdf
280}
281
282/// Quantize probability distribution to integer CDF.
283///
284/// The CDF is constructed to be monotonically non-decreasing with:
285/// - `cdf[0] = 0`
286/// - `cdf[vocab_size] = CDF_TOTAL`
287///
288/// # Arguments
289/// * `pdf` - Probability distribution (must sum to ~1.0)
290///
291/// # Returns
292/// Integer CDF with length vocab_size + 1
293pub fn quantize_pdf_to_cdf(pdf: &[f64]) -> Vec<u32> {
294    let mut cdf = vec![0u32; pdf.len() + 1];
295    quantize_pdf_to_cdf_inplace(pdf, &mut cdf);
296    cdf
297}
298
299/// Quantize PDF to integer CDF using a reusable output buffer.
300///
301/// `cdf_out` must have length at least `pdf.len() + 1`.
302#[inline]
303pub fn quantize_pdf_to_cdf_inplace(pdf: &[f64], cdf_out: &mut [u32]) {
304    let n = pdf.len();
305    debug_assert!(cdf_out.len() >= n + 1, "cdf buffer too small");
306
307    unsafe {
308        *cdf_out.get_unchecked_mut(0) = 0;
309        let scale = CDF_TOTAL as f64;
310        let mut acc = 0.0f64;
311        let mut prev = 0u32;
312
313        for i in 0..n {
314            acc += *pdf.get_unchecked(i);
315            let v = (acc * scale) as u32;
316            let v = v.max(prev);
317            *cdf_out.get_unchecked_mut(i + 1) = v;
318            prev = v;
319        }
320        *cdf_out.get_unchecked_mut(n) = CDF_TOTAL;
321    }
322}
323
324/// Binary arithmetic encoder.
325pub struct ArithmeticEncoder<W: Write> {
326    b_to_pm1: u64,
327    b_to_pm2: u64,
328    mask: u64,
329    low: u64,
330    high: u64,
331    carry_run: u64,
332    out: W,
333    bit_buffer: u8,
334    bit_count: u8,
335    bytes_out: u64,
336}
337
338impl<W: Write> ArithmeticEncoder<W> {
339    /// Create a new arithmetic encoder.
340    pub fn new(out: W) -> Self {
341        let b_to_pm1 = BASE.pow(PRECISION - 1);
342        let b_to_pm2 = BASE.pow(PRECISION - 2);
343        let mask = BASE.pow(PRECISION) - 1;
344        Self {
345            b_to_pm1,
346            b_to_pm2,
347            mask,
348            low: 0,
349            high: mask,
350            carry_run: 0,
351            out,
352            bit_buffer: 0,
353            bit_count: 0,
354            bytes_out: 0,
355        }
356    }
357
358    #[inline]
359    fn write_byte(&mut self, byte: u8) -> anyhow::Result<()> {
360        self.out.write_all(&[byte])?;
361        self.bytes_out += 1;
362        Ok(())
363    }
364
365    #[inline]
366    fn put_bit_internal(&mut self, bit: u8) -> anyhow::Result<()> {
367        self.bit_buffer = (self.bit_buffer << 1) | (bit & 1);
368        self.bit_count += 1;
369        if self.bit_count == 8 {
370            let b = self.bit_buffer;
371            self.write_byte(b)?;
372            self.bit_buffer = 0;
373            self.bit_count = 0;
374        }
375        Ok(())
376    }
377
378    #[inline]
379    fn put_bit(&mut self, bit: u8) -> anyhow::Result<()> {
380        self.put_bit_internal(bit)?;
381        while self.carry_run > 0 {
382            self.put_bit_internal((!bit) & 1)?;
383            self.carry_run -= 1;
384        }
385        Ok(())
386    }
387
388    /// Encode a symbol using integer CDF bounds.
389    ///
390    /// # Arguments
391    /// * `c_lo` - Lower CDF bound (cumulative probability before symbol)
392    /// * `c_hi` - Upper CDF bound (cumulative probability including symbol)
393    /// * `total` - Total CDF range (should be CDF_TOTAL)
394    pub fn encode_counts(&mut self, c_lo: u64, c_hi: u64, total: u64) -> anyhow::Result<()> {
395        let range = (self.high - self.low + 1) as u128;
396        let total_u = total as u128;
397        let c_lo_u = c_lo as u128;
398        let c_hi_u = c_hi as u128;
399        let low_u = self.low as u128;
400
401        let new_low = low_u + (range * c_lo_u) / total_u;
402        let new_high = low_u + (range * c_hi_u) / total_u - 1;
403
404        self.low = (new_low & (self.mask as u128)) as u64;
405        self.high = (new_high & (self.mask as u128)) as u64;
406
407        loop {
408            if self.high < self.b_to_pm1 {
409                self.put_bit(0)?;
410            } else if self.low >= self.b_to_pm1 {
411                self.put_bit(1)?;
412                self.low -= self.b_to_pm1;
413                self.high -= self.b_to_pm1;
414            } else if self.low >= self.b_to_pm2 && self.high < self.b_to_pm2 * 3 {
415                self.carry_run += 1;
416                self.low -= self.b_to_pm2;
417                self.high -= self.b_to_pm2;
418            } else {
419                break;
420            }
421            self.low = (self.low << 1) & self.mask;
422            self.high = ((self.high << 1) & self.mask) | 1;
423        }
424        Ok(())
425    }
426
427    /// Encode a symbol given its PDF and symbol index.
428    ///
429    /// This is a convenience method that quantizes the PDF to CDF internally.
430    pub fn encode_symbol(&mut self, pdf: &[f64], sym: usize) -> anyhow::Result<()> {
431        let cdf = quantize_pdf_to_cdf(pdf);
432        let c_lo = cdf[sym] as u64;
433        let c_hi = cdf[sym + 1] as u64;
434        self.encode_counts(c_lo, c_hi, CDF_TOTAL as u64)
435    }
436
437    /// Finish encoding and flush remaining bits.
438    ///
439    /// Returns the underlying writer.
440    pub fn finish(mut self) -> anyhow::Result<W> {
441        self.carry_run += 1;
442        if self.low < self.b_to_pm2 {
443            self.put_bit(0)?;
444        } else {
445            self.put_bit(1)?;
446        }
447        // Pad remaining bits
448        if self.bit_count > 0 {
449            let remaining = 8 - self.bit_count;
450            for _ in 0..remaining {
451                self.put_bit_internal(0)?;
452            }
453        }
454        Ok(self.out)
455    }
456
457    /// Get the number of bytes written so far.
458    #[inline]
459    pub fn bytes_written(&self) -> u64 {
460        self.bytes_out
461    }
462}
463
464/// Binary arithmetic decoder.
465pub struct ArithmeticDecoder<'a> {
466    b_to_pm1: u64,
467    b_to_pm2: u64,
468    mask: u64,
469    low: u64,
470    high: u64,
471    code: u64,
472    input: &'a [u8],
473    byte_pos: usize,
474    bit_pos: u8,
475}
476
477impl<'a> ArithmeticDecoder<'a> {
478    /// Create a new arithmetic decoder from input bytes.
479    pub fn new(input: &'a [u8]) -> anyhow::Result<Self> {
480        let b_to_pm1 = BASE.pow(PRECISION - 1);
481        let b_to_pm2 = BASE.pow(PRECISION - 2);
482        let mask = BASE.pow(PRECISION) - 1;
483
484        let mut s = Self {
485            b_to_pm1,
486            b_to_pm2,
487            mask,
488            low: 0,
489            high: mask,
490            code: 0,
491            input,
492            byte_pos: 0,
493            bit_pos: 0,
494        };
495
496        // Initialize code register with first PRECISION bits
497        for _ in 0..PRECISION {
498            s.code = (s.code << 1) | (s.get_bit().unwrap_or(1) as u64);
499        }
500
501        Ok(s)
502    }
503
504    #[inline]
505    fn get_bit(&mut self) -> Option<u8> {
506        if self.byte_pos >= self.input.len() {
507            return None;
508        }
509        let byte = self.input[self.byte_pos];
510        let bit = (byte >> (7 - self.bit_pos)) & 1;
511        self.bit_pos += 1;
512        if self.bit_pos >= 8 {
513            self.bit_pos = 0;
514            self.byte_pos += 1;
515        }
516        Some(bit)
517    }
518
519    /// Decode a symbol using integer CDF.
520    ///
521    /// # Arguments
522    /// * `cdf` - Cumulative distribution function (length = vocab_size + 1)
523    /// * `total` - Total CDF range (should be CDF_TOTAL)
524    ///
525    /// # Returns
526    /// The decoded symbol index.
527    pub fn decode_symbol_counts(&mut self, cdf: &[u32], total: u32) -> anyhow::Result<usize> {
528        let total_u = total as u64;
529        let range = self.high - self.low + 1;
530        let value =
531            (((self.code - self.low + 1) as u128 * (total_u as u128)) - 1) / (range as u128);
532        let value_u = value as u32;
533
534        // Binary search for symbol `s` with `cdf[s] <= value < cdf[s+1]`
535        let mut lo = 0usize;
536        let mut hi = cdf.len() - 1;
537        while lo + 1 < hi {
538            let mid = (lo + hi) / 2;
539            if cdf[mid] <= value_u {
540                lo = mid;
541            } else {
542                hi = mid;
543            }
544        }
545        let s = lo;
546        let c_lo = cdf[s] as u64;
547        let c_hi = cdf[s + 1] as u64;
548
549        // Update range
550        let range = (self.high - self.low + 1) as u128;
551        let low_u = self.low as u128;
552        let total_u128 = total as u128;
553        let new_low = low_u + (range * (c_lo as u128)) / total_u128;
554        let new_high = low_u + (range * (c_hi as u128)) / total_u128 - 1;
555
556        self.low = new_low as u64;
557        self.high = new_high as u64;
558
559        // Renormalize
560        loop {
561            if self.high < self.b_to_pm1 {
562                // nothing
563            } else if self.low >= self.b_to_pm1 {
564                self.low -= self.b_to_pm1;
565                self.high -= self.b_to_pm1;
566                self.code -= self.b_to_pm1;
567            } else if self.low >= self.b_to_pm2 && self.high < self.b_to_pm2 * 3 {
568                self.low -= self.b_to_pm2;
569                self.high -= self.b_to_pm2;
570                self.code -= self.b_to_pm2;
571            } else {
572                break;
573            }
574            self.low = (self.low << 1) & self.mask;
575            self.high = ((self.high << 1) & self.mask) | 1;
576            self.code = ((self.code << 1) & self.mask) | (self.get_bit().unwrap_or(1) as u64);
577        }
578
579        Ok(s)
580    }
581
582    /// Decode a symbol given a PDF.
583    ///
584    /// This is a convenience method that quantizes the PDF to CDF internally.
585    pub fn decode_symbol(&mut self, pdf: &[f64]) -> anyhow::Result<usize> {
586        let cdf = quantize_pdf_to_cdf(pdf);
587        self.decode_symbol_counts(&cdf, CDF_TOTAL)
588    }
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594
595    #[test]
596    fn test_roundtrip_uniform() {
597        // Test with uniform distribution
598        let pdf = vec![0.25, 0.25, 0.25, 0.25];
599        let symbols = vec![0, 1, 2, 3, 0, 1, 2, 3];
600
601        // Encode
602        let mut buf = Vec::new();
603        let mut enc = ArithmeticEncoder::new(&mut buf);
604        for &s in &symbols {
605            enc.encode_symbol(&pdf, s).unwrap();
606        }
607        let buf = enc.finish().unwrap().to_vec();
608
609        // Decode
610        let mut dec = ArithmeticDecoder::new(&buf).unwrap();
611        for &expected in &symbols {
612            let got = dec.decode_symbol(&pdf).unwrap();
613            assert_eq!(got, expected);
614        }
615    }
616
617    #[test]
618    fn test_roundtrip_skewed() {
619        // Test with skewed distribution
620        let pdf = vec![0.7, 0.2, 0.05, 0.05];
621        let symbols = vec![0, 0, 0, 1, 0, 2, 0, 3, 0, 0];
622
623        // Encode
624        let mut buf = Vec::new();
625        let mut enc = ArithmeticEncoder::new(&mut buf);
626        for &s in &symbols {
627            enc.encode_symbol(&pdf, s).unwrap();
628        }
629        let buf = enc.finish().unwrap().to_vec();
630
631        // Decode
632        let mut dec = ArithmeticDecoder::new(&buf).unwrap();
633        for &expected in &symbols {
634            let got = dec.decode_symbol(&pdf).unwrap();
635            assert_eq!(got, expected);
636        }
637    }
638
639    #[test]
640    fn test_softmax_pdf_floor() {
641        let logits = vec![1.0f32, 2.0, 3.0, 4.0];
642        let pdf = softmax_pdf_floor(&logits, 4);
643
644        // Check sum is ~1
645        let sum: f64 = pdf.iter().sum();
646        assert!((sum - 1.0).abs() < 1e-10);
647
648        // Check all probabilities are >= floor
649        let p_min_val = p_min();
650        for &p in &pdf {
651            assert!(p >= p_min_val);
652        }
653    }
654
655    #[test]
656    fn test_cdf_monotonic() {
657        let pdf = vec![0.1, 0.2, 0.3, 0.4];
658        let cdf = quantize_pdf_to_cdf(&pdf);
659
660        assert_eq!(cdf[0], 0);
661        assert_eq!(cdf[4], CDF_TOTAL);
662
663        // Check monotonicity
664        for i in 1..cdf.len() {
665            assert!(cdf[i] >= cdf[i - 1]);
666        }
667    }
668}