1use std::cell::RefCell;
10use std::f64;
11use std::mem::size_of;
12
13type Symbol = bool;
14
15#[inline(always)]
16fn ensure_log_caches(log_int: &mut Vec<f64>, log_half: &mut Vec<f64>, upto: usize) {
17 if upto < log_int.len() {
18 return;
19 }
20 let start = log_int.len();
21 log_int.reserve(upto + 1 - start);
22 log_half.reserve(upto + 1 - start);
23 for n in start..=upto {
24 if n == 0 {
25 log_int.push(f64::NEG_INFINITY);
26 } else {
27 log_int.push((n as f64).ln());
28 }
29 log_half.push((n as f64 + 0.5).ln());
30 }
31}
32
33#[derive(Default)]
34struct SharedLogCache {
35 log_int: Vec<f64>,
36 log_half: Vec<f64>,
37}
38
39impl SharedLogCache {
40 fn new() -> Self {
41 Self {
42 log_int: vec![f64::NEG_INFINITY],
43 log_half: vec![(0.5f64).ln()],
44 }
45 }
46
47 #[inline(always)]
48 fn ensure(&mut self, upto: usize) {
49 ensure_log_caches(&mut self.log_int, &mut self.log_half, upto);
50 }
51
52 #[inline(always)]
53 fn memory_usage(&self) -> usize {
54 self.log_int.capacity() * size_of::<f64>() + self.log_half.capacity() * size_of::<f64>()
55 }
56}
57
58thread_local! {
59 static CTW_SHARED_LOG_CACHE: RefCell<SharedLogCache> =
60 RefCell::new(SharedLogCache::new());
61}
62
63#[inline]
64fn with_shared_log_cache<R>(upto: usize, f: impl FnOnce(&[f64], &[f64]) -> R) -> R {
65 CTW_SHARED_LOG_CACHE.with(|cache_cell| {
66 let mut cache = cache_cell.borrow_mut();
67 cache.ensure(upto);
68 f(&cache.log_int, &cache.log_half)
69 })
70}
71
72#[inline]
73fn shared_log_cache_memory_usage() -> usize {
74 CTW_SHARED_LOG_CACHE.with(|cache_cell| cache_cell.borrow().memory_usage())
75}
76
77#[cfg(test)]
78#[inline]
79fn shared_log_cache_lens() -> (usize, usize) {
80 CTW_SHARED_LOG_CACHE.with(|cache_cell| {
81 let cache = cache_cell.borrow();
82 (cache.log_int.len(), cache.log_half.len())
83 })
84}
85
86#[inline(always)]
87fn history_symbol(history: &[Symbol], depth: usize) -> Symbol {
88 let idx = history.len().wrapping_sub(depth + 1);
89 if depth < history.len() {
90 unsafe { *history.get_unchecked(idx) }
91 } else {
92 false
93 }
94}
95
96#[inline(always)]
97unsafe fn history_at_or_zero(history_ptr: *const Symbol, history_len: isize, idx: isize) -> Symbol {
98 if idx >= 0 && idx < history_len {
99 *history_ptr.add(idx as usize)
100 } else {
101 false
102 }
103}
104
105const INDEX_BITS: u32 = 31;
106const INDEX_LIMIT: usize = 1usize << INDEX_BITS;
107const CHILD_SEGMENT_TAG: u32 = 1u32 << INDEX_BITS;
108const CHILD_INDEX_MASK: u32 = CHILD_SEGMENT_TAG - 1;
109const SEG_META_MODE_SHIFT: u32 = 30;
110const SEG_META_MODE_MASK: u32 = 0b11 << SEG_META_MODE_SHIFT;
111const SEG_LEN_MASK: u32 = !SEG_META_MODE_MASK;
112const SEG_MODE_EXACT: u32 = 0 << SEG_META_MODE_SHIFT;
113const SEG_MODE_HISTORY: u32 = 1 << SEG_META_MODE_SHIFT;
114const SEG_MODE_HISTORY_INVERT: u32 = 2 << SEG_META_MODE_SHIFT;
115const SEG_MODE_CONST: u32 = 3 << SEG_META_MODE_SHIFT;
116const SEG_EXACT_MAX_LEN: u32 = 64;
117
118#[derive(Clone, Copy, Debug, PartialEq, Eq)]
120pub struct NodeIndex(u32);
121
122impl NodeIndex {
123 #[cold]
124 #[inline(never)]
125 fn overflow() -> ! {
126 panic!("ctw node index overflow");
127 }
128
129 #[inline(always)]
130 fn from_usize(idx: usize) -> Self {
131 if idx >= INDEX_LIMIT {
132 Self::overflow();
133 }
134 Self(idx as u32)
135 }
136
137 #[inline(always)]
138 fn get(self) -> usize {
139 self.0 as usize
140 }
141}
142
143#[derive(Clone, Copy, Debug, PartialEq, Eq)]
145struct SegmentIndex(u32);
146
147impl SegmentIndex {
148 #[cold]
149 #[inline(never)]
150 fn overflow() -> ! {
151 panic!("ctw segment index overflow");
152 }
153
154 #[inline(always)]
155 fn from_usize(idx: usize) -> Self {
156 if idx >= INDEX_LIMIT {
157 Self::overflow();
158 }
159 Self(idx as u32)
160 }
161
162 #[inline(always)]
163 fn get(self) -> usize {
164 self.0 as usize
165 }
166}
167
168#[derive(Clone, Copy, Debug, PartialEq, Eq)]
170struct ChildRef(u32);
171
172impl ChildRef {
173 const NONE: ChildRef = ChildRef(u32::MAX);
174
175 #[inline(always)]
176 fn from_node(idx: NodeIndex) -> Self {
177 debug_assert!(idx.0 < CHILD_SEGMENT_TAG);
178 Self(idx.0)
179 }
180
181 #[inline(always)]
182 fn from_segment(idx: SegmentIndex) -> Self {
183 debug_assert!(idx.0 < CHILD_SEGMENT_TAG);
184 Self(CHILD_SEGMENT_TAG | idx.0)
185 }
186
187 #[inline(always)]
188 fn is_none(self) -> bool {
189 self.0 == u32::MAX
190 }
191
192 #[inline(always)]
193 fn is_some(self) -> bool {
194 self.0 != u32::MAX
195 }
196
197 #[inline(always)]
198 fn as_node(self) -> Option<NodeIndex> {
199 if self.is_none() || (self.0 & CHILD_SEGMENT_TAG) != 0 {
200 None
201 } else {
202 Some(NodeIndex(self.0))
203 }
204 }
205
206 #[inline(always)]
207 fn as_segment(self) -> Option<SegmentIndex> {
208 if self.is_none() || (self.0 & CHILD_SEGMENT_TAG) == 0 {
209 None
210 } else {
211 Some(SegmentIndex(self.0 & CHILD_INDEX_MASK))
212 }
213 }
214}
215
216impl Default for ChildRef {
217 fn default() -> Self {
218 Self::NONE
219 }
220}
221
222#[derive(Clone, Copy, Debug, Default)]
223struct SegmentPayload {
224 repr_lo: u32,
225 repr_hi: u32,
226 meta: u32,
227}
228
229impl SegmentPayload {
230 #[inline(always)]
231 fn exact(bits: u64, len: u32) -> Self {
232 debug_assert!(len <= SEG_EXACT_MAX_LEN);
233 debug_assert!(len <= SEG_LEN_MASK);
234 Self {
235 repr_lo: bits as u32,
236 repr_hi: (bits >> 32) as u32,
237 meta: SEG_MODE_EXACT | len,
238 }
239 }
240
241 #[inline(always)]
242 fn history(anchor: u32, len: u32, invert: bool) -> Self {
243 debug_assert!(len <= SEG_LEN_MASK);
244 Self {
245 repr_lo: anchor,
246 repr_hi: 0,
247 meta: if invert {
248 SEG_MODE_HISTORY_INVERT | len
249 } else {
250 SEG_MODE_HISTORY | len
251 },
252 }
253 }
254
255 #[inline(always)]
256 fn constant(bit: bool, len: u32) -> Self {
257 debug_assert!(len <= SEG_LEN_MASK);
258 Self {
259 repr_lo: bit as u32,
260 repr_hi: 0,
261 meta: SEG_MODE_CONST | len,
262 }
263 }
264
265 #[inline(always)]
266 fn len(self) -> u32 {
267 self.meta & SEG_LEN_MASK
268 }
269
270 #[inline(always)]
271 fn set_len(&mut self, len: u32) {
272 debug_assert!(len <= SEG_LEN_MASK);
273 self.meta = (self.meta & SEG_META_MODE_MASK) | len;
274 }
275
276 #[inline(always)]
277 fn mode(self) -> u32 {
278 self.meta & SEG_META_MODE_MASK
279 }
280
281 #[inline(always)]
282 fn is_exact(self) -> bool {
283 self.mode() == SEG_MODE_EXACT
284 }
285
286 #[inline(always)]
287 fn exact_bits(self) -> u64 {
288 (self.repr_lo as u64) | ((self.repr_hi as u64) << 32)
289 }
290
291 #[inline(always)]
292 fn anchor_or_const(self) -> u32 {
293 self.repr_lo
294 }
295
296 #[inline(always)]
297 fn const_bit(self) -> bool {
298 (self.repr_lo & 1) != 0
299 }
300
301 #[inline(always)]
302 fn prepend_exact(self, edge: usize) -> Option<Self> {
303 if !self.is_exact() || self.len() >= SEG_EXACT_MAX_LEN {
304 return None;
305 }
306 let len = self.len() + 1;
307 let bits = ((edge as u64) & 1) | (self.exact_bits() << 1);
308 Some(Self::exact(bits, len))
309 }
310
311 #[inline(always)]
312 fn prefix(self, len: u32) -> Self {
313 debug_assert!(len <= self.len());
314 match self.mode() {
315 SEG_MODE_EXACT => Self::exact(self.exact_bits() & low_bits_mask_u64(len), len),
316 SEG_MODE_HISTORY | SEG_MODE_HISTORY_INVERT => Self {
317 meta: (self.meta & SEG_META_MODE_MASK) | len,
318 ..self
319 },
320 SEG_MODE_CONST => Self::constant(self.const_bit(), len),
321 _ => unreachable!("invalid ctw segment payload mode"),
322 }
323 }
324
325 #[inline(always)]
326 fn suffix_after(self, skip: u32) -> Self {
327 debug_assert!(skip <= self.len());
328 let new_len = self.len() - skip;
329 match self.mode() {
330 SEG_MODE_EXACT => Self::exact(self.exact_bits() >> skip, new_len),
331 SEG_MODE_HISTORY | SEG_MODE_HISTORY_INVERT => Self {
332 repr_lo: self
333 .anchor_or_const()
334 .checked_sub(skip)
335 .expect("ctw history segment anchor underflow"),
336 meta: (self.meta & SEG_META_MODE_MASK) | new_len,
337 ..self
338 },
339 SEG_MODE_CONST => Self::constant(self.const_bit(), new_len),
340 _ => unreachable!("invalid ctw segment payload mode"),
341 }
342 }
343
344 #[inline(always)]
345 fn from_path(history: &[Symbol], depth: usize, len: u32) -> Option<Self> {
346 if len > SEG_EXACT_MAX_LEN {
347 return None;
348 }
349 Some(Self::exact(
350 path_bits_from_history(history, depth, len as usize),
351 len,
352 ))
353 }
354}
355
356#[derive(Clone, Copy, Debug, Default)]
357struct LevelState {
358 symbol_count: [u32; 2],
359 log_prob_kt: f64,
360 sibling: ChildRef,
361}
362
363#[cfg(test)]
364#[allow(dead_code)]
365#[derive(Clone, Copy, Debug)]
366struct PredictEntry {
367 symbol_count: [u32; 2],
368 log_prob_kt: f64,
369 log_prob_weighted: f64,
370 sibling_weight: f64,
371 has_sibling: bool,
372}
373
374#[derive(Clone, Copy, Debug)]
375enum Detach {
376 NodeChild { node: NodeIndex, edge: usize },
377 SegmentNext { segment: SegmentIndex, new_len: u32 },
378}
379
380#[derive(Clone, Copy, Debug, PartialEq, Eq)]
381enum ExistingSource {
382 None,
383 Node(NodeIndex),
384 Segment(SegmentIndex, u32),
385}
386
387#[derive(Clone, Copy, Debug, PartialEq, Eq)]
388enum PreparedEnd {
389 MaxDepth,
390 MissingAtRoot,
391 MissingAfterCurrent,
392 MismatchAtCurrentSegment,
393}
394
395#[derive(Clone, Copy, Debug, PartialEq)]
396struct PreparedStep {
397 source: ExistingSource,
398 counts: [u32; 2],
399 kt_log_prob: f64,
400 span: u32,
401 sibling_weight: f64,
402 has_sibling: u8,
403}
404
405#[derive(Clone, Copy, Debug)]
406pub struct CtNode {
408 children: [ChildRef; 2],
409 log_prob_kt: f64,
410 log_prob_weighted: f64,
411 symbol_count: [u32; 2],
412}
413
414#[derive(Clone, Copy, Debug)]
415struct CtSegment {
416 tail: ChildRef,
417 log_prob_kt: f64,
418 head_log_prob_weighted: f64,
419 symbol_count: [u32; 2],
420 payload: SegmentPayload,
421}
422
423impl Default for CtSegment {
424 fn default() -> Self {
425 Self {
426 tail: ChildRef::NONE,
427 log_prob_kt: 0.0,
428 head_log_prob_weighted: 0.0,
429 symbol_count: [0, 0],
430 payload: SegmentPayload::default(),
431 }
432 }
433}
434
435impl CtSegment {
436 #[inline(always)]
437 fn len(self) -> u32 {
438 self.payload.len()
439 }
440
441 #[inline(always)]
442 fn set_len(&mut self, len: u32) {
443 self.payload.set_len(len);
444 }
445}
446
447#[inline(always)]
448fn low_bits_mask_u64(len: u32) -> u64 {
449 if len >= 64 {
450 u64::MAX
451 } else {
452 (1u64 << len) - 1
453 }
454}
455
456#[inline(always)]
457fn path_bits_from_history(history: &[Symbol], depth: usize, len: usize) -> u64 {
458 let history_len = history.len();
459 let available = history_len.saturating_sub(depth).min(len);
460 if available == 0 {
461 return 0;
462 }
463
464 let mut bits = 0u64;
465 let mut hist_idx = history_len - depth - 1;
466 for offset in 0..available {
467 bits |= (unsafe { *history.get_unchecked(hist_idx) } as u64) << offset;
468 if hist_idx == 0 {
469 break;
470 }
471 hist_idx -= 1;
472 }
473 bits
474}
475
476#[inline(always)]
477fn shift_path_bits(path_bits: u64, consumed: usize) -> u64 {
478 if consumed >= 64 {
479 0
480 } else {
481 path_bits >> consumed
482 }
483}
484
485#[inline(always)]
486fn first_exact_segment_mismatch(
487 exact_bits: u64,
488 path_bits: u64,
489 comparable_len: usize,
490) -> Option<(usize, bool, bool)> {
491 if comparable_len == 0 {
492 return None;
493 }
494
495 let diff = (exact_bits ^ path_bits) & low_bits_mask_u64(comparable_len as u32);
496 if diff == 0 {
497 None
498 } else {
499 let offset = diff.trailing_zeros() as usize;
500 Some((
501 offset,
502 ((path_bits >> offset) & 1) != 0,
503 ((exact_bits >> offset) & 1) != 0,
504 ))
505 }
506}
507
508#[inline(always)]
509fn predict_ratio_kt(counts: [u32; 2], sym_idx: usize) -> f64 {
510 let total = (counts[0] + counts[1]) as f64;
511 let sym_count = counts[sym_idx] as f64;
512 (sym_count + 0.5) / (total + 1.0)
513}
514
515#[inline(always)]
516fn predict_ratio_kt_one(counts: [u32; 2]) -> f64 {
517 let total = (counts[0] + counts[1]) as f64;
518 let sym_count = counts[1] as f64;
519 (sym_count + 0.5) / (total + 1.0)
520}
521
522#[inline(always)]
523fn update_weighted_log_prob_non_leaf(kt_log_prob: f64, log_prob_w0: f64, log_prob_w1: f64) -> f64 {
524 let child_log_prob = log_prob_w0 + log_prob_w1;
525 let delta = child_log_prob - kt_log_prob;
526 let log_prob_weighted = if delta >= 0.0 {
527 child_log_prob + (-delta).exp().ln_1p() - std::f64::consts::LN_2
528 } else {
529 kt_log_prob + delta.exp().ln_1p() - std::f64::consts::LN_2
530 };
531 clamp_log_prob(log_prob_weighted)
532}
533
534#[inline(always)]
535fn update_weighted_log_prob(
536 kt_log_prob: f64,
537 log_prob_w0: f64,
538 log_prob_w1: f64,
539 is_leaf: bool,
540) -> f64 {
541 if is_leaf {
542 clamp_log_prob(kt_log_prob)
543 } else {
544 update_weighted_log_prob_non_leaf(kt_log_prob, log_prob_w0, log_prob_w1)
545 }
546}
547
548#[inline(always)]
549fn clamp_log_prob(log_prob: f64) -> f64 {
550 if log_prob > 1.0e-10 { 0.0 } else { log_prob }
551}
552
553#[inline(always)]
554fn logsumexp_pair(lhs: f64, rhs: f64) -> f64 {
555 if lhs == f64::NEG_INFINITY {
556 return rhs;
557 }
558 if rhs == f64::NEG_INFINITY {
559 return lhs;
560 }
561 let pivot = lhs.max(rhs);
562 pivot + ((lhs - pivot).exp() + (rhs - pivot).exp()).ln()
563}
564
565#[inline(always)]
566fn unary_chain_log_weight(kt_log_prob: f64, continuation_log_prob: f64, len: u32) -> f64 {
567 debug_assert!(len > 0);
568 if kt_log_prob.to_bits() == continuation_log_prob.to_bits() {
569 return kt_log_prob;
570 }
571 let log_alpha = -(len as f64) * std::f64::consts::LN_2;
572 let alpha = log_alpha.exp();
573 let log_kt_mass = kt_log_prob + (-alpha).ln_1p();
574 let log_cont_mass = continuation_log_prob + log_alpha;
575 clamp_log_prob(logsumexp_pair(log_kt_mass, log_cont_mass))
576}
577
578#[inline(always)]
579fn combined_weight_ratio_internal(
580 kt_log_prob: f64,
581 counts: [u32; 2],
582 path_child_log_prob: f64,
583 sibling_log_prob: f64,
584 child_ratio: f64,
585 sym_idx: usize,
586) -> (f64, f64) {
587 let kt_ratio = predict_ratio_kt(counts, sym_idx);
588 let child_log_prob = path_child_log_prob + sibling_log_prob;
589 let delta = child_log_prob - kt_log_prob;
590 if delta >= 0.0 {
591 let x = (-delta).exp();
592 (
593 clamp_log_prob(child_log_prob + x.ln_1p() - std::f64::consts::LN_2),
594 (kt_ratio * x + child_ratio) / (1.0 + x),
595 )
596 } else {
597 let x = delta.exp();
598 (
599 clamp_log_prob(kt_log_prob + x.ln_1p() - std::f64::consts::LN_2),
600 (kt_ratio + x * child_ratio) / (1.0 + x),
601 )
602 }
603}
604
605#[inline(always)]
606fn combined_weight_ratio_internal_one(
607 kt_log_prob: f64,
608 counts: [u32; 2],
609 path_child_log_prob: f64,
610 sibling_log_prob: f64,
611 child_ratio: f64,
612) -> (f64, f64) {
613 let kt_ratio = predict_ratio_kt_one(counts);
614 let child_log_prob = path_child_log_prob + sibling_log_prob;
615 let delta = child_log_prob - kt_log_prob;
616 if delta >= 0.0 {
617 let x = (-delta).exp();
618 (
619 clamp_log_prob(child_log_prob + x.ln_1p() - std::f64::consts::LN_2),
620 (kt_ratio * x + child_ratio) / (1.0 + x),
621 )
622 } else {
623 let x = delta.exp();
624 (
625 clamp_log_prob(kt_log_prob + x.ln_1p() - std::f64::consts::LN_2),
626 (kt_ratio + x * child_ratio) / (1.0 + x),
627 )
628 }
629}
630
631#[inline(always)]
632fn unary_chain_log_weight_precomputed(
633 kt_log_prob: f64,
634 continuation_log_prob: f64,
635 alpha: f64,
636 log_alpha: f64,
637 log_one_minus_alpha: f64,
638) -> f64 {
639 if kt_log_prob.to_bits() == continuation_log_prob.to_bits() {
640 return clamp_log_prob(kt_log_prob);
641 }
642
643 let delta = continuation_log_prob - kt_log_prob;
644 let log_prob_weighted = if delta >= 0.0 {
645 let x = ((1.0 - alpha) / alpha) * (-delta).exp();
646 continuation_log_prob + log_alpha + x.ln_1p()
647 } else {
648 let x = (alpha / (1.0 - alpha)) * delta.exp();
649 kt_log_prob + log_one_minus_alpha + x.ln_1p()
650 };
651 clamp_log_prob(log_prob_weighted)
652}
653
654#[inline(always)]
655fn unary_chain_ratio_transform_precomputed(
656 kt_log_prob: f64,
657 counts: [u32; 2],
658 continuation_log_prob: f64,
659 continuation_ratio: f64,
660 alpha: f64,
661 log_alpha: f64,
662 log_one_minus_alpha: f64,
663 sym_idx: usize,
664) -> (f64, f64) {
665 let kt_ratio = predict_ratio_kt(counts, sym_idx);
666 if kt_log_prob.to_bits() == continuation_log_prob.to_bits()
667 && kt_ratio.to_bits() == continuation_ratio.to_bits()
668 {
669 return (clamp_log_prob(kt_log_prob), kt_ratio);
670 }
671
672 let delta = continuation_log_prob - kt_log_prob;
673 if delta >= 0.0 {
674 let x = ((1.0 - alpha) / alpha) * (-delta).exp();
675 (
676 clamp_log_prob(continuation_log_prob + log_alpha + x.ln_1p()),
677 (kt_ratio * x + continuation_ratio) / (1.0 + x),
678 )
679 } else {
680 let x = (alpha / (1.0 - alpha)) * delta.exp();
681 (
682 clamp_log_prob(kt_log_prob + log_one_minus_alpha + x.ln_1p()),
683 (kt_ratio + x * continuation_ratio) / (1.0 + x),
684 )
685 }
686}
687
688#[inline(always)]
689fn unary_chain_ratio_transform_precomputed_one(
690 kt_log_prob: f64,
691 counts: [u32; 2],
692 continuation_log_prob: f64,
693 continuation_ratio: f64,
694 alpha: f64,
695 log_alpha: f64,
696 log_one_minus_alpha: f64,
697) -> (f64, f64) {
698 let kt_ratio = predict_ratio_kt_one(counts);
699 if kt_log_prob.to_bits() == continuation_log_prob.to_bits()
700 && kt_ratio.to_bits() == continuation_ratio.to_bits()
701 {
702 return (clamp_log_prob(kt_log_prob), kt_ratio);
703 }
704
705 let delta = continuation_log_prob - kt_log_prob;
706 if delta >= 0.0 {
707 let x = ((1.0 - alpha) / alpha) * (-delta).exp();
708 (
709 clamp_log_prob(continuation_log_prob + log_alpha + x.ln_1p()),
710 (kt_ratio * x + continuation_ratio) / (1.0 + x),
711 )
712 } else {
713 let x = (alpha / (1.0 - alpha)) * delta.exp();
714 (
715 clamp_log_prob(kt_log_prob + log_one_minus_alpha + x.ln_1p()),
716 (kt_ratio + x * continuation_ratio) / (1.0 + x),
717 )
718 }
719}
720
721#[inline(always)]
722fn predict_ratio_internal(
723 kt_log_prob: f64,
724 counts: [u32; 2],
725 path_child_log_prob: f64,
726 sibling_log_prob: f64,
727 child_ratio: f64,
728 sym_idx: usize,
729) -> f64 {
730 let kt_ratio = predict_ratio_kt(counts, sym_idx);
731 let delta = path_child_log_prob + sibling_log_prob - kt_log_prob;
732 if delta >= 0.0 {
733 let inv_rho = (-delta).exp();
734 (kt_ratio * inv_rho + child_ratio) / (1.0 + inv_rho)
735 } else {
736 let rho = delta.exp();
737 (kt_ratio + rho * child_ratio) / (1.0 + rho)
738 }
739}
740
741#[inline(always)]
742fn predict_ratio_internal_one(
743 kt_log_prob: f64,
744 counts: [u32; 2],
745 path_child_log_prob: f64,
746 sibling_log_prob: f64,
747 child_ratio: f64,
748) -> f64 {
749 let kt_ratio = predict_ratio_kt_one(counts);
750 let delta = path_child_log_prob + sibling_log_prob - kt_log_prob;
751 if delta >= 0.0 {
752 let inv_rho = (-delta).exp();
753 (kt_ratio * inv_rho + child_ratio) / (1.0 + inv_rho)
754 } else {
755 let rho = delta.exp();
756 (kt_ratio + rho * child_ratio) / (1.0 + rho)
757 }
758}
759
760#[inline(always)]
761fn path_edge_at_depth(history: &[Symbol], history_len: usize, depth: usize) -> bool {
762 if depth < history_len {
763 history[history_len - depth - 1]
764 } else {
765 false
766 }
767}
768
769#[inline(always)]
770fn segment_edge_from_parts(
771 segment: CtSegment,
772 offset: usize,
773 history: &[Symbol],
774 history_len: usize,
775) -> bool {
776 match segment.payload.mode() {
777 SEG_MODE_EXACT => ((segment.payload.exact_bits() >> offset) & 1) != 0,
778 SEG_MODE_HISTORY | SEG_MODE_HISTORY_INVERT => {
779 if segment.payload.anchor_or_const() as usize >= offset {
780 let hist_idx = segment.payload.anchor_or_const() as usize - offset;
781 if hist_idx < history_len {
782 let raw = history[hist_idx];
783 if segment.payload.mode() == SEG_MODE_HISTORY_INVERT {
784 !raw
785 } else {
786 raw
787 }
788 } else {
789 segment.payload.mode() == SEG_MODE_HISTORY_INVERT
790 }
791 } else {
792 segment.payload.mode() == SEG_MODE_HISTORY_INVERT
793 }
794 }
795 SEG_MODE_CONST => segment.payload.const_bit(),
796 _ => unreachable!("invalid ctw segment payload mode"),
797 }
798}
799
800#[inline(always)]
801fn first_segment_mismatch(
802 segment: CtSegment,
803 depth: usize,
804 history: &[Symbol],
805 comparable_len: usize,
806) -> Option<(usize, bool, bool)> {
807 if comparable_len == 0 {
808 return None;
809 }
810
811 match segment.payload.mode() {
812 SEG_MODE_EXACT => first_exact_segment_mismatch(
813 segment.payload.exact_bits(),
814 path_bits_from_history(history, depth, comparable_len),
815 comparable_len,
816 ),
817 SEG_MODE_HISTORY | SEG_MODE_HISTORY_INVERT => {
818 let history_ptr = history.as_ptr();
819 let history_len = history.len() as isize;
820 let mut path_hist_idx = history_len - depth as isize - 1;
821 let mut seg_hist_idx = segment.payload.anchor_or_const() as isize;
822 let invert = segment.payload.mode() == SEG_MODE_HISTORY_INVERT;
823 for offset in 0..comparable_len {
824 let path_edge =
825 unsafe { history_at_or_zero(history_ptr, history_len, path_hist_idx) };
826 let existing_raw =
827 unsafe { history_at_or_zero(history_ptr, history_len, seg_hist_idx) };
828 let existing_edge = if invert { !existing_raw } else { existing_raw };
829 if existing_edge != path_edge {
830 return Some((offset, path_edge, existing_edge));
831 }
832 path_hist_idx -= 1;
833 seg_hist_idx -= 1;
834 }
835 None
836 }
837 SEG_MODE_CONST => {
838 let history_ptr = history.as_ptr();
839 let history_len = history.len() as isize;
840 let mut path_hist_idx = history_len - depth as isize - 1;
841 let existing_edge = segment.payload.const_bit();
842 for offset in 0..comparable_len {
843 let path_edge =
844 unsafe { history_at_or_zero(history_ptr, history_len, path_hist_idx) };
845 if existing_edge != path_edge {
846 return Some((offset, path_edge, existing_edge));
847 }
848 path_hist_idx -= 1;
849 }
850 None
851 }
852 _ => unreachable!("invalid ctw segment payload mode"),
853 }
854}
855
856#[inline]
857fn apply_update_to_state_raw(
858 log_int: &[f64],
859 log_half: &[f64],
860 symbol_count: &mut [u32; 2],
861 log_prob_kt: &mut f64,
862 sym_idx: usize,
863) {
864 let total_before = (symbol_count[0] + symbol_count[1]) as usize;
865 let sym_before = symbol_count[sym_idx] as usize;
866 debug_assert!(sym_before <= total_before);
867 debug_assert!(sym_before < log_half.len());
868 debug_assert!(total_before + 1 < log_int.len());
869 let log_half_before = unsafe { *log_half.get_unchecked(sym_before) };
870 let log_total_after = unsafe { *log_int.get_unchecked(total_before + 1) };
871 *log_prob_kt += log_half_before - log_total_after;
872 if *log_prob_kt > 1.0e-10 {
873 *log_prob_kt = 0.0;
874 }
875 symbol_count[sym_idx] = symbol_count[sym_idx]
876 .checked_add(1)
877 .expect("ctw symbol count overflow");
878}
879
880#[inline]
881fn apply_revert_to_state_raw(
882 log_int: &[f64],
883 log_half: &[f64],
884 symbol_count: &mut [u32; 2],
885 log_prob_kt: &mut f64,
886 sym_idx: usize,
887) {
888 let total = (symbol_count[0] + symbol_count[1]) as usize;
889 let sym_count = symbol_count[sym_idx] as usize;
890 if sym_count > 0 && total > 0 {
891 debug_assert!(sym_count - 1 < log_half.len());
892 debug_assert!(total < log_int.len());
893 let log_half_before = unsafe { *log_half.get_unchecked(sym_count - 1) };
894 let log_total = unsafe { *log_int.get_unchecked(total) };
895 *log_prob_kt -= log_half_before - log_total;
896 symbol_count[sym_idx] -= 1;
897 }
898 if *log_prob_kt > 1.0e-10 {
899 *log_prob_kt = 0.0;
900 }
901}
902
903#[derive(Clone, Debug)]
904pub struct CtArena {
909 nodes: Vec<CtNode>,
910 segments: Vec<CtSegment>,
911 free_nodes: Vec<NodeIndex>,
912 free_segments: Vec<SegmentIndex>,
913}
914
915impl CtArena {
916 pub fn new() -> Self {
918 Self {
919 nodes: Vec::with_capacity(1024),
920 segments: Vec::with_capacity(1024),
921 free_nodes: Vec::new(),
922 free_segments: Vec::new(),
923 }
924 }
925
926 pub fn with_capacity(cap: usize) -> Self {
928 Self {
929 nodes: Vec::with_capacity(cap),
930 segments: Vec::with_capacity(cap / 4 + 1),
931 free_nodes: Vec::new(),
932 free_segments: Vec::new(),
933 }
934 }
935
936 #[inline]
937 pub fn reserve_exact(&mut self, additional: usize) {
939 self.nodes.reserve_exact(additional);
940 self.segments.reserve_exact(additional / 4 + 1);
941 }
942
943 #[inline(always)]
944 fn reset_node_slot(&mut self, idx: NodeIndex) {
945 self.nodes[idx.get()] = CtNode {
946 children: [ChildRef::NONE, ChildRef::NONE],
947 log_prob_kt: 0.0,
948 log_prob_weighted: 0.0,
949 symbol_count: [0, 0],
950 };
951 }
952
953 #[inline(always)]
954 fn reset_segment_slot(&mut self, idx: SegmentIndex) {
955 self.segments[idx.get()] = CtSegment::default();
956 }
957
958 #[inline(always)]
959 fn alloc_node(&mut self) -> NodeIndex {
960 if let Some(idx) = self.free_nodes.pop() {
961 self.reset_node_slot(idx);
962 idx
963 } else {
964 let idx = NodeIndex::from_usize(self.nodes.len());
965 self.nodes.push(CtNode {
966 children: [ChildRef::NONE, ChildRef::NONE],
967 log_prob_kt: 0.0,
968 log_prob_weighted: 0.0,
969 symbol_count: [0, 0],
970 });
971 idx
972 }
973 }
974
975 #[inline(always)]
976 fn alloc_node_with_state(&mut self, symbol_count: [u32; 2], log_prob_kt: f64) -> NodeIndex {
977 let idx = self.alloc_node();
978 self.nodes[idx.get()].symbol_count = symbol_count;
979 self.nodes[idx.get()].log_prob_kt = log_prob_kt;
980 idx
981 }
982
983 #[inline(always)]
984 fn free_node(&mut self, idx: NodeIndex) {
985 self.free_nodes.push(idx);
986 }
987
988 #[inline(always)]
989 fn alloc_segment(&mut self) -> SegmentIndex {
990 if let Some(idx) = self.free_segments.pop() {
991 self.reset_segment_slot(idx);
992 idx
993 } else {
994 let idx = SegmentIndex::from_usize(self.segments.len());
995 self.segments.push(CtSegment::default());
996 idx
997 }
998 }
999
1000 #[inline(always)]
1001 fn free_segment(&mut self, idx: SegmentIndex) {
1002 self.reset_segment_slot(idx);
1003 self.free_segments.push(idx);
1004 }
1005
1006 pub fn clear(&mut self) {
1008 self.nodes.clear();
1009 self.segments.clear();
1010 self.free_nodes.clear();
1011 self.free_segments.clear();
1012 }
1013
1014 #[inline(always)]
1015 fn child(&self, parent_idx: NodeIndex, child_idx: usize) -> ChildRef {
1016 debug_assert!(parent_idx.get() < self.nodes.len());
1017 debug_assert!(child_idx < 2);
1018 unsafe {
1019 *self
1020 .nodes
1021 .get_unchecked(parent_idx.get())
1022 .children
1023 .get_unchecked(child_idx)
1024 }
1025 }
1026
1027 #[inline(always)]
1028 fn set_child(&mut self, parent_idx: NodeIndex, child_idx: usize, child: ChildRef) {
1029 debug_assert!(parent_idx.get() < self.nodes.len());
1030 debug_assert!(child_idx < 2);
1031 unsafe {
1032 *self
1033 .nodes
1034 .get_unchecked_mut(parent_idx.get())
1035 .children
1036 .get_unchecked_mut(child_idx) = child;
1037 }
1038 }
1039
1040 #[inline(always)]
1041 fn set_segment_tail(&mut self, segment_idx: SegmentIndex, child: ChildRef) {
1042 self.segments[segment_idx.get()].tail = child;
1043 }
1044
1045 #[inline(always)]
1046 fn counts(&self, idx: NodeIndex) -> [u32; 2] {
1047 self.nodes[idx.get()].symbol_count
1048 }
1049
1050 #[inline(always)]
1051 fn visits(&self, idx: NodeIndex) -> u32 {
1052 let counts = self.nodes[idx.get()].symbol_count;
1053 counts[0] + counts[1]
1054 }
1055
1056 #[inline(always)]
1057 fn segment_symbol_count(&self, segment_idx: SegmentIndex) -> [u32; 2] {
1058 self.segments[segment_idx.get()].symbol_count
1059 }
1060
1061 #[inline(always)]
1062 fn segment_log_prob_kt(&self, segment_idx: SegmentIndex) -> f64 {
1063 self.segments[segment_idx.get()].log_prob_kt
1064 }
1065
1066 #[inline(always)]
1067 fn segment_len(&self, segment_idx: SegmentIndex) -> u32 {
1068 self.segments[segment_idx.get()].len()
1069 }
1070
1071 #[inline(always)]
1072 fn segment_has_child(&self, segment_idx: SegmentIndex, offset: u32) -> bool {
1073 let segment = self.segments[segment_idx.get()];
1074 offset + 1 < segment.len() || segment.tail.is_some()
1075 }
1076
1077 #[inline(always)]
1078 fn log_prob_weighted(&self, idx: NodeIndex) -> f64 {
1079 self.nodes[idx.get()].log_prob_weighted
1080 }
1081
1082 #[inline(always)]
1083 fn log_prob_kt(&self, idx: NodeIndex) -> f64 {
1084 self.nodes[idx.get()].log_prob_kt
1085 }
1086
1087 #[inline(always)]
1088 unsafe fn child_ref_weighted_unchecked(&self, child: ChildRef) -> f64 {
1089 if child.is_none() {
1090 return 0.0;
1091 }
1092
1093 let raw = child.0;
1094 if (raw & CHILD_SEGMENT_TAG) == 0 {
1095 debug_assert!((raw as usize) < self.nodes.len());
1096 self.nodes.get_unchecked(raw as usize).log_prob_weighted
1097 } else {
1098 let idx = (raw & CHILD_INDEX_MASK) as usize;
1099 debug_assert!(idx < self.segments.len());
1100 self.segments.get_unchecked(idx).head_log_prob_weighted
1101 }
1102 }
1103
1104 #[inline(always)]
1105 fn child_ref_weighted(&self, child: ChildRef) -> f64 {
1106 unsafe { self.child_ref_weighted_unchecked(child) }
1107 }
1108
1109 #[inline(always)]
1110 fn singleton_segment_payload(&self, edge: usize) -> SegmentPayload {
1111 SegmentPayload::exact((edge & 1) as u64, 1)
1112 }
1113
1114 #[inline(always)]
1115 fn segment_edge(&self, segment_idx: SegmentIndex, offset: u32, history: &[Symbol]) -> usize {
1116 let segment = self.segments[segment_idx.get()];
1117 segment_edge_from_parts(segment, offset as usize, history, history.len()) as usize
1118 }
1119
1120 fn segment_suffix_weight(&self, segment_idx: SegmentIndex, offset: u32) -> f64 {
1121 let segment = self.segments[segment_idx.get()];
1122 if offset >= segment.len() {
1123 return self.child_ref_weighted(segment.tail);
1124 }
1125 if segment.tail.is_none() {
1126 return segment.log_prob_kt;
1127 }
1128 let remaining = segment.len() - offset;
1129 unary_chain_log_weight(
1130 segment.log_prob_kt,
1131 self.child_ref_weighted(segment.tail),
1132 remaining,
1133 )
1134 }
1135
1136 #[inline(always)]
1137 fn segment_continuation_weight(&self, segment_idx: SegmentIndex, offset: u32) -> f64 {
1138 let segment = self.segments[segment_idx.get()];
1139 if offset + 1 < segment.len() {
1140 self.segment_suffix_weight(segment_idx, offset + 1)
1141 } else {
1142 self.child_ref_weighted(segment.tail)
1143 }
1144 }
1145
1146 fn recompute_segment_head(&mut self, segment_idx: SegmentIndex) {
1147 let segment = self.segments[segment_idx.get()];
1148 let head = if segment.tail.is_some() {
1149 unary_chain_log_weight(
1150 segment.log_prob_kt,
1151 self.child_ref_weighted(segment.tail),
1152 segment.len(),
1153 )
1154 } else {
1155 segment.log_prob_kt
1156 };
1157 self.segments[segment_idx.get()].head_log_prob_weighted = head;
1158 }
1159
1160 fn recompute_node_weight(&mut self, idx: NodeIndex) {
1161 let slot = idx.get();
1162 debug_assert!(slot < self.nodes.len());
1163 let node = unsafe { *self.nodes.get_unchecked(slot) };
1164 let [left, right] = node.children;
1165 let weighted = if left.is_none() && right.is_none() {
1166 clamp_log_prob(node.log_prob_kt)
1167 } else {
1168 let w0 = unsafe { self.child_ref_weighted_unchecked(left) };
1169 let w1 = unsafe { self.child_ref_weighted_unchecked(right) };
1170 update_weighted_log_prob_non_leaf(node.log_prob_kt, w0, w1)
1171 };
1172 unsafe {
1173 self.nodes.get_unchecked_mut(slot).log_prob_weighted = weighted;
1174 }
1175 }
1176
1177 fn alloc_segment_with_parts(
1178 &mut self,
1179 symbol_count: [u32; 2],
1180 log_prob_kt: f64,
1181 tail: ChildRef,
1182 payload: SegmentPayload,
1183 ) -> SegmentIndex {
1184 let segment_idx = self.alloc_segment();
1185 self.segments[segment_idx.get()] = CtSegment {
1186 tail,
1187 log_prob_kt,
1188 head_log_prob_weighted: 0.0,
1189 symbol_count,
1190 payload,
1191 };
1192 if payload.len() == 1 && tail.is_none() {
1193 self.segments[segment_idx.get()].head_log_prob_weighted = log_prob_kt;
1194 } else {
1195 self.recompute_segment_head(segment_idx);
1196 }
1197 segment_idx
1198 }
1199
1200 fn detach_segment_continuation(
1201 &mut self,
1202 segment_idx: SegmentIndex,
1203 offset: u32,
1204 detaches: &mut Vec<Detach>,
1205 ) -> ChildRef {
1206 let segment = self.segments[segment_idx.get()];
1207 if offset + 1 < segment.len() {
1208 let suffix = self.alloc_segment_with_parts(
1209 segment.symbol_count,
1210 segment.log_prob_kt,
1211 segment.tail,
1212 segment.payload.suffix_after(offset + 1),
1213 );
1214 detaches.push(Detach::SegmentNext {
1215 segment: segment_idx,
1216 new_len: offset + 1,
1217 });
1218 ChildRef::from_segment(suffix)
1219 } else {
1220 let tail = segment.tail;
1221 if tail.is_some() {
1222 detaches.push(Detach::SegmentNext {
1223 segment: segment_idx,
1224 new_len: segment.len(),
1225 });
1226 }
1227 tail
1228 }
1229 }
1230
1231 fn prepend_or_alloc_segment(
1232 &mut self,
1233 history: &[Symbol],
1234 depth: usize,
1235 symbol_count: [u32; 2],
1236 log_prob_kt: f64,
1237 child: ChildRef,
1238 edge: usize,
1239 allow_history_pattern: bool,
1240 ) -> ChildRef {
1241 let singleton_payload = self.singleton_segment_payload(edge);
1242
1243 if let Some(segment_idx) = child.as_segment() {
1244 let segment = self.segments[segment_idx.get()];
1245 let same_state = segment.symbol_count == symbol_count
1246 && segment.log_prob_kt.to_bits() == log_prob_kt.to_bits();
1247 if same_state && segment.tail == child {
1248 let segment = &mut self.segments[segment_idx.get()];
1249 let extended_payload = if segment.payload.is_exact() {
1250 segment.payload.prepend_exact(edge)
1251 } else if allow_history_pattern {
1252 let path_payload =
1253 SegmentPayload::from_path(history, depth, segment.len().saturating_add(1));
1254 path_payload.filter(|payload| {
1255 let mut matches = true;
1256 for offset in 0..segment.len() as usize {
1257 let seg_edge =
1258 segment_edge_from_parts(*segment, offset, history, history.len());
1259 let payload_edge = ((payload.exact_bits() >> (offset + 1)) & 1) != 0;
1260 if seg_edge != payload_edge {
1261 matches = false;
1262 break;
1263 }
1264 }
1265 matches
1266 })
1267 } else {
1268 None
1269 };
1270 if let Some(payload) = extended_payload {
1271 let old_head = segment.head_log_prob_weighted;
1272 segment.payload = payload;
1273 segment.head_log_prob_weighted =
1274 update_weighted_log_prob(log_prob_kt, old_head, 0.0, false);
1275 return ChildRef::from_segment(segment_idx);
1276 }
1277 }
1278 }
1279
1280 let segment_idx =
1281 self.alloc_segment_with_parts(symbol_count, log_prob_kt, child, singleton_payload);
1282 ChildRef::from_segment(segment_idx)
1283 }
1284
1285 fn free_child_ref(&mut self, child: ChildRef) {
1286 let mut stack = Vec::with_capacity(16);
1287 if child.is_some() {
1288 stack.push(child);
1289 }
1290 while let Some(next) = stack.pop() {
1291 if let Some(node_idx) = next.as_node() {
1292 let children = self.nodes[node_idx.get()].children;
1293 if children[0].is_some() {
1294 stack.push(children[0]);
1295 }
1296 if children[1].is_some() {
1297 stack.push(children[1]);
1298 }
1299 self.free_node(node_idx);
1300 } else if let Some(segment_idx) = next.as_segment() {
1301 let tail = self.segments[segment_idx.get()].tail;
1302 if tail.is_some() {
1303 stack.push(tail);
1304 }
1305 self.free_segment(segment_idx);
1306 }
1307 }
1308 }
1309
1310 pub fn memory_usage(&self) -> usize {
1312 self.nodes.capacity() * size_of::<CtNode>()
1313 + self.segments.capacity() * size_of::<CtSegment>()
1314 + self.free_nodes.capacity() * size_of::<NodeIndex>()
1315 + self.free_segments.capacity() * size_of::<SegmentIndex>()
1316 }
1317}
1318
1319impl Default for CtArena {
1320 fn default() -> Self {
1321 Self::new()
1322 }
1323}
1324
1325#[derive(Clone)]
1326struct CtEngine {
1327 arena: CtArena,
1328 root: NodeIndex,
1329 max_depth: usize,
1330 segment_alpha: Vec<f64>,
1331 segment_log_alpha: Vec<f64>,
1332 segment_log_one_minus_alpha: Vec<f64>,
1333 levels: Vec<LevelState>,
1334 detaches: Vec<Detach>,
1335 prepared_steps: Vec<PreparedStep>,
1336 prepared_levels: usize,
1337 prepared_end: PreparedEnd,
1338}
1339
1340impl CtEngine {
1341 const RESERVE_MIN_NODES: usize = 4 * 1024;
1342 const RESERVE_MAX_NODES: usize = 1 << 18;
1343 const HOT_PREFIX_DEPTH: usize = 10;
1344
1345 fn new(depth: usize) -> Self {
1346 let mut arena = CtArena::with_capacity(1024.min(1 << depth.min(16)));
1347 let root = arena.alloc_node();
1348 let mut segment_alpha = Vec::with_capacity(depth + 1);
1349 let mut segment_log_alpha = Vec::with_capacity(depth + 1);
1350 let mut segment_log_one_minus_alpha = Vec::with_capacity(depth + 1);
1351 segment_alpha.push(1.0);
1352 segment_log_alpha.push(0.0);
1353 segment_log_one_minus_alpha.push(f64::NEG_INFINITY);
1354 let mut alpha = 1.0f64;
1355 for len in 1..=depth {
1356 alpha *= 0.5;
1357 segment_alpha.push(alpha);
1358 segment_log_alpha.push(-(len as f64) * std::f64::consts::LN_2);
1359 segment_log_one_minus_alpha.push((-alpha).ln_1p());
1360 }
1361 Self {
1362 arena,
1363 root,
1364 max_depth: depth,
1365 segment_alpha,
1366 segment_log_alpha,
1367 segment_log_one_minus_alpha,
1368 levels: vec![LevelState::default(); depth],
1369 detaches: Vec::with_capacity(depth),
1370 prepared_steps: Vec::with_capacity(depth),
1371 prepared_levels: 0,
1372 prepared_end: PreparedEnd::MaxDepth,
1373 }
1374 }
1375
1376 #[inline(always)]
1377 fn root_visits(&self) -> usize {
1378 self.arena.visits(self.root) as usize
1379 }
1380
1381 #[inline(always)]
1382 fn hot_prefix_depth(&self) -> usize {
1383 self.max_depth.min(Self::HOT_PREFIX_DEPTH)
1384 }
1385
1386 fn clear(&mut self) {
1387 self.arena.clear();
1388 self.root = self.arena.alloc_node();
1389 self.levels.fill(LevelState::default());
1390 self.detaches.clear();
1391 self.prepared_steps.clear();
1392 self.prepared_levels = 0;
1393 self.prepared_end = PreparedEnd::MaxDepth;
1394 }
1395
1396 #[inline]
1397 fn reserve_for_symbols(&mut self, total_symbols: usize) {
1398 if total_symbols == 0 {
1399 return;
1400 }
1401
1402 let depth_scale = self.max_depth.saturating_add(1);
1403 let reserve_nodes = total_symbols
1404 .saturating_div(depth_scale)
1405 .clamp(Self::RESERVE_MIN_NODES, Self::RESERVE_MAX_NODES);
1406 let free_nodes = self
1407 .arena
1408 .nodes
1409 .capacity()
1410 .saturating_sub(self.arena.nodes.len());
1411 if reserve_nodes > free_nodes {
1412 self.arena.reserve_exact(reserve_nodes - free_nodes);
1413 }
1414 }
1415
1416 #[inline]
1417 fn get_log_block_probability(&self) -> f64 {
1418 self.arena.log_prob_weighted(self.root)
1419 }
1420
1421 #[inline]
1422 fn with_logs<R>(&mut self, upto: usize, f: impl FnOnce(&mut Self, &[f64], &[f64]) -> R) -> R {
1423 with_shared_log_cache(upto, |log_int, log_half| f(self, log_int, log_half))
1424 }
1425
1426 #[inline]
1427 fn log_cache_memory_usage(&self) -> usize {
1428 shared_log_cache_memory_usage()
1429 }
1430
1431 #[inline(always)]
1432 fn segment_constants(&self, len: u32) -> (f64, f64, f64) {
1433 let idx = len as usize;
1434 debug_assert!(idx < self.segment_alpha.len());
1435 debug_assert!(idx < self.segment_log_alpha.len());
1436 debug_assert!(idx < self.segment_log_one_minus_alpha.len());
1437 (
1438 unsafe { *self.segment_alpha.get_unchecked(idx) },
1439 unsafe { *self.segment_log_alpha.get_unchecked(idx) },
1440 unsafe { *self.segment_log_one_minus_alpha.get_unchecked(idx) },
1441 )
1442 }
1443
1444 fn build_missing_segment_path(
1445 &mut self,
1446 depth: usize,
1447 history: &[Symbol],
1448 sym_idx: usize,
1449 singleton_log_prob_kt: f64,
1450 ) -> ChildRef {
1451 if depth > self.max_depth {
1452 return ChildRef::NONE;
1453 }
1454
1455 let mut counts = [0u32; 2];
1456 counts[sym_idx] = 1;
1457 let log_prob_kt = singleton_log_prob_kt;
1458 let total_len = self.max_depth - depth + 1;
1459
1460 if let Some(payload) = SegmentPayload::from_path(history, depth, total_len as u32) {
1461 let segment =
1462 self.arena
1463 .alloc_segment_with_parts(counts, log_prob_kt, ChildRef::NONE, payload);
1464 return ChildRef::from_segment(segment);
1465 }
1466
1467 let history_nodes = if depth < history.len() {
1468 (self.max_depth.min(history.len() - 1) - depth) + 1
1469 } else {
1470 0
1471 };
1472 let const_nodes = total_len - history_nodes;
1473
1474 let mut built = ChildRef::NONE;
1475 if const_nodes > 0 {
1476 let const_segment = self.arena.alloc_segment_with_parts(
1477 counts,
1478 log_prob_kt,
1479 ChildRef::NONE,
1480 SegmentPayload::constant(false, const_nodes as u32),
1481 );
1482 built = ChildRef::from_segment(const_segment);
1483 }
1484 if history_nodes > 0 {
1485 let history_segment = self.arena.alloc_segment_with_parts(
1486 counts,
1487 log_prob_kt,
1488 built,
1489 SegmentPayload::history(
1490 (history.len() - depth - 1) as u32,
1491 history_nodes as u32,
1492 false,
1493 ),
1494 );
1495 built = ChildRef::from_segment(history_segment);
1496 }
1497 built
1498 }
1499
1500 fn build_missing_path(
1501 &mut self,
1502 depth: usize,
1503 history: &[Symbol],
1504 sym_idx: usize,
1505 singleton_log_prob_kt: f64,
1506 ) -> ChildRef {
1507 if depth > self.max_depth {
1508 return ChildRef::NONE;
1509 }
1510
1511 let hot_prefix_depth = self.hot_prefix_depth();
1512 if depth > hot_prefix_depth {
1513 return self.build_missing_segment_path(depth, history, sym_idx, singleton_log_prob_kt);
1514 }
1515
1516 let mut counts = [0u32; 2];
1517 counts[sym_idx] = 1;
1518 let mut built = if hot_prefix_depth < self.max_depth {
1519 self.build_missing_segment_path(
1520 hot_prefix_depth + 1,
1521 history,
1522 sym_idx,
1523 singleton_log_prob_kt,
1524 )
1525 } else {
1526 ChildRef::NONE
1527 };
1528
1529 for node_depth in (depth..=hot_prefix_depth).rev() {
1530 let node = self
1531 .arena
1532 .alloc_node_with_state(counts, singleton_log_prob_kt);
1533 if node_depth < self.max_depth {
1534 let edge = history_symbol(history, node_depth) as usize;
1535 self.arena.set_child(node, edge, built);
1536 }
1537 self.arena.recompute_node_weight(node);
1538 built = ChildRef::from_node(node);
1539 }
1540 built
1541 }
1542
1543 #[inline(always)]
1544 fn build_missing_segment_path_exact_bits(
1545 &mut self,
1546 depth: usize,
1547 path_bits: u64,
1548 sym_idx: usize,
1549 singleton_log_prob_kt: f64,
1550 ) -> ChildRef {
1551 debug_assert!(self.max_depth <= SEG_EXACT_MAX_LEN as usize);
1552 if depth > self.max_depth {
1553 return ChildRef::NONE;
1554 }
1555
1556 let mut counts = [0u32; 2];
1557 counts[sym_idx] = 1;
1558 let total_len = self.max_depth - depth + 1;
1559 let payload = SegmentPayload::exact(
1560 path_bits & low_bits_mask_u64(total_len as u32),
1561 total_len as u32,
1562 );
1563 let segment = self.arena.alloc_segment_with_parts(
1564 counts,
1565 singleton_log_prob_kt,
1566 ChildRef::NONE,
1567 payload,
1568 );
1569 ChildRef::from_segment(segment)
1570 }
1571
1572 #[inline(always)]
1573 fn build_missing_path_exact_bits(
1574 &mut self,
1575 depth: usize,
1576 path_bits: u64,
1577 sym_idx: usize,
1578 singleton_log_prob_kt: f64,
1579 ) -> ChildRef {
1580 debug_assert!(self.max_depth <= SEG_EXACT_MAX_LEN as usize);
1581 if depth > self.max_depth {
1582 return ChildRef::NONE;
1583 }
1584
1585 let hot_prefix_depth = self.hot_prefix_depth();
1586 if depth > hot_prefix_depth {
1587 return self.build_missing_segment_path_exact_bits(
1588 depth,
1589 path_bits,
1590 sym_idx,
1591 singleton_log_prob_kt,
1592 );
1593 }
1594
1595 let mut counts = [0u32; 2];
1596 counts[sym_idx] = 1;
1597 let mut built = if hot_prefix_depth < self.max_depth {
1598 self.build_missing_segment_path_exact_bits(
1599 hot_prefix_depth + 1,
1600 shift_path_bits(path_bits, hot_prefix_depth + 1 - depth),
1601 sym_idx,
1602 singleton_log_prob_kt,
1603 )
1604 } else {
1605 ChildRef::NONE
1606 };
1607
1608 for node_depth in (depth..=hot_prefix_depth).rev() {
1609 let node = self
1610 .arena
1611 .alloc_node_with_state(counts, singleton_log_prob_kt);
1612 if node_depth < self.max_depth {
1613 let edge = ((path_bits >> (node_depth - depth)) & 1) as usize;
1614 self.arena.set_child(node, edge, built);
1615 }
1616 self.arena.recompute_node_weight(node);
1617 built = ChildRef::from_node(node);
1618 }
1619 built
1620 }
1621
1622 #[inline(always)]
1623 fn child_to_existing_source(child: ChildRef) -> Option<ExistingSource> {
1624 if let Some(node) = child.as_node() {
1625 Some(ExistingSource::Node(node))
1626 } else if let Some(segment) = child.as_segment() {
1627 Some(ExistingSource::Segment(segment, 0))
1628 } else {
1629 None
1630 }
1631 }
1632
1633 #[inline(always)]
1634 fn update_source_state(
1635 &mut self,
1636 log_int: &[f64],
1637 log_half: &[f64],
1638 source: ExistingSource,
1639 sym_idx: usize,
1640 ) {
1641 match source {
1642 ExistingSource::Node(node_idx) => {
1643 let slot = node_idx.get();
1644 let mut counts = self.arena.nodes[slot].symbol_count;
1645 let mut log_prob_kt = self.arena.nodes[slot].log_prob_kt;
1646 apply_update_to_state_raw(
1647 log_int,
1648 log_half,
1649 &mut counts,
1650 &mut log_prob_kt,
1651 sym_idx,
1652 );
1653 self.arena.nodes[slot].symbol_count = counts;
1654 self.arena.nodes[slot].log_prob_kt = log_prob_kt;
1655 }
1656 ExistingSource::Segment(segment_idx, _) => {
1657 let slot = segment_idx.get();
1658 let mut counts = self.arena.segments[slot].symbol_count;
1659 let mut log_prob_kt = self.arena.segments[slot].log_prob_kt;
1660 apply_update_to_state_raw(
1661 log_int,
1662 log_half,
1663 &mut counts,
1664 &mut log_prob_kt,
1665 sym_idx,
1666 );
1667 self.arena.segments[slot].symbol_count = counts;
1668 self.arena.segments[slot].log_prob_kt = log_prob_kt;
1669 }
1670 ExistingSource::None => unreachable!("prepared update should never visit None"),
1671 }
1672 }
1673
1674 #[inline(always)]
1675 fn recompute_source_weight(&mut self, source: ExistingSource) {
1676 match source {
1677 ExistingSource::Node(node_idx) => self.arena.recompute_node_weight(node_idx),
1678 ExistingSource::Segment(segment_idx, _) => self.recompute_segment_head(segment_idx),
1679 ExistingSource::None => unreachable!("prepared update should never visit None"),
1680 }
1681 }
1682
1683 #[inline(always)]
1684 fn recompute_segment_head(&mut self, segment_idx: SegmentIndex) {
1685 let segment = self.arena.segments[segment_idx.get()];
1686 let head = if segment.tail.is_some() {
1687 let (alpha, log_alpha, log_one_minus_alpha) = self.segment_constants(segment.len());
1688 unary_chain_log_weight_precomputed(
1689 segment.log_prob_kt,
1690 self.arena.child_ref_weighted(segment.tail),
1691 alpha,
1692 log_alpha,
1693 log_one_minus_alpha,
1694 )
1695 } else {
1696 segment.log_prob_kt
1697 };
1698 self.arena.segments[segment_idx.get()].head_log_prob_weighted = head;
1699 }
1700
1701 fn attach_missing_after_prepared_path(
1702 &mut self,
1703 history: &[Symbol],
1704 sym_idx: usize,
1705 singleton_log_prob_kt: f64,
1706 ) {
1707 let Some(last_step) = self.prepared_steps.last().copied() else {
1708 return;
1709 };
1710 let depth = self.prepared_levels;
1711 match last_step.source {
1712 ExistingSource::Node(node_idx) => {
1713 debug_assert!(depth < self.max_depth);
1714 let path_edge = history_symbol(history, depth) as usize;
1715 debug_assert!(self.arena.child(node_idx, path_edge).is_none());
1716 let new_child =
1717 self.build_missing_path(depth + 1, history, sym_idx, singleton_log_prob_kt);
1718 self.arena.set_child(node_idx, path_edge, new_child);
1719 }
1720 ExistingSource::Segment(segment_idx, offset) => {
1721 debug_assert!(depth < self.max_depth);
1722 debug_assert_eq!(offset + 1, self.arena.segment_len(segment_idx));
1723 debug_assert!(self.arena.segments[segment_idx.get()].tail.is_none());
1724 let new_tail =
1725 self.build_missing_path(depth + 1, history, sym_idx, singleton_log_prob_kt);
1726 self.arena.set_segment_tail(segment_idx, new_tail);
1727 }
1728 ExistingSource::None => unreachable!("prepared path should never end in None source"),
1729 }
1730 }
1731
1732 fn replace_prepared_child(
1733 &mut self,
1734 history: &[Symbol],
1735 step_index: usize,
1736 current_start_depth: usize,
1737 new_child: ChildRef,
1738 ) {
1739 if step_index == 0 {
1740 let root_edge = history_symbol(history, 0) as usize;
1741 self.arena.set_child(self.root, root_edge, new_child);
1742 return;
1743 }
1744
1745 match self.prepared_steps[step_index - 1].source {
1746 ExistingSource::Node(node_idx) => {
1747 let edge = history_symbol(history, current_start_depth - 1) as usize;
1748 self.arena.set_child(node_idx, edge, new_child);
1749 }
1750 ExistingSource::Segment(segment_idx, offset) => {
1751 debug_assert_eq!(offset + 1, self.arena.segment_len(segment_idx));
1752 self.arena.set_segment_tail(segment_idx, new_child);
1753 }
1754 ExistingSource::None => unreachable!("prepared path should never parent from None"),
1755 }
1756 }
1757
1758 fn update_prepared_mismatch(
1759 &mut self,
1760 log_int: &[f64],
1761 log_half: &[f64],
1762 history: &[Symbol],
1763 sym_idx: usize,
1764 singleton_log_prob_kt: f64,
1765 ) -> ChildRef {
1766 let last_index = self.prepared_steps.len() - 1;
1767 for idx in 0..last_index {
1768 self.update_source_state(log_int, log_half, self.prepared_steps[idx].source, sym_idx);
1769 }
1770
1771 let last_step = self.prepared_steps[last_index];
1772 let ExistingSource::Segment(segment_idx, offset_u32) = last_step.source else {
1773 unreachable!("prepared segment mismatch must end at a segment");
1774 };
1775
1776 let original = self.arena.segments[segment_idx.get()];
1777 let offset = offset_u32 as usize;
1778 let seg_len = original.len() as usize;
1779 let history_len = history.len();
1780 let current_start_depth = self.prepared_levels - last_step.span as usize + 1;
1781 let node_depth = current_start_depth + offset;
1782 let path_edge = path_edge_at_depth(history, history_len, node_depth);
1783 let existing_edge = segment_edge_from_parts(original, offset, history, history_len);
1784 debug_assert_ne!(path_edge, existing_edge);
1785
1786 let old_continuation = if offset + 1 < seg_len {
1787 if offset == 0 {
1788 let segment = &mut self.arena.segments[segment_idx.get()];
1789 segment.payload = original.payload.suffix_after(1);
1790 segment.tail = original.tail;
1791 segment.symbol_count = original.symbol_count;
1792 segment.log_prob_kt = original.log_prob_kt;
1793 self.recompute_segment_head(segment_idx);
1794 ChildRef::from_segment(segment_idx)
1795 } else {
1796 ChildRef::from_segment(self.arena.alloc_segment_with_parts(
1797 original.symbol_count,
1798 original.log_prob_kt,
1799 original.tail,
1800 original.payload.suffix_after(offset as u32 + 1),
1801 ))
1802 }
1803 } else {
1804 original.tail
1805 };
1806
1807 let new_tail =
1808 self.build_missing_path(node_depth + 1, history, sym_idx, singleton_log_prob_kt);
1809 let mut updated_counts = original.symbol_count;
1810 let mut updated_log_prob_kt = original.log_prob_kt;
1811 apply_update_to_state_raw(
1812 log_int,
1813 log_half,
1814 &mut updated_counts,
1815 &mut updated_log_prob_kt,
1816 sym_idx,
1817 );
1818
1819 let branch = self
1820 .arena
1821 .alloc_node_with_state(updated_counts, updated_log_prob_kt);
1822 self.arena
1823 .set_child(branch, existing_edge as usize, old_continuation);
1824 self.arena.set_child(branch, path_edge as usize, new_tail);
1825 self.arena.recompute_node_weight(branch);
1826
1827 if offset == 0 {
1828 if seg_len == 1 {
1829 self.arena.free_segment(segment_idx);
1830 }
1831 self.replace_prepared_child(
1832 history,
1833 last_index,
1834 current_start_depth,
1835 ChildRef::from_node(branch),
1836 );
1837 } else {
1838 let segment = &mut self.arena.segments[segment_idx.get()];
1839 segment.payload = original.payload.prefix(offset as u32);
1840 segment.tail = ChildRef::from_node(branch);
1841 segment.symbol_count = updated_counts;
1842 segment.log_prob_kt = updated_log_prob_kt;
1843 self.recompute_segment_head(segment_idx);
1844 }
1845
1846 for idx in (0..last_index).rev() {
1847 self.recompute_source_weight(self.prepared_steps[idx].source);
1848 }
1849
1850 let root_edge = history_symbol(history, 0) as usize;
1851 self.arena.child(self.root, root_edge)
1852 }
1853
1854 fn update_prepared_cached_path(
1855 &mut self,
1856 log_int: &[f64],
1857 log_half: &[f64],
1858 history: &[Symbol],
1859 sym_idx: usize,
1860 singleton_log_prob_kt: f64,
1861 ) {
1862 debug_assert!(!self.prepared_steps.is_empty());
1863 debug_assert!(matches!(
1864 self.prepared_end,
1865 PreparedEnd::MaxDepth | PreparedEnd::MissingAfterCurrent
1866 ));
1867
1868 if self.prepared_end == PreparedEnd::MissingAfterCurrent {
1869 self.attach_missing_after_prepared_path(history, sym_idx, singleton_log_prob_kt);
1870 }
1871
1872 let last_index = self.prepared_steps.len() - 1;
1873 let mut child_weight = if self.prepared_end == PreparedEnd::MissingAfterCurrent {
1874 let last_step = self.prepared_steps[last_index];
1875 match last_step.source {
1876 ExistingSource::Node(node_idx) => {
1877 let depth = self.prepared_levels;
1878 let edge = history_symbol(history, depth) as usize;
1879 self.arena
1880 .child_ref_weighted(self.arena.child(node_idx, edge))
1881 }
1882 ExistingSource::Segment(segment_idx, offset) => {
1883 debug_assert_eq!(offset + 1, self.arena.segment_len(segment_idx));
1884 self.arena
1885 .child_ref_weighted(self.arena.segments[segment_idx.get()].tail)
1886 }
1887 ExistingSource::None => unreachable!("prepared path should never end in None"),
1888 }
1889 } else {
1890 0.0
1891 };
1892
1893 for idx in (0..=last_index).rev() {
1894 let step = self.prepared_steps[idx];
1895 match step.source {
1896 ExistingSource::Node(node_idx) => {
1897 let mut counts = step.counts;
1898 let mut log_prob_kt = step.kt_log_prob;
1899 apply_update_to_state_raw(
1900 log_int,
1901 log_half,
1902 &mut counts,
1903 &mut log_prob_kt,
1904 sym_idx,
1905 );
1906 let weighted =
1907 if idx == last_index && self.prepared_end == PreparedEnd::MaxDepth {
1908 debug_assert_eq!(step.has_sibling, 0);
1909 clamp_log_prob(log_prob_kt)
1910 } else {
1911 update_weighted_log_prob(
1912 log_prob_kt,
1913 child_weight,
1914 step.sibling_weight,
1915 false,
1916 )
1917 };
1918 let slot = node_idx.get();
1919 self.arena.nodes[slot].symbol_count = counts;
1920 self.arena.nodes[slot].log_prob_kt = log_prob_kt;
1921 self.arena.nodes[slot].log_prob_weighted = weighted;
1922 child_weight = weighted;
1923 }
1924 ExistingSource::Segment(segment_idx, offset) => {
1925 let mut counts = step.counts;
1926 let mut log_prob_kt = step.kt_log_prob;
1927 apply_update_to_state_raw(
1928 log_int,
1929 log_half,
1930 &mut counts,
1931 &mut log_prob_kt,
1932 sym_idx,
1933 );
1934 let slot = segment_idx.get();
1935 let weighted =
1936 if idx == last_index && self.prepared_end == PreparedEnd::MaxDepth {
1937 debug_assert_eq!(offset + 1, self.arena.segment_len(segment_idx));
1938 debug_assert!(self.arena.segments[slot].tail.is_none());
1939 clamp_log_prob(log_prob_kt)
1940 } else {
1941 let (alpha, log_alpha, log_one_minus_alpha) =
1942 self.segment_constants(self.arena.segments[slot].len());
1943 unary_chain_log_weight_precomputed(
1944 log_prob_kt,
1945 child_weight,
1946 alpha,
1947 log_alpha,
1948 log_one_minus_alpha,
1949 )
1950 };
1951 self.arena.segments[slot].symbol_count = counts;
1952 self.arena.segments[slot].log_prob_kt = log_prob_kt;
1953 self.arena.segments[slot].head_log_prob_weighted = weighted;
1954 child_weight = weighted;
1955 }
1956 ExistingSource::None => unreachable!("prepared update should never visit None"),
1957 }
1958 }
1959 }
1960
1961 fn update_child_fast(
1962 &mut self,
1963 log_int: &[f64],
1964 log_half: &[f64],
1965 child: ChildRef,
1966 depth: usize,
1967 history: &[Symbol],
1968 sym_idx: usize,
1969 singleton_log_prob_kt: f64,
1970 ) -> ChildRef {
1971 if depth > self.max_depth {
1972 return child;
1973 }
1974 if child.is_none() {
1975 return self.build_missing_path(depth, history, sym_idx, singleton_log_prob_kt);
1976 }
1977
1978 if let Some(node_idx) = child.as_node() {
1979 if depth < self.max_depth {
1980 let path_edge = history_symbol(history, depth) as usize;
1981 let next = self.arena.child(node_idx, path_edge);
1982 let updated = self.update_child_fast(
1983 log_int,
1984 log_half,
1985 next,
1986 depth + 1,
1987 history,
1988 sym_idx,
1989 singleton_log_prob_kt,
1990 );
1991 if updated != next {
1992 self.arena.set_child(node_idx, path_edge, updated);
1993 }
1994 }
1995 let mut counts = self.arena.nodes[node_idx.get()].symbol_count;
1996 let mut log_prob_kt = self.arena.nodes[node_idx.get()].log_prob_kt;
1997 apply_update_to_state_raw(log_int, log_half, &mut counts, &mut log_prob_kt, sym_idx);
1998 self.arena.nodes[node_idx.get()].symbol_count = counts;
1999 self.arena.nodes[node_idx.get()].log_prob_kt = log_prob_kt;
2000 self.arena.recompute_node_weight(node_idx);
2001 return ChildRef::from_node(node_idx);
2002 }
2003
2004 let segment_idx = child.as_segment().unwrap();
2005 let original = self.arena.segments[segment_idx.get()];
2006 let seg_len = original.len() as usize;
2007 let mut updated_counts = original.symbol_count;
2008 let mut updated_log_prob_kt = original.log_prob_kt;
2009 apply_update_to_state_raw(
2010 log_int,
2011 log_half,
2012 &mut updated_counts,
2013 &mut updated_log_prob_kt,
2014 sym_idx,
2015 );
2016
2017 let depth_budget = self.max_depth.saturating_sub(depth);
2018 let comparable_len = if original.tail.is_none() {
2019 seg_len.saturating_sub(1)
2020 } else {
2021 seg_len
2022 }
2023 .min(depth_budget);
2024 let mismatch = first_segment_mismatch(original, depth, history, comparable_len).map(
2025 |(offset, path_edge, existing_edge)| (offset, depth + offset, path_edge, existing_edge),
2026 );
2027
2028 if let Some((offset, node_depth, path_edge, existing_edge)) = mismatch {
2029 let old_continuation = if offset + 1 < seg_len {
2030 if offset == 0 {
2031 let segment = &mut self.arena.segments[segment_idx.get()];
2032 segment.payload = original.payload.suffix_after(1);
2033 segment.tail = original.tail;
2034 segment.symbol_count = original.symbol_count;
2035 segment.log_prob_kt = original.log_prob_kt;
2036 self.recompute_segment_head(segment_idx);
2037 ChildRef::from_segment(segment_idx)
2038 } else {
2039 ChildRef::from_segment(self.arena.alloc_segment_with_parts(
2040 original.symbol_count,
2041 original.log_prob_kt,
2042 original.tail,
2043 original.payload.suffix_after(offset as u32 + 1),
2044 ))
2045 }
2046 } else {
2047 original.tail
2048 };
2049
2050 let new_tail =
2051 self.build_missing_path(node_depth + 1, history, sym_idx, singleton_log_prob_kt);
2052 let branch = self
2053 .arena
2054 .alloc_node_with_state(updated_counts, updated_log_prob_kt);
2055 self.arena
2056 .set_child(branch, existing_edge as usize, old_continuation);
2057 self.arena.set_child(branch, path_edge as usize, new_tail);
2058 self.arena.recompute_node_weight(branch);
2059
2060 if offset == 0 {
2061 if offset + 1 >= seg_len {
2062 self.arena.free_segment(segment_idx);
2063 }
2064 return ChildRef::from_node(branch);
2065 }
2066
2067 let segment = &mut self.arena.segments[segment_idx.get()];
2068 segment.payload = original.payload.prefix(offset as u32);
2069 segment.tail = ChildRef::from_node(branch);
2070 segment.symbol_count = updated_counts;
2071 segment.log_prob_kt = updated_log_prob_kt;
2072 self.recompute_segment_head(segment_idx);
2073 return ChildRef::from_segment(segment_idx);
2074 }
2075
2076 if depth_budget < seg_len {
2077 self.arena.segments[segment_idx.get()].symbol_count = updated_counts;
2078 self.arena.segments[segment_idx.get()].log_prob_kt = updated_log_prob_kt;
2079 self.recompute_segment_head(segment_idx);
2080 return ChildRef::from_segment(segment_idx);
2081 }
2082
2083 if original.tail.is_none() {
2084 let new_tail =
2085 self.build_missing_path(depth + seg_len, history, sym_idx, singleton_log_prob_kt);
2086 self.arena.segments[segment_idx.get()].tail = new_tail;
2087 self.arena.segments[segment_idx.get()].symbol_count = updated_counts;
2088 self.arena.segments[segment_idx.get()].log_prob_kt = updated_log_prob_kt;
2089 self.recompute_segment_head(segment_idx);
2090 return ChildRef::from_segment(segment_idx);
2091 }
2092
2093 let tail = original.tail;
2094 let updated_tail = self.update_child_fast(
2095 log_int,
2096 log_half,
2097 tail,
2098 depth + seg_len,
2099 history,
2100 sym_idx,
2101 singleton_log_prob_kt,
2102 );
2103 if updated_tail != tail {
2104 self.arena.set_segment_tail(segment_idx, updated_tail);
2105 }
2106 self.arena.segments[segment_idx.get()].symbol_count = updated_counts;
2107 self.arena.segments[segment_idx.get()].log_prob_kt = updated_log_prob_kt;
2108 self.recompute_segment_head(segment_idx);
2109 ChildRef::from_segment(segment_idx)
2110 }
2111
2112 fn update_child_fast_exact(
2113 &mut self,
2114 log_int: &[f64],
2115 log_half: &[f64],
2116 child: ChildRef,
2117 depth: usize,
2118 history: &[Symbol],
2119 path_bits: u64,
2120 sym_idx: usize,
2121 singleton_log_prob_kt: f64,
2122 ) -> ChildRef {
2123 debug_assert!(self.max_depth <= SEG_EXACT_MAX_LEN as usize);
2124 if depth > self.max_depth {
2125 return child;
2126 }
2127 if child.is_none() {
2128 return self.build_missing_path_exact_bits(
2129 depth,
2130 path_bits,
2131 sym_idx,
2132 singleton_log_prob_kt,
2133 );
2134 }
2135
2136 if let Some(node_idx) = child.as_node() {
2137 if depth < self.max_depth {
2138 let path_edge = (path_bits & 1) as usize;
2139 let next = self.arena.child(node_idx, path_edge);
2140 let updated = self.update_child_fast_exact(
2141 log_int,
2142 log_half,
2143 next,
2144 depth + 1,
2145 history,
2146 shift_path_bits(path_bits, 1),
2147 sym_idx,
2148 singleton_log_prob_kt,
2149 );
2150 if updated != next {
2151 self.arena.set_child(node_idx, path_edge, updated);
2152 }
2153 }
2154 let mut counts = self.arena.nodes[node_idx.get()].symbol_count;
2155 let mut log_prob_kt = self.arena.nodes[node_idx.get()].log_prob_kt;
2156 apply_update_to_state_raw(log_int, log_half, &mut counts, &mut log_prob_kt, sym_idx);
2157 self.arena.nodes[node_idx.get()].symbol_count = counts;
2158 self.arena.nodes[node_idx.get()].log_prob_kt = log_prob_kt;
2159 self.arena.recompute_node_weight(node_idx);
2160 return ChildRef::from_node(node_idx);
2161 }
2162
2163 let segment_idx = child.as_segment().unwrap();
2164 let original = self.arena.segments[segment_idx.get()];
2165 if !original.payload.is_exact() {
2166 return self.update_child_fast(
2167 log_int,
2168 log_half,
2169 child,
2170 depth,
2171 history,
2172 sym_idx,
2173 singleton_log_prob_kt,
2174 );
2175 }
2176
2177 let seg_len = original.len() as usize;
2178 let mut updated_counts = original.symbol_count;
2179 let mut updated_log_prob_kt = original.log_prob_kt;
2180 apply_update_to_state_raw(
2181 log_int,
2182 log_half,
2183 &mut updated_counts,
2184 &mut updated_log_prob_kt,
2185 sym_idx,
2186 );
2187
2188 let depth_budget = self.max_depth.saturating_sub(depth);
2189 let comparable_len = if original.tail.is_none() {
2190 seg_len.saturating_sub(1)
2191 } else {
2192 seg_len
2193 }
2194 .min(depth_budget);
2195 let mismatch =
2196 first_exact_segment_mismatch(original.payload.exact_bits(), path_bits, comparable_len)
2197 .map(|(offset, path_edge, existing_edge)| {
2198 (offset, depth + offset, path_edge, existing_edge)
2199 });
2200
2201 if let Some((offset, node_depth, path_edge, existing_edge)) = mismatch {
2202 let old_continuation = if offset + 1 < seg_len {
2203 if offset == 0 {
2204 let segment = &mut self.arena.segments[segment_idx.get()];
2205 segment.payload = original.payload.suffix_after(1);
2206 segment.tail = original.tail;
2207 segment.symbol_count = original.symbol_count;
2208 segment.log_prob_kt = original.log_prob_kt;
2209 self.recompute_segment_head(segment_idx);
2210 ChildRef::from_segment(segment_idx)
2211 } else {
2212 ChildRef::from_segment(self.arena.alloc_segment_with_parts(
2213 original.symbol_count,
2214 original.log_prob_kt,
2215 original.tail,
2216 original.payload.suffix_after(offset as u32 + 1),
2217 ))
2218 }
2219 } else {
2220 original.tail
2221 };
2222
2223 let new_tail = self.build_missing_path_exact_bits(
2224 node_depth + 1,
2225 shift_path_bits(path_bits, offset + 1),
2226 sym_idx,
2227 singleton_log_prob_kt,
2228 );
2229 let branch = self
2230 .arena
2231 .alloc_node_with_state(updated_counts, updated_log_prob_kt);
2232 self.arena
2233 .set_child(branch, existing_edge as usize, old_continuation);
2234 self.arena.set_child(branch, path_edge as usize, new_tail);
2235 self.arena.recompute_node_weight(branch);
2236
2237 if offset == 0 {
2238 if offset + 1 >= seg_len {
2239 self.arena.free_segment(segment_idx);
2240 }
2241 return ChildRef::from_node(branch);
2242 }
2243
2244 let segment = &mut self.arena.segments[segment_idx.get()];
2245 segment.payload = original.payload.prefix(offset as u32);
2246 segment.tail = ChildRef::from_node(branch);
2247 segment.symbol_count = updated_counts;
2248 segment.log_prob_kt = updated_log_prob_kt;
2249 self.recompute_segment_head(segment_idx);
2250 return ChildRef::from_segment(segment_idx);
2251 }
2252
2253 if depth_budget < seg_len {
2254 self.arena.segments[segment_idx.get()].symbol_count = updated_counts;
2255 self.arena.segments[segment_idx.get()].log_prob_kt = updated_log_prob_kt;
2256 self.recompute_segment_head(segment_idx);
2257 return ChildRef::from_segment(segment_idx);
2258 }
2259
2260 if original.tail.is_none() {
2261 let new_tail = self.build_missing_path_exact_bits(
2262 depth + seg_len,
2263 shift_path_bits(path_bits, seg_len),
2264 sym_idx,
2265 singleton_log_prob_kt,
2266 );
2267 self.arena.segments[segment_idx.get()].tail = new_tail;
2268 self.arena.segments[segment_idx.get()].symbol_count = updated_counts;
2269 self.arena.segments[segment_idx.get()].log_prob_kt = updated_log_prob_kt;
2270 self.recompute_segment_head(segment_idx);
2271 return ChildRef::from_segment(segment_idx);
2272 }
2273
2274 let tail = original.tail;
2275 let updated_tail = self.update_child_fast_exact(
2276 log_int,
2277 log_half,
2278 tail,
2279 depth + seg_len,
2280 history,
2281 shift_path_bits(path_bits, seg_len),
2282 sym_idx,
2283 singleton_log_prob_kt,
2284 );
2285 if updated_tail != tail {
2286 self.arena.set_segment_tail(segment_idx, updated_tail);
2287 }
2288 self.arena.segments[segment_idx.get()].symbol_count = updated_counts;
2289 self.arena.segments[segment_idx.get()].log_prob_kt = updated_log_prob_kt;
2290 self.recompute_segment_head(segment_idx);
2291 ChildRef::from_segment(segment_idx)
2292 }
2293
2294 #[inline(always)]
2295 fn update_root_child(
2296 &mut self,
2297 log_int: &[f64],
2298 log_half: &[f64],
2299 child: ChildRef,
2300 history: &[Symbol],
2301 sym_idx: usize,
2302 singleton_log_prob_kt: f64,
2303 ) -> ChildRef {
2304 if self.max_depth <= SEG_EXACT_MAX_LEN as usize {
2305 let path_bits = path_bits_from_history(history, 1, self.max_depth);
2306 self.update_child_fast_exact(
2307 log_int,
2308 log_half,
2309 child,
2310 1,
2311 history,
2312 path_bits,
2313 sym_idx,
2314 singleton_log_prob_kt,
2315 )
2316 } else {
2317 self.update_child_fast(
2318 log_int,
2319 log_half,
2320 child,
2321 1,
2322 history,
2323 sym_idx,
2324 singleton_log_prob_kt,
2325 )
2326 }
2327 }
2328
2329 fn collect_existing_levels(&mut self, history: &[Symbol]) -> ChildRef {
2330 if self.max_depth == 0 {
2331 self.detaches.clear();
2332 return ChildRef::NONE;
2333 }
2334
2335 self.detaches.clear();
2336 self.levels.fill(LevelState::default());
2337
2338 let root_edge = history_symbol(history, 0) as usize;
2339 let old_child = self.arena.child(self.root, root_edge);
2340 let mut source = if let Some(node) = old_child.as_node() {
2341 ExistingSource::Node(node)
2342 } else if let Some(segment) = old_child.as_segment() {
2343 ExistingSource::Segment(segment, 0)
2344 } else {
2345 ExistingSource::None
2346 };
2347
2348 for depth in 1..=self.max_depth {
2349 let slot = depth - 1;
2350 self.levels[slot] = LevelState::default();
2351
2352 match source {
2353 ExistingSource::None => {}
2354 ExistingSource::Node(node_idx) => {
2355 self.levels[slot].symbol_count = self.arena.counts(node_idx);
2356 self.levels[slot].log_prob_kt = self.arena.log_prob_kt(node_idx);
2357 if depth < self.max_depth {
2358 let path_edge = history_symbol(history, depth) as usize;
2359 let sibling_edge = path_edge ^ 1;
2360 let sibling = self.arena.child(node_idx, sibling_edge);
2361 self.levels[slot].sibling = sibling;
2362 if sibling.is_some() {
2363 self.detaches.push(Detach::NodeChild {
2364 node: node_idx,
2365 edge: sibling_edge,
2366 });
2367 }
2368 let next = self.arena.child(node_idx, path_edge);
2369 source = if let Some(next_node) = next.as_node() {
2370 ExistingSource::Node(next_node)
2371 } else if let Some(next_segment) = next.as_segment() {
2372 ExistingSource::Segment(next_segment, 0)
2373 } else {
2374 ExistingSource::None
2375 };
2376 }
2377 }
2378 ExistingSource::Segment(segment_idx, offset) => {
2379 self.levels[slot].symbol_count = self.arena.segment_symbol_count(segment_idx);
2380 self.levels[slot].log_prob_kt = self.arena.segment_log_prob_kt(segment_idx);
2381 if depth < self.max_depth {
2382 let path_edge = history_symbol(history, depth) as usize;
2383 if self.arena.segment_has_child(segment_idx, offset) {
2384 let existing_edge =
2385 self.arena.segment_edge(segment_idx, offset, history);
2386 if path_edge == existing_edge {
2387 let seg_len = self.arena.segment_len(segment_idx);
2388 if offset + 1 < seg_len {
2389 source = ExistingSource::Segment(segment_idx, offset + 1);
2390 } else {
2391 let tail = self.arena.segments[segment_idx.get()].tail;
2392 source = if let Some(next_node) = tail.as_node() {
2393 ExistingSource::Node(next_node)
2394 } else if let Some(next_segment) = tail.as_segment() {
2395 ExistingSource::Segment(next_segment, 0)
2396 } else {
2397 ExistingSource::None
2398 };
2399 }
2400 } else {
2401 let continuation = self.arena.detach_segment_continuation(
2402 segment_idx,
2403 offset,
2404 &mut self.detaches,
2405 );
2406 self.levels[slot].sibling = continuation;
2407 source = ExistingSource::None;
2408 }
2409 } else {
2410 source = ExistingSource::None;
2411 }
2412 }
2413 }
2414 }
2415 }
2416
2417 old_child
2418 }
2419
2420 fn rebuild_path_subtree(&mut self, history: &[Symbol]) -> ChildRef {
2421 let mut built = ChildRef::NONE;
2422
2423 for depth in (1..=self.max_depth).rev() {
2424 let level = self.levels[depth - 1];
2425 let visits = level.symbol_count[0] + level.symbol_count[1];
2426 if visits == 0 {
2427 built = ChildRef::NONE;
2428 continue;
2429 }
2430
2431 let path_edge = if depth < self.max_depth {
2432 history_symbol(history, depth) as usize
2433 } else {
2434 0
2435 };
2436 let has_path_child = built.is_some();
2437 let has_sibling = level.sibling.is_some();
2438 let force_node = depth <= self.hot_prefix_depth();
2439
2440 if force_node || (has_path_child && has_sibling) {
2441 let node = self
2442 .arena
2443 .alloc_node_with_state(level.symbol_count, level.log_prob_kt);
2444 if has_path_child {
2445 self.arena.set_child(node, path_edge, built);
2446 }
2447 if has_sibling {
2448 self.arena.set_child(node, path_edge ^ 1, level.sibling);
2449 }
2450 self.arena.recompute_node_weight(node);
2451 built = ChildRef::from_node(node);
2452 } else {
2453 let (edge, child) = if has_path_child {
2454 (path_edge, built)
2455 } else if has_sibling {
2456 (path_edge ^ 1, level.sibling)
2457 } else {
2458 (path_edge, ChildRef::NONE)
2459 };
2460 built = self.arena.prepend_or_alloc_segment(
2461 history,
2462 depth,
2463 level.symbol_count,
2464 level.log_prob_kt,
2465 child,
2466 edge,
2467 false,
2468 );
2469 }
2470 }
2471
2472 built
2473 }
2474
2475 fn apply_detaches(&mut self) {
2476 for detach in self.detaches.drain(..) {
2477 match detach {
2478 Detach::NodeChild { node, edge } => {
2479 self.arena.set_child(node, edge, ChildRef::NONE);
2480 }
2481 Detach::SegmentNext { segment, new_len } => {
2482 self.arena.segments[segment.get()].set_len(new_len);
2483 self.arena.set_segment_tail(segment, ChildRef::NONE);
2484 }
2485 }
2486 }
2487 }
2488
2489 fn update_with_logs(
2490 &mut self,
2491 log_int: &[f64],
2492 log_half: &[f64],
2493 sym: Symbol,
2494 history: &[Symbol],
2495 ) {
2496 let sym_idx = sym as usize;
2497 let singleton_log_prob_kt = log_half[0] - log_int[1];
2498 {
2499 let slot = self.root.get();
2500 let mut counts = self.arena.nodes[slot].symbol_count;
2501 let mut log_prob_kt = self.arena.nodes[slot].log_prob_kt;
2502 apply_update_to_state_raw(log_int, log_half, &mut counts, &mut log_prob_kt, sym_idx);
2503 self.arena.nodes[slot].symbol_count = counts;
2504 self.arena.nodes[slot].log_prob_kt = log_prob_kt;
2505 }
2506
2507 if self.max_depth > 0 {
2508 let root_edge = history_symbol(history, 0) as usize;
2509 let old_child = self.arena.child(self.root, root_edge);
2510 let new_child = self.update_root_child(
2511 log_int,
2512 log_half,
2513 old_child,
2514 history,
2515 sym_idx,
2516 singleton_log_prob_kt,
2517 );
2518 self.arena.set_child(self.root, root_edge, new_child);
2519 }
2520
2521 self.arena.recompute_node_weight(self.root);
2522 }
2523
2524 fn update(&mut self, sym: Symbol, history: &[Symbol]) {
2525 let upto = self.root_visits() + 1;
2526 self.with_logs(upto, |this, log_int, log_half| {
2527 this.update_with_logs(log_int, log_half, sym, history);
2528 });
2529 }
2530
2531 fn update_prepared(&mut self, sym: Symbol, history: &[Symbol], use_prepared: bool) {
2532 let upto = self.root_visits() + 1;
2533 let sym_idx = sym as usize;
2534 self.with_logs(upto, |this, log_int, log_half| {
2535 let singleton_log_prob_kt = log_half[0] - log_int[1];
2536 {
2537 let slot = this.root.get();
2538 let mut counts = this.arena.nodes[slot].symbol_count;
2539 let mut log_prob_kt = this.arena.nodes[slot].log_prob_kt;
2540 apply_update_to_state_raw(
2541 log_int,
2542 log_half,
2543 &mut counts,
2544 &mut log_prob_kt,
2545 sym_idx,
2546 );
2547 this.arena.nodes[slot].symbol_count = counts;
2548 this.arena.nodes[slot].log_prob_kt = log_prob_kt;
2549 }
2550
2551 if this.max_depth > 0 {
2552 let root_edge = history_symbol(history, 0) as usize;
2553 let old_child = this.arena.child(this.root, root_edge);
2554 let new_child = if use_prepared {
2555 match this.prepared_end {
2556 PreparedEnd::MissingAtRoot => {
2557 this.build_missing_path(1, history, sym_idx, singleton_log_prob_kt)
2558 }
2559 PreparedEnd::MaxDepth | PreparedEnd::MissingAfterCurrent => {
2560 if !this.prepared_steps.is_empty() {
2561 this.update_prepared_cached_path(
2562 log_int,
2563 log_half,
2564 history,
2565 sym_idx,
2566 singleton_log_prob_kt,
2567 );
2568 }
2569 old_child
2570 }
2571 PreparedEnd::MismatchAtCurrentSegment => this.update_prepared_mismatch(
2572 log_int,
2573 log_half,
2574 history,
2575 sym_idx,
2576 singleton_log_prob_kt,
2577 ),
2578 }
2579 } else {
2580 this.update_root_child(
2581 log_int,
2582 log_half,
2583 old_child,
2584 history,
2585 sym_idx,
2586 singleton_log_prob_kt,
2587 )
2588 };
2589 this.arena.set_child(this.root, root_edge, new_child);
2590 }
2591
2592 this.arena.recompute_node_weight(this.root);
2593 });
2594 }
2595
2596 fn revert(&mut self, sym: Symbol, history: &[Symbol]) {
2597 let upto = self.root_visits();
2598 let sym_idx = sym as usize;
2599 self.with_logs(upto, |this, log_int, log_half| {
2600 let old_child = this.collect_existing_levels(history);
2601
2602 {
2603 let slot = this.root.get();
2604 let mut counts = this.arena.nodes[slot].symbol_count;
2605 let mut log_prob_kt = this.arena.nodes[slot].log_prob_kt;
2606 apply_revert_to_state_raw(
2607 log_int,
2608 log_half,
2609 &mut counts,
2610 &mut log_prob_kt,
2611 sym_idx,
2612 );
2613 this.arena.nodes[slot].symbol_count = counts;
2614 this.arena.nodes[slot].log_prob_kt = log_prob_kt;
2615 }
2616
2617 for level in &mut this.levels {
2618 let mut counts = level.symbol_count;
2619 let mut log_prob_kt = level.log_prob_kt;
2620 apply_revert_to_state_raw(
2621 log_int,
2622 log_half,
2623 &mut counts,
2624 &mut log_prob_kt,
2625 sym_idx,
2626 );
2627 level.symbol_count = counts;
2628 level.log_prob_kt = log_prob_kt;
2629 }
2630
2631 if this.max_depth > 0 {
2632 let new_child = this.rebuild_path_subtree(history);
2633 let root_edge = history_symbol(history, 0) as usize;
2634 this.apply_detaches();
2635 this.arena.free_child_ref(old_child);
2636 this.arena.set_child(this.root, root_edge, new_child);
2637 }
2638
2639 this.arena.recompute_node_weight(this.root);
2640 });
2641 }
2642
2643 fn predict(&mut self, sym: Symbol, history: &[Symbol]) -> f64 {
2644 self.prepared_steps.clear();
2645 self.prepared_levels = 0;
2646 self.prepared_end = PreparedEnd::MaxDepth;
2647
2648 let (root_sibling, root_has_sibling, mut source) = if self.max_depth > 0 {
2649 let root_edge = history_symbol(history, 0) as usize;
2650 let path_child = self.arena.child(self.root, root_edge);
2651 let sibling = self.arena.child(self.root, root_edge ^ 1);
2652 (
2653 self.arena.child_ref_weighted(sibling),
2654 sibling.is_some() as u8,
2655 Self::child_to_existing_source(path_child).unwrap_or(ExistingSource::None),
2656 )
2657 } else {
2658 (0.0, 0, ExistingSource::None)
2659 };
2660 if self.max_depth > 0 && matches!(source, ExistingSource::None) {
2661 self.prepared_end = PreparedEnd::MissingAtRoot;
2662 }
2663
2664 let history_len = history.len();
2665 let mut depth = 1usize;
2666 'walk: while depth <= self.max_depth {
2667 match source {
2668 ExistingSource::None => break,
2669 ExistingSource::Node(node_idx) => {
2670 let slot = node_idx.get();
2671 let counts = self.arena.nodes[slot].symbol_count;
2672 let kt_log_prob = self.arena.nodes[slot].log_prob_kt;
2673 if depth == self.max_depth {
2674 self.prepared_steps.push(PreparedStep {
2675 source: ExistingSource::Node(node_idx),
2676 counts,
2677 kt_log_prob,
2678 span: 1,
2679 sibling_weight: 0.0,
2680 has_sibling: 0,
2681 });
2682 self.prepared_levels += 1;
2683 break;
2684 }
2685 let path_edge = history_symbol(history, depth) as usize;
2686 let sibling = self.arena.child(node_idx, path_edge ^ 1);
2687 self.prepared_steps.push(PreparedStep {
2688 source: ExistingSource::Node(node_idx),
2689 counts,
2690 kt_log_prob,
2691 span: 1,
2692 sibling_weight: self.arena.child_ref_weighted(sibling),
2693 has_sibling: sibling.is_some() as u8,
2694 });
2695 self.prepared_levels += 1;
2696 let next = self.arena.child(node_idx, path_edge);
2697 source = Self::child_to_existing_source(next).unwrap_or(ExistingSource::None);
2698 if matches!(source, ExistingSource::None) {
2699 self.prepared_end = PreparedEnd::MissingAfterCurrent;
2700 break;
2701 }
2702 depth += 1;
2703 }
2704 ExistingSource::Segment(segment_idx, _) => {
2705 let segment = self.arena.segments[segment_idx.get()];
2706 let seg_len = segment.len() as usize;
2707 let counts = segment.symbol_count;
2708 let kt_log_prob = segment.log_prob_kt;
2709 for offset in 0..seg_len {
2710 let node_depth = depth + offset;
2711 let span = (offset + 1) as u32;
2712
2713 if node_depth == self.max_depth {
2714 self.prepared_steps.push(PreparedStep {
2715 source: ExistingSource::Segment(segment_idx, offset as u32),
2716 counts,
2717 kt_log_prob,
2718 span,
2719 sibling_weight: 0.0,
2720 has_sibling: 0,
2721 });
2722 self.prepared_levels += span as usize;
2723 break 'walk;
2724 }
2725
2726 if offset + 1 >= seg_len && segment.tail.is_none() {
2727 self.prepared_steps.push(PreparedStep {
2728 source: ExistingSource::Segment(segment_idx, offset as u32),
2729 counts,
2730 kt_log_prob,
2731 span,
2732 sibling_weight: 0.0,
2733 has_sibling: 0,
2734 });
2735 self.prepared_levels += span as usize;
2736 self.prepared_end = PreparedEnd::MissingAfterCurrent;
2737 break 'walk;
2738 }
2739
2740 let path_edge = path_edge_at_depth(history, history_len, node_depth);
2741 let existing_edge =
2742 segment_edge_from_parts(segment, offset, history, history_len);
2743 if path_edge != existing_edge {
2744 self.prepared_steps.push(PreparedStep {
2745 source: ExistingSource::Segment(segment_idx, offset as u32),
2746 counts,
2747 kt_log_prob,
2748 span,
2749 sibling_weight: self
2750 .arena
2751 .segment_continuation_weight(segment_idx, offset as u32),
2752 has_sibling: 1,
2753 });
2754 self.prepared_levels += span as usize;
2755 self.prepared_end = PreparedEnd::MismatchAtCurrentSegment;
2756 break 'walk;
2757 }
2758
2759 if offset + 1 < seg_len {
2760 continue;
2761 }
2762
2763 self.prepared_steps.push(PreparedStep {
2764 source: ExistingSource::Segment(segment_idx, offset as u32),
2765 counts,
2766 kt_log_prob,
2767 span,
2768 sibling_weight: 0.0,
2769 has_sibling: 0,
2770 });
2771 self.prepared_levels += span as usize;
2772 let tail = segment.tail;
2773 source =
2774 Self::child_to_existing_source(tail).unwrap_or(ExistingSource::None);
2775 if matches!(source, ExistingSource::None) {
2776 self.prepared_end = PreparedEnd::MissingAfterCurrent;
2777 break 'walk;
2778 }
2779 depth = node_depth + 1;
2780 continue 'walk;
2781 }
2782 }
2783 }
2784 }
2785
2786 let sym_idx = sym as usize;
2787 if self.prepared_levels == 0 {
2788 let counts = self.arena.counts(self.root);
2789 let kt_log_prob = self.arena.log_prob_kt(self.root);
2790 return if self.prepared_end == PreparedEnd::MaxDepth || root_has_sibling == 0 {
2791 predict_ratio_kt(counts, sym_idx)
2792 } else {
2793 predict_ratio_internal(kt_log_prob, counts, 0.0, root_sibling, 0.5, sym_idx)
2794 };
2795 }
2796
2797 let last_step = *self.prepared_steps.last().unwrap();
2798 let last_counts = last_step.counts;
2799 let last_kt_log_prob = last_step.kt_log_prob;
2800 let (mut child_weight, mut ratio) = if self.prepared_end == PreparedEnd::MaxDepth
2801 && self.prepared_levels == self.max_depth
2802 {
2803 (last_kt_log_prob, predict_ratio_kt(last_counts, sym_idx))
2804 } else if last_step.has_sibling == 0 {
2805 (last_kt_log_prob, predict_ratio_kt(last_counts, sym_idx))
2806 } else {
2807 combined_weight_ratio_internal(
2808 last_kt_log_prob,
2809 last_counts,
2810 0.0,
2811 last_step.sibling_weight,
2812 0.5,
2813 sym_idx,
2814 )
2815 };
2816
2817 if let ExistingSource::Segment(_, _) = last_step.source {
2818 if last_step.span > 1 {
2819 let (alpha, log_alpha, log_one_minus_alpha) =
2820 self.segment_constants(last_step.span - 1);
2821 (child_weight, ratio) = unary_chain_ratio_transform_precomputed(
2822 last_step.kt_log_prob,
2823 last_step.counts,
2824 child_weight,
2825 ratio,
2826 alpha,
2827 log_alpha,
2828 log_one_minus_alpha,
2829 sym_idx,
2830 );
2831 }
2832 }
2833
2834 for idx in (0..self.prepared_steps.len() - 1).rev() {
2835 let step = self.prepared_steps[idx];
2836 match step.source {
2837 ExistingSource::Node(_) => {
2838 (child_weight, ratio) = combined_weight_ratio_internal(
2839 step.kt_log_prob,
2840 step.counts,
2841 child_weight,
2842 step.sibling_weight,
2843 ratio,
2844 sym_idx,
2845 );
2846 }
2847 ExistingSource::Segment(_, _) => {
2848 let (alpha, log_alpha, log_one_minus_alpha) = self.segment_constants(step.span);
2849 (child_weight, ratio) = unary_chain_ratio_transform_precomputed(
2850 step.kt_log_prob,
2851 step.counts,
2852 child_weight,
2853 ratio,
2854 alpha,
2855 log_alpha,
2856 log_one_minus_alpha,
2857 sym_idx,
2858 );
2859 }
2860 ExistingSource::None => unreachable!("prepared step should never store None"),
2861 }
2862 }
2863
2864 let root_counts = self.arena.counts(self.root);
2865 let root_kt_log_prob = self.arena.log_prob_kt(self.root);
2866 predict_ratio_internal(
2867 root_kt_log_prob,
2868 root_counts,
2869 child_weight,
2870 root_sibling,
2871 ratio,
2872 sym_idx,
2873 )
2874 }
2875
2876 fn predict_one(&mut self, history: &[Symbol]) -> f64 {
2877 self.prepared_steps.clear();
2878 self.prepared_levels = 0;
2879 self.prepared_end = PreparedEnd::MaxDepth;
2880
2881 let (root_sibling, root_has_sibling, mut source) = if self.max_depth > 0 {
2882 let root_edge = history_symbol(history, 0) as usize;
2883 let path_child = self.arena.child(self.root, root_edge);
2884 let sibling = self.arena.child(self.root, root_edge ^ 1);
2885 (
2886 self.arena.child_ref_weighted(sibling),
2887 sibling.is_some() as u8,
2888 Self::child_to_existing_source(path_child).unwrap_or(ExistingSource::None),
2889 )
2890 } else {
2891 (0.0, 0, ExistingSource::None)
2892 };
2893 if self.max_depth > 0 && matches!(source, ExistingSource::None) {
2894 self.prepared_end = PreparedEnd::MissingAtRoot;
2895 }
2896
2897 let history_len = history.len();
2898 let mut depth = 1usize;
2899 'walk: while depth <= self.max_depth {
2900 match source {
2901 ExistingSource::None => break,
2902 ExistingSource::Node(node_idx) => {
2903 let slot = node_idx.get();
2904 let counts = self.arena.nodes[slot].symbol_count;
2905 let kt_log_prob = self.arena.nodes[slot].log_prob_kt;
2906 if depth == self.max_depth {
2907 self.prepared_steps.push(PreparedStep {
2908 source: ExistingSource::Node(node_idx),
2909 counts,
2910 kt_log_prob,
2911 span: 1,
2912 sibling_weight: 0.0,
2913 has_sibling: 0,
2914 });
2915 self.prepared_levels += 1;
2916 break;
2917 }
2918 let path_edge = history_symbol(history, depth) as usize;
2919 let sibling = self.arena.child(node_idx, path_edge ^ 1);
2920 self.prepared_steps.push(PreparedStep {
2921 source: ExistingSource::Node(node_idx),
2922 counts,
2923 kt_log_prob,
2924 span: 1,
2925 sibling_weight: self.arena.child_ref_weighted(sibling),
2926 has_sibling: sibling.is_some() as u8,
2927 });
2928 self.prepared_levels += 1;
2929 let next = self.arena.child(node_idx, path_edge);
2930 source = Self::child_to_existing_source(next).unwrap_or(ExistingSource::None);
2931 if matches!(source, ExistingSource::None) {
2932 self.prepared_end = PreparedEnd::MissingAfterCurrent;
2933 break;
2934 }
2935 depth += 1;
2936 }
2937 ExistingSource::Segment(segment_idx, _) => {
2938 let segment = self.arena.segments[segment_idx.get()];
2939 let seg_len = segment.len() as usize;
2940 let counts = segment.symbol_count;
2941 let kt_log_prob = segment.log_prob_kt;
2942 for offset in 0..seg_len {
2943 let node_depth = depth + offset;
2944 let span = (offset + 1) as u32;
2945
2946 if node_depth == self.max_depth {
2947 self.prepared_steps.push(PreparedStep {
2948 source: ExistingSource::Segment(segment_idx, offset as u32),
2949 counts,
2950 kt_log_prob,
2951 span,
2952 sibling_weight: 0.0,
2953 has_sibling: 0,
2954 });
2955 self.prepared_levels += span as usize;
2956 break 'walk;
2957 }
2958
2959 if offset + 1 >= seg_len && segment.tail.is_none() {
2960 self.prepared_steps.push(PreparedStep {
2961 source: ExistingSource::Segment(segment_idx, offset as u32),
2962 counts,
2963 kt_log_prob,
2964 span,
2965 sibling_weight: 0.0,
2966 has_sibling: 0,
2967 });
2968 self.prepared_levels += span as usize;
2969 self.prepared_end = PreparedEnd::MissingAfterCurrent;
2970 break 'walk;
2971 }
2972
2973 let path_edge = path_edge_at_depth(history, history_len, node_depth);
2974 let existing_edge =
2975 segment_edge_from_parts(segment, offset, history, history_len);
2976 if path_edge != existing_edge {
2977 self.prepared_steps.push(PreparedStep {
2978 source: ExistingSource::Segment(segment_idx, offset as u32),
2979 counts,
2980 kt_log_prob,
2981 span,
2982 sibling_weight: self
2983 .arena
2984 .segment_continuation_weight(segment_idx, offset as u32),
2985 has_sibling: 1,
2986 });
2987 self.prepared_levels += span as usize;
2988 self.prepared_end = PreparedEnd::MismatchAtCurrentSegment;
2989 break 'walk;
2990 }
2991
2992 if offset + 1 < seg_len {
2993 continue;
2994 }
2995
2996 self.prepared_steps.push(PreparedStep {
2997 source: ExistingSource::Segment(segment_idx, offset as u32),
2998 counts,
2999 kt_log_prob,
3000 span,
3001 sibling_weight: 0.0,
3002 has_sibling: 0,
3003 });
3004 self.prepared_levels += span as usize;
3005 let tail = segment.tail;
3006 source =
3007 Self::child_to_existing_source(tail).unwrap_or(ExistingSource::None);
3008 if matches!(source, ExistingSource::None) {
3009 self.prepared_end = PreparedEnd::MissingAfterCurrent;
3010 break 'walk;
3011 }
3012 depth = node_depth + 1;
3013 continue 'walk;
3014 }
3015 }
3016 }
3017 }
3018
3019 if self.prepared_levels == 0 {
3020 let counts = self.arena.counts(self.root);
3021 let kt_log_prob = self.arena.log_prob_kt(self.root);
3022 return if self.prepared_end == PreparedEnd::MaxDepth || root_has_sibling == 0 {
3023 predict_ratio_kt_one(counts)
3024 } else {
3025 predict_ratio_internal_one(kt_log_prob, counts, 0.0, root_sibling, 0.5)
3026 };
3027 }
3028
3029 let last_step = *self.prepared_steps.last().unwrap();
3030 let last_counts = last_step.counts;
3031 let last_kt_log_prob = last_step.kt_log_prob;
3032 let (mut child_weight, mut ratio) = if self.prepared_end == PreparedEnd::MaxDepth
3033 && self.prepared_levels == self.max_depth
3034 {
3035 (last_kt_log_prob, predict_ratio_kt_one(last_counts))
3036 } else if last_step.has_sibling == 0 {
3037 (last_kt_log_prob, predict_ratio_kt_one(last_counts))
3038 } else {
3039 combined_weight_ratio_internal_one(
3040 last_kt_log_prob,
3041 last_counts,
3042 0.0,
3043 last_step.sibling_weight,
3044 0.5,
3045 )
3046 };
3047
3048 if let ExistingSource::Segment(_, _) = last_step.source
3049 && last_step.span > 1
3050 {
3051 let (alpha, log_alpha, log_one_minus_alpha) =
3052 self.segment_constants(last_step.span - 1);
3053 (child_weight, ratio) = unary_chain_ratio_transform_precomputed_one(
3054 last_step.kt_log_prob,
3055 last_step.counts,
3056 child_weight,
3057 ratio,
3058 alpha,
3059 log_alpha,
3060 log_one_minus_alpha,
3061 );
3062 }
3063
3064 for idx in (0..self.prepared_steps.len() - 1).rev() {
3065 let step = self.prepared_steps[idx];
3066 match step.source {
3067 ExistingSource::Node(_) => {
3068 (child_weight, ratio) = combined_weight_ratio_internal_one(
3069 step.kt_log_prob,
3070 step.counts,
3071 child_weight,
3072 step.sibling_weight,
3073 ratio,
3074 );
3075 }
3076 ExistingSource::Segment(_, _) => {
3077 let (alpha, log_alpha, log_one_minus_alpha) = self.segment_constants(step.span);
3078 (child_weight, ratio) = unary_chain_ratio_transform_precomputed_one(
3079 step.kt_log_prob,
3080 step.counts,
3081 child_weight,
3082 ratio,
3083 alpha,
3084 log_alpha,
3085 log_one_minus_alpha,
3086 );
3087 }
3088 ExistingSource::None => unreachable!("prepared step should never store None"),
3089 }
3090 }
3091
3092 let root_counts = self.arena.counts(self.root);
3093 let root_kt_log_prob = self.arena.log_prob_kt(self.root);
3094 predict_ratio_internal_one(
3095 root_kt_log_prob,
3096 root_counts,
3097 child_weight,
3098 root_sibling,
3099 ratio,
3100 )
3101 }
3102
3103 fn memory_usage(&self) -> usize {
3104 self.arena.memory_usage()
3105 + self.segment_alpha.capacity() * size_of::<f64>()
3106 + self.segment_log_alpha.capacity() * size_of::<f64>()
3107 + self.segment_log_one_minus_alpha.capacity() * size_of::<f64>()
3108 + self.levels.capacity() * size_of::<LevelState>()
3109 + self.detaches.capacity() * size_of::<Detach>()
3110 + self.prepared_steps.capacity() * size_of::<PreparedStep>()
3111 }
3112}
3113
3114#[derive(Clone)]
3116pub struct ContextTree {
3117 engine: CtEngine,
3118 history: Vec<Symbol>,
3119}
3120
3121impl ContextTree {
3122 pub fn new(depth: usize) -> Self {
3124 Self {
3125 engine: CtEngine::new(depth),
3126 history: Vec::new(),
3127 }
3128 }
3129
3130 pub fn clear(&mut self) {
3132 self.history.clear();
3133 self.engine.clear();
3134 }
3135
3136 #[inline]
3137 pub fn update(&mut self, sym: Symbol) {
3139 self.engine.update(sym, &self.history);
3140 self.history.push(sym);
3141 }
3142
3143 #[inline]
3144 pub fn revert(&mut self) {
3146 let Some(last_sym) = self.history.pop() else {
3147 return;
3148 };
3149 self.engine.revert(last_sym, &self.history);
3150 }
3151
3152 #[inline]
3153 pub fn update_history(&mut self, symbols: &[Symbol]) {
3155 self.history.extend_from_slice(symbols);
3156 }
3157
3158 #[inline]
3159 pub fn revert_history(&mut self) {
3161 self.history.pop();
3162 }
3163
3164 pub fn truncate_history(&mut self, new_size: usize) {
3166 if new_size < self.history.len() {
3167 self.history.truncate(new_size);
3168 }
3169 }
3170
3171 #[inline]
3172 pub fn predict(&mut self, sym: Symbol) -> f64 {
3174 self.engine.predict(sym, &self.history)
3175 }
3176
3177 #[inline]
3178 pub fn predict_sym_prob(&mut self) -> f64 {
3180 self.predict(true)
3181 }
3182
3183 #[inline]
3184 pub fn get_log_block_probability(&self) -> f64 {
3186 self.engine.get_log_block_probability()
3187 }
3188
3189 #[inline]
3190 pub fn depth(&self) -> usize {
3192 self.engine.max_depth
3193 }
3194
3195 #[inline]
3196 pub fn history_size(&self) -> usize {
3198 self.history.len()
3199 }
3200}
3201
3202#[derive(Clone)]
3203struct ContextTreeCore {
3204 engine: CtEngine,
3205 prepared_valid: bool,
3206 prepared_history_len: usize,
3207 prepared_history_version: u64,
3208}
3209
3210impl ContextTreeCore {
3211 fn new(depth: usize) -> Self {
3212 Self {
3213 engine: CtEngine::new(depth),
3214 prepared_valid: false,
3215 prepared_history_len: 0,
3216 prepared_history_version: 0,
3217 }
3218 }
3219
3220 fn clear(&mut self) {
3221 self.engine.clear();
3222 self.prepared_valid = false;
3223 self.prepared_history_len = 0;
3224 self.prepared_history_version = 0;
3225 }
3226
3227 #[inline]
3228 fn reserve_for_symbols(&mut self, total_symbols: usize) {
3229 self.engine.reserve_for_symbols(total_symbols);
3230 }
3231
3232 #[inline]
3233 fn update(&mut self, sym: Symbol, shared_history: &[Symbol]) {
3234 self.prepared_valid = false;
3235 self.engine.update(sym, shared_history);
3236 }
3237
3238 #[inline]
3239 fn update_predicted(&mut self, sym: Symbol, shared_history: &[Symbol], history_version: u64) {
3240 let use_prepared = self.prepared_valid
3241 && self.prepared_history_len == shared_history.len()
3242 && self.prepared_history_version == history_version;
3243 self.prepared_valid = false;
3244 self.engine
3245 .update_prepared(sym, shared_history, use_prepared);
3246 }
3247
3248 #[inline]
3249 fn revert(&mut self, last_sym: Symbol, shared_history: &[Symbol]) {
3250 self.prepared_valid = false;
3251 self.engine.revert(last_sym, shared_history);
3252 }
3253
3254 #[inline]
3255 fn predict(&mut self, sym: Symbol, shared_history: &[Symbol], history_version: u64) -> f64 {
3256 let prob = self.engine.predict(sym, shared_history);
3257 self.prepared_valid = true;
3258 self.prepared_history_len = shared_history.len();
3259 self.prepared_history_version = history_version;
3260 prob
3261 }
3262
3263 #[inline]
3264 fn predict_one(&mut self, shared_history: &[Symbol], history_version: u64) -> f64 {
3265 let prob = self.engine.predict_one(shared_history);
3266 self.prepared_valid = true;
3267 self.prepared_history_len = shared_history.len();
3268 self.prepared_history_version = history_version;
3269 prob
3270 }
3271
3272 #[inline]
3273 fn get_log_block_probability(&self) -> f64 {
3274 self.engine.get_log_block_probability()
3275 }
3276}
3277
3278#[derive(Clone)]
3280pub struct FacContextTree {
3281 trees: Vec<ContextTreeCore>,
3282 shared_history: Vec<Symbol>,
3283 base_depth: usize,
3284 num_bits: usize,
3285 shared_history_version: u64,
3286}
3287
3288impl FacContextTree {
3289 pub fn new(base_depth: usize, num_percept_bits: usize) -> Self {
3293 let trees = (0..num_percept_bits)
3294 .map(|i| ContextTreeCore::new(base_depth + i))
3295 .collect();
3296 Self {
3297 trees,
3298 shared_history: Vec::new(),
3299 base_depth,
3300 num_bits: num_percept_bits,
3301 shared_history_version: 0,
3302 }
3303 }
3304
3305 #[inline(always)]
3306 fn bump_shared_history_version(&mut self) {
3307 self.shared_history_version = self.shared_history_version.wrapping_add(1);
3308 }
3309
3310 #[inline]
3311 pub fn reserve_for_symbols(&mut self, total_symbols: usize) {
3313 if total_symbols == 0 {
3314 return;
3315 }
3316 self.shared_history
3317 .reserve_exact(total_symbols.saturating_mul(self.num_bits));
3318 for tree in &mut self.trees {
3319 tree.reserve_for_symbols(total_symbols);
3320 }
3321 }
3322
3323 #[inline]
3324 pub fn num_bits(&self) -> usize {
3326 self.num_bits
3327 }
3328
3329 #[inline]
3330 pub fn base_depth(&self) -> usize {
3332 self.base_depth
3333 }
3334
3335 #[inline]
3336 pub fn update(&mut self, sym: Symbol, bit_index: usize) {
3338 debug_assert!(bit_index < self.num_bits);
3339 self.trees[bit_index].update(sym, &self.shared_history);
3340 self.shared_history.push(sym);
3341 self.bump_shared_history_version();
3342 }
3343
3344 #[inline]
3345 pub fn update_byte_msb(&mut self, byte: u8) {
3347 if self.num_bits != 8 {
3348 for bit_idx in 0..self.num_bits {
3349 let bit = ((byte >> (7 - bit_idx)) & 1) == 1;
3350 self.update(bit, bit_idx);
3351 }
3352 return;
3353 }
3354
3355 let upto = self.trees[0].engine.root_visits() + 1;
3356 debug_assert!(
3357 self.trees
3358 .iter()
3359 .all(|tree| tree.engine.root_visits() + 1 == upto)
3360 );
3361 with_shared_log_cache(upto, |log_int, log_half| {
3362 for bit_idx in 0..8usize {
3363 let bit = ((byte >> (7 - bit_idx)) & 1) == 1;
3364 let tree = &mut self.trees[bit_idx];
3365 tree.prepared_valid = false;
3366 tree.engine
3367 .update_with_logs(log_int, log_half, bit, &self.shared_history);
3368 self.shared_history.push(bit);
3369 }
3370 });
3371 self.bump_shared_history_version();
3372 }
3373
3374 #[inline]
3375 pub fn update_byte_lsb(&mut self, byte: u8) {
3377 let bits = self.num_bits.clamp(1, 8);
3378 let upto = self.trees[0].engine.root_visits() + 1;
3379 debug_assert!(
3380 self.trees
3381 .iter()
3382 .take(bits)
3383 .all(|tree| tree.engine.root_visits() + 1 == upto)
3384 );
3385 with_shared_log_cache(upto, |log_int, log_half| {
3386 for bit_idx in 0..bits {
3387 let bit = ((byte >> bit_idx) & 1) == 1;
3388 let tree = &mut self.trees[bit_idx];
3389 tree.prepared_valid = false;
3390 tree.engine
3391 .update_with_logs(log_int, log_half, bit, &self.shared_history);
3392 self.shared_history.push(bit);
3393 }
3394 });
3395 self.bump_shared_history_version();
3396 }
3397
3398 #[inline]
3399 pub fn update_predicted(&mut self, sym: Symbol, bit_index: usize) {
3401 debug_assert!(bit_index < self.num_bits);
3402 self.trees[bit_index].update_predicted(
3403 sym,
3404 &self.shared_history,
3405 self.shared_history_version,
3406 );
3407 self.shared_history.push(sym);
3408 self.bump_shared_history_version();
3409 }
3410
3411 #[inline]
3412 pub fn predict(&mut self, sym: Symbol, bit_index: usize) -> f64 {
3414 debug_assert!(bit_index < self.num_bits);
3415 self.trees[bit_index].predict(sym, &self.shared_history, self.shared_history_version)
3416 }
3417
3418 #[inline]
3419 pub(crate) fn predict_one(&mut self, bit_index: usize) -> f64 {
3420 debug_assert!(bit_index < self.num_bits);
3421 self.trees[bit_index].predict_one(&self.shared_history, self.shared_history_version)
3422 }
3423
3424 #[inline]
3425 pub fn revert(&mut self, bit_index: usize) {
3427 debug_assert!(bit_index < self.num_bits);
3428 let Some(last_sym) = self.shared_history.pop() else {
3429 return;
3430 };
3431 self.trees[bit_index].revert(last_sym, &self.shared_history);
3432 self.bump_shared_history_version();
3433 }
3434
3435 #[inline]
3436 pub fn update_history(&mut self, symbols: &[Symbol]) {
3438 if symbols.is_empty() {
3439 return;
3440 }
3441 self.shared_history.extend_from_slice(symbols);
3442 self.bump_shared_history_version();
3443 }
3444
3445 #[inline]
3446 pub fn revert_history(&mut self, count: usize) {
3448 let old_len = self.shared_history.len();
3449 let new_len = self.shared_history.len().saturating_sub(count);
3450 if new_len == old_len {
3451 return;
3452 }
3453 self.shared_history.truncate(new_len);
3454 self.bump_shared_history_version();
3455 }
3456
3457 #[inline]
3458 pub fn reset_history_only(&mut self) {
3460 if self.shared_history.is_empty() {
3461 return;
3462 }
3463 self.shared_history.clear();
3464 self.bump_shared_history_version();
3465 }
3466
3467 #[inline]
3468 pub fn get_log_block_probability(&self) -> f64 {
3470 self.trees
3471 .iter()
3472 .map(|t| t.get_log_block_probability())
3473 .sum()
3474 }
3475
3476 pub fn clear(&mut self) {
3478 for tree in &mut self.trees {
3479 tree.clear();
3480 }
3481 self.shared_history.clear();
3482 self.shared_history_version = 0;
3483 }
3484
3485 pub fn memory_usage(&self) -> usize {
3487 let tree_mem: usize = self.trees.iter().map(|t| t.engine.memory_usage()).sum();
3488 let log_cache_mem = self
3489 .trees
3490 .first()
3491 .map(|t| t.engine.log_cache_memory_usage())
3492 .unwrap_or(0);
3493 let history_mem = self.shared_history.capacity() * size_of::<Symbol>();
3494 tree_mem + log_cache_mem + history_mem
3495 }
3496}
3497
3498#[cfg(test)]
3499mod tests {
3500 use super::*;
3501
3502 #[derive(Clone)]
3503 struct RefNode {
3504 children: [Option<Box<RefNode>>; 2],
3505 log_prob_kt: f64,
3506 log_prob_weighted: f64,
3507 symbol_count: [u32; 2],
3508 }
3509
3510 impl Default for RefNode {
3511 fn default() -> Self {
3512 Self {
3513 children: [None, None],
3514 log_prob_kt: 0.0,
3515 log_prob_weighted: 0.0,
3516 symbol_count: [0, 0],
3517 }
3518 }
3519 }
3520
3521 #[derive(Clone)]
3522 struct RefContextTree {
3523 root: RefNode,
3524 history: Vec<Symbol>,
3525 max_depth: usize,
3526 log_int: Vec<f64>,
3527 log_half: Vec<f64>,
3528 }
3529
3530 impl RefContextTree {
3531 fn new(depth: usize) -> Self {
3532 Self {
3533 root: RefNode::default(),
3534 history: Vec::new(),
3535 max_depth: depth,
3536 log_int: vec![f64::NEG_INFINITY],
3537 log_half: vec![(0.5f64).ln()],
3538 }
3539 }
3540
3541 fn root_visits(&self) -> usize {
3542 (self.root.symbol_count[0] + self.root.symbol_count[1]) as usize
3543 }
3544
3545 fn recompute(node: &mut RefNode) {
3546 let w0 = node.children[0]
3547 .as_ref()
3548 .map(|c| c.log_prob_weighted)
3549 .unwrap_or(0.0);
3550 let w1 = node.children[1]
3551 .as_ref()
3552 .map(|c| c.log_prob_weighted)
3553 .unwrap_or(0.0);
3554 let is_leaf = node.children[0].is_none() && node.children[1].is_none();
3555 node.log_prob_weighted = update_weighted_log_prob(node.log_prob_kt, w0, w1, is_leaf);
3556 }
3557
3558 fn update(&mut self, sym: Symbol) {
3559 let upto = self.root_visits() + 1;
3560 ensure_log_caches(&mut self.log_int, &mut self.log_half, upto);
3561 let sym_idx = sym as usize;
3562 Self::update_node(
3563 &mut self.root,
3564 0,
3565 self.max_depth,
3566 &self.history,
3567 sym_idx,
3568 &self.log_int,
3569 &self.log_half,
3570 );
3571 self.history.push(sym);
3572 }
3573
3574 fn revert(&mut self) {
3575 let Some(last_sym) = self.history.pop() else {
3576 return;
3577 };
3578 let upto = self.root_visits();
3579 ensure_log_caches(&mut self.log_int, &mut self.log_half, upto);
3580 let sym_idx = last_sym as usize;
3581 let _ = Self::revert_node(
3582 &mut self.root,
3583 0,
3584 self.max_depth,
3585 &self.history,
3586 sym_idx,
3587 &self.log_int,
3588 &self.log_half,
3589 );
3590 }
3591
3592 fn predict(&mut self, sym: Symbol) -> f64 {
3593 let sym_idx = sym as usize;
3594 let mut entries = Vec::with_capacity(self.max_depth + 1);
3595 let reached_max_depth = Self::collect_predict_entries(
3596 &self.root,
3597 0,
3598 self.max_depth,
3599 &self.history,
3600 &mut entries,
3601 );
3602
3603 let deepest = entries.len() - 1;
3604 let mut ratio = if reached_max_depth && deepest == self.max_depth {
3605 predict_ratio_kt(entries[deepest].symbol_count, sym_idx)
3606 } else {
3607 0.5
3608 };
3609 for idx in (0..=deepest).rev() {
3610 if reached_max_depth && idx == deepest {
3611 continue;
3612 }
3613 let child_weight = if idx + 1 <= deepest {
3614 entries[idx + 1].log_prob_weighted
3615 } else {
3616 0.0
3617 };
3618 ratio = predict_ratio_internal(
3619 entries[idx].log_prob_kt,
3620 entries[idx].symbol_count,
3621 child_weight,
3622 entries[idx].sibling_weight,
3623 ratio,
3624 sym_idx,
3625 );
3626 }
3627 ratio
3628 }
3629
3630 fn get_log_block_probability(&self) -> f64 {
3631 self.root.log_prob_weighted
3632 }
3633
3634 fn update_node(
3635 node: &mut RefNode,
3636 depth: usize,
3637 max_depth: usize,
3638 history: &[Symbol],
3639 sym_idx: usize,
3640 log_int: &[f64],
3641 log_half: &[f64],
3642 ) {
3643 if depth < max_depth {
3644 let edge = history_symbol(history, depth) as usize;
3645 if node.children[edge].is_none() {
3646 node.children[edge] = Some(Box::new(RefNode::default()));
3647 }
3648 Self::update_node(
3649 node.children[edge].as_deref_mut().unwrap(),
3650 depth + 1,
3651 max_depth,
3652 history,
3653 sym_idx,
3654 log_int,
3655 log_half,
3656 );
3657 }
3658 apply_update_to_state_raw(
3659 log_int,
3660 log_half,
3661 &mut node.symbol_count,
3662 &mut node.log_prob_kt,
3663 sym_idx,
3664 );
3665 Self::recompute(node);
3666 }
3667
3668 fn revert_node(
3669 node: &mut RefNode,
3670 depth: usize,
3671 max_depth: usize,
3672 history: &[Symbol],
3673 sym_idx: usize,
3674 log_int: &[f64],
3675 log_half: &[f64],
3676 ) -> bool {
3677 if depth < max_depth {
3678 let edge = history_symbol(history, depth) as usize;
3679 let remove_child = if let Some(child) = node.children[edge].as_deref_mut() {
3680 Self::revert_node(
3681 child,
3682 depth + 1,
3683 max_depth,
3684 history,
3685 sym_idx,
3686 log_int,
3687 log_half,
3688 )
3689 } else {
3690 false
3691 };
3692 if remove_child {
3693 node.children[edge] = None;
3694 }
3695 }
3696 apply_revert_to_state_raw(
3697 log_int,
3698 log_half,
3699 &mut node.symbol_count,
3700 &mut node.log_prob_kt,
3701 sym_idx,
3702 );
3703 Self::recompute(node);
3704 node.symbol_count[0] + node.symbol_count[1] == 0
3705 }
3706
3707 fn collect_predict_entries(
3708 node: &RefNode,
3709 depth: usize,
3710 max_depth: usize,
3711 history: &[Symbol],
3712 entries: &mut Vec<PredictEntry>,
3713 ) -> bool {
3714 let sibling_weight = if depth < max_depth {
3715 let path_edge = history_symbol(history, depth) as usize;
3716 node.children[path_edge ^ 1]
3717 .as_ref()
3718 .map(|c| c.log_prob_weighted)
3719 .unwrap_or(0.0)
3720 } else {
3721 0.0
3722 };
3723 entries.push(PredictEntry {
3724 symbol_count: node.symbol_count,
3725 log_prob_kt: node.log_prob_kt,
3726 log_prob_weighted: node.log_prob_weighted,
3727 sibling_weight,
3728 has_sibling: depth < max_depth
3729 && node.children[(history_symbol(history, depth) as usize) ^ 1].is_some(),
3730 });
3731 if depth == max_depth {
3732 return true;
3733 }
3734 let edge = history_symbol(history, depth) as usize;
3735 let Some(child) = node.children[edge].as_ref() else {
3736 return false;
3737 };
3738 Self::collect_predict_entries(child, depth + 1, max_depth, history, entries)
3739 }
3740 }
3741
3742 #[derive(Clone)]
3743 struct RefFacContextTree {
3744 trees: Vec<RefContextTree>,
3745 history: Vec<Symbol>,
3746 }
3747
3748 impl RefFacContextTree {
3749 fn new(base_depth: usize, num_bits: usize) -> Self {
3750 Self {
3751 trees: (0..num_bits)
3752 .map(|i| RefContextTree::new(base_depth + i))
3753 .collect(),
3754 history: Vec::new(),
3755 }
3756 }
3757
3758 fn update(&mut self, sym: Symbol, bit_index: usize) {
3759 let tree = &mut self.trees[bit_index];
3760 tree.history = self.history.clone();
3761 tree.update(sym);
3762 self.history.push(sym);
3763 }
3764
3765 fn predict(&mut self, sym: Symbol, bit_index: usize) -> f64 {
3766 let tree = &mut self.trees[bit_index];
3767 tree.history = self.history.clone();
3768 tree.predict(sym)
3769 }
3770
3771 fn revert(&mut self, bit_index: usize) {
3772 let Some(last_sym) = self.history.pop() else {
3773 return;
3774 };
3775 let tree = &mut self.trees[bit_index];
3776 tree.history = self.history.clone();
3777 tree.history.push(last_sym);
3778 tree.revert();
3779 }
3780
3781 fn get_log_block_probability(&self) -> f64 {
3782 self.trees
3783 .iter()
3784 .map(RefContextTree::get_log_block_probability)
3785 .sum()
3786 }
3787 }
3788
3789 fn assert_close(a: f64, b: f64) {
3790 let diff = (a - b).abs();
3791 let scale = a.abs().max(b.abs()).max(1.0);
3792 assert!(diff <= 1e-12 * scale, "a={a} b={b} diff={diff}");
3793 }
3794
3795 fn child_after_hot_prefix(tree: &ContextTree, history_before_update: &[Symbol]) -> ChildRef {
3796 let hot_prefix_depth = tree.engine.hot_prefix_depth();
3797 if hot_prefix_depth == 0 {
3798 return ChildRef::NONE;
3799 }
3800
3801 let root_edge = history_symbol(history_before_update, 0) as usize;
3802 let mut current = tree
3803 .engine
3804 .arena
3805 .child(tree.engine.root, root_edge)
3806 .as_node()
3807 .expect("hot-prefix node");
3808 for node_depth in 1..hot_prefix_depth {
3809 let edge = history_symbol(history_before_update, node_depth) as usize;
3810 current = tree
3811 .engine
3812 .arena
3813 .child(current, edge)
3814 .as_node()
3815 .expect("next hot-prefix node");
3816 }
3817 let tail_edge = history_symbol(history_before_update, hot_prefix_depth) as usize;
3818 tree.engine.arena.child(current, tail_edge)
3819 }
3820
3821 #[test]
3822 #[should_panic(expected = "ctw node index overflow")]
3823 fn node_index_from_usize_rejects_overflow() {
3824 let _ = NodeIndex::from_usize(INDEX_LIMIT);
3825 }
3826
3827 #[test]
3828 #[should_panic(expected = "ctw node index overflow")]
3829 fn node_index_from_usize_rejects_large_values() {
3830 let _ = NodeIndex::from_usize(u32::MAX as usize);
3831 }
3832
3833 #[test]
3834 fn ctw_count_lane_stays_packed() {
3835 assert_eq!(std::mem::size_of::<CtNode>(), 32);
3836 }
3837
3838 #[test]
3839 fn ctw_segment_payload_stays_packed() {
3840 assert_eq!(std::mem::size_of::<CtSegment>(), 40);
3841 }
3842
3843 #[test]
3844 fn context_tree_singleton_paths_use_hot_prefix_nodes() {
3845 let mut tree = ContextTree::new(12);
3846 tree.update(false);
3847
3848 let hot_prefix_depth = tree.engine.hot_prefix_depth();
3849 let child = tree.engine.arena.child(tree.engine.root, 0);
3850 let mut current = child.as_node().expect("hot-prefix node");
3851 let mut visited_hot_prefix_nodes = 1usize;
3852 for depth in 1..hot_prefix_depth {
3853 let next = tree.engine.arena.child(current, 0);
3854 current = next.as_node().expect("next hot-prefix node");
3855 visited_hot_prefix_nodes += 1;
3856 assert!(depth < hot_prefix_depth);
3857 }
3858 assert_eq!(visited_hot_prefix_nodes, hot_prefix_depth);
3859 let segment = tree
3860 .engine
3861 .arena
3862 .child(current, 0)
3863 .as_segment()
3864 .expect("segment tail");
3865 assert!(tree.engine.arena.child(current, 1).is_none());
3866 assert!(tree.engine.arena.segments[segment.get()].tail.is_none());
3867 assert_close(
3868 tree.engine.arena.segments[segment.get()].head_log_prob_weighted,
3869 -std::f64::consts::LN_2,
3870 );
3871 assert_close(tree.get_log_block_probability(), -std::f64::consts::LN_2);
3872 }
3873
3874 #[test]
3875 fn context_tree_missing_path_tail_uses_exact_segment_payloads() {
3876 let mut tree = ContextTree::new(12);
3877 tree.update(true);
3878 let child = tree.engine.arena.child(tree.engine.root, 0);
3879 let mut current = child.as_node().expect("hot-prefix node");
3880 for _ in 1..tree.engine.hot_prefix_depth() {
3881 current = tree
3882 .engine
3883 .arena
3884 .child(current, 0)
3885 .as_node()
3886 .expect("next hot-prefix node");
3887 }
3888 let segment = tree
3889 .engine
3890 .arena
3891 .child(current, 0)
3892 .as_segment()
3893 .expect("segment tail");
3894 let payload = tree.engine.arena.segments[segment.get()].payload;
3895 assert!(payload.is_exact());
3896 assert_eq!(
3897 payload.len() as usize,
3898 tree.engine.max_depth - tree.engine.hot_prefix_depth()
3899 );
3900 assert_eq!(payload.exact_bits() & low_bits_mask_u64(payload.len()), 0);
3901 }
3902
3903 #[test]
3904 fn context_tree_missing_path_tail_uses_const_payload_beyond_exact_limit() {
3905 let mut tree = ContextTree::new(80);
3906 let history_before = tree.history.clone();
3907 tree.update(false);
3908
3909 let segment = child_after_hot_prefix(&tree, &history_before)
3910 .as_segment()
3911 .expect("segment tail");
3912 let segment = tree.engine.arena.segments[segment.get()];
3913 assert_eq!(segment.payload.mode(), SEG_MODE_CONST);
3914 assert_eq!(
3915 segment.payload.len() as usize,
3916 tree.engine.max_depth - tree.engine.hot_prefix_depth()
3917 );
3918 assert!(!segment.payload.const_bit());
3919 assert!(segment.tail.is_none());
3920 }
3921
3922 #[test]
3923 fn context_tree_missing_path_tail_uses_history_and_const_payloads_beyond_exact_limit() {
3924 let mut tree = ContextTree::new(80);
3925 let seeded_history: Vec<Symbol> = (0..80).map(|i| (i & 1) == 1).collect();
3926 tree.update_history(&seeded_history);
3927 let history_before = tree.history.clone();
3928 tree.update(false);
3929
3930 let first_segment = child_after_hot_prefix(&tree, &history_before)
3931 .as_segment()
3932 .expect("history-backed segment tail");
3933 let first_segment = tree.engine.arena.segments[first_segment.get()];
3934 assert_eq!(first_segment.payload.mode(), SEG_MODE_HISTORY);
3935 assert_eq!(first_segment.payload.len(), 69);
3936 for offset in [0usize, 1, 7, 31, 68] {
3937 assert_eq!(
3938 segment_edge_from_parts(
3939 first_segment,
3940 offset,
3941 &history_before,
3942 history_before.len()
3943 ),
3944 history_symbol(&history_before, tree.engine.hot_prefix_depth() + 1 + offset)
3945 );
3946 }
3947
3948 let tail_segment = first_segment
3949 .tail
3950 .as_segment()
3951 .expect("constant fallback tail");
3952 let tail_segment = tree.engine.arena.segments[tail_segment.get()];
3953 assert_eq!(tail_segment.payload.mode(), SEG_MODE_CONST);
3954 assert_eq!(tail_segment.payload.len(), 1);
3955 assert!(!tail_segment.payload.const_bit());
3956 assert!(tail_segment.tail.is_none());
3957 }
3958
3959 #[test]
3960 fn context_tree_matches_reference_on_short_sequences() {
3961 for depth in 0..=6usize {
3962 for len in 0..=6usize {
3963 for mask in 0..(1usize << len) {
3964 let mut prod = ContextTree::new(depth);
3965 let mut reference = RefContextTree::new(depth);
3966 for step in 0..len {
3967 let p_prod_0 = prod.predict(false);
3968 let p_ref_0 = reference.predict(false);
3969 assert!(
3970 (p_prod_0 - p_ref_0).abs()
3971 <= 1e-12 * p_prod_0.abs().max(p_ref_0.abs()).max(1.0),
3972 "predict0 mismatch depth={depth} len={len} mask={mask} step={step} prod={p_prod_0} ref={p_ref_0} history={:?}",
3973 prod.history
3974 );
3975 let p_prod_1 = prod.predict(true);
3976 let p_ref_1 = reference.predict(true);
3977 assert!(
3978 (p_prod_1 - p_ref_1).abs()
3979 <= 1e-12 * p_prod_1.abs().max(p_ref_1.abs()).max(1.0),
3980 "predict1 mismatch depth={depth} len={len} mask={mask} step={step} prod={p_prod_1} ref={p_ref_1} history={:?}",
3981 prod.history
3982 );
3983 let log_prod = prod.get_log_block_probability();
3984 let log_ref = reference.get_log_block_probability();
3985 assert!(
3986 (log_prod - log_ref).abs()
3987 <= 1e-12 * log_prod.abs().max(log_ref.abs()).max(1.0),
3988 "log mismatch before update depth={depth} len={len} mask={mask} step={step} prod={log_prod} ref={log_ref} history={:?}",
3989 prod.history
3990 );
3991 let bit = ((mask >> step) & 1) == 1;
3992 prod.update(bit);
3993 reference.update(bit);
3994 let log_prod = prod.get_log_block_probability();
3995 let log_ref = reference.get_log_block_probability();
3996 assert!(
3997 (log_prod - log_ref).abs()
3998 <= 1e-12 * log_prod.abs().max(log_ref.abs()).max(1.0),
3999 "log mismatch after update depth={depth} len={len} mask={mask} step={step} bit={bit} prod={log_prod} ref={log_ref} history={:?}",
4000 prod.history
4001 );
4002 }
4003 while prod.history_size() > 0 {
4004 let p_prod_0 = prod.predict(false);
4005 let p_ref_0 = reference.predict(false);
4006 assert!(
4007 (p_prod_0 - p_ref_0).abs()
4008 <= 1e-12 * p_prod_0.abs().max(p_ref_0.abs()).max(1.0),
4009 "revert predict0 mismatch depth={depth} len={len} mask={mask} prod={p_prod_0} ref={p_ref_0} history={:?}",
4010 prod.history
4011 );
4012 let p_prod_1 = prod.predict(true);
4013 let p_ref_1 = reference.predict(true);
4014 assert!(
4015 (p_prod_1 - p_ref_1).abs()
4016 <= 1e-12 * p_prod_1.abs().max(p_ref_1.abs()).max(1.0),
4017 "revert predict1 mismatch depth={depth} len={len} mask={mask} prod={p_prod_1} ref={p_ref_1} history={:?}",
4018 prod.history
4019 );
4020 prod.revert();
4021 reference.revert();
4022 let log_prod = prod.get_log_block_probability();
4023 let log_ref = reference.get_log_block_probability();
4024 assert!(
4025 (log_prod - log_ref).abs()
4026 <= 1e-12 * log_prod.abs().max(log_ref.abs()).max(1.0),
4027 "revert log mismatch depth={depth} len={len} mask={mask} prod={log_prod} ref={log_ref} history={:?}",
4028 prod.history
4029 );
4030 }
4031 }
4032 }
4033 }
4034 }
4035
4036 #[test]
4037 fn context_tree_long_depth_matches_reference_on_short_sequences() {
4038 for &depth in &[65usize, 80usize] {
4039 for len in 0..=6usize {
4040 for mask in 0..(1usize << len) {
4041 let mut prod = ContextTree::new(depth);
4042 let mut reference = RefContextTree::new(depth);
4043 for step in 0..len {
4044 let p_prod_0 = prod.predict(false);
4045 let p_ref_0 = reference.predict(false);
4046 assert!(
4047 (p_prod_0 - p_ref_0).abs()
4048 <= 1e-12 * p_prod_0.abs().max(p_ref_0.abs()).max(1.0),
4049 "long-depth predict0 mismatch depth={depth} len={len} mask={mask} step={step} prod={p_prod_0} ref={p_ref_0} history={:?}",
4050 prod.history
4051 );
4052 let p_prod_1 = prod.predict(true);
4053 let p_ref_1 = reference.predict(true);
4054 assert!(
4055 (p_prod_1 - p_ref_1).abs()
4056 <= 1e-12 * p_prod_1.abs().max(p_ref_1.abs()).max(1.0),
4057 "long-depth predict1 mismatch depth={depth} len={len} mask={mask} step={step} prod={p_prod_1} ref={p_ref_1} history={:?}",
4058 prod.history
4059 );
4060 assert_close(
4061 prod.get_log_block_probability(),
4062 reference.get_log_block_probability(),
4063 );
4064 let bit = ((mask >> step) & 1) == 1;
4065 prod.update(bit);
4066 reference.update(bit);
4067 assert_close(
4068 prod.get_log_block_probability(),
4069 reference.get_log_block_probability(),
4070 );
4071 }
4072
4073 while prod.history_size() > 0 {
4074 assert_close(prod.predict(false), reference.predict(false));
4075 assert_close(prod.predict(true), reference.predict(true));
4076 prod.revert();
4077 reference.revert();
4078 assert_close(
4079 prod.get_log_block_probability(),
4080 reference.get_log_block_probability(),
4081 );
4082 }
4083 }
4084 }
4085 }
4086 }
4087
4088 #[test]
4089 fn fac_ctw_matches_reference_on_short_sequences() {
4090 let mut fac = FacContextTree::new(4, 4);
4091 let mut reference = RefFacContextTree::new(4, 4);
4092 let stream = [
4093 (true, 0usize),
4094 (false, 1usize),
4095 (true, 2usize),
4096 (true, 3usize),
4097 (false, 0usize),
4098 (false, 1usize),
4099 (true, 2usize),
4100 (false, 3usize),
4101 ];
4102
4103 for &(bit, idx) in &stream {
4104 assert_close(fac.predict(false, idx), reference.predict(false, idx));
4105 assert_close(fac.predict(true, idx), reference.predict(true, idx));
4106 fac.update(bit, idx);
4107 reference.update(bit, idx);
4108 assert_close(
4109 fac.get_log_block_probability(),
4110 reference.get_log_block_probability(),
4111 );
4112 }
4113
4114 for &(_, idx) in stream.iter().rev() {
4115 fac.revert(idx);
4116 reference.revert(idx);
4117 assert_close(
4118 fac.get_log_block_probability(),
4119 reference.get_log_block_probability(),
4120 );
4121 }
4122 }
4123
4124 #[test]
4125 fn fac_ctw_history_consistency() {
4126 let mut fac = FacContextTree::new(4, 4);
4127
4128 fac.update_history(&[true, false, true]);
4129 assert_eq!(fac.shared_history.len(), 3);
4130
4131 fac.update(true, 0);
4132 fac.update(false, 1);
4133 assert_eq!(fac.shared_history.len(), 5);
4134
4135 fac.revert(1);
4136 assert_eq!(fac.shared_history.len(), 4);
4137
4138 fac.revert(0);
4139 assert_eq!(fac.shared_history.len(), 3);
4140 }
4141
4142 #[test]
4143 fn fac_ctw_predict_one_matches_predict_true() {
4144 let mut fac = FacContextTree::new(6, 8);
4145 for &byte in b"predict-one exactness regression payload" {
4146 for bit_idx in 0..8usize {
4147 let p_generic = fac.predict(true, bit_idx);
4148 let p_one = fac.predict_one(bit_idx);
4149 assert_close(p_generic, p_one);
4150 let bit = ((byte >> (7 - bit_idx)) & 1) == 1;
4151 fac.update_predicted(bit, bit_idx);
4152 }
4153 }
4154 }
4155
4156 #[test]
4157 fn fac_ctw_long_depth_predict_one_matches_predict_true() {
4158 let mut fac = FacContextTree::new(78, 4);
4159 for step in 0..24usize {
4160 for bit_idx in 0..fac.num_bits() {
4161 let p_generic = fac.predict(true, bit_idx);
4162 let p_one = fac.predict_one(bit_idx);
4163 assert_close(p_generic, p_one);
4164 let bit = ((step * 5 + bit_idx * 3) & 1) == 1;
4165 fac.update_predicted(bit, bit_idx);
4166 }
4167 }
4168 }
4169
4170 #[test]
4171 fn fac_ctw_update_byte_msb_matches_bit_updates() {
4172 let mut by_byte = FacContextTree::new(6, 8);
4173 let mut by_bits = FacContextTree::new(6, 8);
4174 for &byte in b"byte update msb regression payload" {
4175 by_byte.update_byte_msb(byte);
4176 for bit_idx in 0..8usize {
4177 let bit = ((byte >> (7 - bit_idx)) & 1) == 1;
4178 by_bits.update(bit, bit_idx);
4179 }
4180 assert_close(
4181 by_byte.get_log_block_probability(),
4182 by_bits.get_log_block_probability(),
4183 );
4184 assert_eq!(by_byte.shared_history, by_bits.shared_history);
4185 }
4186 }
4187
4188 #[test]
4189 fn fac_ctw_update_byte_lsb_matches_bit_updates() {
4190 let mut by_byte = FacContextTree::new(6, 5);
4191 let mut by_bits = FacContextTree::new(6, 5);
4192 for &byte in b"byte update lsb regression payload" {
4193 by_byte.update_byte_lsb(byte);
4194 for bit_idx in 0..5usize {
4195 let bit = ((byte >> bit_idx) & 1) == 1;
4196 by_bits.update(bit, bit_idx);
4197 }
4198 assert_close(
4199 by_byte.get_log_block_probability(),
4200 by_bits.get_log_block_probability(),
4201 );
4202 assert_eq!(by_byte.shared_history, by_bits.shared_history);
4203 }
4204 }
4205
4206 #[test]
4207 fn fac_ctw_log_cache_tracks_tree_visits_not_shared_history() {
4208 let mut fac = FacContextTree::new(8, 8);
4209 let updates_per_tree = 512usize;
4210 let (log_int_before, log_half_before) = shared_log_cache_lens();
4211
4212 for step in 0..updates_per_tree {
4213 let bit = (step & 1) == 1;
4214 for bit_idx in 0..8usize {
4215 fac.update(bit, bit_idx);
4216 }
4217 }
4218
4219 assert_eq!(fac.shared_history.len(), updates_per_tree * 8);
4220 for tree in &fac.trees {
4221 let visits = tree.engine.arena.visits(tree.engine.root) as usize;
4222 assert_eq!(visits, updates_per_tree);
4223 }
4224
4225 let (log_int_after, log_half_after) = shared_log_cache_lens();
4226 let expected_len = updates_per_tree + 1;
4227 assert!(
4228 log_int_after <= log_int_before.max(expected_len),
4229 "log_int grew to {log_int_after} (before={log_int_before}, expected_len={expected_len})"
4230 );
4231 assert!(
4232 log_half_after <= log_half_before.max(expected_len),
4233 "log_half grew to {log_half_after} (before={log_half_before}, expected_len={expected_len})"
4234 );
4235 }
4236
4237 fn seed_fac_cache_regression_state(fac: &mut FacContextTree) {
4238 for step in 0..24usize {
4239 for bit_idx in 0..fac.num_bits() {
4240 let bit = ((step * 3 + bit_idx) & 1) == 1;
4241 fac.update(bit, bit_idx);
4242 }
4243 }
4244 }
4245
4246 fn assert_update_predicted_matches_fresh_after_history_rewrite<F>(mut rewrite: F)
4247 where
4248 F: FnMut(&mut FacContextTree),
4249 {
4250 let mut predicted = FacContextTree::new(6, 4);
4251 seed_fac_cache_regression_state(&mut predicted);
4252 let mut fresh = predicted.clone();
4253 let original_history = predicted.shared_history.clone();
4254 let target_bit = 2usize;
4255
4256 let _ = predicted.predict(true, target_bit);
4257 rewrite(&mut predicted);
4258 rewrite(&mut fresh);
4259
4260 assert_eq!(predicted.shared_history.len(), original_history.len());
4261 assert_ne!(predicted.shared_history, original_history);
4262 assert_eq!(predicted.shared_history, fresh.shared_history);
4263
4264 predicted.update_predicted(false, target_bit);
4265 fresh.update(false, target_bit);
4266
4267 assert_eq!(predicted.shared_history, fresh.shared_history);
4268 assert_close(
4269 predicted.get_log_block_probability(),
4270 fresh.get_log_block_probability(),
4271 );
4272 for bit_idx in 0..predicted.num_bits() {
4273 assert_close(
4274 predicted.predict(false, bit_idx),
4275 fresh.predict(false, bit_idx),
4276 );
4277 assert_close(
4278 predicted.predict(true, bit_idx),
4279 fresh.predict(true, bit_idx),
4280 );
4281 }
4282 }
4283
4284 #[test]
4285 fn fac_ctw_update_predicted_ignores_stale_cache_after_reset_and_rewrite() {
4286 assert_update_predicted_matches_fresh_after_history_rewrite(|fac| {
4287 let mut rewritten = fac.shared_history.clone();
4288 for bit in &mut rewritten {
4289 *bit = !*bit;
4290 }
4291 fac.reset_history_only();
4292 fac.update_history(&rewritten);
4293 });
4294 }
4295
4296 #[test]
4297 fn fac_ctw_update_predicted_ignores_stale_cache_after_revert_and_rewrite() {
4298 assert_update_predicted_matches_fresh_after_history_rewrite(|fac| {
4299 let original = fac.shared_history.clone();
4300 let keep = original.len() / 3;
4301 let remove = original.len() - keep;
4302 let mut rewritten_suffix = original[keep..].to_vec();
4303 for bit in &mut rewritten_suffix {
4304 *bit = !*bit;
4305 }
4306 fac.revert_history(remove);
4307 fac.update_history(&rewritten_suffix);
4308 });
4309 }
4310
4311 #[test]
4312 fn fac_ctw_shared_history_version_tracks_mutations() {
4313 let mut fac = FacContextTree::new(4, 2);
4314 let mut version = fac.shared_history_version;
4315
4316 fac.update_history(&[]);
4317 assert_eq!(fac.shared_history_version, version);
4318
4319 fac.update_history(&[true, false]);
4320 assert_ne!(fac.shared_history_version, version);
4321 version = fac.shared_history_version;
4322
4323 fac.revert_history(0);
4324 assert_eq!(fac.shared_history_version, version);
4325
4326 fac.revert_history(1);
4327 assert_ne!(fac.shared_history_version, version);
4328 version = fac.shared_history_version;
4329
4330 let _ = fac.predict(true, 0);
4331 assert_eq!(fac.shared_history_version, version);
4332
4333 fac.update_predicted(true, 0);
4334 assert_ne!(fac.shared_history_version, version);
4335 version = fac.shared_history_version;
4336
4337 fac.reset_history_only();
4338 assert_ne!(fac.shared_history_version, version);
4339 }
4340
4341 #[test]
4342 fn context_tree_predict_preserves_state() {
4343 let mut tree = ContextTree::new(6);
4344 for &bit in &[true, false, true, true, false, false, true, false] {
4345 tree.update(bit);
4346 }
4347 let p0_before = tree.predict(false);
4348 let p1_before = tree.predict(true);
4349 let log_before = tree.get_log_block_probability();
4350 let history_before = tree.history.clone();
4351 let _ = tree.predict(true);
4352
4353 assert_eq!(tree.history, history_before);
4354 assert_close(tree.get_log_block_probability(), log_before);
4355 assert_close(tree.predict(false), p0_before);
4356 assert_close(tree.predict(true), p1_before);
4357 }
4358
4359 #[test]
4360 fn context_tree_predict_matches_update_ratio() {
4361 let mut tree = ContextTree::new(7);
4362 for &bit in &[true, false, true, false, true, true, false, true, false] {
4363 tree.update(bit);
4364 }
4365 for &sym in &[false, true] {
4366 let predicted = tree.predict(sym);
4367 let mut reference = tree.clone();
4368 let before = reference.get_log_block_probability();
4369 reference.update(sym);
4370 let after = reference.get_log_block_probability();
4371 assert_close(predicted, (after - before).exp());
4372 }
4373 }
4374
4375 #[test]
4376 fn fac_ctw_predict_preserves_state() {
4377 let mut fac = FacContextTree::new(5, 8);
4378 for &byte in b"fac ctw state preservation" {
4379 for bit_idx in 0..8usize {
4380 let bit = ((byte >> (7 - bit_idx)) & 1) == 1;
4381 fac.update(bit, bit_idx);
4382 }
4383 }
4384 let p0_before = fac.predict(false, 3);
4385 let p1_before = fac.predict(true, 3);
4386 let log_before = fac.get_log_block_probability();
4387 let history_before = fac.shared_history.clone();
4388 let _ = fac.predict(true, 3);
4389
4390 assert_eq!(fac.shared_history, history_before);
4391 assert_close(fac.get_log_block_probability(), log_before);
4392 assert_close(fac.predict(false, 3), p0_before);
4393 assert_close(fac.predict(true, 3), p1_before);
4394 }
4395
4396 #[test]
4397 fn fac_ctw_predict_matches_update_ratio() {
4398 let mut fac = FacContextTree::new(6, 8);
4399 for &byte in b"fac ctw exact predictive ratio" {
4400 for bit_idx in 0..8usize {
4401 let bit = ((byte >> (7 - bit_idx)) & 1) == 1;
4402 fac.update(bit, bit_idx);
4403 }
4404 }
4405 for &sym in &[false, true] {
4406 let predicted = fac.predict(sym, 4);
4407 let mut reference = fac.clone();
4408 let before = reference.get_log_block_probability();
4409 reference.update(sym, 4);
4410 let after = reference.get_log_block_probability();
4411 assert_close(predicted, (after - before).exp());
4412 }
4413 }
4414
4415 #[test]
4416 fn fac_ctw_update_predicted_matches_fresh_update_on_byte_stream() {
4417 let mut predicted = FacContextTree::new(6, 8);
4418 let mut fresh = predicted.clone();
4419 let stream = b"fac-ctw prepared update exactness regression";
4420
4421 for (byte_pos, &byte) in stream.iter().enumerate() {
4422 for bit_idx in 0..8usize {
4423 let bit = ((byte >> (7 - bit_idx)) & 1) == 1;
4424 let _ = predicted.predict(true, bit_idx);
4425 predicted.update_predicted(bit, bit_idx);
4426 fresh.update(bit, bit_idx);
4427
4428 let predicted_log = predicted.get_log_block_probability();
4429 let fresh_log = fresh.get_log_block_probability();
4430 assert!(
4431 (predicted_log - fresh_log).abs()
4432 <= 1e-12 * predicted_log.abs().max(fresh_log.abs()).max(1.0),
4433 "log mismatch byte_pos={byte_pos} bit_idx={bit_idx} bit={bit} predicted={predicted_log} fresh={fresh_log}\nshared_history={:?}\npredicted_arena={:#?}\nfresh_arena={:#?}\npredicted_steps={:?}\nfresh_steps={:?}",
4434 predicted.shared_history,
4435 predicted.trees[bit_idx].engine.arena,
4436 fresh.trees[bit_idx].engine.arena,
4437 predicted.trees[bit_idx].engine.prepared_steps,
4438 fresh.trees[bit_idx].engine.prepared_steps,
4439 );
4440 for probe_idx in 0..8usize {
4441 let p_pred_0 = predicted.predict(false, probe_idx);
4442 let p_fresh_0 = fresh.predict(false, probe_idx);
4443 assert!(
4444 (p_pred_0 - p_fresh_0).abs()
4445 <= 1e-12 * p_pred_0.abs().max(p_fresh_0.abs()).max(1.0),
4446 "predict0 mismatch byte_pos={byte_pos} bit_idx={bit_idx} probe_idx={probe_idx} predicted={p_pred_0} fresh={p_fresh_0}",
4447 );
4448 let p_pred_1 = predicted.predict(true, probe_idx);
4449 let p_fresh_1 = fresh.predict(true, probe_idx);
4450 assert!(
4451 (p_pred_1 - p_fresh_1).abs()
4452 <= 1e-12 * p_pred_1.abs().max(p_fresh_1.abs()).max(1.0),
4453 "predict1 mismatch byte_pos={byte_pos} bit_idx={bit_idx} probe_idx={probe_idx} predicted={p_pred_1} fresh={p_fresh_1}",
4454 );
4455 }
4456 }
4457 }
4458 }
4459
4460 #[test]
4461 fn fac_ctw_long_depth_update_predicted_matches_fresh_update_on_bit_stream() {
4462 let mut predicted = FacContextTree::new(78, 4);
4463 let mut fresh = predicted.clone();
4464
4465 for step in 0..20usize {
4466 for bit_idx in 0..predicted.num_bits() {
4467 let bit = ((step * 7 + bit_idx * 11) & 1) == 1;
4468 let _ = predicted.predict(true, bit_idx);
4469 predicted.update_predicted(bit, bit_idx);
4470 fresh.update(bit, bit_idx);
4471 assert_eq!(predicted.shared_history, fresh.shared_history);
4472 assert_close(
4473 predicted.get_log_block_probability(),
4474 fresh.get_log_block_probability(),
4475 );
4476 }
4477 }
4478
4479 for bit_idx in 0..predicted.num_bits() {
4480 assert_close(
4481 predicted.predict(false, bit_idx),
4482 fresh.predict(false, bit_idx),
4483 );
4484 assert_close(
4485 predicted.predict(true, bit_idx),
4486 fresh.predict(true, bit_idx),
4487 );
4488 }
4489 }
4490
4491 fn scan_symbol_space(tree: &mut FacContextTree, bits: usize) {
4492 fn rec(tree: &mut FacContextTree, bits: usize, depth: usize) {
4493 if depth == bits {
4494 return;
4495 }
4496 for bit in [false, true] {
4497 let bit_idx = depth;
4498 tree.update(bit, bit_idx);
4499 rec(tree, bits, depth + 1);
4500 tree.revert(bit_idx);
4501 }
4502 }
4503 rec(tree, bits, 0);
4504 }
4505
4506 fn byte_log_prob(tree: &mut FacContextTree, symbol: u8, msb_first: bool, bits: usize) -> f64 {
4507 let before = tree.get_log_block_probability();
4508 if msb_first {
4509 for bit_idx in 0..bits {
4510 let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
4511 tree.update(bit, bit_idx);
4512 }
4513 let after = tree.get_log_block_probability();
4514 for bit_idx in (0..bits).rev() {
4515 tree.revert(bit_idx);
4516 }
4517 after - before
4518 } else {
4519 for bit_idx in 0..bits {
4520 let bit = ((symbol >> bit_idx) & 1) == 1;
4521 tree.update(bit, bit_idx);
4522 }
4523 let after = tree.get_log_block_probability();
4524 for bit_idx in (0..bits).rev() {
4525 tree.revert(bit_idx);
4526 }
4527 after - before
4528 }
4529 }
4530
4531 fn assert_symbol_scan_then_update_matches_plain(msb_first: bool) {
4532 let bits = 8usize;
4533 let mut with_scan = FacContextTree::new(7, bits);
4534 let mut plain = with_scan.clone();
4535 for &byte in b"pdf then update parity payload" {
4536 for bit_idx in 0..bits {
4537 let bit = if msb_first {
4538 ((byte >> (7 - bit_idx)) & 1) == 1
4539 } else {
4540 ((byte >> bit_idx) & 1) == 1
4541 };
4542 with_scan.update(bit, bit_idx);
4543 plain.update(bit, bit_idx);
4544 }
4545 }
4546
4547 scan_symbol_space(&mut with_scan, bits);
4548
4549 let observed = b'n';
4550 for bit_idx in 0..bits {
4551 let bit = if msb_first {
4552 ((observed >> (7 - bit_idx)) & 1) == 1
4553 } else {
4554 ((observed >> bit_idx) & 1) == 1
4555 };
4556 with_scan.update(bit, bit_idx);
4557 plain.update(bit, bit_idx);
4558 }
4559
4560 for sym in 0u8..=255u8 {
4561 let lp_scan = byte_log_prob(&mut with_scan, sym, msb_first, bits);
4562 let lp_plain = byte_log_prob(&mut plain, sym, msb_first, bits);
4563 let diff = (lp_scan - lp_plain).abs();
4564 assert!(
4565 diff < 1e-12,
4566 "symbol={sym} lp_scan={lp_scan} lp_plain={lp_plain} diff={diff}",
4567 );
4568 }
4569 }
4570
4571 #[test]
4572 fn fac_ctw_symbol_scan_then_update_matches_plain_msb() {
4573 assert_symbol_scan_then_update_matches_plain(true);
4574 }
4575
4576 #[test]
4577 fn fac_ctw_symbol_scan_then_update_matches_plain_lsb() {
4578 assert_symbol_scan_then_update_matches_plain(false);
4579 }
4580}