infotheory/backends/
rosaplus.rs

1//! # ROSA: Rapid Online Suffix Automaton
2//! A high-performance predictive language model for entropy rate estimation.
3//!
4//! ROSA uses a **Suffix Automaton** (SAM) to efficiently find the longest matching context
5//! for each symbol in a sequence. It then applies **Witten-Bell smoothing** to estimate
6//! the conditional probability `P(x_t | x_{<t})`.
7//!
8//! This allows for accurate estimation of:
9//! *   Entropy Rate `Ĥ(X)`
10//! *   Cross-Entropy Rate `Ĥ(P, Q)`
11//! *   Joint Entropy Rate `Ĥ(X, Y)` (via aligned pair symbols)
12//!
13//! The implementation is optimized for speed and memory efficiency, using a compact
14//! graph representation for the automaton.
15
16#![allow(clippy::needless_range_loop)]
17
18use std::collections::HashMap;
19use std::fs::File;
20use std::io::{BufReader, BufWriter, Read, Write};
21
22const SAM_SMALL_MAX: usize = 4;
23// NOTE: bump when on-disk format changes.
24// v5 promotes SAM/LM overflow-link indices to unsigned 32-bit space for enwik-scale byte models
25// and intentionally drops older model compatibility.
26const MAGIC_V5: &[u8] = b"rosa_pb_v5\0";
27
28type SamStateIx = i32;
29type SamEdgeIx = u32;
30type LmNodeIx = u32;
31
32const SAM_STATE_NONE: SamStateIx = -1;
33const SAM_EDGE_NONE: SamEdgeIx = u32::MAX;
34const LM_NODE_NONE: LmNodeIx = u32::MAX;
35const LM_PACKED_SYM_OVERFLOW: u16 = u16::MAX;
36const LM_PACKED_CNT_MAX: u16 = u16::MAX;
37
38// This crate is used byte-wise by infotheory; for fast incremental conditional updates we
39// support an optional fixed 256-byte alphabet LM build/update path.
40const BYTE_ALPHA_N: usize = 256;
41
42#[inline(always)]
43fn state_ix(idx: usize) -> SamStateIx {
44    SamStateIx::try_from(idx).expect("rosa sam state index overflow")
45}
46
47#[inline(always)]
48fn state_usize(idx: SamStateIx) -> usize {
49    debug_assert!(idx >= 0, "negative rosa sam state index");
50    idx as usize
51}
52
53#[inline(always)]
54fn edge_ix(idx: usize) -> SamEdgeIx {
55    SamEdgeIx::try_from(idx).expect("rosa sam edge index overflow")
56}
57
58#[inline(always)]
59fn edge_usize(idx: SamEdgeIx) -> usize {
60    idx as usize
61}
62
63#[inline(always)]
64fn node_ix(idx: usize) -> LmNodeIx {
65    LmNodeIx::try_from(idx).expect("rosa lm node index overflow")
66}
67
68#[inline(always)]
69fn node_usize(idx: LmNodeIx) -> usize {
70    idx as usize
71}
72
73#[inline(always)]
74fn write_len64<W: Write>(w: &mut W, len: usize) -> std::io::Result<()> {
75    w.write_all(&(len as u64).to_le_bytes())
76}
77
78#[inline(always)]
79fn read_len64<R: Read>(r: &mut R) -> std::io::Result<usize> {
80    let mut b8 = [0u8; 8];
81    r.read_exact(&mut b8)?;
82    usize::try_from(u64::from_le_bytes(b8))
83        .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "length overflow"))
84}
85
86#[inline(always)]
87fn write_u32_slice_le<W: Write>(w: &mut W, xs: &[u32]) -> std::io::Result<()> {
88    if cfg!(target_endian = "little") {
89        let bytes = unsafe {
90            std::slice::from_raw_parts(xs.as_ptr() as *const u8, xs.len().saturating_mul(4))
91        };
92        w.write_all(bytes)
93    } else {
94        for &x in xs {
95            w.write_all(&x.to_le_bytes())?;
96        }
97        Ok(())
98    }
99}
100
101#[inline(always)]
102fn write_i32_slice_le<W: Write>(w: &mut W, xs: &[i32]) -> std::io::Result<()> {
103    if cfg!(target_endian = "little") {
104        let bytes = unsafe {
105            std::slice::from_raw_parts(xs.as_ptr() as *const u8, xs.len().saturating_mul(4))
106        };
107        w.write_all(bytes)
108    } else {
109        for &x in xs {
110            w.write_all(&x.to_le_bytes())?;
111        }
112        Ok(())
113    }
114}
115
116#[inline(always)]
117fn write_u64_slice_le<W: Write>(w: &mut W, xs: &[u64]) -> std::io::Result<()> {
118    if cfg!(target_endian = "little") {
119        let bytes = unsafe {
120            std::slice::from_raw_parts(xs.as_ptr() as *const u8, xs.len().saturating_mul(8))
121        };
122        w.write_all(bytes)
123    } else {
124        for &x in xs {
125            w.write_all(&x.to_le_bytes())?;
126        }
127        Ok(())
128    }
129}
130
131#[inline(always)]
132fn read_u32_slice_le<R: Read>(r: &mut R, xs: &mut [u32]) -> std::io::Result<()> {
133    if cfg!(target_endian = "little") {
134        let bytes = unsafe {
135            std::slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut u8, xs.len().saturating_mul(4))
136        };
137        r.read_exact(bytes)
138    } else {
139        let mut b4 = [0u8; 4];
140        for x in xs {
141            r.read_exact(&mut b4)?;
142            *x = u32::from_le_bytes(b4);
143        }
144        Ok(())
145    }
146}
147
148#[inline(always)]
149fn read_i32_slice_le<R: Read>(r: &mut R, xs: &mut [i32]) -> std::io::Result<()> {
150    if cfg!(target_endian = "little") {
151        let bytes = unsafe {
152            std::slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut u8, xs.len().saturating_mul(4))
153        };
154        r.read_exact(bytes)
155    } else {
156        let mut b4 = [0u8; 4];
157        for x in xs {
158            r.read_exact(&mut b4)?;
159            *x = i32::from_le_bytes(b4);
160        }
161        Ok(())
162    }
163}
164
165#[inline(always)]
166fn read_u64_slice_le<R: Read>(r: &mut R, xs: &mut [u64]) -> std::io::Result<()> {
167    if cfg!(target_endian = "little") {
168        let bytes = unsafe {
169            std::slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut u8, xs.len().saturating_mul(8))
170        };
171        r.read_exact(bytes)
172    } else {
173        let mut b8 = [0u8; 8];
174        for x in xs {
175            r.read_exact(&mut b8)?;
176            *x = u64::from_le_bytes(b8);
177        }
178        Ok(())
179    }
180}
181
182#[derive(Clone, Copy, Default)]
183struct SamState {
184    link: SamStateIx,
185    len: i32,
186    endpos: i32,
187    head: SamEdgeIx,
188
189    small_ch: [u32; SAM_SMALL_MAX],
190    small_to: [SamStateIx; SAM_SMALL_MAX],
191    small_n: u8,
192}
193
194#[derive(Clone, Copy, Default)]
195struct SamEdge {
196    ch: u32,
197    to: SamStateIx,
198    next: SamEdgeIx,
199}
200
201#[derive(Clone)]
202struct Sam {
203    st: Vec<SamState>,
204    ed: Vec<SamEdge>,
205    last: SamStateIx,
206    root_to: [SamStateIx; BYTE_ALPHA_N],
207
208    text: Vec<u32>,
209    text_states: Vec<SamStateIx>,
210    boundary_after: Vec<u8>,
211}
212
213impl Default for Sam {
214    fn default() -> Self {
215        Self::new(0)
216    }
217}
218
219impl Sam {
220    fn new(expected_chars: usize) -> Self {
221        let mut s = Sam {
222            st: Vec::new(),
223            ed: Vec::new(),
224            last: 0,
225            root_to: [SAM_STATE_NONE; BYTE_ALPHA_N],
226            text: Vec::new(),
227            text_states: Vec::new(),
228            boundary_after: Vec::new(),
229        };
230
231        let st_cap = if expected_chars > 0 {
232            expected_chars * 2 + 16
233        } else {
234            1024
235        };
236        let ed_cap = if expected_chars > 0 {
237            expected_chars * 3 + 16
238        } else {
239            2048
240        };
241        let text_cap = if expected_chars > 0 {
242            expected_chars + 16
243        } else {
244            1024
245        };
246        s.st.reserve(st_cap);
247        s.ed.reserve(ed_cap);
248        s.text.reserve(text_cap);
249        s.text_states.reserve(text_cap);
250        s.boundary_after.reserve(text_cap);
251
252        let root = SamState {
253            link: SAM_STATE_NONE,
254            len: 0,
255            endpos: -1,
256            small_n: 0,
257            head: SAM_EDGE_NONE,
258            ..Default::default()
259        };
260        s.st.push(root);
261        s.text_states.push(0); // Root state for empty context
262        s
263    }
264
265    #[inline(always)]
266    fn reserve_additional(&mut self, additional: usize) {
267        if additional == 0 {
268            return;
269        }
270        self.st
271            .reserve_exact(additional.saturating_mul(2).saturating_add(16));
272        self.ed
273            .reserve_exact(additional.saturating_mul(3).saturating_add(16));
274        let text_extra = additional.saturating_add(16);
275        self.text.reserve_exact(text_extra);
276        self.text_states.reserve_exact(text_extra);
277        self.boundary_after.reserve_exact(text_extra);
278    }
279
280    #[inline(always)]
281    fn get_edge(&self, v: SamStateIx, ch: u32) -> SamStateIx {
282        if v == 0 && ch < BYTE_ALPHA_N as u32 {
283            return self.root_to[ch as usize];
284        }
285        let st = unsafe { self.st.get_unchecked(state_usize(v)) };
286        for i in 0..(st.small_n as usize) {
287            if st.small_ch[i] == ch {
288                return st.small_to[i];
289            }
290        }
291        let mut ei = st.head;
292        while ei != SAM_EDGE_NONE {
293            let e = unsafe { self.ed.get_unchecked(edge_usize(ei)) };
294            if e.ch == ch {
295                return e.to;
296            }
297            ei = e.next;
298        }
299        SAM_STATE_NONE
300    }
301
302    #[inline(always)]
303    fn add_edge(&mut self, v: SamStateIx, ch: u32, to: SamStateIx) {
304        let idx = edge_ix(self.ed.len());
305        let head = self.st[state_usize(v)].head;
306        self.ed.push(SamEdge { ch, to, next: head });
307        self.st[state_usize(v)].head = idx;
308    }
309
310    #[inline(always)]
311    fn add_edge_absent(&mut self, v: SamStateIx, ch: u32, to: SamStateIx) {
312        let st = &mut self.st[state_usize(v)];
313        if (st.small_n as usize) < SAM_SMALL_MAX {
314            let i = st.small_n as usize;
315            st.small_n += 1;
316            st.small_ch[i] = ch;
317            st.small_to[i] = to;
318        } else {
319            self.add_edge(v, ch, to);
320        }
321        if v == 0 && ch < BYTE_ALPHA_N as u32 {
322            self.root_to[ch as usize] = to;
323        }
324    }
325
326    #[inline(always)]
327    fn replace_edge_to(
328        &mut self,
329        v: SamStateIx,
330        ch: u32,
331        old_to: SamStateIx,
332        new_to: SamStateIx,
333    ) -> bool {
334        {
335            let st = &mut self.st[state_usize(v)];
336            for i in 0..(st.small_n as usize) {
337                if st.small_ch[i] == ch && st.small_to[i] == old_to {
338                    st.small_to[i] = new_to;
339                    if v == 0 && ch < BYTE_ALPHA_N as u32 {
340                        self.root_to[ch as usize] = new_to;
341                    }
342                    return true;
343                }
344            }
345        }
346        let mut ei = self.st[state_usize(v)].head;
347        while ei != SAM_EDGE_NONE {
348            let e = &mut self.ed[edge_usize(ei)];
349            if e.ch == ch && e.to == old_to {
350                e.to = new_to;
351                if v == 0 && ch < BYTE_ALPHA_N as u32 {
352                    self.root_to[ch as usize] = new_to;
353                }
354                return true;
355            }
356            ei = e.next;
357        }
358        false
359    }
360
361    fn rebuild_root_cache(&mut self) {
362        self.root_to.fill(SAM_STATE_NONE);
363        if self.st.is_empty() {
364            return;
365        }
366        let root = self.st[0];
367        for i in 0..(root.small_n as usize) {
368            let ch = root.small_ch[i];
369            if ch < BYTE_ALPHA_N as u32 {
370                self.root_to[ch as usize] = root.small_to[i];
371            }
372        }
373        let mut ei = root.head;
374        while ei != SAM_EDGE_NONE {
375            let e = self.ed[edge_usize(ei)];
376            if e.ch < BYTE_ALPHA_N as u32 {
377                self.root_to[e.ch as usize] = e.to;
378            }
379            ei = e.next;
380        }
381    }
382
383    fn clone_overflow_edges(&mut self, src: SamStateIx, dst: SamStateIx) {
384        self.st[state_usize(dst)].head = SAM_EDGE_NONE;
385        let mut ei = self.st[state_usize(src)].head;
386        while ei != SAM_EDGE_NONE {
387            let e = self.ed[edge_usize(ei)];
388            self.add_edge(dst, e.ch, e.to);
389            ei = e.next;
390        }
391    }
392
393    fn feed(&mut self, ch: u32) {
394        let i = self.text.len() as i32;
395        self.text.push(ch);
396        self.boundary_after.push(0);
397
398        let g = self.last;
399        let r = state_ix(self.st.len());
400        let st_r = SamState {
401            link: 0,
402            len: self.st[state_usize(g)].len + 1,
403            endpos: i,
404            small_n: 0,
405            head: SAM_EDGE_NONE,
406            ..Default::default()
407        };
408        self.st.push(st_r);
409
410        let mut p = g;
411        let mut q;
412        while p != SAM_STATE_NONE {
413            q = self.get_edge(p, ch);
414            if q != SAM_STATE_NONE {
415                break;
416            }
417            self.add_edge_absent(p, ch, r);
418            p = self.st[state_usize(p)].link;
419        }
420
421        if p == SAM_STATE_NONE {
422            self.st[state_usize(r)].link = 0;
423        } else {
424            q = self.get_edge(p, ch);
425            if self.st[state_usize(p)].len + 1 == self.st[state_usize(q)].len {
426                self.st[state_usize(r)].link = q;
427            } else {
428                let u = state_ix(self.st.len());
429                let mut st_u = self.st[state_usize(q)];
430                st_u.len = self.st[state_usize(p)].len + 1;
431                self.st.push(st_u);
432                self.clone_overflow_edges(q, u);
433                while p != SAM_STATE_NONE && self.replace_edge_to(p, ch, q, u) {
434                    p = self.st[state_usize(p)].link;
435                }
436                self.st[state_usize(q)].link = u;
437                self.st[state_usize(r)].link = u;
438            }
439        }
440
441        self.last = r;
442        self.text_states.push(r);
443
444        // Maintain rightmost endpos online (ROSA deterministic predictor).
445        let mut v = r;
446        while v != SAM_STATE_NONE && self.st[state_usize(v)].endpos < i {
447            self.st[state_usize(v)].endpos = i;
448            v = self.st[state_usize(v)].link;
449        }
450    }
451
452    fn mark_boundary(&mut self) {
453        if !self.text.is_empty() {
454            let i = self.text.len() - 1;
455            self.boundary_after[i] = 1;
456        }
457        self.last = 0;
458    }
459
460    fn finalize_endpos(&mut self) {
461        let mut max_len: usize = 0;
462        for v in 0..self.st.len() {
463            let l = self.st[v].len as usize;
464            if l > max_len {
465                max_len = l;
466            }
467        }
468
469        let mut cnt = vec![0usize; max_len + 1];
470        for v in 0..self.st.len() {
471            cnt[self.st[v].len as usize] += 1;
472        }
473        let mut pos = vec![0usize; max_len + 1];
474        let mut acc = 0usize;
475        for l in 0..=max_len {
476            pos[l] = acc;
477            acc += cnt[l];
478        }
479        let mut order = vec![0u32; self.st.len()];
480        for v in 0..self.st.len() {
481            let l = self.st[v].len as usize;
482            let idx = pos[l];
483            order[idx] = v as u32;
484            pos[l] += 1;
485        }
486
487        for oi in (0..order.len()).rev() {
488            let v = order[oi] as usize;
489            let p = self.st[v].link;
490            if p >= 0 {
491                let p = p as usize;
492                if self.st[v].endpos > self.st[p].endpos {
493                    self.st[p].endpos = self.st[v].endpos;
494                }
495            }
496        }
497    }
498
499    #[inline(always)]
500    fn advance(&self, mut v: SamStateIx, ch: u32) -> SamStateIx {
501        let mut to;
502        loop {
503            to = self.get_edge(v, ch);
504            if to != SAM_STATE_NONE {
505                return to;
506            }
507            v = self.st[state_usize(v)].link;
508            if v == SAM_STATE_NONE {
509                break;
510            }
511        }
512        to = self.get_edge(0, ch);
513        if to == SAM_STATE_NONE { 0 } else { to }
514    }
515
516    #[inline(always)]
517    fn predict_det(&self, v: SamStateIx) -> Option<u32> {
518        let mut u = v;
519        while u != SAM_STATE_NONE {
520            let st = unsafe { self.st.get_unchecked(state_usize(u)) };
521            let i = st.endpos;
522            let j = i + 1;
523            if st.len > 0 && j >= 0 && (j as usize) < self.text.len() {
524                if i >= 0
525                    && (i as usize) < self.boundary_after.len()
526                    && self.boundary_after[i as usize] != 0
527                {
528                    u = st.link;
529                    continue;
530                }
531                return Some(self.text[j as usize]);
532            }
533            u = st.link;
534        }
535        None
536    }
537
538    // ===== Transactional (undo-log) support =====
539    fn begin_tx(&self) -> SamTx {
540        SamTx {
541            old_last: self.last,
542            old_text_len: self.text.len(),
543            old_text_states_len: self.text_states.len(),
544            old_boundary_len: self.boundary_after.len(),
545            old_st_len: self.st.len(),
546            old_ed_len: self.ed.len(),
547            st_changes: Vec::new(),
548            ed_changes: Vec::new(),
549        }
550    }
551
552    fn rollback_tx(&mut self, tx: SamTx) {
553        // Restore mutated entries (reverse order is fine even with duplicates).
554        for (idx, old) in tx.ed_changes.into_iter().rev() {
555            if idx < self.ed.len() {
556                self.ed[idx] = old;
557            }
558        }
559        for (idx, old) in tx.st_changes.into_iter().rev() {
560            if idx < self.st.len() {
561                self.st[idx] = old;
562            }
563        }
564
565        self.st.truncate(tx.old_st_len);
566        self.ed.truncate(tx.old_ed_len);
567        self.text.truncate(tx.old_text_len);
568        self.text_states.truncate(tx.old_text_states_len);
569        self.boundary_after.truncate(tx.old_boundary_len);
570        self.last = tx.old_last;
571        self.rebuild_root_cache();
572    }
573
574    #[inline(always)]
575    fn record_state_change(&self, tx: &mut SamTx, idx: usize) {
576        // Duplicates are OK; rollback applies in reverse.
577        tx.st_changes.push((idx, self.st[idx]));
578    }
579
580    #[inline(always)]
581    fn record_edge_change(&self, tx: &mut SamTx, idx: usize) {
582        tx.ed_changes.push((idx, self.ed[idx]));
583    }
584
585    #[inline(always)]
586    fn add_edge_tx(&mut self, tx: &mut SamTx, v: SamStateIx, ch: u32, to: SamStateIx) {
587        let idx = edge_ix(self.ed.len());
588        let head = self.st[state_usize(v)].head;
589        self.ed.push(SamEdge { ch, to, next: head });
590        self.record_state_change(tx, state_usize(v));
591        self.st[state_usize(v)].head = idx;
592    }
593
594    #[inline(always)]
595    fn add_edge_absent_tx(&mut self, tx: &mut SamTx, v: SamStateIx, ch: u32, to: SamStateIx) {
596        let v_usize = state_usize(v);
597        let small_n = self.st[v_usize].small_n as usize;
598        if small_n < SAM_SMALL_MAX {
599            let i = small_n;
600            self.record_state_change(tx, v_usize);
601            let st = &mut self.st[v_usize];
602            st.small_ch[i] = ch;
603            st.small_to[i] = to;
604            st.small_n += 1;
605            if v == 0 && ch < BYTE_ALPHA_N as u32 {
606                self.root_to[ch as usize] = to;
607            }
608        } else {
609            self.add_edge_tx(tx, v, ch, to);
610            if v == 0 && ch < BYTE_ALPHA_N as u32 {
611                self.root_to[ch as usize] = to;
612            }
613        }
614    }
615
616    #[inline(always)]
617    fn replace_edge_to_tx(
618        &mut self,
619        tx: &mut SamTx,
620        v: SamStateIx,
621        ch: u32,
622        old_to: SamStateIx,
623        new_to: SamStateIx,
624    ) -> bool {
625        // small edges
626        {
627            let st = &self.st[state_usize(v)];
628            for i in 0..(st.small_n as usize) {
629                if st.small_ch[i] == ch && st.small_to[i] == old_to {
630                    self.record_state_change(tx, state_usize(v));
631                    self.st[state_usize(v)].small_to[i] = new_to;
632                    if v == 0 && ch < BYTE_ALPHA_N as u32 {
633                        self.root_to[ch as usize] = new_to;
634                    }
635                    return true;
636                }
637            }
638        }
639        // overflow edges
640        let mut ei = self.st[state_usize(v)].head;
641        while ei != SAM_EDGE_NONE {
642            let eidx = edge_usize(ei);
643            let e = self.ed[eidx];
644            if e.ch == ch && e.to == old_to {
645                self.record_edge_change(tx, eidx);
646                self.ed[eidx].to = new_to;
647                if v == 0 && ch < BYTE_ALPHA_N as u32 {
648                    self.root_to[ch as usize] = new_to;
649                }
650                return true;
651            }
652            ei = e.next;
653        }
654        false
655    }
656
657    fn clone_overflow_edges_tx(&mut self, tx: &mut SamTx, src: SamStateIx, dst: SamStateIx) {
658        self.record_state_change(tx, state_usize(dst));
659        self.st[state_usize(dst)].head = SAM_EDGE_NONE;
660        let mut ei = self.st[state_usize(src)].head;
661        while ei != SAM_EDGE_NONE {
662            let e = self.ed[edge_usize(ei)];
663            self.add_edge_tx(tx, dst, e.ch, e.to);
664            ei = e.next;
665        }
666    }
667
668    fn feed_tx(&mut self, tx: &mut SamTx, ch: u32) {
669        let i = self.text.len() as i32;
670        self.text.push(ch);
671        self.boundary_after.push(0);
672
673        let g = self.last;
674        let r = state_ix(self.st.len());
675        let st_r = SamState {
676            link: 0,
677            len: self.st[state_usize(g)].len + 1,
678            endpos: i,
679            small_n: 0,
680            head: SAM_EDGE_NONE,
681            ..Default::default()
682        };
683        self.st.push(st_r);
684
685        let mut p = g;
686        let mut q;
687        while p != SAM_STATE_NONE {
688            q = self.get_edge(p, ch);
689            if q != SAM_STATE_NONE {
690                break;
691            }
692            self.add_edge_absent_tx(tx, p, ch, r);
693            p = self.st[state_usize(p)].link;
694        }
695
696        if p == SAM_STATE_NONE {
697            // link of r is in newly appended state; safe.
698            self.st[state_usize(r)].link = 0;
699        } else {
700            q = self.get_edge(p, ch);
701            if self.st[state_usize(p)].len + 1 == self.st[state_usize(q)].len {
702                self.st[state_usize(r)].link = q;
703            } else {
704                let u = state_ix(self.st.len());
705                let mut st_u = self.st[state_usize(q)];
706                st_u.len = self.st[state_usize(p)].len + 1;
707                self.st.push(st_u);
708                self.clone_overflow_edges_tx(tx, q, u);
709                while p != SAM_STATE_NONE && self.replace_edge_to_tx(tx, p, ch, q, u) {
710                    p = self.st[state_usize(p)].link;
711                }
712                // q is an existing state; record before mutation.
713                self.record_state_change(tx, state_usize(q));
714                self.st[state_usize(q)].link = u;
715                self.st[state_usize(r)].link = u;
716            }
717        }
718
719        self.last = r;
720        self.text_states.push(r);
721
722        // Maintain rightmost endpos online (ROSA deterministic predictor).
723        let mut v = r;
724        while v != SAM_STATE_NONE && self.st[state_usize(v)].endpos < i {
725            self.record_state_change(tx, state_usize(v));
726            self.st[state_usize(v)].endpos = i;
727            v = self.st[state_usize(v)].link;
728        }
729    }
730
731    fn mark_boundary_tx(&mut self, tx: &mut SamTx) {
732        if !self.text.is_empty() {
733            // boundary_after is truncated on rollback, so no need to log.
734            let i = self.text.len() - 1;
735            self.boundary_after[i] = 1;
736        }
737        // last is restored on rollback.
738        self.last = 0;
739        let _ = tx;
740    }
741}
742
743#[derive(Clone)]
744struct SamTx {
745    old_last: SamStateIx,
746    old_text_len: usize,
747    old_text_states_len: usize,
748    old_boundary_len: usize,
749    old_st_len: usize,
750    old_ed_len: usize,
751    st_changes: Vec<(usize, SamState)>,
752    ed_changes: Vec<(usize, SamEdge)>,
753}
754
755#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
756struct LmState {
757    head: LmNodeIx,
758    total_n: u64,
759    types_t: u32,
760
761    last_sym: u32,
762    last_node: LmNodeIx,
763}
764
765#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
766struct CountNode {
767    sym_idx: u32,
768    cnt: u64,
769    next: LmNodeIx,
770}
771
772#[derive(Clone, Debug, Default, PartialEq, Eq)]
773struct LmNodes {
774    sym_lo: Vec<u16>,
775    cnt_lo: Vec<u16>,
776    next: Vec<LmNodeIx>,
777    cnt_overflow_mask: Vec<u8>,
778    sym_overflow: HashMap<u32, u32>,
779    cnt_overflow: HashMap<u32, u64>,
780}
781
782impl LmNodes {
783    #[inline(always)]
784    fn len(&self) -> usize {
785        self.next.len()
786    }
787
788    #[inline(always)]
789    fn clear(&mut self) {
790        self.sym_lo.clear();
791        self.cnt_lo.clear();
792        self.next.clear();
793        self.cnt_overflow_mask.clear();
794        self.sym_overflow.clear();
795        self.cnt_overflow.clear();
796    }
797
798    #[inline(always)]
799    fn reserve_exact(&mut self, additional: usize) {
800        self.sym_lo.reserve_exact(additional);
801        self.cnt_lo.reserve_exact(additional);
802        self.next.reserve_exact(additional);
803        self.cnt_overflow_mask.reserve_exact(additional);
804    }
805
806    #[inline(always)]
807    fn truncate(&mut self, new_len: usize) {
808        self.sym_lo.truncate(new_len);
809        self.cnt_lo.truncate(new_len);
810        self.next.truncate(new_len);
811        self.cnt_overflow_mask.truncate(new_len);
812        self.sym_overflow.retain(|&k, _| (k as usize) < new_len);
813        self.cnt_overflow.retain(|&k, _| (k as usize) < new_len);
814    }
815
816    #[inline(always)]
817    fn resize(&mut self, new_len: usize, value: CountNode) {
818        if new_len <= self.len() {
819            self.truncate(new_len);
820            return;
821        }
822        while self.len() < new_len {
823            self.push(value);
824        }
825    }
826
827    #[inline(always)]
828    fn set_sym_idx(&mut self, idx: usize, sym_idx: u32) {
829        if sym_idx < LM_PACKED_SYM_OVERFLOW as u32 {
830            self.sym_lo[idx] = sym_idx as u16;
831            self.sym_overflow.remove(&(idx as u32));
832        } else {
833            self.sym_lo[idx] = LM_PACKED_SYM_OVERFLOW;
834            self.sym_overflow.insert(idx as u32, sym_idx);
835        }
836    }
837
838    #[inline(always)]
839    fn set_cnt(&mut self, idx: usize, cnt: u64) {
840        if cnt <= LM_PACKED_CNT_MAX as u64 {
841            self.cnt_lo[idx] = cnt as u16;
842            self.cnt_overflow.remove(&(idx as u32));
843            self.cnt_overflow_mask[idx] = 0;
844        } else {
845            self.cnt_lo[idx] = LM_PACKED_CNT_MAX;
846            self.cnt_overflow
847                .insert(idx as u32, cnt - LM_PACKED_CNT_MAX as u64);
848            self.cnt_overflow_mask[idx] = 1;
849        }
850    }
851
852    #[inline(always)]
853    fn push(&mut self, node: CountNode) {
854        let idx = self.len();
855        self.sym_lo.push(0);
856        self.cnt_lo.push(0);
857        self.next.push(node.next);
858        self.cnt_overflow_mask.push(0);
859        self.set_sym_idx(idx, node.sym_idx);
860        self.set_cnt(idx, node.cnt);
861    }
862
863    #[inline(always)]
864    fn get(&self, idx: usize) -> CountNode {
865        CountNode {
866            sym_idx: self.sym_idx(idx),
867            cnt: self.cnt(idx),
868            next: self.next[idx],
869        }
870    }
871
872    #[inline(always)]
873    fn set(&mut self, idx: usize, node: CountNode) {
874        self.next[idx] = node.next;
875        self.set_sym_idx(idx, node.sym_idx);
876        self.set_cnt(idx, node.cnt);
877    }
878
879    #[inline(always)]
880    fn sym_idx(&self, idx: usize) -> u32 {
881        if self.sym_lo[idx] == LM_PACKED_SYM_OVERFLOW {
882            self.sym_overflow
883                .get(&(idx as u32))
884                .copied()
885                .unwrap_or(LM_PACKED_SYM_OVERFLOW as u32)
886        } else {
887            self.sym_lo[idx] as u32
888        }
889    }
890
891    #[inline(always)]
892    fn cnt(&self, idx: usize) -> u64 {
893        if self.cnt_overflow_mask[idx] == 0 {
894            self.cnt_lo[idx] as u64
895        } else {
896            self.cnt_lo[idx] as u64 + self.cnt_overflow.get(&(idx as u32)).copied().unwrap_or(0)
897        }
898    }
899
900    #[inline(always)]
901    fn next(&self, idx: usize) -> LmNodeIx {
902        self.next[idx]
903    }
904
905    #[inline(always)]
906    fn add_cnt(&mut self, idx: usize, add: u64) {
907        let next = self.cnt(idx).saturating_add(add);
908        self.set_cnt(idx, next);
909    }
910}
911
912struct LmNodesIter<'a> {
913    nodes: &'a LmNodes,
914    idx: usize,
915}
916
917impl<'a> Iterator for LmNodesIter<'a> {
918    type Item = CountNode;
919
920    fn next(&mut self) -> Option<Self::Item> {
921        if self.idx >= self.nodes.len() {
922            return None;
923        }
924        let out = self.nodes.get(self.idx);
925        self.idx += 1;
926        Some(out)
927    }
928}
929
930impl<'a> IntoIterator for &'a LmNodes {
931    type Item = CountNode;
932    type IntoIter = LmNodesIter<'a>;
933
934    fn into_iter(self) -> Self::IntoIter {
935        LmNodesIter {
936            nodes: self,
937            idx: 0,
938        }
939    }
940}
941
942#[derive(Clone)]
943struct LM {
944    alphabet: Vec<u32>,
945    unigram: Vec<u64>,
946    alpha_n: u32,
947    total_uni: u64,
948
949    has_byte_map: bool,
950    byte_map: [i16; 256],
951
952    ls: Vec<LmState>,
953    nodes: LmNodes,
954}
955
956impl Default for LM {
957    fn default() -> Self {
958        LM {
959            alphabet: Vec::new(),
960            unigram: Vec::new(),
961            alpha_n: 0,
962            total_uni: 0,
963            has_byte_map: false,
964            byte_map: [-1; 256],
965            ls: Vec::new(),
966            nodes: LmNodes::default(),
967        }
968    }
969}
970
971impl LM {
972    #[inline(always)]
973    fn ls_is_implicit_single(ls: &LmState) -> bool {
974        ls.head == LM_NODE_NONE && ls.types_t == 1 && ls.total_n > 0
975    }
976
977    #[inline(always)]
978    fn capped_start_state(&self, sam: &Sam, max_order: i64, mut v: SamStateIx) -> SamStateIx {
979        if max_order < 0 {
980            return v;
981        }
982        while v != SAM_STATE_NONE && (sam.st[state_usize(v)].len as i64) > max_order {
983            v = sam.st[state_usize(v)].link;
984        }
985        if v == SAM_STATE_NONE { 0 } else { v }
986    }
987
988    fn build_alphabet(&mut self, sam: &Sam) {
989        self.has_byte_map = false;
990        self.byte_map = [-1; 256];
991
992        let mut max_cp = 0u32;
993        for &v in &sam.text {
994            if v > max_cp {
995                max_cp = v;
996            }
997        }
998
999        if max_cp < 256 {
1000            let mut counts = [0u64; 256];
1001            for &v in &sam.text {
1002                counts[v as usize] += 1;
1003            }
1004            let mut uniq = 0usize;
1005            for c in 0..256 {
1006                if counts[c] != 0 {
1007                    uniq += 1;
1008                }
1009            }
1010
1011            if uniq == 0 {
1012                self.alphabet = vec![b'\n' as u32];
1013                self.unigram = vec![1];
1014                self.alpha_n = 1;
1015                self.total_uni = 1;
1016                self.has_byte_map = true;
1017                self.byte_map[b'\n' as usize] = 0;
1018                return;
1019            }
1020
1021            self.alphabet = Vec::with_capacity(uniq);
1022            self.unigram = Vec::with_capacity(uniq);
1023            self.total_uni = 0;
1024            for c in 0..256u32 {
1025                let cnt = counts[c as usize];
1026                if cnt == 0 {
1027                    continue;
1028                }
1029                self.alphabet.push(c);
1030                self.unigram.push(cnt);
1031                self.total_uni += cnt;
1032            }
1033            self.alpha_n = self.alphabet.len() as u32;
1034            self.has_byte_map = true;
1035            for (i, &c) in self.alphabet.iter().enumerate() {
1036                self.byte_map[c as usize] = i as i16;
1037            }
1038            return;
1039        }
1040
1041        let mut tmp = sam.text.clone();
1042        tmp.sort_unstable();
1043        tmp.dedup();
1044        if tmp.is_empty() {
1045            tmp.push(b'\n' as u32);
1046        }
1047        self.alphabet = tmp;
1048        self.alpha_n = self.alphabet.len() as u32;
1049        self.unigram = vec![0u64; self.alphabet.len()];
1050        self.total_uni = 0;
1051        for &ch in &sam.text {
1052            if let Ok(i) = self.alphabet.binary_search(&ch) {
1053                self.unigram[i] += 1;
1054                self.total_uni += 1;
1055            }
1056        }
1057        if self.total_uni == 0 {
1058            self.unigram[0] = 1;
1059            self.total_uni = 1;
1060        }
1061    }
1062
1063    #[inline(always)]
1064    fn find_sym(&self, ch: u32) -> i32 {
1065        if self.has_byte_map && ch < 256 {
1066            return self.byte_map[ch as usize] as i32;
1067        }
1068        match self.alphabet.binary_search(&ch) {
1069            Ok(i) => i as i32,
1070            Err(_) => -1,
1071        }
1072    }
1073
1074    #[inline(always)]
1075    fn inc(&mut self, state: u32, sym_idx: u32, add: u64) {
1076        let ls = &mut self.ls[state as usize];
1077        if ls.head == LM_NODE_NONE {
1078            if ls.total_n == 0 {
1079                ls.total_n = add;
1080                ls.types_t = 1;
1081                ls.last_sym = sym_idx;
1082                ls.last_node = LM_NODE_NONE;
1083                return;
1084            }
1085            if Self::ls_is_implicit_single(ls) {
1086                if ls.last_sym == sym_idx {
1087                    ls.total_n += add;
1088                    ls.last_node = LM_NODE_NONE;
1089                    return;
1090                }
1091                let old_sym = ls.last_sym;
1092                let old_cnt = ls.total_n;
1093                let old_idx = node_ix(self.nodes.len());
1094                self.nodes.push(CountNode {
1095                    sym_idx: old_sym,
1096                    cnt: old_cnt,
1097                    next: LM_NODE_NONE,
1098                });
1099                let new_idx = node_ix(self.nodes.len());
1100                self.nodes.push(CountNode {
1101                    sym_idx,
1102                    cnt: add,
1103                    next: old_idx,
1104                });
1105                ls.head = new_idx;
1106                ls.total_n = old_cnt + add;
1107                ls.types_t = 2;
1108                ls.last_node = new_idx;
1109                ls.last_sym = sym_idx;
1110                return;
1111            }
1112        }
1113
1114        let last = ls.last_node;
1115        if last != LM_NODE_NONE && self.nodes.sym_idx(node_usize(last)) == sym_idx {
1116            self.nodes.add_cnt(node_usize(last), add);
1117            ls.total_n += add;
1118            return;
1119        }
1120
1121        let mut ni = ls.head;
1122        while ni != LM_NODE_NONE {
1123            let idx = node_usize(ni);
1124            if self.nodes.sym_idx(idx) == sym_idx {
1125                self.nodes.add_cnt(idx, add);
1126                ls.total_n += add;
1127                ls.last_node = ni;
1128                ls.last_sym = sym_idx;
1129                return;
1130            }
1131            ni = self.nodes.next(idx);
1132        }
1133
1134        let idx = node_ix(self.nodes.len());
1135        self.nodes.push(CountNode {
1136            sym_idx,
1137            cnt: add,
1138            next: ls.head,
1139        });
1140        ls.head = idx;
1141        ls.total_n += add;
1142        ls.types_t += 1;
1143        ls.last_node = idx;
1144        ls.last_sym = sym_idx;
1145    }
1146
1147    #[inline(always)]
1148    fn reserve_for_stream(&mut self, additional: usize) {
1149        if additional == 0 {
1150            return;
1151        }
1152        self.ls
1153            .reserve_exact(additional.saturating_mul(2).saturating_add(16));
1154        self.nodes
1155            .reserve_exact(additional.saturating_mul(3).saturating_add(16));
1156    }
1157
1158    fn build_counts(&mut self, sam: &Sam, max_order: i64) {
1159        self.ls = vec![
1160            LmState {
1161                head: LM_NODE_NONE,
1162                last_node: LM_NODE_NONE,
1163                ..LmState::default()
1164            };
1165            sam.st.len()
1166        ];
1167        self.nodes.clear();
1168
1169        let mut seg_start = 0usize;
1170        while seg_start < sam.text.len() {
1171            let mut seg_end = seg_start;
1172            while seg_end < sam.text.len() {
1173                let b = sam.boundary_after[seg_end];
1174                seg_end += 1;
1175                if b != 0 {
1176                    break;
1177                }
1178            }
1179            if seg_end - seg_start >= 2 {
1180                let mut v = 0;
1181                for i in seg_start..(seg_end - 1) {
1182                    let ch = sam.text[i];
1183                    v = sam.advance(v, ch);
1184                    let mut ctx = v;
1185                    if max_order >= 0 {
1186                        while ctx != SAM_STATE_NONE
1187                            && (sam.st[state_usize(ctx)].len as i64) > max_order
1188                        {
1189                            ctx = sam.st[state_usize(ctx)].link;
1190                        }
1191                        if ctx == SAM_STATE_NONE {
1192                            ctx = 0;
1193                        }
1194                    }
1195                    let nxt = sam.text[i + 1];
1196                    let si = self.find_sym(nxt);
1197                    if si >= 0 {
1198                        self.inc(state_usize(ctx) as u32, si as u32, 1);
1199                    }
1200                }
1201            }
1202            seg_start = seg_end;
1203        }
1204
1205        // propagate up suffix links (counting sort by len)
1206        let mut max_len: usize = 0;
1207        for st in &sam.st {
1208            let l = st.len as usize;
1209            if l > max_len {
1210                max_len = l;
1211            }
1212        }
1213        let mut cnt = vec![0usize; max_len + 1];
1214        for st in &sam.st {
1215            cnt[st.len as usize] += 1;
1216        }
1217        let mut pos = vec![0usize; max_len + 1];
1218        let mut acc = 0usize;
1219        for l in 0..=max_len {
1220            pos[l] = acc;
1221            acc += cnt[l];
1222        }
1223        let mut order = vec![0u32; sam.st.len()];
1224        for (v, st) in sam.st.iter().enumerate() {
1225            let l = st.len as usize;
1226            let idx = pos[l];
1227            order[idx] = v as u32;
1228            pos[l] += 1;
1229        }
1230
1231        for oi in (0..order.len()).rev() {
1232            let v = order[oi] as usize;
1233            let p = sam.st[v].link;
1234            if p < 0 {
1235                continue;
1236            }
1237            let ls_v = self.ls[v];
1238            if ls_v.total_n == 0 {
1239                continue;
1240            }
1241            if Self::ls_is_implicit_single(&ls_v) {
1242                self.inc(state_usize(p) as u32, ls_v.last_sym, ls_v.total_n);
1243                continue;
1244            }
1245            let mut ni = ls_v.head;
1246            while ni != LM_NODE_NONE {
1247                let node = self.nodes.get(node_usize(ni));
1248                self.inc(state_usize(p) as u32, node.sym_idx, node.cnt);
1249                ni = node.next;
1250            }
1251        }
1252    }
1253
1254    /// Efficient pointwise probability Estimation of a single symbol.
1255    /// Avoids allocating and writing to a dense distribution array.
1256    fn prob_for_sym(&self, sam: &Sam, max_order: i64, v: SamStateIx, sym_idx: i32) -> f64 {
1257        if sym_idx < 0 {
1258            return 1.0 / (self.alpha_n.max(1) as f64);
1259        }
1260        let sym_idx = sym_idx as u32;
1261        let mut p_accum = 0.0f64;
1262        let mut residual = 1.0f64;
1263        let mut u = self.capped_start_state(sam, max_order, v);
1264
1265        while u != SAM_STATE_NONE {
1266            let ls = &self.ls[state_usize(u)];
1267            let n = ls.total_n;
1268            let t = ls.types_t;
1269            if n > 0 {
1270                let lam = if t > 0 {
1271                    (n as f64) / ((n + (t as u64)) as f64)
1272                } else {
1273                    1.0
1274                };
1275
1276                // Total probability mass from this state
1277                let scale = residual * lam;
1278
1279                // Probability of specifically sym_idx in this state
1280                let mut count_for_sym = 0u64;
1281                if Self::ls_is_implicit_single(ls) {
1282                    if ls.last_sym == sym_idx {
1283                        count_for_sym = n;
1284                    }
1285                } else if ls.last_node != LM_NODE_NONE && ls.last_sym == sym_idx {
1286                    count_for_sym = self.nodes.cnt(node_usize(ls.last_node));
1287                } else {
1288                    let mut ni = ls.head;
1289                    while ni != LM_NODE_NONE {
1290                        let node = self.nodes.get(node_usize(ni));
1291                        if node.sym_idx == sym_idx {
1292                            count_for_sym = node.cnt;
1293                            break;
1294                        }
1295                        ni = node.next;
1296                    }
1297                }
1298
1299                if count_for_sym > 0 {
1300                    p_accum += scale * (count_for_sym as f64 / n as f64);
1301                }
1302
1303                residual *= 1.0 - lam;
1304            }
1305            u = sam.st[state_usize(u)].link;
1306        }
1307
1308        if self.total_uni > 0 && residual > 0.0 {
1309            let p_uni = self.unigram[sym_idx as usize] as f64 / self.total_uni as f64;
1310            p_accum += residual * p_uni;
1311        } else if residual > 0.0 {
1312            p_accum += residual * (1.0 / self.alpha_n.max(1) as f64);
1313        }
1314
1315        p_accum.clamp(1e-12, 1.0)
1316    }
1317
1318    fn probs_for_state_raw(&self, sam: &Sam, max_order: i64, v: SamStateIx, out: &mut [f64]) {
1319        out.fill(0.0);
1320        let mut residual = 1.0f64;
1321        let mut u = self.capped_start_state(sam, max_order, v);
1322        while u != SAM_STATE_NONE {
1323            let ls = &self.ls[state_usize(u)];
1324            let n = ls.total_n;
1325            let t = ls.types_t;
1326            if n > 0 {
1327                let lam = if t > 0 {
1328                    (n as f64) / ((n + (t as u64)) as f64)
1329                } else {
1330                    1.0
1331                };
1332                let scale = residual * lam;
1333                let inv_n = 1.0 / (n as f64);
1334                if Self::ls_is_implicit_single(ls) {
1335                    out[ls.last_sym as usize] += scale;
1336                } else {
1337                    let mut ni = ls.head;
1338                    while ni != LM_NODE_NONE {
1339                        let node = self.nodes.get(node_usize(ni));
1340                        out[node.sym_idx as usize] += scale * ((node.cnt as f64) * inv_n);
1341                        ni = node.next;
1342                    }
1343                }
1344                residual *= 1.0 - lam;
1345            }
1346            u = sam.st[state_usize(u)].link;
1347        }
1348
1349        if self.total_uni > 0 && residual > 0.0 {
1350            let inv = 1.0 / (self.total_uni as f64);
1351            for i in 0..(self.alpha_n as usize) {
1352                out[i] += residual * ((self.unigram[i] as f64) * inv);
1353            }
1354        }
1355    }
1356
1357    fn probs_for_state(&self, sam: &Sam, max_order: i64, v: SamStateIx, out: &mut [f64]) {
1358        self.probs_for_state_raw(sam, max_order, v, out);
1359        let mut s = 0.0;
1360        for i in 0..(self.alpha_n as usize) {
1361            s += out[i];
1362        }
1363        if s > 0.0 && s.is_finite() {
1364            if (s - 1.0).abs() <= 1e-12 {
1365                return;
1366            }
1367            let invs = 1.0 / s;
1368            for i in 0..(self.alpha_n as usize) {
1369                out[i] *= invs;
1370            }
1371        } else {
1372            let uprob = 1.0 / (self.alpha_n.max(1) as f64);
1373            for i in 0..(self.alpha_n as usize) {
1374                out[i] = uprob;
1375            }
1376        }
1377    }
1378
1379    #[inline(always)]
1380    fn inc_tx(&mut self, tx: &mut LmTx, state: u32, sym_idx: u32, add: u64) {
1381        let si = state as usize;
1382        // record old LmState
1383        tx.ls_changes.push((si, self.ls[si]));
1384
1385        let ls = &mut self.ls[si];
1386        if ls.head == LM_NODE_NONE {
1387            if ls.total_n == 0 {
1388                ls.total_n = add;
1389                ls.types_t = 1;
1390                ls.last_sym = sym_idx;
1391                ls.last_node = LM_NODE_NONE;
1392                return;
1393            }
1394            if Self::ls_is_implicit_single(ls) {
1395                if ls.last_sym == sym_idx {
1396                    ls.total_n += add;
1397                    ls.last_node = LM_NODE_NONE;
1398                    return;
1399                }
1400                let old_sym = ls.last_sym;
1401                let old_cnt = ls.total_n;
1402                tx.old_nodes_len = tx.old_nodes_len.min(self.nodes.len());
1403                let old_idx = node_ix(self.nodes.len());
1404                self.nodes.push(CountNode {
1405                    sym_idx: old_sym,
1406                    cnt: old_cnt,
1407                    next: LM_NODE_NONE,
1408                });
1409                let new_idx = node_ix(self.nodes.len());
1410                self.nodes.push(CountNode {
1411                    sym_idx,
1412                    cnt: add,
1413                    next: old_idx,
1414                });
1415                ls.head = new_idx;
1416                ls.total_n = old_cnt + add;
1417                ls.types_t = 2;
1418                ls.last_node = new_idx;
1419                ls.last_sym = sym_idx;
1420                return;
1421            }
1422        }
1423
1424        let last = ls.last_node;
1425        if last != LM_NODE_NONE && self.nodes.sym_idx(node_usize(last)) == sym_idx {
1426            let ni = node_usize(last);
1427            tx.node_changes.push((ni, self.nodes.get(ni)));
1428            self.nodes.add_cnt(ni, add);
1429            ls.total_n += add;
1430            return;
1431        }
1432
1433        let mut ni = ls.head;
1434        while ni != LM_NODE_NONE {
1435            let idx = node_usize(ni);
1436            if self.nodes.sym_idx(idx) == sym_idx {
1437                tx.node_changes.push((idx, self.nodes.get(idx)));
1438                self.nodes.add_cnt(idx, add);
1439                ls.total_n += add;
1440                ls.last_node = ni;
1441                ls.last_sym = sym_idx;
1442                return;
1443            }
1444            ni = self.nodes.next(idx);
1445        }
1446
1447        // New node
1448        let idx = node_ix(self.nodes.len());
1449        tx.old_nodes_len = tx.old_nodes_len.min(self.nodes.len());
1450        self.nodes.push(CountNode {
1451            sym_idx,
1452            cnt: add,
1453            next: ls.head,
1454        });
1455        ls.head = idx;
1456        ls.total_n += add;
1457        ls.types_t += 1;
1458        ls.last_node = idx;
1459        ls.last_sym = sym_idx;
1460    }
1461}
1462
1463#[derive(Clone)]
1464struct LmTx {
1465    old_ls_len: usize,
1466    old_nodes_len: usize,
1467    ls_changes: Vec<(usize, LmState)>,
1468    node_changes: Vec<(usize, CountNode)>,
1469    // unigram delta for bytes
1470    uni_delta: [u64; BYTE_ALPHA_N],
1471    total_uni_add: u64,
1472}
1473
1474#[derive(Clone, Default)]
1475struct RngStream {
1476    buf: Vec<u8>,
1477    pos: usize,
1478    xs: u64,
1479}
1480
1481impl RngStream {
1482    fn new(seed: u64) -> Self {
1483        let mut r = RngStream {
1484            buf: Vec::new(),
1485            pos: 0,
1486            xs: 88172645463325252u64,
1487        };
1488        if let Ok(path) = std::env::var("ROSAPLUS_RNG_PATH")
1489            && !path.is_empty()
1490            && let Ok(mut f) = File::open(path)
1491        {
1492            let mut b = Vec::new();
1493            if f.read_to_end(&mut b).is_ok() && b.len() >= 8 {
1494                let n = b.len();
1495                r.pos = ((seed.wrapping_mul(8)) as usize) % n;
1496                r.buf = b;
1497            }
1498        }
1499        r
1500    }
1501
1502    #[inline(always)]
1503    fn next_u64(&mut self) -> u64 {
1504        if self.buf.len() < 8 {
1505            self.xs ^= self.xs << 7;
1506            self.xs ^= self.xs >> 9;
1507            return self.xs;
1508        }
1509        let n = self.buf.len();
1510        let mut b = [0u8; 8];
1511        for i in 0..8 {
1512            b[i] = self.buf[self.pos];
1513            self.pos += 1;
1514            if self.pos >= n {
1515                self.pos = 0;
1516            }
1517        }
1518        u64::from_le_bytes(b)
1519    }
1520
1521    #[inline(always)]
1522    fn next_unit(&mut self) -> f64 {
1523        let x = self.next_u64();
1524        ((x >> 11) as f64) * (1.0 / 9007199254740992.0)
1525    }
1526}
1527
1528// Helper for debugging/printing byte sequences if needed, but and
1529// utf8_decode_lossy/utf8_encode are now removed as we follow byte-wise rules.
1530
1531#[derive(Clone, Default)]
1532struct SampleScratch {
1533    idx: Vec<u32>,
1534    logits: Vec<f64>,
1535    exps: Vec<f64>,
1536}
1537
1538impl SampleScratch {
1539    fn ensure(&mut self, alpha_n: usize, n: usize) {
1540        if self.idx.len() != alpha_n {
1541            self.idx.resize(alpha_n, 0);
1542        }
1543        if self.logits.len() < n {
1544            self.logits.resize(n, 0.0);
1545            self.exps.resize(n, 0.0);
1546        }
1547    }
1548}
1549
1550#[derive(Clone)]
1551/// ROSA+ predictive model with optional transactional updates.
1552pub struct RosaPlus {
1553    max_order: i64,
1554    use_eot: bool,
1555    eot: u32,
1556    seed: u64,
1557
1558    sam: Sam,
1559    lm: LM,
1560    lm_built: bool,
1561
1562    rng: RngStream,
1563    scratch: SampleScratch,
1564    dist: Vec<f64>,
1565}
1566
1567/// A lightweight snapshot of the append-only internal SAM buffers.
1568///
1569/// Restoring to a checkpoint is O(1) (via truncation) and is meant to support
1570/// repeated evaluation of different continuations from the same base training state.
1571#[derive(Clone, Copy, Debug)]
1572pub struct RosaCheckpoint {
1573    sam_st_len: usize,
1574    sam_ed_len: usize,
1575    sam_text_len: usize,
1576    sam_text_states_len: usize,
1577    sam_boundary_after_len: usize,
1578    sam_last: SamStateIx,
1579}
1580
1581/// Transaction object used to roll back a temporary conditional update.
1582#[derive(Clone)]
1583pub struct RosaTx {
1584    sam: SamTx,
1585    lm: LmTx,
1586    seg_start: usize,
1587    seg_len: usize,
1588}
1589
1590impl RosaPlus {
1591    /// Create a new ROSA+ model.
1592    ///
1593    /// `max_order < 0` enables adaptive order selection in predictive scoring.
1594    pub fn new(max_order: i64, use_eot: bool, eot_char: u8, seed: u64) -> Self {
1595        let sam = Sam::new(0);
1596        RosaPlus {
1597            max_order,
1598            use_eot,
1599            eot: eot_char as u32,
1600            seed,
1601            sam,
1602            lm: LM::default(),
1603            lm_built: false,
1604            rng: RngStream::new(seed),
1605            scratch: SampleScratch::default(),
1606            dist: Vec::new(),
1607        }
1608    }
1609
1610    /// Train on one byte sequence, optionally appending EOT marker.
1611    pub fn train_example(&mut self, s: &[u8]) {
1612        if s.is_empty() {
1613            return;
1614        }
1615
1616        if self.sam.text.is_empty() {
1617            self.sam = Sam::new(s.len());
1618        }
1619
1620        for &b in s {
1621            self.sam.feed(b as u32);
1622        }
1623
1624        if self.use_eot {
1625            self.sam.feed(self.eot);
1626        }
1627
1628        self.sam.mark_boundary();
1629        self.lm_built = false;
1630    }
1631
1632    /// Reserve append-only buffers for a future byte stream.
1633    ///
1634    /// This keeps compression-side updates from repeatedly growing the same vectors.
1635    pub fn reserve_for_stream(&mut self, additional_bytes: usize) {
1636        self.sam.reserve_additional(additional_bytes);
1637        self.lm.reserve_for_stream(additional_bytes);
1638        self.dist.reserve(BYTE_ALPHA_N);
1639    }
1640
1641    /// Build the language model from current SAM state.
1642    pub fn build_lm(&mut self) {
1643        self.sam.finalize_endpos();
1644        self.lm = LM::default();
1645        self.lm.build_alphabet(&self.sam);
1646        let mo = if self.max_order < 0 {
1647            -1
1648        } else {
1649            self.max_order
1650        };
1651        self.lm.build_counts(&self.sam, mo);
1652        self.lm_built = true;
1653        self.dist.resize(self.lm.alpha_n as usize, 0.0);
1654    }
1655
1656    /// Build the language model without mutating SAM `endpos`.
1657    ///
1658    /// This is useful when you want to reuse a trained SAM as a stable base state
1659    /// (e.g. universal-prior conditioning) and need cheap checkpoint/restore via truncation.
1660    ///
1661    /// Note: entropy/cross-entropy estimation does not require `endpos` finalization.
1662    pub fn build_lm_no_finalize_endpos(&mut self) {
1663        self.lm = LM::default();
1664        self.lm.build_alphabet(&self.sam);
1665        let mo = if self.max_order < 0 {
1666            -1
1667        } else {
1668            self.max_order
1669        };
1670        self.lm.build_counts(&self.sam, mo);
1671        self.lm_built = true;
1672        self.dist.resize(self.lm.alpha_n as usize, 0.0);
1673    }
1674
1675    /// Build an LM with a fixed byte alphabet of size 256.
1676    ///
1677    /// This avoids alphabet growth issues and enables fast incremental updates.
1678    pub fn build_lm_full_bytes_no_finalize_endpos(&mut self) {
1679        // Fixed alphabet
1680        self.lm = LM::default();
1681        self.lm.has_byte_map = true;
1682        self.lm.alpha_n = BYTE_ALPHA_N as u32;
1683        self.lm.alphabet = (0..BYTE_ALPHA_N as u32).collect();
1684        self.lm.byte_map = [-1; 256];
1685        for i in 0..256 {
1686            self.lm.byte_map[i] = i as i16;
1687        }
1688
1689        // Unigram counts
1690        let mut counts = [0u64; 256];
1691        for &v in &self.sam.text {
1692            if v < 256 {
1693                counts[v as usize] += 1;
1694            }
1695        }
1696        self.lm.unigram = counts.to_vec();
1697        self.lm.total_uni = counts.iter().sum();
1698        if self.lm.total_uni == 0 {
1699            for i in 0..256 {
1700                self.lm.unigram[i] = 1;
1701            }
1702            self.lm.total_uni = 256;
1703        }
1704
1705        // Counts
1706        let mo = if self.max_order < 0 {
1707            -1
1708        } else {
1709            self.max_order
1710        };
1711        self.lm.build_counts(&self.sam, mo);
1712        self.lm_built = true;
1713        self.dist.resize(BYTE_ALPHA_N, 0.0);
1714    }
1715
1716    /// Begin a reversible conditional update transaction.
1717    pub fn begin_tx(&mut self) -> RosaTx {
1718        let sam_tx = self.sam.begin_tx();
1719        let lm_tx = LmTx {
1720            old_ls_len: self.lm.ls.len(),
1721            old_nodes_len: self.lm.nodes.len(),
1722            ls_changes: Vec::new(),
1723            node_changes: Vec::new(),
1724            uni_delta: [0u64; BYTE_ALPHA_N],
1725            total_uni_add: 0,
1726        };
1727        RosaTx {
1728            sam: sam_tx,
1729            lm: lm_tx,
1730            seg_start: self.sam.text.len(),
1731            seg_len: 0,
1732        }
1733    }
1734
1735    /// Apply a training example and update LM counts incrementally (byte alphabet must be full 256).
1736    pub fn train_example_tx(&mut self, tx: &mut RosaTx, s: &[u8]) {
1737        self.train_example_tx_impl(tx, s, true);
1738    }
1739
1740    /// Apply a sequential update without inserting a boundary (continuous stream).
1741    pub fn train_sequence_tx(&mut self, tx: &mut RosaTx, s: &[u8]) {
1742        self.train_example_tx_impl(tx, s, false);
1743    }
1744
1745    /// Apply a sequential byte-stream update without rollback bookkeeping.
1746    pub fn train_sequence(&mut self, s: &[u8]) {
1747        if s.is_empty() {
1748            return;
1749        }
1750
1751        if s.len() == 1 {
1752            self.train_byte(s[0]);
1753            return;
1754        }
1755
1756        if self.sam.text.is_empty() {
1757            self.sam = Sam::new(s.len());
1758        }
1759        self.reserve_for_stream(s.len());
1760        if !self.lm_built || !self.lm.has_byte_map || (self.lm.alpha_n as usize) != BYTE_ALPHA_N {
1761            self.build_lm_full_bytes_no_finalize_endpos();
1762        }
1763
1764        if self.lm.ls.len() < self.sam.st.len() {
1765            self.lm.ls.resize(
1766                self.sam.st.len(),
1767                LmState {
1768                    head: LM_NODE_NONE,
1769                    last_node: LM_NODE_NONE,
1770                    ..LmState::default()
1771                },
1772            );
1773        }
1774
1775        let seg_start = self.sam.text.len();
1776        for &b in s {
1777            self.sam.feed(b as u32);
1778            self.lm.unigram[b as usize] += 1;
1779            self.lm.total_uni += 1;
1780        }
1781
1782        if self.lm.ls.len() < self.sam.st.len() {
1783            self.lm.ls.resize(
1784                self.sam.st.len(),
1785                LmState {
1786                    head: LM_NODE_NONE,
1787                    last_node: LM_NODE_NONE,
1788                    ..LmState::default()
1789                },
1790            );
1791        }
1792
1793        let seg_end = self.sam.text.len();
1794        if seg_end.saturating_sub(seg_start) >= 1 {
1795            let mo = if self.max_order < 0 {
1796                -1
1797            } else {
1798                self.max_order
1799            };
1800            let mut start_i = seg_start;
1801            if seg_start > 0
1802                && self
1803                    .sam
1804                    .boundary_after
1805                    .get(seg_start - 1)
1806                    .copied()
1807                    .unwrap_or(0)
1808                    == 0
1809            {
1810                start_i = seg_start - 1;
1811            }
1812            for i in start_i..(seg_end - 1) {
1813                let mut ctx = self.sam.text_states[i + 1];
1814                if mo >= 0 {
1815                    while ctx != SAM_STATE_NONE && (self.sam.st[state_usize(ctx)].len as i64) > mo {
1816                        ctx = self.sam.st[state_usize(ctx)].link;
1817                    }
1818                    if ctx == SAM_STATE_NONE {
1819                        ctx = 0;
1820                    }
1821                }
1822                let nxt = self.sam.text[i + 1];
1823                let si = self.lm.find_sym(nxt);
1824                if si >= 0 {
1825                    let mut u = ctx;
1826                    while u != SAM_STATE_NONE {
1827                        self.lm.inc(state_usize(u) as u32, si as u32, 1);
1828                        u = self.sam.st[state_usize(u)].link;
1829                    }
1830                }
1831            }
1832        }
1833
1834        self.lm_built = true;
1835    }
1836
1837    /// Apply a single byte sequential update without rollback bookkeeping.
1838    #[inline]
1839    pub fn train_byte(&mut self, b: u8) {
1840        if self.sam.text.is_empty() {
1841            self.sam = Sam::new(1);
1842        }
1843        if !self.lm_built || !self.lm.has_byte_map || (self.lm.alpha_n as usize) != BYTE_ALPHA_N {
1844            self.build_lm_full_bytes_no_finalize_endpos();
1845        }
1846
1847        self.sam.feed(b as u32);
1848        self.lm.unigram[b as usize] += 1;
1849        self.lm.total_uni += 1;
1850
1851        if self.lm.ls.len() < self.sam.st.len() {
1852            self.lm.ls.resize(
1853                self.sam.st.len(),
1854                LmState {
1855                    head: LM_NODE_NONE,
1856                    last_node: LM_NODE_NONE,
1857                    ..LmState::default()
1858                },
1859            );
1860        }
1861
1862        let seg_end = self.sam.text.len();
1863        if seg_end > 1
1864            && self
1865                .sam
1866                .boundary_after
1867                .get(seg_end - 2)
1868                .copied()
1869                .unwrap_or(0)
1870                == 0
1871        {
1872            let mo = if self.max_order < 0 {
1873                -1
1874            } else {
1875                self.max_order
1876            };
1877            let mut ctx = self.sam.text_states[seg_end - 1];
1878            if mo >= 0 {
1879                while ctx != SAM_STATE_NONE && (self.sam.st[state_usize(ctx)].len as i64) > mo {
1880                    ctx = self.sam.st[state_usize(ctx)].link;
1881                }
1882                if ctx == SAM_STATE_NONE {
1883                    ctx = 0;
1884                }
1885            }
1886            let mut u = ctx;
1887            let si = b as u32;
1888            while u != SAM_STATE_NONE {
1889                self.lm.inc(state_usize(u) as u32, si, 1);
1890                u = self.sam.st[state_usize(u)].link;
1891            }
1892        }
1893
1894        self.lm_built = true;
1895    }
1896
1897    /// Reset only the predictive cursor while preserving the trained SAM/LM.
1898    pub fn reset_conditioning_cursor(&mut self) {
1899        self.sam.last = 0;
1900    }
1901
1902    /// Advance only the predictive cursor without mutating fitted counts.
1903    pub fn advance_conditioning_byte(&mut self, b: u8) {
1904        self.sam.last = self.sam.advance(self.sam.last, b as u32);
1905    }
1906
1907    fn train_example_tx_impl(&mut self, tx: &mut RosaTx, s: &[u8], mark_boundary: bool) {
1908        if s.is_empty() {
1909            return;
1910        }
1911
1912        // Ensure LS has entries for current states.
1913        if self.lm.ls.len() < self.sam.st.len() {
1914            self.lm.ls.resize(
1915                self.sam.st.len(),
1916                LmState {
1917                    head: LM_NODE_NONE,
1918                    last_node: LM_NODE_NONE,
1919                    ..LmState::default()
1920                },
1921            );
1922        }
1923
1924        // Feed all bytes (SAM structure changes are logged).
1925        for &b in s {
1926            self.sam.feed_tx(&mut tx.sam, b as u32);
1927            tx.lm.uni_delta[b as usize] += 1;
1928            tx.lm.total_uni_add += 1;
1929        }
1930        if mark_boundary {
1931            self.sam.mark_boundary_tx(&mut tx.sam);
1932        }
1933
1934        // LM must be built for scoring; we keep it built and update counts incrementally.
1935        // Extend ls for any new SAM states created by feeding.
1936        if self.lm.ls.len() < self.sam.st.len() {
1937            self.lm.ls.resize(
1938                self.sam.st.len(),
1939                LmState {
1940                    head: LM_NODE_NONE,
1941                    last_node: LM_NODE_NONE,
1942                    ..LmState::default()
1943                },
1944            );
1945        }
1946
1947        // Update unigram counts (fixed 256 alphabet assumed).
1948        for i in 0..256 {
1949            if tx.lm.uni_delta[i] != 0 {
1950                self.lm.unigram[i] += tx.lm.uni_delta[i];
1951            }
1952        }
1953        self.lm.total_uni += tx.lm.total_uni_add;
1954
1955        // Update conditional counts for the new segment only.
1956        let seg_start = tx.seg_start;
1957        let seg_end = self.sam.text.len();
1958        tx.seg_len = seg_end - seg_start;
1959        if tx.seg_len >= 1 {
1960            let mo = if self.max_order < 0 {
1961                -1
1962            } else {
1963                self.max_order
1964            };
1965            // For continuous streams, include the cross-boundary transition from the
1966            // previous symbol into the first new symbol. For segmented examples,
1967            // respect boundary markers and skip that transition.
1968            let mut start_i = seg_start;
1969            if !mark_boundary
1970                && seg_start > 0
1971                && self
1972                    .sam
1973                    .boundary_after
1974                    .get(seg_start - 1)
1975                    .copied()
1976                    .unwrap_or(0)
1977                    == 0
1978            {
1979                start_i = seg_start - 1;
1980            }
1981            for i in start_i..(seg_end - 1) {
1982                // ctx state after consuming sam.text[i] within its segment
1983                let mut ctx = self.sam.text_states[i + 1];
1984                if mo >= 0 {
1985                    while ctx != SAM_STATE_NONE && (self.sam.st[state_usize(ctx)].len as i64) > mo {
1986                        ctx = self.sam.st[state_usize(ctx)].link;
1987                    }
1988                    if ctx == SAM_STATE_NONE {
1989                        ctx = 0;
1990                    }
1991                }
1992                let nxt = self.sam.text[i + 1];
1993                let si = self.lm.find_sym(nxt);
1994                if si >= 0 {
1995                    let mut u = ctx;
1996                    while u != SAM_STATE_NONE {
1997                        self.lm
1998                            .inc_tx(&mut tx.lm, state_usize(u) as u32, si as u32, 1);
1999                        u = self.sam.st[state_usize(u)].link;
2000                    }
2001                }
2002            }
2003        }
2004
2005        self.lm_built = true;
2006    }
2007
2008    /// Roll back a transaction, restoring the model to the exact state at begin_tx.
2009    pub fn rollback_tx(&mut self, tx: RosaTx) {
2010        // Restore LM changes
2011        // Unigram rollback
2012        if self.lm.unigram.len() >= BYTE_ALPHA_N {
2013            for i in 0..BYTE_ALPHA_N {
2014                let d = tx.lm.uni_delta[i];
2015                if d != 0 {
2016                    self.lm.unigram[i] = self.lm.unigram[i].saturating_sub(d);
2017                }
2018            }
2019            self.lm.total_uni = self.lm.total_uni.saturating_sub(tx.lm.total_uni_add);
2020        }
2021
2022        for (idx, old) in tx.lm.node_changes.into_iter().rev() {
2023            if idx < self.lm.nodes.len() {
2024                self.lm.nodes.set(idx, old);
2025            }
2026        }
2027        for (idx, old) in tx.lm.ls_changes.into_iter().rev() {
2028            if idx < self.lm.ls.len() {
2029                self.lm.ls[idx] = old;
2030            }
2031        }
2032        self.lm.nodes.truncate(tx.lm.old_nodes_len);
2033        self.lm.ls.truncate(tx.lm.old_ls_len);
2034
2035        // Restore SAM
2036        self.sam.rollback_tx(tx.sam);
2037        // lm_built remains true if it was true before; safe to keep true.
2038    }
2039
2040    /// Ensure the LM is built (without mutating SAM endpos).
2041    #[inline(always)]
2042    pub fn ensure_lm_built_no_finalize_endpos(&mut self) {
2043        if !self.lm_built {
2044            self.build_lm_no_finalize_endpos();
2045        }
2046    }
2047
2048    fn predictive_entropy_rate_order(data: &[u8], max_order: i64, seed: u64) -> f64 {
2049        if data.len() < 2 {
2050            return 0.0;
2051        }
2052        let num_chunks = 16;
2053        let chunk_size = data.len().div_ceil(num_chunks);
2054        let mut total_log_prob = 0.0f64;
2055        let mut count = 0usize;
2056        let mut m = RosaPlus::new(max_order, false, 0, seed);
2057        m.sam = Sam::new(data.len());
2058        m.lm_built = false;
2059
2060        for i in 0..num_chunks {
2061            let start = i * chunk_size;
2062            let end = ((i + 1) * chunk_size).min(data.len());
2063            if start >= end {
2064                break;
2065            }
2066            let chunk = &data[start..end];
2067            if i > 0 {
2068                m.build_lm_no_finalize_endpos();
2069                let mut v = 0;
2070                for &b in chunk {
2071                    let sym_idx = m.lm.find_sym(b as u32);
2072                    let p = m.lm.prob_for_sym(&m.sam, max_order, v, sym_idx);
2073                    total_log_prob += p.log2();
2074                    count += 1;
2075                    v = m.sam.advance(v, b as u32);
2076                }
2077            }
2078            for &b in chunk {
2079                m.sam.feed(b as u32);
2080            }
2081        }
2082
2083        if count == 0 {
2084            m.train_example(data);
2085            m.build_lm();
2086            m.cross_entropy(data)
2087        } else {
2088            -total_log_prob / (count as f64)
2089        }
2090    }
2091
2092    /// Current LM alphabet size (0 if LM not built).
2093    pub fn lm_alpha_n(&self) -> usize {
2094        if !self.lm_built {
2095            0
2096        } else {
2097            self.lm.alpha_n as usize
2098        }
2099    }
2100
2101    /// Approximate in-memory footprint of major model buffers.
2102    pub fn estimated_size_bytes(&self) -> usize {
2103        use std::mem::size_of;
2104
2105        let mut n = 0usize;
2106
2107        n = n.saturating_add(self.sam.st.len().saturating_mul(size_of::<SamState>()));
2108        n = n.saturating_add(self.sam.ed.len().saturating_mul(size_of::<SamEdge>()));
2109        n = n.saturating_add(self.sam.text.len().saturating_mul(size_of::<u32>()));
2110        n = n.saturating_add(
2111            self.sam
2112                .text_states
2113                .len()
2114                .saturating_mul(size_of::<SamStateIx>()),
2115        );
2116        n = n.saturating_add(size_of::<[SamStateIx; BYTE_ALPHA_N]>());
2117        n = n.saturating_add(
2118            self.sam
2119                .boundary_after
2120                .len()
2121                .saturating_mul(size_of::<u8>()),
2122        );
2123
2124        n = n.saturating_add(self.lm.alphabet.len().saturating_mul(size_of::<u32>()));
2125        n = n.saturating_add(self.lm.unigram.len().saturating_mul(size_of::<u64>()));
2126        n = n.saturating_add(self.lm.ls.len().saturating_mul(size_of::<LmState>()));
2127        n = n.saturating_add(self.lm.nodes.sym_lo.len().saturating_mul(size_of::<u16>()));
2128        n = n.saturating_add(self.lm.nodes.cnt_lo.len().saturating_mul(size_of::<u16>()));
2129        n = n.saturating_add(
2130            self.lm
2131                .nodes
2132                .next
2133                .len()
2134                .saturating_mul(size_of::<LmNodeIx>()),
2135        );
2136        n = n.saturating_add(
2137            self.lm
2138                .nodes
2139                .cnt_overflow_mask
2140                .len()
2141                .saturating_mul(size_of::<u8>()),
2142        );
2143        n = n.saturating_add(
2144            self.lm
2145                .nodes
2146                .sym_overflow
2147                .len()
2148                .saturating_mul(size_of::<u32>() + size_of::<u32>()),
2149        );
2150        n = n.saturating_add(
2151            self.lm
2152                .nodes
2153                .cnt_overflow
2154                .len()
2155                .saturating_mul(size_of::<u32>() + size_of::<u64>()),
2156        );
2157
2158        n = n.saturating_add(self.dist.len().saturating_mul(size_of::<f64>()));
2159        n = n.saturating_add(self.scratch.idx.len().saturating_mul(size_of::<u32>()));
2160        n = n.saturating_add(self.scratch.logits.len().saturating_mul(size_of::<f64>()));
2161        n = n.saturating_add(self.scratch.exps.len().saturating_mul(size_of::<f64>()));
2162        n = n.saturating_add(self.rng.buf.len().saturating_mul(size_of::<u8>()));
2163
2164        n
2165    }
2166
2167    /// Shrink auxiliary scratch buffers to fit current usage.
2168    pub fn shrink_aux_buffers(&mut self) {
2169        self.dist.shrink_to_fit();
2170        self.scratch.idx.shrink_to_fit();
2171        self.scratch.logits.shrink_to_fit();
2172        self.scratch.exps.shrink_to_fit();
2173        self.rng.buf.shrink_to_fit();
2174    }
2175
2176    /// Create a new model that shares the same trained SAM state but resets LM-related buffers.
2177    ///
2178    /// This is substantially cheaper than cloning the full `RosaPlus` (which includes LM counts,
2179    /// node tables, and distribution buffers) and is safe for workflows that want to start from
2180    /// a fixed base training text (e.g. a universal prior) and then add candidate-specific text.
2181    pub fn fork_from_sam(&self) -> Self {
2182        Self {
2183            max_order: self.max_order,
2184            use_eot: self.use_eot,
2185            eot: self.eot,
2186            seed: self.seed,
2187
2188            sam: self.sam.clone(),
2189            lm: LM::default(),
2190            lm_built: false,
2191
2192            rng: RngStream::new(self.seed),
2193            scratch: SampleScratch::default(),
2194            dist: Vec::new(),
2195        }
2196    }
2197
2198    /// A checkpoint that allows restoring the ROSA model back to a previous trained state
2199    /// by truncating append-only internal buffers.
2200    ///
2201    /// Intended for workflows that repeatedly evaluate different continuations from the same base
2202    /// training text (e.g. universal-prior conditioned scoring).
2203    pub fn checkpoint(&self) -> RosaCheckpoint {
2204        RosaCheckpoint {
2205            sam_st_len: self.sam.st.len(),
2206            sam_ed_len: self.sam.ed.len(),
2207            sam_text_len: self.sam.text.len(),
2208            sam_text_states_len: self.sam.text_states.len(),
2209            sam_boundary_after_len: self.sam.boundary_after.len(),
2210            sam_last: self.sam.last,
2211        }
2212    }
2213
2214    /// Restore the model to a previously captured checkpoint.
2215    ///
2216    /// This invalidates the LM; callers should rebuild it before scoring.
2217    pub fn restore(&mut self, ck: &RosaCheckpoint) {
2218        self.sam.st.truncate(ck.sam_st_len);
2219        self.sam.ed.truncate(ck.sam_ed_len);
2220        self.sam.text.truncate(ck.sam_text_len);
2221        self.sam.text_states.truncate(ck.sam_text_states_len);
2222        self.sam.boundary_after.truncate(ck.sam_boundary_after_len);
2223        self.sam.last = ck.sam_last;
2224        self.lm_built = false;
2225    }
2226
2227    #[inline(always)]
2228    fn sample(&mut self, temperature: f64, top_p: f64, top_k: i32) -> u32 {
2229        let dist = &self.dist;
2230        let alpha_n = self.lm.alpha_n as usize;
2231        self.scratch.ensure(alpha_n, alpha_n);
2232        for i in 0..alpha_n {
2233            self.scratch.idx[i] = i as u32;
2234        }
2235
2236        // O(n^2) sort by dist desc then idx asc (matches C).
2237        for i in 0..alpha_n {
2238            for j in (i + 1)..alpha_n {
2239                let ii = self.scratch.idx[i] as usize;
2240                let jj = self.scratch.idx[j] as usize;
2241                let pi = dist[ii];
2242                let pj = dist[jj];
2243                if pj > pi || (pj == pi && jj < ii) {
2244                    self.scratch.idx.swap(i, j);
2245                }
2246            }
2247        }
2248
2249        let mut n = alpha_n;
2250        if top_k > 0 {
2251            let k = top_k as usize;
2252            if k < n {
2253                n = k;
2254            }
2255        }
2256
2257        if top_p > 0.0 && top_p < 1.0 {
2258            let mut cum = 0.0;
2259            let mut cut = 0usize;
2260            for i in 0..n {
2261                let si = self.scratch.idx[i] as usize;
2262                cum += dist[si];
2263                cut += 1;
2264                if cum >= top_p {
2265                    break;
2266                }
2267            }
2268            n = if cut > 0 { cut } else { 1 };
2269        }
2270
2271        let temperature = if temperature <= 0.0 {
2272            1e-6
2273        } else {
2274            temperature
2275        };
2276
2277        self.scratch.ensure(alpha_n, n);
2278        let mut maxlog = -1e300f64;
2279        for i in 0..n {
2280            let si = self.scratch.idx[i] as usize;
2281            let mut p = dist[si];
2282            if p < 1e-12 {
2283                p = 1e-12;
2284            }
2285            let z = p.ln() / temperature;
2286            self.scratch.logits[i] = z;
2287            if z > maxlog {
2288                maxlog = z;
2289            }
2290        }
2291
2292        let mut zsum = 0.0;
2293        for i in 0..n {
2294            let e = (self.scratch.logits[i] - maxlog).exp();
2295            self.scratch.exps[i] = e;
2296            zsum += e;
2297        }
2298
2299        let r = self.rng.next_unit() * zsum;
2300        let mut cum = 0.0;
2301        let mut pick = 0usize;
2302        for i in 0..n {
2303            cum += self.scratch.exps[i];
2304            if cum > r {
2305                pick = i;
2306                break;
2307            }
2308        }
2309
2310        let sym = self.scratch.idx[pick] as usize;
2311        self.lm.alphabet[sym]
2312    }
2313
2314    /// Generate continuation bytes from a prompt.
2315    ///
2316    /// Returns `None` if LM is not built yet.
2317    pub fn generate(&mut self, prompt: &[u8], steps: i32) -> Option<Vec<u8>> {
2318        if !self.lm_built {
2319            return None;
2320        }
2321        let steps = steps.max(0) as usize;
2322
2323        let mut v = 0i32;
2324        for &b in prompt {
2325            v = self.sam.advance(v, b as u32);
2326        }
2327
2328        let mut out: Vec<u32> = Vec::with_capacity(steps);
2329
2330        for _ in 0..steps {
2331            let mut ch = self.sam.predict_det(v);
2332            if ch.is_none() {
2333                let mo = if self.max_order < 0 {
2334                    -1
2335                } else {
2336                    self.max_order
2337                };
2338                self.lm.probs_for_state(&self.sam, mo, v, &mut self.dist);
2339                ch = Some(self.sample(0.7, 0.9, 0));
2340            }
2341            let ch = ch.unwrap();
2342            out.push(ch);
2343            if self.use_eot && ch == self.eot {
2344                break;
2345            }
2346            v = self.sam.advance(v, ch);
2347        }
2348
2349        Some(out.iter().map(|&c| c as u8).collect())
2350    }
2351
2352    // ========== Entropy Estimation API ==========
2353
2354    /// Returns the probability distribution for the next symbol given a context.
2355    /// Output: Vec of (codepoint, probability) pairs, sorted by codepoint.
2356    /// Builds the LM if not already built.
2357    pub fn get_distribution(&mut self, context: &[u8]) -> Vec<(u32, f64)> {
2358        if !self.lm_built {
2359            self.build_lm();
2360        }
2361
2362        // Advance through context to get SAM state
2363        let mut v = 0i32;
2364        for &b in context {
2365            v = self.sam.advance(v, b as u32);
2366        }
2367
2368        // Get probability distribution at this state
2369        let mo = if self.max_order < 0 {
2370            -1
2371        } else {
2372            self.max_order
2373        };
2374        self.dist.resize(self.lm.alpha_n as usize, 0.0);
2375        self.lm.probs_for_state(&self.sam, mo, v, &mut self.dist);
2376
2377        // Build output as (codepoint, probability) pairs
2378        let mut result = Vec::with_capacity(self.lm.alpha_n as usize);
2379        for i in 0..(self.lm.alpha_n as usize) {
2380            if self.dist[i] > 0.0 {
2381                result.push((self.lm.alphabet[i], self.dist[i]));
2382            }
2383        }
2384        result.sort_by_key(|&(cp, _)| cp);
2385        result
2386    }
2387
2388    /// Compute the predictive entropy rate (bits per symbol) of the given data.
2389    ///
2390    /// Uses chunked prequential scoring (train on past chunks, score next chunk).
2391    pub fn predictive_entropy_rate(&mut self, data: &[u8]) -> f64 {
2392        if data.len() < 2 {
2393            return 0.0;
2394        }
2395        if self.max_order < 0 {
2396            let candidates: [i64; 8] = [0, 1, 2, 4, 8, 16, 32, 64];
2397            let mut best = f64::INFINITY;
2398            for &mo in &candidates {
2399                if mo as usize >= data.len() {
2400                    continue;
2401                }
2402                let h = Self::predictive_entropy_rate_order(data, mo, self.seed);
2403                if h < best {
2404                    best = h;
2405                }
2406            }
2407            if best.is_finite() {
2408                return best;
2409            }
2410        }
2411        Self::predictive_entropy_rate_order(data, self.max_order, self.seed)
2412    }
2413
2414    /// Predictive entropy rate on codepoint streams.
2415    pub fn entropy_rate_cps(&mut self, cps: &[u32]) -> f64 {
2416        if cps.len() < 2 {
2417            return 0.0;
2418        }
2419
2420        self.sam = Sam::new(cps.len());
2421        self.lm_built = false;
2422
2423        let num_chunks = 16;
2424        let chunk_size = cps.len().div_ceil(num_chunks);
2425        let mut total_log_prob = 0.0f64;
2426        let mut count = 0usize;
2427
2428        for i in 0..num_chunks {
2429            let start = i * chunk_size;
2430            let end = ((i + 1) * chunk_size).min(cps.len());
2431            if start >= end {
2432                break;
2433            }
2434            let chunk = &cps[start..end];
2435            if i > 0 {
2436                // Avoid endpos finalization since we continue mutating the SAM across chunks.
2437                self.build_lm_no_finalize_endpos();
2438                let mut v = self.sam.text_states[start];
2439                for &ch in chunk {
2440                    let sym_idx = self.lm.find_sym(ch);
2441                    let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
2442                    total_log_prob += p.log2();
2443                    count += 1;
2444                    v = self.sam.advance(v, ch);
2445                }
2446            }
2447            for &ch in chunk {
2448                self.sam.feed(ch);
2449            }
2450        }
2451
2452        if count == 0 {
2453            self.build_lm();
2454            self.entropy_rate_plugin_cps(cps)
2455        } else {
2456            -total_log_prob / (count as f64)
2457        }
2458    }
2459
2460    #[allow(dead_code)]
2461    fn entropy_rate_plugin_bytes(&mut self, data: &[u8]) -> f64 {
2462        let mut v = 0i32;
2463        let mut total_log_prob = 0.0f64;
2464        let mut count = 0usize;
2465        for t in 0..(data.len() - 1) {
2466            v = self.sam.advance(v, data[t] as u32);
2467            let next_ch = data[t + 1] as u32;
2468            let sym_idx = self.lm.find_sym(next_ch);
2469            let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
2470            total_log_prob += p.log2();
2471            count += 1;
2472        }
2473        if count == 0 {
2474            0.0
2475        } else {
2476            -total_log_prob / (count as f64)
2477        }
2478    }
2479
2480    fn entropy_rate_plugin_cps(&mut self, cps: &[u32]) -> f64 {
2481        let mut v = 0i32;
2482        let mut total_log_prob = 0.0f64;
2483        let mut count = 0usize;
2484        for t in 0..(cps.len() - 1) {
2485            v = self.sam.advance(v, cps[t]);
2486            let next_ch = cps[t + 1];
2487            let sym_idx = self.lm.find_sym(next_ch);
2488            let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
2489            total_log_prob += p.log2();
2490            count += 1;
2491        }
2492        if count == 0 {
2493            0.0
2494        } else {
2495            -total_log_prob / (count as f64)
2496        }
2497    }
2498
2499    /// Cross entropy of byte data under current LM state.
2500    pub fn cross_entropy(&self, data: &[u8]) -> f64 {
2501        if !self.lm_built || data.is_empty() {
2502            return 0.0;
2503        }
2504        let mut total_log_prob = 0.0f64;
2505        let mut v = 0i32;
2506        for &b in data {
2507            let ch = b as u32;
2508            let sym_idx = self.lm.find_sym(ch);
2509            let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
2510            total_log_prob += p.log2();
2511            v = self.sam.advance(v, ch);
2512        }
2513        -total_log_prob / (data.len() as f64)
2514    }
2515
2516    /// Cross entropy of codepoint data under current LM state.
2517    pub fn cross_entropy_cps(&self, data: &[u32]) -> f64 {
2518        if !self.lm_built || data.is_empty() {
2519            return 0.0;
2520        }
2521        let mut total_log_prob = 0.0f64;
2522        let mut v = 0i32;
2523        for &ch in data {
2524            let sym_idx = self.lm.find_sym(ch);
2525            let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
2526            total_log_prob += p.log2();
2527            v = self.sam.advance(v, ch);
2528        }
2529        -total_log_prob / (data.len() as f64)
2530    }
2531
2532    /// Returns the marginal (unigram) distribution over the training data.
2533    /// Output: Vec of (codepoint, probability) pairs, sorted by codepoint.
2534    pub fn marginal_distribution(&self) -> Vec<(u32, f64)> {
2535        if self.lm.total_uni == 0 {
2536            return Vec::new();
2537        }
2538
2539        let inv = 1.0 / (self.lm.total_uni as f64);
2540        let mut result = Vec::with_capacity(self.lm.alpha_n as usize);
2541        for i in 0..(self.lm.alpha_n as usize) {
2542            let p = (self.lm.unigram[i] as f64) * inv;
2543            if p > 0.0 {
2544                result.push((self.lm.alphabet[i], p));
2545            }
2546        }
2547        result.sort_by_key(|&(cp, _)| cp);
2548        result
2549    }
2550
2551    /// Compute the marginal entropy H(X) from the unigram distribution.
2552    /// Returns bits per symbol.
2553    pub fn marginal_entropy(&self) -> f64 {
2554        if self.lm.total_uni == 0 {
2555            return 0.0;
2556        }
2557
2558        let inv = 1.0 / (self.lm.total_uni as f64);
2559        let mut h = 0.0f64;
2560        for i in 0..(self.lm.alpha_n as usize) {
2561            let p = (self.lm.unigram[i] as f64) * inv;
2562            if p > 0.0 {
2563                h -= p * p.log2();
2564            }
2565        }
2566        h
2567    }
2568
2569    /// Persist trained SAM+LM state to disk.
2570    pub fn save(&self, path: &str) -> std::io::Result<()> {
2571        if !self.lm_built {
2572            return Err(std::io::Error::other("LM not built"));
2573        }
2574
2575        // Transactional conditional updates require a valid prefix-state trace.
2576        // If this invariant is violated, the loaded model would be unusable.
2577        if self.sam.text_states.len() != self.sam.text.len() + 1 {
2578            return Err(std::io::Error::other(
2579                "SAM text_states mismatch (expected text.len()+1)",
2580            ));
2581        }
2582        let mut f = BufWriter::with_capacity(1024 * 1024, File::create(path)?);
2583        f.write_all(MAGIC_V5)?;
2584        f.write_all(&self.max_order.to_le_bytes())?;
2585        f.write_all(&(self.use_eot as i32).to_le_bytes())?;
2586        f.write_all(&self.eot.to_le_bytes())?;
2587        f.write_all(&self.seed.to_le_bytes())?;
2588
2589        // SAM
2590        write_len64(&mut f, self.sam.st.len())?;
2591        write_len64(&mut f, self.sam.ed.len())?;
2592        write_len64(&mut f, self.sam.text.len())?;
2593        for st in &self.sam.st {
2594            f.write_all(&st.link.to_le_bytes())?;
2595            f.write_all(&st.len.to_le_bytes())?;
2596            f.write_all(&st.endpos.to_le_bytes())?;
2597            f.write_all(&(st.small_n as u32).to_le_bytes())?;
2598            for k in 0..(st.small_n as usize) {
2599                f.write_all(&st.small_ch[k].to_le_bytes())?;
2600                f.write_all(&st.small_to[k].to_le_bytes())?;
2601            }
2602            f.write_all(&st.head.to_le_bytes())?;
2603        }
2604        for e in &self.sam.ed {
2605            f.write_all(&e.ch.to_le_bytes())?;
2606            f.write_all(&e.to.to_le_bytes())?;
2607            f.write_all(&e.next.to_le_bytes())?;
2608        }
2609        write_u32_slice_le(&mut f, &self.sam.text)?;
2610        f.write_all(&self.sam.boundary_after)?;
2611
2612        // Persist SAM cursor + prefix trace.
2613        f.write_all(&self.sam.last.to_le_bytes())?;
2614        write_len64(&mut f, self.sam.text_states.len())?;
2615        write_i32_slice_le(&mut f, &self.sam.text_states)?;
2616
2617        // LM
2618        f.write_all(&self.lm.alpha_n.to_le_bytes())?;
2619        f.write_all(&self.lm.total_uni.to_le_bytes())?;
2620        write_len64(&mut f, self.lm.nodes.len())?;
2621        write_u32_slice_le(&mut f, &self.lm.alphabet)?;
2622        write_u64_slice_le(&mut f, &self.lm.unigram)?;
2623        for ls in &self.lm.ls {
2624            f.write_all(&ls.head.to_le_bytes())?;
2625            f.write_all(&ls.total_n.to_le_bytes())?;
2626            f.write_all(&ls.types_t.to_le_bytes())?;
2627            f.write_all(&ls.last_sym.to_le_bytes())?;
2628            f.write_all(&ls.last_node.to_le_bytes())?;
2629        }
2630        for n in &self.lm.nodes {
2631            f.write_all(&n.sym_idx.to_le_bytes())?;
2632            f.write_all(&n.cnt.to_le_bytes())?;
2633            f.write_all(&n.next.to_le_bytes())?;
2634        }
2635        f.flush()?;
2636        Ok(())
2637    }
2638
2639    /// Load a previously saved ROSA+ model from disk.
2640    pub fn load(path: &str) -> std::io::Result<Self> {
2641        let mut f = BufReader::with_capacity(1024 * 1024, File::open(path)?);
2642        let mut magic = vec![0u8; MAGIC_V5.len()];
2643        f.read_exact(&mut magic)?;
2644        if magic != MAGIC_V5 {
2645            return Err(std::io::Error::new(
2646                std::io::ErrorKind::InvalidData,
2647                "bad magic or unsupported ROSA+ model version",
2648            ));
2649        }
2650
2651        let mut b8 = [0u8; 8];
2652        let mut b4 = [0u8; 4];
2653
2654        f.read_exact(&mut b8)?;
2655        let max_order = i64::from_le_bytes(b8);
2656        f.read_exact(&mut b4)?;
2657        let use_eot = i32::from_le_bytes(b4) != 0;
2658        f.read_exact(&mut b4)?;
2659        let eot = u32::from_le_bytes(b4);
2660        f.read_exact(&mut b8)?;
2661        let seed = u64::from_le_bytes(b8);
2662
2663        let mut m = RosaPlus::new(max_order, use_eot, eot as u8, seed);
2664
2665        // SAM
2666        let st_n = read_len64(&mut f)?;
2667        let ed_n = read_len64(&mut f)?;
2668        let text_n = read_len64(&mut f)?;
2669
2670        m.sam = Sam::new(text_n);
2671        m.sam.st.resize(st_n, SamState::default());
2672        m.sam.ed.resize(ed_n, SamEdge::default());
2673        m.sam.text.resize(text_n, 0u32);
2674        m.sam.boundary_after.resize(text_n, 0u8);
2675
2676        for i in 0..st_n {
2677            f.read_exact(&mut b4)?;
2678            m.sam.st[i].link = i32::from_le_bytes(b4);
2679            f.read_exact(&mut b4)?;
2680            m.sam.st[i].len = i32::from_le_bytes(b4);
2681            f.read_exact(&mut b4)?;
2682            m.sam.st[i].endpos = i32::from_le_bytes(b4);
2683            f.read_exact(&mut b4)?;
2684            let sn = u32::from_le_bytes(b4) as usize;
2685            if sn > SAM_SMALL_MAX {
2686                return Err(std::io::Error::new(
2687                    std::io::ErrorKind::InvalidData,
2688                    "bad small_n",
2689                ));
2690            }
2691            m.sam.st[i].small_n = sn as u8;
2692            for k in 0..sn {
2693                f.read_exact(&mut b4)?;
2694                m.sam.st[i].small_ch[k] = u32::from_le_bytes(b4);
2695                f.read_exact(&mut b4)?;
2696                m.sam.st[i].small_to[k] = i32::from_le_bytes(b4);
2697            }
2698            f.read_exact(&mut b4)?;
2699            m.sam.st[i].head = u32::from_le_bytes(b4);
2700        }
2701        for i in 0..ed_n {
2702            f.read_exact(&mut b4)?;
2703            m.sam.ed[i].ch = u32::from_le_bytes(b4);
2704            f.read_exact(&mut b4)?;
2705            m.sam.ed[i].to = i32::from_le_bytes(b4);
2706            f.read_exact(&mut b4)?;
2707            m.sam.ed[i].next = u32::from_le_bytes(b4);
2708        }
2709        read_u32_slice_le(&mut f, &mut m.sam.text)?;
2710        f.read_exact(&mut m.sam.boundary_after)?;
2711
2712        // SAM cursor + prefix trace.
2713        f.read_exact(&mut b4)?;
2714        m.sam.last = i32::from_le_bytes(b4);
2715        let text_states_n = read_len64(&mut f)?;
2716        if text_states_n != text_n + 1 {
2717            return Err(std::io::Error::new(
2718                std::io::ErrorKind::InvalidData,
2719                "bad text_states len",
2720            ));
2721        }
2722        m.sam.text_states.resize(text_states_n, 0);
2723        read_i32_slice_le(&mut f, &mut m.sam.text_states)?;
2724        for &v in &m.sam.text_states {
2725            if v < 0 || state_usize(v) >= st_n {
2726                return Err(std::io::Error::new(
2727                    std::io::ErrorKind::InvalidData,
2728                    "bad text_states entry",
2729                ));
2730            }
2731        }
2732        if m.sam.last < 0 || state_usize(m.sam.last) >= st_n {
2733            return Err(std::io::Error::new(
2734                std::io::ErrorKind::InvalidData,
2735                "bad sam.last",
2736            ));
2737        }
2738        for st in &m.sam.st {
2739            if st.link != SAM_STATE_NONE && state_usize(st.link) >= st_n {
2740                return Err(std::io::Error::new(
2741                    std::io::ErrorKind::InvalidData,
2742                    "bad sam link",
2743                ));
2744            }
2745            for k in 0..(st.small_n as usize) {
2746                let to = st.small_to[k];
2747                if to < 0 || state_usize(to) >= st_n {
2748                    return Err(std::io::Error::new(
2749                        std::io::ErrorKind::InvalidData,
2750                        "bad sam small edge",
2751                    ));
2752                }
2753            }
2754            if st.head != SAM_EDGE_NONE && edge_usize(st.head) >= ed_n {
2755                return Err(std::io::Error::new(
2756                    std::io::ErrorKind::InvalidData,
2757                    "bad sam edge head",
2758                ));
2759            }
2760        }
2761        for edge in &m.sam.ed {
2762            if edge.to < 0 || state_usize(edge.to) >= st_n {
2763                return Err(std::io::Error::new(
2764                    std::io::ErrorKind::InvalidData,
2765                    "bad sam edge target",
2766                ));
2767            }
2768            if edge.next != SAM_EDGE_NONE && edge_usize(edge.next) >= ed_n {
2769                return Err(std::io::Error::new(
2770                    std::io::ErrorKind::InvalidData,
2771                    "bad sam edge next",
2772                ));
2773            }
2774        }
2775        m.sam.rebuild_root_cache();
2776
2777        // LM
2778        f.read_exact(&mut b4)?;
2779        let alpha_n = u32::from_le_bytes(b4) as usize;
2780        f.read_exact(&mut b8)?;
2781        let total_uni = u64::from_le_bytes(b8);
2782        let nodes_n = read_len64(&mut f)?;
2783
2784        m.lm = LM::default();
2785        m.lm.alpha_n = alpha_n as u32;
2786        m.lm.total_uni = total_uni;
2787        m.lm.alphabet.resize(alpha_n, 0);
2788        m.lm.unigram.resize(alpha_n, 0);
2789        m.lm.ls = vec![
2790            LmState {
2791                head: LM_NODE_NONE,
2792                last_node: LM_NODE_NONE,
2793                ..LmState::default()
2794            };
2795            st_n
2796        ];
2797        m.lm.nodes.resize(nodes_n, CountNode::default());
2798
2799        read_u32_slice_le(&mut f, &mut m.lm.alphabet)?;
2800        read_u64_slice_le(&mut f, &mut m.lm.unigram)?;
2801        for i in 0..st_n {
2802            f.read_exact(&mut b4)?;
2803            m.lm.ls[i].head = u32::from_le_bytes(b4);
2804            f.read_exact(&mut b8)?;
2805            m.lm.ls[i].total_n = u64::from_le_bytes(b8);
2806            f.read_exact(&mut b4)?;
2807            m.lm.ls[i].types_t = u32::from_le_bytes(b4);
2808            f.read_exact(&mut b4)?;
2809            m.lm.ls[i].last_sym = u32::from_le_bytes(b4);
2810            f.read_exact(&mut b4)?;
2811            m.lm.ls[i].last_node = u32::from_le_bytes(b4);
2812        }
2813        for i in 0..nodes_n {
2814            f.read_exact(&mut b4)?;
2815            let sym_idx = u32::from_le_bytes(b4);
2816            f.read_exact(&mut b8)?;
2817            let cnt = u64::from_le_bytes(b8);
2818            f.read_exact(&mut b4)?;
2819            let next = u32::from_le_bytes(b4);
2820            m.lm.nodes.set(i, CountNode { sym_idx, cnt, next });
2821        }
2822        for ls in &m.lm.ls {
2823            if ls.head != LM_NODE_NONE && node_usize(ls.head) >= nodes_n {
2824                return Err(std::io::Error::new(
2825                    std::io::ErrorKind::InvalidData,
2826                    "bad lm head",
2827                ));
2828            }
2829            if ls.last_node != LM_NODE_NONE && node_usize(ls.last_node) >= nodes_n {
2830                return Err(std::io::Error::new(
2831                    std::io::ErrorKind::InvalidData,
2832                    "bad lm last_node",
2833                ));
2834            }
2835        }
2836        for node in &m.lm.nodes {
2837            if node.next != LM_NODE_NONE && node_usize(node.next) >= nodes_n {
2838                return Err(std::io::Error::new(
2839                    std::io::ErrorKind::InvalidData,
2840                    "bad lm next",
2841                ));
2842            }
2843        }
2844
2845        // rebuild byte_map for lookups
2846        m.lm.has_byte_map = false;
2847        m.lm.byte_map = [-1; 256];
2848        let mut max_cp = 0u32;
2849        for &v in &m.lm.alphabet {
2850            if v > max_cp {
2851                max_cp = v;
2852            }
2853        }
2854        if max_cp < 256 {
2855            m.lm.has_byte_map = true;
2856            for (i, &c) in m.lm.alphabet.iter().enumerate() {
2857                m.lm.byte_map[c as usize] = i as i16;
2858            }
2859        }
2860
2861        m.lm_built = true;
2862        m.dist.resize(alpha_n, 0.0);
2863        Ok(m)
2864    }
2865
2866    /// Probability of `sym` from current SAM cursor (`sam.last`).
2867    pub fn prob_for_last(&mut self, sym: u32) -> f64 {
2868        if !self.lm_built {
2869            self.build_lm();
2870        }
2871        let v = self.sam.last;
2872        let sym_idx = self.lm.find_sym(sym);
2873        let mo = if self.max_order < 0 {
2874            -1
2875        } else {
2876            self.max_order
2877        };
2878        self.lm.prob_for_sym(&self.sam, mo, v, sym_idx)
2879    }
2880
2881    /// Fill a dense byte-wise probability vector for the current SAM cursor (`sam.last`).
2882    ///
2883    /// `out` must have length at least 256. The output is normalized.
2884    pub fn fill_probs_for_last_bytes(&mut self, out: &mut [f64]) {
2885        debug_assert!(out.len() >= 256);
2886        if !self.lm_built {
2887            self.build_lm();
2888        }
2889
2890        let v = self.sam.last;
2891        let mo = if self.max_order < 0 {
2892            -1
2893        } else {
2894            self.max_order
2895        };
2896        self.dist.resize(self.lm.alpha_n as usize, 0.0);
2897        self.lm.probs_for_state(&self.sam, mo, v, &mut self.dist);
2898
2899        if self.lm.has_byte_map
2900            && (self.lm.alpha_n as usize) == BYTE_ALPHA_N
2901            && self.lm.alphabet.len() == BYTE_ALPHA_N
2902        {
2903            out[..BYTE_ALPHA_N].copy_from_slice(&self.dist[..BYTE_ALPHA_N]);
2904            return;
2905        }
2906
2907        out[..BYTE_ALPHA_N].fill(0.0);
2908        let mut sum = 0.0;
2909        for (i, &cp) in self.lm.alphabet.iter().enumerate() {
2910            if cp < BYTE_ALPHA_N as u32 {
2911                let p = self.dist[i];
2912                out[cp as usize] = p;
2913                sum += p;
2914            }
2915        }
2916
2917        if sum.is_finite() && sum > 0.0 {
2918            if (sum - 1.0).abs() > 1e-12 {
2919                let inv = 1.0 / sum;
2920                for p in &mut out[..BYTE_ALPHA_N] {
2921                    *p *= inv;
2922                }
2923            }
2924        } else {
2925            let u = 1.0 / BYTE_ALPHA_N as f64;
2926            for p in &mut out[..BYTE_ALPHA_N] {
2927                *p = u;
2928            }
2929        }
2930    }
2931}
2932
2933#[cfg(test)]
2934mod tests {
2935    use super::*;
2936    use std::fs;
2937    use std::path::PathBuf;
2938    use std::time::{SystemTime, UNIX_EPOCH};
2939
2940    fn temp_model_path(tag: &str) -> PathBuf {
2941        let nanos = SystemTime::now()
2942            .duration_since(UNIX_EPOCH)
2943            .expect("time went backwards")
2944            .as_nanos();
2945        std::env::temp_dir().join(format!(
2946            "infotheory_rosaplus_{tag}_{}_{}.bin",
2947            std::process::id(),
2948            nanos
2949        ))
2950    }
2951
2952    fn manual_chunked_entropy_rate_bytes(data: &[u8], max_order: i64, seed: u64) -> f64 {
2953        if data.len() < 2 {
2954            return 0.0;
2955        }
2956        let num_chunks = 16;
2957        let chunk_size = data.len().div_ceil(num_chunks);
2958        let mut total_log_prob = 0.0f64;
2959        let mut count = 0usize;
2960
2961        for i in 0..num_chunks {
2962            let start = i * chunk_size;
2963            let end = ((i + 1) * chunk_size).min(data.len());
2964            if start >= end {
2965                break;
2966            }
2967            if i == 0 {
2968                continue;
2969            }
2970
2971            let mut m = RosaPlus::new(max_order, false, 0, seed);
2972            m.train_example(&data[..start]);
2973            m.build_lm();
2974            let mut v = m.sam.last;
2975
2976            for &b in &data[start..end] {
2977                let sym_idx = m.lm.find_sym(b as u32);
2978                let p = m.lm.prob_for_sym(&m.sam, max_order, v, sym_idx);
2979                total_log_prob += p.log2();
2980                count += 1;
2981                v = m.sam.advance(v, b as u32);
2982            }
2983        }
2984
2985        if count == 0 {
2986            let mut m = RosaPlus::new(max_order, false, 0, seed);
2987            m.train_example(data);
2988            m.build_lm();
2989            m.cross_entropy(data)
2990        } else {
2991            -total_log_prob / (count as f64)
2992        }
2993    }
2994
2995    fn manual_chunked_entropy_rate_cps(data: &[u32], max_order: i64, seed: u64) -> f64 {
2996        if data.len() < 2 {
2997            return 0.0;
2998        }
2999        let mut m = RosaPlus::new(max_order, false, 0, seed);
3000        m.sam = Sam::new(data.len());
3001        m.lm_built = false;
3002
3003        let num_chunks = 16;
3004        let chunk_size = data.len().div_ceil(num_chunks);
3005        let mut total_log_prob = 0.0f64;
3006        let mut count = 0usize;
3007
3008        for i in 0..num_chunks {
3009            let start = i * chunk_size;
3010            let end = ((i + 1) * chunk_size).min(data.len());
3011            if start >= end {
3012                break;
3013            }
3014            let chunk = &data[start..end];
3015            if i > 0 {
3016                m.build_lm_no_finalize_endpos();
3017                let mut v = m.sam.text_states[start];
3018                for &ch in chunk {
3019                    let sym_idx = m.lm.find_sym(ch);
3020                    let p = m.lm.prob_for_sym(&m.sam, max_order, v, sym_idx);
3021                    total_log_prob += p.log2();
3022                    count += 1;
3023                    v = m.sam.advance(v, ch);
3024                }
3025            }
3026            for &ch in chunk {
3027                m.sam.feed(ch);
3028            }
3029        }
3030
3031        if count == 0 {
3032            m.build_lm();
3033            m.entropy_rate_plugin_cps(data)
3034        } else {
3035            -total_log_prob / (count as f64)
3036        }
3037    }
3038
3039    fn prob_for_sym_reference(
3040        lm: &LM,
3041        sam: &Sam,
3042        max_order: i64,
3043        v: SamStateIx,
3044        sym_idx: i32,
3045    ) -> f64 {
3046        if sym_idx < 0 {
3047            return 1.0 / (lm.alpha_n.max(1) as f64);
3048        }
3049        let sym_idx = sym_idx as u32;
3050        let mut p_accum = 0.0f64;
3051        let mut residual = 1.0f64;
3052        let mut u = v;
3053
3054        while u != SAM_STATE_NONE {
3055            if !(max_order >= 0 && (sam.st[state_usize(u)].len as i64) > max_order) {
3056                let n = lm.ls[state_usize(u)].total_n;
3057                let t = lm.ls[state_usize(u)].types_t;
3058                if n > 0 {
3059                    let lam = if t > 0 {
3060                        (n as f64) / ((n + (t as u64)) as f64)
3061                    } else {
3062                        1.0
3063                    };
3064                    let scale = residual * lam;
3065                    let mut count_for_sym = 0u64;
3066                    let ls = &lm.ls[state_usize(u)];
3067                    if LM::ls_is_implicit_single(ls) {
3068                        if ls.last_sym == sym_idx {
3069                            count_for_sym = n;
3070                        }
3071                    } else {
3072                        let mut ni = ls.head;
3073                        while ni != LM_NODE_NONE {
3074                            let node = lm.nodes.get(node_usize(ni));
3075                            if node.sym_idx == sym_idx {
3076                                count_for_sym = node.cnt;
3077                                break;
3078                            }
3079                            ni = node.next;
3080                        }
3081                    }
3082                    if count_for_sym > 0 {
3083                        p_accum += scale * (count_for_sym as f64 / n as f64);
3084                    }
3085                    residual *= 1.0 - lam;
3086                }
3087            }
3088            u = sam.st[state_usize(u)].link;
3089        }
3090
3091        if lm.total_uni > 0 && residual > 0.0 {
3092            let p_uni = lm.unigram[sym_idx as usize] as f64 / lm.total_uni as f64;
3093            p_accum += residual * p_uni;
3094        } else if residual > 0.0 {
3095            p_accum += residual * (1.0 / lm.alpha_n.max(1) as f64);
3096        }
3097
3098        p_accum.clamp(1e-12, 1.0)
3099    }
3100
3101    fn probs_for_state_reference(lm: &LM, sam: &Sam, max_order: i64, v: SamStateIx) -> Vec<f64> {
3102        let mut out = vec![0.0; lm.alpha_n as usize];
3103        let mut residual = 1.0f64;
3104        let mut u = v;
3105        while u != SAM_STATE_NONE {
3106            if !(max_order >= 0 && (sam.st[state_usize(u)].len as i64) > max_order) {
3107                let n = lm.ls[state_usize(u)].total_n;
3108                let t = lm.ls[state_usize(u)].types_t;
3109                if n > 0 {
3110                    let lam = if t > 0 {
3111                        (n as f64) / ((n + (t as u64)) as f64)
3112                    } else {
3113                        1.0
3114                    };
3115                    let scale = residual * lam;
3116                    let inv_n = 1.0 / (n as f64);
3117                    let ls = &lm.ls[state_usize(u)];
3118                    if LM::ls_is_implicit_single(ls) {
3119                        out[ls.last_sym as usize] += scale;
3120                    } else {
3121                        let mut ni = ls.head;
3122                        while ni != LM_NODE_NONE {
3123                            let node = lm.nodes.get(node_usize(ni));
3124                            out[node.sym_idx as usize] += scale * ((node.cnt as f64) * inv_n);
3125                            ni = node.next;
3126                        }
3127                    }
3128                    residual *= 1.0 - lam;
3129                }
3130            }
3131            u = sam.st[state_usize(u)].link;
3132        }
3133
3134        if lm.total_uni > 0 && residual > 0.0 {
3135            let inv = 1.0 / (lm.total_uni as f64);
3136            for (slot, &count) in out.iter_mut().zip(lm.unigram.iter()) {
3137                *slot += residual * ((count as f64) * inv);
3138            }
3139        }
3140
3141        let sum: f64 = out.iter().sum();
3142        if sum > 0.0 {
3143            let inv = 1.0 / sum;
3144            for slot in &mut out {
3145                *slot *= inv;
3146            }
3147        } else {
3148            let uprob = 1.0 / (lm.alpha_n.max(1) as f64);
3149            out.fill(uprob);
3150        }
3151        out
3152    }
3153
3154    #[test]
3155    fn rosa_md_example_basic() {
3156        // From rosa.md: ROSA predicts next token of best previous match.
3157        let x = b"ababa";
3158        let mut m = RosaPlus::new(1048576, false, 4, 0);
3159        m.train_example(x);
3160        m.build_lm();
3161        let out = m.generate(b"a", 10).unwrap();
3162        assert!(!out.is_empty());
3163    }
3164
3165    #[test]
3166    fn tx_rollback_restores_sam_and_unigram_counts() {
3167        let mut m = RosaPlus::new(4, false, 0, 123);
3168        m.train_example(b"hello");
3169        m.build_lm_full_bytes_no_finalize_endpos();
3170
3171        let base_text = m.sam.text.clone();
3172        let base_text_len = m.sam.text.len();
3173        let base_total_uni = m.lm.total_uni;
3174        assert!(base_text_len > 0);
3175
3176        let mut tx = m.begin_tx();
3177        m.train_example_tx(&mut tx, b"abc");
3178        assert_eq!(m.lm.total_uni, base_total_uni + 3);
3179        assert_eq!(m.sam.text.len(), base_text_len + 3);
3180
3181        m.rollback_tx(tx);
3182        assert_eq!(m.sam.text, base_text);
3183        assert_eq!(m.lm.total_uni, base_total_uni);
3184    }
3185
3186    #[test]
3187    fn train_sequence_matches_transactional_sequence_update() {
3188        let mut direct = RosaPlus::new(4, false, 0, 123);
3189        direct.build_lm_full_bytes_no_finalize_endpos();
3190        direct.reserve_for_stream(64);
3191        direct.train_sequence(b"abracadabra");
3192        direct.train_sequence(b" mississippi");
3193
3194        let mut tx_model = RosaPlus::new(4, false, 0, 123);
3195        tx_model.build_lm_full_bytes_no_finalize_endpos();
3196        tx_model.reserve_for_stream(64);
3197        let mut tx = tx_model.begin_tx();
3198        tx_model.train_sequence_tx(&mut tx, b"abracadabra");
3199        let mut tx = tx_model.begin_tx();
3200        tx_model.train_sequence_tx(&mut tx, b" mississippi");
3201
3202        assert_eq!(direct.sam.text, tx_model.sam.text);
3203        assert_eq!(direct.sam.text_states, tx_model.sam.text_states);
3204        assert_eq!(direct.sam.boundary_after, tx_model.sam.boundary_after);
3205        assert_eq!(direct.sam.last, tx_model.sam.last);
3206        assert_eq!(direct.lm.total_uni, tx_model.lm.total_uni);
3207        assert_eq!(direct.lm.unigram, tx_model.lm.unigram);
3208        assert_eq!(direct.lm.nodes, tx_model.lm.nodes);
3209        assert_eq!(direct.lm.ls, tx_model.lm.ls);
3210
3211        let mut direct_pdf = [0.0; BYTE_ALPHA_N];
3212        let mut tx_pdf = [0.0; BYTE_ALPHA_N];
3213        direct.fill_probs_for_last_bytes(&mut direct_pdf);
3214        tx_model.fill_probs_for_last_bytes(&mut tx_pdf);
3215        for idx in 0..BYTE_ALPHA_N {
3216            assert!((direct_pdf[idx] - tx_pdf[idx]).abs() < 1e-12);
3217        }
3218    }
3219
3220    #[test]
3221    fn repeated_single_byte_train_byte_matches_transactional_update() {
3222        let data = b"abracadabra mississippi";
3223
3224        let mut direct = RosaPlus::new(4, false, 0, 123);
3225        direct.build_lm_full_bytes_no_finalize_endpos();
3226        for &b in data {
3227            direct.train_byte(b);
3228        }
3229
3230        let mut tx_model = RosaPlus::new(4, false, 0, 123);
3231        tx_model.build_lm_full_bytes_no_finalize_endpos();
3232        for &b in data {
3233            let mut tx = tx_model.begin_tx();
3234            tx_model.train_sequence_tx(&mut tx, &[b]);
3235        }
3236
3237        assert_eq!(direct.sam.text, tx_model.sam.text);
3238        assert_eq!(direct.sam.text_states, tx_model.sam.text_states);
3239        assert_eq!(direct.sam.boundary_after, tx_model.sam.boundary_after);
3240        assert_eq!(direct.sam.last, tx_model.sam.last);
3241        assert_eq!(direct.lm.total_uni, tx_model.lm.total_uni);
3242        assert_eq!(direct.lm.unigram, tx_model.lm.unigram);
3243        assert_eq!(direct.lm.nodes, tx_model.lm.nodes);
3244        assert_eq!(direct.lm.ls, tx_model.lm.ls);
3245    }
3246
3247    #[test]
3248    fn max_order_capping_keeps_probability_semantics() {
3249        let mut m = RosaPlus::new(4, false, 0, 321);
3250        m.build_lm_full_bytes_no_finalize_endpos();
3251        m.train_sequence(b"abracadabra mississippi abracadabra abracadabra");
3252
3253        let v = m.sam.last;
3254        for &sym in b"a mz" {
3255            let sym_idx = m.lm.find_sym(sym as u32);
3256            let expected = prob_for_sym_reference(&m.lm, &m.sam, m.max_order, v, sym_idx);
3257            let got = m.lm.prob_for_sym(&m.sam, m.max_order, v, sym_idx);
3258            assert!(
3259                (got - expected).abs() < 1e-12,
3260                "sym={sym} got={got} expected={expected}"
3261            );
3262        }
3263
3264        let expected = probs_for_state_reference(&m.lm, &m.sam, m.max_order, v);
3265        let mut got = vec![0.0; m.lm.alpha_n as usize];
3266        m.lm.probs_for_state(&m.sam, m.max_order, v, &mut got);
3267        for idx in 0..got.len() {
3268            assert!(
3269                (got[idx] - expected[idx]).abs() < 1e-12,
3270                "idx={idx} got={} expected={}",
3271                got[idx],
3272                expected[idx]
3273            );
3274        }
3275    }
3276
3277    #[test]
3278    fn checkpoint_restore_reverts_append_only_buffers() {
3279        let mut m = RosaPlus::new(3, true, b'\n', 7);
3280        m.train_example(b"aaaa");
3281
3282        let ck = m.checkpoint();
3283        let base_text = m.sam.text.clone();
3284        let base_states = m.sam.text_states.clone();
3285        let base_boundary = m.sam.boundary_after.clone();
3286        let base_last = m.sam.last;
3287
3288        m.train_example(b"bbbb");
3289        assert_ne!(m.sam.text, base_text);
3290
3291        m.restore(&ck);
3292        assert_eq!(m.sam.text, base_text);
3293        assert_eq!(m.sam.text_states, base_states);
3294        assert_eq!(m.sam.boundary_after, base_boundary);
3295        assert_eq!(m.sam.last, base_last);
3296        assert!(!m.lm_built);
3297    }
3298
3299    #[test]
3300    fn predictive_entropy_rate_matches_chunked_reference_fixed_order() {
3301        let data = b"abracadabra abracadabra abracadabra";
3302        let seed = 11;
3303        let expected = manual_chunked_entropy_rate_bytes(data, 4, seed);
3304        let mut m = RosaPlus::new(4, false, 0, seed);
3305        let got = m.predictive_entropy_rate(data);
3306        assert!((got - expected).abs() < 1e-12);
3307    }
3308
3309    #[test]
3310    fn predictive_entropy_rate_uncapped_matches_candidate_search() {
3311        let data = b"the quick brown fox jumps over the lazy dog the quick brown fox";
3312        let seed = 29;
3313        let mut expected = f64::INFINITY;
3314        for &mo in &[0, 1, 2, 4, 8, 16, 32, 64] {
3315            if mo as usize >= data.len() {
3316                continue;
3317            }
3318            expected = expected.min(manual_chunked_entropy_rate_bytes(data, mo, seed));
3319        }
3320        let mut m = RosaPlus::new(-1, false, 0, seed);
3321        let got = m.predictive_entropy_rate(data);
3322        assert!((got - expected).abs() < 1e-12);
3323    }
3324
3325    #[test]
3326    fn entropy_rate_cps_matches_chunked_reference() {
3327        let data = [0u32, 7, 0, 42, 7, 42, 0, 7, 42, 42];
3328        let seed = 31;
3329        let expected = manual_chunked_entropy_rate_cps(&data, -1, seed);
3330        let mut m = RosaPlus::new(-1, false, 0, seed);
3331        let got = m.entropy_rate_cps(&data);
3332        assert!((got - expected).abs() < 1e-12);
3333    }
3334
3335    #[cfg(target_pointer_width = "64")]
3336    #[test]
3337    fn wide_index_helpers_preserve_large_indices() {
3338        let large = (i32::MAX as usize) + 17;
3339        assert_eq!(edge_usize(edge_ix(large)), large);
3340        assert_eq!(node_usize(node_ix(large)), large);
3341    }
3342
3343    #[test]
3344    fn save_load_roundtrip_preserves_state_and_probabilities() {
3345        let path = temp_model_path("roundtrip");
3346        let mut m = RosaPlus::new(8, true, b'\n', 1234);
3347        m.train_example(b"abracadabra");
3348        m.build_lm();
3349        let before_prob = m.prob_for_last(b'a' as u32);
3350        let before_size = m.estimated_size_bytes();
3351        let before_text = m.sam.text.clone();
3352        let before_states = m.sam.text_states.clone();
3353        let before_last = m.sam.last;
3354        let before_nodes = m.lm.nodes.len();
3355        let path_str = path.to_string_lossy().into_owned();
3356
3357        m.save(&path_str).expect("save failed");
3358        let mut loaded = RosaPlus::load(&path_str).expect("load failed");
3359        fs::remove_file(&path).expect("cleanup failed");
3360
3361        assert_eq!(loaded.max_order, m.max_order);
3362        assert_eq!(loaded.use_eot, m.use_eot);
3363        assert_eq!(loaded.eot, m.eot);
3364        assert_eq!(loaded.seed, m.seed);
3365        assert_eq!(loaded.sam.text, before_text);
3366        assert_eq!(loaded.sam.text_states, before_states);
3367        assert_eq!(loaded.sam.last, before_last);
3368        assert_eq!(loaded.lm.nodes.len(), before_nodes);
3369        assert_eq!(loaded.estimated_size_bytes(), before_size);
3370        assert!((loaded.prob_for_last(b'a' as u32) - before_prob).abs() < 1e-12);
3371    }
3372}