1use crate::ParticleSpec;
8use crate::simd_math::{axpy_wide, dot_wide, logsumexp_wide, max_wide};
9use std::collections::VecDeque;
10
11#[inline]
18fn det_hash(seed: u64, a: u64, b: u64, c: u64) -> u64 {
19 let mut h = seed;
20 h = h.wrapping_mul(0x517cc1b727220a95).wrapping_add(a);
21 h ^= h >> 33;
22 h = h.wrapping_mul(0x4cf5ad432745937f).wrapping_add(b);
23 h ^= h >> 33;
24 h = h.wrapping_mul(0x6c62272e07bb0142).wrapping_add(c);
25 h ^= h >> 33;
26 h
27}
28
29#[inline]
31fn hash_to_f64(h: u64) -> f64 {
32 let u = (h >> 11) as f64 / ((1u64 << 53) as f64);
34 u * 2.0 - 1.0
35}
36
37#[inline]
39fn init_param(seed: u64, layer: u64, row: u64, col: u64, scale: f64) -> f64 {
40 hash_to_f64(det_hash(seed, layer, row, col)) * scale
41}
42
43#[inline]
48fn clip(x: f64, limit: f64) -> f64 {
49 x.clamp(-limit, limit)
50}
51
52fn softmax_inplace(xs: &mut [f64]) {
53 let max_v = max_wide(xs);
54 let mut sum = 0.0;
55 for x in xs.iter_mut() {
56 *x = (*x - max_v).exp();
57 sum += *x;
58 }
59 if sum > 0.0 {
60 let inv = 1.0 / sum;
61 for x in xs.iter_mut() {
62 *x *= inv;
63 }
64 }
65}
66
67fn log_softmax_with_floor(logits: &[f64], out: &mut [f64], min_prob: f64) {
68 let max_v = max_wide(logits);
69 let mut sum = 0.0;
70 for &l in logits {
71 sum += (l - max_v).exp();
72 }
73 let log_z = max_v + sum.ln();
74 let log_floor = min_prob.ln();
75 let mut log_sum_exp_floor = f64::NEG_INFINITY;
77 for (i, &l) in logits.iter().enumerate() {
78 let lp = (l - log_z).max(log_floor);
79 out[i] = lp;
80 if lp > log_sum_exp_floor {
82 let diff = log_sum_exp_floor - lp;
83 if diff.is_finite() {
84 log_sum_exp_floor = lp + (1.0 + diff.exp()).ln();
85 } else {
86 log_sum_exp_floor = lp;
87 }
88 } else {
89 let diff = lp - log_sum_exp_floor;
90 if diff.is_finite() {
91 log_sum_exp_floor += (1.0 + diff.exp()).ln();
92 }
93 }
94 }
95 if log_sum_exp_floor.is_finite() {
97 for v in out.iter_mut() {
98 *v -= log_sum_exp_floor;
99 }
100 }
101}
102
103#[derive(Clone)]
109struct DenseLayer {
110 weights: Vec<f64>,
111 bias: Vec<f64>,
112 vel_weights: Vec<f64>,
113 vel_bias: Vec<f64>,
114 in_dim: usize,
115 out_dim: usize,
116}
117
118impl DenseLayer {
119 fn new(in_dim: usize, out_dim: usize) -> Self {
120 Self {
121 weights: vec![0.0; out_dim * in_dim],
122 bias: vec![0.0; out_dim],
123 vel_weights: vec![0.0; out_dim * in_dim],
124 vel_bias: vec![0.0; out_dim],
125 in_dim,
126 out_dim,
127 }
128 }
129
130 fn init(&mut self, seed: u64, layer_id: u64, scale: f64) {
131 for r in 0..self.out_dim {
132 for c in 0..self.in_dim {
133 self.weights[r * self.in_dim + c] =
134 init_param(seed, layer_id, r as u64, c as u64, scale);
135 }
136 self.bias[r] = 0.0;
137 }
138 }
139
140 fn forward(&self, x: &[f64], out: &mut [f64]) {
142 debug_assert!(x.len() >= self.in_dim);
143 debug_assert!(out.len() >= self.out_dim);
144 for (r, slot) in out.iter_mut().enumerate().take(self.out_dim) {
145 let row = &self.weights[r * self.in_dim..(r + 1) * self.in_dim];
146 *slot = dot_wide(row, &x[..self.in_dim]) + self.bias[r];
147 }
148 }
149
150 fn forward_relu(&self, x: &[f64], out: &mut [f64]) {
152 self.forward(x, out);
153 for v in out[..self.out_dim].iter_mut() {
154 *v = v.max(0.0);
155 }
156 }
157
158 fn sgd_update(&mut self, grad_out: &[f64], x: &[f64], lr: f64, grad_clip: f64, momentum: f64) {
161 if momentum == 0.0 {
162 for (r, &grad) in grad_out.iter().enumerate().take(self.out_dim) {
163 let g = clip(grad, grad_clip);
164 let row = &mut self.weights[r * self.in_dim..(r + 1) * self.in_dim];
165 axpy_wide(row, -lr * g, &x[..self.in_dim]);
166 self.bias[r] -= lr * g;
167 }
168 return;
169 }
170
171 for (r, &grad) in grad_out.iter().enumerate().take(self.out_dim) {
172 let g = clip(grad, grad_clip);
173 for (c, &x_c) in x.iter().enumerate().take(self.in_dim) {
174 let idx = r * self.in_dim + c;
175 let grad_w = g * x_c;
176 self.vel_weights[idx] = momentum * self.vel_weights[idx] + grad_w;
177 self.weights[idx] -= lr * self.vel_weights[idx];
178 }
179 self.vel_bias[r] = momentum * self.vel_bias[r] + g;
180 self.bias[r] -= lr * self.vel_bias[r];
181 }
182 }
183}
184
185#[derive(Clone)]
191struct CellSelector {
192 hidden: DenseLayer, gate: DenseLayer, }
195
196#[derive(Clone)]
198struct CellRule {
199 hidden: DenseLayer, output: DenseLayer, }
202
203#[derive(Clone)]
205struct CellParams {
206 selector: CellSelector,
207 rules: Vec<CellRule>,
208}
209
210#[derive(Clone)]
217struct ParticleModel {
218 embed: Vec<f64>,
220 cells: Vec<CellParams>,
222 readout: DenseLayer,
224 cell_dim: usize,
226 num_cells: usize,
227 noise_dim: usize,
228 phi_dim: usize, selector_in_dim: usize, }
231
232impl ParticleModel {
233 fn new(spec: &ParticleSpec) -> Self {
234 let cell_dim = spec.cell_dim;
235 let selector_in_dim = 5 * cell_dim;
236 let rule_in_dim = 5 * cell_dim + spec.noise_dim;
237 let phi_dim = 5 * cell_dim;
242
243 let embed = vec![0.0; 256 * cell_dim];
244 let cells = (0..spec.num_cells)
245 .map(|_| CellParams {
246 selector: CellSelector {
247 hidden: DenseLayer::new(selector_in_dim, spec.selector_hidden),
248 gate: DenseLayer::new(spec.selector_hidden, spec.num_rules),
249 },
250 rules: (0..spec.num_rules)
251 .map(|_| CellRule {
252 hidden: DenseLayer::new(rule_in_dim, spec.rule_hidden),
253 output: DenseLayer::new(spec.rule_hidden, cell_dim),
254 })
255 .collect(),
256 })
257 .collect();
258 let readout = DenseLayer::new(phi_dim, 256);
259
260 Self {
261 embed,
262 cells,
263 readout,
264 cell_dim,
265 num_cells: spec.num_cells,
266 noise_dim: spec.noise_dim,
267 phi_dim,
268 selector_in_dim,
269 }
270 }
271
272 fn init(&mut self, seed: u64, spec: &ParticleSpec) {
273 let scale = 0.1;
274 let embed_scale = 0.3;
280 for i in 0..256 {
282 for j in 0..self.cell_dim {
283 self.embed[i * self.cell_dim + j] =
284 init_param(seed, 0, i as u64, j as u64, embed_scale);
285 }
286 }
287 for (ci, cp) in self.cells.iter_mut().enumerate() {
289 let cell_seed = ci as u64 + 1;
290 cp.selector.hidden.init(seed, cell_seed * 100 + 1, scale);
291 cp.selector
292 .gate
293 .init(seed, cell_seed * 100 + 2, scale * 0.1);
294 for (ri, rule) in cp.rules.iter_mut().enumerate() {
295 let r_off = cell_seed * 100 + 10 + ri as u64;
296 rule.hidden.init(seed, r_off * 10 + 1, scale);
297 rule.output.init(seed, r_off * 10 + 2, scale * 0.5);
298 }
299 }
300 self.readout.init(seed, 9999, scale * 0.1);
302 let _ = spec; }
304}
305
306#[derive(Clone)]
311struct ParticleState {
312 particle_id: u64,
314 cells: Vec<f64>,
316 context: Vec<u8>,
318 ctx_pos: usize,
320 ctx_len: usize,
322 model: ParticleModel,
324 cached_log_probs: [f64; 256],
326 cache_valid: bool,
328 scratch_ctx: Vec<f64>,
330 scratch_mean_cells: Vec<f64>,
331 scratch_p: Vec<f64>,
332 scratch_sel_h: Vec<f64>,
333 scratch_gate: Vec<f64>,
334 scratch_rule_in: Vec<f64>,
335 scratch_rule_h: Vec<f64>,
336 scratch_delta_k: Vec<f64>,
337 scratch_delta: Vec<f64>,
338 scratch_phi: Vec<f64>,
339 scratch_logits: Vec<f64>,
340 scratch_d_logits: Vec<f64>,
342 scratch_d_phi: Vec<f64>,
343 scratch_softmax: Vec<f64>,
344 scratch_d_rule_out: Vec<f64>,
345 scratch_d_rule_h: Vec<f64>,
346 scratch_d_gate: Vec<f64>,
347 scratch_d_gate_logits: Vec<f64>,
348 scratch_d_sel_h: Vec<f64>,
349 trace_history: VecDeque<StepTrace>,
350}
351
352#[derive(Clone)]
353struct RuleTrace {
354 rule_h: Vec<f64>,
355 rule_out: Vec<f64>,
356}
357
358#[derive(Clone)]
359struct CellTrace {
360 p: Vec<f64>,
361 sel_h: Vec<f64>,
362 gate: Vec<f64>,
363 rule_in: Vec<f64>,
364 rules: Vec<RuleTrace>,
365}
366
367#[derive(Clone)]
368struct StepTrace {
369 cells: Vec<CellTrace>,
370}
371
372impl ParticleState {
373 fn new(spec: &ParticleSpec, model: ParticleModel, particle_id: u64) -> Self {
374 let cd = spec.cell_dim;
375 let nc = spec.num_cells;
376 let sel_in = 5 * cd;
377 let rule_in = 5 * cd + spec.noise_dim;
378 let phi_dim = model.phi_dim; Self {
380 particle_id,
381 cells: vec![0.0; nc * cd],
382 context: vec![0; spec.context_window],
383 ctx_pos: 0,
384 ctx_len: 0,
385 model,
386 cached_log_probs: [0.0; 256],
387 cache_valid: false,
388 scratch_ctx: vec![0.0; cd],
389 scratch_mean_cells: vec![0.0; cd],
390 scratch_p: vec![0.0; sel_in],
391 scratch_sel_h: vec![0.0; spec.selector_hidden],
392 scratch_gate: vec![0.0; spec.num_rules],
393 scratch_rule_in: vec![0.0; rule_in],
394 scratch_rule_h: vec![0.0; spec.rule_hidden],
395 scratch_delta_k: vec![0.0; cd],
396 scratch_delta: vec![0.0; cd],
397 scratch_phi: vec![0.0; phi_dim],
398 scratch_logits: vec![0.0; 256],
399 scratch_d_logits: vec![0.0; 256],
400 scratch_d_phi: vec![0.0; phi_dim],
401 scratch_softmax: vec![0.0; 256],
402 scratch_d_rule_out: vec![0.0; cd],
403 scratch_d_rule_h: vec![0.0; spec.rule_hidden],
404 scratch_d_gate: vec![0.0; spec.num_rules],
405 scratch_d_gate_logits: vec![0.0; spec.num_rules],
406 scratch_d_sel_h: vec![0.0; spec.selector_hidden],
407 trace_history: VecDeque::with_capacity(spec.bptt_depth.max(1)),
408 }
409 }
410
411 fn build_ctx(&mut self) {
413 let cd = self.model.cell_dim;
414 self.scratch_ctx.iter_mut().for_each(|v| *v = 0.0);
415 let len = self.ctx_len.min(self.context.len());
416 if len == 0 {
417 return;
418 }
419 let cw = self.context.len();
420 let decay = 0.90_f64;
423 let mut weight_sum = 0.0_f64;
424 let mut w = 1.0_f64;
425 for age in 0..len {
426 let pos = (self.ctx_pos + cw - 1 - age) % cw;
427 let byte = self.context[pos] as usize;
428 let emb = &self.model.embed[byte * cd..(byte + 1) * cd];
429 weight_sum += w;
430 for (ctx, &emb_j) in self.scratch_ctx.iter_mut().zip(emb.iter()) {
431 *ctx += emb_j * w;
432 }
433 w *= decay;
434 }
435 if weight_sum > 0.0 {
436 let inv = 1.0 / weight_sum;
437 for v in &mut self.scratch_ctx {
438 *v *= inv;
439 }
440 }
441 }
442
443 fn compute_mean_cells(&mut self) {
445 let cd = self.model.cell_dim;
446 let nc = self.model.num_cells;
447 self.scratch_mean_cells.iter_mut().for_each(|v| *v = 0.0);
448 if nc == 0 {
449 return;
450 }
451 let inv = 1.0 / nc as f64;
452 for ci in 0..nc {
453 let off = ci * cd;
454 for j in 0..cd {
455 self.scratch_mean_cells[j] += self.cells[off + j] * inv;
456 }
457 }
458 }
459
460 fn build_selector_input(&mut self, cell_idx: usize) {
462 let cd = self.model.cell_dim;
463 let nc = self.model.num_cells.max(1);
464 let off = cell_idx * cd;
465 let left_idx = if nc <= 1 {
466 cell_idx
467 } else {
468 (cell_idx + nc - 1) % nc
469 };
470 let right_idx = if nc <= 1 {
471 cell_idx
472 } else {
473 (cell_idx + 1) % nc
474 };
475 let left_off = left_idx * cd;
476 let right_off = right_idx * cd;
477 self.scratch_p[..cd].copy_from_slice(&self.cells[off..off + cd]);
478 self.scratch_p[cd..2 * cd].copy_from_slice(&self.cells[left_off..left_off + cd]);
479 self.scratch_p[2 * cd..3 * cd].copy_from_slice(&self.cells[right_off..right_off + cd]);
480 self.scratch_p[3 * cd..4 * cd].copy_from_slice(&self.scratch_ctx[..cd]);
481 self.scratch_p[4 * cd..5 * cd].copy_from_slice(&self.scratch_mean_cells[..cd]);
482 }
483
484 fn build_rule_input(
486 &mut self,
487 spec: &ParticleSpec,
488 step_idx: u64,
489 unroll_idx: usize,
490 cell_idx: usize,
491 ) {
492 let sel_in = self.model.selector_in_dim;
493 let nd = self.model.noise_dim;
494 self.scratch_rule_in[..sel_in].copy_from_slice(&self.scratch_p[..sel_in]);
495 if nd == 0 || !spec.enable_noise || spec.noise_scale <= 0.0 {
496 for j in sel_in..sel_in + nd {
497 self.scratch_rule_in[j] = 0.0;
498 }
499 return;
500 }
501 let anneal = if spec.noise_anneal_steps == 0 {
502 1.0
503 } else {
504 let rem = spec.noise_anneal_steps.saturating_sub(step_idx as usize) as f64;
505 rem / spec.noise_anneal_steps as f64
506 };
507 let scale = spec.noise_scale * anneal.max(0.0);
508 for j in 0..nd {
509 let h = det_hash(
510 spec.seed ^ self.particle_id,
511 step_idx,
512 ((unroll_idx as u64) << 40) ^ ((cell_idx as u64) << 20) ^ j as u64,
513 0xD1A6_51EED,
514 );
515 self.scratch_rule_in[sel_in + j] = hash_to_f64(h) * scale;
516 }
517 }
518
519 fn build_phi(&mut self) {
535 let cd = self.model.cell_dim;
536 let nc = self.model.num_cells;
537 self.scratch_phi[..cd].copy_from_slice(&self.scratch_mean_cells[..cd]);
539 for j in 0..cd {
541 let mut mx = f64::NEG_INFINITY;
542 for ci in 0..nc {
543 let v = self.cells[ci * cd + j];
544 if v > mx {
545 mx = v;
546 }
547 }
548 self.scratch_phi[cd + j] = if mx.is_finite() { mx } else { 0.0 };
549 }
550 for j in 0..cd {
552 let mean = self.scratch_mean_cells[j];
553 let mut var = 0.0_f64;
554 for ci in 0..nc {
555 let d = self.cells[ci * cd + j] - mean;
556 var += d * d;
557 }
558 self.scratch_phi[2 * cd + j] = (var / nc.max(1) as f64).sqrt();
559 }
560 let cw = self.context.len();
562 if self.ctx_len >= 2 {
563 let pos2 = (self.ctx_pos + cw - 2) % cw;
565 let byte2 = self.context[pos2] as usize;
566 self.scratch_phi[3 * cd..4 * cd]
567 .copy_from_slice(&self.model.embed[byte2 * cd..(byte2 + 1) * cd]);
568 } else {
569 self.scratch_phi[3 * cd..4 * cd].fill(0.0);
570 }
571 if self.ctx_len >= 1 {
573 let pos1 = (self.ctx_pos + cw - 1) % cw;
574 let byte1 = self.context[pos1] as usize;
575 self.scratch_phi[4 * cd..5 * cd]
576 .copy_from_slice(&self.model.embed[byte1 * cd..(byte1 + 1) * cd]);
577 } else {
578 self.scratch_phi[4 * cd..5 * cd].fill(0.0);
579 }
580 }
581
582 fn forward(&mut self, spec: &ParticleSpec, step_idx: u64) {
584 self.build_ctx();
585 self.compute_mean_cells();
586 let capture_trace = spec.learning_rate_selector > 0.0 || spec.learning_rate_rule > 0.0;
587 let mut step_trace = if capture_trace {
588 Some(StepTrace {
589 cells: Vec::with_capacity(self.model.num_cells),
590 })
591 } else {
592 None
593 };
594
595 for unroll_idx in 0..spec.unroll_steps {
597 for ci in 0..self.model.num_cells {
598 self.build_selector_input(ci);
599
600 self.model.cells[ci]
602 .selector
603 .hidden
604 .forward_relu(&self.scratch_p, &mut self.scratch_sel_h);
605 self.model.cells[ci]
607 .selector
608 .gate
609 .forward(&self.scratch_sel_h, &mut self.scratch_gate);
610 softmax_inplace(&mut self.scratch_gate[..spec.num_rules]);
611
612 self.build_rule_input(spec, step_idx, unroll_idx, ci);
614
615 let cd = self.model.cell_dim;
617 self.scratch_delta[..cd].fill(0.0);
618 let mut rule_traces = if capture_trace {
619 Some(Vec::with_capacity(spec.num_rules))
620 } else {
621 None
622 };
623 for ki in 0..spec.num_rules {
624 let gate_k = self.scratch_gate[ki];
625 self.model.cells[ci].rules[ki]
627 .hidden
628 .forward_relu(&self.scratch_rule_in, &mut self.scratch_rule_h);
629 self.model.cells[ci].rules[ki]
631 .output
632 .forward(&self.scratch_rule_h, &mut self.scratch_delta_k);
633 if let Some(rt) = &mut rule_traces {
634 rt.push(RuleTrace {
635 rule_h: self.scratch_rule_h.clone(),
636 rule_out: self.scratch_delta_k[..cd].to_vec(),
637 });
638 }
639 for j in 0..cd {
640 self.scratch_delta[j] += gate_k * self.scratch_delta_k[j];
641 }
642 }
643 if let (Some(st), Some(rt)) = (&mut step_trace, rule_traces) {
644 st.cells.push(CellTrace {
645 p: self.scratch_p.clone(),
646 sel_h: self.scratch_sel_h.clone(),
647 gate: self.scratch_gate[..spec.num_rules].to_vec(),
648 rule_in: self.scratch_rule_in.clone(),
649 rules: rt,
650 });
651 }
652
653 let off = ci * cd;
655 for j in 0..cd {
656 self.cells[off + j] =
657 clip(self.cells[off + j] + self.scratch_delta[j], spec.state_clip);
658 }
659 }
660 self.compute_mean_cells();
662 }
663
664 self.build_phi();
666
667 self.model
669 .readout
670 .forward(&self.scratch_phi, &mut self.scratch_logits);
671
672 self.scratch_softmax.copy_from_slice(&self.scratch_logits);
674 softmax_inplace(&mut self.scratch_softmax);
675
676 log_softmax_with_floor(
678 &self.scratch_logits,
679 &mut self.cached_log_probs,
680 spec.min_prob,
681 );
682 if let Some(st) = step_trace {
683 self.trace_history.push_back(st);
684 while self.trace_history.len() > spec.bptt_depth.max(1) {
685 self.trace_history.pop_front();
686 }
687 }
688 self.cache_valid = true;
689 }
690
691 fn apply_selector_rule_update_from_trace(
692 &mut self,
693 trace: &StepTrace,
694 d_phi: &[f64],
695 temporal: f64,
696 spec: &ParticleSpec,
697 ) {
698 let cd = self.model.cell_dim;
699 let nc = self.model.num_cells.max(1);
700 let d_delta_scale = (1.0 / nc as f64) * temporal;
701
702 for ci in 0..nc.min(trace.cells.len()) {
703 let ct = &trace.cells[ci];
704
705 self.scratch_d_gate[..spec.num_rules].fill(0.0);
707
708 for ki in 0..spec.num_rules.min(ct.rules.len()) {
709 let gate_k = ct.gate[ki];
710 self.scratch_d_rule_out[..cd].fill(0.0);
711 for (dst, &d_phi_j) in self.scratch_d_rule_out[..cd]
712 .iter_mut()
713 .zip(d_phi.iter().take(cd))
714 {
715 *dst = d_phi_j * d_delta_scale * gate_k;
716 }
717
718 self.model.cells[ci].rules[ki].output.sgd_update(
720 &self.scratch_d_rule_out,
721 &ct.rules[ki].rule_h,
722 spec.learning_rate_rule,
723 spec.grad_clip,
724 spec.optimizer_momentum,
725 );
726
727 let rh = spec.rule_hidden;
729 self.scratch_d_rule_h[..rh].fill(0.0);
730 for r in 0..cd {
731 let g = clip(self.scratch_d_rule_out[r], spec.grad_clip);
732 if g.abs() < 1e-15 {
733 continue;
734 }
735 for c in 0..rh {
736 self.scratch_d_rule_h[c] +=
737 g * self.model.cells[ci].rules[ki].output.weights[r * rh + c];
738 }
739 }
740 for (j, h) in ct.rules[ki].rule_h.iter().enumerate().take(rh) {
741 if *h <= 0.0 {
742 self.scratch_d_rule_h[j] = 0.0;
743 }
744 }
745 self.model.cells[ci].rules[ki].hidden.sgd_update(
746 &self.scratch_d_rule_h[..rh],
747 &ct.rule_in,
748 spec.learning_rate_rule,
749 spec.grad_clip,
750 spec.optimizer_momentum,
751 );
752
753 for (&d_phi_j, &rule_out_j) in d_phi
754 .iter()
755 .take(cd)
756 .zip(ct.rules[ki].rule_out.iter().take(cd))
757 {
758 self.scratch_d_gate[ki] += d_phi_j * d_delta_scale * rule_out_j;
759 }
760 }
761
762 let dot_gd: f64 = (0..spec.num_rules.min(ct.gate.len()))
763 .map(|k| ct.gate[k] * self.scratch_d_gate[k])
764 .sum();
765 self.scratch_d_gate_logits[..spec.num_rules].fill(0.0);
766 for k in 0..spec.num_rules.min(ct.gate.len()) {
767 self.scratch_d_gate_logits[k] = ct.gate[k] * (self.scratch_d_gate[k] - dot_gd);
768 }
769
770 self.model.cells[ci].selector.gate.sgd_update(
771 &self.scratch_d_gate_logits[..spec.num_rules],
772 &ct.sel_h,
773 spec.learning_rate_selector,
774 spec.grad_clip,
775 spec.optimizer_momentum,
776 );
777
778 let sh = spec.selector_hidden;
779 self.scratch_d_sel_h[..sh].fill(0.0);
780 for r in 0..spec.num_rules.min(ct.gate.len()) {
781 let g = clip(self.scratch_d_gate_logits[r], spec.grad_clip);
782 if g.abs() < 1e-15 {
783 continue;
784 }
785 for c in 0..sh {
786 self.scratch_d_sel_h[c] +=
787 g * self.model.cells[ci].selector.gate.weights[r * sh + c];
788 }
789 }
790 for (j, h) in ct.sel_h.iter().enumerate().take(sh) {
791 if *h <= 0.0 {
792 self.scratch_d_sel_h[j] = 0.0;
793 }
794 }
795 self.model.cells[ci].selector.hidden.sgd_update(
796 &self.scratch_d_sel_h[..sh],
797 &ct.p,
798 spec.learning_rate_selector,
799 spec.grad_clip,
800 spec.optimizer_momentum,
801 );
802 }
803 }
804
805 fn sgd_update(&mut self, y: u8, spec: &ParticleSpec) {
807 self.scratch_d_logits.copy_from_slice(&self.scratch_softmax);
809 self.scratch_d_logits[y as usize] -= 1.0;
810
811 for v in self.scratch_d_logits.iter_mut() {
813 *v = clip(*v, spec.grad_clip);
814 }
815
816 self.model.readout.sgd_update(
818 &self.scratch_d_logits,
819 &self.scratch_phi,
820 spec.learning_rate_readout,
821 spec.grad_clip,
822 0.0,
823 );
824
825 let phi_dim = self.model.phi_dim;
827 self.scratch_d_phi[..phi_dim].fill(0.0);
828 for r in 0..256 {
829 let g = clip(self.scratch_d_logits[r], spec.grad_clip);
830 if g.abs() < 1e-15 {
831 continue;
832 }
833 let row_start = r * phi_dim;
834 for c in 0..phi_dim {
835 self.scratch_d_phi[c] += g * self.model.readout.weights[row_start + c];
836 }
837 }
838
839 for v in self.scratch_d_phi[..phi_dim].iter_mut() {
841 *v = clip(*v, spec.grad_clip);
842 }
843
844 if spec.learning_rate_selector > 0.0 || spec.learning_rate_rule > 0.0 {
845 let depth = spec.bptt_depth.max(1).min(self.trace_history.len());
846 let traces = std::mem::take(&mut self.trace_history);
847 let d_phi = self.scratch_d_phi[..phi_dim].to_vec();
848 let mut temporal = 1.0_f64;
849 let temporal_decay = 0.7_f64;
850 for idx in 0..depth {
851 let hist_idx = traces.len() - 1 - idx;
852 let trace = &traces[hist_idx];
853 self.apply_selector_rule_update_from_trace(trace, &d_phi, temporal, spec);
854 temporal *= temporal_decay;
855 }
856 self.trace_history = traces;
857 }
858 }
859
860 fn push_context(&mut self, byte: u8) {
862 self.context[self.ctx_pos] = byte;
863 self.ctx_pos = (self.ctx_pos + 1) % self.context.len();
864 self.ctx_len += 1;
865 }
866
867 fn reset_dynamic_state(&mut self) {
868 self.cells.fill(0.0);
869 self.context.fill(0);
870 self.ctx_pos = 0;
871 self.ctx_len = 0;
872 self.cached_log_probs.fill(0.0);
873 self.cache_valid = false;
874 self.scratch_ctx.fill(0.0);
875 self.scratch_mean_cells.fill(0.0);
876 self.scratch_p.fill(0.0);
877 self.scratch_sel_h.fill(0.0);
878 self.scratch_gate.fill(0.0);
879 self.scratch_rule_in.fill(0.0);
880 self.scratch_rule_h.fill(0.0);
881 self.scratch_delta_k.fill(0.0);
882 self.scratch_delta.fill(0.0);
883 self.scratch_phi.fill(0.0);
884 self.scratch_logits.fill(0.0);
885 self.scratch_d_logits.fill(0.0);
886 self.scratch_d_phi.fill(0.0);
887 self.scratch_softmax.fill(0.0);
888 self.scratch_d_rule_out.fill(0.0);
889 self.scratch_d_rule_h.fill(0.0);
890 self.scratch_d_gate.fill(0.0);
891 self.scratch_d_gate_logits.fill(0.0);
892 self.scratch_d_sel_h.fill(0.0);
893 self.trace_history.clear();
894 }
895}
896
897pub struct ParticleRuntime {
906 spec: ParticleSpec,
907 particles: Vec<ParticleState>,
908 log_weights: Vec<f64>,
909 mix_log_probs: [f64; 256],
911 mix_pdf: Vec<f64>,
913 cache_valid: bool,
915 step_idx: u64,
917 scratch_lse: Vec<f64>,
919}
920
921impl ParticleRuntime {
922 #[inline]
923 fn likelihood_beta(&self) -> f64 {
924 const BETA_MIN: f64 = 0.35;
927 const WARMUP_STEPS: u64 = 2048;
928 if self.step_idx >= WARMUP_STEPS {
929 1.0
930 } else {
931 BETA_MIN + (1.0 - BETA_MIN) * (self.step_idx as f64 / WARMUP_STEPS as f64)
932 }
933 }
934
935 #[inline]
936 fn diagnostics_enabled(&self) -> bool {
937 self.spec.diagnostics_interval > 0
938 && self
939 .step_idx
940 .is_multiple_of(self.spec.diagnostics_interval as u64)
941 }
942
943 #[inline]
944 fn weight_stats(&self) -> (f64, f64) {
945 let mut sum_sq = 0.0;
946 let mut max_w = 0.0;
947 for &lw in &self.log_weights {
948 let w = lw.exp();
949 sum_sq += w * w;
950 if w > max_w {
951 max_w = w;
952 }
953 }
954 let n_eff = if sum_sq > 0.0 { 1.0 / sum_sq } else { 0.0 };
955 (n_eff, max_w)
956 }
957
958 fn weighted_prediction_kl_divergence(&self) -> f64 {
959 let n = self.particles.len();
962 if n == 0 {
963 return 0.0;
964 }
965 let log_z = logsumexp_wide(&self.log_weights);
966 let mut mix_log_probs = [0.0_f64; 256];
967 let mut scratch_lse = vec![0.0_f64; n];
968 for (v, mix_logp) in mix_log_probs.iter_mut().enumerate() {
969 for (slot, (log_weight, particle)) in scratch_lse
970 .iter_mut()
971 .zip(self.log_weights.iter().zip(self.particles.iter()))
972 {
973 *slot = *log_weight + particle.cached_log_probs[v];
974 }
975 *mix_logp = logsumexp_wide(&scratch_lse) - log_z;
976 }
977 let mut d = 0.0_f64;
978 for (i, p) in self.particles.iter().enumerate() {
979 let alpha = self.log_weights[i].exp();
980 if alpha <= 0.0 {
981 continue;
982 }
983 let mut kl_i = 0.0_f64;
984 for (&lp_i, &mix_logp) in p.cached_log_probs.iter().zip(mix_log_probs.iter()) {
985 let prob_i = lp_i.exp();
986 kl_i += prob_i * (lp_i - mix_logp);
987 }
988 d += alpha * kl_i.max(0.0);
989 }
990 d
991 }
992
993 fn log_diagnostics(
994 &self,
995 n_eff: f64,
996 max_weight: f64,
997 divergence: f64,
998 beta: f64,
999 will_resample: bool,
1000 ) {
1001 eprintln!(
1002 "[particle] step={} neff={:.3}/{:.0} max_w={:.3}% div_kl={:.6} beta={:.3} resample={}",
1003 self.step_idx,
1004 n_eff,
1005 self.particles.len() as f64,
1006 max_weight * 100.0,
1007 divergence,
1008 beta,
1009 will_resample
1010 );
1011 }
1012
1013 fn diversify_initial_particles(&mut self) {
1014 let scale = 5e-3_f64;
1015 if self.particles.len() <= 1 {
1016 return;
1017 }
1018 for pi in 1..self.particles.len() {
1019 let p = &mut self.particles[pi];
1020 for (idx, v) in p.cells.iter_mut().enumerate() {
1021 let noise = hash_to_f64(det_hash(self.spec.seed, pi as u64, idx as u64, 1000));
1022 *v += noise * scale;
1023 }
1024 for (idx, v) in p.model.readout.bias.iter_mut().enumerate() {
1025 let noise = hash_to_f64(det_hash(self.spec.seed, pi as u64, idx as u64, 1001));
1026 *v += noise * scale;
1027 }
1028 for (idx, v) in p.model.readout.weights.iter_mut().enumerate() {
1029 let noise = hash_to_f64(det_hash(self.spec.seed, pi as u64, idx as u64, 1002));
1030 *v += noise * (scale * 0.5);
1031 }
1032 }
1033 }
1034
1035 pub fn new(spec: &ParticleSpec) -> Self {
1037 let n = spec.num_particles;
1038
1039 let particles: Vec<ParticleState> = (0..n)
1054 .map(|pi| {
1055 let particle_seed = spec
1056 .seed
1057 .wrapping_add((pi as u64).wrapping_mul(0x9e3779b97f4a7c15u64));
1058 let mut model = ParticleModel::new(spec);
1059 model.init(particle_seed, spec);
1060 ParticleState::new(spec, model, pi as u64)
1061 })
1062 .collect();
1063
1064 let log_w = -(n as f64).ln();
1065 let mut rt = Self {
1066 spec: spec.clone(),
1067 particles,
1068 log_weights: vec![log_w; n],
1069 mix_log_probs: [0.0; 256],
1070 mix_pdf: vec![0.0; 256],
1071 cache_valid: false,
1072 step_idx: 0,
1073 scratch_lse: vec![0.0; n],
1074 };
1075 rt.diversify_initial_particles();
1076 rt
1077 }
1078
1079 fn ensure_predictions(&mut self) {
1081 if self.cache_valid {
1082 return;
1083 }
1084 let spec = &self.spec;
1085 for p in &mut self.particles {
1086 if !p.cache_valid {
1087 p.forward(spec, self.step_idx);
1088 }
1089 }
1090 self.compute_mixture_log_probs();
1091 self.cache_valid = true;
1092 }
1093
1094 fn compute_mixture_log_probs(&mut self) {
1096 let n = self.particles.len();
1097 let log_z = logsumexp_wide(&self.log_weights);
1099
1100 for v in 0..256 {
1101 for i in 0..n {
1102 self.scratch_lse[i] = self.log_weights[i] + self.particles[i].cached_log_probs[v];
1103 }
1104 self.mix_log_probs[v] = logsumexp_wide(&self.scratch_lse) - log_z;
1105 }
1106
1107 let max_lp = max_wide(&self.mix_log_probs);
1109 let mut sum = 0.0;
1110 for v in 0..256 {
1111 let p = (self.mix_log_probs[v] - max_lp).exp();
1112 self.mix_pdf[v] = p;
1113 sum += p;
1114 }
1115 if sum > 0.0 {
1116 let inv = 1.0 / sum;
1117 for v in &mut self.mix_pdf {
1118 *v *= inv;
1119 }
1120 }
1121 }
1122
1123 pub fn peek_log_prob(&mut self, symbol: u8) -> f64 {
1125 self.ensure_predictions();
1126 self.mix_log_probs[symbol as usize]
1127 }
1128
1129 pub fn fill_log_probs_cached(&mut self, out: &mut [f64; 256]) {
1131 self.ensure_predictions();
1132 *out = self.mix_log_probs;
1133 }
1134
1135 pub fn pdf_next(&mut self) -> &[f64] {
1137 self.ensure_predictions();
1138 &self.mix_pdf
1139 }
1140
1141 pub fn step(&mut self, symbol: u8) -> f64 {
1143 self.ensure_predictions();
1144 let log_prob = self.mix_log_probs[symbol as usize];
1145
1146 let n = self.particles.len();
1147 let spec = &self.spec;
1148
1149 let beta = self.likelihood_beta();
1151 for i in 0..n {
1152 self.log_weights[i] += beta * self.particles[i].cached_log_probs[symbol as usize];
1153 }
1154 let log_z = logsumexp_wide(&self.log_weights);
1156 for w in &mut self.log_weights {
1157 *w -= log_z;
1158 }
1159
1160 if spec.forget_lambda > 0.0 {
1162 let uniform = -(n as f64).ln();
1163 for w in &mut self.log_weights {
1164 *w = (1.0 - spec.forget_lambda) * *w + spec.forget_lambda * uniform;
1165 }
1166 let log_z2 = logsumexp_wide(&self.log_weights);
1168 for w in &mut self.log_weights {
1169 *w -= log_z2;
1170 }
1171 }
1172
1173 let (n_eff_before, max_w_before) = self.weight_stats();
1174 let will_resample = n_eff_before < self.spec.resample_threshold * n as f64;
1175 let should_log = self.diagnostics_enabled();
1176 let divergence = if should_log {
1177 self.weighted_prediction_kl_divergence()
1178 } else {
1179 0.0
1180 };
1181
1182 for p in &mut self.particles {
1184 p.sgd_update(symbol, spec);
1185 }
1186
1187 for p in &mut self.particles {
1189 p.push_context(symbol);
1190 }
1191
1192 if should_log {
1193 self.log_diagnostics(n_eff_before, max_w_before, divergence, beta, will_resample);
1194 }
1195
1196 let _ = self.maybe_resample();
1198
1199 for p in &mut self.particles {
1201 p.cache_valid = false;
1202 }
1203 self.cache_valid = false;
1204 self.step_idx += 1;
1205
1206 log_prob
1207 }
1208
1209 pub fn reset_frozen_state(&mut self) {
1215 for particle in &mut self.particles {
1216 particle.reset_dynamic_state();
1217 }
1218 self.mix_log_probs.fill(0.0);
1219 self.mix_pdf.fill(1.0 / 256.0);
1220 self.cache_valid = false;
1221 self.step_idx = 0;
1222 }
1223
1224 pub fn update_frozen(&mut self, symbol: u8) {
1230 self.ensure_predictions();
1231 let n = self.particles.len();
1232 let spec = &self.spec;
1233
1234 let beta = self.likelihood_beta();
1235 for i in 0..n {
1236 self.log_weights[i] += beta * self.particles[i].cached_log_probs[symbol as usize];
1237 }
1238 let log_z = logsumexp_wide(&self.log_weights);
1239 for weight in &mut self.log_weights {
1240 *weight -= log_z;
1241 }
1242
1243 if spec.forget_lambda > 0.0 {
1244 let uniform = -(n as f64).ln();
1245 for weight in &mut self.log_weights {
1246 *weight = (1.0 - spec.forget_lambda) * *weight + spec.forget_lambda * uniform;
1247 }
1248 let log_z2 = logsumexp_wide(&self.log_weights);
1249 for weight in &mut self.log_weights {
1250 *weight -= log_z2;
1251 }
1252 }
1253
1254 for particle in &mut self.particles {
1255 particle.push_context(symbol);
1256 particle.cache_valid = false;
1257 }
1258 self.cache_valid = false;
1259 self.step_idx += 1;
1260 }
1261
1262 fn maybe_resample(&mut self) -> bool {
1264 let n = self.particles.len();
1265 if n <= 1 {
1266 return false;
1267 }
1268
1269 let mut sum_sq = 0.0;
1273 for &lw in &self.log_weights {
1274 let w = lw.exp();
1275 sum_sq += w * w;
1276 }
1277 let n_eff = if sum_sq > 0.0 { 1.0 / sum_sq } else { 0.0 };
1278
1279 if n_eff >= self.spec.resample_threshold * n as f64 {
1280 return false;
1281 }
1282
1283 let weights: Vec<f64> = self.log_weights.iter().map(|lw| lw.exp()).collect();
1285 let cdf: Vec<f64> = weights
1286 .iter()
1287 .scan(0.0, |acc, &w| {
1288 *acc += w;
1289 Some(*acc)
1290 })
1291 .collect();
1292 let total = *cdf.last().unwrap_or(&1.0);
1293
1294 let step = total / n as f64;
1295 let u0 =
1298 ((det_hash(self.spec.seed, self.step_idx, 0, 0) >> 11) as f64) / ((1u64 << 53) as f64);
1299 let mut u = u0 * step;
1300 let mut indices = Vec::with_capacity(n);
1301 let mut j = 0;
1302 for _ in 0..n {
1303 while j < n - 1 && cdf[j] < u {
1304 j += 1;
1305 }
1306 indices.push(j);
1307 u += step;
1308 }
1309
1310 let new_particles: Vec<ParticleState> = indices
1312 .iter()
1313 .map(|&idx| self.particles[idx].clone())
1314 .collect();
1315 self.particles = new_particles;
1316
1317 let n_mutate = ((self.spec.mutate_fraction * n as f64).round() as usize).min(n);
1319 let mut mutated = vec![false; n];
1320 let mut picked = 0usize;
1321 let mut draw = 0u64;
1322 while picked < n_mutate && draw < (n * 8) as u64 {
1323 let mi = (det_hash(self.spec.seed ^ self.step_idx, draw, 0xA5A5, 0x5A5A) as usize) % n;
1324 if !mutated[mi] {
1325 self.mutate_particle(mi);
1326 mutated[mi] = true;
1327 picked += 1;
1328 }
1329 draw += 1;
1330 }
1331
1332 let uniform = -(n as f64).ln();
1334 for w in &mut self.log_weights {
1335 *w = uniform;
1336 }
1337 true
1338 }
1339
1340 fn mutate_particle(&mut self, particle_idx: usize) {
1343 let seed = self.spec.seed;
1344 let step = self.step_idx;
1345 let pi = particle_idx as u64;
1346 let scale = self.spec.mutate_scale;
1347 let state_clip = self.spec.state_clip;
1348
1349 let p = &mut self.particles[particle_idx];
1350 let mut param_idx = 0u64;
1351
1352 for v in p.cells.iter_mut() {
1354 let noise = hash_to_f64(det_hash(seed ^ step, pi, param_idx, 0)) * scale;
1355 *v = clip(*v + noise, state_clip);
1356 param_idx += 1;
1357 }
1358
1359 if !self.spec.mutate_model_params {
1360 return;
1361 }
1362
1363 let layer_scale = |vals: &[f64]| -> f64 {
1364 if vals.is_empty() {
1365 return 1.0;
1366 }
1367 let mut s = 0.0_f64;
1368 for &v in vals {
1369 s += v * v;
1370 }
1371 (s / vals.len() as f64).sqrt().max(1e-6)
1372 };
1373
1374 let embed_layer = layer_scale(&p.model.embed);
1376 for v in p.model.embed.iter_mut() {
1377 let noise =
1378 hash_to_f64(det_hash(seed ^ step, pi, param_idx, 1)) * (scale * embed_layer);
1379 *v += noise;
1380 param_idx += 1;
1381 }
1382
1383 for cp in p.model.cells.iter_mut() {
1385 let sel_h_w = layer_scale(&cp.selector.hidden.weights);
1386 let sel_h_b = layer_scale(&cp.selector.hidden.bias);
1387 for v in cp.selector.hidden.weights.iter_mut() {
1388 let noise =
1389 hash_to_f64(det_hash(seed ^ step, pi, param_idx, 2)) * (scale * sel_h_w);
1390 *v += noise;
1391 param_idx += 1;
1392 }
1393 for v in cp.selector.hidden.bias.iter_mut() {
1394 let noise =
1395 hash_to_f64(det_hash(seed ^ step, pi, param_idx, 3)) * (scale * sel_h_b);
1396 *v += noise;
1397 param_idx += 1;
1398 }
1399 let sel_g_w = layer_scale(&cp.selector.gate.weights);
1400 let sel_g_b = layer_scale(&cp.selector.gate.bias);
1401 for v in cp.selector.gate.weights.iter_mut() {
1402 let noise =
1403 hash_to_f64(det_hash(seed ^ step, pi, param_idx, 4)) * (scale * sel_g_w);
1404 *v += noise;
1405 param_idx += 1;
1406 }
1407 for v in cp.selector.gate.bias.iter_mut() {
1408 let noise =
1409 hash_to_f64(det_hash(seed ^ step, pi, param_idx, 5)) * (scale * sel_g_b);
1410 *v += noise;
1411 param_idx += 1;
1412 }
1413 for rule in cp.rules.iter_mut() {
1414 let rule_h_w = layer_scale(&rule.hidden.weights);
1415 let rule_h_b = layer_scale(&rule.hidden.bias);
1416 let rule_o_w = layer_scale(&rule.output.weights);
1417 let rule_o_b = layer_scale(&rule.output.bias);
1418 for v in rule.hidden.weights.iter_mut() {
1419 let noise =
1420 hash_to_f64(det_hash(seed ^ step, pi, param_idx, 6)) * (scale * rule_h_w);
1421 *v += noise;
1422 param_idx += 1;
1423 }
1424 for v in rule.hidden.bias.iter_mut() {
1425 let noise =
1426 hash_to_f64(det_hash(seed ^ step, pi, param_idx, 7)) * (scale * rule_h_b);
1427 *v += noise;
1428 param_idx += 1;
1429 }
1430 for v in rule.output.weights.iter_mut() {
1431 let noise =
1432 hash_to_f64(det_hash(seed ^ step, pi, param_idx, 8)) * (scale * rule_o_w);
1433 *v += noise;
1434 param_idx += 1;
1435 }
1436 for v in rule.output.bias.iter_mut() {
1437 let noise =
1438 hash_to_f64(det_hash(seed ^ step, pi, param_idx, 9)) * (scale * rule_o_b);
1439 *v += noise;
1440 param_idx += 1;
1441 }
1442 }
1443 }
1444
1445 let readout_w = layer_scale(&p.model.readout.weights);
1447 let readout_b = layer_scale(&p.model.readout.bias);
1448 for v in p.model.readout.weights.iter_mut() {
1449 let noise = hash_to_f64(det_hash(seed ^ step, pi, param_idx, 10)) * (scale * readout_w);
1450 *v += noise;
1451 param_idx += 1;
1452 }
1453 for v in p.model.readout.bias.iter_mut() {
1454 let noise = hash_to_f64(det_hash(seed ^ step, pi, param_idx, 11)) * (scale * readout_b);
1455 *v += noise;
1456 param_idx += 1;
1457 }
1458 }
1459}
1460
1461impl Clone for ParticleRuntime {
1462 fn clone(&self) -> Self {
1463 Self {
1464 spec: self.spec.clone(),
1465 particles: self.particles.clone(),
1466 log_weights: self.log_weights.clone(),
1467 mix_log_probs: self.mix_log_probs,
1468 mix_pdf: self.mix_pdf.clone(),
1469 cache_valid: self.cache_valid,
1470 step_idx: self.step_idx,
1471 scratch_lse: self.scratch_lse.clone(),
1472 }
1473 }
1474}
1475
1476impl crate::mixture::OnlineBytePredictor for ParticleRuntime {
1478 fn log_prob(&mut self, symbol: u8) -> f64 {
1479 self.peek_log_prob(symbol)
1480 }
1481
1482 fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
1483 self.fill_log_probs_cached(out)
1484 }
1485
1486 fn update(&mut self, symbol: u8) {
1487 self.step(symbol);
1488 }
1489}
1490
1491#[cfg(test)]
1496mod tests {
1497 use super::*;
1498
1499 fn default_spec() -> ParticleSpec {
1500 ParticleSpec {
1501 num_particles: 4,
1502 context_window: 8,
1503 unroll_steps: 1,
1504 num_cells: 2,
1505 cell_dim: 4,
1506 num_rules: 2,
1507 selector_hidden: 8,
1508 rule_hidden: 8,
1509 noise_dim: 2,
1510 ..ParticleSpec::default()
1511 }
1512 }
1513
1514 #[test]
1515 fn pdf_sums_to_one() {
1516 let spec = default_spec();
1517 let mut rt = ParticleRuntime::new(&spec);
1518 let pdf = rt.pdf_next();
1519 let sum: f64 = pdf.iter().sum();
1520 assert!((sum - 1.0).abs() < 1e-6, "PDF sum = {sum}, expected ~1.0");
1521 }
1522
1523 #[test]
1524 fn log_probs_finite_and_nonpositive() {
1525 let spec = default_spec();
1526 let mut rt = ParticleRuntime::new(&spec);
1527 let data = b"hello world";
1528 for &b in data.iter() {
1529 let lp = rt.peek_log_prob(b);
1530 assert!(lp.is_finite(), "log_prob not finite: {lp}");
1531 assert!(lp <= 0.0, "log_prob positive: {lp}");
1532 rt.step(b);
1533 }
1534 }
1535
1536 #[test]
1537 fn deterministic_same_seed() {
1538 let spec = default_spec();
1539 let data = b"abcdefghij";
1540
1541 let mut rt1 = ParticleRuntime::new(&spec);
1542 let mut rt2 = ParticleRuntime::new(&spec);
1543
1544 for &b in data.iter() {
1545 let lp1 = rt1.step(b);
1546 let lp2 = rt2.step(b);
1547 assert!(
1548 (lp1 - lp2).abs() < 1e-12,
1549 "Mismatch at byte {b}: {lp1} vs {lp2}"
1550 );
1551 }
1552 }
1553
1554 #[test]
1555 fn deterministic_with_hash_noise_enabled() {
1556 let spec = ParticleSpec {
1557 enable_noise: true,
1558 noise_scale: 0.15,
1559 noise_anneal_steps: 128,
1560 ..default_spec()
1561 };
1562 let data = b"particle noise determinism";
1563
1564 let mut rt1 = ParticleRuntime::new(&spec);
1565 let mut rt2 = ParticleRuntime::new(&spec);
1566
1567 for &b in data {
1568 let lp1 = rt1.step(b);
1569 let lp2 = rt2.step(b);
1570 assert!(
1571 (lp1 - lp2).abs() < 1e-12,
1572 "Hash-noise path non-deterministic at byte {b}: {lp1} vs {lp2}"
1573 );
1574 }
1575 }
1576
1577 #[test]
1578 fn resample_forced() {
1579 let spec = ParticleSpec {
1580 resample_threshold: 1.0, ..default_spec()
1582 };
1583 let mut rt = ParticleRuntime::new(&spec);
1584 for &b in b"test resampling works ok" {
1586 let lp = rt.step(b);
1587 assert!(lp.is_finite(), "log_prob not finite after resample: {lp}");
1588 }
1589 }
1590
1591 #[test]
1592 fn mutation_determinism() {
1593 let spec = ParticleSpec {
1594 resample_threshold: 1.0,
1595 mutate_fraction: 1.0,
1596 ..default_spec()
1597 };
1598 let data = b"test mutation";
1599
1600 let mut rt1 = ParticleRuntime::new(&spec);
1601 let mut rt2 = ParticleRuntime::new(&spec);
1602
1603 for &b in data.iter() {
1604 let lp1 = rt1.step(b);
1605 let lp2 = rt2.step(b);
1606 assert!(
1607 (lp1 - lp2).abs() < 1e-12,
1608 "Mutation non-deterministic at byte {b}: {lp1} vs {lp2}"
1609 );
1610 }
1611 }
1612
1613 #[test]
1614 fn empty_input_no_crash() {
1615 let spec = default_spec();
1616 let mut rt = ParticleRuntime::new(&spec);
1617 let lp = rt.peek_log_prob(0);
1619 assert!(lp.is_finite());
1620 }
1621
1622 #[test]
1623 fn fill_log_probs_consistency() {
1624 let spec = default_spec();
1625 let mut rt = ParticleRuntime::new(&spec);
1626 rt.step(b'a');
1627 rt.step(b'b');
1628
1629 let mut bulk = [0.0; 256];
1630 rt.fill_log_probs_cached(&mut bulk);
1631
1632 for sym in 0..256u16 {
1633 let single = rt.peek_log_prob(sym as u8);
1634 assert!(
1635 (bulk[sym as usize] - single).abs() < 1e-12,
1636 "Mismatch for sym {sym}: bulk={} single={}",
1637 bulk[sym as usize],
1638 single
1639 );
1640 }
1641 }
1642
1643 #[test]
1644 fn spec_validation() {
1645 let mut spec = ParticleSpec::default();
1646 assert!(spec.validate().is_ok());
1647
1648 spec.num_particles = 0;
1649 assert!(spec.validate().is_err());
1650 spec.num_particles = 4;
1651
1652 spec.resample_threshold = 0.0;
1653 assert!(spec.validate().is_err());
1654 spec.resample_threshold = 0.5;
1655
1656 spec.min_prob = -1.0;
1657 assert!(spec.validate().is_err());
1658 }
1659
1660 fn assert_models_equal(lhs: &ParticleModel, rhs: &ParticleModel) {
1661 assert_eq!(lhs.embed, rhs.embed);
1662 assert_eq!(lhs.readout.weights, rhs.readout.weights);
1663 assert_eq!(lhs.readout.bias, rhs.readout.bias);
1664 assert_eq!(lhs.readout.vel_weights, rhs.readout.vel_weights);
1665 assert_eq!(lhs.readout.vel_bias, rhs.readout.vel_bias);
1666 assert_eq!(lhs.cells.len(), rhs.cells.len());
1667 for (lhs_cell, rhs_cell) in lhs.cells.iter().zip(rhs.cells.iter()) {
1668 assert_eq!(
1669 lhs_cell.selector.hidden.weights,
1670 rhs_cell.selector.hidden.weights
1671 );
1672 assert_eq!(lhs_cell.selector.hidden.bias, rhs_cell.selector.hidden.bias);
1673 assert_eq!(
1674 lhs_cell.selector.hidden.vel_weights,
1675 rhs_cell.selector.hidden.vel_weights
1676 );
1677 assert_eq!(
1678 lhs_cell.selector.hidden.vel_bias,
1679 rhs_cell.selector.hidden.vel_bias
1680 );
1681 assert_eq!(
1682 lhs_cell.selector.gate.weights,
1683 rhs_cell.selector.gate.weights
1684 );
1685 assert_eq!(lhs_cell.selector.gate.bias, rhs_cell.selector.gate.bias);
1686 assert_eq!(
1687 lhs_cell.selector.gate.vel_weights,
1688 rhs_cell.selector.gate.vel_weights
1689 );
1690 assert_eq!(
1691 lhs_cell.selector.gate.vel_bias,
1692 rhs_cell.selector.gate.vel_bias
1693 );
1694 assert_eq!(lhs_cell.rules.len(), rhs_cell.rules.len());
1695 for (lhs_rule, rhs_rule) in lhs_cell.rules.iter().zip(rhs_cell.rules.iter()) {
1696 assert_eq!(lhs_rule.hidden.weights, rhs_rule.hidden.weights);
1697 assert_eq!(lhs_rule.hidden.bias, rhs_rule.hidden.bias);
1698 assert_eq!(lhs_rule.hidden.vel_weights, rhs_rule.hidden.vel_weights);
1699 assert_eq!(lhs_rule.hidden.vel_bias, rhs_rule.hidden.vel_bias);
1700 assert_eq!(lhs_rule.output.weights, rhs_rule.output.weights);
1701 assert_eq!(lhs_rule.output.bias, rhs_rule.output.bias);
1702 assert_eq!(lhs_rule.output.vel_weights, rhs_rule.output.vel_weights);
1703 assert_eq!(lhs_rule.output.vel_bias, rhs_rule.output.vel_bias);
1704 }
1705 }
1706 }
1707
1708 #[test]
1709 fn frozen_update_preserves_model_parameters() {
1710 let spec = default_spec();
1711 let mut rt = ParticleRuntime::new(&spec);
1712 for &b in b"particle plugin separation" {
1713 rt.step(b);
1714 }
1715
1716 let before_models: Vec<_> = rt.particles.iter().map(|p| p.model.clone()).collect();
1717 rt.reset_frozen_state();
1718 assert!(rt.particles.iter().all(|p| p.ctx_len == 0));
1719
1720 let lp = rt.peek_log_prob(b'x');
1721 assert!(lp.is_finite());
1722 rt.update_frozen(b'x');
1723
1724 for (before, particle) in before_models.iter().zip(rt.particles.iter()) {
1725 assert_models_equal(before, &particle.model);
1726 }
1727 assert_eq!(rt.step_idx, 1);
1728 assert!(rt.particles.iter().all(|p| p.ctx_len == 1));
1729 }
1730}