infotheory/backends/
calibration.rs1use crate::CalibrationContextKind;
2use crate::backends::text_context::{NeuralContextState, TextContextAnalyzer};
3
4#[derive(Clone, Debug)]
5pub struct CalibratorCore {
10 analyzer: TextContextAnalyzer,
11 context: CalibrationContextKind,
12 bins: usize,
13 learning_rate: f64,
14 bias_clip: f64,
15 weights: Vec<f64>,
16 last_context: usize,
17 last_bins: Vec<usize>,
18}
19
20impl CalibratorCore {
21 pub fn new(
23 context: CalibrationContextKind,
24 bins: usize,
25 learning_rate: f64,
26 bias_clip: f64,
27 ) -> Self {
28 let bins = bins.max(2);
29 let weights = vec![0.0; context_cardinality(context) * bins];
30 Self {
31 analyzer: TextContextAnalyzer::new(),
32 context,
33 bins,
34 learning_rate: learning_rate.max(1e-6),
35 bias_clip: bias_clip.max(1e-6),
36 weights,
37 last_context: 0,
38 last_bins: vec![0; 256],
39 }
40 }
41
42 pub fn apply_pdf(&mut self, base: &[f64], out: &mut [f64]) {
46 let ctx = context_index(self.context, self.analyzer.state());
47 self.last_context = ctx;
48 let offset = ctx * self.bins;
49 let mut sum = 0.0;
50 for i in 0..256 {
51 let p = base[i].clamp(1e-12, 1.0 - 1e-12);
52 let bin = probability_bin(p, self.bins);
53 self.last_bins[i] = bin;
54 let w = self.weights[offset + bin];
55 let adjusted = p * w.exp();
56 out[i] = adjusted;
57 sum += adjusted;
58 }
59 if !sum.is_finite() || sum <= 0.0 {
60 let u = 1.0 / 256.0;
61 out.fill(u);
62 return;
63 }
64 let inv = 1.0 / sum;
65 for value in out.iter_mut() {
66 *value *= inv;
67 }
68 }
69
70 pub fn update(&mut self, symbol: u8, calibrated_pdf: &[f64]) {
72 let idx = self.last_context * self.bins + self.last_bins[symbol as usize];
73 let q = calibrated_pdf[symbol as usize].clamp(1e-9, 1.0);
74 self.weights[idx] = (self.weights[idx] + self.learning_rate * (1.0 - q))
75 .clamp(-self.bias_clip, self.bias_clip);
76 self.analyzer.update(symbol);
77 }
78
79 pub fn reset_context(&mut self) {
81 self.analyzer = TextContextAnalyzer::new();
82 self.last_context = 0;
83 self.last_bins.fill(0);
84 }
85
86 pub fn update_context_only(&mut self, symbol: u8) {
88 self.analyzer.update(symbol);
89 }
90}
91
92fn context_cardinality(kind: CalibrationContextKind) -> usize {
93 match kind {
94 CalibrationContextKind::Global => 1,
95 CalibrationContextKind::ByteClass => 8,
96 CalibrationContextKind::Text => 256,
97 CalibrationContextKind::Repeat => 64,
98 CalibrationContextKind::TextRepeat => 512,
99 }
100}
101
102fn context_index(kind: CalibrationContextKind, state: NeuralContextState) -> usize {
103 match kind {
104 CalibrationContextKind::Global => 0,
105 CalibrationContextKind::ByteClass => state.prev1_class as usize,
106 CalibrationContextKind::Text => hash_state(
107 &[
108 state.prev1_class,
109 state.prev2_class,
110 state.word_len_bucket,
111 state.prev_word_class,
112 state.bracket_bucket,
113 state.quote_flags,
114 state.utf8_left,
115 state.sentence_boundary as u8,
116 state.paragraph_break as u8,
117 ],
118 256,
119 ),
120 CalibrationContextKind::Repeat => hash_state(
121 &[
122 state.repeat_len_bucket,
123 state.copied_last_byte as u8,
124 (state.run_len.min(31) as u8),
125 ],
126 64,
127 ),
128 CalibrationContextKind::TextRepeat => hash_state(
129 &[
130 state.prev1_class,
131 state.word_len_bucket,
132 state.prev_word_class,
133 state.bracket_bucket,
134 state.quote_flags,
135 state.repeat_len_bucket,
136 state.copied_last_byte as u8,
137 state.paragraph_break as u8,
138 ],
139 512,
140 ),
141 }
142}
143
144fn probability_bin(prob: f64, bins: usize) -> usize {
145 let logit = (prob / (1.0 - prob)).ln();
146 let scaled = ((logit + 24.0) / 24.0).clamp(0.0, 1.0);
147 (scaled * ((bins - 1) as f64)).round() as usize
148}
149
150fn hash_state(values: &[u8], modulo: usize) -> usize {
151 let mut h = 0x9E37_79B9u32;
152 for &value in values {
153 h ^= value as u32;
154 h = h.rotate_left(5).wrapping_mul(0x85EB_CA6B);
155 }
156 (h as usize) % modulo
157}