infotheory/backends/
calibration.rs

1use crate::CalibrationContextKind;
2use crate::backends::text_context::{NeuralContextState, TextContextAnalyzer};
3
4#[derive(Clone, Debug)]
5/// Lightweight online calibrator that rescales a base 256-way PDF by context/bin.
6///
7/// The calibrator keeps per-context logits over probability bins and applies an
8/// exponential tilt `p' ∝ p * exp(w_bin)` followed by normalization.
9pub struct CalibratorCore {
10    analyzer: TextContextAnalyzer,
11    context: CalibrationContextKind,
12    bins: usize,
13    learning_rate: f64,
14    bias_clip: f64,
15    weights: Vec<f64>,
16    last_context: usize,
17    last_bins: Vec<usize>,
18}
19
20impl CalibratorCore {
21    /// Create a calibrator with bounded bin count and stable learning parameters.
22    pub fn new(
23        context: CalibrationContextKind,
24        bins: usize,
25        learning_rate: f64,
26        bias_clip: f64,
27    ) -> Self {
28        let bins = bins.max(2);
29        let weights = vec![0.0; context_cardinality(context) * bins];
30        Self {
31            analyzer: TextContextAnalyzer::new(),
32            context,
33            bins,
34            learning_rate: learning_rate.max(1e-6),
35            bias_clip: bias_clip.max(1e-6),
36            weights,
37            last_context: 0,
38            last_bins: vec![0; 256],
39        }
40    }
41
42    /// Apply the learned calibration transform to `base`, writing a normalized PDF to `out`.
43    ///
44    /// `base`/`out` are expected to be 256-byte distributions.
45    pub fn apply_pdf(&mut self, base: &[f64], out: &mut [f64]) {
46        let ctx = context_index(self.context, self.analyzer.state());
47        self.last_context = ctx;
48        let offset = ctx * self.bins;
49        let mut sum = 0.0;
50        for i in 0..256 {
51            let p = base[i].clamp(1e-12, 1.0 - 1e-12);
52            let bin = probability_bin(p, self.bins);
53            self.last_bins[i] = bin;
54            let w = self.weights[offset + bin];
55            let adjusted = p * w.exp();
56            out[i] = adjusted;
57            sum += adjusted;
58        }
59        if !sum.is_finite() || sum <= 0.0 {
60            let u = 1.0 / 256.0;
61            out.fill(u);
62            return;
63        }
64        let inv = 1.0 / sum;
65        for value in out.iter_mut() {
66            *value *= inv;
67        }
68    }
69
70    /// Update the active bin weight from the observed symbol and calibrated distribution.
71    pub fn update(&mut self, symbol: u8, calibrated_pdf: &[f64]) {
72        let idx = self.last_context * self.bins + self.last_bins[symbol as usize];
73        let q = calibrated_pdf[symbol as usize].clamp(1e-9, 1.0);
74        self.weights[idx] = (self.weights[idx] + self.learning_rate * (1.0 - q))
75            .clamp(-self.bias_clip, self.bias_clip);
76        self.analyzer.update(symbol);
77    }
78
79    /// Reset only the dynamic context state while preserving fitted weights.
80    pub fn reset_context(&mut self) {
81        self.analyzer = TextContextAnalyzer::new();
82        self.last_context = 0;
83        self.last_bins.fill(0);
84    }
85
86    /// Advance context state without updating fitted calibration weights.
87    pub fn update_context_only(&mut self, symbol: u8) {
88        self.analyzer.update(symbol);
89    }
90}
91
92fn context_cardinality(kind: CalibrationContextKind) -> usize {
93    match kind {
94        CalibrationContextKind::Global => 1,
95        CalibrationContextKind::ByteClass => 8,
96        CalibrationContextKind::Text => 256,
97        CalibrationContextKind::Repeat => 64,
98        CalibrationContextKind::TextRepeat => 512,
99    }
100}
101
102fn context_index(kind: CalibrationContextKind, state: NeuralContextState) -> usize {
103    match kind {
104        CalibrationContextKind::Global => 0,
105        CalibrationContextKind::ByteClass => state.prev1_class as usize,
106        CalibrationContextKind::Text => hash_state(
107            &[
108                state.prev1_class,
109                state.prev2_class,
110                state.word_len_bucket,
111                state.prev_word_class,
112                state.bracket_bucket,
113                state.quote_flags,
114                state.utf8_left,
115                state.sentence_boundary as u8,
116                state.paragraph_break as u8,
117            ],
118            256,
119        ),
120        CalibrationContextKind::Repeat => hash_state(
121            &[
122                state.repeat_len_bucket,
123                state.copied_last_byte as u8,
124                (state.run_len.min(31) as u8),
125            ],
126            64,
127        ),
128        CalibrationContextKind::TextRepeat => hash_state(
129            &[
130                state.prev1_class,
131                state.word_len_bucket,
132                state.prev_word_class,
133                state.bracket_bucket,
134                state.quote_flags,
135                state.repeat_len_bucket,
136                state.copied_last_byte as u8,
137                state.paragraph_break as u8,
138            ],
139            512,
140        ),
141    }
142}
143
144fn probability_bin(prob: f64, bins: usize) -> usize {
145    let logit = (prob / (1.0 - prob)).ln();
146    let scaled = ((logit + 24.0) / 24.0).clamp(0.0, 1.0);
147    (scaled * ((bins - 1) as f64)).round() as usize
148}
149
150fn hash_state(values: &[u8], modulo: usize) -> usize {
151    let mut h = 0x9E37_79B9u32;
152    for &value in values {
153        h ^= value as u32;
154        h = h.rotate_left(5).wrapping_mul(0x85EB_CA6B);
155    }
156    (h as usize) % modulo
157}