infotheory/backends/rwkvzip/rwkv7/
model.rs

1//! RWKV7 model implementation with portable SIMD-optimized inference.
2//!
3//! This is a high-performance implementation built on `wide` kernels.
4//! Single-token inference is the primary use case (streaming compression).
5
6use crate::backends::llm_policy::OptimizerKind;
7use anyhow::{Context, Result, bail};
8use serde_json::json;
9use std::fs::File;
10use std::io::Write;
11use std::path::Path;
12use std::time::Instant;
13use wide::f32x8;
14
15use super::kernel;
16use super::profiling::{NullProfiler, ProfilerSink};
17use super::tensor::Tensor1D;
18use super::weights::Weights;
19
20/// Model configuration.
21#[derive(Debug, Clone)]
22pub struct Config {
23    /// Vocabulary size (byte-level models use 256).
24    pub vocab_size: usize,
25    /// Hidden channel width `C`.
26    pub hidden_size: usize,
27    /// Number of RWKV blocks.
28    pub num_layers: usize,
29    /// Number of attention heads `H`.
30    pub num_heads: usize,
31    /// Per-head channel width `N` (currently fixed to 64 in kernels).
32    pub head_dim: usize,
33    /// Feed-forward intermediate width.
34    pub intermediate_size: usize,
35    /// Epsilon for layer normalization.
36    pub layer_norm_eps: f32,
37    /// Epsilon for group normalization (`64e-5` in reference).
38    pub group_norm_eps: f32,
39
40    /// Low-rank width for decay projection.
41    pub decay_low_rank: usize, // w_lora
42    /// Low-rank width for `a` projection.
43    pub a_low_rank: usize,
44    /// Low-rank width for `v` projection.
45    pub v_low_rank: usize,
46    /// Low-rank width for `g` projection.
47    pub g_low_rank: usize,
48}
49
50impl Default for Config {
51    fn default() -> Self {
52        Self {
53            vocab_size: 256,
54            hidden_size: 256,
55            num_layers: 12,
56            num_heads: 4, // 256 / 64
57            head_dim: 64,
58            intermediate_size: 1024,
59            layer_norm_eps: 1e-5,
60            group_norm_eps: 64e-5,
61            decay_low_rank: 32,
62            a_low_rank: 32,
63            v_low_rank: 32,
64            g_low_rank: 64,
65        }
66    }
67}
68
69impl Config {
70    /// Validate configuration invariants required by current kernels.
71    pub fn validate(&self) -> Result<()> {
72        if self.vocab_size == 0 {
73            bail!("rwkv7 vocab_size must be > 0");
74        }
75        if self.head_dim != 64 {
76            bail!("rwkv7 head_dim must be 64 for current kernels");
77        }
78        if self.hidden_size != self.num_heads * self.head_dim {
79            bail!(
80                "rwkv7 hidden_size must equal num_heads * head_dim ({} != {} * {})",
81                self.hidden_size,
82                self.num_heads,
83                self.head_dim
84            );
85        }
86        if self.num_layers == 0 {
87            bail!("rwkv7 num_layers must be > 0");
88        }
89        if self.intermediate_size == 0 {
90            bail!("rwkv7 intermediate_size must be > 0");
91        }
92        Ok(())
93    }
94}
95
96/// Per-layer state for RWKV7.
97#[derive(Clone)]
98pub struct LayerState {
99    /// Previous token embedding for attention time-shift (hidden_size,)
100    pub att_x_prev: Tensor1D,
101    /// Attention state matrix (num_heads, head_dim, head_dim) = (H, N, N)
102    pub att_state: Tensor1D, // Flat for SIMD access
103    /// Previous token embedding for FFN time-shift (hidden_size,)
104    pub ffn_x_prev: Tensor1D,
105}
106
107impl LayerState {
108    fn new(cfg: &Config) -> Self {
109        let state_size = cfg.num_heads * cfg.head_dim * cfg.head_dim;
110        Self {
111            att_x_prev: Tensor1D::zeros(cfg.hidden_size),
112            att_state: Tensor1D::zeros(state_size),
113            ffn_x_prev: Tensor1D::zeros(cfg.hidden_size),
114        }
115    }
116}
117
118/// Full model state.
119#[derive(Clone)]
120pub struct State {
121    /// Per-layer recurrent state.
122    pub layers: Vec<LayerState>,
123    /// First layer's value output (for residual connection) - pre-allocated
124    pub v_first: Tensor1D,
125    /// Flag to indicate if v_first has been set
126    pub v_first_set: bool,
127}
128
129impl State {
130    /// Allocate a zero-initialized recurrent state for a model configuration.
131    pub fn new(cfg: &Config) -> Self {
132        Self {
133            layers: (0..cfg.num_layers).map(|_| LayerState::new(cfg)).collect(),
134            v_first: Tensor1D::zeros(cfg.hidden_size),
135            v_first_set: false,
136        }
137    }
138
139    /// Reset recurrent buffers to their initial (all-zero) state.
140    pub fn reset(&mut self) {
141        self.v_first_set = false;
142        self.v_first.zero();
143        for layer in &mut self.layers {
144            layer.att_x_prev.zero();
145            layer.att_state.zero();
146            layer.ffn_x_prev.zero();
147        }
148    }
149}
150
151/// Weights for a single attention layer.
152#[derive(Clone)]
153struct AttentionWeights {
154    // Token shift mixing factors
155    x_r: Tensor1D,
156    x_w: Tensor1D,
157    x_k: Tensor1D,
158    x_v: Tensor1D,
159    x_a: Tensor1D,
160    x_g: Tensor1D,
161
162    // Packed r/k/v projections for parallel computation
163    // Layout: [r_proj (C*C), k_proj (C*C), v_proj (C*C)]
164    rkv_proj: Tensor1D,
165
166    // Output projection (stored transposed for efficient gemv)
167    o_proj: Tensor1D,
168
169    // Low-rank W: w = tanh(x @ w1) @ w2 + w0
170    w1: Tensor1D, // (C, D_w)
171    w2: Tensor1D, // (D_w, C)
172    w0: Tensor1D, // (C,)
173
174    // Low-rank A: a = sigmoid(x @ a1 @ a2 + a0)
175    a1: Tensor1D, // (C, D_a)
176    a2: Tensor1D, // (D_a, C)
177    a0: Tensor1D, // (C,)
178
179    // Low-rank V (layers > 0): nu = sigmoid(x @ v1 @ v2 + v0)
180    v1: Option<Tensor1D>, // (C, D_v)
181    v2: Option<Tensor1D>, // (D_v, C)
182    v0: Option<Tensor1D>, // (C,)
183
184    // Low-rank G: g = sigmoid(x @ g1) @ g2
185    g1: Tensor1D, // (C, D_g)
186    g2: Tensor1D, // (D_g, C)
187
188    // Key scaling
189    k_k: Tensor1D, // (C,)
190    k_a: Tensor1D, // (C,)
191    r_k: Tensor1D, // (H, N)
192
193    // Group norm for output
194    g_norm_w: Tensor1D, // (C,)
195    g_norm_b: Tensor1D, // (C,)
196}
197
198/// Weights for a single FFN layer.
199#[derive(Clone)]
200struct FfnWeights {
201    x_k: Tensor1D,     // (C,) time shift mix
202    key_w: Tensor1D,   // (C, I) -> relu(x @ W)^2
203    value_w: Tensor1D, // (I, C)
204}
205
206/// Weights for a single block.
207#[derive(Clone)]
208struct BlockWeights {
209    // Pre-norm (layer 0 only)
210    pre_norm_w: Option<Tensor1D>,
211    pre_norm_b: Option<Tensor1D>,
212
213    // Attention norm
214    attn_norm_w: Tensor1D,
215    attn_norm_b: Tensor1D,
216
217    // FFN norm
218    ffn_norm_w: Tensor1D,
219    ffn_norm_b: Tensor1D,
220
221    attn: AttentionWeights,
222    ffn: FfnWeights,
223}
224
225/// RWKV7 model.
226#[derive(Clone)]
227pub struct Model {
228    cfg: Config,
229
230    // Embeddings (vocab_size, hidden_size)
231    embeddings: Tensor1D,
232
233    // Output norm
234    ln_out_w: Tensor1D,
235    ln_out_b: Tensor1D,
236
237    // LM head (vocab_size, hidden_size)
238    lm_head: Tensor1D,
239
240    // Layers
241    blocks: Vec<BlockWeights>,
242}
243
244#[derive(Clone)]
245struct AdamTensorState {
246    m: Tensor1D,
247    v: Tensor1D,
248}
249
250impl AdamTensorState {
251    #[inline]
252    fn new(len: usize) -> Self {
253        Self {
254            m: Tensor1D::zeros(len),
255            v: Tensor1D::zeros(len),
256        }
257    }
258}
259
260#[derive(Clone)]
261struct AttentionAdamState {
262    x_r: AdamTensorState,
263    x_w: AdamTensorState,
264    x_k: AdamTensorState,
265    x_v: AdamTensorState,
266    x_a: AdamTensorState,
267    x_g: AdamTensorState,
268    rkv_proj: AdamTensorState,
269    o_proj: AdamTensorState,
270    w1: AdamTensorState,
271    w2: AdamTensorState,
272    w0: AdamTensorState,
273    a1: AdamTensorState,
274    a2: AdamTensorState,
275    a0: AdamTensorState,
276    v1: Option<AdamTensorState>,
277    v2: Option<AdamTensorState>,
278    v0: Option<AdamTensorState>,
279    g1: AdamTensorState,
280    g2: AdamTensorState,
281    k_k: AdamTensorState,
282    k_a: AdamTensorState,
283    r_k: AdamTensorState,
284    g_norm_w: AdamTensorState,
285    g_norm_b: AdamTensorState,
286}
287
288#[derive(Clone)]
289struct FfnAdamState {
290    x_k: AdamTensorState,
291    key_w: AdamTensorState,
292    value_w: AdamTensorState,
293}
294
295#[derive(Clone)]
296struct BlockAdamState {
297    pre_norm_w: Option<AdamTensorState>,
298    pre_norm_b: Option<AdamTensorState>,
299    attn_norm_w: AdamTensorState,
300    attn_norm_b: AdamTensorState,
301    ffn_norm_w: AdamTensorState,
302    ffn_norm_b: AdamTensorState,
303    attn: AttentionAdamState,
304    ffn: FfnAdamState,
305}
306
307#[derive(Clone)]
308/// Adam moments for full-parameter RWKV online training.
309pub struct FullAdamState {
310    embeddings: AdamTensorState,
311    ln_out_w: AdamTensorState,
312    ln_out_b: AdamTensorState,
313    lm_head: AdamTensorState,
314    blocks: Vec<BlockAdamState>,
315}
316
317#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
318/// Train-scope mask for RWKV full-parameter online updates.
319pub struct TrainScopeMask {
320    /// Train token embeddings.
321    pub embed: bool,
322    /// Train optional pre-norm parameters.
323    pub pre_norm: bool,
324    /// Train attention norm parameters.
325    pub attn_norm: bool,
326    /// Train FFN norm parameters.
327    pub ffn_norm: bool,
328    /// Train attention block parameters.
329    pub attn: bool,
330    /// Train FFN block parameters.
331    pub ffn: bool,
332    /// Train LM-head weights.
333    pub head: bool,
334    /// Train additive output-bias terms.
335    pub bias: bool,
336}
337
338impl TrainScopeMask {
339    #[inline]
340    /// Enable all train scopes.
341    pub fn all() -> Self {
342        Self {
343            embed: true,
344            pre_norm: true,
345            attn_norm: true,
346            ffn_norm: true,
347            attn: true,
348            ffn: true,
349            head: true,
350            bias: true,
351        }
352    }
353
354    #[inline]
355    /// Returns whether any non-head model parameters are trainable.
356    pub fn trains_non_head_params(&self) -> bool {
357        self.embed || self.pre_norm || self.attn_norm || self.ffn_norm || self.attn || self.ffn
358    }
359
360    #[inline]
361    /// Returns whether any parameter/bias updates are enabled.
362    pub fn trains_any_params(&self) -> bool {
363        self.trains_non_head_params() || self.head || self.bias
364    }
365}
366
367#[derive(Clone)]
368struct AttentionGradState {
369    x_r: Tensor1D,
370    x_w: Tensor1D,
371    x_k: Tensor1D,
372    x_v: Tensor1D,
373    x_a: Tensor1D,
374    x_g: Tensor1D,
375    rkv_proj: Tensor1D,
376    o_proj: Tensor1D,
377    w1: Tensor1D,
378    w2: Tensor1D,
379    w0: Tensor1D,
380    a1: Tensor1D,
381    a2: Tensor1D,
382    a0: Tensor1D,
383    v1: Option<Tensor1D>,
384    v2: Option<Tensor1D>,
385    v0: Option<Tensor1D>,
386    g1: Tensor1D,
387    g2: Tensor1D,
388    k_k: Tensor1D,
389    k_a: Tensor1D,
390    r_k: Tensor1D,
391    g_norm_w: Tensor1D,
392    g_norm_b: Tensor1D,
393}
394
395#[derive(Clone)]
396struct FfnGradState {
397    x_k: Tensor1D,
398    key_w: Tensor1D,
399    value_w: Tensor1D,
400}
401
402#[derive(Clone)]
403struct BlockGradState {
404    pre_norm_w: Option<Tensor1D>,
405    pre_norm_b: Option<Tensor1D>,
406    attn_norm_w: Tensor1D,
407    attn_norm_b: Tensor1D,
408    ffn_norm_w: Tensor1D,
409    ffn_norm_b: Tensor1D,
410    attn: AttentionGradState,
411    ffn: FfnGradState,
412}
413
414#[derive(Clone)]
415struct FullGradState {
416    embeddings: Tensor1D,
417    ln_out_w: Tensor1D,
418    ln_out_b: Tensor1D,
419    lm_head: Tensor1D,
420    blocks: Vec<BlockGradState>,
421}
422
423struct AdamStep {
424    lr: f32,
425    clip: f32,
426    b1: f32,
427    b2: f32,
428    eps: f32,
429    bias_corr1: f32,
430    bias_corr2: f32,
431}
432
433#[derive(Clone)]
434struct LayerTrainTrace {
435    x_in: Tensor1D,
436    x_after_pre: Tensor1D,
437    attn_norm: Tensor1D,
438    att_x_prev_old: Tensor1D,
439    ffn_x_prev_old: Tensor1D,
440    att_state_old: Tensor1D,
441    xr: Tensor1D,
442    xw: Tensor1D,
443    xk: Tensor1D,
444    xv: Tensor1D,
445    xa: Tensor1D,
446    xg: Tensor1D,
447    r: Tensor1D,
448    k_pre: Tensor1D,
449    k: Tensor1D,
450    v_pre: Tensor1D,
451    v: Tensor1D,
452    nu: Tensor1D,
453    w_hidden: Tensor1D,
454    w_pre: Tensor1D,
455    w_sigmoid: Tensor1D,
456    w_decay: Tensor1D,
457    a_hidden: Tensor1D,
458    a: Tensor1D,
459    g_hidden: Tensor1D,
460    g: Tensor1D,
461    kk_pre: Tensor1D,
462    kk: Tensor1D,
463    y_wkv: Tensor1D,
464    y_gn: Tensor1D,
465    alpha: Tensor1D,
466    y_head: Tensor1D,
467    y_gate: Tensor1D,
468    att_out: Tensor1D,
469    x_after_attn: Tensor1D,
470    ffn_norm: Tensor1D,
471    ffn_xk: Tensor1D,
472    ffn_pre: Tensor1D,
473    ffn_k: Tensor1D,
474    ffn_out: Tensor1D,
475    x_out: Tensor1D,
476    v_hidden: Tensor1D,
477    uses_v_residual: bool,
478}
479
480impl LayerTrainTrace {
481    fn new(cfg: &Config) -> Self {
482        let c = cfg.hidden_size;
483        let i = cfg.intermediate_size;
484        let state = cfg.num_heads * cfg.head_dim * cfg.head_dim;
485        Self {
486            x_in: Tensor1D::zeros(c),
487            x_after_pre: Tensor1D::zeros(c),
488            attn_norm: Tensor1D::zeros(c),
489            att_x_prev_old: Tensor1D::zeros(c),
490            ffn_x_prev_old: Tensor1D::zeros(c),
491            att_state_old: Tensor1D::zeros(state),
492            xr: Tensor1D::zeros(c),
493            xw: Tensor1D::zeros(c),
494            xk: Tensor1D::zeros(c),
495            xv: Tensor1D::zeros(c),
496            xa: Tensor1D::zeros(c),
497            xg: Tensor1D::zeros(c),
498            r: Tensor1D::zeros(c),
499            k_pre: Tensor1D::zeros(c),
500            k: Tensor1D::zeros(c),
501            v_pre: Tensor1D::zeros(c),
502            v: Tensor1D::zeros(c),
503            nu: Tensor1D::zeros(c),
504            w_hidden: Tensor1D::zeros(cfg.decay_low_rank),
505            w_pre: Tensor1D::zeros(c),
506            w_sigmoid: Tensor1D::zeros(c),
507            w_decay: Tensor1D::zeros(c),
508            a_hidden: Tensor1D::zeros(cfg.a_low_rank),
509            a: Tensor1D::zeros(c),
510            g_hidden: Tensor1D::zeros(cfg.g_low_rank),
511            g: Tensor1D::zeros(c),
512            kk_pre: Tensor1D::zeros(c),
513            kk: Tensor1D::zeros(c),
514            y_wkv: Tensor1D::zeros(c),
515            y_gn: Tensor1D::zeros(c),
516            alpha: Tensor1D::zeros(cfg.num_heads),
517            y_head: Tensor1D::zeros(c),
518            y_gate: Tensor1D::zeros(c),
519            att_out: Tensor1D::zeros(c),
520            x_after_attn: Tensor1D::zeros(c),
521            ffn_norm: Tensor1D::zeros(c),
522            ffn_xk: Tensor1D::zeros(c),
523            ffn_pre: Tensor1D::zeros(i),
524            ffn_k: Tensor1D::zeros(i),
525            ffn_out: Tensor1D::zeros(c),
526            x_out: Tensor1D::zeros(c),
527            v_hidden: Tensor1D::zeros(cfg.v_low_rank.max(1)),
528            uses_v_residual: false,
529        }
530    }
531}
532
533#[derive(Clone)]
534struct TokenTrainTrace {
535    token: usize,
536    x: Tensor1D,
537    x_normed: Tensor1D,
538    v_first: Tensor1D,
539    layers: Vec<LayerTrainTrace>,
540}
541
542impl TokenTrainTrace {
543    fn from_scratch(scratch: &ScratchBuffers) -> Self {
544        Self {
545            token: scratch.train_token,
546            x: scratch.x.clone(),
547            x_normed: scratch.x_normed.clone(),
548            v_first: scratch.train_v_first.clone(),
549            layers: scratch.train_trace_layers.clone(),
550        }
551    }
552}
553
554#[derive(Clone)]
555struct LayerRecurrentGradState {
556    att_x_prev: Tensor1D,
557    att_state: Tensor1D,
558    ffn_x_prev: Tensor1D,
559}
560
561impl LayerRecurrentGradState {
562    fn new(cfg: &Config) -> Self {
563        let state_size = cfg.num_heads * cfg.head_dim * cfg.head_dim;
564        Self {
565            att_x_prev: Tensor1D::zeros(cfg.hidden_size),
566            att_state: Tensor1D::zeros(state_size),
567            ffn_x_prev: Tensor1D::zeros(cfg.hidden_size),
568        }
569    }
570}
571
572#[derive(Clone)]
573struct RecurrentGradState {
574    layers: Vec<LayerRecurrentGradState>,
575}
576
577impl RecurrentGradState {
578    fn new(cfg: &Config) -> Self {
579        Self {
580            layers: (0..cfg.num_layers)
581                .map(|_| LayerRecurrentGradState::new(cfg))
582                .collect(),
583        }
584    }
585
586    fn zero(&mut self) {
587        for layer in &mut self.layers {
588            layer.att_x_prev.zero();
589            layer.att_state.zero();
590            layer.ffn_x_prev.zero();
591        }
592    }
593}
594
595/// Pre-allocated scratch buffers to avoid allocations in hot path.
596#[derive(Clone)]
597pub struct ScratchBuffers {
598    x: Tensor1D,          // Current hidden state
599    x_normed: Tensor1D,   // After layer norm
600    xr: Tensor1D,         // Token-shifted for r
601    xw: Tensor1D,         // Token-shifted for w
602    xk: Tensor1D,         // Token-shifted for k
603    xv: Tensor1D,         // Token-shifted for v
604    xa: Tensor1D,         // Token-shifted for a
605    xg: Tensor1D,         // Token-shifted for g
606    r: Tensor1D,          // Receptance
607    k: Tensor1D,          // Key
608    v: Tensor1D,          // Value
609    w_lora_tmp: Tensor1D, // Low-rank temp
610    w_decay: Tensor1D,    // Decay factor
611    a: Tensor1D,          // Gate a
612    g: Tensor1D,          // Gate g
613    kk: Tensor1D,         // Normalized key
614    y: Tensor1D,          // WKV output
615    att_out: Tensor1D,    // Attention output
616    ffn_k: Tensor1D,      // FFN key
617    ffn_out: Tensor1D,    // FFN output
618    logits: Tensor1D,     // Output logits
619    grad_x: Tensor1D,
620    grad_x2: Tensor1D,
621    grad_x3: Tensor1D,
622    grad_x4: Tensor1D,
623    grad_x5: Tensor1D,
624    grad_x6: Tensor1D,
625    grad_v_first: Tensor1D,
626    grad_param: Tensor1D,
627    grad_param2: Tensor1D,
628    grad_saved: Tensor1D,
629    grad_ffn: Tensor1D,
630    grad_ffn2: Tensor1D,
631    grad_low_rank: Tensor1D,
632    grad_low_rank2: Tensor1D,
633    grad_att_state: Tensor1D,
634    grad_logits: Tensor1D,
635    train_trace_layers: Vec<LayerTrainTrace>,
636    train_token: usize,
637    train_v_first: Tensor1D,
638    train_trace_valid: bool,
639    capture_train_trace: bool,
640}
641
642impl ScratchBuffers {
643    /// Allocate reusable per-token scratch buffers sized for `cfg`.
644    pub fn new(cfg: &Config) -> Self {
645        let c = cfg.hidden_size;
646        let i = cfg.intermediate_size;
647        let v = cfg.vocab_size;
648        let state_size = cfg.num_heads * cfg.head_dim * cfg.head_dim;
649        let d_rank = cfg
650            .decay_low_rank
651            .max(cfg.a_low_rank)
652            .max(cfg.v_low_rank)
653            .max(cfg.g_low_rank)
654            .max(64);
655        let mut train_trace_layers = Vec::with_capacity(cfg.num_layers);
656        for _ in 0..cfg.num_layers {
657            train_trace_layers.push(LayerTrainTrace::new(cfg));
658        }
659
660        Self {
661            x: Tensor1D::zeros(c),
662            x_normed: Tensor1D::zeros(c),
663            xr: Tensor1D::zeros(c),
664            xw: Tensor1D::zeros(c),
665            xk: Tensor1D::zeros(c),
666            xv: Tensor1D::zeros(c),
667            xa: Tensor1D::zeros(c),
668            xg: Tensor1D::zeros(c),
669            r: Tensor1D::zeros(c),
670            k: Tensor1D::zeros(c),
671            v: Tensor1D::zeros(c),
672            w_lora_tmp: Tensor1D::zeros(d_rank),
673            w_decay: Tensor1D::zeros(c),
674            a: Tensor1D::zeros(c),
675            g: Tensor1D::zeros(c),
676            kk: Tensor1D::zeros(c),
677            y: Tensor1D::zeros(c),
678            att_out: Tensor1D::zeros(c),
679            ffn_k: Tensor1D::zeros(i),
680            ffn_out: Tensor1D::zeros(c),
681            logits: Tensor1D::zeros(v),
682            grad_x: Tensor1D::zeros(c),
683            grad_x2: Tensor1D::zeros(c),
684            grad_x3: Tensor1D::zeros(c),
685            grad_x4: Tensor1D::zeros(c),
686            grad_x5: Tensor1D::zeros(c),
687            grad_x6: Tensor1D::zeros(c),
688            grad_v_first: Tensor1D::zeros(c),
689            grad_param: Tensor1D::zeros(c),
690            grad_param2: Tensor1D::zeros(c),
691            grad_saved: Tensor1D::zeros(c),
692            grad_ffn: Tensor1D::zeros(i),
693            grad_ffn2: Tensor1D::zeros(i),
694            grad_low_rank: Tensor1D::zeros(d_rank),
695            grad_low_rank2: Tensor1D::zeros(d_rank),
696            grad_att_state: Tensor1D::zeros(state_size),
697            grad_logits: Tensor1D::zeros(v),
698            train_trace_layers,
699            train_token: 0,
700            train_v_first: Tensor1D::zeros(c),
701            train_trace_valid: false,
702            capture_train_trace: false,
703        }
704    }
705
706    /// Final normalized hidden state consumed by LM head.
707    #[inline]
708    pub fn lm_head_input(&self) -> &[f32] {
709        self.x_normed.as_slice()
710    }
711
712    #[inline]
713    /// Borrow logits from the latest forward pass scratch buffer.
714    pub fn logits(&self) -> &[f32] {
715        self.logits.as_slice()
716    }
717
718    /// Restore LM-head input snapshot for reversible online updates.
719    #[inline]
720    pub fn set_lm_head_input(&mut self, value: &[f32]) {
721        self.x_normed.as_mut_slice().copy_from_slice(value);
722    }
723
724    /// Enable or disable per-token training trace capture.
725    #[inline]
726    pub fn set_capture_train_trace(&mut self, enabled: bool) {
727        self.capture_train_trace = enabled;
728        if !enabled {
729            self.train_trace_valid = false;
730        }
731    }
732
733    /// Whether the current scratch contains a valid full-trace for the latest forward pass.
734    #[inline]
735    pub fn has_train_trace(&self) -> bool {
736        self.train_trace_valid
737    }
738}
739
740impl Model {
741    fn tensor_from(weights: &Weights, name: &str) -> Result<Tensor1D> {
742        Ok(Tensor1D::from_vec(weights.require(name)?.data().to_vec()))
743    }
744
745    fn optional_tensor_from(weights: &Weights, name: &str) -> Option<Tensor1D> {
746        weights
747            .get(name)
748            .map(|tensor| Tensor1D::from_vec(tensor.data().to_vec()))
749    }
750
751    /// Load model from safetensors file.
752    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
753        let weights = Weights::load(path.as_ref()).context("Failed to load model weights")?;
754
755        // Infer config from weights
756        let emb = weights.require("model.embeddings.weight")?;
757        let vocab_size = emb.shape()[0];
758        let hidden_size = emb.shape()[1];
759
760        let num_heads = hidden_size / 64; // Assume head_dim=64
761        let head_dim = 64;
762
763        // Count layers by looking for layer weights
764        let mut num_layers = 0;
765        while weights
766            .get(&format!("model.layers.{}.attn.r_proj.weight", num_layers))
767            .is_some()
768        {
769            num_layers += 1;
770        }
771
772        // Get intermediate size from FFN
773        let ffn_key = weights.require("model.layers.0.ffn.key.weight")?;
774        let intermediate_size = ffn_key.shape()[0];
775
776        // Get low-rank dimensions
777        let w1 = weights.require("model.layers.0.attn.w_lora.lora.0.weight")?;
778        let decay_low_rank = w1.shape()[0];
779
780        let a1 = weights.require("model.layers.0.attn.a_lora.lora.0.weight")?;
781        let a_low_rank = a1.shape()[0];
782
783        let g1 = weights.require("model.layers.0.attn.g_lora.lora.0.weight")?;
784        let g_low_rank = g1.shape()[0];
785
786        // v_low_rank from layer 1 (layer 0 doesn't have it)
787        let v_low_rank = if num_layers > 1 {
788            if let Some(v1) = weights.get("model.layers.1.attn.v_lora.lora.0.weight") {
789                v1.shape()[0]
790            } else {
791                32
792            }
793        } else {
794            32
795        };
796
797        let cfg = Config {
798            vocab_size,
799            hidden_size,
800            num_layers,
801            num_heads,
802            head_dim,
803            intermediate_size,
804            layer_norm_eps: 1e-5,
805            group_norm_eps: 64e-5,
806            decay_low_rank,
807            a_low_rank,
808            v_low_rank,
809            g_low_rank,
810        };
811
812        // Load embeddings
813        let embeddings = Self::tensor_from(&weights, "model.embeddings.weight")?;
814
815        // Load output norm
816        let ln_out_w = Self::tensor_from(&weights, "model.norm.weight")?;
817        let ln_out_b = Self::tensor_from(&weights, "model.norm.bias")?;
818
819        // Load LM head
820        let lm_head = Self::tensor_from(&weights, "lm_head.weight")?;
821
822        // Load blocks
823        let mut blocks = Vec::with_capacity(num_layers);
824        for i in 0..num_layers {
825            let prefix = format!("model.layers.{}", i);
826
827            // Pre-norm (layer 0 only)
828            let (pre_norm_w, pre_norm_b) = if i == 0 {
829                (
830                    Some(Self::tensor_from(
831                        &weights,
832                        &format!("{}.pre_norm.weight", prefix),
833                    )?),
834                    Some(Self::tensor_from(
835                        &weights,
836                        &format!("{}.pre_norm.bias", prefix),
837                    )?),
838                )
839            } else {
840                (None, None)
841            };
842
843            // Norms
844            let attn_norm_w = Self::tensor_from(&weights, &format!("{}.attn_norm.weight", prefix))?;
845            let attn_norm_b = Self::tensor_from(&weights, &format!("{}.attn_norm.bias", prefix))?;
846            let ffn_norm_w = Self::tensor_from(&weights, &format!("{}.ffn_norm.weight", prefix))?;
847            let ffn_norm_b = Self::tensor_from(&weights, &format!("{}.ffn_norm.bias", prefix))?;
848
849            // Attention weights
850            // Load r/k/v projections and pack them contiguously
851            let r_proj_data = weights
852                .require(&format!("{}.attn.r_proj.weight", prefix))?
853                .data();
854            let k_proj_data = weights
855                .require(&format!("{}.attn.k_proj.weight", prefix))?
856                .data();
857            let v_proj_data = weights
858                .require(&format!("{}.attn.v_proj.weight", prefix))?
859                .data();
860
861            // Create packed RKV tensor: [r_proj, k_proj, v_proj]
862            let proj_size = hidden_size * hidden_size;
863            let mut rkv_proj = Tensor1D::zeros(3 * proj_size);
864            rkv_proj.as_mut_slice()[0..proj_size].copy_from_slice(r_proj_data);
865            rkv_proj.as_mut_slice()[proj_size..2 * proj_size].copy_from_slice(k_proj_data);
866            rkv_proj.as_mut_slice()[2 * proj_size..3 * proj_size].copy_from_slice(v_proj_data);
867
868            let attn = AttentionWeights {
869                x_r: Self::tensor_from(&weights, &format!("{}.attn.x_r", prefix))?,
870                x_w: Self::tensor_from(&weights, &format!("{}.attn.x_w", prefix))?,
871                x_k: Self::tensor_from(&weights, &format!("{}.attn.x_k", prefix))?,
872                x_v: Self::tensor_from(&weights, &format!("{}.attn.x_v", prefix))?,
873                x_a: Self::tensor_from(&weights, &format!("{}.attn.x_a", prefix))?,
874                x_g: Self::tensor_from(&weights, &format!("{}.attn.x_g", prefix))?,
875
876                rkv_proj,
877                o_proj: Self::tensor_from(&weights, &format!("{}.attn.o_proj.weight", prefix))?,
878
879                w1: Self::tensor_from(&weights, &format!("{}.attn.w_lora.lora.0.weight", prefix))?,
880                w2: Self::tensor_from(&weights, &format!("{}.attn.w_lora.lora.2.weight", prefix))?,
881                w0: Self::tensor_from(&weights, &format!("{}.attn.w_lora.lora.2.bias", prefix))?,
882
883                a1: Self::tensor_from(&weights, &format!("{}.attn.a_lora.lora.0.weight", prefix))?,
884                a2: Self::tensor_from(&weights, &format!("{}.attn.a_lora.lora.2.weight", prefix))?,
885                a0: Self::tensor_from(&weights, &format!("{}.attn.a_lora.lora.2.bias", prefix))?,
886
887                v1: Self::optional_tensor_from(
888                    &weights,
889                    &format!("{}.attn.v_lora.lora.0.weight", prefix),
890                ),
891                v2: Self::optional_tensor_from(
892                    &weights,
893                    &format!("{}.attn.v_lora.lora.2.weight", prefix),
894                ),
895                v0: Self::optional_tensor_from(
896                    &weights,
897                    &format!("{}.attn.v_lora.lora.2.bias", prefix),
898                ),
899
900                g1: Self::tensor_from(&weights, &format!("{}.attn.g_lora.lora.0.weight", prefix))?,
901                g2: Self::tensor_from(&weights, &format!("{}.attn.g_lora.lora.2.weight", prefix))?,
902
903                k_k: Self::tensor_from(&weights, &format!("{}.attn.k_k", prefix))?,
904                k_a: Self::tensor_from(&weights, &format!("{}.attn.k_a", prefix))?,
905                r_k: Self::tensor_from(&weights, &format!("{}.attn.r_k", prefix))?,
906
907                g_norm_w: Self::tensor_from(&weights, &format!("{}.attn.g_norm.weight", prefix))?,
908                g_norm_b: Self::tensor_from(&weights, &format!("{}.attn.g_norm.bias", prefix))?,
909            };
910
911            // FFN weights
912            let ffn = FfnWeights {
913                x_k: Self::tensor_from(&weights, &format!("{}.ffn.x_k", prefix))?,
914                key_w: Self::tensor_from(&weights, &format!("{}.ffn.key.weight", prefix))?,
915                value_w: Self::tensor_from(&weights, &format!("{}.ffn.value.weight", prefix))?,
916            };
917
918            blocks.push(BlockWeights {
919                pre_norm_w,
920                pre_norm_b,
921                attn_norm_w,
922                attn_norm_b,
923                ffn_norm_w,
924                ffn_norm_b,
925                attn,
926                ffn,
927            });
928        }
929
930        Ok(Self {
931            cfg,
932            embeddings,
933            ln_out_w,
934            ln_out_b,
935            lm_head,
936            blocks,
937        })
938    }
939
940    /// Create a randomly initialized model for online-training workflows.
941    pub fn new_random(cfg: Config, seed: u64) -> Result<Self> {
942        cfg.validate()?;
943
944        let mut rng = RwkvRng::new(seed);
945        let c = cfg.hidden_size;
946        let v = cfg.vocab_size;
947        let i = cfg.intermediate_size;
948        let d_w = cfg.decay_low_rank;
949        let d_a = cfg.a_low_rank;
950        let d_v = cfg.v_low_rank;
951        let d_g = cfg.g_low_rank;
952
953        let mut embeddings = Tensor1D::zeros(v * c);
954        init_uniform(&mut embeddings, &mut rng, 0.02);
955
956        let mut ln_out_w = Tensor1D::zeros(c);
957        let mut ln_out_b = Tensor1D::zeros(c);
958        init_const(&mut ln_out_w, 1.0);
959        init_const(&mut ln_out_b, 0.0);
960
961        let mut lm_head = Tensor1D::zeros(v * c);
962        init_uniform(&mut lm_head, &mut rng, 0.02);
963
964        let mut blocks = Vec::with_capacity(cfg.num_layers);
965        for layer_idx in 0..cfg.num_layers {
966            let (pre_norm_w, pre_norm_b) = if layer_idx == 0 {
967                let mut w = Tensor1D::zeros(c);
968                let mut b = Tensor1D::zeros(c);
969                init_const(&mut w, 1.0);
970                init_const(&mut b, 0.0);
971                (Some(w), Some(b))
972            } else {
973                (None, None)
974            };
975
976            let mut attn_norm_w = Tensor1D::zeros(c);
977            let mut attn_norm_b = Tensor1D::zeros(c);
978            init_const(&mut attn_norm_w, 1.0);
979            init_const(&mut attn_norm_b, 0.0);
980
981            let mut ffn_norm_w = Tensor1D::zeros(c);
982            let mut ffn_norm_b = Tensor1D::zeros(c);
983            init_const(&mut ffn_norm_w, 1.0);
984            init_const(&mut ffn_norm_b, 0.0);
985
986            let mut rkv_proj = Tensor1D::zeros(3 * c * c);
987            init_uniform(&mut rkv_proj, &mut rng, 0.02);
988
989            let mut o_proj = Tensor1D::zeros(c * c);
990            init_uniform(&mut o_proj, &mut rng, 0.02);
991
992            let mut w1 = Tensor1D::zeros(d_w * c);
993            let mut w2 = Tensor1D::zeros(c * d_w);
994            let mut w0 = Tensor1D::zeros(c);
995            init_uniform(&mut w1, &mut rng, 0.02);
996            init_uniform(&mut w2, &mut rng, 0.02);
997            init_const(&mut w0, 0.0);
998
999            let mut a1 = Tensor1D::zeros(d_a * c);
1000            let mut a2 = Tensor1D::zeros(c * d_a);
1001            let mut a0 = Tensor1D::zeros(c);
1002            init_uniform(&mut a1, &mut rng, 0.02);
1003            init_uniform(&mut a2, &mut rng, 0.02);
1004            init_const(&mut a0, 0.0);
1005
1006            let (v1, v2, v0) = if layer_idx == 0 {
1007                (None, None, None)
1008            } else {
1009                let mut v1 = Tensor1D::zeros(d_v * c);
1010                let mut v2 = Tensor1D::zeros(c * d_v);
1011                let mut v0 = Tensor1D::zeros(c);
1012                init_uniform(&mut v1, &mut rng, 0.02);
1013                init_uniform(&mut v2, &mut rng, 0.02);
1014                init_const(&mut v0, 0.0);
1015                (Some(v1), Some(v2), Some(v0))
1016            };
1017
1018            let mut g1 = Tensor1D::zeros(d_g * c);
1019            let mut g2 = Tensor1D::zeros(c * d_g);
1020            init_uniform(&mut g1, &mut rng, 0.02);
1021            init_uniform(&mut g2, &mut rng, 0.02);
1022
1023            let mut x_r = Tensor1D::zeros(c);
1024            let mut x_w = Tensor1D::zeros(c);
1025            let mut x_k = Tensor1D::zeros(c);
1026            let mut x_v = Tensor1D::zeros(c);
1027            let mut x_a = Tensor1D::zeros(c);
1028            let mut x_g = Tensor1D::zeros(c);
1029            init_centered(&mut x_r, &mut rng, 0.5, 0.02);
1030            init_centered(&mut x_w, &mut rng, 0.5, 0.02);
1031            init_centered(&mut x_k, &mut rng, 0.5, 0.02);
1032            init_centered(&mut x_v, &mut rng, 0.5, 0.02);
1033            init_centered(&mut x_a, &mut rng, 0.5, 0.02);
1034            init_centered(&mut x_g, &mut rng, 0.5, 0.02);
1035
1036            let mut k_k = Tensor1D::zeros(c);
1037            let mut k_a = Tensor1D::zeros(c);
1038            let mut r_k = Tensor1D::zeros(c);
1039            init_const(&mut k_k, 1.0);
1040            init_const(&mut k_a, 1.0);
1041            init_const(&mut r_k, 1.0);
1042
1043            let mut g_norm_w = Tensor1D::zeros(c);
1044            let mut g_norm_b = Tensor1D::zeros(c);
1045            init_const(&mut g_norm_w, 1.0);
1046            init_const(&mut g_norm_b, 0.0);
1047
1048            let attn = AttentionWeights {
1049                x_r,
1050                x_w,
1051                x_k,
1052                x_v,
1053                x_a,
1054                x_g,
1055                rkv_proj,
1056                o_proj,
1057                w1,
1058                w2,
1059                w0,
1060                a1,
1061                a2,
1062                a0,
1063                v1,
1064                v2,
1065                v0,
1066                g1,
1067                g2,
1068                k_k,
1069                k_a,
1070                r_k,
1071                g_norm_w,
1072                g_norm_b,
1073            };
1074
1075            let mut ffn_x_k = Tensor1D::zeros(c);
1076            init_centered(&mut ffn_x_k, &mut rng, 0.5, 0.02);
1077            let mut key_w = Tensor1D::zeros(i * c);
1078            let mut value_w = Tensor1D::zeros(c * i);
1079            init_uniform(&mut key_w, &mut rng, 0.02);
1080            init_uniform(&mut value_w, &mut rng, 0.02);
1081
1082            let ffn = FfnWeights {
1083                x_k: ffn_x_k,
1084                key_w,
1085                value_w,
1086            };
1087
1088            blocks.push(BlockWeights {
1089                pre_norm_w,
1090                pre_norm_b,
1091                attn_norm_w,
1092                attn_norm_b,
1093                ffn_norm_w,
1094                ffn_norm_b,
1095                attn,
1096                ffn,
1097            });
1098        }
1099
1100        Ok(Self {
1101            cfg,
1102            embeddings,
1103            ln_out_w,
1104            ln_out_b,
1105            lm_head,
1106            blocks,
1107        })
1108    }
1109
1110    /// Save model weights to a `.safetensors` file plus JSON sidecar config.
1111    pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> Result<()> {
1112        #[derive(Clone)]
1113        struct TensorRec {
1114            name: String,
1115            shape: Vec<usize>,
1116            data: Vec<f32>,
1117        }
1118
1119        let c = self.cfg.hidden_size;
1120        let v = self.cfg.vocab_size;
1121        let i = self.cfg.intermediate_size;
1122        let d_w = self.cfg.decay_low_rank;
1123        let d_a = self.cfg.a_low_rank;
1124        let d_v = self.cfg.v_low_rank;
1125        let d_g = self.cfg.g_low_rank;
1126
1127        let mut recs = Vec::<TensorRec>::new();
1128        let push = |recs: &mut Vec<TensorRec>, name: String, shape: Vec<usize>, src: &Tensor1D| {
1129            recs.push(TensorRec {
1130                name,
1131                shape,
1132                data: src.as_slice().to_vec(),
1133            });
1134        };
1135
1136        push(
1137            &mut recs,
1138            "model.embeddings.weight".to_string(),
1139            vec![v, c],
1140            &self.embeddings,
1141        );
1142        push(
1143            &mut recs,
1144            "model.norm.weight".to_string(),
1145            vec![c],
1146            &self.ln_out_w,
1147        );
1148        push(
1149            &mut recs,
1150            "model.norm.bias".to_string(),
1151            vec![c],
1152            &self.ln_out_b,
1153        );
1154        push(
1155            &mut recs,
1156            "lm_head.weight".to_string(),
1157            vec![v, c],
1158            &self.lm_head,
1159        );
1160
1161        for (idx, b) in self.blocks.iter().enumerate() {
1162            let pfx = format!("model.layers.{idx}");
1163            if let (Some(w), Some(bias)) = (&b.pre_norm_w, &b.pre_norm_b) {
1164                push(&mut recs, format!("{pfx}.pre_norm.weight"), vec![c], w);
1165                push(&mut recs, format!("{pfx}.pre_norm.bias"), vec![c], bias);
1166            }
1167
1168            push(
1169                &mut recs,
1170                format!("{pfx}.attn_norm.weight"),
1171                vec![c],
1172                &b.attn_norm_w,
1173            );
1174            push(
1175                &mut recs,
1176                format!("{pfx}.attn_norm.bias"),
1177                vec![c],
1178                &b.attn_norm_b,
1179            );
1180            push(
1181                &mut recs,
1182                format!("{pfx}.ffn_norm.weight"),
1183                vec![c],
1184                &b.ffn_norm_w,
1185            );
1186            push(
1187                &mut recs,
1188                format!("{pfx}.ffn_norm.bias"),
1189                vec![c],
1190                &b.ffn_norm_b,
1191            );
1192
1193            let proj = b.attn.rkv_proj.as_slice();
1194            let proj_size = c * c;
1195            recs.push(TensorRec {
1196                name: format!("{pfx}.attn.r_proj.weight"),
1197                shape: vec![c, c],
1198                data: proj[0..proj_size].to_vec(),
1199            });
1200            recs.push(TensorRec {
1201                name: format!("{pfx}.attn.k_proj.weight"),
1202                shape: vec![c, c],
1203                data: proj[proj_size..2 * proj_size].to_vec(),
1204            });
1205            recs.push(TensorRec {
1206                name: format!("{pfx}.attn.v_proj.weight"),
1207                shape: vec![c, c],
1208                data: proj[2 * proj_size..3 * proj_size].to_vec(),
1209            });
1210
1211            push(
1212                &mut recs,
1213                format!("{pfx}.attn.o_proj.weight"),
1214                vec![c, c],
1215                &b.attn.o_proj,
1216            );
1217            push(&mut recs, format!("{pfx}.attn.x_r"), vec![c], &b.attn.x_r);
1218            push(&mut recs, format!("{pfx}.attn.x_w"), vec![c], &b.attn.x_w);
1219            push(&mut recs, format!("{pfx}.attn.x_k"), vec![c], &b.attn.x_k);
1220            push(&mut recs, format!("{pfx}.attn.x_v"), vec![c], &b.attn.x_v);
1221            push(&mut recs, format!("{pfx}.attn.x_a"), vec![c], &b.attn.x_a);
1222            push(&mut recs, format!("{pfx}.attn.x_g"), vec![c], &b.attn.x_g);
1223
1224            push(
1225                &mut recs,
1226                format!("{pfx}.attn.w_lora.lora.0.weight"),
1227                vec![d_w, c],
1228                &b.attn.w1,
1229            );
1230            push(
1231                &mut recs,
1232                format!("{pfx}.attn.w_lora.lora.2.weight"),
1233                vec![c, d_w],
1234                &b.attn.w2,
1235            );
1236            push(
1237                &mut recs,
1238                format!("{pfx}.attn.w_lora.lora.2.bias"),
1239                vec![c],
1240                &b.attn.w0,
1241            );
1242
1243            push(
1244                &mut recs,
1245                format!("{pfx}.attn.a_lora.lora.0.weight"),
1246                vec![d_a, c],
1247                &b.attn.a1,
1248            );
1249            push(
1250                &mut recs,
1251                format!("{pfx}.attn.a_lora.lora.2.weight"),
1252                vec![c, d_a],
1253                &b.attn.a2,
1254            );
1255            push(
1256                &mut recs,
1257                format!("{pfx}.attn.a_lora.lora.2.bias"),
1258                vec![c],
1259                &b.attn.a0,
1260            );
1261
1262            if let Some(v1) = &b.attn.v1 {
1263                push(
1264                    &mut recs,
1265                    format!("{pfx}.attn.v_lora.lora.0.weight"),
1266                    vec![d_v, c],
1267                    v1,
1268                );
1269            }
1270            if let Some(v2) = &b.attn.v2 {
1271                push(
1272                    &mut recs,
1273                    format!("{pfx}.attn.v_lora.lora.2.weight"),
1274                    vec![c, d_v],
1275                    v2,
1276                );
1277            }
1278            if let Some(v0) = &b.attn.v0 {
1279                push(
1280                    &mut recs,
1281                    format!("{pfx}.attn.v_lora.lora.2.bias"),
1282                    vec![c],
1283                    v0,
1284                );
1285            }
1286
1287            push(
1288                &mut recs,
1289                format!("{pfx}.attn.g_lora.lora.0.weight"),
1290                vec![d_g, c],
1291                &b.attn.g1,
1292            );
1293            push(
1294                &mut recs,
1295                format!("{pfx}.attn.g_lora.lora.2.weight"),
1296                vec![c, d_g],
1297                &b.attn.g2,
1298            );
1299
1300            push(&mut recs, format!("{pfx}.attn.k_k"), vec![c], &b.attn.k_k);
1301            push(&mut recs, format!("{pfx}.attn.k_a"), vec![c], &b.attn.k_a);
1302            push(&mut recs, format!("{pfx}.attn.r_k"), vec![c], &b.attn.r_k);
1303            push(
1304                &mut recs,
1305                format!("{pfx}.attn.g_norm.weight"),
1306                vec![c],
1307                &b.attn.g_norm_w,
1308            );
1309            push(
1310                &mut recs,
1311                format!("{pfx}.attn.g_norm.bias"),
1312                vec![c],
1313                &b.attn.g_norm_b,
1314            );
1315
1316            push(&mut recs, format!("{pfx}.ffn.x_k"), vec![c], &b.ffn.x_k);
1317            push(
1318                &mut recs,
1319                format!("{pfx}.ffn.key.weight"),
1320                vec![i, c],
1321                &b.ffn.key_w,
1322            );
1323            push(
1324                &mut recs,
1325                format!("{pfx}.ffn.value.weight"),
1326                vec![c, i],
1327                &b.ffn.value_w,
1328            );
1329        }
1330
1331        recs.sort_by(|a, b| a.name.cmp(&b.name));
1332        let mut offset = 0usize;
1333        let mut header = serde_json::Map::new();
1334        header.insert("__metadata__".to_string(), json!({}));
1335        for rec in &recs {
1336            let bytes = rec.data.len() * 4;
1337            header.insert(
1338                rec.name.clone(),
1339                json!({
1340                    "dtype": "F32",
1341                    "shape": rec.shape,
1342                    "data_offsets": [offset, offset + bytes]
1343                }),
1344            );
1345            offset += bytes;
1346        }
1347        let header_bytes = serde_json::to_vec(&header)?;
1348        let mut f = File::create(path.as_ref())?;
1349        f.write_all(&(header_bytes.len() as u64).to_le_bytes())?;
1350        f.write_all(&header_bytes)?;
1351        for rec in &recs {
1352            for v in &rec.data {
1353                f.write_all(&v.to_le_bytes())?;
1354            }
1355        }
1356        Ok(())
1357    }
1358
1359    /// Allocate zero-initialized Adam moments matching all trainable tensors.
1360    pub fn new_full_adam_state(&self) -> FullAdamState {
1361        let mut blocks = Vec::with_capacity(self.blocks.len());
1362        for b in &self.blocks {
1363            blocks.push(BlockAdamState {
1364                pre_norm_w: b.pre_norm_w.as_ref().map(|t| AdamTensorState::new(t.len())),
1365                pre_norm_b: b.pre_norm_b.as_ref().map(|t| AdamTensorState::new(t.len())),
1366                attn_norm_w: AdamTensorState::new(b.attn_norm_w.len()),
1367                attn_norm_b: AdamTensorState::new(b.attn_norm_b.len()),
1368                ffn_norm_w: AdamTensorState::new(b.ffn_norm_w.len()),
1369                ffn_norm_b: AdamTensorState::new(b.ffn_norm_b.len()),
1370                attn: AttentionAdamState {
1371                    x_r: AdamTensorState::new(b.attn.x_r.len()),
1372                    x_w: AdamTensorState::new(b.attn.x_w.len()),
1373                    x_k: AdamTensorState::new(b.attn.x_k.len()),
1374                    x_v: AdamTensorState::new(b.attn.x_v.len()),
1375                    x_a: AdamTensorState::new(b.attn.x_a.len()),
1376                    x_g: AdamTensorState::new(b.attn.x_g.len()),
1377                    rkv_proj: AdamTensorState::new(b.attn.rkv_proj.len()),
1378                    o_proj: AdamTensorState::new(b.attn.o_proj.len()),
1379                    w1: AdamTensorState::new(b.attn.w1.len()),
1380                    w2: AdamTensorState::new(b.attn.w2.len()),
1381                    w0: AdamTensorState::new(b.attn.w0.len()),
1382                    a1: AdamTensorState::new(b.attn.a1.len()),
1383                    a2: AdamTensorState::new(b.attn.a2.len()),
1384                    a0: AdamTensorState::new(b.attn.a0.len()),
1385                    v1: b.attn.v1.as_ref().map(|t| AdamTensorState::new(t.len())),
1386                    v2: b.attn.v2.as_ref().map(|t| AdamTensorState::new(t.len())),
1387                    v0: b.attn.v0.as_ref().map(|t| AdamTensorState::new(t.len())),
1388                    g1: AdamTensorState::new(b.attn.g1.len()),
1389                    g2: AdamTensorState::new(b.attn.g2.len()),
1390                    k_k: AdamTensorState::new(b.attn.k_k.len()),
1391                    k_a: AdamTensorState::new(b.attn.k_a.len()),
1392                    r_k: AdamTensorState::new(b.attn.r_k.len()),
1393                    g_norm_w: AdamTensorState::new(b.attn.g_norm_w.len()),
1394                    g_norm_b: AdamTensorState::new(b.attn.g_norm_b.len()),
1395                },
1396                ffn: FfnAdamState {
1397                    x_k: AdamTensorState::new(b.ffn.x_k.len()),
1398                    key_w: AdamTensorState::new(b.ffn.key_w.len()),
1399                    value_w: AdamTensorState::new(b.ffn.value_w.len()),
1400                },
1401            });
1402        }
1403        FullAdamState {
1404            embeddings: AdamTensorState::new(self.embeddings.len()),
1405            ln_out_w: AdamTensorState::new(self.ln_out_w.len()),
1406            ln_out_b: AdamTensorState::new(self.ln_out_b.len()),
1407            lm_head: AdamTensorState::new(self.lm_head.len()),
1408            blocks,
1409        }
1410    }
1411
1412    /// Allocate zero-initialized gradient storage matching all trainable tensors.
1413    fn new_full_grad_state(&self) -> FullGradState {
1414        let mut blocks = Vec::with_capacity(self.blocks.len());
1415        for b in &self.blocks {
1416            blocks.push(BlockGradState {
1417                pre_norm_w: b.pre_norm_w.as_ref().map(|t| Tensor1D::zeros(t.len())),
1418                pre_norm_b: b.pre_norm_b.as_ref().map(|t| Tensor1D::zeros(t.len())),
1419                attn_norm_w: Tensor1D::zeros(b.attn_norm_w.len()),
1420                attn_norm_b: Tensor1D::zeros(b.attn_norm_b.len()),
1421                ffn_norm_w: Tensor1D::zeros(b.ffn_norm_w.len()),
1422                ffn_norm_b: Tensor1D::zeros(b.ffn_norm_b.len()),
1423                attn: AttentionGradState {
1424                    x_r: Tensor1D::zeros(b.attn.x_r.len()),
1425                    x_w: Tensor1D::zeros(b.attn.x_w.len()),
1426                    x_k: Tensor1D::zeros(b.attn.x_k.len()),
1427                    x_v: Tensor1D::zeros(b.attn.x_v.len()),
1428                    x_a: Tensor1D::zeros(b.attn.x_a.len()),
1429                    x_g: Tensor1D::zeros(b.attn.x_g.len()),
1430                    rkv_proj: Tensor1D::zeros(b.attn.rkv_proj.len()),
1431                    o_proj: Tensor1D::zeros(b.attn.o_proj.len()),
1432                    w1: Tensor1D::zeros(b.attn.w1.len()),
1433                    w2: Tensor1D::zeros(b.attn.w2.len()),
1434                    w0: Tensor1D::zeros(b.attn.w0.len()),
1435                    a1: Tensor1D::zeros(b.attn.a1.len()),
1436                    a2: Tensor1D::zeros(b.attn.a2.len()),
1437                    a0: Tensor1D::zeros(b.attn.a0.len()),
1438                    v1: b.attn.v1.as_ref().map(|t| Tensor1D::zeros(t.len())),
1439                    v2: b.attn.v2.as_ref().map(|t| Tensor1D::zeros(t.len())),
1440                    v0: b.attn.v0.as_ref().map(|t| Tensor1D::zeros(t.len())),
1441                    g1: Tensor1D::zeros(b.attn.g1.len()),
1442                    g2: Tensor1D::zeros(b.attn.g2.len()),
1443                    k_k: Tensor1D::zeros(b.attn.k_k.len()),
1444                    k_a: Tensor1D::zeros(b.attn.k_a.len()),
1445                    r_k: Tensor1D::zeros(b.attn.r_k.len()),
1446                    g_norm_w: Tensor1D::zeros(b.attn.g_norm_w.len()),
1447                    g_norm_b: Tensor1D::zeros(b.attn.g_norm_b.len()),
1448                },
1449                ffn: FfnGradState {
1450                    x_k: Tensor1D::zeros(b.ffn.x_k.len()),
1451                    key_w: Tensor1D::zeros(b.ffn.key_w.len()),
1452                    value_w: Tensor1D::zeros(b.ffn.value_w.len()),
1453                },
1454            });
1455        }
1456        FullGradState {
1457            embeddings: Tensor1D::zeros(self.embeddings.len()),
1458            ln_out_w: Tensor1D::zeros(self.ln_out_w.len()),
1459            ln_out_b: Tensor1D::zeros(self.ln_out_b.len()),
1460            lm_head: Tensor1D::zeros(self.lm_head.len()),
1461            blocks,
1462        }
1463    }
1464
1465    fn new_recurrent_grad_state(&self) -> RecurrentGradState {
1466        RecurrentGradState::new(&self.cfg)
1467    }
1468
1469    /// Save full-parameter Adam moments for exact online-training continuation.
1470    pub fn save_full_adam_safetensors<P: AsRef<Path>>(
1471        &self,
1472        adam: &FullAdamState,
1473        path: P,
1474    ) -> Result<()> {
1475        #[derive(Clone)]
1476        struct TensorRec {
1477            name: String,
1478            shape: Vec<usize>,
1479            data: Vec<f32>,
1480        }
1481        let c = self.cfg.hidden_size;
1482        let i = self.cfg.intermediate_size;
1483        let v = self.cfg.vocab_size;
1484        let h = self.cfg.num_heads;
1485        let n = self.cfg.head_dim;
1486        let d_w = self.cfg.decay_low_rank;
1487        let d_a = self.cfg.a_low_rank;
1488        let d_v = self.cfg.v_low_rank;
1489        let d_g = self.cfg.g_low_rank;
1490        let mut recs = Vec::<TensorRec>::new();
1491        let mut push_state = |name: &str, shape: Vec<usize>, st: &AdamTensorState| {
1492            recs.push(TensorRec {
1493                name: format!("{name}.m"),
1494                shape: shape.clone(),
1495                data: st.m.as_slice().to_vec(),
1496            });
1497            recs.push(TensorRec {
1498                name: format!("{name}.v"),
1499                shape,
1500                data: st.v.as_slice().to_vec(),
1501            });
1502        };
1503
1504        push_state("opt.model.embeddings.weight", vec![v, c], &adam.embeddings);
1505        push_state("opt.model.norm.weight", vec![c], &adam.ln_out_w);
1506        push_state("opt.model.norm.bias", vec![c], &adam.ln_out_b);
1507        push_state("opt.lm_head.weight", vec![v, c], &adam.lm_head);
1508        for (idx, b) in adam.blocks.iter().enumerate() {
1509            let p = format!("opt.model.layers.{idx}");
1510            if let Some(st) = &b.pre_norm_w {
1511                push_state(&format!("{p}.pre_norm.weight"), vec![c], st);
1512            }
1513            if let Some(st) = &b.pre_norm_b {
1514                push_state(&format!("{p}.pre_norm.bias"), vec![c], st);
1515            }
1516            push_state(&format!("{p}.attn_norm.weight"), vec![c], &b.attn_norm_w);
1517            push_state(&format!("{p}.attn_norm.bias"), vec![c], &b.attn_norm_b);
1518            push_state(&format!("{p}.ffn_norm.weight"), vec![c], &b.ffn_norm_w);
1519            push_state(&format!("{p}.ffn_norm.bias"), vec![c], &b.ffn_norm_b);
1520
1521            push_state(&format!("{p}.attn.x_r"), vec![c], &b.attn.x_r);
1522            push_state(&format!("{p}.attn.x_w"), vec![c], &b.attn.x_w);
1523            push_state(&format!("{p}.attn.x_k"), vec![c], &b.attn.x_k);
1524            push_state(&format!("{p}.attn.x_v"), vec![c], &b.attn.x_v);
1525            push_state(&format!("{p}.attn.x_a"), vec![c], &b.attn.x_a);
1526            push_state(&format!("{p}.attn.x_g"), vec![c], &b.attn.x_g);
1527            push_state(
1528                &format!("{p}.attn.rkv_proj"),
1529                vec![3, c, c],
1530                &b.attn.rkv_proj,
1531            );
1532            push_state(
1533                &format!("{p}.attn.o_proj.weight"),
1534                vec![c, c],
1535                &b.attn.o_proj,
1536            );
1537            push_state(
1538                &format!("{p}.attn.w_lora.lora.0.weight"),
1539                vec![d_w, c],
1540                &b.attn.w1,
1541            );
1542            push_state(
1543                &format!("{p}.attn.w_lora.lora.2.weight"),
1544                vec![c, d_w],
1545                &b.attn.w2,
1546            );
1547            push_state(&format!("{p}.attn.w_lora.lora.2.bias"), vec![c], &b.attn.w0);
1548            push_state(
1549                &format!("{p}.attn.a_lora.lora.0.weight"),
1550                vec![d_a, c],
1551                &b.attn.a1,
1552            );
1553            push_state(
1554                &format!("{p}.attn.a_lora.lora.2.weight"),
1555                vec![c, d_a],
1556                &b.attn.a2,
1557            );
1558            push_state(&format!("{p}.attn.a_lora.lora.2.bias"), vec![c], &b.attn.a0);
1559            if let Some(st) = &b.attn.v1 {
1560                push_state(&format!("{p}.attn.v_lora.lora.0.weight"), vec![d_v, c], st);
1561            }
1562            if let Some(st) = &b.attn.v2 {
1563                push_state(&format!("{p}.attn.v_lora.lora.2.weight"), vec![c, d_v], st);
1564            }
1565            if let Some(st) = &b.attn.v0 {
1566                push_state(&format!("{p}.attn.v_lora.lora.2.bias"), vec![c], st);
1567            }
1568            push_state(
1569                &format!("{p}.attn.g_lora.lora.0.weight"),
1570                vec![d_g, c],
1571                &b.attn.g1,
1572            );
1573            push_state(
1574                &format!("{p}.attn.g_lora.lora.2.weight"),
1575                vec![c, d_g],
1576                &b.attn.g2,
1577            );
1578            push_state(&format!("{p}.attn.k_k"), vec![c], &b.attn.k_k);
1579            push_state(&format!("{p}.attn.k_a"), vec![c], &b.attn.k_a);
1580            push_state(&format!("{p}.attn.r_k"), vec![h, n], &b.attn.r_k);
1581            push_state(
1582                &format!("{p}.attn.g_norm.weight"),
1583                vec![c],
1584                &b.attn.g_norm_w,
1585            );
1586            push_state(&format!("{p}.attn.g_norm.bias"), vec![c], &b.attn.g_norm_b);
1587
1588            push_state(&format!("{p}.ffn.x_k"), vec![c], &b.ffn.x_k);
1589            push_state(&format!("{p}.ffn.key.weight"), vec![i, c], &b.ffn.key_w);
1590            push_state(&format!("{p}.ffn.value.weight"), vec![c, i], &b.ffn.value_w);
1591        }
1592
1593        recs.sort_by(|a, b| a.name.cmp(&b.name));
1594        let mut offset = 0usize;
1595        let mut header = serde_json::Map::new();
1596        header.insert("__metadata__".to_string(), json!({}));
1597        for rec in &recs {
1598            let bytes = rec.data.len() * 4;
1599            header.insert(
1600                rec.name.clone(),
1601                json!({
1602                    "dtype": "F32",
1603                    "shape": rec.shape,
1604                    "data_offsets": [offset, offset + bytes],
1605                }),
1606            );
1607            offset += bytes;
1608        }
1609
1610        let header_bytes = serde_json::to_vec(&header)?;
1611        let mut f = File::create(path)?;
1612        f.write_all(&(header_bytes.len() as u64).to_le_bytes())?;
1613        f.write_all(&header_bytes)?;
1614        for rec in &recs {
1615            for v in &rec.data {
1616                f.write_all(&v.to_le_bytes())?;
1617            }
1618        }
1619        Ok(())
1620    }
1621
1622    /// Load full-parameter Adam moments and validate tensor shapes.
1623    pub fn load_full_adam_safetensors<P: AsRef<Path>>(&self, path: P) -> Result<FullAdamState> {
1624        let weights = Weights::load(path.as_ref()).with_context(|| {
1625            format!(
1626                "failed to load optimizer moments from {}",
1627                path.as_ref().display()
1628            )
1629        })?;
1630        let mut adam = self.new_full_adam_state();
1631        let load_state = |name: &str, st: &mut AdamTensorState| -> Result<()> {
1632            let m_name = format!("{name}.m");
1633            let v_name = format!("{name}.v");
1634            let m_t = weights
1635                .require(&m_name)
1636                .with_context(|| format!("missing optimizer tensor '{m_name}'"))?;
1637            let v_t = weights
1638                .require(&v_name)
1639                .with_context(|| format!("missing optimizer tensor '{v_name}'"))?;
1640            if m_t.data().len() != st.m.len() {
1641                bail!(
1642                    "optimizer tensor '{}' len {} != expected {}",
1643                    m_name,
1644                    m_t.data().len(),
1645                    st.m.len()
1646                );
1647            }
1648            if v_t.data().len() != st.v.len() {
1649                bail!(
1650                    "optimizer tensor '{}' len {} != expected {}",
1651                    v_name,
1652                    v_t.data().len(),
1653                    st.v.len()
1654                );
1655            }
1656            st.m.as_mut_slice().copy_from_slice(m_t.data());
1657            st.v.as_mut_slice().copy_from_slice(v_t.data());
1658            Ok(())
1659        };
1660
1661        let c = self.cfg.hidden_size;
1662        let i = self.cfg.intermediate_size;
1663        let v = self.cfg.vocab_size;
1664        let h = self.cfg.num_heads;
1665        let n = self.cfg.head_dim;
1666        let _ = (c, i, v, h, n);
1667        load_state("opt.model.embeddings.weight", &mut adam.embeddings)?;
1668        load_state("opt.model.norm.weight", &mut adam.ln_out_w)?;
1669        load_state("opt.model.norm.bias", &mut adam.ln_out_b)?;
1670        load_state("opt.lm_head.weight", &mut adam.lm_head)?;
1671        for (idx, b) in adam.blocks.iter_mut().enumerate() {
1672            let p = format!("opt.model.layers.{idx}");
1673            if let Some(st) = b.pre_norm_w.as_mut() {
1674                load_state(&format!("{p}.pre_norm.weight"), st)?;
1675            }
1676            if let Some(st) = b.pre_norm_b.as_mut() {
1677                load_state(&format!("{p}.pre_norm.bias"), st)?;
1678            }
1679            load_state(&format!("{p}.attn_norm.weight"), &mut b.attn_norm_w)?;
1680            load_state(&format!("{p}.attn_norm.bias"), &mut b.attn_norm_b)?;
1681            load_state(&format!("{p}.ffn_norm.weight"), &mut b.ffn_norm_w)?;
1682            load_state(&format!("{p}.ffn_norm.bias"), &mut b.ffn_norm_b)?;
1683            load_state(&format!("{p}.attn.x_r"), &mut b.attn.x_r)?;
1684            load_state(&format!("{p}.attn.x_w"), &mut b.attn.x_w)?;
1685            load_state(&format!("{p}.attn.x_k"), &mut b.attn.x_k)?;
1686            load_state(&format!("{p}.attn.x_v"), &mut b.attn.x_v)?;
1687            load_state(&format!("{p}.attn.x_a"), &mut b.attn.x_a)?;
1688            load_state(&format!("{p}.attn.x_g"), &mut b.attn.x_g)?;
1689            load_state(&format!("{p}.attn.rkv_proj"), &mut b.attn.rkv_proj)?;
1690            load_state(&format!("{p}.attn.o_proj.weight"), &mut b.attn.o_proj)?;
1691            load_state(&format!("{p}.attn.w_lora.lora.0.weight"), &mut b.attn.w1)?;
1692            load_state(&format!("{p}.attn.w_lora.lora.2.weight"), &mut b.attn.w2)?;
1693            load_state(&format!("{p}.attn.w_lora.lora.2.bias"), &mut b.attn.w0)?;
1694            load_state(&format!("{p}.attn.a_lora.lora.0.weight"), &mut b.attn.a1)?;
1695            load_state(&format!("{p}.attn.a_lora.lora.2.weight"), &mut b.attn.a2)?;
1696            load_state(&format!("{p}.attn.a_lora.lora.2.bias"), &mut b.attn.a0)?;
1697            if let Some(st) = b.attn.v1.as_mut() {
1698                load_state(&format!("{p}.attn.v_lora.lora.0.weight"), st)?;
1699            }
1700            if let Some(st) = b.attn.v2.as_mut() {
1701                load_state(&format!("{p}.attn.v_lora.lora.2.weight"), st)?;
1702            }
1703            if let Some(st) = b.attn.v0.as_mut() {
1704                load_state(&format!("{p}.attn.v_lora.lora.2.bias"), st)?;
1705            }
1706            load_state(&format!("{p}.attn.g_lora.lora.0.weight"), &mut b.attn.g1)?;
1707            load_state(&format!("{p}.attn.g_lora.lora.2.weight"), &mut b.attn.g2)?;
1708            load_state(&format!("{p}.attn.k_k"), &mut b.attn.k_k)?;
1709            load_state(&format!("{p}.attn.k_a"), &mut b.attn.k_a)?;
1710            load_state(&format!("{p}.attn.r_k"), &mut b.attn.r_k)?;
1711            load_state(&format!("{p}.attn.g_norm.weight"), &mut b.attn.g_norm_w)?;
1712            load_state(&format!("{p}.attn.g_norm.bias"), &mut b.attn.g_norm_b)?;
1713            load_state(&format!("{p}.ffn.x_k"), &mut b.ffn.x_k)?;
1714            load_state(&format!("{p}.ffn.key.weight"), &mut b.ffn.key_w)?;
1715            load_state(&format!("{p}.ffn.value.weight"), &mut b.ffn.value_w)?;
1716        }
1717        Ok(adam)
1718    }
1719
1720    /// Get model configuration.
1721    pub fn config(&self) -> &Config {
1722        &self.cfg
1723    }
1724
1725    /// Create new state for this model.
1726    pub fn new_state(&self) -> State {
1727        State::new(&self.cfg)
1728    }
1729
1730    /// Immutable LM-head weights, row-major `(vocab, hidden)`.
1731    #[inline]
1732    pub fn lm_head_weights(&self) -> &[f32] {
1733        self.lm_head.as_slice()
1734    }
1735
1736    /// Mutable LM-head weights, row-major `(vocab, hidden)`.
1737    #[inline]
1738    pub fn lm_head_weights_mut(&mut self) -> &mut [f32] {
1739        self.lm_head.as_mut_slice()
1740    }
1741
1742    #[allow(clippy::too_many_arguments)]
1743    fn apply_full_gradients(
1744        &mut self,
1745        grads: &FullGradState,
1746        scope: TrainScopeMask,
1747        optimizer: OptimizerKind,
1748        lr: f32,
1749        clip: f32,
1750        adam_t: &mut usize,
1751        model_adam: Option<&mut FullAdamState>,
1752        out_bias: Option<&mut [f32]>,
1753        out_bias_grad: Option<&[f32]>,
1754        out_bias_adam_m: Option<&mut [f32]>,
1755        out_bias_adam_v: Option<&mut [f32]>,
1756    ) -> Result<()> {
1757        let mut adam_step = None::<AdamStep>;
1758        let mut model_adam = model_adam;
1759        if matches!(optimizer, OptimizerKind::Adam) {
1760            *adam_t = adam_t.saturating_add(1);
1761            let t = (*adam_t).max(1) as i32;
1762            let b1 = 0.9f32;
1763            let b2 = 0.999f32;
1764            adam_step = Some(AdamStep {
1765                lr,
1766                clip: clip.max(0.0),
1767                b1,
1768                b2,
1769                eps: 1e-8,
1770                bias_corr1: 1.0 - b1.powi(t),
1771                bias_corr2: 1.0 - b2.powi(t),
1772            });
1773            if scope.trains_non_head_params() && model_adam.is_none() {
1774                bail!("rwkv Adam full-training state is missing");
1775            }
1776        }
1777
1778        if scope.bias
1779            && let (Some(bias), Some(grad)) = (out_bias, out_bias_grad)
1780        {
1781            match optimizer {
1782                OptimizerKind::Sgd => sgd_vec_update(bias, grad, lr, clip),
1783                OptimizerKind::Adam => {
1784                    let cfg = adam_step.as_ref().expect("adam cfg initialized");
1785                    let Some(m) = out_bias_adam_m else {
1786                        bail!("rwkv Adam output-bias state is missing (m)");
1787                    };
1788                    let Some(v) = out_bias_adam_v else {
1789                        bail!("rwkv Adam output-bias state is missing (v)");
1790                    };
1791                    apply_adam_vec_update_raw(bias, grad, m, v, cfg);
1792                }
1793            }
1794        }
1795
1796        if scope.head {
1797            match optimizer {
1798                OptimizerKind::Sgd => {
1799                    sgd_vec_update(
1800                        self.lm_head.as_mut_slice(),
1801                        grads.lm_head.as_slice(),
1802                        lr,
1803                        clip,
1804                    );
1805                    sgd_vec_update(
1806                        self.ln_out_w.as_mut_slice(),
1807                        grads.ln_out_w.as_slice(),
1808                        lr,
1809                        clip,
1810                    );
1811                    sgd_vec_update(
1812                        self.ln_out_b.as_mut_slice(),
1813                        grads.ln_out_b.as_slice(),
1814                        lr,
1815                        clip,
1816                    );
1817                }
1818                OptimizerKind::Adam => {
1819                    let cfg = adam_step.as_ref().expect("adam cfg initialized");
1820                    let adam = model_adam.as_mut().expect("adam state exists");
1821                    apply_adam_vec_update(
1822                        self.lm_head.as_mut_slice(),
1823                        grads.lm_head.as_slice(),
1824                        &mut adam.lm_head,
1825                        cfg,
1826                    );
1827                    apply_adam_vec_update(
1828                        self.ln_out_w.as_mut_slice(),
1829                        grads.ln_out_w.as_slice(),
1830                        &mut adam.ln_out_w,
1831                        cfg,
1832                    );
1833                    apply_adam_vec_update(
1834                        self.ln_out_b.as_mut_slice(),
1835                        grads.ln_out_b.as_slice(),
1836                        &mut adam.ln_out_b,
1837                        cfg,
1838                    );
1839                }
1840            }
1841        }
1842
1843        for layer_idx in 0..self.cfg.num_layers {
1844            let block = &mut self.blocks[layer_idx];
1845            let grad = &grads.blocks[layer_idx];
1846            match optimizer {
1847                OptimizerKind::Sgd => {
1848                    if scope.ffn {
1849                        sgd_vec_update(
1850                            block.ffn.x_k.as_mut_slice(),
1851                            grad.ffn.x_k.as_slice(),
1852                            lr,
1853                            clip,
1854                        );
1855                        sgd_vec_update(
1856                            block.ffn.key_w.as_mut_slice(),
1857                            grad.ffn.key_w.as_slice(),
1858                            lr,
1859                            clip,
1860                        );
1861                        sgd_vec_update(
1862                            block.ffn.value_w.as_mut_slice(),
1863                            grad.ffn.value_w.as_slice(),
1864                            lr,
1865                            clip,
1866                        );
1867                    }
1868                    if scope.ffn_norm {
1869                        sgd_vec_update(
1870                            block.ffn_norm_w.as_mut_slice(),
1871                            grad.ffn_norm_w.as_slice(),
1872                            lr,
1873                            clip,
1874                        );
1875                        sgd_vec_update(
1876                            block.ffn_norm_b.as_mut_slice(),
1877                            grad.ffn_norm_b.as_slice(),
1878                            lr,
1879                            clip,
1880                        );
1881                    }
1882                    if scope.attn {
1883                        sgd_vec_update(
1884                            block.attn.o_proj.as_mut_slice(),
1885                            grad.attn.o_proj.as_slice(),
1886                            lr,
1887                            clip,
1888                        );
1889                        sgd_vec_update(
1890                            block.attn.r_k.as_mut_slice(),
1891                            grad.attn.r_k.as_slice(),
1892                            lr,
1893                            clip,
1894                        );
1895                        sgd_vec_update(
1896                            block.attn.g_norm_w.as_mut_slice(),
1897                            grad.attn.g_norm_w.as_slice(),
1898                            lr,
1899                            clip,
1900                        );
1901                        sgd_vec_update(
1902                            block.attn.g_norm_b.as_mut_slice(),
1903                            grad.attn.g_norm_b.as_slice(),
1904                            lr,
1905                            clip,
1906                        );
1907                        sgd_vec_update(
1908                            block.attn.k_a.as_mut_slice(),
1909                            grad.attn.k_a.as_slice(),
1910                            lr,
1911                            clip,
1912                        );
1913                        sgd_vec_update(
1914                            block.attn.k_k.as_mut_slice(),
1915                            grad.attn.k_k.as_slice(),
1916                            lr,
1917                            clip,
1918                        );
1919                        sgd_vec_update(
1920                            block.attn.rkv_proj.as_mut_slice(),
1921                            grad.attn.rkv_proj.as_slice(),
1922                            lr,
1923                            clip,
1924                        );
1925                        sgd_vec_update(
1926                            block.attn.w0.as_mut_slice(),
1927                            grad.attn.w0.as_slice(),
1928                            lr,
1929                            clip,
1930                        );
1931                        sgd_vec_update(
1932                            block.attn.w2.as_mut_slice(),
1933                            grad.attn.w2.as_slice(),
1934                            lr,
1935                            clip,
1936                        );
1937                        sgd_vec_update(
1938                            block.attn.w1.as_mut_slice(),
1939                            grad.attn.w1.as_slice(),
1940                            lr,
1941                            clip,
1942                        );
1943                        sgd_vec_update(
1944                            block.attn.a0.as_mut_slice(),
1945                            grad.attn.a0.as_slice(),
1946                            lr,
1947                            clip,
1948                        );
1949                        sgd_vec_update(
1950                            block.attn.a2.as_mut_slice(),
1951                            grad.attn.a2.as_slice(),
1952                            lr,
1953                            clip,
1954                        );
1955                        sgd_vec_update(
1956                            block.attn.a1.as_mut_slice(),
1957                            grad.attn.a1.as_slice(),
1958                            lr,
1959                            clip,
1960                        );
1961                        sgd_vec_update(
1962                            block.attn.g2.as_mut_slice(),
1963                            grad.attn.g2.as_slice(),
1964                            lr,
1965                            clip,
1966                        );
1967                        sgd_vec_update(
1968                            block.attn.g1.as_mut_slice(),
1969                            grad.attn.g1.as_slice(),
1970                            lr,
1971                            clip,
1972                        );
1973                        sgd_vec_update(
1974                            block.attn.x_r.as_mut_slice(),
1975                            grad.attn.x_r.as_slice(),
1976                            lr,
1977                            clip,
1978                        );
1979                        sgd_vec_update(
1980                            block.attn.x_w.as_mut_slice(),
1981                            grad.attn.x_w.as_slice(),
1982                            lr,
1983                            clip,
1984                        );
1985                        sgd_vec_update(
1986                            block.attn.x_k.as_mut_slice(),
1987                            grad.attn.x_k.as_slice(),
1988                            lr,
1989                            clip,
1990                        );
1991                        sgd_vec_update(
1992                            block.attn.x_v.as_mut_slice(),
1993                            grad.attn.x_v.as_slice(),
1994                            lr,
1995                            clip,
1996                        );
1997                        sgd_vec_update(
1998                            block.attn.x_a.as_mut_slice(),
1999                            grad.attn.x_a.as_slice(),
2000                            lr,
2001                            clip,
2002                        );
2003                        sgd_vec_update(
2004                            block.attn.x_g.as_mut_slice(),
2005                            grad.attn.x_g.as_slice(),
2006                            lr,
2007                            clip,
2008                        );
2009                        if let (Some(v1), Some(gv1)) =
2010                            (block.attn.v1.as_mut(), grad.attn.v1.as_ref())
2011                        {
2012                            sgd_vec_update(v1.as_mut_slice(), gv1.as_slice(), lr, clip);
2013                        }
2014                        if let (Some(v2), Some(gv2)) =
2015                            (block.attn.v2.as_mut(), grad.attn.v2.as_ref())
2016                        {
2017                            sgd_vec_update(v2.as_mut_slice(), gv2.as_slice(), lr, clip);
2018                        }
2019                        if let (Some(v0), Some(gv0)) =
2020                            (block.attn.v0.as_mut(), grad.attn.v0.as_ref())
2021                        {
2022                            sgd_vec_update(v0.as_mut_slice(), gv0.as_slice(), lr, clip);
2023                        }
2024                    }
2025                    if scope.attn_norm {
2026                        sgd_vec_update(
2027                            block.attn_norm_w.as_mut_slice(),
2028                            grad.attn_norm_w.as_slice(),
2029                            lr,
2030                            clip,
2031                        );
2032                        sgd_vec_update(
2033                            block.attn_norm_b.as_mut_slice(),
2034                            grad.attn_norm_b.as_slice(),
2035                            lr,
2036                            clip,
2037                        );
2038                    }
2039                    if scope.pre_norm
2040                        && let (Some(w), Some(gw)) =
2041                            (block.pre_norm_w.as_mut(), grad.pre_norm_w.as_ref())
2042                    {
2043                        sgd_vec_update(w.as_mut_slice(), gw.as_slice(), lr, clip);
2044                    }
2045                    if scope.pre_norm
2046                        && let (Some(b), Some(gb)) =
2047                            (block.pre_norm_b.as_mut(), grad.pre_norm_b.as_ref())
2048                    {
2049                        sgd_vec_update(b.as_mut_slice(), gb.as_slice(), lr, clip);
2050                    }
2051                }
2052                OptimizerKind::Adam => {
2053                    let cfg = adam_step.as_ref().expect("adam cfg initialized");
2054                    let adam =
2055                        &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
2056                    if scope.ffn {
2057                        apply_adam_vec_update(
2058                            block.ffn.x_k.as_mut_slice(),
2059                            grad.ffn.x_k.as_slice(),
2060                            &mut adam.ffn.x_k,
2061                            cfg,
2062                        );
2063                        apply_adam_vec_update(
2064                            block.ffn.key_w.as_mut_slice(),
2065                            grad.ffn.key_w.as_slice(),
2066                            &mut adam.ffn.key_w,
2067                            cfg,
2068                        );
2069                        apply_adam_vec_update(
2070                            block.ffn.value_w.as_mut_slice(),
2071                            grad.ffn.value_w.as_slice(),
2072                            &mut adam.ffn.value_w,
2073                            cfg,
2074                        );
2075                    }
2076                    if scope.ffn_norm {
2077                        apply_adam_vec_update(
2078                            block.ffn_norm_w.as_mut_slice(),
2079                            grad.ffn_norm_w.as_slice(),
2080                            &mut adam.ffn_norm_w,
2081                            cfg,
2082                        );
2083                        apply_adam_vec_update(
2084                            block.ffn_norm_b.as_mut_slice(),
2085                            grad.ffn_norm_b.as_slice(),
2086                            &mut adam.ffn_norm_b,
2087                            cfg,
2088                        );
2089                    }
2090                    if scope.attn {
2091                        apply_adam_vec_update(
2092                            block.attn.o_proj.as_mut_slice(),
2093                            grad.attn.o_proj.as_slice(),
2094                            &mut adam.attn.o_proj,
2095                            cfg,
2096                        );
2097                        apply_adam_vec_update(
2098                            block.attn.r_k.as_mut_slice(),
2099                            grad.attn.r_k.as_slice(),
2100                            &mut adam.attn.r_k,
2101                            cfg,
2102                        );
2103                        apply_adam_vec_update(
2104                            block.attn.g_norm_w.as_mut_slice(),
2105                            grad.attn.g_norm_w.as_slice(),
2106                            &mut adam.attn.g_norm_w,
2107                            cfg,
2108                        );
2109                        apply_adam_vec_update(
2110                            block.attn.g_norm_b.as_mut_slice(),
2111                            grad.attn.g_norm_b.as_slice(),
2112                            &mut adam.attn.g_norm_b,
2113                            cfg,
2114                        );
2115                        apply_adam_vec_update(
2116                            block.attn.k_a.as_mut_slice(),
2117                            grad.attn.k_a.as_slice(),
2118                            &mut adam.attn.k_a,
2119                            cfg,
2120                        );
2121                        apply_adam_vec_update(
2122                            block.attn.k_k.as_mut_slice(),
2123                            grad.attn.k_k.as_slice(),
2124                            &mut adam.attn.k_k,
2125                            cfg,
2126                        );
2127                        apply_adam_vec_update(
2128                            block.attn.rkv_proj.as_mut_slice(),
2129                            grad.attn.rkv_proj.as_slice(),
2130                            &mut adam.attn.rkv_proj,
2131                            cfg,
2132                        );
2133                        apply_adam_vec_update(
2134                            block.attn.w0.as_mut_slice(),
2135                            grad.attn.w0.as_slice(),
2136                            &mut adam.attn.w0,
2137                            cfg,
2138                        );
2139                        apply_adam_vec_update(
2140                            block.attn.w2.as_mut_slice(),
2141                            grad.attn.w2.as_slice(),
2142                            &mut adam.attn.w2,
2143                            cfg,
2144                        );
2145                        apply_adam_vec_update(
2146                            block.attn.w1.as_mut_slice(),
2147                            grad.attn.w1.as_slice(),
2148                            &mut adam.attn.w1,
2149                            cfg,
2150                        );
2151                        apply_adam_vec_update(
2152                            block.attn.a0.as_mut_slice(),
2153                            grad.attn.a0.as_slice(),
2154                            &mut adam.attn.a0,
2155                            cfg,
2156                        );
2157                        apply_adam_vec_update(
2158                            block.attn.a2.as_mut_slice(),
2159                            grad.attn.a2.as_slice(),
2160                            &mut adam.attn.a2,
2161                            cfg,
2162                        );
2163                        apply_adam_vec_update(
2164                            block.attn.a1.as_mut_slice(),
2165                            grad.attn.a1.as_slice(),
2166                            &mut adam.attn.a1,
2167                            cfg,
2168                        );
2169                        apply_adam_vec_update(
2170                            block.attn.g2.as_mut_slice(),
2171                            grad.attn.g2.as_slice(),
2172                            &mut adam.attn.g2,
2173                            cfg,
2174                        );
2175                        apply_adam_vec_update(
2176                            block.attn.g1.as_mut_slice(),
2177                            grad.attn.g1.as_slice(),
2178                            &mut adam.attn.g1,
2179                            cfg,
2180                        );
2181                        apply_adam_vec_update(
2182                            block.attn.x_r.as_mut_slice(),
2183                            grad.attn.x_r.as_slice(),
2184                            &mut adam.attn.x_r,
2185                            cfg,
2186                        );
2187                        apply_adam_vec_update(
2188                            block.attn.x_w.as_mut_slice(),
2189                            grad.attn.x_w.as_slice(),
2190                            &mut adam.attn.x_w,
2191                            cfg,
2192                        );
2193                        apply_adam_vec_update(
2194                            block.attn.x_k.as_mut_slice(),
2195                            grad.attn.x_k.as_slice(),
2196                            &mut adam.attn.x_k,
2197                            cfg,
2198                        );
2199                        apply_adam_vec_update(
2200                            block.attn.x_v.as_mut_slice(),
2201                            grad.attn.x_v.as_slice(),
2202                            &mut adam.attn.x_v,
2203                            cfg,
2204                        );
2205                        apply_adam_vec_update(
2206                            block.attn.x_a.as_mut_slice(),
2207                            grad.attn.x_a.as_slice(),
2208                            &mut adam.attn.x_a,
2209                            cfg,
2210                        );
2211                        apply_adam_vec_update(
2212                            block.attn.x_g.as_mut_slice(),
2213                            grad.attn.x_g.as_slice(),
2214                            &mut adam.attn.x_g,
2215                            cfg,
2216                        );
2217                        if let (Some(v1), Some(gv1), Some(av1)) = (
2218                            block.attn.v1.as_mut(),
2219                            grad.attn.v1.as_ref(),
2220                            adam.attn.v1.as_mut(),
2221                        ) {
2222                            apply_adam_vec_update(v1.as_mut_slice(), gv1.as_slice(), av1, cfg);
2223                        }
2224                        if let (Some(v2), Some(gv2), Some(av2)) = (
2225                            block.attn.v2.as_mut(),
2226                            grad.attn.v2.as_ref(),
2227                            adam.attn.v2.as_mut(),
2228                        ) {
2229                            apply_adam_vec_update(v2.as_mut_slice(), gv2.as_slice(), av2, cfg);
2230                        }
2231                        if let (Some(v0), Some(gv0), Some(av0)) = (
2232                            block.attn.v0.as_mut(),
2233                            grad.attn.v0.as_ref(),
2234                            adam.attn.v0.as_mut(),
2235                        ) {
2236                            apply_adam_vec_update(v0.as_mut_slice(), gv0.as_slice(), av0, cfg);
2237                        }
2238                    }
2239                    if scope.attn_norm {
2240                        apply_adam_vec_update(
2241                            block.attn_norm_w.as_mut_slice(),
2242                            grad.attn_norm_w.as_slice(),
2243                            &mut adam.attn_norm_w,
2244                            cfg,
2245                        );
2246                        apply_adam_vec_update(
2247                            block.attn_norm_b.as_mut_slice(),
2248                            grad.attn_norm_b.as_slice(),
2249                            &mut adam.attn_norm_b,
2250                            cfg,
2251                        );
2252                    }
2253                    if scope.pre_norm
2254                        && let (Some(w), Some(gw), Some(aw)) = (
2255                            block.pre_norm_w.as_mut(),
2256                            grad.pre_norm_w.as_ref(),
2257                            adam.pre_norm_w.as_mut(),
2258                        )
2259                    {
2260                        apply_adam_vec_update(w.as_mut_slice(), gw.as_slice(), aw, cfg);
2261                    }
2262                    if scope.pre_norm
2263                        && let (Some(b), Some(gb), Some(ab)) = (
2264                            block.pre_norm_b.as_mut(),
2265                            grad.pre_norm_b.as_ref(),
2266                            adam.pre_norm_b.as_mut(),
2267                        )
2268                    {
2269                        apply_adam_vec_update(b.as_mut_slice(), gb.as_slice(), ab, cfg);
2270                    }
2271                }
2272            }
2273        }
2274
2275        if scope.embed {
2276            match optimizer {
2277                OptimizerKind::Sgd => {
2278                    sgd_vec_update(
2279                        self.embeddings.as_mut_slice(),
2280                        grads.embeddings.as_slice(),
2281                        lr,
2282                        clip,
2283                    );
2284                }
2285                OptimizerKind::Adam => {
2286                    let cfg = adam_step.as_ref().expect("adam cfg initialized");
2287                    let adam = model_adam.as_mut().expect("adam state exists");
2288                    apply_adam_vec_update(
2289                        self.embeddings.as_mut_slice(),
2290                        grads.embeddings.as_slice(),
2291                        &mut adam.embeddings,
2292                        cfg,
2293                    );
2294                }
2295            }
2296        }
2297        Ok(())
2298    }
2299
2300    #[allow(clippy::needless_range_loop)]
2301    fn accumulate_token_step_gradients(
2302        &self,
2303        scratch: &mut ScratchBuffers,
2304        trace: &TokenTrainTrace,
2305        state_new: &State,
2306        symbol: u8,
2307        pdf: &[f64],
2308        grad_scale: f32,
2309        scope: TrainScopeMask,
2310        grads: &mut FullGradState,
2311        out_bias_grad: Option<&mut [f32]>,
2312        future: &mut RecurrentGradState,
2313    ) -> Result<()> {
2314        let c = self.cfg.hidden_size;
2315        let h = self.cfg.num_heads;
2316        let n = self.cfg.head_dim;
2317        let i = self.cfg.intermediate_size;
2318        let d_w = self.cfg.decay_low_rank;
2319        let d_a = self.cfg.a_low_rank;
2320        let d_v = self.cfg.v_low_rank;
2321        let d_g = self.cfg.g_low_rank;
2322        let vocab = self.cfg.vocab_size.min(pdf.len());
2323        if vocab == 0 {
2324            return Ok(());
2325        }
2326
2327        scratch.grad_logits.zero();
2328        for idx in 0..vocab {
2329            let p = pdf[idx].clamp(1e-12, 1.0) as f32;
2330            let target = if idx == symbol as usize { 1.0 } else { 0.0 };
2331            scratch.grad_logits[idx] = (target - p) * grad_scale;
2332        }
2333
2334        if scope.bias
2335            && let Some(bias_grad) = out_bias_grad
2336        {
2337            add_vec_grad(
2338                &mut bias_grad[0..vocab],
2339                &scratch.grad_logits.as_slice()[0..vocab],
2340            );
2341        }
2342
2343        scratch.grad_x.zero();
2344        if scope.head {
2345            add_outer_grad(
2346                grads.lm_head.as_mut_slice(),
2347                vocab,
2348                c,
2349                &scratch.grad_logits.as_slice()[0..vocab],
2350                trace.x_normed.as_slice(),
2351            );
2352        }
2353        for row in 0..vocab {
2354            let g = scratch.grad_logits[row];
2355            if g == 0.0 {
2356                continue;
2357            }
2358            let row_off = row * c;
2359            for col in 0..c {
2360                scratch.grad_x[col] += self.lm_head[row_off + col] * g;
2361            }
2362        }
2363
2364        let needs_backprop = scope.trains_non_head_params() || scope.head;
2365        if !needs_backprop {
2366            return Ok(());
2367        }
2368
2369        layer_norm_backward(
2370            trace.x.as_slice(),
2371            self.ln_out_w.as_slice(),
2372            scratch.grad_x.as_slice(),
2373            self.cfg.layer_norm_eps,
2374            scratch.grad_x2.as_mut_slice(),
2375            scratch.grad_x3.as_mut_slice(),
2376            scratch.grad_x4.as_mut_slice(),
2377        );
2378        if scope.head {
2379            add_vec_grad(grads.ln_out_w.as_mut_slice(), scratch.grad_x3.as_slice());
2380            add_vec_grad(grads.ln_out_b.as_mut_slice(), scratch.grad_x4.as_slice());
2381        }
2382        scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice());
2383        scratch.grad_v_first.zero();
2384
2385        for layer_idx in (0..self.cfg.num_layers).rev() {
2386            let tr = &trace.layers[layer_idx];
2387            let block = &self.blocks[layer_idx];
2388            let block_grads = &mut grads.blocks[layer_idx];
2389            let future_layer = &mut future.layers[layer_idx];
2390
2391            scratch.grad_x2.copy_from_slice(scratch.grad_x.as_slice());
2392            scratch.grad_x3.copy_from_slice(scratch.grad_x.as_slice());
2393
2394            unsafe {
2395                kernel::gemv_t_avx(
2396                    block.ffn.value_w.as_ptr(),
2397                    scratch.grad_x3.as_ptr(),
2398                    scratch.grad_ffn.as_mut_ptr(),
2399                    c,
2400                    i,
2401                );
2402            }
2403            if scope.ffn {
2404                add_outer_grad(
2405                    block_grads.ffn.value_w.as_mut_slice(),
2406                    c,
2407                    i,
2408                    scratch.grad_x3.as_slice(),
2409                    tr.ffn_k.as_slice(),
2410                );
2411            }
2412
2413            for col in 0..i {
2414                let pre = tr.ffn_pre[col];
2415                scratch.grad_ffn2[col] = if pre > 0.0 {
2416                    scratch.grad_ffn[col] * (2.0 * pre)
2417                } else {
2418                    0.0
2419                };
2420            }
2421
2422            unsafe {
2423                kernel::gemv_t_avx(
2424                    block.ffn.key_w.as_ptr(),
2425                    scratch.grad_ffn2.as_ptr(),
2426                    scratch.grad_x4.as_mut_ptr(),
2427                    i,
2428                    c,
2429                );
2430            }
2431            if scope.ffn {
2432                add_outer_grad(
2433                    block_grads.ffn.key_w.as_mut_slice(),
2434                    i,
2435                    c,
2436                    scratch.grad_ffn2.as_slice(),
2437                    tr.ffn_xk.as_slice(),
2438                );
2439            }
2440
2441            scratch
2442                .grad_x5
2443                .copy_from_slice(future_layer.ffn_x_prev.as_slice());
2444            future_layer.ffn_x_prev.zero();
2445            for col in 0..c {
2446                let g = scratch.grad_x4[col];
2447                let mix = block.ffn.x_k[col];
2448                let base = tr.ffn_norm[col];
2449                let prev = tr.ffn_x_prev_old[col];
2450                scratch.grad_x5[col] += g * (1.0 - mix);
2451                future_layer.ffn_x_prev[col] = g * mix;
2452                scratch.grad_param[col] = g * (prev - base);
2453            }
2454            if scope.ffn {
2455                add_vec_grad(
2456                    block_grads.ffn.x_k.as_mut_slice(),
2457                    scratch.grad_param.as_slice(),
2458                );
2459            }
2460
2461            layer_norm_backward(
2462                tr.x_after_attn.as_slice(),
2463                block.ffn_norm_w.as_slice(),
2464                scratch.grad_x5.as_slice(),
2465                self.cfg.layer_norm_eps,
2466                scratch.grad_x4.as_mut_slice(),
2467                scratch.grad_x3.as_mut_slice(),
2468                scratch.grad_x6.as_mut_slice(),
2469            );
2470            if scope.ffn_norm {
2471                add_vec_grad(
2472                    block_grads.ffn_norm_w.as_mut_slice(),
2473                    scratch.grad_x3.as_slice(),
2474                );
2475                add_vec_grad(
2476                    block_grads.ffn_norm_b.as_mut_slice(),
2477                    scratch.grad_x6.as_slice(),
2478                );
2479            }
2480            for col in 0..c {
2481                scratch.grad_x2[col] += scratch.grad_x4[col];
2482            }
2483
2484            scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice());
2485            scratch.grad_x3.copy_from_slice(scratch.grad_x2.as_slice());
2486
2487            unsafe {
2488                kernel::gemv_t_avx(
2489                    block.attn.o_proj.as_ptr(),
2490                    scratch.grad_x3.as_ptr(),
2491                    scratch.grad_x4.as_mut_ptr(),
2492                    c,
2493                    c,
2494                );
2495            }
2496            if scope.attn {
2497                add_outer_grad(
2498                    block_grads.attn.o_proj.as_mut_slice(),
2499                    c,
2500                    c,
2501                    scratch.grad_x3.as_slice(),
2502                    tr.y_gate.as_slice(),
2503                );
2504            }
2505
2506            for col in 0..c {
2507                let gy = scratch.grad_x4[col];
2508                scratch.grad_saved[col] = gy * tr.y_head[col];
2509                scratch.grad_x4[col] = gy * tr.g[col];
2510            }
2511
2512            scratch.grad_x2.zero();
2513            scratch.grad_x3.zero();
2514            scratch.grad_x6.zero();
2515            scratch.grad_param.zero();
2516            for head_idx in 0..h {
2517                let off = head_idx * n;
2518                let mut g_alpha = 0.0f32;
2519                for j in 0..n {
2520                    let g = scratch.grad_x4[off + j];
2521                    g_alpha += g * tr.v[off + j];
2522                    scratch.grad_x6[off + j] += g * tr.alpha[head_idx];
2523                }
2524                for j in 0..n {
2525                    let idx = off + j;
2526                    let rk = block.attn.r_k[idx];
2527                    let rv = tr.r[idx];
2528                    let kv = tr.k[idx];
2529                    let g = g_alpha * rk;
2530                    scratch.grad_x2[idx] += g * kv;
2531                    scratch.grad_x3[idx] += g * rv;
2532                    scratch.grad_param[idx] += g_alpha * rv * kv;
2533                }
2534            }
2535            if scope.attn {
2536                add_vec_grad(
2537                    block_grads.attn.r_k.as_mut_slice(),
2538                    scratch.grad_param.as_slice(),
2539                );
2540            }
2541
2542            scratch.grad_x5.as_mut_slice()[0..c].copy_from_slice(&scratch.grad_x4.as_slice()[0..c]);
2543            group_norm_backward(
2544                tr.y_wkv.as_slice(),
2545                block.attn.g_norm_w.as_slice(),
2546                scratch.grad_x5.as_slice(),
2547                h,
2548                n,
2549                self.cfg.group_norm_eps,
2550                scratch.grad_x4.as_mut_slice(),
2551                scratch.grad_param.as_mut_slice(),
2552                scratch.grad_param2.as_mut_slice(),
2553            );
2554            if scope.attn {
2555                add_vec_grad(
2556                    block_grads.attn.g_norm_w.as_mut_slice(),
2557                    scratch.grad_param.as_slice(),
2558                );
2559                add_vec_grad(
2560                    block_grads.attn.g_norm_b.as_mut_slice(),
2561                    scratch.grad_param2.as_slice(),
2562                );
2563            }
2564
2565            scratch.grad_param.zero();
2566            scratch.grad_x5.zero();
2567            scratch.grad_param2.zero();
2568            scratch
2569                .grad_att_state
2570                .copy_from_slice(future_layer.att_state.as_slice());
2571            future_layer.att_state.zero();
2572            let s_old = tr.att_state_old.as_slice();
2573            let s_new = state_new.layers[layer_idx].att_state.as_slice();
2574            for head_idx in 0..h {
2575                let off = head_idx * n;
2576                let s_off = head_idx * n * n;
2577                let grad_y = &scratch.grad_x4.as_slice()[off..off + n];
2578                let r_head = &tr.r.as_slice()[off..off + n];
2579                let k_head = &tr.k.as_slice()[off..off + n];
2580                let kk_head = &tr.kk.as_slice()[off..off + n];
2581                let a_head = &tr.a.as_slice()[off..off + n];
2582                let v_head = &tr.v.as_slice()[off..off + n];
2583                let w_head = &tr.w_decay.as_slice()[off..off + n];
2584
2585                unsafe {
2586                    kernel::gemv_t_avx(
2587                        s_new.as_ptr().add(s_off),
2588                        grad_y.as_ptr(),
2589                        scratch.grad_low_rank.as_mut_ptr(),
2590                        n,
2591                        n,
2592                    );
2593                }
2594                for j in 0..n {
2595                    scratch.grad_x2[off + j] += scratch.grad_low_rank[j];
2596                }
2597
2598                let g_state = &mut scratch.grad_att_state.as_mut_slice()[s_off..s_off + n * n];
2599                for irow in 0..n {
2600                    let gy = grad_y[irow];
2601                    let row_off = irow * n;
2602                    for j in 0..n {
2603                        g_state[row_off + j] += gy * r_head[j];
2604                    }
2605                }
2606
2607                unsafe {
2608                    kernel::gemv_avx(
2609                        s_old.as_ptr().add(s_off),
2610                        kk_head.as_ptr(),
2611                        scratch.grad_low_rank.as_mut_ptr(),
2612                        n,
2613                        n,
2614                    );
2615                }
2616                let u = &scratch.grad_low_rank.as_slice()[0..n];
2617
2618                for j in 0..n {
2619                    let mut grad_w = 0.0f32;
2620                    let mut grad_k = 0.0f32;
2621                    let mut grad_b = 0.0f32;
2622                    for irow in 0..n {
2623                        let g = g_state[irow * n + j];
2624                        grad_w += g * s_old[s_off + irow * n + j];
2625                        grad_k += g * v_head[irow];
2626                        grad_b -= g * u[irow];
2627                        future_layer.att_state[s_off + irow * n + j] = g * w_head[j];
2628                    }
2629                    scratch.grad_param[off + j] += grad_w;
2630                    scratch.grad_x3[off + j] += grad_k;
2631                    scratch.grad_param2[off + j] += grad_b * a_head[j];
2632                    scratch.grad_x5[off + j] += grad_b * kk_head[j];
2633                }
2634
2635                for irow in 0..n {
2636                    let mut grad_u = 0.0f32;
2637                    for j in 0..n {
2638                        grad_u -= g_state[irow * n + j] * kk_head[j] * a_head[j];
2639                    }
2640                    scratch.grad_low_rank2[irow] = grad_u;
2641                    let row_off = irow * n;
2642                    for j in 0..n {
2643                        future_layer.att_state[s_off + row_off + j] += grad_u * kk_head[j];
2644                    }
2645                }
2646                unsafe {
2647                    kernel::gemv_t_avx(
2648                        s_old.as_ptr().add(s_off),
2649                        scratch.grad_low_rank2.as_ptr(),
2650                        scratch.grad_low_rank.as_mut_ptr(),
2651                        n,
2652                        n,
2653                    );
2654                }
2655                for j in 0..n {
2656                    scratch.grad_param2[off + j] += scratch.grad_low_rank[j];
2657                }
2658
2659                for irow in 0..n {
2660                    let mut grad_v = 0.0f32;
2661                    for j in 0..n {
2662                        grad_v += g_state[irow * n + j] * k_head[j];
2663                    }
2664                    scratch.grad_x6[off + irow] += grad_v;
2665                }
2666            }
2667
2668            for col in 0..c {
2669                let gk = scratch.grad_x3[col];
2670                let scale = 1.0 + (tr.a[col] - 1.0) * block.attn.k_a[col];
2671                let d_scale = gk * tr.k_pre[col];
2672                scratch.grad_x3[col] = gk * scale;
2673                scratch.grad_x5[col] += d_scale * block.attn.k_a[col];
2674                scratch.grad_param[col] = d_scale * (tr.a[col] - 1.0);
2675            }
2676            for head_idx in 0..h {
2677                let off = head_idx * n;
2678                l2_normalize_backward(
2679                    &tr.kk_pre.as_slice()[off..off + n],
2680                    &tr.kk.as_slice()[off..off + n],
2681                    &scratch.grad_param2.as_slice()[off..off + n],
2682                    1e-12,
2683                    &mut scratch.grad_x4.as_mut_slice()[off..off + n],
2684                );
2685            }
2686            for col in 0..c {
2687                let g = scratch.grad_x4[col];
2688                scratch.grad_x3[col] += g * block.attn.k_k[col];
2689                scratch.grad_param2[col] = g * tr.k_pre[col];
2690            }
2691            if scope.attn {
2692                add_vec_grad(
2693                    block_grads.attn.k_a.as_mut_slice(),
2694                    scratch.grad_param.as_slice(),
2695                );
2696                add_vec_grad(
2697                    block_grads.attn.k_k.as_mut_slice(),
2698                    scratch.grad_param2.as_slice(),
2699                );
2700            }
2701
2702            scratch
2703                .grad_param2
2704                .copy_from_slice(scratch.grad_x6.as_slice());
2705            if layer_idx == 0 {
2706                for col in 0..c {
2707                    scratch.grad_x6[col] += scratch.grad_v_first[col];
2708                }
2709            } else if tr.uses_v_residual
2710                && block.attn.v1.is_some()
2711                && block.attn.v2.is_some()
2712                && block.attn.v0.is_some()
2713            {
2714                let v1 = block.attn.v1.as_ref().expect("v1");
2715                let v2 = block.attn.v2.as_ref().expect("v2");
2716                for col in 0..c {
2717                    let gv = scratch.grad_param2[col];
2718                    let nu = tr.nu[col];
2719                    scratch.grad_x6[col] = gv * (1.0 - nu);
2720                    scratch.grad_x3[col] = gv * (trace.v_first[col] - tr.v_pre[col]);
2721                    scratch.grad_v_first[col] += gv * nu;
2722                }
2723                for col in 0..c {
2724                    let nu = tr.nu[col];
2725                    scratch.grad_x3[col] *= nu * (1.0 - nu);
2726                }
2727                if scope.attn {
2728                    add_vec_grad(
2729                        block_grads
2730                            .attn
2731                            .v0
2732                            .as_mut()
2733                            .expect("grad v0")
2734                            .as_mut_slice(),
2735                        scratch.grad_x3.as_slice(),
2736                    );
2737                    add_outer_grad(
2738                        block_grads
2739                            .attn
2740                            .v2
2741                            .as_mut()
2742                            .expect("grad v2")
2743                            .as_mut_slice(),
2744                        c,
2745                        d_v,
2746                        scratch.grad_x3.as_slice(),
2747                        &tr.v_hidden.as_slice()[0..d_v],
2748                    );
2749                }
2750                unsafe {
2751                    kernel::gemv_t_avx(
2752                        v2.as_ptr(),
2753                        scratch.grad_x3.as_ptr(),
2754                        scratch.grad_low_rank.as_mut_ptr(),
2755                        c,
2756                        d_v,
2757                    );
2758                }
2759                if scope.attn {
2760                    add_outer_grad(
2761                        block_grads
2762                            .attn
2763                            .v1
2764                            .as_mut()
2765                            .expect("grad v1")
2766                            .as_mut_slice(),
2767                        d_v,
2768                        c,
2769                        &scratch.grad_low_rank.as_slice()[0..d_v],
2770                        tr.xv.as_slice(),
2771                    );
2772                }
2773                for col in 0..c {
2774                    let mut acc = 0.0f32;
2775                    for row in 0..d_v {
2776                        acc += v1[row * c + col] * scratch.grad_low_rank[row];
2777                    }
2778                    scratch.grad_x4[col] += acc;
2779                }
2780            }
2781
2782            let proj_size = c * c;
2783            if scope.attn {
2784                add_outer_grad(
2785                    &mut block_grads.attn.rkv_proj.as_mut_slice()[0..proj_size],
2786                    c,
2787                    c,
2788                    scratch.grad_x2.as_slice(),
2789                    tr.xr.as_slice(),
2790                );
2791                add_outer_grad(
2792                    &mut block_grads.attn.rkv_proj.as_mut_slice()[proj_size..2 * proj_size],
2793                    c,
2794                    c,
2795                    scratch.grad_x3.as_slice(),
2796                    tr.xk.as_slice(),
2797                );
2798                add_outer_grad(
2799                    &mut block_grads.attn.rkv_proj.as_mut_slice()[2 * proj_size..3 * proj_size],
2800                    c,
2801                    c,
2802                    scratch.grad_x6.as_slice(),
2803                    tr.xv.as_slice(),
2804                );
2805            }
2806            let proj = block.attn.rkv_proj.as_slice();
2807            unsafe {
2808                kernel::gemv_t_avx(
2809                    proj.as_ptr(),
2810                    scratch.grad_x2.as_ptr(),
2811                    scratch.grad_param.as_mut_ptr(),
2812                    c,
2813                    c,
2814                );
2815                kernel::gemv_t_avx(
2816                    proj.as_ptr().add(proj_size),
2817                    scratch.grad_x3.as_ptr(),
2818                    scratch.grad_param2.as_mut_ptr(),
2819                    c,
2820                    c,
2821                );
2822                kernel::gemv_t_avx(
2823                    proj.as_ptr().add(2 * proj_size),
2824                    scratch.grad_x6.as_ptr(),
2825                    scratch.grad_x4.as_mut_ptr(),
2826                    c,
2827                    c,
2828                );
2829            }
2830
2831            let inv_sqrt_e = 1.0 / std::f32::consts::E.sqrt();
2832            for col in 0..c {
2833                let sig = tr.w_sigmoid[col];
2834                let d_sig = scratch.grad_param[col] * (-inv_sqrt_e) * tr.w_decay[col];
2835                scratch.grad_param[col] = d_sig * sig * (1.0 - sig);
2836            }
2837            if scope.attn {
2838                add_vec_grad(
2839                    block_grads.attn.w0.as_mut_slice(),
2840                    scratch.grad_param.as_slice(),
2841                );
2842                add_outer_grad(
2843                    block_grads.attn.w2.as_mut_slice(),
2844                    c,
2845                    d_w,
2846                    scratch.grad_param.as_slice(),
2847                    &tr.w_hidden.as_slice()[0..d_w],
2848                );
2849            }
2850            unsafe {
2851                kernel::gemv_t_avx(
2852                    block.attn.w2.as_ptr(),
2853                    scratch.grad_param.as_ptr(),
2854                    scratch.grad_low_rank.as_mut_ptr(),
2855                    c,
2856                    d_w,
2857                );
2858            }
2859            for col in 0..d_w {
2860                let t = tr.w_hidden[col];
2861                scratch.grad_low_rank[col] *= 1.0 - t * t;
2862            }
2863            if scope.attn {
2864                add_outer_grad(
2865                    block_grads.attn.w1.as_mut_slice(),
2866                    d_w,
2867                    c,
2868                    &scratch.grad_low_rank.as_slice()[0..d_w],
2869                    tr.xw.as_slice(),
2870                );
2871            }
2872            unsafe {
2873                kernel::gemv_t_avx(
2874                    block.attn.w1.as_ptr(),
2875                    scratch.grad_low_rank.as_ptr(),
2876                    scratch.grad_x6.as_mut_ptr(),
2877                    d_w,
2878                    c,
2879                );
2880            }
2881
2882            for col in 0..c {
2883                let a = tr.a[col];
2884                scratch.grad_x5[col] *= a * (1.0 - a);
2885            }
2886            if scope.attn {
2887                add_vec_grad(
2888                    block_grads.attn.a0.as_mut_slice(),
2889                    scratch.grad_x5.as_slice(),
2890                );
2891                add_outer_grad(
2892                    block_grads.attn.a2.as_mut_slice(),
2893                    c,
2894                    d_a,
2895                    scratch.grad_x5.as_slice(),
2896                    &tr.a_hidden.as_slice()[0..d_a],
2897                );
2898            }
2899            unsafe {
2900                kernel::gemv_t_avx(
2901                    block.attn.a2.as_ptr(),
2902                    scratch.grad_x5.as_ptr(),
2903                    scratch.grad_low_rank.as_mut_ptr(),
2904                    c,
2905                    d_a,
2906                );
2907            }
2908            if scope.attn {
2909                add_outer_grad(
2910                    block_grads.attn.a1.as_mut_slice(),
2911                    d_a,
2912                    c,
2913                    &scratch.grad_low_rank.as_slice()[0..d_a],
2914                    tr.xa.as_slice(),
2915                );
2916            }
2917            unsafe {
2918                kernel::gemv_t_avx(
2919                    block.attn.a1.as_ptr(),
2920                    scratch.grad_low_rank.as_ptr(),
2921                    scratch.grad_x5.as_mut_ptr(),
2922                    d_a,
2923                    c,
2924                );
2925            }
2926
2927            if scope.attn {
2928                add_outer_grad(
2929                    block_grads.attn.g2.as_mut_slice(),
2930                    c,
2931                    d_g,
2932                    scratch.grad_saved.as_slice(),
2933                    &tr.g_hidden.as_slice()[0..d_g],
2934                );
2935            }
2936            unsafe {
2937                kernel::gemv_t_avx(
2938                    block.attn.g2.as_ptr(),
2939                    scratch.grad_saved.as_ptr(),
2940                    scratch.grad_low_rank.as_mut_ptr(),
2941                    c,
2942                    d_g,
2943                );
2944            }
2945            for col in 0..d_g {
2946                let sig = tr.g_hidden[col];
2947                scratch.grad_low_rank2[col] = scratch.grad_low_rank[col] * sig * (1.0 - sig);
2948            }
2949            if scope.attn {
2950                add_outer_grad(
2951                    block_grads.attn.g1.as_mut_slice(),
2952                    d_g,
2953                    c,
2954                    &scratch.grad_low_rank2.as_slice()[0..d_g],
2955                    tr.xg.as_slice(),
2956                );
2957            }
2958            unsafe {
2959                kernel::gemv_t_avx(
2960                    block.attn.g1.as_ptr(),
2961                    scratch.grad_low_rank2.as_ptr(),
2962                    scratch.grad_saved.as_mut_ptr(),
2963                    d_g,
2964                    c,
2965                );
2966            }
2967
2968            scratch
2969                .grad_x3
2970                .copy_from_slice(future_layer.att_x_prev.as_slice());
2971            future_layer.att_x_prev.zero();
2972
2973            for col in 0..c {
2974                let g = scratch.grad_param[col];
2975                let mix = block.attn.x_r[col];
2976                let base = tr.attn_norm[col];
2977                let prev = tr.att_x_prev_old[col];
2978                scratch.grad_x3[col] += g * (1.0 - mix);
2979                future_layer.att_x_prev[col] += g * mix;
2980                scratch.grad_x2[col] = g * (prev - base);
2981            }
2982            if scope.attn {
2983                add_vec_grad(
2984                    block_grads.attn.x_r.as_mut_slice(),
2985                    scratch.grad_x2.as_slice(),
2986                );
2987            }
2988
2989            for col in 0..c {
2990                let g = scratch.grad_x6[col];
2991                let mix = block.attn.x_w[col];
2992                let base = tr.attn_norm[col];
2993                let prev = tr.att_x_prev_old[col];
2994                scratch.grad_x3[col] += g * (1.0 - mix);
2995                future_layer.att_x_prev[col] += g * mix;
2996                scratch.grad_x2[col] = g * (prev - base);
2997            }
2998            if scope.attn {
2999                add_vec_grad(
3000                    block_grads.attn.x_w.as_mut_slice(),
3001                    scratch.grad_x2.as_slice(),
3002                );
3003            }
3004
3005            for col in 0..c {
3006                let g = scratch.grad_param2[col];
3007                let mix = block.attn.x_k[col];
3008                let base = tr.attn_norm[col];
3009                let prev = tr.att_x_prev_old[col];
3010                scratch.grad_x3[col] += g * (1.0 - mix);
3011                future_layer.att_x_prev[col] += g * mix;
3012                scratch.grad_x2[col] = g * (prev - base);
3013            }
3014            if scope.attn {
3015                add_vec_grad(
3016                    block_grads.attn.x_k.as_mut_slice(),
3017                    scratch.grad_x2.as_slice(),
3018                );
3019            }
3020
3021            for col in 0..c {
3022                let g = scratch.grad_x4[col];
3023                let mix = block.attn.x_v[col];
3024                let base = tr.attn_norm[col];
3025                let prev = tr.att_x_prev_old[col];
3026                scratch.grad_x3[col] += g * (1.0 - mix);
3027                future_layer.att_x_prev[col] += g * mix;
3028                scratch.grad_x2[col] = g * (prev - base);
3029            }
3030            if scope.attn {
3031                add_vec_grad(
3032                    block_grads.attn.x_v.as_mut_slice(),
3033                    scratch.grad_x2.as_slice(),
3034                );
3035            }
3036
3037            for col in 0..c {
3038                let g = scratch.grad_x5[col];
3039                let mix = block.attn.x_a[col];
3040                let base = tr.attn_norm[col];
3041                let prev = tr.att_x_prev_old[col];
3042                scratch.grad_x3[col] += g * (1.0 - mix);
3043                future_layer.att_x_prev[col] += g * mix;
3044                scratch.grad_x2[col] = g * (prev - base);
3045            }
3046            if scope.attn {
3047                add_vec_grad(
3048                    block_grads.attn.x_a.as_mut_slice(),
3049                    scratch.grad_x2.as_slice(),
3050                );
3051            }
3052
3053            for col in 0..c {
3054                let g = scratch.grad_saved[col];
3055                let mix = block.attn.x_g[col];
3056                let base = tr.attn_norm[col];
3057                let prev = tr.att_x_prev_old[col];
3058                scratch.grad_x3[col] += g * (1.0 - mix);
3059                future_layer.att_x_prev[col] += g * mix;
3060                scratch.grad_x2[col] = g * (prev - base);
3061            }
3062            if scope.attn {
3063                add_vec_grad(
3064                    block_grads.attn.x_g.as_mut_slice(),
3065                    scratch.grad_x2.as_slice(),
3066                );
3067            }
3068
3069            layer_norm_backward(
3070                tr.x_after_pre.as_slice(),
3071                block.attn_norm_w.as_slice(),
3072                scratch.grad_x3.as_slice(),
3073                self.cfg.layer_norm_eps,
3074                scratch.grad_x2.as_mut_slice(),
3075                scratch.grad_x4.as_mut_slice(),
3076                scratch.grad_x5.as_mut_slice(),
3077            );
3078            if scope.attn_norm {
3079                add_vec_grad(
3080                    block_grads.attn_norm_w.as_mut_slice(),
3081                    scratch.grad_x4.as_slice(),
3082                );
3083                add_vec_grad(
3084                    block_grads.attn_norm_b.as_mut_slice(),
3085                    scratch.grad_x5.as_slice(),
3086                );
3087            }
3088            for col in 0..c {
3089                scratch.grad_x[col] += scratch.grad_x2[col];
3090            }
3091
3092            if layer_idx == 0
3093                && let (Some(w), Some(_b)) = (&block.pre_norm_w, &block.pre_norm_b)
3094            {
3095                layer_norm_backward(
3096                    tr.x_in.as_slice(),
3097                    w.as_slice(),
3098                    scratch.grad_x.as_slice(),
3099                    self.cfg.layer_norm_eps,
3100                    scratch.grad_x2.as_mut_slice(),
3101                    scratch.grad_x3.as_mut_slice(),
3102                    scratch.grad_x4.as_mut_slice(),
3103                );
3104                if scope.pre_norm {
3105                    add_vec_grad(
3106                        block_grads
3107                            .pre_norm_w
3108                            .as_mut()
3109                            .expect("grad pre_norm_w")
3110                            .as_mut_slice(),
3111                        scratch.grad_x3.as_slice(),
3112                    );
3113                    add_vec_grad(
3114                        block_grads
3115                            .pre_norm_b
3116                            .as_mut()
3117                            .expect("grad pre_norm_b")
3118                            .as_mut_slice(),
3119                        scratch.grad_x4.as_slice(),
3120                    );
3121                }
3122                scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice());
3123            }
3124        }
3125
3126        if scope.embed {
3127            let token_idx = trace.token.min(self.cfg.vocab_size.saturating_sub(1));
3128            let off = token_idx * c;
3129            add_vec_grad(
3130                &mut grads.embeddings.as_mut_slice()[off..off + c],
3131                scratch.grad_x.as_slice(),
3132            );
3133        }
3134
3135        Ok(())
3136    }
3137
3138    #[allow(clippy::too_many_arguments)]
3139    /// Run one TBPTT training segment and write the resulting live state.
3140    pub fn online_train_segment_tbptt(
3141        &mut self,
3142        scratch: &mut ScratchBuffers,
3143        start_state: &State,
3144        steps: &[(u32, u8)],
3145        scope: TrainScopeMask,
3146        optimizer: OptimizerKind,
3147        lr: f32,
3148        clip: f32,
3149        replay_chunk: usize,
3150        adam_t: &mut usize,
3151        model_adam: Option<&mut FullAdamState>,
3152        out_bias: Option<&mut [f32]>,
3153        out_bias_adam_m: Option<&mut [f32]>,
3154        out_bias_adam_v: Option<&mut [f32]>,
3155        live_state_out: &mut State,
3156    ) -> Result<()> {
3157        if steps.is_empty() {
3158            *live_state_out = start_state.clone();
3159            return Ok(());
3160        }
3161
3162        let grad_scale = 1.0f32 / (steps.len() as f32);
3163        let chunk = replay_chunk.max(1).min(steps.len().max(1));
3164        let mut grads = self.new_full_grad_state();
3165        let mut recurrent = self.new_recurrent_grad_state();
3166        recurrent.zero();
3167        let mut bias_grad = out_bias.as_deref().map(|b| vec![0.0f32; b.len()]);
3168
3169        {
3170            let mut checkpoints = Vec::<State>::new();
3171            let mut checkpoint_state = start_state.clone();
3172            scratch.set_capture_train_trace(false);
3173            for chunk_start in (0..steps.len()).step_by(chunk) {
3174                checkpoints.push(checkpoint_state.clone());
3175                let chunk_end = (chunk_start + chunk).min(steps.len());
3176                for &(input_token, _) in &steps[chunk_start..chunk_end] {
3177                    self.forward(scratch, input_token, &mut checkpoint_state);
3178                }
3179            }
3180
3181            for chunk_idx in (0..checkpoints.len()).rev() {
3182                let chunk_start = chunk_idx * chunk;
3183                let chunk_end = (chunk_start + chunk).min(steps.len());
3184                let mut state = checkpoints[chunk_idx].clone();
3185                let mut step_states = Vec::<State>::with_capacity(chunk_end - chunk_start + 1);
3186                let mut step_traces =
3187                    Vec::<TokenTrainTrace>::with_capacity(chunk_end - chunk_start);
3188                let mut step_pdfs =
3189                    Vec::<Vec<f64>>::with_capacity(chunk_end.saturating_sub(chunk_start));
3190                step_states.push(state.clone());
3191
3192                for &(input_token, _) in &steps[chunk_start..chunk_end] {
3193                    scratch.set_capture_train_trace(true);
3194                    let logits = self.forward(scratch, input_token, &mut state);
3195                    let mut pdf = vec![0.0f64; self.cfg.vocab_size];
3196                    super::super::softmax_pdf_floor_with_bias(
3197                        logits,
3198                        out_bias.as_deref(),
3199                        &mut pdf,
3200                    );
3201                    step_pdfs.push(pdf);
3202                    step_traces.push(TokenTrainTrace::from_scratch(scratch));
3203                    step_states.push(state.clone());
3204                }
3205
3206                for local_idx in (0..step_traces.len()).rev() {
3207                    let (_, target_symbol) = steps[chunk_start + local_idx];
3208                    self.accumulate_token_step_gradients(
3209                        scratch,
3210                        &step_traces[local_idx],
3211                        &step_states[local_idx + 1],
3212                        target_symbol,
3213                        &step_pdfs[local_idx],
3214                        grad_scale,
3215                        scope,
3216                        &mut grads,
3217                        bias_grad.as_deref_mut(),
3218                        &mut recurrent,
3219                    )?;
3220                }
3221            }
3222        }
3223
3224        self.apply_full_gradients(
3225            &grads,
3226            scope,
3227            optimizer,
3228            lr,
3229            clip,
3230            adam_t,
3231            model_adam,
3232            out_bias,
3233            bias_grad.as_deref(),
3234            out_bias_adam_m,
3235            out_bias_adam_v,
3236        )?;
3237
3238        scratch.set_capture_train_trace(false);
3239        *live_state_out = start_state.clone();
3240        for &(input_token, _) in steps {
3241            self.forward(scratch, input_token, live_state_out);
3242        }
3243        Ok(())
3244    }
3245
3246    /// Perform one exact bptt=1 online training step over the latest forward trace.
3247    #[allow(clippy::too_many_arguments)]
3248    #[allow(clippy::needless_range_loop)]
3249    pub fn online_train_step_bptt1(
3250        &mut self,
3251        scratch: &mut ScratchBuffers,
3252        state: &State,
3253        symbol: u8,
3254        pdf: &[f64],
3255        scope: TrainScopeMask,
3256        optimizer: OptimizerKind,
3257        lr: f32,
3258        clip: f32,
3259        adam_t: &mut usize,
3260        model_adam: Option<&mut FullAdamState>,
3261        out_bias: Option<&mut [f32]>,
3262        out_bias_adam_m: Option<&mut [f32]>,
3263        out_bias_adam_v: Option<&mut [f32]>,
3264    ) -> Result<()> {
3265        if !scope.trains_any_params() {
3266            return Ok(());
3267        }
3268        if scope.trains_non_head_params() && !scratch.train_trace_valid {
3269            bail!("rwkv full training trace is missing; run one forward step first");
3270        }
3271        let c = self.cfg.hidden_size;
3272        let h = self.cfg.num_heads;
3273        let n = self.cfg.head_dim;
3274        let i = self.cfg.intermediate_size;
3275        let d_w = self.cfg.decay_low_rank;
3276        let d_a = self.cfg.a_low_rank;
3277        let d_v = self.cfg.v_low_rank;
3278        let d_g = self.cfg.g_low_rank;
3279        let vocab = self.cfg.vocab_size.min(pdf.len());
3280        if vocab == 0 {
3281            return Ok(());
3282        }
3283        let mut adam_step = None::<AdamStep>;
3284        let mut model_adam = model_adam;
3285        if matches!(optimizer, OptimizerKind::Adam) {
3286            *adam_t = adam_t.saturating_add(1);
3287            let t = (*adam_t).max(1) as i32;
3288            let b1 = 0.9f32;
3289            let b2 = 0.999f32;
3290            adam_step = Some(AdamStep {
3291                lr,
3292                clip: clip.max(0.0),
3293                b1,
3294                b2,
3295                eps: 1e-8,
3296                bias_corr1: 1.0 - b1.powi(t),
3297                bias_corr2: 1.0 - b2.powi(t),
3298            });
3299            if scope.trains_non_head_params() && model_adam.is_none() {
3300                bail!("rwkv Adam full-training state is missing");
3301            }
3302        }
3303
3304        scratch.grad_logits.zero();
3305        for idx in 0..vocab {
3306            let p = pdf[idx].clamp(1e-12, 1.0) as f32;
3307            let target = if idx == symbol as usize { 1.0 } else { 0.0 };
3308            let mut g = target - p;
3309            if clip > 0.0 {
3310                g = g.clamp(-clip, clip);
3311            }
3312            scratch.grad_logits[idx] = g;
3313        }
3314
3315        if scope.bias
3316            && let Some(bias) = out_bias
3317        {
3318            match optimizer {
3319                OptimizerKind::Sgd => {
3320                    for idx in 0..bias.len().min(vocab) {
3321                        bias[idx] += lr * scratch.grad_logits[idx];
3322                    }
3323                }
3324                OptimizerKind::Adam => {
3325                    let cfg = adam_step.as_ref().expect("adam cfg initialized");
3326                    let Some(m) = out_bias_adam_m else {
3327                        bail!("rwkv Adam output-bias state is missing (m)");
3328                    };
3329                    let Some(vv) = out_bias_adam_v else {
3330                        bail!("rwkv Adam output-bias state is missing (v)");
3331                    };
3332                    let n = bias.len().min(vocab);
3333                    apply_adam_vec_update_raw(
3334                        &mut bias[0..n],
3335                        &scratch.grad_logits.as_slice()[0..n],
3336                        &mut m[0..n],
3337                        &mut vv[0..n],
3338                        cfg,
3339                    );
3340                }
3341            }
3342        }
3343
3344        scratch.grad_x.zero();
3345        if scope.head {
3346            match optimizer {
3347                OptimizerKind::Sgd => {
3348                    fused_sgd_head_backward_update(
3349                        self.lm_head.as_mut_slice(),
3350                        vocab,
3351                        c,
3352                        &scratch.grad_logits.as_slice()[0..vocab],
3353                        scratch.x_normed.as_slice(),
3354                        scratch.grad_x.as_mut_slice(),
3355                        lr,
3356                        clip,
3357                    );
3358                }
3359                OptimizerKind::Adam => {
3360                    let cfg = adam_step.as_ref().expect("adam cfg initialized");
3361                    let adam = model_adam.as_mut().expect("adam state exists");
3362                    fused_adam_head_backward_update(
3363                        self.lm_head.as_mut_slice(),
3364                        vocab,
3365                        c,
3366                        &scratch.grad_logits.as_slice()[0..vocab],
3367                        scratch.x_normed.as_slice(),
3368                        scratch.grad_x.as_mut_slice(),
3369                        adam.lm_head.m.as_mut_slice(),
3370                        adam.lm_head.v.as_mut_slice(),
3371                        cfg,
3372                    );
3373                }
3374            }
3375        } else {
3376            for row in 0..vocab {
3377                let g = scratch.grad_logits[row];
3378                if g == 0.0 {
3379                    continue;
3380                }
3381                let row_off = row * c;
3382                for col in 0..c {
3383                    scratch.grad_x[col] += self.lm_head[row_off + col] * g;
3384                }
3385            }
3386        }
3387
3388        let needs_backprop = scope.trains_non_head_params() || scope.head;
3389        if !needs_backprop {
3390            return Ok(());
3391        }
3392        layer_norm_backward(
3393            scratch.x.as_slice(),
3394            self.ln_out_w.as_slice(),
3395            scratch.grad_x.as_slice(),
3396            self.cfg.layer_norm_eps,
3397            scratch.grad_x2.as_mut_slice(),
3398            scratch.grad_x3.as_mut_slice(),
3399            scratch.grad_x4.as_mut_slice(),
3400        );
3401        if scope.head {
3402            match optimizer {
3403                OptimizerKind::Sgd => {
3404                    sgd_vec_update(
3405                        self.ln_out_w.as_mut_slice(),
3406                        scratch.grad_x3.as_slice(),
3407                        lr,
3408                        clip,
3409                    );
3410                    sgd_vec_update(
3411                        self.ln_out_b.as_mut_slice(),
3412                        scratch.grad_x4.as_slice(),
3413                        lr,
3414                        clip,
3415                    );
3416                }
3417                OptimizerKind::Adam => {
3418                    let cfg = adam_step.as_ref().expect("adam cfg initialized");
3419                    let adam = model_adam.as_mut().expect("adam state exists");
3420                    apply_adam_vec_update(
3421                        self.ln_out_w.as_mut_slice(),
3422                        scratch.grad_x3.as_slice(),
3423                        &mut adam.ln_out_w,
3424                        cfg,
3425                    );
3426                    apply_adam_vec_update(
3427                        self.ln_out_b.as_mut_slice(),
3428                        scratch.grad_x4.as_slice(),
3429                        &mut adam.ln_out_b,
3430                        cfg,
3431                    );
3432                }
3433            }
3434        }
3435        scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice());
3436        scratch.grad_v_first.zero();
3437
3438        for layer_idx in (0..self.cfg.num_layers).rev() {
3439            let tr = &scratch.train_trace_layers[layer_idx];
3440            let block = &mut self.blocks[layer_idx];
3441
3442            // FFN residual split: x_out = x_after_attn + ffn_out.
3443            scratch.grad_x2.copy_from_slice(scratch.grad_x.as_slice()); // d x_after_attn
3444            scratch.grad_x3.copy_from_slice(scratch.grad_x.as_slice()); // d ffn_out
3445
3446            // ffn_out = value_w @ ffn_k
3447            unsafe {
3448                kernel::gemv_t_avx(
3449                    block.ffn.value_w.as_ptr(),
3450                    scratch.grad_x3.as_ptr(),
3451                    scratch.grad_ffn.as_mut_ptr(),
3452                    c,
3453                    i,
3454                );
3455            }
3456            if scope.ffn {
3457                match optimizer {
3458                    OptimizerKind::Sgd => sgd_outer_update(
3459                        block.ffn.value_w.as_mut_slice(),
3460                        c,
3461                        i,
3462                        scratch.grad_x3.as_slice(),
3463                        tr.ffn_k.as_slice(),
3464                        lr,
3465                        clip,
3466                    ),
3467                    OptimizerKind::Adam => {
3468                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
3469                        let adam =
3470                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
3471                        apply_adam_outer_update(
3472                            block.ffn.value_w.as_mut_slice(),
3473                            c,
3474                            i,
3475                            scratch.grad_x3.as_slice(),
3476                            tr.ffn_k.as_slice(),
3477                            &mut adam.ffn.value_w,
3478                            cfg,
3479                        );
3480                    }
3481                }
3482            }
3483
3484            // relu^2 backward
3485            for col in 0..i {
3486                let pre = tr.ffn_pre[col];
3487                scratch.grad_ffn2[col] = if pre > 0.0 {
3488                    scratch.grad_ffn[col] * (2.0 * pre)
3489                } else {
3490                    0.0
3491                };
3492            }
3493
3494            // key_w backward
3495            unsafe {
3496                kernel::gemv_t_avx(
3497                    block.ffn.key_w.as_ptr(),
3498                    scratch.grad_ffn2.as_ptr(),
3499                    scratch.grad_x4.as_mut_ptr(),
3500                    i,
3501                    c,
3502                );
3503            }
3504            if scope.ffn {
3505                match optimizer {
3506                    OptimizerKind::Sgd => sgd_outer_update(
3507                        block.ffn.key_w.as_mut_slice(),
3508                        i,
3509                        c,
3510                        scratch.grad_ffn2.as_slice(),
3511                        tr.ffn_xk.as_slice(),
3512                        lr,
3513                        clip,
3514                    ),
3515                    OptimizerKind::Adam => {
3516                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
3517                        let adam =
3518                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
3519                        apply_adam_outer_update(
3520                            block.ffn.key_w.as_mut_slice(),
3521                            i,
3522                            c,
3523                            scratch.grad_ffn2.as_slice(),
3524                            tr.ffn_xk.as_slice(),
3525                            &mut adam.ffn.key_w,
3526                            cfg,
3527                        );
3528                    }
3529                }
3530            }
3531
3532            // token_shift backward (ffn)
3533            for col in 0..c {
3534                let g = scratch.grad_x4[col];
3535                let mix = block.ffn.x_k[col];
3536                let base = tr.ffn_norm[col];
3537                let prev = tr.ffn_x_prev_old[col];
3538                scratch.grad_x5[col] = g * (1.0 - mix); // d ffn_norm
3539                scratch.grad_param[col] = g * (prev - base); // d x_k
3540            }
3541            if scope.ffn {
3542                match optimizer {
3543                    OptimizerKind::Sgd => sgd_vec_update(
3544                        block.ffn.x_k.as_mut_slice(),
3545                        scratch.grad_param.as_slice(),
3546                        lr,
3547                        clip,
3548                    ),
3549                    OptimizerKind::Adam => {
3550                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
3551                        let adam =
3552                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
3553                        apply_adam_vec_update(
3554                            block.ffn.x_k.as_mut_slice(),
3555                            scratch.grad_param.as_slice(),
3556                            &mut adam.ffn.x_k,
3557                            cfg,
3558                        );
3559                    }
3560                }
3561            }
3562
3563            // ffn norm backward: d x_after_attn contribution.
3564            layer_norm_backward(
3565                tr.x_after_attn.as_slice(),
3566                block.ffn_norm_w.as_slice(),
3567                scratch.grad_x5.as_slice(),
3568                self.cfg.layer_norm_eps,
3569                scratch.grad_x4.as_mut_slice(),
3570                scratch.grad_x3.as_mut_slice(),
3571                scratch.grad_x6.as_mut_slice(),
3572            );
3573            if scope.ffn_norm {
3574                match optimizer {
3575                    OptimizerKind::Sgd => {
3576                        sgd_vec_update(
3577                            block.ffn_norm_w.as_mut_slice(),
3578                            scratch.grad_x3.as_slice(),
3579                            lr,
3580                            clip,
3581                        );
3582                        sgd_vec_update(
3583                            block.ffn_norm_b.as_mut_slice(),
3584                            scratch.grad_x6.as_slice(),
3585                            lr,
3586                            clip,
3587                        );
3588                    }
3589                    OptimizerKind::Adam => {
3590                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
3591                        let adam =
3592                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
3593                        apply_adam_vec_update(
3594                            block.ffn_norm_w.as_mut_slice(),
3595                            scratch.grad_x3.as_slice(),
3596                            &mut adam.ffn_norm_w,
3597                            cfg,
3598                        );
3599                        apply_adam_vec_update(
3600                            block.ffn_norm_b.as_mut_slice(),
3601                            scratch.grad_x6.as_slice(),
3602                            &mut adam.ffn_norm_b,
3603                            cfg,
3604                        );
3605                    }
3606                }
3607            }
3608            for col in 0..c {
3609                scratch.grad_x2[col] += scratch.grad_x4[col];
3610            }
3611
3612            // Attention residual split.
3613            scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice()); // d x_after_pre
3614            scratch.grad_x3.copy_from_slice(scratch.grad_x2.as_slice()); // d att_out
3615
3616            // out_proj backward
3617            unsafe {
3618                kernel::gemv_t_avx(
3619                    block.attn.o_proj.as_ptr(),
3620                    scratch.grad_x3.as_ptr(),
3621                    scratch.grad_x4.as_mut_ptr(),
3622                    c,
3623                    c,
3624                );
3625            }
3626            if scope.attn {
3627                match optimizer {
3628                    OptimizerKind::Sgd => sgd_outer_update(
3629                        block.attn.o_proj.as_mut_slice(),
3630                        c,
3631                        c,
3632                        scratch.grad_x3.as_slice(),
3633                        tr.y_gate.as_slice(),
3634                        lr,
3635                        clip,
3636                    ),
3637                    OptimizerKind::Adam => {
3638                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
3639                        let adam =
3640                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
3641                        apply_adam_outer_update(
3642                            block.attn.o_proj.as_mut_slice(),
3643                            c,
3644                            c,
3645                            scratch.grad_x3.as_slice(),
3646                            tr.y_gate.as_slice(),
3647                            &mut adam.attn.o_proj,
3648                            cfg,
3649                        );
3650                    }
3651                }
3652            }
3653
3654            // Gate backward.
3655            for col in 0..c {
3656                let gy = scratch.grad_x4[col];
3657                scratch.grad_saved[col] = gy * tr.y_head[col]; // d g
3658                scratch.grad_x4[col] = gy * tr.g[col]; // d y_head
3659            }
3660
3661            // Head-qk branch.
3662            scratch.grad_x2.zero(); // d r
3663            scratch.grad_x3.zero(); // d k_scaled
3664            scratch.grad_x6.zero(); // d v_final
3665            scratch.grad_param.zero(); // d r_k
3666            for head_idx in 0..h {
3667                let off = head_idx * n;
3668                let mut g_alpha = 0.0f32;
3669                for j in 0..n {
3670                    let g = scratch.grad_x4[off + j];
3671                    g_alpha += g * tr.v[off + j];
3672                    scratch.grad_x6[off + j] += g * tr.alpha[head_idx];
3673                }
3674                for j in 0..n {
3675                    let idx = off + j;
3676                    let rk = block.attn.r_k[idx];
3677                    let rv = tr.r[idx];
3678                    let kv = tr.k[idx];
3679                    let g = g_alpha * rk;
3680                    scratch.grad_x2[idx] += g * kv;
3681                    scratch.grad_x3[idx] += g * rv;
3682                    scratch.grad_param[idx] += g_alpha * rv * kv;
3683                }
3684            }
3685            if scope.attn {
3686                match optimizer {
3687                    OptimizerKind::Sgd => sgd_vec_update(
3688                        block.attn.r_k.as_mut_slice(),
3689                        scratch.grad_param.as_slice(),
3690                        lr,
3691                        clip,
3692                    ),
3693                    OptimizerKind::Adam => {
3694                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
3695                        let adam =
3696                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
3697                        apply_adam_vec_update(
3698                            block.attn.r_k.as_mut_slice(),
3699                            scratch.grad_param.as_slice(),
3700                            &mut adam.attn.r_k,
3701                            cfg,
3702                        );
3703                    }
3704                }
3705            }
3706
3707            // GroupNorm backward.
3708            scratch.grad_x5.as_mut_slice()[0..c].copy_from_slice(&scratch.grad_x4.as_slice()[0..c]);
3709            group_norm_backward(
3710                tr.y_wkv.as_slice(),
3711                block.attn.g_norm_w.as_slice(),
3712                scratch.grad_x5.as_slice(),
3713                h,
3714                n,
3715                self.cfg.group_norm_eps,
3716                scratch.grad_x4.as_mut_slice(),     // d y_wkv
3717                scratch.grad_param.as_mut_slice(),  // d g_norm_w
3718                scratch.grad_param2.as_mut_slice(), // d g_norm_b
3719            );
3720            if scope.attn {
3721                match optimizer {
3722                    OptimizerKind::Sgd => {
3723                        sgd_vec_update(
3724                            block.attn.g_norm_w.as_mut_slice(),
3725                            scratch.grad_param.as_slice(),
3726                            lr,
3727                            clip,
3728                        );
3729                        sgd_vec_update(
3730                            block.attn.g_norm_b.as_mut_slice(),
3731                            scratch.grad_param2.as_slice(),
3732                            lr,
3733                            clip,
3734                        );
3735                    }
3736                    OptimizerKind::Adam => {
3737                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
3738                        let adam =
3739                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
3740                        apply_adam_vec_update(
3741                            block.attn.g_norm_w.as_mut_slice(),
3742                            scratch.grad_param.as_slice(),
3743                            &mut adam.attn.g_norm_w,
3744                            cfg,
3745                        );
3746                        apply_adam_vec_update(
3747                            block.attn.g_norm_b.as_mut_slice(),
3748                            scratch.grad_param2.as_slice(),
3749                            &mut adam.attn.g_norm_b,
3750                            cfg,
3751                        );
3752                    }
3753                }
3754            }
3755
3756            // WKV kernel backward.
3757            scratch.grad_param.zero(); // d w_decay
3758            scratch.grad_x5.zero(); // d a
3759            scratch.grad_param2.zero(); // d kk
3760            let s_old = tr.att_state_old.as_slice();
3761            let s_new = state.layers[layer_idx].att_state.as_slice();
3762            for head_idx in 0..h {
3763                let off = head_idx * n;
3764                let s_head_old_off = head_idx * n * n;
3765                let s_head_new_off = head_idx * n * n;
3766                let grad_y = &scratch.grad_x4.as_slice()[off..off + n];
3767                let r_head = &tr.r.as_slice()[off..off + n];
3768                let k_head = &tr.k.as_slice()[off..off + n];
3769                let kk_head = &tr.kk.as_slice()[off..off + n];
3770                let a_head = &tr.a.as_slice()[off..off + n];
3771                let v_head = &tr.v.as_slice()[off..off + n];
3772
3773                unsafe {
3774                    kernel::gemv_t_avx(
3775                        s_new.as_ptr().add(s_head_new_off),
3776                        grad_y.as_ptr(),
3777                        scratch.grad_low_rank.as_mut_ptr(),
3778                        n,
3779                        n,
3780                    );
3781                    kernel::gemv_t_avx(
3782                        s_old.as_ptr().add(s_head_old_off),
3783                        grad_y.as_ptr(),
3784                        scratch.grad_low_rank2.as_mut_ptr(),
3785                        n,
3786                        n,
3787                    );
3788                }
3789
3790                for j in 0..n {
3791                    let idx = off + j;
3792                    scratch.grad_x2[idx] += scratch.grad_low_rank[j];
3793                    scratch.grad_param[idx] += r_head[j] * scratch.grad_low_rank2[j];
3794                }
3795
3796                unsafe {
3797                    kernel::gemv_avx(
3798                        s_old.as_ptr().add(s_head_old_off),
3799                        kk_head.as_ptr(),
3800                        scratch.grad_low_rank.as_mut_ptr(),
3801                        n,
3802                        n,
3803                    );
3804                }
3805
3806                let mut dot_gv = 0.0f32;
3807                let mut dot_rk = 0.0f32;
3808                let mut dot_r_kka = 0.0f32;
3809                let mut sum_gy_u = 0.0f32;
3810                for j in 0..n {
3811                    dot_gv += grad_y[j] * v_head[j];
3812                    dot_rk += r_head[j] * k_head[j];
3813                    dot_r_kka += r_head[j] * kk_head[j] * a_head[j];
3814                    sum_gy_u += grad_y[j] * scratch.grad_low_rank[j];
3815                }
3816
3817                for j in 0..n {
3818                    let idx = off + j;
3819                    scratch.grad_x3[idx] += r_head[j] * dot_gv;
3820                    scratch.grad_x6[idx] += grad_y[j] * dot_rk;
3821                    scratch.grad_x5[idx] -= sum_gy_u * r_head[j] * kk_head[j];
3822                    scratch.grad_low_rank[j] = -grad_y[j] * dot_r_kka;
3823                }
3824
3825                unsafe {
3826                    kernel::gemv_t_avx(
3827                        s_old.as_ptr().add(s_head_old_off),
3828                        scratch.grad_low_rank.as_ptr(),
3829                        scratch.grad_low_rank2.as_mut_ptr(),
3830                        n,
3831                        n,
3832                    );
3833                }
3834                for j in 0..n {
3835                    let idx = off + j;
3836                    scratch.grad_param2[idx] +=
3837                        scratch.grad_low_rank2[j] - sum_gy_u * r_head[j] * a_head[j];
3838                }
3839            }
3840
3841            // k scaling + kk normalization backward.
3842            for col in 0..c {
3843                let gk = scratch.grad_x3[col];
3844                let scale = 1.0 + (tr.a[col] - 1.0) * block.attn.k_a[col];
3845                let d_scale = gk * tr.k_pre[col];
3846                scratch.grad_x3[col] = gk * scale; // d k_pre (scaled path)
3847                scratch.grad_x5[col] += d_scale * block.attn.k_a[col]; // d a
3848                scratch.grad_param[col] = d_scale * (tr.a[col] - 1.0); // d k_a
3849            }
3850            for head_idx in 0..h {
3851                let off = head_idx * n;
3852                l2_normalize_backward(
3853                    &tr.kk_pre.as_slice()[off..off + n],
3854                    &tr.kk.as_slice()[off..off + n],
3855                    &scratch.grad_param2.as_slice()[off..off + n],
3856                    1e-12,
3857                    &mut scratch.grad_x4.as_mut_slice()[off..off + n],
3858                );
3859            }
3860            for col in 0..c {
3861                let g = scratch.grad_x4[col];
3862                scratch.grad_x3[col] += g * block.attn.k_k[col]; // d k_pre from kk_pre
3863                scratch.grad_param2[col] = g * tr.k_pre[col]; // d k_k
3864            }
3865            if scope.attn {
3866                match optimizer {
3867                    OptimizerKind::Sgd => {
3868                        sgd_vec_update(
3869                            block.attn.k_a.as_mut_slice(),
3870                            scratch.grad_param.as_slice(),
3871                            lr,
3872                            clip,
3873                        );
3874                        sgd_vec_update(
3875                            block.attn.k_k.as_mut_slice(),
3876                            scratch.grad_param2.as_slice(),
3877                            lr,
3878                            clip,
3879                        );
3880                    }
3881                    OptimizerKind::Adam => {
3882                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
3883                        let adam =
3884                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
3885                        apply_adam_vec_update(
3886                            block.attn.k_a.as_mut_slice(),
3887                            scratch.grad_param.as_slice(),
3888                            &mut adam.attn.k_a,
3889                            cfg,
3890                        );
3891                        apply_adam_vec_update(
3892                            block.attn.k_k.as_mut_slice(),
3893                            scratch.grad_param2.as_slice(),
3894                            &mut adam.attn.k_k,
3895                            cfg,
3896                        );
3897                    }
3898                }
3899            }
3900
3901            // V-residual backward (layers > 0).
3902            scratch
3903                .grad_param2
3904                .copy_from_slice(scratch.grad_x6.as_slice()); // d v_final snapshot
3905            if layer_idx == 0 {
3906                for col in 0..c {
3907                    scratch.grad_x6[col] += scratch.grad_v_first[col];
3908                }
3909            } else if tr.uses_v_residual
3910                && let (Some(v1), Some(v2), Some(v0)) =
3911                    (&mut block.attn.v1, &mut block.attn.v2, &mut block.attn.v0)
3912            {
3913                for col in 0..c {
3914                    let gv = scratch.grad_param2[col];
3915                    let nu = tr.nu[col];
3916                    scratch.grad_x6[col] = gv * (1.0 - nu); // d v_pre
3917                    scratch.grad_x3[col] = gv * (scratch.train_v_first[col] - tr.v_pre[col]); // d nu
3918                    scratch.grad_v_first[col] += gv * nu; // d v_first
3919                }
3920                for col in 0..c {
3921                    let nu = tr.nu[col];
3922                    scratch.grad_x3[col] *= nu * (1.0 - nu); // d nu_pre
3923                }
3924                if scope.attn {
3925                    match optimizer {
3926                        OptimizerKind::Sgd => {
3927                            sgd_vec_update(v0.as_mut_slice(), scratch.grad_x3.as_slice(), lr, clip)
3928                        }
3929                        OptimizerKind::Adam => {
3930                            let cfg = adam_step.as_ref().expect("adam cfg initialized");
3931                            let adam = &mut model_adam.as_mut().expect("adam state exists").blocks
3932                                [layer_idx];
3933                            apply_adam_vec_update(
3934                                v0.as_mut_slice(),
3935                                scratch.grad_x3.as_slice(),
3936                                adam.attn.v0.as_mut().expect("adam v0 state"),
3937                                cfg,
3938                            );
3939                        }
3940                    }
3941                }
3942                if scope.attn {
3943                    match optimizer {
3944                        OptimizerKind::Sgd => sgd_outer_update(
3945                            v2.as_mut_slice(),
3946                            c,
3947                            d_v,
3948                            scratch.grad_x3.as_slice(),
3949                            &tr.v_hidden.as_slice()[0..d_v],
3950                            lr,
3951                            clip,
3952                        ),
3953                        OptimizerKind::Adam => {
3954                            let cfg = adam_step.as_ref().expect("adam cfg initialized");
3955                            let adam = &mut model_adam.as_mut().expect("adam state exists").blocks
3956                                [layer_idx];
3957                            apply_adam_outer_update(
3958                                v2.as_mut_slice(),
3959                                c,
3960                                d_v,
3961                                scratch.grad_x3.as_slice(),
3962                                &tr.v_hidden.as_slice()[0..d_v],
3963                                adam.attn.v2.as_mut().expect("adam v2 state"),
3964                                cfg,
3965                            );
3966                        }
3967                    }
3968                }
3969                unsafe {
3970                    kernel::gemv_t_avx(
3971                        v2.as_ptr(),
3972                        scratch.grad_x3.as_ptr(),
3973                        scratch.grad_low_rank.as_mut_ptr(),
3974                        c,
3975                        d_v,
3976                    );
3977                }
3978                if scope.attn {
3979                    match optimizer {
3980                        OptimizerKind::Sgd => sgd_outer_update(
3981                            v1.as_mut_slice(),
3982                            d_v,
3983                            c,
3984                            &scratch.grad_low_rank.as_slice()[0..d_v],
3985                            tr.xv.as_slice(),
3986                            lr,
3987                            clip,
3988                        ),
3989                        OptimizerKind::Adam => {
3990                            let cfg = adam_step.as_ref().expect("adam cfg initialized");
3991                            let adam = &mut model_adam.as_mut().expect("adam state exists").blocks
3992                                [layer_idx];
3993                            apply_adam_outer_update(
3994                                v1.as_mut_slice(),
3995                                d_v,
3996                                c,
3997                                &scratch.grad_low_rank.as_slice()[0..d_v],
3998                                tr.xv.as_slice(),
3999                                adam.attn.v1.as_mut().expect("adam v1 state"),
4000                                cfg,
4001                            );
4002                        }
4003                    }
4004                }
4005                for col in 0..c {
4006                    let mut acc = 0.0f32;
4007                    for row in 0..d_v {
4008                        acc += v1[row * c + col] * scratch.grad_low_rank[row];
4009                    }
4010                    scratch.grad_x4[col] += acc; // add into d xv after projection transpose
4011                }
4012            }
4013
4014            // R/K/V projection updates and input grads.
4015            let proj_size = c * c;
4016            if scope.attn {
4017                match optimizer {
4018                    OptimizerKind::Sgd => {
4019                        sgd_outer_update(
4020                            &mut block.attn.rkv_proj.as_mut_slice()[0..proj_size],
4021                            c,
4022                            c,
4023                            scratch.grad_x2.as_slice(),
4024                            tr.xr.as_slice(),
4025                            lr,
4026                            clip,
4027                        );
4028                        sgd_outer_update(
4029                            &mut block.attn.rkv_proj.as_mut_slice()[proj_size..2 * proj_size],
4030                            c,
4031                            c,
4032                            scratch.grad_x3.as_slice(),
4033                            tr.xk.as_slice(),
4034                            lr,
4035                            clip,
4036                        );
4037                        sgd_outer_update(
4038                            &mut block.attn.rkv_proj.as_mut_slice()[2 * proj_size..3 * proj_size],
4039                            c,
4040                            c,
4041                            scratch.grad_x6.as_slice(),
4042                            tr.xv.as_slice(),
4043                            lr,
4044                            clip,
4045                        );
4046                    }
4047                    OptimizerKind::Adam => {
4048                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
4049                        let adam =
4050                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4051                        apply_adam_outer_update_raw(
4052                            &mut block.attn.rkv_proj.as_mut_slice()[0..proj_size],
4053                            c,
4054                            c,
4055                            scratch.grad_x2.as_slice(),
4056                            tr.xr.as_slice(),
4057                            &mut adam.attn.rkv_proj.m.as_mut_slice()[0..proj_size],
4058                            &mut adam.attn.rkv_proj.v.as_mut_slice()[0..proj_size],
4059                            cfg,
4060                        );
4061                        apply_adam_outer_update_raw(
4062                            &mut block.attn.rkv_proj.as_mut_slice()[proj_size..2 * proj_size],
4063                            c,
4064                            c,
4065                            scratch.grad_x3.as_slice(),
4066                            tr.xk.as_slice(),
4067                            &mut adam.attn.rkv_proj.m.as_mut_slice()[proj_size..2 * proj_size],
4068                            &mut adam.attn.rkv_proj.v.as_mut_slice()[proj_size..2 * proj_size],
4069                            cfg,
4070                        );
4071                        apply_adam_outer_update_raw(
4072                            &mut block.attn.rkv_proj.as_mut_slice()[2 * proj_size..3 * proj_size],
4073                            c,
4074                            c,
4075                            scratch.grad_x6.as_slice(),
4076                            tr.xv.as_slice(),
4077                            &mut adam.attn.rkv_proj.m.as_mut_slice()[2 * proj_size..3 * proj_size],
4078                            &mut adam.attn.rkv_proj.v.as_mut_slice()[2 * proj_size..3 * proj_size],
4079                            cfg,
4080                        );
4081                    }
4082                }
4083            }
4084            let proj = block.attn.rkv_proj.as_slice();
4085            unsafe {
4086                kernel::gemv_t_avx(
4087                    proj.as_ptr(),
4088                    scratch.grad_x2.as_ptr(),
4089                    scratch.grad_param.as_mut_ptr(),
4090                    c,
4091                    c,
4092                );
4093                kernel::gemv_t_avx(
4094                    proj.as_ptr().add(proj_size),
4095                    scratch.grad_x3.as_ptr(),
4096                    scratch.grad_param2.as_mut_ptr(),
4097                    c,
4098                    c,
4099                );
4100                kernel::gemv_t_avx(
4101                    proj.as_ptr().add(2 * proj_size),
4102                    scratch.grad_x6.as_ptr(),
4103                    scratch.grad_x4.as_mut_ptr(),
4104                    c,
4105                    c,
4106                );
4107            }
4108
4109            // W low-rank backward.
4110            let inv_sqrt_e = 1.0 / std::f32::consts::E.sqrt();
4111            for col in 0..c {
4112                let sig = tr.w_sigmoid[col];
4113                let d_sig = scratch.grad_param[col] * (-inv_sqrt_e) * tr.w_decay[col];
4114                scratch.grad_param[col] = d_sig * sig * (1.0 - sig); // d w_pre
4115            }
4116            if scope.attn {
4117                match optimizer {
4118                    OptimizerKind::Sgd => sgd_vec_update(
4119                        block.attn.w0.as_mut_slice(),
4120                        scratch.grad_param.as_slice(),
4121                        lr,
4122                        clip,
4123                    ),
4124                    OptimizerKind::Adam => {
4125                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
4126                        let adam =
4127                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4128                        apply_adam_vec_update(
4129                            block.attn.w0.as_mut_slice(),
4130                            scratch.grad_param.as_slice(),
4131                            &mut adam.attn.w0,
4132                            cfg,
4133                        );
4134                    }
4135                }
4136                match optimizer {
4137                    OptimizerKind::Sgd => sgd_outer_update(
4138                        block.attn.w2.as_mut_slice(),
4139                        c,
4140                        d_w,
4141                        scratch.grad_param.as_slice(),
4142                        &tr.w_hidden.as_slice()[0..d_w],
4143                        lr,
4144                        clip,
4145                    ),
4146                    OptimizerKind::Adam => {
4147                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
4148                        let adam =
4149                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4150                        apply_adam_outer_update(
4151                            block.attn.w2.as_mut_slice(),
4152                            c,
4153                            d_w,
4154                            scratch.grad_param.as_slice(),
4155                            &tr.w_hidden.as_slice()[0..d_w],
4156                            &mut adam.attn.w2,
4157                            cfg,
4158                        );
4159                    }
4160                }
4161            }
4162            unsafe {
4163                kernel::gemv_t_avx(
4164                    block.attn.w2.as_ptr(),
4165                    scratch.grad_param.as_ptr(),
4166                    scratch.grad_low_rank.as_mut_ptr(),
4167                    c,
4168                    d_w,
4169                );
4170            }
4171            for col in 0..d_w {
4172                let t = tr.w_hidden[col];
4173                scratch.grad_low_rank[col] *= 1.0 - t * t;
4174            }
4175            if scope.attn {
4176                match optimizer {
4177                    OptimizerKind::Sgd => sgd_outer_update(
4178                        block.attn.w1.as_mut_slice(),
4179                        d_w,
4180                        c,
4181                        &scratch.grad_low_rank.as_slice()[0..d_w],
4182                        tr.xw.as_slice(),
4183                        lr,
4184                        clip,
4185                    ),
4186                    OptimizerKind::Adam => {
4187                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
4188                        let adam =
4189                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4190                        apply_adam_outer_update(
4191                            block.attn.w1.as_mut_slice(),
4192                            d_w,
4193                            c,
4194                            &scratch.grad_low_rank.as_slice()[0..d_w],
4195                            tr.xw.as_slice(),
4196                            &mut adam.attn.w1,
4197                            cfg,
4198                        );
4199                    }
4200                }
4201            }
4202            unsafe {
4203                kernel::gemv_t_avx(
4204                    block.attn.w1.as_ptr(),
4205                    scratch.grad_low_rank.as_ptr(),
4206                    scratch.grad_x6.as_mut_ptr(),
4207                    d_w,
4208                    c,
4209                );
4210            }
4211
4212            // A low-rank backward.
4213            for col in 0..c {
4214                let a = tr.a[col];
4215                scratch.grad_x5[col] *= a * (1.0 - a); // d a_pre
4216            }
4217            if scope.attn {
4218                match optimizer {
4219                    OptimizerKind::Sgd => sgd_vec_update(
4220                        block.attn.a0.as_mut_slice(),
4221                        scratch.grad_x5.as_slice(),
4222                        lr,
4223                        clip,
4224                    ),
4225                    OptimizerKind::Adam => {
4226                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
4227                        let adam =
4228                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4229                        apply_adam_vec_update(
4230                            block.attn.a0.as_mut_slice(),
4231                            scratch.grad_x5.as_slice(),
4232                            &mut adam.attn.a0,
4233                            cfg,
4234                        );
4235                    }
4236                }
4237                match optimizer {
4238                    OptimizerKind::Sgd => sgd_outer_update(
4239                        block.attn.a2.as_mut_slice(),
4240                        c,
4241                        d_a,
4242                        scratch.grad_x5.as_slice(),
4243                        &tr.a_hidden.as_slice()[0..d_a],
4244                        lr,
4245                        clip,
4246                    ),
4247                    OptimizerKind::Adam => {
4248                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
4249                        let adam =
4250                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4251                        apply_adam_outer_update(
4252                            block.attn.a2.as_mut_slice(),
4253                            c,
4254                            d_a,
4255                            scratch.grad_x5.as_slice(),
4256                            &tr.a_hidden.as_slice()[0..d_a],
4257                            &mut adam.attn.a2,
4258                            cfg,
4259                        );
4260                    }
4261                }
4262            }
4263            unsafe {
4264                kernel::gemv_t_avx(
4265                    block.attn.a2.as_ptr(),
4266                    scratch.grad_x5.as_ptr(),
4267                    scratch.grad_low_rank.as_mut_ptr(),
4268                    c,
4269                    d_a,
4270                );
4271            }
4272            if scope.attn {
4273                match optimizer {
4274                    OptimizerKind::Sgd => sgd_outer_update(
4275                        block.attn.a1.as_mut_slice(),
4276                        d_a,
4277                        c,
4278                        &scratch.grad_low_rank.as_slice()[0..d_a],
4279                        tr.xa.as_slice(),
4280                        lr,
4281                        clip,
4282                    ),
4283                    OptimizerKind::Adam => {
4284                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
4285                        let adam =
4286                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4287                        apply_adam_outer_update(
4288                            block.attn.a1.as_mut_slice(),
4289                            d_a,
4290                            c,
4291                            &scratch.grad_low_rank.as_slice()[0..d_a],
4292                            tr.xa.as_slice(),
4293                            &mut adam.attn.a1,
4294                            cfg,
4295                        );
4296                    }
4297                }
4298            }
4299            unsafe {
4300                kernel::gemv_t_avx(
4301                    block.attn.a1.as_ptr(),
4302                    scratch.grad_low_rank.as_ptr(),
4303                    scratch.grad_x5.as_mut_ptr(),
4304                    d_a,
4305                    c,
4306                );
4307            }
4308
4309            // G low-rank backward.
4310            if scope.attn {
4311                match optimizer {
4312                    OptimizerKind::Sgd => sgd_outer_update(
4313                        block.attn.g2.as_mut_slice(),
4314                        c,
4315                        d_g,
4316                        scratch.grad_saved.as_slice(),
4317                        &tr.g_hidden.as_slice()[0..d_g],
4318                        lr,
4319                        clip,
4320                    ),
4321                    OptimizerKind::Adam => {
4322                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
4323                        let adam =
4324                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4325                        apply_adam_outer_update(
4326                            block.attn.g2.as_mut_slice(),
4327                            c,
4328                            d_g,
4329                            scratch.grad_saved.as_slice(),
4330                            &tr.g_hidden.as_slice()[0..d_g],
4331                            &mut adam.attn.g2,
4332                            cfg,
4333                        );
4334                    }
4335                }
4336            }
4337            unsafe {
4338                kernel::gemv_t_avx(
4339                    block.attn.g2.as_ptr(),
4340                    scratch.grad_saved.as_ptr(),
4341                    scratch.grad_low_rank.as_mut_ptr(),
4342                    c,
4343                    d_g,
4344                );
4345            }
4346            for col in 0..d_g {
4347                let sig = tr.g_hidden[col];
4348                scratch.grad_low_rank2[col] = scratch.grad_low_rank[col] * sig * (1.0 - sig);
4349            }
4350            if scope.attn {
4351                match optimizer {
4352                    OptimizerKind::Sgd => sgd_outer_update(
4353                        block.attn.g1.as_mut_slice(),
4354                        d_g,
4355                        c,
4356                        &scratch.grad_low_rank2.as_slice()[0..d_g],
4357                        tr.xg.as_slice(),
4358                        lr,
4359                        clip,
4360                    ),
4361                    OptimizerKind::Adam => {
4362                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
4363                        let adam =
4364                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4365                        apply_adam_outer_update(
4366                            block.attn.g1.as_mut_slice(),
4367                            d_g,
4368                            c,
4369                            &scratch.grad_low_rank2.as_slice()[0..d_g],
4370                            tr.xg.as_slice(),
4371                            &mut adam.attn.g1,
4372                            cfg,
4373                        );
4374                    }
4375                }
4376            }
4377            unsafe {
4378                kernel::gemv_t_avx(
4379                    block.attn.g1.as_ptr(),
4380                    scratch.grad_low_rank2.as_ptr(),
4381                    scratch.grad_saved.as_mut_ptr(),
4382                    d_g,
4383                    c,
4384                );
4385            }
4386
4387            // Token-shift backward for attention branches.
4388            scratch.grad_x3.zero(); // d attn_norm
4389
4390            // x_r
4391            for col in 0..c {
4392                let g = scratch.grad_param[col];
4393                let mix = block.attn.x_r[col];
4394                let base = tr.attn_norm[col];
4395                let prev = tr.att_x_prev_old[col];
4396                scratch.grad_x3[col] += g * (1.0 - mix);
4397                scratch.grad_x2[col] = g * (prev - base);
4398            }
4399            if scope.attn {
4400                match optimizer {
4401                    OptimizerKind::Sgd => sgd_vec_update(
4402                        block.attn.x_r.as_mut_slice(),
4403                        scratch.grad_x2.as_slice(),
4404                        lr,
4405                        clip,
4406                    ),
4407                    OptimizerKind::Adam => {
4408                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
4409                        let adam =
4410                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4411                        apply_adam_vec_update(
4412                            block.attn.x_r.as_mut_slice(),
4413                            scratch.grad_x2.as_slice(),
4414                            &mut adam.attn.x_r,
4415                            cfg,
4416                        );
4417                    }
4418                }
4419            }
4420
4421            // x_w
4422            for col in 0..c {
4423                let g = scratch.grad_x6[col];
4424                let mix = block.attn.x_w[col];
4425                let base = tr.attn_norm[col];
4426                let prev = tr.att_x_prev_old[col];
4427                scratch.grad_x3[col] += g * (1.0 - mix);
4428                scratch.grad_x2[col] = g * (prev - base);
4429            }
4430            if scope.attn {
4431                match optimizer {
4432                    OptimizerKind::Sgd => sgd_vec_update(
4433                        block.attn.x_w.as_mut_slice(),
4434                        scratch.grad_x2.as_slice(),
4435                        lr,
4436                        clip,
4437                    ),
4438                    OptimizerKind::Adam => {
4439                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
4440                        let adam =
4441                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4442                        apply_adam_vec_update(
4443                            block.attn.x_w.as_mut_slice(),
4444                            scratch.grad_x2.as_slice(),
4445                            &mut adam.attn.x_w,
4446                            cfg,
4447                        );
4448                    }
4449                }
4450            }
4451
4452            // x_k
4453            for col in 0..c {
4454                let g = scratch.grad_param2[col];
4455                let mix = block.attn.x_k[col];
4456                let base = tr.attn_norm[col];
4457                let prev = tr.att_x_prev_old[col];
4458                scratch.grad_x3[col] += g * (1.0 - mix);
4459                scratch.grad_x2[col] = g * (prev - base);
4460            }
4461            if scope.attn {
4462                match optimizer {
4463                    OptimizerKind::Sgd => sgd_vec_update(
4464                        block.attn.x_k.as_mut_slice(),
4465                        scratch.grad_x2.as_slice(),
4466                        lr,
4467                        clip,
4468                    ),
4469                    OptimizerKind::Adam => {
4470                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
4471                        let adam =
4472                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4473                        apply_adam_vec_update(
4474                            block.attn.x_k.as_mut_slice(),
4475                            scratch.grad_x2.as_slice(),
4476                            &mut adam.attn.x_k,
4477                            cfg,
4478                        );
4479                    }
4480                }
4481            }
4482
4483            // x_v
4484            for col in 0..c {
4485                let g = scratch.grad_x4[col];
4486                let mix = block.attn.x_v[col];
4487                let base = tr.attn_norm[col];
4488                let prev = tr.att_x_prev_old[col];
4489                scratch.grad_x3[col] += g * (1.0 - mix);
4490                scratch.grad_x2[col] = g * (prev - base);
4491            }
4492            if scope.attn {
4493                match optimizer {
4494                    OptimizerKind::Sgd => sgd_vec_update(
4495                        block.attn.x_v.as_mut_slice(),
4496                        scratch.grad_x2.as_slice(),
4497                        lr,
4498                        clip,
4499                    ),
4500                    OptimizerKind::Adam => {
4501                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
4502                        let adam =
4503                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4504                        apply_adam_vec_update(
4505                            block.attn.x_v.as_mut_slice(),
4506                            scratch.grad_x2.as_slice(),
4507                            &mut adam.attn.x_v,
4508                            cfg,
4509                        );
4510                    }
4511                }
4512            }
4513
4514            // x_a
4515            for col in 0..c {
4516                let g = scratch.grad_x5[col];
4517                let mix = block.attn.x_a[col];
4518                let base = tr.attn_norm[col];
4519                let prev = tr.att_x_prev_old[col];
4520                scratch.grad_x3[col] += g * (1.0 - mix);
4521                scratch.grad_x2[col] = g * (prev - base);
4522            }
4523            if scope.attn {
4524                match optimizer {
4525                    OptimizerKind::Sgd => sgd_vec_update(
4526                        block.attn.x_a.as_mut_slice(),
4527                        scratch.grad_x2.as_slice(),
4528                        lr,
4529                        clip,
4530                    ),
4531                    OptimizerKind::Adam => {
4532                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
4533                        let adam =
4534                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4535                        apply_adam_vec_update(
4536                            block.attn.x_a.as_mut_slice(),
4537                            scratch.grad_x2.as_slice(),
4538                            &mut adam.attn.x_a,
4539                            cfg,
4540                        );
4541                    }
4542                }
4543            }
4544
4545            // x_g
4546            for col in 0..c {
4547                let g = scratch.grad_saved[col];
4548                let mix = block.attn.x_g[col];
4549                let base = tr.attn_norm[col];
4550                let prev = tr.att_x_prev_old[col];
4551                scratch.grad_x3[col] += g * (1.0 - mix);
4552                scratch.grad_x2[col] = g * (prev - base);
4553            }
4554            if scope.attn {
4555                match optimizer {
4556                    OptimizerKind::Sgd => sgd_vec_update(
4557                        block.attn.x_g.as_mut_slice(),
4558                        scratch.grad_x2.as_slice(),
4559                        lr,
4560                        clip,
4561                    ),
4562                    OptimizerKind::Adam => {
4563                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
4564                        let adam =
4565                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4566                        apply_adam_vec_update(
4567                            block.attn.x_g.as_mut_slice(),
4568                            scratch.grad_x2.as_slice(),
4569                            &mut adam.attn.x_g,
4570                            cfg,
4571                        );
4572                    }
4573                }
4574            }
4575
4576            // attn norm backward.
4577            layer_norm_backward(
4578                tr.x_after_pre.as_slice(),
4579                block.attn_norm_w.as_slice(),
4580                scratch.grad_x3.as_slice(),
4581                self.cfg.layer_norm_eps,
4582                scratch.grad_x2.as_mut_slice(),
4583                scratch.grad_x4.as_mut_slice(),
4584                scratch.grad_x5.as_mut_slice(),
4585            );
4586            if scope.attn_norm {
4587                match optimizer {
4588                    OptimizerKind::Sgd => {
4589                        sgd_vec_update(
4590                            block.attn_norm_w.as_mut_slice(),
4591                            scratch.grad_x4.as_slice(),
4592                            lr,
4593                            clip,
4594                        );
4595                        sgd_vec_update(
4596                            block.attn_norm_b.as_mut_slice(),
4597                            scratch.grad_x5.as_slice(),
4598                            lr,
4599                            clip,
4600                        );
4601                    }
4602                    OptimizerKind::Adam => {
4603                        let cfg = adam_step.as_ref().expect("adam cfg initialized");
4604                        let adam =
4605                            &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4606                        apply_adam_vec_update(
4607                            block.attn_norm_w.as_mut_slice(),
4608                            scratch.grad_x4.as_slice(),
4609                            &mut adam.attn_norm_w,
4610                            cfg,
4611                        );
4612                        apply_adam_vec_update(
4613                            block.attn_norm_b.as_mut_slice(),
4614                            scratch.grad_x5.as_slice(),
4615                            &mut adam.attn_norm_b,
4616                            cfg,
4617                        );
4618                    }
4619                }
4620            }
4621            for col in 0..c {
4622                scratch.grad_x[col] += scratch.grad_x2[col];
4623            }
4624
4625            // pre_norm backward (layer 0 only).
4626            if layer_idx == 0
4627                && let (Some(w), Some(b)) = (&mut block.pre_norm_w, &mut block.pre_norm_b)
4628            {
4629                layer_norm_backward(
4630                    tr.x_in.as_slice(),
4631                    w.as_slice(),
4632                    scratch.grad_x.as_slice(),
4633                    self.cfg.layer_norm_eps,
4634                    scratch.grad_x2.as_mut_slice(),
4635                    scratch.grad_x3.as_mut_slice(),
4636                    scratch.grad_x4.as_mut_slice(),
4637                );
4638                if scope.pre_norm {
4639                    match optimizer {
4640                        OptimizerKind::Sgd => {
4641                            sgd_vec_update(w.as_mut_slice(), scratch.grad_x3.as_slice(), lr, clip);
4642                            sgd_vec_update(b.as_mut_slice(), scratch.grad_x4.as_slice(), lr, clip);
4643                        }
4644                        OptimizerKind::Adam => {
4645                            let cfg = adam_step.as_ref().expect("adam cfg initialized");
4646                            let adam = &mut model_adam.as_mut().expect("adam state exists").blocks
4647                                [layer_idx];
4648                            apply_adam_vec_update(
4649                                w.as_mut_slice(),
4650                                scratch.grad_x3.as_slice(),
4651                                adam.pre_norm_w.as_mut().expect("adam pre_norm_w"),
4652                                cfg,
4653                            );
4654                            apply_adam_vec_update(
4655                                b.as_mut_slice(),
4656                                scratch.grad_x4.as_slice(),
4657                                adam.pre_norm_b.as_mut().expect("adam pre_norm_b"),
4658                                cfg,
4659                            );
4660                        }
4661                    }
4662                }
4663                scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice());
4664            }
4665        }
4666
4667        if scope.embed {
4668            let token_idx = scratch
4669                .train_token
4670                .min(self.cfg.vocab_size.saturating_sub(1));
4671            let off = token_idx * c;
4672            let row = &mut self.embeddings.as_mut_slice()[off..off + c];
4673            match optimizer {
4674                OptimizerKind::Sgd => {
4675                    sgd_vec_update(row, scratch.grad_x.as_slice(), lr, clip);
4676                }
4677                OptimizerKind::Adam => {
4678                    let cfg = adam_step.as_ref().expect("adam cfg initialized");
4679                    let adam = model_adam.as_mut().expect("adam state exists");
4680                    let m = &mut adam.embeddings.m.as_mut_slice()[off..off + c];
4681                    let v = &mut adam.embeddings.v.as_mut_slice()[off..off + c];
4682                    apply_adam_vec_update_raw(row, scratch.grad_x.as_slice(), m, v, cfg);
4683                }
4684            }
4685        }
4686        Ok(())
4687    }
4688
4689    /// Forward pass for a single token.
4690    /// Returns logits for next token prediction.
4691    #[inline(never)]
4692    pub fn forward<'a>(
4693        &'a self,
4694        scratch: &'a mut ScratchBuffers,
4695        token: u32,
4696        state: &mut State,
4697    ) -> &'a [f32] {
4698        let mut sink = NullProfiler;
4699        self.forward_with_sink(scratch, token, state, &mut sink)
4700    }
4701
4702    /// Forward pass that records per-layer timings through a custom sink.
4703    #[inline(never)]
4704    pub fn forward_with_profiler<'a, S: ProfilerSink>(
4705        &'a self,
4706        scratch: &'a mut ScratchBuffers,
4707        token: u32,
4708        state: &mut State,
4709        profiler: &mut S,
4710    ) -> &'a [f32] {
4711        self.forward_with_sink(scratch, token, state, profiler)
4712    }
4713
4714    #[inline(never)]
4715    fn forward_with_sink<'a, S: ProfilerSink>(
4716        &'a self,
4717        scratch: &'a mut ScratchBuffers,
4718        token: u32,
4719        state: &mut State,
4720        profiler: &mut S,
4721    ) -> &'a [f32] {
4722        if scratch.capture_train_trace {
4723            self.forward_with_sink_impl::<true, S>(scratch, token, state, profiler)
4724        } else {
4725            self.forward_with_sink_impl::<false, S>(scratch, token, state, profiler)
4726        }
4727    }
4728
4729    fn forward_with_sink_impl<'a, const CAPTURE: bool, S: ProfilerSink>(
4730        &'a self,
4731        scratch: &'a mut ScratchBuffers,
4732        token: u32,
4733        state: &mut State,
4734        profiler: &mut S,
4735    ) -> &'a [f32] {
4736        let c = self.cfg.hidden_size;
4737        let _h = self.cfg.num_heads;
4738        let _n = self.cfg.head_dim;
4739        let num_layers = self.cfg.num_layers;
4740        let token_idx = (token as usize).min(self.cfg.vocab_size.saturating_sub(1));
4741
4742        // Get token embedding
4743        let emb_offset = token_idx * c;
4744        let emb_slice = &self.embeddings.as_slice()[emb_offset..emb_offset + c];
4745        scratch.x.as_mut_slice().copy_from_slice(emb_slice);
4746        if CAPTURE {
4747            scratch.train_token = token_idx;
4748            scratch.train_trace_valid = true;
4749        } else {
4750            scratch.train_trace_valid = false;
4751        }
4752
4753        profiler.begin_token();
4754
4755        unsafe {
4756            // Process each layer (using index to avoid borrow conflicts)
4757            for layer_idx in 0..num_layers {
4758                if CAPTURE {
4759                    scratch.train_trace_layers[layer_idx]
4760                        .x_in
4761                        .copy_from(&scratch.x);
4762                }
4763                // Pre-norm (layer 0 only)
4764                if let (Some(w), Some(b)) = (
4765                    &self.blocks[layer_idx].pre_norm_w,
4766                    &self.blocks[layer_idx].pre_norm_b,
4767                ) {
4768                    kernel::layer_norm_avx(
4769                        scratch.x.as_ptr(),
4770                        w.as_ptr(),
4771                        b.as_ptr(),
4772                        scratch.x.as_mut_ptr(),
4773                        c,
4774                        self.cfg.layer_norm_eps,
4775                    );
4776                }
4777                if CAPTURE {
4778                    scratch.train_trace_layers[layer_idx]
4779                        .x_after_pre
4780                        .copy_from(&scratch.x);
4781                }
4782
4783                // Attention norm
4784                kernel::layer_norm_avx(
4785                    scratch.x.as_ptr(),
4786                    self.blocks[layer_idx].attn_norm_w.as_ptr(),
4787                    self.blocks[layer_idx].attn_norm_b.as_ptr(),
4788                    scratch.x_normed.as_mut_ptr(),
4789                    c,
4790                    self.cfg.layer_norm_eps,
4791                );
4792                if CAPTURE {
4793                    scratch.train_trace_layers[layer_idx]
4794                        .attn_norm
4795                        .copy_from(&scratch.x_normed);
4796                }
4797
4798                let trace_ptr = if CAPTURE {
4799                    &mut scratch.train_trace_layers[layer_idx] as *mut LayerTrainTrace
4800                } else {
4801                    std::ptr::null_mut()
4802                };
4803                if S::ENABLED {
4804                    let attn_start = Instant::now();
4805                    self.attention_forward_impl::<CAPTURE>(scratch, layer_idx, state, trace_ptr);
4806                    profiler.record_attention(layer_idx, attn_start.elapsed());
4807                } else {
4808                    self.attention_forward_impl::<CAPTURE>(scratch, layer_idx, state, trace_ptr);
4809                }
4810
4811                // Add attention residual: x = x + att_out
4812                kernel::add_avx(
4813                    scratch.x.as_ptr(),
4814                    scratch.att_out.as_ptr(),
4815                    scratch.x.as_mut_ptr(),
4816                    c,
4817                );
4818                if CAPTURE {
4819                    scratch.train_trace_layers[layer_idx]
4820                        .x_after_attn
4821                        .copy_from(&scratch.x);
4822                }
4823
4824                // FFN norm
4825                kernel::layer_norm_avx(
4826                    scratch.x.as_ptr(),
4827                    self.blocks[layer_idx].ffn_norm_w.as_ptr(),
4828                    self.blocks[layer_idx].ffn_norm_b.as_ptr(),
4829                    scratch.x_normed.as_mut_ptr(),
4830                    c,
4831                    self.cfg.layer_norm_eps,
4832                );
4833                if CAPTURE {
4834                    scratch.train_trace_layers[layer_idx]
4835                        .ffn_norm
4836                        .copy_from(&scratch.x_normed);
4837                }
4838
4839                if S::ENABLED {
4840                    let ffn_start = Instant::now();
4841                    self.ffn_forward_impl::<CAPTURE>(
4842                        scratch,
4843                        layer_idx,
4844                        &mut state.layers[layer_idx],
4845                        trace_ptr,
4846                    );
4847                    profiler.record_ffn(layer_idx, ffn_start.elapsed());
4848                } else {
4849                    self.ffn_forward_impl::<CAPTURE>(
4850                        scratch,
4851                        layer_idx,
4852                        &mut state.layers[layer_idx],
4853                        trace_ptr,
4854                    );
4855                }
4856
4857                // Add FFN residual: x = x + ffn_out
4858                kernel::add_avx(
4859                    scratch.x.as_ptr(),
4860                    scratch.ffn_out.as_ptr(),
4861                    scratch.x.as_mut_ptr(),
4862                    c,
4863                );
4864                if CAPTURE {
4865                    scratch.train_trace_layers[layer_idx]
4866                        .x_out
4867                        .copy_from(&scratch.x);
4868                }
4869            }
4870
4871            // Output norm
4872            kernel::layer_norm_avx(
4873                scratch.x.as_ptr(),
4874                self.ln_out_w.as_ptr(),
4875                self.ln_out_b.as_ptr(),
4876                scratch.x_normed.as_mut_ptr(),
4877                c,
4878                self.cfg.layer_norm_eps,
4879            );
4880
4881            // LM head: logits = x @ lm_head.T
4882            kernel::gemv_avx(
4883                self.lm_head.as_ptr(),
4884                scratch.x_normed.as_ptr(),
4885                scratch.logits.as_mut_ptr(),
4886                self.cfg.vocab_size,
4887                c,
4888            );
4889        }
4890        if CAPTURE {
4891            scratch.train_v_first.copy_from(&state.v_first);
4892        }
4893
4894        scratch.logits.as_slice()
4895    }
4896
4897    #[inline(always)]
4898    unsafe fn attention_forward_impl<const CAPTURE: bool>(
4899        &self,
4900        scratch: &mut ScratchBuffers,
4901        layer_idx: usize,
4902        state: &mut State,
4903        trace: *mut LayerTrainTrace,
4904    ) {
4905        let attn = &self.blocks[layer_idx].attn;
4906        let layer_state = &mut state.layers[layer_idx];
4907        let c = self.cfg.hidden_size;
4908        let h = self.cfg.num_heads;
4909        let n = self.cfg.head_dim;
4910        let d_w = self.cfg.decay_low_rank;
4911        let d_a = self.cfg.a_low_rank;
4912        let d_g = self.cfg.g_low_rank;
4913        if CAPTURE {
4914            let tr = &mut *trace;
4915            tr.att_x_prev_old.copy_from(&layer_state.att_x_prev);
4916            tr.att_state_old.copy_from(&layer_state.att_state);
4917        }
4918
4919        kernel::token_shift_multi6_avx(
4920            scratch.x_normed.as_ptr(),
4921            layer_state.att_x_prev.as_ptr(),
4922            attn.x_r.as_ptr(),
4923            attn.x_w.as_ptr(),
4924            attn.x_k.as_ptr(),
4925            attn.x_v.as_ptr(),
4926            attn.x_a.as_ptr(),
4927            attn.x_g.as_ptr(),
4928            scratch.xr.as_mut_ptr(),
4929            scratch.xw.as_mut_ptr(),
4930            scratch.xk.as_mut_ptr(),
4931            scratch.xv.as_mut_ptr(),
4932            scratch.xa.as_mut_ptr(),
4933            scratch.xg.as_mut_ptr(),
4934            c,
4935        );
4936        if CAPTURE {
4937            let tr = &mut *trace;
4938            tr.xr.copy_from(&scratch.xr);
4939            tr.xw.copy_from(&scratch.xw);
4940            tr.xk.copy_from(&scratch.xk);
4941            tr.xv.copy_from(&scratch.xv);
4942            tr.xa.copy_from(&scratch.xa);
4943            tr.xg.copy_from(&scratch.xg);
4944        }
4945
4946        // Update prev state for next token
4947        kernel::copy(
4948            scratch.x_normed.as_ptr(),
4949            layer_state.att_x_prev.as_mut_ptr(),
4950            c,
4951        );
4952
4953        // r/k/v projections from packed matrix (sequential for better cache)
4954        // Packed layout: [r_proj (C*C), k_proj (C*C), v_proj (C*C)]
4955        let proj_size = c * c;
4956        kernel::gemv_avx(
4957            attn.rkv_proj.as_ptr(),
4958            scratch.xr.as_ptr(),
4959            scratch.r.as_mut_ptr(),
4960            c,
4961            c,
4962        );
4963        kernel::gemv_avx(
4964            attn.rkv_proj.as_ptr().add(proj_size),
4965            scratch.xk.as_ptr(),
4966            scratch.k.as_mut_ptr(),
4967            c,
4968            c,
4969        );
4970        kernel::gemv_avx(
4971            attn.rkv_proj.as_ptr().add(2 * proj_size),
4972            scratch.xv.as_ptr(),
4973            scratch.v.as_mut_ptr(),
4974            c,
4975            c,
4976        );
4977        if CAPTURE {
4978            let tr = &mut *trace;
4979            tr.r.copy_from(&scratch.r);
4980            tr.k_pre.copy_from(&scratch.k);
4981            tr.v_pre.copy_from(&scratch.v);
4982        }
4983
4984        // w decay: w = exp(-sigmoid(tanh(xw @ w1) @ w2 + w0) / sqrt(e))
4985        // Step 1: tmp = xw @ w1.T (D_w output)
4986        kernel::gemv_avx(
4987            attn.w1.as_ptr(),
4988            scratch.xw.as_ptr(),
4989            scratch.w_lora_tmp.as_mut_ptr(),
4990            d_w,
4991            c,
4992        );
4993        // Step 2: tanh
4994        kernel::tanh_avx(
4995            scratch.w_lora_tmp.as_ptr(),
4996            scratch.w_lora_tmp.as_mut_ptr(),
4997            d_w,
4998        );
4999        if CAPTURE {
5000            let tr = &mut *trace;
5001            tr.w_hidden.as_mut_slice()[0..d_w]
5002                .copy_from_slice(&scratch.w_lora_tmp.as_slice()[0..d_w]);
5003        }
5004        // Step 3: tmp @ w2.T + w0
5005        kernel::gemv_avx(
5006            attn.w2.as_ptr(),
5007            scratch.w_lora_tmp.as_ptr(),
5008            scratch.w_decay.as_mut_ptr(),
5009            c,
5010            d_w,
5011        );
5012        // Add bias w0
5013        kernel::add_avx(
5014            scratch.w_decay.as_ptr(),
5015            attn.w0.as_ptr(),
5016            scratch.w_decay.as_mut_ptr(),
5017            c,
5018        );
5019        if CAPTURE {
5020            let tr = &mut *trace;
5021            tr.w_pre.copy_from(&scratch.w_decay);
5022        }
5023        // Step 4: exp(-sigmoid(x) / sqrt(e))
5024        let inv_sqrt_e = 1.0 / std::f32::consts::E.sqrt();
5025        kernel::sigmoid_exp_neg_scaled_avx(
5026            scratch.w_decay.as_ptr(),
5027            scratch.w_decay.as_mut_ptr(),
5028            if CAPTURE {
5029                (*trace).w_sigmoid.as_mut_ptr()
5030            } else {
5031                std::ptr::null_mut()
5032            },
5033            inv_sqrt_e,
5034            c,
5035        );
5036        if CAPTURE {
5037            let tr = &mut *trace;
5038            tr.w_decay.copy_from(&scratch.w_decay);
5039        }
5040
5041        // a = sigmoid(xa @ a1.T @ a2.T + a0)
5042        kernel::gemv_avx(
5043            attn.a1.as_ptr(),
5044            scratch.xa.as_ptr(),
5045            scratch.w_lora_tmp.as_mut_ptr(),
5046            d_a,
5047            c,
5048        );
5049        if CAPTURE {
5050            let tr = &mut *trace;
5051            tr.a_hidden.as_mut_slice()[0..d_a]
5052                .copy_from_slice(&scratch.w_lora_tmp.as_slice()[0..d_a]);
5053        }
5054        kernel::gemv_avx(
5055            attn.a2.as_ptr(),
5056            scratch.w_lora_tmp.as_ptr(),
5057            scratch.a.as_mut_ptr(),
5058            c,
5059            d_a,
5060        );
5061        kernel::add_avx(
5062            scratch.a.as_ptr(),
5063            attn.a0.as_ptr(),
5064            scratch.a.as_mut_ptr(),
5065            c,
5066        );
5067        kernel::sigmoid_avx(scratch.a.as_ptr(), scratch.a.as_mut_ptr(), c);
5068        if CAPTURE {
5069            let tr = &mut *trace;
5070            tr.a.copy_from(&scratch.a);
5071        }
5072
5073        // g = sigmoid(xg @ g1.T) @ g2.T
5074        kernel::gemv_avx(
5075            attn.g1.as_ptr(),
5076            scratch.xg.as_ptr(),
5077            scratch.w_lora_tmp.as_mut_ptr(),
5078            d_g,
5079            c,
5080        );
5081        kernel::sigmoid_avx(
5082            scratch.w_lora_tmp.as_ptr(),
5083            scratch.w_lora_tmp.as_mut_ptr(),
5084            d_g,
5085        );
5086        if CAPTURE {
5087            let tr = &mut *trace;
5088            tr.g_hidden.as_mut_slice()[0..d_g]
5089                .copy_from_slice(&scratch.w_lora_tmp.as_slice()[0..d_g]);
5090        }
5091        kernel::gemv_avx(
5092            attn.g2.as_ptr(),
5093            scratch.w_lora_tmp.as_ptr(),
5094            scratch.g.as_mut_ptr(),
5095            c,
5096            d_g,
5097        );
5098        if CAPTURE {
5099            let tr = &mut *trace;
5100            tr.g.copy_from(&scratch.g);
5101        }
5102
5103        // Value residual (layer > 0)
5104        if layer_idx == 0 {
5105            // Copy v to v_first buffer (no allocation)
5106            state.v_first.copy_from(&scratch.v);
5107            state.v_first_set = true;
5108            if CAPTURE {
5109                let tr = &mut *trace;
5110                tr.uses_v_residual = false;
5111                tr.v.copy_from(&scratch.v);
5112            }
5113        } else if state.v_first_set
5114            && let (Some(v1), Some(v2), Some(v0)) = (&attn.v1, &attn.v2, &attn.v0)
5115        {
5116            let d_v = self.cfg.v_low_rank;
5117            // nu = sigmoid(xv @ v1.T @ v2.T + v0)
5118            kernel::gemv_avx(
5119                v1.as_ptr(),
5120                scratch.xv.as_ptr(),
5121                scratch.w_lora_tmp.as_mut_ptr(),
5122                d_v,
5123                c,
5124            );
5125            if CAPTURE {
5126                let tr = &mut *trace;
5127                tr.v_hidden.as_mut_slice()[0..d_v]
5128                    .copy_from_slice(&scratch.w_lora_tmp.as_slice()[0..d_v]);
5129            }
5130            kernel::gemv_avx(
5131                v2.as_ptr(),
5132                scratch.w_lora_tmp.as_ptr(),
5133                scratch.att_out.as_mut_ptr(), // reuse as temp
5134                c,
5135                d_v,
5136            );
5137            kernel::add_avx(
5138                scratch.att_out.as_ptr(),
5139                v0.as_ptr(),
5140                scratch.att_out.as_mut_ptr(),
5141                c,
5142            );
5143            kernel::sigmoid_avx(scratch.att_out.as_ptr(), scratch.att_out.as_mut_ptr(), c);
5144            if CAPTURE {
5145                let tr = &mut *trace;
5146                tr.uses_v_residual = true;
5147                tr.nu.copy_from(&scratch.att_out);
5148            }
5149            // v = v + (v_first - v) * nu
5150            for i in 0..c {
5151                let nu = scratch.att_out[i];
5152                scratch.v[i] += (state.v_first[i] - scratch.v[i]) * nu;
5153            }
5154            if CAPTURE {
5155                let tr = &mut *trace;
5156                tr.v.copy_from(&scratch.v);
5157            }
5158        } else if CAPTURE {
5159            let tr = &mut *trace;
5160            tr.uses_v_residual = false;
5161            tr.v.copy_from(&scratch.v);
5162        }
5163
5164        // kk = k * k_k, then L2 normalize per head
5165        kernel::mul_avx(
5166            scratch.k.as_ptr(),
5167            attn.k_k.as_ptr(),
5168            scratch.kk.as_mut_ptr(),
5169            c,
5170        );
5171        if CAPTURE {
5172            let tr = &mut *trace;
5173            tr.kk_pre.copy_from(&scratch.kk);
5174        }
5175        // Normalize per head
5176        for head in 0..h {
5177            let offset = head * n;
5178            kernel::l2_normalize_avx(
5179                scratch.kk.as_ptr().add(offset),
5180                scratch.kk.as_mut_ptr().add(offset),
5181                n,
5182                1e-12,
5183            );
5184        }
5185        if CAPTURE {
5186            let tr = &mut *trace;
5187            tr.kk.copy_from(&scratch.kk);
5188        }
5189
5190        // k = k * (1 + (a - 1) * k_a)
5191        for i in 0..c {
5192            let scale = 1.0 + (scratch.a[i] - 1.0) * attn.k_a[i];
5193            scratch.k[i] *= scale;
5194        }
5195        if CAPTURE {
5196            let tr = &mut *trace;
5197            tr.k.copy_from(&scratch.k);
5198        }
5199
5200        // WKV state update: S = S*w.T - S@kk*(kk*a).T + v*k.T; y = S@r
5201        kernel::rwkv7_wkv_update_avx(
5202            layer_state.att_state.as_mut_ptr(),
5203            scratch.w_decay.as_ptr(),
5204            scratch.k.as_ptr(),
5205            scratch.v.as_ptr(),
5206            scratch.kk.as_ptr(),
5207            scratch.a.as_ptr(),
5208            scratch.r.as_ptr(),
5209            scratch.y.as_mut_ptr(),
5210            h,
5211            n,
5212        );
5213        if CAPTURE {
5214            let tr = &mut *trace;
5215            tr.y_wkv.copy_from(&scratch.y);
5216        }
5217
5218        // Group norm
5219        kernel::group_norm_avx(
5220            scratch.y.as_ptr(),
5221            attn.g_norm_w.as_ptr(),
5222            attn.g_norm_b.as_ptr(),
5223            scratch.y.as_mut_ptr(),
5224            h,
5225            n,
5226            self.cfg.group_norm_eps,
5227        );
5228        if CAPTURE {
5229            let tr = &mut *trace;
5230            tr.y_gn.copy_from(&scratch.y);
5231        }
5232
5233        // Add head-qk term: y += ((r * k * r_k).sum_per_head) * v
5234        for head in 0..h {
5235            let offset = head * n;
5236            let mut alpha = 0.0f32;
5237            for j in 0..n {
5238                alpha += scratch.r[offset + j] * scratch.k[offset + j] * attn.r_k[head * n + j];
5239            }
5240            if CAPTURE {
5241                let tr = &mut *trace;
5242                tr.alpha[head] = alpha;
5243            }
5244            for j in 0..n {
5245                scratch.y[offset + j] += alpha * scratch.v[offset + j];
5246            }
5247        }
5248        if CAPTURE {
5249            let tr = &mut *trace;
5250            tr.y_head.copy_from(&scratch.y);
5251        }
5252
5253        // Apply gate: y = y * g
5254        kernel::mul_avx(
5255            scratch.y.as_ptr(),
5256            scratch.g.as_ptr(),
5257            scratch.y.as_mut_ptr(),
5258            c,
5259        );
5260        if CAPTURE {
5261            let tr = &mut *trace;
5262            tr.y_gate.copy_from(&scratch.y);
5263        }
5264
5265        // Output projection: att_out = o_proj @ y
5266        kernel::gemv_avx(
5267            attn.o_proj.as_ptr(),
5268            scratch.y.as_ptr(),
5269            scratch.att_out.as_mut_ptr(),
5270            c,
5271            c,
5272        );
5273        if CAPTURE {
5274            let tr = &mut *trace;
5275            tr.att_out.copy_from(&scratch.att_out);
5276        }
5277    }
5278
5279    #[inline(always)]
5280    unsafe fn ffn_forward_impl<const CAPTURE: bool>(
5281        &self,
5282        scratch: &mut ScratchBuffers,
5283        layer_idx: usize,
5284        layer_state: &mut LayerState,
5285        trace: *mut LayerTrainTrace,
5286    ) {
5287        let ffn = &self.blocks[layer_idx].ffn;
5288        let c = self.cfg.hidden_size;
5289        let i = self.cfg.intermediate_size;
5290        if CAPTURE {
5291            let tr = &mut *trace;
5292            tr.ffn_x_prev_old.copy_from(&layer_state.ffn_x_prev);
5293        }
5294
5295        // Token shift: xk = x_normed + x_k * (prev - x_normed)
5296        kernel::token_shift_avx(
5297            scratch.x_normed.as_ptr(),
5298            layer_state.ffn_x_prev.as_ptr(),
5299            ffn.x_k.as_ptr(),
5300            scratch.xk.as_mut_ptr(),
5301            c,
5302        );
5303        if CAPTURE {
5304            let tr = &mut *trace;
5305            tr.ffn_xk.copy_from(&scratch.xk);
5306        }
5307
5308        // Update prev state
5309        kernel::copy(
5310            scratch.x_normed.as_ptr(),
5311            layer_state.ffn_x_prev.as_mut_ptr(),
5312            c,
5313        );
5314
5315        // k = relu(xk @ key_w.T)^2
5316        kernel::gemv_avx(
5317            ffn.key_w.as_ptr(),
5318            scratch.xk.as_ptr(),
5319            scratch.ffn_k.as_mut_ptr(),
5320            i,
5321            c,
5322        );
5323        if CAPTURE {
5324            let tr = &mut *trace;
5325            tr.ffn_pre.copy_from(&scratch.ffn_k);
5326        }
5327        kernel::relu_squared_avx(scratch.ffn_k.as_ptr(), scratch.ffn_k.as_mut_ptr(), i);
5328        if CAPTURE {
5329            let tr = &mut *trace;
5330            tr.ffn_k.copy_from(&scratch.ffn_k);
5331        }
5332
5333        // ffn_out = k @ value_w.T
5334        kernel::gemv_avx(
5335            ffn.value_w.as_ptr(),
5336            scratch.ffn_k.as_ptr(),
5337            scratch.ffn_out.as_mut_ptr(),
5338            c,
5339            i,
5340        );
5341        if CAPTURE {
5342            let tr = &mut *trace;
5343            tr.ffn_out.copy_from(&scratch.ffn_out);
5344        }
5345    }
5346}
5347
5348#[allow(clippy::needless_range_loop)]
5349fn layer_norm_backward(
5350    input: &[f32],
5351    weight: &[f32],
5352    grad_out: &[f32],
5353    eps: f32,
5354    grad_input: &mut [f32],
5355    grad_weight: &mut [f32],
5356    grad_bias: &mut [f32],
5357) {
5358    let n = input
5359        .len()
5360        .min(weight.len())
5361        .min(grad_out.len())
5362        .min(grad_input.len())
5363        .min(grad_weight.len())
5364        .min(grad_bias.len());
5365    if n == 0 {
5366        return;
5367    }
5368    let nf = n as f32;
5369    let mut mean = 0.0f32;
5370    for &x in &input[0..n] {
5371        mean += x;
5372    }
5373    mean /= nf;
5374    let mut var = 0.0f32;
5375    for &x in &input[0..n] {
5376        let d = x - mean;
5377        var += d * d;
5378    }
5379    var /= nf;
5380    let inv_std = (var + eps).sqrt().recip();
5381    let mut sum_gw = 0.0f32;
5382    let mut sum_gw_xhat = 0.0f32;
5383    for i in 0..n {
5384        let xhat = (input[i] - mean) * inv_std;
5385        let gw = grad_out[i] * weight[i];
5386        grad_weight[i] = grad_out[i] * xhat;
5387        grad_bias[i] = grad_out[i];
5388        sum_gw += gw;
5389        sum_gw_xhat += gw * xhat;
5390    }
5391    for i in 0..n {
5392        let xhat = (input[i] - mean) * inv_std;
5393        let gw = grad_out[i] * weight[i];
5394        grad_input[i] = (gw * nf - sum_gw - xhat * sum_gw_xhat) * inv_std / nf;
5395    }
5396}
5397
5398#[allow(clippy::needless_range_loop, clippy::too_many_arguments)]
5399fn group_norm_backward(
5400    input: &[f32],
5401    weight: &[f32],
5402    grad_out: &[f32],
5403    num_groups: usize,
5404    group_size: usize,
5405    eps: f32,
5406    grad_input: &mut [f32],
5407    grad_weight: &mut [f32],
5408    grad_bias: &mut [f32],
5409) {
5410    let c = input
5411        .len()
5412        .min(weight.len())
5413        .min(grad_out.len())
5414        .min(grad_input.len())
5415        .min(grad_weight.len())
5416        .min(grad_bias.len());
5417    if c == 0 || num_groups == 0 || group_size == 0 {
5418        return;
5419    }
5420    grad_input[0..c].fill(0.0);
5421    grad_weight[0..c].fill(0.0);
5422    grad_bias[0..c].fill(0.0);
5423    let g = num_groups.min(c / group_size);
5424    let n = group_size as f32;
5425    for group in 0..g {
5426        let off = group * group_size;
5427        let end = (off + group_size).min(c);
5428        let len = end - off;
5429        if len == 0 {
5430            continue;
5431        }
5432        let mut mean = 0.0f32;
5433        for idx in off..end {
5434            mean += input[idx];
5435        }
5436        mean /= len as f32;
5437        let mut var = 0.0f32;
5438        for idx in off..end {
5439            let d = input[idx] - mean;
5440            var += d * d;
5441        }
5442        var /= len as f32;
5443        let inv_std = (var + eps).sqrt().recip();
5444        let mut sum_gw = 0.0f32;
5445        let mut sum_gw_xhat = 0.0f32;
5446        for idx in off..end {
5447            let xhat = (input[idx] - mean) * inv_std;
5448            let gw = grad_out[idx] * weight[idx];
5449            grad_weight[idx] += grad_out[idx] * xhat;
5450            grad_bias[idx] += grad_out[idx];
5451            sum_gw += gw;
5452            sum_gw_xhat += gw * xhat;
5453        }
5454        for idx in off..end {
5455            let xhat = (input[idx] - mean) * inv_std;
5456            let gw = grad_out[idx] * weight[idx];
5457            grad_input[idx] = (gw * n - sum_gw - xhat * sum_gw_xhat) * inv_std / n;
5458        }
5459    }
5460}
5461
5462fn l2_normalize_backward(
5463    x: &[f32],
5464    y: &[f32],
5465    grad_out: &[f32],
5466    min_norm: f32,
5467    grad_input: &mut [f32],
5468) {
5469    let n = x
5470        .len()
5471        .min(y.len())
5472        .min(grad_out.len())
5473        .min(grad_input.len());
5474    if n == 0 {
5475        return;
5476    }
5477    let mut norm_sq = 0.0f32;
5478    for &v in &x[0..n] {
5479        norm_sq += v * v;
5480    }
5481    let norm_raw = norm_sq.sqrt();
5482    if norm_raw <= min_norm {
5483        let inv = min_norm.recip();
5484        for i in 0..n {
5485            grad_input[i] = grad_out[i] * inv;
5486        }
5487        return;
5488    }
5489    let norm = norm_raw;
5490    let mut dot = 0.0f32;
5491    for i in 0..n {
5492        dot += grad_out[i] * y[i];
5493    }
5494    let inv = norm.recip();
5495    for i in 0..n {
5496        grad_input[i] = (grad_out[i] - y[i] * dot) * inv;
5497    }
5498}
5499
5500#[inline(always)]
5501fn add_vec_grad(dst: &mut [f32], src: &[f32]) {
5502    let n = dst.len().min(src.len());
5503    for i in 0..n {
5504        dst[i] += src[i];
5505    }
5506}
5507
5508#[inline(always)]
5509#[allow(clippy::needless_range_loop)]
5510fn add_outer_grad(dst: &mut [f32], rows: usize, cols: usize, left: &[f32], right: &[f32]) {
5511    let rows = rows.min(left.len());
5512    let cols = cols.min(right.len());
5513    let n = dst.len();
5514    if rows == 0 || cols == 0 || n == 0 {
5515        return;
5516    }
5517    for r in 0..rows {
5518        let g = left[r];
5519        if g == 0.0 {
5520            continue;
5521        }
5522        let off = r * cols;
5523        if off >= n {
5524            break;
5525        }
5526        let row_cols = cols.min(n - off);
5527        for c in 0..row_cols {
5528            dst[off + c] += g * right[c];
5529        }
5530    }
5531}
5532
5533#[inline(always)]
5534fn sgd_vec_update(param: &mut [f32], grad: &[f32], lr: f32, clip: f32) {
5535    let n = param.len().min(grad.len());
5536    if n == 0 {
5537        return;
5538    }
5539    if clip > 0.0 {
5540        for i in 0..n {
5541            param[i] += lr * grad[i].clamp(-clip, clip);
5542        }
5543    } else {
5544        for i in 0..n {
5545            param[i] += lr * grad[i];
5546        }
5547    }
5548}
5549
5550#[inline(always)]
5551#[allow(clippy::needless_range_loop)]
5552fn sgd_outer_update(
5553    param: &mut [f32],
5554    rows: usize,
5555    cols: usize,
5556    left: &[f32],
5557    right: &[f32],
5558    lr: f32,
5559    clip: f32,
5560) {
5561    let rows = rows.min(left.len());
5562    let cols = cols.min(right.len());
5563    let n = param.len();
5564    if rows == 0 || cols == 0 || n == 0 {
5565        return;
5566    }
5567    for r in 0..rows {
5568        let g = left[r];
5569        let off = r * cols;
5570        if off >= n {
5571            break;
5572        }
5573        let row_cols = cols.min(n - off);
5574        if clip > 0.0 {
5575            for c in 0..row_cols {
5576                param[off + c] += lr * (g * right[c]).clamp(-clip, clip);
5577            }
5578        } else {
5579            for c in 0..row_cols {
5580                param[off + c] += lr * g * right[c];
5581            }
5582        }
5583    }
5584}
5585
5586#[inline(always)]
5587#[allow(clippy::needless_range_loop, clippy::too_many_arguments)]
5588fn fused_sgd_head_backward_update(
5589    param: &mut [f32],
5590    rows: usize,
5591    cols: usize,
5592    left: &[f32],
5593    right: &[f32],
5594    grad_input: &mut [f32],
5595    lr: f32,
5596    clip: f32,
5597) {
5598    let rows = rows.min(left.len());
5599    let cols = cols.min(right.len()).min(grad_input.len());
5600    let n = param.len();
5601    if rows == 0 || cols == 0 || n == 0 {
5602        return;
5603    }
5604    let do_clip = clip > 0.0;
5605    let lr8 = f32x8::splat(lr);
5606    for row in 0..rows {
5607        let g = left[row];
5608        if g == 0.0 {
5609            continue;
5610        }
5611        let off = row * cols;
5612        if off >= n {
5613            break;
5614        }
5615        let row_cols = cols.min(n - off);
5616        if do_clip {
5617            for col in 0..row_cols {
5618                let idx = off + col;
5619                let w_old = param[idx];
5620                grad_input[col] += w_old * g;
5621                param[idx] = w_old + lr * (g * right[col]).clamp(-clip, clip);
5622            }
5623            continue;
5624        }
5625        let mut col = 0usize;
5626        unsafe {
5627            let g8 = f32x8::splat(g);
5628            while col + 8 <= row_cols {
5629                let idx = off + col;
5630                let wv = param.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5631                let rv = right.as_ptr().add(col).cast::<f32x8>().read_unaligned();
5632                let giv = grad_input
5633                    .as_ptr()
5634                    .add(col)
5635                    .cast::<f32x8>()
5636                    .read_unaligned();
5637                grad_input
5638                    .as_mut_ptr()
5639                    .add(col)
5640                    .cast::<f32x8>()
5641                    .write_unaligned(giv + wv * g8);
5642                param
5643                    .as_mut_ptr()
5644                    .add(idx)
5645                    .cast::<f32x8>()
5646                    .write_unaligned(wv + (g8 * rv) * lr8);
5647                col += 8;
5648            }
5649        }
5650        while col < row_cols {
5651            let idx = off + col;
5652            let w_old = param[idx];
5653            grad_input[col] += w_old * g;
5654            param[idx] = w_old + lr * g * right[col];
5655            col += 1;
5656        }
5657    }
5658}
5659
5660#[inline(always)]
5661fn apply_adam_vec_update(
5662    param: &mut [f32],
5663    grad: &[f32],
5664    adam: &mut AdamTensorState,
5665    step: &AdamStep,
5666) {
5667    let n = param
5668        .len()
5669        .min(grad.len())
5670        .min(adam.m.len())
5671        .min(adam.v.len());
5672    if n == 0 {
5673        return;
5674    }
5675    apply_adam_vec_update_raw(
5676        &mut param[0..n],
5677        &grad[0..n],
5678        &mut adam.m.as_mut_slice()[0..n],
5679        &mut adam.v.as_mut_slice()[0..n],
5680        step,
5681    );
5682}
5683
5684#[inline(always)]
5685fn apply_adam_vec_update_raw(
5686    param: &mut [f32],
5687    grad: &[f32],
5688    m: &mut [f32],
5689    v: &mut [f32],
5690    step: &AdamStep,
5691) {
5692    let n = param.len().min(grad.len()).min(m.len()).min(v.len());
5693    if n == 0 {
5694        return;
5695    }
5696    let b1 = step.b1;
5697    let b2 = step.b2;
5698    let one_b1 = 1.0 - b1;
5699    let one_b2 = 1.0 - b2;
5700    let inv_bc1 = 1.0 / step.bias_corr1;
5701    let inv_bc2 = 1.0 / step.bias_corr2;
5702    let do_clip = step.clip > 0.0;
5703    let clip = step.clip;
5704    if do_clip {
5705        for idx in 0..n {
5706            let g = grad[idx].clamp(-clip, clip);
5707            let mm = b1 * m[idx] + one_b1 * g;
5708            let vv = b2 * v[idx] + one_b2 * g * g;
5709            m[idx] = mm;
5710            v[idx] = vv;
5711            let m_hat = mm * inv_bc1;
5712            let v_hat = vv * inv_bc2;
5713            param[idx] += step.lr * m_hat / (v_hat.sqrt() + step.eps);
5714        }
5715        return;
5716    }
5717    let mut idx = 0usize;
5718    unsafe {
5719        let b1v = f32x8::splat(b1);
5720        let b2v = f32x8::splat(b2);
5721        let one_b1v = f32x8::splat(one_b1);
5722        let one_b2v = f32x8::splat(one_b2);
5723        let inv_bc1v = f32x8::splat(inv_bc1);
5724        let inv_bc2v = f32x8::splat(inv_bc2);
5725        let lrv = f32x8::splat(step.lr);
5726        let epsv = f32x8::splat(step.eps);
5727        while idx + 8 <= n {
5728            let gv = grad.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5729            let mv = m.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5730            let vv = v.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5731            let mm = mv * b1v + gv * one_b1v;
5732            let vv2 = vv * b2v + (gv * gv) * one_b2v;
5733            m.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(mm);
5734            v.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(vv2);
5735            let pv = param.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5736            let upd = ((mm * inv_bc1v) / ((vv2 * inv_bc2v).sqrt() + epsv)) * lrv;
5737            param
5738                .as_mut_ptr()
5739                .add(idx)
5740                .cast::<f32x8>()
5741                .write_unaligned(pv + upd);
5742            idx += 8;
5743        }
5744    }
5745    while idx < n {
5746        let g = grad[idx];
5747        let mm = b1 * m[idx] + one_b1 * g;
5748        let vv = b2 * v[idx] + one_b2 * g * g;
5749        m[idx] = mm;
5750        v[idx] = vv;
5751        let m_hat = mm * inv_bc1;
5752        let v_hat = vv * inv_bc2;
5753        param[idx] += step.lr * m_hat / (v_hat.sqrt() + step.eps);
5754        idx += 1;
5755    }
5756}
5757
5758#[inline(always)]
5759#[allow(clippy::needless_range_loop, clippy::too_many_arguments)]
5760fn fused_adam_head_backward_update(
5761    param: &mut [f32],
5762    rows: usize,
5763    cols: usize,
5764    left: &[f32],
5765    right: &[f32],
5766    grad_input: &mut [f32],
5767    m: &mut [f32],
5768    v: &mut [f32],
5769    step: &AdamStep,
5770) {
5771    let rows = rows.min(left.len());
5772    let cols = cols.min(right.len()).min(grad_input.len());
5773    let n = param.len().min(m.len()).min(v.len());
5774    if rows == 0 || cols == 0 || n == 0 {
5775        return;
5776    }
5777    let b1 = step.b1;
5778    let b2 = step.b2;
5779    let one_b1 = 1.0 - b1;
5780    let one_b2 = 1.0 - b2;
5781    let inv_bc1 = 1.0 / step.bias_corr1;
5782    let inv_bc2 = 1.0 / step.bias_corr2;
5783    let do_clip = step.clip > 0.0;
5784    let clip = step.clip;
5785    let b1v = f32x8::splat(b1);
5786    let b2v = f32x8::splat(b2);
5787    let one_b1v = f32x8::splat(one_b1);
5788    let one_b2v = f32x8::splat(one_b2);
5789    let inv_bc1v = f32x8::splat(inv_bc1);
5790    let inv_bc2v = f32x8::splat(inv_bc2);
5791    let epsv = f32x8::splat(step.eps);
5792    let lrv = f32x8::splat(step.lr);
5793    for row in 0..rows {
5794        let g = left[row];
5795        if g == 0.0 {
5796            continue;
5797        }
5798        let off = row * cols;
5799        if off >= n {
5800            break;
5801        }
5802        let row_cols = cols.min(n - off);
5803        if do_clip {
5804            for col in 0..row_cols {
5805                let idx = off + col;
5806                let w_old = param[idx];
5807                grad_input[col] += w_old * g;
5808                let gg = (g * right[col]).clamp(-clip, clip);
5809                let mm = b1 * m[idx] + one_b1 * gg;
5810                let vv = b2 * v[idx] + one_b2 * gg * gg;
5811                m[idx] = mm;
5812                v[idx] = vv;
5813                let m_hat = mm * inv_bc1;
5814                let v_hat = vv * inv_bc2;
5815                param[idx] = w_old + step.lr * m_hat / (v_hat.sqrt() + step.eps);
5816            }
5817            continue;
5818        }
5819        let mut col = 0usize;
5820        unsafe {
5821            let g8 = f32x8::splat(g);
5822            while col + 8 <= row_cols {
5823                let idx = off + col;
5824                let wv = param.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5825                let rv = right.as_ptr().add(col).cast::<f32x8>().read_unaligned();
5826                let giv = grad_input
5827                    .as_ptr()
5828                    .add(col)
5829                    .cast::<f32x8>()
5830                    .read_unaligned();
5831                grad_input
5832                    .as_mut_ptr()
5833                    .add(col)
5834                    .cast::<f32x8>()
5835                    .write_unaligned(giv + wv * g8);
5836                let gv = g8 * rv;
5837                let mv = m.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5838                let vv = v.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5839                let mm = mv * b1v + gv * one_b1v;
5840                let vv2 = vv * b2v + (gv * gv) * one_b2v;
5841                m.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(mm);
5842                v.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(vv2);
5843                let upd = ((mm * inv_bc1v) / ((vv2 * inv_bc2v).sqrt() + epsv)) * lrv;
5844                param
5845                    .as_mut_ptr()
5846                    .add(idx)
5847                    .cast::<f32x8>()
5848                    .write_unaligned(wv + upd);
5849                col += 8;
5850            }
5851        }
5852        while col < row_cols {
5853            let idx = off + col;
5854            let w_old = param[idx];
5855            grad_input[col] += w_old * g;
5856            let gg = g * right[col];
5857            let mm = b1 * m[idx] + one_b1 * gg;
5858            let vv = b2 * v[idx] + one_b2 * gg * gg;
5859            m[idx] = mm;
5860            v[idx] = vv;
5861            let m_hat = mm * inv_bc1;
5862            let v_hat = vv * inv_bc2;
5863            param[idx] = w_old + step.lr * m_hat / (v_hat.sqrt() + step.eps);
5864            col += 1;
5865        }
5866    }
5867}
5868
5869#[inline(always)]
5870fn apply_adam_outer_update(
5871    param: &mut [f32],
5872    rows: usize,
5873    cols: usize,
5874    left: &[f32],
5875    right: &[f32],
5876    adam: &mut AdamTensorState,
5877    step: &AdamStep,
5878) {
5879    let n = param.len().min(adam.m.len()).min(adam.v.len());
5880    if n == 0 {
5881        return;
5882    }
5883    apply_adam_outer_update_raw(
5884        &mut param[0..n],
5885        rows,
5886        cols,
5887        left,
5888        right,
5889        &mut adam.m.as_mut_slice()[0..n],
5890        &mut adam.v.as_mut_slice()[0..n],
5891        step,
5892    );
5893}
5894
5895#[allow(clippy::too_many_arguments)]
5896#[inline(always)]
5897#[allow(clippy::needless_range_loop)]
5898fn apply_adam_outer_update_raw(
5899    param: &mut [f32],
5900    rows: usize,
5901    cols: usize,
5902    left: &[f32],
5903    right: &[f32],
5904    m: &mut [f32],
5905    v: &mut [f32],
5906    step: &AdamStep,
5907) {
5908    let rows = rows.min(left.len());
5909    let cols = cols.min(right.len());
5910    let n = param.len().min(m.len()).min(v.len());
5911    if rows == 0 || cols == 0 || n == 0 {
5912        return;
5913    }
5914    let b1 = step.b1;
5915    let b2 = step.b2;
5916    let one_b1 = 1.0 - b1;
5917    let one_b2 = 1.0 - b2;
5918    let inv_bc1 = 1.0 / step.bias_corr1;
5919    let inv_bc2 = 1.0 / step.bias_corr2;
5920    let do_clip = step.clip > 0.0;
5921    let clip = step.clip;
5922    let b1v = f32x8::splat(b1);
5923    let b2v = f32x8::splat(b2);
5924    let one_b1v = f32x8::splat(one_b1);
5925    let one_b2v = f32x8::splat(one_b2);
5926    let inv_bc1v = f32x8::splat(inv_bc1);
5927    let inv_bc2v = f32x8::splat(inv_bc2);
5928    let epsv = f32x8::splat(step.eps);
5929    let lrv = f32x8::splat(step.lr);
5930    for row in 0..rows {
5931        let g_row = left[row];
5932        let off = row * cols;
5933        if off >= n {
5934            break;
5935        }
5936        let row_cols = (n - off).min(cols);
5937        if do_clip {
5938            for col in 0..row_cols {
5939                let idx = off + col;
5940                let g = (g_row * right[col]).clamp(-clip, clip);
5941                let mm = b1 * m[idx] + one_b1 * g;
5942                let vv = b2 * v[idx] + one_b2 * g * g;
5943                m[idx] = mm;
5944                v[idx] = vv;
5945                let m_hat = mm * inv_bc1;
5946                let v_hat = vv * inv_bc2;
5947                param[idx] += step.lr * m_hat / (v_hat.sqrt() + step.eps);
5948            }
5949            continue;
5950        }
5951        let mut col = 0usize;
5952        unsafe {
5953            let g8 = f32x8::splat(g_row);
5954            while col + 8 <= row_cols {
5955                let idx = off + col;
5956                let rv = right.as_ptr().add(col).cast::<f32x8>().read_unaligned();
5957                let gv = g8 * rv;
5958                let mv = m.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5959                let vv = v.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5960                let mm = mv * b1v + gv * one_b1v;
5961                let vv2 = vv * b2v + (gv * gv) * one_b2v;
5962                m.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(mm);
5963                v.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(vv2);
5964                let pv = param.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5965                let upd = ((mm * inv_bc1v) / ((vv2 * inv_bc2v).sqrt() + epsv)) * lrv;
5966                param
5967                    .as_mut_ptr()
5968                    .add(idx)
5969                    .cast::<f32x8>()
5970                    .write_unaligned(pv + upd);
5971                col += 8;
5972            }
5973        }
5974        while col < row_cols {
5975            let idx = off + col;
5976            let g = g_row * right[col];
5977            let mm = b1 * m[idx] + one_b1 * g;
5978            let vv = b2 * v[idx] + one_b2 * g * g;
5979            m[idx] = mm;
5980            v[idx] = vv;
5981            let m_hat = mm * inv_bc1;
5982            let v_hat = vv * inv_bc2;
5983            param[idx] += step.lr * m_hat / (v_hat.sqrt() + step.eps);
5984            col += 1;
5985        }
5986    }
5987}
5988
5989struct RwkvRng {
5990    state: u64,
5991}
5992
5993impl RwkvRng {
5994    fn new(seed: u64) -> Self {
5995        Self {
5996            state: seed ^ 0x9E37_79B9_7F4A_7C15,
5997        }
5998    }
5999
6000    #[inline]
6001    fn next_u32(&mut self) -> u32 {
6002        self.state = self
6003            .state
6004            .wrapping_mul(6_364_136_223_846_793_005)
6005            .wrapping_add(1);
6006        (self.state >> 32) as u32
6007    }
6008
6009    #[inline]
6010    fn next_f32(&mut self) -> f32 {
6011        let v = self.next_u32() as f32;
6012        v * (1.0 / (u32::MAX as f32))
6013    }
6014}
6015
6016#[inline]
6017fn init_uniform(t: &mut Tensor1D, rng: &mut RwkvRng, scale: f32) {
6018    let s = t.as_mut_slice();
6019    for v in s {
6020        let r = rng.next_f32() - 0.5;
6021        *v = r * 2.0 * scale;
6022    }
6023}
6024
6025#[inline]
6026fn init_centered(t: &mut Tensor1D, rng: &mut RwkvRng, center: f32, scale: f32) {
6027    let s = t.as_mut_slice();
6028    for v in s {
6029        let r = rng.next_f32() - 0.5;
6030        *v = center + r * 2.0 * scale;
6031    }
6032}
6033
6034#[inline]
6035fn init_const(t: &mut Tensor1D, value: f32) {
6036    t.as_mut_slice().fill(value);
6037}
6038
6039#[cfg(test)]
6040mod tests {
6041    use super::*;
6042
6043    fn test_cfg() -> Config {
6044        Config {
6045            vocab_size: 256,
6046            hidden_size: 64,
6047            num_layers: 1,
6048            num_heads: 1,
6049            head_dim: 64,
6050            intermediate_size: 64,
6051            layer_norm_eps: 1e-5,
6052            group_norm_eps: 64e-5,
6053            decay_low_rank: 8,
6054            a_low_rank: 8,
6055            v_low_rank: 8,
6056            g_low_rank: 8,
6057        }
6058    }
6059
6060    fn softmax_loss(logits: &[f32], target: u8) -> f64 {
6061        let max_logit = logits
6062            .iter()
6063            .copied()
6064            .fold(f32::NEG_INFINITY, |a, b| a.max(b));
6065        let mut sum = 0.0f64;
6066        for &z in logits {
6067            sum += ((z - max_logit) as f64).exp();
6068        }
6069        let p = ((logits[target as usize] - max_logit) as f64).exp() / sum.max(1e-300);
6070        -p.max(1e-300).ln()
6071    }
6072
6073    fn segment_loss(model: &Model, cfg: &Config, steps: &[(u32, u8)]) -> f64 {
6074        if steps.is_empty() {
6075            return 0.0;
6076        }
6077        let mut scratch = ScratchBuffers::new(cfg);
6078        let mut state = model.new_state();
6079        let mut loss = 0.0f64;
6080        for &(input, target) in steps {
6081            let logits = model.forward(&mut scratch, input, &mut state);
6082            loss += softmax_loss(logits, target);
6083        }
6084        loss / (steps.len() as f64)
6085    }
6086
6087    fn segment_grads(model: &Model, cfg: &Config, steps: &[(u32, u8)]) -> FullGradState {
6088        let mut scratch = ScratchBuffers::new(cfg);
6089        let mut state = model.new_state();
6090        let mut states = Vec::with_capacity(steps.len() + 1);
6091        let mut traces = Vec::with_capacity(steps.len());
6092        let mut pdfs = Vec::with_capacity(steps.len());
6093        states.push(state.clone());
6094        for &(input, _) in steps {
6095            scratch.set_capture_train_trace(true);
6096            let logits = model.forward(&mut scratch, input, &mut state);
6097            let mut pdf = vec![0.0f64; cfg.vocab_size];
6098            super::super::super::softmax_pdf_floor_with_bias(logits, None, &mut pdf);
6099            pdfs.push(pdf);
6100            traces.push(TokenTrainTrace::from_scratch(&scratch));
6101            states.push(state.clone());
6102        }
6103        let mut grads = model.new_full_grad_state();
6104        let mut recurrent = model.new_recurrent_grad_state();
6105        let scope = TrainScopeMask {
6106            embed: true,
6107            pre_norm: true,
6108            attn_norm: true,
6109            ffn_norm: true,
6110            attn: true,
6111            ffn: true,
6112            head: true,
6113            bias: false,
6114        };
6115        let grad_scale = 1.0f32 / (steps.len() as f32);
6116        for idx in (0..steps.len()).rev() {
6117            model
6118                .accumulate_token_step_gradients(
6119                    &mut scratch,
6120                    &traces[idx],
6121                    &states[idx + 1],
6122                    steps[idx].1,
6123                    &pdfs[idx],
6124                    grad_scale,
6125                    scope,
6126                    &mut grads,
6127                    None,
6128                    &mut recurrent,
6129                )
6130                .expect("segment gradient accumulation");
6131        }
6132        grads
6133    }
6134
6135    #[derive(Clone, Copy, Debug)]
6136    enum Probe {
6137        Embed,
6138        LnOutW,
6139        AttnNormW,
6140        OProj,
6141        KProj,
6142        VProj,
6143        FfnKey,
6144    }
6145
6146    fn probe_value(model: &Model, probe: Probe) -> f32 {
6147        match probe {
6148            Probe::Embed => model.embeddings[7],
6149            Probe::LnOutW => model.ln_out_w[5],
6150            Probe::AttnNormW => model.blocks[0].attn_norm_w[9],
6151            Probe::OProj => model.blocks[0].attn.o_proj[23],
6152            Probe::KProj => model.blocks[0].attn.rkv_proj[64 * 64 + 17],
6153            Probe::VProj => model.blocks[0].attn.rkv_proj[2 * 64 * 64 + 29],
6154            Probe::FfnKey => model.blocks[0].ffn.key_w[11],
6155        }
6156    }
6157
6158    fn set_probe(model: &mut Model, probe: Probe, value: f32) {
6159        match probe {
6160            Probe::Embed => model.embeddings[7] = value,
6161            Probe::LnOutW => model.ln_out_w[5] = value,
6162            Probe::AttnNormW => model.blocks[0].attn_norm_w[9] = value,
6163            Probe::OProj => model.blocks[0].attn.o_proj[23] = value,
6164            Probe::KProj => model.blocks[0].attn.rkv_proj[64 * 64 + 17] = value,
6165            Probe::VProj => model.blocks[0].attn.rkv_proj[2 * 64 * 64 + 29] = value,
6166            Probe::FfnKey => model.blocks[0].ffn.key_w[11] = value,
6167        }
6168    }
6169
6170    fn probe_grad(grads: &FullGradState, probe: Probe) -> f32 {
6171        match probe {
6172            Probe::Embed => grads.embeddings[7],
6173            Probe::LnOutW => grads.ln_out_w[5],
6174            Probe::AttnNormW => grads.blocks[0].attn_norm_w[9],
6175            Probe::OProj => grads.blocks[0].attn.o_proj[23],
6176            Probe::KProj => grads.blocks[0].attn.rkv_proj[64 * 64 + 17],
6177            Probe::VProj => grads.blocks[0].attn.rkv_proj[2 * 64 * 64 + 29],
6178            Probe::FfnKey => grads.blocks[0].ffn.key_w[11],
6179        }
6180    }
6181
6182    fn weighted_checksum(data: &[f32]) -> f64 {
6183        data.iter()
6184            .enumerate()
6185            .map(|(i, &v)| (i as f64 + 1.0) * (v as f64))
6186            .sum()
6187    }
6188
6189    #[test]
6190    fn test_config_default() {
6191        let cfg = Config::default();
6192        assert_eq!(cfg.vocab_size, 256);
6193        assert_eq!(cfg.hidden_size, 256);
6194        assert_eq!(cfg.num_layers, 12);
6195        assert_eq!(cfg.num_heads, 4);
6196        assert_eq!(cfg.head_dim, 64);
6197    }
6198
6199    #[test]
6200    fn test_forward_deterministic_snapshot() {
6201        let cfg = Config {
6202            vocab_size: 256,
6203            hidden_size: 64,
6204            num_layers: 2,
6205            num_heads: 1,
6206            head_dim: 64,
6207            intermediate_size: 128,
6208            layer_norm_eps: 1e-5,
6209            group_norm_eps: 64e-5,
6210            decay_low_rank: 16,
6211            a_low_rank: 16,
6212            v_low_rank: 16,
6213            g_low_rank: 32,
6214        };
6215        cfg.validate().expect("valid test config");
6216
6217        let model = Model::new_random(cfg.clone(), 0x1234_5678_9ABC_DEF0).expect("random model");
6218        let mut state = model.new_state();
6219        let mut scratch = ScratchBuffers::new(&cfg);
6220        let tokens = [0u32, 1, 7, 42, 255, 3, 128, 64, 17, 99];
6221
6222        let mut probes = Vec::new();
6223        let mut last_logits = vec![0.0; 8];
6224
6225        for &token in &tokens {
6226            let logits = model.forward(&mut scratch, token, &mut state);
6227            probes.push(logits[0]);
6228            probes.push(logits[1]);
6229            probes.push(logits[2]);
6230            probes.push(logits[42]);
6231            probes.push(logits[127]);
6232            probes.push(logits[255]);
6233            last_logits.copy_from_slice(&logits[0..8]);
6234        }
6235
6236        let probe_checksum = weighted_checksum(&probes);
6237        let last_logits_checksum = weighted_checksum(&last_logits);
6238        let state_att_checksum = weighted_checksum(state.layers[0].att_state.as_slice());
6239        let state_prev_checksum = weighted_checksum(state.layers[1].att_x_prev.as_slice());
6240        let v_first_checksum = weighted_checksum(state.v_first.as_slice());
6241
6242        let expected_probe_checksum = 25.674_967_924_598_604_f64;
6243        let expected_last_logits_checksum = 0.679_873_816_668_987_3_f64;
6244        let expected_state_att_checksum = 129.962_464_237_222_32_f64;
6245        let expected_state_prev_checksum = -231.326_208_570_972_08_f64;
6246        let expected_v_first_checksum = -1.921_361_377_462_744_7_f64;
6247
6248        let tol = 2e-4_f64;
6249        assert!(
6250            (probe_checksum - expected_probe_checksum).abs() <= tol,
6251            "probe_checksum={probe_checksum}"
6252        );
6253        assert!(
6254            (last_logits_checksum - expected_last_logits_checksum).abs() <= tol,
6255            "last_logits_checksum={last_logits_checksum}"
6256        );
6257        assert!(
6258            (state_att_checksum - expected_state_att_checksum).abs() <= tol,
6259            "state_att_checksum={state_att_checksum}"
6260        );
6261        assert!(
6262            (state_prev_checksum - expected_state_prev_checksum).abs() <= tol,
6263            "state_prev_checksum={state_prev_checksum}"
6264        );
6265        assert!(
6266            (v_first_checksum - expected_v_first_checksum).abs() <= tol,
6267            "v_first_checksum={v_first_checksum}"
6268        );
6269    }
6270
6271    #[test]
6272    fn traced_and_untraced_forward_match_exactly() {
6273        let cfg = Config {
6274            vocab_size: 256,
6275            hidden_size: 64,
6276            num_layers: 2,
6277            num_heads: 1,
6278            head_dim: 64,
6279            intermediate_size: 128,
6280            layer_norm_eps: 1e-5,
6281            group_norm_eps: 64e-5,
6282            decay_low_rank: 16,
6283            a_low_rank: 16,
6284            v_low_rank: 16,
6285            g_low_rank: 32,
6286        };
6287        cfg.validate().expect("valid test config");
6288        let model = Model::new_random(cfg.clone(), 0xCAFEBABE).expect("random model");
6289        let mut traced_state = model.new_state();
6290        let mut plain_state = model.new_state();
6291        let mut traced_scratch = ScratchBuffers::new(&cfg);
6292        let mut plain_scratch = ScratchBuffers::new(&cfg);
6293        traced_scratch.set_capture_train_trace(true);
6294        plain_scratch.set_capture_train_trace(false);
6295
6296        let tokens = [3u32, 19, 77, 120, 255, 5, 88, 13, 144, 1, 200];
6297        for &token in &tokens {
6298            let traced_logits = model
6299                .forward(&mut traced_scratch, token, &mut traced_state)
6300                .to_vec();
6301            let plain_logits = model
6302                .forward(&mut plain_scratch, token, &mut plain_state)
6303                .to_vec();
6304            for (a, b) in traced_logits.iter().zip(plain_logits.iter()) {
6305                assert_eq!(a.to_bits(), b.to_bits());
6306            }
6307            assert_eq!(traced_state.v_first_set, plain_state.v_first_set);
6308            for (&a, &b) in traced_state
6309                .v_first
6310                .as_slice()
6311                .iter()
6312                .zip(plain_state.v_first.as_slice())
6313            {
6314                assert_eq!(a.to_bits(), b.to_bits());
6315            }
6316            for (tr_layer, plain_layer) in traced_state.layers.iter().zip(plain_state.layers.iter())
6317            {
6318                for (&a, &b) in tr_layer
6319                    .att_x_prev
6320                    .as_slice()
6321                    .iter()
6322                    .zip(plain_layer.att_x_prev.as_slice())
6323                {
6324                    assert_eq!(a.to_bits(), b.to_bits());
6325                }
6326                for (&a, &b) in tr_layer
6327                    .att_state
6328                    .as_slice()
6329                    .iter()
6330                    .zip(plain_layer.att_state.as_slice())
6331                {
6332                    assert_eq!(a.to_bits(), b.to_bits());
6333                }
6334                for (&a, &b) in tr_layer
6335                    .ffn_x_prev
6336                    .as_slice()
6337                    .iter()
6338                    .zip(plain_layer.ffn_x_prev.as_slice())
6339                {
6340                    assert_eq!(a.to_bits(), b.to_bits());
6341                }
6342            }
6343        }
6344    }
6345
6346    #[test]
6347    fn tbptt_segment_gradients_match_finite_difference() {
6348        let cfg = test_cfg();
6349        cfg.validate().expect("valid test config");
6350        let model = Model::new_random(cfg.clone(), 0xD00D_F00D).expect("random model");
6351        let steps = [(0u32, 1u8), (1, 2), (2, 3)];
6352        let grads = segment_grads(&model, &cfg, &steps);
6353        let eps = 1e-3f32;
6354
6355        for probe in [
6356            Probe::Embed,
6357            Probe::LnOutW,
6358            Probe::AttnNormW,
6359            Probe::OProj,
6360            Probe::KProj,
6361            Probe::VProj,
6362            Probe::FfnKey,
6363        ] {
6364            let analytic = probe_grad(&grads, probe);
6365
6366            let mut plus = model.clone();
6367            let base = probe_value(&plus, probe);
6368            set_probe(&mut plus, probe, base + eps);
6369            let loss_plus = segment_loss(&plus, &cfg, &steps);
6370
6371            let mut minus = model.clone();
6372            set_probe(&mut minus, probe, base - eps);
6373            let loss_minus = segment_loss(&minus, &cfg, &steps);
6374
6375            let numeric = -((loss_plus - loss_minus) / (2.0 * eps as f64)) as f32;
6376            let tol = 5e-2f32.max(analytic.abs().max(numeric.abs()) * 8e-2);
6377            assert!(
6378                (analytic - numeric).abs() <= tol,
6379                "probe={probe:?} analytic={analytic} numeric={numeric} tol={tol}"
6380            );
6381        }
6382    }
6383
6384    #[test]
6385    fn tbptt_sgd_step_reduces_mean_segment_loss() {
6386        let cfg = test_cfg();
6387        cfg.validate().expect("valid test config");
6388        let mut model = Model::new_random(cfg.clone(), 0x1234_5678).expect("random model");
6389        let steps = [(0u32, 1u8), (1, 2), (2, 3), (3, 4)];
6390        let before = segment_loss(&model, &cfg, &steps);
6391
6392        let mut scratch = ScratchBuffers::new(&cfg);
6393        let start_state = model.new_state();
6394        let mut live_state = model.new_state();
6395        let mut adam_t = 0usize;
6396        let scope = TrainScopeMask {
6397            embed: true,
6398            pre_norm: true,
6399            attn_norm: true,
6400            ffn_norm: true,
6401            attn: true,
6402            ffn: true,
6403            head: true,
6404            bias: false,
6405        };
6406
6407        model
6408            .online_train_segment_tbptt(
6409                &mut scratch,
6410                &start_state,
6411                &steps,
6412                scope,
6413                OptimizerKind::Sgd,
6414                1e-3,
6415                0.0,
6416                2,
6417                &mut adam_t,
6418                None,
6419                None,
6420                None,
6421                None,
6422                &mut live_state,
6423            )
6424            .expect("tbptt sgd step");
6425
6426        let after = segment_loss(&model, &cfg, &steps);
6427        assert!(
6428            after < before,
6429            "expected SGD TBPTT step to reduce mean loss: before={before} after={after}"
6430        );
6431    }
6432}