1#![allow(clippy::needless_range_loop)]
17
18use std::collections::HashMap;
19use std::fs::File;
20use std::io::{BufReader, BufWriter, Read, Write};
21
22const SAM_SMALL_MAX: usize = 4;
23const MAGIC_V5: &[u8] = b"rosa_pb_v5\0";
27
28type SamStateIx = i32;
29type SamEdgeIx = u32;
30type LmNodeIx = u32;
31
32const SAM_STATE_NONE: SamStateIx = -1;
33const SAM_EDGE_NONE: SamEdgeIx = u32::MAX;
34const LM_NODE_NONE: LmNodeIx = u32::MAX;
35const LM_PACKED_SYM_OVERFLOW: u16 = u16::MAX;
36const LM_PACKED_CNT_MAX: u16 = u16::MAX;
37
38const BYTE_ALPHA_N: usize = 256;
41
42#[inline(always)]
43fn state_ix(idx: usize) -> SamStateIx {
44 SamStateIx::try_from(idx).expect("rosa sam state index overflow")
45}
46
47#[inline(always)]
48fn state_usize(idx: SamStateIx) -> usize {
49 debug_assert!(idx >= 0, "negative rosa sam state index");
50 idx as usize
51}
52
53#[inline(always)]
54fn edge_ix(idx: usize) -> SamEdgeIx {
55 SamEdgeIx::try_from(idx).expect("rosa sam edge index overflow")
56}
57
58#[inline(always)]
59fn edge_usize(idx: SamEdgeIx) -> usize {
60 idx as usize
61}
62
63#[inline(always)]
64fn node_ix(idx: usize) -> LmNodeIx {
65 LmNodeIx::try_from(idx).expect("rosa lm node index overflow")
66}
67
68#[inline(always)]
69fn node_usize(idx: LmNodeIx) -> usize {
70 idx as usize
71}
72
73#[inline(always)]
74fn write_len64<W: Write>(w: &mut W, len: usize) -> std::io::Result<()> {
75 w.write_all(&(len as u64).to_le_bytes())
76}
77
78#[inline(always)]
79fn read_len64<R: Read>(r: &mut R) -> std::io::Result<usize> {
80 let mut b8 = [0u8; 8];
81 r.read_exact(&mut b8)?;
82 usize::try_from(u64::from_le_bytes(b8))
83 .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "length overflow"))
84}
85
86#[inline(always)]
87fn write_u32_slice_le<W: Write>(w: &mut W, xs: &[u32]) -> std::io::Result<()> {
88 if cfg!(target_endian = "little") {
89 let bytes = unsafe {
90 std::slice::from_raw_parts(xs.as_ptr() as *const u8, xs.len().saturating_mul(4))
91 };
92 w.write_all(bytes)
93 } else {
94 for &x in xs {
95 w.write_all(&x.to_le_bytes())?;
96 }
97 Ok(())
98 }
99}
100
101#[inline(always)]
102fn write_i32_slice_le<W: Write>(w: &mut W, xs: &[i32]) -> std::io::Result<()> {
103 if cfg!(target_endian = "little") {
104 let bytes = unsafe {
105 std::slice::from_raw_parts(xs.as_ptr() as *const u8, xs.len().saturating_mul(4))
106 };
107 w.write_all(bytes)
108 } else {
109 for &x in xs {
110 w.write_all(&x.to_le_bytes())?;
111 }
112 Ok(())
113 }
114}
115
116#[inline(always)]
117fn write_u64_slice_le<W: Write>(w: &mut W, xs: &[u64]) -> std::io::Result<()> {
118 if cfg!(target_endian = "little") {
119 let bytes = unsafe {
120 std::slice::from_raw_parts(xs.as_ptr() as *const u8, xs.len().saturating_mul(8))
121 };
122 w.write_all(bytes)
123 } else {
124 for &x in xs {
125 w.write_all(&x.to_le_bytes())?;
126 }
127 Ok(())
128 }
129}
130
131#[inline(always)]
132fn read_u32_slice_le<R: Read>(r: &mut R, xs: &mut [u32]) -> std::io::Result<()> {
133 if cfg!(target_endian = "little") {
134 let bytes = unsafe {
135 std::slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut u8, xs.len().saturating_mul(4))
136 };
137 r.read_exact(bytes)
138 } else {
139 let mut b4 = [0u8; 4];
140 for x in xs {
141 r.read_exact(&mut b4)?;
142 *x = u32::from_le_bytes(b4);
143 }
144 Ok(())
145 }
146}
147
148#[inline(always)]
149fn read_i32_slice_le<R: Read>(r: &mut R, xs: &mut [i32]) -> std::io::Result<()> {
150 if cfg!(target_endian = "little") {
151 let bytes = unsafe {
152 std::slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut u8, xs.len().saturating_mul(4))
153 };
154 r.read_exact(bytes)
155 } else {
156 let mut b4 = [0u8; 4];
157 for x in xs {
158 r.read_exact(&mut b4)?;
159 *x = i32::from_le_bytes(b4);
160 }
161 Ok(())
162 }
163}
164
165#[inline(always)]
166fn read_u64_slice_le<R: Read>(r: &mut R, xs: &mut [u64]) -> std::io::Result<()> {
167 if cfg!(target_endian = "little") {
168 let bytes = unsafe {
169 std::slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut u8, xs.len().saturating_mul(8))
170 };
171 r.read_exact(bytes)
172 } else {
173 let mut b8 = [0u8; 8];
174 for x in xs {
175 r.read_exact(&mut b8)?;
176 *x = u64::from_le_bytes(b8);
177 }
178 Ok(())
179 }
180}
181
182#[derive(Clone, Copy, Default)]
183struct SamState {
184 link: SamStateIx,
185 len: i32,
186 endpos: i32,
187 head: SamEdgeIx,
188
189 small_ch: [u32; SAM_SMALL_MAX],
190 small_to: [SamStateIx; SAM_SMALL_MAX],
191 small_n: u8,
192}
193
194#[derive(Clone, Copy, Default)]
195struct SamEdge {
196 ch: u32,
197 to: SamStateIx,
198 next: SamEdgeIx,
199}
200
201#[derive(Clone)]
202struct Sam {
203 st: Vec<SamState>,
204 ed: Vec<SamEdge>,
205 last: SamStateIx,
206 root_to: [SamStateIx; BYTE_ALPHA_N],
207
208 text: Vec<u32>,
209 text_states: Vec<SamStateIx>,
210 boundary_after: Vec<u8>,
211}
212
213impl Default for Sam {
214 fn default() -> Self {
215 Self::new(0)
216 }
217}
218
219impl Sam {
220 fn new(expected_chars: usize) -> Self {
221 let mut s = Sam {
222 st: Vec::new(),
223 ed: Vec::new(),
224 last: 0,
225 root_to: [SAM_STATE_NONE; BYTE_ALPHA_N],
226 text: Vec::new(),
227 text_states: Vec::new(),
228 boundary_after: Vec::new(),
229 };
230
231 let st_cap = if expected_chars > 0 {
232 expected_chars * 2 + 16
233 } else {
234 1024
235 };
236 let ed_cap = if expected_chars > 0 {
237 expected_chars * 3 + 16
238 } else {
239 2048
240 };
241 let text_cap = if expected_chars > 0 {
242 expected_chars + 16
243 } else {
244 1024
245 };
246 s.st.reserve(st_cap);
247 s.ed.reserve(ed_cap);
248 s.text.reserve(text_cap);
249 s.text_states.reserve(text_cap);
250 s.boundary_after.reserve(text_cap);
251
252 let root = SamState {
253 link: SAM_STATE_NONE,
254 len: 0,
255 endpos: -1,
256 small_n: 0,
257 head: SAM_EDGE_NONE,
258 ..Default::default()
259 };
260 s.st.push(root);
261 s.text_states.push(0); s
263 }
264
265 #[inline(always)]
266 fn reserve_additional(&mut self, additional: usize) {
267 if additional == 0 {
268 return;
269 }
270 self.st
271 .reserve_exact(additional.saturating_mul(2).saturating_add(16));
272 self.ed
273 .reserve_exact(additional.saturating_mul(3).saturating_add(16));
274 let text_extra = additional.saturating_add(16);
275 self.text.reserve_exact(text_extra);
276 self.text_states.reserve_exact(text_extra);
277 self.boundary_after.reserve_exact(text_extra);
278 }
279
280 #[inline(always)]
281 fn get_edge(&self, v: SamStateIx, ch: u32) -> SamStateIx {
282 if v == 0 && ch < BYTE_ALPHA_N as u32 {
283 return self.root_to[ch as usize];
284 }
285 let st = unsafe { self.st.get_unchecked(state_usize(v)) };
286 for i in 0..(st.small_n as usize) {
287 if st.small_ch[i] == ch {
288 return st.small_to[i];
289 }
290 }
291 let mut ei = st.head;
292 while ei != SAM_EDGE_NONE {
293 let e = unsafe { self.ed.get_unchecked(edge_usize(ei)) };
294 if e.ch == ch {
295 return e.to;
296 }
297 ei = e.next;
298 }
299 SAM_STATE_NONE
300 }
301
302 #[inline(always)]
303 fn add_edge(&mut self, v: SamStateIx, ch: u32, to: SamStateIx) {
304 let idx = edge_ix(self.ed.len());
305 let head = self.st[state_usize(v)].head;
306 self.ed.push(SamEdge { ch, to, next: head });
307 self.st[state_usize(v)].head = idx;
308 }
309
310 #[inline(always)]
311 fn add_edge_absent(&mut self, v: SamStateIx, ch: u32, to: SamStateIx) {
312 let st = &mut self.st[state_usize(v)];
313 if (st.small_n as usize) < SAM_SMALL_MAX {
314 let i = st.small_n as usize;
315 st.small_n += 1;
316 st.small_ch[i] = ch;
317 st.small_to[i] = to;
318 } else {
319 self.add_edge(v, ch, to);
320 }
321 if v == 0 && ch < BYTE_ALPHA_N as u32 {
322 self.root_to[ch as usize] = to;
323 }
324 }
325
326 #[inline(always)]
327 fn replace_edge_to(
328 &mut self,
329 v: SamStateIx,
330 ch: u32,
331 old_to: SamStateIx,
332 new_to: SamStateIx,
333 ) -> bool {
334 {
335 let st = &mut self.st[state_usize(v)];
336 for i in 0..(st.small_n as usize) {
337 if st.small_ch[i] == ch && st.small_to[i] == old_to {
338 st.small_to[i] = new_to;
339 if v == 0 && ch < BYTE_ALPHA_N as u32 {
340 self.root_to[ch as usize] = new_to;
341 }
342 return true;
343 }
344 }
345 }
346 let mut ei = self.st[state_usize(v)].head;
347 while ei != SAM_EDGE_NONE {
348 let e = &mut self.ed[edge_usize(ei)];
349 if e.ch == ch && e.to == old_to {
350 e.to = new_to;
351 if v == 0 && ch < BYTE_ALPHA_N as u32 {
352 self.root_to[ch as usize] = new_to;
353 }
354 return true;
355 }
356 ei = e.next;
357 }
358 false
359 }
360
361 fn rebuild_root_cache(&mut self) {
362 self.root_to.fill(SAM_STATE_NONE);
363 if self.st.is_empty() {
364 return;
365 }
366 let root = self.st[0];
367 for i in 0..(root.small_n as usize) {
368 let ch = root.small_ch[i];
369 if ch < BYTE_ALPHA_N as u32 {
370 self.root_to[ch as usize] = root.small_to[i];
371 }
372 }
373 let mut ei = root.head;
374 while ei != SAM_EDGE_NONE {
375 let e = self.ed[edge_usize(ei)];
376 if e.ch < BYTE_ALPHA_N as u32 {
377 self.root_to[e.ch as usize] = e.to;
378 }
379 ei = e.next;
380 }
381 }
382
383 fn clone_overflow_edges(&mut self, src: SamStateIx, dst: SamStateIx) {
384 self.st[state_usize(dst)].head = SAM_EDGE_NONE;
385 let mut ei = self.st[state_usize(src)].head;
386 while ei != SAM_EDGE_NONE {
387 let e = self.ed[edge_usize(ei)];
388 self.add_edge(dst, e.ch, e.to);
389 ei = e.next;
390 }
391 }
392
393 fn feed(&mut self, ch: u32) {
394 let i = self.text.len() as i32;
395 self.text.push(ch);
396 self.boundary_after.push(0);
397
398 let g = self.last;
399 let r = state_ix(self.st.len());
400 let st_r = SamState {
401 link: 0,
402 len: self.st[state_usize(g)].len + 1,
403 endpos: i,
404 small_n: 0,
405 head: SAM_EDGE_NONE,
406 ..Default::default()
407 };
408 self.st.push(st_r);
409
410 let mut p = g;
411 let mut q;
412 while p != SAM_STATE_NONE {
413 q = self.get_edge(p, ch);
414 if q != SAM_STATE_NONE {
415 break;
416 }
417 self.add_edge_absent(p, ch, r);
418 p = self.st[state_usize(p)].link;
419 }
420
421 if p == SAM_STATE_NONE {
422 self.st[state_usize(r)].link = 0;
423 } else {
424 q = self.get_edge(p, ch);
425 if self.st[state_usize(p)].len + 1 == self.st[state_usize(q)].len {
426 self.st[state_usize(r)].link = q;
427 } else {
428 let u = state_ix(self.st.len());
429 let mut st_u = self.st[state_usize(q)];
430 st_u.len = self.st[state_usize(p)].len + 1;
431 self.st.push(st_u);
432 self.clone_overflow_edges(q, u);
433 while p != SAM_STATE_NONE && self.replace_edge_to(p, ch, q, u) {
434 p = self.st[state_usize(p)].link;
435 }
436 self.st[state_usize(q)].link = u;
437 self.st[state_usize(r)].link = u;
438 }
439 }
440
441 self.last = r;
442 self.text_states.push(r);
443
444 let mut v = r;
446 while v != SAM_STATE_NONE && self.st[state_usize(v)].endpos < i {
447 self.st[state_usize(v)].endpos = i;
448 v = self.st[state_usize(v)].link;
449 }
450 }
451
452 fn mark_boundary(&mut self) {
453 if !self.text.is_empty() {
454 let i = self.text.len() - 1;
455 self.boundary_after[i] = 1;
456 }
457 self.last = 0;
458 }
459
460 fn finalize_endpos(&mut self) {
461 let mut max_len: usize = 0;
462 for v in 0..self.st.len() {
463 let l = self.st[v].len as usize;
464 if l > max_len {
465 max_len = l;
466 }
467 }
468
469 let mut cnt = vec![0usize; max_len + 1];
470 for v in 0..self.st.len() {
471 cnt[self.st[v].len as usize] += 1;
472 }
473 let mut pos = vec![0usize; max_len + 1];
474 let mut acc = 0usize;
475 for l in 0..=max_len {
476 pos[l] = acc;
477 acc += cnt[l];
478 }
479 let mut order = vec![0u32; self.st.len()];
480 for v in 0..self.st.len() {
481 let l = self.st[v].len as usize;
482 let idx = pos[l];
483 order[idx] = v as u32;
484 pos[l] += 1;
485 }
486
487 for oi in (0..order.len()).rev() {
488 let v = order[oi] as usize;
489 let p = self.st[v].link;
490 if p >= 0 {
491 let p = p as usize;
492 if self.st[v].endpos > self.st[p].endpos {
493 self.st[p].endpos = self.st[v].endpos;
494 }
495 }
496 }
497 }
498
499 #[inline(always)]
500 fn advance(&self, mut v: SamStateIx, ch: u32) -> SamStateIx {
501 let mut to;
502 loop {
503 to = self.get_edge(v, ch);
504 if to != SAM_STATE_NONE {
505 return to;
506 }
507 v = self.st[state_usize(v)].link;
508 if v == SAM_STATE_NONE {
509 break;
510 }
511 }
512 to = self.get_edge(0, ch);
513 if to == SAM_STATE_NONE { 0 } else { to }
514 }
515
516 #[inline(always)]
517 fn predict_det(&self, v: SamStateIx) -> Option<u32> {
518 let mut u = v;
519 while u != SAM_STATE_NONE {
520 let st = unsafe { self.st.get_unchecked(state_usize(u)) };
521 let i = st.endpos;
522 let j = i + 1;
523 if st.len > 0 && j >= 0 && (j as usize) < self.text.len() {
524 if i >= 0
525 && (i as usize) < self.boundary_after.len()
526 && self.boundary_after[i as usize] != 0
527 {
528 u = st.link;
529 continue;
530 }
531 return Some(self.text[j as usize]);
532 }
533 u = st.link;
534 }
535 None
536 }
537
538 fn begin_tx(&self) -> SamTx {
540 SamTx {
541 old_last: self.last,
542 old_text_len: self.text.len(),
543 old_text_states_len: self.text_states.len(),
544 old_boundary_len: self.boundary_after.len(),
545 old_st_len: self.st.len(),
546 old_ed_len: self.ed.len(),
547 st_changes: Vec::new(),
548 ed_changes: Vec::new(),
549 }
550 }
551
552 fn rollback_tx(&mut self, tx: SamTx) {
553 for (idx, old) in tx.ed_changes.into_iter().rev() {
555 if idx < self.ed.len() {
556 self.ed[idx] = old;
557 }
558 }
559 for (idx, old) in tx.st_changes.into_iter().rev() {
560 if idx < self.st.len() {
561 self.st[idx] = old;
562 }
563 }
564
565 self.st.truncate(tx.old_st_len);
566 self.ed.truncate(tx.old_ed_len);
567 self.text.truncate(tx.old_text_len);
568 self.text_states.truncate(tx.old_text_states_len);
569 self.boundary_after.truncate(tx.old_boundary_len);
570 self.last = tx.old_last;
571 self.rebuild_root_cache();
572 }
573
574 #[inline(always)]
575 fn record_state_change(&self, tx: &mut SamTx, idx: usize) {
576 tx.st_changes.push((idx, self.st[idx]));
578 }
579
580 #[inline(always)]
581 fn record_edge_change(&self, tx: &mut SamTx, idx: usize) {
582 tx.ed_changes.push((idx, self.ed[idx]));
583 }
584
585 #[inline(always)]
586 fn add_edge_tx(&mut self, tx: &mut SamTx, v: SamStateIx, ch: u32, to: SamStateIx) {
587 let idx = edge_ix(self.ed.len());
588 let head = self.st[state_usize(v)].head;
589 self.ed.push(SamEdge { ch, to, next: head });
590 self.record_state_change(tx, state_usize(v));
591 self.st[state_usize(v)].head = idx;
592 }
593
594 #[inline(always)]
595 fn add_edge_absent_tx(&mut self, tx: &mut SamTx, v: SamStateIx, ch: u32, to: SamStateIx) {
596 let v_usize = state_usize(v);
597 let small_n = self.st[v_usize].small_n as usize;
598 if small_n < SAM_SMALL_MAX {
599 let i = small_n;
600 self.record_state_change(tx, v_usize);
601 let st = &mut self.st[v_usize];
602 st.small_ch[i] = ch;
603 st.small_to[i] = to;
604 st.small_n += 1;
605 if v == 0 && ch < BYTE_ALPHA_N as u32 {
606 self.root_to[ch as usize] = to;
607 }
608 } else {
609 self.add_edge_tx(tx, v, ch, to);
610 if v == 0 && ch < BYTE_ALPHA_N as u32 {
611 self.root_to[ch as usize] = to;
612 }
613 }
614 }
615
616 #[inline(always)]
617 fn replace_edge_to_tx(
618 &mut self,
619 tx: &mut SamTx,
620 v: SamStateIx,
621 ch: u32,
622 old_to: SamStateIx,
623 new_to: SamStateIx,
624 ) -> bool {
625 {
627 let st = &self.st[state_usize(v)];
628 for i in 0..(st.small_n as usize) {
629 if st.small_ch[i] == ch && st.small_to[i] == old_to {
630 self.record_state_change(tx, state_usize(v));
631 self.st[state_usize(v)].small_to[i] = new_to;
632 if v == 0 && ch < BYTE_ALPHA_N as u32 {
633 self.root_to[ch as usize] = new_to;
634 }
635 return true;
636 }
637 }
638 }
639 let mut ei = self.st[state_usize(v)].head;
641 while ei != SAM_EDGE_NONE {
642 let eidx = edge_usize(ei);
643 let e = self.ed[eidx];
644 if e.ch == ch && e.to == old_to {
645 self.record_edge_change(tx, eidx);
646 self.ed[eidx].to = new_to;
647 if v == 0 && ch < BYTE_ALPHA_N as u32 {
648 self.root_to[ch as usize] = new_to;
649 }
650 return true;
651 }
652 ei = e.next;
653 }
654 false
655 }
656
657 fn clone_overflow_edges_tx(&mut self, tx: &mut SamTx, src: SamStateIx, dst: SamStateIx) {
658 self.record_state_change(tx, state_usize(dst));
659 self.st[state_usize(dst)].head = SAM_EDGE_NONE;
660 let mut ei = self.st[state_usize(src)].head;
661 while ei != SAM_EDGE_NONE {
662 let e = self.ed[edge_usize(ei)];
663 self.add_edge_tx(tx, dst, e.ch, e.to);
664 ei = e.next;
665 }
666 }
667
668 fn feed_tx(&mut self, tx: &mut SamTx, ch: u32) {
669 let i = self.text.len() as i32;
670 self.text.push(ch);
671 self.boundary_after.push(0);
672
673 let g = self.last;
674 let r = state_ix(self.st.len());
675 let st_r = SamState {
676 link: 0,
677 len: self.st[state_usize(g)].len + 1,
678 endpos: i,
679 small_n: 0,
680 head: SAM_EDGE_NONE,
681 ..Default::default()
682 };
683 self.st.push(st_r);
684
685 let mut p = g;
686 let mut q;
687 while p != SAM_STATE_NONE {
688 q = self.get_edge(p, ch);
689 if q != SAM_STATE_NONE {
690 break;
691 }
692 self.add_edge_absent_tx(tx, p, ch, r);
693 p = self.st[state_usize(p)].link;
694 }
695
696 if p == SAM_STATE_NONE {
697 self.st[state_usize(r)].link = 0;
699 } else {
700 q = self.get_edge(p, ch);
701 if self.st[state_usize(p)].len + 1 == self.st[state_usize(q)].len {
702 self.st[state_usize(r)].link = q;
703 } else {
704 let u = state_ix(self.st.len());
705 let mut st_u = self.st[state_usize(q)];
706 st_u.len = self.st[state_usize(p)].len + 1;
707 self.st.push(st_u);
708 self.clone_overflow_edges_tx(tx, q, u);
709 while p != SAM_STATE_NONE && self.replace_edge_to_tx(tx, p, ch, q, u) {
710 p = self.st[state_usize(p)].link;
711 }
712 self.record_state_change(tx, state_usize(q));
714 self.st[state_usize(q)].link = u;
715 self.st[state_usize(r)].link = u;
716 }
717 }
718
719 self.last = r;
720 self.text_states.push(r);
721
722 let mut v = r;
724 while v != SAM_STATE_NONE && self.st[state_usize(v)].endpos < i {
725 self.record_state_change(tx, state_usize(v));
726 self.st[state_usize(v)].endpos = i;
727 v = self.st[state_usize(v)].link;
728 }
729 }
730
731 fn mark_boundary_tx(&mut self, tx: &mut SamTx) {
732 if !self.text.is_empty() {
733 let i = self.text.len() - 1;
735 self.boundary_after[i] = 1;
736 }
737 self.last = 0;
739 let _ = tx;
740 }
741}
742
743#[derive(Clone)]
744struct SamTx {
745 old_last: SamStateIx,
746 old_text_len: usize,
747 old_text_states_len: usize,
748 old_boundary_len: usize,
749 old_st_len: usize,
750 old_ed_len: usize,
751 st_changes: Vec<(usize, SamState)>,
752 ed_changes: Vec<(usize, SamEdge)>,
753}
754
755#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
756struct LmState {
757 head: LmNodeIx,
758 total_n: u64,
759 types_t: u32,
760
761 last_sym: u32,
762 last_node: LmNodeIx,
763}
764
765#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
766struct CountNode {
767 sym_idx: u32,
768 cnt: u64,
769 next: LmNodeIx,
770}
771
772#[derive(Clone, Debug, Default, PartialEq, Eq)]
773struct LmNodes {
774 sym_lo: Vec<u16>,
775 cnt_lo: Vec<u16>,
776 next: Vec<LmNodeIx>,
777 cnt_overflow_mask: Vec<u8>,
778 sym_overflow: HashMap<u32, u32>,
779 cnt_overflow: HashMap<u32, u64>,
780}
781
782impl LmNodes {
783 #[inline(always)]
784 fn len(&self) -> usize {
785 self.next.len()
786 }
787
788 #[inline(always)]
789 fn clear(&mut self) {
790 self.sym_lo.clear();
791 self.cnt_lo.clear();
792 self.next.clear();
793 self.cnt_overflow_mask.clear();
794 self.sym_overflow.clear();
795 self.cnt_overflow.clear();
796 }
797
798 #[inline(always)]
799 fn reserve_exact(&mut self, additional: usize) {
800 self.sym_lo.reserve_exact(additional);
801 self.cnt_lo.reserve_exact(additional);
802 self.next.reserve_exact(additional);
803 self.cnt_overflow_mask.reserve_exact(additional);
804 }
805
806 #[inline(always)]
807 fn truncate(&mut self, new_len: usize) {
808 self.sym_lo.truncate(new_len);
809 self.cnt_lo.truncate(new_len);
810 self.next.truncate(new_len);
811 self.cnt_overflow_mask.truncate(new_len);
812 self.sym_overflow.retain(|&k, _| (k as usize) < new_len);
813 self.cnt_overflow.retain(|&k, _| (k as usize) < new_len);
814 }
815
816 #[inline(always)]
817 fn resize(&mut self, new_len: usize, value: CountNode) {
818 if new_len <= self.len() {
819 self.truncate(new_len);
820 return;
821 }
822 while self.len() < new_len {
823 self.push(value);
824 }
825 }
826
827 #[inline(always)]
828 fn set_sym_idx(&mut self, idx: usize, sym_idx: u32) {
829 if sym_idx < LM_PACKED_SYM_OVERFLOW as u32 {
830 self.sym_lo[idx] = sym_idx as u16;
831 self.sym_overflow.remove(&(idx as u32));
832 } else {
833 self.sym_lo[idx] = LM_PACKED_SYM_OVERFLOW;
834 self.sym_overflow.insert(idx as u32, sym_idx);
835 }
836 }
837
838 #[inline(always)]
839 fn set_cnt(&mut self, idx: usize, cnt: u64) {
840 if cnt <= LM_PACKED_CNT_MAX as u64 {
841 self.cnt_lo[idx] = cnt as u16;
842 self.cnt_overflow.remove(&(idx as u32));
843 self.cnt_overflow_mask[idx] = 0;
844 } else {
845 self.cnt_lo[idx] = LM_PACKED_CNT_MAX;
846 self.cnt_overflow
847 .insert(idx as u32, cnt - LM_PACKED_CNT_MAX as u64);
848 self.cnt_overflow_mask[idx] = 1;
849 }
850 }
851
852 #[inline(always)]
853 fn push(&mut self, node: CountNode) {
854 let idx = self.len();
855 self.sym_lo.push(0);
856 self.cnt_lo.push(0);
857 self.next.push(node.next);
858 self.cnt_overflow_mask.push(0);
859 self.set_sym_idx(idx, node.sym_idx);
860 self.set_cnt(idx, node.cnt);
861 }
862
863 #[inline(always)]
864 fn get(&self, idx: usize) -> CountNode {
865 CountNode {
866 sym_idx: self.sym_idx(idx),
867 cnt: self.cnt(idx),
868 next: self.next[idx],
869 }
870 }
871
872 #[inline(always)]
873 fn set(&mut self, idx: usize, node: CountNode) {
874 self.next[idx] = node.next;
875 self.set_sym_idx(idx, node.sym_idx);
876 self.set_cnt(idx, node.cnt);
877 }
878
879 #[inline(always)]
880 fn sym_idx(&self, idx: usize) -> u32 {
881 if self.sym_lo[idx] == LM_PACKED_SYM_OVERFLOW {
882 self.sym_overflow
883 .get(&(idx as u32))
884 .copied()
885 .unwrap_or(LM_PACKED_SYM_OVERFLOW as u32)
886 } else {
887 self.sym_lo[idx] as u32
888 }
889 }
890
891 #[inline(always)]
892 fn cnt(&self, idx: usize) -> u64 {
893 if self.cnt_overflow_mask[idx] == 0 {
894 self.cnt_lo[idx] as u64
895 } else {
896 self.cnt_lo[idx] as u64 + self.cnt_overflow.get(&(idx as u32)).copied().unwrap_or(0)
897 }
898 }
899
900 #[inline(always)]
901 fn next(&self, idx: usize) -> LmNodeIx {
902 self.next[idx]
903 }
904
905 #[inline(always)]
906 fn add_cnt(&mut self, idx: usize, add: u64) {
907 let next = self.cnt(idx).saturating_add(add);
908 self.set_cnt(idx, next);
909 }
910}
911
912struct LmNodesIter<'a> {
913 nodes: &'a LmNodes,
914 idx: usize,
915}
916
917impl<'a> Iterator for LmNodesIter<'a> {
918 type Item = CountNode;
919
920 fn next(&mut self) -> Option<Self::Item> {
921 if self.idx >= self.nodes.len() {
922 return None;
923 }
924 let out = self.nodes.get(self.idx);
925 self.idx += 1;
926 Some(out)
927 }
928}
929
930impl<'a> IntoIterator for &'a LmNodes {
931 type Item = CountNode;
932 type IntoIter = LmNodesIter<'a>;
933
934 fn into_iter(self) -> Self::IntoIter {
935 LmNodesIter {
936 nodes: self,
937 idx: 0,
938 }
939 }
940}
941
942#[derive(Clone)]
943struct LM {
944 alphabet: Vec<u32>,
945 unigram: Vec<u64>,
946 alpha_n: u32,
947 total_uni: u64,
948
949 has_byte_map: bool,
950 byte_map: [i16; 256],
951
952 ls: Vec<LmState>,
953 nodes: LmNodes,
954}
955
956impl Default for LM {
957 fn default() -> Self {
958 LM {
959 alphabet: Vec::new(),
960 unigram: Vec::new(),
961 alpha_n: 0,
962 total_uni: 0,
963 has_byte_map: false,
964 byte_map: [-1; 256],
965 ls: Vec::new(),
966 nodes: LmNodes::default(),
967 }
968 }
969}
970
971impl LM {
972 #[inline(always)]
973 fn ls_is_implicit_single(ls: &LmState) -> bool {
974 ls.head == LM_NODE_NONE && ls.types_t == 1 && ls.total_n > 0
975 }
976
977 #[inline(always)]
978 fn capped_start_state(&self, sam: &Sam, max_order: i64, mut v: SamStateIx) -> SamStateIx {
979 if max_order < 0 {
980 return v;
981 }
982 while v != SAM_STATE_NONE && (sam.st[state_usize(v)].len as i64) > max_order {
983 v = sam.st[state_usize(v)].link;
984 }
985 if v == SAM_STATE_NONE { 0 } else { v }
986 }
987
988 fn build_alphabet(&mut self, sam: &Sam) {
989 self.has_byte_map = false;
990 self.byte_map = [-1; 256];
991
992 let mut max_cp = 0u32;
993 for &v in &sam.text {
994 if v > max_cp {
995 max_cp = v;
996 }
997 }
998
999 if max_cp < 256 {
1000 let mut counts = [0u64; 256];
1001 for &v in &sam.text {
1002 counts[v as usize] += 1;
1003 }
1004 let mut uniq = 0usize;
1005 for c in 0..256 {
1006 if counts[c] != 0 {
1007 uniq += 1;
1008 }
1009 }
1010
1011 if uniq == 0 {
1012 self.alphabet = vec![b'\n' as u32];
1013 self.unigram = vec![1];
1014 self.alpha_n = 1;
1015 self.total_uni = 1;
1016 self.has_byte_map = true;
1017 self.byte_map[b'\n' as usize] = 0;
1018 return;
1019 }
1020
1021 self.alphabet = Vec::with_capacity(uniq);
1022 self.unigram = Vec::with_capacity(uniq);
1023 self.total_uni = 0;
1024 for c in 0..256u32 {
1025 let cnt = counts[c as usize];
1026 if cnt == 0 {
1027 continue;
1028 }
1029 self.alphabet.push(c);
1030 self.unigram.push(cnt);
1031 self.total_uni += cnt;
1032 }
1033 self.alpha_n = self.alphabet.len() as u32;
1034 self.has_byte_map = true;
1035 for (i, &c) in self.alphabet.iter().enumerate() {
1036 self.byte_map[c as usize] = i as i16;
1037 }
1038 return;
1039 }
1040
1041 let mut tmp = sam.text.clone();
1042 tmp.sort_unstable();
1043 tmp.dedup();
1044 if tmp.is_empty() {
1045 tmp.push(b'\n' as u32);
1046 }
1047 self.alphabet = tmp;
1048 self.alpha_n = self.alphabet.len() as u32;
1049 self.unigram = vec![0u64; self.alphabet.len()];
1050 self.total_uni = 0;
1051 for &ch in &sam.text {
1052 if let Ok(i) = self.alphabet.binary_search(&ch) {
1053 self.unigram[i] += 1;
1054 self.total_uni += 1;
1055 }
1056 }
1057 if self.total_uni == 0 {
1058 self.unigram[0] = 1;
1059 self.total_uni = 1;
1060 }
1061 }
1062
1063 #[inline(always)]
1064 fn find_sym(&self, ch: u32) -> i32 {
1065 if self.has_byte_map && ch < 256 {
1066 return self.byte_map[ch as usize] as i32;
1067 }
1068 match self.alphabet.binary_search(&ch) {
1069 Ok(i) => i as i32,
1070 Err(_) => -1,
1071 }
1072 }
1073
1074 #[inline(always)]
1075 fn inc(&mut self, state: u32, sym_idx: u32, add: u64) {
1076 let ls = &mut self.ls[state as usize];
1077 if ls.head == LM_NODE_NONE {
1078 if ls.total_n == 0 {
1079 ls.total_n = add;
1080 ls.types_t = 1;
1081 ls.last_sym = sym_idx;
1082 ls.last_node = LM_NODE_NONE;
1083 return;
1084 }
1085 if Self::ls_is_implicit_single(ls) {
1086 if ls.last_sym == sym_idx {
1087 ls.total_n += add;
1088 ls.last_node = LM_NODE_NONE;
1089 return;
1090 }
1091 let old_sym = ls.last_sym;
1092 let old_cnt = ls.total_n;
1093 let old_idx = node_ix(self.nodes.len());
1094 self.nodes.push(CountNode {
1095 sym_idx: old_sym,
1096 cnt: old_cnt,
1097 next: LM_NODE_NONE,
1098 });
1099 let new_idx = node_ix(self.nodes.len());
1100 self.nodes.push(CountNode {
1101 sym_idx,
1102 cnt: add,
1103 next: old_idx,
1104 });
1105 ls.head = new_idx;
1106 ls.total_n = old_cnt + add;
1107 ls.types_t = 2;
1108 ls.last_node = new_idx;
1109 ls.last_sym = sym_idx;
1110 return;
1111 }
1112 }
1113
1114 let last = ls.last_node;
1115 if last != LM_NODE_NONE && self.nodes.sym_idx(node_usize(last)) == sym_idx {
1116 self.nodes.add_cnt(node_usize(last), add);
1117 ls.total_n += add;
1118 return;
1119 }
1120
1121 let mut ni = ls.head;
1122 while ni != LM_NODE_NONE {
1123 let idx = node_usize(ni);
1124 if self.nodes.sym_idx(idx) == sym_idx {
1125 self.nodes.add_cnt(idx, add);
1126 ls.total_n += add;
1127 ls.last_node = ni;
1128 ls.last_sym = sym_idx;
1129 return;
1130 }
1131 ni = self.nodes.next(idx);
1132 }
1133
1134 let idx = node_ix(self.nodes.len());
1135 self.nodes.push(CountNode {
1136 sym_idx,
1137 cnt: add,
1138 next: ls.head,
1139 });
1140 ls.head = idx;
1141 ls.total_n += add;
1142 ls.types_t += 1;
1143 ls.last_node = idx;
1144 ls.last_sym = sym_idx;
1145 }
1146
1147 #[inline(always)]
1148 fn reserve_for_stream(&mut self, additional: usize) {
1149 if additional == 0 {
1150 return;
1151 }
1152 self.ls
1153 .reserve_exact(additional.saturating_mul(2).saturating_add(16));
1154 self.nodes
1155 .reserve_exact(additional.saturating_mul(3).saturating_add(16));
1156 }
1157
1158 fn build_counts(&mut self, sam: &Sam, max_order: i64) {
1159 self.ls = vec![
1160 LmState {
1161 head: LM_NODE_NONE,
1162 last_node: LM_NODE_NONE,
1163 ..LmState::default()
1164 };
1165 sam.st.len()
1166 ];
1167 self.nodes.clear();
1168
1169 let mut seg_start = 0usize;
1170 while seg_start < sam.text.len() {
1171 let mut seg_end = seg_start;
1172 while seg_end < sam.text.len() {
1173 let b = sam.boundary_after[seg_end];
1174 seg_end += 1;
1175 if b != 0 {
1176 break;
1177 }
1178 }
1179 if seg_end - seg_start >= 2 {
1180 let mut v = 0;
1181 for i in seg_start..(seg_end - 1) {
1182 let ch = sam.text[i];
1183 v = sam.advance(v, ch);
1184 let mut ctx = v;
1185 if max_order >= 0 {
1186 while ctx != SAM_STATE_NONE
1187 && (sam.st[state_usize(ctx)].len as i64) > max_order
1188 {
1189 ctx = sam.st[state_usize(ctx)].link;
1190 }
1191 if ctx == SAM_STATE_NONE {
1192 ctx = 0;
1193 }
1194 }
1195 let nxt = sam.text[i + 1];
1196 let si = self.find_sym(nxt);
1197 if si >= 0 {
1198 self.inc(state_usize(ctx) as u32, si as u32, 1);
1199 }
1200 }
1201 }
1202 seg_start = seg_end;
1203 }
1204
1205 let mut max_len: usize = 0;
1207 for st in &sam.st {
1208 let l = st.len as usize;
1209 if l > max_len {
1210 max_len = l;
1211 }
1212 }
1213 let mut cnt = vec![0usize; max_len + 1];
1214 for st in &sam.st {
1215 cnt[st.len as usize] += 1;
1216 }
1217 let mut pos = vec![0usize; max_len + 1];
1218 let mut acc = 0usize;
1219 for l in 0..=max_len {
1220 pos[l] = acc;
1221 acc += cnt[l];
1222 }
1223 let mut order = vec![0u32; sam.st.len()];
1224 for (v, st) in sam.st.iter().enumerate() {
1225 let l = st.len as usize;
1226 let idx = pos[l];
1227 order[idx] = v as u32;
1228 pos[l] += 1;
1229 }
1230
1231 for oi in (0..order.len()).rev() {
1232 let v = order[oi] as usize;
1233 let p = sam.st[v].link;
1234 if p < 0 {
1235 continue;
1236 }
1237 let ls_v = self.ls[v];
1238 if ls_v.total_n == 0 {
1239 continue;
1240 }
1241 if Self::ls_is_implicit_single(&ls_v) {
1242 self.inc(state_usize(p) as u32, ls_v.last_sym, ls_v.total_n);
1243 continue;
1244 }
1245 let mut ni = ls_v.head;
1246 while ni != LM_NODE_NONE {
1247 let node = self.nodes.get(node_usize(ni));
1248 self.inc(state_usize(p) as u32, node.sym_idx, node.cnt);
1249 ni = node.next;
1250 }
1251 }
1252 }
1253
1254 fn prob_for_sym(&self, sam: &Sam, max_order: i64, v: SamStateIx, sym_idx: i32) -> f64 {
1257 if sym_idx < 0 {
1258 return 1.0 / (self.alpha_n.max(1) as f64);
1259 }
1260 let sym_idx = sym_idx as u32;
1261 let mut p_accum = 0.0f64;
1262 let mut residual = 1.0f64;
1263 let mut u = self.capped_start_state(sam, max_order, v);
1264
1265 while u != SAM_STATE_NONE {
1266 let ls = &self.ls[state_usize(u)];
1267 let n = ls.total_n;
1268 let t = ls.types_t;
1269 if n > 0 {
1270 let lam = if t > 0 {
1271 (n as f64) / ((n + (t as u64)) as f64)
1272 } else {
1273 1.0
1274 };
1275
1276 let scale = residual * lam;
1278
1279 let mut count_for_sym = 0u64;
1281 if Self::ls_is_implicit_single(ls) {
1282 if ls.last_sym == sym_idx {
1283 count_for_sym = n;
1284 }
1285 } else if ls.last_node != LM_NODE_NONE && ls.last_sym == sym_idx {
1286 count_for_sym = self.nodes.cnt(node_usize(ls.last_node));
1287 } else {
1288 let mut ni = ls.head;
1289 while ni != LM_NODE_NONE {
1290 let node = self.nodes.get(node_usize(ni));
1291 if node.sym_idx == sym_idx {
1292 count_for_sym = node.cnt;
1293 break;
1294 }
1295 ni = node.next;
1296 }
1297 }
1298
1299 if count_for_sym > 0 {
1300 p_accum += scale * (count_for_sym as f64 / n as f64);
1301 }
1302
1303 residual *= 1.0 - lam;
1304 }
1305 u = sam.st[state_usize(u)].link;
1306 }
1307
1308 if self.total_uni > 0 && residual > 0.0 {
1309 let p_uni = self.unigram[sym_idx as usize] as f64 / self.total_uni as f64;
1310 p_accum += residual * p_uni;
1311 } else if residual > 0.0 {
1312 p_accum += residual * (1.0 / self.alpha_n.max(1) as f64);
1313 }
1314
1315 p_accum.clamp(1e-12, 1.0)
1316 }
1317
1318 fn probs_for_state_raw(&self, sam: &Sam, max_order: i64, v: SamStateIx, out: &mut [f64]) {
1319 out.fill(0.0);
1320 let mut residual = 1.0f64;
1321 let mut u = self.capped_start_state(sam, max_order, v);
1322 while u != SAM_STATE_NONE {
1323 let ls = &self.ls[state_usize(u)];
1324 let n = ls.total_n;
1325 let t = ls.types_t;
1326 if n > 0 {
1327 let lam = if t > 0 {
1328 (n as f64) / ((n + (t as u64)) as f64)
1329 } else {
1330 1.0
1331 };
1332 let scale = residual * lam;
1333 let inv_n = 1.0 / (n as f64);
1334 if Self::ls_is_implicit_single(ls) {
1335 out[ls.last_sym as usize] += scale;
1336 } else {
1337 let mut ni = ls.head;
1338 while ni != LM_NODE_NONE {
1339 let node = self.nodes.get(node_usize(ni));
1340 out[node.sym_idx as usize] += scale * ((node.cnt as f64) * inv_n);
1341 ni = node.next;
1342 }
1343 }
1344 residual *= 1.0 - lam;
1345 }
1346 u = sam.st[state_usize(u)].link;
1347 }
1348
1349 if self.total_uni > 0 && residual > 0.0 {
1350 let inv = 1.0 / (self.total_uni as f64);
1351 for i in 0..(self.alpha_n as usize) {
1352 out[i] += residual * ((self.unigram[i] as f64) * inv);
1353 }
1354 }
1355 }
1356
1357 fn probs_for_state(&self, sam: &Sam, max_order: i64, v: SamStateIx, out: &mut [f64]) {
1358 self.probs_for_state_raw(sam, max_order, v, out);
1359 let mut s = 0.0;
1360 for i in 0..(self.alpha_n as usize) {
1361 s += out[i];
1362 }
1363 if s > 0.0 && s.is_finite() {
1364 if (s - 1.0).abs() <= 1e-12 {
1365 return;
1366 }
1367 let invs = 1.0 / s;
1368 for i in 0..(self.alpha_n as usize) {
1369 out[i] *= invs;
1370 }
1371 } else {
1372 let uprob = 1.0 / (self.alpha_n.max(1) as f64);
1373 for i in 0..(self.alpha_n as usize) {
1374 out[i] = uprob;
1375 }
1376 }
1377 }
1378
1379 #[inline(always)]
1380 fn inc_tx(&mut self, tx: &mut LmTx, state: u32, sym_idx: u32, add: u64) {
1381 let si = state as usize;
1382 tx.ls_changes.push((si, self.ls[si]));
1384
1385 let ls = &mut self.ls[si];
1386 if ls.head == LM_NODE_NONE {
1387 if ls.total_n == 0 {
1388 ls.total_n = add;
1389 ls.types_t = 1;
1390 ls.last_sym = sym_idx;
1391 ls.last_node = LM_NODE_NONE;
1392 return;
1393 }
1394 if Self::ls_is_implicit_single(ls) {
1395 if ls.last_sym == sym_idx {
1396 ls.total_n += add;
1397 ls.last_node = LM_NODE_NONE;
1398 return;
1399 }
1400 let old_sym = ls.last_sym;
1401 let old_cnt = ls.total_n;
1402 tx.old_nodes_len = tx.old_nodes_len.min(self.nodes.len());
1403 let old_idx = node_ix(self.nodes.len());
1404 self.nodes.push(CountNode {
1405 sym_idx: old_sym,
1406 cnt: old_cnt,
1407 next: LM_NODE_NONE,
1408 });
1409 let new_idx = node_ix(self.nodes.len());
1410 self.nodes.push(CountNode {
1411 sym_idx,
1412 cnt: add,
1413 next: old_idx,
1414 });
1415 ls.head = new_idx;
1416 ls.total_n = old_cnt + add;
1417 ls.types_t = 2;
1418 ls.last_node = new_idx;
1419 ls.last_sym = sym_idx;
1420 return;
1421 }
1422 }
1423
1424 let last = ls.last_node;
1425 if last != LM_NODE_NONE && self.nodes.sym_idx(node_usize(last)) == sym_idx {
1426 let ni = node_usize(last);
1427 tx.node_changes.push((ni, self.nodes.get(ni)));
1428 self.nodes.add_cnt(ni, add);
1429 ls.total_n += add;
1430 return;
1431 }
1432
1433 let mut ni = ls.head;
1434 while ni != LM_NODE_NONE {
1435 let idx = node_usize(ni);
1436 if self.nodes.sym_idx(idx) == sym_idx {
1437 tx.node_changes.push((idx, self.nodes.get(idx)));
1438 self.nodes.add_cnt(idx, add);
1439 ls.total_n += add;
1440 ls.last_node = ni;
1441 ls.last_sym = sym_idx;
1442 return;
1443 }
1444 ni = self.nodes.next(idx);
1445 }
1446
1447 let idx = node_ix(self.nodes.len());
1449 tx.old_nodes_len = tx.old_nodes_len.min(self.nodes.len());
1450 self.nodes.push(CountNode {
1451 sym_idx,
1452 cnt: add,
1453 next: ls.head,
1454 });
1455 ls.head = idx;
1456 ls.total_n += add;
1457 ls.types_t += 1;
1458 ls.last_node = idx;
1459 ls.last_sym = sym_idx;
1460 }
1461}
1462
1463#[derive(Clone)]
1464struct LmTx {
1465 old_ls_len: usize,
1466 old_nodes_len: usize,
1467 ls_changes: Vec<(usize, LmState)>,
1468 node_changes: Vec<(usize, CountNode)>,
1469 uni_delta: [u64; BYTE_ALPHA_N],
1471 total_uni_add: u64,
1472}
1473
1474#[derive(Clone, Default)]
1475struct RngStream {
1476 buf: Vec<u8>,
1477 pos: usize,
1478 xs: u64,
1479}
1480
1481impl RngStream {
1482 fn new(seed: u64) -> Self {
1483 let mut r = RngStream {
1484 buf: Vec::new(),
1485 pos: 0,
1486 xs: 88172645463325252u64,
1487 };
1488 if let Ok(path) = std::env::var("ROSAPLUS_RNG_PATH")
1489 && !path.is_empty()
1490 && let Ok(mut f) = File::open(path)
1491 {
1492 let mut b = Vec::new();
1493 if f.read_to_end(&mut b).is_ok() && b.len() >= 8 {
1494 let n = b.len();
1495 r.pos = ((seed.wrapping_mul(8)) as usize) % n;
1496 r.buf = b;
1497 }
1498 }
1499 r
1500 }
1501
1502 #[inline(always)]
1503 fn next_u64(&mut self) -> u64 {
1504 if self.buf.len() < 8 {
1505 self.xs ^= self.xs << 7;
1506 self.xs ^= self.xs >> 9;
1507 return self.xs;
1508 }
1509 let n = self.buf.len();
1510 let mut b = [0u8; 8];
1511 for i in 0..8 {
1512 b[i] = self.buf[self.pos];
1513 self.pos += 1;
1514 if self.pos >= n {
1515 self.pos = 0;
1516 }
1517 }
1518 u64::from_le_bytes(b)
1519 }
1520
1521 #[inline(always)]
1522 fn next_unit(&mut self) -> f64 {
1523 let x = self.next_u64();
1524 ((x >> 11) as f64) * (1.0 / 9007199254740992.0)
1525 }
1526}
1527
1528#[derive(Clone, Default)]
1532struct SampleScratch {
1533 idx: Vec<u32>,
1534 logits: Vec<f64>,
1535 exps: Vec<f64>,
1536}
1537
1538impl SampleScratch {
1539 fn ensure(&mut self, alpha_n: usize, n: usize) {
1540 if self.idx.len() != alpha_n {
1541 self.idx.resize(alpha_n, 0);
1542 }
1543 if self.logits.len() < n {
1544 self.logits.resize(n, 0.0);
1545 self.exps.resize(n, 0.0);
1546 }
1547 }
1548}
1549
1550#[derive(Clone)]
1551pub struct RosaPlus {
1553 max_order: i64,
1554 use_eot: bool,
1555 eot: u32,
1556 seed: u64,
1557
1558 sam: Sam,
1559 lm: LM,
1560 lm_built: bool,
1561
1562 rng: RngStream,
1563 scratch: SampleScratch,
1564 dist: Vec<f64>,
1565}
1566
1567#[derive(Clone, Copy, Debug)]
1572pub struct RosaCheckpoint {
1573 sam_st_len: usize,
1574 sam_ed_len: usize,
1575 sam_text_len: usize,
1576 sam_text_states_len: usize,
1577 sam_boundary_after_len: usize,
1578 sam_last: SamStateIx,
1579}
1580
1581#[derive(Clone)]
1583pub struct RosaTx {
1584 sam: SamTx,
1585 lm: LmTx,
1586 seg_start: usize,
1587 seg_len: usize,
1588}
1589
1590impl RosaPlus {
1591 pub fn new(max_order: i64, use_eot: bool, eot_char: u8, seed: u64) -> Self {
1595 let sam = Sam::new(0);
1596 RosaPlus {
1597 max_order,
1598 use_eot,
1599 eot: eot_char as u32,
1600 seed,
1601 sam,
1602 lm: LM::default(),
1603 lm_built: false,
1604 rng: RngStream::new(seed),
1605 scratch: SampleScratch::default(),
1606 dist: Vec::new(),
1607 }
1608 }
1609
1610 pub fn train_example(&mut self, s: &[u8]) {
1612 if s.is_empty() {
1613 return;
1614 }
1615
1616 if self.sam.text.is_empty() {
1617 self.sam = Sam::new(s.len());
1618 }
1619
1620 for &b in s {
1621 self.sam.feed(b as u32);
1622 }
1623
1624 if self.use_eot {
1625 self.sam.feed(self.eot);
1626 }
1627
1628 self.sam.mark_boundary();
1629 self.lm_built = false;
1630 }
1631
1632 pub fn reserve_for_stream(&mut self, additional_bytes: usize) {
1636 self.sam.reserve_additional(additional_bytes);
1637 self.lm.reserve_for_stream(additional_bytes);
1638 self.dist.reserve(BYTE_ALPHA_N);
1639 }
1640
1641 pub fn build_lm(&mut self) {
1643 self.sam.finalize_endpos();
1644 self.lm = LM::default();
1645 self.lm.build_alphabet(&self.sam);
1646 let mo = if self.max_order < 0 {
1647 -1
1648 } else {
1649 self.max_order
1650 };
1651 self.lm.build_counts(&self.sam, mo);
1652 self.lm_built = true;
1653 self.dist.resize(self.lm.alpha_n as usize, 0.0);
1654 }
1655
1656 pub fn build_lm_no_finalize_endpos(&mut self) {
1663 self.lm = LM::default();
1664 self.lm.build_alphabet(&self.sam);
1665 let mo = if self.max_order < 0 {
1666 -1
1667 } else {
1668 self.max_order
1669 };
1670 self.lm.build_counts(&self.sam, mo);
1671 self.lm_built = true;
1672 self.dist.resize(self.lm.alpha_n as usize, 0.0);
1673 }
1674
1675 pub fn build_lm_full_bytes_no_finalize_endpos(&mut self) {
1679 self.lm = LM::default();
1681 self.lm.has_byte_map = true;
1682 self.lm.alpha_n = BYTE_ALPHA_N as u32;
1683 self.lm.alphabet = (0..BYTE_ALPHA_N as u32).collect();
1684 self.lm.byte_map = [-1; 256];
1685 for i in 0..256 {
1686 self.lm.byte_map[i] = i as i16;
1687 }
1688
1689 let mut counts = [0u64; 256];
1691 for &v in &self.sam.text {
1692 if v < 256 {
1693 counts[v as usize] += 1;
1694 }
1695 }
1696 self.lm.unigram = counts.to_vec();
1697 self.lm.total_uni = counts.iter().sum();
1698 if self.lm.total_uni == 0 {
1699 for i in 0..256 {
1700 self.lm.unigram[i] = 1;
1701 }
1702 self.lm.total_uni = 256;
1703 }
1704
1705 let mo = if self.max_order < 0 {
1707 -1
1708 } else {
1709 self.max_order
1710 };
1711 self.lm.build_counts(&self.sam, mo);
1712 self.lm_built = true;
1713 self.dist.resize(BYTE_ALPHA_N, 0.0);
1714 }
1715
1716 pub fn begin_tx(&mut self) -> RosaTx {
1718 let sam_tx = self.sam.begin_tx();
1719 let lm_tx = LmTx {
1720 old_ls_len: self.lm.ls.len(),
1721 old_nodes_len: self.lm.nodes.len(),
1722 ls_changes: Vec::new(),
1723 node_changes: Vec::new(),
1724 uni_delta: [0u64; BYTE_ALPHA_N],
1725 total_uni_add: 0,
1726 };
1727 RosaTx {
1728 sam: sam_tx,
1729 lm: lm_tx,
1730 seg_start: self.sam.text.len(),
1731 seg_len: 0,
1732 }
1733 }
1734
1735 pub fn train_example_tx(&mut self, tx: &mut RosaTx, s: &[u8]) {
1737 self.train_example_tx_impl(tx, s, true);
1738 }
1739
1740 pub fn train_sequence_tx(&mut self, tx: &mut RosaTx, s: &[u8]) {
1742 self.train_example_tx_impl(tx, s, false);
1743 }
1744
1745 pub fn train_sequence(&mut self, s: &[u8]) {
1747 if s.is_empty() {
1748 return;
1749 }
1750
1751 if s.len() == 1 {
1752 self.train_byte(s[0]);
1753 return;
1754 }
1755
1756 if self.sam.text.is_empty() {
1757 self.sam = Sam::new(s.len());
1758 }
1759 self.reserve_for_stream(s.len());
1760 if !self.lm_built || !self.lm.has_byte_map || (self.lm.alpha_n as usize) != BYTE_ALPHA_N {
1761 self.build_lm_full_bytes_no_finalize_endpos();
1762 }
1763
1764 if self.lm.ls.len() < self.sam.st.len() {
1765 self.lm.ls.resize(
1766 self.sam.st.len(),
1767 LmState {
1768 head: LM_NODE_NONE,
1769 last_node: LM_NODE_NONE,
1770 ..LmState::default()
1771 },
1772 );
1773 }
1774
1775 let seg_start = self.sam.text.len();
1776 for &b in s {
1777 self.sam.feed(b as u32);
1778 self.lm.unigram[b as usize] += 1;
1779 self.lm.total_uni += 1;
1780 }
1781
1782 if self.lm.ls.len() < self.sam.st.len() {
1783 self.lm.ls.resize(
1784 self.sam.st.len(),
1785 LmState {
1786 head: LM_NODE_NONE,
1787 last_node: LM_NODE_NONE,
1788 ..LmState::default()
1789 },
1790 );
1791 }
1792
1793 let seg_end = self.sam.text.len();
1794 if seg_end.saturating_sub(seg_start) >= 1 {
1795 let mo = if self.max_order < 0 {
1796 -1
1797 } else {
1798 self.max_order
1799 };
1800 let mut start_i = seg_start;
1801 if seg_start > 0
1802 && self
1803 .sam
1804 .boundary_after
1805 .get(seg_start - 1)
1806 .copied()
1807 .unwrap_or(0)
1808 == 0
1809 {
1810 start_i = seg_start - 1;
1811 }
1812 for i in start_i..(seg_end - 1) {
1813 let mut ctx = self.sam.text_states[i + 1];
1814 if mo >= 0 {
1815 while ctx != SAM_STATE_NONE && (self.sam.st[state_usize(ctx)].len as i64) > mo {
1816 ctx = self.sam.st[state_usize(ctx)].link;
1817 }
1818 if ctx == SAM_STATE_NONE {
1819 ctx = 0;
1820 }
1821 }
1822 let nxt = self.sam.text[i + 1];
1823 let si = self.lm.find_sym(nxt);
1824 if si >= 0 {
1825 let mut u = ctx;
1826 while u != SAM_STATE_NONE {
1827 self.lm.inc(state_usize(u) as u32, si as u32, 1);
1828 u = self.sam.st[state_usize(u)].link;
1829 }
1830 }
1831 }
1832 }
1833
1834 self.lm_built = true;
1835 }
1836
1837 #[inline]
1839 pub fn train_byte(&mut self, b: u8) {
1840 if self.sam.text.is_empty() {
1841 self.sam = Sam::new(1);
1842 }
1843 if !self.lm_built || !self.lm.has_byte_map || (self.lm.alpha_n as usize) != BYTE_ALPHA_N {
1844 self.build_lm_full_bytes_no_finalize_endpos();
1845 }
1846
1847 self.sam.feed(b as u32);
1848 self.lm.unigram[b as usize] += 1;
1849 self.lm.total_uni += 1;
1850
1851 if self.lm.ls.len() < self.sam.st.len() {
1852 self.lm.ls.resize(
1853 self.sam.st.len(),
1854 LmState {
1855 head: LM_NODE_NONE,
1856 last_node: LM_NODE_NONE,
1857 ..LmState::default()
1858 },
1859 );
1860 }
1861
1862 let seg_end = self.sam.text.len();
1863 if seg_end > 1
1864 && self
1865 .sam
1866 .boundary_after
1867 .get(seg_end - 2)
1868 .copied()
1869 .unwrap_or(0)
1870 == 0
1871 {
1872 let mo = if self.max_order < 0 {
1873 -1
1874 } else {
1875 self.max_order
1876 };
1877 let mut ctx = self.sam.text_states[seg_end - 1];
1878 if mo >= 0 {
1879 while ctx != SAM_STATE_NONE && (self.sam.st[state_usize(ctx)].len as i64) > mo {
1880 ctx = self.sam.st[state_usize(ctx)].link;
1881 }
1882 if ctx == SAM_STATE_NONE {
1883 ctx = 0;
1884 }
1885 }
1886 let mut u = ctx;
1887 let si = b as u32;
1888 while u != SAM_STATE_NONE {
1889 self.lm.inc(state_usize(u) as u32, si, 1);
1890 u = self.sam.st[state_usize(u)].link;
1891 }
1892 }
1893
1894 self.lm_built = true;
1895 }
1896
1897 pub fn reset_conditioning_cursor(&mut self) {
1899 self.sam.last = 0;
1900 }
1901
1902 pub fn advance_conditioning_byte(&mut self, b: u8) {
1904 self.sam.last = self.sam.advance(self.sam.last, b as u32);
1905 }
1906
1907 fn train_example_tx_impl(&mut self, tx: &mut RosaTx, s: &[u8], mark_boundary: bool) {
1908 if s.is_empty() {
1909 return;
1910 }
1911
1912 if self.lm.ls.len() < self.sam.st.len() {
1914 self.lm.ls.resize(
1915 self.sam.st.len(),
1916 LmState {
1917 head: LM_NODE_NONE,
1918 last_node: LM_NODE_NONE,
1919 ..LmState::default()
1920 },
1921 );
1922 }
1923
1924 for &b in s {
1926 self.sam.feed_tx(&mut tx.sam, b as u32);
1927 tx.lm.uni_delta[b as usize] += 1;
1928 tx.lm.total_uni_add += 1;
1929 }
1930 if mark_boundary {
1931 self.sam.mark_boundary_tx(&mut tx.sam);
1932 }
1933
1934 if self.lm.ls.len() < self.sam.st.len() {
1937 self.lm.ls.resize(
1938 self.sam.st.len(),
1939 LmState {
1940 head: LM_NODE_NONE,
1941 last_node: LM_NODE_NONE,
1942 ..LmState::default()
1943 },
1944 );
1945 }
1946
1947 for i in 0..256 {
1949 if tx.lm.uni_delta[i] != 0 {
1950 self.lm.unigram[i] += tx.lm.uni_delta[i];
1951 }
1952 }
1953 self.lm.total_uni += tx.lm.total_uni_add;
1954
1955 let seg_start = tx.seg_start;
1957 let seg_end = self.sam.text.len();
1958 tx.seg_len = seg_end - seg_start;
1959 if tx.seg_len >= 1 {
1960 let mo = if self.max_order < 0 {
1961 -1
1962 } else {
1963 self.max_order
1964 };
1965 let mut start_i = seg_start;
1969 if !mark_boundary
1970 && seg_start > 0
1971 && self
1972 .sam
1973 .boundary_after
1974 .get(seg_start - 1)
1975 .copied()
1976 .unwrap_or(0)
1977 == 0
1978 {
1979 start_i = seg_start - 1;
1980 }
1981 for i in start_i..(seg_end - 1) {
1982 let mut ctx = self.sam.text_states[i + 1];
1984 if mo >= 0 {
1985 while ctx != SAM_STATE_NONE && (self.sam.st[state_usize(ctx)].len as i64) > mo {
1986 ctx = self.sam.st[state_usize(ctx)].link;
1987 }
1988 if ctx == SAM_STATE_NONE {
1989 ctx = 0;
1990 }
1991 }
1992 let nxt = self.sam.text[i + 1];
1993 let si = self.lm.find_sym(nxt);
1994 if si >= 0 {
1995 let mut u = ctx;
1996 while u != SAM_STATE_NONE {
1997 self.lm
1998 .inc_tx(&mut tx.lm, state_usize(u) as u32, si as u32, 1);
1999 u = self.sam.st[state_usize(u)].link;
2000 }
2001 }
2002 }
2003 }
2004
2005 self.lm_built = true;
2006 }
2007
2008 pub fn rollback_tx(&mut self, tx: RosaTx) {
2010 if self.lm.unigram.len() >= BYTE_ALPHA_N {
2013 for i in 0..BYTE_ALPHA_N {
2014 let d = tx.lm.uni_delta[i];
2015 if d != 0 {
2016 self.lm.unigram[i] = self.lm.unigram[i].saturating_sub(d);
2017 }
2018 }
2019 self.lm.total_uni = self.lm.total_uni.saturating_sub(tx.lm.total_uni_add);
2020 }
2021
2022 for (idx, old) in tx.lm.node_changes.into_iter().rev() {
2023 if idx < self.lm.nodes.len() {
2024 self.lm.nodes.set(idx, old);
2025 }
2026 }
2027 for (idx, old) in tx.lm.ls_changes.into_iter().rev() {
2028 if idx < self.lm.ls.len() {
2029 self.lm.ls[idx] = old;
2030 }
2031 }
2032 self.lm.nodes.truncate(tx.lm.old_nodes_len);
2033 self.lm.ls.truncate(tx.lm.old_ls_len);
2034
2035 self.sam.rollback_tx(tx.sam);
2037 }
2039
2040 #[inline(always)]
2042 pub fn ensure_lm_built_no_finalize_endpos(&mut self) {
2043 if !self.lm_built {
2044 self.build_lm_no_finalize_endpos();
2045 }
2046 }
2047
2048 fn predictive_entropy_rate_order(data: &[u8], max_order: i64, seed: u64) -> f64 {
2049 if data.len() < 2 {
2050 return 0.0;
2051 }
2052 let num_chunks = 16;
2053 let chunk_size = data.len().div_ceil(num_chunks);
2054 let mut total_log_prob = 0.0f64;
2055 let mut count = 0usize;
2056 let mut m = RosaPlus::new(max_order, false, 0, seed);
2057 m.sam = Sam::new(data.len());
2058 m.lm_built = false;
2059
2060 for i in 0..num_chunks {
2061 let start = i * chunk_size;
2062 let end = ((i + 1) * chunk_size).min(data.len());
2063 if start >= end {
2064 break;
2065 }
2066 let chunk = &data[start..end];
2067 if i > 0 {
2068 m.build_lm_no_finalize_endpos();
2069 let mut v = 0;
2070 for &b in chunk {
2071 let sym_idx = m.lm.find_sym(b as u32);
2072 let p = m.lm.prob_for_sym(&m.sam, max_order, v, sym_idx);
2073 total_log_prob += p.log2();
2074 count += 1;
2075 v = m.sam.advance(v, b as u32);
2076 }
2077 }
2078 for &b in chunk {
2079 m.sam.feed(b as u32);
2080 }
2081 }
2082
2083 if count == 0 {
2084 m.train_example(data);
2085 m.build_lm();
2086 m.cross_entropy(data)
2087 } else {
2088 -total_log_prob / (count as f64)
2089 }
2090 }
2091
2092 pub fn lm_alpha_n(&self) -> usize {
2094 if !self.lm_built {
2095 0
2096 } else {
2097 self.lm.alpha_n as usize
2098 }
2099 }
2100
2101 pub fn estimated_size_bytes(&self) -> usize {
2103 use std::mem::size_of;
2104
2105 let mut n = 0usize;
2106
2107 n = n.saturating_add(self.sam.st.len().saturating_mul(size_of::<SamState>()));
2108 n = n.saturating_add(self.sam.ed.len().saturating_mul(size_of::<SamEdge>()));
2109 n = n.saturating_add(self.sam.text.len().saturating_mul(size_of::<u32>()));
2110 n = n.saturating_add(
2111 self.sam
2112 .text_states
2113 .len()
2114 .saturating_mul(size_of::<SamStateIx>()),
2115 );
2116 n = n.saturating_add(size_of::<[SamStateIx; BYTE_ALPHA_N]>());
2117 n = n.saturating_add(
2118 self.sam
2119 .boundary_after
2120 .len()
2121 .saturating_mul(size_of::<u8>()),
2122 );
2123
2124 n = n.saturating_add(self.lm.alphabet.len().saturating_mul(size_of::<u32>()));
2125 n = n.saturating_add(self.lm.unigram.len().saturating_mul(size_of::<u64>()));
2126 n = n.saturating_add(self.lm.ls.len().saturating_mul(size_of::<LmState>()));
2127 n = n.saturating_add(self.lm.nodes.sym_lo.len().saturating_mul(size_of::<u16>()));
2128 n = n.saturating_add(self.lm.nodes.cnt_lo.len().saturating_mul(size_of::<u16>()));
2129 n = n.saturating_add(
2130 self.lm
2131 .nodes
2132 .next
2133 .len()
2134 .saturating_mul(size_of::<LmNodeIx>()),
2135 );
2136 n = n.saturating_add(
2137 self.lm
2138 .nodes
2139 .cnt_overflow_mask
2140 .len()
2141 .saturating_mul(size_of::<u8>()),
2142 );
2143 n = n.saturating_add(
2144 self.lm
2145 .nodes
2146 .sym_overflow
2147 .len()
2148 .saturating_mul(size_of::<u32>() + size_of::<u32>()),
2149 );
2150 n = n.saturating_add(
2151 self.lm
2152 .nodes
2153 .cnt_overflow
2154 .len()
2155 .saturating_mul(size_of::<u32>() + size_of::<u64>()),
2156 );
2157
2158 n = n.saturating_add(self.dist.len().saturating_mul(size_of::<f64>()));
2159 n = n.saturating_add(self.scratch.idx.len().saturating_mul(size_of::<u32>()));
2160 n = n.saturating_add(self.scratch.logits.len().saturating_mul(size_of::<f64>()));
2161 n = n.saturating_add(self.scratch.exps.len().saturating_mul(size_of::<f64>()));
2162 n = n.saturating_add(self.rng.buf.len().saturating_mul(size_of::<u8>()));
2163
2164 n
2165 }
2166
2167 pub fn shrink_aux_buffers(&mut self) {
2169 self.dist.shrink_to_fit();
2170 self.scratch.idx.shrink_to_fit();
2171 self.scratch.logits.shrink_to_fit();
2172 self.scratch.exps.shrink_to_fit();
2173 self.rng.buf.shrink_to_fit();
2174 }
2175
2176 pub fn fork_from_sam(&self) -> Self {
2182 Self {
2183 max_order: self.max_order,
2184 use_eot: self.use_eot,
2185 eot: self.eot,
2186 seed: self.seed,
2187
2188 sam: self.sam.clone(),
2189 lm: LM::default(),
2190 lm_built: false,
2191
2192 rng: RngStream::new(self.seed),
2193 scratch: SampleScratch::default(),
2194 dist: Vec::new(),
2195 }
2196 }
2197
2198 pub fn checkpoint(&self) -> RosaCheckpoint {
2204 RosaCheckpoint {
2205 sam_st_len: self.sam.st.len(),
2206 sam_ed_len: self.sam.ed.len(),
2207 sam_text_len: self.sam.text.len(),
2208 sam_text_states_len: self.sam.text_states.len(),
2209 sam_boundary_after_len: self.sam.boundary_after.len(),
2210 sam_last: self.sam.last,
2211 }
2212 }
2213
2214 pub fn restore(&mut self, ck: &RosaCheckpoint) {
2218 self.sam.st.truncate(ck.sam_st_len);
2219 self.sam.ed.truncate(ck.sam_ed_len);
2220 self.sam.text.truncate(ck.sam_text_len);
2221 self.sam.text_states.truncate(ck.sam_text_states_len);
2222 self.sam.boundary_after.truncate(ck.sam_boundary_after_len);
2223 self.sam.last = ck.sam_last;
2224 self.lm_built = false;
2225 }
2226
2227 #[inline(always)]
2228 fn sample(&mut self, temperature: f64, top_p: f64, top_k: i32) -> u32 {
2229 let dist = &self.dist;
2230 let alpha_n = self.lm.alpha_n as usize;
2231 self.scratch.ensure(alpha_n, alpha_n);
2232 for i in 0..alpha_n {
2233 self.scratch.idx[i] = i as u32;
2234 }
2235
2236 for i in 0..alpha_n {
2238 for j in (i + 1)..alpha_n {
2239 let ii = self.scratch.idx[i] as usize;
2240 let jj = self.scratch.idx[j] as usize;
2241 let pi = dist[ii];
2242 let pj = dist[jj];
2243 if pj > pi || (pj == pi && jj < ii) {
2244 self.scratch.idx.swap(i, j);
2245 }
2246 }
2247 }
2248
2249 let mut n = alpha_n;
2250 if top_k > 0 {
2251 let k = top_k as usize;
2252 if k < n {
2253 n = k;
2254 }
2255 }
2256
2257 if top_p > 0.0 && top_p < 1.0 {
2258 let mut cum = 0.0;
2259 let mut cut = 0usize;
2260 for i in 0..n {
2261 let si = self.scratch.idx[i] as usize;
2262 cum += dist[si];
2263 cut += 1;
2264 if cum >= top_p {
2265 break;
2266 }
2267 }
2268 n = if cut > 0 { cut } else { 1 };
2269 }
2270
2271 let temperature = if temperature <= 0.0 {
2272 1e-6
2273 } else {
2274 temperature
2275 };
2276
2277 self.scratch.ensure(alpha_n, n);
2278 let mut maxlog = -1e300f64;
2279 for i in 0..n {
2280 let si = self.scratch.idx[i] as usize;
2281 let mut p = dist[si];
2282 if p < 1e-12 {
2283 p = 1e-12;
2284 }
2285 let z = p.ln() / temperature;
2286 self.scratch.logits[i] = z;
2287 if z > maxlog {
2288 maxlog = z;
2289 }
2290 }
2291
2292 let mut zsum = 0.0;
2293 for i in 0..n {
2294 let e = (self.scratch.logits[i] - maxlog).exp();
2295 self.scratch.exps[i] = e;
2296 zsum += e;
2297 }
2298
2299 let r = self.rng.next_unit() * zsum;
2300 let mut cum = 0.0;
2301 let mut pick = 0usize;
2302 for i in 0..n {
2303 cum += self.scratch.exps[i];
2304 if cum > r {
2305 pick = i;
2306 break;
2307 }
2308 }
2309
2310 let sym = self.scratch.idx[pick] as usize;
2311 self.lm.alphabet[sym]
2312 }
2313
2314 pub fn generate(&mut self, prompt: &[u8], steps: i32) -> Option<Vec<u8>> {
2318 if !self.lm_built {
2319 return None;
2320 }
2321 let steps = steps.max(0) as usize;
2322
2323 let mut v = 0i32;
2324 for &b in prompt {
2325 v = self.sam.advance(v, b as u32);
2326 }
2327
2328 let mut out: Vec<u32> = Vec::with_capacity(steps);
2329
2330 for _ in 0..steps {
2331 let mut ch = self.sam.predict_det(v);
2332 if ch.is_none() {
2333 let mo = if self.max_order < 0 {
2334 -1
2335 } else {
2336 self.max_order
2337 };
2338 self.lm.probs_for_state(&self.sam, mo, v, &mut self.dist);
2339 ch = Some(self.sample(0.7, 0.9, 0));
2340 }
2341 let ch = ch.unwrap();
2342 out.push(ch);
2343 if self.use_eot && ch == self.eot {
2344 break;
2345 }
2346 v = self.sam.advance(v, ch);
2347 }
2348
2349 Some(out.iter().map(|&c| c as u8).collect())
2350 }
2351
2352 pub fn get_distribution(&mut self, context: &[u8]) -> Vec<(u32, f64)> {
2358 if !self.lm_built {
2359 self.build_lm();
2360 }
2361
2362 let mut v = 0i32;
2364 for &b in context {
2365 v = self.sam.advance(v, b as u32);
2366 }
2367
2368 let mo = if self.max_order < 0 {
2370 -1
2371 } else {
2372 self.max_order
2373 };
2374 self.dist.resize(self.lm.alpha_n as usize, 0.0);
2375 self.lm.probs_for_state(&self.sam, mo, v, &mut self.dist);
2376
2377 let mut result = Vec::with_capacity(self.lm.alpha_n as usize);
2379 for i in 0..(self.lm.alpha_n as usize) {
2380 if self.dist[i] > 0.0 {
2381 result.push((self.lm.alphabet[i], self.dist[i]));
2382 }
2383 }
2384 result.sort_by_key(|&(cp, _)| cp);
2385 result
2386 }
2387
2388 pub fn predictive_entropy_rate(&mut self, data: &[u8]) -> f64 {
2392 if data.len() < 2 {
2393 return 0.0;
2394 }
2395 if self.max_order < 0 {
2396 let candidates: [i64; 8] = [0, 1, 2, 4, 8, 16, 32, 64];
2397 let mut best = f64::INFINITY;
2398 for &mo in &candidates {
2399 if mo as usize >= data.len() {
2400 continue;
2401 }
2402 let h = Self::predictive_entropy_rate_order(data, mo, self.seed);
2403 if h < best {
2404 best = h;
2405 }
2406 }
2407 if best.is_finite() {
2408 return best;
2409 }
2410 }
2411 Self::predictive_entropy_rate_order(data, self.max_order, self.seed)
2412 }
2413
2414 pub fn entropy_rate_cps(&mut self, cps: &[u32]) -> f64 {
2416 if cps.len() < 2 {
2417 return 0.0;
2418 }
2419
2420 self.sam = Sam::new(cps.len());
2421 self.lm_built = false;
2422
2423 let num_chunks = 16;
2424 let chunk_size = cps.len().div_ceil(num_chunks);
2425 let mut total_log_prob = 0.0f64;
2426 let mut count = 0usize;
2427
2428 for i in 0..num_chunks {
2429 let start = i * chunk_size;
2430 let end = ((i + 1) * chunk_size).min(cps.len());
2431 if start >= end {
2432 break;
2433 }
2434 let chunk = &cps[start..end];
2435 if i > 0 {
2436 self.build_lm_no_finalize_endpos();
2438 let mut v = self.sam.text_states[start];
2439 for &ch in chunk {
2440 let sym_idx = self.lm.find_sym(ch);
2441 let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
2442 total_log_prob += p.log2();
2443 count += 1;
2444 v = self.sam.advance(v, ch);
2445 }
2446 }
2447 for &ch in chunk {
2448 self.sam.feed(ch);
2449 }
2450 }
2451
2452 if count == 0 {
2453 self.build_lm();
2454 self.entropy_rate_plugin_cps(cps)
2455 } else {
2456 -total_log_prob / (count as f64)
2457 }
2458 }
2459
2460 #[allow(dead_code)]
2461 fn entropy_rate_plugin_bytes(&mut self, data: &[u8]) -> f64 {
2462 let mut v = 0i32;
2463 let mut total_log_prob = 0.0f64;
2464 let mut count = 0usize;
2465 for t in 0..(data.len() - 1) {
2466 v = self.sam.advance(v, data[t] as u32);
2467 let next_ch = data[t + 1] as u32;
2468 let sym_idx = self.lm.find_sym(next_ch);
2469 let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
2470 total_log_prob += p.log2();
2471 count += 1;
2472 }
2473 if count == 0 {
2474 0.0
2475 } else {
2476 -total_log_prob / (count as f64)
2477 }
2478 }
2479
2480 fn entropy_rate_plugin_cps(&mut self, cps: &[u32]) -> f64 {
2481 let mut v = 0i32;
2482 let mut total_log_prob = 0.0f64;
2483 let mut count = 0usize;
2484 for t in 0..(cps.len() - 1) {
2485 v = self.sam.advance(v, cps[t]);
2486 let next_ch = cps[t + 1];
2487 let sym_idx = self.lm.find_sym(next_ch);
2488 let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
2489 total_log_prob += p.log2();
2490 count += 1;
2491 }
2492 if count == 0 {
2493 0.0
2494 } else {
2495 -total_log_prob / (count as f64)
2496 }
2497 }
2498
2499 pub fn cross_entropy(&self, data: &[u8]) -> f64 {
2501 if !self.lm_built || data.is_empty() {
2502 return 0.0;
2503 }
2504 let mut total_log_prob = 0.0f64;
2505 let mut v = 0i32;
2506 for &b in data {
2507 let ch = b as u32;
2508 let sym_idx = self.lm.find_sym(ch);
2509 let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
2510 total_log_prob += p.log2();
2511 v = self.sam.advance(v, ch);
2512 }
2513 -total_log_prob / (data.len() as f64)
2514 }
2515
2516 pub fn cross_entropy_cps(&self, data: &[u32]) -> f64 {
2518 if !self.lm_built || data.is_empty() {
2519 return 0.0;
2520 }
2521 let mut total_log_prob = 0.0f64;
2522 let mut v = 0i32;
2523 for &ch in data {
2524 let sym_idx = self.lm.find_sym(ch);
2525 let p = self.lm.prob_for_sym(&self.sam, self.max_order, v, sym_idx);
2526 total_log_prob += p.log2();
2527 v = self.sam.advance(v, ch);
2528 }
2529 -total_log_prob / (data.len() as f64)
2530 }
2531
2532 pub fn marginal_distribution(&self) -> Vec<(u32, f64)> {
2535 if self.lm.total_uni == 0 {
2536 return Vec::new();
2537 }
2538
2539 let inv = 1.0 / (self.lm.total_uni as f64);
2540 let mut result = Vec::with_capacity(self.lm.alpha_n as usize);
2541 for i in 0..(self.lm.alpha_n as usize) {
2542 let p = (self.lm.unigram[i] as f64) * inv;
2543 if p > 0.0 {
2544 result.push((self.lm.alphabet[i], p));
2545 }
2546 }
2547 result.sort_by_key(|&(cp, _)| cp);
2548 result
2549 }
2550
2551 pub fn marginal_entropy(&self) -> f64 {
2554 if self.lm.total_uni == 0 {
2555 return 0.0;
2556 }
2557
2558 let inv = 1.0 / (self.lm.total_uni as f64);
2559 let mut h = 0.0f64;
2560 for i in 0..(self.lm.alpha_n as usize) {
2561 let p = (self.lm.unigram[i] as f64) * inv;
2562 if p > 0.0 {
2563 h -= p * p.log2();
2564 }
2565 }
2566 h
2567 }
2568
2569 pub fn save(&self, path: &str) -> std::io::Result<()> {
2571 if !self.lm_built {
2572 return Err(std::io::Error::other("LM not built"));
2573 }
2574
2575 if self.sam.text_states.len() != self.sam.text.len() + 1 {
2578 return Err(std::io::Error::other(
2579 "SAM text_states mismatch (expected text.len()+1)",
2580 ));
2581 }
2582 let mut f = BufWriter::with_capacity(1024 * 1024, File::create(path)?);
2583 f.write_all(MAGIC_V5)?;
2584 f.write_all(&self.max_order.to_le_bytes())?;
2585 f.write_all(&(self.use_eot as i32).to_le_bytes())?;
2586 f.write_all(&self.eot.to_le_bytes())?;
2587 f.write_all(&self.seed.to_le_bytes())?;
2588
2589 write_len64(&mut f, self.sam.st.len())?;
2591 write_len64(&mut f, self.sam.ed.len())?;
2592 write_len64(&mut f, self.sam.text.len())?;
2593 for st in &self.sam.st {
2594 f.write_all(&st.link.to_le_bytes())?;
2595 f.write_all(&st.len.to_le_bytes())?;
2596 f.write_all(&st.endpos.to_le_bytes())?;
2597 f.write_all(&(st.small_n as u32).to_le_bytes())?;
2598 for k in 0..(st.small_n as usize) {
2599 f.write_all(&st.small_ch[k].to_le_bytes())?;
2600 f.write_all(&st.small_to[k].to_le_bytes())?;
2601 }
2602 f.write_all(&st.head.to_le_bytes())?;
2603 }
2604 for e in &self.sam.ed {
2605 f.write_all(&e.ch.to_le_bytes())?;
2606 f.write_all(&e.to.to_le_bytes())?;
2607 f.write_all(&e.next.to_le_bytes())?;
2608 }
2609 write_u32_slice_le(&mut f, &self.sam.text)?;
2610 f.write_all(&self.sam.boundary_after)?;
2611
2612 f.write_all(&self.sam.last.to_le_bytes())?;
2614 write_len64(&mut f, self.sam.text_states.len())?;
2615 write_i32_slice_le(&mut f, &self.sam.text_states)?;
2616
2617 f.write_all(&self.lm.alpha_n.to_le_bytes())?;
2619 f.write_all(&self.lm.total_uni.to_le_bytes())?;
2620 write_len64(&mut f, self.lm.nodes.len())?;
2621 write_u32_slice_le(&mut f, &self.lm.alphabet)?;
2622 write_u64_slice_le(&mut f, &self.lm.unigram)?;
2623 for ls in &self.lm.ls {
2624 f.write_all(&ls.head.to_le_bytes())?;
2625 f.write_all(&ls.total_n.to_le_bytes())?;
2626 f.write_all(&ls.types_t.to_le_bytes())?;
2627 f.write_all(&ls.last_sym.to_le_bytes())?;
2628 f.write_all(&ls.last_node.to_le_bytes())?;
2629 }
2630 for n in &self.lm.nodes {
2631 f.write_all(&n.sym_idx.to_le_bytes())?;
2632 f.write_all(&n.cnt.to_le_bytes())?;
2633 f.write_all(&n.next.to_le_bytes())?;
2634 }
2635 f.flush()?;
2636 Ok(())
2637 }
2638
2639 pub fn load(path: &str) -> std::io::Result<Self> {
2641 let mut f = BufReader::with_capacity(1024 * 1024, File::open(path)?);
2642 let mut magic = vec![0u8; MAGIC_V5.len()];
2643 f.read_exact(&mut magic)?;
2644 if magic != MAGIC_V5 {
2645 return Err(std::io::Error::new(
2646 std::io::ErrorKind::InvalidData,
2647 "bad magic or unsupported ROSA+ model version",
2648 ));
2649 }
2650
2651 let mut b8 = [0u8; 8];
2652 let mut b4 = [0u8; 4];
2653
2654 f.read_exact(&mut b8)?;
2655 let max_order = i64::from_le_bytes(b8);
2656 f.read_exact(&mut b4)?;
2657 let use_eot = i32::from_le_bytes(b4) != 0;
2658 f.read_exact(&mut b4)?;
2659 let eot = u32::from_le_bytes(b4);
2660 f.read_exact(&mut b8)?;
2661 let seed = u64::from_le_bytes(b8);
2662
2663 let mut m = RosaPlus::new(max_order, use_eot, eot as u8, seed);
2664
2665 let st_n = read_len64(&mut f)?;
2667 let ed_n = read_len64(&mut f)?;
2668 let text_n = read_len64(&mut f)?;
2669
2670 m.sam = Sam::new(text_n);
2671 m.sam.st.resize(st_n, SamState::default());
2672 m.sam.ed.resize(ed_n, SamEdge::default());
2673 m.sam.text.resize(text_n, 0u32);
2674 m.sam.boundary_after.resize(text_n, 0u8);
2675
2676 for i in 0..st_n {
2677 f.read_exact(&mut b4)?;
2678 m.sam.st[i].link = i32::from_le_bytes(b4);
2679 f.read_exact(&mut b4)?;
2680 m.sam.st[i].len = i32::from_le_bytes(b4);
2681 f.read_exact(&mut b4)?;
2682 m.sam.st[i].endpos = i32::from_le_bytes(b4);
2683 f.read_exact(&mut b4)?;
2684 let sn = u32::from_le_bytes(b4) as usize;
2685 if sn > SAM_SMALL_MAX {
2686 return Err(std::io::Error::new(
2687 std::io::ErrorKind::InvalidData,
2688 "bad small_n",
2689 ));
2690 }
2691 m.sam.st[i].small_n = sn as u8;
2692 for k in 0..sn {
2693 f.read_exact(&mut b4)?;
2694 m.sam.st[i].small_ch[k] = u32::from_le_bytes(b4);
2695 f.read_exact(&mut b4)?;
2696 m.sam.st[i].small_to[k] = i32::from_le_bytes(b4);
2697 }
2698 f.read_exact(&mut b4)?;
2699 m.sam.st[i].head = u32::from_le_bytes(b4);
2700 }
2701 for i in 0..ed_n {
2702 f.read_exact(&mut b4)?;
2703 m.sam.ed[i].ch = u32::from_le_bytes(b4);
2704 f.read_exact(&mut b4)?;
2705 m.sam.ed[i].to = i32::from_le_bytes(b4);
2706 f.read_exact(&mut b4)?;
2707 m.sam.ed[i].next = u32::from_le_bytes(b4);
2708 }
2709 read_u32_slice_le(&mut f, &mut m.sam.text)?;
2710 f.read_exact(&mut m.sam.boundary_after)?;
2711
2712 f.read_exact(&mut b4)?;
2714 m.sam.last = i32::from_le_bytes(b4);
2715 let text_states_n = read_len64(&mut f)?;
2716 if text_states_n != text_n + 1 {
2717 return Err(std::io::Error::new(
2718 std::io::ErrorKind::InvalidData,
2719 "bad text_states len",
2720 ));
2721 }
2722 m.sam.text_states.resize(text_states_n, 0);
2723 read_i32_slice_le(&mut f, &mut m.sam.text_states)?;
2724 for &v in &m.sam.text_states {
2725 if v < 0 || state_usize(v) >= st_n {
2726 return Err(std::io::Error::new(
2727 std::io::ErrorKind::InvalidData,
2728 "bad text_states entry",
2729 ));
2730 }
2731 }
2732 if m.sam.last < 0 || state_usize(m.sam.last) >= st_n {
2733 return Err(std::io::Error::new(
2734 std::io::ErrorKind::InvalidData,
2735 "bad sam.last",
2736 ));
2737 }
2738 for st in &m.sam.st {
2739 if st.link != SAM_STATE_NONE && state_usize(st.link) >= st_n {
2740 return Err(std::io::Error::new(
2741 std::io::ErrorKind::InvalidData,
2742 "bad sam link",
2743 ));
2744 }
2745 for k in 0..(st.small_n as usize) {
2746 let to = st.small_to[k];
2747 if to < 0 || state_usize(to) >= st_n {
2748 return Err(std::io::Error::new(
2749 std::io::ErrorKind::InvalidData,
2750 "bad sam small edge",
2751 ));
2752 }
2753 }
2754 if st.head != SAM_EDGE_NONE && edge_usize(st.head) >= ed_n {
2755 return Err(std::io::Error::new(
2756 std::io::ErrorKind::InvalidData,
2757 "bad sam edge head",
2758 ));
2759 }
2760 }
2761 for edge in &m.sam.ed {
2762 if edge.to < 0 || state_usize(edge.to) >= st_n {
2763 return Err(std::io::Error::new(
2764 std::io::ErrorKind::InvalidData,
2765 "bad sam edge target",
2766 ));
2767 }
2768 if edge.next != SAM_EDGE_NONE && edge_usize(edge.next) >= ed_n {
2769 return Err(std::io::Error::new(
2770 std::io::ErrorKind::InvalidData,
2771 "bad sam edge next",
2772 ));
2773 }
2774 }
2775 m.sam.rebuild_root_cache();
2776
2777 f.read_exact(&mut b4)?;
2779 let alpha_n = u32::from_le_bytes(b4) as usize;
2780 f.read_exact(&mut b8)?;
2781 let total_uni = u64::from_le_bytes(b8);
2782 let nodes_n = read_len64(&mut f)?;
2783
2784 m.lm = LM::default();
2785 m.lm.alpha_n = alpha_n as u32;
2786 m.lm.total_uni = total_uni;
2787 m.lm.alphabet.resize(alpha_n, 0);
2788 m.lm.unigram.resize(alpha_n, 0);
2789 m.lm.ls = vec![
2790 LmState {
2791 head: LM_NODE_NONE,
2792 last_node: LM_NODE_NONE,
2793 ..LmState::default()
2794 };
2795 st_n
2796 ];
2797 m.lm.nodes.resize(nodes_n, CountNode::default());
2798
2799 read_u32_slice_le(&mut f, &mut m.lm.alphabet)?;
2800 read_u64_slice_le(&mut f, &mut m.lm.unigram)?;
2801 for i in 0..st_n {
2802 f.read_exact(&mut b4)?;
2803 m.lm.ls[i].head = u32::from_le_bytes(b4);
2804 f.read_exact(&mut b8)?;
2805 m.lm.ls[i].total_n = u64::from_le_bytes(b8);
2806 f.read_exact(&mut b4)?;
2807 m.lm.ls[i].types_t = u32::from_le_bytes(b4);
2808 f.read_exact(&mut b4)?;
2809 m.lm.ls[i].last_sym = u32::from_le_bytes(b4);
2810 f.read_exact(&mut b4)?;
2811 m.lm.ls[i].last_node = u32::from_le_bytes(b4);
2812 }
2813 for i in 0..nodes_n {
2814 f.read_exact(&mut b4)?;
2815 let sym_idx = u32::from_le_bytes(b4);
2816 f.read_exact(&mut b8)?;
2817 let cnt = u64::from_le_bytes(b8);
2818 f.read_exact(&mut b4)?;
2819 let next = u32::from_le_bytes(b4);
2820 m.lm.nodes.set(i, CountNode { sym_idx, cnt, next });
2821 }
2822 for ls in &m.lm.ls {
2823 if ls.head != LM_NODE_NONE && node_usize(ls.head) >= nodes_n {
2824 return Err(std::io::Error::new(
2825 std::io::ErrorKind::InvalidData,
2826 "bad lm head",
2827 ));
2828 }
2829 if ls.last_node != LM_NODE_NONE && node_usize(ls.last_node) >= nodes_n {
2830 return Err(std::io::Error::new(
2831 std::io::ErrorKind::InvalidData,
2832 "bad lm last_node",
2833 ));
2834 }
2835 }
2836 for node in &m.lm.nodes {
2837 if node.next != LM_NODE_NONE && node_usize(node.next) >= nodes_n {
2838 return Err(std::io::Error::new(
2839 std::io::ErrorKind::InvalidData,
2840 "bad lm next",
2841 ));
2842 }
2843 }
2844
2845 m.lm.has_byte_map = false;
2847 m.lm.byte_map = [-1; 256];
2848 let mut max_cp = 0u32;
2849 for &v in &m.lm.alphabet {
2850 if v > max_cp {
2851 max_cp = v;
2852 }
2853 }
2854 if max_cp < 256 {
2855 m.lm.has_byte_map = true;
2856 for (i, &c) in m.lm.alphabet.iter().enumerate() {
2857 m.lm.byte_map[c as usize] = i as i16;
2858 }
2859 }
2860
2861 m.lm_built = true;
2862 m.dist.resize(alpha_n, 0.0);
2863 Ok(m)
2864 }
2865
2866 pub fn prob_for_last(&mut self, sym: u32) -> f64 {
2868 if !self.lm_built {
2869 self.build_lm();
2870 }
2871 let v = self.sam.last;
2872 let sym_idx = self.lm.find_sym(sym);
2873 let mo = if self.max_order < 0 {
2874 -1
2875 } else {
2876 self.max_order
2877 };
2878 self.lm.prob_for_sym(&self.sam, mo, v, sym_idx)
2879 }
2880
2881 pub fn fill_probs_for_last_bytes(&mut self, out: &mut [f64]) {
2885 debug_assert!(out.len() >= 256);
2886 if !self.lm_built {
2887 self.build_lm();
2888 }
2889
2890 let v = self.sam.last;
2891 let mo = if self.max_order < 0 {
2892 -1
2893 } else {
2894 self.max_order
2895 };
2896 self.dist.resize(self.lm.alpha_n as usize, 0.0);
2897 self.lm.probs_for_state(&self.sam, mo, v, &mut self.dist);
2898
2899 if self.lm.has_byte_map
2900 && (self.lm.alpha_n as usize) == BYTE_ALPHA_N
2901 && self.lm.alphabet.len() == BYTE_ALPHA_N
2902 {
2903 out[..BYTE_ALPHA_N].copy_from_slice(&self.dist[..BYTE_ALPHA_N]);
2904 return;
2905 }
2906
2907 out[..BYTE_ALPHA_N].fill(0.0);
2908 let mut sum = 0.0;
2909 for (i, &cp) in self.lm.alphabet.iter().enumerate() {
2910 if cp < BYTE_ALPHA_N as u32 {
2911 let p = self.dist[i];
2912 out[cp as usize] = p;
2913 sum += p;
2914 }
2915 }
2916
2917 if sum.is_finite() && sum > 0.0 {
2918 if (sum - 1.0).abs() > 1e-12 {
2919 let inv = 1.0 / sum;
2920 for p in &mut out[..BYTE_ALPHA_N] {
2921 *p *= inv;
2922 }
2923 }
2924 } else {
2925 let u = 1.0 / BYTE_ALPHA_N as f64;
2926 for p in &mut out[..BYTE_ALPHA_N] {
2927 *p = u;
2928 }
2929 }
2930 }
2931}
2932
2933#[cfg(test)]
2934mod tests {
2935 use super::*;
2936 use std::fs;
2937 use std::path::PathBuf;
2938 use std::time::{SystemTime, UNIX_EPOCH};
2939
2940 fn temp_model_path(tag: &str) -> PathBuf {
2941 let nanos = SystemTime::now()
2942 .duration_since(UNIX_EPOCH)
2943 .expect("time went backwards")
2944 .as_nanos();
2945 std::env::temp_dir().join(format!(
2946 "infotheory_rosaplus_{tag}_{}_{}.bin",
2947 std::process::id(),
2948 nanos
2949 ))
2950 }
2951
2952 fn manual_chunked_entropy_rate_bytes(data: &[u8], max_order: i64, seed: u64) -> f64 {
2953 if data.len() < 2 {
2954 return 0.0;
2955 }
2956 let num_chunks = 16;
2957 let chunk_size = data.len().div_ceil(num_chunks);
2958 let mut total_log_prob = 0.0f64;
2959 let mut count = 0usize;
2960
2961 for i in 0..num_chunks {
2962 let start = i * chunk_size;
2963 let end = ((i + 1) * chunk_size).min(data.len());
2964 if start >= end {
2965 break;
2966 }
2967 if i == 0 {
2968 continue;
2969 }
2970
2971 let mut m = RosaPlus::new(max_order, false, 0, seed);
2972 m.train_example(&data[..start]);
2973 m.build_lm();
2974 let mut v = m.sam.last;
2975
2976 for &b in &data[start..end] {
2977 let sym_idx = m.lm.find_sym(b as u32);
2978 let p = m.lm.prob_for_sym(&m.sam, max_order, v, sym_idx);
2979 total_log_prob += p.log2();
2980 count += 1;
2981 v = m.sam.advance(v, b as u32);
2982 }
2983 }
2984
2985 if count == 0 {
2986 let mut m = RosaPlus::new(max_order, false, 0, seed);
2987 m.train_example(data);
2988 m.build_lm();
2989 m.cross_entropy(data)
2990 } else {
2991 -total_log_prob / (count as f64)
2992 }
2993 }
2994
2995 fn manual_chunked_entropy_rate_cps(data: &[u32], max_order: i64, seed: u64) -> f64 {
2996 if data.len() < 2 {
2997 return 0.0;
2998 }
2999 let mut m = RosaPlus::new(max_order, false, 0, seed);
3000 m.sam = Sam::new(data.len());
3001 m.lm_built = false;
3002
3003 let num_chunks = 16;
3004 let chunk_size = data.len().div_ceil(num_chunks);
3005 let mut total_log_prob = 0.0f64;
3006 let mut count = 0usize;
3007
3008 for i in 0..num_chunks {
3009 let start = i * chunk_size;
3010 let end = ((i + 1) * chunk_size).min(data.len());
3011 if start >= end {
3012 break;
3013 }
3014 let chunk = &data[start..end];
3015 if i > 0 {
3016 m.build_lm_no_finalize_endpos();
3017 let mut v = m.sam.text_states[start];
3018 for &ch in chunk {
3019 let sym_idx = m.lm.find_sym(ch);
3020 let p = m.lm.prob_for_sym(&m.sam, max_order, v, sym_idx);
3021 total_log_prob += p.log2();
3022 count += 1;
3023 v = m.sam.advance(v, ch);
3024 }
3025 }
3026 for &ch in chunk {
3027 m.sam.feed(ch);
3028 }
3029 }
3030
3031 if count == 0 {
3032 m.build_lm();
3033 m.entropy_rate_plugin_cps(data)
3034 } else {
3035 -total_log_prob / (count as f64)
3036 }
3037 }
3038
3039 fn prob_for_sym_reference(
3040 lm: &LM,
3041 sam: &Sam,
3042 max_order: i64,
3043 v: SamStateIx,
3044 sym_idx: i32,
3045 ) -> f64 {
3046 if sym_idx < 0 {
3047 return 1.0 / (lm.alpha_n.max(1) as f64);
3048 }
3049 let sym_idx = sym_idx as u32;
3050 let mut p_accum = 0.0f64;
3051 let mut residual = 1.0f64;
3052 let mut u = v;
3053
3054 while u != SAM_STATE_NONE {
3055 if !(max_order >= 0 && (sam.st[state_usize(u)].len as i64) > max_order) {
3056 let n = lm.ls[state_usize(u)].total_n;
3057 let t = lm.ls[state_usize(u)].types_t;
3058 if n > 0 {
3059 let lam = if t > 0 {
3060 (n as f64) / ((n + (t as u64)) as f64)
3061 } else {
3062 1.0
3063 };
3064 let scale = residual * lam;
3065 let mut count_for_sym = 0u64;
3066 let ls = &lm.ls[state_usize(u)];
3067 if LM::ls_is_implicit_single(ls) {
3068 if ls.last_sym == sym_idx {
3069 count_for_sym = n;
3070 }
3071 } else {
3072 let mut ni = ls.head;
3073 while ni != LM_NODE_NONE {
3074 let node = lm.nodes.get(node_usize(ni));
3075 if node.sym_idx == sym_idx {
3076 count_for_sym = node.cnt;
3077 break;
3078 }
3079 ni = node.next;
3080 }
3081 }
3082 if count_for_sym > 0 {
3083 p_accum += scale * (count_for_sym as f64 / n as f64);
3084 }
3085 residual *= 1.0 - lam;
3086 }
3087 }
3088 u = sam.st[state_usize(u)].link;
3089 }
3090
3091 if lm.total_uni > 0 && residual > 0.0 {
3092 let p_uni = lm.unigram[sym_idx as usize] as f64 / lm.total_uni as f64;
3093 p_accum += residual * p_uni;
3094 } else if residual > 0.0 {
3095 p_accum += residual * (1.0 / lm.alpha_n.max(1) as f64);
3096 }
3097
3098 p_accum.clamp(1e-12, 1.0)
3099 }
3100
3101 fn probs_for_state_reference(lm: &LM, sam: &Sam, max_order: i64, v: SamStateIx) -> Vec<f64> {
3102 let mut out = vec![0.0; lm.alpha_n as usize];
3103 let mut residual = 1.0f64;
3104 let mut u = v;
3105 while u != SAM_STATE_NONE {
3106 if !(max_order >= 0 && (sam.st[state_usize(u)].len as i64) > max_order) {
3107 let n = lm.ls[state_usize(u)].total_n;
3108 let t = lm.ls[state_usize(u)].types_t;
3109 if n > 0 {
3110 let lam = if t > 0 {
3111 (n as f64) / ((n + (t as u64)) as f64)
3112 } else {
3113 1.0
3114 };
3115 let scale = residual * lam;
3116 let inv_n = 1.0 / (n as f64);
3117 let ls = &lm.ls[state_usize(u)];
3118 if LM::ls_is_implicit_single(ls) {
3119 out[ls.last_sym as usize] += scale;
3120 } else {
3121 let mut ni = ls.head;
3122 while ni != LM_NODE_NONE {
3123 let node = lm.nodes.get(node_usize(ni));
3124 out[node.sym_idx as usize] += scale * ((node.cnt as f64) * inv_n);
3125 ni = node.next;
3126 }
3127 }
3128 residual *= 1.0 - lam;
3129 }
3130 }
3131 u = sam.st[state_usize(u)].link;
3132 }
3133
3134 if lm.total_uni > 0 && residual > 0.0 {
3135 let inv = 1.0 / (lm.total_uni as f64);
3136 for (slot, &count) in out.iter_mut().zip(lm.unigram.iter()) {
3137 *slot += residual * ((count as f64) * inv);
3138 }
3139 }
3140
3141 let sum: f64 = out.iter().sum();
3142 if sum > 0.0 {
3143 let inv = 1.0 / sum;
3144 for slot in &mut out {
3145 *slot *= inv;
3146 }
3147 } else {
3148 let uprob = 1.0 / (lm.alpha_n.max(1) as f64);
3149 out.fill(uprob);
3150 }
3151 out
3152 }
3153
3154 #[test]
3155 fn rosa_md_example_basic() {
3156 let x = b"ababa";
3158 let mut m = RosaPlus::new(1048576, false, 4, 0);
3159 m.train_example(x);
3160 m.build_lm();
3161 let out = m.generate(b"a", 10).unwrap();
3162 assert!(!out.is_empty());
3163 }
3164
3165 #[test]
3166 fn tx_rollback_restores_sam_and_unigram_counts() {
3167 let mut m = RosaPlus::new(4, false, 0, 123);
3168 m.train_example(b"hello");
3169 m.build_lm_full_bytes_no_finalize_endpos();
3170
3171 let base_text = m.sam.text.clone();
3172 let base_text_len = m.sam.text.len();
3173 let base_total_uni = m.lm.total_uni;
3174 assert!(base_text_len > 0);
3175
3176 let mut tx = m.begin_tx();
3177 m.train_example_tx(&mut tx, b"abc");
3178 assert_eq!(m.lm.total_uni, base_total_uni + 3);
3179 assert_eq!(m.sam.text.len(), base_text_len + 3);
3180
3181 m.rollback_tx(tx);
3182 assert_eq!(m.sam.text, base_text);
3183 assert_eq!(m.lm.total_uni, base_total_uni);
3184 }
3185
3186 #[test]
3187 fn train_sequence_matches_transactional_sequence_update() {
3188 let mut direct = RosaPlus::new(4, false, 0, 123);
3189 direct.build_lm_full_bytes_no_finalize_endpos();
3190 direct.reserve_for_stream(64);
3191 direct.train_sequence(b"abracadabra");
3192 direct.train_sequence(b" mississippi");
3193
3194 let mut tx_model = RosaPlus::new(4, false, 0, 123);
3195 tx_model.build_lm_full_bytes_no_finalize_endpos();
3196 tx_model.reserve_for_stream(64);
3197 let mut tx = tx_model.begin_tx();
3198 tx_model.train_sequence_tx(&mut tx, b"abracadabra");
3199 let mut tx = tx_model.begin_tx();
3200 tx_model.train_sequence_tx(&mut tx, b" mississippi");
3201
3202 assert_eq!(direct.sam.text, tx_model.sam.text);
3203 assert_eq!(direct.sam.text_states, tx_model.sam.text_states);
3204 assert_eq!(direct.sam.boundary_after, tx_model.sam.boundary_after);
3205 assert_eq!(direct.sam.last, tx_model.sam.last);
3206 assert_eq!(direct.lm.total_uni, tx_model.lm.total_uni);
3207 assert_eq!(direct.lm.unigram, tx_model.lm.unigram);
3208 assert_eq!(direct.lm.nodes, tx_model.lm.nodes);
3209 assert_eq!(direct.lm.ls, tx_model.lm.ls);
3210
3211 let mut direct_pdf = [0.0; BYTE_ALPHA_N];
3212 let mut tx_pdf = [0.0; BYTE_ALPHA_N];
3213 direct.fill_probs_for_last_bytes(&mut direct_pdf);
3214 tx_model.fill_probs_for_last_bytes(&mut tx_pdf);
3215 for idx in 0..BYTE_ALPHA_N {
3216 assert!((direct_pdf[idx] - tx_pdf[idx]).abs() < 1e-12);
3217 }
3218 }
3219
3220 #[test]
3221 fn repeated_single_byte_train_byte_matches_transactional_update() {
3222 let data = b"abracadabra mississippi";
3223
3224 let mut direct = RosaPlus::new(4, false, 0, 123);
3225 direct.build_lm_full_bytes_no_finalize_endpos();
3226 for &b in data {
3227 direct.train_byte(b);
3228 }
3229
3230 let mut tx_model = RosaPlus::new(4, false, 0, 123);
3231 tx_model.build_lm_full_bytes_no_finalize_endpos();
3232 for &b in data {
3233 let mut tx = tx_model.begin_tx();
3234 tx_model.train_sequence_tx(&mut tx, &[b]);
3235 }
3236
3237 assert_eq!(direct.sam.text, tx_model.sam.text);
3238 assert_eq!(direct.sam.text_states, tx_model.sam.text_states);
3239 assert_eq!(direct.sam.boundary_after, tx_model.sam.boundary_after);
3240 assert_eq!(direct.sam.last, tx_model.sam.last);
3241 assert_eq!(direct.lm.total_uni, tx_model.lm.total_uni);
3242 assert_eq!(direct.lm.unigram, tx_model.lm.unigram);
3243 assert_eq!(direct.lm.nodes, tx_model.lm.nodes);
3244 assert_eq!(direct.lm.ls, tx_model.lm.ls);
3245 }
3246
3247 #[test]
3248 fn max_order_capping_keeps_probability_semantics() {
3249 let mut m = RosaPlus::new(4, false, 0, 321);
3250 m.build_lm_full_bytes_no_finalize_endpos();
3251 m.train_sequence(b"abracadabra mississippi abracadabra abracadabra");
3252
3253 let v = m.sam.last;
3254 for &sym in b"a mz" {
3255 let sym_idx = m.lm.find_sym(sym as u32);
3256 let expected = prob_for_sym_reference(&m.lm, &m.sam, m.max_order, v, sym_idx);
3257 let got = m.lm.prob_for_sym(&m.sam, m.max_order, v, sym_idx);
3258 assert!(
3259 (got - expected).abs() < 1e-12,
3260 "sym={sym} got={got} expected={expected}"
3261 );
3262 }
3263
3264 let expected = probs_for_state_reference(&m.lm, &m.sam, m.max_order, v);
3265 let mut got = vec![0.0; m.lm.alpha_n as usize];
3266 m.lm.probs_for_state(&m.sam, m.max_order, v, &mut got);
3267 for idx in 0..got.len() {
3268 assert!(
3269 (got[idx] - expected[idx]).abs() < 1e-12,
3270 "idx={idx} got={} expected={}",
3271 got[idx],
3272 expected[idx]
3273 );
3274 }
3275 }
3276
3277 #[test]
3278 fn checkpoint_restore_reverts_append_only_buffers() {
3279 let mut m = RosaPlus::new(3, true, b'\n', 7);
3280 m.train_example(b"aaaa");
3281
3282 let ck = m.checkpoint();
3283 let base_text = m.sam.text.clone();
3284 let base_states = m.sam.text_states.clone();
3285 let base_boundary = m.sam.boundary_after.clone();
3286 let base_last = m.sam.last;
3287
3288 m.train_example(b"bbbb");
3289 assert_ne!(m.sam.text, base_text);
3290
3291 m.restore(&ck);
3292 assert_eq!(m.sam.text, base_text);
3293 assert_eq!(m.sam.text_states, base_states);
3294 assert_eq!(m.sam.boundary_after, base_boundary);
3295 assert_eq!(m.sam.last, base_last);
3296 assert!(!m.lm_built);
3297 }
3298
3299 #[test]
3300 fn predictive_entropy_rate_matches_chunked_reference_fixed_order() {
3301 let data = b"abracadabra abracadabra abracadabra";
3302 let seed = 11;
3303 let expected = manual_chunked_entropy_rate_bytes(data, 4, seed);
3304 let mut m = RosaPlus::new(4, false, 0, seed);
3305 let got = m.predictive_entropy_rate(data);
3306 assert!((got - expected).abs() < 1e-12);
3307 }
3308
3309 #[test]
3310 fn predictive_entropy_rate_uncapped_matches_candidate_search() {
3311 let data = b"the quick brown fox jumps over the lazy dog the quick brown fox";
3312 let seed = 29;
3313 let mut expected = f64::INFINITY;
3314 for &mo in &[0, 1, 2, 4, 8, 16, 32, 64] {
3315 if mo as usize >= data.len() {
3316 continue;
3317 }
3318 expected = expected.min(manual_chunked_entropy_rate_bytes(data, mo, seed));
3319 }
3320 let mut m = RosaPlus::new(-1, false, 0, seed);
3321 let got = m.predictive_entropy_rate(data);
3322 assert!((got - expected).abs() < 1e-12);
3323 }
3324
3325 #[test]
3326 fn entropy_rate_cps_matches_chunked_reference() {
3327 let data = [0u32, 7, 0, 42, 7, 42, 0, 7, 42, 42];
3328 let seed = 31;
3329 let expected = manual_chunked_entropy_rate_cps(&data, -1, seed);
3330 let mut m = RosaPlus::new(-1, false, 0, seed);
3331 let got = m.entropy_rate_cps(&data);
3332 assert!((got - expected).abs() < 1e-12);
3333 }
3334
3335 #[cfg(target_pointer_width = "64")]
3336 #[test]
3337 fn wide_index_helpers_preserve_large_indices() {
3338 let large = (i32::MAX as usize) + 17;
3339 assert_eq!(edge_usize(edge_ix(large)), large);
3340 assert_eq!(node_usize(node_ix(large)), large);
3341 }
3342
3343 #[test]
3344 fn save_load_roundtrip_preserves_state_and_probabilities() {
3345 let path = temp_model_path("roundtrip");
3346 let mut m = RosaPlus::new(8, true, b'\n', 1234);
3347 m.train_example(b"abracadabra");
3348 m.build_lm();
3349 let before_prob = m.prob_for_last(b'a' as u32);
3350 let before_size = m.estimated_size_bytes();
3351 let before_text = m.sam.text.clone();
3352 let before_states = m.sam.text_states.clone();
3353 let before_last = m.sam.last;
3354 let before_nodes = m.lm.nodes.len();
3355 let path_str = path.to_string_lossy().into_owned();
3356
3357 m.save(&path_str).expect("save failed");
3358 let mut loaded = RosaPlus::load(&path_str).expect("load failed");
3359 fs::remove_file(&path).expect("cleanup failed");
3360
3361 assert_eq!(loaded.max_order, m.max_order);
3362 assert_eq!(loaded.use_eot, m.use_eot);
3363 assert_eq!(loaded.eot, m.eot);
3364 assert_eq!(loaded.seed, m.seed);
3365 assert_eq!(loaded.sam.text, before_text);
3366 assert_eq!(loaded.sam.text_states, before_states);
3367 assert_eq!(loaded.sam.last, before_last);
3368 assert_eq!(loaded.lm.nodes.len(), before_nodes);
3369 assert_eq!(loaded.estimated_size_bytes(), before_size);
3370 assert!((loaded.prob_for_last(b'a' as u32) - before_prob).abs() < 1e-12);
3371 }
3372}