infotheory/
ctw.rs

1//! Context Tree Weighting (CTW) and Factorized Action-Conditional CTW (FAC-CTW).
2//!
3//! This module implements both the standard CTW algorithm for binary sequence prediction
4//! and the FAC-CTW variant described in Veness et al. (2011) for agent-based prediction.
5//!
6//! # Arena Allocator
7//! For performance with deep trees (D > 64), nodes are stored in a flat arena using
8//! indices rather than `Box` pointers, eliminating pointer chasing and improving cache locality.
9//!
10//! # Shared History Optimization (FAC-CTW)
11//! FAC-CTW uses k trees that share the same base history. Rather than duplicating the
12//! history k times, a single shared history is maintained with per-tree length tracking.
13
14use std::f64;
15
16type Symbol = bool;
17
18/// Index into the node arena. `NONE` indicates no child.
19#[derive(Clone, Copy, Debug, PartialEq, Eq)]
20pub struct NodeIndex(u32);
21
22impl NodeIndex {
23    pub const NONE: NodeIndex = NodeIndex(u32::MAX);
24
25    #[inline(always)]
26    pub fn is_none(self) -> bool {
27        self.0 == u32::MAX
28    }
29
30    #[inline(always)]
31    pub fn is_some(self) -> bool {
32        self.0 != u32::MAX
33    }
34
35    #[inline(always)]
36    pub fn get(self) -> usize {
37        self.0 as usize
38    }
39}
40
41/// A node in the Context Tree, stored in an arena.
42#[derive(Clone, Debug)]
43pub struct CtNode {
44    /// Child node indices for context extensions (0 and 1).
45    pub children: [NodeIndex; 2],
46    /// Log-probability estimated by the KT-estimator.
47    pub log_prob_kt: f64,
48    /// Weighted log-probability (mix of KT and children's weighted probabilities).
49    pub log_prob_weighted: f64,
50    /// Counts of symbols (0 and 1) observed in this context.
51    pub symbol_count: [u32; 2],
52}
53
54impl CtNode {
55    #[inline(always)]
56    pub fn new() -> Self {
57        Self {
58            children: [NodeIndex::NONE, NodeIndex::NONE],
59            log_prob_kt: 0.0,
60            log_prob_weighted: 0.0,
61            symbol_count: [0, 0],
62        }
63    }
64
65    /// Returns the total number of visits (symbol observations) at this node.
66    #[inline(always)]
67    pub fn visits(&self) -> u32 {
68        self.symbol_count[0] + self.symbol_count[1]
69    }
70}
71
72/// Arena allocator for context tree nodes.
73#[derive(Clone, Debug)]
74pub struct CtArena {
75    nodes: Vec<CtNode>,
76    free_list: Vec<NodeIndex>,
77}
78
79impl CtArena {
80    pub fn new() -> Self {
81        Self {
82            nodes: Vec::with_capacity(1024),
83            free_list: Vec::new(),
84        }
85    }
86
87    pub fn with_capacity(cap: usize) -> Self {
88        Self {
89            nodes: Vec::with_capacity(cap),
90            free_list: Vec::new(),
91        }
92    }
93
94    #[inline(always)]
95    pub fn alloc(&mut self) -> NodeIndex {
96        if let Some(idx) = self.free_list.pop() {
97            self.nodes[idx.get()] = CtNode::new();
98            idx
99        } else {
100            let idx = NodeIndex(self.nodes.len() as u32);
101            self.nodes.push(CtNode::new());
102            idx
103        }
104    }
105
106    #[inline(always)]
107    pub fn free(&mut self, idx: NodeIndex) {
108        if idx.is_some() {
109            self.free_list.push(idx);
110        }
111    }
112
113    #[inline(always)]
114    pub fn get(&self, idx: NodeIndex) -> &CtNode {
115        &self.nodes[idx.get()]
116    }
117
118    #[inline(always)]
119    pub fn get_mut(&mut self, idx: NodeIndex) -> &mut CtNode {
120        &mut self.nodes[idx.get()]
121    }
122
123    pub fn clear(&mut self) {
124        self.nodes.clear();
125        self.free_list.clear();
126    }
127
128    /// Returns approximate memory usage in bytes.
129    pub fn memory_usage(&self) -> usize {
130        self.nodes.capacity() * std::mem::size_of::<CtNode>()
131            + self.free_list.capacity() * std::mem::size_of::<NodeIndex>()
132    }
133}
134
135/// A Context Tree for binary sequence prediction using arena allocation.
136#[derive(Clone)]
137pub struct ContextTree {
138    arena: CtArena,
139    root: NodeIndex,
140    history: Vec<Symbol>,
141    max_depth: usize,
142    context_buf: Vec<Symbol>,
143    path_buf: Vec<NodeIndex>,
144}
145
146impl ContextTree {
147    /// Creates a new `ContextTree` with the given depth.
148    pub fn new(depth: usize) -> Self {
149        let mut arena = CtArena::with_capacity(1024.min(1 << depth.min(16)));
150        let root = arena.alloc();
151        Self {
152            arena,
153            root,
154            history: Vec::new(),
155            max_depth: depth,
156            context_buf: vec![false; depth],
157            path_buf: Vec::with_capacity(depth + 1),
158        }
159    }
160
161    /// Resets the tree and history.
162    pub fn clear(&mut self) {
163        self.history.clear();
164        self.arena.clear();
165        self.root = self.arena.alloc();
166        self.context_buf.fill(false);
167    }
168
169    /// Updates the tree with a new symbol.
170    #[inline]
171    pub fn update(&mut self, sym: Symbol) {
172        self.prepare_context();
173        self.update_from_root(sym, false);
174        self.history.push(sym);
175    }
176
177    /// Reverts the last symbol update.
178    #[inline]
179    pub fn revert(&mut self) {
180        let Some(last_sym) = self.history.pop() else {
181            return;
182        };
183        self.prepare_context();
184        self.update_from_root(last_sym, true);
185    }
186
187    /// Appends symbols to the history without updating the tree (for action conditioning).
188    #[inline]
189    pub fn update_history(&mut self, symbols: &[Symbol]) {
190        self.history.extend_from_slice(symbols);
191    }
192
193    /// Removes the last symbol from history without tree update.
194    #[inline]
195    pub fn revert_history(&mut self) {
196        self.history.pop();
197    }
198
199    /// Truncates the history to `new_size`.
200    pub fn truncate_history(&mut self, new_size: usize) {
201        if new_size < self.history.len() {
202            self.history.truncate(new_size);
203        }
204    }
205
206    /// Predicts the probability of the next symbol being `sym`.
207    #[inline]
208    pub fn predict(&mut self, sym: Symbol) -> f64 {
209        let log_prob_before = self.arena.get(self.root).log_prob_weighted;
210        self.update(sym);
211        let log_prob_after = self.arena.get(self.root).log_prob_weighted;
212        self.revert();
213        (log_prob_after - log_prob_before).exp()
214    }
215
216    /// Shorthand for predicting the probability of symbol `1` (`true`).
217    #[inline]
218    pub fn predict_sym_prob(&mut self) -> f64 {
219        self.predict(true)
220    }
221
222    /// Returns the total log-probability of the sequence observed so far.
223    #[inline]
224    pub fn get_log_block_probability(&self) -> f64 {
225        self.arena.get(self.root).log_prob_weighted
226    }
227
228    /// Returns the configured maximum depth of the tree.
229    #[inline]
230    pub fn depth(&self) -> usize {
231        self.max_depth
232    }
233
234    /// Returns the current length of the history.
235    #[inline]
236    pub fn history_size(&self) -> usize {
237        self.history.len()
238    }
239
240    // --- Internal methods ---
241
242    #[inline(always)]
243    fn prepare_context(&mut self) {
244        self.context_buf.fill(false);
245        let history_len = self.history.len();
246        let copy_len = history_len.min(self.max_depth);
247        if copy_len > 0 {
248            self.context_buf[self.max_depth - copy_len..]
249                .copy_from_slice(&self.history[history_len - copy_len..]);
250        }
251    }
252
253    #[inline]
254    fn update_from_root(&mut self, sym: Symbol, revert: bool) {
255        self.update_node_iterative(self.root, sym, revert);
256    }
257
258    /// Iterative update to avoid deep recursion and enable better inlining.
259    #[inline]
260    fn update_node_iterative(&mut self, root_idx: NodeIndex, sym: Symbol, revert: bool) {
261        let max_depth = self.max_depth;
262
263        // Build path from root to leaf
264        // optimizations: reuse buffer to avoid repeated allocations
265        let mut path = std::mem::take(&mut self.path_buf);
266        path.clear();
267        path.push(root_idx);
268
269        let mut current = root_idx;
270        for depth in 0..max_depth {
271            let child_sym = self.context_buf[max_depth - 1 - depth];
272            let child_idx = self.arena.get(current).children[child_sym as usize];
273
274            if revert {
275                if child_idx.is_none() {
276                    break;
277                }
278                current = child_idx;
279            } else {
280                let child = if child_idx.is_none() {
281                    let new_child = self.arena.alloc();
282                    self.arena.get_mut(current).children[child_sym as usize] = new_child;
283                    new_child
284                } else {
285                    child_idx
286                };
287                current = child;
288            }
289            path.push(current);
290        }
291
292        // Update nodes from leaf to root
293        let leaf_depth = path.len() - 1;
294        for (i, &node_idx) in path.iter().enumerate().rev() {
295            let is_leaf = i == leaf_depth;
296            self.update_single_node(node_idx, sym, revert, is_leaf);
297
298            // Clean up empty children during revert
299            if revert && i > 0 {
300                let parent_idx = path[i - 1];
301                let depth = i - 1;
302                let child_sym = self.context_buf[max_depth - 1 - depth];
303                if self.arena.get(node_idx).visits() == 0 {
304                    self.arena.get_mut(parent_idx).children[child_sym as usize] = NodeIndex::NONE;
305                    self.arena.free(node_idx);
306                }
307            }
308        }
309
310        self.path_buf = path;
311    }
312
313    #[inline(always)]
314    fn update_single_node(&mut self, idx: NodeIndex, sym: Symbol, revert: bool, is_leaf: bool) {
315        // Read child weighted probs BEFORE taking mutable borrow
316        let (log_prob_w0, log_prob_w1) = if !is_leaf {
317            let node = self.arena.get(idx);
318            let child0 = node.children[0];
319            let child1 = node.children[1];
320            let w0 = if child0.is_some() {
321                self.arena.get(child0).log_prob_weighted
322            } else {
323                0.0
324            };
325            let w1 = if child1.is_some() {
326                self.arena.get(child1).log_prob_weighted
327            } else {
328                0.0
329            };
330            (w0, w1)
331        } else {
332            (0.0, 0.0)
333        };
334
335        let node = self.arena.get_mut(idx);
336
337        // Update KT estimator
338        let sym_idx = sym as usize;
339        if !revert {
340            node.log_prob_kt += log_kt_mul(node.symbol_count, sym);
341            node.symbol_count[sym_idx] += 1;
342        } else {
343            let total = node.symbol_count[0] + node.symbol_count[1];
344            if node.symbol_count[sym_idx] > 0 && total > 0 {
345                let numerator = (node.symbol_count[sym_idx] as f64 - 0.5).ln();
346                let denominator = (total as f64).ln();
347                node.log_prob_kt -= numerator - denominator;
348                node.symbol_count[sym_idx] -= 1;
349            }
350        }
351
352        // Update weighted probability
353        if is_leaf {
354            node.log_prob_weighted = node.log_prob_kt;
355        } else {
356            let mut prob_w01_kt_ratio = (log_prob_w0 + log_prob_w1 - node.log_prob_kt).exp();
357            if prob_w01_kt_ratio > 1.0 {
358                prob_w01_kt_ratio = (node.log_prob_kt - log_prob_w0 - log_prob_w1).exp();
359                node.log_prob_weighted = log_prob_w0 + log_prob_w1;
360            } else {
361                node.log_prob_weighted = node.log_prob_kt;
362            }
363
364            if prob_w01_kt_ratio.is_nan() {
365                prob_w01_kt_ratio = 0.0;
366            }
367            node.log_prob_weighted += prob_w01_kt_ratio.ln_1p() - std::f64::consts::LN_2;
368        }
369
370        // Sanity check
371        if node.log_prob_kt > 1.0e-10 {
372            node.log_prob_kt = 0.0;
373        }
374        if node.log_prob_weighted > 1.0e-10 {
375            node.log_prob_weighted = 0.0;
376        }
377    }
378}
379
380/// KT estimator log-multiplier calculation.
381#[inline(always)]
382fn log_kt_mul(counts: [u32; 2], sym: Symbol) -> f64 {
383    let sym_idx = sym as usize;
384    let denominator = ((counts[0] + counts[1] + 1) as f64).ln();
385    (counts[sym_idx] as f64 + 0.5).ln() - denominator
386}
387// Factorized Action-Conditional CTW (FAC-CTW)
388// =============================================================================
389
390/// Core tree structure without owned history (for use in shared-history FAC-CTW).
391#[derive(Clone)]
392struct ContextTreeCore {
393    arena: CtArena,
394    root: NodeIndex,
395    max_depth: usize,
396    context_buf: Vec<Symbol>,
397    path_buf: Vec<NodeIndex>,
398}
399
400impl ContextTreeCore {
401    fn new(depth: usize) -> Self {
402        let mut arena = CtArena::with_capacity(1024.min(1 << depth.min(16)));
403        let root = arena.alloc();
404        Self {
405            arena,
406            root,
407            max_depth: depth,
408            context_buf: vec![false; depth],
409            path_buf: Vec::with_capacity(depth + 1),
410        }
411    }
412
413    fn clear(&mut self) {
414        self.arena.clear();
415        self.root = self.arena.alloc();
416        self.context_buf.fill(false);
417    }
418
419    /// Prepares context buffer from shared history using this tree's effective length.
420    #[inline(always)]
421    fn prepare_context(&mut self, shared_history: &[Symbol]) {
422        self.context_buf.fill(false);
423        let history_len = shared_history.len();
424        let copy_len = history_len.min(self.max_depth);
425        if copy_len > 0 {
426            self.context_buf[self.max_depth - copy_len..]
427                .copy_from_slice(&shared_history[history_len - copy_len..]);
428        }
429    }
430
431    /// Update tree with symbol, using shared history for context.
432    #[inline]
433    fn update(&mut self, sym: Symbol, shared_history: &[Symbol]) {
434        self.prepare_context(shared_history);
435        self.update_node_iterative(sym, false);
436    }
437
438    /// Revert last update, using shared history for context.
439    #[inline]
440    fn revert(&mut self, last_sym: Symbol, shared_history: &[Symbol]) {
441        self.prepare_context(shared_history);
442        self.update_node_iterative(last_sym, true);
443    }
444
445    /// Predict probability of sym using shared history.
446    #[inline]
447    fn predict(&mut self, sym: Symbol, shared_history: &[Symbol]) -> f64 {
448        let log_prob_before = self.arena.get(self.root).log_prob_weighted;
449        self.update(sym, shared_history);
450        let log_prob_after = self.arena.get(self.root).log_prob_weighted;
451        self.prepare_context(shared_history);
452        self.update_node_iterative(sym, true);
453        (log_prob_after - log_prob_before).exp()
454    }
455
456    #[inline]
457    fn get_log_block_probability(&self) -> f64 {
458        self.arena.get(self.root).log_prob_weighted
459    }
460
461    #[inline]
462    fn update_node_iterative(&mut self, sym: Symbol, revert: bool) {
463        let max_depth = self.max_depth;
464
465        let mut path = std::mem::take(&mut self.path_buf);
466        path.clear();
467        path.push(self.root);
468
469        let mut current = self.root;
470        for depth in 0..max_depth {
471            let child_sym = self.context_buf[max_depth - 1 - depth];
472            let child_idx = self.arena.get(current).children[child_sym as usize];
473
474            if revert {
475                if child_idx.is_none() {
476                    break;
477                }
478                current = child_idx;
479            } else {
480                let child = if child_idx.is_none() {
481                    let new_child = self.arena.alloc();
482                    self.arena.get_mut(current).children[child_sym as usize] = new_child;
483                    new_child
484                } else {
485                    child_idx
486                };
487                current = child;
488            }
489            path.push(current);
490        }
491
492        let leaf_depth = path.len() - 1;
493        for (i, &node_idx) in path.iter().enumerate().rev() {
494            let is_leaf = i == leaf_depth;
495            self.update_single_node(node_idx, sym, revert, is_leaf);
496
497            if revert && i > 0 {
498                let parent_idx = path[i - 1];
499                let depth = i - 1;
500                let child_sym = self.context_buf[max_depth - 1 - depth];
501                if self.arena.get(node_idx).visits() == 0 {
502                    self.arena.get_mut(parent_idx).children[child_sym as usize] = NodeIndex::NONE;
503                    self.arena.free(node_idx);
504                }
505            }
506        }
507
508        self.path_buf = path;
509    }
510
511    #[inline(always)]
512    fn update_single_node(&mut self, idx: NodeIndex, sym: Symbol, revert: bool, is_leaf: bool) {
513        let (log_prob_w0, log_prob_w1) = if !is_leaf {
514            let node = self.arena.get(idx);
515            let child0 = node.children[0];
516            let child1 = node.children[1];
517            let w0 = if child0.is_some() {
518                self.arena.get(child0).log_prob_weighted
519            } else {
520                0.0
521            };
522            let w1 = if child1.is_some() {
523                self.arena.get(child1).log_prob_weighted
524            } else {
525                0.0
526            };
527            (w0, w1)
528        } else {
529            (0.0, 0.0)
530        };
531
532        let node = self.arena.get_mut(idx);
533
534        let sym_idx = sym as usize;
535        if !revert {
536            node.log_prob_kt += log_kt_mul(node.symbol_count, sym);
537            node.symbol_count[sym_idx] += 1;
538        } else {
539            let total = node.symbol_count[0] + node.symbol_count[1];
540            if node.symbol_count[sym_idx] > 0 && total > 0 {
541                let numerator = (node.symbol_count[sym_idx] as f64 - 0.5).ln();
542                let denominator = (total as f64).ln();
543                node.log_prob_kt -= numerator - denominator;
544                node.symbol_count[sym_idx] -= 1;
545            }
546        }
547
548        if is_leaf {
549            node.log_prob_weighted = node.log_prob_kt;
550        } else {
551            let mut prob_w01_kt_ratio = (log_prob_w0 + log_prob_w1 - node.log_prob_kt).exp();
552            if prob_w01_kt_ratio > 1.0 {
553                prob_w01_kt_ratio = (node.log_prob_kt - log_prob_w0 - log_prob_w1).exp();
554                node.log_prob_weighted = log_prob_w0 + log_prob_w1;
555            } else {
556                node.log_prob_weighted = node.log_prob_kt;
557            }
558
559            if prob_w01_kt_ratio.is_nan() {
560                prob_w01_kt_ratio = 0.0;
561            }
562            node.log_prob_weighted += prob_w01_kt_ratio.ln_1p() - std::f64::consts::LN_2;
563        }
564
565        if node.log_prob_kt > 1.0e-10 {
566            node.log_prob_kt = 0.0;
567        }
568        if node.log_prob_weighted > 1.0e-10 {
569            node.log_prob_weighted = 0.0;
570        }
571    }
572}
573
574/// Factorized Action-Conditional Context Tree Weighting.
575///
576/// FAC-CTW uses `k` separate context trees, one for each bit of the percept space.
577/// Tree `i` (0-indexed) has depth `base_depth + i`, ensuring each percept bit is
578/// dependent on the same portion of history while incorporating type information.
579///
580/// **Optimization**: All trees share a single history vector. Each tree tracks only
581/// its effective history length, avoiding k-way duplication of history data.
582///
583/// Reference: Veness et al. (2011), Section 5, Equation 33-34.
584#[derive(Clone)]
585pub struct FacContextTree {
586    /// Core trees without owned history.
587    trees: Vec<ContextTreeCore>,
588    /// Single shared history for all trees.
589    shared_history: Vec<Symbol>,
590    /// Base context depth D.
591    base_depth: usize,
592    /// Number of percept bits (k = l_X).
593    num_bits: usize,
594}
595
596impl FacContextTree {
597    /// Creates a new FAC-CTW with `base_depth` D and `num_percept_bits` k.
598    ///
599    /// Tree i will have depth D + i to ensure proper context overlap.
600    pub fn new(base_depth: usize, num_percept_bits: usize) -> Self {
601        let trees = (0..num_percept_bits)
602            .map(|i| ContextTreeCore::new(base_depth + i))
603            .collect();
604        Self {
605            trees,
606            shared_history: Vec::new(),
607            base_depth,
608            num_bits: num_percept_bits,
609        }
610    }
611
612    /// Returns the number of percept bits (k).
613    #[inline]
614    pub fn num_bits(&self) -> usize {
615        self.num_bits
616    }
617
618    /// Returns the base context depth (D).
619    #[inline]
620    pub fn base_depth(&self) -> usize {
621        self.base_depth
622    }
623
624    /// Updates tree `bit_index` with symbol `sym` and updates subsequent trees' effective lengths.
625    ///
626    /// Call this in sequence for bits 0..k of each percept.
627    #[inline]
628    pub fn update(&mut self, sym: Symbol, bit_index: usize) {
629        debug_assert!(bit_index < self.num_bits);
630
631        // Update the tree responsible for this bit
632        self.trees[bit_index].update(sym, &self.shared_history);
633
634        // Append to shared history
635        self.shared_history.push(sym);
636    }
637
638    /// Predicts the probability of `sym` at `bit_index`.
639    #[inline]
640    pub fn predict(&mut self, sym: Symbol, bit_index: usize) -> f64 {
641        debug_assert!(bit_index < self.num_bits);
642        self.trees[bit_index].predict(sym, &self.shared_history)
643    }
644
645    /// Reverts the update at `bit_index`.
646    #[inline]
647    pub fn revert(&mut self, bit_index: usize) {
648        debug_assert!(bit_index < self.num_bits);
649
650        // Pop from shared history first to get the symbol
651        let Some(last_sym) = self.shared_history.pop() else {
652            return;
653        };
654
655        // Revert the tree responsible for this bit
656        self.trees[bit_index].revert(last_sym, &self.shared_history);
657    }
658
659    /// Updates all trees' effective history lengths with action symbols (no KT update).
660    #[inline]
661    pub fn update_history(&mut self, symbols: &[Symbol]) {
662        self.shared_history.extend_from_slice(symbols);
663    }
664
665    /// Reverts history from all trees.
666    #[inline]
667    pub fn revert_history(&mut self, count: usize) {
668        let new_len = self.shared_history.len().saturating_sub(count);
669        self.shared_history.truncate(new_len);
670    }
671
672    /// Returns the combined log block probability (sum of all trees).
673    #[inline]
674    pub fn get_log_block_probability(&self) -> f64 {
675        self.trees
676            .iter()
677            .map(|t| t.get_log_block_probability())
678            .sum()
679    }
680
681    /// Clears all trees and shared history.
682    pub fn clear(&mut self) {
683        for tree in &mut self.trees {
684            tree.clear();
685        }
686        self.shared_history.clear();
687    }
688
689    /// Returns approximate memory usage in bytes (including shared history).
690    pub fn memory_usage(&self) -> usize {
691        let tree_mem: usize = self.trees.iter().map(|t| t.arena.memory_usage()).sum();
692        let history_mem = self.shared_history.capacity() * std::mem::size_of::<Symbol>();
693        tree_mem + history_mem
694    }
695}
696
697#[cfg(test)]
698mod tests {
699    use super::*;
700
701    #[test]
702    fn fac_ctw_history_consistency() {
703        let mut fac = FacContextTree::new(4, 4);
704
705        // Add action history
706        fac.update_history(&[true, false, true]);
707        assert_eq!(fac.shared_history.len(), 3);
708
709        // Update percept bits
710        fac.update(true, 0);
711        fac.update(false, 1);
712        assert_eq!(fac.shared_history.len(), 5);
713
714        // Revert
715        fac.revert(1);
716        assert_eq!(fac.shared_history.len(), 4);
717
718        fac.revert(0);
719        assert_eq!(fac.shared_history.len(), 3);
720    }
721}