1use ahash::AHashMap;
2
3type RepeatPos = u32;
4const REPEAT_POS_NONE: RepeatPos = u32::MAX;
5type RepeatKey = u32;
6
7#[inline]
8fn repeat_pos_from_usize(pos: usize) -> RepeatPos {
9 if pos >= REPEAT_POS_NONE as usize {
10 panic!("text repeat position overflow");
11 }
12 pos as RepeatPos
13}
14
15#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
16pub(crate) struct NeuralContextState {
17 pub(crate) prev1: u8,
18 pub(crate) prev2: u8,
19 pub(crate) prev1_class: u8,
20 pub(crate) prev2_class: u8,
21 pub(crate) run_len: u16,
22 pub(crate) utf8_left: u8,
23 pub(crate) in_word: bool,
24 pub(crate) word_len_bucket: u8,
25 pub(crate) prev_word_class: u8,
26 pub(crate) bracket_bucket: u8,
27 pub(crate) quote_flags: u8,
28 pub(crate) sentence_boundary: bool,
29 pub(crate) paragraph_break: bool,
30 pub(crate) repeat_len_bucket: u8,
31 pub(crate) copied_last_byte: bool,
32 pub(crate) has_history: bool,
33}
34
35pub(crate) type NeuralHistoryState = NeuralContextState;
36
37#[derive(Clone, Copy, Debug, Default)]
38struct WordTracker {
39 len: u16,
40 saw_lower: bool,
41 saw_upper: bool,
42 saw_digit: bool,
43 saw_other: bool,
44}
45
46impl WordTracker {
47 fn clear(&mut self) {
48 *self = Self::default();
49 }
50
51 fn observe(&mut self, byte: u8) {
52 self.len = self.len.saturating_add(1);
53 match byte {
54 b'a'..=b'z' => self.saw_lower = true,
55 b'A'..=b'Z' => self.saw_upper = true,
56 b'0'..=b'9' => self.saw_digit = true,
57 _ => self.saw_other = true,
58 }
59 }
60
61 fn class(self) -> u8 {
62 if self.saw_digit && !self.saw_lower && !self.saw_upper && !self.saw_other {
63 3
64 } else if self.saw_upper && !self.saw_lower && !self.saw_digit && !self.saw_other {
65 2
66 } else if self.saw_lower && !self.saw_upper && !self.saw_digit && !self.saw_other {
67 1
68 } else if self.len > 0 {
69 4
70 } else {
71 0
72 }
73 }
74}
75
76#[derive(Clone, Debug, Default)]
77struct LocalRepeatState {
78 history: Vec<u8>,
79 table: AHashMap<RepeatKey, (RepeatPos, RepeatPos)>,
80 predicted: Option<u8>,
81 match_len: usize,
82 copied_last_byte: bool,
83}
84
85impl LocalRepeatState {
86 const MIN_LEN: usize = 4;
87 const MAX_LEN: usize = 255;
88
89 fn predict_from_history(&self) -> (Option<u8>, usize) {
90 if self.history.len() < Self::MIN_LEN {
91 return (None, 0);
92 }
93 let end = self.history.len() - 1;
94 let Some(key) = repeat_key(&self.history) else {
95 return (None, 0);
96 };
97 let Some(&(latest, previous)) = self.table.get(&key) else {
98 return (None, 0);
99 };
100 let end_pos = repeat_pos_from_usize(end);
101 let candidate_end = if latest == end_pos { previous } else { latest };
102 if candidate_end == REPEAT_POS_NONE {
103 return (None, 0);
104 }
105 let candidate_end = candidate_end as usize;
106 if candidate_end + 1 >= self.history.len() {
107 return (None, 0);
108 }
109 let mut matched = Self::MIN_LEN;
110 while matched < Self::MAX_LEN
111 && end >= matched
112 && candidate_end >= matched
113 && self.history[end - matched] == self.history[candidate_end - matched]
114 {
115 matched += 1;
116 }
117 (Some(self.history[candidate_end + 1]), matched)
118 }
119
120 fn repeat_len_bucket(&self) -> u8 {
121 bucket_repeat_len(self.match_len)
122 }
123
124 fn update(&mut self, symbol: u8) {
125 self.copied_last_byte = self.predicted == Some(symbol);
126 self.history.push(symbol);
127 if let Some(key) = repeat_key(&self.history) {
128 let end = repeat_pos_from_usize(self.history.len() - 1);
129 self.table
130 .entry(key)
131 .and_modify(|entry| {
132 entry.1 = entry.0;
133 entry.0 = end;
134 })
135 .or_insert((end, REPEAT_POS_NONE));
136 }
137 let (predicted, match_len) = self.predict_from_history();
138 self.predicted = predicted;
139 self.match_len = match_len;
140 }
141}
142
143#[derive(Clone, Debug, Default)]
144pub(crate) struct TextContextAnalyzer {
145 state: NeuralContextState,
146 word: WordTracker,
147 newline_run: u8,
148 bracket_stack: [u8; 8],
149 bracket_depth: usize,
150 repeat: LocalRepeatState,
151}
152
153impl TextContextAnalyzer {
154 pub(crate) fn new() -> Self {
155 Self::default()
156 }
157
158 pub(crate) fn state(&self) -> NeuralContextState {
159 self.state
160 }
161
162 pub(crate) fn update(&mut self, symbol: u8) {
163 let was_predicted = self.repeat.predicted;
164 self.repeat.update(symbol);
165
166 let byte_class = classify_byte(symbol);
167 if self.state.has_history && symbol == self.state.prev1 {
168 self.state.run_len = self.state.run_len.saturating_add(1).min(255);
169 } else {
170 self.state.run_len = 1;
171 }
172
173 let prev_in_word = self.state.in_word;
174 let is_word_byte = is_word_byte(symbol);
175 if is_word_byte {
176 if !prev_in_word {
177 self.word.clear();
178 }
179 self.word.observe(symbol);
180 } else if prev_in_word {
181 self.state.prev_word_class = self.word.class();
182 self.word.clear();
183 }
184 self.state.in_word = is_word_byte;
185 self.state.word_len_bucket = bucket_word_len(self.word.len);
186
187 self.update_structure(symbol);
188
189 self.state.prev2 = self.state.prev1;
190 self.state.prev2_class = self.state.prev1_class;
191 self.state.prev1 = symbol;
192 self.state.prev1_class = byte_class;
193 self.state.utf8_left = utf8_left_after(symbol, self.state.utf8_left);
194 self.state.repeat_len_bucket = self.repeat.repeat_len_bucket();
195 self.state.copied_last_byte = was_predicted == Some(symbol);
196 self.state.has_history = true;
197 }
198
199 fn update_structure(&mut self, symbol: u8) {
200 self.state.sentence_boundary = matches!(symbol, b'.' | b'!' | b'?');
201
202 if symbol == b'\n' {
203 self.newline_run = self.newline_run.saturating_add(1).min(3);
204 } else {
205 self.newline_run = 0;
206 }
207 self.state.paragraph_break = self.newline_run >= 2;
208
209 match symbol {
210 b'(' => self.push_bracket(1),
211 b'[' => self.push_bracket(2),
212 b'{' => self.push_bracket(3),
213 b'<' => self.push_bracket(4),
214 b')' => self.pop_bracket(1),
215 b']' => self.pop_bracket(2),
216 b'}' => self.pop_bracket(3),
217 b'>' => self.pop_bracket(4),
218 b'"' => self.state.quote_flags ^= 0x1,
219 b'\'' => self.state.quote_flags ^= 0x2,
220 _ => {}
221 }
222 self.state.bracket_bucket = if self.bracket_depth == 0 {
223 0
224 } else {
225 self.bracket_stack[self.bracket_depth - 1]
226 };
227 }
228
229 fn push_bracket(&mut self, bracket: u8) {
230 if self.bracket_depth < self.bracket_stack.len() {
231 self.bracket_stack[self.bracket_depth] = bracket;
232 self.bracket_depth += 1;
233 } else {
234 self.bracket_stack[self.bracket_stack.len() - 1] = bracket;
235 }
236 }
237
238 fn pop_bracket(&mut self, bracket: u8) {
239 if self.bracket_depth == 0 {
240 return;
241 }
242 if self.bracket_stack[self.bracket_depth - 1] == bracket {
243 self.bracket_depth -= 1;
244 return;
245 }
246 for idx in (0..self.bracket_depth).rev() {
247 if self.bracket_stack[idx] == bracket {
248 self.bracket_depth = idx;
249 return;
250 }
251 }
252 }
253}
254
255#[inline]
256pub(crate) fn classify_byte(byte: u8) -> u8 {
257 match byte {
258 b'a'..=b'z' | b'A'..=b'Z' => 1,
259 b'0'..=b'9' => 2,
260 b' ' | b'\t' | b'\n' | b'\r' => 3,
261 b'!'..=b'/' | b':'..=b'@' | b'['..=b'`' | b'{'..=b'~' => 4,
262 0xC0..=0xFF => 5,
263 0x80..=0xBF => 6,
264 _ => 0,
265 }
266}
267
268#[inline]
269fn is_word_byte(byte: u8) -> bool {
270 matches!(byte, b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_' | 0x80..=0xFF)
271}
272
273#[inline]
274fn bucket_word_len(len: u16) -> u8 {
275 match len {
276 0 => 0,
277 1 => 1,
278 2 => 2,
279 3..=4 => 3,
280 5..=8 => 4,
281 9..=16 => 5,
282 _ => 6,
283 }
284}
285
286#[inline]
287pub(crate) fn bucket_repeat_len(len: usize) -> u8 {
288 match len {
289 0 => 0,
290 1..=3 => 1,
291 4..=5 => 2,
292 6..=8 => 3,
293 9..=12 => 4,
294 13..=16 => 5,
295 17..=24 => 6,
296 _ => 7,
297 }
298}
299
300#[inline]
301fn utf8_left_after(symbol: u8, prev_left: u8) -> u8 {
302 if (0x80..=0xBF).contains(&symbol) {
303 prev_left.saturating_sub(1)
304 } else {
305 match symbol {
306 0xC0..=0xDF => 1,
307 0xE0..=0xEF => 2,
308 0xF0..=0xF7 => 3,
309 _ => 0,
310 }
311 }
312}
313
314#[inline]
315fn repeat_key(history: &[u8]) -> Option<RepeatKey> {
316 if history.len() < LocalRepeatState::MIN_LEN {
317 return None;
318 }
319 let n = history.len();
320 Some(u32::from_be_bytes([
321 history[n - 4],
322 history[n - 3],
323 history[n - 2],
324 history[n - 1],
325 ]))
326}