rosaplus/
lib.rs

1//! # ROSA: Rapid Online Suffix Automaton
2//! A high-performance predictive language model for entropy rate estimation.
3//!
4//! ROSA uses a **Suffix Automaton** (SAM) to efficiently find the longest matching context
5//! for each symbol in a sequence. It then applies **Witten-Bell smoothing** to estimate
6//! the conditional probability `P(x_t | x_{<t})`.
7//!
8//! This allows for accurate estimation of:
9//! *   Entropy Rate `Ĥ(X)`
10//! *   Cross-Entropy Rate `Ĥ(P, Q)`
11//! *   Joint Entropy Rate `Ĥ(X, Y)` (via aligned pair symbols)
12//!
13//! The implementation is optimized for speed and memory efficiency, using a compact
14//! graph representation for the automaton.
15
16#![allow(clippy::needless_range_loop)]
17
18use std::fs::File;
19use std::io::{BufReader, BufWriter, Read, Write};
20
21const SAM_SMALL_MAX: usize = 4;
22// NOTE: bump when on-disk format changes.
23// v4 adds serialization of `sam.last` and `sam.text_states` (required for reversible conditional updates).
24const MAGIC: &[u8] = b"rosa_pb_v4\0";
25
26// This crate is used byte-wise by infotheory; for fast incremental conditional updates we
27// support an optional fixed 256-byte alphabet LM build/update path.
28const BYTE_ALPHA_N: usize = 256;
29
30#[inline(always)]
31fn write_u32_slice_le<W: Write>(w: &mut W, xs: &[u32]) -> std::io::Result<()> {
32    if cfg!(target_endian = "little") {
33        let bytes = unsafe {
34            std::slice::from_raw_parts(xs.as_ptr() as *const u8, xs.len().saturating_mul(4))
35        };
36        w.write_all(bytes)
37    } else {
38        for &x in xs {
39            w.write_all(&x.to_le_bytes())?;
40        }
41        Ok(())
42    }
43}
44
45#[inline(always)]
46fn write_i32_slice_le<W: Write>(w: &mut W, xs: &[i32]) -> std::io::Result<()> {
47    if cfg!(target_endian = "little") {
48        let bytes = unsafe {
49            std::slice::from_raw_parts(xs.as_ptr() as *const u8, xs.len().saturating_mul(4))
50        };
51        w.write_all(bytes)
52    } else {
53        for &x in xs {
54            w.write_all(&x.to_le_bytes())?;
55        }
56        Ok(())
57    }
58}
59
60#[inline(always)]
61fn write_u64_slice_le<W: Write>(w: &mut W, xs: &[u64]) -> std::io::Result<()> {
62    if cfg!(target_endian = "little") {
63        let bytes = unsafe {
64            std::slice::from_raw_parts(xs.as_ptr() as *const u8, xs.len().saturating_mul(8))
65        };
66        w.write_all(bytes)
67    } else {
68        for &x in xs {
69            w.write_all(&x.to_le_bytes())?;
70        }
71        Ok(())
72    }
73}
74
75#[inline(always)]
76fn read_u32_slice_le<R: Read>(r: &mut R, xs: &mut [u32]) -> std::io::Result<()> {
77    if cfg!(target_endian = "little") {
78        let bytes = unsafe {
79            std::slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut u8, xs.len().saturating_mul(4))
80        };
81        r.read_exact(bytes)
82    } else {
83        let mut b4 = [0u8; 4];
84        for x in xs {
85            r.read_exact(&mut b4)?;
86            *x = u32::from_le_bytes(b4);
87        }
88        Ok(())
89    }
90}
91
92#[inline(always)]
93fn read_i32_slice_le<R: Read>(r: &mut R, xs: &mut [i32]) -> std::io::Result<()> {
94    if cfg!(target_endian = "little") {
95        let bytes = unsafe {
96            std::slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut u8, xs.len().saturating_mul(4))
97        };
98        r.read_exact(bytes)
99    } else {
100        let mut b4 = [0u8; 4];
101        for x in xs {
102            r.read_exact(&mut b4)?;
103            *x = i32::from_le_bytes(b4);
104        }
105        Ok(())
106    }
107}
108
109#[inline(always)]
110fn read_u64_slice_le<R: Read>(r: &mut R, xs: &mut [u64]) -> std::io::Result<()> {
111    if cfg!(target_endian = "little") {
112        let bytes = unsafe {
113            std::slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut u8, xs.len().saturating_mul(8))
114        };
115        r.read_exact(bytes)
116    } else {
117        let mut b8 = [0u8; 8];
118        for x in xs {
119            r.read_exact(&mut b8)?;
120            *x = u64::from_le_bytes(b8);
121        }
122        Ok(())
123    }
124}
125
126#[derive(Clone, Copy, Default)]
127struct SamState {
128    link: i32,
129    len: i32,
130    endpos: i32,
131    head: i32,
132
133    small_ch: [u32; SAM_SMALL_MAX],
134    small_to: [i32; SAM_SMALL_MAX],
135    small_n: u8,
136}
137
138#[derive(Clone, Copy, Default)]
139struct SamEdge {
140    ch: u32,
141    to: i32,
142    next: i32,
143}
144
145#[derive(Clone, Default)]
146struct Sam {
147    st: Vec<SamState>,
148    ed: Vec<SamEdge>,
149    last: i32,
150
151    text: Vec<u32>,
152    text_states: Vec<i32>,
153    boundary_after: Vec<u8>,
154}
155
156impl Sam {
157    fn new(expected_chars: usize) -> Self {
158        let mut s = Sam {
159            st: Vec::new(),
160            ed: Vec::new(),
161            last: 0,
162            text: Vec::new(),
163            text_states: Vec::new(),
164            boundary_after: Vec::new(),
165        };
166
167        let st_cap = if expected_chars > 0 {
168            expected_chars * 2 + 16
169        } else {
170            1024
171        };
172        let ed_cap = if expected_chars > 0 {
173            expected_chars * 3 + 16
174        } else {
175            2048
176        };
177        let text_cap = if expected_chars > 0 {
178            expected_chars + 16
179        } else {
180            1024
181        };
182        s.st.reserve(st_cap);
183        s.ed.reserve(ed_cap);
184        s.text.reserve(text_cap);
185        s.text_states.reserve(text_cap);
186        s.boundary_after.reserve(text_cap);
187
188        let mut root = SamState::default();
189        root.link = -1;
190        root.len = 0;
191        root.endpos = -1;
192        root.small_n = 0;
193        root.head = -1;
194        s.st.push(root);
195        s.text_states.push(0); // Root state for empty context
196        s
197    }
198
199    #[inline(always)]
200    fn get_edge(&self, v: i32, ch: u32) -> i32 {
201        let st = unsafe { self.st.get_unchecked(v as usize) };
202        for i in 0..(st.small_n as usize) {
203            if st.small_ch[i] == ch {
204                return st.small_to[i];
205            }
206        }
207        let mut ei = st.head;
208        while ei != -1 {
209            let e = unsafe { self.ed.get_unchecked(ei as usize) };
210            if e.ch == ch {
211                return e.to;
212            }
213            ei = e.next;
214        }
215        -1
216    }
217
218    #[inline(always)]
219    fn add_edge(&mut self, v: i32, ch: u32, to: i32) {
220        let idx = self.ed.len() as i32;
221        let head = self.st[v as usize].head;
222        self.ed.push(SamEdge { ch, to, next: head });
223        self.st[v as usize].head = idx;
224    }
225
226    #[inline(always)]
227    fn add_edge_absent(&mut self, v: i32, ch: u32, to: i32) {
228        let st = &mut self.st[v as usize];
229        if (st.small_n as usize) < SAM_SMALL_MAX {
230            let i = st.small_n as usize;
231            st.small_n += 1;
232            st.small_ch[i] = ch;
233            st.small_to[i] = to;
234        } else {
235            self.add_edge(v, ch, to);
236        }
237    }
238
239    #[inline(always)]
240    fn replace_edge_to(&mut self, v: i32, ch: u32, old_to: i32, new_to: i32) -> bool {
241        {
242            let st = &mut self.st[v as usize];
243            for i in 0..(st.small_n as usize) {
244                if st.small_ch[i] == ch && st.small_to[i] == old_to {
245                    st.small_to[i] = new_to;
246                    return true;
247                }
248            }
249        }
250        let mut ei = self.st[v as usize].head;
251        while ei != -1 {
252            let e = &mut self.ed[ei as usize];
253            if e.ch == ch && e.to == old_to {
254                e.to = new_to;
255                return true;
256            }
257            ei = e.next;
258        }
259        false
260    }
261
262    fn clone_overflow_edges(&mut self, src: i32, dst: i32) {
263        self.st[dst as usize].head = -1;
264        let mut ei = self.st[src as usize].head;
265        while ei != -1 {
266            let e = self.ed[ei as usize];
267            self.add_edge(dst, e.ch, e.to);
268            ei = e.next;
269        }
270    }
271
272    fn feed(&mut self, ch: u32) {
273        let i = self.text.len() as i32;
274        self.text.push(ch);
275        self.boundary_after.push(0);
276
277        let g = self.last;
278        let r = self.st.len() as i32;
279        let mut st_r = SamState::default();
280        st_r.link = 0;
281        st_r.len = self.st[g as usize].len + 1;
282        st_r.endpos = i;
283        st_r.small_n = 0;
284        st_r.head = -1;
285        self.st.push(st_r);
286
287        let mut p = g;
288        let mut q;
289        while p != -1 {
290            q = self.get_edge(p, ch);
291            if q != -1 {
292                break;
293            }
294            self.add_edge_absent(p, ch, r);
295            p = self.st[p as usize].link;
296        }
297
298        if p == -1 {
299            self.st[r as usize].link = 0;
300        } else {
301            q = self.get_edge(p, ch);
302            if self.st[p as usize].len + 1 == self.st[q as usize].len {
303                self.st[r as usize].link = q;
304            } else {
305                let u = self.st.len() as i32;
306                let mut st_u = self.st[q as usize];
307                st_u.len = self.st[p as usize].len + 1;
308                self.st.push(st_u);
309                self.clone_overflow_edges(q, u);
310                while p != -1 && self.replace_edge_to(p, ch, q, u) {
311                    p = self.st[p as usize].link;
312                }
313                self.st[q as usize].link = u;
314                self.st[r as usize].link = u;
315            }
316        }
317
318        self.last = r;
319        self.text_states.push(r);
320
321        // Maintain rightmost endpos online (ROSA deterministic predictor).
322        let mut v = r;
323        while v != -1 && self.st[v as usize].endpos < i {
324            self.st[v as usize].endpos = i;
325            v = self.st[v as usize].link;
326        }
327    }
328
329    fn mark_boundary(&mut self) {
330        if !self.text.is_empty() {
331            let i = self.text.len() - 1;
332            self.boundary_after[i] = 1;
333        }
334        self.last = 0;
335    }
336
337    fn finalize_endpos(&mut self) {
338        let mut max_len: usize = 0;
339        for v in 0..self.st.len() {
340            let l = self.st[v].len as usize;
341            if l > max_len {
342                max_len = l;
343            }
344        }
345
346        let mut cnt = vec![0usize; max_len + 1];
347        for v in 0..self.st.len() {
348            cnt[self.st[v].len as usize] += 1;
349        }
350        let mut pos = vec![0usize; max_len + 1];
351        let mut acc = 0usize;
352        for l in 0..=max_len {
353            pos[l] = acc;
354            acc += cnt[l];
355        }
356        let mut order = vec![0u32; self.st.len()];
357        for v in 0..self.st.len() {
358            let l = self.st[v].len as usize;
359            let idx = pos[l];
360            order[idx] = v as u32;
361            pos[l] += 1;
362        }
363
364        for oi in (0..order.len()).rev() {
365            let v = order[oi] as usize;
366            let p = self.st[v].link;
367            if p >= 0 {
368                let p = p as usize;
369                if self.st[v].endpos > self.st[p].endpos {
370                    self.st[p].endpos = self.st[v].endpos;
371                }
372            }
373        }
374    }
375
376    #[inline(always)]
377    fn advance(&self, mut v: i32, ch: u32) -> i32 {
378        let mut to;
379        loop {
380            to = self.get_edge(v, ch);
381            if to != -1 {
382                return to;
383            }
384            v = self.st[v as usize].link;
385            if v == -1 {
386                break;
387            }
388        }
389        to = self.get_edge(0, ch);
390        if to == -1 { 0 } else { to }
391    }
392
393    #[inline(always)]
394    fn predict_det(&self, v: i32) -> Option<u32> {
395        let mut u = v;
396        while u != -1 {
397            let st = unsafe { self.st.get_unchecked(u as usize) };
398            let i = st.endpos;
399            let j = i + 1;
400            if st.len > 0 && j >= 0 && (j as usize) < self.text.len() {
401                if i >= 0
402                    && (i as usize) < self.boundary_after.len()
403                    && self.boundary_after[i as usize] != 0
404                {
405                    u = st.link;
406                    continue;
407                }
408                return Some(self.text[j as usize]);
409            }
410            u = st.link;
411        }
412        None
413    }
414
415    // ===== Transactional (undo-log) support =====
416    fn begin_tx(&self) -> SamTx {
417        SamTx {
418            old_last: self.last,
419            old_text_len: self.text.len(),
420            old_text_states_len: self.text_states.len(),
421            old_boundary_len: self.boundary_after.len(),
422            old_st_len: self.st.len(),
423            old_ed_len: self.ed.len(),
424            st_changes: Vec::new(),
425            ed_changes: Vec::new(),
426        }
427    }
428
429    fn rollback_tx(&mut self, tx: SamTx) {
430        // Restore mutated entries (reverse order is fine even with duplicates).
431        for (idx, old) in tx.ed_changes.into_iter().rev() {
432            if idx < self.ed.len() {
433                self.ed[idx] = old;
434            }
435        }
436        for (idx, old) in tx.st_changes.into_iter().rev() {
437            if idx < self.st.len() {
438                self.st[idx] = old;
439            }
440        }
441
442        self.st.truncate(tx.old_st_len);
443        self.ed.truncate(tx.old_ed_len);
444        self.text.truncate(tx.old_text_len);
445        self.text_states.truncate(tx.old_text_states_len);
446        self.boundary_after.truncate(tx.old_boundary_len);
447        self.last = tx.old_last;
448    }
449
450    #[inline(always)]
451    fn record_state_change(&self, tx: &mut SamTx, idx: usize) {
452        // Duplicates are OK; rollback applies in reverse.
453        tx.st_changes.push((idx, self.st[idx]));
454    }
455
456    #[inline(always)]
457    fn record_edge_change(&self, tx: &mut SamTx, idx: usize) {
458        tx.ed_changes.push((idx, self.ed[idx]));
459    }
460
461    #[inline(always)]
462    fn add_edge_tx(&mut self, tx: &mut SamTx, v: i32, ch: u32, to: i32) {
463        let idx = self.ed.len() as i32;
464        let head = self.st[v as usize].head;
465        self.ed.push(SamEdge { ch, to, next: head });
466        self.record_state_change(tx, v as usize);
467        self.st[v as usize].head = idx;
468    }
469
470    #[inline(always)]
471    fn add_edge_absent_tx(&mut self, tx: &mut SamTx, v: i32, ch: u32, to: i32) {
472        let v_usize = v as usize;
473        let small_n = self.st[v_usize].small_n as usize;
474        if small_n < SAM_SMALL_MAX {
475            let i = small_n;
476            self.record_state_change(tx, v_usize);
477            let st = &mut self.st[v_usize];
478            st.small_ch[i] = ch;
479            st.small_to[i] = to;
480            st.small_n += 1;
481        } else {
482            self.add_edge_tx(tx, v, ch, to);
483        }
484    }
485
486    #[inline(always)]
487    fn replace_edge_to_tx(
488        &mut self,
489        tx: &mut SamTx,
490        v: i32,
491        ch: u32,
492        old_to: i32,
493        new_to: i32,
494    ) -> bool {
495        // small edges
496        {
497            let st = &self.st[v as usize];
498            for i in 0..(st.small_n as usize) {
499                if st.small_ch[i] == ch && st.small_to[i] == old_to {
500                    self.record_state_change(tx, v as usize);
501                    self.st[v as usize].small_to[i] = new_to;
502                    return true;
503                }
504            }
505        }
506        // overflow edges
507        let mut ei = self.st[v as usize].head;
508        while ei != -1 {
509            let eidx = ei as usize;
510            let e = self.ed[eidx];
511            if e.ch == ch && e.to == old_to {
512                self.record_edge_change(tx, eidx);
513                self.ed[eidx].to = new_to;
514                return true;
515            }
516            ei = e.next;
517        }
518        false
519    }
520
521    fn clone_overflow_edges_tx(&mut self, tx: &mut SamTx, src: i32, dst: i32) {
522        self.record_state_change(tx, dst as usize);
523        self.st[dst as usize].head = -1;
524        let mut ei = self.st[src as usize].head;
525        while ei != -1 {
526            let e = self.ed[ei as usize];
527            self.add_edge_tx(tx, dst, e.ch, e.to);
528            ei = e.next;
529        }
530    }
531
532    fn feed_tx(&mut self, tx: &mut SamTx, ch: u32) {
533        let i = self.text.len() as i32;
534        self.text.push(ch);
535        self.boundary_after.push(0);
536
537        let g = self.last;
538        let r = self.st.len() as i32;
539        let mut st_r = SamState::default();
540        st_r.link = 0;
541        st_r.len = self.st[g as usize].len + 1;
542        st_r.endpos = i;
543        st_r.small_n = 0;
544        st_r.head = -1;
545        self.st.push(st_r);
546
547        let mut p = g;
548        let mut q;
549        while p != -1 {
550            q = self.get_edge(p, ch);
551            if q != -1 {
552                break;
553            }
554            self.add_edge_absent_tx(tx, p, ch, r);
555            p = self.st[p as usize].link;
556        }
557
558        if p == -1 {
559            // link of r is in newly appended state; safe.
560            self.st[r as usize].link = 0;
561        } else {
562            q = self.get_edge(p, ch);
563            if self.st[p as usize].len + 1 == self.st[q as usize].len {
564                self.st[r as usize].link = q;
565            } else {
566                let u = self.st.len() as i32;
567                let mut st_u = self.st[q as usize];
568                st_u.len = self.st[p as usize].len + 1;
569                self.st.push(st_u);
570                self.clone_overflow_edges_tx(tx, q, u);
571                while p != -1 && self.replace_edge_to_tx(tx, p, ch, q, u) {
572                    p = self.st[p as usize].link;
573                }
574                // q is an existing state; record before mutation.
575                self.record_state_change(tx, q as usize);
576                self.st[q as usize].link = u;
577                self.st[r as usize].link = u;
578            }
579        }
580
581        self.last = r;
582        self.text_states.push(r);
583
584        // Maintain rightmost endpos online (ROSA deterministic predictor).
585        let mut v = r;
586        while v != -1 && self.st[v as usize].endpos < i {
587            self.record_state_change(tx, v as usize);
588            self.st[v as usize].endpos = i;
589            v = self.st[v as usize].link;
590        }
591    }
592
593    fn mark_boundary_tx(&mut self, tx: &mut SamTx) {
594        if !self.text.is_empty() {
595            // boundary_after is truncated on rollback, so no need to log.
596            let i = self.text.len() - 1;
597            self.boundary_after[i] = 1;
598        }
599        // last is restored on rollback.
600        self.last = 0;
601        let _ = tx;
602    }
603}
604
605#[derive(Clone)]
606struct SamTx {
607    old_last: i32,
608    old_text_len: usize,
609    old_text_states_len: usize,
610    old_boundary_len: usize,
611    old_st_len: usize,
612    old_ed_len: usize,
613    st_changes: Vec<(usize, SamState)>,
614    ed_changes: Vec<(usize, SamEdge)>,
615}
616
617#[derive(Clone, Copy, Default)]
618struct LmState {
619    head: i32,
620    total_n: u64,
621    types_t: u32,
622
623    last_sym: u32,
624    last_node: i32,
625}
626
627#[derive(Clone, Copy, Default)]
628struct CountNode {
629    sym_idx: u32,
630    cnt: u64,
631    next: i32,
632}
633
634#[derive(Clone)]
635struct LM {
636    alphabet: Vec<u32>,
637    unigram: Vec<u64>,
638    alpha_n: u32,
639    total_uni: u64,
640
641    has_byte_map: bool,
642    byte_map: [i16; 256],
643
644    ls: Vec<LmState>,
645    nodes: Vec<CountNode>,
646}
647
648impl Default for LM {
649    fn default() -> Self {
650        LM {
651            alphabet: Vec::new(),
652            unigram: Vec::new(),
653            alpha_n: 0,
654            total_uni: 0,
655            has_byte_map: false,
656            byte_map: [-1; 256],
657            ls: Vec::new(),
658            nodes: Vec::new(),
659        }
660    }
661}
662
663impl LM {
664    fn build_alphabet(&mut self, sam: &Sam) {
665        self.has_byte_map = false;
666        self.byte_map = [-1; 256];
667
668        let mut max_cp = 0u32;
669        for &v in &sam.text {
670            if v > max_cp {
671                max_cp = v;
672            }
673        }
674
675        if max_cp < 256 {
676            let mut counts = [0u64; 256];
677            for &v in &sam.text {
678                counts[v as usize] += 1;
679            }
680            let mut uniq = 0usize;
681            for c in 0..256 {
682                if counts[c] != 0 {
683                    uniq += 1;
684                }
685            }
686
687            if uniq == 0 {
688                self.alphabet = vec![b'\n' as u32];
689                self.unigram = vec![1];
690                self.alpha_n = 1;
691                self.total_uni = 1;
692                self.has_byte_map = true;
693                self.byte_map[b'\n' as usize] = 0;
694                return;
695            }
696
697            self.alphabet = Vec::with_capacity(uniq);
698            self.unigram = Vec::with_capacity(uniq);
699            self.total_uni = 0;
700            for c in 0..256u32 {
701                let cnt = counts[c as usize];
702                if cnt == 0 {
703                    continue;
704                }
705                self.alphabet.push(c);
706                self.unigram.push(cnt);
707                self.total_uni += cnt;
708            }
709            self.alpha_n = self.alphabet.len() as u32;
710            self.has_byte_map = true;
711            for (i, &c) in self.alphabet.iter().enumerate() {
712                self.byte_map[c as usize] = i as i16;
713            }
714            return;
715        }
716
717        let mut tmp = sam.text.clone();
718        tmp.sort_unstable();
719        tmp.dedup();
720        if tmp.is_empty() {
721            tmp.push(b'\n' as u32);
722        }
723        self.alphabet = tmp;
724        self.alpha_n = self.alphabet.len() as u32;
725        self.unigram = vec![0u64; self.alphabet.len()];
726        self.total_uni = 0;
727        for &ch in &sam.text {
728            if let Ok(i) = self.alphabet.binary_search(&ch) {
729                self.unigram[i] += 1;
730                self.total_uni += 1;
731            }
732        }
733        if self.total_uni == 0 {
734            self.unigram[0] = 1;
735            self.total_uni = 1;
736        }
737    }
738
739    #[inline(always)]
740    fn find_sym(&self, ch: u32) -> i32 {
741        if self.has_byte_map && ch < 256 {
742            return self.byte_map[ch as usize] as i32;
743        }
744        match self.alphabet.binary_search(&ch) {
745            Ok(i) => i as i32,
746            Err(_) => -1,
747        }
748    }
749
750    #[inline(always)]
751    fn inc(&mut self, state: u32, sym_idx: u32, add: u64) {
752        let ls = &mut self.ls[state as usize];
753        let last = ls.last_node;
754        if last != -1 && self.nodes[last as usize].sym_idx == sym_idx {
755            self.nodes[last as usize].cnt += add;
756            ls.total_n += add;
757            return;
758        }
759
760        let mut ni = ls.head;
761        while ni != -1 {
762            let node = &mut self.nodes[ni as usize];
763            if node.sym_idx == sym_idx {
764                node.cnt += add;
765                ls.total_n += add;
766                ls.last_node = ni;
767                ls.last_sym = sym_idx;
768                return;
769            }
770            ni = node.next;
771        }
772
773        let idx = self.nodes.len() as i32;
774        self.nodes.push(CountNode {
775            sym_idx,
776            cnt: add,
777            next: ls.head,
778        });
779        ls.head = idx;
780        ls.total_n += add;
781        ls.types_t += 1;
782        ls.last_node = idx;
783        ls.last_sym = sym_idx;
784    }
785
786    fn build_counts(&mut self, sam: &Sam, max_order: i64) {
787        self.ls = vec![
788            LmState {
789                head: -1,
790                last_node: -1,
791                ..LmState::default()
792            };
793            sam.st.len()
794        ];
795        self.nodes.clear();
796
797        let mut seg_start = 0usize;
798        while seg_start < sam.text.len() {
799            let mut seg_end = seg_start;
800            while seg_end < sam.text.len() {
801                let b = sam.boundary_after[seg_end];
802                seg_end += 1;
803                if b != 0 {
804                    break;
805                }
806            }
807            if seg_end - seg_start >= 2 {
808                let mut v = 0i32;
809                for i in seg_start..(seg_end - 1) {
810                    let ch = sam.text[i];
811                    v = sam.advance(v, ch);
812                    let mut ctx = v;
813                    if max_order >= 0 {
814                        while ctx != -1 && (sam.st[ctx as usize].len as i64) > max_order {
815                            ctx = sam.st[ctx as usize].link;
816                        }
817                        if ctx == -1 {
818                            ctx = 0;
819                        }
820                    }
821                    let nxt = sam.text[i + 1];
822                    let si = self.find_sym(nxt);
823                    if si >= 0 {
824                        self.inc(ctx as u32, si as u32, 1);
825                    }
826                }
827            }
828            seg_start = seg_end;
829        }
830
831        // propagate up suffix links (counting sort by len)
832        let mut max_len: usize = 0;
833        for st in &sam.st {
834            let l = st.len as usize;
835            if l > max_len {
836                max_len = l;
837            }
838        }
839        let mut cnt = vec![0usize; max_len + 1];
840        for st in &sam.st {
841            cnt[st.len as usize] += 1;
842        }
843        let mut pos = vec![0usize; max_len + 1];
844        let mut acc = 0usize;
845        for l in 0..=max_len {
846            pos[l] = acc;
847            acc += cnt[l];
848        }
849        let mut order = vec![0u32; sam.st.len()];
850        for (v, st) in sam.st.iter().enumerate() {
851            let l = st.len as usize;
852            let idx = pos[l];
853            order[idx] = v as u32;
854            pos[l] += 1;
855        }
856
857        for oi in (0..order.len()).rev() {
858            let v = order[oi] as usize;
859            let p = sam.st[v].link;
860            if p < 0 {
861                continue;
862            }
863            if self.ls[v].total_n == 0 {
864                continue;
865            }
866            let mut ni = self.ls[v].head;
867            while ni != -1 {
868                let node = self.nodes[ni as usize];
869                self.inc(p as u32, node.sym_idx, node.cnt);
870                ni = node.next;
871            }
872        }
873    }
874
875    /// Efficient pointwise probability Estimation of a single symbol.
876    /// Avoids allocating and writing to a dense distribution array.
877    fn prob_for_sym(&self, sam: &Sam, max_order: i64, v: i32, sym_idx: i32) -> f64 {
878        if sym_idx < 0 {
879            return 1.0 / (self.alpha_n.max(1) as f64);
880        }
881        let sym_idx = sym_idx as u32;
882        let mut p_accum = 0.0f64;
883        let mut residual = 1.0f64;
884        let mut u = v;
885
886        while u != -1 {
887            if !(max_order >= 0 && (sam.st[u as usize].len as i64) > max_order) {
888                let n = self.ls[u as usize].total_n;
889                let t = self.ls[u as usize].types_t;
890                if n > 0 {
891                    let lam = if t > 0 {
892                        (n as f64) / ((n + (t as u64)) as f64)
893                    } else {
894                        1.0
895                    };
896
897                    // Total probability mass from this state
898                    let scale = residual * lam;
899
900                    // Probability of specifically sym_idx in this state
901                    let mut count_for_sym = 0u64;
902                    let mut ni = self.ls[u as usize].head;
903                    while ni != -1 {
904                        let node = self.nodes[ni as usize];
905                        if node.sym_idx == sym_idx {
906                            count_for_sym = node.cnt;
907                            break;
908                        }
909                        ni = node.next;
910                    }
911
912                    if count_for_sym > 0 {
913                        p_accum += scale * (count_for_sym as f64 / n as f64);
914                    }
915
916                    residual *= 1.0 - lam;
917                }
918            }
919            u = sam.st[u as usize].link;
920        }
921
922        if self.total_uni > 0 && residual > 0.0 {
923            let p_uni = self.unigram[sym_idx as usize] as f64 / self.total_uni as f64;
924            p_accum += residual * p_uni;
925        } else if residual > 0.0 {
926            p_accum += residual * (1.0 / self.alpha_n.max(1) as f64);
927        }
928
929        p_accum.clamp(1e-12, 1.0)
930    }
931
932    fn probs_for_state(&self, sam: &Sam, max_order: i64, v: i32, out: &mut [f64]) {
933        out.fill(0.0);
934        let mut residual = 1.0f64;
935        let mut u = v;
936        while u != -1 {
937            if !(max_order >= 0 && (sam.st[u as usize].len as i64) > max_order) {
938                let n = self.ls[u as usize].total_n;
939                let t = self.ls[u as usize].types_t;
940                if n > 0 {
941                    let lam = if t > 0 {
942                        (n as f64) / ((n + (t as u64)) as f64)
943                    } else {
944                        1.0
945                    };
946                    let scale = residual * lam;
947                    let inv_n = 1.0 / (n as f64);
948                    let mut ni = self.ls[u as usize].head;
949                    while ni != -1 {
950                        let node = self.nodes[ni as usize];
951                        out[node.sym_idx as usize] += scale * ((node.cnt as f64) * inv_n);
952                        ni = node.next;
953                    }
954                    residual *= 1.0 - lam;
955                }
956            }
957            u = sam.st[u as usize].link;
958        }
959
960        if self.total_uni > 0 && residual > 0.0 {
961            let inv = 1.0 / (self.total_uni as f64);
962            for i in 0..(self.alpha_n as usize) {
963                out[i] += residual * ((self.unigram[i] as f64) * inv);
964            }
965        }
966
967        let mut s = 0.0;
968        for i in 0..(self.alpha_n as usize) {
969            s += out[i];
970        }
971        if s > 0.0 {
972            let invs = 1.0 / s;
973            for i in 0..(self.alpha_n as usize) {
974                out[i] *= invs;
975            }
976        } else {
977            let uprob = 1.0 / (self.alpha_n.max(1) as f64);
978            for i in 0..(self.alpha_n as usize) {
979                out[i] = uprob;
980            }
981        }
982    }
983
984    #[inline(always)]
985    fn inc_tx(&mut self, tx: &mut LmTx, state: u32, sym_idx: u32, add: u64) {
986        let si = state as usize;
987        // record old LmState
988        tx.ls_changes.push((si, self.ls[si]));
989
990        let ls = &mut self.ls[si];
991        let last = ls.last_node;
992        if last != -1 && self.nodes[last as usize].sym_idx == sym_idx {
993            let ni = last as usize;
994            tx.node_changes.push((ni, self.nodes[ni]));
995            self.nodes[ni].cnt += add;
996            ls.total_n += add;
997            return;
998        }
999
1000        let mut ni = ls.head;
1001        while ni != -1 {
1002            let idx = ni as usize;
1003            if self.nodes[idx].sym_idx == sym_idx {
1004                tx.node_changes.push((idx, self.nodes[idx]));
1005                self.nodes[idx].cnt += add;
1006                ls.total_n += add;
1007                ls.last_node = ni;
1008                ls.last_sym = sym_idx;
1009                return;
1010            }
1011            ni = self.nodes[idx].next;
1012        }
1013
1014        // New node
1015        let idx = self.nodes.len() as i32;
1016        tx.old_nodes_len = tx.old_nodes_len.min(self.nodes.len());
1017        self.nodes.push(CountNode {
1018            sym_idx,
1019            cnt: add,
1020            next: ls.head,
1021        });
1022        ls.head = idx;
1023        ls.total_n += add;
1024        ls.types_t += 1;
1025        ls.last_node = idx;
1026        ls.last_sym = sym_idx;
1027    }
1028}
1029
1030#[derive(Clone)]
1031struct LmTx {
1032    old_ls_len: usize,
1033    old_nodes_len: usize,
1034    ls_changes: Vec<(usize, LmState)>,
1035    node_changes: Vec<(usize, CountNode)>,
1036    // unigram delta for bytes
1037    uni_delta: [u64; BYTE_ALPHA_N],
1038    total_uni_add: u64,
1039}
1040
1041#[derive(Clone, Default)]
1042struct RngStream {
1043    buf: Vec<u8>,
1044    pos: usize,
1045    xs: u64,
1046}
1047
1048impl RngStream {
1049    fn new(seed: u64) -> Self {
1050        let mut r = RngStream {
1051            buf: Vec::new(),
1052            pos: 0,
1053            xs: 88172645463325252u64,
1054        };
1055        if let Ok(path) = std::env::var("ROSAPLUS_RNG_PATH") {
1056            if !path.is_empty() {
1057                if let Ok(mut f) = File::open(path) {
1058                    let mut b = Vec::new();
1059                    if f.read_to_end(&mut b).is_ok() && b.len() >= 8 {
1060                        let n = b.len();
1061                        r.pos = ((seed.wrapping_mul(8)) as usize) % n;
1062                        r.buf = b;
1063                    }
1064                }
1065            }
1066        }
1067        r
1068    }
1069
1070    #[inline(always)]
1071    fn next_u64(&mut self) -> u64 {
1072        if self.buf.len() < 8 {
1073            self.xs ^= self.xs << 7;
1074            self.xs ^= self.xs >> 9;
1075            return self.xs;
1076        }
1077        let n = self.buf.len();
1078        let mut b = [0u8; 8];
1079        for i in 0..8 {
1080            b[i] = self.buf[self.pos];
1081            self.pos += 1;
1082            if self.pos >= n {
1083                self.pos = 0;
1084            }
1085        }
1086        u64::from_le_bytes(b)
1087    }
1088
1089    #[inline(always)]
1090    fn next_unit(&mut self) -> f64 {
1091        let x = self.next_u64();
1092        ((x >> 11) as f64) * (1.0 / 9007199254740992.0)
1093    }
1094}
1095
1096// Helper for debugging/printing byte sequences if needed, but and
1097// utf8_decode_lossy/utf8_encode are now removed as we follow byte-wise rules.
1098
1099#[derive(Clone, Default)]
1100struct SampleScratch {
1101    idx: Vec<u32>,
1102    logits: Vec<f64>,
1103    exps: Vec<f64>,
1104}
1105
1106impl SampleScratch {
1107    fn ensure(&mut self, alpha_n: usize, n: usize) {
1108        if self.idx.len() != alpha_n {
1109            self.idx.resize(alpha_n, 0);
1110        }
1111        if self.logits.len() < n {
1112            self.logits.resize(n, 0.0);
1113            self.exps.resize(n, 0.0);
1114        }
1115    }
1116}
1117
1118#[derive(Clone)]
1119pub struct RosaPlus {
1120    max_order: i64,
1121    use_eot: bool,
1122    eot: u32,
1123    seed: u64,
1124
1125    sam: Sam,
1126    lm: LM,
1127    lm_built: bool,
1128
1129    rng: RngStream,
1130    scratch: SampleScratch,
1131    dist: Vec<f64>,
1132}
1133
1134/// A lightweight snapshot of the append-only internal SAM buffers.
1135///
1136/// Restoring to a checkpoint is O(1) (via truncation) and is meant to support
1137/// repeated evaluation of different continuations from the same base training state.
1138#[derive(Clone, Copy, Debug)]
1139pub struct RosaCheckpoint {
1140    sam_st_len: usize,
1141    sam_ed_len: usize,
1142    sam_text_len: usize,
1143    sam_text_states_len: usize,
1144    sam_boundary_after_len: usize,
1145    sam_last: i32,
1146}
1147
1148/// Transaction object used to roll back a temporary conditional update.
1149#[derive(Clone)]
1150pub struct RosaTx {
1151    sam: SamTx,
1152    lm: LmTx,
1153    seg_start: usize,
1154    seg_len: usize,
1155}
1156
1157impl RosaPlus {
1158    pub fn new(max_order: i64, use_eot: bool, eot_char: u8, seed: u64) -> Self {
1159        let sam = Sam::new(0);
1160        RosaPlus {
1161            max_order,
1162            use_eot,
1163            eot: eot_char as u32,
1164            seed,
1165            sam,
1166            lm: LM::default(),
1167            lm_built: false,
1168            rng: RngStream::new(seed),
1169            scratch: SampleScratch::default(),
1170            dist: Vec::new(),
1171        }
1172    }
1173
1174    pub fn train_example(&mut self, s: &[u8]) {
1175        if s.is_empty() {
1176            return;
1177        }
1178
1179        if self.sam.text.is_empty() {
1180            self.sam = Sam::new(s.len());
1181        }
1182
1183        for &b in s {
1184            self.sam.feed(b as u32);
1185        }
1186
1187        if self.use_eot {
1188            self.sam.feed(self.eot);
1189        }
1190
1191        self.sam.mark_boundary();
1192        self.lm_built = false;
1193    }
1194
1195    pub fn build_lm(&mut self) {
1196        self.sam.finalize_endpos();
1197        self.lm = LM::default();
1198        self.lm.build_alphabet(&self.sam);
1199        let mo = if self.max_order < 0 {
1200            -1
1201        } else {
1202            self.max_order
1203        };
1204        self.lm.build_counts(&self.sam, mo);
1205        self.lm_built = true;
1206        self.dist.resize(self.lm.alpha_n as usize, 0.0);
1207    }
1208
1209    /// Build the language model without mutating SAM `endpos`.
1210    ///
1211    /// This is useful when you want to reuse a trained SAM as a stable base state
1212    /// (e.g. universal-prior conditioning) and need cheap checkpoint/restore via truncation.
1213    ///
1214    /// Note: entropy/cross-entropy estimation does not require `endpos` finalization.
1215    pub fn build_lm_no_finalize_endpos(&mut self) {
1216        self.lm = LM::default();
1217        self.lm.build_alphabet(&self.sam);
1218        let mo = if self.max_order < 0 {
1219            -1
1220        } else {
1221            self.max_order
1222        };
1223        self.lm.build_counts(&self.sam, mo);
1224        self.lm_built = true;
1225        self.dist.resize(self.lm.alpha_n as usize, 0.0);
1226    }
1227
1228    /// Build an LM with a fixed byte alphabet of size 256.
1229    ///
1230    /// This avoids alphabet growth issues and enables fast incremental updates.
1231    pub fn build_lm_full_bytes_no_finalize_endpos(&mut self) {
1232        // Fixed alphabet
1233        self.lm = LM::default();
1234        self.lm.has_byte_map = true;
1235        self.lm.alpha_n = BYTE_ALPHA_N as u32;
1236        self.lm.alphabet = (0..BYTE_ALPHA_N as u32).collect();
1237        self.lm.byte_map = [-1; 256];
1238        for i in 0..256 {
1239            self.lm.byte_map[i] = i as i16;
1240        }
1241
1242        // Unigram counts
1243        let mut counts = [0u64; 256];
1244        for &v in &self.sam.text {
1245            if v < 256 {
1246                counts[v as usize] += 1;
1247            }
1248        }
1249        self.lm.unigram = counts.to_vec();
1250        self.lm.total_uni = counts.iter().sum();
1251        if self.lm.total_uni == 0 {
1252            for i in 0..256 {
1253                self.lm.unigram[i] = 1;
1254            }
1255            self.lm.total_uni = 256;
1256        }
1257
1258        // Counts
1259        let mo = if self.max_order < 0 {
1260            -1
1261        } else {
1262            self.max_order
1263        };
1264        self.lm.build_counts(&self.sam, mo);
1265        self.lm_built = true;
1266        self.dist.resize(BYTE_ALPHA_N, 0.0);
1267    }
1268
1269    /// Begin a reversible conditional update transaction.
1270    pub fn begin_tx(&mut self) -> RosaTx {
1271        let sam_tx = self.sam.begin_tx();
1272        let lm_tx = LmTx {
1273            old_ls_len: self.lm.ls.len(),
1274            old_nodes_len: self.lm.nodes.len(),
1275            ls_changes: Vec::new(),
1276            node_changes: Vec::new(),
1277            uni_delta: [0u64; BYTE_ALPHA_N],
1278            total_uni_add: 0,
1279        };
1280        RosaTx {
1281            sam: sam_tx,
1282            lm: lm_tx,
1283            seg_start: self.sam.text.len(),
1284            seg_len: 0,
1285        }
1286    }
1287
1288    /// Apply a training example and update LM counts incrementally (byte alphabet must be full 256).
1289    pub fn train_example_tx(&mut self, tx: &mut RosaTx, s: &[u8]) {
1290        self.train_example_tx_impl(tx, s, true);
1291    }
1292
1293    /// Apply a sequential update without inserting a boundary (continuous stream).
1294    pub fn train_sequence_tx(&mut self, tx: &mut RosaTx, s: &[u8]) {
1295        self.train_example_tx_impl(tx, s, false);
1296    }
1297
1298    fn train_example_tx_impl(&mut self, tx: &mut RosaTx, s: &[u8], mark_boundary: bool) {
1299        if s.is_empty() {
1300            return;
1301        }
1302
1303        // Ensure LS has entries for current states.
1304        if self.lm.ls.len() < self.sam.st.len() {
1305            self.lm.ls.resize(
1306                self.sam.st.len(),
1307                LmState {
1308                    head: -1,
1309                    last_node: -1,
1310                    ..LmState::default()
1311                },
1312            );
1313        }
1314
1315        // Feed all bytes (SAM structure changes are logged).
1316        for &b in s {
1317            self.sam.feed_tx(&mut tx.sam, b as u32);
1318            tx.lm.uni_delta[b as usize] += 1;
1319            tx.lm.total_uni_add += 1;
1320        }
1321        if mark_boundary {
1322            self.sam.mark_boundary_tx(&mut tx.sam);
1323        }
1324
1325        // LM must be built for scoring; we keep it built and update counts incrementally.
1326        // Extend ls for any new SAM states created by feeding.
1327        if self.lm.ls.len() < self.sam.st.len() {
1328            self.lm.ls.resize(
1329                self.sam.st.len(),
1330                LmState {
1331                    head: -1,
1332                    last_node: -1,
1333                    ..LmState::default()
1334                },
1335            );
1336        }
1337
1338        // Update unigram counts (fixed 256 alphabet assumed).
1339        for i in 0..256 {
1340            if tx.lm.uni_delta[i] != 0 {
1341                self.lm.unigram[i] += tx.lm.uni_delta[i];
1342            }
1343        }
1344        self.lm.total_uni += tx.lm.total_uni_add;
1345
1346        // Update conditional counts for the new segment only.
1347        let seg_start = tx.seg_start;
1348        let seg_end = self.sam.text.len();
1349        tx.seg_len = seg_end - seg_start;
1350        if tx.seg_len >= 1 {
1351            let mo = if self.max_order < 0 {
1352                -1
1353            } else {
1354                self.max_order
1355            };
1356            // For continuous streams, include the cross-boundary transition from the
1357            // previous symbol into the first new symbol. For segmented examples,
1358            // respect boundary markers and skip that transition.
1359            let mut start_i = seg_start;
1360            if !mark_boundary
1361                && seg_start > 0
1362                && self.sam.boundary_after.get(seg_start - 1).copied().unwrap_or(0) == 0
1363            {
1364                start_i = seg_start - 1;
1365            }
1366            for i in start_i..(seg_end - 1) {
1367                // ctx state after consuming sam.text[i] within its segment
1368                let mut ctx = self.sam.text_states[i + 1];
1369                if mo >= 0 {
1370                    while ctx != -1 && (self.sam.st[ctx as usize].len as i64) > mo {
1371                        ctx = self.sam.st[ctx as usize].link;
1372                    }
1373                    if ctx == -1 {
1374                        ctx = 0;
1375                    }
1376                }
1377                let nxt = self.sam.text[i + 1];
1378                let si = self.lm.find_sym(nxt);
1379                if si >= 0 {
1380                    let mut u = ctx;
1381                    while u != -1 {
1382                        self.lm.inc_tx(&mut tx.lm, u as u32, si as u32, 1);
1383                        u = self.sam.st[u as usize].link;
1384                    }
1385                }
1386            }
1387        }
1388
1389        self.lm_built = true;
1390    }
1391
1392    /// Roll back a transaction, restoring the model to the exact state at begin_tx.
1393    pub fn rollback_tx(&mut self, tx: RosaTx) {
1394        // Restore LM changes
1395        // Unigram rollback
1396        if self.lm.unigram.len() >= BYTE_ALPHA_N {
1397            for i in 0..BYTE_ALPHA_N {
1398                let d = tx.lm.uni_delta[i];
1399                if d != 0 {
1400                    self.lm.unigram[i] = self.lm.unigram[i].saturating_sub(d);
1401                }
1402            }
1403            self.lm.total_uni = self.lm.total_uni.saturating_sub(tx.lm.total_uni_add);
1404        }
1405
1406        for (idx, old) in tx.lm.node_changes.into_iter().rev() {
1407            if idx < self.lm.nodes.len() {
1408                self.lm.nodes[idx] = old;
1409            }
1410        }
1411        for (idx, old) in tx.lm.ls_changes.into_iter().rev() {
1412            if idx < self.lm.ls.len() {
1413                self.lm.ls[idx] = old;
1414            }
1415        }
1416        self.lm.nodes.truncate(tx.lm.old_nodes_len);
1417        self.lm.ls.truncate(tx.lm.old_ls_len);
1418
1419        // Restore SAM
1420        self.sam.rollback_tx(tx.sam);
1421        // lm_built remains true if it was true before; safe to keep true.
1422    }
1423
1424    /// Ensure the LM is built (without mutating SAM endpos).
1425    #[inline(always)]
1426    pub fn ensure_lm_built_no_finalize_endpos(&mut self) {
1427        if !self.lm_built {
1428            self.build_lm_no_finalize_endpos();
1429        }
1430    }
1431
1432    fn predictive_entropy_rate_order(data: &[u8], max_order: i64, seed: u64) -> f64 {
1433        if data.len() < 2 {
1434            return 0.0;
1435        }
1436        let num_chunks = 16;
1437        let chunk_size = (data.len() + num_chunks - 1) / num_chunks;
1438        let mut total_log_prob = 0.0f64;
1439        let mut count = 0usize;
1440
1441        for i in 0..num_chunks {
1442            let start = i * chunk_size;
1443            let end = ((i + 1) * chunk_size).min(data.len());
1444            if start >= end {
1445                break;
1446            }
1447            if i == 0 {
1448                continue;
1449            }
1450
1451            let mut m = RosaPlus::new(max_order, false, 0, seed);
1452            m.train_example(&data[..start]);
1453            m.build_lm();
1454            let mut v = m.sam.last;
1455
1456            for &b in &data[start..end] {
1457                let sym_idx = m.lm.find_sym(b as u32);
1458                let p = m.lm.prob_for_sym(&m.sam, max_order, v, sym_idx);
1459                total_log_prob += p.log2();
1460                count += 1;
1461                v = m.sam.advance(v, b as u32);
1462            }
1463        }
1464
1465        if count == 0 {
1466            let mut m = RosaPlus::new(max_order, false, 0, seed);
1467            m.train_example(data);
1468            m.build_lm();
1469            m.cross_entropy(data)
1470        } else {
1471            -total_log_prob / (count as f64)
1472        }
1473    }
1474
1475    /// Current LM alphabet size (0 if LM not built).
1476    pub fn lm_alpha_n(&self) -> usize {
1477        if !self.lm_built {
1478            0
1479        } else {
1480            self.lm.alpha_n as usize
1481        }
1482    }
1483
1484    pub fn estimated_size_bytes(&self) -> usize {
1485        use std::mem::size_of;
1486
1487        let mut n = 0usize;
1488
1489        n = n.saturating_add(self.sam.st.len().saturating_mul(size_of::<SamState>()));
1490        n = n.saturating_add(self.sam.ed.len().saturating_mul(size_of::<SamEdge>()));
1491        n = n.saturating_add(self.sam.text.len().saturating_mul(size_of::<u32>()));
1492        n = n.saturating_add(self.sam.text_states.len().saturating_mul(size_of::<i32>()));
1493        n = n.saturating_add(
1494            self.sam
1495                .boundary_after
1496                .len()
1497                .saturating_mul(size_of::<u8>()),
1498        );
1499
1500        n = n.saturating_add(self.lm.alphabet.len().saturating_mul(size_of::<u32>()));
1501        n = n.saturating_add(self.lm.unigram.len().saturating_mul(size_of::<u64>()));
1502        n = n.saturating_add(self.lm.ls.len().saturating_mul(size_of::<LmState>()));
1503        n = n.saturating_add(self.lm.nodes.len().saturating_mul(size_of::<CountNode>()));
1504
1505        n = n.saturating_add(self.dist.len().saturating_mul(size_of::<f64>()));
1506        n = n.saturating_add(self.scratch.idx.len().saturating_mul(size_of::<u32>()));
1507        n = n.saturating_add(self.scratch.logits.len().saturating_mul(size_of::<f64>()));
1508        n = n.saturating_add(self.scratch.exps.len().saturating_mul(size_of::<f64>()));
1509        n = n.saturating_add(self.rng.buf.len().saturating_mul(size_of::<u8>()));
1510
1511        n
1512    }
1513
1514    pub fn shrink_aux_buffers(&mut self) {
1515        self.dist.shrink_to_fit();
1516        self.scratch.idx.shrink_to_fit();
1517        self.scratch.logits.shrink_to_fit();
1518        self.scratch.exps.shrink_to_fit();
1519        self.rng.buf.shrink_to_fit();
1520    }
1521
1522    /// Create a new model that shares the same trained SAM state but resets LM-related buffers.
1523    ///
1524    /// This is substantially cheaper than cloning the full `RosaPlus` (which includes LM counts,
1525    /// node tables, and distribution buffers) and is safe for workflows that want to start from
1526    /// a fixed base training text (e.g. a universal prior) and then add candidate-specific text.
1527    pub fn fork_from_sam(&self) -> Self {
1528        Self {
1529            max_order: self.max_order,
1530            use_eot: self.use_eot,
1531            eot: self.eot,
1532            seed: self.seed,
1533
1534            sam: self.sam.clone(),
1535            lm: LM::default(),
1536            lm_built: false,
1537
1538            rng: RngStream::new(self.seed),
1539            scratch: SampleScratch::default(),
1540            dist: Vec::new(),
1541        }
1542    }
1543
1544    /// A checkpoint that allows restoring the ROSA model back to a previous trained state
1545    /// by truncating append-only internal buffers.
1546    ///
1547    /// Intended for workflows that repeatedly evaluate different continuations from the same base
1548    /// training text (e.g. universal-prior conditioned scoring).
1549    pub fn checkpoint(&self) -> RosaCheckpoint {
1550        RosaCheckpoint {
1551            sam_st_len: self.sam.st.len(),
1552            sam_ed_len: self.sam.ed.len(),
1553            sam_text_len: self.sam.text.len(),
1554            sam_text_states_len: self.sam.text_states.len(),
1555            sam_boundary_after_len: self.sam.boundary_after.len(),
1556            sam_last: self.sam.last,
1557        }
1558    }
1559
1560    /// Restore the model to a previously captured checkpoint.
1561    ///
1562    /// This invalidates the LM; callers should rebuild it before scoring.
1563    pub fn restore(&mut self, ck: &RosaCheckpoint) {
1564        self.sam.st.truncate(ck.sam_st_len);
1565        self.sam.ed.truncate(ck.sam_ed_len);
1566        self.sam.text.truncate(ck.sam_text_len);
1567        self.sam.text_states.truncate(ck.sam_text_states_len);
1568        self.sam.boundary_after.truncate(ck.sam_boundary_after_len);
1569        self.sam.last = ck.sam_last;
1570        self.lm_built = false;
1571    }
1572
1573    #[inline(always)]
1574    fn sample(&mut self, temperature: f64, top_p: f64, top_k: i32) -> u32 {
1575        let dist = &self.dist;
1576        let alpha_n = self.lm.alpha_n as usize;
1577        self.scratch.ensure(alpha_n, alpha_n);
1578        for i in 0..alpha_n {
1579            self.scratch.idx[i] = i as u32;
1580        }
1581
1582        // O(n^2) sort by dist desc then idx asc (matches C).
1583        for i in 0..alpha_n {
1584            for j in (i + 1)..alpha_n {
1585                let ii = self.scratch.idx[i] as usize;
1586                let jj = self.scratch.idx[j] as usize;
1587                let pi = dist[ii];
1588                let pj = dist[jj];
1589                if pj > pi || (pj == pi && jj < ii) {
1590                    self.scratch.idx.swap(i, j);
1591                }
1592            }
1593        }
1594
1595        let mut n = alpha_n;
1596        if top_k > 0 {
1597            let k = top_k as usize;
1598            if k < n {
1599                n = k;
1600            }
1601        }
1602
1603        if top_p > 0.0 && top_p < 1.0 {
1604            let mut cum = 0.0;
1605            let mut cut = 0usize;
1606            for i in 0..n {
1607                let si = self.scratch.idx[i] as usize;
1608                cum += dist[si];
1609                cut += 1;
1610                if cum >= top_p {
1611                    break;
1612                }
1613            }
1614            n = if cut > 0 { cut } else { 1 };
1615        }
1616
1617        let temperature = if temperature <= 0.0 {
1618            1e-6
1619        } else {
1620            temperature
1621        };
1622
1623        self.scratch.ensure(alpha_n, n);
1624        let mut maxlog = -1e300f64;
1625        for i in 0..n {
1626            let si = self.scratch.idx[i] as usize;
1627            let mut p = dist[si];
1628            if p < 1e-12 {
1629                p = 1e-12;
1630            }
1631            let z = p.ln() / temperature;
1632            self.scratch.logits[i] = z;
1633            if z > maxlog {
1634                maxlog = z;
1635            }
1636        }
1637
1638        let mut zsum = 0.0;
1639        for i in 0..n {
1640            let e = (self.scratch.logits[i] - maxlog).exp();
1641            self.scratch.exps[i] = e;
1642            zsum += e;
1643        }
1644
1645        let r = self.rng.next_unit() * zsum;
1646        let mut cum = 0.0;
1647        let mut pick = 0usize;
1648        for i in 0..n {
1649            cum += self.scratch.exps[i];
1650            if cum > r {
1651                pick = i;
1652                break;
1653            }
1654        }
1655
1656        let sym = self.scratch.idx[pick] as usize;
1657        self.lm.alphabet[sym]
1658    }
1659
1660    pub fn generate(&mut self, prompt: &[u8], steps: i32) -> Option<Vec<u8>> {
1661        if !self.lm_built {
1662            return None;
1663        }
1664        let steps = steps.max(0) as usize;
1665
1666        let mut v = 0i32;
1667        for &b in prompt {
1668            v = self.sam.advance(v, b as u32);
1669        }
1670
1671        let mut out: Vec<u32> = Vec::with_capacity(steps);
1672
1673        for _ in 0..steps {
1674            let mut ch = self.sam.predict_det(v);
1675            if ch.is_none() {
1676                let mo = if self.max_order < 0 {
1677                    -1
1678                } else {
1679                    self.max_order
1680                };
1681                self.lm.probs_for_state(&self.sam, mo, v, &mut self.dist);
1682                ch = Some(self.sample(0.7, 0.9, 0));
1683            }
1684            let ch = ch.unwrap();
1685            out.push(ch);
1686            if self.use_eot && ch == self.eot {
1687                break;
1688            }
1689            v = self.sam.advance(v, ch);
1690        }
1691
1692        Some(out.iter().map(|&c| c as u8).collect())
1693    }
1694
1695    // ========== Entropy Estimation API ==========
1696
1697    /// Returns the probability distribution for the next symbol given a context.
1698    /// Output: Vec of (codepoint, probability) pairs, sorted by codepoint.
1699    /// Builds the LM if not already built.
1700    pub fn get_distribution(&mut self, context: &[u8]) -> Vec<(u32, f64)> {
1701        if !self.lm_built {
1702            self.build_lm();
1703        }
1704
1705        // Advance through context to get SAM state
1706        let mut v = 0i32;
1707        for &b in context {
1708            v = self.sam.advance(v, b as u32);
1709        }
1710
1711        // Get probability distribution at this state
1712        let mo = if self.max_order < 0 {
1713            -1
1714        } else {
1715            self.max_order
1716        };
1717        self.dist.resize(self.lm.alpha_n as usize, 0.0);
1718        self.lm.probs_for_state(&self.sam, mo, v, &mut self.dist);
1719
1720        // Build output as (codepoint, probability) pairs
1721        let mut result = Vec::with_capacity(self.lm.alpha_n as usize);
1722        for i in 0..(self.lm.alpha_n as usize) {
1723            if self.dist[i] > 0.0 {
1724                result.push((self.lm.alphabet[i], self.dist[i]));
1725            }
1726        }
1727        result.sort_by_key(|&(cp, _)| cp);
1728        result
1729    }
1730
1731    /// Compute the predictive entropy rate (bits per symbol) of the given data.
1732    ///
1733    /// Uses chunked prequential scoring (train on past chunks, score next chunk).
1734    pub fn predictive_entropy_rate(&mut self, data: &[u8]) -> f64 {
1735        if data.len() < 2 {
1736            return 0.0;
1737        }
1738        if self.max_order < 0 {
1739            let candidates: [i64; 8] = [0, 1, 2, 4, 8, 16, 32, 64];
1740            let mut best = f64::INFINITY;
1741            for &mo in &candidates {
1742                if mo as usize >= data.len() {
1743                    continue;
1744                }
1745                let h = Self::predictive_entropy_rate_order(data, mo, self.seed);
1746                if h < best {
1747                    best = h;
1748                }
1749            }
1750            if best.is_finite() {
1751                return best;
1752            }
1753        }
1754        Self::predictive_entropy_rate_order(data, self.max_order, self.seed)
1755    }
1756
1757    pub fn entropy_rate_cps(&mut self, cps: &[u32]) -> f64 {
1758        if cps.len() < 2 {
1759            return 0.0;
1760        }
1761
1762        self.sam = Sam::new(cps.len());
1763        self.lm_built = false;
1764
1765        let num_chunks = 16;
1766        let chunk_size = (cps.len() + num_chunks - 1) / num_chunks;
1767        let mut total_log_prob = 0.0f64;
1768        let mut count = 0usize;
1769
1770        for i in 0..num_chunks {
1771            let start = i * chunk_size;
1772            let end = ((i + 1) * chunk_size).min(cps.len());
1773            if start >= end {
1774                break;
1775            }
1776            let chunk = &cps[start..end];
1777            if i > 0 {
1778                // Avoid endpos finalization since we continue mutating the SAM across chunks.
1779                self.build_lm_no_finalize_endpos();
1780                let mut v = self.sam.text_states[start];
1781                for &ch in chunk {
1782                    let sym_idx = self.lm.find_sym(ch);
1783                    let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
1784                    total_log_prob += p.log2();
1785                    count += 1;
1786                    v = self.sam.advance(v, ch);
1787                }
1788            }
1789            for &ch in chunk {
1790                self.sam.feed(ch);
1791            }
1792        }
1793
1794        if count == 0 {
1795            self.build_lm();
1796            self.entropy_rate_plugin_cps(cps)
1797        } else {
1798            -total_log_prob / (count as f64)
1799        }
1800    }
1801
1802    #[allow(dead_code)]
1803    fn entropy_rate_plugin_bytes(&mut self, data: &[u8]) -> f64 {
1804        let mut v = 0i32;
1805        let mut total_log_prob = 0.0f64;
1806        let mut count = 0usize;
1807        for t in 0..(data.len() - 1) {
1808            v = self.sam.advance(v, data[t] as u32);
1809            let next_ch = data[t + 1] as u32;
1810            let sym_idx = self.lm.find_sym(next_ch);
1811            let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
1812            total_log_prob += p.log2();
1813            count += 1;
1814        }
1815        if count == 0 {
1816            0.0
1817        } else {
1818            -total_log_prob / (count as f64)
1819        }
1820    }
1821
1822    pub fn cross_entropy(&self, data: &[u8]) -> f64 {
1823        if !self.lm_built || data.is_empty() {
1824            return 0.0;
1825        }
1826        let mut total_log_prob = 0.0f64;
1827        let mut v = 0i32;
1828        for &b in data {
1829            let ch = b as u32;
1830            let sym_idx = self.lm.find_sym(ch);
1831            let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
1832            total_log_prob += p.log2();
1833            v = self.sam.advance(v, ch);
1834        }
1835        -total_log_prob / (data.len() as f64)
1836    }
1837
1838    pub fn cross_entropy_cps(&self, data: &[u32]) -> f64 {
1839        if !self.lm_built || data.is_empty() {
1840            return 0.0;
1841        }
1842        let mut total_log_prob = 0.0f64;
1843        let mut v = 0i32;
1844        for &ch in data {
1845            let sym_idx = self.lm.find_sym(ch);
1846            let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
1847            total_log_prob += p.log2();
1848            v = self.sam.advance(v, ch);
1849        }
1850        -total_log_prob / (data.len() as f64)
1851    }
1852
1853    fn entropy_rate_plugin_cps(&mut self, cps: &[u32]) -> f64 {
1854        let mut v = 0i32;
1855        let mut total_log_prob = 0.0f64;
1856        let mut count = 0usize;
1857        for t in 0..(cps.len() - 1) {
1858            v = self.sam.advance(v, cps[t]);
1859            let next_ch = cps[t + 1];
1860            let sym_idx = self.lm.find_sym(next_ch);
1861            let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
1862            total_log_prob += p.log2();
1863            count += 1;
1864        }
1865        if count == 0 {
1866            0.0
1867        } else {
1868            -total_log_prob / (count as f64)
1869        }
1870    }
1871
1872    /// Returns the marginal (unigram) distribution over the training data.
1873    /// Output: Vec of (codepoint, probability) pairs, sorted by codepoint.
1874    pub fn marginal_distribution(&self) -> Vec<(u32, f64)> {
1875        if self.lm.total_uni == 0 {
1876            return Vec::new();
1877        }
1878
1879        let inv = 1.0 / (self.lm.total_uni as f64);
1880        let mut result = Vec::with_capacity(self.lm.alpha_n as usize);
1881        for i in 0..(self.lm.alpha_n as usize) {
1882            let p = (self.lm.unigram[i] as f64) * inv;
1883            if p > 0.0 {
1884                result.push((self.lm.alphabet[i], p));
1885            }
1886        }
1887        result.sort_by_key(|&(cp, _)| cp);
1888        result
1889    }
1890
1891    /// Compute the marginal entropy H(X) from the unigram distribution.
1892    /// Returns bits per symbol.
1893    pub fn marginal_entropy(&self) -> f64 {
1894        if self.lm.total_uni == 0 {
1895            return 0.0;
1896        }
1897
1898        let inv = 1.0 / (self.lm.total_uni as f64);
1899        let mut h = 0.0f64;
1900        for i in 0..(self.lm.alpha_n as usize) {
1901            let p = (self.lm.unigram[i] as f64) * inv;
1902            if p > 0.0 {
1903                h -= p * p.log2();
1904            }
1905        }
1906        h
1907    }
1908
1909    pub fn save(&self, path: &str) -> std::io::Result<()> {
1910        if !self.lm_built {
1911            return Err(std::io::Error::new(
1912                std::io::ErrorKind::Other,
1913                "LM not built",
1914            ));
1915        }
1916
1917        // Transactional conditional updates require a valid prefix-state trace.
1918        // If this invariant is violated, the loaded model would be unusable.
1919        if self.sam.text_states.len() != self.sam.text.len() + 1 {
1920            return Err(std::io::Error::new(
1921                std::io::ErrorKind::Other,
1922                "SAM text_states mismatch (expected text.len()+1)",
1923            ));
1924        }
1925        let mut f = BufWriter::with_capacity(1024 * 1024, File::create(path)?);
1926        f.write_all(MAGIC)?;
1927        f.write_all(&self.max_order.to_le_bytes())?;
1928        f.write_all(&(self.use_eot as i32).to_le_bytes())?;
1929        f.write_all(&self.eot.to_le_bytes())?;
1930        f.write_all(&self.seed.to_le_bytes())?;
1931
1932        // SAM
1933        f.write_all(&(self.sam.st.len() as u32).to_le_bytes())?;
1934        f.write_all(&(self.sam.ed.len() as u32).to_le_bytes())?;
1935        f.write_all(&(self.sam.text.len() as u32).to_le_bytes())?;
1936        for st in &self.sam.st {
1937            f.write_all(&st.link.to_le_bytes())?;
1938            f.write_all(&st.len.to_le_bytes())?;
1939            f.write_all(&st.endpos.to_le_bytes())?;
1940            f.write_all(&(st.small_n as u32).to_le_bytes())?;
1941            for k in 0..(st.small_n as usize) {
1942                f.write_all(&st.small_ch[k].to_le_bytes())?;
1943                f.write_all(&st.small_to[k].to_le_bytes())?;
1944            }
1945            f.write_all(&st.head.to_le_bytes())?;
1946        }
1947        for e in &self.sam.ed {
1948            f.write_all(&e.ch.to_le_bytes())?;
1949            f.write_all(&e.to.to_le_bytes())?;
1950            f.write_all(&e.next.to_le_bytes())?;
1951        }
1952        write_u32_slice_le(&mut f, &self.sam.text)?;
1953        f.write_all(&self.sam.boundary_after)?;
1954
1955        // Persist SAM cursor + prefix trace.
1956        f.write_all(&self.sam.last.to_le_bytes())?;
1957        f.write_all(&(self.sam.text_states.len() as u32).to_le_bytes())?;
1958        write_i32_slice_le(&mut f, &self.sam.text_states)?;
1959
1960        // LM
1961        f.write_all(&self.lm.alpha_n.to_le_bytes())?;
1962        f.write_all(&self.lm.total_uni.to_le_bytes())?;
1963        f.write_all(&(self.lm.nodes.len() as u32).to_le_bytes())?;
1964        write_u32_slice_le(&mut f, &self.lm.alphabet)?;
1965        write_u64_slice_le(&mut f, &self.lm.unigram)?;
1966        for ls in &self.lm.ls {
1967            f.write_all(&ls.head.to_le_bytes())?;
1968            f.write_all(&ls.total_n.to_le_bytes())?;
1969            f.write_all(&ls.types_t.to_le_bytes())?;
1970        }
1971        for n in &self.lm.nodes {
1972            f.write_all(&n.sym_idx.to_le_bytes())?;
1973            f.write_all(&n.cnt.to_le_bytes())?;
1974            f.write_all(&n.next.to_le_bytes())?;
1975        }
1976        f.flush()?;
1977        Ok(())
1978    }
1979
1980    pub fn load(path: &str) -> std::io::Result<Self> {
1981        let mut f = BufReader::with_capacity(1024 * 1024, File::open(path)?);
1982        let mut magic = vec![0u8; MAGIC.len()];
1983        f.read_exact(&mut magic)?;
1984        if magic != MAGIC {
1985            return Err(std::io::Error::new(
1986                std::io::ErrorKind::InvalidData,
1987                "bad magic",
1988            ));
1989        }
1990
1991        let mut b8 = [0u8; 8];
1992        let mut b4 = [0u8; 4];
1993
1994        f.read_exact(&mut b8)?;
1995        let max_order = i64::from_le_bytes(b8);
1996        f.read_exact(&mut b4)?;
1997        let use_eot = i32::from_le_bytes(b4) != 0;
1998        f.read_exact(&mut b4)?;
1999        let eot = u32::from_le_bytes(b4);
2000        f.read_exact(&mut b8)?;
2001        let seed = u64::from_le_bytes(b8);
2002
2003        let mut m = RosaPlus::new(max_order, use_eot, eot as u8, seed);
2004
2005        // SAM
2006        f.read_exact(&mut b4)?;
2007        let st_n = u32::from_le_bytes(b4) as usize;
2008        f.read_exact(&mut b4)?;
2009        let ed_n = u32::from_le_bytes(b4) as usize;
2010        f.read_exact(&mut b4)?;
2011        let text_n = u32::from_le_bytes(b4) as usize;
2012
2013        m.sam = Sam::new(text_n);
2014        m.sam.st.resize(st_n, SamState::default());
2015        m.sam.ed.resize(ed_n, SamEdge::default());
2016        m.sam.text.resize(text_n, 0u32);
2017        m.sam.boundary_after.resize(text_n, 0u8);
2018
2019        for i in 0..st_n {
2020            f.read_exact(&mut b4)?;
2021            m.sam.st[i].link = i32::from_le_bytes(b4);
2022            f.read_exact(&mut b4)?;
2023            m.sam.st[i].len = i32::from_le_bytes(b4);
2024            f.read_exact(&mut b4)?;
2025            m.sam.st[i].endpos = i32::from_le_bytes(b4);
2026            f.read_exact(&mut b4)?;
2027            let sn = u32::from_le_bytes(b4) as usize;
2028            if sn > SAM_SMALL_MAX {
2029                return Err(std::io::Error::new(
2030                    std::io::ErrorKind::InvalidData,
2031                    "bad small_n",
2032                ));
2033            }
2034            m.sam.st[i].small_n = sn as u8;
2035            for k in 0..sn {
2036                f.read_exact(&mut b4)?;
2037                m.sam.st[i].small_ch[k] = u32::from_le_bytes(b4);
2038                f.read_exact(&mut b4)?;
2039                m.sam.st[i].small_to[k] = i32::from_le_bytes(b4);
2040            }
2041            f.read_exact(&mut b4)?;
2042            m.sam.st[i].head = i32::from_le_bytes(b4);
2043        }
2044        for i in 0..ed_n {
2045            f.read_exact(&mut b4)?;
2046            m.sam.ed[i].ch = u32::from_le_bytes(b4);
2047            f.read_exact(&mut b4)?;
2048            m.sam.ed[i].to = i32::from_le_bytes(b4);
2049            f.read_exact(&mut b4)?;
2050            m.sam.ed[i].next = i32::from_le_bytes(b4);
2051        }
2052        read_u32_slice_le(&mut f, &mut m.sam.text)?;
2053        f.read_exact(&mut m.sam.boundary_after)?;
2054
2055        // SAM cursor + prefix trace.
2056        f.read_exact(&mut b4)?;
2057        m.sam.last = i32::from_le_bytes(b4);
2058        f.read_exact(&mut b4)?;
2059        let text_states_n = u32::from_le_bytes(b4) as usize;
2060        if text_states_n != text_n + 1 {
2061            return Err(std::io::Error::new(
2062                std::io::ErrorKind::InvalidData,
2063                "bad text_states len",
2064            ));
2065        }
2066        m.sam.text_states.resize(text_states_n, 0);
2067        read_i32_slice_le(&mut f, &mut m.sam.text_states)?;
2068        for &v in &m.sam.text_states {
2069            if v < 0 || (v as usize) >= st_n {
2070                return Err(std::io::Error::new(
2071                    std::io::ErrorKind::InvalidData,
2072                    "bad text_states entry",
2073                ));
2074            }
2075        }
2076        if m.sam.last < 0 || (m.sam.last as usize) >= st_n {
2077            return Err(std::io::Error::new(
2078                std::io::ErrorKind::InvalidData,
2079                "bad sam.last",
2080            ));
2081        }
2082
2083        // LM
2084        f.read_exact(&mut b4)?;
2085        let alpha_n = u32::from_le_bytes(b4) as usize;
2086        f.read_exact(&mut b8)?;
2087        let total_uni = u64::from_le_bytes(b8);
2088        f.read_exact(&mut b4)?;
2089        let nodes_n = u32::from_le_bytes(b4) as usize;
2090
2091        m.lm = LM::default();
2092        m.lm.alpha_n = alpha_n as u32;
2093        m.lm.total_uni = total_uni;
2094        m.lm.alphabet.resize(alpha_n, 0);
2095        m.lm.unigram.resize(alpha_n, 0);
2096        m.lm.ls = vec![
2097            LmState {
2098                head: -1,
2099                last_node: -1,
2100                ..LmState::default()
2101            };
2102            st_n
2103        ];
2104        m.lm.nodes.resize(nodes_n, CountNode::default());
2105
2106        read_u32_slice_le(&mut f, &mut m.lm.alphabet)?;
2107        read_u64_slice_le(&mut f, &mut m.lm.unigram)?;
2108        for i in 0..st_n {
2109            f.read_exact(&mut b4)?;
2110            m.lm.ls[i].head = i32::from_le_bytes(b4);
2111            f.read_exact(&mut b8)?;
2112            m.lm.ls[i].total_n = u64::from_le_bytes(b8);
2113            f.read_exact(&mut b4)?;
2114            m.lm.ls[i].types_t = u32::from_le_bytes(b4);
2115            m.lm.ls[i].last_node = -1;
2116            m.lm.ls[i].last_sym = 0;
2117        }
2118        for i in 0..nodes_n {
2119            f.read_exact(&mut b4)?;
2120            m.lm.nodes[i].sym_idx = u32::from_le_bytes(b4);
2121            f.read_exact(&mut b8)?;
2122            m.lm.nodes[i].cnt = u64::from_le_bytes(b8);
2123            f.read_exact(&mut b4)?;
2124            m.lm.nodes[i].next = i32::from_le_bytes(b4);
2125        }
2126
2127        // rebuild byte_map for lookups
2128        m.lm.has_byte_map = false;
2129        m.lm.byte_map = [-1; 256];
2130        let mut max_cp = 0u32;
2131        for &v in &m.lm.alphabet {
2132            if v > max_cp {
2133                max_cp = v;
2134            }
2135        }
2136        if max_cp < 256 {
2137            m.lm.has_byte_map = true;
2138            for (i, &c) in m.lm.alphabet.iter().enumerate() {
2139                m.lm.byte_map[c as usize] = i as i16;
2140            }
2141        }
2142
2143        m.lm_built = true;
2144        m.dist.resize(alpha_n, 0.0);
2145        Ok(m)
2146    }
2147
2148    pub fn prob_for_last(&mut self, sym: u32) -> f64 {
2149        if !self.lm_built {
2150            self.build_lm();
2151        }
2152        let v = self.sam.last;
2153        let sym_idx = self.lm.find_sym(sym);
2154        let mo = if self.max_order < 0 {
2155            -1
2156        } else {
2157            self.max_order
2158        };
2159        self.lm.prob_for_sym(&self.sam, mo, v, sym_idx)
2160    }
2161}
2162
2163#[cfg(test)]
2164mod tests {
2165    use super::*;
2166
2167    #[test]
2168    fn rosa_md_example_basic() {
2169        // From rosa.md: ROSA predicts next token of best previous match.
2170        let x = b"ababa";
2171        let mut m = RosaPlus::new(1048576, false, 4, 0);
2172        m.train_example(x);
2173        m.build_lm();
2174        let out = m.generate(b"a", 10).unwrap();
2175        assert!(!out.is_empty());
2176    }
2177
2178    #[test]
2179    fn tx_rollback_restores_sam_and_unigram_counts() {
2180        let mut m = RosaPlus::new(4, false, 0, 123);
2181        m.train_example(b"hello");
2182        m.build_lm_full_bytes_no_finalize_endpos();
2183
2184        let base_text = m.sam.text.clone();
2185        let base_text_len = m.sam.text.len();
2186        let base_total_uni = m.lm.total_uni;
2187        assert!(base_text_len > 0);
2188
2189        let mut tx = m.begin_tx();
2190        m.train_example_tx(&mut tx, b"abc");
2191        assert_eq!(m.lm.total_uni, base_total_uni + 3);
2192        assert_eq!(m.sam.text.len(), base_text_len + 3);
2193
2194        m.rollback_tx(tx);
2195        assert_eq!(m.sam.text, base_text);
2196        assert_eq!(m.lm.total_uni, base_total_uni);
2197    }
2198
2199    #[test]
2200    fn checkpoint_restore_reverts_append_only_buffers() {
2201        let mut m = RosaPlus::new(3, true, b'\n', 7);
2202        m.train_example(b"aaaa");
2203
2204        let ck = m.checkpoint();
2205        let base_text = m.sam.text.clone();
2206        let base_states = m.sam.text_states.clone();
2207        let base_boundary = m.sam.boundary_after.clone();
2208        let base_last = m.sam.last;
2209
2210        m.train_example(b"bbbb");
2211        assert_ne!(m.sam.text, base_text);
2212
2213        m.restore(&ck);
2214        assert_eq!(m.sam.text, base_text);
2215        assert_eq!(m.sam.text_states, base_states);
2216        assert_eq!(m.sam.boundary_after, base_boundary);
2217        assert_eq!(m.sam.last, base_last);
2218        assert!(!m.lm_built);
2219    }
2220}