1use std::f64;
15
16type Symbol = bool;
17
18#[derive(Clone, Copy, Debug, PartialEq, Eq)]
20pub struct NodeIndex(u32);
21
22impl NodeIndex {
23 pub const NONE: NodeIndex = NodeIndex(u32::MAX);
24
25 #[inline(always)]
26 pub fn is_none(self) -> bool {
27 self.0 == u32::MAX
28 }
29
30 #[inline(always)]
31 pub fn is_some(self) -> bool {
32 self.0 != u32::MAX
33 }
34
35 #[inline(always)]
36 pub fn get(self) -> usize {
37 self.0 as usize
38 }
39}
40
41#[derive(Clone, Debug)]
43pub struct CtNode {
44 pub children: [NodeIndex; 2],
46 pub log_prob_kt: f64,
48 pub log_prob_weighted: f64,
50 pub symbol_count: [u32; 2],
52}
53
54impl CtNode {
55 #[inline(always)]
56 pub fn new() -> Self {
57 Self {
58 children: [NodeIndex::NONE, NodeIndex::NONE],
59 log_prob_kt: 0.0,
60 log_prob_weighted: 0.0,
61 symbol_count: [0, 0],
62 }
63 }
64
65 #[inline(always)]
67 pub fn visits(&self) -> u32 {
68 self.symbol_count[0] + self.symbol_count[1]
69 }
70}
71
72#[derive(Clone, Debug)]
74pub struct CtArena {
75 nodes: Vec<CtNode>,
76 free_list: Vec<NodeIndex>,
77}
78
79impl CtArena {
80 pub fn new() -> Self {
81 Self {
82 nodes: Vec::with_capacity(1024),
83 free_list: Vec::new(),
84 }
85 }
86
87 pub fn with_capacity(cap: usize) -> Self {
88 Self {
89 nodes: Vec::with_capacity(cap),
90 free_list: Vec::new(),
91 }
92 }
93
94 #[inline(always)]
95 pub fn alloc(&mut self) -> NodeIndex {
96 if let Some(idx) = self.free_list.pop() {
97 self.nodes[idx.get()] = CtNode::new();
98 idx
99 } else {
100 let idx = NodeIndex(self.nodes.len() as u32);
101 self.nodes.push(CtNode::new());
102 idx
103 }
104 }
105
106 #[inline(always)]
107 pub fn free(&mut self, idx: NodeIndex) {
108 if idx.is_some() {
109 self.free_list.push(idx);
110 }
111 }
112
113 #[inline(always)]
114 pub fn get(&self, idx: NodeIndex) -> &CtNode {
115 &self.nodes[idx.get()]
116 }
117
118 #[inline(always)]
119 pub fn get_mut(&mut self, idx: NodeIndex) -> &mut CtNode {
120 &mut self.nodes[idx.get()]
121 }
122
123 pub fn clear(&mut self) {
124 self.nodes.clear();
125 self.free_list.clear();
126 }
127
128 pub fn memory_usage(&self) -> usize {
130 self.nodes.capacity() * std::mem::size_of::<CtNode>()
131 + self.free_list.capacity() * std::mem::size_of::<NodeIndex>()
132 }
133}
134
135#[derive(Clone)]
137pub struct ContextTree {
138 arena: CtArena,
139 root: NodeIndex,
140 history: Vec<Symbol>,
141 max_depth: usize,
142 context_buf: Vec<Symbol>,
143 path_buf: Vec<NodeIndex>,
144}
145
146impl ContextTree {
147 pub fn new(depth: usize) -> Self {
149 let mut arena = CtArena::with_capacity(1024.min(1 << depth.min(16)));
150 let root = arena.alloc();
151 Self {
152 arena,
153 root,
154 history: Vec::new(),
155 max_depth: depth,
156 context_buf: vec![false; depth],
157 path_buf: Vec::with_capacity(depth + 1),
158 }
159 }
160
161 pub fn clear(&mut self) {
163 self.history.clear();
164 self.arena.clear();
165 self.root = self.arena.alloc();
166 self.context_buf.fill(false);
167 }
168
169 #[inline]
171 pub fn update(&mut self, sym: Symbol) {
172 self.prepare_context();
173 self.update_from_root(sym, false);
174 self.history.push(sym);
175 }
176
177 #[inline]
179 pub fn revert(&mut self) {
180 let Some(last_sym) = self.history.pop() else {
181 return;
182 };
183 self.prepare_context();
184 self.update_from_root(last_sym, true);
185 }
186
187 #[inline]
189 pub fn update_history(&mut self, symbols: &[Symbol]) {
190 self.history.extend_from_slice(symbols);
191 }
192
193 #[inline]
195 pub fn revert_history(&mut self) {
196 self.history.pop();
197 }
198
199 pub fn truncate_history(&mut self, new_size: usize) {
201 if new_size < self.history.len() {
202 self.history.truncate(new_size);
203 }
204 }
205
206 #[inline]
208 pub fn predict(&mut self, sym: Symbol) -> f64 {
209 let log_prob_before = self.arena.get(self.root).log_prob_weighted;
210 self.update(sym);
211 let log_prob_after = self.arena.get(self.root).log_prob_weighted;
212 self.revert();
213 (log_prob_after - log_prob_before).exp()
214 }
215
216 #[inline]
218 pub fn predict_sym_prob(&mut self) -> f64 {
219 self.predict(true)
220 }
221
222 #[inline]
224 pub fn get_log_block_probability(&self) -> f64 {
225 self.arena.get(self.root).log_prob_weighted
226 }
227
228 #[inline]
230 pub fn depth(&self) -> usize {
231 self.max_depth
232 }
233
234 #[inline]
236 pub fn history_size(&self) -> usize {
237 self.history.len()
238 }
239
240 #[inline(always)]
243 fn prepare_context(&mut self) {
244 self.context_buf.fill(false);
245 let history_len = self.history.len();
246 let copy_len = history_len.min(self.max_depth);
247 if copy_len > 0 {
248 self.context_buf[self.max_depth - copy_len..]
249 .copy_from_slice(&self.history[history_len - copy_len..]);
250 }
251 }
252
253 #[inline]
254 fn update_from_root(&mut self, sym: Symbol, revert: bool) {
255 self.update_node_iterative(self.root, sym, revert);
256 }
257
258 #[inline]
260 fn update_node_iterative(&mut self, root_idx: NodeIndex, sym: Symbol, revert: bool) {
261 let max_depth = self.max_depth;
262
263 let mut path = std::mem::take(&mut self.path_buf);
266 path.clear();
267 path.push(root_idx);
268
269 let mut current = root_idx;
270 for depth in 0..max_depth {
271 let child_sym = self.context_buf[max_depth - 1 - depth];
272 let child_idx = self.arena.get(current).children[child_sym as usize];
273
274 if revert {
275 if child_idx.is_none() {
276 break;
277 }
278 current = child_idx;
279 } else {
280 let child = if child_idx.is_none() {
281 let new_child = self.arena.alloc();
282 self.arena.get_mut(current).children[child_sym as usize] = new_child;
283 new_child
284 } else {
285 child_idx
286 };
287 current = child;
288 }
289 path.push(current);
290 }
291
292 let leaf_depth = path.len() - 1;
294 for (i, &node_idx) in path.iter().enumerate().rev() {
295 let is_leaf = i == leaf_depth;
296 self.update_single_node(node_idx, sym, revert, is_leaf);
297
298 if revert && i > 0 {
300 let parent_idx = path[i - 1];
301 let depth = i - 1;
302 let child_sym = self.context_buf[max_depth - 1 - depth];
303 if self.arena.get(node_idx).visits() == 0 {
304 self.arena.get_mut(parent_idx).children[child_sym as usize] = NodeIndex::NONE;
305 self.arena.free(node_idx);
306 }
307 }
308 }
309
310 self.path_buf = path;
311 }
312
313 #[inline(always)]
314 fn update_single_node(&mut self, idx: NodeIndex, sym: Symbol, revert: bool, is_leaf: bool) {
315 let (log_prob_w0, log_prob_w1) = if !is_leaf {
317 let node = self.arena.get(idx);
318 let child0 = node.children[0];
319 let child1 = node.children[1];
320 let w0 = if child0.is_some() {
321 self.arena.get(child0).log_prob_weighted
322 } else {
323 0.0
324 };
325 let w1 = if child1.is_some() {
326 self.arena.get(child1).log_prob_weighted
327 } else {
328 0.0
329 };
330 (w0, w1)
331 } else {
332 (0.0, 0.0)
333 };
334
335 let node = self.arena.get_mut(idx);
336
337 let sym_idx = sym as usize;
339 if !revert {
340 node.log_prob_kt += log_kt_mul(node.symbol_count, sym);
341 node.symbol_count[sym_idx] += 1;
342 } else {
343 let total = node.symbol_count[0] + node.symbol_count[1];
344 if node.symbol_count[sym_idx] > 0 && total > 0 {
345 let numerator = (node.symbol_count[sym_idx] as f64 - 0.5).ln();
346 let denominator = (total as f64).ln();
347 node.log_prob_kt -= numerator - denominator;
348 node.symbol_count[sym_idx] -= 1;
349 }
350 }
351
352 if is_leaf {
354 node.log_prob_weighted = node.log_prob_kt;
355 } else {
356 let mut prob_w01_kt_ratio = (log_prob_w0 + log_prob_w1 - node.log_prob_kt).exp();
357 if prob_w01_kt_ratio > 1.0 {
358 prob_w01_kt_ratio = (node.log_prob_kt - log_prob_w0 - log_prob_w1).exp();
359 node.log_prob_weighted = log_prob_w0 + log_prob_w1;
360 } else {
361 node.log_prob_weighted = node.log_prob_kt;
362 }
363
364 if prob_w01_kt_ratio.is_nan() {
365 prob_w01_kt_ratio = 0.0;
366 }
367 node.log_prob_weighted += prob_w01_kt_ratio.ln_1p() - std::f64::consts::LN_2;
368 }
369
370 if node.log_prob_kt > 1.0e-10 {
372 node.log_prob_kt = 0.0;
373 }
374 if node.log_prob_weighted > 1.0e-10 {
375 node.log_prob_weighted = 0.0;
376 }
377 }
378}
379
380#[inline(always)]
382fn log_kt_mul(counts: [u32; 2], sym: Symbol) -> f64 {
383 let sym_idx = sym as usize;
384 let denominator = ((counts[0] + counts[1] + 1) as f64).ln();
385 (counts[sym_idx] as f64 + 0.5).ln() - denominator
386}
387#[derive(Clone)]
392struct ContextTreeCore {
393 arena: CtArena,
394 root: NodeIndex,
395 max_depth: usize,
396 context_buf: Vec<Symbol>,
397 path_buf: Vec<NodeIndex>,
398}
399
400impl ContextTreeCore {
401 fn new(depth: usize) -> Self {
402 let mut arena = CtArena::with_capacity(1024.min(1 << depth.min(16)));
403 let root = arena.alloc();
404 Self {
405 arena,
406 root,
407 max_depth: depth,
408 context_buf: vec![false; depth],
409 path_buf: Vec::with_capacity(depth + 1),
410 }
411 }
412
413 fn clear(&mut self) {
414 self.arena.clear();
415 self.root = self.arena.alloc();
416 self.context_buf.fill(false);
417 }
418
419 #[inline(always)]
421 fn prepare_context(&mut self, shared_history: &[Symbol]) {
422 self.context_buf.fill(false);
423 let history_len = shared_history.len();
424 let copy_len = history_len.min(self.max_depth);
425 if copy_len > 0 {
426 self.context_buf[self.max_depth - copy_len..]
427 .copy_from_slice(&shared_history[history_len - copy_len..]);
428 }
429 }
430
431 #[inline]
433 fn update(&mut self, sym: Symbol, shared_history: &[Symbol]) {
434 self.prepare_context(shared_history);
435 self.update_node_iterative(sym, false);
436 }
437
438 #[inline]
440 fn revert(&mut self, last_sym: Symbol, shared_history: &[Symbol]) {
441 self.prepare_context(shared_history);
442 self.update_node_iterative(last_sym, true);
443 }
444
445 #[inline]
447 fn predict(&mut self, sym: Symbol, shared_history: &[Symbol]) -> f64 {
448 let log_prob_before = self.arena.get(self.root).log_prob_weighted;
449 self.update(sym, shared_history);
450 let log_prob_after = self.arena.get(self.root).log_prob_weighted;
451 self.prepare_context(shared_history);
452 self.update_node_iterative(sym, true);
453 (log_prob_after - log_prob_before).exp()
454 }
455
456 #[inline]
457 fn get_log_block_probability(&self) -> f64 {
458 self.arena.get(self.root).log_prob_weighted
459 }
460
461 #[inline]
462 fn update_node_iterative(&mut self, sym: Symbol, revert: bool) {
463 let max_depth = self.max_depth;
464
465 let mut path = std::mem::take(&mut self.path_buf);
466 path.clear();
467 path.push(self.root);
468
469 let mut current = self.root;
470 for depth in 0..max_depth {
471 let child_sym = self.context_buf[max_depth - 1 - depth];
472 let child_idx = self.arena.get(current).children[child_sym as usize];
473
474 if revert {
475 if child_idx.is_none() {
476 break;
477 }
478 current = child_idx;
479 } else {
480 let child = if child_idx.is_none() {
481 let new_child = self.arena.alloc();
482 self.arena.get_mut(current).children[child_sym as usize] = new_child;
483 new_child
484 } else {
485 child_idx
486 };
487 current = child;
488 }
489 path.push(current);
490 }
491
492 let leaf_depth = path.len() - 1;
493 for (i, &node_idx) in path.iter().enumerate().rev() {
494 let is_leaf = i == leaf_depth;
495 self.update_single_node(node_idx, sym, revert, is_leaf);
496
497 if revert && i > 0 {
498 let parent_idx = path[i - 1];
499 let depth = i - 1;
500 let child_sym = self.context_buf[max_depth - 1 - depth];
501 if self.arena.get(node_idx).visits() == 0 {
502 self.arena.get_mut(parent_idx).children[child_sym as usize] = NodeIndex::NONE;
503 self.arena.free(node_idx);
504 }
505 }
506 }
507
508 self.path_buf = path;
509 }
510
511 #[inline(always)]
512 fn update_single_node(&mut self, idx: NodeIndex, sym: Symbol, revert: bool, is_leaf: bool) {
513 let (log_prob_w0, log_prob_w1) = if !is_leaf {
514 let node = self.arena.get(idx);
515 let child0 = node.children[0];
516 let child1 = node.children[1];
517 let w0 = if child0.is_some() {
518 self.arena.get(child0).log_prob_weighted
519 } else {
520 0.0
521 };
522 let w1 = if child1.is_some() {
523 self.arena.get(child1).log_prob_weighted
524 } else {
525 0.0
526 };
527 (w0, w1)
528 } else {
529 (0.0, 0.0)
530 };
531
532 let node = self.arena.get_mut(idx);
533
534 let sym_idx = sym as usize;
535 if !revert {
536 node.log_prob_kt += log_kt_mul(node.symbol_count, sym);
537 node.symbol_count[sym_idx] += 1;
538 } else {
539 let total = node.symbol_count[0] + node.symbol_count[1];
540 if node.symbol_count[sym_idx] > 0 && total > 0 {
541 let numerator = (node.symbol_count[sym_idx] as f64 - 0.5).ln();
542 let denominator = (total as f64).ln();
543 node.log_prob_kt -= numerator - denominator;
544 node.symbol_count[sym_idx] -= 1;
545 }
546 }
547
548 if is_leaf {
549 node.log_prob_weighted = node.log_prob_kt;
550 } else {
551 let mut prob_w01_kt_ratio = (log_prob_w0 + log_prob_w1 - node.log_prob_kt).exp();
552 if prob_w01_kt_ratio > 1.0 {
553 prob_w01_kt_ratio = (node.log_prob_kt - log_prob_w0 - log_prob_w1).exp();
554 node.log_prob_weighted = log_prob_w0 + log_prob_w1;
555 } else {
556 node.log_prob_weighted = node.log_prob_kt;
557 }
558
559 if prob_w01_kt_ratio.is_nan() {
560 prob_w01_kt_ratio = 0.0;
561 }
562 node.log_prob_weighted += prob_w01_kt_ratio.ln_1p() - std::f64::consts::LN_2;
563 }
564
565 if node.log_prob_kt > 1.0e-10 {
566 node.log_prob_kt = 0.0;
567 }
568 if node.log_prob_weighted > 1.0e-10 {
569 node.log_prob_weighted = 0.0;
570 }
571 }
572}
573
574#[derive(Clone)]
585pub struct FacContextTree {
586 trees: Vec<ContextTreeCore>,
588 shared_history: Vec<Symbol>,
590 base_depth: usize,
592 num_bits: usize,
594}
595
596impl FacContextTree {
597 pub fn new(base_depth: usize, num_percept_bits: usize) -> Self {
601 let trees = (0..num_percept_bits)
602 .map(|i| ContextTreeCore::new(base_depth + i))
603 .collect();
604 Self {
605 trees,
606 shared_history: Vec::new(),
607 base_depth,
608 num_bits: num_percept_bits,
609 }
610 }
611
612 #[inline]
614 pub fn num_bits(&self) -> usize {
615 self.num_bits
616 }
617
618 #[inline]
620 pub fn base_depth(&self) -> usize {
621 self.base_depth
622 }
623
624 #[inline]
628 pub fn update(&mut self, sym: Symbol, bit_index: usize) {
629 debug_assert!(bit_index < self.num_bits);
630
631 self.trees[bit_index].update(sym, &self.shared_history);
633
634 self.shared_history.push(sym);
636 }
637
638 #[inline]
640 pub fn predict(&mut self, sym: Symbol, bit_index: usize) -> f64 {
641 debug_assert!(bit_index < self.num_bits);
642 self.trees[bit_index].predict(sym, &self.shared_history)
643 }
644
645 #[inline]
647 pub fn revert(&mut self, bit_index: usize) {
648 debug_assert!(bit_index < self.num_bits);
649
650 let Some(last_sym) = self.shared_history.pop() else {
652 return;
653 };
654
655 self.trees[bit_index].revert(last_sym, &self.shared_history);
657 }
658
659 #[inline]
661 pub fn update_history(&mut self, symbols: &[Symbol]) {
662 self.shared_history.extend_from_slice(symbols);
663 }
664
665 #[inline]
667 pub fn revert_history(&mut self, count: usize) {
668 let new_len = self.shared_history.len().saturating_sub(count);
669 self.shared_history.truncate(new_len);
670 }
671
672 #[inline]
674 pub fn get_log_block_probability(&self) -> f64 {
675 self.trees
676 .iter()
677 .map(|t| t.get_log_block_probability())
678 .sum()
679 }
680
681 pub fn clear(&mut self) {
683 for tree in &mut self.trees {
684 tree.clear();
685 }
686 self.shared_history.clear();
687 }
688
689 pub fn memory_usage(&self) -> usize {
691 let tree_mem: usize = self.trees.iter().map(|t| t.arena.memory_usage()).sum();
692 let history_mem = self.shared_history.capacity() * std::mem::size_of::<Symbol>();
693 tree_mem + history_mem
694 }
695}
696
697#[cfg(test)]
698mod tests {
699 use super::*;
700
701 #[test]
702 fn fac_ctw_history_consistency() {
703 let mut fac = FacContextTree::new(4, 4);
704
705 fac.update_history(&[true, false, true]);
707 assert_eq!(fac.shared_history.len(), 3);
708
709 fac.update(true, 0);
711 fac.update(false, 1);
712 assert_eq!(fac.shared_history.len(), 5);
713
714 fac.revert(1);
716 assert_eq!(fac.shared_history.len(), 4);
717
718 fac.revert(0);
719 assert_eq!(fac.shared_history.len(), 3);
720 }
721}