infotheory/backends/
particle.rs

1//! Particle-latent filter ensemble rate backend.
2//!
3//! Implements a deterministic-by-default sequential particle model with latent
4//! cells, selector/rule dynamics, online SGD, Bayesian particle weighting,
5//! and resample+mutation.
6
7use crate::ParticleSpec;
8use crate::simd_math::{axpy_wide, dot_wide, logsumexp_wide, max_wide};
9use std::collections::VecDeque;
10
11// ---------------------------------------------------------------------------
12// Deterministic hash utilities
13// ---------------------------------------------------------------------------
14
15/// Deterministic hash producing a u64 from four index components.
16/// Uses a simple multiply-xor-shift chain seeded by `seed`.
17#[inline]
18fn det_hash(seed: u64, a: u64, b: u64, c: u64) -> u64 {
19    let mut h = seed;
20    h = h.wrapping_mul(0x517cc1b727220a95).wrapping_add(a);
21    h ^= h >> 33;
22    h = h.wrapping_mul(0x4cf5ad432745937f).wrapping_add(b);
23    h ^= h >> 33;
24    h = h.wrapping_mul(0x6c62272e07bb0142).wrapping_add(c);
25    h ^= h >> 33;
26    h
27}
28
29/// Map a hash to a deterministic f64 in [-1, 1].
30#[inline]
31fn hash_to_f64(h: u64) -> f64 {
32    // Map to [0, 1) then scale to [-1, 1)
33    let u = (h >> 11) as f64 / ((1u64 << 53) as f64);
34    u * 2.0 - 1.0
35}
36
37/// Deterministic init value for a parameter, small magnitude.
38#[inline]
39fn init_param(seed: u64, layer: u64, row: u64, col: u64, scale: f64) -> f64 {
40    hash_to_f64(det_hash(seed, layer, row, col)) * scale
41}
42
43// ---------------------------------------------------------------------------
44// Small math helpers
45// ---------------------------------------------------------------------------
46
47#[inline]
48fn clip(x: f64, limit: f64) -> f64 {
49    x.clamp(-limit, limit)
50}
51
52fn softmax_inplace(xs: &mut [f64]) {
53    let max_v = max_wide(xs);
54    let mut sum = 0.0;
55    for x in xs.iter_mut() {
56        *x = (*x - max_v).exp();
57        sum += *x;
58    }
59    if sum > 0.0 {
60        let inv = 1.0 / sum;
61        for x in xs.iter_mut() {
62            *x *= inv;
63        }
64    }
65}
66
67fn log_softmax_with_floor(logits: &[f64], out: &mut [f64], min_prob: f64) {
68    let max_v = max_wide(logits);
69    let mut sum = 0.0;
70    for &l in logits {
71        sum += (l - max_v).exp();
72    }
73    let log_z = max_v + sum.ln();
74    let log_floor = min_prob.ln();
75    // First pass: compute raw log-softmax with floor
76    let mut log_sum_exp_floor = f64::NEG_INFINITY;
77    for (i, &l) in logits.iter().enumerate() {
78        let lp = (l - log_z).max(log_floor);
79        out[i] = lp;
80        // accumulate for renormalization
81        if lp > log_sum_exp_floor {
82            let diff = log_sum_exp_floor - lp;
83            if diff.is_finite() {
84                log_sum_exp_floor = lp + (1.0 + diff.exp()).ln();
85            } else {
86                log_sum_exp_floor = lp;
87            }
88        } else {
89            let diff = lp - log_sum_exp_floor;
90            if diff.is_finite() {
91                log_sum_exp_floor += (1.0 + diff.exp()).ln();
92            }
93        }
94    }
95    // Renormalize so that sum(exp(out)) = 1
96    if log_sum_exp_floor.is_finite() {
97        for v in out.iter_mut() {
98            *v -= log_sum_exp_floor;
99        }
100    }
101}
102
103// ---------------------------------------------------------------------------
104// MLP layer: y = relu(W * x + b) or y = W * x + b
105// ---------------------------------------------------------------------------
106
107/// Dense layer parameters (row-major: weights[out_dim * in_dim]).
108#[derive(Clone)]
109struct DenseLayer {
110    weights: Vec<f64>,
111    bias: Vec<f64>,
112    vel_weights: Vec<f64>,
113    vel_bias: Vec<f64>,
114    in_dim: usize,
115    out_dim: usize,
116}
117
118impl DenseLayer {
119    fn new(in_dim: usize, out_dim: usize) -> Self {
120        Self {
121            weights: vec![0.0; out_dim * in_dim],
122            bias: vec![0.0; out_dim],
123            vel_weights: vec![0.0; out_dim * in_dim],
124            vel_bias: vec![0.0; out_dim],
125            in_dim,
126            out_dim,
127        }
128    }
129
130    fn init(&mut self, seed: u64, layer_id: u64, scale: f64) {
131        for r in 0..self.out_dim {
132            for c in 0..self.in_dim {
133                self.weights[r * self.in_dim + c] =
134                    init_param(seed, layer_id, r as u64, c as u64, scale);
135            }
136            self.bias[r] = 0.0;
137        }
138    }
139
140    /// Forward: out = W * x + b (no activation).
141    fn forward(&self, x: &[f64], out: &mut [f64]) {
142        debug_assert!(x.len() >= self.in_dim);
143        debug_assert!(out.len() >= self.out_dim);
144        for (r, slot) in out.iter_mut().enumerate().take(self.out_dim) {
145            let row = &self.weights[r * self.in_dim..(r + 1) * self.in_dim];
146            *slot = dot_wide(row, &x[..self.in_dim]) + self.bias[r];
147        }
148    }
149
150    /// Forward with ReLU: out = max(0, W * x + b).
151    fn forward_relu(&self, x: &[f64], out: &mut [f64]) {
152        self.forward(x, out);
153        for v in out[..self.out_dim].iter_mut() {
154            *v = v.max(0.0);
155        }
156    }
157
158    /// SGD update: weights -= lr * grad_out ⊗ x, bias -= lr * grad_out.
159    /// Clips gradients and parameters.
160    fn sgd_update(&mut self, grad_out: &[f64], x: &[f64], lr: f64, grad_clip: f64, momentum: f64) {
161        if momentum == 0.0 {
162            for (r, &grad) in grad_out.iter().enumerate().take(self.out_dim) {
163                let g = clip(grad, grad_clip);
164                let row = &mut self.weights[r * self.in_dim..(r + 1) * self.in_dim];
165                axpy_wide(row, -lr * g, &x[..self.in_dim]);
166                self.bias[r] -= lr * g;
167            }
168            return;
169        }
170
171        for (r, &grad) in grad_out.iter().enumerate().take(self.out_dim) {
172            let g = clip(grad, grad_clip);
173            for (c, &x_c) in x.iter().enumerate().take(self.in_dim) {
174                let idx = r * self.in_dim + c;
175                let grad_w = g * x_c;
176                self.vel_weights[idx] = momentum * self.vel_weights[idx] + grad_w;
177                self.weights[idx] -= lr * self.vel_weights[idx];
178            }
179            self.vel_bias[r] = momentum * self.vel_bias[r] + g;
180            self.bias[r] -= lr * self.vel_bias[r];
181        }
182    }
183}
184
185// ---------------------------------------------------------------------------
186// Per-cell selector + rules
187// ---------------------------------------------------------------------------
188
189/// Selector MLP for one cell: maps input → rule gate probabilities.
190#[derive(Clone)]
191struct CellSelector {
192    hidden: DenseLayer, // in: 5*cell_dim → selector_hidden (relu)
193    gate: DenseLayer,   // in: selector_hidden → num_rules (softmax)
194}
195
196/// One rule MLP for one cell: maps (input, noise) → cell delta.
197#[derive(Clone)]
198struct CellRule {
199    hidden: DenseLayer, // in: 5*cell_dim + noise_dim → rule_hidden (relu)
200    output: DenseLayer, // in: rule_hidden → cell_dim (linear)
201}
202
203/// All selector + rule params for one cell.
204#[derive(Clone)]
205struct CellParams {
206    selector: CellSelector,
207    rules: Vec<CellRule>,
208}
209
210// ---------------------------------------------------------------------------
211// ParticleModel: shared parameters across the ensemble
212// ---------------------------------------------------------------------------
213
214/// The neural model parameters shared by all particles (cloned per-particle
215/// for independent SGD, but initialized identically).
216#[derive(Clone)]
217struct ParticleModel {
218    /// Byte embedding table: [256][cell_dim].
219    embed: Vec<f64>,
220    /// Per-cell selector + rules.
221    cells: Vec<CellParams>,
222    /// Readout layer: phi_dim → 256.
223    readout: DenseLayer,
224    /// Spec dimensions (cached).
225    cell_dim: usize,
226    num_cells: usize,
227    noise_dim: usize,
228    phi_dim: usize, // 5 * cell_dim (mean, max, stddev, second_last_emb, last_emb)
229    selector_in_dim: usize, // 5 * cell_dim
230}
231
232impl ParticleModel {
233    fn new(spec: &ParticleSpec) -> Self {
234        let cell_dim = spec.cell_dim;
235        let selector_in_dim = 5 * cell_dim;
236        let rule_in_dim = 5 * cell_dim + spec.noise_dim;
237        // phi = [mean_cells, max_cells, stddev_cells, ctx, last_emb]
238        // The last_emb component is the direct embedding of the most recent byte.
239        // This skip connection gives the readout immediate access to the nearest
240        // predecessor without waiting for the slow cell warm-up.
241        let phi_dim = 5 * cell_dim;
242
243        let embed = vec![0.0; 256 * cell_dim];
244        let cells = (0..spec.num_cells)
245            .map(|_| CellParams {
246                selector: CellSelector {
247                    hidden: DenseLayer::new(selector_in_dim, spec.selector_hidden),
248                    gate: DenseLayer::new(spec.selector_hidden, spec.num_rules),
249                },
250                rules: (0..spec.num_rules)
251                    .map(|_| CellRule {
252                        hidden: DenseLayer::new(rule_in_dim, spec.rule_hidden),
253                        output: DenseLayer::new(spec.rule_hidden, cell_dim),
254                    })
255                    .collect(),
256            })
257            .collect();
258        let readout = DenseLayer::new(phi_dim, 256);
259
260        Self {
261            embed,
262            cells,
263            readout,
264            cell_dim,
265            num_cells: spec.num_cells,
266            noise_dim: spec.noise_dim,
267            phi_dim,
268            selector_in_dim,
269        }
270    }
271
272    fn init(&mut self, seed: u64, spec: &ParticleSpec) {
273        let scale = 0.1;
274        // Use a larger scale for the embedding table so that the ctx and last_emb
275        // components of phi carry significant signal from the first byte onward.
276        // With scale=0.1 the effective cell-update magnitude is ~0.0005/byte,
277        // meaning cells only become meaningful after thousands of bytes.
278        // With embed_scale=0.3 the cell dynamics warm up ~3× faster.
279        let embed_scale = 0.3;
280        // Embedding table
281        for i in 0..256 {
282            for j in 0..self.cell_dim {
283                self.embed[i * self.cell_dim + j] =
284                    init_param(seed, 0, i as u64, j as u64, embed_scale);
285            }
286        }
287        // Per-cell params
288        for (ci, cp) in self.cells.iter_mut().enumerate() {
289            let cell_seed = ci as u64 + 1;
290            cp.selector.hidden.init(seed, cell_seed * 100 + 1, scale);
291            cp.selector
292                .gate
293                .init(seed, cell_seed * 100 + 2, scale * 0.1);
294            for (ri, rule) in cp.rules.iter_mut().enumerate() {
295                let r_off = cell_seed * 100 + 10 + ri as u64;
296                rule.hidden.init(seed, r_off * 10 + 1, scale);
297                rule.output.init(seed, r_off * 10 + 2, scale * 0.5);
298            }
299        }
300        // Readout: small init
301        self.readout.init(seed, 9999, scale * 0.1);
302        let _ = spec; // spec used for future extensions
303    }
304}
305
306// ---------------------------------------------------------------------------
307// ParticleState: per-particle latent state
308// ---------------------------------------------------------------------------
309
310#[derive(Clone)]
311struct ParticleState {
312    /// Stable particle id for deterministic hash-noise generation.
313    particle_id: u64,
314    /// Latent cell values: [num_cells * cell_dim].
315    cells: Vec<f64>,
316    /// Context ring buffer (stores raw byte values).
317    context: Vec<u8>,
318    /// Write position in ring buffer.
319    ctx_pos: usize,
320    /// Number of bytes seen (for context length tracking).
321    ctx_len: usize,
322    /// Per-particle model parameters (for independent SGD).
323    model: ParticleModel,
324    /// Cached 256-way log-probabilities from last forward pass.
325    cached_log_probs: [f64; 256],
326    /// Whether cached_log_probs is valid.
327    cache_valid: bool,
328    // Scratch buffers (reused across forward passes to avoid allocation).
329    scratch_ctx: Vec<f64>,
330    scratch_mean_cells: Vec<f64>,
331    scratch_p: Vec<f64>,
332    scratch_sel_h: Vec<f64>,
333    scratch_gate: Vec<f64>,
334    scratch_rule_in: Vec<f64>,
335    scratch_rule_h: Vec<f64>,
336    scratch_delta_k: Vec<f64>,
337    scratch_delta: Vec<f64>,
338    scratch_phi: Vec<f64>,
339    scratch_logits: Vec<f64>,
340    // Backprop scratch buffers
341    scratch_d_logits: Vec<f64>,
342    scratch_d_phi: Vec<f64>,
343    scratch_softmax: Vec<f64>,
344    scratch_d_rule_out: Vec<f64>,
345    scratch_d_rule_h: Vec<f64>,
346    scratch_d_gate: Vec<f64>,
347    scratch_d_gate_logits: Vec<f64>,
348    scratch_d_sel_h: Vec<f64>,
349    trace_history: VecDeque<StepTrace>,
350}
351
352#[derive(Clone)]
353struct RuleTrace {
354    rule_h: Vec<f64>,
355    rule_out: Vec<f64>,
356}
357
358#[derive(Clone)]
359struct CellTrace {
360    p: Vec<f64>,
361    sel_h: Vec<f64>,
362    gate: Vec<f64>,
363    rule_in: Vec<f64>,
364    rules: Vec<RuleTrace>,
365}
366
367#[derive(Clone)]
368struct StepTrace {
369    cells: Vec<CellTrace>,
370}
371
372impl ParticleState {
373    fn new(spec: &ParticleSpec, model: ParticleModel, particle_id: u64) -> Self {
374        let cd = spec.cell_dim;
375        let nc = spec.num_cells;
376        let sel_in = 5 * cd;
377        let rule_in = 5 * cd + spec.noise_dim;
378        let phi_dim = model.phi_dim; // 5 * cell_dim
379        Self {
380            particle_id,
381            cells: vec![0.0; nc * cd],
382            context: vec![0; spec.context_window],
383            ctx_pos: 0,
384            ctx_len: 0,
385            model,
386            cached_log_probs: [0.0; 256],
387            cache_valid: false,
388            scratch_ctx: vec![0.0; cd],
389            scratch_mean_cells: vec![0.0; cd],
390            scratch_p: vec![0.0; sel_in],
391            scratch_sel_h: vec![0.0; spec.selector_hidden],
392            scratch_gate: vec![0.0; spec.num_rules],
393            scratch_rule_in: vec![0.0; rule_in],
394            scratch_rule_h: vec![0.0; spec.rule_hidden],
395            scratch_delta_k: vec![0.0; cd],
396            scratch_delta: vec![0.0; cd],
397            scratch_phi: vec![0.0; phi_dim],
398            scratch_logits: vec![0.0; 256],
399            scratch_d_logits: vec![0.0; 256],
400            scratch_d_phi: vec![0.0; phi_dim],
401            scratch_softmax: vec![0.0; 256],
402            scratch_d_rule_out: vec![0.0; cd],
403            scratch_d_rule_h: vec![0.0; spec.rule_hidden],
404            scratch_d_gate: vec![0.0; spec.num_rules],
405            scratch_d_gate_logits: vec![0.0; spec.num_rules],
406            scratch_d_sel_h: vec![0.0; spec.selector_hidden],
407            trace_history: VecDeque::with_capacity(spec.bptt_depth.max(1)),
408        }
409    }
410
411    /// Build context vector by mean-pooling byte embeddings from the ring buffer.
412    fn build_ctx(&mut self) {
413        let cd = self.model.cell_dim;
414        self.scratch_ctx.iter_mut().for_each(|v| *v = 0.0);
415        let len = self.ctx_len.min(self.context.len());
416        if len == 0 {
417            return;
418        }
419        let cw = self.context.len();
420        // Exponential-decay pooling preserves order information by emphasizing
421        // recent bytes while keeping fixed-size context.
422        let decay = 0.90_f64;
423        let mut weight_sum = 0.0_f64;
424        let mut w = 1.0_f64;
425        for age in 0..len {
426            let pos = (self.ctx_pos + cw - 1 - age) % cw;
427            let byte = self.context[pos] as usize;
428            let emb = &self.model.embed[byte * cd..(byte + 1) * cd];
429            weight_sum += w;
430            for (ctx, &emb_j) in self.scratch_ctx.iter_mut().zip(emb.iter()) {
431                *ctx += emb_j * w;
432            }
433            w *= decay;
434        }
435        if weight_sum > 0.0 {
436            let inv = 1.0 / weight_sum;
437            for v in &mut self.scratch_ctx {
438                *v *= inv;
439            }
440        }
441    }
442
443    /// Compute mean of all cell vectors.
444    fn compute_mean_cells(&mut self) {
445        let cd = self.model.cell_dim;
446        let nc = self.model.num_cells;
447        self.scratch_mean_cells.iter_mut().for_each(|v| *v = 0.0);
448        if nc == 0 {
449            return;
450        }
451        let inv = 1.0 / nc as f64;
452        for ci in 0..nc {
453            let off = ci * cd;
454            for j in 0..cd {
455                self.scratch_mean_cells[j] += self.cells[off + j] * inv;
456            }
457        }
458    }
459
460    /// Build selector input p = concat(cell_i, left, right, ctx, mean_cells).
461    fn build_selector_input(&mut self, cell_idx: usize) {
462        let cd = self.model.cell_dim;
463        let nc = self.model.num_cells.max(1);
464        let off = cell_idx * cd;
465        let left_idx = if nc <= 1 {
466            cell_idx
467        } else {
468            (cell_idx + nc - 1) % nc
469        };
470        let right_idx = if nc <= 1 {
471            cell_idx
472        } else {
473            (cell_idx + 1) % nc
474        };
475        let left_off = left_idx * cd;
476        let right_off = right_idx * cd;
477        self.scratch_p[..cd].copy_from_slice(&self.cells[off..off + cd]);
478        self.scratch_p[cd..2 * cd].copy_from_slice(&self.cells[left_off..left_off + cd]);
479        self.scratch_p[2 * cd..3 * cd].copy_from_slice(&self.cells[right_off..right_off + cd]);
480        self.scratch_p[3 * cd..4 * cd].copy_from_slice(&self.scratch_ctx[..cd]);
481        self.scratch_p[4 * cd..5 * cd].copy_from_slice(&self.scratch_mean_cells[..cd]);
482    }
483
484    /// Build rule input = concat(p, z) with deterministic hash-noise annealing.
485    fn build_rule_input(
486        &mut self,
487        spec: &ParticleSpec,
488        step_idx: u64,
489        unroll_idx: usize,
490        cell_idx: usize,
491    ) {
492        let sel_in = self.model.selector_in_dim;
493        let nd = self.model.noise_dim;
494        self.scratch_rule_in[..sel_in].copy_from_slice(&self.scratch_p[..sel_in]);
495        if nd == 0 || !spec.enable_noise || spec.noise_scale <= 0.0 {
496            for j in sel_in..sel_in + nd {
497                self.scratch_rule_in[j] = 0.0;
498            }
499            return;
500        }
501        let anneal = if spec.noise_anneal_steps == 0 {
502            1.0
503        } else {
504            let rem = spec.noise_anneal_steps.saturating_sub(step_idx as usize) as f64;
505            rem / spec.noise_anneal_steps as f64
506        };
507        let scale = spec.noise_scale * anneal.max(0.0);
508        for j in 0..nd {
509            let h = det_hash(
510                spec.seed ^ self.particle_id,
511                step_idx,
512                ((unroll_idx as u64) << 40) ^ ((cell_idx as u64) << 20) ^ j as u64,
513                0xD1A6_51EED,
514            );
515            self.scratch_rule_in[sel_in + j] = hash_to_f64(h) * scale;
516        }
517    }
518
519    /// Build featurize vector phi = concat(mean_cells, max_cells, stddev_cells, second_last_emb, last_emb).
520    ///
521    /// Components:
522    ///   [0..cd]     mean_cells      — long-range context via latent cells
523    ///   [cd..2cd]   max_cells       — long-range context peak
524    ///   [2cd..3cd]  stddev_cells    — latent uncertainty
525    ///   [3cd..4cd]  second_last_emb — direct embedding of x_{t-2} (skip connection)
526    ///   [4cd..5cd]  last_emb        — direct embedding of x_{t-1} (skip connection)
527    ///
528    /// The two skip connections give the readout immediate access to the exact
529    /// preceding two bytes, enabling bigram/trigram statistics to be learned
530    /// from the very first observation.  We no longer include the blurred
531    /// exponential-decay context average as a phi component; the selector/rule
532    /// cell dynamics still use it (via scratch_ctx), but for prediction it is
533    /// dominated by the direct embeddings.
534    fn build_phi(&mut self) {
535        let cd = self.model.cell_dim;
536        let nc = self.model.num_cells;
537        // mean_cells already in scratch_mean_cells
538        self.scratch_phi[..cd].copy_from_slice(&self.scratch_mean_cells[..cd]);
539        // max_cells
540        for j in 0..cd {
541            let mut mx = f64::NEG_INFINITY;
542            for ci in 0..nc {
543                let v = self.cells[ci * cd + j];
544                if v > mx {
545                    mx = v;
546                }
547            }
548            self.scratch_phi[cd + j] = if mx.is_finite() { mx } else { 0.0 };
549        }
550        // stddev across cells (captures latent uncertainty)
551        for j in 0..cd {
552            let mean = self.scratch_mean_cells[j];
553            let mut var = 0.0_f64;
554            for ci in 0..nc {
555                let d = self.cells[ci * cd + j] - mean;
556                var += d * d;
557            }
558            self.scratch_phi[2 * cd + j] = (var / nc.max(1) as f64).sqrt();
559        }
560        // second_last_emb: direct embedding of x_{t-2}.
561        let cw = self.context.len();
562        if self.ctx_len >= 2 {
563            // ctx_pos points to the NEXT write slot; walk back 2 bytes.
564            let pos2 = (self.ctx_pos + cw - 2) % cw;
565            let byte2 = self.context[pos2] as usize;
566            self.scratch_phi[3 * cd..4 * cd]
567                .copy_from_slice(&self.model.embed[byte2 * cd..(byte2 + 1) * cd]);
568        } else {
569            self.scratch_phi[3 * cd..4 * cd].fill(0.0);
570        }
571        // last_emb: direct embedding of x_{t-1}.
572        if self.ctx_len >= 1 {
573            let pos1 = (self.ctx_pos + cw - 1) % cw;
574            let byte1 = self.context[pos1] as usize;
575            self.scratch_phi[4 * cd..5 * cd]
576                .copy_from_slice(&self.model.embed[byte1 * cd..(byte1 + 1) * cd]);
577        } else {
578            self.scratch_phi[4 * cd..5 * cd].fill(0.0);
579        }
580    }
581
582    /// Full forward pass: update latent cells, compute log-probabilities.
583    fn forward(&mut self, spec: &ParticleSpec, step_idx: u64) {
584        self.build_ctx();
585        self.compute_mean_cells();
586        let capture_trace = spec.learning_rate_selector > 0.0 || spec.learning_rate_rule > 0.0;
587        let mut step_trace = if capture_trace {
588            Some(StepTrace {
589                cells: Vec::with_capacity(self.model.num_cells),
590            })
591        } else {
592            None
593        };
594
595        // Latent update (unroll_steps iterations)
596        for unroll_idx in 0..spec.unroll_steps {
597            for ci in 0..self.model.num_cells {
598                self.build_selector_input(ci);
599
600                // Selector: hidden = relu(W_sel * p + b_sel)
601                self.model.cells[ci]
602                    .selector
603                    .hidden
604                    .forward_relu(&self.scratch_p, &mut self.scratch_sel_h);
605                // Gate: gate_logits = V_sel * h + c_sel, then softmax
606                self.model.cells[ci]
607                    .selector
608                    .gate
609                    .forward(&self.scratch_sel_h, &mut self.scratch_gate);
610                softmax_inplace(&mut self.scratch_gate[..spec.num_rules]);
611
612                // Build rule input
613                self.build_rule_input(spec, step_idx, unroll_idx, ci);
614
615                // Compute weighted delta
616                let cd = self.model.cell_dim;
617                self.scratch_delta[..cd].fill(0.0);
618                let mut rule_traces = if capture_trace {
619                    Some(Vec::with_capacity(spec.num_rules))
620                } else {
621                    None
622                };
623                for ki in 0..spec.num_rules {
624                    let gate_k = self.scratch_gate[ki];
625                    // Rule hidden
626                    self.model.cells[ci].rules[ki]
627                        .hidden
628                        .forward_relu(&self.scratch_rule_in, &mut self.scratch_rule_h);
629                    // Rule output
630                    self.model.cells[ci].rules[ki]
631                        .output
632                        .forward(&self.scratch_rule_h, &mut self.scratch_delta_k);
633                    if let Some(rt) = &mut rule_traces {
634                        rt.push(RuleTrace {
635                            rule_h: self.scratch_rule_h.clone(),
636                            rule_out: self.scratch_delta_k[..cd].to_vec(),
637                        });
638                    }
639                    for j in 0..cd {
640                        self.scratch_delta[j] += gate_k * self.scratch_delta_k[j];
641                    }
642                }
643                if let (Some(st), Some(rt)) = (&mut step_trace, rule_traces) {
644                    st.cells.push(CellTrace {
645                        p: self.scratch_p.clone(),
646                        sel_h: self.scratch_sel_h.clone(),
647                        gate: self.scratch_gate[..spec.num_rules].to_vec(),
648                        rule_in: self.scratch_rule_in.clone(),
649                        rules: rt,
650                    });
651                }
652
653                // Update cell
654                let off = ci * cd;
655                for j in 0..cd {
656                    self.cells[off + j] =
657                        clip(self.cells[off + j] + self.scratch_delta[j], spec.state_clip);
658                }
659            }
660            // Recompute mean_cells after each unroll step
661            self.compute_mean_cells();
662        }
663
664        // Featurize
665        self.build_phi();
666
667        // Readout
668        self.model
669            .readout
670            .forward(&self.scratch_phi, &mut self.scratch_logits);
671
672        // Reuse exact softmax(logits) in SGD to avoid recomputing it.
673        self.scratch_softmax.copy_from_slice(&self.scratch_logits);
674        softmax_inplace(&mut self.scratch_softmax);
675
676        // Log-softmax with floor
677        log_softmax_with_floor(
678            &self.scratch_logits,
679            &mut self.cached_log_probs,
680            spec.min_prob,
681        );
682        if let Some(st) = step_trace {
683            self.trace_history.push_back(st);
684            while self.trace_history.len() > spec.bptt_depth.max(1) {
685                self.trace_history.pop_front();
686            }
687        }
688        self.cache_valid = true;
689    }
690
691    fn apply_selector_rule_update_from_trace(
692        &mut self,
693        trace: &StepTrace,
694        d_phi: &[f64],
695        temporal: f64,
696        spec: &ParticleSpec,
697    ) {
698        let cd = self.model.cell_dim;
699        let nc = self.model.num_cells.max(1);
700        let d_delta_scale = (1.0 / nc as f64) * temporal;
701
702        for ci in 0..nc.min(trace.cells.len()) {
703            let ct = &trace.cells[ci];
704
705            // d_gate[k] = dot(d_delta, rule_out_k)
706            self.scratch_d_gate[..spec.num_rules].fill(0.0);
707
708            for ki in 0..spec.num_rules.min(ct.rules.len()) {
709                let gate_k = ct.gate[ki];
710                self.scratch_d_rule_out[..cd].fill(0.0);
711                for (dst, &d_phi_j) in self.scratch_d_rule_out[..cd]
712                    .iter_mut()
713                    .zip(d_phi.iter().take(cd))
714                {
715                    *dst = d_phi_j * d_delta_scale * gate_k;
716                }
717
718                // output layer update
719                self.model.cells[ci].rules[ki].output.sgd_update(
720                    &self.scratch_d_rule_out,
721                    &ct.rules[ki].rule_h,
722                    spec.learning_rate_rule,
723                    spec.grad_clip,
724                    spec.optimizer_momentum,
725                );
726
727                // hidden layer update
728                let rh = spec.rule_hidden;
729                self.scratch_d_rule_h[..rh].fill(0.0);
730                for r in 0..cd {
731                    let g = clip(self.scratch_d_rule_out[r], spec.grad_clip);
732                    if g.abs() < 1e-15 {
733                        continue;
734                    }
735                    for c in 0..rh {
736                        self.scratch_d_rule_h[c] +=
737                            g * self.model.cells[ci].rules[ki].output.weights[r * rh + c];
738                    }
739                }
740                for (j, h) in ct.rules[ki].rule_h.iter().enumerate().take(rh) {
741                    if *h <= 0.0 {
742                        self.scratch_d_rule_h[j] = 0.0;
743                    }
744                }
745                self.model.cells[ci].rules[ki].hidden.sgd_update(
746                    &self.scratch_d_rule_h[..rh],
747                    &ct.rule_in,
748                    spec.learning_rate_rule,
749                    spec.grad_clip,
750                    spec.optimizer_momentum,
751                );
752
753                for (&d_phi_j, &rule_out_j) in d_phi
754                    .iter()
755                    .take(cd)
756                    .zip(ct.rules[ki].rule_out.iter().take(cd))
757                {
758                    self.scratch_d_gate[ki] += d_phi_j * d_delta_scale * rule_out_j;
759                }
760            }
761
762            let dot_gd: f64 = (0..spec.num_rules.min(ct.gate.len()))
763                .map(|k| ct.gate[k] * self.scratch_d_gate[k])
764                .sum();
765            self.scratch_d_gate_logits[..spec.num_rules].fill(0.0);
766            for k in 0..spec.num_rules.min(ct.gate.len()) {
767                self.scratch_d_gate_logits[k] = ct.gate[k] * (self.scratch_d_gate[k] - dot_gd);
768            }
769
770            self.model.cells[ci].selector.gate.sgd_update(
771                &self.scratch_d_gate_logits[..spec.num_rules],
772                &ct.sel_h,
773                spec.learning_rate_selector,
774                spec.grad_clip,
775                spec.optimizer_momentum,
776            );
777
778            let sh = spec.selector_hidden;
779            self.scratch_d_sel_h[..sh].fill(0.0);
780            for r in 0..spec.num_rules.min(ct.gate.len()) {
781                let g = clip(self.scratch_d_gate_logits[r], spec.grad_clip);
782                if g.abs() < 1e-15 {
783                    continue;
784                }
785                for c in 0..sh {
786                    self.scratch_d_sel_h[c] +=
787                        g * self.model.cells[ci].selector.gate.weights[r * sh + c];
788                }
789            }
790            for (j, h) in ct.sel_h.iter().enumerate().take(sh) {
791                if *h <= 0.0 {
792                    self.scratch_d_sel_h[j] = 0.0;
793                }
794            }
795            self.model.cells[ci].selector.hidden.sgd_update(
796                &self.scratch_d_sel_h[..sh],
797                &ct.p,
798                spec.learning_rate_selector,
799                spec.grad_clip,
800                spec.optimizer_momentum,
801            );
802        }
803    }
804
805    /// Online SGD update after observing byte `y`.
806    fn sgd_update(&mut self, y: u8, spec: &ParticleSpec) {
807        // Readout gradient: d_logits = softmax - onehot(y)
808        self.scratch_d_logits.copy_from_slice(&self.scratch_softmax);
809        self.scratch_d_logits[y as usize] -= 1.0;
810
811        // Clip readout gradients
812        for v in self.scratch_d_logits.iter_mut() {
813            *v = clip(*v, spec.grad_clip);
814        }
815
816        // Update readout: W -= lr * d_logits ⊗ phi, b -= lr * d_logits
817        self.model.readout.sgd_update(
818            &self.scratch_d_logits,
819            &self.scratch_phi,
820            spec.learning_rate_readout,
821            spec.grad_clip,
822            0.0,
823        );
824
825        // Backprop to phi: d_phi = readout.W^T * d_logits
826        let phi_dim = self.model.phi_dim;
827        self.scratch_d_phi[..phi_dim].fill(0.0);
828        for r in 0..256 {
829            let g = clip(self.scratch_d_logits[r], spec.grad_clip);
830            if g.abs() < 1e-15 {
831                continue;
832            }
833            let row_start = r * phi_dim;
834            for c in 0..phi_dim {
835                self.scratch_d_phi[c] += g * self.model.readout.weights[row_start + c];
836            }
837        }
838
839        // Clip d_phi
840        for v in self.scratch_d_phi[..phi_dim].iter_mut() {
841            *v = clip(*v, spec.grad_clip);
842        }
843
844        if spec.learning_rate_selector > 0.0 || spec.learning_rate_rule > 0.0 {
845            let depth = spec.bptt_depth.max(1).min(self.trace_history.len());
846            let traces = std::mem::take(&mut self.trace_history);
847            let d_phi = self.scratch_d_phi[..phi_dim].to_vec();
848            let mut temporal = 1.0_f64;
849            let temporal_decay = 0.7_f64;
850            for idx in 0..depth {
851                let hist_idx = traces.len() - 1 - idx;
852                let trace = &traces[hist_idx];
853                self.apply_selector_rule_update_from_trace(trace, &d_phi, temporal, spec);
854                temporal *= temporal_decay;
855            }
856            self.trace_history = traces;
857        }
858    }
859
860    /// Push a byte into the context ring buffer.
861    fn push_context(&mut self, byte: u8) {
862        self.context[self.ctx_pos] = byte;
863        self.ctx_pos = (self.ctx_pos + 1) % self.context.len();
864        self.ctx_len += 1;
865    }
866
867    fn reset_dynamic_state(&mut self) {
868        self.cells.fill(0.0);
869        self.context.fill(0);
870        self.ctx_pos = 0;
871        self.ctx_len = 0;
872        self.cached_log_probs.fill(0.0);
873        self.cache_valid = false;
874        self.scratch_ctx.fill(0.0);
875        self.scratch_mean_cells.fill(0.0);
876        self.scratch_p.fill(0.0);
877        self.scratch_sel_h.fill(0.0);
878        self.scratch_gate.fill(0.0);
879        self.scratch_rule_in.fill(0.0);
880        self.scratch_rule_h.fill(0.0);
881        self.scratch_delta_k.fill(0.0);
882        self.scratch_delta.fill(0.0);
883        self.scratch_phi.fill(0.0);
884        self.scratch_logits.fill(0.0);
885        self.scratch_d_logits.fill(0.0);
886        self.scratch_d_phi.fill(0.0);
887        self.scratch_softmax.fill(0.0);
888        self.scratch_d_rule_out.fill(0.0);
889        self.scratch_d_rule_h.fill(0.0);
890        self.scratch_d_gate.fill(0.0);
891        self.scratch_d_gate_logits.fill(0.0);
892        self.scratch_d_sel_h.fill(0.0);
893        self.trace_history.clear();
894    }
895}
896
897// ---------------------------------------------------------------------------
898// ParticleRuntime: the ensemble predictor (public API)
899// ---------------------------------------------------------------------------
900
901/// Runtime for the particle-latent filter ensemble.
902///
903/// Implements [`crate::mixture::OnlineBytePredictor`] and provides
904/// `pdf_next()` for compression compatibility.
905pub struct ParticleRuntime {
906    spec: ParticleSpec,
907    particles: Vec<ParticleState>,
908    log_weights: Vec<f64>,
909    /// Cached mixture log-probabilities [256].
910    mix_log_probs: [f64; 256],
911    /// Cached mixture PDF [256].
912    mix_pdf: Vec<f64>,
913    /// Whether mix_log_probs is valid.
914    cache_valid: bool,
915    /// Step counter for deterministic hash.
916    step_idx: u64,
917    /// Scratch for logsumexp across particles.
918    scratch_lse: Vec<f64>,
919}
920
921impl ParticleRuntime {
922    #[inline]
923    fn likelihood_beta(&self) -> f64 {
924        // Early temperature prevents particle-weight collapse before experts
925        // have adapted. Anneal to full Bayes update.
926        const BETA_MIN: f64 = 0.35;
927        const WARMUP_STEPS: u64 = 2048;
928        if self.step_idx >= WARMUP_STEPS {
929            1.0
930        } else {
931            BETA_MIN + (1.0 - BETA_MIN) * (self.step_idx as f64 / WARMUP_STEPS as f64)
932        }
933    }
934
935    #[inline]
936    fn diagnostics_enabled(&self) -> bool {
937        self.spec.diagnostics_interval > 0
938            && self
939                .step_idx
940                .is_multiple_of(self.spec.diagnostics_interval as u64)
941    }
942
943    #[inline]
944    fn weight_stats(&self) -> (f64, f64) {
945        let mut sum_sq = 0.0;
946        let mut max_w = 0.0;
947        for &lw in &self.log_weights {
948            let w = lw.exp();
949            sum_sq += w * w;
950            if w > max_w {
951                max_w = w;
952            }
953        }
954        let n_eff = if sum_sq > 0.0 { 1.0 / sum_sq } else { 0.0 };
955        (n_eff, max_w)
956    }
957
958    fn weighted_prediction_kl_divergence(&self) -> f64 {
959        // D = Σ_i α_i KL(p_i || p_mix), where α_i = softmax(log_weights).
960        // This is the ensemble disagreement signal: zero means collapse.
961        let n = self.particles.len();
962        if n == 0 {
963            return 0.0;
964        }
965        let log_z = logsumexp_wide(&self.log_weights);
966        let mut mix_log_probs = [0.0_f64; 256];
967        let mut scratch_lse = vec![0.0_f64; n];
968        for (v, mix_logp) in mix_log_probs.iter_mut().enumerate() {
969            for (slot, (log_weight, particle)) in scratch_lse
970                .iter_mut()
971                .zip(self.log_weights.iter().zip(self.particles.iter()))
972            {
973                *slot = *log_weight + particle.cached_log_probs[v];
974            }
975            *mix_logp = logsumexp_wide(&scratch_lse) - log_z;
976        }
977        let mut d = 0.0_f64;
978        for (i, p) in self.particles.iter().enumerate() {
979            let alpha = self.log_weights[i].exp();
980            if alpha <= 0.0 {
981                continue;
982            }
983            let mut kl_i = 0.0_f64;
984            for (&lp_i, &mix_logp) in p.cached_log_probs.iter().zip(mix_log_probs.iter()) {
985                let prob_i = lp_i.exp();
986                kl_i += prob_i * (lp_i - mix_logp);
987            }
988            d += alpha * kl_i.max(0.0);
989        }
990        d
991    }
992
993    fn log_diagnostics(
994        &self,
995        n_eff: f64,
996        max_weight: f64,
997        divergence: f64,
998        beta: f64,
999        will_resample: bool,
1000    ) {
1001        eprintln!(
1002            "[particle] step={} neff={:.3}/{:.0} max_w={:.3}% div_kl={:.6} beta={:.3} resample={}",
1003            self.step_idx,
1004            n_eff,
1005            self.particles.len() as f64,
1006            max_weight * 100.0,
1007            divergence,
1008            beta,
1009            will_resample
1010        );
1011    }
1012
1013    fn diversify_initial_particles(&mut self) {
1014        let scale = 5e-3_f64;
1015        if self.particles.len() <= 1 {
1016            return;
1017        }
1018        for pi in 1..self.particles.len() {
1019            let p = &mut self.particles[pi];
1020            for (idx, v) in p.cells.iter_mut().enumerate() {
1021                let noise = hash_to_f64(det_hash(self.spec.seed, pi as u64, idx as u64, 1000));
1022                *v += noise * scale;
1023            }
1024            for (idx, v) in p.model.readout.bias.iter_mut().enumerate() {
1025                let noise = hash_to_f64(det_hash(self.spec.seed, pi as u64, idx as u64, 1001));
1026                *v += noise * scale;
1027            }
1028            for (idx, v) in p.model.readout.weights.iter_mut().enumerate() {
1029                let noise = hash_to_f64(det_hash(self.spec.seed, pi as u64, idx as u64, 1002));
1030                *v += noise * (scale * 0.5);
1031            }
1032        }
1033    }
1034
1035    /// Create a new particle runtime from a spec.
1036    pub fn new(spec: &ParticleSpec) -> Self {
1037        let n = spec.num_particles;
1038
1039        // Each particle is independently initialized with its own seed derived
1040        // from the master seed via a multiplicative hash spread.  This ensures
1041        // genuine parameter diversity: every embedding table, selector MLP, rule
1042        // MLP, and readout starts from a distinct random point in weight space.
1043        //
1044        // Without this, all particles are near-identical clones (only differing
1045        // by the tiny 5e-3 perturbation in diversify_initial_particles), and with
1046        // identical selector/rule MLPs they evolve identically regardless of
1047        // learning rate.  The ensemble then collapses to a single-particle model,
1048        // explaining why changing num_particles has no effect on compression.
1049        //
1050        // Using φ-like multiplicative spread (0x9e3779b9 ≈ 2^32/φ) ensures that
1051        // seeds for different particle indices are maximally well-separated in the
1052        // 64-bit hash space.
1053        let particles: Vec<ParticleState> = (0..n)
1054            .map(|pi| {
1055                let particle_seed = spec
1056                    .seed
1057                    .wrapping_add((pi as u64).wrapping_mul(0x9e3779b97f4a7c15u64));
1058                let mut model = ParticleModel::new(spec);
1059                model.init(particle_seed, spec);
1060                ParticleState::new(spec, model, pi as u64)
1061            })
1062            .collect();
1063
1064        let log_w = -(n as f64).ln();
1065        let mut rt = Self {
1066            spec: spec.clone(),
1067            particles,
1068            log_weights: vec![log_w; n],
1069            mix_log_probs: [0.0; 256],
1070            mix_pdf: vec![0.0; 256],
1071            cache_valid: false,
1072            step_idx: 0,
1073            scratch_lse: vec![0.0; n],
1074        };
1075        rt.diversify_initial_particles();
1076        rt
1077    }
1078
1079    /// Ensure all particles have valid cached log-probabilities.
1080    fn ensure_predictions(&mut self) {
1081        if self.cache_valid {
1082            return;
1083        }
1084        let spec = &self.spec;
1085        for p in &mut self.particles {
1086            if !p.cache_valid {
1087                p.forward(spec, self.step_idx);
1088            }
1089        }
1090        self.compute_mixture_log_probs();
1091        self.cache_valid = true;
1092    }
1093
1094    /// Compute mixture log-probabilities by weighting particle predictions.
1095    fn compute_mixture_log_probs(&mut self) {
1096        let n = self.particles.len();
1097        // log_z = logsumexp(log_weights)
1098        let log_z = logsumexp_wide(&self.log_weights);
1099
1100        for v in 0..256 {
1101            for i in 0..n {
1102                self.scratch_lse[i] = self.log_weights[i] + self.particles[i].cached_log_probs[v];
1103            }
1104            self.mix_log_probs[v] = logsumexp_wide(&self.scratch_lse) - log_z;
1105        }
1106
1107        // Also compute PDF for compression
1108        let max_lp = max_wide(&self.mix_log_probs);
1109        let mut sum = 0.0;
1110        for v in 0..256 {
1111            let p = (self.mix_log_probs[v] - max_lp).exp();
1112            self.mix_pdf[v] = p;
1113            sum += p;
1114        }
1115        if sum > 0.0 {
1116            let inv = 1.0 / sum;
1117            for v in &mut self.mix_pdf {
1118                *v *= inv;
1119            }
1120        }
1121    }
1122
1123    /// Non-mutating log-probability query for a single symbol.
1124    pub fn peek_log_prob(&mut self, symbol: u8) -> f64 {
1125        self.ensure_predictions();
1126        self.mix_log_probs[symbol as usize]
1127    }
1128
1129    /// Fill 256-way log-probabilities (non-mutating).
1130    pub fn fill_log_probs_cached(&mut self, out: &mut [f64; 256]) {
1131        self.ensure_predictions();
1132        *out = self.mix_log_probs;
1133    }
1134
1135    /// Return 256-element PDF slice for compression.
1136    pub fn pdf_next(&mut self) -> &[f64] {
1137        self.ensure_predictions();
1138        &self.mix_pdf
1139    }
1140
1141    /// Observe byte `y`: return ln(p(y)) then update ensemble state.
1142    pub fn step(&mut self, symbol: u8) -> f64 {
1143        self.ensure_predictions();
1144        let log_prob = self.mix_log_probs[symbol as usize];
1145
1146        let n = self.particles.len();
1147        let spec = &self.spec;
1148
1149        // (1) Weight update: logw_i += logq_i[y]
1150        let beta = self.likelihood_beta();
1151        for i in 0..n {
1152            self.log_weights[i] += beta * self.particles[i].cached_log_probs[symbol as usize];
1153        }
1154        // Normalize log-weights
1155        let log_z = logsumexp_wide(&self.log_weights);
1156        for w in &mut self.log_weights {
1157            *w -= log_z;
1158        }
1159
1160        // (2) Forgetting
1161        if spec.forget_lambda > 0.0 {
1162            let uniform = -(n as f64).ln();
1163            for w in &mut self.log_weights {
1164                *w = (1.0 - spec.forget_lambda) * *w + spec.forget_lambda * uniform;
1165            }
1166            // Renormalize
1167            let log_z2 = logsumexp_wide(&self.log_weights);
1168            for w in &mut self.log_weights {
1169                *w -= log_z2;
1170            }
1171        }
1172
1173        let (n_eff_before, max_w_before) = self.weight_stats();
1174        let will_resample = n_eff_before < self.spec.resample_threshold * n as f64;
1175        let should_log = self.diagnostics_enabled();
1176        let divergence = if should_log {
1177            self.weighted_prediction_kl_divergence()
1178        } else {
1179            0.0
1180        };
1181
1182        // (3) Online SGD per particle
1183        for p in &mut self.particles {
1184            p.sgd_update(symbol, spec);
1185        }
1186
1187        // (4) Push context byte
1188        for p in &mut self.particles {
1189            p.push_context(symbol);
1190        }
1191
1192        if should_log {
1193            self.log_diagnostics(n_eff_before, max_w_before, divergence, beta, will_resample);
1194        }
1195
1196        // (5) Resample check
1197        let _ = self.maybe_resample();
1198
1199        // Invalidate caches
1200        for p in &mut self.particles {
1201            p.cache_valid = false;
1202        }
1203        self.cache_valid = false;
1204        self.step_idx += 1;
1205
1206        log_prob
1207    }
1208
1209    /// Reset dynamic inference state while preserving learned model parameters and particle weights.
1210    ///
1211    /// Frozen/plugin scoring for particles intentionally keeps the learned
1212    /// ensemble and posterior weights from the fit pass. Only stream-local
1213    /// recurrent/context state is reset before the score pass begins.
1214    pub fn reset_frozen_state(&mut self) {
1215        for particle in &mut self.particles {
1216            particle.reset_dynamic_state();
1217        }
1218        self.mix_log_probs.fill(0.0);
1219        self.mix_pdf.fill(1.0 / 256.0);
1220        self.cache_valid = false;
1221        self.step_idx = 0;
1222    }
1223
1224    /// Advance the ensemble without SGD/model adaptation using the current observation.
1225    ///
1226    /// This still performs the Bayesian posterior-weight update over particles,
1227    /// because that latent filtering step is part of the fixed model's
1228    /// inference dynamics rather than a new fit/update of model parameters.
1229    pub fn update_frozen(&mut self, symbol: u8) {
1230        self.ensure_predictions();
1231        let n = self.particles.len();
1232        let spec = &self.spec;
1233
1234        let beta = self.likelihood_beta();
1235        for i in 0..n {
1236            self.log_weights[i] += beta * self.particles[i].cached_log_probs[symbol as usize];
1237        }
1238        let log_z = logsumexp_wide(&self.log_weights);
1239        for weight in &mut self.log_weights {
1240            *weight -= log_z;
1241        }
1242
1243        if spec.forget_lambda > 0.0 {
1244            let uniform = -(n as f64).ln();
1245            for weight in &mut self.log_weights {
1246                *weight = (1.0 - spec.forget_lambda) * *weight + spec.forget_lambda * uniform;
1247            }
1248            let log_z2 = logsumexp_wide(&self.log_weights);
1249            for weight in &mut self.log_weights {
1250                *weight -= log_z2;
1251            }
1252        }
1253
1254        for particle in &mut self.particles {
1255            particle.push_context(symbol);
1256            particle.cache_valid = false;
1257        }
1258        self.cache_valid = false;
1259        self.step_idx += 1;
1260    }
1261
1262    /// Check effective sample size and resample if needed.
1263    fn maybe_resample(&mut self) -> bool {
1264        let n = self.particles.len();
1265        if n <= 1 {
1266            return false;
1267        }
1268
1269        // Compute Neff = 1 / Σ α_i^2 where α = softmax(logw)
1270        // = exp(-logsumexp(2*logw)) / exp(2*(-logsumexp(logw)))
1271        // But logw is already normalized, so logsumexp(logw) ≈ 0
1272        let mut sum_sq = 0.0;
1273        for &lw in &self.log_weights {
1274            let w = lw.exp();
1275            sum_sq += w * w;
1276        }
1277        let n_eff = if sum_sq > 0.0 { 1.0 / sum_sq } else { 0.0 };
1278
1279        if n_eff >= self.spec.resample_threshold * n as f64 {
1280            return false;
1281        }
1282
1283        // Deterministic systematic resampling with fixed offset 0.5/n
1284        let weights: Vec<f64> = self.log_weights.iter().map(|lw| lw.exp()).collect();
1285        let cdf: Vec<f64> = weights
1286            .iter()
1287            .scan(0.0, |acc, &w| {
1288                *acc += w;
1289                Some(*acc)
1290            })
1291            .collect();
1292        let total = *cdf.last().unwrap_or(&1.0);
1293
1294        let step = total / n as f64;
1295        // Deterministic stratified offset from hash to avoid repeating the same
1296        // systematic pattern at every resample event.
1297        let u0 =
1298            ((det_hash(self.spec.seed, self.step_idx, 0, 0) >> 11) as f64) / ((1u64 << 53) as f64);
1299        let mut u = u0 * step;
1300        let mut indices = Vec::with_capacity(n);
1301        let mut j = 0;
1302        for _ in 0..n {
1303            while j < n - 1 && cdf[j] < u {
1304                j += 1;
1305            }
1306            indices.push(j);
1307            u += step;
1308        }
1309
1310        // Clone selected particles
1311        let new_particles: Vec<ParticleState> = indices
1312            .iter()
1313            .map(|&idx| self.particles[idx].clone())
1314            .collect();
1315        self.particles = new_particles;
1316
1317        // Mutate a fraction of particles
1318        let n_mutate = ((self.spec.mutate_fraction * n as f64).round() as usize).min(n);
1319        let mut mutated = vec![false; n];
1320        let mut picked = 0usize;
1321        let mut draw = 0u64;
1322        while picked < n_mutate && draw < (n * 8) as u64 {
1323            let mi = (det_hash(self.spec.seed ^ self.step_idx, draw, 0xA5A5, 0x5A5A) as usize) % n;
1324            if !mutated[mi] {
1325                self.mutate_particle(mi);
1326                mutated[mi] = true;
1327                picked += 1;
1328            }
1329            draw += 1;
1330        }
1331
1332        // Reset log-weights to uniform
1333        let uniform = -(n as f64).ln();
1334        for w in &mut self.log_weights {
1335            *w = uniform;
1336        }
1337        true
1338    }
1339
1340    /// Apply deterministic mutation: always perturb latent state; optionally
1341    /// perturb model parameters when explicitly enabled.
1342    fn mutate_particle(&mut self, particle_idx: usize) {
1343        let seed = self.spec.seed;
1344        let step = self.step_idx;
1345        let pi = particle_idx as u64;
1346        let scale = self.spec.mutate_scale;
1347        let state_clip = self.spec.state_clip;
1348
1349        let p = &mut self.particles[particle_idx];
1350        let mut param_idx = 0u64;
1351
1352        // Mutate cell states
1353        for v in p.cells.iter_mut() {
1354            let noise = hash_to_f64(det_hash(seed ^ step, pi, param_idx, 0)) * scale;
1355            *v = clip(*v + noise, state_clip);
1356            param_idx += 1;
1357        }
1358
1359        if !self.spec.mutate_model_params {
1360            return;
1361        }
1362
1363        let layer_scale = |vals: &[f64]| -> f64 {
1364            if vals.is_empty() {
1365                return 1.0;
1366            }
1367            let mut s = 0.0_f64;
1368            for &v in vals {
1369                s += v * v;
1370            }
1371            (s / vals.len() as f64).sqrt().max(1e-6)
1372        };
1373
1374        // Mutate model parameters with layer-adaptive scale to avoid destructive jumps.
1375        let embed_layer = layer_scale(&p.model.embed);
1376        for v in p.model.embed.iter_mut() {
1377            let noise =
1378                hash_to_f64(det_hash(seed ^ step, pi, param_idx, 1)) * (scale * embed_layer);
1379            *v += noise;
1380            param_idx += 1;
1381        }
1382
1383        // Cell params
1384        for cp in p.model.cells.iter_mut() {
1385            let sel_h_w = layer_scale(&cp.selector.hidden.weights);
1386            let sel_h_b = layer_scale(&cp.selector.hidden.bias);
1387            for v in cp.selector.hidden.weights.iter_mut() {
1388                let noise =
1389                    hash_to_f64(det_hash(seed ^ step, pi, param_idx, 2)) * (scale * sel_h_w);
1390                *v += noise;
1391                param_idx += 1;
1392            }
1393            for v in cp.selector.hidden.bias.iter_mut() {
1394                let noise =
1395                    hash_to_f64(det_hash(seed ^ step, pi, param_idx, 3)) * (scale * sel_h_b);
1396                *v += noise;
1397                param_idx += 1;
1398            }
1399            let sel_g_w = layer_scale(&cp.selector.gate.weights);
1400            let sel_g_b = layer_scale(&cp.selector.gate.bias);
1401            for v in cp.selector.gate.weights.iter_mut() {
1402                let noise =
1403                    hash_to_f64(det_hash(seed ^ step, pi, param_idx, 4)) * (scale * sel_g_w);
1404                *v += noise;
1405                param_idx += 1;
1406            }
1407            for v in cp.selector.gate.bias.iter_mut() {
1408                let noise =
1409                    hash_to_f64(det_hash(seed ^ step, pi, param_idx, 5)) * (scale * sel_g_b);
1410                *v += noise;
1411                param_idx += 1;
1412            }
1413            for rule in cp.rules.iter_mut() {
1414                let rule_h_w = layer_scale(&rule.hidden.weights);
1415                let rule_h_b = layer_scale(&rule.hidden.bias);
1416                let rule_o_w = layer_scale(&rule.output.weights);
1417                let rule_o_b = layer_scale(&rule.output.bias);
1418                for v in rule.hidden.weights.iter_mut() {
1419                    let noise =
1420                        hash_to_f64(det_hash(seed ^ step, pi, param_idx, 6)) * (scale * rule_h_w);
1421                    *v += noise;
1422                    param_idx += 1;
1423                }
1424                for v in rule.hidden.bias.iter_mut() {
1425                    let noise =
1426                        hash_to_f64(det_hash(seed ^ step, pi, param_idx, 7)) * (scale * rule_h_b);
1427                    *v += noise;
1428                    param_idx += 1;
1429                }
1430                for v in rule.output.weights.iter_mut() {
1431                    let noise =
1432                        hash_to_f64(det_hash(seed ^ step, pi, param_idx, 8)) * (scale * rule_o_w);
1433                    *v += noise;
1434                    param_idx += 1;
1435                }
1436                for v in rule.output.bias.iter_mut() {
1437                    let noise =
1438                        hash_to_f64(det_hash(seed ^ step, pi, param_idx, 9)) * (scale * rule_o_b);
1439                    *v += noise;
1440                    param_idx += 1;
1441                }
1442            }
1443        }
1444
1445        // Readout
1446        let readout_w = layer_scale(&p.model.readout.weights);
1447        let readout_b = layer_scale(&p.model.readout.bias);
1448        for v in p.model.readout.weights.iter_mut() {
1449            let noise = hash_to_f64(det_hash(seed ^ step, pi, param_idx, 10)) * (scale * readout_w);
1450            *v += noise;
1451            param_idx += 1;
1452        }
1453        for v in p.model.readout.bias.iter_mut() {
1454            let noise = hash_to_f64(det_hash(seed ^ step, pi, param_idx, 11)) * (scale * readout_b);
1455            *v += noise;
1456            param_idx += 1;
1457        }
1458    }
1459}
1460
1461impl Clone for ParticleRuntime {
1462    fn clone(&self) -> Self {
1463        Self {
1464            spec: self.spec.clone(),
1465            particles: self.particles.clone(),
1466            log_weights: self.log_weights.clone(),
1467            mix_log_probs: self.mix_log_probs,
1468            mix_pdf: self.mix_pdf.clone(),
1469            cache_valid: self.cache_valid,
1470            step_idx: self.step_idx,
1471            scratch_lse: self.scratch_lse.clone(),
1472        }
1473    }
1474}
1475
1476// Implement OnlineBytePredictor
1477impl crate::mixture::OnlineBytePredictor for ParticleRuntime {
1478    fn log_prob(&mut self, symbol: u8) -> f64 {
1479        self.peek_log_prob(symbol)
1480    }
1481
1482    fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
1483        self.fill_log_probs_cached(out)
1484    }
1485
1486    fn update(&mut self, symbol: u8) {
1487        self.step(symbol);
1488    }
1489}
1490
1491// ---------------------------------------------------------------------------
1492// Tests
1493// ---------------------------------------------------------------------------
1494
1495#[cfg(test)]
1496mod tests {
1497    use super::*;
1498
1499    fn default_spec() -> ParticleSpec {
1500        ParticleSpec {
1501            num_particles: 4,
1502            context_window: 8,
1503            unroll_steps: 1,
1504            num_cells: 2,
1505            cell_dim: 4,
1506            num_rules: 2,
1507            selector_hidden: 8,
1508            rule_hidden: 8,
1509            noise_dim: 2,
1510            ..ParticleSpec::default()
1511        }
1512    }
1513
1514    #[test]
1515    fn pdf_sums_to_one() {
1516        let spec = default_spec();
1517        let mut rt = ParticleRuntime::new(&spec);
1518        let pdf = rt.pdf_next();
1519        let sum: f64 = pdf.iter().sum();
1520        assert!((sum - 1.0).abs() < 1e-6, "PDF sum = {sum}, expected ~1.0");
1521    }
1522
1523    #[test]
1524    fn log_probs_finite_and_nonpositive() {
1525        let spec = default_spec();
1526        let mut rt = ParticleRuntime::new(&spec);
1527        let data = b"hello world";
1528        for &b in data.iter() {
1529            let lp = rt.peek_log_prob(b);
1530            assert!(lp.is_finite(), "log_prob not finite: {lp}");
1531            assert!(lp <= 0.0, "log_prob positive: {lp}");
1532            rt.step(b);
1533        }
1534    }
1535
1536    #[test]
1537    fn deterministic_same_seed() {
1538        let spec = default_spec();
1539        let data = b"abcdefghij";
1540
1541        let mut rt1 = ParticleRuntime::new(&spec);
1542        let mut rt2 = ParticleRuntime::new(&spec);
1543
1544        for &b in data.iter() {
1545            let lp1 = rt1.step(b);
1546            let lp2 = rt2.step(b);
1547            assert!(
1548                (lp1 - lp2).abs() < 1e-12,
1549                "Mismatch at byte {b}: {lp1} vs {lp2}"
1550            );
1551        }
1552    }
1553
1554    #[test]
1555    fn deterministic_with_hash_noise_enabled() {
1556        let spec = ParticleSpec {
1557            enable_noise: true,
1558            noise_scale: 0.15,
1559            noise_anneal_steps: 128,
1560            ..default_spec()
1561        };
1562        let data = b"particle noise determinism";
1563
1564        let mut rt1 = ParticleRuntime::new(&spec);
1565        let mut rt2 = ParticleRuntime::new(&spec);
1566
1567        for &b in data {
1568            let lp1 = rt1.step(b);
1569            let lp2 = rt2.step(b);
1570            assert!(
1571                (lp1 - lp2).abs() < 1e-12,
1572                "Hash-noise path non-deterministic at byte {b}: {lp1} vs {lp2}"
1573            );
1574        }
1575    }
1576
1577    #[test]
1578    fn resample_forced() {
1579        let spec = ParticleSpec {
1580            resample_threshold: 1.0, // always resample
1581            ..default_spec()
1582        };
1583        let mut rt = ParticleRuntime::new(&spec);
1584        // Should not panic even with constant resampling
1585        for &b in b"test resampling works ok" {
1586            let lp = rt.step(b);
1587            assert!(lp.is_finite(), "log_prob not finite after resample: {lp}");
1588        }
1589    }
1590
1591    #[test]
1592    fn mutation_determinism() {
1593        let spec = ParticleSpec {
1594            resample_threshold: 1.0,
1595            mutate_fraction: 1.0,
1596            ..default_spec()
1597        };
1598        let data = b"test mutation";
1599
1600        let mut rt1 = ParticleRuntime::new(&spec);
1601        let mut rt2 = ParticleRuntime::new(&spec);
1602
1603        for &b in data.iter() {
1604            let lp1 = rt1.step(b);
1605            let lp2 = rt2.step(b);
1606            assert!(
1607                (lp1 - lp2).abs() < 1e-12,
1608                "Mutation non-deterministic at byte {b}: {lp1} vs {lp2}"
1609            );
1610        }
1611    }
1612
1613    #[test]
1614    fn empty_input_no_crash() {
1615        let spec = default_spec();
1616        let mut rt = ParticleRuntime::new(&spec);
1617        // Just verify we can get predictions without any input
1618        let lp = rt.peek_log_prob(0);
1619        assert!(lp.is_finite());
1620    }
1621
1622    #[test]
1623    fn fill_log_probs_consistency() {
1624        let spec = default_spec();
1625        let mut rt = ParticleRuntime::new(&spec);
1626        rt.step(b'a');
1627        rt.step(b'b');
1628
1629        let mut bulk = [0.0; 256];
1630        rt.fill_log_probs_cached(&mut bulk);
1631
1632        for sym in 0..256u16 {
1633            let single = rt.peek_log_prob(sym as u8);
1634            assert!(
1635                (bulk[sym as usize] - single).abs() < 1e-12,
1636                "Mismatch for sym {sym}: bulk={} single={}",
1637                bulk[sym as usize],
1638                single
1639            );
1640        }
1641    }
1642
1643    #[test]
1644    fn spec_validation() {
1645        let mut spec = ParticleSpec::default();
1646        assert!(spec.validate().is_ok());
1647
1648        spec.num_particles = 0;
1649        assert!(spec.validate().is_err());
1650        spec.num_particles = 4;
1651
1652        spec.resample_threshold = 0.0;
1653        assert!(spec.validate().is_err());
1654        spec.resample_threshold = 0.5;
1655
1656        spec.min_prob = -1.0;
1657        assert!(spec.validate().is_err());
1658    }
1659
1660    fn assert_models_equal(lhs: &ParticleModel, rhs: &ParticleModel) {
1661        assert_eq!(lhs.embed, rhs.embed);
1662        assert_eq!(lhs.readout.weights, rhs.readout.weights);
1663        assert_eq!(lhs.readout.bias, rhs.readout.bias);
1664        assert_eq!(lhs.readout.vel_weights, rhs.readout.vel_weights);
1665        assert_eq!(lhs.readout.vel_bias, rhs.readout.vel_bias);
1666        assert_eq!(lhs.cells.len(), rhs.cells.len());
1667        for (lhs_cell, rhs_cell) in lhs.cells.iter().zip(rhs.cells.iter()) {
1668            assert_eq!(
1669                lhs_cell.selector.hidden.weights,
1670                rhs_cell.selector.hidden.weights
1671            );
1672            assert_eq!(lhs_cell.selector.hidden.bias, rhs_cell.selector.hidden.bias);
1673            assert_eq!(
1674                lhs_cell.selector.hidden.vel_weights,
1675                rhs_cell.selector.hidden.vel_weights
1676            );
1677            assert_eq!(
1678                lhs_cell.selector.hidden.vel_bias,
1679                rhs_cell.selector.hidden.vel_bias
1680            );
1681            assert_eq!(
1682                lhs_cell.selector.gate.weights,
1683                rhs_cell.selector.gate.weights
1684            );
1685            assert_eq!(lhs_cell.selector.gate.bias, rhs_cell.selector.gate.bias);
1686            assert_eq!(
1687                lhs_cell.selector.gate.vel_weights,
1688                rhs_cell.selector.gate.vel_weights
1689            );
1690            assert_eq!(
1691                lhs_cell.selector.gate.vel_bias,
1692                rhs_cell.selector.gate.vel_bias
1693            );
1694            assert_eq!(lhs_cell.rules.len(), rhs_cell.rules.len());
1695            for (lhs_rule, rhs_rule) in lhs_cell.rules.iter().zip(rhs_cell.rules.iter()) {
1696                assert_eq!(lhs_rule.hidden.weights, rhs_rule.hidden.weights);
1697                assert_eq!(lhs_rule.hidden.bias, rhs_rule.hidden.bias);
1698                assert_eq!(lhs_rule.hidden.vel_weights, rhs_rule.hidden.vel_weights);
1699                assert_eq!(lhs_rule.hidden.vel_bias, rhs_rule.hidden.vel_bias);
1700                assert_eq!(lhs_rule.output.weights, rhs_rule.output.weights);
1701                assert_eq!(lhs_rule.output.bias, rhs_rule.output.bias);
1702                assert_eq!(lhs_rule.output.vel_weights, rhs_rule.output.vel_weights);
1703                assert_eq!(lhs_rule.output.vel_bias, rhs_rule.output.vel_bias);
1704            }
1705        }
1706    }
1707
1708    #[test]
1709    fn frozen_update_preserves_model_parameters() {
1710        let spec = default_spec();
1711        let mut rt = ParticleRuntime::new(&spec);
1712        for &b in b"particle plugin separation" {
1713            rt.step(b);
1714        }
1715
1716        let before_models: Vec<_> = rt.particles.iter().map(|p| p.model.clone()).collect();
1717        rt.reset_frozen_state();
1718        assert!(rt.particles.iter().all(|p| p.ctx_len == 0));
1719
1720        let lp = rt.peek_log_prob(b'x');
1721        assert!(lp.is_finite());
1722        rt.update_frozen(b'x');
1723
1724        for (before, particle) in before_models.iter().zip(rt.particles.iter()) {
1725            assert_models_equal(before, &particle.model);
1726        }
1727        assert_eq!(rt.step_idx, 1);
1728        assert!(rt.particles.iter().all(|p| p.ctx_len == 1));
1729    }
1730}