infotheory/backends/
text_context.rs

1use ahash::AHashMap;
2
3type RepeatPos = u32;
4const REPEAT_POS_NONE: RepeatPos = u32::MAX;
5type RepeatKey = u32;
6
7#[inline]
8fn repeat_pos_from_usize(pos: usize) -> RepeatPos {
9    if pos >= REPEAT_POS_NONE as usize {
10        panic!("text repeat position overflow");
11    }
12    pos as RepeatPos
13}
14
15#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
16pub(crate) struct NeuralContextState {
17    pub(crate) prev1: u8,
18    pub(crate) prev2: u8,
19    pub(crate) prev1_class: u8,
20    pub(crate) prev2_class: u8,
21    pub(crate) run_len: u16,
22    pub(crate) utf8_left: u8,
23    pub(crate) in_word: bool,
24    pub(crate) word_len_bucket: u8,
25    pub(crate) prev_word_class: u8,
26    pub(crate) bracket_bucket: u8,
27    pub(crate) quote_flags: u8,
28    pub(crate) sentence_boundary: bool,
29    pub(crate) paragraph_break: bool,
30    pub(crate) repeat_len_bucket: u8,
31    pub(crate) copied_last_byte: bool,
32    pub(crate) has_history: bool,
33}
34
35pub(crate) type NeuralHistoryState = NeuralContextState;
36
37#[derive(Clone, Copy, Debug, Default)]
38struct WordTracker {
39    len: u16,
40    saw_lower: bool,
41    saw_upper: bool,
42    saw_digit: bool,
43    saw_other: bool,
44}
45
46impl WordTracker {
47    fn clear(&mut self) {
48        *self = Self::default();
49    }
50
51    fn observe(&mut self, byte: u8) {
52        self.len = self.len.saturating_add(1);
53        match byte {
54            b'a'..=b'z' => self.saw_lower = true,
55            b'A'..=b'Z' => self.saw_upper = true,
56            b'0'..=b'9' => self.saw_digit = true,
57            _ => self.saw_other = true,
58        }
59    }
60
61    fn class(self) -> u8 {
62        if self.saw_digit && !self.saw_lower && !self.saw_upper && !self.saw_other {
63            3
64        } else if self.saw_upper && !self.saw_lower && !self.saw_digit && !self.saw_other {
65            2
66        } else if self.saw_lower && !self.saw_upper && !self.saw_digit && !self.saw_other {
67            1
68        } else if self.len > 0 {
69            4
70        } else {
71            0
72        }
73    }
74}
75
76#[derive(Clone, Debug, Default)]
77struct LocalRepeatState {
78    history: Vec<u8>,
79    table: AHashMap<RepeatKey, (RepeatPos, RepeatPos)>,
80    predicted: Option<u8>,
81    match_len: usize,
82    copied_last_byte: bool,
83}
84
85impl LocalRepeatState {
86    const MIN_LEN: usize = 4;
87    const MAX_LEN: usize = 255;
88
89    fn predict_from_history(&self) -> (Option<u8>, usize) {
90        if self.history.len() < Self::MIN_LEN {
91            return (None, 0);
92        }
93        let end = self.history.len() - 1;
94        let Some(key) = repeat_key(&self.history) else {
95            return (None, 0);
96        };
97        let Some(&(latest, previous)) = self.table.get(&key) else {
98            return (None, 0);
99        };
100        let end_pos = repeat_pos_from_usize(end);
101        let candidate_end = if latest == end_pos { previous } else { latest };
102        if candidate_end == REPEAT_POS_NONE {
103            return (None, 0);
104        }
105        let candidate_end = candidate_end as usize;
106        if candidate_end + 1 >= self.history.len() {
107            return (None, 0);
108        }
109        let mut matched = Self::MIN_LEN;
110        while matched < Self::MAX_LEN
111            && end >= matched
112            && candidate_end >= matched
113            && self.history[end - matched] == self.history[candidate_end - matched]
114        {
115            matched += 1;
116        }
117        (Some(self.history[candidate_end + 1]), matched)
118    }
119
120    fn repeat_len_bucket(&self) -> u8 {
121        bucket_repeat_len(self.match_len)
122    }
123
124    fn update(&mut self, symbol: u8) {
125        self.copied_last_byte = self.predicted == Some(symbol);
126        self.history.push(symbol);
127        if let Some(key) = repeat_key(&self.history) {
128            let end = repeat_pos_from_usize(self.history.len() - 1);
129            self.table
130                .entry(key)
131                .and_modify(|entry| {
132                    entry.1 = entry.0;
133                    entry.0 = end;
134                })
135                .or_insert((end, REPEAT_POS_NONE));
136        }
137        let (predicted, match_len) = self.predict_from_history();
138        self.predicted = predicted;
139        self.match_len = match_len;
140    }
141}
142
143#[derive(Clone, Debug, Default)]
144pub(crate) struct TextContextAnalyzer {
145    state: NeuralContextState,
146    word: WordTracker,
147    newline_run: u8,
148    bracket_stack: [u8; 8],
149    bracket_depth: usize,
150    repeat: LocalRepeatState,
151}
152
153impl TextContextAnalyzer {
154    pub(crate) fn new() -> Self {
155        Self::default()
156    }
157
158    pub(crate) fn state(&self) -> NeuralContextState {
159        self.state
160    }
161
162    pub(crate) fn update(&mut self, symbol: u8) {
163        let was_predicted = self.repeat.predicted;
164        self.repeat.update(symbol);
165
166        let byte_class = classify_byte(symbol);
167        if self.state.has_history && symbol == self.state.prev1 {
168            self.state.run_len = self.state.run_len.saturating_add(1).min(255);
169        } else {
170            self.state.run_len = 1;
171        }
172
173        let prev_in_word = self.state.in_word;
174        let is_word_byte = is_word_byte(symbol);
175        if is_word_byte {
176            if !prev_in_word {
177                self.word.clear();
178            }
179            self.word.observe(symbol);
180        } else if prev_in_word {
181            self.state.prev_word_class = self.word.class();
182            self.word.clear();
183        }
184        self.state.in_word = is_word_byte;
185        self.state.word_len_bucket = bucket_word_len(self.word.len);
186
187        self.update_structure(symbol);
188
189        self.state.prev2 = self.state.prev1;
190        self.state.prev2_class = self.state.prev1_class;
191        self.state.prev1 = symbol;
192        self.state.prev1_class = byte_class;
193        self.state.utf8_left = utf8_left_after(symbol, self.state.utf8_left);
194        self.state.repeat_len_bucket = self.repeat.repeat_len_bucket();
195        self.state.copied_last_byte = was_predicted == Some(symbol);
196        self.state.has_history = true;
197    }
198
199    fn update_structure(&mut self, symbol: u8) {
200        self.state.sentence_boundary = matches!(symbol, b'.' | b'!' | b'?');
201
202        if symbol == b'\n' {
203            self.newline_run = self.newline_run.saturating_add(1).min(3);
204        } else {
205            self.newline_run = 0;
206        }
207        self.state.paragraph_break = self.newline_run >= 2;
208
209        match symbol {
210            b'(' => self.push_bracket(1),
211            b'[' => self.push_bracket(2),
212            b'{' => self.push_bracket(3),
213            b'<' => self.push_bracket(4),
214            b')' => self.pop_bracket(1),
215            b']' => self.pop_bracket(2),
216            b'}' => self.pop_bracket(3),
217            b'>' => self.pop_bracket(4),
218            b'"' => self.state.quote_flags ^= 0x1,
219            b'\'' => self.state.quote_flags ^= 0x2,
220            _ => {}
221        }
222        self.state.bracket_bucket = if self.bracket_depth == 0 {
223            0
224        } else {
225            self.bracket_stack[self.bracket_depth - 1]
226        };
227    }
228
229    fn push_bracket(&mut self, bracket: u8) {
230        if self.bracket_depth < self.bracket_stack.len() {
231            self.bracket_stack[self.bracket_depth] = bracket;
232            self.bracket_depth += 1;
233        } else {
234            self.bracket_stack[self.bracket_stack.len() - 1] = bracket;
235        }
236    }
237
238    fn pop_bracket(&mut self, bracket: u8) {
239        if self.bracket_depth == 0 {
240            return;
241        }
242        if self.bracket_stack[self.bracket_depth - 1] == bracket {
243            self.bracket_depth -= 1;
244            return;
245        }
246        for idx in (0..self.bracket_depth).rev() {
247            if self.bracket_stack[idx] == bracket {
248                self.bracket_depth = idx;
249                return;
250            }
251        }
252    }
253}
254
255#[inline]
256pub(crate) fn classify_byte(byte: u8) -> u8 {
257    match byte {
258        b'a'..=b'z' | b'A'..=b'Z' => 1,
259        b'0'..=b'9' => 2,
260        b' ' | b'\t' | b'\n' | b'\r' => 3,
261        b'!'..=b'/' | b':'..=b'@' | b'['..=b'`' | b'{'..=b'~' => 4,
262        0xC0..=0xFF => 5,
263        0x80..=0xBF => 6,
264        _ => 0,
265    }
266}
267
268#[inline]
269fn is_word_byte(byte: u8) -> bool {
270    matches!(byte, b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_' | 0x80..=0xFF)
271}
272
273#[inline]
274fn bucket_word_len(len: u16) -> u8 {
275    match len {
276        0 => 0,
277        1 => 1,
278        2 => 2,
279        3..=4 => 3,
280        5..=8 => 4,
281        9..=16 => 5,
282        _ => 6,
283    }
284}
285
286#[inline]
287pub(crate) fn bucket_repeat_len(len: usize) -> u8 {
288    match len {
289        0 => 0,
290        1..=3 => 1,
291        4..=5 => 2,
292        6..=8 => 3,
293        9..=12 => 4,
294        13..=16 => 5,
295        17..=24 => 6,
296        _ => 7,
297    }
298}
299
300#[inline]
301fn utf8_left_after(symbol: u8, prev_left: u8) -> u8 {
302    if (0x80..=0xBF).contains(&symbol) {
303        prev_left.saturating_sub(1)
304    } else {
305        match symbol {
306            0xC0..=0xDF => 1,
307            0xE0..=0xEF => 2,
308            0xF0..=0xF7 => 3,
309            _ => 0,
310        }
311    }
312}
313
314#[inline]
315fn repeat_key(history: &[u8]) -> Option<RepeatKey> {
316    if history.len() < LocalRepeatState::MIN_LEN {
317        return None;
318    }
319    let n = history.len();
320    Some(u32::from_be_bytes([
321        history[n - 4],
322        history[n - 3],
323        history[n - 2],
324        history[n - 1],
325    ]))
326}