Skip to main content

infotheory/backends/
sequitur.rs

1use ahash::AHashMap;
2
3const PDF_MIN: f64 = crate::mixture::DEFAULT_MIN_PROB;
4const RAW_FALLBACK_MAX: usize = 4;
5
6type NodeIx = u32;
7type RuleId = u32;
8
9#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
10enum Symbol {
11    Terminal(u8),
12    NonTerminal(RuleId),
13}
14
15#[derive(Clone, Copy, Debug, Eq, PartialEq)]
16enum NodeData {
17    Guard(RuleId),
18    Sym(Symbol),
19}
20
21#[derive(Clone, Copy, Debug)]
22struct Node {
23    prev: NodeIx,
24    next: NodeIx,
25    data: NodeData,
26}
27
28#[derive(Clone, Debug)]
29struct Rule {
30    guard: NodeIx,
31    ref_count: u32,
32    active: bool,
33}
34
35#[derive(Clone, Debug, Default, Eq, PartialEq)]
36struct ContextFollowers {
37    counts: Vec<(u8, u64)>,
38    total: u64,
39}
40
41impl ContextFollowers {
42    fn observe(&mut self, symbol: u8) {
43        if let Some((_, count)) = self.counts.iter_mut().find(|(s, _)| *s == symbol) {
44            *count += 1;
45        } else {
46            self.counts.push((symbol, 1));
47        }
48        self.total += 1;
49    }
50
51    fn distinct(&self) -> usize {
52        self.counts.len()
53    }
54}
55
56#[derive(Clone, Debug)]
57enum UndoOp {
58    SetPrev {
59        node: NodeIx,
60        old: NodeIx,
61    },
62    SetNext {
63        node: NodeIx,
64        old: NodeIx,
65    },
66    SetRuleRefCount {
67        rule: RuleId,
68        old: u32,
69    },
70    SetRuleActive {
71        rule: RuleId,
72        old: bool,
73    },
74    SetDigram {
75        key: u64,
76        old: Option<NodeIx>,
77    },
78    SetContextFollowers {
79        key: Box<[u8]>,
80        old: Option<ContextFollowers>,
81    },
82    SetUnigramSymbol {
83        symbol: u8,
84        old_count: u64,
85        old_total: u64,
86    },
87    SetCommittedRawTail {
88        old: Vec<u8>,
89    },
90    SetFrozenRawTail {
91        old: Vec<u8>,
92    },
93}
94
95#[derive(Clone, Copy, Debug, Eq, PartialEq)]
96/// Opaque rollback marker for [`SequiturModel`] state.
97///
98/// Checkpoints are created with [`SequiturModel::checkpoint`] and later applied
99/// with [`SequiturModel::restore`].
100pub struct SequiturCheckpoint {
101    undo_len: usize,
102    node_len: usize,
103    rule_len: usize,
104}
105
106#[derive(Clone, Debug, Eq, PartialEq)]
107/// A canonical Sequitur rule in a deterministic exported grammar.
108pub struct CanonicalRule {
109    /// Canonical rule id (index within [`CanonicalGrammar::rules`]).
110    pub id: usize,
111    /// Right-hand side symbols for this rule.
112    pub rhs: Vec<CanonicalSymbol>,
113}
114
115#[derive(Clone, Copy, Debug, Eq, PartialEq)]
116/// Symbol in a [`CanonicalRule`] right-hand side.
117pub enum CanonicalSymbol {
118    /// Terminal byte symbol.
119    Terminal(u8),
120    /// Reference to another canonical rule by canonical id.
121    NonTerminal(usize),
122}
123
124#[derive(Clone, Debug, Eq, PartialEq)]
125/// Deterministic snapshot of the currently learned Sequitur grammar.
126pub struct CanonicalGrammar {
127    /// Reachable rules in preorder, with rule `0` as the start rule.
128    pub rules: Vec<CanonicalRule>,
129}
130
131#[derive(Clone, Debug)]
132/// Online Sequitur-based byte predictor with rollback support.
133///
134/// The model learns grammar structure from committed updates and exposes a
135/// normalized next-byte distribution for coding and planning.
136pub struct SequiturModel {
137    context_bytes: usize,
138    nodes: Vec<Node>,
139    rules: Vec<Rule>,
140    dummy: NodeIx,
141    digrams: AHashMap<u64, NodeIx>,
142    followers: AHashMap<Box<[u8]>, ContextFollowers>,
143    unigram: [u64; 256],
144    unigram_total: u64,
145    committed_raw_tail: Vec<u8>,
146    frozen_raw_tail: Vec<u8>,
147    pdf: [f64; 256],
148    pdf_valid: bool,
149    undo: Vec<UndoOp>,
150    undo_enabled: bool,
151}
152
153impl SequiturModel {
154    /// Create a new Sequitur model.
155    ///
156    /// `context_bytes` is clamped to at least `2` and controls the maximum tail
157    /// context length considered by the fallback context statistics.
158    pub fn new(context_bytes: usize) -> Self {
159        let context_bytes = context_bytes.max(2);
160        let mut nodes = Vec::with_capacity(16);
161        nodes.push(Node {
162            prev: 0,
163            next: 0,
164            data: NodeData::Guard(0),
165        });
166        nodes.push(Node {
167            prev: 1,
168            next: 1,
169            data: NodeData::Guard(0),
170        });
171        let rules = vec![Rule {
172            guard: 0,
173            ref_count: 0,
174            active: true,
175        }];
176        Self {
177            context_bytes,
178            nodes,
179            rules,
180            dummy: 1,
181            digrams: AHashMap::new(),
182            followers: AHashMap::new(),
183            unigram: [0; 256],
184            unigram_total: 0,
185            committed_raw_tail: Vec::with_capacity(RAW_FALLBACK_MAX),
186            frozen_raw_tail: Vec::with_capacity(RAW_FALLBACK_MAX),
187            pdf: [1.0 / 256.0; 256],
188            pdf_valid: false,
189            undo: Vec::new(),
190            undo_enabled: false,
191        }
192    }
193
194    /// Reserve internal capacity for an upcoming stream.
195    ///
196    /// This only affects allocation behavior and does not change model state.
197    pub fn reserve_for_stream(&mut self, additional_symbols: usize) {
198        self.nodes.reserve(additional_symbols.saturating_mul(2));
199        self.digrams.reserve(additional_symbols);
200    }
201
202    /// Prepare the model for a new stream.
203    ///
204    /// If `total_symbols` is provided, internal storage is pre-reserved. Any
205    /// speculative frozen tail is cleared.
206    pub fn begin_stream(&mut self, total_symbols: Option<u64>) {
207        if let Some(total) = total_symbols {
208            let reserve = usize::try_from(total).unwrap_or(usize::MAX / 4);
209            self.reserve_for_stream(reserve);
210        }
211        self.frozen_raw_tail.clear();
212        self.pdf_valid = false;
213    }
214
215    /// Finish a stream.
216    ///
217    /// This is currently a no-op and exists for API symmetry with other
218    /// predictors.
219    pub fn finish_stream(&mut self) {}
220
221    /// Capture a rollback point for subsequent speculative updates.
222    ///
223    /// Calling this enables undo journaling until checkpoints are cleared.
224    pub fn checkpoint(&mut self) -> SequiturCheckpoint {
225        self.undo_enabled = true;
226        SequiturCheckpoint {
227            undo_len: self.undo.len(),
228            node_len: self.nodes.len(),
229            rule_len: self.rules.len(),
230        }
231    }
232
233    /// Restore model state to a previously created checkpoint.
234    pub fn restore(&mut self, checkpoint: &SequiturCheckpoint) {
235        let saved = self.undo_enabled;
236        self.undo_enabled = false;
237        while self.undo.len() > checkpoint.undo_len {
238            let op = self.undo.pop().expect("undo underflow");
239            self.apply_undo(op);
240        }
241        self.nodes.truncate(checkpoint.node_len);
242        self.rules.truncate(checkpoint.rule_len);
243        self.undo_enabled = saved;
244        self.pdf_valid = false;
245    }
246
247    /// Drop all stored undo history and disable journaling.
248    pub fn clear_checkpoints(&mut self) {
249        self.undo.clear();
250        self.undo_enabled = false;
251    }
252
253    /// Clear speculative frozen updates without touching committed state.
254    pub fn reset_frozen(&mut self) {
255        self.frozen_raw_tail.clear();
256        self.pdf_valid = false;
257    }
258
259    /// Fill `out` with the current normalized next-byte probability mass.
260    pub fn fill_pdf(&mut self, out: &mut [f64; 256]) {
261        self.ensure_pdf();
262        out.copy_from_slice(&self.pdf);
263    }
264
265    /// Return the current normalized next-byte probability mass.
266    pub fn pdf(&mut self) -> &[f64; 256] {
267        self.ensure_pdf();
268        &self.pdf
269    }
270
271    /// Return the natural-log probability of `symbol` under the current model.
272    ///
273    /// The symbol probability is clamped from below by `min_prob`.
274    pub fn log_prob(&mut self, symbol: u8, min_prob: f64) -> f64 {
275        self.ensure_pdf();
276        self.pdf[symbol as usize].max(min_prob).ln()
277    }
278
279    /// Commit one observed symbol into grammar and context statistics.
280    pub fn update(&mut self, symbol: u8) {
281        self.observe_symbol_in_stats(symbol);
282        self.append_terminal(symbol);
283        self.record_committed_raw_tail(symbol);
284        if !self.frozen_raw_tail.is_empty() {
285            self.record_frozen_raw_tail_inner(Vec::new());
286        }
287        self.pdf_valid = false;
288    }
289
290    /// Apply a speculative symbol update used for lookahead.
291    ///
292    /// Frozen updates affect only the temporary raw-tail context and can be
293    /// cleared with [`Self::reset_frozen`]. They do not mutate committed grammar
294    /// structure or follower counts.
295    pub fn update_frozen(&mut self, symbol: u8) {
296        let mut next = self.frozen_raw_tail.clone();
297        next.push(symbol);
298        if next.len() > RAW_FALLBACK_MAX {
299            let drain = next.len() - RAW_FALLBACK_MAX;
300            next.drain(0..drain);
301        }
302        self.record_frozen_raw_tail_inner(next);
303        self.pdf_valid = false;
304    }
305
306    /// Decode the current start rule into the equivalent terminal byte stream.
307    pub fn decode(&self) -> Vec<u8> {
308        let mut out = Vec::new();
309        self.decode_rule(0, &mut out);
310        out
311    }
312
313    /// Export a deterministic canonical grammar snapshot.
314    ///
315    /// Rules are emitted in preorder from the start rule, and non-terminals are
316    /// rewritten to canonical rule ids.
317    pub fn canonical_grammar(&self) -> CanonicalGrammar {
318        let mut order = Vec::<RuleId>::new();
319        let mut seen = AHashMap::<RuleId, usize>::new();
320        self.collect_rule_preorder(0, &mut order, &mut seen);
321        let mut canonical = Vec::with_capacity(order.len());
322        for (idx, &rule_id) in order.iter().enumerate() {
323            let mut rhs = Vec::new();
324            let guard = self.rules[rule_id as usize].guard;
325            let mut node = self.nodes[guard as usize].next;
326            while node != guard {
327                match self.nodes[node as usize].data {
328                    NodeData::Guard(_) => unreachable!("guard in rule body"),
329                    NodeData::Sym(Symbol::Terminal(byte)) => {
330                        rhs.push(CanonicalSymbol::Terminal(byte))
331                    }
332                    NodeData::Sym(Symbol::NonTerminal(child)) => {
333                        let mapped = *seen
334                            .get(&child)
335                            .expect("canonical grammar missing child mapping");
336                        rhs.push(CanonicalSymbol::NonTerminal(mapped));
337                    }
338                }
339                node = self.nodes[node as usize].next;
340            }
341            canonical.push(CanonicalRule { id: idx, rhs });
342        }
343        CanonicalGrammar { rules: canonical }
344    }
345
346    /// Record a per-step predictive trace for `data`.
347    ///
348    /// At each position this captures the current PDF prefix of length
349    /// `alphabet_prefix` (capped at `256`) and then commits the observed byte.
350    pub fn predictive_trace(&mut self, data: &[u8], alphabet_prefix: usize) -> Vec<Vec<f64>> {
351        let mut trace = Vec::with_capacity(data.len());
352        let mut pdf = [0.0; 256];
353        for &byte in data {
354            self.fill_pdf(&mut pdf);
355            trace.push(pdf[..alphabet_prefix.min(256)].to_vec());
356            self.update(byte);
357        }
358        trace
359    }
360
361    #[cfg(test)]
362    pub fn validate_invariants(&self) -> Result<(), String> {
363        self.validate_rule_shapes()?;
364        self.validate_rule_refcounts()?;
365        self.validate_digram_uniqueness()?;
366        Ok(())
367    }
368
369    fn collect_rule_preorder(
370        &self,
371        rule_id: RuleId,
372        order: &mut Vec<RuleId>,
373        seen: &mut AHashMap<RuleId, usize>,
374    ) {
375        if seen.contains_key(&rule_id) {
376            return;
377        }
378        let idx = order.len();
379        seen.insert(rule_id, idx);
380        order.push(rule_id);
381        let guard = self.rules[rule_id as usize].guard;
382        let mut node = self.nodes[guard as usize].next;
383        while node != guard {
384            if let NodeData::Sym(Symbol::NonTerminal(child)) = self.nodes[node as usize].data {
385                if self.rules[child as usize].active {
386                    self.collect_rule_preorder(child, order, seen);
387                }
388            }
389            node = self.nodes[node as usize].next;
390        }
391    }
392
393    fn apply_undo(&mut self, op: UndoOp) {
394        match op {
395            UndoOp::SetPrev { node, old } => {
396                self.nodes[node as usize].prev = old;
397            }
398            UndoOp::SetNext { node, old } => {
399                self.nodes[node as usize].next = old;
400            }
401            UndoOp::SetRuleRefCount { rule, old } => {
402                self.rules[rule as usize].ref_count = old;
403            }
404            UndoOp::SetRuleActive { rule, old } => {
405                self.rules[rule as usize].active = old;
406            }
407            UndoOp::SetDigram { key, old } => {
408                if let Some(node) = old {
409                    self.digrams.insert(key, node);
410                } else {
411                    self.digrams.remove(&key);
412                }
413            }
414            UndoOp::SetContextFollowers { key, old } => {
415                if let Some(state) = old {
416                    self.followers.insert(key, state);
417                } else {
418                    self.followers.remove(key.as_ref());
419                }
420            }
421            UndoOp::SetUnigramSymbol {
422                symbol,
423                old_count,
424                old_total,
425            } => {
426                self.unigram[symbol as usize] = old_count;
427                self.unigram_total = old_total;
428            }
429            UndoOp::SetCommittedRawTail { old } => {
430                self.committed_raw_tail = old;
431            }
432            UndoOp::SetFrozenRawTail { old } => {
433                self.frozen_raw_tail = old;
434            }
435        }
436    }
437
438    fn push_undo(&mut self, op: UndoOp) {
439        if self.undo_enabled {
440            self.undo.push(op);
441        }
442    }
443
444    fn set_prev(&mut self, node: NodeIx, prev: NodeIx) {
445        let old = self.nodes[node as usize].prev;
446        if old != prev {
447            self.push_undo(UndoOp::SetPrev { node, old });
448            self.nodes[node as usize].prev = prev;
449        }
450    }
451
452    fn set_next(&mut self, node: NodeIx, next: NodeIx) {
453        let old = self.nodes[node as usize].next;
454        if old != next {
455            self.push_undo(UndoOp::SetNext { node, old });
456            self.nodes[node as usize].next = next;
457        }
458    }
459
460    fn set_rule_ref_count(&mut self, rule: RuleId, ref_count: u32) {
461        let old = self.rules[rule as usize].ref_count;
462        if old != ref_count {
463            self.push_undo(UndoOp::SetRuleRefCount { rule, old });
464            self.rules[rule as usize].ref_count = ref_count;
465        }
466    }
467
468    fn set_rule_active(&mut self, rule: RuleId, active: bool) {
469        let old = self.rules[rule as usize].active;
470        if old != active {
471            self.push_undo(UndoOp::SetRuleActive { rule, old });
472            self.rules[rule as usize].active = active;
473        }
474    }
475
476    fn set_digram(&mut self, key: u64, value: Option<NodeIx>) {
477        let old = self.digrams.get(&key).copied();
478        if old == value {
479            return;
480        }
481        self.push_undo(UndoOp::SetDigram { key, old });
482        if let Some(node) = value {
483            self.digrams.insert(key, node);
484        } else {
485            self.digrams.remove(&key);
486        }
487    }
488
489    fn record_context_followers(&mut self, key: &[u8], new_state: Option<ContextFollowers>) {
490        let boxed: Box<[u8]> = key.to_vec().into_boxed_slice();
491        let old = self.followers.get(boxed.as_ref()).cloned();
492        if old == new_state {
493            return;
494        }
495        self.push_undo(UndoOp::SetContextFollowers {
496            key: boxed.clone(),
497            old,
498        });
499        if let Some(state) = new_state {
500            self.followers.insert(boxed, state);
501        } else {
502            self.followers.remove(boxed.as_ref());
503        }
504    }
505
506    fn record_unigram(&mut self, symbol: u8, next_count: u64, next_total: u64) {
507        let old_count = self.unigram[symbol as usize];
508        let old_total = self.unigram_total;
509        if old_count == next_count && old_total == next_total {
510            return;
511        }
512        self.push_undo(UndoOp::SetUnigramSymbol {
513            symbol,
514            old_count,
515            old_total,
516        });
517        self.unigram[symbol as usize] = next_count;
518        self.unigram_total = next_total;
519    }
520
521    fn record_committed_raw_tail(&mut self, symbol: u8) {
522        let mut next = self.committed_raw_tail.clone();
523        next.push(symbol);
524        if next.len() > RAW_FALLBACK_MAX {
525            let drain = next.len() - RAW_FALLBACK_MAX;
526            next.drain(0..drain);
527        }
528        if next != self.committed_raw_tail {
529            self.push_undo(UndoOp::SetCommittedRawTail {
530                old: self.committed_raw_tail.clone(),
531            });
532            self.committed_raw_tail = next;
533        }
534    }
535
536    fn record_frozen_raw_tail_inner(&mut self, next: Vec<u8>) {
537        if next != self.frozen_raw_tail {
538            self.push_undo(UndoOp::SetFrozenRawTail {
539                old: self.frozen_raw_tail.clone(),
540            });
541            self.frozen_raw_tail = next;
542        }
543    }
544
545    fn ensure_pdf(&mut self) {
546        if self.pdf_valid {
547            return;
548        }
549
550        let denom = (self.unigram_total as f64) + 128.0;
551        for (idx, slot) in self.pdf.iter_mut().enumerate() {
552            *slot = ((self.unigram[idx] as f64) + 0.5) / denom;
553        }
554
555        let contexts = self.current_contexts();
556        let mut next = [0.0; 256];
557        for context in contexts {
558            let Some(stats) = self.followers.get(context.as_slice()) else {
559                continue;
560            };
561            let distinct = stats.distinct();
562            if stats.total == 0 || distinct == 0 {
563                continue;
564            }
565            let total = stats.total as f64;
566            let types = distinct as f64;
567            let escape = types / (total + types);
568            for i in 0..256 {
569                next[i] = self.pdf[i] * escape;
570            }
571            for &(symbol, count) in &stats.counts {
572                next[symbol as usize] += (count as f64) / (total + types);
573            }
574            self.pdf.copy_from_slice(&next);
575        }
576
577        normalize_pdf(&mut self.pdf);
578        self.pdf_valid = true;
579    }
580
581    fn current_contexts(&self) -> Vec<Vec<u8>> {
582        let mut out = Vec::<Vec<u8>>::new();
583        let raw_tail = self.effective_raw_tail();
584        for len in 1..=raw_tail.len().min(RAW_FALLBACK_MAX) {
585            let ctx = raw_tail[raw_tail.len() - len..].to_vec();
586            if !out.iter().any(|existing| existing == &ctx) {
587                out.push(ctx);
588            }
589        }
590
591        let mut rule_chain = Vec::<RuleId>::new();
592        rule_chain.push(0);
593        let mut current = 0u32;
594        loop {
595            let guard = self.rules[current as usize].guard;
596            let last = self.nodes[guard as usize].prev;
597            if last == guard {
598                break;
599            }
600            match self.nodes[last as usize].data {
601                NodeData::Sym(Symbol::NonTerminal(child)) if self.rules[child as usize].active => {
602                    rule_chain.push(child);
603                    current = child;
604                }
605                _ => break,
606            }
607        }
608
609        for &rule_id in &rule_chain {
610            let ctx = self.rule_tail_bytes(rule_id, self.context_bytes);
611            if !ctx.is_empty() && !out.iter().any(|existing| existing == &ctx) {
612                out.push(ctx);
613            }
614        }
615
616        out
617    }
618
619    fn effective_raw_tail(&self) -> Vec<u8> {
620        let mut out = Vec::with_capacity(RAW_FALLBACK_MAX);
621        let total = self.committed_raw_tail.len() + self.frozen_raw_tail.len();
622        let keep_from = total.saturating_sub(RAW_FALLBACK_MAX);
623        for (idx, &byte) in self
624            .committed_raw_tail
625            .iter()
626            .chain(self.frozen_raw_tail.iter())
627            .enumerate()
628        {
629            if idx >= keep_from {
630                out.push(byte);
631            }
632        }
633        out
634    }
635
636    fn observe_symbol_in_stats(&mut self, symbol: u8) {
637        let contexts = self.current_contexts();
638        for context in contexts {
639            let mut state = self
640                .followers
641                .get(context.as_slice())
642                .cloned()
643                .unwrap_or_default();
644            state.observe(symbol);
645            self.record_context_followers(&context, Some(state));
646        }
647        let old_count = self.unigram[symbol as usize];
648        let old_total = self.unigram_total;
649        self.record_unigram(symbol, old_count + 1, old_total + 1);
650    }
651
652    fn append_terminal(&mut self, byte: u8) {
653        let last = self.last_node_of_rule(0);
654        let _ = self.insert_after(last, Symbol::Terminal(byte));
655        let _ = self.check(last);
656    }
657
658    fn alloc_node(&mut self, data: NodeData) -> NodeIx {
659        let idx = self.nodes.len() as NodeIx;
660        let (prev, next) = match data {
661            NodeData::Guard(_) => (idx, idx),
662            NodeData::Sym(_) => (self.dummy, self.dummy),
663        };
664        self.nodes.push(Node { prev, next, data });
665        idx
666    }
667
668    fn new_rule(&mut self) -> RuleId {
669        let id = self.rules.len() as RuleId;
670        let guard = self.alloc_node(NodeData::Guard(id));
671        self.rules.push(Rule {
672            guard,
673            ref_count: 0,
674            active: true,
675        });
676        id
677    }
678
679    fn is_guard(&self, node: NodeIx) -> bool {
680        matches!(self.nodes[node as usize].data, NodeData::Guard(_))
681    }
682
683    fn symbol_of(&self, node: NodeIx) -> Symbol {
684        match self.nodes[node as usize].data {
685            NodeData::Sym(symbol) => symbol,
686            NodeData::Guard(_) => panic!("node_symbol called on guard node"),
687        }
688    }
689
690    fn symbol_maybe(&self, node: NodeIx) -> Option<Symbol> {
691        match self.nodes[node as usize].data {
692            NodeData::Sym(symbol) => Some(symbol),
693            NodeData::Guard(_) => None,
694        }
695    }
696
697    fn guard_rule(&self, node: NodeIx) -> Option<RuleId> {
698        match self.nodes[node as usize].data {
699            NodeData::Guard(rule) => Some(rule),
700            NodeData::Sym(_) => None,
701        }
702    }
703
704    fn first_node_of_rule(&self, rule: RuleId) -> NodeIx {
705        let guard = self.rules[rule as usize].guard;
706        self.nodes[guard as usize].next
707    }
708
709    fn last_node_of_rule(&self, rule: RuleId) -> NodeIx {
710        let guard = self.rules[rule as usize].guard;
711        self.nodes[guard as usize].prev
712    }
713
714    fn encode_symbol(symbol: Symbol) -> u32 {
715        match symbol {
716            Symbol::Terminal(byte) => byte as u32,
717            Symbol::NonTerminal(rule) => 256u32.wrapping_add(rule),
718        }
719    }
720
721    fn digram_key_from_symbols(left: Symbol, right: Symbol) -> u64 {
722        ((Self::encode_symbol(left) as u64) << 32) | (Self::encode_symbol(right) as u64)
723    }
724
725    fn digram_key_at(&self, node: NodeIx) -> Option<u64> {
726        if self.is_guard(node) {
727            return None;
728        }
729        let next = self.nodes[node as usize].next;
730        if self.is_guard(next) {
731            return None;
732        }
733        Some(Self::digram_key_from_symbols(
734            self.symbol_of(node),
735            self.symbol_of(next),
736        ))
737    }
738
739    fn link(&mut self, left: NodeIx, right: NodeIx) {
740        let left_prev = self.nodes[left as usize].prev;
741        let left_next = self.nodes[left as usize].next;
742        let right_prev = self.nodes[right as usize].prev;
743        let right_next = self.nodes[right as usize].next;
744
745        if !self.is_guard(left_next) {
746            self.delete_digram(left);
747
748            match (
749                self.symbol_maybe(right_prev),
750                self.symbol_maybe(right),
751                self.symbol_maybe(right_next),
752            ) {
753                (Some(sym1), Some(sym2), Some(sym3)) if sym1 == sym2 && sym2 == sym3 => {
754                    self.set_digram(Self::digram_key_from_symbols(sym2, sym3), Some(right));
755                }
756                _ => {}
757            }
758
759            match (
760                self.symbol_maybe(left_prev),
761                self.symbol_maybe(left),
762                self.symbol_maybe(left_next),
763            ) {
764                (Some(sym1), Some(sym2), Some(sym3)) if sym1 == sym2 && sym2 == sym3 => {
765                    self.set_digram(Self::digram_key_from_symbols(sym1, sym2), Some(left_prev));
766                }
767                _ => {}
768            }
769        }
770
771        self.set_next(left, right);
772        self.set_prev(right, left);
773    }
774
775    fn insert_after(&mut self, node: NodeIx, symbol: Symbol) -> NodeIx {
776        let new_node = self.alloc_node(NodeData::Sym(symbol));
777        let next = self.nodes[node as usize].next;
778        self.link(new_node, next);
779        self.link(node, new_node);
780        if let Symbol::NonTerminal(rule) = symbol {
781            let next_count = self.rules[rule as usize].ref_count.saturating_add(1);
782            self.set_rule_ref_count(rule, next_count);
783        }
784        new_node
785    }
786
787    fn delete_digram(&mut self, node: NodeIx) {
788        let Some(key) = self.digram_key_at(node) else {
789            return;
790        };
791        match self.digrams.get(&key).copied() {
792            Some(existing) if existing != node => {}
793            _ => self.set_digram(key, None),
794        }
795    }
796
797    fn check(&mut self, node: NodeIx) -> bool {
798        let Some(key) = self.digram_key_at(node) else {
799            return false;
800        };
801        let existing = self.digrams.get(&key).copied();
802        match existing {
803            None => {
804                self.set_digram(key, Some(node));
805                false
806            }
807            Some(other) => {
808                let other_next = self.nodes[other as usize].next;
809                let node_next = self.nodes[node as usize].next;
810                if node == other_next || other == node_next {
811                    false
812                } else {
813                    self.match_nodes(node, other);
814                    true
815                }
816            }
817        }
818    }
819
820    fn match_nodes(&mut self, ss: NodeIx, m: NodeIx) {
821        let m_prev = self.nodes[m as usize].prev;
822        let m_next = self.nodes[m as usize].next;
823        let m_next_next = self.nodes[m_next as usize].next;
824
825        let rule = if let Some(rule) = self.guard_rule(m_prev) {
826            if rule != 0 && self.is_guard(m_next_next) {
827                self.substitute(ss, rule);
828                rule
829            } else {
830                let rule = self.new_rule();
831                let ss2 = self.nodes[ss as usize].next;
832                let last = self.last_node_of_rule(rule);
833                let node1 = self.insert_after(last, self.symbol_of(ss));
834                let node2 = self.insert_after(node1, self.symbol_of(ss2));
835                self.substitute(m, rule);
836                self.substitute(ss, rule);
837                self.set_digram(
838                    Self::digram_key_from_symbols(self.symbol_of(node1), self.symbol_of(node2)),
839                    Some(node1),
840                );
841                rule
842            }
843        } else {
844            let rule = self.new_rule();
845            let ss2 = self.nodes[ss as usize].next;
846            let last = self.last_node_of_rule(rule);
847            let node1 = self.insert_after(last, self.symbol_of(ss));
848            let node2 = self.insert_after(node1, self.symbol_of(ss2));
849            self.substitute(m, rule);
850            self.substitute(ss, rule);
851            self.set_digram(
852                Self::digram_key_from_symbols(self.symbol_of(node1), self.symbol_of(node2)),
853                Some(node1),
854            );
855            rule
856        };
857
858        let first = self.first_node_of_rule(rule);
859        if let Symbol::NonTerminal(child) = self.symbol_of(first) {
860            if self.rules[child as usize].ref_count == 1 {
861                self.expand(first, child);
862            }
863        }
864    }
865
866    fn delete_node(&mut self, node: NodeIx) {
867        debug_assert!(!self.is_guard(node), "delete_node called on guard");
868        let prev = self.nodes[node as usize].prev;
869        let next = self.nodes[node as usize].next;
870        self.link(prev, next);
871        self.delete_digram(node);
872        if let Symbol::NonTerminal(rule) = self.symbol_of(node) {
873            let next_count = self.rules[rule as usize].ref_count.saturating_sub(1);
874            self.set_rule_ref_count(rule, next_count);
875        }
876    }
877
878    fn substitute(&mut self, node: NodeIx, rule: RuleId) {
879        let prev = self.nodes[node as usize].prev;
880        let first = self.nodes[prev as usize].next;
881        debug_assert!(!self.is_guard(first), "substitute first guard");
882        self.delete_node(first);
883        let second = self.nodes[prev as usize].next;
884        debug_assert!(!self.is_guard(second), "substitute second guard");
885        self.delete_node(second);
886        let _ = self.insert_after(prev, Symbol::NonTerminal(rule));
887        if !self.check(prev) {
888            let next = self.nodes[prev as usize].next;
889            let _ = self.check(next);
890        }
891    }
892
893    fn expand(&mut self, node: NodeIx, rule: RuleId) {
894        let left = self.nodes[node as usize].prev;
895        let right = self.nodes[node as usize].next;
896        self.delete_node(node);
897
898        let first = self.first_node_of_rule(rule);
899        let last = self.last_node_of_rule(rule);
900        self.link(left, first);
901        self.link(last, right);
902
903        let next = self.nodes[last as usize].next;
904        self.set_digram(
905            Self::digram_key_from_symbols(self.symbol_of(last), self.symbol_of(next)),
906            Some(last),
907        );
908
909        let guard = self.rules[rule as usize].guard;
910        self.link(guard, guard);
911        self.set_rule_active(rule, false);
912    }
913
914    fn decode_rule(&self, rule: RuleId, out: &mut Vec<u8>) {
915        let guard = self.rules[rule as usize].guard;
916        let mut node = self.nodes[guard as usize].next;
917        while node != guard {
918            match self.nodes[node as usize].data {
919                NodeData::Guard(_) => unreachable!("guard encountered in active rule body"),
920                NodeData::Sym(Symbol::Terminal(byte)) => out.push(byte),
921                NodeData::Sym(Symbol::NonTerminal(child)) => self.decode_rule(child, out),
922            }
923            node = self.nodes[node as usize].next;
924        }
925    }
926
927    fn rule_tail_bytes(&self, rule: RuleId, limit: usize) -> Vec<u8> {
928        let mut rev = Vec::with_capacity(limit);
929        self.collect_rule_tail_rev(rule, limit, &mut rev);
930        rev.reverse();
931        rev
932    }
933
934    fn collect_rule_tail_rev(&self, rule: RuleId, limit: usize, out_rev: &mut Vec<u8>) {
935        if out_rev.len() >= limit {
936            return;
937        }
938        let guard = self.rules[rule as usize].guard;
939        let mut node = self.nodes[guard as usize].prev;
940        while node != guard && out_rev.len() < limit {
941            match self.nodes[node as usize].data {
942                NodeData::Guard(_) => break,
943                NodeData::Sym(Symbol::Terminal(byte)) => out_rev.push(byte),
944                NodeData::Sym(Symbol::NonTerminal(child)) => {
945                    self.collect_rule_tail_rev(child, limit, out_rev);
946                }
947            }
948            node = self.nodes[node as usize].prev;
949        }
950    }
951
952    #[cfg(test)]
953    fn active_rule_ids(&self) -> Vec<RuleId> {
954        self.rules
955            .iter()
956            .enumerate()
957            .filter_map(|(idx, rule)| rule.active.then_some(idx as RuleId))
958            .collect()
959    }
960
961    #[cfg(test)]
962    fn rule_body_symbols(&self, rule: RuleId) -> Vec<Symbol> {
963        let guard = self.rules[rule as usize].guard;
964        let mut out = Vec::new();
965        let mut node = self.nodes[guard as usize].next;
966        while node != guard {
967            out.push(self.symbol_of(node));
968            node = self.nodes[node as usize].next;
969        }
970        out
971    }
972
973    #[cfg(test)]
974    fn validate_rule_shapes(&self) -> Result<(), String> {
975        for rule in self.active_rule_ids() {
976            if rule == 0 {
977                continue;
978            }
979            let len = self.rule_body_symbols(rule).len();
980            if len < 2 {
981                return Err(format!("rule {rule} has rhs length {len}, expected >= 2"));
982            }
983            if self.rules[rule as usize].ref_count < 2 {
984                return Err(format!(
985                    "rule {rule} has utility {}, expected >= 2",
986                    self.rules[rule as usize].ref_count
987                ));
988            }
989        }
990        Ok(())
991    }
992
993    #[cfg(test)]
994    fn validate_rule_refcounts(&self) -> Result<(), String> {
995        let mut counts = vec![0u32; self.rules.len()];
996        for rule in self.active_rule_ids() {
997            for sym in self.rule_body_symbols(rule) {
998                if let Symbol::NonTerminal(child) = sym {
999                    counts[child as usize] += 1;
1000                }
1001            }
1002        }
1003        for rule in self.active_rule_ids() {
1004            if rule == 0 {
1005                continue;
1006            }
1007            let actual = self.rules[rule as usize].ref_count;
1008            let expected = counts[rule as usize];
1009            if actual != expected {
1010                return Err(format!(
1011                    "rule {rule} ref_count mismatch: actual={actual}, expected={expected}"
1012                ));
1013            }
1014        }
1015        Ok(())
1016    }
1017
1018    #[cfg(test)]
1019    fn validate_digram_uniqueness(&self) -> Result<(), String> {
1020        let mut seen = AHashMap::<u64, NodeIx>::new();
1021        for rule in self.active_rule_ids() {
1022            let guard = self.rules[rule as usize].guard;
1023            let mut node = self.nodes[guard as usize].next;
1024            while node != guard {
1025                let next = self.nodes[node as usize].next;
1026                if next == guard {
1027                    break;
1028                }
1029                let key = Self::digram_key_from_symbols(self.symbol_of(node), self.symbol_of(next));
1030                if let Some(&other) = seen.get(&key) {
1031                    let other_next = self.nodes[other as usize].next;
1032                    if other_next != node && self.nodes[node as usize].next != other {
1033                        return Err(format!(
1034                            "duplicate non-overlapping digram for key {key}: {other} and {node}"
1035                        ));
1036                    }
1037                } else {
1038                    seen.insert(key, node);
1039                }
1040                node = next;
1041            }
1042        }
1043        Ok(())
1044    }
1045}
1046
1047fn normalize_pdf(pdf: &mut [f64; 256]) {
1048    let mut sum = 0.0f64;
1049    for value in pdf.iter_mut() {
1050        *value = if value.is_finite() {
1051            (*value).max(PDF_MIN)
1052        } else {
1053            PDF_MIN
1054        };
1055        sum += *value;
1056    }
1057    if !sum.is_finite() || sum <= 0.0 {
1058        pdf.fill(1.0 / 256.0);
1059        return;
1060    }
1061    let inv = 1.0 / sum;
1062    for value in pdf.iter_mut() {
1063        *value *= inv;
1064    }
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069    use super::*;
1070
1071    fn train_model(data: &[u8], context_bytes: usize) -> SequiturModel {
1072        let mut model = SequiturModel::new(context_bytes);
1073        for &byte in data {
1074            model.update(byte);
1075        }
1076        model
1077    }
1078
1079    #[test]
1080    fn sequitur_roundtrips_and_preserves_invariants() {
1081        let data = b"abcabcabcabcabc";
1082        let model = train_model(data, 64);
1083        assert_eq!(model.decode(), data);
1084        model.validate_invariants().unwrap();
1085    }
1086
1087    #[test]
1088    fn sequitur_invariants_hold_after_each_step() {
1089        let mut model = SequiturModel::new(32);
1090        for &byte in b"abracadabra abracadabra" {
1091            model.update(byte);
1092            model.validate_invariants().unwrap();
1093        }
1094    }
1095
1096    #[test]
1097    fn sequitur_canonical_grammar_is_deterministic() {
1098        let model_a = train_model(b"abcabcabcabc", 64);
1099        let model_b = train_model(b"abcabcabcabc", 64);
1100        assert_eq!(model_a.canonical_grammar(), model_b.canonical_grammar());
1101    }
1102
1103    #[test]
1104    fn sequitur_pdf_is_normalized() {
1105        let mut model = train_model(b"abababababa", 32);
1106        let mut pdf = [0.0; 256];
1107        model.fill_pdf(&mut pdf);
1108        let sum: f64 = pdf.iter().sum();
1109        assert!((sum - 1.0).abs() < 1e-9, "sum={sum}");
1110        assert!(pdf.iter().all(|p| p.is_finite() && *p > 0.0));
1111    }
1112
1113    #[test]
1114    fn frozen_updates_do_not_mutate_learned_distribution_after_reset() {
1115        let mut model = train_model(b"banana banana banana", 32);
1116        let mut before = [0.0; 256];
1117        model.fill_pdf(&mut before);
1118        let grammar_before = model.canonical_grammar();
1119        let followers_before = model.followers.clone();
1120        model.reset_frozen();
1121        for &byte in b"ZZZZ" {
1122            model.update_frozen(byte);
1123        }
1124        assert_eq!(grammar_before, model.canonical_grammar());
1125        assert_eq!(followers_before, model.followers);
1126        model.reset_frozen();
1127        let mut after = [0.0; 256];
1128        model.fill_pdf(&mut after);
1129        assert_eq!(before, after);
1130    }
1131
1132    #[test]
1133    fn checkpoint_restore_recovers_exact_state() {
1134        let mut model = train_model(b"mississippi", 32);
1135        let checkpoint = model.checkpoint();
1136        let grammar_before = model.canonical_grammar();
1137        let pdf_before = {
1138            let mut pdf = [0.0; 256];
1139            model.fill_pdf(&mut pdf);
1140            pdf
1141        };
1142        for &byte in b" river" {
1143            model.update(byte);
1144        }
1145        model.restore(&checkpoint);
1146        model.clear_checkpoints();
1147        assert_eq!(grammar_before, model.canonical_grammar());
1148        let mut pdf_after = [0.0; 256];
1149        model.fill_pdf(&mut pdf_after);
1150        assert_eq!(pdf_before, pdf_after);
1151        model.validate_invariants().unwrap();
1152    }
1153
1154    #[test]
1155    fn sequitur_repetitive_binary_inputs_preserve_invariants() {
1156        for data in [
1157            b"\x00\x00\x00\x00\x00\x00\x00\x00".as_slice(),
1158            b"\x00\x01\x00\x01\x00\x01\x00\x01".as_slice(),
1159        ] {
1160            let mut model = SequiturModel::new(32);
1161            for (idx, &byte) in data.iter().enumerate() {
1162                model.update(byte);
1163                if let Err(err) = model.validate_invariants() {
1164                    let active = model
1165                        .active_rule_ids()
1166                        .into_iter()
1167                        .map(|rule| {
1168                            (
1169                                rule,
1170                                model.rules[rule as usize].ref_count,
1171                                model.rule_body_symbols(rule),
1172                            )
1173                        })
1174                        .collect::<Vec<_>>();
1175                    panic!(
1176                        "invariants failed at step {idx} for {:?}: {err}; active={active:?}",
1177                        data
1178                    );
1179                }
1180            }
1181        }
1182    }
1183}