1use ahash::AHashMap;
2
3const PDF_MIN: f64 = crate::mixture::DEFAULT_MIN_PROB;
4const RAW_FALLBACK_MAX: usize = 4;
5
6type NodeIx = u32;
7type RuleId = u32;
8
9#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
10enum Symbol {
11 Terminal(u8),
12 NonTerminal(RuleId),
13}
14
15#[derive(Clone, Copy, Debug, Eq, PartialEq)]
16enum NodeData {
17 Guard(RuleId),
18 Sym(Symbol),
19}
20
21#[derive(Clone, Copy, Debug)]
22struct Node {
23 prev: NodeIx,
24 next: NodeIx,
25 data: NodeData,
26}
27
28#[derive(Clone, Debug)]
29struct Rule {
30 guard: NodeIx,
31 ref_count: u32,
32 active: bool,
33}
34
35#[derive(Clone, Debug, Default, Eq, PartialEq)]
36struct ContextFollowers {
37 counts: Vec<(u8, u64)>,
38 total: u64,
39}
40
41impl ContextFollowers {
42 fn observe(&mut self, symbol: u8) {
43 if let Some((_, count)) = self.counts.iter_mut().find(|(s, _)| *s == symbol) {
44 *count += 1;
45 } else {
46 self.counts.push((symbol, 1));
47 }
48 self.total += 1;
49 }
50
51 fn distinct(&self) -> usize {
52 self.counts.len()
53 }
54}
55
56#[derive(Clone, Debug)]
57enum UndoOp {
58 SetPrev {
59 node: NodeIx,
60 old: NodeIx,
61 },
62 SetNext {
63 node: NodeIx,
64 old: NodeIx,
65 },
66 SetRuleRefCount {
67 rule: RuleId,
68 old: u32,
69 },
70 SetRuleActive {
71 rule: RuleId,
72 old: bool,
73 },
74 SetDigram {
75 key: u64,
76 old: Option<NodeIx>,
77 },
78 SetContextFollowers {
79 key: Box<[u8]>,
80 old: Option<ContextFollowers>,
81 },
82 SetUnigramSymbol {
83 symbol: u8,
84 old_count: u64,
85 old_total: u64,
86 },
87 SetCommittedRawTail {
88 old: Vec<u8>,
89 },
90 SetFrozenRawTail {
91 old: Vec<u8>,
92 },
93}
94
95#[derive(Clone, Copy, Debug, Eq, PartialEq)]
96pub struct SequiturCheckpoint {
101 undo_len: usize,
102 node_len: usize,
103 rule_len: usize,
104}
105
106#[derive(Clone, Debug, Eq, PartialEq)]
107pub struct CanonicalRule {
109 pub id: usize,
111 pub rhs: Vec<CanonicalSymbol>,
113}
114
115#[derive(Clone, Copy, Debug, Eq, PartialEq)]
116pub enum CanonicalSymbol {
118 Terminal(u8),
120 NonTerminal(usize),
122}
123
124#[derive(Clone, Debug, Eq, PartialEq)]
125pub struct CanonicalGrammar {
127 pub rules: Vec<CanonicalRule>,
129}
130
131#[derive(Clone, Debug)]
132pub struct SequiturModel {
137 context_bytes: usize,
138 nodes: Vec<Node>,
139 rules: Vec<Rule>,
140 dummy: NodeIx,
141 digrams: AHashMap<u64, NodeIx>,
142 followers: AHashMap<Box<[u8]>, ContextFollowers>,
143 unigram: [u64; 256],
144 unigram_total: u64,
145 committed_raw_tail: Vec<u8>,
146 frozen_raw_tail: Vec<u8>,
147 pdf: [f64; 256],
148 pdf_valid: bool,
149 undo: Vec<UndoOp>,
150 undo_enabled: bool,
151}
152
153impl SequiturModel {
154 pub fn new(context_bytes: usize) -> Self {
159 let context_bytes = context_bytes.max(2);
160 let mut nodes = Vec::with_capacity(16);
161 nodes.push(Node {
162 prev: 0,
163 next: 0,
164 data: NodeData::Guard(0),
165 });
166 nodes.push(Node {
167 prev: 1,
168 next: 1,
169 data: NodeData::Guard(0),
170 });
171 let rules = vec![Rule {
172 guard: 0,
173 ref_count: 0,
174 active: true,
175 }];
176 Self {
177 context_bytes,
178 nodes,
179 rules,
180 dummy: 1,
181 digrams: AHashMap::new(),
182 followers: AHashMap::new(),
183 unigram: [0; 256],
184 unigram_total: 0,
185 committed_raw_tail: Vec::with_capacity(RAW_FALLBACK_MAX),
186 frozen_raw_tail: Vec::with_capacity(RAW_FALLBACK_MAX),
187 pdf: [1.0 / 256.0; 256],
188 pdf_valid: false,
189 undo: Vec::new(),
190 undo_enabled: false,
191 }
192 }
193
194 pub fn reserve_for_stream(&mut self, additional_symbols: usize) {
198 self.nodes.reserve(additional_symbols.saturating_mul(2));
199 self.digrams.reserve(additional_symbols);
200 }
201
202 pub fn begin_stream(&mut self, total_symbols: Option<u64>) {
207 if let Some(total) = total_symbols {
208 let reserve = usize::try_from(total).unwrap_or(usize::MAX / 4);
209 self.reserve_for_stream(reserve);
210 }
211 self.frozen_raw_tail.clear();
212 self.pdf_valid = false;
213 }
214
215 pub fn finish_stream(&mut self) {}
220
221 pub fn checkpoint(&mut self) -> SequiturCheckpoint {
225 self.undo_enabled = true;
226 SequiturCheckpoint {
227 undo_len: self.undo.len(),
228 node_len: self.nodes.len(),
229 rule_len: self.rules.len(),
230 }
231 }
232
233 pub fn restore(&mut self, checkpoint: &SequiturCheckpoint) {
235 let saved = self.undo_enabled;
236 self.undo_enabled = false;
237 while self.undo.len() > checkpoint.undo_len {
238 let op = self.undo.pop().expect("undo underflow");
239 self.apply_undo(op);
240 }
241 self.nodes.truncate(checkpoint.node_len);
242 self.rules.truncate(checkpoint.rule_len);
243 self.undo_enabled = saved;
244 self.pdf_valid = false;
245 }
246
247 pub fn clear_checkpoints(&mut self) {
249 self.undo.clear();
250 self.undo_enabled = false;
251 }
252
253 pub fn reset_frozen(&mut self) {
255 self.frozen_raw_tail.clear();
256 self.pdf_valid = false;
257 }
258
259 pub fn fill_pdf(&mut self, out: &mut [f64; 256]) {
261 self.ensure_pdf();
262 out.copy_from_slice(&self.pdf);
263 }
264
265 pub fn pdf(&mut self) -> &[f64; 256] {
267 self.ensure_pdf();
268 &self.pdf
269 }
270
271 pub fn log_prob(&mut self, symbol: u8, min_prob: f64) -> f64 {
275 self.ensure_pdf();
276 self.pdf[symbol as usize].max(min_prob).ln()
277 }
278
279 pub fn update(&mut self, symbol: u8) {
281 self.observe_symbol_in_stats(symbol);
282 self.append_terminal(symbol);
283 self.record_committed_raw_tail(symbol);
284 if !self.frozen_raw_tail.is_empty() {
285 self.record_frozen_raw_tail_inner(Vec::new());
286 }
287 self.pdf_valid = false;
288 }
289
290 pub fn update_frozen(&mut self, symbol: u8) {
296 let mut next = self.frozen_raw_tail.clone();
297 next.push(symbol);
298 if next.len() > RAW_FALLBACK_MAX {
299 let drain = next.len() - RAW_FALLBACK_MAX;
300 next.drain(0..drain);
301 }
302 self.record_frozen_raw_tail_inner(next);
303 self.pdf_valid = false;
304 }
305
306 pub fn decode(&self) -> Vec<u8> {
308 let mut out = Vec::new();
309 self.decode_rule(0, &mut out);
310 out
311 }
312
313 pub fn canonical_grammar(&self) -> CanonicalGrammar {
318 let mut order = Vec::<RuleId>::new();
319 let mut seen = AHashMap::<RuleId, usize>::new();
320 self.collect_rule_preorder(0, &mut order, &mut seen);
321 let mut canonical = Vec::with_capacity(order.len());
322 for (idx, &rule_id) in order.iter().enumerate() {
323 let mut rhs = Vec::new();
324 let guard = self.rules[rule_id as usize].guard;
325 let mut node = self.nodes[guard as usize].next;
326 while node != guard {
327 match self.nodes[node as usize].data {
328 NodeData::Guard(_) => unreachable!("guard in rule body"),
329 NodeData::Sym(Symbol::Terminal(byte)) => {
330 rhs.push(CanonicalSymbol::Terminal(byte))
331 }
332 NodeData::Sym(Symbol::NonTerminal(child)) => {
333 let mapped = *seen
334 .get(&child)
335 .expect("canonical grammar missing child mapping");
336 rhs.push(CanonicalSymbol::NonTerminal(mapped));
337 }
338 }
339 node = self.nodes[node as usize].next;
340 }
341 canonical.push(CanonicalRule { id: idx, rhs });
342 }
343 CanonicalGrammar { rules: canonical }
344 }
345
346 pub fn predictive_trace(&mut self, data: &[u8], alphabet_prefix: usize) -> Vec<Vec<f64>> {
351 let mut trace = Vec::with_capacity(data.len());
352 let mut pdf = [0.0; 256];
353 for &byte in data {
354 self.fill_pdf(&mut pdf);
355 trace.push(pdf[..alphabet_prefix.min(256)].to_vec());
356 self.update(byte);
357 }
358 trace
359 }
360
361 #[cfg(test)]
362 pub fn validate_invariants(&self) -> Result<(), String> {
363 self.validate_rule_shapes()?;
364 self.validate_rule_refcounts()?;
365 self.validate_digram_uniqueness()?;
366 Ok(())
367 }
368
369 fn collect_rule_preorder(
370 &self,
371 rule_id: RuleId,
372 order: &mut Vec<RuleId>,
373 seen: &mut AHashMap<RuleId, usize>,
374 ) {
375 if seen.contains_key(&rule_id) {
376 return;
377 }
378 let idx = order.len();
379 seen.insert(rule_id, idx);
380 order.push(rule_id);
381 let guard = self.rules[rule_id as usize].guard;
382 let mut node = self.nodes[guard as usize].next;
383 while node != guard {
384 if let NodeData::Sym(Symbol::NonTerminal(child)) = self.nodes[node as usize].data {
385 if self.rules[child as usize].active {
386 self.collect_rule_preorder(child, order, seen);
387 }
388 }
389 node = self.nodes[node as usize].next;
390 }
391 }
392
393 fn apply_undo(&mut self, op: UndoOp) {
394 match op {
395 UndoOp::SetPrev { node, old } => {
396 self.nodes[node as usize].prev = old;
397 }
398 UndoOp::SetNext { node, old } => {
399 self.nodes[node as usize].next = old;
400 }
401 UndoOp::SetRuleRefCount { rule, old } => {
402 self.rules[rule as usize].ref_count = old;
403 }
404 UndoOp::SetRuleActive { rule, old } => {
405 self.rules[rule as usize].active = old;
406 }
407 UndoOp::SetDigram { key, old } => {
408 if let Some(node) = old {
409 self.digrams.insert(key, node);
410 } else {
411 self.digrams.remove(&key);
412 }
413 }
414 UndoOp::SetContextFollowers { key, old } => {
415 if let Some(state) = old {
416 self.followers.insert(key, state);
417 } else {
418 self.followers.remove(key.as_ref());
419 }
420 }
421 UndoOp::SetUnigramSymbol {
422 symbol,
423 old_count,
424 old_total,
425 } => {
426 self.unigram[symbol as usize] = old_count;
427 self.unigram_total = old_total;
428 }
429 UndoOp::SetCommittedRawTail { old } => {
430 self.committed_raw_tail = old;
431 }
432 UndoOp::SetFrozenRawTail { old } => {
433 self.frozen_raw_tail = old;
434 }
435 }
436 }
437
438 fn push_undo(&mut self, op: UndoOp) {
439 if self.undo_enabled {
440 self.undo.push(op);
441 }
442 }
443
444 fn set_prev(&mut self, node: NodeIx, prev: NodeIx) {
445 let old = self.nodes[node as usize].prev;
446 if old != prev {
447 self.push_undo(UndoOp::SetPrev { node, old });
448 self.nodes[node as usize].prev = prev;
449 }
450 }
451
452 fn set_next(&mut self, node: NodeIx, next: NodeIx) {
453 let old = self.nodes[node as usize].next;
454 if old != next {
455 self.push_undo(UndoOp::SetNext { node, old });
456 self.nodes[node as usize].next = next;
457 }
458 }
459
460 fn set_rule_ref_count(&mut self, rule: RuleId, ref_count: u32) {
461 let old = self.rules[rule as usize].ref_count;
462 if old != ref_count {
463 self.push_undo(UndoOp::SetRuleRefCount { rule, old });
464 self.rules[rule as usize].ref_count = ref_count;
465 }
466 }
467
468 fn set_rule_active(&mut self, rule: RuleId, active: bool) {
469 let old = self.rules[rule as usize].active;
470 if old != active {
471 self.push_undo(UndoOp::SetRuleActive { rule, old });
472 self.rules[rule as usize].active = active;
473 }
474 }
475
476 fn set_digram(&mut self, key: u64, value: Option<NodeIx>) {
477 let old = self.digrams.get(&key).copied();
478 if old == value {
479 return;
480 }
481 self.push_undo(UndoOp::SetDigram { key, old });
482 if let Some(node) = value {
483 self.digrams.insert(key, node);
484 } else {
485 self.digrams.remove(&key);
486 }
487 }
488
489 fn record_context_followers(&mut self, key: &[u8], new_state: Option<ContextFollowers>) {
490 let boxed: Box<[u8]> = key.to_vec().into_boxed_slice();
491 let old = self.followers.get(boxed.as_ref()).cloned();
492 if old == new_state {
493 return;
494 }
495 self.push_undo(UndoOp::SetContextFollowers {
496 key: boxed.clone(),
497 old,
498 });
499 if let Some(state) = new_state {
500 self.followers.insert(boxed, state);
501 } else {
502 self.followers.remove(boxed.as_ref());
503 }
504 }
505
506 fn record_unigram(&mut self, symbol: u8, next_count: u64, next_total: u64) {
507 let old_count = self.unigram[symbol as usize];
508 let old_total = self.unigram_total;
509 if old_count == next_count && old_total == next_total {
510 return;
511 }
512 self.push_undo(UndoOp::SetUnigramSymbol {
513 symbol,
514 old_count,
515 old_total,
516 });
517 self.unigram[symbol as usize] = next_count;
518 self.unigram_total = next_total;
519 }
520
521 fn record_committed_raw_tail(&mut self, symbol: u8) {
522 let mut next = self.committed_raw_tail.clone();
523 next.push(symbol);
524 if next.len() > RAW_FALLBACK_MAX {
525 let drain = next.len() - RAW_FALLBACK_MAX;
526 next.drain(0..drain);
527 }
528 if next != self.committed_raw_tail {
529 self.push_undo(UndoOp::SetCommittedRawTail {
530 old: self.committed_raw_tail.clone(),
531 });
532 self.committed_raw_tail = next;
533 }
534 }
535
536 fn record_frozen_raw_tail_inner(&mut self, next: Vec<u8>) {
537 if next != self.frozen_raw_tail {
538 self.push_undo(UndoOp::SetFrozenRawTail {
539 old: self.frozen_raw_tail.clone(),
540 });
541 self.frozen_raw_tail = next;
542 }
543 }
544
545 fn ensure_pdf(&mut self) {
546 if self.pdf_valid {
547 return;
548 }
549
550 let denom = (self.unigram_total as f64) + 128.0;
551 for (idx, slot) in self.pdf.iter_mut().enumerate() {
552 *slot = ((self.unigram[idx] as f64) + 0.5) / denom;
553 }
554
555 let contexts = self.current_contexts();
556 let mut next = [0.0; 256];
557 for context in contexts {
558 let Some(stats) = self.followers.get(context.as_slice()) else {
559 continue;
560 };
561 let distinct = stats.distinct();
562 if stats.total == 0 || distinct == 0 {
563 continue;
564 }
565 let total = stats.total as f64;
566 let types = distinct as f64;
567 let escape = types / (total + types);
568 for i in 0..256 {
569 next[i] = self.pdf[i] * escape;
570 }
571 for &(symbol, count) in &stats.counts {
572 next[symbol as usize] += (count as f64) / (total + types);
573 }
574 self.pdf.copy_from_slice(&next);
575 }
576
577 normalize_pdf(&mut self.pdf);
578 self.pdf_valid = true;
579 }
580
581 fn current_contexts(&self) -> Vec<Vec<u8>> {
582 let mut out = Vec::<Vec<u8>>::new();
583 let raw_tail = self.effective_raw_tail();
584 for len in 1..=raw_tail.len().min(RAW_FALLBACK_MAX) {
585 let ctx = raw_tail[raw_tail.len() - len..].to_vec();
586 if !out.iter().any(|existing| existing == &ctx) {
587 out.push(ctx);
588 }
589 }
590
591 let mut rule_chain = Vec::<RuleId>::new();
592 rule_chain.push(0);
593 let mut current = 0u32;
594 loop {
595 let guard = self.rules[current as usize].guard;
596 let last = self.nodes[guard as usize].prev;
597 if last == guard {
598 break;
599 }
600 match self.nodes[last as usize].data {
601 NodeData::Sym(Symbol::NonTerminal(child)) if self.rules[child as usize].active => {
602 rule_chain.push(child);
603 current = child;
604 }
605 _ => break,
606 }
607 }
608
609 for &rule_id in &rule_chain {
610 let ctx = self.rule_tail_bytes(rule_id, self.context_bytes);
611 if !ctx.is_empty() && !out.iter().any(|existing| existing == &ctx) {
612 out.push(ctx);
613 }
614 }
615
616 out
617 }
618
619 fn effective_raw_tail(&self) -> Vec<u8> {
620 let mut out = Vec::with_capacity(RAW_FALLBACK_MAX);
621 let total = self.committed_raw_tail.len() + self.frozen_raw_tail.len();
622 let keep_from = total.saturating_sub(RAW_FALLBACK_MAX);
623 for (idx, &byte) in self
624 .committed_raw_tail
625 .iter()
626 .chain(self.frozen_raw_tail.iter())
627 .enumerate()
628 {
629 if idx >= keep_from {
630 out.push(byte);
631 }
632 }
633 out
634 }
635
636 fn observe_symbol_in_stats(&mut self, symbol: u8) {
637 let contexts = self.current_contexts();
638 for context in contexts {
639 let mut state = self
640 .followers
641 .get(context.as_slice())
642 .cloned()
643 .unwrap_or_default();
644 state.observe(symbol);
645 self.record_context_followers(&context, Some(state));
646 }
647 let old_count = self.unigram[symbol as usize];
648 let old_total = self.unigram_total;
649 self.record_unigram(symbol, old_count + 1, old_total + 1);
650 }
651
652 fn append_terminal(&mut self, byte: u8) {
653 let last = self.last_node_of_rule(0);
654 let _ = self.insert_after(last, Symbol::Terminal(byte));
655 let _ = self.check(last);
656 }
657
658 fn alloc_node(&mut self, data: NodeData) -> NodeIx {
659 let idx = self.nodes.len() as NodeIx;
660 let (prev, next) = match data {
661 NodeData::Guard(_) => (idx, idx),
662 NodeData::Sym(_) => (self.dummy, self.dummy),
663 };
664 self.nodes.push(Node { prev, next, data });
665 idx
666 }
667
668 fn new_rule(&mut self) -> RuleId {
669 let id = self.rules.len() as RuleId;
670 let guard = self.alloc_node(NodeData::Guard(id));
671 self.rules.push(Rule {
672 guard,
673 ref_count: 0,
674 active: true,
675 });
676 id
677 }
678
679 fn is_guard(&self, node: NodeIx) -> bool {
680 matches!(self.nodes[node as usize].data, NodeData::Guard(_))
681 }
682
683 fn symbol_of(&self, node: NodeIx) -> Symbol {
684 match self.nodes[node as usize].data {
685 NodeData::Sym(symbol) => symbol,
686 NodeData::Guard(_) => panic!("node_symbol called on guard node"),
687 }
688 }
689
690 fn symbol_maybe(&self, node: NodeIx) -> Option<Symbol> {
691 match self.nodes[node as usize].data {
692 NodeData::Sym(symbol) => Some(symbol),
693 NodeData::Guard(_) => None,
694 }
695 }
696
697 fn guard_rule(&self, node: NodeIx) -> Option<RuleId> {
698 match self.nodes[node as usize].data {
699 NodeData::Guard(rule) => Some(rule),
700 NodeData::Sym(_) => None,
701 }
702 }
703
704 fn first_node_of_rule(&self, rule: RuleId) -> NodeIx {
705 let guard = self.rules[rule as usize].guard;
706 self.nodes[guard as usize].next
707 }
708
709 fn last_node_of_rule(&self, rule: RuleId) -> NodeIx {
710 let guard = self.rules[rule as usize].guard;
711 self.nodes[guard as usize].prev
712 }
713
714 fn encode_symbol(symbol: Symbol) -> u32 {
715 match symbol {
716 Symbol::Terminal(byte) => byte as u32,
717 Symbol::NonTerminal(rule) => 256u32.wrapping_add(rule),
718 }
719 }
720
721 fn digram_key_from_symbols(left: Symbol, right: Symbol) -> u64 {
722 ((Self::encode_symbol(left) as u64) << 32) | (Self::encode_symbol(right) as u64)
723 }
724
725 fn digram_key_at(&self, node: NodeIx) -> Option<u64> {
726 if self.is_guard(node) {
727 return None;
728 }
729 let next = self.nodes[node as usize].next;
730 if self.is_guard(next) {
731 return None;
732 }
733 Some(Self::digram_key_from_symbols(
734 self.symbol_of(node),
735 self.symbol_of(next),
736 ))
737 }
738
739 fn link(&mut self, left: NodeIx, right: NodeIx) {
740 let left_prev = self.nodes[left as usize].prev;
741 let left_next = self.nodes[left as usize].next;
742 let right_prev = self.nodes[right as usize].prev;
743 let right_next = self.nodes[right as usize].next;
744
745 if !self.is_guard(left_next) {
746 self.delete_digram(left);
747
748 match (
749 self.symbol_maybe(right_prev),
750 self.symbol_maybe(right),
751 self.symbol_maybe(right_next),
752 ) {
753 (Some(sym1), Some(sym2), Some(sym3)) if sym1 == sym2 && sym2 == sym3 => {
754 self.set_digram(Self::digram_key_from_symbols(sym2, sym3), Some(right));
755 }
756 _ => {}
757 }
758
759 match (
760 self.symbol_maybe(left_prev),
761 self.symbol_maybe(left),
762 self.symbol_maybe(left_next),
763 ) {
764 (Some(sym1), Some(sym2), Some(sym3)) if sym1 == sym2 && sym2 == sym3 => {
765 self.set_digram(Self::digram_key_from_symbols(sym1, sym2), Some(left_prev));
766 }
767 _ => {}
768 }
769 }
770
771 self.set_next(left, right);
772 self.set_prev(right, left);
773 }
774
775 fn insert_after(&mut self, node: NodeIx, symbol: Symbol) -> NodeIx {
776 let new_node = self.alloc_node(NodeData::Sym(symbol));
777 let next = self.nodes[node as usize].next;
778 self.link(new_node, next);
779 self.link(node, new_node);
780 if let Symbol::NonTerminal(rule) = symbol {
781 let next_count = self.rules[rule as usize].ref_count.saturating_add(1);
782 self.set_rule_ref_count(rule, next_count);
783 }
784 new_node
785 }
786
787 fn delete_digram(&mut self, node: NodeIx) {
788 let Some(key) = self.digram_key_at(node) else {
789 return;
790 };
791 match self.digrams.get(&key).copied() {
792 Some(existing) if existing != node => {}
793 _ => self.set_digram(key, None),
794 }
795 }
796
797 fn check(&mut self, node: NodeIx) -> bool {
798 let Some(key) = self.digram_key_at(node) else {
799 return false;
800 };
801 let existing = self.digrams.get(&key).copied();
802 match existing {
803 None => {
804 self.set_digram(key, Some(node));
805 false
806 }
807 Some(other) => {
808 let other_next = self.nodes[other as usize].next;
809 let node_next = self.nodes[node as usize].next;
810 if node == other_next || other == node_next {
811 false
812 } else {
813 self.match_nodes(node, other);
814 true
815 }
816 }
817 }
818 }
819
820 fn match_nodes(&mut self, ss: NodeIx, m: NodeIx) {
821 let m_prev = self.nodes[m as usize].prev;
822 let m_next = self.nodes[m as usize].next;
823 let m_next_next = self.nodes[m_next as usize].next;
824
825 let rule = if let Some(rule) = self.guard_rule(m_prev) {
826 if rule != 0 && self.is_guard(m_next_next) {
827 self.substitute(ss, rule);
828 rule
829 } else {
830 let rule = self.new_rule();
831 let ss2 = self.nodes[ss as usize].next;
832 let last = self.last_node_of_rule(rule);
833 let node1 = self.insert_after(last, self.symbol_of(ss));
834 let node2 = self.insert_after(node1, self.symbol_of(ss2));
835 self.substitute(m, rule);
836 self.substitute(ss, rule);
837 self.set_digram(
838 Self::digram_key_from_symbols(self.symbol_of(node1), self.symbol_of(node2)),
839 Some(node1),
840 );
841 rule
842 }
843 } else {
844 let rule = self.new_rule();
845 let ss2 = self.nodes[ss as usize].next;
846 let last = self.last_node_of_rule(rule);
847 let node1 = self.insert_after(last, self.symbol_of(ss));
848 let node2 = self.insert_after(node1, self.symbol_of(ss2));
849 self.substitute(m, rule);
850 self.substitute(ss, rule);
851 self.set_digram(
852 Self::digram_key_from_symbols(self.symbol_of(node1), self.symbol_of(node2)),
853 Some(node1),
854 );
855 rule
856 };
857
858 let first = self.first_node_of_rule(rule);
859 if let Symbol::NonTerminal(child) = self.symbol_of(first) {
860 if self.rules[child as usize].ref_count == 1 {
861 self.expand(first, child);
862 }
863 }
864 }
865
866 fn delete_node(&mut self, node: NodeIx) {
867 debug_assert!(!self.is_guard(node), "delete_node called on guard");
868 let prev = self.nodes[node as usize].prev;
869 let next = self.nodes[node as usize].next;
870 self.link(prev, next);
871 self.delete_digram(node);
872 if let Symbol::NonTerminal(rule) = self.symbol_of(node) {
873 let next_count = self.rules[rule as usize].ref_count.saturating_sub(1);
874 self.set_rule_ref_count(rule, next_count);
875 }
876 }
877
878 fn substitute(&mut self, node: NodeIx, rule: RuleId) {
879 let prev = self.nodes[node as usize].prev;
880 let first = self.nodes[prev as usize].next;
881 debug_assert!(!self.is_guard(first), "substitute first guard");
882 self.delete_node(first);
883 let second = self.nodes[prev as usize].next;
884 debug_assert!(!self.is_guard(second), "substitute second guard");
885 self.delete_node(second);
886 let _ = self.insert_after(prev, Symbol::NonTerminal(rule));
887 if !self.check(prev) {
888 let next = self.nodes[prev as usize].next;
889 let _ = self.check(next);
890 }
891 }
892
893 fn expand(&mut self, node: NodeIx, rule: RuleId) {
894 let left = self.nodes[node as usize].prev;
895 let right = self.nodes[node as usize].next;
896 self.delete_node(node);
897
898 let first = self.first_node_of_rule(rule);
899 let last = self.last_node_of_rule(rule);
900 self.link(left, first);
901 self.link(last, right);
902
903 let next = self.nodes[last as usize].next;
904 self.set_digram(
905 Self::digram_key_from_symbols(self.symbol_of(last), self.symbol_of(next)),
906 Some(last),
907 );
908
909 let guard = self.rules[rule as usize].guard;
910 self.link(guard, guard);
911 self.set_rule_active(rule, false);
912 }
913
914 fn decode_rule(&self, rule: RuleId, out: &mut Vec<u8>) {
915 let guard = self.rules[rule as usize].guard;
916 let mut node = self.nodes[guard as usize].next;
917 while node != guard {
918 match self.nodes[node as usize].data {
919 NodeData::Guard(_) => unreachable!("guard encountered in active rule body"),
920 NodeData::Sym(Symbol::Terminal(byte)) => out.push(byte),
921 NodeData::Sym(Symbol::NonTerminal(child)) => self.decode_rule(child, out),
922 }
923 node = self.nodes[node as usize].next;
924 }
925 }
926
927 fn rule_tail_bytes(&self, rule: RuleId, limit: usize) -> Vec<u8> {
928 let mut rev = Vec::with_capacity(limit);
929 self.collect_rule_tail_rev(rule, limit, &mut rev);
930 rev.reverse();
931 rev
932 }
933
934 fn collect_rule_tail_rev(&self, rule: RuleId, limit: usize, out_rev: &mut Vec<u8>) {
935 if out_rev.len() >= limit {
936 return;
937 }
938 let guard = self.rules[rule as usize].guard;
939 let mut node = self.nodes[guard as usize].prev;
940 while node != guard && out_rev.len() < limit {
941 match self.nodes[node as usize].data {
942 NodeData::Guard(_) => break,
943 NodeData::Sym(Symbol::Terminal(byte)) => out_rev.push(byte),
944 NodeData::Sym(Symbol::NonTerminal(child)) => {
945 self.collect_rule_tail_rev(child, limit, out_rev);
946 }
947 }
948 node = self.nodes[node as usize].prev;
949 }
950 }
951
952 #[cfg(test)]
953 fn active_rule_ids(&self) -> Vec<RuleId> {
954 self.rules
955 .iter()
956 .enumerate()
957 .filter_map(|(idx, rule)| rule.active.then_some(idx as RuleId))
958 .collect()
959 }
960
961 #[cfg(test)]
962 fn rule_body_symbols(&self, rule: RuleId) -> Vec<Symbol> {
963 let guard = self.rules[rule as usize].guard;
964 let mut out = Vec::new();
965 let mut node = self.nodes[guard as usize].next;
966 while node != guard {
967 out.push(self.symbol_of(node));
968 node = self.nodes[node as usize].next;
969 }
970 out
971 }
972
973 #[cfg(test)]
974 fn validate_rule_shapes(&self) -> Result<(), String> {
975 for rule in self.active_rule_ids() {
976 if rule == 0 {
977 continue;
978 }
979 let len = self.rule_body_symbols(rule).len();
980 if len < 2 {
981 return Err(format!("rule {rule} has rhs length {len}, expected >= 2"));
982 }
983 if self.rules[rule as usize].ref_count < 2 {
984 return Err(format!(
985 "rule {rule} has utility {}, expected >= 2",
986 self.rules[rule as usize].ref_count
987 ));
988 }
989 }
990 Ok(())
991 }
992
993 #[cfg(test)]
994 fn validate_rule_refcounts(&self) -> Result<(), String> {
995 let mut counts = vec![0u32; self.rules.len()];
996 for rule in self.active_rule_ids() {
997 for sym in self.rule_body_symbols(rule) {
998 if let Symbol::NonTerminal(child) = sym {
999 counts[child as usize] += 1;
1000 }
1001 }
1002 }
1003 for rule in self.active_rule_ids() {
1004 if rule == 0 {
1005 continue;
1006 }
1007 let actual = self.rules[rule as usize].ref_count;
1008 let expected = counts[rule as usize];
1009 if actual != expected {
1010 return Err(format!(
1011 "rule {rule} ref_count mismatch: actual={actual}, expected={expected}"
1012 ));
1013 }
1014 }
1015 Ok(())
1016 }
1017
1018 #[cfg(test)]
1019 fn validate_digram_uniqueness(&self) -> Result<(), String> {
1020 let mut seen = AHashMap::<u64, NodeIx>::new();
1021 for rule in self.active_rule_ids() {
1022 let guard = self.rules[rule as usize].guard;
1023 let mut node = self.nodes[guard as usize].next;
1024 while node != guard {
1025 let next = self.nodes[node as usize].next;
1026 if next == guard {
1027 break;
1028 }
1029 let key = Self::digram_key_from_symbols(self.symbol_of(node), self.symbol_of(next));
1030 if let Some(&other) = seen.get(&key) {
1031 let other_next = self.nodes[other as usize].next;
1032 if other_next != node && self.nodes[node as usize].next != other {
1033 return Err(format!(
1034 "duplicate non-overlapping digram for key {key}: {other} and {node}"
1035 ));
1036 }
1037 } else {
1038 seen.insert(key, node);
1039 }
1040 node = next;
1041 }
1042 }
1043 Ok(())
1044 }
1045}
1046
1047fn normalize_pdf(pdf: &mut [f64; 256]) {
1048 let mut sum = 0.0f64;
1049 for value in pdf.iter_mut() {
1050 *value = if value.is_finite() {
1051 (*value).max(PDF_MIN)
1052 } else {
1053 PDF_MIN
1054 };
1055 sum += *value;
1056 }
1057 if !sum.is_finite() || sum <= 0.0 {
1058 pdf.fill(1.0 / 256.0);
1059 return;
1060 }
1061 let inv = 1.0 / sum;
1062 for value in pdf.iter_mut() {
1063 *value *= inv;
1064 }
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069 use super::*;
1070
1071 fn train_model(data: &[u8], context_bytes: usize) -> SequiturModel {
1072 let mut model = SequiturModel::new(context_bytes);
1073 for &byte in data {
1074 model.update(byte);
1075 }
1076 model
1077 }
1078
1079 #[test]
1080 fn sequitur_roundtrips_and_preserves_invariants() {
1081 let data = b"abcabcabcabcabc";
1082 let model = train_model(data, 64);
1083 assert_eq!(model.decode(), data);
1084 model.validate_invariants().unwrap();
1085 }
1086
1087 #[test]
1088 fn sequitur_invariants_hold_after_each_step() {
1089 let mut model = SequiturModel::new(32);
1090 for &byte in b"abracadabra abracadabra" {
1091 model.update(byte);
1092 model.validate_invariants().unwrap();
1093 }
1094 }
1095
1096 #[test]
1097 fn sequitur_canonical_grammar_is_deterministic() {
1098 let model_a = train_model(b"abcabcabcabc", 64);
1099 let model_b = train_model(b"abcabcabcabc", 64);
1100 assert_eq!(model_a.canonical_grammar(), model_b.canonical_grammar());
1101 }
1102
1103 #[test]
1104 fn sequitur_pdf_is_normalized() {
1105 let mut model = train_model(b"abababababa", 32);
1106 let mut pdf = [0.0; 256];
1107 model.fill_pdf(&mut pdf);
1108 let sum: f64 = pdf.iter().sum();
1109 assert!((sum - 1.0).abs() < 1e-9, "sum={sum}");
1110 assert!(pdf.iter().all(|p| p.is_finite() && *p > 0.0));
1111 }
1112
1113 #[test]
1114 fn frozen_updates_do_not_mutate_learned_distribution_after_reset() {
1115 let mut model = train_model(b"banana banana banana", 32);
1116 let mut before = [0.0; 256];
1117 model.fill_pdf(&mut before);
1118 let grammar_before = model.canonical_grammar();
1119 let followers_before = model.followers.clone();
1120 model.reset_frozen();
1121 for &byte in b"ZZZZ" {
1122 model.update_frozen(byte);
1123 }
1124 assert_eq!(grammar_before, model.canonical_grammar());
1125 assert_eq!(followers_before, model.followers);
1126 model.reset_frozen();
1127 let mut after = [0.0; 256];
1128 model.fill_pdf(&mut after);
1129 assert_eq!(before, after);
1130 }
1131
1132 #[test]
1133 fn checkpoint_restore_recovers_exact_state() {
1134 let mut model = train_model(b"mississippi", 32);
1135 let checkpoint = model.checkpoint();
1136 let grammar_before = model.canonical_grammar();
1137 let pdf_before = {
1138 let mut pdf = [0.0; 256];
1139 model.fill_pdf(&mut pdf);
1140 pdf
1141 };
1142 for &byte in b" river" {
1143 model.update(byte);
1144 }
1145 model.restore(&checkpoint);
1146 model.clear_checkpoints();
1147 assert_eq!(grammar_before, model.canonical_grammar());
1148 let mut pdf_after = [0.0; 256];
1149 model.fill_pdf(&mut pdf_after);
1150 assert_eq!(pdf_before, pdf_after);
1151 model.validate_invariants().unwrap();
1152 }
1153
1154 #[test]
1155 fn sequitur_repetitive_binary_inputs_preserve_invariants() {
1156 for data in [
1157 b"\x00\x00\x00\x00\x00\x00\x00\x00".as_slice(),
1158 b"\x00\x01\x00\x01\x00\x01\x00\x01".as_slice(),
1159 ] {
1160 let mut model = SequiturModel::new(32);
1161 for (idx, &byte) in data.iter().enumerate() {
1162 model.update(byte);
1163 if let Err(err) = model.validate_invariants() {
1164 let active = model
1165 .active_rule_ids()
1166 .into_iter()
1167 .map(|rule| {
1168 (
1169 rule,
1170 model.rules[rule as usize].ref_count,
1171 model.rule_body_symbols(rule),
1172 )
1173 })
1174 .collect::<Vec<_>>();
1175 panic!(
1176 "invariants failed at step {idx} for {:?}: {err}; active={active:?}",
1177 data
1178 );
1179 }
1180 }
1181 }
1182 }
1183}