1#![allow(clippy::needless_range_loop)]
17
18use std::fs::File;
19use std::io::{BufReader, BufWriter, Read, Write};
20
21const SAM_SMALL_MAX: usize = 4;
22const MAGIC: &[u8] = b"rosa_pb_v4\0";
25
26const BYTE_ALPHA_N: usize = 256;
29
30#[inline(always)]
31fn write_u32_slice_le<W: Write>(w: &mut W, xs: &[u32]) -> std::io::Result<()> {
32 if cfg!(target_endian = "little") {
33 let bytes = unsafe {
34 std::slice::from_raw_parts(xs.as_ptr() as *const u8, xs.len().saturating_mul(4))
35 };
36 w.write_all(bytes)
37 } else {
38 for &x in xs {
39 w.write_all(&x.to_le_bytes())?;
40 }
41 Ok(())
42 }
43}
44
45#[inline(always)]
46fn write_i32_slice_le<W: Write>(w: &mut W, xs: &[i32]) -> std::io::Result<()> {
47 if cfg!(target_endian = "little") {
48 let bytes = unsafe {
49 std::slice::from_raw_parts(xs.as_ptr() as *const u8, xs.len().saturating_mul(4))
50 };
51 w.write_all(bytes)
52 } else {
53 for &x in xs {
54 w.write_all(&x.to_le_bytes())?;
55 }
56 Ok(())
57 }
58}
59
60#[inline(always)]
61fn write_u64_slice_le<W: Write>(w: &mut W, xs: &[u64]) -> std::io::Result<()> {
62 if cfg!(target_endian = "little") {
63 let bytes = unsafe {
64 std::slice::from_raw_parts(xs.as_ptr() as *const u8, xs.len().saturating_mul(8))
65 };
66 w.write_all(bytes)
67 } else {
68 for &x in xs {
69 w.write_all(&x.to_le_bytes())?;
70 }
71 Ok(())
72 }
73}
74
75#[inline(always)]
76fn read_u32_slice_le<R: Read>(r: &mut R, xs: &mut [u32]) -> std::io::Result<()> {
77 if cfg!(target_endian = "little") {
78 let bytes = unsafe {
79 std::slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut u8, xs.len().saturating_mul(4))
80 };
81 r.read_exact(bytes)
82 } else {
83 let mut b4 = [0u8; 4];
84 for x in xs {
85 r.read_exact(&mut b4)?;
86 *x = u32::from_le_bytes(b4);
87 }
88 Ok(())
89 }
90}
91
92#[inline(always)]
93fn read_i32_slice_le<R: Read>(r: &mut R, xs: &mut [i32]) -> std::io::Result<()> {
94 if cfg!(target_endian = "little") {
95 let bytes = unsafe {
96 std::slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut u8, xs.len().saturating_mul(4))
97 };
98 r.read_exact(bytes)
99 } else {
100 let mut b4 = [0u8; 4];
101 for x in xs {
102 r.read_exact(&mut b4)?;
103 *x = i32::from_le_bytes(b4);
104 }
105 Ok(())
106 }
107}
108
109#[inline(always)]
110fn read_u64_slice_le<R: Read>(r: &mut R, xs: &mut [u64]) -> std::io::Result<()> {
111 if cfg!(target_endian = "little") {
112 let bytes = unsafe {
113 std::slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut u8, xs.len().saturating_mul(8))
114 };
115 r.read_exact(bytes)
116 } else {
117 let mut b8 = [0u8; 8];
118 for x in xs {
119 r.read_exact(&mut b8)?;
120 *x = u64::from_le_bytes(b8);
121 }
122 Ok(())
123 }
124}
125
126#[derive(Clone, Copy, Default)]
127struct SamState {
128 link: i32,
129 len: i32,
130 endpos: i32,
131 head: i32,
132
133 small_ch: [u32; SAM_SMALL_MAX],
134 small_to: [i32; SAM_SMALL_MAX],
135 small_n: u8,
136}
137
138#[derive(Clone, Copy, Default)]
139struct SamEdge {
140 ch: u32,
141 to: i32,
142 next: i32,
143}
144
145#[derive(Clone, Default)]
146struct Sam {
147 st: Vec<SamState>,
148 ed: Vec<SamEdge>,
149 last: i32,
150
151 text: Vec<u32>,
152 text_states: Vec<i32>,
153 boundary_after: Vec<u8>,
154}
155
156impl Sam {
157 fn new(expected_chars: usize) -> Self {
158 let mut s = Sam {
159 st: Vec::new(),
160 ed: Vec::new(),
161 last: 0,
162 text: Vec::new(),
163 text_states: Vec::new(),
164 boundary_after: Vec::new(),
165 };
166
167 let st_cap = if expected_chars > 0 {
168 expected_chars * 2 + 16
169 } else {
170 1024
171 };
172 let ed_cap = if expected_chars > 0 {
173 expected_chars * 3 + 16
174 } else {
175 2048
176 };
177 let text_cap = if expected_chars > 0 {
178 expected_chars + 16
179 } else {
180 1024
181 };
182 s.st.reserve(st_cap);
183 s.ed.reserve(ed_cap);
184 s.text.reserve(text_cap);
185 s.text_states.reserve(text_cap);
186 s.boundary_after.reserve(text_cap);
187
188 let mut root = SamState::default();
189 root.link = -1;
190 root.len = 0;
191 root.endpos = -1;
192 root.small_n = 0;
193 root.head = -1;
194 s.st.push(root);
195 s.text_states.push(0); s
197 }
198
199 #[inline(always)]
200 fn get_edge(&self, v: i32, ch: u32) -> i32 {
201 let st = unsafe { self.st.get_unchecked(v as usize) };
202 for i in 0..(st.small_n as usize) {
203 if st.small_ch[i] == ch {
204 return st.small_to[i];
205 }
206 }
207 let mut ei = st.head;
208 while ei != -1 {
209 let e = unsafe { self.ed.get_unchecked(ei as usize) };
210 if e.ch == ch {
211 return e.to;
212 }
213 ei = e.next;
214 }
215 -1
216 }
217
218 #[inline(always)]
219 fn add_edge(&mut self, v: i32, ch: u32, to: i32) {
220 let idx = self.ed.len() as i32;
221 let head = self.st[v as usize].head;
222 self.ed.push(SamEdge { ch, to, next: head });
223 self.st[v as usize].head = idx;
224 }
225
226 #[inline(always)]
227 fn add_edge_absent(&mut self, v: i32, ch: u32, to: i32) {
228 let st = &mut self.st[v as usize];
229 if (st.small_n as usize) < SAM_SMALL_MAX {
230 let i = st.small_n as usize;
231 st.small_n += 1;
232 st.small_ch[i] = ch;
233 st.small_to[i] = to;
234 } else {
235 self.add_edge(v, ch, to);
236 }
237 }
238
239 #[inline(always)]
240 fn replace_edge_to(&mut self, v: i32, ch: u32, old_to: i32, new_to: i32) -> bool {
241 {
242 let st = &mut self.st[v as usize];
243 for i in 0..(st.small_n as usize) {
244 if st.small_ch[i] == ch && st.small_to[i] == old_to {
245 st.small_to[i] = new_to;
246 return true;
247 }
248 }
249 }
250 let mut ei = self.st[v as usize].head;
251 while ei != -1 {
252 let e = &mut self.ed[ei as usize];
253 if e.ch == ch && e.to == old_to {
254 e.to = new_to;
255 return true;
256 }
257 ei = e.next;
258 }
259 false
260 }
261
262 fn clone_overflow_edges(&mut self, src: i32, dst: i32) {
263 self.st[dst as usize].head = -1;
264 let mut ei = self.st[src as usize].head;
265 while ei != -1 {
266 let e = self.ed[ei as usize];
267 self.add_edge(dst, e.ch, e.to);
268 ei = e.next;
269 }
270 }
271
272 fn feed(&mut self, ch: u32) {
273 let i = self.text.len() as i32;
274 self.text.push(ch);
275 self.boundary_after.push(0);
276
277 let g = self.last;
278 let r = self.st.len() as i32;
279 let mut st_r = SamState::default();
280 st_r.link = 0;
281 st_r.len = self.st[g as usize].len + 1;
282 st_r.endpos = i;
283 st_r.small_n = 0;
284 st_r.head = -1;
285 self.st.push(st_r);
286
287 let mut p = g;
288 let mut q;
289 while p != -1 {
290 q = self.get_edge(p, ch);
291 if q != -1 {
292 break;
293 }
294 self.add_edge_absent(p, ch, r);
295 p = self.st[p as usize].link;
296 }
297
298 if p == -1 {
299 self.st[r as usize].link = 0;
300 } else {
301 q = self.get_edge(p, ch);
302 if self.st[p as usize].len + 1 == self.st[q as usize].len {
303 self.st[r as usize].link = q;
304 } else {
305 let u = self.st.len() as i32;
306 let mut st_u = self.st[q as usize];
307 st_u.len = self.st[p as usize].len + 1;
308 self.st.push(st_u);
309 self.clone_overflow_edges(q, u);
310 while p != -1 && self.replace_edge_to(p, ch, q, u) {
311 p = self.st[p as usize].link;
312 }
313 self.st[q as usize].link = u;
314 self.st[r as usize].link = u;
315 }
316 }
317
318 self.last = r;
319 self.text_states.push(r);
320
321 let mut v = r;
323 while v != -1 && self.st[v as usize].endpos < i {
324 self.st[v as usize].endpos = i;
325 v = self.st[v as usize].link;
326 }
327 }
328
329 fn mark_boundary(&mut self) {
330 if !self.text.is_empty() {
331 let i = self.text.len() - 1;
332 self.boundary_after[i] = 1;
333 }
334 self.last = 0;
335 }
336
337 fn finalize_endpos(&mut self) {
338 let mut max_len: usize = 0;
339 for v in 0..self.st.len() {
340 let l = self.st[v].len as usize;
341 if l > max_len {
342 max_len = l;
343 }
344 }
345
346 let mut cnt = vec![0usize; max_len + 1];
347 for v in 0..self.st.len() {
348 cnt[self.st[v].len as usize] += 1;
349 }
350 let mut pos = vec![0usize; max_len + 1];
351 let mut acc = 0usize;
352 for l in 0..=max_len {
353 pos[l] = acc;
354 acc += cnt[l];
355 }
356 let mut order = vec![0u32; self.st.len()];
357 for v in 0..self.st.len() {
358 let l = self.st[v].len as usize;
359 let idx = pos[l];
360 order[idx] = v as u32;
361 pos[l] += 1;
362 }
363
364 for oi in (0..order.len()).rev() {
365 let v = order[oi] as usize;
366 let p = self.st[v].link;
367 if p >= 0 {
368 let p = p as usize;
369 if self.st[v].endpos > self.st[p].endpos {
370 self.st[p].endpos = self.st[v].endpos;
371 }
372 }
373 }
374 }
375
376 #[inline(always)]
377 fn advance(&self, mut v: i32, ch: u32) -> i32 {
378 let mut to;
379 loop {
380 to = self.get_edge(v, ch);
381 if to != -1 {
382 return to;
383 }
384 v = self.st[v as usize].link;
385 if v == -1 {
386 break;
387 }
388 }
389 to = self.get_edge(0, ch);
390 if to == -1 { 0 } else { to }
391 }
392
393 #[inline(always)]
394 fn predict_det(&self, v: i32) -> Option<u32> {
395 let mut u = v;
396 while u != -1 {
397 let st = unsafe { self.st.get_unchecked(u as usize) };
398 let i = st.endpos;
399 let j = i + 1;
400 if st.len > 0 && j >= 0 && (j as usize) < self.text.len() {
401 if i >= 0
402 && (i as usize) < self.boundary_after.len()
403 && self.boundary_after[i as usize] != 0
404 {
405 u = st.link;
406 continue;
407 }
408 return Some(self.text[j as usize]);
409 }
410 u = st.link;
411 }
412 None
413 }
414
415 fn begin_tx(&self) -> SamTx {
417 SamTx {
418 old_last: self.last,
419 old_text_len: self.text.len(),
420 old_text_states_len: self.text_states.len(),
421 old_boundary_len: self.boundary_after.len(),
422 old_st_len: self.st.len(),
423 old_ed_len: self.ed.len(),
424 st_changes: Vec::new(),
425 ed_changes: Vec::new(),
426 }
427 }
428
429 fn rollback_tx(&mut self, tx: SamTx) {
430 for (idx, old) in tx.ed_changes.into_iter().rev() {
432 if idx < self.ed.len() {
433 self.ed[idx] = old;
434 }
435 }
436 for (idx, old) in tx.st_changes.into_iter().rev() {
437 if idx < self.st.len() {
438 self.st[idx] = old;
439 }
440 }
441
442 self.st.truncate(tx.old_st_len);
443 self.ed.truncate(tx.old_ed_len);
444 self.text.truncate(tx.old_text_len);
445 self.text_states.truncate(tx.old_text_states_len);
446 self.boundary_after.truncate(tx.old_boundary_len);
447 self.last = tx.old_last;
448 }
449
450 #[inline(always)]
451 fn record_state_change(&self, tx: &mut SamTx, idx: usize) {
452 tx.st_changes.push((idx, self.st[idx]));
454 }
455
456 #[inline(always)]
457 fn record_edge_change(&self, tx: &mut SamTx, idx: usize) {
458 tx.ed_changes.push((idx, self.ed[idx]));
459 }
460
461 #[inline(always)]
462 fn add_edge_tx(&mut self, tx: &mut SamTx, v: i32, ch: u32, to: i32) {
463 let idx = self.ed.len() as i32;
464 let head = self.st[v as usize].head;
465 self.ed.push(SamEdge { ch, to, next: head });
466 self.record_state_change(tx, v as usize);
467 self.st[v as usize].head = idx;
468 }
469
470 #[inline(always)]
471 fn add_edge_absent_tx(&mut self, tx: &mut SamTx, v: i32, ch: u32, to: i32) {
472 let v_usize = v as usize;
473 let small_n = self.st[v_usize].small_n as usize;
474 if small_n < SAM_SMALL_MAX {
475 let i = small_n;
476 self.record_state_change(tx, v_usize);
477 let st = &mut self.st[v_usize];
478 st.small_ch[i] = ch;
479 st.small_to[i] = to;
480 st.small_n += 1;
481 } else {
482 self.add_edge_tx(tx, v, ch, to);
483 }
484 }
485
486 #[inline(always)]
487 fn replace_edge_to_tx(
488 &mut self,
489 tx: &mut SamTx,
490 v: i32,
491 ch: u32,
492 old_to: i32,
493 new_to: i32,
494 ) -> bool {
495 {
497 let st = &self.st[v as usize];
498 for i in 0..(st.small_n as usize) {
499 if st.small_ch[i] == ch && st.small_to[i] == old_to {
500 self.record_state_change(tx, v as usize);
501 self.st[v as usize].small_to[i] = new_to;
502 return true;
503 }
504 }
505 }
506 let mut ei = self.st[v as usize].head;
508 while ei != -1 {
509 let eidx = ei as usize;
510 let e = self.ed[eidx];
511 if e.ch == ch && e.to == old_to {
512 self.record_edge_change(tx, eidx);
513 self.ed[eidx].to = new_to;
514 return true;
515 }
516 ei = e.next;
517 }
518 false
519 }
520
521 fn clone_overflow_edges_tx(&mut self, tx: &mut SamTx, src: i32, dst: i32) {
522 self.record_state_change(tx, dst as usize);
523 self.st[dst as usize].head = -1;
524 let mut ei = self.st[src as usize].head;
525 while ei != -1 {
526 let e = self.ed[ei as usize];
527 self.add_edge_tx(tx, dst, e.ch, e.to);
528 ei = e.next;
529 }
530 }
531
532 fn feed_tx(&mut self, tx: &mut SamTx, ch: u32) {
533 let i = self.text.len() as i32;
534 self.text.push(ch);
535 self.boundary_after.push(0);
536
537 let g = self.last;
538 let r = self.st.len() as i32;
539 let mut st_r = SamState::default();
540 st_r.link = 0;
541 st_r.len = self.st[g as usize].len + 1;
542 st_r.endpos = i;
543 st_r.small_n = 0;
544 st_r.head = -1;
545 self.st.push(st_r);
546
547 let mut p = g;
548 let mut q;
549 while p != -1 {
550 q = self.get_edge(p, ch);
551 if q != -1 {
552 break;
553 }
554 self.add_edge_absent_tx(tx, p, ch, r);
555 p = self.st[p as usize].link;
556 }
557
558 if p == -1 {
559 self.st[r as usize].link = 0;
561 } else {
562 q = self.get_edge(p, ch);
563 if self.st[p as usize].len + 1 == self.st[q as usize].len {
564 self.st[r as usize].link = q;
565 } else {
566 let u = self.st.len() as i32;
567 let mut st_u = self.st[q as usize];
568 st_u.len = self.st[p as usize].len + 1;
569 self.st.push(st_u);
570 self.clone_overflow_edges_tx(tx, q, u);
571 while p != -1 && self.replace_edge_to_tx(tx, p, ch, q, u) {
572 p = self.st[p as usize].link;
573 }
574 self.record_state_change(tx, q as usize);
576 self.st[q as usize].link = u;
577 self.st[r as usize].link = u;
578 }
579 }
580
581 self.last = r;
582 self.text_states.push(r);
583
584 let mut v = r;
586 while v != -1 && self.st[v as usize].endpos < i {
587 self.record_state_change(tx, v as usize);
588 self.st[v as usize].endpos = i;
589 v = self.st[v as usize].link;
590 }
591 }
592
593 fn mark_boundary_tx(&mut self, tx: &mut SamTx) {
594 if !self.text.is_empty() {
595 let i = self.text.len() - 1;
597 self.boundary_after[i] = 1;
598 }
599 self.last = 0;
601 let _ = tx;
602 }
603}
604
605#[derive(Clone)]
606struct SamTx {
607 old_last: i32,
608 old_text_len: usize,
609 old_text_states_len: usize,
610 old_boundary_len: usize,
611 old_st_len: usize,
612 old_ed_len: usize,
613 st_changes: Vec<(usize, SamState)>,
614 ed_changes: Vec<(usize, SamEdge)>,
615}
616
617#[derive(Clone, Copy, Default)]
618struct LmState {
619 head: i32,
620 total_n: u64,
621 types_t: u32,
622
623 last_sym: u32,
624 last_node: i32,
625}
626
627#[derive(Clone, Copy, Default)]
628struct CountNode {
629 sym_idx: u32,
630 cnt: u64,
631 next: i32,
632}
633
634#[derive(Clone)]
635struct LM {
636 alphabet: Vec<u32>,
637 unigram: Vec<u64>,
638 alpha_n: u32,
639 total_uni: u64,
640
641 has_byte_map: bool,
642 byte_map: [i16; 256],
643
644 ls: Vec<LmState>,
645 nodes: Vec<CountNode>,
646}
647
648impl Default for LM {
649 fn default() -> Self {
650 LM {
651 alphabet: Vec::new(),
652 unigram: Vec::new(),
653 alpha_n: 0,
654 total_uni: 0,
655 has_byte_map: false,
656 byte_map: [-1; 256],
657 ls: Vec::new(),
658 nodes: Vec::new(),
659 }
660 }
661}
662
663impl LM {
664 fn build_alphabet(&mut self, sam: &Sam) {
665 self.has_byte_map = false;
666 self.byte_map = [-1; 256];
667
668 let mut max_cp = 0u32;
669 for &v in &sam.text {
670 if v > max_cp {
671 max_cp = v;
672 }
673 }
674
675 if max_cp < 256 {
676 let mut counts = [0u64; 256];
677 for &v in &sam.text {
678 counts[v as usize] += 1;
679 }
680 let mut uniq = 0usize;
681 for c in 0..256 {
682 if counts[c] != 0 {
683 uniq += 1;
684 }
685 }
686
687 if uniq == 0 {
688 self.alphabet = vec![b'\n' as u32];
689 self.unigram = vec![1];
690 self.alpha_n = 1;
691 self.total_uni = 1;
692 self.has_byte_map = true;
693 self.byte_map[b'\n' as usize] = 0;
694 return;
695 }
696
697 self.alphabet = Vec::with_capacity(uniq);
698 self.unigram = Vec::with_capacity(uniq);
699 self.total_uni = 0;
700 for c in 0..256u32 {
701 let cnt = counts[c as usize];
702 if cnt == 0 {
703 continue;
704 }
705 self.alphabet.push(c);
706 self.unigram.push(cnt);
707 self.total_uni += cnt;
708 }
709 self.alpha_n = self.alphabet.len() as u32;
710 self.has_byte_map = true;
711 for (i, &c) in self.alphabet.iter().enumerate() {
712 self.byte_map[c as usize] = i as i16;
713 }
714 return;
715 }
716
717 let mut tmp = sam.text.clone();
718 tmp.sort_unstable();
719 tmp.dedup();
720 if tmp.is_empty() {
721 tmp.push(b'\n' as u32);
722 }
723 self.alphabet = tmp;
724 self.alpha_n = self.alphabet.len() as u32;
725 self.unigram = vec![0u64; self.alphabet.len()];
726 self.total_uni = 0;
727 for &ch in &sam.text {
728 if let Ok(i) = self.alphabet.binary_search(&ch) {
729 self.unigram[i] += 1;
730 self.total_uni += 1;
731 }
732 }
733 if self.total_uni == 0 {
734 self.unigram[0] = 1;
735 self.total_uni = 1;
736 }
737 }
738
739 #[inline(always)]
740 fn find_sym(&self, ch: u32) -> i32 {
741 if self.has_byte_map && ch < 256 {
742 return self.byte_map[ch as usize] as i32;
743 }
744 match self.alphabet.binary_search(&ch) {
745 Ok(i) => i as i32,
746 Err(_) => -1,
747 }
748 }
749
750 #[inline(always)]
751 fn inc(&mut self, state: u32, sym_idx: u32, add: u64) {
752 let ls = &mut self.ls[state as usize];
753 let last = ls.last_node;
754 if last != -1 && self.nodes[last as usize].sym_idx == sym_idx {
755 self.nodes[last as usize].cnt += add;
756 ls.total_n += add;
757 return;
758 }
759
760 let mut ni = ls.head;
761 while ni != -1 {
762 let node = &mut self.nodes[ni as usize];
763 if node.sym_idx == sym_idx {
764 node.cnt += add;
765 ls.total_n += add;
766 ls.last_node = ni;
767 ls.last_sym = sym_idx;
768 return;
769 }
770 ni = node.next;
771 }
772
773 let idx = self.nodes.len() as i32;
774 self.nodes.push(CountNode {
775 sym_idx,
776 cnt: add,
777 next: ls.head,
778 });
779 ls.head = idx;
780 ls.total_n += add;
781 ls.types_t += 1;
782 ls.last_node = idx;
783 ls.last_sym = sym_idx;
784 }
785
786 fn build_counts(&mut self, sam: &Sam, max_order: i64) {
787 self.ls = vec![
788 LmState {
789 head: -1,
790 last_node: -1,
791 ..LmState::default()
792 };
793 sam.st.len()
794 ];
795 self.nodes.clear();
796
797 let mut seg_start = 0usize;
798 while seg_start < sam.text.len() {
799 let mut seg_end = seg_start;
800 while seg_end < sam.text.len() {
801 let b = sam.boundary_after[seg_end];
802 seg_end += 1;
803 if b != 0 {
804 break;
805 }
806 }
807 if seg_end - seg_start >= 2 {
808 let mut v = 0i32;
809 for i in seg_start..(seg_end - 1) {
810 let ch = sam.text[i];
811 v = sam.advance(v, ch);
812 let mut ctx = v;
813 if max_order >= 0 {
814 while ctx != -1 && (sam.st[ctx as usize].len as i64) > max_order {
815 ctx = sam.st[ctx as usize].link;
816 }
817 if ctx == -1 {
818 ctx = 0;
819 }
820 }
821 let nxt = sam.text[i + 1];
822 let si = self.find_sym(nxt);
823 if si >= 0 {
824 self.inc(ctx as u32, si as u32, 1);
825 }
826 }
827 }
828 seg_start = seg_end;
829 }
830
831 let mut max_len: usize = 0;
833 for st in &sam.st {
834 let l = st.len as usize;
835 if l > max_len {
836 max_len = l;
837 }
838 }
839 let mut cnt = vec![0usize; max_len + 1];
840 for st in &sam.st {
841 cnt[st.len as usize] += 1;
842 }
843 let mut pos = vec![0usize; max_len + 1];
844 let mut acc = 0usize;
845 for l in 0..=max_len {
846 pos[l] = acc;
847 acc += cnt[l];
848 }
849 let mut order = vec![0u32; sam.st.len()];
850 for (v, st) in sam.st.iter().enumerate() {
851 let l = st.len as usize;
852 let idx = pos[l];
853 order[idx] = v as u32;
854 pos[l] += 1;
855 }
856
857 for oi in (0..order.len()).rev() {
858 let v = order[oi] as usize;
859 let p = sam.st[v].link;
860 if p < 0 {
861 continue;
862 }
863 if self.ls[v].total_n == 0 {
864 continue;
865 }
866 let mut ni = self.ls[v].head;
867 while ni != -1 {
868 let node = self.nodes[ni as usize];
869 self.inc(p as u32, node.sym_idx, node.cnt);
870 ni = node.next;
871 }
872 }
873 }
874
875 fn prob_for_sym(&self, sam: &Sam, max_order: i64, v: i32, sym_idx: i32) -> f64 {
878 if sym_idx < 0 {
879 return 1.0 / (self.alpha_n.max(1) as f64);
880 }
881 let sym_idx = sym_idx as u32;
882 let mut p_accum = 0.0f64;
883 let mut residual = 1.0f64;
884 let mut u = v;
885
886 while u != -1 {
887 if !(max_order >= 0 && (sam.st[u as usize].len as i64) > max_order) {
888 let n = self.ls[u as usize].total_n;
889 let t = self.ls[u as usize].types_t;
890 if n > 0 {
891 let lam = if t > 0 {
892 (n as f64) / ((n + (t as u64)) as f64)
893 } else {
894 1.0
895 };
896
897 let scale = residual * lam;
899
900 let mut count_for_sym = 0u64;
902 let mut ni = self.ls[u as usize].head;
903 while ni != -1 {
904 let node = self.nodes[ni as usize];
905 if node.sym_idx == sym_idx {
906 count_for_sym = node.cnt;
907 break;
908 }
909 ni = node.next;
910 }
911
912 if count_for_sym > 0 {
913 p_accum += scale * (count_for_sym as f64 / n as f64);
914 }
915
916 residual *= 1.0 - lam;
917 }
918 }
919 u = sam.st[u as usize].link;
920 }
921
922 if self.total_uni > 0 && residual > 0.0 {
923 let p_uni = self.unigram[sym_idx as usize] as f64 / self.total_uni as f64;
924 p_accum += residual * p_uni;
925 } else if residual > 0.0 {
926 p_accum += residual * (1.0 / self.alpha_n.max(1) as f64);
927 }
928
929 p_accum.clamp(1e-12, 1.0)
930 }
931
932 fn probs_for_state(&self, sam: &Sam, max_order: i64, v: i32, out: &mut [f64]) {
933 out.fill(0.0);
934 let mut residual = 1.0f64;
935 let mut u = v;
936 while u != -1 {
937 if !(max_order >= 0 && (sam.st[u as usize].len as i64) > max_order) {
938 let n = self.ls[u as usize].total_n;
939 let t = self.ls[u as usize].types_t;
940 if n > 0 {
941 let lam = if t > 0 {
942 (n as f64) / ((n + (t as u64)) as f64)
943 } else {
944 1.0
945 };
946 let scale = residual * lam;
947 let inv_n = 1.0 / (n as f64);
948 let mut ni = self.ls[u as usize].head;
949 while ni != -1 {
950 let node = self.nodes[ni as usize];
951 out[node.sym_idx as usize] += scale * ((node.cnt as f64) * inv_n);
952 ni = node.next;
953 }
954 residual *= 1.0 - lam;
955 }
956 }
957 u = sam.st[u as usize].link;
958 }
959
960 if self.total_uni > 0 && residual > 0.0 {
961 let inv = 1.0 / (self.total_uni as f64);
962 for i in 0..(self.alpha_n as usize) {
963 out[i] += residual * ((self.unigram[i] as f64) * inv);
964 }
965 }
966
967 let mut s = 0.0;
968 for i in 0..(self.alpha_n as usize) {
969 s += out[i];
970 }
971 if s > 0.0 {
972 let invs = 1.0 / s;
973 for i in 0..(self.alpha_n as usize) {
974 out[i] *= invs;
975 }
976 } else {
977 let uprob = 1.0 / (self.alpha_n.max(1) as f64);
978 for i in 0..(self.alpha_n as usize) {
979 out[i] = uprob;
980 }
981 }
982 }
983
984 #[inline(always)]
985 fn inc_tx(&mut self, tx: &mut LmTx, state: u32, sym_idx: u32, add: u64) {
986 let si = state as usize;
987 tx.ls_changes.push((si, self.ls[si]));
989
990 let ls = &mut self.ls[si];
991 let last = ls.last_node;
992 if last != -1 && self.nodes[last as usize].sym_idx == sym_idx {
993 let ni = last as usize;
994 tx.node_changes.push((ni, self.nodes[ni]));
995 self.nodes[ni].cnt += add;
996 ls.total_n += add;
997 return;
998 }
999
1000 let mut ni = ls.head;
1001 while ni != -1 {
1002 let idx = ni as usize;
1003 if self.nodes[idx].sym_idx == sym_idx {
1004 tx.node_changes.push((idx, self.nodes[idx]));
1005 self.nodes[idx].cnt += add;
1006 ls.total_n += add;
1007 ls.last_node = ni;
1008 ls.last_sym = sym_idx;
1009 return;
1010 }
1011 ni = self.nodes[idx].next;
1012 }
1013
1014 let idx = self.nodes.len() as i32;
1016 tx.old_nodes_len = tx.old_nodes_len.min(self.nodes.len());
1017 self.nodes.push(CountNode {
1018 sym_idx,
1019 cnt: add,
1020 next: ls.head,
1021 });
1022 ls.head = idx;
1023 ls.total_n += add;
1024 ls.types_t += 1;
1025 ls.last_node = idx;
1026 ls.last_sym = sym_idx;
1027 }
1028}
1029
1030#[derive(Clone)]
1031struct LmTx {
1032 old_ls_len: usize,
1033 old_nodes_len: usize,
1034 ls_changes: Vec<(usize, LmState)>,
1035 node_changes: Vec<(usize, CountNode)>,
1036 uni_delta: [u64; BYTE_ALPHA_N],
1038 total_uni_add: u64,
1039}
1040
1041#[derive(Clone, Default)]
1042struct RngStream {
1043 buf: Vec<u8>,
1044 pos: usize,
1045 xs: u64,
1046}
1047
1048impl RngStream {
1049 fn new(seed: u64) -> Self {
1050 let mut r = RngStream {
1051 buf: Vec::new(),
1052 pos: 0,
1053 xs: 88172645463325252u64,
1054 };
1055 if let Ok(path) = std::env::var("ROSAPLUS_RNG_PATH") {
1056 if !path.is_empty() {
1057 if let Ok(mut f) = File::open(path) {
1058 let mut b = Vec::new();
1059 if f.read_to_end(&mut b).is_ok() && b.len() >= 8 {
1060 let n = b.len();
1061 r.pos = ((seed.wrapping_mul(8)) as usize) % n;
1062 r.buf = b;
1063 }
1064 }
1065 }
1066 }
1067 r
1068 }
1069
1070 #[inline(always)]
1071 fn next_u64(&mut self) -> u64 {
1072 if self.buf.len() < 8 {
1073 self.xs ^= self.xs << 7;
1074 self.xs ^= self.xs >> 9;
1075 return self.xs;
1076 }
1077 let n = self.buf.len();
1078 let mut b = [0u8; 8];
1079 for i in 0..8 {
1080 b[i] = self.buf[self.pos];
1081 self.pos += 1;
1082 if self.pos >= n {
1083 self.pos = 0;
1084 }
1085 }
1086 u64::from_le_bytes(b)
1087 }
1088
1089 #[inline(always)]
1090 fn next_unit(&mut self) -> f64 {
1091 let x = self.next_u64();
1092 ((x >> 11) as f64) * (1.0 / 9007199254740992.0)
1093 }
1094}
1095
1096#[derive(Clone, Default)]
1100struct SampleScratch {
1101 idx: Vec<u32>,
1102 logits: Vec<f64>,
1103 exps: Vec<f64>,
1104}
1105
1106impl SampleScratch {
1107 fn ensure(&mut self, alpha_n: usize, n: usize) {
1108 if self.idx.len() != alpha_n {
1109 self.idx.resize(alpha_n, 0);
1110 }
1111 if self.logits.len() < n {
1112 self.logits.resize(n, 0.0);
1113 self.exps.resize(n, 0.0);
1114 }
1115 }
1116}
1117
1118#[derive(Clone)]
1119pub struct RosaPlus {
1120 max_order: i64,
1121 use_eot: bool,
1122 eot: u32,
1123 seed: u64,
1124
1125 sam: Sam,
1126 lm: LM,
1127 lm_built: bool,
1128
1129 rng: RngStream,
1130 scratch: SampleScratch,
1131 dist: Vec<f64>,
1132}
1133
1134#[derive(Clone, Copy, Debug)]
1139pub struct RosaCheckpoint {
1140 sam_st_len: usize,
1141 sam_ed_len: usize,
1142 sam_text_len: usize,
1143 sam_text_states_len: usize,
1144 sam_boundary_after_len: usize,
1145 sam_last: i32,
1146}
1147
1148#[derive(Clone)]
1150pub struct RosaTx {
1151 sam: SamTx,
1152 lm: LmTx,
1153 seg_start: usize,
1154 seg_len: usize,
1155}
1156
1157impl RosaPlus {
1158 pub fn new(max_order: i64, use_eot: bool, eot_char: u8, seed: u64) -> Self {
1159 let sam = Sam::new(0);
1160 RosaPlus {
1161 max_order,
1162 use_eot,
1163 eot: eot_char as u32,
1164 seed,
1165 sam,
1166 lm: LM::default(),
1167 lm_built: false,
1168 rng: RngStream::new(seed),
1169 scratch: SampleScratch::default(),
1170 dist: Vec::new(),
1171 }
1172 }
1173
1174 pub fn train_example(&mut self, s: &[u8]) {
1175 if s.is_empty() {
1176 return;
1177 }
1178
1179 if self.sam.text.is_empty() {
1180 self.sam = Sam::new(s.len());
1181 }
1182
1183 for &b in s {
1184 self.sam.feed(b as u32);
1185 }
1186
1187 if self.use_eot {
1188 self.sam.feed(self.eot);
1189 }
1190
1191 self.sam.mark_boundary();
1192 self.lm_built = false;
1193 }
1194
1195 pub fn build_lm(&mut self) {
1196 self.sam.finalize_endpos();
1197 self.lm = LM::default();
1198 self.lm.build_alphabet(&self.sam);
1199 let mo = if self.max_order < 0 {
1200 -1
1201 } else {
1202 self.max_order
1203 };
1204 self.lm.build_counts(&self.sam, mo);
1205 self.lm_built = true;
1206 self.dist.resize(self.lm.alpha_n as usize, 0.0);
1207 }
1208
1209 pub fn build_lm_no_finalize_endpos(&mut self) {
1216 self.lm = LM::default();
1217 self.lm.build_alphabet(&self.sam);
1218 let mo = if self.max_order < 0 {
1219 -1
1220 } else {
1221 self.max_order
1222 };
1223 self.lm.build_counts(&self.sam, mo);
1224 self.lm_built = true;
1225 self.dist.resize(self.lm.alpha_n as usize, 0.0);
1226 }
1227
1228 pub fn build_lm_full_bytes_no_finalize_endpos(&mut self) {
1232 self.lm = LM::default();
1234 self.lm.has_byte_map = true;
1235 self.lm.alpha_n = BYTE_ALPHA_N as u32;
1236 self.lm.alphabet = (0..BYTE_ALPHA_N as u32).collect();
1237 self.lm.byte_map = [-1; 256];
1238 for i in 0..256 {
1239 self.lm.byte_map[i] = i as i16;
1240 }
1241
1242 let mut counts = [0u64; 256];
1244 for &v in &self.sam.text {
1245 if v < 256 {
1246 counts[v as usize] += 1;
1247 }
1248 }
1249 self.lm.unigram = counts.to_vec();
1250 self.lm.total_uni = counts.iter().sum();
1251 if self.lm.total_uni == 0 {
1252 for i in 0..256 {
1253 self.lm.unigram[i] = 1;
1254 }
1255 self.lm.total_uni = 256;
1256 }
1257
1258 let mo = if self.max_order < 0 {
1260 -1
1261 } else {
1262 self.max_order
1263 };
1264 self.lm.build_counts(&self.sam, mo);
1265 self.lm_built = true;
1266 self.dist.resize(BYTE_ALPHA_N, 0.0);
1267 }
1268
1269 pub fn begin_tx(&mut self) -> RosaTx {
1271 let sam_tx = self.sam.begin_tx();
1272 let lm_tx = LmTx {
1273 old_ls_len: self.lm.ls.len(),
1274 old_nodes_len: self.lm.nodes.len(),
1275 ls_changes: Vec::new(),
1276 node_changes: Vec::new(),
1277 uni_delta: [0u64; BYTE_ALPHA_N],
1278 total_uni_add: 0,
1279 };
1280 RosaTx {
1281 sam: sam_tx,
1282 lm: lm_tx,
1283 seg_start: self.sam.text.len(),
1284 seg_len: 0,
1285 }
1286 }
1287
1288 pub fn train_example_tx(&mut self, tx: &mut RosaTx, s: &[u8]) {
1290 self.train_example_tx_impl(tx, s, true);
1291 }
1292
1293 pub fn train_sequence_tx(&mut self, tx: &mut RosaTx, s: &[u8]) {
1295 self.train_example_tx_impl(tx, s, false);
1296 }
1297
1298 fn train_example_tx_impl(&mut self, tx: &mut RosaTx, s: &[u8], mark_boundary: bool) {
1299 if s.is_empty() {
1300 return;
1301 }
1302
1303 if self.lm.ls.len() < self.sam.st.len() {
1305 self.lm.ls.resize(
1306 self.sam.st.len(),
1307 LmState {
1308 head: -1,
1309 last_node: -1,
1310 ..LmState::default()
1311 },
1312 );
1313 }
1314
1315 for &b in s {
1317 self.sam.feed_tx(&mut tx.sam, b as u32);
1318 tx.lm.uni_delta[b as usize] += 1;
1319 tx.lm.total_uni_add += 1;
1320 }
1321 if mark_boundary {
1322 self.sam.mark_boundary_tx(&mut tx.sam);
1323 }
1324
1325 if self.lm.ls.len() < self.sam.st.len() {
1328 self.lm.ls.resize(
1329 self.sam.st.len(),
1330 LmState {
1331 head: -1,
1332 last_node: -1,
1333 ..LmState::default()
1334 },
1335 );
1336 }
1337
1338 for i in 0..256 {
1340 if tx.lm.uni_delta[i] != 0 {
1341 self.lm.unigram[i] += tx.lm.uni_delta[i];
1342 }
1343 }
1344 self.lm.total_uni += tx.lm.total_uni_add;
1345
1346 let seg_start = tx.seg_start;
1348 let seg_end = self.sam.text.len();
1349 tx.seg_len = seg_end - seg_start;
1350 if tx.seg_len >= 1 {
1351 let mo = if self.max_order < 0 {
1352 -1
1353 } else {
1354 self.max_order
1355 };
1356 let mut start_i = seg_start;
1360 if !mark_boundary
1361 && seg_start > 0
1362 && self.sam.boundary_after.get(seg_start - 1).copied().unwrap_or(0) == 0
1363 {
1364 start_i = seg_start - 1;
1365 }
1366 for i in start_i..(seg_end - 1) {
1367 let mut ctx = self.sam.text_states[i + 1];
1369 if mo >= 0 {
1370 while ctx != -1 && (self.sam.st[ctx as usize].len as i64) > mo {
1371 ctx = self.sam.st[ctx as usize].link;
1372 }
1373 if ctx == -1 {
1374 ctx = 0;
1375 }
1376 }
1377 let nxt = self.sam.text[i + 1];
1378 let si = self.lm.find_sym(nxt);
1379 if si >= 0 {
1380 let mut u = ctx;
1381 while u != -1 {
1382 self.lm.inc_tx(&mut tx.lm, u as u32, si as u32, 1);
1383 u = self.sam.st[u as usize].link;
1384 }
1385 }
1386 }
1387 }
1388
1389 self.lm_built = true;
1390 }
1391
1392 pub fn rollback_tx(&mut self, tx: RosaTx) {
1394 if self.lm.unigram.len() >= BYTE_ALPHA_N {
1397 for i in 0..BYTE_ALPHA_N {
1398 let d = tx.lm.uni_delta[i];
1399 if d != 0 {
1400 self.lm.unigram[i] = self.lm.unigram[i].saturating_sub(d);
1401 }
1402 }
1403 self.lm.total_uni = self.lm.total_uni.saturating_sub(tx.lm.total_uni_add);
1404 }
1405
1406 for (idx, old) in tx.lm.node_changes.into_iter().rev() {
1407 if idx < self.lm.nodes.len() {
1408 self.lm.nodes[idx] = old;
1409 }
1410 }
1411 for (idx, old) in tx.lm.ls_changes.into_iter().rev() {
1412 if idx < self.lm.ls.len() {
1413 self.lm.ls[idx] = old;
1414 }
1415 }
1416 self.lm.nodes.truncate(tx.lm.old_nodes_len);
1417 self.lm.ls.truncate(tx.lm.old_ls_len);
1418
1419 self.sam.rollback_tx(tx.sam);
1421 }
1423
1424 #[inline(always)]
1426 pub fn ensure_lm_built_no_finalize_endpos(&mut self) {
1427 if !self.lm_built {
1428 self.build_lm_no_finalize_endpos();
1429 }
1430 }
1431
1432 fn predictive_entropy_rate_order(data: &[u8], max_order: i64, seed: u64) -> f64 {
1433 if data.len() < 2 {
1434 return 0.0;
1435 }
1436 let num_chunks = 16;
1437 let chunk_size = (data.len() + num_chunks - 1) / num_chunks;
1438 let mut total_log_prob = 0.0f64;
1439 let mut count = 0usize;
1440
1441 for i in 0..num_chunks {
1442 let start = i * chunk_size;
1443 let end = ((i + 1) * chunk_size).min(data.len());
1444 if start >= end {
1445 break;
1446 }
1447 if i == 0 {
1448 continue;
1449 }
1450
1451 let mut m = RosaPlus::new(max_order, false, 0, seed);
1452 m.train_example(&data[..start]);
1453 m.build_lm();
1454 let mut v = m.sam.last;
1455
1456 for &b in &data[start..end] {
1457 let sym_idx = m.lm.find_sym(b as u32);
1458 let p = m.lm.prob_for_sym(&m.sam, max_order, v, sym_idx);
1459 total_log_prob += p.log2();
1460 count += 1;
1461 v = m.sam.advance(v, b as u32);
1462 }
1463 }
1464
1465 if count == 0 {
1466 let mut m = RosaPlus::new(max_order, false, 0, seed);
1467 m.train_example(data);
1468 m.build_lm();
1469 m.cross_entropy(data)
1470 } else {
1471 -total_log_prob / (count as f64)
1472 }
1473 }
1474
1475 pub fn lm_alpha_n(&self) -> usize {
1477 if !self.lm_built {
1478 0
1479 } else {
1480 self.lm.alpha_n as usize
1481 }
1482 }
1483
1484 pub fn estimated_size_bytes(&self) -> usize {
1485 use std::mem::size_of;
1486
1487 let mut n = 0usize;
1488
1489 n = n.saturating_add(self.sam.st.len().saturating_mul(size_of::<SamState>()));
1490 n = n.saturating_add(self.sam.ed.len().saturating_mul(size_of::<SamEdge>()));
1491 n = n.saturating_add(self.sam.text.len().saturating_mul(size_of::<u32>()));
1492 n = n.saturating_add(self.sam.text_states.len().saturating_mul(size_of::<i32>()));
1493 n = n.saturating_add(
1494 self.sam
1495 .boundary_after
1496 .len()
1497 .saturating_mul(size_of::<u8>()),
1498 );
1499
1500 n = n.saturating_add(self.lm.alphabet.len().saturating_mul(size_of::<u32>()));
1501 n = n.saturating_add(self.lm.unigram.len().saturating_mul(size_of::<u64>()));
1502 n = n.saturating_add(self.lm.ls.len().saturating_mul(size_of::<LmState>()));
1503 n = n.saturating_add(self.lm.nodes.len().saturating_mul(size_of::<CountNode>()));
1504
1505 n = n.saturating_add(self.dist.len().saturating_mul(size_of::<f64>()));
1506 n = n.saturating_add(self.scratch.idx.len().saturating_mul(size_of::<u32>()));
1507 n = n.saturating_add(self.scratch.logits.len().saturating_mul(size_of::<f64>()));
1508 n = n.saturating_add(self.scratch.exps.len().saturating_mul(size_of::<f64>()));
1509 n = n.saturating_add(self.rng.buf.len().saturating_mul(size_of::<u8>()));
1510
1511 n
1512 }
1513
1514 pub fn shrink_aux_buffers(&mut self) {
1515 self.dist.shrink_to_fit();
1516 self.scratch.idx.shrink_to_fit();
1517 self.scratch.logits.shrink_to_fit();
1518 self.scratch.exps.shrink_to_fit();
1519 self.rng.buf.shrink_to_fit();
1520 }
1521
1522 pub fn fork_from_sam(&self) -> Self {
1528 Self {
1529 max_order: self.max_order,
1530 use_eot: self.use_eot,
1531 eot: self.eot,
1532 seed: self.seed,
1533
1534 sam: self.sam.clone(),
1535 lm: LM::default(),
1536 lm_built: false,
1537
1538 rng: RngStream::new(self.seed),
1539 scratch: SampleScratch::default(),
1540 dist: Vec::new(),
1541 }
1542 }
1543
1544 pub fn checkpoint(&self) -> RosaCheckpoint {
1550 RosaCheckpoint {
1551 sam_st_len: self.sam.st.len(),
1552 sam_ed_len: self.sam.ed.len(),
1553 sam_text_len: self.sam.text.len(),
1554 sam_text_states_len: self.sam.text_states.len(),
1555 sam_boundary_after_len: self.sam.boundary_after.len(),
1556 sam_last: self.sam.last,
1557 }
1558 }
1559
1560 pub fn restore(&mut self, ck: &RosaCheckpoint) {
1564 self.sam.st.truncate(ck.sam_st_len);
1565 self.sam.ed.truncate(ck.sam_ed_len);
1566 self.sam.text.truncate(ck.sam_text_len);
1567 self.sam.text_states.truncate(ck.sam_text_states_len);
1568 self.sam.boundary_after.truncate(ck.sam_boundary_after_len);
1569 self.sam.last = ck.sam_last;
1570 self.lm_built = false;
1571 }
1572
1573 #[inline(always)]
1574 fn sample(&mut self, temperature: f64, top_p: f64, top_k: i32) -> u32 {
1575 let dist = &self.dist;
1576 let alpha_n = self.lm.alpha_n as usize;
1577 self.scratch.ensure(alpha_n, alpha_n);
1578 for i in 0..alpha_n {
1579 self.scratch.idx[i] = i as u32;
1580 }
1581
1582 for i in 0..alpha_n {
1584 for j in (i + 1)..alpha_n {
1585 let ii = self.scratch.idx[i] as usize;
1586 let jj = self.scratch.idx[j] as usize;
1587 let pi = dist[ii];
1588 let pj = dist[jj];
1589 if pj > pi || (pj == pi && jj < ii) {
1590 self.scratch.idx.swap(i, j);
1591 }
1592 }
1593 }
1594
1595 let mut n = alpha_n;
1596 if top_k > 0 {
1597 let k = top_k as usize;
1598 if k < n {
1599 n = k;
1600 }
1601 }
1602
1603 if top_p > 0.0 && top_p < 1.0 {
1604 let mut cum = 0.0;
1605 let mut cut = 0usize;
1606 for i in 0..n {
1607 let si = self.scratch.idx[i] as usize;
1608 cum += dist[si];
1609 cut += 1;
1610 if cum >= top_p {
1611 break;
1612 }
1613 }
1614 n = if cut > 0 { cut } else { 1 };
1615 }
1616
1617 let temperature = if temperature <= 0.0 {
1618 1e-6
1619 } else {
1620 temperature
1621 };
1622
1623 self.scratch.ensure(alpha_n, n);
1624 let mut maxlog = -1e300f64;
1625 for i in 0..n {
1626 let si = self.scratch.idx[i] as usize;
1627 let mut p = dist[si];
1628 if p < 1e-12 {
1629 p = 1e-12;
1630 }
1631 let z = p.ln() / temperature;
1632 self.scratch.logits[i] = z;
1633 if z > maxlog {
1634 maxlog = z;
1635 }
1636 }
1637
1638 let mut zsum = 0.0;
1639 for i in 0..n {
1640 let e = (self.scratch.logits[i] - maxlog).exp();
1641 self.scratch.exps[i] = e;
1642 zsum += e;
1643 }
1644
1645 let r = self.rng.next_unit() * zsum;
1646 let mut cum = 0.0;
1647 let mut pick = 0usize;
1648 for i in 0..n {
1649 cum += self.scratch.exps[i];
1650 if cum > r {
1651 pick = i;
1652 break;
1653 }
1654 }
1655
1656 let sym = self.scratch.idx[pick] as usize;
1657 self.lm.alphabet[sym]
1658 }
1659
1660 pub fn generate(&mut self, prompt: &[u8], steps: i32) -> Option<Vec<u8>> {
1661 if !self.lm_built {
1662 return None;
1663 }
1664 let steps = steps.max(0) as usize;
1665
1666 let mut v = 0i32;
1667 for &b in prompt {
1668 v = self.sam.advance(v, b as u32);
1669 }
1670
1671 let mut out: Vec<u32> = Vec::with_capacity(steps);
1672
1673 for _ in 0..steps {
1674 let mut ch = self.sam.predict_det(v);
1675 if ch.is_none() {
1676 let mo = if self.max_order < 0 {
1677 -1
1678 } else {
1679 self.max_order
1680 };
1681 self.lm.probs_for_state(&self.sam, mo, v, &mut self.dist);
1682 ch = Some(self.sample(0.7, 0.9, 0));
1683 }
1684 let ch = ch.unwrap();
1685 out.push(ch);
1686 if self.use_eot && ch == self.eot {
1687 break;
1688 }
1689 v = self.sam.advance(v, ch);
1690 }
1691
1692 Some(out.iter().map(|&c| c as u8).collect())
1693 }
1694
1695 pub fn get_distribution(&mut self, context: &[u8]) -> Vec<(u32, f64)> {
1701 if !self.lm_built {
1702 self.build_lm();
1703 }
1704
1705 let mut v = 0i32;
1707 for &b in context {
1708 v = self.sam.advance(v, b as u32);
1709 }
1710
1711 let mo = if self.max_order < 0 {
1713 -1
1714 } else {
1715 self.max_order
1716 };
1717 self.dist.resize(self.lm.alpha_n as usize, 0.0);
1718 self.lm.probs_for_state(&self.sam, mo, v, &mut self.dist);
1719
1720 let mut result = Vec::with_capacity(self.lm.alpha_n as usize);
1722 for i in 0..(self.lm.alpha_n as usize) {
1723 if self.dist[i] > 0.0 {
1724 result.push((self.lm.alphabet[i], self.dist[i]));
1725 }
1726 }
1727 result.sort_by_key(|&(cp, _)| cp);
1728 result
1729 }
1730
1731 pub fn predictive_entropy_rate(&mut self, data: &[u8]) -> f64 {
1735 if data.len() < 2 {
1736 return 0.0;
1737 }
1738 if self.max_order < 0 {
1739 let candidates: [i64; 8] = [0, 1, 2, 4, 8, 16, 32, 64];
1740 let mut best = f64::INFINITY;
1741 for &mo in &candidates {
1742 if mo as usize >= data.len() {
1743 continue;
1744 }
1745 let h = Self::predictive_entropy_rate_order(data, mo, self.seed);
1746 if h < best {
1747 best = h;
1748 }
1749 }
1750 if best.is_finite() {
1751 return best;
1752 }
1753 }
1754 Self::predictive_entropy_rate_order(data, self.max_order, self.seed)
1755 }
1756
1757 pub fn entropy_rate_cps(&mut self, cps: &[u32]) -> f64 {
1758 if cps.len() < 2 {
1759 return 0.0;
1760 }
1761
1762 self.sam = Sam::new(cps.len());
1763 self.lm_built = false;
1764
1765 let num_chunks = 16;
1766 let chunk_size = (cps.len() + num_chunks - 1) / num_chunks;
1767 let mut total_log_prob = 0.0f64;
1768 let mut count = 0usize;
1769
1770 for i in 0..num_chunks {
1771 let start = i * chunk_size;
1772 let end = ((i + 1) * chunk_size).min(cps.len());
1773 if start >= end {
1774 break;
1775 }
1776 let chunk = &cps[start..end];
1777 if i > 0 {
1778 self.build_lm_no_finalize_endpos();
1780 let mut v = self.sam.text_states[start];
1781 for &ch in chunk {
1782 let sym_idx = self.lm.find_sym(ch);
1783 let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
1784 total_log_prob += p.log2();
1785 count += 1;
1786 v = self.sam.advance(v, ch);
1787 }
1788 }
1789 for &ch in chunk {
1790 self.sam.feed(ch);
1791 }
1792 }
1793
1794 if count == 0 {
1795 self.build_lm();
1796 self.entropy_rate_plugin_cps(cps)
1797 } else {
1798 -total_log_prob / (count as f64)
1799 }
1800 }
1801
1802 #[allow(dead_code)]
1803 fn entropy_rate_plugin_bytes(&mut self, data: &[u8]) -> f64 {
1804 let mut v = 0i32;
1805 let mut total_log_prob = 0.0f64;
1806 let mut count = 0usize;
1807 for t in 0..(data.len() - 1) {
1808 v = self.sam.advance(v, data[t] as u32);
1809 let next_ch = data[t + 1] as u32;
1810 let sym_idx = self.lm.find_sym(next_ch);
1811 let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
1812 total_log_prob += p.log2();
1813 count += 1;
1814 }
1815 if count == 0 {
1816 0.0
1817 } else {
1818 -total_log_prob / (count as f64)
1819 }
1820 }
1821
1822 pub fn cross_entropy(&self, data: &[u8]) -> f64 {
1823 if !self.lm_built || data.is_empty() {
1824 return 0.0;
1825 }
1826 let mut total_log_prob = 0.0f64;
1827 let mut v = 0i32;
1828 for &b in data {
1829 let ch = b as u32;
1830 let sym_idx = self.lm.find_sym(ch);
1831 let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
1832 total_log_prob += p.log2();
1833 v = self.sam.advance(v, ch);
1834 }
1835 -total_log_prob / (data.len() as f64)
1836 }
1837
1838 pub fn cross_entropy_cps(&self, data: &[u32]) -> f64 {
1839 if !self.lm_built || data.is_empty() {
1840 return 0.0;
1841 }
1842 let mut total_log_prob = 0.0f64;
1843 let mut v = 0i32;
1844 for &ch in data {
1845 let sym_idx = self.lm.find_sym(ch);
1846 let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
1847 total_log_prob += p.log2();
1848 v = self.sam.advance(v, ch);
1849 }
1850 -total_log_prob / (data.len() as f64)
1851 }
1852
1853 fn entropy_rate_plugin_cps(&mut self, cps: &[u32]) -> f64 {
1854 let mut v = 0i32;
1855 let mut total_log_prob = 0.0f64;
1856 let mut count = 0usize;
1857 for t in 0..(cps.len() - 1) {
1858 v = self.sam.advance(v, cps[t]);
1859 let next_ch = cps[t + 1];
1860 let sym_idx = self.lm.find_sym(next_ch);
1861 let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
1862 total_log_prob += p.log2();
1863 count += 1;
1864 }
1865 if count == 0 {
1866 0.0
1867 } else {
1868 -total_log_prob / (count as f64)
1869 }
1870 }
1871
1872 pub fn marginal_distribution(&self) -> Vec<(u32, f64)> {
1875 if self.lm.total_uni == 0 {
1876 return Vec::new();
1877 }
1878
1879 let inv = 1.0 / (self.lm.total_uni as f64);
1880 let mut result = Vec::with_capacity(self.lm.alpha_n as usize);
1881 for i in 0..(self.lm.alpha_n as usize) {
1882 let p = (self.lm.unigram[i] as f64) * inv;
1883 if p > 0.0 {
1884 result.push((self.lm.alphabet[i], p));
1885 }
1886 }
1887 result.sort_by_key(|&(cp, _)| cp);
1888 result
1889 }
1890
1891 pub fn marginal_entropy(&self) -> f64 {
1894 if self.lm.total_uni == 0 {
1895 return 0.0;
1896 }
1897
1898 let inv = 1.0 / (self.lm.total_uni as f64);
1899 let mut h = 0.0f64;
1900 for i in 0..(self.lm.alpha_n as usize) {
1901 let p = (self.lm.unigram[i] as f64) * inv;
1902 if p > 0.0 {
1903 h -= p * p.log2();
1904 }
1905 }
1906 h
1907 }
1908
1909 pub fn save(&self, path: &str) -> std::io::Result<()> {
1910 if !self.lm_built {
1911 return Err(std::io::Error::new(
1912 std::io::ErrorKind::Other,
1913 "LM not built",
1914 ));
1915 }
1916
1917 if self.sam.text_states.len() != self.sam.text.len() + 1 {
1920 return Err(std::io::Error::new(
1921 std::io::ErrorKind::Other,
1922 "SAM text_states mismatch (expected text.len()+1)",
1923 ));
1924 }
1925 let mut f = BufWriter::with_capacity(1024 * 1024, File::create(path)?);
1926 f.write_all(MAGIC)?;
1927 f.write_all(&self.max_order.to_le_bytes())?;
1928 f.write_all(&(self.use_eot as i32).to_le_bytes())?;
1929 f.write_all(&self.eot.to_le_bytes())?;
1930 f.write_all(&self.seed.to_le_bytes())?;
1931
1932 f.write_all(&(self.sam.st.len() as u32).to_le_bytes())?;
1934 f.write_all(&(self.sam.ed.len() as u32).to_le_bytes())?;
1935 f.write_all(&(self.sam.text.len() as u32).to_le_bytes())?;
1936 for st in &self.sam.st {
1937 f.write_all(&st.link.to_le_bytes())?;
1938 f.write_all(&st.len.to_le_bytes())?;
1939 f.write_all(&st.endpos.to_le_bytes())?;
1940 f.write_all(&(st.small_n as u32).to_le_bytes())?;
1941 for k in 0..(st.small_n as usize) {
1942 f.write_all(&st.small_ch[k].to_le_bytes())?;
1943 f.write_all(&st.small_to[k].to_le_bytes())?;
1944 }
1945 f.write_all(&st.head.to_le_bytes())?;
1946 }
1947 for e in &self.sam.ed {
1948 f.write_all(&e.ch.to_le_bytes())?;
1949 f.write_all(&e.to.to_le_bytes())?;
1950 f.write_all(&e.next.to_le_bytes())?;
1951 }
1952 write_u32_slice_le(&mut f, &self.sam.text)?;
1953 f.write_all(&self.sam.boundary_after)?;
1954
1955 f.write_all(&self.sam.last.to_le_bytes())?;
1957 f.write_all(&(self.sam.text_states.len() as u32).to_le_bytes())?;
1958 write_i32_slice_le(&mut f, &self.sam.text_states)?;
1959
1960 f.write_all(&self.lm.alpha_n.to_le_bytes())?;
1962 f.write_all(&self.lm.total_uni.to_le_bytes())?;
1963 f.write_all(&(self.lm.nodes.len() as u32).to_le_bytes())?;
1964 write_u32_slice_le(&mut f, &self.lm.alphabet)?;
1965 write_u64_slice_le(&mut f, &self.lm.unigram)?;
1966 for ls in &self.lm.ls {
1967 f.write_all(&ls.head.to_le_bytes())?;
1968 f.write_all(&ls.total_n.to_le_bytes())?;
1969 f.write_all(&ls.types_t.to_le_bytes())?;
1970 }
1971 for n in &self.lm.nodes {
1972 f.write_all(&n.sym_idx.to_le_bytes())?;
1973 f.write_all(&n.cnt.to_le_bytes())?;
1974 f.write_all(&n.next.to_le_bytes())?;
1975 }
1976 f.flush()?;
1977 Ok(())
1978 }
1979
1980 pub fn load(path: &str) -> std::io::Result<Self> {
1981 let mut f = BufReader::with_capacity(1024 * 1024, File::open(path)?);
1982 let mut magic = vec![0u8; MAGIC.len()];
1983 f.read_exact(&mut magic)?;
1984 if magic != MAGIC {
1985 return Err(std::io::Error::new(
1986 std::io::ErrorKind::InvalidData,
1987 "bad magic",
1988 ));
1989 }
1990
1991 let mut b8 = [0u8; 8];
1992 let mut b4 = [0u8; 4];
1993
1994 f.read_exact(&mut b8)?;
1995 let max_order = i64::from_le_bytes(b8);
1996 f.read_exact(&mut b4)?;
1997 let use_eot = i32::from_le_bytes(b4) != 0;
1998 f.read_exact(&mut b4)?;
1999 let eot = u32::from_le_bytes(b4);
2000 f.read_exact(&mut b8)?;
2001 let seed = u64::from_le_bytes(b8);
2002
2003 let mut m = RosaPlus::new(max_order, use_eot, eot as u8, seed);
2004
2005 f.read_exact(&mut b4)?;
2007 let st_n = u32::from_le_bytes(b4) as usize;
2008 f.read_exact(&mut b4)?;
2009 let ed_n = u32::from_le_bytes(b4) as usize;
2010 f.read_exact(&mut b4)?;
2011 let text_n = u32::from_le_bytes(b4) as usize;
2012
2013 m.sam = Sam::new(text_n);
2014 m.sam.st.resize(st_n, SamState::default());
2015 m.sam.ed.resize(ed_n, SamEdge::default());
2016 m.sam.text.resize(text_n, 0u32);
2017 m.sam.boundary_after.resize(text_n, 0u8);
2018
2019 for i in 0..st_n {
2020 f.read_exact(&mut b4)?;
2021 m.sam.st[i].link = i32::from_le_bytes(b4);
2022 f.read_exact(&mut b4)?;
2023 m.sam.st[i].len = i32::from_le_bytes(b4);
2024 f.read_exact(&mut b4)?;
2025 m.sam.st[i].endpos = i32::from_le_bytes(b4);
2026 f.read_exact(&mut b4)?;
2027 let sn = u32::from_le_bytes(b4) as usize;
2028 if sn > SAM_SMALL_MAX {
2029 return Err(std::io::Error::new(
2030 std::io::ErrorKind::InvalidData,
2031 "bad small_n",
2032 ));
2033 }
2034 m.sam.st[i].small_n = sn as u8;
2035 for k in 0..sn {
2036 f.read_exact(&mut b4)?;
2037 m.sam.st[i].small_ch[k] = u32::from_le_bytes(b4);
2038 f.read_exact(&mut b4)?;
2039 m.sam.st[i].small_to[k] = i32::from_le_bytes(b4);
2040 }
2041 f.read_exact(&mut b4)?;
2042 m.sam.st[i].head = i32::from_le_bytes(b4);
2043 }
2044 for i in 0..ed_n {
2045 f.read_exact(&mut b4)?;
2046 m.sam.ed[i].ch = u32::from_le_bytes(b4);
2047 f.read_exact(&mut b4)?;
2048 m.sam.ed[i].to = i32::from_le_bytes(b4);
2049 f.read_exact(&mut b4)?;
2050 m.sam.ed[i].next = i32::from_le_bytes(b4);
2051 }
2052 read_u32_slice_le(&mut f, &mut m.sam.text)?;
2053 f.read_exact(&mut m.sam.boundary_after)?;
2054
2055 f.read_exact(&mut b4)?;
2057 m.sam.last = i32::from_le_bytes(b4);
2058 f.read_exact(&mut b4)?;
2059 let text_states_n = u32::from_le_bytes(b4) as usize;
2060 if text_states_n != text_n + 1 {
2061 return Err(std::io::Error::new(
2062 std::io::ErrorKind::InvalidData,
2063 "bad text_states len",
2064 ));
2065 }
2066 m.sam.text_states.resize(text_states_n, 0);
2067 read_i32_slice_le(&mut f, &mut m.sam.text_states)?;
2068 for &v in &m.sam.text_states {
2069 if v < 0 || (v as usize) >= st_n {
2070 return Err(std::io::Error::new(
2071 std::io::ErrorKind::InvalidData,
2072 "bad text_states entry",
2073 ));
2074 }
2075 }
2076 if m.sam.last < 0 || (m.sam.last as usize) >= st_n {
2077 return Err(std::io::Error::new(
2078 std::io::ErrorKind::InvalidData,
2079 "bad sam.last",
2080 ));
2081 }
2082
2083 f.read_exact(&mut b4)?;
2085 let alpha_n = u32::from_le_bytes(b4) as usize;
2086 f.read_exact(&mut b8)?;
2087 let total_uni = u64::from_le_bytes(b8);
2088 f.read_exact(&mut b4)?;
2089 let nodes_n = u32::from_le_bytes(b4) as usize;
2090
2091 m.lm = LM::default();
2092 m.lm.alpha_n = alpha_n as u32;
2093 m.lm.total_uni = total_uni;
2094 m.lm.alphabet.resize(alpha_n, 0);
2095 m.lm.unigram.resize(alpha_n, 0);
2096 m.lm.ls = vec![
2097 LmState {
2098 head: -1,
2099 last_node: -1,
2100 ..LmState::default()
2101 };
2102 st_n
2103 ];
2104 m.lm.nodes.resize(nodes_n, CountNode::default());
2105
2106 read_u32_slice_le(&mut f, &mut m.lm.alphabet)?;
2107 read_u64_slice_le(&mut f, &mut m.lm.unigram)?;
2108 for i in 0..st_n {
2109 f.read_exact(&mut b4)?;
2110 m.lm.ls[i].head = i32::from_le_bytes(b4);
2111 f.read_exact(&mut b8)?;
2112 m.lm.ls[i].total_n = u64::from_le_bytes(b8);
2113 f.read_exact(&mut b4)?;
2114 m.lm.ls[i].types_t = u32::from_le_bytes(b4);
2115 m.lm.ls[i].last_node = -1;
2116 m.lm.ls[i].last_sym = 0;
2117 }
2118 for i in 0..nodes_n {
2119 f.read_exact(&mut b4)?;
2120 m.lm.nodes[i].sym_idx = u32::from_le_bytes(b4);
2121 f.read_exact(&mut b8)?;
2122 m.lm.nodes[i].cnt = u64::from_le_bytes(b8);
2123 f.read_exact(&mut b4)?;
2124 m.lm.nodes[i].next = i32::from_le_bytes(b4);
2125 }
2126
2127 m.lm.has_byte_map = false;
2129 m.lm.byte_map = [-1; 256];
2130 let mut max_cp = 0u32;
2131 for &v in &m.lm.alphabet {
2132 if v > max_cp {
2133 max_cp = v;
2134 }
2135 }
2136 if max_cp < 256 {
2137 m.lm.has_byte_map = true;
2138 for (i, &c) in m.lm.alphabet.iter().enumerate() {
2139 m.lm.byte_map[c as usize] = i as i16;
2140 }
2141 }
2142
2143 m.lm_built = true;
2144 m.dist.resize(alpha_n, 0.0);
2145 Ok(m)
2146 }
2147
2148 pub fn prob_for_last(&mut self, sym: u32) -> f64 {
2149 if !self.lm_built {
2150 self.build_lm();
2151 }
2152 let v = self.sam.last;
2153 let sym_idx = self.lm.find_sym(sym);
2154 let mo = if self.max_order < 0 {
2155 -1
2156 } else {
2157 self.max_order
2158 };
2159 self.lm.prob_for_sym(&self.sam, mo, v, sym_idx)
2160 }
2161}
2162
2163#[cfg(test)]
2164mod tests {
2165 use super::*;
2166
2167 #[test]
2168 fn rosa_md_example_basic() {
2169 let x = b"ababa";
2171 let mut m = RosaPlus::new(1048576, false, 4, 0);
2172 m.train_example(x);
2173 m.build_lm();
2174 let out = m.generate(b"a", 10).unwrap();
2175 assert!(!out.is_empty());
2176 }
2177
2178 #[test]
2179 fn tx_rollback_restores_sam_and_unigram_counts() {
2180 let mut m = RosaPlus::new(4, false, 0, 123);
2181 m.train_example(b"hello");
2182 m.build_lm_full_bytes_no_finalize_endpos();
2183
2184 let base_text = m.sam.text.clone();
2185 let base_text_len = m.sam.text.len();
2186 let base_total_uni = m.lm.total_uni;
2187 assert!(base_text_len > 0);
2188
2189 let mut tx = m.begin_tx();
2190 m.train_example_tx(&mut tx, b"abc");
2191 assert_eq!(m.lm.total_uni, base_total_uni + 3);
2192 assert_eq!(m.sam.text.len(), base_text_len + 3);
2193
2194 m.rollback_tx(tx);
2195 assert_eq!(m.sam.text, base_text);
2196 assert_eq!(m.lm.total_uni, base_total_uni);
2197 }
2198
2199 #[test]
2200 fn checkpoint_restore_reverts_append_only_buffers() {
2201 let mut m = RosaPlus::new(3, true, b'\n', 7);
2202 m.train_example(b"aaaa");
2203
2204 let ck = m.checkpoint();
2205 let base_text = m.sam.text.clone();
2206 let base_states = m.sam.text_states.clone();
2207 let base_boundary = m.sam.boundary_after.clone();
2208 let base_last = m.sam.last;
2209
2210 m.train_example(b"bbbb");
2211 assert_ne!(m.sam.text, base_text);
2212
2213 m.restore(&ck);
2214 assert_eq!(m.sam.text, base_text);
2215 assert_eq!(m.sam.text_states, base_states);
2216 assert_eq!(m.sam.boundary_after, base_boundary);
2217 assert_eq!(m.sam.last, base_last);
2218 assert!(!m.lm_built);
2219 }
2220}