infotheory/backends/mambazip/mamba1/
model.rs

1//! Deterministic CPU Mamba-1 single-token inference runtime.
2
3use super::kernel;
4use super::tensor::Tensor1D;
5use super::weights::{WeightTensor, Weights};
6use crate::backends::llm_policy::OptimizerKind;
7use anyhow::{Context, Result, bail};
8use serde_json::json;
9use std::fs::File;
10use std::io::Write;
11use std::path::Path;
12use wide::f32x8;
13
14/// Mamba-1 model configuration.
15#[derive(Debug, Clone)]
16pub struct Config {
17    /// Byte vocabulary size.
18    pub vocab_size: usize,
19    /// Model hidden width.
20    pub hidden_size: usize,
21    /// Number of Mamba blocks.
22    pub num_layers: usize,
23    /// Expanded mixer width (`d_inner`).
24    pub inner_size: usize,
25    /// SSM state width (`d_state`).
26    pub state_size: usize,
27    /// Depthwise convolution kernel width (`d_conv`).
28    pub conv_kernel: usize,
29    /// Delta projection rank (`dt_rank`).
30    pub dt_rank: usize,
31    /// RMSNorm epsilon.
32    pub layer_norm_eps: f32,
33}
34
35impl Default for Config {
36    fn default() -> Self {
37        Self {
38            vocab_size: 256,
39            hidden_size: 256,
40            num_layers: 6,
41            inner_size: 512,
42            state_size: 16,
43            conv_kernel: 4,
44            dt_rank: 16,
45            layer_norm_eps: 1e-5,
46        }
47    }
48}
49
50impl Config {
51    /// Validate shape invariants.
52    pub fn validate(&self) -> Result<()> {
53        if self.vocab_size == 0 {
54            bail!("mamba vocab_size must be > 0");
55        }
56        if self.hidden_size == 0 {
57            bail!("mamba hidden_size must be > 0");
58        }
59        if self.num_layers == 0 {
60            bail!("mamba num_layers must be > 0");
61        }
62        if self.inner_size == 0 {
63            bail!("mamba inner_size must be > 0");
64        }
65        if self.state_size == 0 {
66            bail!("mamba state_size must be > 0");
67        }
68        if self.conv_kernel == 0 {
69            bail!("mamba conv_kernel must be > 0");
70        }
71        if self.dt_rank == 0 {
72            bail!("mamba dt_rank must be > 0");
73        }
74        Ok(())
75    }
76}
77
78#[derive(Clone)]
79struct LayerState {
80    conv: Tensor1D, // (inner * conv_kernel)
81    conv_pos: usize,
82    ssm: Tensor1D, // (inner * state)
83}
84
85impl LayerState {
86    fn new(cfg: &Config) -> Self {
87        Self {
88            conv: Tensor1D::zeros(cfg.inner_size * cfg.conv_kernel),
89            conv_pos: 0,
90            ssm: Tensor1D::zeros(cfg.inner_size * cfg.state_size),
91        }
92    }
93
94    fn reset(&mut self) {
95        self.conv.zero();
96        self.conv_pos = 0;
97        self.ssm.zero();
98    }
99}
100
101/// Recurrent runtime state for a Mamba model.
102#[derive(Clone)]
103pub struct State {
104    layers: Vec<LayerState>,
105}
106
107impl State {
108    /// Create a zero-initialized state.
109    pub fn new(cfg: &Config) -> Self {
110        Self {
111            layers: (0..cfg.num_layers).map(|_| LayerState::new(cfg)).collect(),
112        }
113    }
114
115    /// Reset all recurrent buffers.
116    pub fn reset(&mut self) {
117        for l in &mut self.layers {
118            l.reset();
119        }
120    }
121}
122
123#[derive(Clone)]
124struct LayerWeights {
125    norm_w: Tensor1D,
126    norm_b: Option<Tensor1D>,
127
128    in_proj_w: Tensor1D, // (2*inner, hidden)
129    in_proj_b: Option<Tensor1D>,
130
131    conv_w: Tensor1D, // (inner, conv_kernel)
132    conv_b: Option<Tensor1D>,
133
134    x_proj_w: Tensor1D, // (dt_rank + 2*state, inner)
135    x_proj_b: Option<Tensor1D>,
136
137    dt_proj_w: Tensor1D, // (inner, dt_rank)
138    dt_proj_b: Tensor1D,
139
140    // Stored SSM diagonal in official parameterization and cached runtime form.
141    a_log: Tensor1D, // (inner, state)
142    a: Tensor1D,     // (inner, state), equals -exp(a_log)
143    d: Tensor1D,     // (inner)
144
145    out_proj_w: Tensor1D, // (hidden, inner)
146    out_proj_b: Option<Tensor1D>,
147}
148
149#[derive(Clone)]
150struct AdamTensorState {
151    m: Tensor1D,
152    v: Tensor1D,
153}
154
155impl AdamTensorState {
156    #[inline]
157    fn new(len: usize) -> Self {
158        Self {
159            m: Tensor1D::zeros(len),
160            v: Tensor1D::zeros(len),
161        }
162    }
163}
164
165#[derive(Clone)]
166struct LayerAdamState {
167    norm_w: AdamTensorState,
168    norm_b: Option<AdamTensorState>,
169    in_proj_w: AdamTensorState,
170    in_proj_b: Option<AdamTensorState>,
171    conv_w: AdamTensorState,
172    conv_b: Option<AdamTensorState>,
173    x_proj_w: AdamTensorState,
174    x_proj_b: Option<AdamTensorState>,
175    dt_proj_w: AdamTensorState,
176    dt_proj_b: AdamTensorState,
177    a: AdamTensorState,
178    d: AdamTensorState,
179    out_proj_w: AdamTensorState,
180    out_proj_b: Option<AdamTensorState>,
181}
182
183#[derive(Clone)]
184/// Adam moments for full-parameter online Mamba training.
185pub struct FullAdamState {
186    embeddings: AdamTensorState,
187    final_norm_w: AdamTensorState,
188    final_norm_b: Option<AdamTensorState>,
189    lm_head: AdamTensorState,
190    lm_head_b: Option<AdamTensorState>,
191    layers: Vec<LayerAdamState>,
192}
193
194#[derive(Clone, Copy, Debug, Default)]
195/// Train-scope mask for Mamba full-parameter online updates.
196pub struct TrainScopeMask {
197    /// Train token embeddings.
198    pub embed: bool,
199    /// Train layer-normalization weights/biases.
200    pub layer_norm: bool,
201    /// Train convolutional mixer parameters.
202    pub mixer_conv: bool,
203    /// Train SSM/state-space mixer parameters.
204    pub mixer_ssm: bool,
205    /// Train projection matrices around the mixer.
206    pub mixer_proj: bool,
207    /// Train LM-head weights.
208    pub head: bool,
209    /// Train additive output-bias terms.
210    pub bias: bool,
211}
212
213impl TrainScopeMask {
214    #[inline]
215    /// Enable all train scopes.
216    pub fn all() -> Self {
217        Self {
218            embed: true,
219            layer_norm: true,
220            mixer_conv: true,
221            mixer_ssm: true,
222            mixer_proj: true,
223            head: true,
224            bias: true,
225        }
226    }
227
228    #[inline]
229    /// Returns whether any model parameters (excluding standalone output bias) are trainable.
230    pub fn trains_model_params(&self) -> bool {
231        self.embed
232            || self.layer_norm
233            || self.mixer_conv
234            || self.mixer_ssm
235            || self.mixer_proj
236            || self.head
237    }
238}
239
240struct AdamStep {
241    lr: f32,
242    clip: f32,
243    b1: f32,
244    b2: f32,
245    eps: f32,
246    bias_corr1: f32,
247    bias_corr2: f32,
248}
249
250#[derive(Clone)]
251struct LayerTrainTrace {
252    h_in: Tensor1D,
253    norm: Tensor1D,
254    xz: Tensor1D,
255    conv_pre: Tensor1D,
256    conv_post: Tensor1D,
257    conv_sigmoid: Tensor1D,
258    proj: Tensor1D,
259    dt_raw: Tensor1D,
260    dt: Tensor1D,
261    gate: Tensor1D,
262    gate_sigmoid: Tensor1D,
263    y_pre: Tensor1D,
264    y: Tensor1D,
265    out: Tensor1D,
266    d_a: Tensor1D,
267    ssm_prev: Tensor1D,
268    conv_prev: Tensor1D,
269    conv_pos_prev: usize,
270}
271
272impl LayerTrainTrace {
273    fn new(cfg: &Config) -> Self {
274        Self {
275            h_in: Tensor1D::zeros(cfg.hidden_size),
276            norm: Tensor1D::zeros(cfg.hidden_size),
277            xz: Tensor1D::zeros(cfg.inner_size * 2),
278            conv_pre: Tensor1D::zeros(cfg.inner_size),
279            conv_post: Tensor1D::zeros(cfg.inner_size),
280            conv_sigmoid: Tensor1D::zeros(cfg.inner_size),
281            proj: Tensor1D::zeros(cfg.dt_rank + 2 * cfg.state_size),
282            dt_raw: Tensor1D::zeros(cfg.inner_size),
283            dt: Tensor1D::zeros(cfg.inner_size),
284            gate: Tensor1D::zeros(cfg.inner_size),
285            gate_sigmoid: Tensor1D::zeros(cfg.inner_size),
286            y_pre: Tensor1D::zeros(cfg.inner_size),
287            y: Tensor1D::zeros(cfg.inner_size),
288            out: Tensor1D::zeros(cfg.hidden_size),
289            d_a: Tensor1D::zeros(cfg.inner_size * cfg.state_size),
290            ssm_prev: Tensor1D::zeros(cfg.inner_size * cfg.state_size),
291            conv_prev: Tensor1D::zeros(cfg.inner_size * cfg.conv_kernel),
292            conv_pos_prev: 0,
293        }
294    }
295}
296
297#[derive(Clone)]
298struct TokenTrainTrace {
299    token: usize,
300    norm: Tensor1D,
301    h_final: Tensor1D,
302    layers: Vec<LayerTrainTrace>,
303}
304
305impl TokenTrainTrace {
306    fn from_scratch(scratch: &ScratchBuffers) -> Self {
307        Self {
308            token: scratch.train_token,
309            norm: scratch.norm.clone(),
310            h_final: scratch.train_h_final.clone(),
311            layers: scratch.train_trace_layers.clone(),
312        }
313    }
314}
315
316#[derive(Clone)]
317struct LayerGradState {
318    norm_w: Tensor1D,
319    norm_b: Option<Tensor1D>,
320    in_proj_w: Tensor1D,
321    in_proj_b: Option<Tensor1D>,
322    conv_w: Tensor1D,
323    conv_b: Option<Tensor1D>,
324    x_proj_w: Tensor1D,
325    x_proj_b: Option<Tensor1D>,
326    dt_proj_w: Tensor1D,
327    dt_proj_b: Tensor1D,
328    a: Tensor1D,
329    d: Tensor1D,
330    out_proj_w: Tensor1D,
331    out_proj_b: Option<Tensor1D>,
332}
333
334#[derive(Clone)]
335struct FullGradState {
336    embeddings: Tensor1D,
337    final_norm_w: Tensor1D,
338    final_norm_b: Option<Tensor1D>,
339    lm_head: Tensor1D,
340    lm_head_b: Option<Tensor1D>,
341    layers: Vec<LayerGradState>,
342}
343
344#[derive(Clone)]
345struct LayerRecurrentGradState {
346    ssm_next: Tensor1D,
347    conv_next: Tensor1D,
348}
349
350impl LayerRecurrentGradState {
351    fn new(cfg: &Config) -> Self {
352        Self {
353            ssm_next: Tensor1D::zeros(cfg.inner_size * cfg.state_size),
354            conv_next: Tensor1D::zeros(cfg.inner_size * cfg.conv_kernel),
355        }
356    }
357}
358
359#[derive(Clone)]
360struct RecurrentGradState {
361    layers: Vec<LayerRecurrentGradState>,
362}
363
364impl RecurrentGradState {
365    fn new(cfg: &Config) -> Self {
366        Self {
367            layers: (0..cfg.num_layers)
368                .map(|_| LayerRecurrentGradState::new(cfg))
369                .collect(),
370        }
371    }
372
373    fn zero(&mut self) {
374        for layer in &mut self.layers {
375            layer.ssm_next.zero();
376            layer.conv_next.zero();
377        }
378    }
379}
380
381/// Mamba model weights and inference kernels.
382#[derive(Clone)]
383pub struct Model {
384    cfg: Config,
385    embeddings: Tensor1D, // (vocab, hidden)
386    final_norm_w: Tensor1D,
387    final_norm_b: Option<Tensor1D>,
388    lm_head: Tensor1D, // (vocab, hidden)
389    lm_head_b: Option<Tensor1D>,
390    layers: Vec<LayerWeights>,
391}
392
393/// Preallocated temporary buffers for token forward passes.
394#[derive(Clone)]
395pub struct ScratchBuffers {
396    h: Tensor1D,
397    norm: Tensor1D,
398    xz: Tensor1D,
399    conv: Tensor1D,
400    proj: Tensor1D,
401    dt: Tensor1D,
402    y: Tensor1D,
403    out: Tensor1D,
404    logits: Tensor1D,
405    grad_h: Tensor1D,
406    grad_norm: Tensor1D,
407    grad_xz: Tensor1D,
408    grad_conv: Tensor1D,
409    grad_conv_pre: Tensor1D,
410    grad_proj: Tensor1D,
411    grad_dt_raw: Tensor1D,
412    grad_u: Tensor1D,
413    grad_b: Tensor1D,
414    grad_c: Tensor1D,
415    grad_ssm_d: Tensor1D,
416    grad_ssm_a: Tensor1D,
417    grad_ssm_a_log: Tensor1D,
418    grad_conv_w: Tensor1D,
419    grad_conv_b: Tensor1D,
420    grad_y: Tensor1D,
421    grad_out: Tensor1D,
422    grad_logits: Tensor1D,
423    grad_residual: Tensor1D,
424    train_trace_layers: Vec<LayerTrainTrace>,
425    train_h_final: Tensor1D,
426    train_token: usize,
427    train_trace_valid: bool,
428    capture_train_trace: bool,
429}
430
431impl ScratchBuffers {
432    /// Allocate scratch for config.
433    pub fn new(cfg: &Config) -> Self {
434        let mut train_trace_layers = Vec::with_capacity(cfg.num_layers);
435        for _ in 0..cfg.num_layers {
436            train_trace_layers.push(LayerTrainTrace::new(cfg));
437        }
438        Self {
439            h: Tensor1D::zeros(cfg.hidden_size),
440            norm: Tensor1D::zeros(cfg.hidden_size),
441            xz: Tensor1D::zeros(cfg.inner_size * 2),
442            conv: Tensor1D::zeros(cfg.inner_size),
443            proj: Tensor1D::zeros(cfg.dt_rank + 2 * cfg.state_size),
444            dt: Tensor1D::zeros(cfg.inner_size),
445            y: Tensor1D::zeros(cfg.inner_size),
446            out: Tensor1D::zeros(cfg.hidden_size),
447            logits: Tensor1D::zeros(cfg.vocab_size),
448            grad_h: Tensor1D::zeros(cfg.hidden_size),
449            grad_norm: Tensor1D::zeros(cfg.hidden_size),
450            grad_xz: Tensor1D::zeros(cfg.inner_size * 2),
451            grad_conv: Tensor1D::zeros(cfg.inner_size),
452            grad_conv_pre: Tensor1D::zeros(cfg.inner_size),
453            grad_proj: Tensor1D::zeros(cfg.dt_rank + 2 * cfg.state_size),
454            grad_dt_raw: Tensor1D::zeros(cfg.inner_size),
455            grad_u: Tensor1D::zeros(cfg.dt_rank),
456            grad_b: Tensor1D::zeros(cfg.state_size),
457            grad_c: Tensor1D::zeros(cfg.state_size),
458            grad_ssm_d: Tensor1D::zeros(cfg.inner_size),
459            grad_ssm_a: Tensor1D::zeros(cfg.inner_size * cfg.state_size),
460            grad_ssm_a_log: Tensor1D::zeros(cfg.inner_size * cfg.state_size),
461            grad_conv_w: Tensor1D::zeros(cfg.inner_size * cfg.conv_kernel),
462            grad_conv_b: Tensor1D::zeros(cfg.inner_size),
463            grad_y: Tensor1D::zeros(cfg.inner_size),
464            grad_out: Tensor1D::zeros(cfg.hidden_size),
465            grad_logits: Tensor1D::zeros(cfg.vocab_size),
466            grad_residual: Tensor1D::zeros(cfg.hidden_size),
467            train_trace_layers,
468            train_h_final: Tensor1D::zeros(cfg.hidden_size),
469            train_token: 0,
470            train_trace_valid: false,
471            capture_train_trace: false,
472        }
473    }
474
475    /// Final normalized hidden state consumed by LM head.
476    #[inline]
477    pub fn lm_head_input(&self) -> &[f32] {
478        self.norm.as_slice()
479    }
480
481    /// Output logits from the latest forward pass.
482    #[inline]
483    pub fn logits(&self) -> &[f32] {
484        self.logits.as_slice()
485    }
486
487    /// Restore LM-head input snapshot for reversible online updates.
488    #[inline]
489    pub fn set_lm_head_input(&mut self, value: &[f32]) {
490        self.norm.as_mut_slice().copy_from_slice(value);
491    }
492
493    /// Enable or disable per-token training trace capture.
494    #[inline]
495    pub fn set_capture_train_trace(&mut self, enabled: bool) {
496        self.capture_train_trace = enabled;
497        if !enabled {
498            self.train_trace_valid = false;
499        }
500    }
501
502    /// Whether the current scratch contains a valid full-trace for the latest forward pass.
503    #[inline]
504    pub fn has_train_trace(&self) -> bool {
505        self.train_trace_valid
506    }
507}
508
509impl Model {
510    /// Allocate zero-initialized Adam moments matching all trainable tensors.
511    pub fn new_full_adam_state(&self) -> FullAdamState {
512        let mut layers = Vec::with_capacity(self.layers.len());
513        for layer in &self.layers {
514            layers.push(LayerAdamState {
515                norm_w: AdamTensorState::new(layer.norm_w.len()),
516                norm_b: layer.norm_b.as_ref().map(|b| AdamTensorState::new(b.len())),
517                in_proj_w: AdamTensorState::new(layer.in_proj_w.len()),
518                in_proj_b: layer
519                    .in_proj_b
520                    .as_ref()
521                    .map(|b| AdamTensorState::new(b.len())),
522                conv_w: AdamTensorState::new(layer.conv_w.len()),
523                conv_b: layer.conv_b.as_ref().map(|b| AdamTensorState::new(b.len())),
524                x_proj_w: AdamTensorState::new(layer.x_proj_w.len()),
525                x_proj_b: layer
526                    .x_proj_b
527                    .as_ref()
528                    .map(|b| AdamTensorState::new(b.len())),
529                dt_proj_w: AdamTensorState::new(layer.dt_proj_w.len()),
530                dt_proj_b: AdamTensorState::new(layer.dt_proj_b.len()),
531                a: AdamTensorState::new(layer.a_log.len()),
532                d: AdamTensorState::new(layer.d.len()),
533                out_proj_w: AdamTensorState::new(layer.out_proj_w.len()),
534                out_proj_b: layer
535                    .out_proj_b
536                    .as_ref()
537                    .map(|b| AdamTensorState::new(b.len())),
538            });
539        }
540
541        FullAdamState {
542            embeddings: AdamTensorState::new(self.embeddings.len()),
543            final_norm_w: AdamTensorState::new(self.final_norm_w.len()),
544            final_norm_b: self
545                .final_norm_b
546                .as_ref()
547                .map(|b| AdamTensorState::new(b.len())),
548            lm_head: AdamTensorState::new(self.lm_head.len()),
549            lm_head_b: self
550                .lm_head_b
551                .as_ref()
552                .map(|b| AdamTensorState::new(b.len())),
553            layers,
554        }
555    }
556
557    /// Load a model from safetensors checkpoint.
558    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
559        let weights = Weights::load(path.as_ref()).with_context(|| {
560            format!(
561                "failed to load model weights from {}",
562                path.as_ref().display()
563            )
564        })?;
565
566        if weights.get("backbone.embedding.weight").is_some() {
567            Self::load_official(&weights)
568        } else {
569            Self::load_native(&weights)
570        }
571    }
572
573    /// Build a deterministic random model for online-mode workflows.
574    pub fn new_random(cfg: Config, seed: u64) -> Result<Self> {
575        cfg.validate()?;
576
577        let mut rng = MambaRng::new(seed);
578        let v = cfg.vocab_size;
579        let h = cfg.hidden_size;
580        let i = cfg.inner_size;
581        let s = cfg.state_size;
582        let k = cfg.conv_kernel;
583        let r = cfg.dt_rank;
584
585        let mut embeddings = Tensor1D::zeros(v * h);
586        init_uniform(&mut embeddings, &mut rng, 0.02);
587
588        let mut final_norm_w = Tensor1D::zeros(h);
589        init_const(&mut final_norm_w, 1.0);
590
591        let mut lm_head = Tensor1D::zeros(v * h);
592        init_uniform(&mut lm_head, &mut rng, 0.02);
593
594        let mut layers = Vec::with_capacity(cfg.num_layers);
595        for _ in 0..cfg.num_layers {
596            let mut norm_w = Tensor1D::zeros(h);
597            init_const(&mut norm_w, 1.0);
598
599            let mut in_proj_w = Tensor1D::zeros((2 * i) * h);
600            init_uniform(&mut in_proj_w, &mut rng, 0.02);
601            let mut in_proj_b = Tensor1D::zeros(2 * i);
602            init_const(&mut in_proj_b, 0.0);
603
604            let mut conv_w = Tensor1D::zeros(i * k);
605            init_uniform(&mut conv_w, &mut rng, 0.05);
606            let mut conv_b = Tensor1D::zeros(i);
607            init_const(&mut conv_b, 0.0);
608
609            let mut x_proj_w = Tensor1D::zeros((r + 2 * s) * i);
610            init_uniform(&mut x_proj_w, &mut rng, 0.02);
611
612            let mut dt_proj_w = Tensor1D::zeros(i * r);
613            init_uniform(&mut dt_proj_w, &mut rng, 0.02);
614            let mut dt_proj_b = Tensor1D::zeros(i);
615            init_const(&mut dt_proj_b, -2.0);
616
617            let mut a_log = Tensor1D::zeros(i * s);
618            init_const(&mut a_log, 0.0);
619            let a = a_from_a_log_tensor(&a_log);
620            let mut d = Tensor1D::zeros(i);
621            init_const(&mut d, 1.0);
622
623            let mut out_proj_w = Tensor1D::zeros(h * i);
624            init_uniform(&mut out_proj_w, &mut rng, 0.02);
625            let mut out_proj_b = Tensor1D::zeros(h);
626            init_const(&mut out_proj_b, 0.0);
627
628            layers.push(LayerWeights {
629                norm_w,
630                norm_b: None,
631                in_proj_w,
632                in_proj_b: Some(in_proj_b),
633                conv_w,
634                conv_b: Some(conv_b),
635                x_proj_w,
636                x_proj_b: None,
637                dt_proj_w,
638                dt_proj_b,
639                a_log,
640                a,
641                d,
642                out_proj_w,
643                out_proj_b: Some(out_proj_b),
644            });
645        }
646
647        Ok(Self {
648            cfg,
649            embeddings,
650            final_norm_w,
651            final_norm_b: None,
652            lm_head,
653            lm_head_b: None,
654            layers,
655        })
656    }
657
658    /// Save checkpoint to safetensors (native infotheory Mamba layout).
659    pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> Result<()> {
660        #[derive(Clone)]
661        struct TensorRec {
662            name: String,
663            shape: Vec<usize>,
664            data: Vec<f32>,
665        }
666
667        let c = self.cfg.hidden_size;
668        let v = self.cfg.vocab_size;
669        let i = self.cfg.inner_size;
670        let s = self.cfg.state_size;
671        let k = self.cfg.conv_kernel;
672        let r = self.cfg.dt_rank;
673
674        let mut recs: Vec<TensorRec> = Vec::new();
675
676        let mut push = |name: String, shape: Vec<usize>, t: &Tensor1D| {
677            recs.push(TensorRec {
678                name,
679                shape,
680                data: t.as_slice().to_vec(),
681            });
682        };
683
684        push(
685            "model.embeddings.weight".to_string(),
686            vec![v, c],
687            &self.embeddings,
688        );
689        push("model.norm.weight".to_string(), vec![c], &self.final_norm_w);
690        if let Some(b) = &self.final_norm_b {
691            push("model.norm.bias".to_string(), vec![c], b);
692        }
693        push("lm_head.weight".to_string(), vec![v, c], &self.lm_head);
694        if let Some(b) = &self.lm_head_b {
695            push("lm_head.bias".to_string(), vec![v], b);
696        }
697
698        for (idx, layer) in self.layers.iter().enumerate() {
699            let pfx = format!("model.layers.{idx}.mixer");
700            push(
701                format!("model.layers.{idx}.norm.weight"),
702                vec![c],
703                &layer.norm_w,
704            );
705            if let Some(b) = &layer.norm_b {
706                push(format!("model.layers.{idx}.norm.bias"), vec![c], b);
707            }
708            push(
709                format!("{pfx}.in_proj.weight"),
710                vec![2 * i, c],
711                &layer.in_proj_w,
712            );
713            if let Some(b) = &layer.in_proj_b {
714                push(format!("{pfx}.in_proj.bias"), vec![2 * i], b);
715            }
716            push(format!("{pfx}.conv1d.weight"), vec![i, 1, k], &layer.conv_w);
717            if let Some(b) = &layer.conv_b {
718                push(format!("{pfx}.conv1d.bias"), vec![i], b);
719            }
720            push(
721                format!("{pfx}.x_proj.weight"),
722                vec![r + 2 * s, i],
723                &layer.x_proj_w,
724            );
725            if let Some(b) = &layer.x_proj_b {
726                push(format!("{pfx}.x_proj.bias"), vec![r + 2 * s], b);
727            }
728            push(
729                format!("{pfx}.dt_proj.weight"),
730                vec![i, r],
731                &layer.dt_proj_w,
732            );
733            push(format!("{pfx}.dt_proj.bias"), vec![i], &layer.dt_proj_b);
734            push(format!("{pfx}.A_log"), vec![i, s], &layer.a_log);
735            push(format!("{pfx}.D"), vec![i], &layer.d);
736            push(
737                format!("{pfx}.out_proj.weight"),
738                vec![c, i],
739                &layer.out_proj_w,
740            );
741            if let Some(b) = &layer.out_proj_b {
742                push(format!("{pfx}.out_proj.bias"), vec![c], b);
743            }
744        }
745
746        recs.sort_by(|a, b| a.name.cmp(&b.name));
747
748        let mut offset = 0usize;
749        let mut header = serde_json::Map::new();
750        header.insert("__metadata__".to_string(), json!({}));
751        for rec in &recs {
752            let bytes = rec.data.len() * 4;
753            header.insert(
754                rec.name.clone(),
755                json!({
756                    "dtype": "F32",
757                    "shape": rec.shape,
758                    "data_offsets": [offset, offset + bytes],
759                }),
760            );
761            offset += bytes;
762        }
763
764        let header_bytes = serde_json::to_vec(&header)?;
765        let mut f = File::create(path)?;
766        f.write_all(&(header_bytes.len() as u64).to_le_bytes())?;
767        f.write_all(&header_bytes)?;
768        for rec in &recs {
769            for v in &rec.data {
770                f.write_all(&v.to_le_bytes())?;
771            }
772        }
773        Ok(())
774    }
775
776    /// Save full-parameter Adam moments for exact online-training continuation.
777    pub fn save_full_adam_safetensors<P: AsRef<Path>>(
778        &self,
779        adam: &FullAdamState,
780        path: P,
781    ) -> Result<()> {
782        #[derive(Clone)]
783        struct TensorRec {
784            name: String,
785            shape: Vec<usize>,
786            data: Vec<f32>,
787        }
788
789        let c = self.cfg.hidden_size;
790        let v = self.cfg.vocab_size;
791        let i = self.cfg.inner_size;
792        let s = self.cfg.state_size;
793        let k = self.cfg.conv_kernel;
794        let r = self.cfg.dt_rank;
795
796        let mut recs: Vec<TensorRec> = Vec::new();
797        let mut push_state = |name_prefix: &str, shape: Vec<usize>, st: &AdamTensorState| {
798            recs.push(TensorRec {
799                name: format!("{name_prefix}.m"),
800                shape: shape.clone(),
801                data: st.m.as_slice().to_vec(),
802            });
803            recs.push(TensorRec {
804                name: format!("{name_prefix}.v"),
805                shape,
806                data: st.v.as_slice().to_vec(),
807            });
808        };
809
810        push_state("opt.embeddings", vec![v, c], &adam.embeddings);
811        push_state("opt.final_norm.weight", vec![c], &adam.final_norm_w);
812        if let Some(b) = &adam.final_norm_b {
813            push_state("opt.final_norm.bias", vec![c], b);
814        }
815        push_state("opt.lm_head.weight", vec![v, c], &adam.lm_head);
816        if let Some(b) = &adam.lm_head_b {
817            push_state("opt.lm_head.bias", vec![v], b);
818        }
819
820        for (idx, layer) in adam.layers.iter().enumerate() {
821            let pfx = format!("opt.layers.{idx}");
822            push_state(&format!("{pfx}.norm.weight"), vec![c], &layer.norm_w);
823            if let Some(b) = &layer.norm_b {
824                push_state(&format!("{pfx}.norm.bias"), vec![c], b);
825            }
826            push_state(
827                &format!("{pfx}.in_proj.weight"),
828                vec![2 * i, c],
829                &layer.in_proj_w,
830            );
831            if let Some(b) = &layer.in_proj_b {
832                push_state(&format!("{pfx}.in_proj.bias"), vec![2 * i], b);
833            }
834            push_state(
835                &format!("{pfx}.conv1d.weight"),
836                vec![i, 1, k],
837                &layer.conv_w,
838            );
839            if let Some(b) = &layer.conv_b {
840                push_state(&format!("{pfx}.conv1d.bias"), vec![i], b);
841            }
842            push_state(
843                &format!("{pfx}.x_proj.weight"),
844                vec![r + 2 * s, i],
845                &layer.x_proj_w,
846            );
847            if let Some(b) = &layer.x_proj_b {
848                push_state(&format!("{pfx}.x_proj.bias"), vec![r + 2 * s], b);
849            }
850            push_state(
851                &format!("{pfx}.dt_proj.weight"),
852                vec![i, r],
853                &layer.dt_proj_w,
854            );
855            push_state(&format!("{pfx}.dt_proj.bias"), vec![i], &layer.dt_proj_b);
856            push_state(&format!("{pfx}.A_log"), vec![i, s], &layer.a);
857            push_state(&format!("{pfx}.D"), vec![i], &layer.d);
858            push_state(
859                &format!("{pfx}.out_proj.weight"),
860                vec![c, i],
861                &layer.out_proj_w,
862            );
863            if let Some(b) = &layer.out_proj_b {
864                push_state(&format!("{pfx}.out_proj.bias"), vec![c], b);
865            }
866        }
867
868        recs.sort_by(|a, b| a.name.cmp(&b.name));
869
870        let mut offset = 0usize;
871        let mut header = serde_json::Map::new();
872        header.insert("__metadata__".to_string(), json!({}));
873        for rec in &recs {
874            let bytes = rec.data.len() * 4;
875            header.insert(
876                rec.name.clone(),
877                json!({
878                    "dtype": "F32",
879                    "shape": rec.shape,
880                    "data_offsets": [offset, offset + bytes],
881                }),
882            );
883            offset += bytes;
884        }
885
886        let header_bytes = serde_json::to_vec(&header)?;
887        let mut f = File::create(path)?;
888        f.write_all(&(header_bytes.len() as u64).to_le_bytes())?;
889        f.write_all(&header_bytes)?;
890        for rec in &recs {
891            for v in &rec.data {
892                f.write_all(&v.to_le_bytes())?;
893            }
894        }
895        Ok(())
896    }
897
898    /// Load full-parameter Adam moments and validate tensor shapes.
899    pub fn load_full_adam_safetensors<P: AsRef<Path>>(&self, path: P) -> Result<FullAdamState> {
900        let weights = Weights::load(path.as_ref()).with_context(|| {
901            format!(
902                "failed to load optimizer moments from {}",
903                path.as_ref().display()
904            )
905        })?;
906        let mut adam = self.new_full_adam_state();
907
908        let load_state = |name_prefix: &str, st: &mut AdamTensorState| -> Result<()> {
909            let m_name = format!("{name_prefix}.m");
910            let v_name = format!("{name_prefix}.v");
911            let m_t = weights
912                .require(&m_name)
913                .with_context(|| format!("missing optimizer tensor '{m_name}'"))?;
914            let v_t = weights
915                .require(&v_name)
916                .with_context(|| format!("missing optimizer tensor '{v_name}'"))?;
917            if m_t.data().len() != st.m.len() {
918                bail!(
919                    "optimizer tensor '{}' len {} != expected {}",
920                    m_name,
921                    m_t.data().len(),
922                    st.m.len()
923                );
924            }
925            if v_t.data().len() != st.v.len() {
926                bail!(
927                    "optimizer tensor '{}' len {} != expected {}",
928                    v_name,
929                    v_t.data().len(),
930                    st.v.len()
931                );
932            }
933            st.m.as_mut_slice().copy_from_slice(m_t.data());
934            st.v.as_mut_slice().copy_from_slice(v_t.data());
935            Ok(())
936        };
937
938        load_state("opt.embeddings", &mut adam.embeddings)?;
939        load_state("opt.final_norm.weight", &mut adam.final_norm_w)?;
940        if let Some(st) = adam.final_norm_b.as_mut() {
941            load_state("opt.final_norm.bias", st)?;
942        }
943        load_state("opt.lm_head.weight", &mut adam.lm_head)?;
944        if let Some(st) = adam.lm_head_b.as_mut() {
945            load_state("opt.lm_head.bias", st)?;
946        }
947
948        for (idx, layer) in adam.layers.iter_mut().enumerate() {
949            let pfx = format!("opt.layers.{idx}");
950            load_state(&format!("{pfx}.norm.weight"), &mut layer.norm_w)?;
951            if let Some(st) = layer.norm_b.as_mut() {
952                load_state(&format!("{pfx}.norm.bias"), st)?;
953            }
954            load_state(&format!("{pfx}.in_proj.weight"), &mut layer.in_proj_w)?;
955            if let Some(st) = layer.in_proj_b.as_mut() {
956                load_state(&format!("{pfx}.in_proj.bias"), st)?;
957            }
958            load_state(&format!("{pfx}.conv1d.weight"), &mut layer.conv_w)?;
959            if let Some(st) = layer.conv_b.as_mut() {
960                load_state(&format!("{pfx}.conv1d.bias"), st)?;
961            }
962            load_state(&format!("{pfx}.x_proj.weight"), &mut layer.x_proj_w)?;
963            if let Some(st) = layer.x_proj_b.as_mut() {
964                load_state(&format!("{pfx}.x_proj.bias"), st)?;
965            }
966            load_state(&format!("{pfx}.dt_proj.weight"), &mut layer.dt_proj_w)?;
967            load_state(&format!("{pfx}.dt_proj.bias"), &mut layer.dt_proj_b)?;
968            load_state(&format!("{pfx}.A_log"), &mut layer.a)?;
969            load_state(&format!("{pfx}.D"), &mut layer.d)?;
970            load_state(&format!("{pfx}.out_proj.weight"), &mut layer.out_proj_w)?;
971            if let Some(st) = layer.out_proj_b.as_mut() {
972                load_state(&format!("{pfx}.out_proj.bias"), st)?;
973            }
974        }
975
976        Ok(adam)
977    }
978
979    /// Access model config.
980    #[inline]
981    pub fn config(&self) -> &Config {
982        &self.cfg
983    }
984
985    /// Allocate a new recurrent state.
986    #[inline]
987    pub fn new_state(&self) -> State {
988        State::new(&self.cfg)
989    }
990
991    /// Immutable LM-head weights, row-major `(vocab, hidden)`.
992    #[inline]
993    pub fn lm_head_weights(&self) -> &[f32] {
994        self.lm_head.as_slice()
995    }
996
997    /// Mutable LM-head weights, row-major `(vocab, hidden)`.
998    #[inline]
999    pub fn lm_head_weights_mut(&mut self) -> &mut [f32] {
1000        self.lm_head.as_mut_slice()
1001    }
1002
1003    fn new_full_grad_state(&self) -> FullGradState {
1004        let mut layers = Vec::with_capacity(self.layers.len());
1005        for layer in &self.layers {
1006            layers.push(LayerGradState {
1007                norm_w: Tensor1D::zeros(layer.norm_w.len()),
1008                norm_b: layer.norm_b.as_ref().map(|b| Tensor1D::zeros(b.len())),
1009                in_proj_w: Tensor1D::zeros(layer.in_proj_w.len()),
1010                in_proj_b: layer.in_proj_b.as_ref().map(|b| Tensor1D::zeros(b.len())),
1011                conv_w: Tensor1D::zeros(layer.conv_w.len()),
1012                conv_b: layer.conv_b.as_ref().map(|b| Tensor1D::zeros(b.len())),
1013                x_proj_w: Tensor1D::zeros(layer.x_proj_w.len()),
1014                x_proj_b: layer.x_proj_b.as_ref().map(|b| Tensor1D::zeros(b.len())),
1015                dt_proj_w: Tensor1D::zeros(layer.dt_proj_w.len()),
1016                dt_proj_b: Tensor1D::zeros(layer.dt_proj_b.len()),
1017                a: Tensor1D::zeros(layer.a.len()),
1018                d: Tensor1D::zeros(layer.d.len()),
1019                out_proj_w: Tensor1D::zeros(layer.out_proj_w.len()),
1020                out_proj_b: layer.out_proj_b.as_ref().map(|b| Tensor1D::zeros(b.len())),
1021            });
1022        }
1023        FullGradState {
1024            embeddings: Tensor1D::zeros(self.embeddings.len()),
1025            final_norm_w: Tensor1D::zeros(self.final_norm_w.len()),
1026            final_norm_b: self.final_norm_b.as_ref().map(|b| Tensor1D::zeros(b.len())),
1027            lm_head: Tensor1D::zeros(self.lm_head.len()),
1028            lm_head_b: self.lm_head_b.as_ref().map(|b| Tensor1D::zeros(b.len())),
1029            layers,
1030        }
1031    }
1032
1033    fn new_recurrent_grad_state(&self) -> RecurrentGradState {
1034        RecurrentGradState::new(&self.cfg)
1035    }
1036
1037    #[allow(clippy::too_many_arguments)]
1038    fn apply_full_gradients(
1039        &mut self,
1040        grads: &FullGradState,
1041        scope: TrainScopeMask,
1042        optimizer: OptimizerKind,
1043        lr: f32,
1044        clip: f32,
1045        adam_t: &mut usize,
1046        model_adam: Option<&mut FullAdamState>,
1047        out_bias: Option<&mut [f32]>,
1048        out_bias_grad: Option<&[f32]>,
1049        out_bias_adam_m: Option<&mut [f32]>,
1050        out_bias_adam_v: Option<&mut [f32]>,
1051    ) -> Result<()> {
1052        let mut adam_step = None::<AdamStep>;
1053        let mut model_adam = model_adam;
1054        if matches!(optimizer, OptimizerKind::Adam) {
1055            *adam_t = adam_t.saturating_add(1);
1056            let t = (*adam_t).max(1) as i32;
1057            let b1 = 0.9f32;
1058            let b2 = 0.999f32;
1059            adam_step = Some(AdamStep {
1060                lr,
1061                clip: clip.max(0.0),
1062                b1,
1063                b2,
1064                eps: 1e-8,
1065                bias_corr1: 1.0 - b1.powi(t),
1066                bias_corr2: 1.0 - b2.powi(t),
1067            });
1068            if scope.trains_model_params() && model_adam.is_none() {
1069                bail!("mamba Adam full-training state is missing");
1070            }
1071        }
1072
1073        if scope.bias
1074            && let (Some(bias), Some(grad)) = (out_bias, out_bias_grad)
1075        {
1076            match optimizer {
1077                OptimizerKind::Sgd => sgd_vec_update(bias, grad, lr, clip),
1078                OptimizerKind::Adam => {
1079                    let cfg = adam_step.as_ref().expect("adam cfg initialized");
1080                    let Some(m) = out_bias_adam_m else {
1081                        bail!("mamba Adam output-bias moments are missing");
1082                    };
1083                    let Some(v) = out_bias_adam_v else {
1084                        bail!("mamba Adam output-bias moments are missing");
1085                    };
1086                    apply_adam_vec_update_raw(bias, grad, m, v, cfg);
1087                }
1088            }
1089        }
1090
1091        if scope.head {
1092            match optimizer {
1093                OptimizerKind::Sgd => {
1094                    sgd_vec_update(
1095                        self.lm_head.as_mut_slice(),
1096                        grads.lm_head.as_slice(),
1097                        lr,
1098                        clip,
1099                    );
1100                    if let (Some(b), Some(gb)) = (self.lm_head_b.as_mut(), grads.lm_head_b.as_ref())
1101                    {
1102                        sgd_vec_update(b.as_mut_slice(), gb.as_slice(), lr, clip);
1103                    }
1104                }
1105                OptimizerKind::Adam => {
1106                    let cfg = adam_step.as_ref().expect("adam cfg initialized");
1107                    let adam = model_adam.as_mut().expect("adam state exists");
1108                    apply_adam_vec_update(
1109                        self.lm_head.as_mut_slice(),
1110                        grads.lm_head.as_slice(),
1111                        &mut adam.lm_head,
1112                        cfg,
1113                    );
1114                    if let (Some(b), Some(gb), Some(ab)) = (
1115                        self.lm_head_b.as_mut(),
1116                        grads.lm_head_b.as_ref(),
1117                        adam.lm_head_b.as_mut(),
1118                    ) {
1119                        apply_adam_vec_update(b.as_mut_slice(), gb.as_slice(), ab, cfg);
1120                    }
1121                }
1122            }
1123        }
1124
1125        if scope.layer_norm {
1126            match optimizer {
1127                OptimizerKind::Sgd => {
1128                    sgd_vec_update(
1129                        self.final_norm_w.as_mut_slice(),
1130                        grads.final_norm_w.as_slice(),
1131                        lr,
1132                        clip,
1133                    );
1134                    if let (Some(b), Some(gb)) =
1135                        (self.final_norm_b.as_mut(), grads.final_norm_b.as_ref())
1136                    {
1137                        sgd_vec_update(b.as_mut_slice(), gb.as_slice(), lr, clip);
1138                    }
1139                }
1140                OptimizerKind::Adam => {
1141                    let cfg = adam_step.as_ref().expect("adam cfg initialized");
1142                    let adam = model_adam.as_mut().expect("adam state exists");
1143                    apply_adam_vec_update(
1144                        self.final_norm_w.as_mut_slice(),
1145                        grads.final_norm_w.as_slice(),
1146                        &mut adam.final_norm_w,
1147                        cfg,
1148                    );
1149                    if let (Some(b), Some(gb), Some(ab)) = (
1150                        self.final_norm_b.as_mut(),
1151                        grads.final_norm_b.as_ref(),
1152                        adam.final_norm_b.as_mut(),
1153                    ) {
1154                        apply_adam_vec_update(b.as_mut_slice(), gb.as_slice(), ab, cfg);
1155                    }
1156                }
1157            }
1158        }
1159
1160        for layer_idx in 0..self.cfg.num_layers {
1161            let layer = &mut self.layers[layer_idx];
1162            let grad = &grads.layers[layer_idx];
1163            match optimizer {
1164                OptimizerKind::Sgd => {
1165                    if scope.layer_norm {
1166                        sgd_vec_update(
1167                            layer.norm_w.as_mut_slice(),
1168                            grad.norm_w.as_slice(),
1169                            lr,
1170                            clip,
1171                        );
1172                        if let (Some(b), Some(gb)) = (layer.norm_b.as_mut(), grad.norm_b.as_ref()) {
1173                            sgd_vec_update(b.as_mut_slice(), gb.as_slice(), lr, clip);
1174                        }
1175                    }
1176                    if scope.mixer_proj {
1177                        sgd_vec_update(
1178                            layer.in_proj_w.as_mut_slice(),
1179                            grad.in_proj_w.as_slice(),
1180                            lr,
1181                            clip,
1182                        );
1183                        if let (Some(b), Some(gb)) =
1184                            (layer.in_proj_b.as_mut(), grad.in_proj_b.as_ref())
1185                        {
1186                            sgd_vec_update(b.as_mut_slice(), gb.as_slice(), lr, clip);
1187                        }
1188                        sgd_vec_update(
1189                            layer.x_proj_w.as_mut_slice(),
1190                            grad.x_proj_w.as_slice(),
1191                            lr,
1192                            clip,
1193                        );
1194                        if let (Some(b), Some(gb)) =
1195                            (layer.x_proj_b.as_mut(), grad.x_proj_b.as_ref())
1196                        {
1197                            sgd_vec_update(b.as_mut_slice(), gb.as_slice(), lr, clip);
1198                        }
1199                        sgd_vec_update(
1200                            layer.dt_proj_w.as_mut_slice(),
1201                            grad.dt_proj_w.as_slice(),
1202                            lr,
1203                            clip,
1204                        );
1205                        sgd_vec_update(
1206                            layer.dt_proj_b.as_mut_slice(),
1207                            grad.dt_proj_b.as_slice(),
1208                            lr,
1209                            clip,
1210                        );
1211                        sgd_vec_update(
1212                            layer.out_proj_w.as_mut_slice(),
1213                            grad.out_proj_w.as_slice(),
1214                            lr,
1215                            clip,
1216                        );
1217                        if let (Some(b), Some(gb)) =
1218                            (layer.out_proj_b.as_mut(), grad.out_proj_b.as_ref())
1219                        {
1220                            sgd_vec_update(b.as_mut_slice(), gb.as_slice(), lr, clip);
1221                        }
1222                    }
1223                    if scope.mixer_conv {
1224                        sgd_vec_update(
1225                            layer.conv_w.as_mut_slice(),
1226                            grad.conv_w.as_slice(),
1227                            lr,
1228                            clip,
1229                        );
1230                        if let (Some(b), Some(gb)) = (layer.conv_b.as_mut(), grad.conv_b.as_ref()) {
1231                            sgd_vec_update(b.as_mut_slice(), gb.as_slice(), lr, clip);
1232                        }
1233                    }
1234                    if scope.mixer_ssm {
1235                        sgd_vec_update(layer.d.as_mut_slice(), grad.d.as_slice(), lr, clip);
1236                        for idx in 0..layer.a_log.len().min(grad.a.len()) {
1237                            let mut g = grad.a[idx] * layer.a[idx];
1238                            if clip > 0.0 {
1239                                g = g.clamp(-clip, clip);
1240                            }
1241                            let new_log = layer.a_log[idx] + lr * g;
1242                            layer.a_log[idx] = new_log;
1243                            layer.a[idx] = -new_log.exp();
1244                        }
1245                    }
1246                }
1247                OptimizerKind::Adam => {
1248                    let cfg = adam_step.as_ref().expect("adam cfg initialized");
1249                    let adam =
1250                        &mut model_adam.as_mut().expect("adam state exists").layers[layer_idx];
1251                    if scope.layer_norm {
1252                        apply_adam_vec_update(
1253                            layer.norm_w.as_mut_slice(),
1254                            grad.norm_w.as_slice(),
1255                            &mut adam.norm_w,
1256                            cfg,
1257                        );
1258                        if let (Some(b), Some(gb), Some(ab)) = (
1259                            layer.norm_b.as_mut(),
1260                            grad.norm_b.as_ref(),
1261                            adam.norm_b.as_mut(),
1262                        ) {
1263                            apply_adam_vec_update(b.as_mut_slice(), gb.as_slice(), ab, cfg);
1264                        }
1265                    }
1266                    if scope.mixer_proj {
1267                        apply_adam_vec_update(
1268                            layer.in_proj_w.as_mut_slice(),
1269                            grad.in_proj_w.as_slice(),
1270                            &mut adam.in_proj_w,
1271                            cfg,
1272                        );
1273                        if let (Some(b), Some(gb), Some(ab)) = (
1274                            layer.in_proj_b.as_mut(),
1275                            grad.in_proj_b.as_ref(),
1276                            adam.in_proj_b.as_mut(),
1277                        ) {
1278                            apply_adam_vec_update(b.as_mut_slice(), gb.as_slice(), ab, cfg);
1279                        }
1280                        apply_adam_vec_update(
1281                            layer.x_proj_w.as_mut_slice(),
1282                            grad.x_proj_w.as_slice(),
1283                            &mut adam.x_proj_w,
1284                            cfg,
1285                        );
1286                        if let (Some(b), Some(gb), Some(ab)) = (
1287                            layer.x_proj_b.as_mut(),
1288                            grad.x_proj_b.as_ref(),
1289                            adam.x_proj_b.as_mut(),
1290                        ) {
1291                            apply_adam_vec_update(b.as_mut_slice(), gb.as_slice(), ab, cfg);
1292                        }
1293                        apply_adam_vec_update(
1294                            layer.dt_proj_w.as_mut_slice(),
1295                            grad.dt_proj_w.as_slice(),
1296                            &mut adam.dt_proj_w,
1297                            cfg,
1298                        );
1299                        apply_adam_vec_update(
1300                            layer.dt_proj_b.as_mut_slice(),
1301                            grad.dt_proj_b.as_slice(),
1302                            &mut adam.dt_proj_b,
1303                            cfg,
1304                        );
1305                        apply_adam_vec_update(
1306                            layer.out_proj_w.as_mut_slice(),
1307                            grad.out_proj_w.as_slice(),
1308                            &mut adam.out_proj_w,
1309                            cfg,
1310                        );
1311                        if let (Some(b), Some(gb), Some(ab)) = (
1312                            layer.out_proj_b.as_mut(),
1313                            grad.out_proj_b.as_ref(),
1314                            adam.out_proj_b.as_mut(),
1315                        ) {
1316                            apply_adam_vec_update(b.as_mut_slice(), gb.as_slice(), ab, cfg);
1317                        }
1318                    }
1319                    if scope.mixer_conv {
1320                        apply_adam_vec_update(
1321                            layer.conv_w.as_mut_slice(),
1322                            grad.conv_w.as_slice(),
1323                            &mut adam.conv_w,
1324                            cfg,
1325                        );
1326                        if let (Some(b), Some(gb), Some(ab)) = (
1327                            layer.conv_b.as_mut(),
1328                            grad.conv_b.as_ref(),
1329                            adam.conv_b.as_mut(),
1330                        ) {
1331                            apply_adam_vec_update(b.as_mut_slice(), gb.as_slice(), ab, cfg);
1332                        }
1333                    }
1334                    if scope.mixer_ssm {
1335                        apply_adam_vec_update(
1336                            layer.d.as_mut_slice(),
1337                            grad.d.as_slice(),
1338                            &mut adam.d,
1339                            cfg,
1340                        );
1341                        let mut grad_log = vec![0.0f32; grad.a.len().min(layer.a.len())];
1342                        for idx in 0..grad_log.len() {
1343                            grad_log[idx] = grad.a[idx] * layer.a[idx];
1344                        }
1345                        apply_adam_vec_update_and_sync_neg_exp(
1346                            layer.a_log.as_mut_slice(),
1347                            layer.a.as_mut_slice(),
1348                            &grad_log,
1349                            &mut adam.a,
1350                            cfg,
1351                        );
1352                    }
1353                }
1354            }
1355        }
1356
1357        if scope.embed {
1358            match optimizer {
1359                OptimizerKind::Sgd => {
1360                    sgd_vec_update(
1361                        self.embeddings.as_mut_slice(),
1362                        grads.embeddings.as_slice(),
1363                        lr,
1364                        clip,
1365                    );
1366                }
1367                OptimizerKind::Adam => {
1368                    let cfg = adam_step.as_ref().expect("adam cfg initialized");
1369                    let adam = model_adam.as_mut().expect("adam state exists");
1370                    apply_adam_vec_update(
1371                        self.embeddings.as_mut_slice(),
1372                        grads.embeddings.as_slice(),
1373                        &mut adam.embeddings,
1374                        cfg,
1375                    );
1376                }
1377            }
1378        }
1379
1380        Ok(())
1381    }
1382
1383    #[allow(clippy::too_many_arguments, clippy::needless_range_loop)]
1384    fn accumulate_token_step_gradients(
1385        &self,
1386        scratch: &mut ScratchBuffers,
1387        trace: &TokenTrainTrace,
1388        state_new: &State,
1389        symbol: u8,
1390        pdf: &[f64],
1391        grad_scale: f32,
1392        scope: TrainScopeMask,
1393        grads: &mut FullGradState,
1394        out_bias_grad: Option<&mut [f32]>,
1395        future: &mut RecurrentGradState,
1396    ) -> Result<()> {
1397        let c = self.cfg.hidden_size;
1398        let i = self.cfg.inner_size;
1399        let s = self.cfg.state_size;
1400        let r = self.cfg.dt_rank;
1401        let k = self.cfg.conv_kernel;
1402        let v = self.cfg.vocab_size.min(pdf.len());
1403        if v == 0 {
1404            return Ok(());
1405        }
1406
1407        scratch.grad_logits.zero();
1408        for tok in 0..v {
1409            let p = pdf[tok].clamp(1e-12, 1.0) as f32;
1410            let target = if tok == symbol as usize { 1.0 } else { 0.0 };
1411            scratch.grad_logits[tok] = (target - p) * grad_scale;
1412        }
1413
1414        if scope.bias
1415            && let Some(bias_grad) = out_bias_grad
1416        {
1417            add_vec_grad(&mut bias_grad[0..v], &scratch.grad_logits.as_slice()[0..v]);
1418        }
1419
1420        scratch.grad_h.zero();
1421        if scope.head {
1422            add_outer_grad(
1423                grads.lm_head.as_mut_slice(),
1424                v,
1425                c,
1426                &scratch.grad_logits.as_slice()[0..v],
1427                trace.norm.as_slice(),
1428            );
1429            if let Some(lm_head_b) = grads.lm_head_b.as_mut() {
1430                let n = v.min(lm_head_b.len());
1431                add_vec_grad(
1432                    &mut lm_head_b.as_mut_slice()[0..n],
1433                    &scratch.grad_logits.as_slice()[0..n],
1434                );
1435            }
1436        }
1437        for tok in 0..v {
1438            let g = scratch.grad_logits[tok];
1439            if g == 0.0 {
1440                continue;
1441            }
1442            let row_off = tok * c;
1443            for col in 0..c {
1444                scratch.grad_h[col] += self.lm_head[row_off + col] * g;
1445            }
1446        }
1447
1448        let needs_backprop = scope.embed
1449            || scope.layer_norm
1450            || scope.mixer_conv
1451            || scope.mixer_ssm
1452            || scope.mixer_proj;
1453        if !needs_backprop {
1454            return Ok(());
1455        }
1456
1457        rms_norm_backward(
1458            trace.h_final.as_slice(),
1459            self.final_norm_w.as_slice(),
1460            scratch.grad_h.as_slice(),
1461            self.cfg.layer_norm_eps,
1462            scratch.grad_norm.as_mut_slice(),
1463            scratch.grad_out.as_mut_slice(),
1464        );
1465        if scope.layer_norm {
1466            add_vec_grad(
1467                grads.final_norm_w.as_mut_slice(),
1468                scratch.grad_out.as_slice(),
1469            );
1470            if let Some(final_norm_b) = grads.final_norm_b.as_mut() {
1471                add_vec_grad(final_norm_b.as_mut_slice(), scratch.grad_h.as_slice());
1472            }
1473        }
1474        scratch
1475            .grad_h
1476            .as_mut_slice()
1477            .copy_from_slice(scratch.grad_norm.as_slice());
1478
1479        for layer_idx in (0..self.cfg.num_layers).rev() {
1480            let tr = &trace.layers[layer_idx];
1481            let st_new = &state_new.layers[layer_idx];
1482            let layer = &self.layers[layer_idx];
1483            let layer_grads = &mut grads.layers[layer_idx];
1484            let future_layer = &mut future.layers[layer_idx];
1485
1486            scratch
1487                .grad_out
1488                .as_mut_slice()
1489                .copy_from_slice(scratch.grad_h.as_slice());
1490
1491            unsafe {
1492                kernel::gemv_t(
1493                    layer.out_proj_w.as_ptr(),
1494                    scratch.grad_out.as_ptr(),
1495                    scratch.grad_y.as_mut_ptr(),
1496                    c,
1497                    i,
1498                );
1499            }
1500            if scope.mixer_proj {
1501                add_outer_grad(
1502                    layer_grads.out_proj_w.as_mut_slice(),
1503                    c,
1504                    i,
1505                    scratch.grad_out.as_slice(),
1506                    tr.y.as_slice(),
1507                );
1508                if let Some(out_proj_b) = layer_grads.out_proj_b.as_mut() {
1509                    add_vec_grad(out_proj_b.as_mut_slice(), scratch.grad_out.as_slice());
1510                }
1511            }
1512
1513            scratch
1514                .grad_residual
1515                .as_mut_slice()
1516                .copy_from_slice(scratch.grad_out.as_slice());
1517            scratch.grad_xz.zero();
1518            scratch.grad_b.zero();
1519            scratch.grad_c.zero();
1520            scratch.grad_ssm_d.zero();
1521            scratch.grad_ssm_a.zero();
1522            scratch.grad_dt_raw.zero();
1523            scratch.grad_conv.zero();
1524            scratch.grad_conv_pre.zero();
1525            scratch.grad_conv_w.zero();
1526            scratch.grad_conv_b.zero();
1527
1528            for ch in 0..i {
1529                let g_y = scratch.grad_y[ch];
1530                let g_y_pre = g_y * tr.gate[ch];
1531                let g_gate = g_y * tr.y_pre[ch];
1532                scratch.grad_xz[i + ch] =
1533                    g_gate * silu_grad_from_sigmoid(tr.xz[i + ch], tr.gate_sigmoid[ch]);
1534
1535                let conv = tr.conv_post[ch];
1536                let dt = tr.dt[ch];
1537                let xdt = conv * dt;
1538                let mut g_xdt = 0.0f32;
1539                let mut g_dt = 0.0f32;
1540
1541                scratch.grad_conv[ch] = g_y_pre * layer.d[ch];
1542                if scope.mixer_ssm {
1543                    scratch.grad_ssm_d[ch] = g_y_pre * conv;
1544                }
1545
1546                let row = ch * s;
1547                for j in 0..s {
1548                    let idx = row + j;
1549                    let c_j = tr.proj[r + s + j];
1550                    let b_j = tr.proj[r + j];
1551                    let s_prev = tr.ssm_prev[idx];
1552                    let s_new = st_new.ssm[idx];
1553                    let a_ij = layer.a[idx];
1554                    let d_a = tr.d_a[idx];
1555
1556                    let g_ssm_new = g_y_pre * c_j + future_layer.ssm_next[idx];
1557                    scratch.grad_c[j] += g_y_pre * s_new;
1558                    g_xdt += g_ssm_new * b_j;
1559                    scratch.grad_b[j] += g_ssm_new * xdt;
1560
1561                    let g_d_a = g_ssm_new * s_prev;
1562                    g_dt += g_d_a * d_a * a_ij;
1563                    if scope.mixer_ssm {
1564                        scratch.grad_ssm_a[idx] += g_d_a * d_a * dt;
1565                    }
1566                    future_layer.ssm_next[idx] = g_ssm_new * d_a;
1567                }
1568
1569                scratch.grad_conv[ch] += g_xdt * dt;
1570                g_dt += g_xdt * conv;
1571                let dt_pre = tr.dt_raw[ch] + layer.dt_proj_b[ch];
1572                scratch.grad_dt_raw[ch] = g_dt * sigmoid(dt_pre);
1573            }
1574
1575            if scope.mixer_ssm {
1576                add_vec_grad(layer_grads.d.as_mut_slice(), scratch.grad_ssm_d.as_slice());
1577                add_vec_grad(layer_grads.a.as_mut_slice(), scratch.grad_ssm_a.as_slice());
1578            }
1579
1580            unsafe {
1581                kernel::gemv_t(
1582                    layer.dt_proj_w.as_ptr(),
1583                    scratch.grad_dt_raw.as_ptr(),
1584                    scratch.grad_u.as_mut_ptr(),
1585                    i,
1586                    r,
1587                );
1588            }
1589            if scope.mixer_proj {
1590                add_outer_grad(
1591                    layer_grads.dt_proj_w.as_mut_slice(),
1592                    i,
1593                    r,
1594                    scratch.grad_dt_raw.as_slice(),
1595                    &tr.proj.as_slice()[0..r],
1596                );
1597                add_vec_grad(
1598                    layer_grads.dt_proj_b.as_mut_slice(),
1599                    scratch.grad_dt_raw.as_slice(),
1600                );
1601            }
1602
1603            for kk in 0..r {
1604                scratch.grad_proj[kk] = scratch.grad_u[kk];
1605            }
1606            for j in 0..s {
1607                scratch.grad_proj[r + j] = scratch.grad_b[j];
1608                scratch.grad_proj[r + s + j] = scratch.grad_c[j];
1609            }
1610
1611            unsafe {
1612                kernel::gemv_t(
1613                    layer.x_proj_w.as_ptr(),
1614                    scratch.grad_proj.as_ptr(),
1615                    scratch.grad_conv_pre.as_mut_ptr(),
1616                    r + 2 * s,
1617                    i,
1618                );
1619                kernel::add_inplace(
1620                    scratch.grad_conv.as_mut_ptr(),
1621                    scratch.grad_conv_pre.as_ptr(),
1622                    i,
1623                );
1624            }
1625            if scope.mixer_proj {
1626                add_outer_grad(
1627                    layer_grads.x_proj_w.as_mut_slice(),
1628                    r + 2 * s,
1629                    i,
1630                    scratch.grad_proj.as_slice(),
1631                    tr.conv_post.as_slice(),
1632                );
1633                if let Some(x_proj_b) = layer_grads.x_proj_b.as_mut() {
1634                    add_vec_grad(x_proj_b.as_mut_slice(), scratch.grad_proj.as_slice());
1635                }
1636            }
1637
1638            for ch in 0..i {
1639                scratch.grad_conv_pre[ch] = scratch.grad_conv[ch]
1640                    * silu_grad_from_sigmoid(tr.conv_pre[ch], tr.conv_sigmoid[ch]);
1641            }
1642
1643            for ch in 0..i {
1644                let g = scratch.grad_conv_pre[ch];
1645                let base = ch * k;
1646                let conv_future = &mut future_layer.conv_next.as_mut_slice()[base..base + k];
1647                let mut ring = tr.conv_pos_prev;
1648
1649                scratch.grad_xz[ch] += g * layer.conv_w[base];
1650                scratch.grad_xz[ch] += conv_future[tr.conv_pos_prev];
1651
1652                if scope.mixer_conv {
1653                    scratch.grad_conv_w[base] += g * tr.xz[ch];
1654                    if layer.conv_b.is_some() {
1655                        scratch.grad_conv_b[ch] += g;
1656                    }
1657                }
1658
1659                conv_future[tr.conv_pos_prev] = 0.0;
1660                for tap in 1..k {
1661                    ring = if ring == 0 { k - 1 } else { ring - 1 };
1662                    conv_future[ring] += g * layer.conv_w[base + tap];
1663                    if scope.mixer_conv {
1664                        scratch.grad_conv_w[base + tap] += g * tr.conv_prev[base + ring];
1665                    }
1666                }
1667            }
1668            if scope.mixer_conv {
1669                add_vec_grad(
1670                    layer_grads.conv_w.as_mut_slice(),
1671                    scratch.grad_conv_w.as_slice(),
1672                );
1673                if let Some(conv_b) = layer_grads.conv_b.as_mut() {
1674                    add_vec_grad(conv_b.as_mut_slice(), scratch.grad_conv_b.as_slice());
1675                }
1676            }
1677
1678            unsafe {
1679                kernel::gemv_t(
1680                    layer.in_proj_w.as_ptr(),
1681                    scratch.grad_xz.as_ptr(),
1682                    scratch.grad_norm.as_mut_ptr(),
1683                    2 * i,
1684                    c,
1685                );
1686            }
1687            if scope.mixer_proj {
1688                add_outer_grad(
1689                    layer_grads.in_proj_w.as_mut_slice(),
1690                    2 * i,
1691                    c,
1692                    scratch.grad_xz.as_slice(),
1693                    tr.norm.as_slice(),
1694                );
1695                if let Some(in_proj_b) = layer_grads.in_proj_b.as_mut() {
1696                    add_vec_grad(in_proj_b.as_mut_slice(), scratch.grad_xz.as_slice());
1697                }
1698            }
1699
1700            rms_norm_backward(
1701                tr.h_in.as_slice(),
1702                layer.norm_w.as_slice(),
1703                scratch.grad_norm.as_slice(),
1704                self.cfg.layer_norm_eps,
1705                scratch.grad_h.as_mut_slice(),
1706                scratch.grad_out.as_mut_slice(),
1707            );
1708            if scope.layer_norm {
1709                add_vec_grad(
1710                    layer_grads.norm_w.as_mut_slice(),
1711                    scratch.grad_out.as_slice(),
1712                );
1713                if let Some(norm_b) = layer_grads.norm_b.as_mut() {
1714                    add_vec_grad(norm_b.as_mut_slice(), scratch.grad_norm.as_slice());
1715                }
1716            }
1717
1718            for idx in 0..c {
1719                scratch.grad_h[idx] += scratch.grad_residual[idx];
1720            }
1721        }
1722
1723        if scope.embed {
1724            let tok = trace.token.min(self.cfg.vocab_size.saturating_sub(1));
1725            let row_off = tok * c;
1726            add_vec_grad(
1727                &mut grads.embeddings.as_mut_slice()[row_off..row_off + c],
1728                scratch.grad_h.as_slice(),
1729            );
1730        }
1731
1732        Ok(())
1733    }
1734
1735    #[allow(clippy::too_many_arguments)]
1736    /// Run one TBPTT training segment and write the resulting live state.
1737    pub fn online_train_segment_tbptt(
1738        &mut self,
1739        scratch: &mut ScratchBuffers,
1740        start_state: &State,
1741        steps: &[(u32, u8, Vec<f64>)],
1742        scope: TrainScopeMask,
1743        optimizer: OptimizerKind,
1744        lr: f32,
1745        clip: f32,
1746        replay_chunk: usize,
1747        adam_t: &mut usize,
1748        model_adam: Option<&mut FullAdamState>,
1749        out_bias: Option<&mut [f32]>,
1750        out_bias_adam_m: Option<&mut [f32]>,
1751        out_bias_adam_v: Option<&mut [f32]>,
1752        live_state_out: &mut State,
1753    ) -> Result<()> {
1754        if steps.is_empty() {
1755            *live_state_out = start_state.clone();
1756            return Ok(());
1757        }
1758
1759        let grad_scale = 1.0f32 / (steps.len() as f32);
1760        let chunk = replay_chunk.max(1).min(steps.len().max(1));
1761        let mut grads = self.new_full_grad_state();
1762        let mut recurrent = self.new_recurrent_grad_state();
1763        recurrent.zero();
1764        let mut bias_grad = out_bias.as_deref().map(|b| vec![0.0f32; b.len()]);
1765
1766        {
1767            let mut checkpoints = Vec::<State>::new();
1768            let mut checkpoint_state = start_state.clone();
1769            scratch.set_capture_train_trace(false);
1770            for chunk_start in (0..steps.len()).step_by(chunk) {
1771                checkpoints.push(checkpoint_state.clone());
1772                let chunk_end = (chunk_start + chunk).min(steps.len());
1773                for (input_token, _, _) in &steps[chunk_start..chunk_end] {
1774                    let _ = self.forward(scratch, *input_token, &mut checkpoint_state);
1775                }
1776            }
1777
1778            for chunk_idx in (0..checkpoints.len()).rev() {
1779                let chunk_start = chunk_idx * chunk;
1780                let chunk_end = (chunk_start + chunk).min(steps.len());
1781                let mut state = checkpoints[chunk_idx].clone();
1782                let mut step_states = Vec::<State>::with_capacity(chunk_end - chunk_start + 1);
1783                let mut step_traces =
1784                    Vec::<TokenTrainTrace>::with_capacity(chunk_end - chunk_start);
1785                step_states.push(state.clone());
1786
1787                for (input_token, _, _) in &steps[chunk_start..chunk_end] {
1788                    scratch.set_capture_train_trace(true);
1789                    let _ = self.forward(scratch, *input_token, &mut state);
1790                    step_traces.push(TokenTrainTrace::from_scratch(scratch));
1791                    step_states.push(state.clone());
1792                }
1793
1794                for local_idx in (0..step_traces.len()).rev() {
1795                    let (_, target_symbol, pdf) = &steps[chunk_start + local_idx];
1796                    self.accumulate_token_step_gradients(
1797                        scratch,
1798                        &step_traces[local_idx],
1799                        &step_states[local_idx + 1],
1800                        *target_symbol,
1801                        pdf,
1802                        grad_scale,
1803                        scope,
1804                        &mut grads,
1805                        bias_grad.as_deref_mut(),
1806                        &mut recurrent,
1807                    )?;
1808                }
1809            }
1810        }
1811
1812        self.apply_full_gradients(
1813            &grads,
1814            scope,
1815            optimizer,
1816            lr,
1817            clip,
1818            adam_t,
1819            model_adam,
1820            out_bias,
1821            bias_grad.as_deref(),
1822            out_bias_adam_m,
1823            out_bias_adam_v,
1824        )?;
1825
1826        scratch.set_capture_train_trace(false);
1827        *live_state_out = start_state.clone();
1828        for (input_token, _, _) in steps {
1829            let _ = self.forward(scratch, *input_token, live_state_out);
1830        }
1831        Ok(())
1832    }
1833
1834    /// Forward a single byte token and return logits for next symbol.
1835    #[inline(never)]
1836    pub fn forward<'a>(
1837        &'a self,
1838        scratch: &'a mut ScratchBuffers,
1839        token: u32,
1840        state: &mut State,
1841    ) -> &'a [f32] {
1842        if scratch.capture_train_trace {
1843            self.forward_impl::<true>(scratch, token, state)
1844        } else {
1845            self.forward_impl::<false>(scratch, token, state)
1846        }
1847    }
1848
1849    fn forward_impl<'a, const CAPTURE: bool>(
1850        &'a self,
1851        scratch: &'a mut ScratchBuffers,
1852        token: u32,
1853        state: &mut State,
1854    ) -> &'a [f32] {
1855        let c = self.cfg.hidden_size;
1856        let i = self.cfg.inner_size;
1857        let s = self.cfg.state_size;
1858        let r = self.cfg.dt_rank;
1859
1860        let token_idx = (token as usize).min(self.cfg.vocab_size.saturating_sub(1));
1861        let emb_off = token_idx * c;
1862        if CAPTURE {
1863            scratch.train_token = token_idx;
1864            scratch.train_trace_valid = true;
1865        } else {
1866            scratch.train_trace_valid = false;
1867        }
1868        scratch
1869            .h
1870            .as_mut_slice()
1871            .copy_from_slice(&self.embeddings.as_slice()[emb_off..emb_off + c]);
1872
1873        for layer_idx in 0..self.cfg.num_layers {
1874            let layer = &self.layers[layer_idx];
1875            let st = &mut state.layers[layer_idx];
1876            if CAPTURE {
1877                let tr = &mut scratch.train_trace_layers[layer_idx];
1878                tr.h_in.as_mut_slice().copy_from_slice(scratch.h.as_slice());
1879                tr.ssm_prev
1880                    .as_mut_slice()
1881                    .copy_from_slice(st.ssm.as_slice());
1882                tr.conv_prev
1883                    .as_mut_slice()
1884                    .copy_from_slice(st.conv.as_slice());
1885                tr.conv_pos_prev = st.conv_pos;
1886            }
1887
1888            rms_norm(
1889                scratch.h.as_slice(),
1890                layer.norm_w.as_slice(),
1891                layer.norm_b.as_ref().map(Tensor1D::as_slice),
1892                self.cfg.layer_norm_eps,
1893                scratch.norm.as_mut_slice(),
1894            );
1895            if CAPTURE {
1896                let tr = &mut scratch.train_trace_layers[layer_idx];
1897                tr.norm
1898                    .as_mut_slice()
1899                    .copy_from_slice(scratch.norm.as_slice());
1900            }
1901
1902            unsafe {
1903                // SAFETY: row-major dimensions are validated at load time.
1904                kernel::gemv(
1905                    layer.in_proj_w.as_ptr(),
1906                    scratch.norm.as_ptr(),
1907                    scratch.xz.as_mut_ptr(),
1908                    i * 2,
1909                    c,
1910                );
1911            }
1912            if let Some(bias) = &layer.in_proj_b {
1913                for (dst, &b) in scratch.xz.as_mut_slice().iter_mut().zip(bias.as_slice()) {
1914                    *dst += b;
1915                }
1916            }
1917            if CAPTURE {
1918                let tr = &mut scratch.train_trace_layers[layer_idx];
1919                tr.xz.as_mut_slice().copy_from_slice(scratch.xz.as_slice());
1920            }
1921
1922            depthwise_conv_step(
1923                &scratch.xz.as_slice()[0..i],
1924                &layer.conv_w,
1925                layer.conv_b.as_ref(),
1926                self.cfg.conv_kernel,
1927                st,
1928                scratch.conv.as_mut_slice(),
1929            );
1930            if CAPTURE {
1931                let tr = &mut scratch.train_trace_layers[layer_idx];
1932                tr.conv_pre
1933                    .as_mut_slice()
1934                    .copy_from_slice(scratch.conv.as_slice());
1935            }
1936
1937            if CAPTURE {
1938                let tr = &mut scratch.train_trace_layers[layer_idx];
1939                for idx in 0..i {
1940                    let (post, sig) = silu_with_sigmoid(scratch.conv[idx]);
1941                    scratch.conv[idx] = post;
1942                    tr.conv_post[idx] = post;
1943                    tr.conv_sigmoid[idx] = sig;
1944                }
1945            } else {
1946                for idx in 0..i {
1947                    scratch.conv[idx] = silu(scratch.conv[idx]);
1948                }
1949            }
1950
1951            unsafe {
1952                // SAFETY: row-major dimensions are validated at load time.
1953                kernel::gemv(
1954                    layer.x_proj_w.as_ptr(),
1955                    scratch.conv.as_ptr(),
1956                    scratch.proj.as_mut_ptr(),
1957                    r + 2 * s,
1958                    i,
1959                );
1960            }
1961            if let Some(bias) = &layer.x_proj_b {
1962                for (dst, &b) in scratch.proj.as_mut_slice().iter_mut().zip(bias.as_slice()) {
1963                    *dst += b;
1964                }
1965            }
1966            if CAPTURE {
1967                let tr = &mut scratch.train_trace_layers[layer_idx];
1968                tr.proj
1969                    .as_mut_slice()
1970                    .copy_from_slice(scratch.proj.as_slice());
1971            }
1972
1973            unsafe {
1974                // SAFETY: row-major dimensions are validated at load time.
1975                kernel::gemv(
1976                    layer.dt_proj_w.as_ptr(),
1977                    scratch.proj.as_ptr(),
1978                    scratch.dt.as_mut_ptr(),
1979                    i,
1980                    r,
1981                );
1982            }
1983            if CAPTURE {
1984                let tr = &mut scratch.train_trace_layers[layer_idx];
1985                tr.dt_raw
1986                    .as_mut_slice()
1987                    .copy_from_slice(scratch.dt.as_slice());
1988            }
1989
1990            let proj = scratch.proj.as_slice();
1991            let b_vec = &proj[r..r + s];
1992            let c_vec = &proj[r + s..r + 2 * s];
1993            let conv = scratch.conv.as_slice();
1994            let dt_raw = scratch.dt.as_slice();
1995            let xz = scratch.xz.as_slice();
1996            let d = layer.d.as_slice();
1997            let a = layer.a.as_slice();
1998            let dt_bias = layer.dt_proj_b.as_slice();
1999            let ssm = st.ssm.as_mut_slice();
2000            let b_ptr = b_vec.as_ptr();
2001            let c_ptr = c_vec.as_ptr();
2002            let a_ptr = a.as_ptr();
2003            let ssm_ptr = ssm.as_mut_ptr();
2004
2005            if s == 16 {
2006                for ch in 0..i {
2007                    let x_ch = conv[ch];
2008                    let dt_pre = dt_raw[ch] + dt_bias[ch];
2009                    let gate_pre = xz[i + ch];
2010                    let dt = softplus(dt_pre);
2011                    let (gate, gate_sigmoid) = silu_with_sigmoid(gate_pre);
2012                    let x_dt = x_ch * dt;
2013
2014                    let ssm_row_off = ch * s;
2015                    let row_a = unsafe { a_ptr.add(ssm_row_off) };
2016                    let row_ssm = unsafe { ssm_ptr.add(ssm_row_off) };
2017                    let trace_ptr = if CAPTURE {
2018                        unsafe {
2019                            scratch.train_trace_layers[layer_idx]
2020                                .d_a
2021                                .as_mut_ptr()
2022                                .add(ssm_row_off)
2023                        }
2024                    } else {
2025                        std::ptr::null_mut()
2026                    };
2027                    let mut y = d[ch] * x_ch;
2028                    y += unsafe {
2029                        selective_scan_state16::<CAPTURE>(
2030                            row_a, row_ssm, dt, x_dt, b_ptr, c_ptr, trace_ptr,
2031                        )
2032                    };
2033                    if CAPTURE {
2034                        let tr = &mut scratch.train_trace_layers[layer_idx];
2035                        tr.dt[ch] = dt;
2036                        tr.gate[ch] = gate;
2037                        tr.gate_sigmoid[ch] = gate_sigmoid;
2038                        tr.y_pre[ch] = y;
2039                    }
2040                    scratch.y[ch] = y * gate;
2041                    if CAPTURE {
2042                        scratch.train_trace_layers[layer_idx].y[ch] = scratch.y[ch];
2043                    }
2044                }
2045            } else {
2046                for ch in 0..i {
2047                    let x_ch = conv[ch];
2048                    let dt_pre = dt_raw[ch] + dt_bias[ch];
2049                    let gate_pre = xz[i + ch];
2050                    let dt = softplus(dt_pre);
2051                    let (gate, gate_sigmoid) = silu_with_sigmoid(gate_pre);
2052                    let x_dt = x_ch * dt;
2053
2054                    let mut y = d[ch] * x_ch;
2055                    let ssm_row_off = ch * s;
2056                    let row_a = unsafe { a_ptr.add(ssm_row_off) };
2057                    let row_ssm = unsafe { ssm_ptr.add(ssm_row_off) };
2058                    let mut j = 0usize;
2059                    while j < s {
2060                        let prev = unsafe { *row_ssm.add(j) };
2061                        let d_a = (dt * unsafe { *row_a.add(j) }).exp();
2062                        if CAPTURE {
2063                            scratch.train_trace_layers[layer_idx].d_a[ssm_row_off + j] = d_a;
2064                        }
2065                        let next = prev * d_a + x_dt * unsafe { *b_ptr.add(j) };
2066                        unsafe { *row_ssm.add(j) = next };
2067                        y += next * unsafe { *c_ptr.add(j) };
2068                        j += 1;
2069                    }
2070                    if CAPTURE {
2071                        let tr = &mut scratch.train_trace_layers[layer_idx];
2072                        tr.dt[ch] = dt;
2073                        tr.gate[ch] = gate;
2074                        tr.gate_sigmoid[ch] = gate_sigmoid;
2075                        tr.y_pre[ch] = y;
2076                    }
2077                    scratch.y[ch] = y * gate;
2078                    if CAPTURE {
2079                        scratch.train_trace_layers[layer_idx].y[ch] = scratch.y[ch];
2080                    }
2081                }
2082            }
2083
2084            unsafe {
2085                // SAFETY: row-major dimensions are validated at load time.
2086                kernel::gemv(
2087                    layer.out_proj_w.as_ptr(),
2088                    scratch.y.as_ptr(),
2089                    scratch.out.as_mut_ptr(),
2090                    c,
2091                    i,
2092                );
2093            }
2094            if let Some(bias) = &layer.out_proj_b {
2095                for (dst, &b) in scratch.out.as_mut_slice().iter_mut().zip(bias.as_slice()) {
2096                    *dst += b;
2097                }
2098            }
2099            if CAPTURE {
2100                let tr = &mut scratch.train_trace_layers[layer_idx];
2101                tr.out
2102                    .as_mut_slice()
2103                    .copy_from_slice(scratch.out.as_slice());
2104            }
2105
2106            unsafe {
2107                // SAFETY: vector lengths match hidden size.
2108                kernel::add_inplace(scratch.h.as_mut_ptr(), scratch.out.as_ptr(), c);
2109            }
2110        }
2111        if CAPTURE {
2112            scratch
2113                .train_h_final
2114                .as_mut_slice()
2115                .copy_from_slice(scratch.h.as_slice());
2116        }
2117
2118        rms_norm(
2119            scratch.h.as_slice(),
2120            self.final_norm_w.as_slice(),
2121            self.final_norm_b.as_ref().map(Tensor1D::as_slice),
2122            self.cfg.layer_norm_eps,
2123            scratch.norm.as_mut_slice(),
2124        );
2125
2126        unsafe {
2127            // SAFETY: row-major dimensions are validated at load time.
2128            kernel::gemv(
2129                self.lm_head.as_ptr(),
2130                scratch.norm.as_ptr(),
2131                scratch.logits.as_mut_ptr(),
2132                self.cfg.vocab_size,
2133                c,
2134            );
2135        }
2136        if let Some(bias) = &self.lm_head_b {
2137            for (dst, &b) in scratch
2138                .logits
2139                .as_mut_slice()
2140                .iter_mut()
2141                .zip(bias.as_slice())
2142            {
2143                *dst += b;
2144            }
2145        }
2146
2147        scratch.logits.as_slice()
2148    }
2149
2150    /// Exact single-step (`bptt=1`) online training for all Mamba parameters.
2151    ///
2152    /// This consumes the latest forward trace captured in `scratch` and applies
2153    /// one gradient step using the externally provided PDF/target symbol.
2154    #[allow(clippy::too_many_arguments)]
2155    #[allow(clippy::needless_range_loop)]
2156    pub fn online_train_step_bptt1(
2157        &mut self,
2158        scratch: &mut ScratchBuffers,
2159        state: &State,
2160        symbol: u8,
2161        pdf: &[f64],
2162        scope: TrainScopeMask,
2163        optimizer: OptimizerKind,
2164        lr: f32,
2165        clip: f32,
2166        adam_t: &mut usize,
2167        model_adam: Option<&mut FullAdamState>,
2168        out_bias: Option<&mut [f32]>,
2169        out_bias_adam_m: Option<&mut [f32]>,
2170        out_bias_adam_v: Option<&mut [f32]>,
2171    ) -> Result<()> {
2172        if !scope.trains_model_params() && !scope.bias {
2173            return Ok(());
2174        }
2175        let needs_backprop = scope.embed
2176            || scope.layer_norm
2177            || scope.mixer_conv
2178            || scope.mixer_ssm
2179            || scope.mixer_proj;
2180        if needs_backprop && !scratch.train_trace_valid {
2181            bail!("mamba full training trace is missing; run one forward step first");
2182        }
2183        let c = self.cfg.hidden_size;
2184        let i = self.cfg.inner_size;
2185        let s = self.cfg.state_size;
2186        let r = self.cfg.dt_rank;
2187        let v = self.cfg.vocab_size.min(pdf.len());
2188        if v == 0 {
2189            return Ok(());
2190        }
2191
2192        let mut adam_cfg = None::<AdamStep>;
2193        let mut model_adam = model_adam;
2194        if matches!(optimizer, OptimizerKind::Adam) {
2195            *adam_t = adam_t.saturating_add(1);
2196            let t = (*adam_t).max(1) as i32;
2197            let b1 = 0.9f32;
2198            let b2 = 0.999f32;
2199            adam_cfg = Some(AdamStep {
2200                lr,
2201                clip,
2202                b1,
2203                b2,
2204                eps: 1e-8,
2205                bias_corr1: 1.0 - b1.powi(t),
2206                bias_corr2: 1.0 - b2.powi(t),
2207            });
2208            if scope.trains_model_params() && model_adam.is_none() {
2209                bail!("mamba Adam full-training state is missing");
2210            }
2211        }
2212
2213        // dL/d(logits + out_bias) for cross-entropy with one-hot target.
2214        scratch.grad_logits.zero();
2215        for tok in 0..v {
2216            let p = pdf[tok].clamp(1e-12, 1.0) as f32;
2217            let target = if tok == symbol as usize { 1.0 } else { 0.0 };
2218            let mut g = target - p;
2219            if clip > 0.0 {
2220                g = g.clamp(-clip, clip);
2221            }
2222            scratch.grad_logits[tok] = g;
2223        }
2224
2225        if scope.bias
2226            && let Some(bias) = out_bias
2227        {
2228            let n = bias.len().min(v);
2229            let grad = &scratch.grad_logits.as_slice()[..n];
2230            match optimizer {
2231                OptimizerKind::Sgd => {
2232                    for idx in 0..n {
2233                        bias[idx] += lr * grad[idx];
2234                    }
2235                }
2236                OptimizerKind::Adam => {
2237                    let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2238                    let Some(m) = out_bias_adam_m else {
2239                        bail!("mamba Adam output-bias moments are missing");
2240                    };
2241                    let Some(vv) = out_bias_adam_v else {
2242                        bail!("mamba Adam output-bias moments are missing");
2243                    };
2244                    if m.len() < n || vv.len() < n {
2245                        bail!("mamba Adam output-bias moments have invalid shape");
2246                    }
2247                    for idx in 0..n {
2248                        let g = grad[idx];
2249                        m[idx] = cfg.b1 * m[idx] + (1.0 - cfg.b1) * g;
2250                        vv[idx] = cfg.b2 * vv[idx] + (1.0 - cfg.b2) * g * g;
2251                        let m_hat = m[idx] / cfg.bias_corr1;
2252                        let v_hat = vv[idx] / cfg.bias_corr2;
2253                        bias[idx] += cfg.lr * m_hat / (v_hat.sqrt() + cfg.eps);
2254                    }
2255                }
2256            }
2257        }
2258
2259        // Head backward/update. When scope.head is enabled we fuse grad_h and
2260        // parameter update in one pass to reduce memory traffic.
2261        scratch.grad_h.zero();
2262        let norm_in = scratch.norm.as_slice();
2263        if scope.head {
2264            match optimizer {
2265                OptimizerKind::Sgd => {
2266                    let head = self.lm_head.as_mut_slice();
2267                    let norm_ptr = norm_in.as_ptr();
2268                    let grad_h_ptr = scratch.grad_h.as_mut_slice().as_mut_ptr();
2269                    for tok in 0..v {
2270                        let g = scratch.grad_logits[tok];
2271                        let row_off = tok * c;
2272                        let mut j = 0usize;
2273                        unsafe {
2274                            let g8 = f32x8::splat(g);
2275                            let lr8 = f32x8::splat(lr);
2276                            while j + 8 <= c {
2277                                let idx = row_off + j;
2278                                let wv = head.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
2279                                let nv = norm_ptr.add(j).cast::<f32x8>().read_unaligned();
2280                                let ghv = grad_h_ptr.add(j).cast::<f32x8>().read_unaligned();
2281                                grad_h_ptr
2282                                    .add(j)
2283                                    .cast::<f32x8>()
2284                                    .write_unaligned(ghv + wv * g8);
2285                                head.as_mut_ptr()
2286                                    .add(idx)
2287                                    .cast::<f32x8>()
2288                                    .write_unaligned(wv + (g8 * nv) * lr8);
2289                                j += 8;
2290                            }
2291                        }
2292                        while j < c {
2293                            let idx = row_off + j;
2294                            let w_old = head[idx];
2295                            scratch.grad_h[j] += w_old * g;
2296                            head[idx] = w_old + lr * g * norm_in[j];
2297                            j += 1;
2298                        }
2299                    }
2300                    if let Some(b) = self.lm_head_b.as_mut() {
2301                        for tok in 0..v.min(b.len()) {
2302                            b[tok] += lr * scratch.grad_logits[tok];
2303                        }
2304                    }
2305                }
2306                OptimizerKind::Adam => {
2307                    let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2308                    let adam = model_adam.as_mut().expect("adam state exists");
2309                    let head = self.lm_head.as_mut_slice();
2310                    let hm = adam.lm_head.m.as_mut_slice();
2311                    let hv = adam.lm_head.v.as_mut_slice();
2312                    let norm_ptr = norm_in.as_ptr();
2313                    let grad_h_ptr = scratch.grad_h.as_mut_slice().as_mut_ptr();
2314                    let b1 = f32x8::splat(cfg.b1);
2315                    let b2 = f32x8::splat(cfg.b2);
2316                    let one_b1 = f32x8::splat(1.0 - cfg.b1);
2317                    let one_b2 = f32x8::splat(1.0 - cfg.b2);
2318                    let inv_bc1 = f32x8::splat(1.0 / cfg.bias_corr1);
2319                    let inv_bc2 = f32x8::splat(1.0 / cfg.bias_corr2);
2320                    let eps = f32x8::splat(cfg.eps);
2321                    let lr8 = f32x8::splat(cfg.lr);
2322                    for tok in 0..v {
2323                        let g = scratch.grad_logits[tok];
2324                        let row_off = tok * c;
2325                        let mut j = 0usize;
2326                        unsafe {
2327                            let g8 = f32x8::splat(g);
2328                            while j + 8 <= c {
2329                                let idx = row_off + j;
2330                                let wv = head.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
2331                                let nv = norm_ptr.add(j).cast::<f32x8>().read_unaligned();
2332                                let ghv = grad_h_ptr.add(j).cast::<f32x8>().read_unaligned();
2333                                grad_h_ptr
2334                                    .add(j)
2335                                    .cast::<f32x8>()
2336                                    .write_unaligned(ghv + wv * g8);
2337
2338                                let gg = g8 * nv;
2339                                let hm_old = hm.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
2340                                let hv_old = hv.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
2341                                let m = hm_old * b1 + gg * one_b1;
2342                                let vv = hv_old * b2 + (gg * gg) * one_b2;
2343                                hm.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(m);
2344                                hv.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(vv);
2345                                let upd = ((m * inv_bc1) / ((vv * inv_bc2).sqrt() + eps)) * lr8;
2346                                head.as_mut_ptr()
2347                                    .add(idx)
2348                                    .cast::<f32x8>()
2349                                    .write_unaligned(wv + upd);
2350                                j += 8;
2351                            }
2352                        }
2353                        while j < c {
2354                            let idx = row_off + j;
2355                            let w_old = head[idx];
2356                            scratch.grad_h[j] += w_old * g;
2357                            let gg = g * norm_in[j];
2358                            let m = cfg.b1 * hm[idx] + (1.0 - cfg.b1) * gg;
2359                            let vv = cfg.b2 * hv[idx] + (1.0 - cfg.b2) * gg * gg;
2360                            hm[idx] = m;
2361                            hv[idx] = vv;
2362                            let m_hat = m / cfg.bias_corr1;
2363                            let v_hat = vv / cfg.bias_corr2;
2364                            head[idx] = w_old + cfg.lr * m_hat / (v_hat.sqrt() + cfg.eps);
2365                            j += 1;
2366                        }
2367                    }
2368                    if let (Some(b), Some(bm)) = (self.lm_head_b.as_mut(), adam.lm_head_b.as_mut())
2369                    {
2370                        let bm_m = bm.m.as_mut_slice();
2371                        let bm_v = bm.v.as_mut_slice();
2372                        for tok in 0..v.min(b.len()) {
2373                            let g = scratch.grad_logits[tok];
2374                            let m = cfg.b1 * bm_m[tok] + (1.0 - cfg.b1) * g;
2375                            let vv = cfg.b2 * bm_v[tok] + (1.0 - cfg.b2) * g * g;
2376                            bm_m[tok] = m;
2377                            bm_v[tok] = vv;
2378                            let m_hat = m / cfg.bias_corr1;
2379                            let v_hat = vv / cfg.bias_corr2;
2380                            b[tok] += cfg.lr * m_hat / (v_hat.sqrt() + cfg.eps);
2381                        }
2382                    }
2383                }
2384            }
2385        } else {
2386            let head = self.lm_head.as_slice();
2387            for tok in 0..v {
2388                let g = scratch.grad_logits[tok];
2389                let row_off = tok * c;
2390                for j in 0..c {
2391                    scratch.grad_h[j] += head[row_off + j] * g;
2392                }
2393            }
2394        }
2395
2396        if !needs_backprop {
2397            return Ok(());
2398        }
2399
2400        // Final RMSNorm backward.
2401        rms_norm_backward(
2402            scratch.train_h_final.as_slice(),
2403            self.final_norm_w.as_slice(),
2404            scratch.grad_h.as_slice(),
2405            self.cfg.layer_norm_eps,
2406            scratch.grad_norm.as_mut_slice(),
2407            scratch.grad_out.as_mut_slice(),
2408        );
2409        if scope.layer_norm {
2410            match optimizer {
2411                OptimizerKind::Sgd => {
2412                    for idx in 0..c {
2413                        self.final_norm_w[idx] += lr * scratch.grad_out[idx];
2414                    }
2415                    if let Some(b) = self.final_norm_b.as_mut() {
2416                        for idx in 0..c.min(b.len()) {
2417                            b[idx] += lr * scratch.grad_h[idx];
2418                        }
2419                    }
2420                }
2421                OptimizerKind::Adam => {
2422                    let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2423                    let adam = model_adam.as_mut().expect("adam state exists");
2424                    apply_adam_vec_update(
2425                        self.final_norm_w.as_mut_slice(),
2426                        scratch.grad_out.as_slice(),
2427                        &mut adam.final_norm_w,
2428                        cfg,
2429                    );
2430                    if let (Some(b), Some(bm)) =
2431                        (self.final_norm_b.as_mut(), adam.final_norm_b.as_mut())
2432                    {
2433                        apply_adam_vec_update(b.as_mut_slice(), scratch.grad_h.as_slice(), bm, cfg);
2434                    }
2435                }
2436            }
2437        }
2438        scratch
2439            .grad_h
2440            .as_mut_slice()
2441            .copy_from_slice(scratch.grad_norm.as_slice());
2442
2443        // Layerwise reverse-mode pass.
2444        for layer_idx in (0..self.cfg.num_layers).rev() {
2445            let tr = &scratch.train_trace_layers[layer_idx];
2446            let st_new = &state.layers[layer_idx];
2447            let layer = &mut self.layers[layer_idx];
2448
2449            scratch
2450                .grad_out
2451                .as_mut_slice()
2452                .copy_from_slice(scratch.grad_h.as_slice());
2453
2454            // grad_y = out_proj^T @ grad_out.
2455            unsafe {
2456                kernel::gemv_t(
2457                    layer.out_proj_w.as_ptr(),
2458                    scratch.grad_out.as_ptr(),
2459                    scratch.grad_y.as_mut_ptr(),
2460                    c,
2461                    i,
2462                );
2463            }
2464
2465            if scope.mixer_proj {
2466                match optimizer {
2467                    OptimizerKind::Sgd => {
2468                        for row in 0..c {
2469                            let g = scratch.grad_out[row];
2470                            let off = row * i;
2471                            for col in 0..i {
2472                                layer.out_proj_w[off + col] += lr * g * tr.y[col];
2473                            }
2474                        }
2475                        if let Some(b) = layer.out_proj_b.as_mut() {
2476                            for row in 0..c.min(b.len()) {
2477                                b[row] += lr * scratch.grad_out[row];
2478                            }
2479                        }
2480                    }
2481                    OptimizerKind::Adam => {
2482                        let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2483                        let adam_layer =
2484                            &mut model_adam.as_mut().expect("adam state exists").layers[layer_idx];
2485                        apply_adam_outer_update(
2486                            layer.out_proj_w.as_mut_slice(),
2487                            c,
2488                            i,
2489                            scratch.grad_out.as_slice(),
2490                            tr.y.as_slice(),
2491                            &mut adam_layer.out_proj_w,
2492                            cfg,
2493                        );
2494                        if let (Some(b), Some(bm)) =
2495                            (layer.out_proj_b.as_mut(), adam_layer.out_proj_b.as_mut())
2496                        {
2497                            apply_adam_vec_update(
2498                                b.as_mut_slice(),
2499                                scratch.grad_out.as_slice(),
2500                                bm,
2501                                cfg,
2502                            );
2503                        }
2504                    }
2505                }
2506            }
2507
2508            scratch
2509                .grad_residual
2510                .as_mut_slice()
2511                .copy_from_slice(scratch.grad_out.as_slice());
2512            scratch.grad_xz.zero();
2513            scratch.grad_b.zero();
2514            scratch.grad_c.zero();
2515
2516            // y path, selective scan, and SSM params.
2517            for ch in 0..i {
2518                let g_y = scratch.grad_y[ch];
2519                let gate = tr.gate[ch];
2520                let y_pre = tr.y_pre[ch];
2521                let g_y_pre = g_y * gate;
2522                let g_gate = g_y * y_pre;
2523                scratch.grad_xz[i + ch] =
2524                    g_gate * silu_grad_from_sigmoid(tr.xz[i + ch], tr.gate_sigmoid[ch]);
2525
2526                let conv = tr.conv_post[ch];
2527                let dt = tr.dt[ch];
2528                let xdt = conv * dt;
2529                let mut g_xdt = 0.0f32;
2530                let mut g_dt = 0.0f32;
2531
2532                let mut g_conv = g_y_pre * layer.d[ch];
2533                if scope.mixer_ssm {
2534                    scratch.grad_ssm_d[ch] = g_y_pre * conv;
2535                }
2536
2537                let row = ch * s;
2538                for j in 0..s {
2539                    let idx = row + j;
2540                    let c_j = tr.proj[r + s + j];
2541                    let b_j = tr.proj[r + j];
2542                    let s_prev = tr.ssm_prev[idx];
2543                    let s_new = st_new.ssm[idx];
2544                    let a_ij = layer.a[idx];
2545
2546                    let g_ssm_new = g_y_pre * c_j;
2547                    scratch.grad_c[j] += g_y_pre * s_new;
2548                    g_xdt += g_ssm_new * b_j;
2549                    scratch.grad_b[j] += g_ssm_new * xdt;
2550
2551                    let d_a = tr.d_a[idx];
2552                    let g_d_a = g_ssm_new * s_prev;
2553                    g_dt += g_d_a * d_a * a_ij;
2554                    if scope.mixer_ssm {
2555                        scratch.grad_ssm_a[idx] = g_d_a * d_a * dt;
2556                    }
2557                }
2558
2559                g_conv += g_xdt * dt;
2560                g_dt += g_xdt * conv;
2561                let dt_pre = tr.dt_raw[ch] + layer.dt_proj_b[ch];
2562                scratch.grad_dt_raw[ch] = g_dt * sigmoid(dt_pre);
2563                scratch.grad_conv[ch] = g_conv;
2564            }
2565
2566            if scope.mixer_ssm {
2567                match optimizer {
2568                    OptimizerKind::Sgd => {
2569                        if clip > 0.0 {
2570                            for idx in 0..i {
2571                                layer.d[idx] += lr * scratch.grad_ssm_d[idx].clamp(-clip, clip);
2572                            }
2573                            for idx in 0..(i * s) {
2574                                let g_log =
2575                                    (scratch.grad_ssm_a[idx] * layer.a[idx]).clamp(-clip, clip);
2576                                let new_log = layer.a_log[idx] + lr * g_log;
2577                                layer.a_log[idx] = new_log;
2578                                layer.a[idx] = -new_log.exp();
2579                            }
2580                        } else {
2581                            for idx in 0..i {
2582                                layer.d[idx] += lr * scratch.grad_ssm_d[idx];
2583                            }
2584                            for idx in 0..(i * s) {
2585                                let g_log = scratch.grad_ssm_a[idx] * layer.a[idx];
2586                                let new_log = layer.a_log[idx] + lr * g_log;
2587                                layer.a_log[idx] = new_log;
2588                                layer.a[idx] = -new_log.exp();
2589                            }
2590                        }
2591                    }
2592                    OptimizerKind::Adam => {
2593                        let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2594                        let adam_layer =
2595                            &mut model_adam.as_mut().expect("adam state exists").layers[layer_idx];
2596                        for idx in 0..(i * s) {
2597                            scratch.grad_ssm_a_log[idx] = scratch.grad_ssm_a[idx] * layer.a[idx];
2598                        }
2599                        apply_adam_vec_update(
2600                            layer.d.as_mut_slice(),
2601                            scratch.grad_ssm_d.as_slice(),
2602                            &mut adam_layer.d,
2603                            cfg,
2604                        );
2605                        apply_adam_vec_update_and_sync_neg_exp(
2606                            layer.a_log.as_mut_slice(),
2607                            layer.a.as_mut_slice(),
2608                            scratch.grad_ssm_a_log.as_slice(),
2609                            &mut adam_layer.a,
2610                            cfg,
2611                        );
2612                    }
2613                }
2614            }
2615
2616            // dt_proj: grad_u = W^T grad_dt_raw, and parameter grads.
2617            unsafe {
2618                kernel::gemv_t(
2619                    layer.dt_proj_w.as_ptr(),
2620                    scratch.grad_dt_raw.as_ptr(),
2621                    scratch.grad_u.as_mut_ptr(),
2622                    i,
2623                    r,
2624                );
2625            }
2626            if scope.mixer_proj {
2627                match optimizer {
2628                    OptimizerKind::Sgd => {
2629                        for ch in 0..i {
2630                            let g = scratch.grad_dt_raw[ch];
2631                            let off = ch * r;
2632                            for kk in 0..r {
2633                                layer.dt_proj_w[off + kk] += lr * g * tr.proj[kk];
2634                            }
2635                            layer.dt_proj_b[ch] += lr * g;
2636                        }
2637                    }
2638                    OptimizerKind::Adam => {
2639                        let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2640                        let adam_layer =
2641                            &mut model_adam.as_mut().expect("adam state exists").layers[layer_idx];
2642                        apply_adam_outer_update(
2643                            layer.dt_proj_w.as_mut_slice(),
2644                            i,
2645                            r,
2646                            scratch.grad_dt_raw.as_slice(),
2647                            &tr.proj.as_slice()[..r],
2648                            &mut adam_layer.dt_proj_w,
2649                            cfg,
2650                        );
2651                        apply_adam_vec_update(
2652                            layer.dt_proj_b.as_mut_slice(),
2653                            scratch.grad_dt_raw.as_slice(),
2654                            &mut adam_layer.dt_proj_b,
2655                            cfg,
2656                        );
2657                    }
2658                }
2659            }
2660
2661            for kk in 0..r {
2662                scratch.grad_proj[kk] = scratch.grad_u[kk];
2663            }
2664            for j in 0..s {
2665                scratch.grad_proj[r + j] = scratch.grad_b[j];
2666                scratch.grad_proj[r + s + j] = scratch.grad_c[j];
2667            }
2668
2669            // x_proj: grad_conv += W^T grad_proj.
2670            unsafe {
2671                kernel::gemv_t(
2672                    layer.x_proj_w.as_ptr(),
2673                    scratch.grad_proj.as_ptr(),
2674                    scratch.grad_conv_pre.as_mut_ptr(),
2675                    r + 2 * s,
2676                    i,
2677                );
2678                kernel::add_inplace(
2679                    scratch.grad_conv.as_mut_ptr(),
2680                    scratch.grad_conv_pre.as_ptr(),
2681                    i,
2682                );
2683            }
2684            if scope.mixer_proj {
2685                match optimizer {
2686                    OptimizerKind::Sgd => {
2687                        for row in 0..(r + 2 * s) {
2688                            let g = scratch.grad_proj[row];
2689                            let off = row * i;
2690                            for col in 0..i {
2691                                layer.x_proj_w[off + col] += lr * g * tr.conv_post[col];
2692                            }
2693                        }
2694                        if let Some(b) = layer.x_proj_b.as_mut() {
2695                            for row in 0..(r + 2 * s).min(b.len()) {
2696                                b[row] += lr * scratch.grad_proj[row];
2697                            }
2698                        }
2699                    }
2700                    OptimizerKind::Adam => {
2701                        let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2702                        let adam_layer =
2703                            &mut model_adam.as_mut().expect("adam state exists").layers[layer_idx];
2704                        apply_adam_outer_update(
2705                            layer.x_proj_w.as_mut_slice(),
2706                            r + 2 * s,
2707                            i,
2708                            scratch.grad_proj.as_slice(),
2709                            tr.conv_post.as_slice(),
2710                            &mut adam_layer.x_proj_w,
2711                            cfg,
2712                        );
2713                        if let (Some(b), Some(bm)) =
2714                            (layer.x_proj_b.as_mut(), adam_layer.x_proj_b.as_mut())
2715                        {
2716                            apply_adam_vec_update(
2717                                b.as_mut_slice(),
2718                                scratch.grad_proj.as_slice(),
2719                                bm,
2720                                cfg,
2721                            );
2722                        }
2723                    }
2724                }
2725            }
2726
2727            // conv + silu + in-proj(x branch)
2728            for ch in 0..i {
2729                scratch.grad_conv_pre[ch] = scratch.grad_conv[ch]
2730                    * silu_grad_from_sigmoid(tr.conv_pre[ch], tr.conv_sigmoid[ch]);
2731            }
2732
2733            for ch in 0..i {
2734                let g = scratch.grad_conv_pre[ch];
2735                let base = ch * self.cfg.conv_kernel;
2736                let w0 = layer.conv_w[base];
2737                scratch.grad_xz[ch] += g * w0;
2738                if scope.mixer_conv && self.cfg.conv_kernel == 4 {
2739                    let vals = match tr.conv_pos_prev {
2740                        0 => [
2741                            tr.xz[ch],
2742                            tr.conv_prev[base + 3],
2743                            tr.conv_prev[base + 2],
2744                            tr.conv_prev[base + 1],
2745                        ],
2746                        1 => [
2747                            tr.xz[ch],
2748                            tr.conv_prev[base],
2749                            tr.conv_prev[base + 3],
2750                            tr.conv_prev[base + 2],
2751                        ],
2752                        2 => [
2753                            tr.xz[ch],
2754                            tr.conv_prev[base + 1],
2755                            tr.conv_prev[base],
2756                            tr.conv_prev[base + 3],
2757                        ],
2758                        _ => [
2759                            tr.xz[ch],
2760                            tr.conv_prev[base + 2],
2761                            tr.conv_prev[base + 1],
2762                            tr.conv_prev[base],
2763                        ],
2764                    };
2765                    scratch.grad_conv_w[base] = g * vals[0];
2766                    scratch.grad_conv_w[base + 1] = g * vals[1];
2767                    scratch.grad_conv_w[base + 2] = g * vals[2];
2768                    scratch.grad_conv_w[base + 3] = g * vals[3];
2769                } else {
2770                    let mut ring = tr.conv_pos_prev;
2771                    for tap in 0..self.cfg.conv_kernel {
2772                        let val = if ring == tr.conv_pos_prev {
2773                            tr.xz[ch]
2774                        } else {
2775                            tr.conv_prev[base + ring]
2776                        };
2777                        if scope.mixer_conv {
2778                            scratch.grad_conv_w[base + tap] = g * val;
2779                        }
2780                        ring = if ring == 0 {
2781                            self.cfg.conv_kernel - 1
2782                        } else {
2783                            ring - 1
2784                        };
2785                    }
2786                }
2787                if scope.mixer_conv && layer.conv_b.is_some() {
2788                    scratch.grad_conv_b[ch] = g;
2789                }
2790            }
2791
2792            if scope.mixer_conv {
2793                match optimizer {
2794                    OptimizerKind::Sgd => {
2795                        if clip > 0.0 {
2796                            for idx in 0..layer.conv_w.len() {
2797                                layer.conv_w[idx] +=
2798                                    lr * scratch.grad_conv_w[idx].clamp(-clip, clip);
2799                            }
2800                        } else {
2801                            for idx in 0..layer.conv_w.len() {
2802                                layer.conv_w[idx] += lr * scratch.grad_conv_w[idx];
2803                            }
2804                        }
2805                        if let Some(bias) = layer.conv_b.as_mut() {
2806                            if clip > 0.0 {
2807                                for idx in 0..bias.len().min(i) {
2808                                    bias[idx] += lr * scratch.grad_conv_b[idx].clamp(-clip, clip);
2809                                }
2810                            } else {
2811                                for idx in 0..bias.len().min(i) {
2812                                    bias[idx] += lr * scratch.grad_conv_b[idx];
2813                                }
2814                            }
2815                        }
2816                    }
2817                    OptimizerKind::Adam => {
2818                        let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2819                        let adam_layer =
2820                            &mut model_adam.as_mut().expect("adam state exists").layers[layer_idx];
2821                        apply_adam_vec_update(
2822                            layer.conv_w.as_mut_slice(),
2823                            scratch.grad_conv_w.as_slice(),
2824                            &mut adam_layer.conv_w,
2825                            cfg,
2826                        );
2827                        if let (Some(bias), Some(bm)) =
2828                            (layer.conv_b.as_mut(), adam_layer.conv_b.as_mut())
2829                        {
2830                            apply_adam_vec_update(
2831                                bias.as_mut_slice(),
2832                                scratch.grad_conv_b.as_slice(),
2833                                bm,
2834                                cfg,
2835                            );
2836                        }
2837                    }
2838                }
2839            }
2840
2841            // in_proj backward: grad_norm = W^T grad_xz
2842            unsafe {
2843                kernel::gemv_t(
2844                    layer.in_proj_w.as_ptr(),
2845                    scratch.grad_xz.as_ptr(),
2846                    scratch.grad_norm.as_mut_ptr(),
2847                    2 * i,
2848                    c,
2849                );
2850            }
2851            if scope.mixer_proj {
2852                match optimizer {
2853                    OptimizerKind::Sgd => {
2854                        for row in 0..(2 * i) {
2855                            let g = scratch.grad_xz[row];
2856                            let off = row * c;
2857                            for col in 0..c {
2858                                layer.in_proj_w[off + col] += lr * g * tr.norm[col];
2859                            }
2860                        }
2861                        if let Some(b) = layer.in_proj_b.as_mut() {
2862                            for row in 0..(2 * i).min(b.len()) {
2863                                b[row] += lr * scratch.grad_xz[row];
2864                            }
2865                        }
2866                    }
2867                    OptimizerKind::Adam => {
2868                        let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2869                        let adam_layer =
2870                            &mut model_adam.as_mut().expect("adam state exists").layers[layer_idx];
2871                        apply_adam_outer_update(
2872                            layer.in_proj_w.as_mut_slice(),
2873                            2 * i,
2874                            c,
2875                            scratch.grad_xz.as_slice(),
2876                            tr.norm.as_slice(),
2877                            &mut adam_layer.in_proj_w,
2878                            cfg,
2879                        );
2880                        if let (Some(b), Some(bm)) =
2881                            (layer.in_proj_b.as_mut(), adam_layer.in_proj_b.as_mut())
2882                        {
2883                            apply_adam_vec_update(
2884                                b.as_mut_slice(),
2885                                scratch.grad_xz.as_slice(),
2886                                bm,
2887                                cfg,
2888                            );
2889                        }
2890                    }
2891                }
2892            }
2893
2894            // layer norm backward + residual combine.
2895            rms_norm_backward(
2896                tr.h_in.as_slice(),
2897                layer.norm_w.as_slice(),
2898                scratch.grad_norm.as_slice(),
2899                self.cfg.layer_norm_eps,
2900                scratch.grad_h.as_mut_slice(),
2901                scratch.grad_out.as_mut_slice(),
2902            );
2903            if scope.layer_norm {
2904                match optimizer {
2905                    OptimizerKind::Sgd => {
2906                        for idx in 0..c {
2907                            layer.norm_w[idx] += lr * scratch.grad_out[idx];
2908                        }
2909                        if let Some(b) = layer.norm_b.as_mut() {
2910                            for idx in 0..c.min(b.len()) {
2911                                b[idx] += lr * scratch.grad_norm[idx];
2912                            }
2913                        }
2914                    }
2915                    OptimizerKind::Adam => {
2916                        let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2917                        let adam_layer =
2918                            &mut model_adam.as_mut().expect("adam state exists").layers[layer_idx];
2919                        apply_adam_vec_update(
2920                            layer.norm_w.as_mut_slice(),
2921                            scratch.grad_out.as_slice(),
2922                            &mut adam_layer.norm_w,
2923                            cfg,
2924                        );
2925                        if let (Some(b), Some(bm)) =
2926                            (layer.norm_b.as_mut(), adam_layer.norm_b.as_mut())
2927                        {
2928                            apply_adam_vec_update(
2929                                b.as_mut_slice(),
2930                                scratch.grad_norm.as_slice(),
2931                                bm,
2932                                cfg,
2933                            );
2934                        }
2935                    }
2936                }
2937            }
2938
2939            for idx in 0..c {
2940                scratch.grad_h[idx] += scratch.grad_residual[idx];
2941            }
2942        }
2943
2944        if scope.embed {
2945            let tok = scratch
2946                .train_token
2947                .min(self.cfg.vocab_size.saturating_sub(1));
2948            let row_off = tok * c;
2949            match optimizer {
2950                OptimizerKind::Sgd => {
2951                    for j in 0..c {
2952                        let g = if clip > 0.0 {
2953                            scratch.grad_h[j].clamp(-clip, clip)
2954                        } else {
2955                            scratch.grad_h[j]
2956                        };
2957                        self.embeddings[row_off + j] += lr * g;
2958                    }
2959                }
2960                OptimizerKind::Adam => {
2961                    let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2962                    let adam = model_adam.as_mut().expect("adam state exists");
2963                    let pm = adam.embeddings.m.as_mut_slice();
2964                    let pv = adam.embeddings.v.as_mut_slice();
2965                    let emb = self.embeddings.as_mut_slice();
2966                    for j in 0..c {
2967                        let idx = row_off + j;
2968                        let g = scratch.grad_h[j];
2969                        pm[idx] = cfg.b1 * pm[idx] + (1.0 - cfg.b1) * g;
2970                        pv[idx] = cfg.b2 * pv[idx] + (1.0 - cfg.b2) * g * g;
2971                        let m_hat = pm[idx] / cfg.bias_corr1;
2972                        let v_hat = pv[idx] / cfg.bias_corr2;
2973                        emb[idx] += cfg.lr * m_hat / (v_hat.sqrt() + cfg.eps);
2974                    }
2975                }
2976            }
2977        }
2978
2979        Ok(())
2980    }
2981
2982    fn load_native(weights: &Weights) -> Result<Self> {
2983        let emb = weights.require("model.embeddings.weight")?;
2984        if emb.shape().len() != 2 {
2985            bail!("model.embeddings.weight must be rank-2");
2986        }
2987        let vocab_size = emb.shape()[0];
2988        let hidden_size = emb.shape()[1];
2989
2990        let num_layers = count_layers(weights, "model.layers.", "mixer.in_proj.weight")?;
2991        if num_layers == 0 {
2992            bail!("no Mamba layers found in native checkpoint");
2993        }
2994
2995        let first_in = weights.require("model.layers.0.mixer.in_proj.weight")?;
2996        let first_conv = weights.require("model.layers.0.mixer.conv1d.weight")?;
2997        let first_a = weights.require("model.layers.0.mixer.A_log")?;
2998        let first_dt = weights.require("model.layers.0.mixer.dt_proj.weight")?;
2999
3000        let inner_size =
3001            infer_in_proj_inner(first_in, hidden_size, "model.layers.0.mixer.in_proj.weight")?;
3002        let conv_kernel =
3003            infer_conv_kernel(first_conv, inner_size, "model.layers.0.mixer.conv1d.weight")?;
3004        let state_size = infer_state_size(first_a, inner_size, "model.layers.0.mixer.A_log")?;
3005        let dt_rank = infer_dt_rank(first_dt, inner_size, "model.layers.0.mixer.dt_proj.weight")?;
3006
3007        let cfg = Config {
3008            vocab_size,
3009            hidden_size,
3010            num_layers,
3011            inner_size,
3012            state_size,
3013            conv_kernel,
3014            dt_rank,
3015            layer_norm_eps: 1e-5,
3016        };
3017        cfg.validate()?;
3018
3019        let embeddings = tensor_from(emb)?;
3020        let final_norm_w = tensor_from(weights.require("model.norm.weight")?)?;
3021        let final_norm_b = optional_tensor_from(weights, "model.norm.bias")?;
3022        let lm_head = if let Some(t) = weights.get("lm_head.weight") {
3023            tensor_from(t)?
3024        } else {
3025            embeddings.clone()
3026        };
3027        let lm_head_b = optional_tensor_from(weights, "lm_head.bias")?;
3028
3029        let mut layers = Vec::with_capacity(num_layers);
3030        for idx in 0..num_layers {
3031            let root = format!("model.layers.{idx}");
3032            let mixer = format!("{root}.mixer");
3033
3034            let norm_w = tensor_from(weights.require(&format!("{root}.norm.weight"))?)?;
3035            let norm_b = optional_tensor_from(weights, &format!("{root}.norm.bias"))?;
3036
3037            let in_proj_w = tensor_from(weights.require(&format!("{mixer}.in_proj.weight"))?)?;
3038            let in_proj_b = optional_tensor_from(weights, &format!("{mixer}.in_proj.bias"))?;
3039
3040            let conv_w = tensor_from_conv(
3041                weights.require(&format!("{mixer}.conv1d.weight"))?,
3042                inner_size,
3043            )?;
3044            let conv_b = optional_tensor_from(weights, &format!("{mixer}.conv1d.bias"))?;
3045
3046            let x_proj_w = tensor_from(weights.require(&format!("{mixer}.x_proj.weight"))?)?;
3047            let x_proj_b = optional_tensor_from(weights, &format!("{mixer}.x_proj.bias"))?;
3048
3049            let dt_proj_w = tensor_from(weights.require(&format!("{mixer}.dt_proj.weight"))?)?;
3050            let dt_proj_b = tensor_from(weights.require(&format!("{mixer}.dt_proj.bias"))?)?;
3051
3052            let a_log = tensor_from(weights.require(&format!("{mixer}.A_log"))?)?;
3053            let a = a_from_a_log_tensor(&a_log);
3054            let d = tensor_from(weights.require(&format!("{mixer}.D"))?)?;
3055
3056            let out_proj_w = tensor_from(weights.require(&format!("{mixer}.out_proj.weight"))?)?;
3057            let out_proj_b = optional_tensor_from(weights, &format!("{mixer}.out_proj.bias"))?;
3058
3059            validate_layer_shapes(
3060                &cfg,
3061                idx,
3062                &norm_w,
3063                norm_b.as_ref(),
3064                &in_proj_w,
3065                in_proj_b.as_ref(),
3066                &conv_w,
3067                conv_b.as_ref(),
3068                &x_proj_w,
3069                x_proj_b.as_ref(),
3070                &dt_proj_w,
3071                &dt_proj_b,
3072                &a,
3073                &d,
3074                &out_proj_w,
3075                out_proj_b.as_ref(),
3076            )?;
3077
3078            layers.push(LayerWeights {
3079                norm_w,
3080                norm_b,
3081                in_proj_w,
3082                in_proj_b,
3083                conv_w,
3084                conv_b,
3085                x_proj_w,
3086                x_proj_b,
3087                dt_proj_w,
3088                dt_proj_b,
3089                a_log,
3090                a,
3091                d,
3092                out_proj_w,
3093                out_proj_b,
3094            });
3095        }
3096
3097        Ok(Self {
3098            cfg,
3099            embeddings,
3100            final_norm_w,
3101            final_norm_b,
3102            lm_head,
3103            lm_head_b,
3104            layers,
3105        })
3106    }
3107
3108    fn load_official(weights: &Weights) -> Result<Self> {
3109        let emb = weights.require("backbone.embedding.weight")?;
3110        if emb.shape().len() != 2 {
3111            bail!("backbone.embedding.weight must be rank-2");
3112        }
3113        let vocab_size = emb.shape()[0];
3114        let hidden_size = emb.shape()[1];
3115
3116        let num_layers = count_layers(weights, "backbone.layers.", "mixer.in_proj.weight")?;
3117        if num_layers == 0 {
3118            bail!("no Mamba layers found in official checkpoint");
3119        }
3120
3121        let first_in = weights.require("backbone.layers.0.mixer.in_proj.weight")?;
3122        let first_conv = weights.require("backbone.layers.0.mixer.conv1d.weight")?;
3123        let first_a = weights.require("backbone.layers.0.mixer.A_log")?;
3124        let first_dt = weights.require("backbone.layers.0.mixer.dt_proj.weight")?;
3125
3126        let inner_size = infer_in_proj_inner(
3127            first_in,
3128            hidden_size,
3129            "backbone.layers.0.mixer.in_proj.weight",
3130        )?;
3131        let conv_kernel = infer_conv_kernel(
3132            first_conv,
3133            inner_size,
3134            "backbone.layers.0.mixer.conv1d.weight",
3135        )?;
3136        let state_size = infer_state_size(first_a, inner_size, "backbone.layers.0.mixer.A_log")?;
3137        let dt_rank = infer_dt_rank(
3138            first_dt,
3139            inner_size,
3140            "backbone.layers.0.mixer.dt_proj.weight",
3141        )?;
3142
3143        let cfg = Config {
3144            vocab_size,
3145            hidden_size,
3146            num_layers,
3147            inner_size,
3148            state_size,
3149            conv_kernel,
3150            dt_rank,
3151            layer_norm_eps: 1e-5,
3152        };
3153        cfg.validate()?;
3154
3155        let embeddings = tensor_from(emb)?;
3156        let final_norm_w = tensor_from(weights.require("norm_f.weight")?)?;
3157        let final_norm_b = optional_tensor_from(weights, "norm_f.bias")?;
3158        let lm_head = if let Some(t) = weights.get("lm_head.weight") {
3159            tensor_from(t)?
3160        } else {
3161            embeddings.clone()
3162        };
3163        let lm_head_b = optional_tensor_from(weights, "lm_head.bias")?;
3164
3165        let mut layers = Vec::with_capacity(num_layers);
3166        for idx in 0..num_layers {
3167            let root = format!("backbone.layers.{idx}");
3168            let mixer = format!("{root}.mixer");
3169
3170            let norm_w = tensor_from(weights.require(&format!("{root}.norm.weight"))?)?;
3171            let norm_b = optional_tensor_from(weights, &format!("{root}.norm.bias"))?;
3172
3173            let in_proj_w = tensor_from(weights.require(&format!("{mixer}.in_proj.weight"))?)?;
3174            let in_proj_b = optional_tensor_from(weights, &format!("{mixer}.in_proj.bias"))?;
3175
3176            let conv_w = tensor_from_conv(
3177                weights.require(&format!("{mixer}.conv1d.weight"))?,
3178                inner_size,
3179            )?;
3180            let conv_b = optional_tensor_from(weights, &format!("{mixer}.conv1d.bias"))?;
3181
3182            let x_proj_w = tensor_from(weights.require(&format!("{mixer}.x_proj.weight"))?)?;
3183            let x_proj_b = optional_tensor_from(weights, &format!("{mixer}.x_proj.bias"))?;
3184
3185            let dt_proj_w = tensor_from(weights.require(&format!("{mixer}.dt_proj.weight"))?)?;
3186            let dt_proj_b = tensor_from(weights.require(&format!("{mixer}.dt_proj.bias"))?)?;
3187
3188            let a_log = tensor_from(weights.require(&format!("{mixer}.A_log"))?)?;
3189            let a = a_from_a_log_tensor(&a_log);
3190            let d = tensor_from(weights.require(&format!("{mixer}.D"))?)?;
3191
3192            let out_proj_w = tensor_from(weights.require(&format!("{mixer}.out_proj.weight"))?)?;
3193            let out_proj_b = optional_tensor_from(weights, &format!("{mixer}.out_proj.bias"))?;
3194
3195            validate_layer_shapes(
3196                &cfg,
3197                idx,
3198                &norm_w,
3199                norm_b.as_ref(),
3200                &in_proj_w,
3201                in_proj_b.as_ref(),
3202                &conv_w,
3203                conv_b.as_ref(),
3204                &x_proj_w,
3205                x_proj_b.as_ref(),
3206                &dt_proj_w,
3207                &dt_proj_b,
3208                &a,
3209                &d,
3210                &out_proj_w,
3211                out_proj_b.as_ref(),
3212            )?;
3213
3214            layers.push(LayerWeights {
3215                norm_w,
3216                norm_b,
3217                in_proj_w,
3218                in_proj_b,
3219                conv_w,
3220                conv_b,
3221                x_proj_w,
3222                x_proj_b,
3223                dt_proj_w,
3224                dt_proj_b,
3225                a_log,
3226                a,
3227                d,
3228                out_proj_w,
3229                out_proj_b,
3230            });
3231        }
3232
3233        Ok(Self {
3234            cfg,
3235            embeddings,
3236            final_norm_w,
3237            final_norm_b,
3238            lm_head,
3239            lm_head_b,
3240            layers,
3241        })
3242    }
3243}
3244
3245fn tensor_from(t: &WeightTensor) -> Result<Tensor1D> {
3246    Ok(Tensor1D::from_vec(t.data().to_vec()))
3247}
3248
3249fn a_from_a_log_tensor(a_log: &Tensor1D) -> Tensor1D {
3250    let mut out = a_log.as_slice().to_vec();
3251    for v in &mut out {
3252        *v = -v.exp();
3253    }
3254    Tensor1D::from_vec(out)
3255}
3256
3257fn optional_tensor_from(weights: &Weights, name: &str) -> Result<Option<Tensor1D>> {
3258    match weights.get(name) {
3259        Some(t) => Ok(Some(tensor_from(t)?)),
3260        None => Ok(None),
3261    }
3262}
3263
3264fn tensor_from_conv(t: &WeightTensor, inner_size: usize) -> Result<Tensor1D> {
3265    match t.shape() {
3266        [i, _k] if *i == inner_size => Ok(Tensor1D::from_vec(t.data().to_vec())),
3267        [i, one, k] if *i == inner_size && *one == 1 => {
3268            let mut out = Vec::with_capacity(inner_size * k);
3269            let src = t.data();
3270            for ch in 0..inner_size {
3271                let off = ch * k;
3272                out.extend_from_slice(&src[off..off + k]);
3273            }
3274            Ok(Tensor1D::from_vec(out))
3275        }
3276        other => bail!("unexpected conv1d weight shape {:?}", other),
3277    }
3278}
3279
3280fn count_layers(weights: &Weights, prefix: &str, suffix: &str) -> Result<usize> {
3281    let mut max_layer = None::<usize>;
3282    for name in weights.tensor_names() {
3283        let Some(rest) = name.strip_prefix(prefix) else {
3284            continue;
3285        };
3286        let Some((idx_s, tail)) = rest.split_once('.') else {
3287            continue;
3288        };
3289        if tail != suffix {
3290            continue;
3291        }
3292        let idx = idx_s
3293            .parse::<usize>()
3294            .with_context(|| format!("invalid layer index in tensor name '{name}'"))?;
3295        max_layer = Some(max_layer.map_or(idx, |m| m.max(idx)));
3296    }
3297    Ok(max_layer.map_or(0, |m| m + 1))
3298}
3299
3300fn infer_in_proj_inner(t: &WeightTensor, hidden: usize, name: &str) -> Result<usize> {
3301    let shape = t.shape();
3302    if shape.len() != 2 {
3303        bail!("{name} must be rank-2, got {:?}", shape);
3304    }
3305    if shape[1] != hidden {
3306        bail!("{name} expected cols={}, got {}", hidden, shape[1]);
3307    }
3308    if !shape[0].is_multiple_of(2) {
3309        bail!("{name} first dim {} must be 2*d_inner", shape[0]);
3310    }
3311    Ok(shape[0] / 2)
3312}
3313
3314fn infer_conv_kernel(t: &WeightTensor, inner: usize, name: &str) -> Result<usize> {
3315    let shape = t.shape();
3316    match shape {
3317        [i, k] if *i == inner => Ok(*k),
3318        [i, one, k] if *i == inner && *one == 1 => Ok(*k),
3319        _ => bail!("{name} shape {:?} incompatible with d_inner={inner}", shape),
3320    }
3321}
3322
3323fn infer_state_size(t: &WeightTensor, inner: usize, name: &str) -> Result<usize> {
3324    let shape = t.shape();
3325    if shape.len() != 2 {
3326        bail!("{name} must be rank-2, got {:?}", shape);
3327    }
3328    if shape[0] != inner {
3329        bail!("{name} expected rows={}, got {}", inner, shape[0]);
3330    }
3331    Ok(shape[1])
3332}
3333
3334fn infer_dt_rank(t: &WeightTensor, inner: usize, name: &str) -> Result<usize> {
3335    let shape = t.shape();
3336    if shape.len() != 2 {
3337        bail!("{name} must be rank-2, got {:?}", shape);
3338    }
3339    if shape[0] != inner {
3340        bail!("{name} expected rows={}, got {}", inner, shape[0]);
3341    }
3342    Ok(shape[1])
3343}
3344
3345#[allow(clippy::too_many_arguments)]
3346fn validate_layer_shapes(
3347    cfg: &Config,
3348    idx: usize,
3349    norm_w: &Tensor1D,
3350    norm_b: Option<&Tensor1D>,
3351    in_proj_w: &Tensor1D,
3352    in_proj_b: Option<&Tensor1D>,
3353    conv_w: &Tensor1D,
3354    conv_b: Option<&Tensor1D>,
3355    x_proj_w: &Tensor1D,
3356    x_proj_b: Option<&Tensor1D>,
3357    dt_proj_w: &Tensor1D,
3358    dt_proj_b: &Tensor1D,
3359    a: &Tensor1D,
3360    d: &Tensor1D,
3361    out_proj_w: &Tensor1D,
3362    out_proj_b: Option<&Tensor1D>,
3363) -> Result<()> {
3364    let c = cfg.hidden_size;
3365    let i = cfg.inner_size;
3366    let s = cfg.state_size;
3367    let k = cfg.conv_kernel;
3368    let r = cfg.dt_rank;
3369
3370    let check = |cond: bool, msg: String| -> Result<()> {
3371        if cond {
3372            Ok(())
3373        } else {
3374            bail!("layer {idx}: {msg}")
3375        }
3376    };
3377
3378    check(
3379        norm_w.len() == c,
3380        format!("norm.weight len {} != hidden {c}", norm_w.len()),
3381    )?;
3382    if let Some(b) = norm_b {
3383        check(
3384            b.len() == c,
3385            format!("norm.bias len {} != hidden {c}", b.len()),
3386        )?;
3387    }
3388
3389    check(
3390        in_proj_w.len() == (2 * i) * c,
3391        format!("in_proj.weight len {} != {}", in_proj_w.len(), (2 * i) * c),
3392    )?;
3393    if let Some(b) = in_proj_b {
3394        check(
3395            b.len() == 2 * i,
3396            format!("in_proj.bias len {} != {}", b.len(), 2 * i),
3397        )?;
3398    }
3399
3400    check(
3401        conv_w.len() == i * k,
3402        format!("conv1d.weight len {} != {}", conv_w.len(), i * k),
3403    )?;
3404    if let Some(b) = conv_b {
3405        check(b.len() == i, format!("conv1d.bias len {} != {i}", b.len()))?;
3406    }
3407
3408    check(
3409        x_proj_w.len() == (r + 2 * s) * i,
3410        format!(
3411            "x_proj.weight len {} != {}",
3412            x_proj_w.len(),
3413            (r + 2 * s) * i
3414        ),
3415    )?;
3416    if let Some(b) = x_proj_b {
3417        check(
3418            b.len() == r + 2 * s,
3419            format!("x_proj.bias len {} != {}", b.len(), r + 2 * s),
3420        )?;
3421    }
3422
3423    check(
3424        dt_proj_w.len() == i * r,
3425        format!("dt_proj.weight len {} != {}", dt_proj_w.len(), i * r),
3426    )?;
3427    check(
3428        dt_proj_b.len() == i,
3429        format!("dt_proj.bias len {} != {i}", dt_proj_b.len()),
3430    )?;
3431
3432    check(a.len() == i * s, format!("A len {} != {}", a.len(), i * s))?;
3433    check(d.len() == i, format!("D len {} != {i}", d.len()))?;
3434
3435    check(
3436        out_proj_w.len() == c * i,
3437        format!("out_proj.weight len {} != {}", out_proj_w.len(), c * i),
3438    )?;
3439    if let Some(b) = out_proj_b {
3440        check(
3441            b.len() == c,
3442            format!("out_proj.bias len {} != {c}", b.len()),
3443        )?;
3444    }
3445
3446    Ok(())
3447}
3448
3449fn rms_norm(input: &[f32], weight: &[f32], bias: Option<&[f32]>, eps: f32, out: &mut [f32]) {
3450    debug_assert_eq!(input.len(), weight.len());
3451    debug_assert_eq!(input.len(), out.len());
3452    if let Some(b) = bias {
3453        debug_assert_eq!(b.len(), input.len());
3454    }
3455
3456    let mut mean_sq = 0.0f32;
3457    for &x in input {
3458        mean_sq += x * x;
3459    }
3460    mean_sq /= input.len().max(1) as f32;
3461    let inv = (mean_sq + eps).sqrt().recip();
3462
3463    if let Some(b) = bias {
3464        for idx in 0..input.len() {
3465            out[idx] = input[idx] * inv * weight[idx] + b[idx];
3466        }
3467    } else {
3468        for idx in 0..input.len() {
3469            out[idx] = input[idx] * inv * weight[idx];
3470        }
3471    }
3472}
3473
3474fn rms_norm_backward(
3475    input: &[f32],
3476    weight: &[f32],
3477    grad_out: &[f32],
3478    eps: f32,
3479    grad_input: &mut [f32],
3480    grad_weight: &mut [f32],
3481) {
3482    debug_assert_eq!(input.len(), weight.len());
3483    debug_assert_eq!(input.len(), grad_out.len());
3484    debug_assert_eq!(input.len(), grad_input.len());
3485    debug_assert_eq!(input.len(), grad_weight.len());
3486
3487    let n = input.len().max(1) as f32;
3488    let mut mean_sq = 0.0f32;
3489    for &x in input {
3490        mean_sq += x * x;
3491    }
3492    mean_sq /= n;
3493    let inv = (mean_sq + eps).sqrt().recip();
3494
3495    let mut s = 0.0f32;
3496    for idx in 0..input.len() {
3497        let gw = grad_out[idx] * weight[idx];
3498        grad_weight[idx] = grad_out[idx] * input[idx] * inv;
3499        s += gw * input[idx];
3500    }
3501    let coeff = -s * inv * inv * inv / n;
3502    for idx in 0..input.len() {
3503        grad_input[idx] = grad_out[idx] * weight[idx] * inv + input[idx] * coeff;
3504    }
3505}
3506
3507#[inline(always)]
3508fn add_vec_grad(dst: &mut [f32], src: &[f32]) {
3509    let n = dst.len().min(src.len());
3510    for idx in 0..n {
3511        dst[idx] += src[idx];
3512    }
3513}
3514
3515#[inline(always)]
3516fn sgd_vec_update(param: &mut [f32], grad: &[f32], lr: f32, clip: f32) {
3517    let n = param.len().min(grad.len());
3518    if clip > 0.0 {
3519        for idx in 0..n {
3520            param[idx] += lr * grad[idx].clamp(-clip, clip);
3521        }
3522    } else {
3523        for idx in 0..n {
3524            param[idx] += lr * grad[idx];
3525        }
3526    }
3527}
3528
3529#[allow(clippy::needless_range_loop)]
3530#[inline(always)]
3531fn add_outer_grad(dst: &mut [f32], rows: usize, cols: usize, left: &[f32], right: &[f32]) {
3532    let rows = rows.min(left.len());
3533    let cols = cols.min(right.len());
3534    let n = dst.len();
3535    for row in 0..rows {
3536        let off = row * cols;
3537        if off >= n {
3538            break;
3539        }
3540        let limit = (n - off).min(cols);
3541        let g = left[row];
3542        for col in 0..limit {
3543            dst[off + col] += g * right[col];
3544        }
3545    }
3546}
3547
3548#[inline(always)]
3549fn apply_adam_vec_update_raw(
3550    param: &mut [f32],
3551    grad: &[f32],
3552    moment_m: &mut [f32],
3553    moment_v: &mut [f32],
3554    step: &AdamStep,
3555) {
3556    let n = param
3557        .len()
3558        .min(grad.len())
3559        .min(moment_m.len())
3560        .min(moment_v.len());
3561    if n == 0 {
3562        return;
3563    }
3564    let b1 = step.b1;
3565    let b2 = step.b2;
3566    let one_m_b1 = 1.0 - b1;
3567    let one_m_b2 = 1.0 - b2;
3568    let lr = step.lr;
3569    let eps = step.eps;
3570    let inv_bc1 = 1.0 / step.bias_corr1;
3571    let inv_bc2 = 1.0 / step.bias_corr2;
3572    if step.clip > 0.0 {
3573        let clip = step.clip;
3574        for idx in 0..n {
3575            let g = grad[idx].clamp(-clip, clip);
3576            let m = b1 * moment_m[idx] + one_m_b1 * g;
3577            let v = b2 * moment_v[idx] + one_m_b2 * g * g;
3578            moment_m[idx] = m;
3579            moment_v[idx] = v;
3580            let m_hat = m * inv_bc1;
3581            let v_hat = v * inv_bc2;
3582            param[idx] += lr * m_hat / (v_hat.sqrt() + eps);
3583        }
3584    } else {
3585        for idx in 0..n {
3586            let g = grad[idx];
3587            let m = b1 * moment_m[idx] + one_m_b1 * g;
3588            let v = b2 * moment_v[idx] + one_m_b2 * g * g;
3589            moment_m[idx] = m;
3590            moment_v[idx] = v;
3591            let m_hat = m * inv_bc1;
3592            let v_hat = v * inv_bc2;
3593            param[idx] += lr * m_hat / (v_hat.sqrt() + eps);
3594        }
3595    }
3596}
3597
3598#[inline(always)]
3599fn apply_adam_vec_update(
3600    param: &mut [f32],
3601    grad: &[f32],
3602    adam: &mut AdamTensorState,
3603    step: &AdamStep,
3604) {
3605    let n = param
3606        .len()
3607        .min(grad.len())
3608        .min(adam.m.len())
3609        .min(adam.v.len());
3610    if n == 0 {
3611        return;
3612    }
3613    let b1 = step.b1;
3614    let b2 = step.b2;
3615    let one_m_b1 = 1.0 - b1;
3616    let one_m_b2 = 1.0 - b2;
3617    let lr = step.lr;
3618    let eps = step.eps;
3619    let inv_bc1 = 1.0 / step.bias_corr1;
3620    let inv_bc2 = 1.0 / step.bias_corr2;
3621    let do_clip = step.clip > 0.0;
3622    let clip = step.clip;
3623    let m = adam.m.as_mut_slice();
3624    let v = adam.v.as_mut_slice();
3625    if do_clip {
3626        for idx in 0..n {
3627            let g = grad[idx].clamp(-clip, clip);
3628            let mm = b1 * m[idx] + one_m_b1 * g;
3629            let vv = b2 * v[idx] + one_m_b2 * g * g;
3630            m[idx] = mm;
3631            v[idx] = vv;
3632            let m_hat = mm * inv_bc1;
3633            let v_hat = vv * inv_bc2;
3634            param[idx] += lr * m_hat / (v_hat.sqrt() + eps);
3635        }
3636    } else {
3637        let mut idx = 0usize;
3638        unsafe {
3639            let b1v = f32x8::splat(b1);
3640            let b2v = f32x8::splat(b2);
3641            let one_b1v = f32x8::splat(one_m_b1);
3642            let one_b2v = f32x8::splat(one_m_b2);
3643            let inv_bc1v = f32x8::splat(inv_bc1);
3644            let inv_bc2v = f32x8::splat(inv_bc2);
3645            let lrv = f32x8::splat(lr);
3646            let epsv = f32x8::splat(eps);
3647            while idx + 8 <= n {
3648                let gv = grad.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3649                let mv = m.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3650                let vv = v.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3651                let mm = mv * b1v + gv * one_b1v;
3652                let vv2 = vv * b2v + (gv * gv) * one_b2v;
3653                m.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(mm);
3654                v.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(vv2);
3655
3656                let pv = param.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3657                let upd = ((mm * inv_bc1v) / ((vv2 * inv_bc2v).sqrt() + epsv)) * lrv;
3658                param
3659                    .as_mut_ptr()
3660                    .add(idx)
3661                    .cast::<f32x8>()
3662                    .write_unaligned(pv + upd);
3663                idx += 8;
3664            }
3665        }
3666        while idx < n {
3667            let g = grad[idx];
3668            let mm = b1 * m[idx] + one_m_b1 * g;
3669            let vv = b2 * v[idx] + one_m_b2 * g * g;
3670            m[idx] = mm;
3671            v[idx] = vv;
3672            let m_hat = mm * inv_bc1;
3673            let v_hat = vv * inv_bc2;
3674            param[idx] += lr * m_hat / (v_hat.sqrt() + eps);
3675            idx += 1;
3676        }
3677    }
3678}
3679
3680#[inline(always)]
3681fn apply_adam_vec_update_and_sync_neg_exp(
3682    param_log: &mut [f32],
3683    param_value: &mut [f32],
3684    grad: &[f32],
3685    adam: &mut AdamTensorState,
3686    step: &AdamStep,
3687) {
3688    let n = param_log
3689        .len()
3690        .min(param_value.len())
3691        .min(grad.len())
3692        .min(adam.m.len())
3693        .min(adam.v.len());
3694    if n == 0 {
3695        return;
3696    }
3697    let b1 = step.b1;
3698    let b2 = step.b2;
3699    let one_m_b1 = 1.0 - b1;
3700    let one_m_b2 = 1.0 - b2;
3701    let lr = step.lr;
3702    let eps = step.eps;
3703    let inv_bc1 = 1.0 / step.bias_corr1;
3704    let inv_bc2 = 1.0 / step.bias_corr2;
3705    let do_clip = step.clip > 0.0;
3706    let clip = step.clip;
3707    let m = adam.m.as_mut_slice();
3708    let v = adam.v.as_mut_slice();
3709    if do_clip {
3710        for idx in 0..n {
3711            let g = grad[idx].clamp(-clip, clip);
3712            let mm = b1 * m[idx] + one_m_b1 * g;
3713            let vv = b2 * v[idx] + one_m_b2 * g * g;
3714            m[idx] = mm;
3715            v[idx] = vv;
3716            let m_hat = mm * inv_bc1;
3717            let v_hat = vv * inv_bc2;
3718            let new_log = param_log[idx] + lr * m_hat / (v_hat.sqrt() + eps);
3719            param_log[idx] = new_log;
3720            param_value[idx] = -new_log.exp();
3721        }
3722        return;
3723    }
3724
3725    let mut idx = 0usize;
3726    unsafe {
3727        let b1v = f32x8::splat(b1);
3728        let b2v = f32x8::splat(b2);
3729        let one_b1v = f32x8::splat(one_m_b1);
3730        let one_b2v = f32x8::splat(one_m_b2);
3731        let inv_bc1v = f32x8::splat(inv_bc1);
3732        let inv_bc2v = f32x8::splat(inv_bc2);
3733        let lrv = f32x8::splat(lr);
3734        let epsv = f32x8::splat(eps);
3735        while idx + 8 <= n {
3736            let gv = grad.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3737            let mv = m.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3738            let vv = v.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3739            let mm = mv * b1v + gv * one_b1v;
3740            let vv2 = vv * b2v + (gv * gv) * one_b2v;
3741            m.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(mm);
3742            v.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(vv2);
3743
3744            let pv = param_log.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3745            let new_log = pv + ((mm * inv_bc1v) / ((vv2 * inv_bc2v).sqrt() + epsv)) * lrv;
3746            param_log
3747                .as_mut_ptr()
3748                .add(idx)
3749                .cast::<f32x8>()
3750                .write_unaligned(new_log);
3751            let lanes = new_log.to_array();
3752            for (lane, value) in lanes.iter().enumerate() {
3753                param_value[idx + lane] = -value.exp();
3754            }
3755            idx += 8;
3756        }
3757    }
3758    while idx < n {
3759        let g = grad[idx];
3760        let mm = b1 * m[idx] + one_m_b1 * g;
3761        let vv = b2 * v[idx] + one_m_b2 * g * g;
3762        m[idx] = mm;
3763        v[idx] = vv;
3764        let m_hat = mm * inv_bc1;
3765        let v_hat = vv * inv_bc2;
3766        let new_log = param_log[idx] + lr * m_hat / (v_hat.sqrt() + eps);
3767        param_log[idx] = new_log;
3768        param_value[idx] = -new_log.exp();
3769        idx += 1;
3770    }
3771}
3772
3773#[inline(always)]
3774#[allow(clippy::needless_range_loop)]
3775fn apply_adam_outer_update(
3776    param: &mut [f32],
3777    rows: usize,
3778    cols: usize,
3779    left: &[f32],
3780    right: &[f32],
3781    adam: &mut AdamTensorState,
3782    step: &AdamStep,
3783) {
3784    let rows = rows.min(left.len());
3785    let cols = cols.min(right.len());
3786    let n = param.len().min(adam.m.len()).min(adam.v.len());
3787    if rows == 0 || cols == 0 || n == 0 {
3788        return;
3789    }
3790    let b1 = step.b1;
3791    let b2 = step.b2;
3792    let one_m_b1 = 1.0 - b1;
3793    let one_m_b2 = 1.0 - b2;
3794    let lr = step.lr;
3795    let eps = step.eps;
3796    let inv_bc1 = 1.0 / step.bias_corr1;
3797    let inv_bc2 = 1.0 / step.bias_corr2;
3798    let do_clip = step.clip > 0.0;
3799    let clip = step.clip;
3800    let m = adam.m.as_mut_slice();
3801    let v = adam.v.as_mut_slice();
3802    let b1v = f32x8::splat(b1);
3803    let b2v = f32x8::splat(b2);
3804    let one_b1v = f32x8::splat(one_m_b1);
3805    let one_b2v = f32x8::splat(one_m_b2);
3806    let inv_bc1v = f32x8::splat(inv_bc1);
3807    let inv_bc2v = f32x8::splat(inv_bc2);
3808    let epsv = f32x8::splat(eps);
3809    let lrv = f32x8::splat(lr);
3810    for row in 0..rows {
3811        let g_row = left[row];
3812        let off = row * cols;
3813        if off >= n {
3814            break;
3815        }
3816        let row_cols = (n - off).min(cols);
3817        if do_clip {
3818            for col in 0..row_cols {
3819                let idx = off + col;
3820                let g = (g_row * right[col]).clamp(-clip, clip);
3821                let mm = b1 * m[idx] + one_m_b1 * g;
3822                let vv = b2 * v[idx] + one_m_b2 * g * g;
3823                m[idx] = mm;
3824                v[idx] = vv;
3825                let m_hat = mm * inv_bc1;
3826                let v_hat = vv * inv_bc2;
3827                param[idx] += lr * m_hat / (v_hat.sqrt() + eps);
3828            }
3829        } else {
3830            let mut col = 0usize;
3831            unsafe {
3832                let g8 = f32x8::splat(g_row);
3833                while col + 8 <= row_cols {
3834                    let idx = off + col;
3835                    let rv = right.as_ptr().add(col).cast::<f32x8>().read_unaligned();
3836                    let gv = g8 * rv;
3837
3838                    let mv = m.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3839                    let vv = v.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3840                    let mm = mv * b1v + gv * one_b1v;
3841                    let vv2 = vv * b2v + (gv * gv) * one_b2v;
3842                    m.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(mm);
3843                    v.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(vv2);
3844
3845                    let pv = param.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3846                    let upd = ((mm * inv_bc1v) / ((vv2 * inv_bc2v).sqrt() + epsv)) * lrv;
3847                    param
3848                        .as_mut_ptr()
3849                        .add(idx)
3850                        .cast::<f32x8>()
3851                        .write_unaligned(pv + upd);
3852                    col += 8;
3853                }
3854            }
3855            while col < row_cols {
3856                let idx = off + col;
3857                let g = g_row * right[col];
3858                let mm = b1 * m[idx] + one_m_b1 * g;
3859                let vv = b2 * v[idx] + one_m_b2 * g * g;
3860                m[idx] = mm;
3861                v[idx] = vv;
3862                let m_hat = mm * inv_bc1;
3863                let v_hat = vv * inv_bc2;
3864                param[idx] += lr * m_hat / (v_hat.sqrt() + eps);
3865                col += 1;
3866            }
3867        }
3868    }
3869}
3870
3871fn depthwise_conv_step(
3872    x: &[f32],
3873    conv_w: &Tensor1D,
3874    conv_b: Option<&Tensor1D>,
3875    conv_kernel: usize,
3876    state: &mut LayerState,
3877    out: &mut [f32],
3878) {
3879    if conv_kernel == 4 {
3880        depthwise_conv_step_k4(x, conv_w, conv_b, state, out);
3881        return;
3882    }
3883    let inner = x.len();
3884    debug_assert_eq!(out.len(), inner);
3885    debug_assert_eq!(conv_w.len(), inner * conv_kernel);
3886
3887    let pos = state.conv_pos;
3888    let conv_state = state.conv.as_mut_slice();
3889    let weight = conv_w.as_slice();
3890
3891    for ch in 0..inner {
3892        let base = ch * conv_kernel;
3893        conv_state[base + pos] = x[ch];
3894
3895        let mut acc = conv_b.as_ref().map_or(0.0, |b| b[ch]);
3896        let mut ring_idx = pos;
3897        for tap in 0..conv_kernel {
3898            acc += conv_state[base + ring_idx] * weight[base + tap];
3899            ring_idx = if ring_idx == 0 {
3900                conv_kernel - 1
3901            } else {
3902                ring_idx - 1
3903            };
3904        }
3905        out[ch] = acc;
3906    }
3907
3908    state.conv_pos = if pos + 1 == conv_kernel { 0 } else { pos + 1 };
3909}
3910
3911#[inline(always)]
3912fn depthwise_conv_step_k4(
3913    x: &[f32],
3914    conv_w: &Tensor1D,
3915    conv_b: Option<&Tensor1D>,
3916    state: &mut LayerState,
3917    out: &mut [f32],
3918) {
3919    let inner = x.len();
3920    debug_assert_eq!(out.len(), inner);
3921    debug_assert_eq!(conv_w.len(), inner * 4);
3922
3923    let pos = state.conv_pos;
3924    let conv_state = state.conv.as_mut_slice();
3925    let weight = conv_w.as_slice();
3926
3927    for ch in 0..inner {
3928        let base = ch * 4;
3929        conv_state[base + pos] = x[ch];
3930        let acc = match pos {
3931            0 => {
3932                conv_state[base] * weight[base]
3933                    + conv_state[base + 3] * weight[base + 1]
3934                    + conv_state[base + 2] * weight[base + 2]
3935                    + conv_state[base + 1] * weight[base + 3]
3936            }
3937            1 => {
3938                conv_state[base + 1] * weight[base]
3939                    + conv_state[base] * weight[base + 1]
3940                    + conv_state[base + 3] * weight[base + 2]
3941                    + conv_state[base + 2] * weight[base + 3]
3942            }
3943            2 => {
3944                conv_state[base + 2] * weight[base]
3945                    + conv_state[base + 1] * weight[base + 1]
3946                    + conv_state[base] * weight[base + 2]
3947                    + conv_state[base + 3] * weight[base + 3]
3948            }
3949            _ => {
3950                conv_state[base + 3] * weight[base]
3951                    + conv_state[base + 2] * weight[base + 1]
3952                    + conv_state[base + 1] * weight[base + 2]
3953                    + conv_state[base] * weight[base + 3]
3954            }
3955        };
3956        out[ch] = acc + conv_b.as_ref().map_or(0.0, |b| b[ch]);
3957    }
3958
3959    state.conv_pos = (pos + 1) & 3;
3960}
3961
3962#[inline(always)]
3963unsafe fn selective_scan_state16<const CAPTURE: bool>(
3964    row_a: *const f32,
3965    row_ssm: *mut f32,
3966    dt: f32,
3967    x_dt: f32,
3968    b_ptr: *const f32,
3969    c_ptr: *const f32,
3970    trace_d_a: *mut f32,
3971) -> f32 {
3972    let mut y = 0.0f32;
3973    let mut j = 0usize;
3974    while j < 16 {
3975        let prev = *row_ssm.add(j);
3976        let d_a = (dt * *row_a.add(j)).exp();
3977        if CAPTURE {
3978            *trace_d_a.add(j) = d_a;
3979        }
3980        let next = prev * d_a + x_dt * *b_ptr.add(j);
3981        *row_ssm.add(j) = next;
3982        y += next * *c_ptr.add(j);
3983        j += 1;
3984    }
3985    y
3986}
3987
3988#[inline(always)]
3989fn silu(x: f32) -> f32 {
3990    x / (1.0 + (-x).exp())
3991}
3992
3993#[inline(always)]
3994fn sigmoid(x: f32) -> f32 {
3995    1.0 / (1.0 + (-x).exp())
3996}
3997
3998#[inline(always)]
3999fn silu_with_sigmoid(x: f32) -> (f32, f32) {
4000    let denom = 1.0 + (-x).exp();
4001    (x / denom, 1.0 / denom)
4002}
4003
4004#[inline(always)]
4005fn silu_grad_from_sigmoid(x: f32, s: f32) -> f32 {
4006    s * (1.0 + x * (1.0 - s))
4007}
4008
4009#[inline(always)]
4010fn softplus(x: f32) -> f32 {
4011    if x > 20.0 { x } else { (1.0 + x.exp()).ln() }
4012}
4013
4014struct MambaRng {
4015    state: u64,
4016}
4017
4018impl MambaRng {
4019    fn new(seed: u64) -> Self {
4020        Self {
4021            state: seed ^ 0x9E37_79B9_7F4A_7C15,
4022        }
4023    }
4024
4025    #[inline]
4026    fn next_u32(&mut self) -> u32 {
4027        self.state = self
4028            .state
4029            .wrapping_mul(6_364_136_223_846_793_005)
4030            .wrapping_add(1);
4031        (self.state >> 32) as u32
4032    }
4033
4034    #[inline]
4035    fn next_f32(&mut self) -> f32 {
4036        let v = self.next_u32() as f32;
4037        v * (1.0 / (u32::MAX as f32))
4038    }
4039}
4040
4041#[inline]
4042fn init_uniform(t: &mut Tensor1D, rng: &mut MambaRng, scale: f32) {
4043    for v in t.as_mut_slice() {
4044        let r = rng.next_f32() - 0.5;
4045        *v = r * 2.0 * scale;
4046    }
4047}
4048
4049#[inline]
4050fn init_const(t: &mut Tensor1D, value: f32) {
4051    t.as_mut_slice().fill(value);
4052}
4053
4054#[cfg(test)]
4055mod tests {
4056    use super::*;
4057
4058    fn target_log_prob(model: &Model, token: u32, target: u8) -> f32 {
4059        let mut state = model.new_state();
4060        let mut scratch = ScratchBuffers::new(model.config());
4061        let logits = model.forward(&mut scratch, token, &mut state);
4062        let mut max_logit = f32::NEG_INFINITY;
4063        for &logit in logits {
4064            max_logit = max_logit.max(logit);
4065        }
4066        let mut denom = 0.0f64;
4067        for &logit in logits {
4068            denom += ((logit - max_logit) as f64).exp();
4069        }
4070        let p = ((logits[target as usize] - max_logit) as f64).exp() / denom;
4071        p.max(1e-30).ln() as f32
4072    }
4073
4074    #[test]
4075    fn forward_is_deterministic_for_same_input_and_state() {
4076        let cfg = Config {
4077            vocab_size: 256,
4078            hidden_size: 64,
4079            num_layers: 2,
4080            inner_size: 96,
4081            state_size: 8,
4082            conv_kernel: 4,
4083            dt_rank: 8,
4084            layer_norm_eps: 1e-5,
4085        };
4086        let model = Model::new_random(cfg.clone(), 1234).expect("random model");
4087        let mut s1 = model.new_state();
4088        let mut s2 = model.new_state();
4089        let mut b1 = ScratchBuffers::new(&cfg);
4090        let mut b2 = ScratchBuffers::new(&cfg);
4091
4092        let seq = b"deterministic mamba";
4093        for &tok in seq {
4094            let l1 = model.forward(&mut b1, tok as u32, &mut s1).to_vec();
4095            let l2 = model.forward(&mut b2, tok as u32, &mut s2).to_vec();
4096            assert_eq!(l1.len(), l2.len());
4097            for (a, b) in l1.iter().zip(l2.iter()) {
4098                assert_eq!(a.to_bits(), b.to_bits());
4099            }
4100        }
4101    }
4102
4103    #[test]
4104    fn traced_and_untraced_forward_match_exactly() {
4105        let cfg = Config {
4106            vocab_size: 256,
4107            hidden_size: 64,
4108            num_layers: 2,
4109            inner_size: 96,
4110            state_size: 8,
4111            conv_kernel: 4,
4112            dt_rank: 8,
4113            layer_norm_eps: 1e-5,
4114        };
4115        let model = Model::new_random(cfg.clone(), 4321).expect("random model");
4116        let mut traced_state = model.new_state();
4117        let mut plain_state = model.new_state();
4118        let mut traced_scratch = ScratchBuffers::new(&cfg);
4119        let mut plain_scratch = ScratchBuffers::new(&cfg);
4120        traced_scratch.set_capture_train_trace(true);
4121        plain_scratch.set_capture_train_trace(false);
4122
4123        let seq = b"trace equivalence for mamba";
4124        for &tok in seq {
4125            let traced_logits = model
4126                .forward(&mut traced_scratch, tok as u32, &mut traced_state)
4127                .to_vec();
4128            let plain_logits = model
4129                .forward(&mut plain_scratch, tok as u32, &mut plain_state)
4130                .to_vec();
4131            for (a, b) in traced_logits.iter().zip(plain_logits.iter()) {
4132                assert_eq!(a.to_bits(), b.to_bits());
4133            }
4134            for (tr_layer, plain_layer) in traced_state.layers.iter().zip(plain_state.layers.iter())
4135            {
4136                for (&a, &b) in tr_layer
4137                    .conv
4138                    .as_slice()
4139                    .iter()
4140                    .zip(plain_layer.conv.as_slice())
4141                {
4142                    assert_eq!(a.to_bits(), b.to_bits());
4143                }
4144                for (&a, &b) in tr_layer
4145                    .ssm
4146                    .as_slice()
4147                    .iter()
4148                    .zip(plain_layer.ssm.as_slice())
4149                {
4150                    assert_eq!(a.to_bits(), b.to_bits());
4151                }
4152                assert_eq!(tr_layer.conv_pos, plain_layer.conv_pos);
4153            }
4154        }
4155    }
4156
4157    #[test]
4158    fn online_embed_gradient_matches_finite_difference() {
4159        let cfg = Config {
4160            vocab_size: 256,
4161            hidden_size: 16,
4162            num_layers: 2,
4163            inner_size: 24,
4164            state_size: 4,
4165            conv_kernel: 3,
4166            dt_rank: 4,
4167            layer_norm_eps: 1e-5,
4168        };
4169        let token = 7u32;
4170        let target = 19u8;
4171        let lr = 1e-3f32;
4172        let eps = 1e-3f32;
4173
4174        let model = Model::new_random(cfg.clone(), 99).expect("random model");
4175        let mut state = model.new_state();
4176        let mut scratch = ScratchBuffers::new(&cfg);
4177        scratch.set_capture_train_trace(true);
4178        let logits = model.forward(&mut scratch, token, &mut state);
4179
4180        let mut pdf = vec![0.0f64; cfg.vocab_size];
4181        let mut max_logit = f32::NEG_INFINITY;
4182        for &logit in logits {
4183            max_logit = max_logit.max(logit);
4184        }
4185        let mut denom = 0.0f64;
4186        for &logit in logits {
4187            denom += ((logit - max_logit) as f64).exp();
4188        }
4189        for (idx, out) in pdf.iter_mut().enumerate() {
4190            *out = ((logits[idx] - max_logit) as f64).exp() / denom;
4191        }
4192
4193        let base = model.clone();
4194        let mut trained = base.clone();
4195        let mut train_scratch = scratch.clone();
4196        trained
4197            .online_train_step_bptt1(
4198                &mut train_scratch,
4199                &state,
4200                target,
4201                &pdf,
4202                TrainScopeMask {
4203                    embed: true,
4204                    ..TrainScopeMask::default()
4205                },
4206                OptimizerKind::Sgd,
4207                lr,
4208                0.0,
4209                &mut 0usize,
4210                None,
4211                None,
4212                None,
4213                None,
4214            )
4215            .expect("training step");
4216
4217        let param_idx = token as usize * cfg.hidden_size;
4218        let analytic = (trained.embeddings[param_idx] - base.embeddings[param_idx]) / lr;
4219
4220        let mut plus = base.clone();
4221        plus.embeddings[param_idx] += eps;
4222        let mut minus = base.clone();
4223        minus.embeddings[param_idx] -= eps;
4224        let numeric = (target_log_prob(&plus, token, target)
4225            - target_log_prob(&minus, token, target))
4226            / (2.0 * eps);
4227
4228        let diff = (analytic - numeric).abs();
4229        let scale = analytic.abs().max(numeric.abs()).max(1.0);
4230        assert!(
4231            diff <= 2e-2 * scale,
4232            "analytic={analytic} numeric={numeric} diff={diff}"
4233        );
4234    }
4235
4236    fn test_cfg() -> Config {
4237        Config {
4238            vocab_size: 256,
4239            hidden_size: 32,
4240            num_layers: 1,
4241            inner_size: 48,
4242            state_size: 6,
4243            conv_kernel: 3,
4244            dt_rank: 6,
4245            layer_norm_eps: 1e-5,
4246        }
4247    }
4248
4249    fn softmax_loss(logits: &[f32], target: u8) -> f64 {
4250        let max_logit = logits
4251            .iter()
4252            .copied()
4253            .fold(f32::NEG_INFINITY, |a, b| a.max(b));
4254        let mut denom = 0.0f64;
4255        for &z in logits {
4256            denom += ((z - max_logit) as f64).exp();
4257        }
4258        let p = ((logits[target as usize] - max_logit) as f64).exp() / denom.max(1e-300);
4259        -p.max(1e-300).ln()
4260    }
4261
4262    fn softmax_pdf(logits: &[f32]) -> Vec<f64> {
4263        let mut pdf = vec![0.0f64; logits.len()];
4264        let max_logit = logits
4265            .iter()
4266            .copied()
4267            .fold(f32::NEG_INFINITY, |a, b| a.max(b));
4268        let mut denom = 0.0f64;
4269        for &z in logits {
4270            denom += ((z - max_logit) as f64).exp();
4271        }
4272        let inv = 1.0 / denom.max(1e-300);
4273        for (idx, out) in pdf.iter_mut().enumerate() {
4274            *out = ((logits[idx] - max_logit) as f64).exp() * inv;
4275        }
4276        pdf
4277    }
4278
4279    fn segment_loss(model: &Model, cfg: &Config, steps: &[(u32, u8)]) -> f64 {
4280        if steps.is_empty() {
4281            return 0.0;
4282        }
4283        let mut scratch = ScratchBuffers::new(cfg);
4284        let mut state = model.new_state();
4285        let mut loss = 0.0f64;
4286        for &(input, target) in steps {
4287            let logits = model.forward(&mut scratch, input, &mut state);
4288            loss += softmax_loss(logits, target);
4289        }
4290        loss / (steps.len() as f64)
4291    }
4292
4293    fn segment_grads(model: &Model, cfg: &Config, steps: &[(u32, u8)]) -> FullGradState {
4294        let mut scratch = ScratchBuffers::new(cfg);
4295        let mut state = model.new_state();
4296        let mut states = Vec::with_capacity(steps.len() + 1);
4297        let mut traces = Vec::with_capacity(steps.len());
4298        let mut pdfs = Vec::with_capacity(steps.len());
4299        states.push(state.clone());
4300        for &(input, _) in steps {
4301            scratch.set_capture_train_trace(true);
4302            let logits = model.forward(&mut scratch, input, &mut state);
4303            pdfs.push(softmax_pdf(logits));
4304            traces.push(TokenTrainTrace::from_scratch(&scratch));
4305            states.push(state.clone());
4306        }
4307        let mut grads = model.new_full_grad_state();
4308        let mut recurrent = model.new_recurrent_grad_state();
4309        recurrent.zero();
4310        let scope = TrainScopeMask {
4311            embed: true,
4312            layer_norm: true,
4313            mixer_conv: true,
4314            mixer_ssm: true,
4315            mixer_proj: true,
4316            head: true,
4317            bias: false,
4318        };
4319        let grad_scale = 1.0f32 / (steps.len() as f32);
4320        for idx in (0..steps.len()).rev() {
4321            model
4322                .accumulate_token_step_gradients(
4323                    &mut scratch,
4324                    &traces[idx],
4325                    &states[idx + 1],
4326                    steps[idx].1,
4327                    &pdfs[idx],
4328                    grad_scale,
4329                    scope,
4330                    &mut grads,
4331                    None,
4332                    &mut recurrent,
4333                )
4334                .expect("segment gradient accumulation");
4335        }
4336        grads
4337    }
4338
4339    #[derive(Clone, Copy, Debug)]
4340    enum Probe {
4341        Embed,
4342        FinalNormW,
4343        LayerNormW,
4344        InProjW,
4345        ConvW,
4346        SsmA,
4347        OutProjW,
4348        LmHead,
4349    }
4350
4351    fn probe_value(model: &Model, probe: Probe) -> f32 {
4352        match probe {
4353            Probe::Embed => model.embeddings[7],
4354            Probe::FinalNormW => model.final_norm_w[5],
4355            Probe::LayerNormW => model.layers[0].norm_w[9],
4356            Probe::InProjW => model.layers[0].in_proj_w[13],
4357            Probe::ConvW => model.layers[0].conv_w[4],
4358            Probe::SsmA => model.layers[0].a[11],
4359            Probe::OutProjW => model.layers[0].out_proj_w[17],
4360            Probe::LmHead => model.lm_head[23],
4361        }
4362    }
4363
4364    fn set_probe(model: &mut Model, probe: Probe, value: f32) {
4365        match probe {
4366            Probe::Embed => model.embeddings[7] = value,
4367            Probe::FinalNormW => model.final_norm_w[5] = value,
4368            Probe::LayerNormW => model.layers[0].norm_w[9] = value,
4369            Probe::InProjW => model.layers[0].in_proj_w[13] = value,
4370            Probe::ConvW => model.layers[0].conv_w[4] = value,
4371            Probe::SsmA => model.layers[0].a[11] = value,
4372            Probe::OutProjW => model.layers[0].out_proj_w[17] = value,
4373            Probe::LmHead => model.lm_head[23] = value,
4374        }
4375    }
4376
4377    fn probe_grad(grads: &FullGradState, probe: Probe) -> f32 {
4378        match probe {
4379            Probe::Embed => grads.embeddings[7],
4380            Probe::FinalNormW => grads.final_norm_w[5],
4381            Probe::LayerNormW => grads.layers[0].norm_w[9],
4382            Probe::InProjW => grads.layers[0].in_proj_w[13],
4383            Probe::ConvW => grads.layers[0].conv_w[4],
4384            Probe::SsmA => grads.layers[0].a[11],
4385            Probe::OutProjW => grads.layers[0].out_proj_w[17],
4386            Probe::LmHead => grads.lm_head[23],
4387        }
4388    }
4389
4390    #[test]
4391    fn tbptt_segment_gradients_match_finite_difference() {
4392        let cfg = test_cfg();
4393        cfg.validate().expect("valid test config");
4394        let model = Model::new_random(cfg.clone(), 0xD00D_F00D).expect("random model");
4395        let steps = [(0u32, 1u8), (1, 2), (2, 3)];
4396        let grads = segment_grads(&model, &cfg, &steps);
4397        let eps = 1e-3f32;
4398
4399        for probe in [
4400            Probe::Embed,
4401            Probe::FinalNormW,
4402            Probe::LayerNormW,
4403            Probe::InProjW,
4404            Probe::ConvW,
4405            Probe::SsmA,
4406            Probe::OutProjW,
4407            Probe::LmHead,
4408        ] {
4409            let analytic = probe_grad(&grads, probe);
4410
4411            let mut plus = model.clone();
4412            let base = probe_value(&plus, probe);
4413            set_probe(&mut plus, probe, base + eps);
4414            let loss_plus = segment_loss(&plus, &cfg, &steps);
4415
4416            let mut minus = model.clone();
4417            set_probe(&mut minus, probe, base - eps);
4418            let loss_minus = segment_loss(&minus, &cfg, &steps);
4419
4420            let numeric = -((loss_plus - loss_minus) / (2.0 * eps as f64)) as f32;
4421            let tol = 6e-2f32.max(analytic.abs().max(numeric.abs()) * 1e-1);
4422            assert!(
4423                (analytic - numeric).abs() <= tol,
4424                "probe={probe:?} analytic={analytic} numeric={numeric} tol={tol}"
4425            );
4426        }
4427    }
4428
4429    #[test]
4430    fn tbptt_sgd_step_reduces_mean_segment_loss() {
4431        let cfg = test_cfg();
4432        cfg.validate().expect("valid test config");
4433        let mut model = Model::new_random(cfg.clone(), 0x1234_5678).expect("random model");
4434        let steps = [(0u32, 1u8), (1, 2), (2, 3), (3, 4)];
4435        let before = segment_loss(&model, &cfg, &steps);
4436
4437        let mut scratch = ScratchBuffers::new(&cfg);
4438        let start_state = model.new_state();
4439        let mut state = start_state.clone();
4440        let mut segment_steps = Vec::with_capacity(steps.len());
4441        for &(input, target) in &steps {
4442            let logits = model.forward(&mut scratch, input, &mut state);
4443            segment_steps.push((input, target, softmax_pdf(logits)));
4444        }
4445
4446        let mut live_state = model.new_state();
4447        let mut adam_t = 0usize;
4448        let scope = TrainScopeMask {
4449            embed: true,
4450            layer_norm: true,
4451            mixer_conv: true,
4452            mixer_ssm: true,
4453            mixer_proj: true,
4454            head: true,
4455            bias: false,
4456        };
4457
4458        model
4459            .online_train_segment_tbptt(
4460                &mut scratch,
4461                &start_state,
4462                &segment_steps,
4463                scope,
4464                OptimizerKind::Sgd,
4465                8e-4,
4466                0.0,
4467                2,
4468                &mut adam_t,
4469                None,
4470                None,
4471                None,
4472                None,
4473                &mut live_state,
4474            )
4475            .expect("tbptt sgd step");
4476
4477        let after = segment_loss(&model, &cfg, &steps);
4478        assert!(
4479            after < before,
4480            "expected SGD TBPTT step to reduce mean loss: before={before} after={after}"
4481        );
4482    }
4483}