infotheory/backends/
ctw.rs

1//! Context Tree Weighting (CTW) and Factorized Action-Conditional CTW (FAC-CTW).
2//!
3//! This implementation stores maximal non-root unary runs as chains while keeping
4//! the root and true branching points explicit. Updates and reverts rebuild only
5//! the touched context path from exact per-node state, which preserves predictive
6//! semantics while avoiding the `O(depth)` explicit-node blow-up for singleton
7//! paths.
8
9use std::cell::RefCell;
10use std::f64;
11use std::mem::size_of;
12
13type Symbol = bool;
14
15#[inline(always)]
16fn ensure_log_caches(log_int: &mut Vec<f64>, log_half: &mut Vec<f64>, upto: usize) {
17    if upto < log_int.len() {
18        return;
19    }
20    let start = log_int.len();
21    log_int.reserve(upto + 1 - start);
22    log_half.reserve(upto + 1 - start);
23    for n in start..=upto {
24        if n == 0 {
25            log_int.push(f64::NEG_INFINITY);
26        } else {
27            log_int.push((n as f64).ln());
28        }
29        log_half.push((n as f64 + 0.5).ln());
30    }
31}
32
33#[derive(Default)]
34struct SharedLogCache {
35    log_int: Vec<f64>,
36    log_half: Vec<f64>,
37}
38
39impl SharedLogCache {
40    fn new() -> Self {
41        Self {
42            log_int: vec![f64::NEG_INFINITY],
43            log_half: vec![(0.5f64).ln()],
44        }
45    }
46
47    #[inline(always)]
48    fn ensure(&mut self, upto: usize) {
49        ensure_log_caches(&mut self.log_int, &mut self.log_half, upto);
50    }
51
52    #[inline(always)]
53    fn memory_usage(&self) -> usize {
54        self.log_int.capacity() * size_of::<f64>() + self.log_half.capacity() * size_of::<f64>()
55    }
56}
57
58thread_local! {
59    static CTW_SHARED_LOG_CACHE: RefCell<SharedLogCache> =
60        RefCell::new(SharedLogCache::new());
61}
62
63#[inline]
64fn with_shared_log_cache<R>(upto: usize, f: impl FnOnce(&[f64], &[f64]) -> R) -> R {
65    CTW_SHARED_LOG_CACHE.with(|cache_cell| {
66        let mut cache = cache_cell.borrow_mut();
67        cache.ensure(upto);
68        f(&cache.log_int, &cache.log_half)
69    })
70}
71
72#[inline]
73fn shared_log_cache_memory_usage() -> usize {
74    CTW_SHARED_LOG_CACHE.with(|cache_cell| cache_cell.borrow().memory_usage())
75}
76
77#[cfg(test)]
78#[inline]
79fn shared_log_cache_lens() -> (usize, usize) {
80    CTW_SHARED_LOG_CACHE.with(|cache_cell| {
81        let cache = cache_cell.borrow();
82        (cache.log_int.len(), cache.log_half.len())
83    })
84}
85
86#[inline(always)]
87fn history_symbol(history: &[Symbol], depth: usize) -> Symbol {
88    let idx = history.len().wrapping_sub(depth + 1);
89    if depth < history.len() {
90        unsafe { *history.get_unchecked(idx) }
91    } else {
92        false
93    }
94}
95
96#[inline(always)]
97unsafe fn history_at_or_zero(history_ptr: *const Symbol, history_len: isize, idx: isize) -> Symbol {
98    if idx >= 0 && idx < history_len {
99        *history_ptr.add(idx as usize)
100    } else {
101        false
102    }
103}
104
105const INDEX_BITS: u32 = 31;
106const INDEX_LIMIT: usize = 1usize << INDEX_BITS;
107const CHILD_SEGMENT_TAG: u32 = 1u32 << INDEX_BITS;
108const CHILD_INDEX_MASK: u32 = CHILD_SEGMENT_TAG - 1;
109const SEG_META_MODE_SHIFT: u32 = 30;
110const SEG_META_MODE_MASK: u32 = 0b11 << SEG_META_MODE_SHIFT;
111const SEG_LEN_MASK: u32 = !SEG_META_MODE_MASK;
112const SEG_MODE_EXACT: u32 = 0 << SEG_META_MODE_SHIFT;
113const SEG_MODE_HISTORY: u32 = 1 << SEG_META_MODE_SHIFT;
114const SEG_MODE_HISTORY_INVERT: u32 = 2 << SEG_META_MODE_SHIFT;
115const SEG_MODE_CONST: u32 = 3 << SEG_META_MODE_SHIFT;
116const SEG_EXACT_MAX_LEN: u32 = 64;
117
118/// Index into the explicit-node arena.
119#[derive(Clone, Copy, Debug, PartialEq, Eq)]
120pub struct NodeIndex(u32);
121
122impl NodeIndex {
123    #[cold]
124    #[inline(never)]
125    fn overflow() -> ! {
126        panic!("ctw node index overflow");
127    }
128
129    #[inline(always)]
130    fn from_usize(idx: usize) -> Self {
131        if idx >= INDEX_LIMIT {
132            Self::overflow();
133        }
134        Self(idx as u32)
135    }
136
137    #[inline(always)]
138    fn get(self) -> usize {
139        self.0 as usize
140    }
141}
142
143/// Index into the unary-segment arena.
144#[derive(Clone, Copy, Debug, PartialEq, Eq)]
145struct SegmentIndex(u32);
146
147impl SegmentIndex {
148    #[cold]
149    #[inline(never)]
150    fn overflow() -> ! {
151        panic!("ctw segment index overflow");
152    }
153
154    #[inline(always)]
155    fn from_usize(idx: usize) -> Self {
156        if idx >= INDEX_LIMIT {
157            Self::overflow();
158        }
159        Self(idx as u32)
160    }
161
162    #[inline(always)]
163    fn get(self) -> usize {
164        self.0 as usize
165    }
166}
167
168/// Tagged child reference: `None`, explicit node, or unary segment.
169#[derive(Clone, Copy, Debug, PartialEq, Eq)]
170struct ChildRef(u32);
171
172impl ChildRef {
173    const NONE: ChildRef = ChildRef(u32::MAX);
174
175    #[inline(always)]
176    fn from_node(idx: NodeIndex) -> Self {
177        debug_assert!(idx.0 < CHILD_SEGMENT_TAG);
178        Self(idx.0)
179    }
180
181    #[inline(always)]
182    fn from_segment(idx: SegmentIndex) -> Self {
183        debug_assert!(idx.0 < CHILD_SEGMENT_TAG);
184        Self(CHILD_SEGMENT_TAG | idx.0)
185    }
186
187    #[inline(always)]
188    fn is_none(self) -> bool {
189        self.0 == u32::MAX
190    }
191
192    #[inline(always)]
193    fn is_some(self) -> bool {
194        self.0 != u32::MAX
195    }
196
197    #[inline(always)]
198    fn as_node(self) -> Option<NodeIndex> {
199        if self.is_none() || (self.0 & CHILD_SEGMENT_TAG) != 0 {
200            None
201        } else {
202            Some(NodeIndex(self.0))
203        }
204    }
205
206    #[inline(always)]
207    fn as_segment(self) -> Option<SegmentIndex> {
208        if self.is_none() || (self.0 & CHILD_SEGMENT_TAG) == 0 {
209            None
210        } else {
211            Some(SegmentIndex(self.0 & CHILD_INDEX_MASK))
212        }
213    }
214}
215
216impl Default for ChildRef {
217    fn default() -> Self {
218        Self::NONE
219    }
220}
221
222#[derive(Clone, Copy, Debug, Default)]
223struct SegmentPayload {
224    repr_lo: u32,
225    repr_hi: u32,
226    meta: u32,
227}
228
229impl SegmentPayload {
230    #[inline(always)]
231    fn exact(bits: u64, len: u32) -> Self {
232        debug_assert!(len <= SEG_EXACT_MAX_LEN);
233        debug_assert!(len <= SEG_LEN_MASK);
234        Self {
235            repr_lo: bits as u32,
236            repr_hi: (bits >> 32) as u32,
237            meta: SEG_MODE_EXACT | len,
238        }
239    }
240
241    #[inline(always)]
242    fn history(anchor: u32, len: u32, invert: bool) -> Self {
243        debug_assert!(len <= SEG_LEN_MASK);
244        Self {
245            repr_lo: anchor,
246            repr_hi: 0,
247            meta: if invert {
248                SEG_MODE_HISTORY_INVERT | len
249            } else {
250                SEG_MODE_HISTORY | len
251            },
252        }
253    }
254
255    #[inline(always)]
256    fn constant(bit: bool, len: u32) -> Self {
257        debug_assert!(len <= SEG_LEN_MASK);
258        Self {
259            repr_lo: bit as u32,
260            repr_hi: 0,
261            meta: SEG_MODE_CONST | len,
262        }
263    }
264
265    #[inline(always)]
266    fn len(self) -> u32 {
267        self.meta & SEG_LEN_MASK
268    }
269
270    #[inline(always)]
271    fn set_len(&mut self, len: u32) {
272        debug_assert!(len <= SEG_LEN_MASK);
273        self.meta = (self.meta & SEG_META_MODE_MASK) | len;
274    }
275
276    #[inline(always)]
277    fn mode(self) -> u32 {
278        self.meta & SEG_META_MODE_MASK
279    }
280
281    #[inline(always)]
282    fn is_exact(self) -> bool {
283        self.mode() == SEG_MODE_EXACT
284    }
285
286    #[inline(always)]
287    fn exact_bits(self) -> u64 {
288        (self.repr_lo as u64) | ((self.repr_hi as u64) << 32)
289    }
290
291    #[inline(always)]
292    fn anchor_or_const(self) -> u32 {
293        self.repr_lo
294    }
295
296    #[inline(always)]
297    fn const_bit(self) -> bool {
298        (self.repr_lo & 1) != 0
299    }
300
301    #[inline(always)]
302    fn prepend_exact(self, edge: usize) -> Option<Self> {
303        if !self.is_exact() || self.len() >= SEG_EXACT_MAX_LEN {
304            return None;
305        }
306        let len = self.len() + 1;
307        let bits = ((edge as u64) & 1) | (self.exact_bits() << 1);
308        Some(Self::exact(bits, len))
309    }
310
311    #[inline(always)]
312    fn prefix(self, len: u32) -> Self {
313        debug_assert!(len <= self.len());
314        match self.mode() {
315            SEG_MODE_EXACT => Self::exact(self.exact_bits() & low_bits_mask_u64(len), len),
316            SEG_MODE_HISTORY | SEG_MODE_HISTORY_INVERT => Self {
317                meta: (self.meta & SEG_META_MODE_MASK) | len,
318                ..self
319            },
320            SEG_MODE_CONST => Self::constant(self.const_bit(), len),
321            _ => unreachable!("invalid ctw segment payload mode"),
322        }
323    }
324
325    #[inline(always)]
326    fn suffix_after(self, skip: u32) -> Self {
327        debug_assert!(skip <= self.len());
328        let new_len = self.len() - skip;
329        match self.mode() {
330            SEG_MODE_EXACT => Self::exact(self.exact_bits() >> skip, new_len),
331            SEG_MODE_HISTORY | SEG_MODE_HISTORY_INVERT => Self {
332                repr_lo: self
333                    .anchor_or_const()
334                    .checked_sub(skip)
335                    .expect("ctw history segment anchor underflow"),
336                meta: (self.meta & SEG_META_MODE_MASK) | new_len,
337                ..self
338            },
339            SEG_MODE_CONST => Self::constant(self.const_bit(), new_len),
340            _ => unreachable!("invalid ctw segment payload mode"),
341        }
342    }
343
344    #[inline(always)]
345    fn from_path(history: &[Symbol], depth: usize, len: u32) -> Option<Self> {
346        if len > SEG_EXACT_MAX_LEN {
347            return None;
348        }
349        Some(Self::exact(
350            path_bits_from_history(history, depth, len as usize),
351            len,
352        ))
353    }
354}
355
356#[derive(Clone, Copy, Debug, Default)]
357struct LevelState {
358    symbol_count: [u32; 2],
359    log_prob_kt: f64,
360    sibling: ChildRef,
361}
362
363#[cfg(test)]
364#[allow(dead_code)]
365#[derive(Clone, Copy, Debug)]
366struct PredictEntry {
367    symbol_count: [u32; 2],
368    log_prob_kt: f64,
369    log_prob_weighted: f64,
370    sibling_weight: f64,
371    has_sibling: bool,
372}
373
374#[derive(Clone, Copy, Debug)]
375enum Detach {
376    NodeChild { node: NodeIndex, edge: usize },
377    SegmentNext { segment: SegmentIndex, new_len: u32 },
378}
379
380#[derive(Clone, Copy, Debug, PartialEq, Eq)]
381enum ExistingSource {
382    None,
383    Node(NodeIndex),
384    Segment(SegmentIndex, u32),
385}
386
387#[derive(Clone, Copy, Debug, PartialEq, Eq)]
388enum PreparedEnd {
389    MaxDepth,
390    MissingAtRoot,
391    MissingAfterCurrent,
392    MismatchAtCurrentSegment,
393}
394
395#[derive(Clone, Copy, Debug, PartialEq)]
396struct PreparedStep {
397    source: ExistingSource,
398    counts: [u32; 2],
399    kt_log_prob: f64,
400    span: u32,
401    sibling_weight: f64,
402    has_sibling: u8,
403}
404
405#[derive(Clone, Copy, Debug)]
406/// Compact node payload used by [`CtArena`].
407pub struct CtNode {
408    children: [ChildRef; 2],
409    log_prob_kt: f64,
410    log_prob_weighted: f64,
411    symbol_count: [u32; 2],
412}
413
414#[derive(Clone, Copy, Debug)]
415struct CtSegment {
416    tail: ChildRef,
417    log_prob_kt: f64,
418    head_log_prob_weighted: f64,
419    symbol_count: [u32; 2],
420    payload: SegmentPayload,
421}
422
423impl Default for CtSegment {
424    fn default() -> Self {
425        Self {
426            tail: ChildRef::NONE,
427            log_prob_kt: 0.0,
428            head_log_prob_weighted: 0.0,
429            symbol_count: [0, 0],
430            payload: SegmentPayload::default(),
431        }
432    }
433}
434
435impl CtSegment {
436    #[inline(always)]
437    fn len(self) -> u32 {
438        self.payload.len()
439    }
440
441    #[inline(always)]
442    fn set_len(&mut self, len: u32) {
443        self.payload.set_len(len);
444    }
445}
446
447#[inline(always)]
448fn low_bits_mask_u64(len: u32) -> u64 {
449    if len >= 64 {
450        u64::MAX
451    } else {
452        (1u64 << len) - 1
453    }
454}
455
456#[inline(always)]
457fn path_bits_from_history(history: &[Symbol], depth: usize, len: usize) -> u64 {
458    let history_len = history.len();
459    let available = history_len.saturating_sub(depth).min(len);
460    if available == 0 {
461        return 0;
462    }
463
464    let mut bits = 0u64;
465    let mut hist_idx = history_len - depth - 1;
466    for offset in 0..available {
467        bits |= (unsafe { *history.get_unchecked(hist_idx) } as u64) << offset;
468        if hist_idx == 0 {
469            break;
470        }
471        hist_idx -= 1;
472    }
473    bits
474}
475
476#[inline(always)]
477fn shift_path_bits(path_bits: u64, consumed: usize) -> u64 {
478    if consumed >= 64 {
479        0
480    } else {
481        path_bits >> consumed
482    }
483}
484
485#[inline(always)]
486fn first_exact_segment_mismatch(
487    exact_bits: u64,
488    path_bits: u64,
489    comparable_len: usize,
490) -> Option<(usize, bool, bool)> {
491    if comparable_len == 0 {
492        return None;
493    }
494
495    let diff = (exact_bits ^ path_bits) & low_bits_mask_u64(comparable_len as u32);
496    if diff == 0 {
497        None
498    } else {
499        let offset = diff.trailing_zeros() as usize;
500        Some((
501            offset,
502            ((path_bits >> offset) & 1) != 0,
503            ((exact_bits >> offset) & 1) != 0,
504        ))
505    }
506}
507
508#[inline(always)]
509fn predict_ratio_kt(counts: [u32; 2], sym_idx: usize) -> f64 {
510    let total = (counts[0] + counts[1]) as f64;
511    let sym_count = counts[sym_idx] as f64;
512    (sym_count + 0.5) / (total + 1.0)
513}
514
515#[inline(always)]
516fn predict_ratio_kt_one(counts: [u32; 2]) -> f64 {
517    let total = (counts[0] + counts[1]) as f64;
518    let sym_count = counts[1] as f64;
519    (sym_count + 0.5) / (total + 1.0)
520}
521
522#[inline(always)]
523fn update_weighted_log_prob_non_leaf(kt_log_prob: f64, log_prob_w0: f64, log_prob_w1: f64) -> f64 {
524    let child_log_prob = log_prob_w0 + log_prob_w1;
525    let delta = child_log_prob - kt_log_prob;
526    let log_prob_weighted = if delta >= 0.0 {
527        child_log_prob + (-delta).exp().ln_1p() - std::f64::consts::LN_2
528    } else {
529        kt_log_prob + delta.exp().ln_1p() - std::f64::consts::LN_2
530    };
531    clamp_log_prob(log_prob_weighted)
532}
533
534#[inline(always)]
535fn update_weighted_log_prob(
536    kt_log_prob: f64,
537    log_prob_w0: f64,
538    log_prob_w1: f64,
539    is_leaf: bool,
540) -> f64 {
541    if is_leaf {
542        clamp_log_prob(kt_log_prob)
543    } else {
544        update_weighted_log_prob_non_leaf(kt_log_prob, log_prob_w0, log_prob_w1)
545    }
546}
547
548#[inline(always)]
549fn clamp_log_prob(log_prob: f64) -> f64 {
550    if log_prob > 1.0e-10 { 0.0 } else { log_prob }
551}
552
553#[inline(always)]
554fn logsumexp_pair(lhs: f64, rhs: f64) -> f64 {
555    if lhs == f64::NEG_INFINITY {
556        return rhs;
557    }
558    if rhs == f64::NEG_INFINITY {
559        return lhs;
560    }
561    let pivot = lhs.max(rhs);
562    pivot + ((lhs - pivot).exp() + (rhs - pivot).exp()).ln()
563}
564
565#[inline(always)]
566fn unary_chain_log_weight(kt_log_prob: f64, continuation_log_prob: f64, len: u32) -> f64 {
567    debug_assert!(len > 0);
568    if kt_log_prob.to_bits() == continuation_log_prob.to_bits() {
569        return kt_log_prob;
570    }
571    let log_alpha = -(len as f64) * std::f64::consts::LN_2;
572    let alpha = log_alpha.exp();
573    let log_kt_mass = kt_log_prob + (-alpha).ln_1p();
574    let log_cont_mass = continuation_log_prob + log_alpha;
575    clamp_log_prob(logsumexp_pair(log_kt_mass, log_cont_mass))
576}
577
578#[inline(always)]
579fn combined_weight_ratio_internal(
580    kt_log_prob: f64,
581    counts: [u32; 2],
582    path_child_log_prob: f64,
583    sibling_log_prob: f64,
584    child_ratio: f64,
585    sym_idx: usize,
586) -> (f64, f64) {
587    let kt_ratio = predict_ratio_kt(counts, sym_idx);
588    let child_log_prob = path_child_log_prob + sibling_log_prob;
589    let delta = child_log_prob - kt_log_prob;
590    if delta >= 0.0 {
591        let x = (-delta).exp();
592        (
593            clamp_log_prob(child_log_prob + x.ln_1p() - std::f64::consts::LN_2),
594            (kt_ratio * x + child_ratio) / (1.0 + x),
595        )
596    } else {
597        let x = delta.exp();
598        (
599            clamp_log_prob(kt_log_prob + x.ln_1p() - std::f64::consts::LN_2),
600            (kt_ratio + x * child_ratio) / (1.0 + x),
601        )
602    }
603}
604
605#[inline(always)]
606fn combined_weight_ratio_internal_one(
607    kt_log_prob: f64,
608    counts: [u32; 2],
609    path_child_log_prob: f64,
610    sibling_log_prob: f64,
611    child_ratio: f64,
612) -> (f64, f64) {
613    let kt_ratio = predict_ratio_kt_one(counts);
614    let child_log_prob = path_child_log_prob + sibling_log_prob;
615    let delta = child_log_prob - kt_log_prob;
616    if delta >= 0.0 {
617        let x = (-delta).exp();
618        (
619            clamp_log_prob(child_log_prob + x.ln_1p() - std::f64::consts::LN_2),
620            (kt_ratio * x + child_ratio) / (1.0 + x),
621        )
622    } else {
623        let x = delta.exp();
624        (
625            clamp_log_prob(kt_log_prob + x.ln_1p() - std::f64::consts::LN_2),
626            (kt_ratio + x * child_ratio) / (1.0 + x),
627        )
628    }
629}
630
631#[inline(always)]
632fn unary_chain_log_weight_precomputed(
633    kt_log_prob: f64,
634    continuation_log_prob: f64,
635    alpha: f64,
636    log_alpha: f64,
637    log_one_minus_alpha: f64,
638) -> f64 {
639    if kt_log_prob.to_bits() == continuation_log_prob.to_bits() {
640        return clamp_log_prob(kt_log_prob);
641    }
642
643    let delta = continuation_log_prob - kt_log_prob;
644    let log_prob_weighted = if delta >= 0.0 {
645        let x = ((1.0 - alpha) / alpha) * (-delta).exp();
646        continuation_log_prob + log_alpha + x.ln_1p()
647    } else {
648        let x = (alpha / (1.0 - alpha)) * delta.exp();
649        kt_log_prob + log_one_minus_alpha + x.ln_1p()
650    };
651    clamp_log_prob(log_prob_weighted)
652}
653
654#[inline(always)]
655fn unary_chain_ratio_transform_precomputed(
656    kt_log_prob: f64,
657    counts: [u32; 2],
658    continuation_log_prob: f64,
659    continuation_ratio: f64,
660    alpha: f64,
661    log_alpha: f64,
662    log_one_minus_alpha: f64,
663    sym_idx: usize,
664) -> (f64, f64) {
665    let kt_ratio = predict_ratio_kt(counts, sym_idx);
666    if kt_log_prob.to_bits() == continuation_log_prob.to_bits()
667        && kt_ratio.to_bits() == continuation_ratio.to_bits()
668    {
669        return (clamp_log_prob(kt_log_prob), kt_ratio);
670    }
671
672    let delta = continuation_log_prob - kt_log_prob;
673    if delta >= 0.0 {
674        let x = ((1.0 - alpha) / alpha) * (-delta).exp();
675        (
676            clamp_log_prob(continuation_log_prob + log_alpha + x.ln_1p()),
677            (kt_ratio * x + continuation_ratio) / (1.0 + x),
678        )
679    } else {
680        let x = (alpha / (1.0 - alpha)) * delta.exp();
681        (
682            clamp_log_prob(kt_log_prob + log_one_minus_alpha + x.ln_1p()),
683            (kt_ratio + x * continuation_ratio) / (1.0 + x),
684        )
685    }
686}
687
688#[inline(always)]
689fn unary_chain_ratio_transform_precomputed_one(
690    kt_log_prob: f64,
691    counts: [u32; 2],
692    continuation_log_prob: f64,
693    continuation_ratio: f64,
694    alpha: f64,
695    log_alpha: f64,
696    log_one_minus_alpha: f64,
697) -> (f64, f64) {
698    let kt_ratio = predict_ratio_kt_one(counts);
699    if kt_log_prob.to_bits() == continuation_log_prob.to_bits()
700        && kt_ratio.to_bits() == continuation_ratio.to_bits()
701    {
702        return (clamp_log_prob(kt_log_prob), kt_ratio);
703    }
704
705    let delta = continuation_log_prob - kt_log_prob;
706    if delta >= 0.0 {
707        let x = ((1.0 - alpha) / alpha) * (-delta).exp();
708        (
709            clamp_log_prob(continuation_log_prob + log_alpha + x.ln_1p()),
710            (kt_ratio * x + continuation_ratio) / (1.0 + x),
711        )
712    } else {
713        let x = (alpha / (1.0 - alpha)) * delta.exp();
714        (
715            clamp_log_prob(kt_log_prob + log_one_minus_alpha + x.ln_1p()),
716            (kt_ratio + x * continuation_ratio) / (1.0 + x),
717        )
718    }
719}
720
721#[inline(always)]
722fn predict_ratio_internal(
723    kt_log_prob: f64,
724    counts: [u32; 2],
725    path_child_log_prob: f64,
726    sibling_log_prob: f64,
727    child_ratio: f64,
728    sym_idx: usize,
729) -> f64 {
730    let kt_ratio = predict_ratio_kt(counts, sym_idx);
731    let delta = path_child_log_prob + sibling_log_prob - kt_log_prob;
732    if delta >= 0.0 {
733        let inv_rho = (-delta).exp();
734        (kt_ratio * inv_rho + child_ratio) / (1.0 + inv_rho)
735    } else {
736        let rho = delta.exp();
737        (kt_ratio + rho * child_ratio) / (1.0 + rho)
738    }
739}
740
741#[inline(always)]
742fn predict_ratio_internal_one(
743    kt_log_prob: f64,
744    counts: [u32; 2],
745    path_child_log_prob: f64,
746    sibling_log_prob: f64,
747    child_ratio: f64,
748) -> f64 {
749    let kt_ratio = predict_ratio_kt_one(counts);
750    let delta = path_child_log_prob + sibling_log_prob - kt_log_prob;
751    if delta >= 0.0 {
752        let inv_rho = (-delta).exp();
753        (kt_ratio * inv_rho + child_ratio) / (1.0 + inv_rho)
754    } else {
755        let rho = delta.exp();
756        (kt_ratio + rho * child_ratio) / (1.0 + rho)
757    }
758}
759
760#[inline(always)]
761fn path_edge_at_depth(history: &[Symbol], history_len: usize, depth: usize) -> bool {
762    if depth < history_len {
763        history[history_len - depth - 1]
764    } else {
765        false
766    }
767}
768
769#[inline(always)]
770fn segment_edge_from_parts(
771    segment: CtSegment,
772    offset: usize,
773    history: &[Symbol],
774    history_len: usize,
775) -> bool {
776    match segment.payload.mode() {
777        SEG_MODE_EXACT => ((segment.payload.exact_bits() >> offset) & 1) != 0,
778        SEG_MODE_HISTORY | SEG_MODE_HISTORY_INVERT => {
779            if segment.payload.anchor_or_const() as usize >= offset {
780                let hist_idx = segment.payload.anchor_or_const() as usize - offset;
781                if hist_idx < history_len {
782                    let raw = history[hist_idx];
783                    if segment.payload.mode() == SEG_MODE_HISTORY_INVERT {
784                        !raw
785                    } else {
786                        raw
787                    }
788                } else {
789                    segment.payload.mode() == SEG_MODE_HISTORY_INVERT
790                }
791            } else {
792                segment.payload.mode() == SEG_MODE_HISTORY_INVERT
793            }
794        }
795        SEG_MODE_CONST => segment.payload.const_bit(),
796        _ => unreachable!("invalid ctw segment payload mode"),
797    }
798}
799
800#[inline(always)]
801fn first_segment_mismatch(
802    segment: CtSegment,
803    depth: usize,
804    history: &[Symbol],
805    comparable_len: usize,
806) -> Option<(usize, bool, bool)> {
807    if comparable_len == 0 {
808        return None;
809    }
810
811    match segment.payload.mode() {
812        SEG_MODE_EXACT => first_exact_segment_mismatch(
813            segment.payload.exact_bits(),
814            path_bits_from_history(history, depth, comparable_len),
815            comparable_len,
816        ),
817        SEG_MODE_HISTORY | SEG_MODE_HISTORY_INVERT => {
818            let history_ptr = history.as_ptr();
819            let history_len = history.len() as isize;
820            let mut path_hist_idx = history_len - depth as isize - 1;
821            let mut seg_hist_idx = segment.payload.anchor_or_const() as isize;
822            let invert = segment.payload.mode() == SEG_MODE_HISTORY_INVERT;
823            for offset in 0..comparable_len {
824                let path_edge =
825                    unsafe { history_at_or_zero(history_ptr, history_len, path_hist_idx) };
826                let existing_raw =
827                    unsafe { history_at_or_zero(history_ptr, history_len, seg_hist_idx) };
828                let existing_edge = if invert { !existing_raw } else { existing_raw };
829                if existing_edge != path_edge {
830                    return Some((offset, path_edge, existing_edge));
831                }
832                path_hist_idx -= 1;
833                seg_hist_idx -= 1;
834            }
835            None
836        }
837        SEG_MODE_CONST => {
838            let history_ptr = history.as_ptr();
839            let history_len = history.len() as isize;
840            let mut path_hist_idx = history_len - depth as isize - 1;
841            let existing_edge = segment.payload.const_bit();
842            for offset in 0..comparable_len {
843                let path_edge =
844                    unsafe { history_at_or_zero(history_ptr, history_len, path_hist_idx) };
845                if existing_edge != path_edge {
846                    return Some((offset, path_edge, existing_edge));
847                }
848                path_hist_idx -= 1;
849            }
850            None
851        }
852        _ => unreachable!("invalid ctw segment payload mode"),
853    }
854}
855
856#[inline]
857fn apply_update_to_state_raw(
858    log_int: &[f64],
859    log_half: &[f64],
860    symbol_count: &mut [u32; 2],
861    log_prob_kt: &mut f64,
862    sym_idx: usize,
863) {
864    let total_before = (symbol_count[0] + symbol_count[1]) as usize;
865    let sym_before = symbol_count[sym_idx] as usize;
866    debug_assert!(sym_before <= total_before);
867    debug_assert!(sym_before < log_half.len());
868    debug_assert!(total_before + 1 < log_int.len());
869    let log_half_before = unsafe { *log_half.get_unchecked(sym_before) };
870    let log_total_after = unsafe { *log_int.get_unchecked(total_before + 1) };
871    *log_prob_kt += log_half_before - log_total_after;
872    if *log_prob_kt > 1.0e-10 {
873        *log_prob_kt = 0.0;
874    }
875    symbol_count[sym_idx] = symbol_count[sym_idx]
876        .checked_add(1)
877        .expect("ctw symbol count overflow");
878}
879
880#[inline]
881fn apply_revert_to_state_raw(
882    log_int: &[f64],
883    log_half: &[f64],
884    symbol_count: &mut [u32; 2],
885    log_prob_kt: &mut f64,
886    sym_idx: usize,
887) {
888    let total = (symbol_count[0] + symbol_count[1]) as usize;
889    let sym_count = symbol_count[sym_idx] as usize;
890    if sym_count > 0 && total > 0 {
891        debug_assert!(sym_count - 1 < log_half.len());
892        debug_assert!(total < log_int.len());
893        let log_half_before = unsafe { *log_half.get_unchecked(sym_count - 1) };
894        let log_total = unsafe { *log_int.get_unchecked(total) };
895        *log_prob_kt -= log_half_before - log_total;
896        symbol_count[sym_idx] -= 1;
897    }
898    if *log_prob_kt > 1.0e-10 {
899        *log_prob_kt = 0.0;
900    }
901}
902
903#[derive(Clone, Debug)]
904/// Arena allocator and storage for CTW nodes/segments.
905///
906/// This structure owns all backing memory for compressed CTW trees and provides
907/// index-based access used by the update/revert engine.
908pub struct CtArena {
909    nodes: Vec<CtNode>,
910    segments: Vec<CtSegment>,
911    free_nodes: Vec<NodeIndex>,
912    free_segments: Vec<SegmentIndex>,
913}
914
915impl CtArena {
916    /// Create an empty arena with small default capacities.
917    pub fn new() -> Self {
918        Self {
919            nodes: Vec::with_capacity(1024),
920            segments: Vec::with_capacity(1024),
921            free_nodes: Vec::new(),
922            free_segments: Vec::new(),
923        }
924    }
925
926    /// Create an empty arena with explicit node-oriented capacity hint.
927    pub fn with_capacity(cap: usize) -> Self {
928        Self {
929            nodes: Vec::with_capacity(cap),
930            segments: Vec::with_capacity(cap / 4 + 1),
931            free_nodes: Vec::new(),
932            free_segments: Vec::new(),
933        }
934    }
935
936    #[inline]
937    /// Reserve additional capacity for upcoming node/segment allocations.
938    pub fn reserve_exact(&mut self, additional: usize) {
939        self.nodes.reserve_exact(additional);
940        self.segments.reserve_exact(additional / 4 + 1);
941    }
942
943    #[inline(always)]
944    fn reset_node_slot(&mut self, idx: NodeIndex) {
945        self.nodes[idx.get()] = CtNode {
946            children: [ChildRef::NONE, ChildRef::NONE],
947            log_prob_kt: 0.0,
948            log_prob_weighted: 0.0,
949            symbol_count: [0, 0],
950        };
951    }
952
953    #[inline(always)]
954    fn reset_segment_slot(&mut self, idx: SegmentIndex) {
955        self.segments[idx.get()] = CtSegment::default();
956    }
957
958    #[inline(always)]
959    fn alloc_node(&mut self) -> NodeIndex {
960        if let Some(idx) = self.free_nodes.pop() {
961            self.reset_node_slot(idx);
962            idx
963        } else {
964            let idx = NodeIndex::from_usize(self.nodes.len());
965            self.nodes.push(CtNode {
966                children: [ChildRef::NONE, ChildRef::NONE],
967                log_prob_kt: 0.0,
968                log_prob_weighted: 0.0,
969                symbol_count: [0, 0],
970            });
971            idx
972        }
973    }
974
975    #[inline(always)]
976    fn alloc_node_with_state(&mut self, symbol_count: [u32; 2], log_prob_kt: f64) -> NodeIndex {
977        let idx = self.alloc_node();
978        self.nodes[idx.get()].symbol_count = symbol_count;
979        self.nodes[idx.get()].log_prob_kt = log_prob_kt;
980        idx
981    }
982
983    #[inline(always)]
984    fn free_node(&mut self, idx: NodeIndex) {
985        self.free_nodes.push(idx);
986    }
987
988    #[inline(always)]
989    fn alloc_segment(&mut self) -> SegmentIndex {
990        if let Some(idx) = self.free_segments.pop() {
991            self.reset_segment_slot(idx);
992            idx
993        } else {
994            let idx = SegmentIndex::from_usize(self.segments.len());
995            self.segments.push(CtSegment::default());
996            idx
997        }
998    }
999
1000    #[inline(always)]
1001    fn free_segment(&mut self, idx: SegmentIndex) {
1002        self.reset_segment_slot(idx);
1003        self.free_segments.push(idx);
1004    }
1005
1006    /// Drop all nodes/segments and free-list bookkeeping.
1007    pub fn clear(&mut self) {
1008        self.nodes.clear();
1009        self.segments.clear();
1010        self.free_nodes.clear();
1011        self.free_segments.clear();
1012    }
1013
1014    #[inline(always)]
1015    fn child(&self, parent_idx: NodeIndex, child_idx: usize) -> ChildRef {
1016        debug_assert!(parent_idx.get() < self.nodes.len());
1017        debug_assert!(child_idx < 2);
1018        unsafe {
1019            *self
1020                .nodes
1021                .get_unchecked(parent_idx.get())
1022                .children
1023                .get_unchecked(child_idx)
1024        }
1025    }
1026
1027    #[inline(always)]
1028    fn set_child(&mut self, parent_idx: NodeIndex, child_idx: usize, child: ChildRef) {
1029        debug_assert!(parent_idx.get() < self.nodes.len());
1030        debug_assert!(child_idx < 2);
1031        unsafe {
1032            *self
1033                .nodes
1034                .get_unchecked_mut(parent_idx.get())
1035                .children
1036                .get_unchecked_mut(child_idx) = child;
1037        }
1038    }
1039
1040    #[inline(always)]
1041    fn set_segment_tail(&mut self, segment_idx: SegmentIndex, child: ChildRef) {
1042        self.segments[segment_idx.get()].tail = child;
1043    }
1044
1045    #[inline(always)]
1046    fn counts(&self, idx: NodeIndex) -> [u32; 2] {
1047        self.nodes[idx.get()].symbol_count
1048    }
1049
1050    #[inline(always)]
1051    fn visits(&self, idx: NodeIndex) -> u32 {
1052        let counts = self.nodes[idx.get()].symbol_count;
1053        counts[0] + counts[1]
1054    }
1055
1056    #[inline(always)]
1057    fn segment_symbol_count(&self, segment_idx: SegmentIndex) -> [u32; 2] {
1058        self.segments[segment_idx.get()].symbol_count
1059    }
1060
1061    #[inline(always)]
1062    fn segment_log_prob_kt(&self, segment_idx: SegmentIndex) -> f64 {
1063        self.segments[segment_idx.get()].log_prob_kt
1064    }
1065
1066    #[inline(always)]
1067    fn segment_len(&self, segment_idx: SegmentIndex) -> u32 {
1068        self.segments[segment_idx.get()].len()
1069    }
1070
1071    #[inline(always)]
1072    fn segment_has_child(&self, segment_idx: SegmentIndex, offset: u32) -> bool {
1073        let segment = self.segments[segment_idx.get()];
1074        offset + 1 < segment.len() || segment.tail.is_some()
1075    }
1076
1077    #[inline(always)]
1078    fn log_prob_weighted(&self, idx: NodeIndex) -> f64 {
1079        self.nodes[idx.get()].log_prob_weighted
1080    }
1081
1082    #[inline(always)]
1083    fn log_prob_kt(&self, idx: NodeIndex) -> f64 {
1084        self.nodes[idx.get()].log_prob_kt
1085    }
1086
1087    #[inline(always)]
1088    unsafe fn child_ref_weighted_unchecked(&self, child: ChildRef) -> f64 {
1089        if child.is_none() {
1090            return 0.0;
1091        }
1092
1093        let raw = child.0;
1094        if (raw & CHILD_SEGMENT_TAG) == 0 {
1095            debug_assert!((raw as usize) < self.nodes.len());
1096            self.nodes.get_unchecked(raw as usize).log_prob_weighted
1097        } else {
1098            let idx = (raw & CHILD_INDEX_MASK) as usize;
1099            debug_assert!(idx < self.segments.len());
1100            self.segments.get_unchecked(idx).head_log_prob_weighted
1101        }
1102    }
1103
1104    #[inline(always)]
1105    fn child_ref_weighted(&self, child: ChildRef) -> f64 {
1106        unsafe { self.child_ref_weighted_unchecked(child) }
1107    }
1108
1109    #[inline(always)]
1110    fn singleton_segment_payload(&self, edge: usize) -> SegmentPayload {
1111        SegmentPayload::exact((edge & 1) as u64, 1)
1112    }
1113
1114    #[inline(always)]
1115    fn segment_edge(&self, segment_idx: SegmentIndex, offset: u32, history: &[Symbol]) -> usize {
1116        let segment = self.segments[segment_idx.get()];
1117        segment_edge_from_parts(segment, offset as usize, history, history.len()) as usize
1118    }
1119
1120    fn segment_suffix_weight(&self, segment_idx: SegmentIndex, offset: u32) -> f64 {
1121        let segment = self.segments[segment_idx.get()];
1122        if offset >= segment.len() {
1123            return self.child_ref_weighted(segment.tail);
1124        }
1125        if segment.tail.is_none() {
1126            return segment.log_prob_kt;
1127        }
1128        let remaining = segment.len() - offset;
1129        unary_chain_log_weight(
1130            segment.log_prob_kt,
1131            self.child_ref_weighted(segment.tail),
1132            remaining,
1133        )
1134    }
1135
1136    #[inline(always)]
1137    fn segment_continuation_weight(&self, segment_idx: SegmentIndex, offset: u32) -> f64 {
1138        let segment = self.segments[segment_idx.get()];
1139        if offset + 1 < segment.len() {
1140            self.segment_suffix_weight(segment_idx, offset + 1)
1141        } else {
1142            self.child_ref_weighted(segment.tail)
1143        }
1144    }
1145
1146    fn recompute_segment_head(&mut self, segment_idx: SegmentIndex) {
1147        let segment = self.segments[segment_idx.get()];
1148        let head = if segment.tail.is_some() {
1149            unary_chain_log_weight(
1150                segment.log_prob_kt,
1151                self.child_ref_weighted(segment.tail),
1152                segment.len(),
1153            )
1154        } else {
1155            segment.log_prob_kt
1156        };
1157        self.segments[segment_idx.get()].head_log_prob_weighted = head;
1158    }
1159
1160    fn recompute_node_weight(&mut self, idx: NodeIndex) {
1161        let slot = idx.get();
1162        debug_assert!(slot < self.nodes.len());
1163        let node = unsafe { *self.nodes.get_unchecked(slot) };
1164        let [left, right] = node.children;
1165        let weighted = if left.is_none() && right.is_none() {
1166            clamp_log_prob(node.log_prob_kt)
1167        } else {
1168            let w0 = unsafe { self.child_ref_weighted_unchecked(left) };
1169            let w1 = unsafe { self.child_ref_weighted_unchecked(right) };
1170            update_weighted_log_prob_non_leaf(node.log_prob_kt, w0, w1)
1171        };
1172        unsafe {
1173            self.nodes.get_unchecked_mut(slot).log_prob_weighted = weighted;
1174        }
1175    }
1176
1177    fn alloc_segment_with_parts(
1178        &mut self,
1179        symbol_count: [u32; 2],
1180        log_prob_kt: f64,
1181        tail: ChildRef,
1182        payload: SegmentPayload,
1183    ) -> SegmentIndex {
1184        let segment_idx = self.alloc_segment();
1185        self.segments[segment_idx.get()] = CtSegment {
1186            tail,
1187            log_prob_kt,
1188            head_log_prob_weighted: 0.0,
1189            symbol_count,
1190            payload,
1191        };
1192        if payload.len() == 1 && tail.is_none() {
1193            self.segments[segment_idx.get()].head_log_prob_weighted = log_prob_kt;
1194        } else {
1195            self.recompute_segment_head(segment_idx);
1196        }
1197        segment_idx
1198    }
1199
1200    fn detach_segment_continuation(
1201        &mut self,
1202        segment_idx: SegmentIndex,
1203        offset: u32,
1204        detaches: &mut Vec<Detach>,
1205    ) -> ChildRef {
1206        let segment = self.segments[segment_idx.get()];
1207        if offset + 1 < segment.len() {
1208            let suffix = self.alloc_segment_with_parts(
1209                segment.symbol_count,
1210                segment.log_prob_kt,
1211                segment.tail,
1212                segment.payload.suffix_after(offset + 1),
1213            );
1214            detaches.push(Detach::SegmentNext {
1215                segment: segment_idx,
1216                new_len: offset + 1,
1217            });
1218            ChildRef::from_segment(suffix)
1219        } else {
1220            let tail = segment.tail;
1221            if tail.is_some() {
1222                detaches.push(Detach::SegmentNext {
1223                    segment: segment_idx,
1224                    new_len: segment.len(),
1225                });
1226            }
1227            tail
1228        }
1229    }
1230
1231    fn prepend_or_alloc_segment(
1232        &mut self,
1233        history: &[Symbol],
1234        depth: usize,
1235        symbol_count: [u32; 2],
1236        log_prob_kt: f64,
1237        child: ChildRef,
1238        edge: usize,
1239        allow_history_pattern: bool,
1240    ) -> ChildRef {
1241        let singleton_payload = self.singleton_segment_payload(edge);
1242
1243        if let Some(segment_idx) = child.as_segment() {
1244            let segment = self.segments[segment_idx.get()];
1245            let same_state = segment.symbol_count == symbol_count
1246                && segment.log_prob_kt.to_bits() == log_prob_kt.to_bits();
1247            if same_state && segment.tail == child {
1248                let segment = &mut self.segments[segment_idx.get()];
1249                let extended_payload = if segment.payload.is_exact() {
1250                    segment.payload.prepend_exact(edge)
1251                } else if allow_history_pattern {
1252                    let path_payload =
1253                        SegmentPayload::from_path(history, depth, segment.len().saturating_add(1));
1254                    path_payload.filter(|payload| {
1255                        let mut matches = true;
1256                        for offset in 0..segment.len() as usize {
1257                            let seg_edge =
1258                                segment_edge_from_parts(*segment, offset, history, history.len());
1259                            let payload_edge = ((payload.exact_bits() >> (offset + 1)) & 1) != 0;
1260                            if seg_edge != payload_edge {
1261                                matches = false;
1262                                break;
1263                            }
1264                        }
1265                        matches
1266                    })
1267                } else {
1268                    None
1269                };
1270                if let Some(payload) = extended_payload {
1271                    let old_head = segment.head_log_prob_weighted;
1272                    segment.payload = payload;
1273                    segment.head_log_prob_weighted =
1274                        update_weighted_log_prob(log_prob_kt, old_head, 0.0, false);
1275                    return ChildRef::from_segment(segment_idx);
1276                }
1277            }
1278        }
1279
1280        let segment_idx =
1281            self.alloc_segment_with_parts(symbol_count, log_prob_kt, child, singleton_payload);
1282        ChildRef::from_segment(segment_idx)
1283    }
1284
1285    fn free_child_ref(&mut self, child: ChildRef) {
1286        let mut stack = Vec::with_capacity(16);
1287        if child.is_some() {
1288            stack.push(child);
1289        }
1290        while let Some(next) = stack.pop() {
1291            if let Some(node_idx) = next.as_node() {
1292                let children = self.nodes[node_idx.get()].children;
1293                if children[0].is_some() {
1294                    stack.push(children[0]);
1295                }
1296                if children[1].is_some() {
1297                    stack.push(children[1]);
1298                }
1299                self.free_node(node_idx);
1300            } else if let Some(segment_idx) = next.as_segment() {
1301                let tail = self.segments[segment_idx.get()].tail;
1302                if tail.is_some() {
1303                    stack.push(tail);
1304                }
1305                self.free_segment(segment_idx);
1306            }
1307        }
1308    }
1309
1310    /// Approximate heap usage (bytes) for arena-owned storage.
1311    pub fn memory_usage(&self) -> usize {
1312        self.nodes.capacity() * size_of::<CtNode>()
1313            + self.segments.capacity() * size_of::<CtSegment>()
1314            + self.free_nodes.capacity() * size_of::<NodeIndex>()
1315            + self.free_segments.capacity() * size_of::<SegmentIndex>()
1316    }
1317}
1318
1319impl Default for CtArena {
1320    fn default() -> Self {
1321        Self::new()
1322    }
1323}
1324
1325#[derive(Clone)]
1326struct CtEngine {
1327    arena: CtArena,
1328    root: NodeIndex,
1329    max_depth: usize,
1330    segment_alpha: Vec<f64>,
1331    segment_log_alpha: Vec<f64>,
1332    segment_log_one_minus_alpha: Vec<f64>,
1333    levels: Vec<LevelState>,
1334    detaches: Vec<Detach>,
1335    prepared_steps: Vec<PreparedStep>,
1336    prepared_levels: usize,
1337    prepared_end: PreparedEnd,
1338}
1339
1340impl CtEngine {
1341    const RESERVE_MIN_NODES: usize = 4 * 1024;
1342    const RESERVE_MAX_NODES: usize = 1 << 18;
1343    const HOT_PREFIX_DEPTH: usize = 10;
1344
1345    fn new(depth: usize) -> Self {
1346        let mut arena = CtArena::with_capacity(1024.min(1 << depth.min(16)));
1347        let root = arena.alloc_node();
1348        let mut segment_alpha = Vec::with_capacity(depth + 1);
1349        let mut segment_log_alpha = Vec::with_capacity(depth + 1);
1350        let mut segment_log_one_minus_alpha = Vec::with_capacity(depth + 1);
1351        segment_alpha.push(1.0);
1352        segment_log_alpha.push(0.0);
1353        segment_log_one_minus_alpha.push(f64::NEG_INFINITY);
1354        let mut alpha = 1.0f64;
1355        for len in 1..=depth {
1356            alpha *= 0.5;
1357            segment_alpha.push(alpha);
1358            segment_log_alpha.push(-(len as f64) * std::f64::consts::LN_2);
1359            segment_log_one_minus_alpha.push((-alpha).ln_1p());
1360        }
1361        Self {
1362            arena,
1363            root,
1364            max_depth: depth,
1365            segment_alpha,
1366            segment_log_alpha,
1367            segment_log_one_minus_alpha,
1368            levels: vec![LevelState::default(); depth],
1369            detaches: Vec::with_capacity(depth),
1370            prepared_steps: Vec::with_capacity(depth),
1371            prepared_levels: 0,
1372            prepared_end: PreparedEnd::MaxDepth,
1373        }
1374    }
1375
1376    #[inline(always)]
1377    fn root_visits(&self) -> usize {
1378        self.arena.visits(self.root) as usize
1379    }
1380
1381    #[inline(always)]
1382    fn hot_prefix_depth(&self) -> usize {
1383        self.max_depth.min(Self::HOT_PREFIX_DEPTH)
1384    }
1385
1386    fn clear(&mut self) {
1387        self.arena.clear();
1388        self.root = self.arena.alloc_node();
1389        self.levels.fill(LevelState::default());
1390        self.detaches.clear();
1391        self.prepared_steps.clear();
1392        self.prepared_levels = 0;
1393        self.prepared_end = PreparedEnd::MaxDepth;
1394    }
1395
1396    #[inline]
1397    fn reserve_for_symbols(&mut self, total_symbols: usize) {
1398        if total_symbols == 0 {
1399            return;
1400        }
1401
1402        let depth_scale = self.max_depth.saturating_add(1);
1403        let reserve_nodes = total_symbols
1404            .saturating_div(depth_scale)
1405            .clamp(Self::RESERVE_MIN_NODES, Self::RESERVE_MAX_NODES);
1406        let free_nodes = self
1407            .arena
1408            .nodes
1409            .capacity()
1410            .saturating_sub(self.arena.nodes.len());
1411        if reserve_nodes > free_nodes {
1412            self.arena.reserve_exact(reserve_nodes - free_nodes);
1413        }
1414    }
1415
1416    #[inline]
1417    fn get_log_block_probability(&self) -> f64 {
1418        self.arena.log_prob_weighted(self.root)
1419    }
1420
1421    #[inline]
1422    fn with_logs<R>(&mut self, upto: usize, f: impl FnOnce(&mut Self, &[f64], &[f64]) -> R) -> R {
1423        with_shared_log_cache(upto, |log_int, log_half| f(self, log_int, log_half))
1424    }
1425
1426    #[inline]
1427    fn log_cache_memory_usage(&self) -> usize {
1428        shared_log_cache_memory_usage()
1429    }
1430
1431    #[inline(always)]
1432    fn segment_constants(&self, len: u32) -> (f64, f64, f64) {
1433        let idx = len as usize;
1434        debug_assert!(idx < self.segment_alpha.len());
1435        debug_assert!(idx < self.segment_log_alpha.len());
1436        debug_assert!(idx < self.segment_log_one_minus_alpha.len());
1437        (
1438            unsafe { *self.segment_alpha.get_unchecked(idx) },
1439            unsafe { *self.segment_log_alpha.get_unchecked(idx) },
1440            unsafe { *self.segment_log_one_minus_alpha.get_unchecked(idx) },
1441        )
1442    }
1443
1444    fn build_missing_segment_path(
1445        &mut self,
1446        depth: usize,
1447        history: &[Symbol],
1448        sym_idx: usize,
1449        singleton_log_prob_kt: f64,
1450    ) -> ChildRef {
1451        if depth > self.max_depth {
1452            return ChildRef::NONE;
1453        }
1454
1455        let mut counts = [0u32; 2];
1456        counts[sym_idx] = 1;
1457        let log_prob_kt = singleton_log_prob_kt;
1458        let total_len = self.max_depth - depth + 1;
1459
1460        if let Some(payload) = SegmentPayload::from_path(history, depth, total_len as u32) {
1461            let segment =
1462                self.arena
1463                    .alloc_segment_with_parts(counts, log_prob_kt, ChildRef::NONE, payload);
1464            return ChildRef::from_segment(segment);
1465        }
1466
1467        let history_nodes = if depth < history.len() {
1468            (self.max_depth.min(history.len() - 1) - depth) + 1
1469        } else {
1470            0
1471        };
1472        let const_nodes = total_len - history_nodes;
1473
1474        let mut built = ChildRef::NONE;
1475        if const_nodes > 0 {
1476            let const_segment = self.arena.alloc_segment_with_parts(
1477                counts,
1478                log_prob_kt,
1479                ChildRef::NONE,
1480                SegmentPayload::constant(false, const_nodes as u32),
1481            );
1482            built = ChildRef::from_segment(const_segment);
1483        }
1484        if history_nodes > 0 {
1485            let history_segment = self.arena.alloc_segment_with_parts(
1486                counts,
1487                log_prob_kt,
1488                built,
1489                SegmentPayload::history(
1490                    (history.len() - depth - 1) as u32,
1491                    history_nodes as u32,
1492                    false,
1493                ),
1494            );
1495            built = ChildRef::from_segment(history_segment);
1496        }
1497        built
1498    }
1499
1500    fn build_missing_path(
1501        &mut self,
1502        depth: usize,
1503        history: &[Symbol],
1504        sym_idx: usize,
1505        singleton_log_prob_kt: f64,
1506    ) -> ChildRef {
1507        if depth > self.max_depth {
1508            return ChildRef::NONE;
1509        }
1510
1511        let hot_prefix_depth = self.hot_prefix_depth();
1512        if depth > hot_prefix_depth {
1513            return self.build_missing_segment_path(depth, history, sym_idx, singleton_log_prob_kt);
1514        }
1515
1516        let mut counts = [0u32; 2];
1517        counts[sym_idx] = 1;
1518        let mut built = if hot_prefix_depth < self.max_depth {
1519            self.build_missing_segment_path(
1520                hot_prefix_depth + 1,
1521                history,
1522                sym_idx,
1523                singleton_log_prob_kt,
1524            )
1525        } else {
1526            ChildRef::NONE
1527        };
1528
1529        for node_depth in (depth..=hot_prefix_depth).rev() {
1530            let node = self
1531                .arena
1532                .alloc_node_with_state(counts, singleton_log_prob_kt);
1533            if node_depth < self.max_depth {
1534                let edge = history_symbol(history, node_depth) as usize;
1535                self.arena.set_child(node, edge, built);
1536            }
1537            self.arena.recompute_node_weight(node);
1538            built = ChildRef::from_node(node);
1539        }
1540        built
1541    }
1542
1543    #[inline(always)]
1544    fn build_missing_segment_path_exact_bits(
1545        &mut self,
1546        depth: usize,
1547        path_bits: u64,
1548        sym_idx: usize,
1549        singleton_log_prob_kt: f64,
1550    ) -> ChildRef {
1551        debug_assert!(self.max_depth <= SEG_EXACT_MAX_LEN as usize);
1552        if depth > self.max_depth {
1553            return ChildRef::NONE;
1554        }
1555
1556        let mut counts = [0u32; 2];
1557        counts[sym_idx] = 1;
1558        let total_len = self.max_depth - depth + 1;
1559        let payload = SegmentPayload::exact(
1560            path_bits & low_bits_mask_u64(total_len as u32),
1561            total_len as u32,
1562        );
1563        let segment = self.arena.alloc_segment_with_parts(
1564            counts,
1565            singleton_log_prob_kt,
1566            ChildRef::NONE,
1567            payload,
1568        );
1569        ChildRef::from_segment(segment)
1570    }
1571
1572    #[inline(always)]
1573    fn build_missing_path_exact_bits(
1574        &mut self,
1575        depth: usize,
1576        path_bits: u64,
1577        sym_idx: usize,
1578        singleton_log_prob_kt: f64,
1579    ) -> ChildRef {
1580        debug_assert!(self.max_depth <= SEG_EXACT_MAX_LEN as usize);
1581        if depth > self.max_depth {
1582            return ChildRef::NONE;
1583        }
1584
1585        let hot_prefix_depth = self.hot_prefix_depth();
1586        if depth > hot_prefix_depth {
1587            return self.build_missing_segment_path_exact_bits(
1588                depth,
1589                path_bits,
1590                sym_idx,
1591                singleton_log_prob_kt,
1592            );
1593        }
1594
1595        let mut counts = [0u32; 2];
1596        counts[sym_idx] = 1;
1597        let mut built = if hot_prefix_depth < self.max_depth {
1598            self.build_missing_segment_path_exact_bits(
1599                hot_prefix_depth + 1,
1600                shift_path_bits(path_bits, hot_prefix_depth + 1 - depth),
1601                sym_idx,
1602                singleton_log_prob_kt,
1603            )
1604        } else {
1605            ChildRef::NONE
1606        };
1607
1608        for node_depth in (depth..=hot_prefix_depth).rev() {
1609            let node = self
1610                .arena
1611                .alloc_node_with_state(counts, singleton_log_prob_kt);
1612            if node_depth < self.max_depth {
1613                let edge = ((path_bits >> (node_depth - depth)) & 1) as usize;
1614                self.arena.set_child(node, edge, built);
1615            }
1616            self.arena.recompute_node_weight(node);
1617            built = ChildRef::from_node(node);
1618        }
1619        built
1620    }
1621
1622    #[inline(always)]
1623    fn child_to_existing_source(child: ChildRef) -> Option<ExistingSource> {
1624        if let Some(node) = child.as_node() {
1625            Some(ExistingSource::Node(node))
1626        } else if let Some(segment) = child.as_segment() {
1627            Some(ExistingSource::Segment(segment, 0))
1628        } else {
1629            None
1630        }
1631    }
1632
1633    #[inline(always)]
1634    fn update_source_state(
1635        &mut self,
1636        log_int: &[f64],
1637        log_half: &[f64],
1638        source: ExistingSource,
1639        sym_idx: usize,
1640    ) {
1641        match source {
1642            ExistingSource::Node(node_idx) => {
1643                let slot = node_idx.get();
1644                let mut counts = self.arena.nodes[slot].symbol_count;
1645                let mut log_prob_kt = self.arena.nodes[slot].log_prob_kt;
1646                apply_update_to_state_raw(
1647                    log_int,
1648                    log_half,
1649                    &mut counts,
1650                    &mut log_prob_kt,
1651                    sym_idx,
1652                );
1653                self.arena.nodes[slot].symbol_count = counts;
1654                self.arena.nodes[slot].log_prob_kt = log_prob_kt;
1655            }
1656            ExistingSource::Segment(segment_idx, _) => {
1657                let slot = segment_idx.get();
1658                let mut counts = self.arena.segments[slot].symbol_count;
1659                let mut log_prob_kt = self.arena.segments[slot].log_prob_kt;
1660                apply_update_to_state_raw(
1661                    log_int,
1662                    log_half,
1663                    &mut counts,
1664                    &mut log_prob_kt,
1665                    sym_idx,
1666                );
1667                self.arena.segments[slot].symbol_count = counts;
1668                self.arena.segments[slot].log_prob_kt = log_prob_kt;
1669            }
1670            ExistingSource::None => unreachable!("prepared update should never visit None"),
1671        }
1672    }
1673
1674    #[inline(always)]
1675    fn recompute_source_weight(&mut self, source: ExistingSource) {
1676        match source {
1677            ExistingSource::Node(node_idx) => self.arena.recompute_node_weight(node_idx),
1678            ExistingSource::Segment(segment_idx, _) => self.recompute_segment_head(segment_idx),
1679            ExistingSource::None => unreachable!("prepared update should never visit None"),
1680        }
1681    }
1682
1683    #[inline(always)]
1684    fn recompute_segment_head(&mut self, segment_idx: SegmentIndex) {
1685        let segment = self.arena.segments[segment_idx.get()];
1686        let head = if segment.tail.is_some() {
1687            let (alpha, log_alpha, log_one_minus_alpha) = self.segment_constants(segment.len());
1688            unary_chain_log_weight_precomputed(
1689                segment.log_prob_kt,
1690                self.arena.child_ref_weighted(segment.tail),
1691                alpha,
1692                log_alpha,
1693                log_one_minus_alpha,
1694            )
1695        } else {
1696            segment.log_prob_kt
1697        };
1698        self.arena.segments[segment_idx.get()].head_log_prob_weighted = head;
1699    }
1700
1701    fn attach_missing_after_prepared_path(
1702        &mut self,
1703        history: &[Symbol],
1704        sym_idx: usize,
1705        singleton_log_prob_kt: f64,
1706    ) {
1707        let Some(last_step) = self.prepared_steps.last().copied() else {
1708            return;
1709        };
1710        let depth = self.prepared_levels;
1711        match last_step.source {
1712            ExistingSource::Node(node_idx) => {
1713                debug_assert!(depth < self.max_depth);
1714                let path_edge = history_symbol(history, depth) as usize;
1715                debug_assert!(self.arena.child(node_idx, path_edge).is_none());
1716                let new_child =
1717                    self.build_missing_path(depth + 1, history, sym_idx, singleton_log_prob_kt);
1718                self.arena.set_child(node_idx, path_edge, new_child);
1719            }
1720            ExistingSource::Segment(segment_idx, offset) => {
1721                debug_assert!(depth < self.max_depth);
1722                debug_assert_eq!(offset + 1, self.arena.segment_len(segment_idx));
1723                debug_assert!(self.arena.segments[segment_idx.get()].tail.is_none());
1724                let new_tail =
1725                    self.build_missing_path(depth + 1, history, sym_idx, singleton_log_prob_kt);
1726                self.arena.set_segment_tail(segment_idx, new_tail);
1727            }
1728            ExistingSource::None => unreachable!("prepared path should never end in None source"),
1729        }
1730    }
1731
1732    fn replace_prepared_child(
1733        &mut self,
1734        history: &[Symbol],
1735        step_index: usize,
1736        current_start_depth: usize,
1737        new_child: ChildRef,
1738    ) {
1739        if step_index == 0 {
1740            let root_edge = history_symbol(history, 0) as usize;
1741            self.arena.set_child(self.root, root_edge, new_child);
1742            return;
1743        }
1744
1745        match self.prepared_steps[step_index - 1].source {
1746            ExistingSource::Node(node_idx) => {
1747                let edge = history_symbol(history, current_start_depth - 1) as usize;
1748                self.arena.set_child(node_idx, edge, new_child);
1749            }
1750            ExistingSource::Segment(segment_idx, offset) => {
1751                debug_assert_eq!(offset + 1, self.arena.segment_len(segment_idx));
1752                self.arena.set_segment_tail(segment_idx, new_child);
1753            }
1754            ExistingSource::None => unreachable!("prepared path should never parent from None"),
1755        }
1756    }
1757
1758    fn update_prepared_mismatch(
1759        &mut self,
1760        log_int: &[f64],
1761        log_half: &[f64],
1762        history: &[Symbol],
1763        sym_idx: usize,
1764        singleton_log_prob_kt: f64,
1765    ) -> ChildRef {
1766        let last_index = self.prepared_steps.len() - 1;
1767        for idx in 0..last_index {
1768            self.update_source_state(log_int, log_half, self.prepared_steps[idx].source, sym_idx);
1769        }
1770
1771        let last_step = self.prepared_steps[last_index];
1772        let ExistingSource::Segment(segment_idx, offset_u32) = last_step.source else {
1773            unreachable!("prepared segment mismatch must end at a segment");
1774        };
1775
1776        let original = self.arena.segments[segment_idx.get()];
1777        let offset = offset_u32 as usize;
1778        let seg_len = original.len() as usize;
1779        let history_len = history.len();
1780        let current_start_depth = self.prepared_levels - last_step.span as usize + 1;
1781        let node_depth = current_start_depth + offset;
1782        let path_edge = path_edge_at_depth(history, history_len, node_depth);
1783        let existing_edge = segment_edge_from_parts(original, offset, history, history_len);
1784        debug_assert_ne!(path_edge, existing_edge);
1785
1786        let old_continuation = if offset + 1 < seg_len {
1787            if offset == 0 {
1788                let segment = &mut self.arena.segments[segment_idx.get()];
1789                segment.payload = original.payload.suffix_after(1);
1790                segment.tail = original.tail;
1791                segment.symbol_count = original.symbol_count;
1792                segment.log_prob_kt = original.log_prob_kt;
1793                self.recompute_segment_head(segment_idx);
1794                ChildRef::from_segment(segment_idx)
1795            } else {
1796                ChildRef::from_segment(self.arena.alloc_segment_with_parts(
1797                    original.symbol_count,
1798                    original.log_prob_kt,
1799                    original.tail,
1800                    original.payload.suffix_after(offset as u32 + 1),
1801                ))
1802            }
1803        } else {
1804            original.tail
1805        };
1806
1807        let new_tail =
1808            self.build_missing_path(node_depth + 1, history, sym_idx, singleton_log_prob_kt);
1809        let mut updated_counts = original.symbol_count;
1810        let mut updated_log_prob_kt = original.log_prob_kt;
1811        apply_update_to_state_raw(
1812            log_int,
1813            log_half,
1814            &mut updated_counts,
1815            &mut updated_log_prob_kt,
1816            sym_idx,
1817        );
1818
1819        let branch = self
1820            .arena
1821            .alloc_node_with_state(updated_counts, updated_log_prob_kt);
1822        self.arena
1823            .set_child(branch, existing_edge as usize, old_continuation);
1824        self.arena.set_child(branch, path_edge as usize, new_tail);
1825        self.arena.recompute_node_weight(branch);
1826
1827        if offset == 0 {
1828            if seg_len == 1 {
1829                self.arena.free_segment(segment_idx);
1830            }
1831            self.replace_prepared_child(
1832                history,
1833                last_index,
1834                current_start_depth,
1835                ChildRef::from_node(branch),
1836            );
1837        } else {
1838            let segment = &mut self.arena.segments[segment_idx.get()];
1839            segment.payload = original.payload.prefix(offset as u32);
1840            segment.tail = ChildRef::from_node(branch);
1841            segment.symbol_count = updated_counts;
1842            segment.log_prob_kt = updated_log_prob_kt;
1843            self.recompute_segment_head(segment_idx);
1844        }
1845
1846        for idx in (0..last_index).rev() {
1847            self.recompute_source_weight(self.prepared_steps[idx].source);
1848        }
1849
1850        let root_edge = history_symbol(history, 0) as usize;
1851        self.arena.child(self.root, root_edge)
1852    }
1853
1854    fn update_prepared_cached_path(
1855        &mut self,
1856        log_int: &[f64],
1857        log_half: &[f64],
1858        history: &[Symbol],
1859        sym_idx: usize,
1860        singleton_log_prob_kt: f64,
1861    ) {
1862        debug_assert!(!self.prepared_steps.is_empty());
1863        debug_assert!(matches!(
1864            self.prepared_end,
1865            PreparedEnd::MaxDepth | PreparedEnd::MissingAfterCurrent
1866        ));
1867
1868        if self.prepared_end == PreparedEnd::MissingAfterCurrent {
1869            self.attach_missing_after_prepared_path(history, sym_idx, singleton_log_prob_kt);
1870        }
1871
1872        let last_index = self.prepared_steps.len() - 1;
1873        let mut child_weight = if self.prepared_end == PreparedEnd::MissingAfterCurrent {
1874            let last_step = self.prepared_steps[last_index];
1875            match last_step.source {
1876                ExistingSource::Node(node_idx) => {
1877                    let depth = self.prepared_levels;
1878                    let edge = history_symbol(history, depth) as usize;
1879                    self.arena
1880                        .child_ref_weighted(self.arena.child(node_idx, edge))
1881                }
1882                ExistingSource::Segment(segment_idx, offset) => {
1883                    debug_assert_eq!(offset + 1, self.arena.segment_len(segment_idx));
1884                    self.arena
1885                        .child_ref_weighted(self.arena.segments[segment_idx.get()].tail)
1886                }
1887                ExistingSource::None => unreachable!("prepared path should never end in None"),
1888            }
1889        } else {
1890            0.0
1891        };
1892
1893        for idx in (0..=last_index).rev() {
1894            let step = self.prepared_steps[idx];
1895            match step.source {
1896                ExistingSource::Node(node_idx) => {
1897                    let mut counts = step.counts;
1898                    let mut log_prob_kt = step.kt_log_prob;
1899                    apply_update_to_state_raw(
1900                        log_int,
1901                        log_half,
1902                        &mut counts,
1903                        &mut log_prob_kt,
1904                        sym_idx,
1905                    );
1906                    let weighted =
1907                        if idx == last_index && self.prepared_end == PreparedEnd::MaxDepth {
1908                            debug_assert_eq!(step.has_sibling, 0);
1909                            clamp_log_prob(log_prob_kt)
1910                        } else {
1911                            update_weighted_log_prob(
1912                                log_prob_kt,
1913                                child_weight,
1914                                step.sibling_weight,
1915                                false,
1916                            )
1917                        };
1918                    let slot = node_idx.get();
1919                    self.arena.nodes[slot].symbol_count = counts;
1920                    self.arena.nodes[slot].log_prob_kt = log_prob_kt;
1921                    self.arena.nodes[slot].log_prob_weighted = weighted;
1922                    child_weight = weighted;
1923                }
1924                ExistingSource::Segment(segment_idx, offset) => {
1925                    let mut counts = step.counts;
1926                    let mut log_prob_kt = step.kt_log_prob;
1927                    apply_update_to_state_raw(
1928                        log_int,
1929                        log_half,
1930                        &mut counts,
1931                        &mut log_prob_kt,
1932                        sym_idx,
1933                    );
1934                    let slot = segment_idx.get();
1935                    let weighted =
1936                        if idx == last_index && self.prepared_end == PreparedEnd::MaxDepth {
1937                            debug_assert_eq!(offset + 1, self.arena.segment_len(segment_idx));
1938                            debug_assert!(self.arena.segments[slot].tail.is_none());
1939                            clamp_log_prob(log_prob_kt)
1940                        } else {
1941                            let (alpha, log_alpha, log_one_minus_alpha) =
1942                                self.segment_constants(self.arena.segments[slot].len());
1943                            unary_chain_log_weight_precomputed(
1944                                log_prob_kt,
1945                                child_weight,
1946                                alpha,
1947                                log_alpha,
1948                                log_one_minus_alpha,
1949                            )
1950                        };
1951                    self.arena.segments[slot].symbol_count = counts;
1952                    self.arena.segments[slot].log_prob_kt = log_prob_kt;
1953                    self.arena.segments[slot].head_log_prob_weighted = weighted;
1954                    child_weight = weighted;
1955                }
1956                ExistingSource::None => unreachable!("prepared update should never visit None"),
1957            }
1958        }
1959    }
1960
1961    fn update_child_fast(
1962        &mut self,
1963        log_int: &[f64],
1964        log_half: &[f64],
1965        child: ChildRef,
1966        depth: usize,
1967        history: &[Symbol],
1968        sym_idx: usize,
1969        singleton_log_prob_kt: f64,
1970    ) -> ChildRef {
1971        if depth > self.max_depth {
1972            return child;
1973        }
1974        if child.is_none() {
1975            return self.build_missing_path(depth, history, sym_idx, singleton_log_prob_kt);
1976        }
1977
1978        if let Some(node_idx) = child.as_node() {
1979            if depth < self.max_depth {
1980                let path_edge = history_symbol(history, depth) as usize;
1981                let next = self.arena.child(node_idx, path_edge);
1982                let updated = self.update_child_fast(
1983                    log_int,
1984                    log_half,
1985                    next,
1986                    depth + 1,
1987                    history,
1988                    sym_idx,
1989                    singleton_log_prob_kt,
1990                );
1991                if updated != next {
1992                    self.arena.set_child(node_idx, path_edge, updated);
1993                }
1994            }
1995            let mut counts = self.arena.nodes[node_idx.get()].symbol_count;
1996            let mut log_prob_kt = self.arena.nodes[node_idx.get()].log_prob_kt;
1997            apply_update_to_state_raw(log_int, log_half, &mut counts, &mut log_prob_kt, sym_idx);
1998            self.arena.nodes[node_idx.get()].symbol_count = counts;
1999            self.arena.nodes[node_idx.get()].log_prob_kt = log_prob_kt;
2000            self.arena.recompute_node_weight(node_idx);
2001            return ChildRef::from_node(node_idx);
2002        }
2003
2004        let segment_idx = child.as_segment().unwrap();
2005        let original = self.arena.segments[segment_idx.get()];
2006        let seg_len = original.len() as usize;
2007        let mut updated_counts = original.symbol_count;
2008        let mut updated_log_prob_kt = original.log_prob_kt;
2009        apply_update_to_state_raw(
2010            log_int,
2011            log_half,
2012            &mut updated_counts,
2013            &mut updated_log_prob_kt,
2014            sym_idx,
2015        );
2016
2017        let depth_budget = self.max_depth.saturating_sub(depth);
2018        let comparable_len = if original.tail.is_none() {
2019            seg_len.saturating_sub(1)
2020        } else {
2021            seg_len
2022        }
2023        .min(depth_budget);
2024        let mismatch = first_segment_mismatch(original, depth, history, comparable_len).map(
2025            |(offset, path_edge, existing_edge)| (offset, depth + offset, path_edge, existing_edge),
2026        );
2027
2028        if let Some((offset, node_depth, path_edge, existing_edge)) = mismatch {
2029            let old_continuation = if offset + 1 < seg_len {
2030                if offset == 0 {
2031                    let segment = &mut self.arena.segments[segment_idx.get()];
2032                    segment.payload = original.payload.suffix_after(1);
2033                    segment.tail = original.tail;
2034                    segment.symbol_count = original.symbol_count;
2035                    segment.log_prob_kt = original.log_prob_kt;
2036                    self.recompute_segment_head(segment_idx);
2037                    ChildRef::from_segment(segment_idx)
2038                } else {
2039                    ChildRef::from_segment(self.arena.alloc_segment_with_parts(
2040                        original.symbol_count,
2041                        original.log_prob_kt,
2042                        original.tail,
2043                        original.payload.suffix_after(offset as u32 + 1),
2044                    ))
2045                }
2046            } else {
2047                original.tail
2048            };
2049
2050            let new_tail =
2051                self.build_missing_path(node_depth + 1, history, sym_idx, singleton_log_prob_kt);
2052            let branch = self
2053                .arena
2054                .alloc_node_with_state(updated_counts, updated_log_prob_kt);
2055            self.arena
2056                .set_child(branch, existing_edge as usize, old_continuation);
2057            self.arena.set_child(branch, path_edge as usize, new_tail);
2058            self.arena.recompute_node_weight(branch);
2059
2060            if offset == 0 {
2061                if offset + 1 >= seg_len {
2062                    self.arena.free_segment(segment_idx);
2063                }
2064                return ChildRef::from_node(branch);
2065            }
2066
2067            let segment = &mut self.arena.segments[segment_idx.get()];
2068            segment.payload = original.payload.prefix(offset as u32);
2069            segment.tail = ChildRef::from_node(branch);
2070            segment.symbol_count = updated_counts;
2071            segment.log_prob_kt = updated_log_prob_kt;
2072            self.recompute_segment_head(segment_idx);
2073            return ChildRef::from_segment(segment_idx);
2074        }
2075
2076        if depth_budget < seg_len {
2077            self.arena.segments[segment_idx.get()].symbol_count = updated_counts;
2078            self.arena.segments[segment_idx.get()].log_prob_kt = updated_log_prob_kt;
2079            self.recompute_segment_head(segment_idx);
2080            return ChildRef::from_segment(segment_idx);
2081        }
2082
2083        if original.tail.is_none() {
2084            let new_tail =
2085                self.build_missing_path(depth + seg_len, history, sym_idx, singleton_log_prob_kt);
2086            self.arena.segments[segment_idx.get()].tail = new_tail;
2087            self.arena.segments[segment_idx.get()].symbol_count = updated_counts;
2088            self.arena.segments[segment_idx.get()].log_prob_kt = updated_log_prob_kt;
2089            self.recompute_segment_head(segment_idx);
2090            return ChildRef::from_segment(segment_idx);
2091        }
2092
2093        let tail = original.tail;
2094        let updated_tail = self.update_child_fast(
2095            log_int,
2096            log_half,
2097            tail,
2098            depth + seg_len,
2099            history,
2100            sym_idx,
2101            singleton_log_prob_kt,
2102        );
2103        if updated_tail != tail {
2104            self.arena.set_segment_tail(segment_idx, updated_tail);
2105        }
2106        self.arena.segments[segment_idx.get()].symbol_count = updated_counts;
2107        self.arena.segments[segment_idx.get()].log_prob_kt = updated_log_prob_kt;
2108        self.recompute_segment_head(segment_idx);
2109        ChildRef::from_segment(segment_idx)
2110    }
2111
2112    fn update_child_fast_exact(
2113        &mut self,
2114        log_int: &[f64],
2115        log_half: &[f64],
2116        child: ChildRef,
2117        depth: usize,
2118        history: &[Symbol],
2119        path_bits: u64,
2120        sym_idx: usize,
2121        singleton_log_prob_kt: f64,
2122    ) -> ChildRef {
2123        debug_assert!(self.max_depth <= SEG_EXACT_MAX_LEN as usize);
2124        if depth > self.max_depth {
2125            return child;
2126        }
2127        if child.is_none() {
2128            return self.build_missing_path_exact_bits(
2129                depth,
2130                path_bits,
2131                sym_idx,
2132                singleton_log_prob_kt,
2133            );
2134        }
2135
2136        if let Some(node_idx) = child.as_node() {
2137            if depth < self.max_depth {
2138                let path_edge = (path_bits & 1) as usize;
2139                let next = self.arena.child(node_idx, path_edge);
2140                let updated = self.update_child_fast_exact(
2141                    log_int,
2142                    log_half,
2143                    next,
2144                    depth + 1,
2145                    history,
2146                    shift_path_bits(path_bits, 1),
2147                    sym_idx,
2148                    singleton_log_prob_kt,
2149                );
2150                if updated != next {
2151                    self.arena.set_child(node_idx, path_edge, updated);
2152                }
2153            }
2154            let mut counts = self.arena.nodes[node_idx.get()].symbol_count;
2155            let mut log_prob_kt = self.arena.nodes[node_idx.get()].log_prob_kt;
2156            apply_update_to_state_raw(log_int, log_half, &mut counts, &mut log_prob_kt, sym_idx);
2157            self.arena.nodes[node_idx.get()].symbol_count = counts;
2158            self.arena.nodes[node_idx.get()].log_prob_kt = log_prob_kt;
2159            self.arena.recompute_node_weight(node_idx);
2160            return ChildRef::from_node(node_idx);
2161        }
2162
2163        let segment_idx = child.as_segment().unwrap();
2164        let original = self.arena.segments[segment_idx.get()];
2165        if !original.payload.is_exact() {
2166            return self.update_child_fast(
2167                log_int,
2168                log_half,
2169                child,
2170                depth,
2171                history,
2172                sym_idx,
2173                singleton_log_prob_kt,
2174            );
2175        }
2176
2177        let seg_len = original.len() as usize;
2178        let mut updated_counts = original.symbol_count;
2179        let mut updated_log_prob_kt = original.log_prob_kt;
2180        apply_update_to_state_raw(
2181            log_int,
2182            log_half,
2183            &mut updated_counts,
2184            &mut updated_log_prob_kt,
2185            sym_idx,
2186        );
2187
2188        let depth_budget = self.max_depth.saturating_sub(depth);
2189        let comparable_len = if original.tail.is_none() {
2190            seg_len.saturating_sub(1)
2191        } else {
2192            seg_len
2193        }
2194        .min(depth_budget);
2195        let mismatch =
2196            first_exact_segment_mismatch(original.payload.exact_bits(), path_bits, comparable_len)
2197                .map(|(offset, path_edge, existing_edge)| {
2198                    (offset, depth + offset, path_edge, existing_edge)
2199                });
2200
2201        if let Some((offset, node_depth, path_edge, existing_edge)) = mismatch {
2202            let old_continuation = if offset + 1 < seg_len {
2203                if offset == 0 {
2204                    let segment = &mut self.arena.segments[segment_idx.get()];
2205                    segment.payload = original.payload.suffix_after(1);
2206                    segment.tail = original.tail;
2207                    segment.symbol_count = original.symbol_count;
2208                    segment.log_prob_kt = original.log_prob_kt;
2209                    self.recompute_segment_head(segment_idx);
2210                    ChildRef::from_segment(segment_idx)
2211                } else {
2212                    ChildRef::from_segment(self.arena.alloc_segment_with_parts(
2213                        original.symbol_count,
2214                        original.log_prob_kt,
2215                        original.tail,
2216                        original.payload.suffix_after(offset as u32 + 1),
2217                    ))
2218                }
2219            } else {
2220                original.tail
2221            };
2222
2223            let new_tail = self.build_missing_path_exact_bits(
2224                node_depth + 1,
2225                shift_path_bits(path_bits, offset + 1),
2226                sym_idx,
2227                singleton_log_prob_kt,
2228            );
2229            let branch = self
2230                .arena
2231                .alloc_node_with_state(updated_counts, updated_log_prob_kt);
2232            self.arena
2233                .set_child(branch, existing_edge as usize, old_continuation);
2234            self.arena.set_child(branch, path_edge as usize, new_tail);
2235            self.arena.recompute_node_weight(branch);
2236
2237            if offset == 0 {
2238                if offset + 1 >= seg_len {
2239                    self.arena.free_segment(segment_idx);
2240                }
2241                return ChildRef::from_node(branch);
2242            }
2243
2244            let segment = &mut self.arena.segments[segment_idx.get()];
2245            segment.payload = original.payload.prefix(offset as u32);
2246            segment.tail = ChildRef::from_node(branch);
2247            segment.symbol_count = updated_counts;
2248            segment.log_prob_kt = updated_log_prob_kt;
2249            self.recompute_segment_head(segment_idx);
2250            return ChildRef::from_segment(segment_idx);
2251        }
2252
2253        if depth_budget < seg_len {
2254            self.arena.segments[segment_idx.get()].symbol_count = updated_counts;
2255            self.arena.segments[segment_idx.get()].log_prob_kt = updated_log_prob_kt;
2256            self.recompute_segment_head(segment_idx);
2257            return ChildRef::from_segment(segment_idx);
2258        }
2259
2260        if original.tail.is_none() {
2261            let new_tail = self.build_missing_path_exact_bits(
2262                depth + seg_len,
2263                shift_path_bits(path_bits, seg_len),
2264                sym_idx,
2265                singleton_log_prob_kt,
2266            );
2267            self.arena.segments[segment_idx.get()].tail = new_tail;
2268            self.arena.segments[segment_idx.get()].symbol_count = updated_counts;
2269            self.arena.segments[segment_idx.get()].log_prob_kt = updated_log_prob_kt;
2270            self.recompute_segment_head(segment_idx);
2271            return ChildRef::from_segment(segment_idx);
2272        }
2273
2274        let tail = original.tail;
2275        let updated_tail = self.update_child_fast_exact(
2276            log_int,
2277            log_half,
2278            tail,
2279            depth + seg_len,
2280            history,
2281            shift_path_bits(path_bits, seg_len),
2282            sym_idx,
2283            singleton_log_prob_kt,
2284        );
2285        if updated_tail != tail {
2286            self.arena.set_segment_tail(segment_idx, updated_tail);
2287        }
2288        self.arena.segments[segment_idx.get()].symbol_count = updated_counts;
2289        self.arena.segments[segment_idx.get()].log_prob_kt = updated_log_prob_kt;
2290        self.recompute_segment_head(segment_idx);
2291        ChildRef::from_segment(segment_idx)
2292    }
2293
2294    #[inline(always)]
2295    fn update_root_child(
2296        &mut self,
2297        log_int: &[f64],
2298        log_half: &[f64],
2299        child: ChildRef,
2300        history: &[Symbol],
2301        sym_idx: usize,
2302        singleton_log_prob_kt: f64,
2303    ) -> ChildRef {
2304        if self.max_depth <= SEG_EXACT_MAX_LEN as usize {
2305            let path_bits = path_bits_from_history(history, 1, self.max_depth);
2306            self.update_child_fast_exact(
2307                log_int,
2308                log_half,
2309                child,
2310                1,
2311                history,
2312                path_bits,
2313                sym_idx,
2314                singleton_log_prob_kt,
2315            )
2316        } else {
2317            self.update_child_fast(
2318                log_int,
2319                log_half,
2320                child,
2321                1,
2322                history,
2323                sym_idx,
2324                singleton_log_prob_kt,
2325            )
2326        }
2327    }
2328
2329    fn collect_existing_levels(&mut self, history: &[Symbol]) -> ChildRef {
2330        if self.max_depth == 0 {
2331            self.detaches.clear();
2332            return ChildRef::NONE;
2333        }
2334
2335        self.detaches.clear();
2336        self.levels.fill(LevelState::default());
2337
2338        let root_edge = history_symbol(history, 0) as usize;
2339        let old_child = self.arena.child(self.root, root_edge);
2340        let mut source = if let Some(node) = old_child.as_node() {
2341            ExistingSource::Node(node)
2342        } else if let Some(segment) = old_child.as_segment() {
2343            ExistingSource::Segment(segment, 0)
2344        } else {
2345            ExistingSource::None
2346        };
2347
2348        for depth in 1..=self.max_depth {
2349            let slot = depth - 1;
2350            self.levels[slot] = LevelState::default();
2351
2352            match source {
2353                ExistingSource::None => {}
2354                ExistingSource::Node(node_idx) => {
2355                    self.levels[slot].symbol_count = self.arena.counts(node_idx);
2356                    self.levels[slot].log_prob_kt = self.arena.log_prob_kt(node_idx);
2357                    if depth < self.max_depth {
2358                        let path_edge = history_symbol(history, depth) as usize;
2359                        let sibling_edge = path_edge ^ 1;
2360                        let sibling = self.arena.child(node_idx, sibling_edge);
2361                        self.levels[slot].sibling = sibling;
2362                        if sibling.is_some() {
2363                            self.detaches.push(Detach::NodeChild {
2364                                node: node_idx,
2365                                edge: sibling_edge,
2366                            });
2367                        }
2368                        let next = self.arena.child(node_idx, path_edge);
2369                        source = if let Some(next_node) = next.as_node() {
2370                            ExistingSource::Node(next_node)
2371                        } else if let Some(next_segment) = next.as_segment() {
2372                            ExistingSource::Segment(next_segment, 0)
2373                        } else {
2374                            ExistingSource::None
2375                        };
2376                    }
2377                }
2378                ExistingSource::Segment(segment_idx, offset) => {
2379                    self.levels[slot].symbol_count = self.arena.segment_symbol_count(segment_idx);
2380                    self.levels[slot].log_prob_kt = self.arena.segment_log_prob_kt(segment_idx);
2381                    if depth < self.max_depth {
2382                        let path_edge = history_symbol(history, depth) as usize;
2383                        if self.arena.segment_has_child(segment_idx, offset) {
2384                            let existing_edge =
2385                                self.arena.segment_edge(segment_idx, offset, history);
2386                            if path_edge == existing_edge {
2387                                let seg_len = self.arena.segment_len(segment_idx);
2388                                if offset + 1 < seg_len {
2389                                    source = ExistingSource::Segment(segment_idx, offset + 1);
2390                                } else {
2391                                    let tail = self.arena.segments[segment_idx.get()].tail;
2392                                    source = if let Some(next_node) = tail.as_node() {
2393                                        ExistingSource::Node(next_node)
2394                                    } else if let Some(next_segment) = tail.as_segment() {
2395                                        ExistingSource::Segment(next_segment, 0)
2396                                    } else {
2397                                        ExistingSource::None
2398                                    };
2399                                }
2400                            } else {
2401                                let continuation = self.arena.detach_segment_continuation(
2402                                    segment_idx,
2403                                    offset,
2404                                    &mut self.detaches,
2405                                );
2406                                self.levels[slot].sibling = continuation;
2407                                source = ExistingSource::None;
2408                            }
2409                        } else {
2410                            source = ExistingSource::None;
2411                        }
2412                    }
2413                }
2414            }
2415        }
2416
2417        old_child
2418    }
2419
2420    fn rebuild_path_subtree(&mut self, history: &[Symbol]) -> ChildRef {
2421        let mut built = ChildRef::NONE;
2422
2423        for depth in (1..=self.max_depth).rev() {
2424            let level = self.levels[depth - 1];
2425            let visits = level.symbol_count[0] + level.symbol_count[1];
2426            if visits == 0 {
2427                built = ChildRef::NONE;
2428                continue;
2429            }
2430
2431            let path_edge = if depth < self.max_depth {
2432                history_symbol(history, depth) as usize
2433            } else {
2434                0
2435            };
2436            let has_path_child = built.is_some();
2437            let has_sibling = level.sibling.is_some();
2438            let force_node = depth <= self.hot_prefix_depth();
2439
2440            if force_node || (has_path_child && has_sibling) {
2441                let node = self
2442                    .arena
2443                    .alloc_node_with_state(level.symbol_count, level.log_prob_kt);
2444                if has_path_child {
2445                    self.arena.set_child(node, path_edge, built);
2446                }
2447                if has_sibling {
2448                    self.arena.set_child(node, path_edge ^ 1, level.sibling);
2449                }
2450                self.arena.recompute_node_weight(node);
2451                built = ChildRef::from_node(node);
2452            } else {
2453                let (edge, child) = if has_path_child {
2454                    (path_edge, built)
2455                } else if has_sibling {
2456                    (path_edge ^ 1, level.sibling)
2457                } else {
2458                    (path_edge, ChildRef::NONE)
2459                };
2460                built = self.arena.prepend_or_alloc_segment(
2461                    history,
2462                    depth,
2463                    level.symbol_count,
2464                    level.log_prob_kt,
2465                    child,
2466                    edge,
2467                    false,
2468                );
2469            }
2470        }
2471
2472        built
2473    }
2474
2475    fn apply_detaches(&mut self) {
2476        for detach in self.detaches.drain(..) {
2477            match detach {
2478                Detach::NodeChild { node, edge } => {
2479                    self.arena.set_child(node, edge, ChildRef::NONE);
2480                }
2481                Detach::SegmentNext { segment, new_len } => {
2482                    self.arena.segments[segment.get()].set_len(new_len);
2483                    self.arena.set_segment_tail(segment, ChildRef::NONE);
2484                }
2485            }
2486        }
2487    }
2488
2489    fn update_with_logs(
2490        &mut self,
2491        log_int: &[f64],
2492        log_half: &[f64],
2493        sym: Symbol,
2494        history: &[Symbol],
2495    ) {
2496        let sym_idx = sym as usize;
2497        let singleton_log_prob_kt = log_half[0] - log_int[1];
2498        {
2499            let slot = self.root.get();
2500            let mut counts = self.arena.nodes[slot].symbol_count;
2501            let mut log_prob_kt = self.arena.nodes[slot].log_prob_kt;
2502            apply_update_to_state_raw(log_int, log_half, &mut counts, &mut log_prob_kt, sym_idx);
2503            self.arena.nodes[slot].symbol_count = counts;
2504            self.arena.nodes[slot].log_prob_kt = log_prob_kt;
2505        }
2506
2507        if self.max_depth > 0 {
2508            let root_edge = history_symbol(history, 0) as usize;
2509            let old_child = self.arena.child(self.root, root_edge);
2510            let new_child = self.update_root_child(
2511                log_int,
2512                log_half,
2513                old_child,
2514                history,
2515                sym_idx,
2516                singleton_log_prob_kt,
2517            );
2518            self.arena.set_child(self.root, root_edge, new_child);
2519        }
2520
2521        self.arena.recompute_node_weight(self.root);
2522    }
2523
2524    fn update(&mut self, sym: Symbol, history: &[Symbol]) {
2525        let upto = self.root_visits() + 1;
2526        self.with_logs(upto, |this, log_int, log_half| {
2527            this.update_with_logs(log_int, log_half, sym, history);
2528        });
2529    }
2530
2531    fn update_prepared(&mut self, sym: Symbol, history: &[Symbol], use_prepared: bool) {
2532        let upto = self.root_visits() + 1;
2533        let sym_idx = sym as usize;
2534        self.with_logs(upto, |this, log_int, log_half| {
2535            let singleton_log_prob_kt = log_half[0] - log_int[1];
2536            {
2537                let slot = this.root.get();
2538                let mut counts = this.arena.nodes[slot].symbol_count;
2539                let mut log_prob_kt = this.arena.nodes[slot].log_prob_kt;
2540                apply_update_to_state_raw(
2541                    log_int,
2542                    log_half,
2543                    &mut counts,
2544                    &mut log_prob_kt,
2545                    sym_idx,
2546                );
2547                this.arena.nodes[slot].symbol_count = counts;
2548                this.arena.nodes[slot].log_prob_kt = log_prob_kt;
2549            }
2550
2551            if this.max_depth > 0 {
2552                let root_edge = history_symbol(history, 0) as usize;
2553                let old_child = this.arena.child(this.root, root_edge);
2554                let new_child = if use_prepared {
2555                    match this.prepared_end {
2556                        PreparedEnd::MissingAtRoot => {
2557                            this.build_missing_path(1, history, sym_idx, singleton_log_prob_kt)
2558                        }
2559                        PreparedEnd::MaxDepth | PreparedEnd::MissingAfterCurrent => {
2560                            if !this.prepared_steps.is_empty() {
2561                                this.update_prepared_cached_path(
2562                                    log_int,
2563                                    log_half,
2564                                    history,
2565                                    sym_idx,
2566                                    singleton_log_prob_kt,
2567                                );
2568                            }
2569                            old_child
2570                        }
2571                        PreparedEnd::MismatchAtCurrentSegment => this.update_prepared_mismatch(
2572                            log_int,
2573                            log_half,
2574                            history,
2575                            sym_idx,
2576                            singleton_log_prob_kt,
2577                        ),
2578                    }
2579                } else {
2580                    this.update_root_child(
2581                        log_int,
2582                        log_half,
2583                        old_child,
2584                        history,
2585                        sym_idx,
2586                        singleton_log_prob_kt,
2587                    )
2588                };
2589                this.arena.set_child(this.root, root_edge, new_child);
2590            }
2591
2592            this.arena.recompute_node_weight(this.root);
2593        });
2594    }
2595
2596    fn revert(&mut self, sym: Symbol, history: &[Symbol]) {
2597        let upto = self.root_visits();
2598        let sym_idx = sym as usize;
2599        self.with_logs(upto, |this, log_int, log_half| {
2600            let old_child = this.collect_existing_levels(history);
2601
2602            {
2603                let slot = this.root.get();
2604                let mut counts = this.arena.nodes[slot].symbol_count;
2605                let mut log_prob_kt = this.arena.nodes[slot].log_prob_kt;
2606                apply_revert_to_state_raw(
2607                    log_int,
2608                    log_half,
2609                    &mut counts,
2610                    &mut log_prob_kt,
2611                    sym_idx,
2612                );
2613                this.arena.nodes[slot].symbol_count = counts;
2614                this.arena.nodes[slot].log_prob_kt = log_prob_kt;
2615            }
2616
2617            for level in &mut this.levels {
2618                let mut counts = level.symbol_count;
2619                let mut log_prob_kt = level.log_prob_kt;
2620                apply_revert_to_state_raw(
2621                    log_int,
2622                    log_half,
2623                    &mut counts,
2624                    &mut log_prob_kt,
2625                    sym_idx,
2626                );
2627                level.symbol_count = counts;
2628                level.log_prob_kt = log_prob_kt;
2629            }
2630
2631            if this.max_depth > 0 {
2632                let new_child = this.rebuild_path_subtree(history);
2633                let root_edge = history_symbol(history, 0) as usize;
2634                this.apply_detaches();
2635                this.arena.free_child_ref(old_child);
2636                this.arena.set_child(this.root, root_edge, new_child);
2637            }
2638
2639            this.arena.recompute_node_weight(this.root);
2640        });
2641    }
2642
2643    fn predict(&mut self, sym: Symbol, history: &[Symbol]) -> f64 {
2644        self.prepared_steps.clear();
2645        self.prepared_levels = 0;
2646        self.prepared_end = PreparedEnd::MaxDepth;
2647
2648        let (root_sibling, root_has_sibling, mut source) = if self.max_depth > 0 {
2649            let root_edge = history_symbol(history, 0) as usize;
2650            let path_child = self.arena.child(self.root, root_edge);
2651            let sibling = self.arena.child(self.root, root_edge ^ 1);
2652            (
2653                self.arena.child_ref_weighted(sibling),
2654                sibling.is_some() as u8,
2655                Self::child_to_existing_source(path_child).unwrap_or(ExistingSource::None),
2656            )
2657        } else {
2658            (0.0, 0, ExistingSource::None)
2659        };
2660        if self.max_depth > 0 && matches!(source, ExistingSource::None) {
2661            self.prepared_end = PreparedEnd::MissingAtRoot;
2662        }
2663
2664        let history_len = history.len();
2665        let mut depth = 1usize;
2666        'walk: while depth <= self.max_depth {
2667            match source {
2668                ExistingSource::None => break,
2669                ExistingSource::Node(node_idx) => {
2670                    let slot = node_idx.get();
2671                    let counts = self.arena.nodes[slot].symbol_count;
2672                    let kt_log_prob = self.arena.nodes[slot].log_prob_kt;
2673                    if depth == self.max_depth {
2674                        self.prepared_steps.push(PreparedStep {
2675                            source: ExistingSource::Node(node_idx),
2676                            counts,
2677                            kt_log_prob,
2678                            span: 1,
2679                            sibling_weight: 0.0,
2680                            has_sibling: 0,
2681                        });
2682                        self.prepared_levels += 1;
2683                        break;
2684                    }
2685                    let path_edge = history_symbol(history, depth) as usize;
2686                    let sibling = self.arena.child(node_idx, path_edge ^ 1);
2687                    self.prepared_steps.push(PreparedStep {
2688                        source: ExistingSource::Node(node_idx),
2689                        counts,
2690                        kt_log_prob,
2691                        span: 1,
2692                        sibling_weight: self.arena.child_ref_weighted(sibling),
2693                        has_sibling: sibling.is_some() as u8,
2694                    });
2695                    self.prepared_levels += 1;
2696                    let next = self.arena.child(node_idx, path_edge);
2697                    source = Self::child_to_existing_source(next).unwrap_or(ExistingSource::None);
2698                    if matches!(source, ExistingSource::None) {
2699                        self.prepared_end = PreparedEnd::MissingAfterCurrent;
2700                        break;
2701                    }
2702                    depth += 1;
2703                }
2704                ExistingSource::Segment(segment_idx, _) => {
2705                    let segment = self.arena.segments[segment_idx.get()];
2706                    let seg_len = segment.len() as usize;
2707                    let counts = segment.symbol_count;
2708                    let kt_log_prob = segment.log_prob_kt;
2709                    for offset in 0..seg_len {
2710                        let node_depth = depth + offset;
2711                        let span = (offset + 1) as u32;
2712
2713                        if node_depth == self.max_depth {
2714                            self.prepared_steps.push(PreparedStep {
2715                                source: ExistingSource::Segment(segment_idx, offset as u32),
2716                                counts,
2717                                kt_log_prob,
2718                                span,
2719                                sibling_weight: 0.0,
2720                                has_sibling: 0,
2721                            });
2722                            self.prepared_levels += span as usize;
2723                            break 'walk;
2724                        }
2725
2726                        if offset + 1 >= seg_len && segment.tail.is_none() {
2727                            self.prepared_steps.push(PreparedStep {
2728                                source: ExistingSource::Segment(segment_idx, offset as u32),
2729                                counts,
2730                                kt_log_prob,
2731                                span,
2732                                sibling_weight: 0.0,
2733                                has_sibling: 0,
2734                            });
2735                            self.prepared_levels += span as usize;
2736                            self.prepared_end = PreparedEnd::MissingAfterCurrent;
2737                            break 'walk;
2738                        }
2739
2740                        let path_edge = path_edge_at_depth(history, history_len, node_depth);
2741                        let existing_edge =
2742                            segment_edge_from_parts(segment, offset, history, history_len);
2743                        if path_edge != existing_edge {
2744                            self.prepared_steps.push(PreparedStep {
2745                                source: ExistingSource::Segment(segment_idx, offset as u32),
2746                                counts,
2747                                kt_log_prob,
2748                                span,
2749                                sibling_weight: self
2750                                    .arena
2751                                    .segment_continuation_weight(segment_idx, offset as u32),
2752                                has_sibling: 1,
2753                            });
2754                            self.prepared_levels += span as usize;
2755                            self.prepared_end = PreparedEnd::MismatchAtCurrentSegment;
2756                            break 'walk;
2757                        }
2758
2759                        if offset + 1 < seg_len {
2760                            continue;
2761                        }
2762
2763                        self.prepared_steps.push(PreparedStep {
2764                            source: ExistingSource::Segment(segment_idx, offset as u32),
2765                            counts,
2766                            kt_log_prob,
2767                            span,
2768                            sibling_weight: 0.0,
2769                            has_sibling: 0,
2770                        });
2771                        self.prepared_levels += span as usize;
2772                        let tail = segment.tail;
2773                        source =
2774                            Self::child_to_existing_source(tail).unwrap_or(ExistingSource::None);
2775                        if matches!(source, ExistingSource::None) {
2776                            self.prepared_end = PreparedEnd::MissingAfterCurrent;
2777                            break 'walk;
2778                        }
2779                        depth = node_depth + 1;
2780                        continue 'walk;
2781                    }
2782                }
2783            }
2784        }
2785
2786        let sym_idx = sym as usize;
2787        if self.prepared_levels == 0 {
2788            let counts = self.arena.counts(self.root);
2789            let kt_log_prob = self.arena.log_prob_kt(self.root);
2790            return if self.prepared_end == PreparedEnd::MaxDepth || root_has_sibling == 0 {
2791                predict_ratio_kt(counts, sym_idx)
2792            } else {
2793                predict_ratio_internal(kt_log_prob, counts, 0.0, root_sibling, 0.5, sym_idx)
2794            };
2795        }
2796
2797        let last_step = *self.prepared_steps.last().unwrap();
2798        let last_counts = last_step.counts;
2799        let last_kt_log_prob = last_step.kt_log_prob;
2800        let (mut child_weight, mut ratio) = if self.prepared_end == PreparedEnd::MaxDepth
2801            && self.prepared_levels == self.max_depth
2802        {
2803            (last_kt_log_prob, predict_ratio_kt(last_counts, sym_idx))
2804        } else if last_step.has_sibling == 0 {
2805            (last_kt_log_prob, predict_ratio_kt(last_counts, sym_idx))
2806        } else {
2807            combined_weight_ratio_internal(
2808                last_kt_log_prob,
2809                last_counts,
2810                0.0,
2811                last_step.sibling_weight,
2812                0.5,
2813                sym_idx,
2814            )
2815        };
2816
2817        if let ExistingSource::Segment(_, _) = last_step.source {
2818            if last_step.span > 1 {
2819                let (alpha, log_alpha, log_one_minus_alpha) =
2820                    self.segment_constants(last_step.span - 1);
2821                (child_weight, ratio) = unary_chain_ratio_transform_precomputed(
2822                    last_step.kt_log_prob,
2823                    last_step.counts,
2824                    child_weight,
2825                    ratio,
2826                    alpha,
2827                    log_alpha,
2828                    log_one_minus_alpha,
2829                    sym_idx,
2830                );
2831            }
2832        }
2833
2834        for idx in (0..self.prepared_steps.len() - 1).rev() {
2835            let step = self.prepared_steps[idx];
2836            match step.source {
2837                ExistingSource::Node(_) => {
2838                    (child_weight, ratio) = combined_weight_ratio_internal(
2839                        step.kt_log_prob,
2840                        step.counts,
2841                        child_weight,
2842                        step.sibling_weight,
2843                        ratio,
2844                        sym_idx,
2845                    );
2846                }
2847                ExistingSource::Segment(_, _) => {
2848                    let (alpha, log_alpha, log_one_minus_alpha) = self.segment_constants(step.span);
2849                    (child_weight, ratio) = unary_chain_ratio_transform_precomputed(
2850                        step.kt_log_prob,
2851                        step.counts,
2852                        child_weight,
2853                        ratio,
2854                        alpha,
2855                        log_alpha,
2856                        log_one_minus_alpha,
2857                        sym_idx,
2858                    );
2859                }
2860                ExistingSource::None => unreachable!("prepared step should never store None"),
2861            }
2862        }
2863
2864        let root_counts = self.arena.counts(self.root);
2865        let root_kt_log_prob = self.arena.log_prob_kt(self.root);
2866        predict_ratio_internal(
2867            root_kt_log_prob,
2868            root_counts,
2869            child_weight,
2870            root_sibling,
2871            ratio,
2872            sym_idx,
2873        )
2874    }
2875
2876    fn predict_one(&mut self, history: &[Symbol]) -> f64 {
2877        self.prepared_steps.clear();
2878        self.prepared_levels = 0;
2879        self.prepared_end = PreparedEnd::MaxDepth;
2880
2881        let (root_sibling, root_has_sibling, mut source) = if self.max_depth > 0 {
2882            let root_edge = history_symbol(history, 0) as usize;
2883            let path_child = self.arena.child(self.root, root_edge);
2884            let sibling = self.arena.child(self.root, root_edge ^ 1);
2885            (
2886                self.arena.child_ref_weighted(sibling),
2887                sibling.is_some() as u8,
2888                Self::child_to_existing_source(path_child).unwrap_or(ExistingSource::None),
2889            )
2890        } else {
2891            (0.0, 0, ExistingSource::None)
2892        };
2893        if self.max_depth > 0 && matches!(source, ExistingSource::None) {
2894            self.prepared_end = PreparedEnd::MissingAtRoot;
2895        }
2896
2897        let history_len = history.len();
2898        let mut depth = 1usize;
2899        'walk: while depth <= self.max_depth {
2900            match source {
2901                ExistingSource::None => break,
2902                ExistingSource::Node(node_idx) => {
2903                    let slot = node_idx.get();
2904                    let counts = self.arena.nodes[slot].symbol_count;
2905                    let kt_log_prob = self.arena.nodes[slot].log_prob_kt;
2906                    if depth == self.max_depth {
2907                        self.prepared_steps.push(PreparedStep {
2908                            source: ExistingSource::Node(node_idx),
2909                            counts,
2910                            kt_log_prob,
2911                            span: 1,
2912                            sibling_weight: 0.0,
2913                            has_sibling: 0,
2914                        });
2915                        self.prepared_levels += 1;
2916                        break;
2917                    }
2918                    let path_edge = history_symbol(history, depth) as usize;
2919                    let sibling = self.arena.child(node_idx, path_edge ^ 1);
2920                    self.prepared_steps.push(PreparedStep {
2921                        source: ExistingSource::Node(node_idx),
2922                        counts,
2923                        kt_log_prob,
2924                        span: 1,
2925                        sibling_weight: self.arena.child_ref_weighted(sibling),
2926                        has_sibling: sibling.is_some() as u8,
2927                    });
2928                    self.prepared_levels += 1;
2929                    let next = self.arena.child(node_idx, path_edge);
2930                    source = Self::child_to_existing_source(next).unwrap_or(ExistingSource::None);
2931                    if matches!(source, ExistingSource::None) {
2932                        self.prepared_end = PreparedEnd::MissingAfterCurrent;
2933                        break;
2934                    }
2935                    depth += 1;
2936                }
2937                ExistingSource::Segment(segment_idx, _) => {
2938                    let segment = self.arena.segments[segment_idx.get()];
2939                    let seg_len = segment.len() as usize;
2940                    let counts = segment.symbol_count;
2941                    let kt_log_prob = segment.log_prob_kt;
2942                    for offset in 0..seg_len {
2943                        let node_depth = depth + offset;
2944                        let span = (offset + 1) as u32;
2945
2946                        if node_depth == self.max_depth {
2947                            self.prepared_steps.push(PreparedStep {
2948                                source: ExistingSource::Segment(segment_idx, offset as u32),
2949                                counts,
2950                                kt_log_prob,
2951                                span,
2952                                sibling_weight: 0.0,
2953                                has_sibling: 0,
2954                            });
2955                            self.prepared_levels += span as usize;
2956                            break 'walk;
2957                        }
2958
2959                        if offset + 1 >= seg_len && segment.tail.is_none() {
2960                            self.prepared_steps.push(PreparedStep {
2961                                source: ExistingSource::Segment(segment_idx, offset as u32),
2962                                counts,
2963                                kt_log_prob,
2964                                span,
2965                                sibling_weight: 0.0,
2966                                has_sibling: 0,
2967                            });
2968                            self.prepared_levels += span as usize;
2969                            self.prepared_end = PreparedEnd::MissingAfterCurrent;
2970                            break 'walk;
2971                        }
2972
2973                        let path_edge = path_edge_at_depth(history, history_len, node_depth);
2974                        let existing_edge =
2975                            segment_edge_from_parts(segment, offset, history, history_len);
2976                        if path_edge != existing_edge {
2977                            self.prepared_steps.push(PreparedStep {
2978                                source: ExistingSource::Segment(segment_idx, offset as u32),
2979                                counts,
2980                                kt_log_prob,
2981                                span,
2982                                sibling_weight: self
2983                                    .arena
2984                                    .segment_continuation_weight(segment_idx, offset as u32),
2985                                has_sibling: 1,
2986                            });
2987                            self.prepared_levels += span as usize;
2988                            self.prepared_end = PreparedEnd::MismatchAtCurrentSegment;
2989                            break 'walk;
2990                        }
2991
2992                        if offset + 1 < seg_len {
2993                            continue;
2994                        }
2995
2996                        self.prepared_steps.push(PreparedStep {
2997                            source: ExistingSource::Segment(segment_idx, offset as u32),
2998                            counts,
2999                            kt_log_prob,
3000                            span,
3001                            sibling_weight: 0.0,
3002                            has_sibling: 0,
3003                        });
3004                        self.prepared_levels += span as usize;
3005                        let tail = segment.tail;
3006                        source =
3007                            Self::child_to_existing_source(tail).unwrap_or(ExistingSource::None);
3008                        if matches!(source, ExistingSource::None) {
3009                            self.prepared_end = PreparedEnd::MissingAfterCurrent;
3010                            break 'walk;
3011                        }
3012                        depth = node_depth + 1;
3013                        continue 'walk;
3014                    }
3015                }
3016            }
3017        }
3018
3019        if self.prepared_levels == 0 {
3020            let counts = self.arena.counts(self.root);
3021            let kt_log_prob = self.arena.log_prob_kt(self.root);
3022            return if self.prepared_end == PreparedEnd::MaxDepth || root_has_sibling == 0 {
3023                predict_ratio_kt_one(counts)
3024            } else {
3025                predict_ratio_internal_one(kt_log_prob, counts, 0.0, root_sibling, 0.5)
3026            };
3027        }
3028
3029        let last_step = *self.prepared_steps.last().unwrap();
3030        let last_counts = last_step.counts;
3031        let last_kt_log_prob = last_step.kt_log_prob;
3032        let (mut child_weight, mut ratio) = if self.prepared_end == PreparedEnd::MaxDepth
3033            && self.prepared_levels == self.max_depth
3034        {
3035            (last_kt_log_prob, predict_ratio_kt_one(last_counts))
3036        } else if last_step.has_sibling == 0 {
3037            (last_kt_log_prob, predict_ratio_kt_one(last_counts))
3038        } else {
3039            combined_weight_ratio_internal_one(
3040                last_kt_log_prob,
3041                last_counts,
3042                0.0,
3043                last_step.sibling_weight,
3044                0.5,
3045            )
3046        };
3047
3048        if let ExistingSource::Segment(_, _) = last_step.source
3049            && last_step.span > 1
3050        {
3051            let (alpha, log_alpha, log_one_minus_alpha) =
3052                self.segment_constants(last_step.span - 1);
3053            (child_weight, ratio) = unary_chain_ratio_transform_precomputed_one(
3054                last_step.kt_log_prob,
3055                last_step.counts,
3056                child_weight,
3057                ratio,
3058                alpha,
3059                log_alpha,
3060                log_one_minus_alpha,
3061            );
3062        }
3063
3064        for idx in (0..self.prepared_steps.len() - 1).rev() {
3065            let step = self.prepared_steps[idx];
3066            match step.source {
3067                ExistingSource::Node(_) => {
3068                    (child_weight, ratio) = combined_weight_ratio_internal_one(
3069                        step.kt_log_prob,
3070                        step.counts,
3071                        child_weight,
3072                        step.sibling_weight,
3073                        ratio,
3074                    );
3075                }
3076                ExistingSource::Segment(_, _) => {
3077                    let (alpha, log_alpha, log_one_minus_alpha) = self.segment_constants(step.span);
3078                    (child_weight, ratio) = unary_chain_ratio_transform_precomputed_one(
3079                        step.kt_log_prob,
3080                        step.counts,
3081                        child_weight,
3082                        ratio,
3083                        alpha,
3084                        log_alpha,
3085                        log_one_minus_alpha,
3086                    );
3087                }
3088                ExistingSource::None => unreachable!("prepared step should never store None"),
3089            }
3090        }
3091
3092        let root_counts = self.arena.counts(self.root);
3093        let root_kt_log_prob = self.arena.log_prob_kt(self.root);
3094        predict_ratio_internal_one(
3095            root_kt_log_prob,
3096            root_counts,
3097            child_weight,
3098            root_sibling,
3099            ratio,
3100        )
3101    }
3102
3103    fn memory_usage(&self) -> usize {
3104        self.arena.memory_usage()
3105            + self.segment_alpha.capacity() * size_of::<f64>()
3106            + self.segment_log_alpha.capacity() * size_of::<f64>()
3107            + self.segment_log_one_minus_alpha.capacity() * size_of::<f64>()
3108            + self.levels.capacity() * size_of::<LevelState>()
3109            + self.detaches.capacity() * size_of::<Detach>()
3110            + self.prepared_steps.capacity() * size_of::<PreparedStep>()
3111    }
3112}
3113
3114/// A Context Tree for binary sequence prediction.
3115#[derive(Clone)]
3116pub struct ContextTree {
3117    engine: CtEngine,
3118    history: Vec<Symbol>,
3119}
3120
3121impl ContextTree {
3122    /// Construct a binary CTW predictor with maximum context depth `depth`.
3123    pub fn new(depth: usize) -> Self {
3124        Self {
3125            engine: CtEngine::new(depth),
3126            history: Vec::new(),
3127        }
3128    }
3129
3130    /// Reset tree parameters and clear conditioning history.
3131    pub fn clear(&mut self) {
3132        self.history.clear();
3133        self.engine.clear();
3134    }
3135
3136    #[inline]
3137    /// Observe one binary symbol and update the model.
3138    pub fn update(&mut self, sym: Symbol) {
3139        self.engine.update(sym, &self.history);
3140        self.history.push(sym);
3141    }
3142
3143    #[inline]
3144    /// Revert the last symbol update if history is non-empty.
3145    pub fn revert(&mut self) {
3146        let Some(last_sym) = self.history.pop() else {
3147            return;
3148        };
3149        self.engine.revert(last_sym, &self.history);
3150    }
3151
3152    #[inline]
3153    /// Append external symbols to history without touching model state.
3154    pub fn update_history(&mut self, symbols: &[Symbol]) {
3155        self.history.extend_from_slice(symbols);
3156    }
3157
3158    #[inline]
3159    /// Remove one history symbol without reverting model statistics.
3160    pub fn revert_history(&mut self) {
3161        self.history.pop();
3162    }
3163
3164    /// Truncate the stored history to `new_size` symbols.
3165    pub fn truncate_history(&mut self, new_size: usize) {
3166        if new_size < self.history.len() {
3167            self.history.truncate(new_size);
3168        }
3169    }
3170
3171    #[inline]
3172    /// Predict `P(sym | history)` under current weighted CTW model.
3173    pub fn predict(&mut self, sym: Symbol) -> f64 {
3174        self.engine.predict(sym, &self.history)
3175    }
3176
3177    #[inline]
3178    /// Predict probability of symbol `true`.
3179    pub fn predict_sym_prob(&mut self) -> f64 {
3180        self.predict(true)
3181    }
3182
3183    #[inline]
3184    /// Return log block probability accumulated by the root model.
3185    pub fn get_log_block_probability(&self) -> f64 {
3186        self.engine.get_log_block_probability()
3187    }
3188
3189    #[inline]
3190    /// Maximum context depth configured for this tree.
3191    pub fn depth(&self) -> usize {
3192        self.engine.max_depth
3193    }
3194
3195    #[inline]
3196    /// Current number of stored history symbols.
3197    pub fn history_size(&self) -> usize {
3198        self.history.len()
3199    }
3200}
3201
3202#[derive(Clone)]
3203struct ContextTreeCore {
3204    engine: CtEngine,
3205    prepared_valid: bool,
3206    prepared_history_len: usize,
3207    prepared_history_version: u64,
3208}
3209
3210impl ContextTreeCore {
3211    fn new(depth: usize) -> Self {
3212        Self {
3213            engine: CtEngine::new(depth),
3214            prepared_valid: false,
3215            prepared_history_len: 0,
3216            prepared_history_version: 0,
3217        }
3218    }
3219
3220    fn clear(&mut self) {
3221        self.engine.clear();
3222        self.prepared_valid = false;
3223        self.prepared_history_len = 0;
3224        self.prepared_history_version = 0;
3225    }
3226
3227    #[inline]
3228    fn reserve_for_symbols(&mut self, total_symbols: usize) {
3229        self.engine.reserve_for_symbols(total_symbols);
3230    }
3231
3232    #[inline]
3233    fn update(&mut self, sym: Symbol, shared_history: &[Symbol]) {
3234        self.prepared_valid = false;
3235        self.engine.update(sym, shared_history);
3236    }
3237
3238    #[inline]
3239    fn update_predicted(&mut self, sym: Symbol, shared_history: &[Symbol], history_version: u64) {
3240        let use_prepared = self.prepared_valid
3241            && self.prepared_history_len == shared_history.len()
3242            && self.prepared_history_version == history_version;
3243        self.prepared_valid = false;
3244        self.engine
3245            .update_prepared(sym, shared_history, use_prepared);
3246    }
3247
3248    #[inline]
3249    fn revert(&mut self, last_sym: Symbol, shared_history: &[Symbol]) {
3250        self.prepared_valid = false;
3251        self.engine.revert(last_sym, shared_history);
3252    }
3253
3254    #[inline]
3255    fn predict(&mut self, sym: Symbol, shared_history: &[Symbol], history_version: u64) -> f64 {
3256        let prob = self.engine.predict(sym, shared_history);
3257        self.prepared_valid = true;
3258        self.prepared_history_len = shared_history.len();
3259        self.prepared_history_version = history_version;
3260        prob
3261    }
3262
3263    #[inline]
3264    fn predict_one(&mut self, shared_history: &[Symbol], history_version: u64) -> f64 {
3265        let prob = self.engine.predict_one(shared_history);
3266        self.prepared_valid = true;
3267        self.prepared_history_len = shared_history.len();
3268        self.prepared_history_version = history_version;
3269        prob
3270    }
3271
3272    #[inline]
3273    fn get_log_block_probability(&self) -> f64 {
3274        self.engine.get_log_block_probability()
3275    }
3276}
3277
3278/// Factorized Action-Conditional Context Tree Weighting.
3279#[derive(Clone)]
3280pub struct FacContextTree {
3281    trees: Vec<ContextTreeCore>,
3282    shared_history: Vec<Symbol>,
3283    base_depth: usize,
3284    num_bits: usize,
3285    shared_history_version: u64,
3286}
3287
3288impl FacContextTree {
3289    /// Create a factorized CTW stack over `num_percept_bits` bit positions.
3290    ///
3291    /// Tree `i` uses depth `base_depth + i`, matching FAC-CTW's increasing context.
3292    pub fn new(base_depth: usize, num_percept_bits: usize) -> Self {
3293        let trees = (0..num_percept_bits)
3294            .map(|i| ContextTreeCore::new(base_depth + i))
3295            .collect();
3296        Self {
3297            trees,
3298            shared_history: Vec::new(),
3299            base_depth,
3300            num_bits: num_percept_bits,
3301            shared_history_version: 0,
3302        }
3303    }
3304
3305    #[inline(always)]
3306    fn bump_shared_history_version(&mut self) {
3307        self.shared_history_version = self.shared_history_version.wrapping_add(1);
3308    }
3309
3310    #[inline]
3311    /// Reserve history/tree capacity for approximately `total_symbols` updates.
3312    pub fn reserve_for_symbols(&mut self, total_symbols: usize) {
3313        if total_symbols == 0 {
3314            return;
3315        }
3316        self.shared_history
3317            .reserve_exact(total_symbols.saturating_mul(self.num_bits));
3318        for tree in &mut self.trees {
3319            tree.reserve_for_symbols(total_symbols);
3320        }
3321    }
3322
3323    #[inline]
3324    /// Number of bit positions currently modeled per symbol.
3325    pub fn num_bits(&self) -> usize {
3326        self.num_bits
3327    }
3328
3329    #[inline]
3330    /// Base depth used to construct the first factorized tree.
3331    pub fn base_depth(&self) -> usize {
3332        self.base_depth
3333    }
3334
3335    #[inline]
3336    /// Update one bit position with a binary symbol.
3337    pub fn update(&mut self, sym: Symbol, bit_index: usize) {
3338        debug_assert!(bit_index < self.num_bits);
3339        self.trees[bit_index].update(sym, &self.shared_history);
3340        self.shared_history.push(sym);
3341        self.bump_shared_history_version();
3342    }
3343
3344    #[inline]
3345    /// Update all bit positions from one byte, most-significant bit first.
3346    pub fn update_byte_msb(&mut self, byte: u8) {
3347        if self.num_bits != 8 {
3348            for bit_idx in 0..self.num_bits {
3349                let bit = ((byte >> (7 - bit_idx)) & 1) == 1;
3350                self.update(bit, bit_idx);
3351            }
3352            return;
3353        }
3354
3355        let upto = self.trees[0].engine.root_visits() + 1;
3356        debug_assert!(
3357            self.trees
3358                .iter()
3359                .all(|tree| tree.engine.root_visits() + 1 == upto)
3360        );
3361        with_shared_log_cache(upto, |log_int, log_half| {
3362            for bit_idx in 0..8usize {
3363                let bit = ((byte >> (7 - bit_idx)) & 1) == 1;
3364                let tree = &mut self.trees[bit_idx];
3365                tree.prepared_valid = false;
3366                tree.engine
3367                    .update_with_logs(log_int, log_half, bit, &self.shared_history);
3368                self.shared_history.push(bit);
3369            }
3370        });
3371        self.bump_shared_history_version();
3372    }
3373
3374    #[inline]
3375    /// Update all active bit positions from one byte, least-significant bit first.
3376    pub fn update_byte_lsb(&mut self, byte: u8) {
3377        let bits = self.num_bits.clamp(1, 8);
3378        let upto = self.trees[0].engine.root_visits() + 1;
3379        debug_assert!(
3380            self.trees
3381                .iter()
3382                .take(bits)
3383                .all(|tree| tree.engine.root_visits() + 1 == upto)
3384        );
3385        with_shared_log_cache(upto, |log_int, log_half| {
3386            for bit_idx in 0..bits {
3387                let bit = ((byte >> bit_idx) & 1) == 1;
3388                let tree = &mut self.trees[bit_idx];
3389                tree.prepared_valid = false;
3390                tree.engine
3391                    .update_with_logs(log_int, log_half, bit, &self.shared_history);
3392                self.shared_history.push(bit);
3393            }
3394        });
3395        self.bump_shared_history_version();
3396    }
3397
3398    #[inline]
3399    /// Commit an update using the fast-path prepared by a prior prediction call.
3400    pub fn update_predicted(&mut self, sym: Symbol, bit_index: usize) {
3401        debug_assert!(bit_index < self.num_bits);
3402        self.trees[bit_index].update_predicted(
3403            sym,
3404            &self.shared_history,
3405            self.shared_history_version,
3406        );
3407        self.shared_history.push(sym);
3408        self.bump_shared_history_version();
3409    }
3410
3411    #[inline]
3412    /// Predict a single bit probability for `bit_index`.
3413    pub fn predict(&mut self, sym: Symbol, bit_index: usize) -> f64 {
3414        debug_assert!(bit_index < self.num_bits);
3415        self.trees[bit_index].predict(sym, &self.shared_history, self.shared_history_version)
3416    }
3417
3418    #[inline]
3419    pub(crate) fn predict_one(&mut self, bit_index: usize) -> f64 {
3420        debug_assert!(bit_index < self.num_bits);
3421        self.trees[bit_index].predict_one(&self.shared_history, self.shared_history_version)
3422    }
3423
3424    #[inline]
3425    /// Revert the most recent bit update for `bit_index`.
3426    pub fn revert(&mut self, bit_index: usize) {
3427        debug_assert!(bit_index < self.num_bits);
3428        let Some(last_sym) = self.shared_history.pop() else {
3429            return;
3430        };
3431        self.trees[bit_index].revert(last_sym, &self.shared_history);
3432        self.bump_shared_history_version();
3433    }
3434
3435    #[inline]
3436    /// Append raw shared-history symbols without model updates.
3437    pub fn update_history(&mut self, symbols: &[Symbol]) {
3438        if symbols.is_empty() {
3439            return;
3440        }
3441        self.shared_history.extend_from_slice(symbols);
3442        self.bump_shared_history_version();
3443    }
3444
3445    #[inline]
3446    /// Drop the last `count` symbols from shared history.
3447    pub fn revert_history(&mut self, count: usize) {
3448        let old_len = self.shared_history.len();
3449        let new_len = self.shared_history.len().saturating_sub(count);
3450        if new_len == old_len {
3451            return;
3452        }
3453        self.shared_history.truncate(new_len);
3454        self.bump_shared_history_version();
3455    }
3456
3457    #[inline]
3458    /// Clear shared history while preserving learned tree parameters.
3459    pub fn reset_history_only(&mut self) {
3460        if self.shared_history.is_empty() {
3461            return;
3462        }
3463        self.shared_history.clear();
3464        self.bump_shared_history_version();
3465    }
3466
3467    #[inline]
3468    /// Sum of per-tree log block probabilities.
3469    pub fn get_log_block_probability(&self) -> f64 {
3470        self.trees
3471            .iter()
3472            .map(|t| t.get_log_block_probability())
3473            .sum()
3474    }
3475
3476    /// Clear all trees and shared history.
3477    pub fn clear(&mut self) {
3478        for tree in &mut self.trees {
3479            tree.clear();
3480        }
3481        self.shared_history.clear();
3482        self.shared_history_version = 0;
3483    }
3484
3485    /// Approximate heap memory usage in bytes.
3486    pub fn memory_usage(&self) -> usize {
3487        let tree_mem: usize = self.trees.iter().map(|t| t.engine.memory_usage()).sum();
3488        let log_cache_mem = self
3489            .trees
3490            .first()
3491            .map(|t| t.engine.log_cache_memory_usage())
3492            .unwrap_or(0);
3493        let history_mem = self.shared_history.capacity() * size_of::<Symbol>();
3494        tree_mem + log_cache_mem + history_mem
3495    }
3496}
3497
3498#[cfg(test)]
3499mod tests {
3500    use super::*;
3501
3502    #[derive(Clone)]
3503    struct RefNode {
3504        children: [Option<Box<RefNode>>; 2],
3505        log_prob_kt: f64,
3506        log_prob_weighted: f64,
3507        symbol_count: [u32; 2],
3508    }
3509
3510    impl Default for RefNode {
3511        fn default() -> Self {
3512            Self {
3513                children: [None, None],
3514                log_prob_kt: 0.0,
3515                log_prob_weighted: 0.0,
3516                symbol_count: [0, 0],
3517            }
3518        }
3519    }
3520
3521    #[derive(Clone)]
3522    struct RefContextTree {
3523        root: RefNode,
3524        history: Vec<Symbol>,
3525        max_depth: usize,
3526        log_int: Vec<f64>,
3527        log_half: Vec<f64>,
3528    }
3529
3530    impl RefContextTree {
3531        fn new(depth: usize) -> Self {
3532            Self {
3533                root: RefNode::default(),
3534                history: Vec::new(),
3535                max_depth: depth,
3536                log_int: vec![f64::NEG_INFINITY],
3537                log_half: vec![(0.5f64).ln()],
3538            }
3539        }
3540
3541        fn root_visits(&self) -> usize {
3542            (self.root.symbol_count[0] + self.root.symbol_count[1]) as usize
3543        }
3544
3545        fn recompute(node: &mut RefNode) {
3546            let w0 = node.children[0]
3547                .as_ref()
3548                .map(|c| c.log_prob_weighted)
3549                .unwrap_or(0.0);
3550            let w1 = node.children[1]
3551                .as_ref()
3552                .map(|c| c.log_prob_weighted)
3553                .unwrap_or(0.0);
3554            let is_leaf = node.children[0].is_none() && node.children[1].is_none();
3555            node.log_prob_weighted = update_weighted_log_prob(node.log_prob_kt, w0, w1, is_leaf);
3556        }
3557
3558        fn update(&mut self, sym: Symbol) {
3559            let upto = self.root_visits() + 1;
3560            ensure_log_caches(&mut self.log_int, &mut self.log_half, upto);
3561            let sym_idx = sym as usize;
3562            Self::update_node(
3563                &mut self.root,
3564                0,
3565                self.max_depth,
3566                &self.history,
3567                sym_idx,
3568                &self.log_int,
3569                &self.log_half,
3570            );
3571            self.history.push(sym);
3572        }
3573
3574        fn revert(&mut self) {
3575            let Some(last_sym) = self.history.pop() else {
3576                return;
3577            };
3578            let upto = self.root_visits();
3579            ensure_log_caches(&mut self.log_int, &mut self.log_half, upto);
3580            let sym_idx = last_sym as usize;
3581            let _ = Self::revert_node(
3582                &mut self.root,
3583                0,
3584                self.max_depth,
3585                &self.history,
3586                sym_idx,
3587                &self.log_int,
3588                &self.log_half,
3589            );
3590        }
3591
3592        fn predict(&mut self, sym: Symbol) -> f64 {
3593            let sym_idx = sym as usize;
3594            let mut entries = Vec::with_capacity(self.max_depth + 1);
3595            let reached_max_depth = Self::collect_predict_entries(
3596                &self.root,
3597                0,
3598                self.max_depth,
3599                &self.history,
3600                &mut entries,
3601            );
3602
3603            let deepest = entries.len() - 1;
3604            let mut ratio = if reached_max_depth && deepest == self.max_depth {
3605                predict_ratio_kt(entries[deepest].symbol_count, sym_idx)
3606            } else {
3607                0.5
3608            };
3609            for idx in (0..=deepest).rev() {
3610                if reached_max_depth && idx == deepest {
3611                    continue;
3612                }
3613                let child_weight = if idx + 1 <= deepest {
3614                    entries[idx + 1].log_prob_weighted
3615                } else {
3616                    0.0
3617                };
3618                ratio = predict_ratio_internal(
3619                    entries[idx].log_prob_kt,
3620                    entries[idx].symbol_count,
3621                    child_weight,
3622                    entries[idx].sibling_weight,
3623                    ratio,
3624                    sym_idx,
3625                );
3626            }
3627            ratio
3628        }
3629
3630        fn get_log_block_probability(&self) -> f64 {
3631            self.root.log_prob_weighted
3632        }
3633
3634        fn update_node(
3635            node: &mut RefNode,
3636            depth: usize,
3637            max_depth: usize,
3638            history: &[Symbol],
3639            sym_idx: usize,
3640            log_int: &[f64],
3641            log_half: &[f64],
3642        ) {
3643            if depth < max_depth {
3644                let edge = history_symbol(history, depth) as usize;
3645                if node.children[edge].is_none() {
3646                    node.children[edge] = Some(Box::new(RefNode::default()));
3647                }
3648                Self::update_node(
3649                    node.children[edge].as_deref_mut().unwrap(),
3650                    depth + 1,
3651                    max_depth,
3652                    history,
3653                    sym_idx,
3654                    log_int,
3655                    log_half,
3656                );
3657            }
3658            apply_update_to_state_raw(
3659                log_int,
3660                log_half,
3661                &mut node.symbol_count,
3662                &mut node.log_prob_kt,
3663                sym_idx,
3664            );
3665            Self::recompute(node);
3666        }
3667
3668        fn revert_node(
3669            node: &mut RefNode,
3670            depth: usize,
3671            max_depth: usize,
3672            history: &[Symbol],
3673            sym_idx: usize,
3674            log_int: &[f64],
3675            log_half: &[f64],
3676        ) -> bool {
3677            if depth < max_depth {
3678                let edge = history_symbol(history, depth) as usize;
3679                let remove_child = if let Some(child) = node.children[edge].as_deref_mut() {
3680                    Self::revert_node(
3681                        child,
3682                        depth + 1,
3683                        max_depth,
3684                        history,
3685                        sym_idx,
3686                        log_int,
3687                        log_half,
3688                    )
3689                } else {
3690                    false
3691                };
3692                if remove_child {
3693                    node.children[edge] = None;
3694                }
3695            }
3696            apply_revert_to_state_raw(
3697                log_int,
3698                log_half,
3699                &mut node.symbol_count,
3700                &mut node.log_prob_kt,
3701                sym_idx,
3702            );
3703            Self::recompute(node);
3704            node.symbol_count[0] + node.symbol_count[1] == 0
3705        }
3706
3707        fn collect_predict_entries(
3708            node: &RefNode,
3709            depth: usize,
3710            max_depth: usize,
3711            history: &[Symbol],
3712            entries: &mut Vec<PredictEntry>,
3713        ) -> bool {
3714            let sibling_weight = if depth < max_depth {
3715                let path_edge = history_symbol(history, depth) as usize;
3716                node.children[path_edge ^ 1]
3717                    .as_ref()
3718                    .map(|c| c.log_prob_weighted)
3719                    .unwrap_or(0.0)
3720            } else {
3721                0.0
3722            };
3723            entries.push(PredictEntry {
3724                symbol_count: node.symbol_count,
3725                log_prob_kt: node.log_prob_kt,
3726                log_prob_weighted: node.log_prob_weighted,
3727                sibling_weight,
3728                has_sibling: depth < max_depth
3729                    && node.children[(history_symbol(history, depth) as usize) ^ 1].is_some(),
3730            });
3731            if depth == max_depth {
3732                return true;
3733            }
3734            let edge = history_symbol(history, depth) as usize;
3735            let Some(child) = node.children[edge].as_ref() else {
3736                return false;
3737            };
3738            Self::collect_predict_entries(child, depth + 1, max_depth, history, entries)
3739        }
3740    }
3741
3742    #[derive(Clone)]
3743    struct RefFacContextTree {
3744        trees: Vec<RefContextTree>,
3745        history: Vec<Symbol>,
3746    }
3747
3748    impl RefFacContextTree {
3749        fn new(base_depth: usize, num_bits: usize) -> Self {
3750            Self {
3751                trees: (0..num_bits)
3752                    .map(|i| RefContextTree::new(base_depth + i))
3753                    .collect(),
3754                history: Vec::new(),
3755            }
3756        }
3757
3758        fn update(&mut self, sym: Symbol, bit_index: usize) {
3759            let tree = &mut self.trees[bit_index];
3760            tree.history = self.history.clone();
3761            tree.update(sym);
3762            self.history.push(sym);
3763        }
3764
3765        fn predict(&mut self, sym: Symbol, bit_index: usize) -> f64 {
3766            let tree = &mut self.trees[bit_index];
3767            tree.history = self.history.clone();
3768            tree.predict(sym)
3769        }
3770
3771        fn revert(&mut self, bit_index: usize) {
3772            let Some(last_sym) = self.history.pop() else {
3773                return;
3774            };
3775            let tree = &mut self.trees[bit_index];
3776            tree.history = self.history.clone();
3777            tree.history.push(last_sym);
3778            tree.revert();
3779        }
3780
3781        fn get_log_block_probability(&self) -> f64 {
3782            self.trees
3783                .iter()
3784                .map(RefContextTree::get_log_block_probability)
3785                .sum()
3786        }
3787    }
3788
3789    fn assert_close(a: f64, b: f64) {
3790        let diff = (a - b).abs();
3791        let scale = a.abs().max(b.abs()).max(1.0);
3792        assert!(diff <= 1e-12 * scale, "a={a} b={b} diff={diff}");
3793    }
3794
3795    fn child_after_hot_prefix(tree: &ContextTree, history_before_update: &[Symbol]) -> ChildRef {
3796        let hot_prefix_depth = tree.engine.hot_prefix_depth();
3797        if hot_prefix_depth == 0 {
3798            return ChildRef::NONE;
3799        }
3800
3801        let root_edge = history_symbol(history_before_update, 0) as usize;
3802        let mut current = tree
3803            .engine
3804            .arena
3805            .child(tree.engine.root, root_edge)
3806            .as_node()
3807            .expect("hot-prefix node");
3808        for node_depth in 1..hot_prefix_depth {
3809            let edge = history_symbol(history_before_update, node_depth) as usize;
3810            current = tree
3811                .engine
3812                .arena
3813                .child(current, edge)
3814                .as_node()
3815                .expect("next hot-prefix node");
3816        }
3817        let tail_edge = history_symbol(history_before_update, hot_prefix_depth) as usize;
3818        tree.engine.arena.child(current, tail_edge)
3819    }
3820
3821    #[test]
3822    #[should_panic(expected = "ctw node index overflow")]
3823    fn node_index_from_usize_rejects_overflow() {
3824        let _ = NodeIndex::from_usize(INDEX_LIMIT);
3825    }
3826
3827    #[test]
3828    #[should_panic(expected = "ctw node index overflow")]
3829    fn node_index_from_usize_rejects_large_values() {
3830        let _ = NodeIndex::from_usize(u32::MAX as usize);
3831    }
3832
3833    #[test]
3834    fn ctw_count_lane_stays_packed() {
3835        assert_eq!(std::mem::size_of::<CtNode>(), 32);
3836    }
3837
3838    #[test]
3839    fn ctw_segment_payload_stays_packed() {
3840        assert_eq!(std::mem::size_of::<CtSegment>(), 40);
3841    }
3842
3843    #[test]
3844    fn context_tree_singleton_paths_use_hot_prefix_nodes() {
3845        let mut tree = ContextTree::new(12);
3846        tree.update(false);
3847
3848        let hot_prefix_depth = tree.engine.hot_prefix_depth();
3849        let child = tree.engine.arena.child(tree.engine.root, 0);
3850        let mut current = child.as_node().expect("hot-prefix node");
3851        let mut visited_hot_prefix_nodes = 1usize;
3852        for depth in 1..hot_prefix_depth {
3853            let next = tree.engine.arena.child(current, 0);
3854            current = next.as_node().expect("next hot-prefix node");
3855            visited_hot_prefix_nodes += 1;
3856            assert!(depth < hot_prefix_depth);
3857        }
3858        assert_eq!(visited_hot_prefix_nodes, hot_prefix_depth);
3859        let segment = tree
3860            .engine
3861            .arena
3862            .child(current, 0)
3863            .as_segment()
3864            .expect("segment tail");
3865        assert!(tree.engine.arena.child(current, 1).is_none());
3866        assert!(tree.engine.arena.segments[segment.get()].tail.is_none());
3867        assert_close(
3868            tree.engine.arena.segments[segment.get()].head_log_prob_weighted,
3869            -std::f64::consts::LN_2,
3870        );
3871        assert_close(tree.get_log_block_probability(), -std::f64::consts::LN_2);
3872    }
3873
3874    #[test]
3875    fn context_tree_missing_path_tail_uses_exact_segment_payloads() {
3876        let mut tree = ContextTree::new(12);
3877        tree.update(true);
3878        let child = tree.engine.arena.child(tree.engine.root, 0);
3879        let mut current = child.as_node().expect("hot-prefix node");
3880        for _ in 1..tree.engine.hot_prefix_depth() {
3881            current = tree
3882                .engine
3883                .arena
3884                .child(current, 0)
3885                .as_node()
3886                .expect("next hot-prefix node");
3887        }
3888        let segment = tree
3889            .engine
3890            .arena
3891            .child(current, 0)
3892            .as_segment()
3893            .expect("segment tail");
3894        let payload = tree.engine.arena.segments[segment.get()].payload;
3895        assert!(payload.is_exact());
3896        assert_eq!(
3897            payload.len() as usize,
3898            tree.engine.max_depth - tree.engine.hot_prefix_depth()
3899        );
3900        assert_eq!(payload.exact_bits() & low_bits_mask_u64(payload.len()), 0);
3901    }
3902
3903    #[test]
3904    fn context_tree_missing_path_tail_uses_const_payload_beyond_exact_limit() {
3905        let mut tree = ContextTree::new(80);
3906        let history_before = tree.history.clone();
3907        tree.update(false);
3908
3909        let segment = child_after_hot_prefix(&tree, &history_before)
3910            .as_segment()
3911            .expect("segment tail");
3912        let segment = tree.engine.arena.segments[segment.get()];
3913        assert_eq!(segment.payload.mode(), SEG_MODE_CONST);
3914        assert_eq!(
3915            segment.payload.len() as usize,
3916            tree.engine.max_depth - tree.engine.hot_prefix_depth()
3917        );
3918        assert!(!segment.payload.const_bit());
3919        assert!(segment.tail.is_none());
3920    }
3921
3922    #[test]
3923    fn context_tree_missing_path_tail_uses_history_and_const_payloads_beyond_exact_limit() {
3924        let mut tree = ContextTree::new(80);
3925        let seeded_history: Vec<Symbol> = (0..80).map(|i| (i & 1) == 1).collect();
3926        tree.update_history(&seeded_history);
3927        let history_before = tree.history.clone();
3928        tree.update(false);
3929
3930        let first_segment = child_after_hot_prefix(&tree, &history_before)
3931            .as_segment()
3932            .expect("history-backed segment tail");
3933        let first_segment = tree.engine.arena.segments[first_segment.get()];
3934        assert_eq!(first_segment.payload.mode(), SEG_MODE_HISTORY);
3935        assert_eq!(first_segment.payload.len(), 69);
3936        for offset in [0usize, 1, 7, 31, 68] {
3937            assert_eq!(
3938                segment_edge_from_parts(
3939                    first_segment,
3940                    offset,
3941                    &history_before,
3942                    history_before.len()
3943                ),
3944                history_symbol(&history_before, tree.engine.hot_prefix_depth() + 1 + offset)
3945            );
3946        }
3947
3948        let tail_segment = first_segment
3949            .tail
3950            .as_segment()
3951            .expect("constant fallback tail");
3952        let tail_segment = tree.engine.arena.segments[tail_segment.get()];
3953        assert_eq!(tail_segment.payload.mode(), SEG_MODE_CONST);
3954        assert_eq!(tail_segment.payload.len(), 1);
3955        assert!(!tail_segment.payload.const_bit());
3956        assert!(tail_segment.tail.is_none());
3957    }
3958
3959    #[test]
3960    fn context_tree_matches_reference_on_short_sequences() {
3961        for depth in 0..=6usize {
3962            for len in 0..=6usize {
3963                for mask in 0..(1usize << len) {
3964                    let mut prod = ContextTree::new(depth);
3965                    let mut reference = RefContextTree::new(depth);
3966                    for step in 0..len {
3967                        let p_prod_0 = prod.predict(false);
3968                        let p_ref_0 = reference.predict(false);
3969                        assert!(
3970                            (p_prod_0 - p_ref_0).abs()
3971                                <= 1e-12 * p_prod_0.abs().max(p_ref_0.abs()).max(1.0),
3972                            "predict0 mismatch depth={depth} len={len} mask={mask} step={step} prod={p_prod_0} ref={p_ref_0} history={:?}",
3973                            prod.history
3974                        );
3975                        let p_prod_1 = prod.predict(true);
3976                        let p_ref_1 = reference.predict(true);
3977                        assert!(
3978                            (p_prod_1 - p_ref_1).abs()
3979                                <= 1e-12 * p_prod_1.abs().max(p_ref_1.abs()).max(1.0),
3980                            "predict1 mismatch depth={depth} len={len} mask={mask} step={step} prod={p_prod_1} ref={p_ref_1} history={:?}",
3981                            prod.history
3982                        );
3983                        let log_prod = prod.get_log_block_probability();
3984                        let log_ref = reference.get_log_block_probability();
3985                        assert!(
3986                            (log_prod - log_ref).abs()
3987                                <= 1e-12 * log_prod.abs().max(log_ref.abs()).max(1.0),
3988                            "log mismatch before update depth={depth} len={len} mask={mask} step={step} prod={log_prod} ref={log_ref} history={:?}",
3989                            prod.history
3990                        );
3991                        let bit = ((mask >> step) & 1) == 1;
3992                        prod.update(bit);
3993                        reference.update(bit);
3994                        let log_prod = prod.get_log_block_probability();
3995                        let log_ref = reference.get_log_block_probability();
3996                        assert!(
3997                            (log_prod - log_ref).abs()
3998                                <= 1e-12 * log_prod.abs().max(log_ref.abs()).max(1.0),
3999                            "log mismatch after update depth={depth} len={len} mask={mask} step={step} bit={bit} prod={log_prod} ref={log_ref} history={:?}",
4000                            prod.history
4001                        );
4002                    }
4003                    while prod.history_size() > 0 {
4004                        let p_prod_0 = prod.predict(false);
4005                        let p_ref_0 = reference.predict(false);
4006                        assert!(
4007                            (p_prod_0 - p_ref_0).abs()
4008                                <= 1e-12 * p_prod_0.abs().max(p_ref_0.abs()).max(1.0),
4009                            "revert predict0 mismatch depth={depth} len={len} mask={mask} prod={p_prod_0} ref={p_ref_0} history={:?}",
4010                            prod.history
4011                        );
4012                        let p_prod_1 = prod.predict(true);
4013                        let p_ref_1 = reference.predict(true);
4014                        assert!(
4015                            (p_prod_1 - p_ref_1).abs()
4016                                <= 1e-12 * p_prod_1.abs().max(p_ref_1.abs()).max(1.0),
4017                            "revert predict1 mismatch depth={depth} len={len} mask={mask} prod={p_prod_1} ref={p_ref_1} history={:?}",
4018                            prod.history
4019                        );
4020                        prod.revert();
4021                        reference.revert();
4022                        let log_prod = prod.get_log_block_probability();
4023                        let log_ref = reference.get_log_block_probability();
4024                        assert!(
4025                            (log_prod - log_ref).abs()
4026                                <= 1e-12 * log_prod.abs().max(log_ref.abs()).max(1.0),
4027                            "revert log mismatch depth={depth} len={len} mask={mask} prod={log_prod} ref={log_ref} history={:?}",
4028                            prod.history
4029                        );
4030                    }
4031                }
4032            }
4033        }
4034    }
4035
4036    #[test]
4037    fn context_tree_long_depth_matches_reference_on_short_sequences() {
4038        for &depth in &[65usize, 80usize] {
4039            for len in 0..=6usize {
4040                for mask in 0..(1usize << len) {
4041                    let mut prod = ContextTree::new(depth);
4042                    let mut reference = RefContextTree::new(depth);
4043                    for step in 0..len {
4044                        let p_prod_0 = prod.predict(false);
4045                        let p_ref_0 = reference.predict(false);
4046                        assert!(
4047                            (p_prod_0 - p_ref_0).abs()
4048                                <= 1e-12 * p_prod_0.abs().max(p_ref_0.abs()).max(1.0),
4049                            "long-depth predict0 mismatch depth={depth} len={len} mask={mask} step={step} prod={p_prod_0} ref={p_ref_0} history={:?}",
4050                            prod.history
4051                        );
4052                        let p_prod_1 = prod.predict(true);
4053                        let p_ref_1 = reference.predict(true);
4054                        assert!(
4055                            (p_prod_1 - p_ref_1).abs()
4056                                <= 1e-12 * p_prod_1.abs().max(p_ref_1.abs()).max(1.0),
4057                            "long-depth predict1 mismatch depth={depth} len={len} mask={mask} step={step} prod={p_prod_1} ref={p_ref_1} history={:?}",
4058                            prod.history
4059                        );
4060                        assert_close(
4061                            prod.get_log_block_probability(),
4062                            reference.get_log_block_probability(),
4063                        );
4064                        let bit = ((mask >> step) & 1) == 1;
4065                        prod.update(bit);
4066                        reference.update(bit);
4067                        assert_close(
4068                            prod.get_log_block_probability(),
4069                            reference.get_log_block_probability(),
4070                        );
4071                    }
4072
4073                    while prod.history_size() > 0 {
4074                        assert_close(prod.predict(false), reference.predict(false));
4075                        assert_close(prod.predict(true), reference.predict(true));
4076                        prod.revert();
4077                        reference.revert();
4078                        assert_close(
4079                            prod.get_log_block_probability(),
4080                            reference.get_log_block_probability(),
4081                        );
4082                    }
4083                }
4084            }
4085        }
4086    }
4087
4088    #[test]
4089    fn fac_ctw_matches_reference_on_short_sequences() {
4090        let mut fac = FacContextTree::new(4, 4);
4091        let mut reference = RefFacContextTree::new(4, 4);
4092        let stream = [
4093            (true, 0usize),
4094            (false, 1usize),
4095            (true, 2usize),
4096            (true, 3usize),
4097            (false, 0usize),
4098            (false, 1usize),
4099            (true, 2usize),
4100            (false, 3usize),
4101        ];
4102
4103        for &(bit, idx) in &stream {
4104            assert_close(fac.predict(false, idx), reference.predict(false, idx));
4105            assert_close(fac.predict(true, idx), reference.predict(true, idx));
4106            fac.update(bit, idx);
4107            reference.update(bit, idx);
4108            assert_close(
4109                fac.get_log_block_probability(),
4110                reference.get_log_block_probability(),
4111            );
4112        }
4113
4114        for &(_, idx) in stream.iter().rev() {
4115            fac.revert(idx);
4116            reference.revert(idx);
4117            assert_close(
4118                fac.get_log_block_probability(),
4119                reference.get_log_block_probability(),
4120            );
4121        }
4122    }
4123
4124    #[test]
4125    fn fac_ctw_history_consistency() {
4126        let mut fac = FacContextTree::new(4, 4);
4127
4128        fac.update_history(&[true, false, true]);
4129        assert_eq!(fac.shared_history.len(), 3);
4130
4131        fac.update(true, 0);
4132        fac.update(false, 1);
4133        assert_eq!(fac.shared_history.len(), 5);
4134
4135        fac.revert(1);
4136        assert_eq!(fac.shared_history.len(), 4);
4137
4138        fac.revert(0);
4139        assert_eq!(fac.shared_history.len(), 3);
4140    }
4141
4142    #[test]
4143    fn fac_ctw_predict_one_matches_predict_true() {
4144        let mut fac = FacContextTree::new(6, 8);
4145        for &byte in b"predict-one exactness regression payload" {
4146            for bit_idx in 0..8usize {
4147                let p_generic = fac.predict(true, bit_idx);
4148                let p_one = fac.predict_one(bit_idx);
4149                assert_close(p_generic, p_one);
4150                let bit = ((byte >> (7 - bit_idx)) & 1) == 1;
4151                fac.update_predicted(bit, bit_idx);
4152            }
4153        }
4154    }
4155
4156    #[test]
4157    fn fac_ctw_long_depth_predict_one_matches_predict_true() {
4158        let mut fac = FacContextTree::new(78, 4);
4159        for step in 0..24usize {
4160            for bit_idx in 0..fac.num_bits() {
4161                let p_generic = fac.predict(true, bit_idx);
4162                let p_one = fac.predict_one(bit_idx);
4163                assert_close(p_generic, p_one);
4164                let bit = ((step * 5 + bit_idx * 3) & 1) == 1;
4165                fac.update_predicted(bit, bit_idx);
4166            }
4167        }
4168    }
4169
4170    #[test]
4171    fn fac_ctw_update_byte_msb_matches_bit_updates() {
4172        let mut by_byte = FacContextTree::new(6, 8);
4173        let mut by_bits = FacContextTree::new(6, 8);
4174        for &byte in b"byte update msb regression payload" {
4175            by_byte.update_byte_msb(byte);
4176            for bit_idx in 0..8usize {
4177                let bit = ((byte >> (7 - bit_idx)) & 1) == 1;
4178                by_bits.update(bit, bit_idx);
4179            }
4180            assert_close(
4181                by_byte.get_log_block_probability(),
4182                by_bits.get_log_block_probability(),
4183            );
4184            assert_eq!(by_byte.shared_history, by_bits.shared_history);
4185        }
4186    }
4187
4188    #[test]
4189    fn fac_ctw_update_byte_lsb_matches_bit_updates() {
4190        let mut by_byte = FacContextTree::new(6, 5);
4191        let mut by_bits = FacContextTree::new(6, 5);
4192        for &byte in b"byte update lsb regression payload" {
4193            by_byte.update_byte_lsb(byte);
4194            for bit_idx in 0..5usize {
4195                let bit = ((byte >> bit_idx) & 1) == 1;
4196                by_bits.update(bit, bit_idx);
4197            }
4198            assert_close(
4199                by_byte.get_log_block_probability(),
4200                by_bits.get_log_block_probability(),
4201            );
4202            assert_eq!(by_byte.shared_history, by_bits.shared_history);
4203        }
4204    }
4205
4206    #[test]
4207    fn fac_ctw_log_cache_tracks_tree_visits_not_shared_history() {
4208        let mut fac = FacContextTree::new(8, 8);
4209        let updates_per_tree = 512usize;
4210        let (log_int_before, log_half_before) = shared_log_cache_lens();
4211
4212        for step in 0..updates_per_tree {
4213            let bit = (step & 1) == 1;
4214            for bit_idx in 0..8usize {
4215                fac.update(bit, bit_idx);
4216            }
4217        }
4218
4219        assert_eq!(fac.shared_history.len(), updates_per_tree * 8);
4220        for tree in &fac.trees {
4221            let visits = tree.engine.arena.visits(tree.engine.root) as usize;
4222            assert_eq!(visits, updates_per_tree);
4223        }
4224
4225        let (log_int_after, log_half_after) = shared_log_cache_lens();
4226        let expected_len = updates_per_tree + 1;
4227        assert!(
4228            log_int_after <= log_int_before.max(expected_len),
4229            "log_int grew to {log_int_after} (before={log_int_before}, expected_len={expected_len})"
4230        );
4231        assert!(
4232            log_half_after <= log_half_before.max(expected_len),
4233            "log_half grew to {log_half_after} (before={log_half_before}, expected_len={expected_len})"
4234        );
4235    }
4236
4237    fn seed_fac_cache_regression_state(fac: &mut FacContextTree) {
4238        for step in 0..24usize {
4239            for bit_idx in 0..fac.num_bits() {
4240                let bit = ((step * 3 + bit_idx) & 1) == 1;
4241                fac.update(bit, bit_idx);
4242            }
4243        }
4244    }
4245
4246    fn assert_update_predicted_matches_fresh_after_history_rewrite<F>(mut rewrite: F)
4247    where
4248        F: FnMut(&mut FacContextTree),
4249    {
4250        let mut predicted = FacContextTree::new(6, 4);
4251        seed_fac_cache_regression_state(&mut predicted);
4252        let mut fresh = predicted.clone();
4253        let original_history = predicted.shared_history.clone();
4254        let target_bit = 2usize;
4255
4256        let _ = predicted.predict(true, target_bit);
4257        rewrite(&mut predicted);
4258        rewrite(&mut fresh);
4259
4260        assert_eq!(predicted.shared_history.len(), original_history.len());
4261        assert_ne!(predicted.shared_history, original_history);
4262        assert_eq!(predicted.shared_history, fresh.shared_history);
4263
4264        predicted.update_predicted(false, target_bit);
4265        fresh.update(false, target_bit);
4266
4267        assert_eq!(predicted.shared_history, fresh.shared_history);
4268        assert_close(
4269            predicted.get_log_block_probability(),
4270            fresh.get_log_block_probability(),
4271        );
4272        for bit_idx in 0..predicted.num_bits() {
4273            assert_close(
4274                predicted.predict(false, bit_idx),
4275                fresh.predict(false, bit_idx),
4276            );
4277            assert_close(
4278                predicted.predict(true, bit_idx),
4279                fresh.predict(true, bit_idx),
4280            );
4281        }
4282    }
4283
4284    #[test]
4285    fn fac_ctw_update_predicted_ignores_stale_cache_after_reset_and_rewrite() {
4286        assert_update_predicted_matches_fresh_after_history_rewrite(|fac| {
4287            let mut rewritten = fac.shared_history.clone();
4288            for bit in &mut rewritten {
4289                *bit = !*bit;
4290            }
4291            fac.reset_history_only();
4292            fac.update_history(&rewritten);
4293        });
4294    }
4295
4296    #[test]
4297    fn fac_ctw_update_predicted_ignores_stale_cache_after_revert_and_rewrite() {
4298        assert_update_predicted_matches_fresh_after_history_rewrite(|fac| {
4299            let original = fac.shared_history.clone();
4300            let keep = original.len() / 3;
4301            let remove = original.len() - keep;
4302            let mut rewritten_suffix = original[keep..].to_vec();
4303            for bit in &mut rewritten_suffix {
4304                *bit = !*bit;
4305            }
4306            fac.revert_history(remove);
4307            fac.update_history(&rewritten_suffix);
4308        });
4309    }
4310
4311    #[test]
4312    fn fac_ctw_shared_history_version_tracks_mutations() {
4313        let mut fac = FacContextTree::new(4, 2);
4314        let mut version = fac.shared_history_version;
4315
4316        fac.update_history(&[]);
4317        assert_eq!(fac.shared_history_version, version);
4318
4319        fac.update_history(&[true, false]);
4320        assert_ne!(fac.shared_history_version, version);
4321        version = fac.shared_history_version;
4322
4323        fac.revert_history(0);
4324        assert_eq!(fac.shared_history_version, version);
4325
4326        fac.revert_history(1);
4327        assert_ne!(fac.shared_history_version, version);
4328        version = fac.shared_history_version;
4329
4330        let _ = fac.predict(true, 0);
4331        assert_eq!(fac.shared_history_version, version);
4332
4333        fac.update_predicted(true, 0);
4334        assert_ne!(fac.shared_history_version, version);
4335        version = fac.shared_history_version;
4336
4337        fac.reset_history_only();
4338        assert_ne!(fac.shared_history_version, version);
4339    }
4340
4341    #[test]
4342    fn context_tree_predict_preserves_state() {
4343        let mut tree = ContextTree::new(6);
4344        for &bit in &[true, false, true, true, false, false, true, false] {
4345            tree.update(bit);
4346        }
4347        let p0_before = tree.predict(false);
4348        let p1_before = tree.predict(true);
4349        let log_before = tree.get_log_block_probability();
4350        let history_before = tree.history.clone();
4351        let _ = tree.predict(true);
4352
4353        assert_eq!(tree.history, history_before);
4354        assert_close(tree.get_log_block_probability(), log_before);
4355        assert_close(tree.predict(false), p0_before);
4356        assert_close(tree.predict(true), p1_before);
4357    }
4358
4359    #[test]
4360    fn context_tree_predict_matches_update_ratio() {
4361        let mut tree = ContextTree::new(7);
4362        for &bit in &[true, false, true, false, true, true, false, true, false] {
4363            tree.update(bit);
4364        }
4365        for &sym in &[false, true] {
4366            let predicted = tree.predict(sym);
4367            let mut reference = tree.clone();
4368            let before = reference.get_log_block_probability();
4369            reference.update(sym);
4370            let after = reference.get_log_block_probability();
4371            assert_close(predicted, (after - before).exp());
4372        }
4373    }
4374
4375    #[test]
4376    fn fac_ctw_predict_preserves_state() {
4377        let mut fac = FacContextTree::new(5, 8);
4378        for &byte in b"fac ctw state preservation" {
4379            for bit_idx in 0..8usize {
4380                let bit = ((byte >> (7 - bit_idx)) & 1) == 1;
4381                fac.update(bit, bit_idx);
4382            }
4383        }
4384        let p0_before = fac.predict(false, 3);
4385        let p1_before = fac.predict(true, 3);
4386        let log_before = fac.get_log_block_probability();
4387        let history_before = fac.shared_history.clone();
4388        let _ = fac.predict(true, 3);
4389
4390        assert_eq!(fac.shared_history, history_before);
4391        assert_close(fac.get_log_block_probability(), log_before);
4392        assert_close(fac.predict(false, 3), p0_before);
4393        assert_close(fac.predict(true, 3), p1_before);
4394    }
4395
4396    #[test]
4397    fn fac_ctw_predict_matches_update_ratio() {
4398        let mut fac = FacContextTree::new(6, 8);
4399        for &byte in b"fac ctw exact predictive ratio" {
4400            for bit_idx in 0..8usize {
4401                let bit = ((byte >> (7 - bit_idx)) & 1) == 1;
4402                fac.update(bit, bit_idx);
4403            }
4404        }
4405        for &sym in &[false, true] {
4406            let predicted = fac.predict(sym, 4);
4407            let mut reference = fac.clone();
4408            let before = reference.get_log_block_probability();
4409            reference.update(sym, 4);
4410            let after = reference.get_log_block_probability();
4411            assert_close(predicted, (after - before).exp());
4412        }
4413    }
4414
4415    #[test]
4416    fn fac_ctw_update_predicted_matches_fresh_update_on_byte_stream() {
4417        let mut predicted = FacContextTree::new(6, 8);
4418        let mut fresh = predicted.clone();
4419        let stream = b"fac-ctw prepared update exactness regression";
4420
4421        for (byte_pos, &byte) in stream.iter().enumerate() {
4422            for bit_idx in 0..8usize {
4423                let bit = ((byte >> (7 - bit_idx)) & 1) == 1;
4424                let _ = predicted.predict(true, bit_idx);
4425                predicted.update_predicted(bit, bit_idx);
4426                fresh.update(bit, bit_idx);
4427
4428                let predicted_log = predicted.get_log_block_probability();
4429                let fresh_log = fresh.get_log_block_probability();
4430                assert!(
4431                    (predicted_log - fresh_log).abs()
4432                        <= 1e-12 * predicted_log.abs().max(fresh_log.abs()).max(1.0),
4433                    "log mismatch byte_pos={byte_pos} bit_idx={bit_idx} bit={bit} predicted={predicted_log} fresh={fresh_log}\nshared_history={:?}\npredicted_arena={:#?}\nfresh_arena={:#?}\npredicted_steps={:?}\nfresh_steps={:?}",
4434                    predicted.shared_history,
4435                    predicted.trees[bit_idx].engine.arena,
4436                    fresh.trees[bit_idx].engine.arena,
4437                    predicted.trees[bit_idx].engine.prepared_steps,
4438                    fresh.trees[bit_idx].engine.prepared_steps,
4439                );
4440                for probe_idx in 0..8usize {
4441                    let p_pred_0 = predicted.predict(false, probe_idx);
4442                    let p_fresh_0 = fresh.predict(false, probe_idx);
4443                    assert!(
4444                        (p_pred_0 - p_fresh_0).abs()
4445                            <= 1e-12 * p_pred_0.abs().max(p_fresh_0.abs()).max(1.0),
4446                        "predict0 mismatch byte_pos={byte_pos} bit_idx={bit_idx} probe_idx={probe_idx} predicted={p_pred_0} fresh={p_fresh_0}",
4447                    );
4448                    let p_pred_1 = predicted.predict(true, probe_idx);
4449                    let p_fresh_1 = fresh.predict(true, probe_idx);
4450                    assert!(
4451                        (p_pred_1 - p_fresh_1).abs()
4452                            <= 1e-12 * p_pred_1.abs().max(p_fresh_1.abs()).max(1.0),
4453                        "predict1 mismatch byte_pos={byte_pos} bit_idx={bit_idx} probe_idx={probe_idx} predicted={p_pred_1} fresh={p_fresh_1}",
4454                    );
4455                }
4456            }
4457        }
4458    }
4459
4460    #[test]
4461    fn fac_ctw_long_depth_update_predicted_matches_fresh_update_on_bit_stream() {
4462        let mut predicted = FacContextTree::new(78, 4);
4463        let mut fresh = predicted.clone();
4464
4465        for step in 0..20usize {
4466            for bit_idx in 0..predicted.num_bits() {
4467                let bit = ((step * 7 + bit_idx * 11) & 1) == 1;
4468                let _ = predicted.predict(true, bit_idx);
4469                predicted.update_predicted(bit, bit_idx);
4470                fresh.update(bit, bit_idx);
4471                assert_eq!(predicted.shared_history, fresh.shared_history);
4472                assert_close(
4473                    predicted.get_log_block_probability(),
4474                    fresh.get_log_block_probability(),
4475                );
4476            }
4477        }
4478
4479        for bit_idx in 0..predicted.num_bits() {
4480            assert_close(
4481                predicted.predict(false, bit_idx),
4482                fresh.predict(false, bit_idx),
4483            );
4484            assert_close(
4485                predicted.predict(true, bit_idx),
4486                fresh.predict(true, bit_idx),
4487            );
4488        }
4489    }
4490
4491    fn scan_symbol_space(tree: &mut FacContextTree, bits: usize) {
4492        fn rec(tree: &mut FacContextTree, bits: usize, depth: usize) {
4493            if depth == bits {
4494                return;
4495            }
4496            for bit in [false, true] {
4497                let bit_idx = depth;
4498                tree.update(bit, bit_idx);
4499                rec(tree, bits, depth + 1);
4500                tree.revert(bit_idx);
4501            }
4502        }
4503        rec(tree, bits, 0);
4504    }
4505
4506    fn byte_log_prob(tree: &mut FacContextTree, symbol: u8, msb_first: bool, bits: usize) -> f64 {
4507        let before = tree.get_log_block_probability();
4508        if msb_first {
4509            for bit_idx in 0..bits {
4510                let bit = ((symbol >> (7 - bit_idx)) & 1) == 1;
4511                tree.update(bit, bit_idx);
4512            }
4513            let after = tree.get_log_block_probability();
4514            for bit_idx in (0..bits).rev() {
4515                tree.revert(bit_idx);
4516            }
4517            after - before
4518        } else {
4519            for bit_idx in 0..bits {
4520                let bit = ((symbol >> bit_idx) & 1) == 1;
4521                tree.update(bit, bit_idx);
4522            }
4523            let after = tree.get_log_block_probability();
4524            for bit_idx in (0..bits).rev() {
4525                tree.revert(bit_idx);
4526            }
4527            after - before
4528        }
4529    }
4530
4531    fn assert_symbol_scan_then_update_matches_plain(msb_first: bool) {
4532        let bits = 8usize;
4533        let mut with_scan = FacContextTree::new(7, bits);
4534        let mut plain = with_scan.clone();
4535        for &byte in b"pdf then update parity payload" {
4536            for bit_idx in 0..bits {
4537                let bit = if msb_first {
4538                    ((byte >> (7 - bit_idx)) & 1) == 1
4539                } else {
4540                    ((byte >> bit_idx) & 1) == 1
4541                };
4542                with_scan.update(bit, bit_idx);
4543                plain.update(bit, bit_idx);
4544            }
4545        }
4546
4547        scan_symbol_space(&mut with_scan, bits);
4548
4549        let observed = b'n';
4550        for bit_idx in 0..bits {
4551            let bit = if msb_first {
4552                ((observed >> (7 - bit_idx)) & 1) == 1
4553            } else {
4554                ((observed >> bit_idx) & 1) == 1
4555            };
4556            with_scan.update(bit, bit_idx);
4557            plain.update(bit, bit_idx);
4558        }
4559
4560        for sym in 0u8..=255u8 {
4561            let lp_scan = byte_log_prob(&mut with_scan, sym, msb_first, bits);
4562            let lp_plain = byte_log_prob(&mut plain, sym, msb_first, bits);
4563            let diff = (lp_scan - lp_plain).abs();
4564            assert!(
4565                diff < 1e-12,
4566                "symbol={sym} lp_scan={lp_scan} lp_plain={lp_plain} diff={diff}",
4567            );
4568        }
4569    }
4570
4571    #[test]
4572    fn fac_ctw_symbol_scan_then_update_matches_plain_msb() {
4573        assert_symbol_scan_then_update_matches_plain(true);
4574    }
4575
4576    #[test]
4577    fn fac_ctw_symbol_scan_then_update_matches_plain_lsb() {
4578        assert_symbol_scan_then_update_matches_plain(false);
4579    }
4580}